routex 0.1.6

cross platform route implement
Documentation
#![allow(clippy::macro_metavars_in_unsafe)]

use std::{ffi::CString, io, net::IpAddr};

#[cfg(target_os = "linux")]
mod linux;
#[cfg(target_os = "macos")]
mod macos;
#[cfg(target_os = "windows")]
mod windows;

#[cfg(target_os = "linux")]
pub use linux::RouteSock;
#[cfg(target_os = "macos")]
pub use macos::RouteSock;
#[cfg(target_os = "windows")]
pub use windows::RouteSock;

#[macro_export]
#[allow(clippy::macro_metavars_in_unsafe)]
macro_rules! syscall {
    ($fn: ident ( $($arg: expr),* ) ) => {{
        #[allow(unused_unsafe)]
        let res = unsafe { libc::$fn($( $arg), *) };
        if res < 0 {
            Err(std::io::Error::last_os_error())
        } else {
            Ok(res)
        }
    }};
}

#[derive(Debug, Clone)]
pub struct Route {
    pub destination: IpAddr,
    pub prefix: u8,
    pub gateway: Option<IpAddr>,
    pub ifindex: Option<u32>,
}

impl Default for Route {
    fn default() -> Self {
        use std::net::Ipv4Addr;

        Route {
            destination: IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)),
            prefix: 0,
            gateway: None,
            ifindex: None,
        }
    }
}

impl Route {
    pub fn new(destination: IpAddr, prefix: u8) -> Route {
        Route {
            destination,
            prefix,
            gateway: None,
            ifindex: None,
        }
    }

    #[cfg(target_os = "macos")]
    pub(crate) fn mask(&self) -> IpAddr {
        use std::net::{Ipv4Addr, Ipv6Addr};

        match self.destination {
            IpAddr::V4(_) => IpAddr::V4(Ipv4Addr::from(
                u32::MAX.checked_shl(32 - self.prefix as u32).unwrap_or(0),
            )),
            IpAddr::V6(_) => IpAddr::V6(Ipv6Addr::from(
                u128::MAX.checked_shl(128 - self.prefix as u32).unwrap_or(0),
            )),
        }
    }

    #[cfg(target_os = "macos")]
    pub(crate) fn cidr(&mut self, netmask: IpAddr) {
        use std::net::{Ipv4Addr, Ipv6Addr};

        self.prefix = match netmask {
            IpAddr::V4(netmask) => <Ipv4Addr as Into<u32>>::into(netmask).leading_ones() as u8,
            IpAddr::V6(netmask) => <Ipv6Addr as Into<u128>>::into(netmask).leading_ones() as u8,
        }
    }

    pub(crate) fn validate(&self) -> io::Result<()> {
        let max_prefix = match self.destination {
            IpAddr::V4(_) => 32,
            IpAddr::V6(_) => 128,
        };

        if self.prefix > max_prefix {
            return Err(io::Error::new(
                io::ErrorKind::InvalidInput,
                format!("invalid prefix {} for {}", self.prefix, self.destination),
            ));
        }

        if let Some(gateway) = self.gateway {
            if std::mem::discriminant(&self.destination) != std::mem::discriminant(&gateway) {
                return Err(io::Error::new(
                    io::ErrorKind::InvalidInput,
                    format!(
                        "gateway {gateway} does not match destination {}",
                        self.destination
                    ),
                ));
            }
        }

        Ok(())
    }

    #[cfg(target_os = "linux")]
    pub fn via(mut self, gateway: IpAddr) -> Route {
        self.gateway = Some(gateway);
        self
    }

    pub fn gateway(mut self, gateway: IpAddr) -> Route {
        self.gateway = Some(gateway);
        self
    }

    pub fn ifindex(mut self, ifindex: u32) -> Route {
        self.ifindex = Some(ifindex);
        self
    }

    #[cfg(any(target_os = "macos", target_os = "windows"))]
    pub fn interface(mut self, interface: &str) -> Route {
        self.ifindex = if_nametoindex(interface);
        self
    }

    #[cfg(target_os = "linux")]
    pub fn dev(mut self, interface: &str) -> Route {
        self.ifindex = if_nametoindex(interface);
        self
    }
}

#[derive(Debug)]
pub enum RouteChange {
    ADD,
    DELETE,
    CHANGE,
    OTHER(u8),
}

impl From<u8> for RouteChange {
    fn from(value: u8) -> Self {
        match value {
            1 => RouteChange::ADD,
            2 => RouteChange::DELETE,
            3 => RouteChange::CHANGE,
            _ => RouteChange::OTHER(value),
        }
    }
}

pub trait RouteAction {
    fn add(&mut self, route: &Route) -> io::Result<()>;
    fn delete(&mut self, route: &Route) -> io::Result<()>;
    fn get(&mut self, route: &Route) -> io::Result<Route>;
    fn monitor(&mut self, buf: &mut [u8]) -> io::Result<(RouteChange, Route)>;
}

pub fn if_nametoindex(name: &str) -> Option<u32> {
    let name = CString::new(name).ok()?;

    #[cfg(not(target_os = "windows"))]
    let ifindex = unsafe { libc::if_nametoindex(name.as_ptr()) };

    #[cfg(target_os = "windows")]
    let ifindex = unsafe {
        windows_sys::Win32::NetworkManagement::IpHelper::if_nametoindex(name.as_ptr().cast())
    };

    (ifindex != 0).then_some(ifindex)
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::net::{Ipv4Addr, Ipv6Addr};

    #[test]
    fn validate_rejects_prefix_outside_address_family() {
        let route = Route::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 33);

        assert_eq!(
            route.validate().unwrap_err().kind(),
            io::ErrorKind::InvalidInput
        );
    }

    #[test]
    fn validate_rejects_mismatched_gateway_family() {
        let route = Route::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)
            .gateway(IpAddr::V6(Ipv6Addr::LOCALHOST));

        assert_eq!(
            route.validate().unwrap_err().kind(),
            io::ErrorKind::InvalidInput
        );
    }
}