netlink_socket2/
chained.rs

1use std::{
2    fmt,
3    io::{self, IoSlice},
4    sync::Arc,
5};
6
7use netlink_bindings::traits::NetlinkChained;
8
9use crate::{NetlinkReplyInner, NetlinkSocket, ReplyError, Socket, RECV_BUF_SIZE};
10
11impl NetlinkSocket {
12    /// Execute a chained request (experimental)
13    ///
14    /// Some subsystems have special requirements for related requests,
15    /// expecting certain types of messages to be sent within a single write
16    /// operation. For example transactions in nftables subsystem.
17    ///
18    /// Chained requests currently don't support replies carrying data.
19    #[cfg_attr(not(feature = "async"), maybe_async::maybe_async)]
20    pub async fn request_chained<'a, Chained>(
21        &'a mut self,
22        request: &'a Chained,
23    ) -> io::Result<NetlinkReplyChained<'a>>
24    where
25        Chained: NetlinkChained,
26    {
27        let sock = Self::get_socket_cached(&mut self.sock, request.protonum())?;
28
29        Self::write_buf(sock, &[IoSlice::new(request.payload())]).await?;
30
31        Ok(NetlinkReplyChained {
32            sock,
33            buf: &mut self.buf,
34            request,
35            inner: NetlinkReplyInner {
36                buf_offset: 0,
37                buf_read: 0,
38            },
39            done: Bits::with_len(request.chain_len()),
40        })
41    }
42}
43
44pub struct NetlinkReplyChained<'sock> {
45    inner: NetlinkReplyInner,
46    request: &'sock dyn NetlinkChained,
47    sock: &'sock mut Socket,
48    buf: &'sock mut Arc<[u8; RECV_BUF_SIZE]>,
49    done: Bits,
50}
51
52impl NetlinkReplyChained<'_> {
53    #[cfg_attr(not(feature = "async"), maybe_async::maybe_async)]
54    pub async fn recv_all(&mut self) -> Result<(), ReplyError> {
55        while let Some(res) = self.recv().await {
56            res?;
57        }
58        Ok(())
59    }
60
61    #[cfg_attr(not(feature = "async"), maybe_async::maybe_async)]
62    pub async fn recv(&mut self) -> Option<Result<(), ReplyError>> {
63        if self.done.is_all() {
64            return None;
65        }
66
67        let buf = Arc::make_mut(self.buf);
68
69        loop {
70            match self.inner.recv(self.sock, buf).await {
71                Err(io_err) => {
72                    self.done.set_all();
73                    return Some(Err(io_err.into()));
74                }
75                Ok((seq, res)) => {
76                    let Some(index) = self.request.get_index(seq) else {
77                        continue;
78                    };
79                    match res {
80                        Ok(_) => return Some(Ok(())),
81                        Err(mut err) => {
82                            if err.code.raw_os_error().unwrap() == 0 {
83                                self.done.set(index);
84                                return Some(Ok(()));
85                            } else {
86                                self.done.set_all();
87                                err.chained_name = Some(self.request.name(index));
88                                if err.has_context() {
89                                    err.lookup = self.request.lookup(index);
90                                    err.reply_buf = Some(self.buf.clone());
91                                }
92                                return Some(Err(err));
93                            };
94                        }
95                    }
96                }
97            };
98        }
99    }
100}
101
102#[derive(Clone)]
103enum Bits {
104    Inline(u64),
105    Vec(Vec<u64>),
106}
107
108impl fmt::Debug for Bits {
109    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
110        let n = self.count_zeros();
111        write!(f, "{n} replies pending")
112    }
113}
114
115impl Bits {
116    fn with_len(len: usize) -> Self {
117        if len < 64 {
118            Self::Inline(u64::MAX << (len % 64))
119        } else {
120            let mut vec = vec![0; len.div_ceil(64)];
121            *vec.last_mut().unwrap() |= u64::MAX << (len % 64);
122            Self::Vec(vec)
123        }
124    }
125
126    fn set(&mut self, index: usize) {
127        match self {
128            Self::Inline(w) => *w |= 1u64 << index,
129            Self::Vec(bits) => bits[index / 64] |= 1u64 << (index % 64),
130        }
131    }
132
133    fn is_all(&self) -> bool {
134        match self {
135            Self::Inline(w) => *w == u64::MAX,
136            Self::Vec(bits) => bits.iter().all(|w| *w == u64::MAX),
137        }
138    }
139
140    fn set_all(&mut self) {
141        match self {
142            Self::Inline(w) => *w = u64::MAX,
143            Self::Vec(bits) => bits.iter_mut().for_each(|w| *w = u64::MAX),
144        }
145    }
146
147    fn count_zeros(&self) -> usize {
148        match self {
149            Self::Inline(w) => w.count_zeros() as usize,
150            Self::Vec(bits) => bits.iter().map(|s| s.count_zeros() as usize).sum(),
151        }
152    }
153}