1use bytes::BytesMut;
4use std::{
5 fmt::Debug,
6 io,
7 marker::PhantomData,
8 mem::size_of,
9 pin::Pin,
10 task::{Context, Poll},
11};
12
13use futures::{Sink, Stream};
14use log::error;
15
16use crate::{
17 codecs::NetlinkMessageCodec,
18 sys::{AsyncSocket, SocketAddr},
19};
20use netlink_packet_core::{
21 NetlinkDeserializable, NetlinkHeader, NetlinkMessage, NetlinkPayload,
22 NetlinkSerializable, NLMSG_OVERRUN,
23};
24
25const ENOBUFS: i32 = 105;
27
28pub struct NetlinkFramed<T, S, C> {
29 socket: S,
30 msg_type: PhantomData<fn(T) -> T>, codec: PhantomData<fn(C) -> C>, reader: BytesMut,
36 writer: BytesMut,
37 in_addr: SocketAddr,
38 out_addr: SocketAddr,
39 flushed: bool,
40}
41
42impl<T, S, C> Stream for NetlinkFramed<T, S, C>
43where
44 T: NetlinkDeserializable + Debug,
45 S: AsyncSocket,
46 C: NetlinkMessageCodec,
47{
48 type Item = (NetlinkMessage<T>, SocketAddr);
49
50 fn poll_next(
51 self: Pin<&mut Self>,
52 cx: &mut Context<'_>,
53 ) -> Poll<Option<Self::Item>> {
54 let Self {
55 ref mut socket,
56 ref mut in_addr,
57 ref mut reader,
58 ..
59 } = Pin::get_mut(self);
60
61 loop {
62 match C::decode::<T>(reader) {
63 Ok(Some(item)) => return Poll::Ready(Some((item, *in_addr))),
64 Ok(None) => {}
65 Err(e) => {
66 error!("unrecoverable error in decoder: {:?}", e);
67 return Poll::Ready(None);
68 }
69 }
70
71 reader.clear();
72 reader.reserve(INITIAL_READER_CAPACITY);
73
74 *in_addr = match ready!(socket.poll_recv_from(cx, reader)) {
75 Ok(addr) => addr,
76 Err(e) if e.raw_os_error() == Some(ENOBUFS) => {
90 warn!("netlink socket buffer full");
91 let mut hdr = NetlinkHeader::default();
92 hdr.length = size_of::<NetlinkHeader>() as u32;
93 hdr.message_type = NLMSG_OVERRUN;
94 let msg = NetlinkMessage::new(
95 hdr,
96 NetlinkPayload::Overrun(Vec::new()),
97 );
98 return Poll::Ready(Some((msg, SocketAddr::new(0, 0))));
99 }
100 Err(e) => {
101 error!("failed to read from netlink socket: {:?}", e);
102 return Poll::Ready(None);
103 }
104 };
105 }
106 }
107}
108
109impl<T, S, C> Sink<(NetlinkMessage<T>, SocketAddr)> for NetlinkFramed<T, S, C>
110where
111 T: NetlinkSerializable + Debug,
112 S: AsyncSocket,
113 C: NetlinkMessageCodec,
114{
115 type Error = io::Error;
116
117 fn poll_ready(
118 self: Pin<&mut Self>,
119 cx: &mut Context<'_>,
120 ) -> Poll<Result<(), Self::Error>> {
121 if !self.flushed {
122 match self.poll_flush(cx)? {
123 Poll::Ready(()) => {}
124 Poll::Pending => return Poll::Pending,
125 }
126 }
127
128 Poll::Ready(Ok(()))
129 }
130
131 fn start_send(
132 self: Pin<&mut Self>,
133 item: (NetlinkMessage<T>, SocketAddr),
134 ) -> Result<(), Self::Error> {
135 trace!("sending frame");
136 let (frame, out_addr) = item;
137 let pin = self.get_mut();
138 C::encode(frame, &mut pin.writer)?;
139 pin.out_addr = out_addr;
140 pin.flushed = false;
141 trace!("frame encoded; length={}", pin.writer.len());
142 Ok(())
143 }
144
145 fn poll_flush(
146 mut self: Pin<&mut Self>,
147 cx: &mut Context<'_>,
148 ) -> Poll<Result<(), Self::Error>> {
149 if self.flushed {
150 return Poll::Ready(Ok(()));
151 }
152
153 trace!("flushing frame; length={}", self.writer.len());
154 let Self {
155 ref mut socket,
156 ref mut out_addr,
157 ref mut writer,
158 ..
159 } = *self;
160
161 let n = ready!(socket.poll_send_to(cx, writer, out_addr))?;
162 trace!("written {}", n);
163
164 let wrote_all = n == self.writer.len();
165 self.writer.clear();
166 self.flushed = true;
167
168 let res = if wrote_all {
169 Ok(())
170 } else {
171 Err(io::Error::new(
172 io::ErrorKind::Other,
173 "failed to write entire datagram to socket",
174 ))
175 };
176
177 Poll::Ready(res)
178 }
179
180 fn poll_close(
181 self: Pin<&mut Self>,
182 cx: &mut Context<'_>,
183 ) -> Poll<Result<(), Self::Error>> {
184 ready!(self.poll_flush(cx))?;
185 Poll::Ready(Ok(()))
186 }
187}
188
189const INITIAL_READER_CAPACITY: usize = 64 * 1024;
193const INITIAL_WRITER_CAPACITY: usize = 8 * 1024;
194
195impl<T, S, C> NetlinkFramed<T, S, C> {
196 pub fn new(socket: S) -> Self {
200 Self {
201 socket,
202 msg_type: PhantomData,
203 codec: PhantomData,
204 out_addr: SocketAddr::new(0, 0),
205 in_addr: SocketAddr::new(0, 0),
206 reader: BytesMut::with_capacity(INITIAL_READER_CAPACITY),
207 writer: BytesMut::with_capacity(INITIAL_WRITER_CAPACITY),
208 flushed: true,
209 }
210 }
211
212 pub fn get_ref(&self) -> &S {
220 &self.socket
221 }
222
223 pub fn get_mut(&mut self) -> &mut S {
232 &mut self.socket
233 }
234
235 pub fn into_inner(self) -> S {
237 self.socket
238 }
239}