rsln 0.1.1

Netlink library implemented in Rust that provides the netlink protocol based kernel interfaces
Documentation
use anyhow::{bail, Result};

use crate::{
    core::{message::Message, socket::Socket},
    handle::rule::RuleHandle,
};

use super::{
    addr::AddrHandle, generic::GenericHandle, link::LinkHandle, neigh::NeighHandle,
    routing::RouteHandle, sock_diag::SockDiagHandle,
};

const PID_KERNEL: u32 = 0;

const NLMSG_DONE: u16 = 3;
const NLMSG_ERROR: u16 = 2;

#[derive(Clone)]
pub struct SocketHandle {
    pub socket: Socket,
    pub seq: u32,
}

impl SocketHandle {
    pub fn new(proto: i32) -> Self {
        Self {
            socket: Socket::new(proto, 0, 0).unwrap(),
            seq: 0,
        }
    }

    pub fn next_seq(&mut self) -> u32 {
        self.seq += 1;
        self.seq
    }

    pub fn handle_link(&mut self) -> LinkHandle<'_> {
        LinkHandle::from(self)
    }

    pub fn handle_addr(&mut self) -> AddrHandle<'_> {
        AddrHandle::from(self)
    }

    pub fn handle_route(&mut self) -> RouteHandle<'_> {
        RouteHandle::from(self)
    }

    pub fn handle_rule(&mut self) -> RuleHandle<'_> {
        RuleHandle::from(self)
    }

    pub fn handle_neigh(&mut self) -> NeighHandle<'_> {
        NeighHandle::from(self)
    }

    pub fn handle_generic(&mut self) -> GenericHandle<'_> {
        GenericHandle::from(self)
    }

    pub fn handle_sock_diag(&mut self) -> SockDiagHandle<'_> {
        SockDiagHandle::from(self)
    }

    pub fn request(&mut self, msg: &mut Message, res_type: u16) -> Result<Vec<Vec<u8>>> {
        let next_seq = self.next_seq();
        msg.header.nlmsg_seq = next_seq;

        self.socket.send(&msg.serialize()?)?;

        let pid = self.socket.pid()?;
        let mut res: Vec<Vec<u8>> = Vec::new();

        'done: loop {
            let (msgs, from) = self.socket.recv()?;

            if from.nl_pid != PID_KERNEL {
                bail!(
                    "wrong sender pid: {}, expected: {}",
                    from.nl_pid,
                    PID_KERNEL
                );
            }

            for mut m in msgs {
                if m.verify_header(next_seq, pid).is_err() {
                    continue;
                }

                match m.header.nlmsg_type {
                    NLMSG_DONE | NLMSG_ERROR => {
                        let payload = m.payload.as_ref().unwrap();
                        let err_no = i32::from_ne_bytes(payload[0..4].try_into()?);

                        if err_no == 0 {
                            break 'done;
                        }

                        let err_msg = std::io::Error::from_raw_os_error(-err_no);
                        bail!("{} ({}): {:?}", err_msg, -err_no, &payload[4..]);
                    }
                    t if res_type != 0 && t != res_type => {
                        continue;
                    }
                    _ => {
                        res.push(m.payload.take().unwrap());
                    }
                }

                if m.check_last_message() {
                    break 'done;
                }
            }
        }

        Ok(res)
    }
}