netlink_socket2/
chained.rs1use std::{
2 fmt,
3 io::{self, IoSlice},
4 sync::Arc,
5};
6
7use netlink_bindings::traits::NetlinkChained;
8
9use crate::{NetlinkReplyInner, NetlinkSocket, ReplyError, Socket, RECV_BUF_SIZE};
10
11impl NetlinkSocket {
12 #[cfg_attr(not(feature = "async"), maybe_async::maybe_async)]
20 pub async fn request_chained<'a, Chained>(
21 &'a mut self,
22 request: &'a Chained,
23 ) -> io::Result<NetlinkReplyChained<'a>>
24 where
25 Chained: NetlinkChained,
26 {
27 let sock = Self::get_socket_cached(&mut self.sock, request.protonum())?;
28
29 Self::write_buf(sock, &[IoSlice::new(request.payload())]).await?;
30
31 Ok(NetlinkReplyChained {
32 sock,
33 buf: &mut self.buf,
34 request,
35 inner: NetlinkReplyInner {
36 buf_offset: 0,
37 buf_read: 0,
38 },
39 done: Bits::with_len(request.chain_len()),
40 })
41 }
42}
43
44pub struct NetlinkReplyChained<'sock> {
45 inner: NetlinkReplyInner,
46 request: &'sock dyn NetlinkChained,
47 sock: &'sock mut Socket,
48 buf: &'sock mut Arc<[u8; RECV_BUF_SIZE]>,
49 done: Bits,
50}
51
52impl NetlinkReplyChained<'_> {
53 #[cfg_attr(not(feature = "async"), maybe_async::maybe_async)]
54 pub async fn recv_all(&mut self) -> Result<(), ReplyError> {
55 while let Some(res) = self.recv().await {
56 res?;
57 }
58 Ok(())
59 }
60
61 #[cfg_attr(not(feature = "async"), maybe_async::maybe_async)]
62 pub async fn recv(&mut self) -> Option<Result<(), ReplyError>> {
63 if self.done.is_all() {
64 return None;
65 }
66
67 let buf = Arc::make_mut(self.buf);
68
69 loop {
70 match self.inner.recv(self.sock, buf).await {
71 Err(io_err) => {
72 self.done.set_all();
73 return Some(Err(io_err.into()));
74 }
75 Ok((seq, res)) => {
76 let Some(index) = self.request.get_index(seq) else {
77 continue;
78 };
79 match res {
80 Ok(_) => return Some(Ok(())),
81 Err(mut err) => {
82 if err.code.raw_os_error().unwrap() == 0 {
83 self.done.set(index);
84 return Some(Ok(()));
85 } else {
86 self.done.set_all();
87 err.chained_name = Some(self.request.name(index));
88 if err.has_context() {
89 err.lookup = self.request.lookup(index);
90 err.reply_buf = Some(self.buf.clone());
91 }
92 return Some(Err(err));
93 };
94 }
95 }
96 }
97 };
98 }
99 }
100}
101
102#[derive(Clone)]
103enum Bits {
104 Inline(u64),
105 Vec(Vec<u64>),
106}
107
108impl fmt::Debug for Bits {
109 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
110 let n = self.count_zeros();
111 write!(f, "{n} replies pending")
112 }
113}
114
115impl Bits {
116 fn with_len(len: usize) -> Self {
117 if len < 64 {
118 Self::Inline(u64::MAX << (len % 64))
119 } else {
120 let mut vec = vec![0; len.div_ceil(64)];
121 *vec.last_mut().unwrap() |= u64::MAX << (len % 64);
122 Self::Vec(vec)
123 }
124 }
125
126 fn set(&mut self, index: usize) {
127 match self {
128 Self::Inline(w) => *w |= 1u64 << index,
129 Self::Vec(bits) => bits[index / 64] |= 1u64 << (index % 64),
130 }
131 }
132
133 fn is_all(&self) -> bool {
134 match self {
135 Self::Inline(w) => *w == u64::MAX,
136 Self::Vec(bits) => bits.iter().all(|w| *w == u64::MAX),
137 }
138 }
139
140 fn set_all(&mut self) {
141 match self {
142 Self::Inline(w) => *w = u64::MAX,
143 Self::Vec(bits) => bits.iter_mut().for_each(|w| *w = u64::MAX),
144 }
145 }
146
147 fn count_zeros(&self) -> usize {
148 match self {
149 Self::Inline(w) => w.count_zeros() as usize,
150 Self::Vec(bits) => bits.iter().map(|s| s.count_zeros() as usize).sum(),
151 }
152 }
153}