use std::{
ffi::{CStr, CString, FromBytesWithNulError},
io, iter, mem,
os::fd::{AsRawFd as _, BorrowedFd, FromRawFd as _},
ptr, slice,
};
use aya_obj::generated::{
IFLA_XDP_EXPECTED_FD, IFLA_XDP_FD, IFLA_XDP_FLAGS, NLMSG_ALIGNTO, TC_H_CLSACT, TC_H_INGRESS,
TC_H_MAJ_MASK, TC_H_UNSPEC, TCA_BPF_FD, TCA_BPF_FLAG_ACT_DIRECT, TCA_BPF_FLAGS, TCA_BPF_NAME,
TCA_KIND, TCA_OPTIONS, XDP_FLAGS_REPLACE, ifinfomsg, nlmsgerr_attrs::NLMSGERR_ATTR_MSG, tcmsg,
};
use libc::{
AF_NETLINK, AF_UNSPEC, ETH_P_ALL, IFF_UP, IFLA_XDP, NETLINK_CAP_ACK, NETLINK_EXT_ACK,
NETLINK_ROUTE, NLA_ALIGNTO, NLA_F_NESTED, NLA_TYPE_MASK, NLM_F_ACK, NLM_F_CREATE, NLM_F_DUMP,
NLM_F_ECHO, NLM_F_EXCL, NLM_F_MULTI, NLM_F_REQUEST, NLMSG_DONE, NLMSG_ERROR, RTM_DELTFILTER,
RTM_GETTFILTER, RTM_NEWQDISC, RTM_NEWTFILTER, RTM_SETLINK, SOCK_RAW, SOL_NETLINK, getsockname,
nlattr, nlmsgerr, nlmsghdr, recv, send, setsockopt, sockaddr_nl, socket,
};
use thiserror::Error;
use crate::{
Pod,
programs::TcAttachType,
util::{bytes_of, tc_handler_make},
};
const _: () = assert!(NLA_ALIGNTO < u8::MAX as i32);
macro_rules! nla_align {
($v:expr) => {{
#[expect(clippy::as_underscore, reason = "statically known to be less than u8::MAX")]
let result = $v.next_multiple_of(NLA_ALIGNTO as _);
result
}};
}
const NLMSG_HDR_LEN: usize = size_of::<nlmsghdr>();
const NLMSG_HDR_ALIGN_LEN: usize = nla_align!(NLMSG_HDR_LEN);
const NLA_HDR_LEN: usize = size_of::<nlattr>();
const NLA_HDR_ALIGN_LEN: usize = nla_align!(NLA_HDR_LEN);
const CLS_BPF_NAME_LEN: usize = 256;
const fn tc_request_attrs_size() -> usize {
NLA_HDR_ALIGN_LEN + nla_align!(c"bpf".to_bytes_with_nul().len())
+ NLA_HDR_ALIGN_LEN
+ NLA_HDR_ALIGN_LEN + nla_align!(size_of::<i32>())
+ NLA_HDR_ALIGN_LEN + nla_align!(CLS_BPF_NAME_LEN)
+ NLA_HDR_ALIGN_LEN + nla_align!(size_of::<u32>())
}
const _: () = assert!(tc_request_attrs_size() == 288);
#[derive(Error, Debug)]
pub(crate) enum NetlinkErrorInternal {
#[error("netlink error: {messages:?}")]
Error {
messages: Vec<CString>,
#[source]
source: io::Error,
},
#[error(transparent)]
IoError(#[from] io::Error),
#[error(transparent)]
NlAttrError(#[from] NlAttrError),
}
#[derive(Error, Debug)]
#[error(transparent)]
#[expect(
unnameable_types,
reason = "the internal error is crate-private but transparently wrapped"
)]
pub struct NetlinkError(#[from] NetlinkErrorInternal);
impl NetlinkError {
pub fn raw_os_error(&self) -> Option<i32> {
let Self(inner) = self;
match inner {
NetlinkErrorInternal::Error { source, .. } => source.raw_os_error(),
NetlinkErrorInternal::IoError(err) => err.raw_os_error(),
NetlinkErrorInternal::NlAttrError(err) => match err {
NlAttrError::BufferLength { .. }
| NlAttrError::HeaderLength { .. }
| NlAttrError::CStrFromBytesWithNul { .. } => None,
},
}
}
}
pub(crate) unsafe fn netlink_set_xdp_fd(
if_index: i32,
fd: Option<BorrowedFd<'_>>,
old_fd: Option<BorrowedFd<'_>>,
flags: u32,
) -> Result<(), NetlinkError> {
let sock = NetlinkSocket::open()?;
let mut req = unsafe { mem::zeroed::<Request>() };
let nlmsg_len = size_of::<nlmsghdr>() + size_of::<ifinfomsg>();
req.header = nlmsghdr {
nlmsg_len: nlmsg_len as u32,
nlmsg_flags: (NLM_F_REQUEST | NLM_F_ACK) as u16,
nlmsg_type: RTM_SETLINK,
nlmsg_pid: 0,
nlmsg_seq: 1,
};
req.if_info.ifi_family = AF_UNSPEC as u8;
req.if_info.ifi_index = if_index;
let attrs_buf = unsafe { request_attributes(&mut req, nlmsg_len) };
let mut attrs = NestedAttrs::new(attrs_buf, IFLA_XDP);
attrs
.write_attr(IFLA_XDP_FD as u16, fd.map_or(-1, |fd| fd.as_raw_fd()))
.map_err(|e| NetlinkError(NetlinkErrorInternal::IoError(e)))?;
if flags > 0 {
attrs
.write_attr(IFLA_XDP_FLAGS as u16, flags)
.map_err(|e| NetlinkError(NetlinkErrorInternal::IoError(e)))?;
}
if flags & XDP_FLAGS_REPLACE != 0 {
attrs
.write_attr(
IFLA_XDP_EXPECTED_FD as u16,
old_fd.map(|fd| fd.as_raw_fd()).unwrap(),
)
.map_err(|e| NetlinkError(NetlinkErrorInternal::IoError(e)))?;
}
let nla_len = attrs
.finish()
.map_err(|e| NetlinkError(NetlinkErrorInternal::IoError(e)))?;
req.header.nlmsg_len += nla_align!(nla_len) as u32;
sock.send(&bytes_of(&req)[..req.header.nlmsg_len as usize])?;
for msg in sock.recv() {
msg?;
}
Ok(())
}
pub(crate) unsafe fn netlink_qdisc_add_clsact(if_index: i32) -> Result<(), NetlinkError> {
let sock = NetlinkSocket::open()?;
let mut req = unsafe { mem::zeroed::<TcRequest>() };
let nlmsg_len = size_of::<nlmsghdr>() + size_of::<tcmsg>();
req.header = nlmsghdr {
nlmsg_len: nlmsg_len as u32,
nlmsg_flags: (NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE) as u16,
nlmsg_type: RTM_NEWQDISC,
nlmsg_pid: 0,
nlmsg_seq: 1,
};
req.tc_info.tcm_family = AF_UNSPEC as u8;
req.tc_info.tcm_ifindex = if_index;
req.tc_info.tcm_handle = tc_handler_make(TC_H_CLSACT, TC_H_UNSPEC);
req.tc_info.tcm_parent = tc_handler_make(TC_H_CLSACT, TC_H_INGRESS);
req.tc_info.tcm_info = 0;
let attrs_buf = unsafe { request_attributes(&mut req, nlmsg_len) };
let (_, attr_len) = write_attr_bytes(attrs_buf, TCA_KIND as u16, c"clsact".to_bytes_with_nul())
.map_err(|e| NetlinkError(NetlinkErrorInternal::IoError(e)))?;
req.header.nlmsg_len += nla_align!(attr_len) as u32;
sock.send(&bytes_of(&req)[..req.header.nlmsg_len as usize])?;
for msg in sock.recv() {
msg?;
}
Ok(())
}
fn write_tc_attach_attrs(
req: &mut TcRequest,
nlmsg_len: usize,
prog_fd: i32,
prog_name: &[u8],
) -> io::Result<()> {
let attrs_buf = unsafe { request_attributes(req, nlmsg_len) };
let (attrs_buf, kind_len) =
write_attr_bytes(attrs_buf, TCA_KIND as u16, c"bpf".to_bytes_with_nul())?;
let mut options = NestedAttrs::new(attrs_buf, TCA_OPTIONS as u16);
options.write_attr(TCA_BPF_FD as u16, prog_fd)?;
options.write_attr_bytes(TCA_BPF_NAME as u16, prog_name)?;
options.write_attr(TCA_BPF_FLAGS as u16, TCA_BPF_FLAG_ACT_DIRECT)?;
let options_len = options.finish()?;
req.header.nlmsg_len += nla_align!(kind_len + options_len) as u32;
Ok(())
}
pub(crate) unsafe fn netlink_qdisc_attach(
if_index: i32,
attach_type: &TcAttachType,
prog_fd: BorrowedFd<'_>,
prog_name: &CStr,
priority: u16,
handle: u32,
create: bool,
) -> Result<(u16, u32), NetlinkError> {
let sock = NetlinkSocket::open()?;
let mut req = unsafe { mem::zeroed::<TcRequest>() };
let nlmsg_len = size_of::<nlmsghdr>() + size_of::<tcmsg>();
let request_flags = if create {
NLM_F_CREATE | NLM_F_EXCL
} else {
0
};
req.header = nlmsghdr {
nlmsg_len: nlmsg_len as u32,
nlmsg_flags: (NLM_F_REQUEST | NLM_F_ACK | NLM_F_ECHO | request_flags) as u16,
nlmsg_type: RTM_NEWTFILTER,
nlmsg_pid: 0,
nlmsg_seq: 1,
};
req.tc_info.tcm_family = AF_UNSPEC as u8;
req.tc_info.tcm_handle = handle; req.tc_info.tcm_ifindex = if_index;
req.tc_info.tcm_parent = attach_type.tc_parent();
req.tc_info.tcm_info = tc_handler_make(
u32::from(priority) << 16,
u32::from(htons(ETH_P_ALL as u16)),
);
write_tc_attach_attrs(
&mut req,
nlmsg_len,
prog_fd.as_raw_fd(),
prog_name.to_bytes_with_nul(),
)
.map_err(|e| NetlinkError(NetlinkErrorInternal::IoError(e)))?;
sock.send(&bytes_of(&req)[..req.header.nlmsg_len as usize])?;
let mut tc_msg: Vec<tcmsg> = Vec::new();
for msg in sock.recv() {
let msg = msg?;
if msg.header.nlmsg_type == RTM_NEWTFILTER {
tc_msg.push(unsafe { ptr::read_unaligned(msg.data.as_ptr().cast()) });
}
}
match tc_msg.as_slice() {
[] => Err(NetlinkError(NetlinkErrorInternal::IoError(
io::Error::other("no RTM_NEWTFILTER reply received, this is a bug in the kernel"),
))),
[tc_msg] => {
let priority = ((tc_msg.tcm_info & TC_H_MAJ_MASK) >> 16) as u16;
Ok((priority, tc_msg.tcm_handle))
}
_tc_msg => Err(NetlinkError(NetlinkErrorInternal::IoError(
io::Error::other(
"multiple RTM_NEWTFILTER replies received, this is a bug in the kernel",
),
))),
}
}
pub(crate) unsafe fn netlink_qdisc_detach(
if_index: i32,
attach_type: TcAttachType,
priority: u16,
handle: u32,
) -> Result<(), NetlinkError> {
let sock = NetlinkSocket::open()?;
let mut req = unsafe { mem::zeroed::<TcRequest>() };
req.header = nlmsghdr {
nlmsg_len: (size_of::<nlmsghdr>() + size_of::<tcmsg>()) as u32,
nlmsg_flags: (NLM_F_REQUEST | NLM_F_ACK) as u16,
nlmsg_type: RTM_DELTFILTER,
nlmsg_pid: 0,
nlmsg_seq: 1,
};
req.tc_info.tcm_family = AF_UNSPEC as u8;
req.tc_info.tcm_handle = handle; req.tc_info.tcm_info = tc_handler_make(
u32::from(priority) << 16,
u32::from(htons(ETH_P_ALL as u16)),
);
req.tc_info.tcm_parent = attach_type.tc_parent();
req.tc_info.tcm_ifindex = if_index;
sock.send(&bytes_of(&req)[..req.header.nlmsg_len as usize])?;
for msg in sock.recv() {
msg?;
}
Ok(())
}
pub(crate) fn netlink_find_filter_with_name(
sock: &NetlinkSocket,
if_index: i32,
attach_type: TcAttachType,
name: &CStr,
) -> Result<impl Iterator<Item = Result<(u16, u32), NetlinkError>>, NetlinkError> {
let mut req = unsafe { mem::zeroed::<TcRequest>() };
let nlmsg_len = size_of::<nlmsghdr>() + size_of::<tcmsg>();
req.header = nlmsghdr {
nlmsg_len: nlmsg_len as u32,
nlmsg_type: RTM_GETTFILTER,
nlmsg_flags: (NLM_F_REQUEST | NLM_F_DUMP) as u16,
nlmsg_pid: 0,
nlmsg_seq: 1,
};
req.tc_info.tcm_family = AF_UNSPEC as u8;
req.tc_info.tcm_handle = 0; req.tc_info.tcm_ifindex = if_index;
req.tc_info.tcm_parent = attach_type.tc_parent();
sock.send(&bytes_of(&req)[..req.header.nlmsg_len as usize])?;
let mut resp = sock.recv();
Ok(iter::from_fn(move || {
loop {
let msg = resp.next()?;
if let Some(result) = (|| {
let msg = msg?;
if msg.header.nlmsg_type != RTM_NEWTFILTER {
return Ok(None);
}
let (tc_msg_buf, attrs_buf) = msg
.data
.split_at_checked(size_of::<tcmsg>())
.ok_or_else(|| {
NetlinkError(NetlinkErrorInternal::IoError(io::Error::other(
"RTM_NEWTFILTER payload smaller than tcmsg",
)))
})?;
let tc_msg: tcmsg = unsafe { ptr::read_unaligned(tc_msg_buf.as_ptr().cast()) };
let priority = (tc_msg.tcm_info >> 16) as u16;
let mut filter = None;
for opt in NlAttrsIterator::new(attrs_buf) {
let opt =
opt.map_err(|e| NetlinkError(NetlinkErrorInternal::NlAttrError(e)))?;
if opt.header.nla_type & NLA_TYPE_MASK as u16 != TCA_OPTIONS as u16 {
continue;
}
for opt in NlAttrsIterator::new(opt.data) {
let opt =
opt.map_err(|e| NetlinkError(NetlinkErrorInternal::NlAttrError(e)))?;
if opt.header.nla_type & NLA_TYPE_MASK as u16 != TCA_BPF_NAME as u16 {
continue;
}
let f_name = CStr::from_bytes_with_nul(opt.data)
.map_err(NlAttrError::CStrFromBytesWithNul)
.map_err(|e| NetlinkError(NetlinkErrorInternal::NlAttrError(e)))?;
if f_name != name {
continue;
}
filter = Some((priority, tc_msg.tcm_handle));
}
}
Ok(filter)
})()
.transpose()
{
break Some(result);
}
}
}))
}
#[doc(hidden)]
pub unsafe fn netlink_set_link_up(if_index: i32) -> Result<(), NetlinkError> {
let sock = NetlinkSocket::open()?;
let mut req = unsafe { mem::zeroed::<Request>() };
let nlmsg_len = size_of::<nlmsghdr>() + size_of::<ifinfomsg>();
req.header = nlmsghdr {
nlmsg_len: nlmsg_len as u32,
nlmsg_flags: (NLM_F_REQUEST | NLM_F_ACK) as u16,
nlmsg_type: RTM_SETLINK,
nlmsg_pid: 0,
nlmsg_seq: 1,
};
req.if_info.ifi_family = AF_UNSPEC as u8;
req.if_info.ifi_index = if_index;
req.if_info.ifi_flags = IFF_UP as u32;
req.if_info.ifi_change = IFF_UP as u32;
sock.send(&bytes_of(&req)[..req.header.nlmsg_len as usize])?;
for msg in sock.recv() {
msg?;
}
Ok(())
}
#[derive(Copy, Clone)]
#[repr(C)]
struct Request {
header: nlmsghdr,
if_info: ifinfomsg,
attrs: [u8; 64],
}
unsafe impl Pod for Request {}
#[derive(Copy, Clone)]
#[repr(C)]
struct TcRequest {
header: nlmsghdr,
tc_info: tcmsg,
attrs: [u8; tc_request_attrs_size()],
}
unsafe impl Pod for TcRequest {}
pub(crate) struct NetlinkSocket {
sock: crate::MockableFd,
_nl_pid: u32,
}
impl NetlinkSocket {
pub(crate) fn open() -> Result<Self, NetlinkErrorInternal> {
let sock = unsafe { socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE) };
if sock < 0 {
return Err(NetlinkErrorInternal::IoError(io::Error::last_os_error()));
}
let sock = unsafe { crate::MockableFd::from_raw_fd(sock) };
let enable = 1i32;
unsafe {
if setsockopt(
sock.as_raw_fd(),
SOL_NETLINK,
NETLINK_EXT_ACK,
ptr::from_ref(&enable).cast(),
size_of_val(&enable) as u32,
) < 0
{
return Err(NetlinkErrorInternal::IoError(io::Error::last_os_error()));
}
if setsockopt(
sock.as_raw_fd(),
SOL_NETLINK,
NETLINK_CAP_ACK,
ptr::from_ref(&enable).cast(),
size_of_val(&enable) as u32,
) < 0
{
return Err(NetlinkErrorInternal::IoError(io::Error::last_os_error()));
}
}
let mut addr = unsafe { mem::zeroed::<sockaddr_nl>() };
addr.nl_family = AF_NETLINK as u16;
let mut addr_len = size_of::<sockaddr_nl>() as u32;
if unsafe {
getsockname(
sock.as_raw_fd(),
ptr::from_mut(&mut addr).cast(),
ptr::from_mut(&mut addr_len).cast(),
)
} < 0
{
return Err(NetlinkErrorInternal::IoError(io::Error::last_os_error()));
}
Ok(Self {
sock,
_nl_pid: addr.nl_pid,
})
}
fn send(&self, msg: &[u8]) -> Result<(), NetlinkErrorInternal> {
if unsafe { send(self.sock.as_raw_fd(), msg.as_ptr().cast(), msg.len(), 0) } < 0 {
return Err(NetlinkErrorInternal::IoError(io::Error::last_os_error()));
}
Ok(())
}
fn recv(&self) -> impl Iterator<Item = Result<NetlinkMessage, NetlinkErrorInternal>> {
let mut scratch = [0u8; 4096];
let mut len = 0;
let mut offset = 0;
let mut multipart = true;
iter::from_fn(move || {
(|| {
loop {
while offset < len {
let message = NetlinkMessage::read(&scratch[offset..len])?;
offset += nla_align!(message.header.nlmsg_len as usize);
multipart = message.header.nlmsg_flags & NLM_F_MULTI as u16 != 0;
return match i32::from(message.header.nlmsg_type) {
NLMSG_ERROR => {
let error = message.error.unwrap();
if error.error == 0 {
continue;
}
let mut messages = Vec::new();
for attr in NlAttrsIterator::new(&message.data) {
let attr = attr?;
if attr.header.nla_type & NLA_TYPE_MASK as u16
!= NLMSGERR_ATTR_MSG as u16
{
continue;
}
let message = CStr::from_bytes_with_nul(attr.data)
.map_err(NlAttrError::CStrFromBytesWithNul)?;
messages.push(message.to_owned());
}
let source = io::Error::from_raw_os_error(-error.error);
Err(NetlinkErrorInternal::Error { messages, source })
}
NLMSG_DONE => Ok(None),
_ => Ok(Some(message)),
};
}
if !multipart {
return Ok(None);
}
let recv_len = unsafe {
recv(
self.sock.as_raw_fd(),
scratch.as_mut_ptr().cast(),
scratch.len(),
0,
)
};
let recv_len = usize::try_from(recv_len).map_err(
|std::num::TryFromIntError { .. }| {
NetlinkErrorInternal::IoError(io::Error::last_os_error())
},
)?;
if recv_len == 0 {
return Ok(None);
}
len = recv_len;
offset = 0;
}
})()
.transpose()
})
}
}
struct NetlinkMessage {
header: nlmsghdr,
data: Vec<u8>,
error: Option<nlmsgerr>,
}
impl NetlinkMessage {
fn read(buf: &[u8]) -> io::Result<Self> {
let header_buf = buf
.get(..NLMSG_HDR_LEN)
.ok_or_else(|| io::Error::other("buffer smaller than nlmsghdr"))?;
let header: nlmsghdr = unsafe { ptr::read_unaligned(header_buf.as_ptr().cast()) };
let msg_len = header.nlmsg_len as usize;
if msg_len < NLMSG_HDR_LEN {
return Err(io::Error::other("invalid nlmsg_len"));
}
let msg = buf
.get(..msg_len)
.ok_or_else(|| io::Error::other("invalid nlmsg_len"))?;
let data = msg
.get(NLMSG_HDR_ALIGN_LEN..)
.ok_or_else(|| io::Error::other("need more data"))?;
let (rest, error) = if header.nlmsg_type == NLMSG_ERROR as u16 {
let (err_buf, rest) = data
.split_at_checked(size_of::<nlmsgerr>())
.ok_or_else(|| io::Error::other("NLMSG_ERROR but not enough space for nlmsgerr"))?;
let err: nlmsgerr = unsafe { ptr::read_unaligned(err_buf.as_ptr().cast()) };
(rest, Some(err))
} else {
(data, None)
};
Ok(Self {
header,
data: rest.to_vec(),
error,
})
}
}
const fn htons(u: u16) -> u16 {
u.to_be()
}
struct NestedAttrs<'a> {
header_buf: &'a mut [u8],
rest: &'a mut [u8],
top_attr_type: u16,
nla_len: usize,
}
impl<'a> NestedAttrs<'a> {
const fn new(buf: &'a mut [u8], top_attr_type: u16) -> Self {
const fn empty() -> &'static mut [u8] {
&mut []
}
let (header_buf, rest) = match buf.split_at_mut_checked(NLA_HDR_ALIGN_LEN) {
Some(parts) => parts,
None => (empty(), empty()),
};
Self {
header_buf,
rest,
top_attr_type,
nla_len: NLA_HDR_ALIGN_LEN,
}
}
fn write_attr<T: Pod>(&mut self, attr_type: u16, value: T) -> io::Result<()> {
let Self {
header_buf: _,
rest,
top_attr_type: _,
nla_len,
} = self;
let buf = mem::take(rest);
let (rest, size) = write_attr(buf, attr_type, value)?;
*nla_len += size;
self.rest = rest;
Ok(())
}
fn write_attr_bytes(&mut self, attr_type: u16, value: &[u8]) -> io::Result<()> {
let Self {
header_buf: _,
rest,
top_attr_type: _,
nla_len,
} = self;
let buf = mem::take(rest);
let (rest, size) = write_attr_bytes(buf, attr_type, value)?;
*nla_len += size;
self.rest = rest;
Ok(())
}
fn finish(self) -> io::Result<usize> {
let Self {
header_buf,
rest: _,
top_attr_type: _,
nla_len,
} = self;
let attr = nlattr {
nla_type: NLA_F_NESTED as u16 | self.top_attr_type,
nla_len: nla_len as u16,
};
let (_, header_len) = write_attr_header(header_buf, attr)?;
debug_assert_eq!(header_len, NLA_HDR_ALIGN_LEN);
Ok(nla_len)
}
}
fn write_attr<T: Pod>(buf: &mut [u8], attr_type: u16, value: T) -> io::Result<(&mut [u8], usize)> {
let value = bytes_of(&value);
write_attr_bytes(buf, attr_type, value)
}
fn write_attr_bytes<'a>(
buf: &'a mut [u8],
attr_type: u16,
value: &[u8],
) -> io::Result<(&'a mut [u8], usize)> {
let attr = nlattr {
nla_type: attr_type,
nla_len: ((NLA_HDR_LEN + value.len()) as u16),
};
let (buf, header_len) = write_attr_header(buf, attr)?;
let (buf, value_len) = write_bytes(buf, value)?;
Ok((buf, header_len + value_len))
}
unsafe impl Pod for nlattr {}
fn write_attr_header(buf: &mut [u8], attr: nlattr) -> io::Result<(&mut [u8], usize)> {
let attr = bytes_of(&attr);
let (buf, header_len) = write_bytes(buf, attr)?;
debug_assert_eq!(header_len, NLA_HDR_ALIGN_LEN);
Ok((buf, header_len))
}
fn write_bytes<'a>(buf: &'a mut [u8], value: &[u8]) -> io::Result<(&'a mut [u8], usize)> {
let align_len = nla_align!(value.len());
let (buf, remaining) = buf
.split_at_mut_checked(align_len)
.ok_or_else(|| io::Error::other("no space left"))?;
buf[..value.len()].copy_from_slice(value);
Ok((remaining, align_len))
}
struct NlAttrsIterator<'a> {
buf: &'a [u8],
}
impl<'a> NlAttrsIterator<'a> {
const fn new(buf: &'a [u8]) -> Self {
Self { buf }
}
}
impl<'a> Iterator for NlAttrsIterator<'a> {
type Item = Result<NlAttr<'a>, NlAttrError>;
fn next(&mut self) -> Option<Self::Item> {
let Self { buf } = self;
if buf.is_empty() {
return None;
}
let buf = mem::take(buf);
let Some((header_buf, buf)) = buf.split_at_checked(NLA_HDR_LEN) else {
return Some(Err(NlAttrError::BufferLength {
size: buf.len(),
expected: NLA_HDR_LEN,
}));
};
let attr: nlattr = unsafe { ptr::read_unaligned(header_buf.as_ptr().cast()) };
let len = attr.nla_len as usize;
let Some(payload_len) = len.checked_sub(NLA_HDR_LEN) else {
return Some(Err(NlAttrError::HeaderLength(len)));
};
let align_len = nla_align!(len);
let payload_align_len = align_len - NLA_HDR_LEN;
let Some((data, buf)) = buf.split_at_checked(payload_align_len) else {
return Some(Err(NlAttrError::BufferLength {
size: buf.len(),
expected: payload_align_len,
}));
};
let data = &data[..payload_len];
self.buf = buf;
Some(Ok(NlAttr { header: attr, data }))
}
}
#[derive(Clone)]
struct NlAttr<'a> {
header: nlattr,
data: &'a [u8],
}
#[derive(Debug, Error, PartialEq, Eq)]
pub(crate) enum NlAttrError {
#[error("invalid buffer size `{size}`, expected `{expected}`")]
BufferLength { size: usize, expected: usize },
#[error("invalid nlattr header length `{0}`")]
HeaderLength(usize),
#[error("invalid CStr from bytes with nul: {0}")]
CStrFromBytesWithNul(#[from] FromBytesWithNulError),
}
unsafe fn request_attributes<T>(req: &mut T, msg_len: usize) -> &mut [u8] {
let req: *mut _ = req;
let req: *mut u8 = req.cast();
let attrs_addr = unsafe { req.add(msg_len) };
let align_offset = attrs_addr.align_offset(NLMSG_ALIGNTO as usize);
let attrs_addr = unsafe { attrs_addr.add(align_offset) };
let len = size_of::<T>() - msg_len - align_offset;
unsafe { slice::from_raw_parts_mut(attrs_addr, len) }
}
#[cfg(test)]
mod tests {
use assert_matches::assert_matches;
use super::*;
#[test]
fn test_nested_attrs() {
let mut buf = [0; 64];
let mut attrs = NestedAttrs::new(&mut buf, IFLA_XDP);
attrs.write_attr(IFLA_XDP_FD as u16, 42u32).unwrap();
attrs
.write_attr(IFLA_XDP_EXPECTED_FD as u16, 24u32)
.unwrap();
let len = attrs.finish().unwrap() as u16;
let nla_len = (NLA_HDR_LEN * 3 + size_of::<u32>() * 2) as u16;
assert_eq!(len, nla_len);
let attr: nlattr = unsafe { ptr::read_unaligned(buf.as_ptr().cast()) };
assert_eq!(attr.nla_type, NLA_F_NESTED as u16 | IFLA_XDP);
assert_eq!(attr.nla_len, nla_len);
let attr: nlattr = unsafe { ptr::read_unaligned(buf[NLA_HDR_LEN..].as_ptr().cast()) };
assert_eq!(attr.nla_type, IFLA_XDP_FD as u16);
assert_eq!(attr.nla_len, (NLA_HDR_LEN + size_of::<u32>()) as u16);
let fd: u32 = unsafe { ptr::read_unaligned(buf[NLA_HDR_LEN * 2..].as_ptr().cast()) };
assert_eq!(fd, 42);
let attr: nlattr = unsafe {
ptr::read_unaligned(buf[NLA_HDR_LEN * 2 + size_of::<u32>()..].as_ptr().cast())
};
assert_eq!(attr.nla_type, IFLA_XDP_EXPECTED_FD as u16);
assert_eq!(attr.nla_len, (NLA_HDR_LEN + size_of::<u32>()) as u16);
let fd: u32 = unsafe {
ptr::read_unaligned(buf[NLA_HDR_LEN * 3 + size_of::<u32>()..].as_ptr().cast())
};
assert_eq!(fd, 24);
}
#[test]
fn test_nlattr_iterator_empty() {
let mut iter = NlAttrsIterator::new(&[]);
assert!(iter.next().is_none());
}
#[test]
fn test_nlattr_iterator_one() {
let mut buf = [0; NLA_HDR_LEN + size_of::<u32>()];
let (_rest, _written) = write_attr(&mut buf, IFLA_XDP_FD as u16, 42u32).unwrap();
let mut iter = NlAttrsIterator::new(&buf);
let attr = iter.next().unwrap().unwrap();
assert_eq!(attr.header.nla_type, IFLA_XDP_FD as u16);
assert_eq!(attr.data.len(), size_of::<u32>());
assert_eq!(u32::from_ne_bytes(attr.data.try_into().unwrap()), 42);
assert!(iter.next().is_none());
}
#[test]
fn test_nlattr_iterator_many() {
let mut buf = [0; (NLA_HDR_LEN + size_of::<u32>()) * 2];
let (rest, _) = write_attr(&mut buf, IFLA_XDP_FD as u16, 42u32).unwrap();
let (_rest, _written) = write_attr(rest, IFLA_XDP_EXPECTED_FD as u16, 12u32).unwrap();
let mut iter = NlAttrsIterator::new(&buf);
let attr = iter.next().unwrap().unwrap();
assert_eq!(attr.header.nla_type, IFLA_XDP_FD as u16);
assert_eq!(attr.data.len(), size_of::<u32>());
assert_eq!(u32::from_ne_bytes(attr.data.try_into().unwrap()), 42);
let attr = iter.next().unwrap().unwrap();
assert_eq!(attr.header.nla_type, IFLA_XDP_EXPECTED_FD as u16);
assert_eq!(attr.data.len(), size_of::<u32>());
assert_eq!(u32::from_ne_bytes(attr.data.try_into().unwrap()), 12);
assert!(iter.next().is_none());
}
#[test]
fn test_nlattr_iterator_nested() {
let mut buf = [0; 1024];
let mut options = NestedAttrs::new(&mut buf, TCA_OPTIONS as u16);
options.write_attr(TCA_BPF_FD as u16, 42).unwrap();
let name = CString::new("foo").unwrap();
options
.write_attr_bytes(TCA_BPF_NAME as u16, name.to_bytes_with_nul())
.unwrap();
options.finish().unwrap();
let mut iter = NlAttrsIterator::new(&buf);
let outer = iter.next().unwrap().unwrap();
assert_eq!(
outer.header.nla_type & NLA_TYPE_MASK as u16,
TCA_OPTIONS as u16
);
let mut iter = NlAttrsIterator::new(outer.data);
let inner = iter.next().unwrap().unwrap();
assert_eq!(
inner.header.nla_type & NLA_TYPE_MASK as u16,
TCA_BPF_FD as u16
);
let inner = iter.next().unwrap().unwrap();
assert_eq!(
inner.header.nla_type & NLA_TYPE_MASK as u16,
TCA_BPF_NAME as u16
);
let name = CStr::from_bytes_with_nul(inner.data).unwrap();
assert_eq!(name.to_str().unwrap(), "foo");
}
fn tc_request(name: &[u8]) -> io::Result<()> {
let mut req = unsafe { mem::zeroed::<TcRequest>() };
let nlmsg_len = size_of::<nlmsghdr>() + size_of::<tcmsg>();
req.header.nlmsg_len = nlmsg_len as u32;
write_tc_attach_attrs(&mut req, nlmsg_len, 0, name)
}
#[test]
fn tc_request_fits_max_length_name() {
assert_matches!(tc_request(&[b'a'; CLS_BPF_NAME_LEN]), Ok(()));
}
#[test]
fn tc_request_rejects_oversized_name() {
assert_matches!(
tc_request(&[b'a'; CLS_BPF_NAME_LEN + 1]),
Err(err) => {
assert_eq!(err.kind(), io::ErrorKind::Other);
assert_eq!(err.to_string(), "no space left");
}
);
}
}