kratanet/
raw_socket.rs

1use anyhow::{anyhow, Result};
2use bytes::BytesMut;
3use log::{debug, warn};
4use std::io::ErrorKind;
5use std::os::fd::{FromRawFd, IntoRawFd};
6use std::os::unix::io::{AsRawFd, RawFd};
7use std::sync::Arc;
8use std::{io, mem};
9use tokio::net::UdpSocket;
10use tokio::select;
11use tokio::sync::mpsc::{channel, Receiver, Sender};
12use tokio::task::JoinHandle;
13
14const RAW_SOCKET_TRANSMIT_QUEUE_LEN: usize = 3000;
15const RAW_SOCKET_RECEIVE_QUEUE_LEN: usize = 3000;
16
17#[derive(Debug)]
18pub enum RawSocketProtocol {
19    Icmpv4,
20    Icmpv6,
21    Ethernet,
22}
23
24impl RawSocketProtocol {
25    pub fn to_socket_domain(&self) -> i32 {
26        match self {
27            RawSocketProtocol::Icmpv4 => libc::AF_INET,
28            RawSocketProtocol::Icmpv6 => libc::AF_INET6,
29            RawSocketProtocol::Ethernet => libc::AF_PACKET,
30        }
31    }
32
33    pub fn to_socket_protocol(&self) -> u16 {
34        match self {
35            RawSocketProtocol::Icmpv4 => libc::IPPROTO_ICMP as u16,
36            RawSocketProtocol::Icmpv6 => libc::IPPROTO_ICMPV6 as u16,
37            RawSocketProtocol::Ethernet => (libc::ETH_P_ALL as u16).to_be(),
38        }
39    }
40
41    pub fn to_socket_type(&self) -> i32 {
42        libc::SOCK_RAW
43    }
44}
45
46const SIOCGIFINDEX: libc::c_ulong = 0x8933;
47const SIOCGIFMTU: libc::c_ulong = 0x8921;
48
49#[derive(Debug)]
50pub struct RawSocketHandle {
51    protocol: RawSocketProtocol,
52    lower: libc::c_int,
53}
54
55impl AsRawFd for RawSocketHandle {
56    fn as_raw_fd(&self) -> RawFd {
57        self.lower
58    }
59}
60
61impl IntoRawFd for RawSocketHandle {
62    fn into_raw_fd(self) -> RawFd {
63        let fd = self.lower;
64        mem::forget(self);
65        fd
66    }
67}
68
69impl RawSocketHandle {
70    pub fn new(protocol: RawSocketProtocol) -> io::Result<RawSocketHandle> {
71        let lower = unsafe {
72            let lower = libc::socket(
73                protocol.to_socket_domain(),
74                protocol.to_socket_type() | libc::SOCK_NONBLOCK,
75                protocol.to_socket_protocol() as i32,
76            );
77            if lower == -1 {
78                return Err(io::Error::last_os_error());
79            }
80            lower
81        };
82
83        Ok(RawSocketHandle { protocol, lower })
84    }
85
86    pub fn bound_to_interface(interface: &str, protocol: RawSocketProtocol) -> Result<Self> {
87        let mut socket = RawSocketHandle::new(protocol)?;
88        socket.bind_to_interface(interface)?;
89        Ok(socket)
90    }
91
92    pub fn bind_to_interface(&mut self, interface: &str) -> io::Result<()> {
93        let mut ifreq = ifreq_for(interface);
94        let sockaddr = libc::sockaddr_ll {
95            sll_family: libc::AF_PACKET as u16,
96            sll_protocol: self.protocol.to_socket_protocol(),
97            sll_ifindex: ifreq_ioctl(self.lower, &mut ifreq, SIOCGIFINDEX)?,
98            sll_hatype: 1,
99            sll_pkttype: 0,
100            sll_halen: 6,
101            sll_addr: [0; 8],
102        };
103
104        unsafe {
105            let res = libc::bind(
106                self.lower,
107                &sockaddr as *const libc::sockaddr_ll as *const libc::sockaddr,
108                mem::size_of::<libc::sockaddr_ll>() as libc::socklen_t,
109            );
110            if res == -1 {
111                return Err(io::Error::last_os_error());
112            }
113        }
114
115        Ok(())
116    }
117
118    pub fn mtu_of_interface(&mut self, interface: &str) -> io::Result<usize> {
119        let mut ifreq = ifreq_for(interface);
120        ifreq_ioctl(self.lower, &mut ifreq, SIOCGIFMTU).map(|mtu| mtu as usize)
121    }
122
123    pub fn recv(&self, buffer: &mut [u8]) -> io::Result<usize> {
124        unsafe {
125            let len = libc::recv(
126                self.lower,
127                buffer.as_mut_ptr() as *mut libc::c_void,
128                buffer.len(),
129                0,
130            );
131            if len == -1 {
132                return Err(io::Error::last_os_error());
133            }
134            Ok(len as usize)
135        }
136    }
137
138    pub fn send(&self, buffer: &[u8]) -> io::Result<usize> {
139        unsafe {
140            let len = libc::send(
141                self.lower,
142                buffer.as_ptr() as *const libc::c_void,
143                buffer.len(),
144                0,
145            );
146            if len == -1 {
147                return Err(io::Error::last_os_error());
148            }
149            Ok(len as usize)
150        }
151    }
152}
153
154impl Drop for RawSocketHandle {
155    fn drop(&mut self) {
156        unsafe {
157            libc::close(self.lower);
158        }
159    }
160}
161
162#[repr(C)]
163#[derive(Debug)]
164struct Ifreq {
165    ifr_name: [libc::c_char; libc::IF_NAMESIZE],
166    ifr_data: libc::c_int,
167}
168
169fn ifreq_for(name: &str) -> Ifreq {
170    let mut ifreq = Ifreq {
171        ifr_name: [0; libc::IF_NAMESIZE],
172        ifr_data: 0,
173    };
174    for (i, byte) in name.as_bytes().iter().enumerate() {
175        ifreq.ifr_name[i] = *byte as libc::c_char
176    }
177    ifreq
178}
179
180fn ifreq_ioctl(
181    lower: libc::c_int,
182    ifreq: &mut Ifreq,
183    cmd: libc::c_ulong,
184) -> io::Result<libc::c_int> {
185    unsafe {
186        let res = libc::ioctl(lower, cmd as _, ifreq as *mut Ifreq);
187        if res == -1 {
188            return Err(io::Error::last_os_error());
189        }
190    }
191
192    Ok(ifreq.ifr_data)
193}
194
195pub struct AsyncRawSocketChannel {
196    pub sender: Sender<BytesMut>,
197    pub receiver: Receiver<BytesMut>,
198    _task: Arc<JoinHandle<()>>,
199}
200
201enum AsyncRawSocketChannelSelect {
202    TransmitPacket(Option<BytesMut>),
203    Readable(()),
204}
205
206impl AsyncRawSocketChannel {
207    pub fn new(mtu: usize, socket: RawSocketHandle) -> Result<AsyncRawSocketChannel> {
208        let (transmit_sender, transmit_receiver) = channel(RAW_SOCKET_TRANSMIT_QUEUE_LEN);
209        let (receive_sender, receive_receiver) = channel(RAW_SOCKET_RECEIVE_QUEUE_LEN);
210        let task = AsyncRawSocketChannel::launch(mtu, socket, transmit_receiver, receive_sender)?;
211        Ok(AsyncRawSocketChannel {
212            sender: transmit_sender,
213            receiver: receive_receiver,
214            _task: Arc::new(task),
215        })
216    }
217
218    fn launch(
219        mtu: usize,
220        socket: RawSocketHandle,
221        transmit_receiver: Receiver<BytesMut>,
222        receive_sender: Sender<BytesMut>,
223    ) -> Result<JoinHandle<()>> {
224        Ok(tokio::task::spawn(async move {
225            if let Err(error) =
226                AsyncRawSocketChannel::process(mtu, socket, transmit_receiver, receive_sender).await
227            {
228                warn!("failed to process raw socket: {}", error);
229            }
230        }))
231    }
232
233    async fn process(
234        mtu: usize,
235        socket: RawSocketHandle,
236        mut transmit_receiver: Receiver<BytesMut>,
237        receive_sender: Sender<BytesMut>,
238    ) -> Result<()> {
239        let socket = unsafe { std::net::UdpSocket::from_raw_fd(socket.into_raw_fd()) };
240        let socket = UdpSocket::from_std(socket)?;
241
242        let tear_off_size = 100 * mtu;
243        let mut buffer: BytesMut = BytesMut::with_capacity(tear_off_size);
244        loop {
245            if buffer.capacity() < mtu {
246                buffer = BytesMut::with_capacity(tear_off_size);
247            }
248
249            let selection = select! {
250                x = transmit_receiver.recv() => AsyncRawSocketChannelSelect::TransmitPacket(x),
251                x = socket.readable() => AsyncRawSocketChannelSelect::Readable(x?),
252            };
253
254            match selection {
255                AsyncRawSocketChannelSelect::Readable(_) => {
256                    buffer.resize(mtu, 0);
257                    match socket.try_recv(&mut buffer) {
258                        Ok(len) => {
259                            if len == 0 {
260                                continue;
261                            }
262                            let packet = buffer.split_to(len);
263                            if let Err(error) = receive_sender.try_send(packet) {
264                                debug!(
265                                    "failed to process received packet from raw socket: {}",
266                                    error
267                                );
268                            }
269                        }
270
271                        Err(ref error) => {
272                            if error.kind() == ErrorKind::WouldBlock {
273                                continue;
274                            }
275
276                            // device no longer exists
277                            if error.raw_os_error() == Some(6) {
278                                break;
279                            }
280
281                            return Err(anyhow!("failed to read from raw socket: {}", error));
282                        }
283                    };
284                }
285
286                AsyncRawSocketChannelSelect::TransmitPacket(Some(packet)) => {
287                    match socket.try_send(&packet) {
288                        Ok(_len) => {}
289                        Err(ref error) => {
290                            if error.kind() == ErrorKind::WouldBlock {
291                                debug!("failed to transmit: would block");
292                                continue;
293                            }
294
295                            // device no longer exists
296                            if error.raw_os_error() == Some(6) {
297                                break;
298                            }
299
300                            return Err(anyhow!(
301                                "failed to write {} bytes to raw socket: {}",
302                                packet.len(),
303                                error
304                            ));
305                        }
306                    };
307                }
308
309                AsyncRawSocketChannelSelect::TransmitPacket(None) => {
310                    break;
311                }
312            }
313        }
314
315        Ok(())
316    }
317}