aya/sys/
netlink.rs

1use std::{
2    collections::HashMap,
3    ffi::CStr,
4    io, mem,
5    os::fd::{AsRawFd as _, BorrowedFd, FromRawFd as _},
6    ptr, slice,
7};
8
9use libc::{
10    getsockname, nlattr, nlmsgerr, nlmsghdr, recv, send, setsockopt, sockaddr_nl, socket,
11    AF_NETLINK, AF_UNSPEC, ETH_P_ALL, IFF_UP, IFLA_XDP, NETLINK_EXT_ACK, NETLINK_ROUTE,
12    NLA_ALIGNTO, NLA_F_NESTED, NLA_TYPE_MASK, NLMSG_DONE, NLMSG_ERROR, NLM_F_ACK, NLM_F_CREATE,
13    NLM_F_DUMP, NLM_F_ECHO, NLM_F_EXCL, NLM_F_MULTI, NLM_F_REQUEST, RTM_DELTFILTER, RTM_GETTFILTER,
14    RTM_NEWQDISC, RTM_NEWTFILTER, RTM_SETLINK, SOCK_RAW, SOL_NETLINK,
15};
16use thiserror::Error;
17
18use crate::{
19    generated::{
20        ifinfomsg, tcmsg, IFLA_XDP_EXPECTED_FD, IFLA_XDP_FD, IFLA_XDP_FLAGS, NLMSG_ALIGNTO,
21        TCA_BPF_FD, TCA_BPF_FLAGS, TCA_BPF_FLAG_ACT_DIRECT, TCA_BPF_NAME, TCA_KIND, TCA_OPTIONS,
22        TC_H_CLSACT, TC_H_INGRESS, TC_H_MAJ_MASK, TC_H_UNSPEC, XDP_FLAGS_REPLACE,
23    },
24    programs::TcAttachType,
25    util::tc_handler_make,
26};
27
28const NLA_HDR_LEN: usize = align_to(mem::size_of::<nlattr>(), NLA_ALIGNTO as usize);
29
30// Safety: marking this as unsafe overall because of all the pointer math required to comply with
31// netlink alignments
32pub(crate) unsafe fn netlink_set_xdp_fd(
33    if_index: i32,
34    fd: Option<BorrowedFd<'_>>,
35    old_fd: Option<BorrowedFd<'_>>,
36    flags: u32,
37) -> Result<(), io::Error> {
38    let sock = NetlinkSocket::open()?;
39
40    // Safety: Request is POD so this is safe
41    let mut req = mem::zeroed::<Request>();
42
43    let nlmsg_len = mem::size_of::<nlmsghdr>() + mem::size_of::<ifinfomsg>();
44    req.header = nlmsghdr {
45        nlmsg_len: nlmsg_len as u32,
46        nlmsg_flags: (NLM_F_REQUEST | NLM_F_ACK) as u16,
47        nlmsg_type: RTM_SETLINK,
48        nlmsg_pid: 0,
49        nlmsg_seq: 1,
50    };
51    req.if_info.ifi_family = AF_UNSPEC as u8;
52    req.if_info.ifi_index = if_index;
53
54    // write the attrs
55    let attrs_buf = request_attributes(&mut req, nlmsg_len);
56    let mut attrs = NestedAttrs::new(attrs_buf, IFLA_XDP);
57    attrs.write_attr(
58        IFLA_XDP_FD as u16,
59        fd.map(|fd| fd.as_raw_fd()).unwrap_or(-1),
60    )?;
61
62    if flags > 0 {
63        attrs.write_attr(IFLA_XDP_FLAGS as u16, flags)?;
64    }
65
66    if flags & XDP_FLAGS_REPLACE != 0 {
67        attrs.write_attr(
68            IFLA_XDP_EXPECTED_FD as u16,
69            old_fd.map(|fd| fd.as_raw_fd()).unwrap(),
70        )?;
71    }
72
73    let nla_len = attrs.finish()?;
74    req.header.nlmsg_len += align_to(nla_len, NLA_ALIGNTO as usize) as u32;
75
76    sock.send(&bytes_of(&req)[..req.header.nlmsg_len as usize])?;
77
78    sock.recv()?;
79
80    Ok(())
81}
82
83pub(crate) unsafe fn netlink_qdisc_add_clsact(if_index: i32) -> Result<(), io::Error> {
84    let sock = NetlinkSocket::open()?;
85
86    let mut req = mem::zeroed::<TcRequest>();
87
88    let nlmsg_len = mem::size_of::<nlmsghdr>() + mem::size_of::<tcmsg>();
89    req.header = nlmsghdr {
90        nlmsg_len: nlmsg_len as u32,
91        nlmsg_flags: (NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE) as u16,
92        nlmsg_type: RTM_NEWQDISC,
93        nlmsg_pid: 0,
94        nlmsg_seq: 1,
95    };
96    req.tc_info.tcm_family = AF_UNSPEC as u8;
97    req.tc_info.tcm_ifindex = if_index;
98    req.tc_info.tcm_handle = tc_handler_make(TC_H_CLSACT, TC_H_UNSPEC);
99    req.tc_info.tcm_parent = tc_handler_make(TC_H_CLSACT, TC_H_INGRESS);
100    req.tc_info.tcm_info = 0;
101
102    // add the TCA_KIND attribute
103    let attrs_buf = request_attributes(&mut req, nlmsg_len);
104    let attr_len = write_attr_bytes(attrs_buf, 0, TCA_KIND as u16, b"clsact\0")?;
105    req.header.nlmsg_len += align_to(attr_len, NLA_ALIGNTO as usize) as u32;
106
107    sock.send(&bytes_of(&req)[..req.header.nlmsg_len as usize])?;
108    sock.recv()?;
109
110    Ok(())
111}
112
113pub(crate) unsafe fn netlink_qdisc_attach(
114    if_index: i32,
115    attach_type: &TcAttachType,
116    prog_fd: BorrowedFd<'_>,
117    prog_name: &CStr,
118    priority: u16,
119    handle: u32,
120    create: bool,
121) -> Result<(u16, u32), io::Error> {
122    let sock = NetlinkSocket::open()?;
123    let mut req = mem::zeroed::<TcRequest>();
124
125    let nlmsg_len = mem::size_of::<nlmsghdr>() + mem::size_of::<tcmsg>();
126    // When create=true, we're creating a new attachment so we must set NLM_F_CREATE. Then we also
127    // set NLM_F_EXCL so that attaching fails if there's already a program attached to the given
128    // handle.
129    //
130    // When create=false we're replacing an existing attachment so we must not set either flags.
131    //
132    // See https://github.com/torvalds/linux/blob/3a87498/net/sched/cls_api.c#L2304
133    let request_flags = if create {
134        NLM_F_CREATE | NLM_F_EXCL
135    } else {
136        // NLM_F_REPLACE exists, but seems unused by cls_bpf
137        0
138    };
139    req.header = nlmsghdr {
140        nlmsg_len: nlmsg_len as u32,
141        nlmsg_flags: (NLM_F_REQUEST | NLM_F_ACK | NLM_F_ECHO | request_flags) as u16,
142        nlmsg_type: RTM_NEWTFILTER,
143        nlmsg_pid: 0,
144        nlmsg_seq: 1,
145    };
146    req.tc_info.tcm_family = AF_UNSPEC as u8;
147    req.tc_info.tcm_handle = handle; // auto-assigned, if zero
148    req.tc_info.tcm_ifindex = if_index;
149    req.tc_info.tcm_parent = attach_type.tc_parent();
150    req.tc_info.tcm_info = tc_handler_make((priority as u32) << 16, htons(ETH_P_ALL as u16) as u32);
151
152    let attrs_buf = request_attributes(&mut req, nlmsg_len);
153
154    // add TCA_KIND
155    let kind_len = write_attr_bytes(attrs_buf, 0, TCA_KIND as u16, b"bpf\0")?;
156
157    // add TCA_OPTIONS which includes TCA_BPF_FD, TCA_BPF_NAME and TCA_BPF_FLAGS
158    let mut options = NestedAttrs::new(&mut attrs_buf[kind_len..], TCA_OPTIONS as u16);
159    options.write_attr(TCA_BPF_FD as u16, prog_fd)?;
160    options.write_attr_bytes(TCA_BPF_NAME as u16, prog_name.to_bytes_with_nul())?;
161    let flags: u32 = TCA_BPF_FLAG_ACT_DIRECT;
162    options.write_attr(TCA_BPF_FLAGS as u16, flags)?;
163    let options_len = options.finish()?;
164
165    req.header.nlmsg_len += align_to(kind_len + options_len, NLA_ALIGNTO as usize) as u32;
166    sock.send(&bytes_of(&req)[..req.header.nlmsg_len as usize])?;
167
168    // find the RTM_NEWTFILTER reply and read the tcm_info and tcm_handle fields
169    // which we'll need to detach
170    let tc_msg = match sock
171        .recv()?
172        .iter()
173        .find(|reply| reply.header.nlmsg_type == RTM_NEWTFILTER)
174    {
175        Some(reply) => ptr::read_unaligned(reply.data.as_ptr() as *const tcmsg),
176        None => {
177            // if sock.recv() succeeds we should never get here unless there's a
178            // bug in the kernel
179            return Err(io::Error::new(
180                io::ErrorKind::Other,
181                "no RTM_NEWTFILTER reply received, this is a bug.",
182            ));
183        }
184    };
185
186    let priority = ((tc_msg.tcm_info & TC_H_MAJ_MASK) >> 16) as u16;
187    Ok((priority, tc_msg.tcm_handle))
188}
189
190pub(crate) unsafe fn netlink_qdisc_detach(
191    if_index: i32,
192    attach_type: &TcAttachType,
193    priority: u16,
194    handle: u32,
195) -> Result<(), io::Error> {
196    let sock = NetlinkSocket::open()?;
197    let mut req = mem::zeroed::<TcRequest>();
198
199    req.header = nlmsghdr {
200        nlmsg_len: (mem::size_of::<nlmsghdr>() + mem::size_of::<tcmsg>()) as u32,
201        nlmsg_flags: (NLM_F_REQUEST | NLM_F_ACK) as u16,
202        nlmsg_type: RTM_DELTFILTER,
203        nlmsg_pid: 0,
204        nlmsg_seq: 1,
205    };
206
207    req.tc_info.tcm_family = AF_UNSPEC as u8;
208    req.tc_info.tcm_handle = handle; // auto-assigned, if zero
209    req.tc_info.tcm_info = tc_handler_make((priority as u32) << 16, htons(ETH_P_ALL as u16) as u32);
210    req.tc_info.tcm_parent = attach_type.tc_parent();
211    req.tc_info.tcm_ifindex = if_index;
212
213    sock.send(&bytes_of(&req)[..req.header.nlmsg_len as usize])?;
214
215    sock.recv()?;
216
217    Ok(())
218}
219
220// Returns a vector of tuple (priority, handle) for filters matching the provided parameters
221pub(crate) unsafe fn netlink_find_filter_with_name(
222    if_index: i32,
223    attach_type: TcAttachType,
224    name: &CStr,
225) -> Result<Vec<(u16, u32)>, io::Error> {
226    let mut req = mem::zeroed::<TcRequest>();
227
228    let nlmsg_len = mem::size_of::<nlmsghdr>() + mem::size_of::<tcmsg>();
229    req.header = nlmsghdr {
230        nlmsg_len: nlmsg_len as u32,
231        nlmsg_type: RTM_GETTFILTER,
232        nlmsg_flags: (NLM_F_REQUEST | NLM_F_DUMP) as u16,
233        nlmsg_pid: 0,
234        nlmsg_seq: 1,
235    };
236    req.tc_info.tcm_family = AF_UNSPEC as u8;
237    req.tc_info.tcm_handle = 0; // auto-assigned, if zero
238    req.tc_info.tcm_ifindex = if_index;
239    req.tc_info.tcm_parent = attach_type.tc_parent();
240
241    let sock = NetlinkSocket::open()?;
242    sock.send(&bytes_of(&req)[..req.header.nlmsg_len as usize])?;
243
244    let mut filter_info = Vec::new();
245    for msg in sock.recv()? {
246        if msg.header.nlmsg_type != RTM_NEWTFILTER {
247            continue;
248        }
249
250        let tc_msg = ptr::read_unaligned(msg.data.as_ptr() as *const tcmsg);
251        let priority = (tc_msg.tcm_info >> 16) as u16;
252        let attrs = parse_attrs(&msg.data[mem::size_of::<tcmsg>()..])?;
253
254        if let Some(opts) = attrs.get(&(TCA_OPTIONS as u16)) {
255            let opts = parse_attrs(opts.data)?;
256            if let Some(f_name) = opts.get(&(TCA_BPF_NAME as u16)) {
257                if let Ok(f_name) = CStr::from_bytes_with_nul(f_name.data) {
258                    if name == f_name {
259                        filter_info.push((priority, tc_msg.tcm_handle));
260                    }
261                }
262            }
263        }
264    }
265
266    Ok(filter_info)
267}
268
269#[doc(hidden)]
270pub unsafe fn netlink_set_link_up(if_index: i32) -> Result<(), io::Error> {
271    let sock = NetlinkSocket::open()?;
272
273    // Safety: Request is POD so this is safe
274    let mut req = mem::zeroed::<Request>();
275
276    let nlmsg_len = mem::size_of::<nlmsghdr>() + mem::size_of::<ifinfomsg>();
277    req.header = nlmsghdr {
278        nlmsg_len: nlmsg_len as u32,
279        nlmsg_flags: (NLM_F_REQUEST | NLM_F_ACK) as u16,
280        nlmsg_type: RTM_SETLINK,
281        nlmsg_pid: 0,
282        nlmsg_seq: 1,
283    };
284    req.if_info.ifi_family = AF_UNSPEC as u8;
285    req.if_info.ifi_index = if_index;
286    req.if_info.ifi_flags = IFF_UP as u32;
287    req.if_info.ifi_change = IFF_UP as u32;
288
289    sock.send(&bytes_of(&req)[..req.header.nlmsg_len as usize])?;
290    sock.recv()?;
291
292    Ok(())
293}
294
295#[repr(C)]
296struct Request {
297    header: nlmsghdr,
298    if_info: ifinfomsg,
299    attrs: [u8; 64],
300}
301
302#[repr(C)]
303struct TcRequest {
304    header: nlmsghdr,
305    tc_info: tcmsg,
306    attrs: [u8; 64],
307}
308
309struct NetlinkSocket {
310    sock: crate::MockableFd,
311    _nl_pid: u32,
312}
313
314impl NetlinkSocket {
315    fn open() -> Result<Self, io::Error> {
316        // Safety: libc wrapper
317        let sock = unsafe { socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE) };
318        if sock < 0 {
319            return Err(io::Error::last_os_error());
320        }
321        // SAFETY: `socket` returns a file descriptor.
322        let sock = unsafe { crate::MockableFd::from_raw_fd(sock) };
323
324        let enable = 1i32;
325        // Safety: libc wrapper
326        unsafe {
327            setsockopt(
328                sock.as_raw_fd(),
329                SOL_NETLINK,
330                NETLINK_EXT_ACK,
331                &enable as *const _ as *const _,
332                mem::size_of::<i32>() as u32,
333            )
334        };
335
336        // Safety: sockaddr_nl is POD so this is safe
337        let mut addr = unsafe { mem::zeroed::<sockaddr_nl>() };
338        addr.nl_family = AF_NETLINK as u16;
339        let mut addr_len = mem::size_of::<sockaddr_nl>() as u32;
340        // Safety: libc wrapper
341        if unsafe {
342            getsockname(
343                sock.as_raw_fd(),
344                &mut addr as *mut _ as *mut _,
345                &mut addr_len as *mut _,
346            )
347        } < 0
348        {
349            return Err(io::Error::last_os_error());
350        }
351
352        Ok(Self {
353            sock,
354            _nl_pid: addr.nl_pid,
355        })
356    }
357
358    fn send(&self, msg: &[u8]) -> Result<(), io::Error> {
359        if unsafe {
360            send(
361                self.sock.as_raw_fd(),
362                msg.as_ptr() as *const _,
363                msg.len(),
364                0,
365            )
366        } < 0
367        {
368            return Err(io::Error::last_os_error());
369        }
370        Ok(())
371    }
372
373    fn recv(&self) -> Result<Vec<NetlinkMessage>, io::Error> {
374        let mut buf = [0u8; 4096];
375        let mut messages = Vec::new();
376        let mut multipart = true;
377        'out: while multipart {
378            multipart = false;
379            // Safety: libc wrapper
380            let len = unsafe {
381                recv(
382                    self.sock.as_raw_fd(),
383                    buf.as_mut_ptr() as *mut _,
384                    buf.len(),
385                    0,
386                )
387            };
388            if len < 0 {
389                return Err(io::Error::last_os_error());
390            }
391            if len == 0 {
392                break;
393            }
394
395            let len = len as usize;
396            let mut offset = 0;
397            while offset < len {
398                let message = NetlinkMessage::read(&buf[offset..])?;
399                offset += align_to(message.header.nlmsg_len as usize, NLMSG_ALIGNTO as usize);
400                multipart = message.header.nlmsg_flags & NLM_F_MULTI as u16 != 0;
401                match message.header.nlmsg_type as i32 {
402                    NLMSG_ERROR => {
403                        let err = message.error.unwrap();
404                        if err.error == 0 {
405                            // this is an ACK
406                            continue;
407                        }
408                        return Err(io::Error::from_raw_os_error(-err.error));
409                    }
410                    NLMSG_DONE => break 'out,
411                    _ => messages.push(message),
412                }
413            }
414        }
415
416        Ok(messages)
417    }
418}
419
420struct NetlinkMessage {
421    header: nlmsghdr,
422    data: Vec<u8>,
423    error: Option<nlmsgerr>,
424}
425
426impl NetlinkMessage {
427    fn read(buf: &[u8]) -> Result<Self, io::Error> {
428        if mem::size_of::<nlmsghdr>() > buf.len() {
429            return Err(io::Error::new(
430                io::ErrorKind::Other,
431                "buffer smaller than nlmsghdr",
432            ));
433        }
434
435        // Safety: nlmsghdr is POD so read is safe
436        let header = unsafe { ptr::read_unaligned(buf.as_ptr() as *const nlmsghdr) };
437        let msg_len = header.nlmsg_len as usize;
438        if msg_len < mem::size_of::<nlmsghdr>() || msg_len > buf.len() {
439            return Err(io::Error::new(io::ErrorKind::Other, "invalid nlmsg_len"));
440        }
441
442        let data_offset = align_to(mem::size_of::<nlmsghdr>(), NLMSG_ALIGNTO as usize);
443        if data_offset >= buf.len() {
444            return Err(io::Error::new(io::ErrorKind::Other, "need more data"));
445        }
446
447        let (data, error) = if header.nlmsg_type == NLMSG_ERROR as u16 {
448            if data_offset + mem::size_of::<nlmsgerr>() > buf.len() {
449                return Err(io::Error::new(
450                    io::ErrorKind::Other,
451                    "NLMSG_ERROR but not enough space for nlmsgerr",
452                ));
453            }
454            (
455                Vec::new(),
456                // Safety: nlmsgerr is POD so read is safe
457                Some(unsafe {
458                    ptr::read_unaligned(buf[data_offset..].as_ptr() as *const nlmsgerr)
459                }),
460            )
461        } else {
462            (buf[data_offset..msg_len].to_vec(), None)
463        };
464
465        Ok(Self {
466            header,
467            data,
468            error,
469        })
470    }
471}
472
473const fn align_to(v: usize, align: usize) -> usize {
474    (v + (align - 1)) & !(align - 1)
475}
476
477fn htons(u: u16) -> u16 {
478    u.to_be()
479}
480
481struct NestedAttrs<'a> {
482    buf: &'a mut [u8],
483    top_attr_type: u16,
484    offset: usize,
485}
486
487impl<'a> NestedAttrs<'a> {
488    fn new(buf: &'a mut [u8], top_attr_type: u16) -> Self {
489        Self {
490            buf,
491            top_attr_type,
492            offset: NLA_HDR_LEN,
493        }
494    }
495
496    fn write_attr<T>(&mut self, attr_type: u16, value: T) -> Result<usize, io::Error> {
497        let size = write_attr(self.buf, self.offset, attr_type, value)?;
498        self.offset += size;
499        Ok(size)
500    }
501
502    fn write_attr_bytes(&mut self, attr_type: u16, value: &[u8]) -> Result<usize, io::Error> {
503        let size = write_attr_bytes(self.buf, self.offset, attr_type, value)?;
504        self.offset += size;
505        Ok(size)
506    }
507
508    fn finish(self) -> Result<usize, io::Error> {
509        let nla_len = self.offset;
510        let attr = nlattr {
511            nla_type: NLA_F_NESTED as u16 | self.top_attr_type,
512            nla_len: nla_len as u16,
513        };
514
515        write_attr_header(self.buf, 0, attr)?;
516        Ok(nla_len)
517    }
518}
519
520fn write_attr<T>(
521    buf: &mut [u8],
522    offset: usize,
523    attr_type: u16,
524    value: T,
525) -> Result<usize, io::Error> {
526    let value =
527        unsafe { slice::from_raw_parts(&value as *const _ as *const _, mem::size_of::<T>()) };
528    write_attr_bytes(buf, offset, attr_type, value)
529}
530
531fn write_attr_bytes(
532    buf: &mut [u8],
533    offset: usize,
534    attr_type: u16,
535    value: &[u8],
536) -> Result<usize, io::Error> {
537    let attr = nlattr {
538        nla_type: attr_type,
539        nla_len: ((NLA_HDR_LEN + value.len()) as u16),
540    };
541
542    write_attr_header(buf, offset, attr)?;
543    let value_len = write_bytes(buf, offset + NLA_HDR_LEN, value)?;
544
545    Ok(NLA_HDR_LEN + value_len)
546}
547
548fn write_attr_header(buf: &mut [u8], offset: usize, attr: nlattr) -> Result<usize, io::Error> {
549    let attr =
550        unsafe { slice::from_raw_parts(&attr as *const _ as *const _, mem::size_of::<nlattr>()) };
551
552    write_bytes(buf, offset, attr)?;
553    Ok(NLA_HDR_LEN)
554}
555
556fn write_bytes(buf: &mut [u8], offset: usize, value: &[u8]) -> Result<usize, io::Error> {
557    let align_len = align_to(value.len(), NLA_ALIGNTO as usize);
558    if offset + align_len > buf.len() {
559        return Err(io::Error::new(io::ErrorKind::Other, "no space left"));
560    }
561
562    buf[offset..offset + value.len()].copy_from_slice(value);
563
564    Ok(align_len)
565}
566
567struct NlAttrsIterator<'a> {
568    attrs: &'a [u8],
569    offset: usize,
570}
571
572impl<'a> NlAttrsIterator<'a> {
573    fn new(attrs: &'a [u8]) -> Self {
574        Self { attrs, offset: 0 }
575    }
576}
577
578impl<'a> Iterator for NlAttrsIterator<'a> {
579    type Item = Result<NlAttr<'a>, NlAttrError>;
580
581    fn next(&mut self) -> Option<Self::Item> {
582        let buf = &self.attrs[self.offset..];
583        if buf.is_empty() {
584            return None;
585        }
586
587        if NLA_HDR_LEN > buf.len() {
588            self.offset = buf.len();
589            return Some(Err(NlAttrError::InvalidBufferLength {
590                size: buf.len(),
591                expected: NLA_HDR_LEN,
592            }));
593        }
594
595        let attr = unsafe { ptr::read_unaligned(buf.as_ptr() as *const nlattr) };
596        let len = attr.nla_len as usize;
597        let align_len = align_to(len, NLA_ALIGNTO as usize);
598        if len < NLA_HDR_LEN {
599            return Some(Err(NlAttrError::InvalidHeaderLength(len)));
600        }
601        if align_len > buf.len() {
602            return Some(Err(NlAttrError::InvalidBufferLength {
603                size: buf.len(),
604                expected: align_len,
605            }));
606        }
607
608        let data = &buf[NLA_HDR_LEN..len];
609
610        self.offset += align_len;
611        Some(Ok(NlAttr { header: attr, data }))
612    }
613}
614
615fn parse_attrs(buf: &[u8]) -> Result<HashMap<u16, NlAttr<'_>>, NlAttrError> {
616    let mut attrs = HashMap::new();
617    for attr in NlAttrsIterator::new(buf) {
618        let attr = attr?;
619        attrs.insert(attr.header.nla_type & NLA_TYPE_MASK as u16, attr);
620    }
621    Ok(attrs)
622}
623
624#[derive(Clone)]
625struct NlAttr<'a> {
626    header: nlattr,
627    data: &'a [u8],
628}
629
630#[derive(Debug, Error, PartialEq, Eq)]
631enum NlAttrError {
632    #[error("invalid buffer size `{size}`, expected `{expected}`")]
633    InvalidBufferLength { size: usize, expected: usize },
634
635    #[error("invalid nlattr header length `{0}`")]
636    InvalidHeaderLength(usize),
637}
638
639impl From<NlAttrError> for io::Error {
640    fn from(e: NlAttrError) -> Self {
641        Self::new(io::ErrorKind::Other, e)
642    }
643}
644
645unsafe fn request_attributes<T>(req: &mut T, msg_len: usize) -> &mut [u8] {
646    let attrs_addr = align_to(req as *mut _ as usize + msg_len, NLMSG_ALIGNTO as usize);
647    let attrs_end = req as *mut _ as usize + mem::size_of::<T>();
648    slice::from_raw_parts_mut(attrs_addr as *mut u8, attrs_end - attrs_addr)
649}
650
651fn bytes_of<T>(val: &T) -> &[u8] {
652    let size = mem::size_of::<T>();
653    unsafe { slice::from_raw_parts(slice::from_ref(val).as_ptr().cast(), size) }
654}
655
656#[cfg(test)]
657mod tests {
658    use std::ffi::CString;
659
660    use super::*;
661
662    #[test]
663    fn test_nested_attrs() {
664        let mut buf = [0; 64];
665
666        // write IFLA_XDP with 2 nested attrs
667        let mut attrs = NestedAttrs::new(&mut buf, IFLA_XDP);
668        attrs.write_attr(IFLA_XDP_FD as u16, 42u32).unwrap();
669        attrs
670            .write_attr(IFLA_XDP_EXPECTED_FD as u16, 24u32)
671            .unwrap();
672        let len = attrs.finish().unwrap() as u16;
673
674        // 3 nlattr headers (IFLA_XDP, IFLA_XDP_FD and IFLA_XDP_EXPECTED_FD) + the fd
675        let nla_len = (NLA_HDR_LEN * 3 + mem::size_of::<u32>() * 2) as u16;
676        assert_eq!(len, nla_len);
677
678        // read IFLA_XDP
679        let attr = unsafe { ptr::read_unaligned(buf.as_ptr() as *const nlattr) };
680        assert_eq!(attr.nla_type, NLA_F_NESTED as u16 | IFLA_XDP);
681        assert_eq!(attr.nla_len, nla_len);
682
683        // read IFLA_XDP_FD + fd
684        let attr = unsafe { ptr::read_unaligned(buf[NLA_HDR_LEN..].as_ptr() as *const nlattr) };
685        assert_eq!(attr.nla_type, IFLA_XDP_FD as u16);
686        assert_eq!(attr.nla_len, (NLA_HDR_LEN + mem::size_of::<u32>()) as u16);
687        let fd = unsafe { ptr::read_unaligned(buf[NLA_HDR_LEN * 2..].as_ptr() as *const u32) };
688        assert_eq!(fd, 42);
689
690        // read IFLA_XDP_EXPECTED_FD + fd
691        let attr = unsafe {
692            ptr::read_unaligned(
693                buf[NLA_HDR_LEN * 2 + mem::size_of::<u32>()..].as_ptr() as *const nlattr
694            )
695        };
696        assert_eq!(attr.nla_type, IFLA_XDP_EXPECTED_FD as u16);
697        assert_eq!(attr.nla_len, (NLA_HDR_LEN + mem::size_of::<u32>()) as u16);
698        let fd = unsafe {
699            ptr::read_unaligned(
700                buf[NLA_HDR_LEN * 3 + mem::size_of::<u32>()..].as_ptr() as *const u32
701            )
702        };
703        assert_eq!(fd, 24);
704    }
705
706    #[test]
707    fn test_nlattr_iterator_empty() {
708        let mut iter = NlAttrsIterator::new(&[]);
709        assert!(iter.next().is_none());
710    }
711
712    #[test]
713    fn test_nlattr_iterator_one() {
714        let mut buf = [0; NLA_HDR_LEN + mem::size_of::<u32>()];
715
716        write_attr(&mut buf, 0, IFLA_XDP_FD as u16, 42u32).unwrap();
717
718        let mut iter = NlAttrsIterator::new(&buf);
719        let attr = iter.next().unwrap().unwrap();
720        assert_eq!(attr.header.nla_type, IFLA_XDP_FD as u16);
721        assert_eq!(attr.data.len(), mem::size_of::<u32>());
722        assert_eq!(u32::from_ne_bytes(attr.data.try_into().unwrap()), 42);
723
724        assert!(iter.next().is_none());
725    }
726
727    #[test]
728    fn test_nlattr_iterator_many() {
729        let mut buf = [0; (NLA_HDR_LEN + mem::size_of::<u32>()) * 2];
730
731        write_attr(&mut buf, 0, IFLA_XDP_FD as u16, 42u32).unwrap();
732        write_attr(
733            &mut buf,
734            NLA_HDR_LEN + mem::size_of::<u32>(),
735            IFLA_XDP_EXPECTED_FD as u16,
736            12u32,
737        )
738        .unwrap();
739
740        let mut iter = NlAttrsIterator::new(&buf);
741
742        let attr = iter.next().unwrap().unwrap();
743        assert_eq!(attr.header.nla_type, IFLA_XDP_FD as u16);
744        assert_eq!(attr.data.len(), mem::size_of::<u32>());
745        assert_eq!(u32::from_ne_bytes(attr.data.try_into().unwrap()), 42);
746
747        let attr = iter.next().unwrap().unwrap();
748        assert_eq!(attr.header.nla_type, IFLA_XDP_EXPECTED_FD as u16);
749        assert_eq!(attr.data.len(), mem::size_of::<u32>());
750        assert_eq!(u32::from_ne_bytes(attr.data.try_into().unwrap()), 12);
751
752        assert!(iter.next().is_none());
753    }
754
755    #[test]
756    fn test_nlattr_iterator_nested() {
757        let mut buf = [0; 1024];
758
759        let mut options = NestedAttrs::new(&mut buf, TCA_OPTIONS as u16);
760        options.write_attr(TCA_BPF_FD as u16, 42).unwrap();
761
762        let name = CString::new("foo").unwrap();
763        options
764            .write_attr_bytes(TCA_BPF_NAME as u16, name.to_bytes_with_nul())
765            .unwrap();
766        options.finish().unwrap();
767
768        let mut iter = NlAttrsIterator::new(&buf);
769        let outer = iter.next().unwrap().unwrap();
770        assert_eq!(
771            outer.header.nla_type & NLA_TYPE_MASK as u16,
772            TCA_OPTIONS as u16
773        );
774
775        let mut iter = NlAttrsIterator::new(outer.data);
776        let inner = iter.next().unwrap().unwrap();
777        assert_eq!(
778            inner.header.nla_type & NLA_TYPE_MASK as u16,
779            TCA_BPF_FD as u16
780        );
781        let inner = iter.next().unwrap().unwrap();
782        assert_eq!(
783            inner.header.nla_type & NLA_TYPE_MASK as u16,
784            TCA_BPF_NAME as u16
785        );
786        let name = CStr::from_bytes_with_nul(inner.data).unwrap();
787        assert_eq!(name.to_str().unwrap(), "foo");
788    }
789}