Skip to main content

mnl/
callback.rs

1use crate::NlMessages;
2use mnl_sys::{self, libc};
3
4use std::{io, ptr};
5
6/// The result of processing a batch of netlink responses.
7pub enum CbResult {
8    /// Everything went fine and this batch is finished processing.
9    Stop,
10    /// Everything went fine, but we expect more messages to come back from the kernel for this
11    /// batch.
12    Ok,
13}
14
15/// Callback function signature.
16/// TODO: Write abstraction for `nlmsghdr` that can reach all fields and payload.
17pub type Callback<T> = fn(msg: &libc::nlmsghdr, data: &mut T) -> libc::c_int;
18
19/// Callback runqueue for netlink messages. Checks that all netlink messages in `buffer` are OK.
20/// `buffer` must be aligned to `align_of::<nlmsghdr>()`, or this fails.
21pub fn cb_run(buffer: &[u8], seq: u32, portid: u32) -> io::Result<CbResult> {
22    // NOTE: See comment on [`validate_messages`] for why we need to validate messages here.
23    validate_messages(buffer)?;
24
25    let len = buffer.len();
26    let buf = buffer.as_ptr() as *const libc::c_void;
27
28    log::debug!("Processing {} byte netlink message without a callback", len);
29    match unsafe { mnl_sys::mnl_cb_run(buf, len, seq, portid, None, ptr::null_mut()) } {
30        i if i <= mnl_sys::MNL_CB_ERROR => Err(io::Error::last_os_error()),
31        mnl_sys::MNL_CB_STOP => Ok(CbResult::Stop),
32        _ => Ok(CbResult::Ok),
33    }
34}
35
36/// Callback runqueue for netlink messages. Checks that all netlink messages in `buffer` are OK.
37/// Calls the given `callback` if needed.
38/// `buffer` must be aligned to `align_of::<nlmsghdr>()`, or this fails.
39pub fn cb_run2<T>(
40    buffer: &[u8],
41    seq: u32,
42    portid: u32,
43    callback: Callback<T>,
44    data: &mut T,
45) -> io::Result<CbResult> {
46    // NOTE: See comment on [`validate_messages`] for why we need to validate messages here.
47    validate_messages(buffer)?;
48
49    let len = buffer.len();
50    let buf = buffer.as_ptr() as *const libc::c_void;
51    let mut callback_context = CallbackContext { callback, data };
52    log::debug!("Processing {} byte netlink message with callback", len);
53    match unsafe {
54        mnl_sys::mnl_cb_run(
55            buf,
56            len,
57            seq,
58            portid,
59            Some(callback_wrapper::<T>),
60            &mut callback_context as *mut _ as *mut libc::c_void,
61        )
62    } {
63        i if i <= mnl_sys::MNL_CB_ERROR => Err(io::Error::last_os_error()),
64        mnl_sys::MNL_CB_STOP => Ok(CbResult::Stop),
65        _ => Ok(CbResult::Ok),
66    }
67}
68
69/// libmnl contains a bug in `mnl_nlmsg_ok` where it casts `nlh->nlmsg_len` to an `int`,
70/// i.e. `(int)nlh->nlmsg_len`. This becomes negative if `nlmsg_len` is greater than `INT_MAX`,
71/// causing the validation to succeed even if the buffer is too small. `mnl_nlmsg_ok` is
72/// used by `mnl_cb_run` and `mnl_cb_run2`.
73///
74/// This was fixed on 2023-11-04 in commit `754c9de5ea1bea821495523cf01989299552e524`,
75/// but the latest version of libmnl 1.0.5 was released on 2022-04-05, so as of writing
76/// there is no released version of libmnl that contains the fix.
77///
78/// Thus we need our own validation.
79///
80/// See the libmnl git repo and that commit for details: git://git.netfilter.org/libmnl
81///
82/// This addresses [RUSTSEC-2025-0142](https://rustsec.org/advisories/RUSTSEC-2025-0142.html).
83fn validate_messages(buffer: &[u8]) -> io::Result<()> {
84    NlMessages::new(buffer).try_for_each(|msg| {
85        msg?;
86        Ok(())
87    })
88}
89
90/// Internal struct for helping to convert the unsafe FFI callback to the safe `Callback`.
91struct CallbackContext<'a, T> {
92    pub callback: Callback<T>,
93    pub data: &'a mut T,
94}
95
96/// Internal FFI callback converting the callback from libmnl into a `Callback<T>` callback.
97extern "C" fn callback_wrapper<T>(
98    nlh: *const libc::nlmsghdr,
99    data: *mut libc::c_void,
100) -> libc::c_int {
101    let context: &mut CallbackContext<'_, T> =
102        unsafe { &mut *(data as *mut CallbackContext<'_, T>) };
103    (context.callback)(unsafe { &*nlh }, context.data)
104}