netlink_proto/
connection.rs

1// SPDX-License-Identifier: MIT
2
3use std::{
4    fmt::Debug,
5    io,
6    pin::Pin,
7    task::{Context, Poll},
8};
9
10use futures::{
11    channel::mpsc::{UnboundedReceiver, UnboundedSender},
12    Future, Sink, Stream,
13};
14use log::{error, warn};
15use netlink_packet_core::{
16    NetlinkDeserializable, NetlinkMessage, NetlinkPayload, NetlinkSerializable,
17};
18
19use crate::{
20    codecs::{NetlinkCodec, NetlinkMessageCodec},
21    framed::NetlinkFramed,
22    sys::{AsyncSocket, SocketAddr},
23    Protocol, Request, Response,
24};
25
26#[cfg(feature = "tokio_socket")]
27use netlink_sys::TokioSocket as DefaultSocket;
28#[cfg(not(feature = "tokio_socket"))]
29type DefaultSocket = ();
30
31/// Connection to a Netlink socket, running in the background.
32///
33/// [`ConnectionHandle`](struct.ConnectionHandle.html) are used to pass new
34/// requests to the `Connection`, that in turn, sends them through the netlink
35/// socket.
36pub struct Connection<T, S = DefaultSocket, C = NetlinkCodec>
37where
38    T: Debug + NetlinkSerializable + NetlinkDeserializable,
39{
40    socket: NetlinkFramed<T, S, C>,
41
42    protocol: Protocol<T, UnboundedSender<NetlinkMessage<T>>>,
43
44    /// Channel used by the user to pass requests to the connection.
45    requests_rx: Option<UnboundedReceiver<Request<T>>>,
46
47    /// Channel used to transmit to the ConnectionHandle the unsolicited
48    /// messages received from the socket (multicast messages for instance).
49    unsolicited_messages_tx:
50        Option<UnboundedSender<(NetlinkMessage<T>, SocketAddr)>>,
51
52    socket_closed: bool,
53
54    forward_noop: bool,
55    forward_done: bool,
56    forward_ack: bool,
57}
58
59impl<T, S, C> Connection<T, S, C>
60where
61    T: Debug + NetlinkSerializable + NetlinkDeserializable + Unpin,
62    S: AsyncSocket,
63    C: NetlinkMessageCodec,
64{
65    pub(crate) fn new(
66        requests_rx: UnboundedReceiver<Request<T>>,
67        unsolicited_messages_tx: UnboundedSender<(
68            NetlinkMessage<T>,
69            SocketAddr,
70        )>,
71        protocol: isize,
72    ) -> io::Result<Self> {
73        let socket = S::new(protocol)?;
74        Ok(Connection {
75            socket: NetlinkFramed::new(socket),
76            protocol: Protocol::new(),
77            requests_rx: Some(requests_rx),
78            unsolicited_messages_tx: Some(unsolicited_messages_tx),
79            socket_closed: false,
80            forward_noop: false,
81            forward_done: false,
82            forward_ack: false,
83        })
84    }
85
86    /// Whether [NetlinkPayload::Noop] should forwared to handler
87    pub fn set_forward_noop(&mut self, value: bool) {
88        self.forward_noop = value;
89    }
90
91    /// Whether [NetlinkPayload::Done] should forwared to handler
92    pub fn set_forward_done(&mut self, value: bool) {
93        self.forward_done = value;
94    }
95
96    /// Whether [NetlinkPayload::Ack] should forwared to handler
97    pub fn set_forward_ack(&mut self, value: bool) {
98        self.forward_ack = value;
99    }
100
101    pub fn socket_mut(&mut self) -> &mut S {
102        self.socket.get_mut()
103    }
104
105    pub fn poll_send_messages(&mut self, cx: &mut Context) {
106        trace!("poll_send_messages called");
107        let Connection {
108            ref mut socket,
109            ref mut protocol,
110            ..
111        } = self;
112        let mut socket = Pin::new(socket);
113
114        if !protocol.outgoing_messages.is_empty() {
115            trace!(
116                "found outgoing message to send checking if socket is ready"
117            );
118            match Pin::as_mut(&mut socket).poll_ready(cx) {
119                Poll::Ready(Err(e)) => {
120                    // Sink errors are usually not recoverable. The socket
121                    // probably shut down.
122                    warn!("netlink socket shut down: {:?}", e);
123                    self.socket_closed = true;
124                    return;
125                }
126                Poll::Pending => {
127                    trace!("poll is not ready, returning");
128                    return;
129                }
130                Poll::Ready(Ok(_)) => {}
131            }
132
133            let (mut message, addr) =
134                protocol.outgoing_messages.pop_front().unwrap();
135            message.finalize();
136
137            trace!("sending outgoing message");
138            if let Err(e) = Pin::as_mut(&mut socket).start_send((message, addr))
139            {
140                error!("failed to send message: {:?}", e);
141                self.socket_closed = true;
142                return;
143            }
144        }
145
146        trace!("poll_send_messages done");
147        self.poll_flush(cx)
148    }
149
150    pub fn poll_flush(&mut self, cx: &mut Context) {
151        trace!("poll_flush called");
152        if let Poll::Ready(Err(e)) = Pin::new(&mut self.socket).poll_flush(cx) {
153            warn!("error flushing netlink socket: {:?}", e);
154            self.socket_closed = true;
155        }
156    }
157
158    pub fn poll_read_messages(&mut self, cx: &mut Context) {
159        trace!("poll_read_messages called");
160        let mut socket = Pin::new(&mut self.socket);
161
162        loop {
163            trace!("polling socket");
164            match socket.as_mut().poll_next(cx) {
165                Poll::Ready(Some((message, addr))) => {
166                    trace!("read datagram from socket");
167                    self.protocol.handle_message(message, addr);
168                }
169                Poll::Ready(None) => {
170                    warn!("netlink socket stream shut down");
171                    self.socket_closed = true;
172                    return;
173                }
174                Poll::Pending => {
175                    trace!("no datagram read from socket");
176                    return;
177                }
178            }
179        }
180    }
181
182    pub fn poll_requests(&mut self, cx: &mut Context) {
183        trace!("poll_requests called");
184        if let Some(mut stream) = self.requests_rx.as_mut() {
185            loop {
186                match Pin::new(&mut stream).poll_next(cx) {
187                    Poll::Ready(Some(request)) => {
188                        self.protocol.request(request)
189                    }
190                    Poll::Ready(None) => break,
191                    Poll::Pending => return,
192                }
193            }
194            let _ = self.requests_rx.take();
195            trace!("no new requests to handle poll_requests done");
196        }
197    }
198
199    pub fn forward_unsolicited_messages(&mut self) {
200        if self.unsolicited_messages_tx.is_none() {
201            while let Some((message, source)) =
202                self.protocol.incoming_requests.pop_front()
203            {
204                warn!(
205                    "ignoring unsolicited message {:?} from {:?}",
206                    message, source
207                );
208            }
209            return;
210        }
211
212        trace!("forward_unsolicited_messages called");
213        let mut ready = false;
214
215        let Connection {
216            ref mut protocol,
217            ref mut unsolicited_messages_tx,
218            ..
219        } = self;
220
221        while let Some((message, source)) =
222            protocol.incoming_requests.pop_front()
223        {
224            if unsolicited_messages_tx
225                .as_mut()
226                .unwrap()
227                .unbounded_send((message, source))
228                .is_err()
229            {
230                // The channel is unbounded so the only error that can
231                // occur is that the channel is closed because the
232                // receiver was dropped
233                warn!("failed to forward message to connection handle: channel closed");
234                ready = true;
235                break;
236            }
237        }
238
239        // Rust 1.82 has Option::is_none_or() which can simplify
240        // below checks but that version is released on 2024 Oct. 17 which
241        // is not available on old OS like Ubuntu 24.04 LTS, RHEL 9 yet.
242        if ready
243            || self.unsolicited_messages_tx.as_ref().is_none()
244            || self.unsolicited_messages_tx.as_ref().map(|x| x.is_closed())
245                == Some(true)
246        {
247            // The channel is closed so we can drop the sender.
248            let _ = self.unsolicited_messages_tx.take();
249            // purge `protocol.incoming_requests`
250            self.forward_unsolicited_messages();
251        }
252
253        trace!("forward_unsolicited_messages done");
254    }
255
256    pub fn forward_responses(&mut self) {
257        trace!("forward_responses called");
258        let protocol = &mut self.protocol;
259
260        while let Some(response) = protocol.incoming_responses.pop_front() {
261            let Response {
262                message,
263                done,
264                metadata: tx,
265            } = response;
266            if done {
267                use NetlinkPayload::*;
268                match &message.payload {
269                    Noop => {
270                        if !self.forward_noop {
271                            trace!("Not forwarding Noop message to the handle");
272                            continue;
273                        }
274                    }
275                    // Since `self.protocol` set the `done` flag here,
276                    // we know it has already dropped the request and
277                    // its associated metadata, ie the UnboundedSender
278                    // used to forward messages back to the
279                    // ConnectionHandle. By just continuing we're
280                    // dropping the last instance of that sender,
281                    // hence closing the channel and signaling the
282                    // handle that no more messages are expected.
283                    Done(_) => {
284                        if !self.forward_done {
285                            trace!("Not forwarding Done message to the handle");
286                            continue;
287                        }
288                    }
289                    // I'm not sure how we should handle overrun messages
290                    Overrun(_) => unimplemented!("overrun is not handled yet"),
291                    // We need to forward error messages and messages
292                    // that are part of the netlink subprotocol,
293                    // because only the user knows how they want to
294                    // handle them.
295                    Error(err_msg) => {
296                        if err_msg.code.is_none() && !self.forward_ack {
297                            trace!("Not forwarding Ack message to the handle");
298                            continue;
299                        }
300                    }
301                    InnerMessage(_) => {}
302                    _ => {}
303                }
304            }
305
306            trace!("forwarding response to the handle");
307            if tx.unbounded_send(message).is_err() {
308                // With an unboundedsender, an error can
309                // only happen if the receiver is closed.
310                warn!("failed to forward response back to the handle");
311            }
312        }
313        trace!("forward_responses done");
314    }
315
316    pub fn should_shut_down(&self) -> bool {
317        self.socket_closed
318            || (self.unsolicited_messages_tx.is_none()
319                && self.requests_rx.is_none())
320    }
321}
322
323impl<T, S, C> Connection<T, S, C>
324where
325    T: Debug + NetlinkSerializable + NetlinkDeserializable + Unpin,
326    S: AsyncSocket,
327    C: NetlinkMessageCodec,
328{
329    pub(crate) fn from_socket(
330        requests_rx: UnboundedReceiver<Request<T>>,
331        unsolicited_messages_tx: UnboundedSender<(
332            NetlinkMessage<T>,
333            SocketAddr,
334        )>,
335        socket: S,
336    ) -> Self {
337        Connection {
338            socket: NetlinkFramed::new(socket),
339            protocol: Protocol::new(),
340            requests_rx: Some(requests_rx),
341            unsolicited_messages_tx: Some(unsolicited_messages_tx),
342            socket_closed: false,
343            forward_noop: false,
344            forward_done: false,
345            forward_ack: false,
346        }
347    }
348}
349
350impl<T, S, C> Future for Connection<T, S, C>
351where
352    T: Debug + NetlinkSerializable + NetlinkDeserializable + Unpin,
353    S: AsyncSocket,
354    C: NetlinkMessageCodec,
355{
356    type Output = ();
357
358    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
359        trace!("polling Connection");
360        let pinned = self.get_mut();
361
362        trace!("reading incoming messages");
363        pinned.poll_read_messages(cx);
364
365        trace!("forwarding unsolicited messages to the connection handle");
366        pinned.forward_unsolicited_messages();
367
368        trace!(
369            "forwarding responses to previous requests to the connection handle"
370        );
371        pinned.forward_responses();
372
373        trace!("handling requests");
374        pinned.poll_requests(cx);
375
376        trace!("sending messages");
377        pinned.poll_send_messages(cx);
378
379        trace!("done polling Connection");
380
381        if pinned.should_shut_down() {
382            Poll::Ready(())
383        } else {
384            Poll::Pending
385        }
386    }
387}
388
389#[cfg(all(test, feature = "tokio_socket"))]
390mod tests {
391    use crate::new_connection;
392    use crate::sys::protocols::NETLINK_AUDIT;
393    use netlink_packet_audit::AuditMessage;
394    use tokio::time;
395
396    #[tokio::test]
397    async fn connection_is_closed() {
398        let (conn, _, _) =
399            new_connection::<AuditMessage>(NETLINK_AUDIT).unwrap();
400        let join_handle = tokio::spawn(conn);
401        time::sleep(time::Duration::from_millis(200)).await;
402        assert!(join_handle.is_finished());
403    }
404}