async-raw 0.1.0

Demo project of using raw read/writable sockets with tokio
Documentation
use std::{ffi::c_int, io, os::fd::{AsRawFd, RawFd}};
use tokio::io::unix::AsyncFd;

pub struct SockOpts<'opt> {
    /// The ethernet protocol type to bind this socket to. [`libc::ETH_P_ALL`] for example 
    /// would allow reading and writing all arbitrary packet types
    protocol: c_int,
    /// The name of the interface to bind this raw socket to
    intf: &'opt str,
}

pub struct RawSock {
    fd: AsyncFd<RawFd>,
}

impl RawSock {
    pub fn new(opts: SockOpts) -> Result<Self, io::Error> {
        unsafe {
            if opts.intf.len() >= libc::IFNAMSIZ {
                return Err(io::Error::other("invalid interface name - exceeds length"));
            }

            let sock_fd = libc::socket(
                libc::AF_PACKET,
                libc::SOCK_RAW | libc::SOCK_NONBLOCK,
                opts.protocol
            );

            if sock_fd < 0 {
                return Err(io::Error::last_os_error())
            }

            let mut ifreq = libc::ifreq {
                ifr_name: [0;libc::IFNAMSIZ],
                ifr_ifru: std::mem::zeroed(),
            };

            let intf_c = &*(opts.intf.as_bytes() as *const _ as *const [i8]);
            ifreq.ifr_name[..intf_c.len()].copy_from_slice(intf_c);
            
            if libc::ioctl(
                sock_fd,
                libc::SIOCGIFINDEX,
                &ifreq as *const _,
            ) < 0 {
                return Err(io::Error::last_os_error())
            }
        
            let addr = libc::sockaddr_ll {
                sll_family: libc::AF_PACKET as u16,
                sll_protocol: u16::to_be(opts.protocol as u16),
                sll_ifindex: ifreq.ifr_ifru.ifru_ifindex,
                sll_hatype: 0,
                sll_pkttype: 0,
                sll_halen: 0,
                sll_addr: [0; 8],
            };
            
            if libc::bind(sock_fd, &addr as *const _ as *const libc::sockaddr, std::mem::size_of::<libc::sockaddr_ll>() as u32) < 0 {
                return Err(io::Error::last_os_error())
            }

            Ok(Self {
                fd: AsyncFd::new(sock_fd).unwrap(),
            })
        }
    }

    pub async fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
        loop {
            let guard = self.fd.readable().await?;

            unsafe {
                let res = libc::recv(
                    guard.get_ref().as_raw_fd(),
                    buf as *mut _ as *mut libc::c_void,
                    buf.len(), 
                    0
                );

                if res < 0 {
                    let err = io::Error::last_os_error();

                    match err.kind() {
                        io::ErrorKind::WouldBlock => continue,
                        _ => return Err(err)
                    }
                } else { 
                    return Ok(res as usize)
                }
            }
        }
    }

    pub async fn write(&self, buf: &[u8]) -> io::Result<usize> {
        loop {
            let guard = self.fd.writable().await?;

            unsafe {
                let res = libc::send(
                    guard.get_ref().as_raw_fd(),
                    buf as *const _ as *const libc::c_void,
                    buf.len(),
                    0,
                );

                if res < 0 {
                    let err = io::Error::last_os_error();

                    match err.kind() {
                        io::ErrorKind::WouldBlock => continue,
                        _ => return Err(err)
                    }
                } else { 
                    return Ok(res as usize)
                }
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
    async fn test_creation() {
        let my_sock = RawSock::new(SockOpts { protocol: libc::ETH_P_ALL, intf: "lo" }).unwrap();

        let mut my_buf = [0u8;128];

        // ICMP localhost -> localhost
        let packet: &[u8] = &[
            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x86, 0xdd, 0x60, 0x04, 0x90, 0x15, 0x00, 0x40, 0x3a, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x80, 0x00, 0xd0, 0x40, 0x00, 0x0a, 0x00, 0x01, 0xb9, 0xb1, 0x09, 0x68, 0x00, 0x00, 0x00, 0x00, 0x27, 0x4b, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37,
        ];

        my_sock.write(&packet).await.unwrap();
        let read_size = my_sock.read(&mut my_buf).await.unwrap();

        assert_eq!(read_size, packet.len());
        assert_eq!(&my_buf[..read_size], packet);
    }
}