netsock 0.7.0

Cross-platform library for network sockets information
Documentation
use crate::error::*;
use crate::socket::{ProtocolSocketInfo, SocketInfo, TcpSocketInfo, UdpSocketInfo};
use crate::state::TcpState;
use log::warn;

use netlink_packet_core::{
    NLM_F_DUMP, NLM_F_REQUEST, NetlinkHeader, NetlinkMessage, NetlinkPayload,
};
use netlink_packet_sock_diag::inet::InetResponse;
use netlink_packet_sock_diag::{
    SockDiagMessage,
    constants::*,
    inet::{ExtensionFlags, InetRequest, SocketId, StateFlags},
};
use netlink_sys::{Socket, SocketAddr, protocols::NETLINK_SOCK_DIAG};

const SOCKET_BUFFER_SIZE: usize = 8192;

pub struct NetlinkIterator {
    protocol: u8,
    recv_buf: [u8; SOCKET_BUFFER_SIZE],
    socket: Socket,
    offset: usize,
    size: usize,
    is_done: bool,
}

impl NetlinkIterator {
    pub fn new(family: u8, protocol: u8) -> Result<Self, Error> {
        let mut socket = Socket::new(NETLINK_SOCK_DIAG)?;
        let _port_number = socket.bind_auto()?.port_number();
        socket.connect(&SocketAddr::new(0, 0))?;

        let mut nl_hdr = NetlinkHeader::default();
        nl_hdr.flags = NLM_F_REQUEST | NLM_F_DUMP;
        let mut packet = NetlinkMessage::new(
            nl_hdr,
            SockDiagMessage::InetRequest(InetRequest {
                family,
                protocol,
                extensions: ExtensionFlags::empty(),
                states: StateFlags::all(),
                socket_id: SocketId::new_v4(),
            })
            .into(),
        );

        packet.finalize();

        let mut buf = vec![0; packet.buffer_len()];
        packet.serialize(&mut buf[..]);
        socket.send(&buf[..], 0)?;

        Ok(NetlinkIterator {
            protocol,
            socket,
            recv_buf: [0u8; SOCKET_BUFFER_SIZE],
            offset: 0,
            size: 0,
            is_done: false,
        })
    }

    fn try_read_next_packet(&mut self) -> Result<Option<SocketInfo>, Error> {
        if self.is_done {
            return Ok(None);
        }

        loop {
            if self.offset >= self.size {
                self.size = self.socket.recv(&mut &mut self.recv_buf[..], 0)?;
                self.offset = 0;
            }

            let bytes = &self.recv_buf[self.offset..self.size];

            let rx_packet: NetlinkMessage<SockDiagMessage> = match NetlinkMessage::deserialize(
                bytes,
            ) {
                Ok(rx_packet) => rx_packet,
                Err(e) => {
                    warn!(
                        "Failed to deserialize netlink message for protocol {} ({} bytes remaining): {e}",
                        self.protocol,
                        bytes.len()
                    );
                    self.offset = self.size;
                    continue;
                }
            };
            self.offset += rx_packet.header.length as usize;

            match rx_packet.payload {
                NetlinkPayload::Noop => {}
                NetlinkPayload::InnerMessage(SockDiagMessage::InetResponse(response)) => {
                    return Ok(Some(parse_diag_msg(&response, self.protocol)?));
                }
                NetlinkPayload::Done(_) => {
                    self.is_done = true;
                    return Ok(None);
                }
                NetlinkPayload::Error(err) => {
                    self.is_done = true;
                    return Err(Error::NetLinkPacketError(err));
                }
                _ => return Ok(None),
            }
        }
    }
}

impl Iterator for NetlinkIterator {
    type Item = Result<SocketInfo, Error>;

    fn next(&mut self) -> Option<Self::Item> {
        self.try_read_next_packet().transpose()
    }
}

fn parse_diag_msg(diag_msg: &InetResponse, protocol: u8) -> Result<SocketInfo, Error> {
    let src_port = diag_msg.header.socket_id.source_port;
    let dst_port = diag_msg.header.socket_id.destination_port;
    let src_ip = diag_msg.header.socket_id.source_address;
    let dst_ip = diag_msg.header.socket_id.destination_address;

    let sock_info = match protocol {
        IPPROTO_TCP => SocketInfo {
            protocol_socket_info: ProtocolSocketInfo::Tcp(TcpSocketInfo {
                local_addr: src_ip,
                local_port: src_port,
                remote_addr: dst_ip,
                remote_port: dst_port,
                state: TcpState::from(diag_msg.header.state),
            }),
            processes: Vec::new(),
            inode: diag_msg.header.inode,
            uid: diag_msg.header.uid,
        },
        IPPROTO_UDP => SocketInfo {
            protocol_socket_info: ProtocolSocketInfo::Udp(UdpSocketInfo {
                local_addr: src_ip,
                local_port: src_port,
            }),
            processes: Vec::new(),
            inode: diag_msg.header.inode,
            uid: diag_msg.header.uid,
        },
        _ => return Err(Error::UnknownProtocol(protocol)),
    };

    Ok(sock_info)
}