wiretun 0.5.0

WireGuard Library
Documentation
use std::io;
use std::mem::{size_of, size_of_val};
use std::os::fd::{AsRawFd, FromRawFd, OwnedFd};
use std::sync::Arc;

use async_trait::async_trait;
use bytes::{Buf, BytesMut};
use regex::Regex;
use tokio::io::unix::AsyncFd;
use tracing::debug;

use super::sys;
use crate::tun::{Error, Tun};

#[inline]
fn parse_name(name: &str) -> Result<u32, Error> {
    if name == "utun" {
        return Ok(0);
    }
    let re = Regex::new(r"^utun([1-9]\d*|0)?$").unwrap();
    if !re.is_match(name) {
        return Err(Error::InvalidName);
    }
    name[4..]
        .parse()
        .map(|i: u32| i + 1)
        .map_err(|_| Error::InvalidName)
}

#[derive(Debug, Clone)]
pub struct NativeTun {
    fd: Arc<AsyncFd<OwnedFd>>,
    name: String,
}

impl NativeTun {
    pub fn new(name: &str) -> Result<Self, Error> {
        let idx = parse_name(name)?;

        let fd = match unsafe {
            libc::socket(libc::PF_SYSTEM, libc::SOCK_DGRAM, libc::SYSPROTO_CONTROL)
        } {
            -1 => return Err(io::Error::last_os_error().into()),
            fd => unsafe { OwnedFd::from_raw_fd(fd) },
        };

        let info = libc::ctl_info {
            ctl_id: 0,
            ctl_name: sys::CTRL_NAME,
        };
        if unsafe { libc::ioctl(fd.as_raw_fd(), libc::CTLIOCGINFO, &info) } < 0 {
            return Err(io::Error::last_os_error().into());
        }

        let addr = libc::sockaddr_ctl {
            sc_len: size_of::<libc::sockaddr_ctl>() as _,
            sc_family: libc::AF_SYSTEM as _,
            ss_sysaddr: libc::AF_SYS_CONTROL as _,
            sc_id: info.ctl_id,
            sc_unit: idx,
            sc_reserved: Default::default(),
        };
        if unsafe {
            libc::connect(
                fd.as_raw_fd(),
                &addr as *const libc::sockaddr_ctl as _,
                size_of_val(&addr) as _,
            )
        } < 0
        {
            return Err(io::Error::last_os_error().into());
        }

        sys::set_nonblocking(fd.as_raw_fd())?;

        let name = unsafe { sys::get_iface_name(fd.as_raw_fd()) }?;
        let fd = Arc::new(AsyncFd::new(fd)?);

        Ok(Self { fd, name })
    }
}

#[async_trait]
impl Tun for NativeTun {
    fn name(&self) -> &str {
        &self.name
    }

    fn set_mtu(&self, mtu: u16) -> Result<(), Error> {
        let mut req = sys::ifreq::new(&self.name);
        req.ifru.mtu = mtu as _;
        unsafe { sys::ioctl_set_mtu(self.fd.as_raw_fd(), &req) }?;

        Ok(())
    }

    fn mtu(&self) -> Result<u16, Error> {
        let mut req = sys::ifreq::new(&self.name);

        unsafe { sys::ioctl_get_mtu(self.fd.as_raw_fd(), &mut req) }?;

        Ok(unsafe { req.ifru.mtu as _ })
    }

    async fn recv(&self) -> Result<Vec<u8>, Error> {
        let mut buf = BytesMut::zeroed(1500);

        loop {
            let ret = {
                let mut guard = self.fd.readable().await?;
                guard.try_io(|inner| unsafe {
                    let ret = libc::read(inner.as_raw_fd(), buf.as_mut_ptr() as _, buf.len());
                    if ret < 0 {
                        Err::<usize, io::Error>(io::Error::last_os_error())
                    } else {
                        Ok(ret as usize)
                    }
                })
            };

            match ret {
                Ok(Ok(n)) if n >= 4 => {
                    debug!("TUN read {} bytes", n);
                    buf.advance(4);
                    buf.truncate(n - 4);
                    return Ok(buf.freeze().to_vec());
                }
                Ok(Err(e)) => return Err(e.into()),
                _ => continue,
            }
        }
    }

    async fn send(&self, buf: &[u8]) -> Result<(), Error> {
        let buf = {
            let mut m = vec![0u8; 4 + buf.len()];
            m[3] = match buf[0] >> 4 {
                4 => 0x2,
                6 => 0x1e,
                _ => return Err(Error::InvalidIpPacket),
            };
            m[4..].copy_from_slice(buf);
            m
        };

        let mut guard = self.fd.writable().await?;
        let ret = guard.try_io(|inner| unsafe {
            let ret = libc::write(inner.as_raw_fd(), buf.as_ptr() as _, buf.len());
            if ret < 0 {
                Err::<usize, io::Error>(io::Error::last_os_error())
            } else {
                Ok(ret as usize)
            }
        });

        match ret {
            Ok(Ok(_)) => return Ok(()),
            Ok(Err(e)) => return Err(e.into()),
            _ => {}
        }

        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_parse_name() {
        let success_cases = [("utun", 0), ("utun0", 1), ("utun42", 43)];

        for (input, expected) in success_cases {
            let rv = parse_name(input);
            assert!(rv.is_ok());
            assert_eq!(rv.unwrap(), expected);
        }

        let failure_cases = ["utun04", "utun007", "utun42foo", "utunfoo", "futun"];

        for input in failure_cases {
            assert!(parse_name(input).is_err())
        }
    }
}