abpfiff/
netlink.rs

1use alloc::vec::Vec;
2use bytemuck::Zeroable;
3use core::ops::ControlFlow;
4
5use crate::nlattr;
6use crate::sys::{
7    ArcTable, IfInfoMsg, LibBpfErrno, NlIfInfoReq, NlMsgErr, NlMsgHdr, NlTcReq, SockaddrNl,
8};
9use crate::{Errno, Netlink, OwnedFd, XdpQuery};
10
11pub struct NetlinkRecvBuffer {
12    iovec: libc::iovec,
13    mhdr: libc::msghdr,
14    /// Using `u32` due to align >= 4 requirement.
15    buf: Vec<u32>,
16    /// The expected sequence number.
17    seq: u32,
18}
19
20/// Just for reference, the command constants.
21#[allow(dead_code)]
22impl Netlink {
23    const ADD_MEMBERSHIP: libc::c_int = 1;
24    const DROP_MEMBERSHIP: libc::c_int = 2;
25    const PKTINFO: libc::c_int = 3;
26    const BROADCAST_ERROR: libc::c_int = 4;
27    const NO_ENOBUFS: libc::c_int = 5;
28    const RX_RING: libc::c_int = 6;
29    const TX_RING: libc::c_int = 7;
30    const LISTEN_ALL_NSID: libc::c_int = 8;
31    const LIST_MEMBERSHIPS: libc::c_int = 9;
32    const CAP_ACK: libc::c_int = 10;
33    const EXT_ACK: libc::c_int = 11;
34    const GET_STRICT_CHK: libc::c_int = 12;
35}
36
37impl Netlink {
38    pub fn open(sys: ArcTable) -> Result<Self, Errno> {
39        let sock = unsafe {
40            (sys.socket)(
41                libc::AF_NETLINK,
42                libc::SOCK_RAW | libc::SOCK_CLOEXEC,
43                libc::NETLINK_ROUTE,
44            )
45        };
46
47        if sock < 0 {
48            return Err(sys.errno());
49        }
50
51        let sock = OwnedFd(sock, sys.clone());
52
53        if {
54            let one: libc::c_int = 1;
55            let size = core::mem::size_of_val(&one) as libc::socklen_t;
56
57            unsafe {
58                (sys.setsockopt)(
59                    sock.0,
60                    libc::SOL_NETLINK,
61                    Self::EXT_ACK,
62                    (&one) as *const _ as *const libc::c_void,
63                    size,
64                )
65            }
66        } < 0
67        {}
68
69        let mut sockaddr_nl = SockaddrNl {
70            nl_family: libc::AF_NETLINK as libc::sa_family_t,
71            nl_pad: 0,
72            nl_pid: 0,
73            nl_groups: 0,
74        };
75
76        if {
77            unsafe {
78                (sys.bind)(
79                    sock.0,
80                    (&mut sockaddr_nl) as *mut _ as *mut libc::sockaddr,
81                    core::mem::size_of_val(&sockaddr_nl) as libc::socklen_t,
82                )
83            }
84        } < 0
85        {
86            return Err(sys.errno());
87        }
88
89        if {
90            let mut addrlen = core::mem::size_of_val(&sockaddr_nl) as libc::socklen_t;
91            unsafe {
92                (sys.getsockname)(
93                    sock.0,
94                    (&mut sockaddr_nl) as *mut _ as *mut libc::sockaddr,
95                    &mut addrlen,
96                )
97            }
98        } < 0
99        {
100            return Err(sys.errno());
101        }
102
103        let pid = sockaddr_nl.nl_pid;
104        let seq = 0u32;
105        let buf = alloc::vec::Vec::new();
106
107        Ok(Netlink {
108            sock,
109            pid,
110            seq,
111            buf,
112        })
113    }
114
115    pub fn sys(&self) -> &ArcTable {
116        &self.sock.1
117    }
118
119    pub fn xdp_query(
120        &mut self,
121        ifindex: u32,
122        buf: &mut NetlinkRecvBuffer,
123    ) -> Result<XdpQuery, Errno> {
124        let mut req = NlIfInfoReq {
125            hdr: NlMsgHdr {
126                nlmsg_type: libc::RTM_GETLINK,
127                nlmsg_flags: NlMsgHdr::NLM_F_DUMP | NlMsgHdr::NLM_F_REQUEST,
128                ..NlMsgHdr::zeroed()
129            },
130            msg: IfInfoMsg {
131                ifi_family: libc::AF_PACKET as u8,
132                ..IfInfoMsg::zeroed()
133            },
134        };
135
136        let sys = self.sys().clone();
137        let mut query = XdpQuery::default();
138        let mut parse_err = Ok(());
139
140        self.sendmsg_if_info(&mut req, buf)?;
141        self.recvmsg_multi(buf, |hdr, data| {
142            Self::link_nlmsg_parse(
143                &sys,
144                hdr,
145                data,
146                |hdr, attr| {
147                    if hdr.ifi_index as u32 != ifindex {
148                        // eprint!("Nested Data: {:?}\n", attr[nlattr::IflaType::IFLA_XDP as usize].data);
149                        return Ok(());
150                    }
151
152                    // eprint!("Nested Data: {:?}\n", nlattr::IflaType::IFLA_XDP as usize);
153                    // eprint!("Nested Data: {:?}\n", attr[nlattr::IflaType::IFLA_XDP as usize].data);
154
155                    let nested = match attr[nlattr::IflaType::IFLA_XDP as usize].data {
156                        Some(data) => data,
157                        None => return Ok(()),
158                    };
159
160                    for attr in &mut attr[..nlattr::IFLA_XDP_MAX] {
161                        *attr = nlattr::Attr::default();
162                    }
163
164                    nlattr::parse(&mut attr[..nlattr::IFLA_XDP_MAX], nested)?;
165                    // eprint!("Nested: {:?}\n", &attr[..nlattr::IFLA_XDP_MAX]);
166
167                    if !attr[nlattr::IflaXdp::IFLA_XDP_ATTACHED as usize].is_set() {
168                        return Ok(());
169                    }
170
171                    query.attach_mode =
172                        attr[nlattr::IflaXdp::IFLA_XDP_ATTACHED as usize].getattr_u8()?;
173
174                    if query.attach_mode == 0 {
175                        return Ok(());
176                    }
177
178                    if attr[nlattr::IflaXdp::IFLA_XDP_PROG_ID as usize].is_set() {
179                        query.prog_id =
180                            attr[nlattr::IflaXdp::IFLA_XDP_PROG_ID as usize].getattr_u32()?;
181                    }
182
183                    if attr[nlattr::IflaXdp::IFLA_XDP_SKB_PROG_ID as usize].is_set() {
184                        query.skb_prog_id =
185                            attr[nlattr::IflaXdp::IFLA_XDP_SKB_PROG_ID as usize].getattr_u32()?;
186                    }
187
188                    if attr[nlattr::IflaXdp::IFLA_XDP_DRV_PROG_ID as usize].is_set() {
189                        query.drv_prog_id =
190                            attr[nlattr::IflaXdp::IFLA_XDP_DRV_PROG_ID as usize].getattr_u32()?;
191                    }
192
193                    if attr[nlattr::IflaXdp::IFLA_XDP_HW_PROG_ID as usize].is_set() {
194                        query.hw_prog_id =
195                            attr[nlattr::IflaXdp::IFLA_XDP_HW_PROG_ID as usize].getattr_u32()?;
196                    }
197
198                    Ok(())
199                },
200                &mut parse_err,
201            )
202        })?;
203
204        Ok(query)
205    }
206
207    /** Low-level methods to interact directly with Netlink. */
208    pub fn sendmsg_if_info(
209        &mut self,
210        req: &mut NlIfInfoReq,
211        buf: &mut NetlinkRecvBuffer,
212    ) -> Result<(), Errno> {
213        let nlmsg_len = core::mem::size_of_val(req);
214        req.hdr.nlmsg_pid = 0;
215        req.hdr.nlmsg_seq = self.seq;
216        req.hdr.nlmsg_len = nlmsg_len as u32;
217        unsafe { self.sendmsg_after_len(req as *mut _ as *const _, nlmsg_len, buf) }
218    }
219
220    pub fn sendmsg_tc(
221        &mut self,
222        req: &mut NlTcReq,
223        buf: &mut NetlinkRecvBuffer,
224    ) -> Result<(), Errno> {
225        let nlmsg_len = core::mem::size_of_val(req);
226        req.hdr.nlmsg_pid = 0;
227        req.hdr.nlmsg_seq = self.seq;
228        req.hdr.nlmsg_len = nlmsg_len as u32;
229        unsafe { self.sendmsg_after_len(req as *mut _ as *const _, nlmsg_len, buf) }
230    }
231
232    pub(crate) unsafe fn sendmsg_after_len(
233        &mut self,
234        req: *const NlMsgHdr,
235        nlmsg_len: usize,
236        buf: &mut NetlinkRecvBuffer,
237    ) -> Result<(), Errno> {
238        buf.set_seq(self.seq);
239        self.seq += 1;
240
241        if unsafe {
242            (self.sys().send)(
243                self.sock.0,
244                req as *const _ as *const libc::c_void,
245                nlmsg_len,
246                0,
247            )
248        } < 0
249        {
250            Err(self.sys().errno())
251        } else {
252            Ok(())
253        }
254    }
255
256    pub fn recvmsg_multi(
257        &self,
258        buffer: &mut NetlinkRecvBuffer,
259        fn_: impl FnMut(&NlMsgHdr, &[u8]) -> ControlFlow<()>,
260    ) -> Result<(), Errno> {
261        buffer.recvmsg_multi(self, fn_)
262    }
263
264    /// `__dump_link_nlattr`.
265    fn link_nlmsg_parse<F>(
266        sys: &ArcTable,
267        _: &NlMsgHdr,
268        data: &[u8],
269        mut f: F,
270        err: &mut Result<(), Errno>,
271    ) -> ControlFlow<()>
272    where
273        F: FnMut(&IfInfoMsg, &mut [nlattr::Attr]) -> Result<(), LibBpfErrno>,
274    {
275        if err.is_err() {
276            return ControlFlow::Break(());
277        }
278
279        let ifohdr = match data.get(..core::mem::size_of::<IfInfoMsg>()) {
280            None => {
281                *err = Err(sys.bpf_err(LibBpfErrno::LIBBPF_ERRNO__NLPARSE));
282                return ControlFlow::Break(());
283            }
284            Some(msg) => msg,
285        };
286
287        let data = &data[core::mem::size_of::<IfInfoMsg>()..];
288        let ifohdr: &IfInfoMsg = match bytemuck::try_from_bytes(ifohdr) {
289            Err(_) => {
290                *err = Err(sys.bpf_err(LibBpfErrno::LIBBPF_ERRNO__NLPARSE));
291                return ControlFlow::Break(());
292            }
293            Ok(msg) => msg,
294        };
295
296        let mut nlattr = [nlattr::Attr::default(); nlattr::IFLA_MAX + 1];
297        match nlattr::parse(&mut nlattr, data) {
298            Err(no) => {
299                *err = Err(sys.bpf_err(no));
300                return ControlFlow::Break(());
301            }
302            Ok(len) => len,
303        }
304
305        match f(ifohdr, &mut nlattr[..]) {
306            Err(no) => {
307                *err = Err(sys.bpf_err(no));
308                return ControlFlow::Break(());
309            }
310            Ok(len) => len,
311        }
312
313        ControlFlow::Continue(())
314    }
315
316    fn get_xdp_info() -> ControlFlow<()> {
317        ControlFlow::Continue(())
318    }
319}
320
321impl NetlinkRecvBuffer {
322    pub const fn new() -> Self {
323        let iovec = libc::iovec {
324            iov_base: core::ptr::null_mut(),
325            iov_len: 0,
326        };
327
328        let mhdr = libc::msghdr {
329            msg_iov: core::ptr::null_mut(),
330            msg_iovlen: 0,
331            msg_control: core::ptr::null_mut(),
332            msg_controllen: 0,
333            msg_flags: 0,
334            msg_name: core::ptr::null_mut(),
335            msg_namelen: 0,
336        };
337
338        NetlinkRecvBuffer {
339            iovec,
340            mhdr,
341            buf: Vec::new(),
342            seq: 0,
343        }
344    }
345
346    /// Set the expected sec for `recvmsg_multi`.
347    pub fn set_seq(&mut self, seq: u32) {
348        self.seq = seq;
349    }
350
351    /// Clear the buffer, deallocating its memory in the process.
352    pub fn clear(&mut self) {
353        let _ = self.buf.split_off(0);
354    }
355
356    /// Receive one message, may be part of a multipart.
357    fn recvmsg_part(&mut self, from: &Netlink) -> Result<NlMessage<'_>, Errno> {
358        /* > Netlink expects that the user buffer will be at least 8kB or a page size of the CPU
359         * architecture, whichever is bigger. Particular Netlink families may, however, require a
360         * larger buffer. 32kB buffer is recommended for most efficient handling of dumps (larger
361         * buffer fits more dumped objects and therefore fewer recvmsg() calls are needed).
362         * > -- <https://kernel.org/doc/html/next/userspace-api/netlink/intro.html>
363         *
364         * We can peek a message as well, then resize the buffer based off the header. Let's do
365         * that, just like in libbpf. However, we can preserve that buffer.
366         *
367         * */
368        self.buf.reserve(4096usize.saturating_sub(self.buf.len()));
369
370        let len = unsafe {
371            let mhdr = self.prepare_mhdr();
372            (from.sock.1.recvmsg)(from.sock.0, mhdr, libc::MSG_PEEK | libc::MSG_TRUNC)
373        };
374
375        if len < 0 {
376            return Err(from.sock.1.errno());
377        }
378
379        self.buf
380            .reserve((len as usize).saturating_sub(self.buf.len()));
381
382        let len = unsafe {
383            let mhdr = self.prepare_mhdr();
384            (from.sock.1.recvmsg)(from.sock.0, mhdr, 0)
385        };
386
387        if len < 0 {
388            return Err(from.sock.1.errno());
389        }
390
391        unsafe { self.buf.set_len(len as usize) };
392
393        Ok(NlMessage {
394            buf: self.as_data(len as usize),
395            is_multipart_detected: false,
396        })
397    }
398
399    /// Parse message contents for the expected sequence number.
400    ///
401    /// The method can return `ControlFlow::Break` to break processing parts of one message, and
402    /// continue to the next multipart message if it exists.
403    ///
404    /// If this returns an error, then the Netlink to the kernel is likely broken or in an invalid
405    /// state. Please don't use it afterwards. Recoverable errors (i.e. ignored packets) are
406    /// handled internally or via callbacks, not via early return.
407    fn recvmsg_multi(
408        &mut self,
409        from: &Netlink,
410        mut fn_: impl FnMut(&NlMsgHdr, &[u8]) -> ControlFlow<()>,
411    ) -> Result<(), Errno> {
412        let seq = self.seq;
413
414        loop {
415            let mut msg = self.recvmsg_part(from)?;
416            'parts: while let Some((hdr, data)) = msg.next() {
417                if hdr.nlmsg_pid != from.pid {
418                    return Err(from.sys().bpf_err(LibBpfErrno::LIBBPF_ERRNO__WRNGPID));
419                }
420
421                if hdr.nlmsg_seq < seq {
422                    continue;
423                }
424
425                if hdr.nlmsg_seq > seq {
426                    return Err(from.sys().bpf_err(LibBpfErrno::LIBBPF_ERRNO__INVSEQ));
427                }
428
429                match hdr.nlmsg_type {
430                    NlMsgHdr::NLMSG_ERROR => {
431                        // Huh, this check is missing from the libbpf implementation, just reading
432                        // into that part of the message. Guess that's okay because we trust the
433                        // kernel? Eh, let's verify and fail with something useful.
434                        let err = match bytemuck::try_from_bytes::<NlMsgErr>(data) {
435                            Err(_) => {
436                                return Err(from.sys().bpf_err(LibBpfErrno::LIBBPF_ERRNO__INTERNAL))
437                            }
438                            Ok(err) => err,
439                        };
440
441                        if err.error == 0 {
442                            continue;
443                        }
444
445                        return Err(from.sys().mk_errno(err.error));
446                    }
447                    NlMsgHdr::NLMSG_DONE => {
448                        return Ok(());
449                    }
450                    _ => {}
451                }
452
453                match fn_(hdr, data) {
454                    ControlFlow::Continue(()) => {}
455                    ControlFlow::Break(()) => break 'parts,
456                }
457            }
458
459            if !msg.is_multipart_detected() {
460                return Ok(());
461            }
462        }
463    }
464
465    /// Helper method, ensuring pointers in the raw FFI structs are ready and valid on use.
466    fn prepare_mhdr(&mut self) -> &mut libc::msghdr {
467        self.iovec.iov_len = self.buf.capacity();
468        self.iovec.iov_base = self.buf.as_mut_ptr() as *mut libc::c_void;
469        self.mhdr.msg_iovlen = 1;
470        self.mhdr.msg_iov = &mut self.iovec;
471        &mut self.mhdr
472    }
473
474    fn as_data(&self, data: usize) -> &[u8] {
475        &bytemuck::cast_slice(self.buf.as_slice())[..data]
476    }
477}
478
479impl<'a> NlMessage<'a> {
480    pub fn next(&mut self) -> Option<(&'a NlMsgHdr, &'a [u8])> {
481        let hdr = self.buf.get(..core::mem::size_of::<NlMsgHdr>())?;
482        let hdr = bytemuck::try_from_bytes::<NlMsgHdr>(hdr).ok()?;
483        self.is_multipart_detected |= (hdr.nlmsg_flags & NlMsgHdr::NLM_F_MULTI) != 0;
484
485        let end = hdr.nlmsg_len as usize;
486        let data = self.buf.get(core::mem::size_of::<NlMsgHdr>()..end)?;
487        // Round up to 4 as per <linux/netlink.h>
488        let offset = (hdr.nlmsg_len + 3) & !3;
489
490        self.buf = self.buf.get(offset as usize..)?;
491
492        Some((hdr, data))
493    }
494
495    /// Return true if any of the parts had the multipart flag set.
496    pub fn is_multipart_detected(&self) -> bool {
497        self.is_multipart_detected
498    }
499}
500
501/// One full message datagram.
502///
503/// The message itself contains multiple Netlink portions. May be part of a multipart. _After_
504/// iterating over all its nl portions, query `is_multipart_detected()` to find out.
505struct NlMessage<'a> {
506    buf: &'a [u8],
507    is_multipart_detected: bool,
508}