netlink_request/
lib.rs

1#[cfg(target_os = "linux")]
2mod linux {
3    use netlink_packet_core::{
4        NetlinkDeserializable, NetlinkMessage, NetlinkPayload, NetlinkSerializable,
5        NETLINK_HEADER_LEN, NLM_F_ACK, NLM_F_CREATE, NLM_F_EXCL, NLM_F_REQUEST,
6    };
7    use netlink_packet_generic::{
8        constants::GENL_HDRLEN,
9        ctrl::{nlas::GenlCtrlAttrs, GenlCtrl, GenlCtrlCmd},
10        GenlFamily, GenlHeader, GenlMessage,
11    };
12    use netlink_packet_route::RouteNetlinkMessage;
13    use netlink_packet_utils::{Emitable, ParseableParametrized};
14    use netlink_sys::{constants::NETLINK_GENERIC, protocols::NETLINK_ROUTE, Socket};
15    use nix::unistd::{sysconf, SysconfVar};
16    use once_cell::sync::OnceCell;
17    use std::{fmt::Debug, io};
18
19    macro_rules! get_nla_value {
20        ($nlas:expr, $e:ident, $v:ident) => {
21            $nlas.iter().find_map(|attr| match attr {
22                $e::$v(value) => Some(value),
23                _ => None,
24            })
25        };
26    }
27
28    pub fn max_netlink_buffer_length() -> usize {
29        static LENGTH: OnceCell<usize> = OnceCell::new();
30        *LENGTH.get_or_init(|| {
31            // https://www.kernel.org/doc/html/v6.2/userspace-api/netlink/intro.html#buffer-sizing
32            // "Netlink expects that the user buffer will be at least 8kB or a page
33            // size of the CPU architecture, whichever is bigger."
34            const MIN_NELINK_BUFFER_LENGTH: usize = 8 * 1024;
35            // Note that sysconf only returns Err / Ok(None) when the parameter is
36            // invalid, unsupported on the current OS, or an unset limit. PAGE_SIZE
37            // is *required* to be supported and is not considered a limit, so this
38            // should never fail unless something has gone massively wrong.
39            let page_size = sysconf(SysconfVar::PAGE_SIZE).unwrap().unwrap() as usize;
40            std::cmp::max(MIN_NELINK_BUFFER_LENGTH, page_size)
41        })
42    }
43
44    pub fn max_genl_payload_length() -> usize {
45        max_netlink_buffer_length() - NETLINK_HEADER_LEN - GENL_HDRLEN
46    }
47
48    pub fn netlink_request_genl<F>(
49        mut message: GenlMessage<F>,
50        flags: Option<u16>,
51    ) -> Result<Vec<NetlinkMessage<GenlMessage<F>>>, io::Error>
52    where
53        F: GenlFamily + Clone + Debug + Eq + Emitable + ParseableParametrized<[u8], GenlHeader>,
54        GenlMessage<F>: Clone + Debug + Eq + NetlinkSerializable + NetlinkDeserializable,
55    {
56        if message.family_id() == 0 {
57            let genlmsg: GenlMessage<GenlCtrl> = GenlMessage::from_payload(GenlCtrl {
58                cmd: GenlCtrlCmd::GetFamily,
59                nlas: vec![GenlCtrlAttrs::FamilyName(F::family_name().to_string())],
60            });
61            let responses =
62                netlink_request_genl::<GenlCtrl>(genlmsg, Some(NLM_F_REQUEST | NLM_F_ACK))?;
63
64            match responses.first() {
65                Some(NetlinkMessage {
66                    payload:
67                        NetlinkPayload::InnerMessage(GenlMessage {
68                            payload: GenlCtrl { nlas, .. },
69                            ..
70                        }),
71                    ..
72                }) => {
73                    let family_id = get_nla_value!(nlas, GenlCtrlAttrs, FamilyId)
74                        .ok_or_else(|| io::ErrorKind::NotFound)?;
75                    message.set_resolved_family_id(*family_id);
76                },
77                _ => {
78                    return Err(io::Error::new(
79                        io::ErrorKind::InvalidData,
80                        "Unexpected netlink payload",
81                    ))
82                },
83            };
84        }
85        netlink_request(message, flags, NETLINK_GENERIC)
86    }
87
88    pub fn netlink_request_rtnl(
89        message: RouteNetlinkMessage,
90        flags: Option<u16>,
91    ) -> Result<Vec<NetlinkMessage<RouteNetlinkMessage>>, io::Error> {
92        netlink_request(message, flags, NETLINK_ROUTE)
93    }
94
95    pub fn netlink_request<I>(
96        message: I,
97        flags: Option<u16>,
98        socket: isize,
99    ) -> Result<Vec<NetlinkMessage<I>>, io::Error>
100    where
101        NetlinkPayload<I>: From<I>,
102        I: Clone + Debug + Eq + Emitable + NetlinkSerializable + NetlinkDeserializable,
103    {
104        let mut req = NetlinkMessage::from(message);
105
106        let max_buffer_len = max_netlink_buffer_length();
107        if req.buffer_len() > max_buffer_len {
108            return Err(io::Error::new(
109                io::ErrorKind::InvalidInput,
110                format!(
111                    "Serialized netlink packet ({} bytes) larger than maximum size {}: {:?}",
112                    req.buffer_len(),
113                    max_buffer_len,
114                    req
115                ),
116            ));
117        }
118
119        req.header.flags = flags.unwrap_or(NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE);
120        req.finalize();
121        let mut buf = vec![0; max_buffer_len];
122        req.serialize(&mut buf);
123        let len = req.buffer_len();
124
125        let socket = Socket::new(socket)?;
126        let kernel_addr = netlink_sys::SocketAddr::new(0, 0);
127        socket.connect(&kernel_addr)?;
128        let n_sent = socket.send(&buf[..len], 0)?;
129        if n_sent != len {
130            return Err(io::Error::new(
131                io::ErrorKind::UnexpectedEof,
132                "failed to send netlink request",
133            ));
134        }
135
136        let mut responses = vec![];
137        loop {
138            let n_received = socket.recv(&mut &mut buf[..], 0)?;
139            let mut offset = 0;
140            loop {
141                let bytes = &buf[offset..];
142                let response = NetlinkMessage::<I>::deserialize(bytes)
143                    .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
144                match response.payload {
145                    // We've parsed all parts of the response and can leave the loop.
146                    NetlinkPayload::Error(e) if e.code.is_some() => return Err(e.into()),
147                    NetlinkPayload::Done(_) | NetlinkPayload::Error(_) => return Ok(responses),
148                    _ => {},
149                }
150                responses.push(response.clone());
151                offset += response.header.length as usize;
152                if offset == n_received || response.header.length == 0 {
153                    // We've fully parsed the datagram, but there may be further datagrams
154                    // with additional netlink response parts.
155                    break;
156                }
157            }
158        }
159    }
160}
161
162#[cfg(target_os = "linux")]
163pub use linux::{
164    max_genl_payload_length, max_netlink_buffer_length, netlink_request, netlink_request_genl,
165    netlink_request_rtnl,
166};