netlink-socket 0.0.1

netlink sockets, with optional integration with mio and tokio
Documentation
//! Netlink socket related functions
use libc;
use std::io::{Error, Result};
use std::mem;
use std::os::unix::io::{AsRawFd, RawFd};

use super::Protocol;

pub struct Socket(RawFd);

impl AsRawFd for Socket {
    fn as_raw_fd(&self) -> RawFd {
        self.0
    }
}

impl Drop for Socket {
    fn drop(&mut self) {
        unsafe { libc::close(self.0) };
    }
}

#[derive(Copy, Clone)]
pub struct SocketAddr(libc::sockaddr_nl);

impl SocketAddr {
    pub fn new(port_number: u32, multicast_groups: u32) -> Self {
        let mut addr: libc::sockaddr_nl = unsafe { mem::zeroed() };
        addr.nl_family = libc::PF_NETLINK as libc::sa_family_t;
        addr.nl_pid = port_number;
        addr.nl_groups = multicast_groups;
        SocketAddr(addr)
    }

    pub fn port_number(&self) -> u32 {
        self.0.nl_pid
    }

    pub fn multicast_groups(&self) -> u32 {
        self.0.nl_groups
    }

    fn as_raw(&self) -> (*const libc::sockaddr, libc::socklen_t) {
        let addr_ptr = &self.0 as *const libc::sockaddr_nl as *const libc::sockaddr;
        //             \                                 / \                      /
        //              +---------------+---------------+   +----------+---------+
        //                               |                             |
        //                               v                             |
        //             create a raw pointer to the sockaddr_nl         |
        //                                                             v
        //                                                cast *sockaddr_nl -> *sockaddr
        //
        // This kind of things seems to be pretty usual when using C APIs from Rust. It could be
        // written in a shorter way thank to type inference:
        //
        //      let addr_ptr: *const libc:sockaddr = &self.0 as *const _ as *const _;
        //
        // But since this is my first time dealing with this kind of things I chose the most
        // explicit form.

        let addr_len = mem::size_of::<libc::sockaddr_nl>() as libc::socklen_t;
        (addr_ptr, addr_len)
    }

    fn as_raw_mut(&mut self) -> (*mut libc::sockaddr, libc::socklen_t) {
        let addr_ptr = &mut self.0 as *mut libc::sockaddr_nl as *mut libc::sockaddr;
        let addr_len = mem::size_of::<libc::sockaddr_nl>() as libc::socklen_t;
        (addr_ptr, addr_len)
    }
}

impl Socket {
    pub fn new(protocol: Protocol) -> Result<Self> {
        let res =
            unsafe { libc::socket(libc::PF_NETLINK, libc::SOCK_DGRAM, protocol as libc::c_int) };
        if res < 0 {
            return Err(Error::last_os_error());
        }
        Ok(Socket(res))
    }

    pub fn bind(&mut self, addr: &SocketAddr) -> Result<()> {
        let (addr_ptr, addr_len) = addr.as_raw();
        let res = unsafe { libc::bind(self.0, addr_ptr, addr_len) };
        if res < 0 {
            return Err(Error::last_os_error());
        }
        Ok(())
    }

    pub fn bind_auto(&mut self) -> Result<SocketAddr> {
        let mut addr = SocketAddr::new(0, 0);
        self.bind(&addr)?;
        self.get_address(&mut addr)?;
        Ok(addr)
    }

    pub fn get_address(&self, addr: &mut SocketAddr) -> Result<()> {
        let (addr_ptr, mut addr_len) = addr.as_raw_mut();
        let addr_len_copy = addr_len;
        let addr_len_ptr = &mut addr_len as *mut libc::socklen_t;
        let res = unsafe { libc::getsockname(self.0, addr_ptr, addr_len_ptr) };
        if res < 0 {
            return Err(Error::last_os_error());
        }
        assert_eq!(addr_len, addr_len_copy);
        Ok(())
    }

    pub fn set_non_blocking(&self, non_blocking: bool) -> Result<()> {
        let mut non_blocking = non_blocking as libc::c_int;
        let res = unsafe { libc::ioctl(self.0, libc::FIONBIO, &mut non_blocking) };
        if res < 0 {
            return Err(Error::last_os_error());
        }
        Ok(())
    }

    pub fn connect(&self, remote_addr: &SocketAddr) -> Result<()> {
        // Event though for SOCK_DGRAM sockets there's no IO, since our socket is non-blocking,
        // connect() might return EINPROGRESS. In theory, the right way to treat EINPROGRESS would
        // be to ignore the error, and let the user poll the socket to check when it becomes
        // writable, indicating that the connection succeeded. The code already exists in mio for
        // TcpStream:
        //
        // > pub fn connect(stream: net::TcpStream, addr: &SocketAddr) -> io::Result<TcpStream> {
        // >     set_non_block(stream.as_raw_fd())?;
        // >     match stream.connect(addr) {
        // >         Ok(..) => {}
        // >         Err(ref e) if e.raw_os_error() == Some(libc::EINPROGRESS) => {}
        // >         Err(e) => return Err(e),
        // >     }
        // >     Ok(TcpStream {  inner: stream })
        // > }
        //
        // The polling to wait for the connection is available in the tokio-tcp crate. See:
        // https://github.com/tokio-rs/tokio/blob/363b207f2b6c25857c70d76b303356db87212f59/tokio-tcp/src/stream.rs#L706
        //
        // In practice, since the connection does not require any IO for SOCK_DGRAM sockets, it
        // almost never returns EINPROGRESS and so for now, we just return whatever libc::connect
        // returns. If it returns EINPROGRESS, the caller will have to handle the error themself
        //
        // Refs:
        //
        // - https://stackoverflow.com/a/14046386/1836144
        // - https://lists.isc.org/pipermail/bind-users/2009-August/077527.html
        let (addr, addr_len) = remote_addr.as_raw();
        let res = unsafe { libc::connect(self.0, addr, addr_len) };
        if res < 0 {
            return Err(Error::last_os_error());
        }
        Ok(())
    }

    // Most of the comments in this method come from a discussion on rust users forum.
    // [thread]: https://users.rust-lang.org/t/help-understanding-libc-call/17308/9
    pub fn recv_from(&self, buf: &mut [u8], flags: libc::c_int) -> Result<(usize, SocketAddr)> {
        // Create an empty storage for the address. Note that Rust standard library create a
        // sockaddr_storage so that it works for any address family, but here, we already know that
        // we'll have a Netlink address, so we can create the appropriate storage.
        let mut addr = unsafe { mem::zeroed::<libc::sockaddr_nl>() };

        // recvfrom takes a *sockaddr as parameter so that it can accept any kind of address
        // storage, so we need to create such a pointer for the sockaddr_nl we just initialized.
        //
        //                     Create a raw pointer to        Cast our raw pointer to a
        //                     our storage. We cannot         generic pointer to *sockaddr
        //                     pass it to recvfrom yet.       that recvfrom can use
        //                                 ^                              ^
        //                                 |                              |
        //                  +--------------+---------------+    +---------+--------+
        //                 /                                \  /                    \
        let addr_ptr = &mut addr as *mut libc::sockaddr_nl as *mut libc::sockaddr;

        // Why do we need to pass the address length? We're passing a generic *sockaddr to
        // recvfrom. Somehow recvfrom needs to make sure that the address of the received packet
        // would fit into the actual type that is behind *sockaddr: it could be a sockaddr_nl but
        // also a sockaddr_in, a sockaddr_in6, or even the generic sockaddr_storage that can store
        // any address.
        let mut addrlen = mem::size_of_val(&addr);
        // recvfrom does not take the address length by value (see [thread]), so we need to create
        // a pointer to it.
        let addrlen_ptr = &mut addrlen as *mut usize as *mut libc::socklen_t;

        //                      Cast the *mut u8 into *mut void.
        //               This is equivalent to casting a *char into *void
        //                                 See [thread]
        //                                       ^
        //           Create a *mut u8            |
        //                   ^                   |
        //                   |                   |
        //             +-----+-----+    +--------+-------+
        //            /             \  /                  \
        let buf_ptr = buf.as_mut_ptr() as *mut libc::c_void;
        let buf_len = buf.len() as libc::size_t;

        let res = unsafe { libc::recvfrom(self.0, buf_ptr, buf_len, flags, addr_ptr, addrlen_ptr) };
        if res < 0 {
            return Err(Error::last_os_error());
        }
        Ok((res as usize, SocketAddr(addr)))
    }

    pub fn recv(&self, buf: &mut [u8], flags: libc::c_int) -> Result<usize> {
        let buf_ptr = buf.as_mut_ptr() as *mut libc::c_void;
        let buf_len = buf.len() as libc::size_t;

        let res = unsafe { libc::recv(self.0, buf_ptr, buf_len, flags) };
        if res < 0 {
            return Err(Error::last_os_error());
        }
        Ok(res as usize)
    }

    pub fn send_to(&self, buf: &[u8], addr: &SocketAddr, flags: libc::c_int) -> Result<usize> {
        let (addr_ptr, addr_len) = addr.as_raw();
        let buf_ptr = buf.as_ptr() as *const libc::c_void;
        let buf_len = buf.len() as libc::size_t;

        let res = unsafe { libc::sendto(self.0, buf_ptr, buf_len, flags, addr_ptr, addr_len) };
        if res < 0 {
            return Err(Error::last_os_error());
        }
        Ok(res as usize)
    }

    pub fn send(&self, buf: &[u8], flags: libc::c_int) -> Result<usize> {
        let buf_ptr = buf.as_ptr() as *const libc::c_void;
        let buf_len = buf.len() as libc::size_t;

        let res = unsafe { libc::send(self.0, buf_ptr, buf_len, flags) };
        if res < 0 {
            return Err(Error::last_os_error());
        }
        Ok(res as usize)
    }

    pub fn set_pktinfo(&mut self, set: bool) -> Result<()> {
        setsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_PKTINFO, set)
    }

    pub fn get_pktinfo(&self) -> Result<bool> {
        getsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_PKTINFO)
    }

    pub fn add_membership(&mut self, group: u32) -> Result<()> {
        setsockopt(
            self.0,
            libc::SOL_NETLINK,
            libc::NETLINK_ADD_MEMBERSHIP,
            group,
        )
    }

    pub fn drop_membership(&mut self, group: u32) -> Result<()> {
        setsockopt(
            self.0,
            libc::SOL_NETLINK,
            libc::NETLINK_DROP_MEMBERSHIP,
            group,
        )
    }

    pub fn list_membership(&self) -> Vec<u32> {
        unimplemented!();
        // getsockopt won't be enough here, because we may need to perform 2 calls, and because the
        // length of the list returned by libc::getsockopt is returned by mutating the length
        // argument, which our implementation of getsockopt forbids.
    }

    pub fn set_broadcast_error(&mut self, set: bool) -> Result<()> {
        setsockopt(
            self.0,
            libc::SOL_NETLINK,
            libc::NETLINK_BROADCAST_ERROR,
            set,
        )
    }

    pub fn get_broadcast_error(&self) -> Result<bool> {
        getsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_BROADCAST_ERROR)
    }

    pub fn set_no_enobufs(&mut self, set: bool) -> Result<()> {
        setsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_NO_ENOBUFS, set)
    }

    pub fn get_no_enobufs(&self) -> Result<bool> {
        getsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_NO_ENOBUFS)
    }

    pub fn set_listen_all_namespaces(&mut self, set: bool) -> Result<()> {
        setsockopt(
            self.0,
            libc::SOL_NETLINK,
            libc::NETLINK_LISTEN_ALL_NSID,
            set,
        )
    }

    pub fn get_listen_all_namespaces(&self) -> Result<bool> {
        getsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_LISTEN_ALL_NSID)
    }

    pub fn set_cap_ack(&mut self, set: bool) -> Result<()> {
        setsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_CAP_ACK, set)
    }

    pub fn get_cap_ack(&self) -> Result<bool> {
        getsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_CAP_ACK)
    }
}

// FIXME: setsockopt and getsockopt don't work... We get always get EINVAL, which the manpage
// describes as:
//
//  > The specified option is invalid at the specified socket level or the socket has been shut
//  > down.

// adapted from rust standard library
fn getsockopt<T: Copy>(fd: RawFd, opt: libc::c_int, val: libc::c_int) -> Result<T> {
    unsafe {
        // Create storage for the options we're fetching
        let mut slot: T = mem::zeroed();

        // Create a mutable raw pointer to the storage so that getsockopt can fill the value
        let slot_ptr = &mut slot as *mut T as *mut libc::c_void;

        // Let getsockopt know how big our storage is
        let mut slot_len = mem::size_of::<T>() as libc::socklen_t;

        // getsockopt takes a mutable pointer to the length, because for some options like
        // NETLINK_LIST_MEMBERSHIP where the option value is a list with arbitrary length,
        // getsockopt uses this parameter to signal how big the storage needs to be.
        let slot_len_ptr = &mut slot_len as *mut libc::socklen_t;

        let res = libc::getsockopt(fd, opt, val, slot_ptr, slot_len_ptr);
        if res < 0 {
            return Err(Error::last_os_error());
        }

        // Ignore the options that require the legnth to be set by getsockopt.
        // We'll deal with them individually.
        assert_eq!(slot_len as usize, mem::size_of::<T>());

        Ok(slot)
    }
}

// adapted from rust standard library
fn setsockopt<T>(fd: RawFd, opt: libc::c_int, val: libc::c_int, payload: T) -> Result<()> {
    unsafe {
        let payload = &payload as *const T as *const libc::c_void;
        let payload_len = mem::size_of::<T>() as libc::socklen_t;

        let res = libc::setsockopt(fd, opt, val, payload, payload_len);
        if res < 0 {
            return Err(Error::last_os_error());
        }
    }
    Ok(())
}

#[cfg(test)]
mod test {
    use super::*;

    #[test]
    fn new() {
        Socket::new(Protocol::Route).unwrap();
    }

    #[test]
    fn connect() {
        let sock = Socket::new(Protocol::Route).unwrap();
        sock.connect(&SocketAddr::new(0, 0)).unwrap();
    }

    #[test]
    fn bind() {
        let mut sock = Socket::new(Protocol::Route).unwrap();
        sock.bind(&SocketAddr::new(4321, 0)).unwrap();
    }

    #[test]
    fn bind_auto() {
        let mut sock = Socket::new(Protocol::Route).unwrap();
        let addr = sock.bind_auto().unwrap();
        // make sure that the address we got from the kernel is there
        assert!(addr.port_number() != 0);
    }

    #[test]
    fn set_non_blocking() {
        let sock = Socket::new(Protocol::Route).unwrap();
        sock.set_non_blocking(true).unwrap();
        sock.set_non_blocking(false).unwrap();
    }

    // FIXME!
    // #[test]
    // fn options() {
    //     let mut sock = Socket::new(Protocol::Route).unwrap();

    //     sock.set_no_enobufs(true).unwrap();
    //     assert!(sock.get_no_enobufs().unwrap());
    //     sock.set_no_enobufs(false).unwrap();
    //     assert!(!sock.get_no_enobufs().unwrap());

    //     sock.set_broadcast_error(true).unwrap();
    //     assert!(sock.get_broadcast_error().unwrap());
    //     sock.set_broadcast_error(false).unwrap();
    //     assert!(!sock.get_broadcast_error().unwrap());

    //     sock.set_cap_ack(true).unwrap();
    //     assert!(sock.get_cap_ack().unwrap());
    //     sock.set_cap_ack(false).unwrap();
    //     assert!(!sock.get_cap_ack().unwrap());

    //     sock.set_listen_all_namespaces(true).unwrap();
    //     assert!(sock.get_listen_all_namespaces().unwrap());
    //     sock.set_listen_all_namespaces(false).unwrap();
    //     assert!(!sock.get_listen_all_namespaces().unwrap());
    // }

    #[test]
    fn address() {
        let mut addr = SocketAddr::new(42, 1234);
        assert_eq!(addr.port_number(), 42);
        assert_eq!(addr.multicast_groups(), 1234);

        {
            let (addr_ptr, _) = addr.as_raw();
            let inner_addr = unsafe { *(addr_ptr as *const libc::sockaddr_nl) };
            assert_eq!(inner_addr.nl_pid, 42);
            assert_eq!(inner_addr.nl_groups, 1234);
        }

        {
            let (addr_ptr, _) = addr.as_raw_mut();
            let sockaddr_nl = addr_ptr as *mut libc::sockaddr_nl;
            unsafe {
                sockaddr_nl.as_mut().unwrap().nl_pid = 24;
                sockaddr_nl.as_mut().unwrap().nl_groups = 4321
            }
        }
        assert_eq!(addr.port_number(), 24);
        assert_eq!(addr.multicast_groups(), 4321);
    }
}