async_raw/
lib.rs

1use std::{ffi::c_int, io, os::fd::{AsRawFd, RawFd}};
2use tokio::io::unix::AsyncFd;
3
4pub struct SockOpts<'opt> {
5    /// The ethernet protocol type to bind this socket to. [`libc::ETH_P_ALL`] for example 
6    /// would allow reading and writing all arbitrary packet types
7    protocol: c_int,
8    /// The name of the interface to bind this raw socket to
9    intf: &'opt str,
10}
11
12pub struct RawSock {
13    fd: AsyncFd<RawFd>,
14}
15
16impl RawSock {
17    pub fn new(opts: SockOpts) -> Result<Self, io::Error> {
18        unsafe {
19            if opts.intf.len() >= libc::IFNAMSIZ {
20                return Err(io::Error::other("invalid interface name - exceeds length"));
21            }
22
23            let sock_fd = libc::socket(
24                libc::AF_PACKET,
25                libc::SOCK_RAW | libc::SOCK_NONBLOCK,
26                opts.protocol
27            );
28
29            if sock_fd < 0 {
30                return Err(io::Error::last_os_error())
31            }
32
33            let mut ifreq = libc::ifreq {
34                ifr_name: [0;libc::IFNAMSIZ],
35                ifr_ifru: std::mem::zeroed(),
36            };
37
38            let intf_c = &*(opts.intf.as_bytes() as *const _ as *const [i8]);
39            ifreq.ifr_name[..intf_c.len()].copy_from_slice(intf_c);
40            
41            if libc::ioctl(
42                sock_fd,
43                libc::SIOCGIFINDEX,
44                &ifreq as *const _,
45            ) < 0 {
46                return Err(io::Error::last_os_error())
47            }
48        
49            let addr = libc::sockaddr_ll {
50                sll_family: libc::AF_PACKET as u16,
51                sll_protocol: u16::to_be(opts.protocol as u16),
52                sll_ifindex: ifreq.ifr_ifru.ifru_ifindex,
53                sll_hatype: 0,
54                sll_pkttype: 0,
55                sll_halen: 0,
56                sll_addr: [0; 8],
57            };
58            
59            if libc::bind(sock_fd, &addr as *const _ as *const libc::sockaddr, std::mem::size_of::<libc::sockaddr_ll>() as u32) < 0 {
60                return Err(io::Error::last_os_error())
61            }
62
63            Ok(Self {
64                fd: AsyncFd::new(sock_fd).unwrap(),
65            })
66        }
67    }
68
69    pub async fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
70        loop {
71            let guard = self.fd.readable().await?;
72
73            unsafe {
74                let res = libc::recv(
75                    guard.get_ref().as_raw_fd(),
76                    buf as *mut _ as *mut libc::c_void,
77                    buf.len(), 
78                    0
79                );
80
81                if res < 0 {
82                    let err = io::Error::last_os_error();
83
84                    match err.kind() {
85                        io::ErrorKind::WouldBlock => continue,
86                        _ => return Err(err)
87                    }
88                } else { 
89                    return Ok(res as usize)
90                }
91            }
92        }
93    }
94
95    pub async fn write(&self, buf: &[u8]) -> io::Result<usize> {
96        loop {
97            let guard = self.fd.writable().await?;
98
99            unsafe {
100                let res = libc::send(
101                    guard.get_ref().as_raw_fd(),
102                    buf as *const _ as *const libc::c_void,
103                    buf.len(),
104                    0,
105                );
106
107                if res < 0 {
108                    let err = io::Error::last_os_error();
109
110                    match err.kind() {
111                        io::ErrorKind::WouldBlock => continue,
112                        _ => return Err(err)
113                    }
114                } else { 
115                    return Ok(res as usize)
116                }
117            }
118        }
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125
126    #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
127    async fn test_creation() {
128        let my_sock = RawSock::new(SockOpts { protocol: libc::ETH_P_ALL, intf: "lo" }).unwrap();
129
130        let mut my_buf = [0u8;128];
131
132        // ICMP localhost -> localhost
133        let packet: &[u8] = &[
134            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x86, 0xdd, 0x60, 0x04, 0x90, 0x15, 0x00, 0x40, 0x3a, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x80, 0x00, 0xd0, 0x40, 0x00, 0x0a, 0x00, 0x01, 0xb9, 0xb1, 0x09, 0x68, 0x00, 0x00, 0x00, 0x00, 0x27, 0x4b, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37,
135        ];
136
137        my_sock.write(&packet).await.unwrap();
138        let read_size = my_sock.read(&mut my_buf).await.unwrap();
139
140        assert_eq!(read_size, packet.len());
141        assert_eq!(&my_buf[..read_size], packet);
142    }
143}