netlink_proto/
framed.rs

1// SPDX-License-Identifier: MIT
2
3use bytes::BytesMut;
4use std::{
5    fmt::Debug,
6    io,
7    marker::PhantomData,
8    mem::size_of,
9    pin::Pin,
10    task::{Context, Poll},
11};
12
13use futures::{Sink, Stream};
14use log::error;
15
16use crate::{
17    codecs::NetlinkMessageCodec,
18    sys::{AsyncSocket, SocketAddr},
19};
20use netlink_packet_core::{
21    NetlinkDeserializable, NetlinkHeader, NetlinkMessage, NetlinkPayload,
22    NetlinkSerializable, NLMSG_OVERRUN,
23};
24
25/// Buffer overrun condition
26const ENOBUFS: i32 = 105;
27
28pub struct NetlinkFramed<T, S, C> {
29    socket: S,
30    // see https://doc.rust-lang.org/nomicon/phantom-data.html
31    // "invariant" seems like the safe choice; using `fn(T) -> T`
32    // should make it invariant but still Send+Sync.
33    msg_type: PhantomData<fn(T) -> T>, // invariant
34    codec: PhantomData<fn(C) -> C>,    // invariant
35    reader: BytesMut,
36    writer: BytesMut,
37    in_addr: SocketAddr,
38    out_addr: SocketAddr,
39    flushed: bool,
40}
41
42impl<T, S, C> Stream for NetlinkFramed<T, S, C>
43where
44    T: NetlinkDeserializable + Debug,
45    S: AsyncSocket,
46    C: NetlinkMessageCodec,
47{
48    type Item = (NetlinkMessage<T>, SocketAddr);
49
50    fn poll_next(
51        self: Pin<&mut Self>,
52        cx: &mut Context<'_>,
53    ) -> Poll<Option<Self::Item>> {
54        let Self {
55            ref mut socket,
56            ref mut in_addr,
57            ref mut reader,
58            ..
59        } = Pin::get_mut(self);
60
61        loop {
62            match C::decode::<T>(reader) {
63                Ok(Some(item)) => return Poll::Ready(Some((item, *in_addr))),
64                Ok(None) => {}
65                Err(e) => {
66                    error!("unrecoverable error in decoder: {:?}", e);
67                    return Poll::Ready(None);
68                }
69            }
70
71            reader.clear();
72            reader.reserve(INITIAL_READER_CAPACITY);
73
74            *in_addr = match ready!(socket.poll_recv_from(cx, reader)) {
75                Ok(addr) => addr,
76                // When receiving messages in multicast mode (i.e. we subscribed
77                // to notifications), the kernel will not wait
78                // for us to read datagrams before sending more.
79                // The receive buffer has a finite size, so once it is full (no
80                // more message can fit in), new messages will be dropped and
81                // recv calls will return `ENOBUFS`.
82                // This needs to be handled for applications to resynchronize
83                // with the contents of the kernel if necessary.
84                // We don't need to do anything special:
85                // - contents of the reader is still valid because we won't have
86                //   partial messages in there anyways (large enough buffer)
87                // - contents of the socket's internal buffer is still valid
88                //   because the kernel won't put partial data in it
89                Err(e) if e.raw_os_error() == Some(ENOBUFS) => {
90                    warn!("netlink socket buffer full");
91                    let mut hdr = NetlinkHeader::default();
92                    hdr.length = size_of::<NetlinkHeader>() as u32;
93                    hdr.message_type = NLMSG_OVERRUN;
94                    let msg = NetlinkMessage::new(
95                        hdr,
96                        NetlinkPayload::Overrun(Vec::new()),
97                    );
98                    return Poll::Ready(Some((msg, SocketAddr::new(0, 0))));
99                }
100                Err(e) => {
101                    error!("failed to read from netlink socket: {:?}", e);
102                    return Poll::Ready(None);
103                }
104            };
105        }
106    }
107}
108
109impl<T, S, C> Sink<(NetlinkMessage<T>, SocketAddr)> for NetlinkFramed<T, S, C>
110where
111    T: NetlinkSerializable + Debug,
112    S: AsyncSocket,
113    C: NetlinkMessageCodec,
114{
115    type Error = io::Error;
116
117    fn poll_ready(
118        self: Pin<&mut Self>,
119        cx: &mut Context<'_>,
120    ) -> Poll<Result<(), Self::Error>> {
121        if !self.flushed {
122            match self.poll_flush(cx)? {
123                Poll::Ready(()) => {}
124                Poll::Pending => return Poll::Pending,
125            }
126        }
127
128        Poll::Ready(Ok(()))
129    }
130
131    fn start_send(
132        self: Pin<&mut Self>,
133        item: (NetlinkMessage<T>, SocketAddr),
134    ) -> Result<(), Self::Error> {
135        trace!("sending frame");
136        let (frame, out_addr) = item;
137        let pin = self.get_mut();
138        C::encode(frame, &mut pin.writer)?;
139        pin.out_addr = out_addr;
140        pin.flushed = false;
141        trace!("frame encoded; length={}", pin.writer.len());
142        Ok(())
143    }
144
145    fn poll_flush(
146        mut self: Pin<&mut Self>,
147        cx: &mut Context<'_>,
148    ) -> Poll<Result<(), Self::Error>> {
149        if self.flushed {
150            return Poll::Ready(Ok(()));
151        }
152
153        trace!("flushing frame; length={}", self.writer.len());
154        let Self {
155            ref mut socket,
156            ref mut out_addr,
157            ref mut writer,
158            ..
159        } = *self;
160
161        let n = ready!(socket.poll_send_to(cx, writer, out_addr))?;
162        trace!("written {}", n);
163
164        let wrote_all = n == self.writer.len();
165        self.writer.clear();
166        self.flushed = true;
167
168        let res = if wrote_all {
169            Ok(())
170        } else {
171            Err(io::Error::new(
172                io::ErrorKind::Other,
173                "failed to write entire datagram to socket",
174            ))
175        };
176
177        Poll::Ready(res)
178    }
179
180    fn poll_close(
181        self: Pin<&mut Self>,
182        cx: &mut Context<'_>,
183    ) -> Poll<Result<(), Self::Error>> {
184        ready!(self.poll_flush(cx))?;
185        Poll::Ready(Ok(()))
186    }
187}
188
189// The theoritical max netlink packet size is 32KB for a netlink
190// message since Linux 4.9 (16KB before). See:
191// https://git.kernel.org/pub/scm/linux/kernel/git/davem/net-next.git/commit/?id=d35c99ff77ecb2eb239731b799386f3b3637a31e
192const INITIAL_READER_CAPACITY: usize = 64 * 1024;
193const INITIAL_WRITER_CAPACITY: usize = 8 * 1024;
194
195impl<T, S, C> NetlinkFramed<T, S, C> {
196    /// Create a new `NetlinkFramed` backed by the given socket and codec.
197    ///
198    /// See struct level documentation for more details.
199    pub fn new(socket: S) -> Self {
200        Self {
201            socket,
202            msg_type: PhantomData,
203            codec: PhantomData,
204            out_addr: SocketAddr::new(0, 0),
205            in_addr: SocketAddr::new(0, 0),
206            reader: BytesMut::with_capacity(INITIAL_READER_CAPACITY),
207            writer: BytesMut::with_capacity(INITIAL_WRITER_CAPACITY),
208            flushed: true,
209        }
210    }
211
212    /// Returns a reference to the underlying I/O stream wrapped by `Framed`.
213    ///
214    /// # Note
215    ///
216    /// Care should be taken to not tamper with the underlying stream of data
217    /// coming in as it may corrupt the stream of frames otherwise being worked
218    /// with.
219    pub fn get_ref(&self) -> &S {
220        &self.socket
221    }
222
223    /// Returns a mutable reference to the underlying I/O stream wrapped by
224    /// `Framed`.
225    ///
226    /// # Note
227    ///
228    /// Care should be taken to not tamper with the underlying stream of data
229    /// coming in as it may corrupt the stream of frames otherwise being worked
230    /// with.
231    pub fn get_mut(&mut self) -> &mut S {
232        &mut self.socket
233    }
234
235    /// Consumes the `Framed`, returning its underlying I/O stream.
236    pub fn into_inner(self) -> S {
237        self.socket
238    }
239}