1#![allow(clippy::doc_lazy_continuation)]
2#![doc = include_str!("../README.md")]
3
4use std::{
5 collections::{hash_map::Entry, HashMap},
6 io::{self, ErrorKind, IoSlice},
7 marker::PhantomData,
8 os::fd::{AsRawFd, FromRawFd, OwnedFd},
9 sync::Arc,
10};
11
12#[cfg(not(feature = "async"))]
13use std::{
14 io::{Read, Write},
15 net::TcpStream as Socket,
16};
17
18#[cfg(feature = "tokio")]
19use tokio::net::TcpStream as Socket;
20
21#[cfg(feature = "smol")]
22use smol::{
23 io::{AsyncReadExt, AsyncWriteExt},
24 net::TcpStream as Socket,
25};
26
27use netlink_bindings::{
28 builtin::PushNlmsghdr,
29 nlctrl,
30 traits::{NetlinkRequest, Protocol},
31 utils,
32};
33
34mod chained;
35mod error;
36
37pub use chained::NetlinkReplyChained;
38pub use error::ReplyError;
39
40pub const RECV_BUF_SIZE: usize = 8192;
42
43pub struct NetlinkSocket {
44 buf: Arc<[u8; RECV_BUF_SIZE]>,
45 cache: HashMap<&'static [u8], u16>,
46 sock: HashMap<u16, Socket>,
47 seq: u32,
48}
49
50impl NetlinkSocket {
51 #[allow(clippy::new_without_default)]
52 pub fn new() -> Self {
53 Self {
54 buf: Arc::new([0u8; RECV_BUF_SIZE]),
55 cache: HashMap::default(),
56 sock: HashMap::new(),
57 seq: 1,
58 }
59 }
60
61 fn get_socket_cached(
62 cache: &mut HashMap<u16, Socket>,
63 protonum: u16,
64 ) -> io::Result<&mut Socket> {
65 match cache.entry(protonum) {
66 Entry::Occupied(sock) => Ok(sock.into_mut()),
67 Entry::Vacant(ent) => {
68 let sock = Self::get_socket_new(protonum)?;
69 Ok(ent.insert(sock))
70 }
71 }
72 }
73
74 fn get_socket_new(family: u16) -> io::Result<Socket> {
75 let fd = unsafe {
76 libc::socket(
77 libc::AF_NETLINK,
78 libc::SOCK_RAW | libc::SOCK_CLOEXEC,
79 family as i32,
80 )
81 };
82 if fd < 0 {
83 return Err(io::Error::from_raw_os_error(-fd));
84 }
85 let fd = unsafe { OwnedFd::from_raw_fd(fd) };
86
87 let res = unsafe {
89 libc::setsockopt(
90 fd.as_raw_fd(),
91 libc::SOL_NETLINK,
92 libc::NETLINK_EXT_ACK,
93 (&1u32) as *const u32 as *const libc::c_void,
94 4,
95 )
96 };
97 if res < 0 {
98 return Err(io::Error::from_raw_os_error(-res));
99 }
100
101 let sock: std::net::TcpStream = fd.into();
102
103 #[cfg(feature = "async")]
104 {
105 sock.set_nonblocking(true)?;
106 Socket::try_from(sock)
107 }
108
109 #[cfg(not(feature = "async"))]
110 Ok(sock)
111 }
112
113 pub fn reserve_seq(&mut self, len: u32) -> u32 {
116 let seq = self.seq;
117 self.seq = self.seq.wrapping_add(len);
118 seq
119 }
120
121 #[cfg_attr(not(feature = "async"), maybe_async::maybe_async)]
122 pub async fn request<'sock, Request>(
123 &'sock mut self,
124 request: &Request,
125 ) -> io::Result<NetlinkReply<'sock, Request>>
126 where
127 Request: NetlinkRequest,
128 {
129 let (protonum, request_type) = match request.protocol() {
130 Protocol::Raw {
131 protonum,
132 request_type,
133 } => (protonum, request_type),
134 Protocol::Generic(name) => (libc::GENL_ID_CTRL as u16, self.resolve(name).await?),
135 };
136
137 self.request_raw(request, protonum, request_type).await
138 }
139
140 #[cfg_attr(not(feature = "async"), maybe_async::maybe_async)]
141 async fn resolve(&mut self, family_name: &'static [u8]) -> io::Result<u16> {
142 if let Some(id) = self.cache.get(family_name) {
143 return Ok(*id);
144 }
145
146 let mut request = nlctrl::Request::new().op_getfamily_do_request();
147 request.encode().push_family_name_bytes(family_name);
148
149 let Protocol::Raw {
150 protonum,
151 request_type,
152 } = request.protocol()
153 else {
154 unreachable!()
155 };
156 assert_eq!(protonum, libc::NETLINK_GENERIC as u16);
157 assert_eq!(request_type, libc::GENL_ID_CTRL as u16);
158
159 let mut iter = self.request_raw(&request, protonum, request_type).await?;
160 if let Some(reply) = iter.recv().await {
161 let Ok(id) = reply?.get_family_id() else {
162 return Err(ErrorKind::Unsupported.into());
163 };
164 self.cache.insert(family_name, id);
165 return Ok(id);
166 }
167
168 Err(ErrorKind::UnexpectedEof.into())
169 }
170
171 #[cfg_attr(not(feature = "async"), maybe_async::maybe_async)]
172 async fn request_raw<'sock, Request>(
173 &'sock mut self,
174 request: &Request,
175 protonum: u16,
176 request_type: u16,
177 ) -> io::Result<NetlinkReply<'sock, Request>>
178 where
179 Request: NetlinkRequest,
180 {
181 let seq = self.reserve_seq(1);
182 let sock = Self::get_socket_cached(&mut self.sock, protonum)?;
183
184 let mut header = PushNlmsghdr::new();
185 header.set_len(header.as_slice().len() as u32 + request.payload().len() as u32);
186 header.set_type(request_type);
187 header.set_flags(request.flags() | libc::NLM_F_REQUEST as u16 | libc::NLM_F_ACK as u16);
188 header.set_seq(seq);
189
190 Self::write_buf(
191 sock,
192 &[
193 IoSlice::new(header.as_slice()),
194 IoSlice::new(request.payload()),
195 ],
196 )
197 .await?;
198
199 Ok(NetlinkReply {
200 sock,
201 buf: &mut self.buf,
202 inner: NetlinkReplyInner {
203 buf_offset: 0,
204 buf_read: 0,
205 },
206 seq: header.seq(),
207 done: false,
208 phantom: PhantomData,
209 })
210 }
211
212 #[cfg_attr(not(feature = "async"), maybe_async::maybe_async)]
213 async fn write_buf(sock: &mut Socket, payload: &[IoSlice<'_>]) -> io::Result<()> {
214 loop {
215 #[cfg(not(feature = "tokio"))]
216 let res = sock.write_vectored(payload).await;
217
218 #[cfg(feature = "tokio")]
219 let res = loop {
220 let res = sock.try_write_vectored(payload);
224 if matches!(&res, Err(err) if err.kind() == ErrorKind::WouldBlock) {
225 sock.writable().await?;
226 continue;
227 }
228 break res;
229 };
230
231 match res {
232 Ok(sent) if sent != payload.iter().map(|s| s.len()).sum() => {
233 return Err(io::Error::other("Couldn't send the whole message"));
234 }
235 Ok(_) => return Ok(()),
236 Err(err) if err.kind() == ErrorKind::Interrupted => continue,
237 Err(err) => return Err(err),
238 }
239 }
240 }
241}
242
243struct NetlinkReplyInner {
244 buf_offset: usize,
245 buf_read: usize,
246}
247
248impl NetlinkReplyInner {
249 #[cfg_attr(not(feature = "async"), maybe_async::maybe_async)]
250 async fn read_buf(sock: &mut Socket, buf: &mut [u8]) -> io::Result<usize> {
251 loop {
252 #[cfg(not(feature = "tokio"))]
253 let res = sock.read(&mut buf[..]).await;
254
255 #[cfg(feature = "tokio")]
256 let res = {
257 let res = sock.try_read(&mut buf[..]);
261 if matches!(&res, Err(err) if err.kind() == ErrorKind::WouldBlock) {
262 sock.readable().await?;
263 continue;
264 }
265 res
266 };
267
268 match res {
269 Ok(read) => return Ok(read),
270 Err(err) if err.kind() == ErrorKind::Interrupted => continue,
271 Err(err) => return Err(err),
272 }
273 }
274 }
275
276 #[allow(clippy::type_complexity)]
277 #[cfg_attr(not(feature = "async"), maybe_async::maybe_async)]
278 pub async fn recv(
279 &mut self,
280 sock: &mut Socket,
281 buf: &mut [u8; RECV_BUF_SIZE],
282 ) -> io::Result<(u32, Result<(usize, usize), ReplyError>)> {
283 if self.buf_offset == self.buf_read {
284 self.buf_read = Self::read_buf(sock, &mut buf[..]).await?;
285 self.buf_offset = 0;
286 }
287
288 let packet = &buf[self.buf_offset..self.buf_read];
289
290 let too_short_err = || io::Error::other("Received packet is too short");
291
292 let Some(header) = packet.get(..PushNlmsghdr::len()) else {
293 return Err(too_short_err());
294 };
295 let header = PushNlmsghdr::new_from_slice(header).unwrap();
296
297 let payload_start = self.buf_offset + PushNlmsghdr::len();
298 self.buf_offset += header.get_len() as usize;
299
300 match header.get_type() as i32 {
301 libc::NLMSG_DONE | libc::NLMSG_ERROR => {
302 let Some(code) = packet.get(16..20) else {
303 return Err(too_short_err());
304 };
305 let code = utils::parse_i32(code).unwrap();
306
307 let (echo_start, echo_end) =
308 if code == 0 || header.get_type() == libc::NLMSG_DONE as u16 {
309 (20, 20)
310 } else {
311 let Some(echo_header) = packet.get(20..(20 + PushNlmsghdr::len())) else {
312 return Err(too_short_err());
313 };
314 let echo_header = PushNlmsghdr::new_from_slice(echo_header).unwrap();
315
316 if echo_header.flags() & libc::NLM_F_CAPPED as u16 == 0 {
317 let start = echo_header.get_len();
318 if packet.len() < start as usize + 20 {
319 return Err(too_short_err());
320 }
321
322 (20 + 16, 20 + start as usize)
323 } else {
324 let ext_ack_start = 20 + PushNlmsghdr::len();
325 (ext_ack_start, ext_ack_start)
326 }
327 };
328
329 Ok((
330 header.seq(),
331 Err(ReplyError {
332 code: io::Error::from_raw_os_error(-code),
333 request_bounds: (echo_start as u32, echo_end as u32),
334 ext_ack_bounds: (echo_end as u32, self.buf_offset as u32),
335 reply_buf: None,
336 chained_name: None,
337 lookup: |_, _, _| Default::default(),
338 }),
339 ))
340 }
341 libc::NLMSG_NOOP => Ok((
342 header.seq(),
343 Err(io::Error::other("Received NLMSG_NOOP").into()),
344 )),
345 libc::NLMSG_OVERRUN => Ok((
346 header.seq(),
347 Err(io::Error::other("Received NLMSG_OVERRUN").into()),
348 )),
349 _ => Ok((header.seq(), Ok((payload_start, self.buf_offset)))),
350 }
351 }
352}
353
354pub struct NetlinkReply<'sock, Request: NetlinkRequest> {
355 inner: NetlinkReplyInner,
356 sock: &'sock mut Socket,
357 buf: &'sock mut Arc<[u8; RECV_BUF_SIZE]>,
358 seq: u32,
359 done: bool,
360 phantom: PhantomData<Request>,
361}
362
363impl<Request: NetlinkRequest> NetlinkReply<'_, Request> {
364 #[cfg_attr(not(feature = "async"), maybe_async::maybe_async)]
365 pub async fn recv_one(&mut self) -> Result<Request::ReplyType<'_>, ReplyError> {
366 if let Some(res) = self.recv().await {
367 return res;
368 }
369 Err(io::Error::other("Reply didn't contain data").into())
370 }
371
372 #[cfg_attr(not(feature = "async"), maybe_async::maybe_async)]
373 pub async fn recv_ack(&mut self) -> Result<(), ReplyError> {
374 if let Some(res) = self.recv().await {
375 res?;
376 return Err(io::Error::other("Reply isn't just an ack").into());
377 }
378 Ok(())
379 }
380
381 #[cfg_attr(not(feature = "async"), maybe_async::maybe_async)]
382 pub async fn recv(&mut self) -> Option<Result<Request::ReplyType<'_>, ReplyError>> {
383 if self.done {
384 return None;
385 }
386
387 let buf = Arc::make_mut(self.buf);
388
389 loop {
390 match self.inner.recv(self.sock, buf).await {
391 Err(io_err) => {
392 self.done = true;
393 return Some(Err(io_err.into()));
394 }
395 Ok((seq, res)) => {
396 if seq != self.seq {
397 continue;
398 }
399 return match res {
400 Ok((l, r)) => Some(Ok(Request::decode_reply(&self.buf[l..r]))),
401 Err(mut err) => {
402 self.done = true;
403 if err.code.raw_os_error().unwrap() == 0 {
404 None
405 } else {
406 if err.has_context() {
407 err.lookup = Request::lookup;
408 err.reply_buf = Some(self.buf.clone());
409 }
410 Some(Err(err))
411 }
412 }
413 };
414 }
415 };
416 }
417 }
418}