1use std::{
4 fmt::Debug,
5 io,
6 pin::Pin,
7 task::{Context, Poll},
8};
9
10use futures::{
11 channel::mpsc::{UnboundedReceiver, UnboundedSender},
12 Future, Sink, Stream,
13};
14use log::{error, warn};
15use netlink_packet_core::{
16 NetlinkDeserializable, NetlinkMessage, NetlinkPayload, NetlinkSerializable,
17};
18
19use crate::{
20 codecs::{NetlinkCodec, NetlinkMessageCodec},
21 framed::NetlinkFramed,
22 sys::{AsyncSocket, SocketAddr},
23 Protocol, Request, Response,
24};
25
26#[cfg(feature = "tokio_socket")]
27use netlink_sys::TokioSocket as DefaultSocket;
28#[cfg(not(feature = "tokio_socket"))]
29type DefaultSocket = ();
30
31pub struct Connection<T, S = DefaultSocket, C = NetlinkCodec>
37where
38 T: Debug + NetlinkSerializable + NetlinkDeserializable,
39{
40 socket: NetlinkFramed<T, S, C>,
41
42 protocol: Protocol<T, UnboundedSender<NetlinkMessage<T>>>,
43
44 requests_rx: Option<UnboundedReceiver<Request<T>>>,
46
47 unsolicited_messages_tx:
50 Option<UnboundedSender<(NetlinkMessage<T>, SocketAddr)>>,
51
52 socket_closed: bool,
53
54 forward_noop: bool,
55 forward_done: bool,
56 forward_ack: bool,
57}
58
59impl<T, S, C> Connection<T, S, C>
60where
61 T: Debug + NetlinkSerializable + NetlinkDeserializable + Unpin,
62 S: AsyncSocket,
63 C: NetlinkMessageCodec,
64{
65 pub(crate) fn new(
66 requests_rx: UnboundedReceiver<Request<T>>,
67 unsolicited_messages_tx: UnboundedSender<(
68 NetlinkMessage<T>,
69 SocketAddr,
70 )>,
71 protocol: isize,
72 ) -> io::Result<Self> {
73 let socket = S::new(protocol)?;
74 Ok(Connection {
75 socket: NetlinkFramed::new(socket),
76 protocol: Protocol::new(),
77 requests_rx: Some(requests_rx),
78 unsolicited_messages_tx: Some(unsolicited_messages_tx),
79 socket_closed: false,
80 forward_noop: false,
81 forward_done: false,
82 forward_ack: false,
83 })
84 }
85
86 pub fn set_forward_noop(&mut self, value: bool) {
88 self.forward_noop = value;
89 }
90
91 pub fn set_forward_done(&mut self, value: bool) {
93 self.forward_done = value;
94 }
95
96 pub fn set_forward_ack(&mut self, value: bool) {
98 self.forward_ack = value;
99 }
100
101 pub fn socket_mut(&mut self) -> &mut S {
102 self.socket.get_mut()
103 }
104
105 pub fn poll_send_messages(&mut self, cx: &mut Context) {
106 trace!("poll_send_messages called");
107 let Connection {
108 ref mut socket,
109 ref mut protocol,
110 ..
111 } = self;
112 let mut socket = Pin::new(socket);
113
114 if !protocol.outgoing_messages.is_empty() {
115 trace!(
116 "found outgoing message to send checking if socket is ready"
117 );
118 match Pin::as_mut(&mut socket).poll_ready(cx) {
119 Poll::Ready(Err(e)) => {
120 warn!("netlink socket shut down: {:?}", e);
123 self.socket_closed = true;
124 return;
125 }
126 Poll::Pending => {
127 trace!("poll is not ready, returning");
128 return;
129 }
130 Poll::Ready(Ok(_)) => {}
131 }
132
133 let (mut message, addr) =
134 protocol.outgoing_messages.pop_front().unwrap();
135 message.finalize();
136
137 trace!("sending outgoing message");
138 if let Err(e) = Pin::as_mut(&mut socket).start_send((message, addr))
139 {
140 error!("failed to send message: {:?}", e);
141 self.socket_closed = true;
142 return;
143 }
144 }
145
146 trace!("poll_send_messages done");
147 self.poll_flush(cx)
148 }
149
150 pub fn poll_flush(&mut self, cx: &mut Context) {
151 trace!("poll_flush called");
152 if let Poll::Ready(Err(e)) = Pin::new(&mut self.socket).poll_flush(cx) {
153 warn!("error flushing netlink socket: {:?}", e);
154 self.socket_closed = true;
155 }
156 }
157
158 pub fn poll_read_messages(&mut self, cx: &mut Context) {
159 trace!("poll_read_messages called");
160 let mut socket = Pin::new(&mut self.socket);
161
162 loop {
163 trace!("polling socket");
164 match socket.as_mut().poll_next(cx) {
165 Poll::Ready(Some((message, addr))) => {
166 trace!("read datagram from socket");
167 self.protocol.handle_message(message, addr);
168 }
169 Poll::Ready(None) => {
170 warn!("netlink socket stream shut down");
171 self.socket_closed = true;
172 return;
173 }
174 Poll::Pending => {
175 trace!("no datagram read from socket");
176 return;
177 }
178 }
179 }
180 }
181
182 pub fn poll_requests(&mut self, cx: &mut Context) {
183 trace!("poll_requests called");
184 if let Some(mut stream) = self.requests_rx.as_mut() {
185 loop {
186 match Pin::new(&mut stream).poll_next(cx) {
187 Poll::Ready(Some(request)) => {
188 self.protocol.request(request)
189 }
190 Poll::Ready(None) => break,
191 Poll::Pending => return,
192 }
193 }
194 let _ = self.requests_rx.take();
195 trace!("no new requests to handle poll_requests done");
196 }
197 }
198
199 pub fn forward_unsolicited_messages(&mut self) {
200 if self.unsolicited_messages_tx.is_none() {
201 while let Some((message, source)) =
202 self.protocol.incoming_requests.pop_front()
203 {
204 warn!(
205 "ignoring unsolicited message {:?} from {:?}",
206 message, source
207 );
208 }
209 return;
210 }
211
212 trace!("forward_unsolicited_messages called");
213 let mut ready = false;
214
215 let Connection {
216 ref mut protocol,
217 ref mut unsolicited_messages_tx,
218 ..
219 } = self;
220
221 while let Some((message, source)) =
222 protocol.incoming_requests.pop_front()
223 {
224 if unsolicited_messages_tx
225 .as_mut()
226 .unwrap()
227 .unbounded_send((message, source))
228 .is_err()
229 {
230 warn!("failed to forward message to connection handle: channel closed");
234 ready = true;
235 break;
236 }
237 }
238
239 if ready
243 || self.unsolicited_messages_tx.as_ref().is_none()
244 || self.unsolicited_messages_tx.as_ref().map(|x| x.is_closed())
245 == Some(true)
246 {
247 let _ = self.unsolicited_messages_tx.take();
249 self.forward_unsolicited_messages();
251 }
252
253 trace!("forward_unsolicited_messages done");
254 }
255
256 pub fn forward_responses(&mut self) {
257 trace!("forward_responses called");
258 let protocol = &mut self.protocol;
259
260 while let Some(response) = protocol.incoming_responses.pop_front() {
261 let Response {
262 message,
263 done,
264 metadata: tx,
265 } = response;
266 if done {
267 use NetlinkPayload::*;
268 match &message.payload {
269 Noop => {
270 if !self.forward_noop {
271 trace!("Not forwarding Noop message to the handle");
272 continue;
273 }
274 }
275 Done(_) => {
284 if !self.forward_done {
285 trace!("Not forwarding Done message to the handle");
286 continue;
287 }
288 }
289 Overrun(_) => unimplemented!("overrun is not handled yet"),
291 Error(err_msg) => {
296 if err_msg.code.is_none() && !self.forward_ack {
297 trace!("Not forwarding Ack message to the handle");
298 continue;
299 }
300 }
301 InnerMessage(_) => {}
302 _ => {}
303 }
304 }
305
306 trace!("forwarding response to the handle");
307 if tx.unbounded_send(message).is_err() {
308 warn!("failed to forward response back to the handle");
311 }
312 }
313 trace!("forward_responses done");
314 }
315
316 pub fn should_shut_down(&self) -> bool {
317 self.socket_closed
318 || (self.unsolicited_messages_tx.is_none()
319 && self.requests_rx.is_none())
320 }
321}
322
323impl<T, S, C> Connection<T, S, C>
324where
325 T: Debug + NetlinkSerializable + NetlinkDeserializable + Unpin,
326 S: AsyncSocket,
327 C: NetlinkMessageCodec,
328{
329 pub(crate) fn from_socket(
330 requests_rx: UnboundedReceiver<Request<T>>,
331 unsolicited_messages_tx: UnboundedSender<(
332 NetlinkMessage<T>,
333 SocketAddr,
334 )>,
335 socket: S,
336 ) -> Self {
337 Connection {
338 socket: NetlinkFramed::new(socket),
339 protocol: Protocol::new(),
340 requests_rx: Some(requests_rx),
341 unsolicited_messages_tx: Some(unsolicited_messages_tx),
342 socket_closed: false,
343 forward_noop: false,
344 forward_done: false,
345 forward_ack: false,
346 }
347 }
348}
349
350impl<T, S, C> Future for Connection<T, S, C>
351where
352 T: Debug + NetlinkSerializable + NetlinkDeserializable + Unpin,
353 S: AsyncSocket,
354 C: NetlinkMessageCodec,
355{
356 type Output = ();
357
358 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
359 trace!("polling Connection");
360 let pinned = self.get_mut();
361
362 trace!("reading incoming messages");
363 pinned.poll_read_messages(cx);
364
365 trace!("forwarding unsolicited messages to the connection handle");
366 pinned.forward_unsolicited_messages();
367
368 trace!(
369 "forwarding responses to previous requests to the connection handle"
370 );
371 pinned.forward_responses();
372
373 trace!("handling requests");
374 pinned.poll_requests(cx);
375
376 trace!("sending messages");
377 pinned.poll_send_messages(cx);
378
379 trace!("done polling Connection");
380
381 if pinned.should_shut_down() {
382 Poll::Ready(())
383 } else {
384 Poll::Pending
385 }
386 }
387}
388
389#[cfg(all(test, feature = "tokio_socket"))]
390mod tests {
391 use crate::new_connection;
392 use crate::sys::protocols::NETLINK_AUDIT;
393 use netlink_packet_audit::AuditMessage;
394 use tokio::time;
395
396 #[tokio::test]
397 async fn connection_is_closed() {
398 let (conn, _, _) =
399 new_connection::<AuditMessage>(NETLINK_AUDIT).unwrap();
400 let join_handle = tokio::spawn(conn);
401 time::sleep(time::Duration::from_millis(200)).await;
402 assert!(join_handle.is_finished());
403 }
404}