netlink_socket2/
lib.rs

1#![allow(clippy::doc_lazy_continuation)]
2#![doc = include_str!("../README.md")]
3
4use std::{
5    collections::{hash_map::Entry, HashMap},
6    io::{self, ErrorKind, IoSlice},
7    marker::PhantomData,
8    os::fd::{AsRawFd, FromRawFd, OwnedFd},
9    sync::Arc,
10};
11
12#[cfg(not(feature = "async"))]
13use std::{
14    io::{Read, Write},
15    net::TcpStream as Socket,
16};
17
18#[cfg(feature = "tokio")]
19use tokio::net::TcpStream as Socket;
20
21#[cfg(feature = "smol")]
22use smol::{
23    io::{AsyncReadExt, AsyncWriteExt},
24    net::TcpStream as Socket,
25};
26
27use netlink_bindings::{
28    builtin::PushNlmsghdr,
29    nlctrl,
30    traits::{NetlinkRequest, Protocol},
31    utils,
32};
33
34mod chained;
35mod error;
36
37pub use chained::NetlinkReplyChained;
38pub use error::ReplyError;
39
40/// Netlink documentation recommends max(8192, page_size)
41pub const RECV_BUF_SIZE: usize = 8192;
42
43pub struct NetlinkSocket {
44    buf: Arc<[u8; RECV_BUF_SIZE]>,
45    cache: HashMap<&'static [u8], u16>,
46    sock: HashMap<u16, Socket>,
47    seq: u32,
48}
49
50impl NetlinkSocket {
51    #[allow(clippy::new_without_default)]
52    pub fn new() -> Self {
53        Self {
54            buf: Arc::new([0u8; RECV_BUF_SIZE]),
55            cache: HashMap::default(),
56            sock: HashMap::new(),
57            seq: 1,
58        }
59    }
60
61    fn get_socket_cached(
62        cache: &mut HashMap<u16, Socket>,
63        protonum: u16,
64    ) -> io::Result<&mut Socket> {
65        match cache.entry(protonum) {
66            Entry::Occupied(sock) => Ok(sock.into_mut()),
67            Entry::Vacant(ent) => {
68                let sock = Self::get_socket_new(protonum)?;
69                Ok(ent.insert(sock))
70            }
71        }
72    }
73
74    fn get_socket_new(family: u16) -> io::Result<Socket> {
75        let fd = unsafe {
76            libc::socket(
77                libc::AF_NETLINK,
78                libc::SOCK_RAW | libc::SOCK_CLOEXEC,
79                family as i32,
80            )
81        };
82        if fd < 0 {
83            return Err(io::Error::from_raw_os_error(-fd));
84        }
85        let fd = unsafe { OwnedFd::from_raw_fd(fd) };
86
87        // Enable extended attributes in libc::NLMSG_ERROR and libc::NLMSG_DONE
88        let res = unsafe {
89            libc::setsockopt(
90                fd.as_raw_fd(),
91                libc::SOL_NETLINK,
92                libc::NETLINK_EXT_ACK,
93                (&1u32) as *const u32 as *const libc::c_void,
94                4,
95            )
96        };
97        if res < 0 {
98            return Err(io::Error::from_raw_os_error(-res));
99        }
100
101        let sock: std::net::TcpStream = fd.into();
102
103        #[cfg(feature = "async")]
104        {
105            sock.set_nonblocking(true)?;
106            Socket::try_from(sock)
107        }
108
109        #[cfg(not(feature = "async"))]
110        Ok(sock)
111    }
112
113    /// Reserve a sequential chunk of `seq` values, so chained messages don't
114    /// get confused. A random `seq` number might be used just as well.
115    pub fn reserve_seq(&mut self, len: u32) -> u32 {
116        let seq = self.seq;
117        self.seq = self.seq.wrapping_add(len);
118        seq
119    }
120
121    #[cfg_attr(not(feature = "async"), maybe_async::maybe_async)]
122    pub async fn request<'sock, Request>(
123        &'sock mut self,
124        request: &Request,
125    ) -> io::Result<NetlinkReply<'sock, Request>>
126    where
127        Request: NetlinkRequest,
128    {
129        let (protonum, request_type) = match request.protocol() {
130            Protocol::Raw {
131                protonum,
132                request_type,
133            } => (protonum, request_type),
134            Protocol::Generic(name) => (libc::GENL_ID_CTRL as u16, self.resolve(name).await?),
135        };
136
137        self.request_raw(request, protonum, request_type).await
138    }
139
140    #[cfg_attr(not(feature = "async"), maybe_async::maybe_async)]
141    async fn resolve(&mut self, family_name: &'static [u8]) -> io::Result<u16> {
142        if let Some(id) = self.cache.get(family_name) {
143            return Ok(*id);
144        }
145
146        let mut request = nlctrl::Request::new().op_getfamily_do_request();
147        request.encode().push_family_name_bytes(family_name);
148
149        let Protocol::Raw {
150            protonum,
151            request_type,
152        } = request.protocol()
153        else {
154            unreachable!()
155        };
156        assert_eq!(protonum, libc::NETLINK_GENERIC as u16);
157        assert_eq!(request_type, libc::GENL_ID_CTRL as u16);
158
159        let mut iter = self.request_raw(&request, protonum, request_type).await?;
160        if let Some(reply) = iter.recv().await {
161            let Ok(id) = reply?.get_family_id() else {
162                return Err(ErrorKind::Unsupported.into());
163            };
164            self.cache.insert(family_name, id);
165            return Ok(id);
166        }
167
168        Err(ErrorKind::UnexpectedEof.into())
169    }
170
171    #[cfg_attr(not(feature = "async"), maybe_async::maybe_async)]
172    async fn request_raw<'sock, Request>(
173        &'sock mut self,
174        request: &Request,
175        protonum: u16,
176        request_type: u16,
177    ) -> io::Result<NetlinkReply<'sock, Request>>
178    where
179        Request: NetlinkRequest,
180    {
181        let seq = self.reserve_seq(1);
182        let sock = Self::get_socket_cached(&mut self.sock, protonum)?;
183
184        let mut header = PushNlmsghdr::new();
185        header.set_len(header.as_slice().len() as u32 + request.payload().len() as u32);
186        header.set_type(request_type);
187        header.set_flags(request.flags() | libc::NLM_F_REQUEST as u16 | libc::NLM_F_ACK as u16);
188        header.set_seq(seq);
189
190        Self::write_buf(
191            sock,
192            &[
193                IoSlice::new(header.as_slice()),
194                IoSlice::new(request.payload()),
195            ],
196        )
197        .await?;
198
199        Ok(NetlinkReply {
200            sock,
201            buf: &mut self.buf,
202            inner: NetlinkReplyInner {
203                buf_offset: 0,
204                buf_read: 0,
205            },
206            seq: header.seq(),
207            done: false,
208            phantom: PhantomData,
209        })
210    }
211
212    #[cfg_attr(not(feature = "async"), maybe_async::maybe_async)]
213    async fn write_buf(sock: &mut Socket, payload: &[IoSlice<'_>]) -> io::Result<()> {
214        loop {
215            #[cfg(not(feature = "tokio"))]
216            let res = sock.write_vectored(payload).await;
217
218            #[cfg(feature = "tokio")]
219            let res = loop {
220                // Some subsystems don't correctly implement io notifications, which tokio runtime
221                // expects to receive before doing any actual io, hence we instead always attempt an io
222                // operation first.
223                let res = sock.try_write_vectored(payload);
224                if matches!(&res, Err(err) if err.kind() == ErrorKind::WouldBlock) {
225                    sock.writable().await?;
226                    continue;
227                }
228                break res;
229            };
230
231            match res {
232                Ok(sent) if sent != payload.iter().map(|s| s.len()).sum() => {
233                    return Err(io::Error::other("Couldn't send the whole message"));
234                }
235                Ok(_) => return Ok(()),
236                Err(err) if err.kind() == ErrorKind::Interrupted => continue,
237                Err(err) => return Err(err),
238            }
239        }
240    }
241}
242
243struct NetlinkReplyInner {
244    buf_offset: usize,
245    buf_read: usize,
246}
247
248impl NetlinkReplyInner {
249    #[cfg_attr(not(feature = "async"), maybe_async::maybe_async)]
250    async fn read_buf(sock: &mut Socket, buf: &mut [u8]) -> io::Result<usize> {
251        loop {
252            #[cfg(not(feature = "tokio"))]
253            let res = sock.read(&mut buf[..]).await;
254
255            #[cfg(feature = "tokio")]
256            let res = {
257                // Some subsystems don't correctly implement io notifications, which tokio
258                // runtime expects to receive before doing any actual io, hence we instead
259                // always attempt an io operation first.
260                let res = sock.try_read(&mut buf[..]);
261                if matches!(&res, Err(err) if err.kind() == ErrorKind::WouldBlock) {
262                    sock.readable().await?;
263                    continue;
264                }
265                res
266            };
267
268            match res {
269                Ok(read) => return Ok(read),
270                Err(err) if err.kind() == ErrorKind::Interrupted => continue,
271                Err(err) => return Err(err),
272            }
273        }
274    }
275
276    #[allow(clippy::type_complexity)]
277    #[cfg_attr(not(feature = "async"), maybe_async::maybe_async)]
278    pub async fn recv(
279        &mut self,
280        sock: &mut Socket,
281        buf: &mut [u8; RECV_BUF_SIZE],
282    ) -> io::Result<(u32, Result<(usize, usize), ReplyError>)> {
283        if self.buf_offset == self.buf_read {
284            self.buf_read = Self::read_buf(sock, &mut buf[..]).await?;
285            self.buf_offset = 0;
286        }
287
288        let packet = &buf[self.buf_offset..self.buf_read];
289
290        let too_short_err = || io::Error::other("Received packet is too short");
291
292        let Some(header) = packet.get(..PushNlmsghdr::len()) else {
293            return Err(too_short_err());
294        };
295        let header = PushNlmsghdr::new_from_slice(header).unwrap();
296
297        let payload_start = self.buf_offset + PushNlmsghdr::len();
298        self.buf_offset += header.get_len() as usize;
299
300        match header.get_type() as i32 {
301            libc::NLMSG_DONE | libc::NLMSG_ERROR => {
302                let Some(code) = packet.get(16..20) else {
303                    return Err(too_short_err());
304                };
305                let code = utils::parse_i32(code).unwrap();
306
307                let (echo_start, echo_end) =
308                    if code == 0 || header.get_type() == libc::NLMSG_DONE as u16 {
309                        (20, 20)
310                    } else {
311                        let Some(echo_header) = packet.get(20..(20 + PushNlmsghdr::len())) else {
312                            return Err(too_short_err());
313                        };
314                        let echo_header = PushNlmsghdr::new_from_slice(echo_header).unwrap();
315
316                        if echo_header.flags() & libc::NLM_F_CAPPED as u16 == 0 {
317                            let start = echo_header.get_len();
318                            if packet.len() < start as usize + 20 {
319                                return Err(too_short_err());
320                            }
321
322                            (20 + 16, 20 + start as usize)
323                        } else {
324                            let ext_ack_start = 20 + PushNlmsghdr::len();
325                            (ext_ack_start, ext_ack_start)
326                        }
327                    };
328
329                Ok((
330                    header.seq(),
331                    Err(ReplyError {
332                        code: io::Error::from_raw_os_error(-code),
333                        request_bounds: (echo_start as u32, echo_end as u32),
334                        ext_ack_bounds: (echo_end as u32, self.buf_offset as u32),
335                        reply_buf: None,
336                        chained_name: None,
337                        lookup: |_, _, _| Default::default(),
338                    }),
339                ))
340            }
341            libc::NLMSG_NOOP => Ok((
342                header.seq(),
343                Err(io::Error::other("Received NLMSG_NOOP").into()),
344            )),
345            libc::NLMSG_OVERRUN => Ok((
346                header.seq(),
347                Err(io::Error::other("Received NLMSG_OVERRUN").into()),
348            )),
349            _ => Ok((header.seq(), Ok((payload_start, self.buf_offset)))),
350        }
351    }
352}
353
354pub struct NetlinkReply<'sock, Request: NetlinkRequest> {
355    inner: NetlinkReplyInner,
356    sock: &'sock mut Socket,
357    buf: &'sock mut Arc<[u8; RECV_BUF_SIZE]>,
358    seq: u32,
359    done: bool,
360    phantom: PhantomData<Request>,
361}
362
363impl<Request: NetlinkRequest> NetlinkReply<'_, Request> {
364    #[cfg_attr(not(feature = "async"), maybe_async::maybe_async)]
365    pub async fn recv_one(&mut self) -> Result<Request::ReplyType<'_>, ReplyError> {
366        if let Some(res) = self.recv().await {
367            return res;
368        }
369        Err(io::Error::other("Reply didn't contain data").into())
370    }
371
372    #[cfg_attr(not(feature = "async"), maybe_async::maybe_async)]
373    pub async fn recv_ack(&mut self) -> Result<(), ReplyError> {
374        if let Some(res) = self.recv().await {
375            res?;
376            return Err(io::Error::other("Reply isn't just an ack").into());
377        }
378        Ok(())
379    }
380
381    #[cfg_attr(not(feature = "async"), maybe_async::maybe_async)]
382    pub async fn recv(&mut self) -> Option<Result<Request::ReplyType<'_>, ReplyError>> {
383        if self.done {
384            return None;
385        }
386
387        let buf = Arc::make_mut(self.buf);
388
389        loop {
390            match self.inner.recv(self.sock, buf).await {
391                Err(io_err) => {
392                    self.done = true;
393                    return Some(Err(io_err.into()));
394                }
395                Ok((seq, res)) => {
396                    if seq != self.seq {
397                        continue;
398                    }
399                    return match res {
400                        Ok((l, r)) => Some(Ok(Request::decode_reply(&self.buf[l..r]))),
401                        Err(mut err) => {
402                            self.done = true;
403                            if err.code.raw_os_error().unwrap() == 0 {
404                                None
405                            } else {
406                                if err.has_context() {
407                                    err.lookup = Request::lookup;
408                                    err.reply_buf = Some(self.buf.clone());
409                                }
410                                Some(Err(err))
411                            }
412                        }
413                    };
414                }
415            };
416        }
417    }
418}