async_icmp/ping/
multiplexer_task.rs

1use crate::message::IcmpV6MsgType;
2#[cfg(doc)]
3use crate::ping::PingMultiplexer;
4use crate::{
5    message::{
6        decode::DecodedIcmpMsg,
7        echo::{parse_echo_reply, EchoId, EchoSeq},
8        IcmpV4MsgType,
9    },
10    platform,
11    socket::{SocketConfig, SocketPair},
12    Icmpv4, Icmpv6,
13};
14use hashbrown::hash_map::Entry;
15use log::{debug, warn};
16use std::{fmt, hash, io, net, sync, time};
17use tokio::sync::{mpsc as tmpsc, mpsc::error::TrySendError, oneshot};
18
19/// Multiplexer task that runs in the background.
20///
21/// All `recv` is done by this task so that session lookup can be lock and wait free.
22///
23/// It also handles session management (add and close) so that there's a single point of truth
24/// for whether a session exists.
25pub(crate) struct MultiplexTask {
26    v4_buf: Vec<u8>,
27    v6_buf: Vec<u8>,
28    sockets: sync::Arc<SocketPair>,
29    /// Session lookup for send path.
30    ///
31    /// `hashbrown` map so we can use `try_insert`.
32    ///
33    /// Uses [`SessionHandle`] as a key since the send key will be sent across the command
34    /// channel with every send, so the key should be small.
35    send_session_states:
36        sync::Arc<sync::RwLock<hashbrown::HashMap<SessionHandle, SendSessionState>>>,
37    /// Session lookup for recv path.
38    ///
39    /// `hashbrown` map so we can use `Equivalent` and `try_insert`..
40    /// Otherwise, we end up stuck trying to implement a self-referential Borrow to appease stdlib
41    /// HashMap::get
42    recv_session_states: hashbrown::HashMap<RecvHashKey, RecvSessionState>,
43    commands: tmpsc::Receiver<MultiplexerCommand>,
44    /// If true, main loop should abort.
45    shutdown: bool,
46    /// The next handle id to attempt to use
47    next_handle_id: u64,
48}
49
50impl MultiplexTask {
51    /// Create a new multiplexer task state with the provided config.
52    ///
53    /// Returns the multiplexer, ipv4 port, ipv6 port, and command channel sender.
54    #[allow(clippy::type_complexity)]
55    pub(crate) fn new(
56        icmpv4_config: SocketConfig<Icmpv4>,
57        icmpv6_config: SocketConfig<Icmpv6>,
58    ) -> io::Result<(
59        Self,
60        u16,
61        u16,
62        sync::Arc<SocketPair>,
63        tmpsc::Sender<MultiplexerCommand>,
64        sync::Arc<sync::RwLock<hashbrown::HashMap<SessionHandle, SendSessionState>>>,
65    )> {
66        let (tx, rx) = tmpsc::channel(16);
67        let sockets = sync::Arc::new(SocketPair::new(icmpv4_config, icmpv6_config)?);
68        let v4_port = sockets.ipv4.local_port();
69        let v6_port = sockets.ipv6.local_port();
70        let send_session_states = sync::Arc::new(sync::RwLock::new(hashbrown::HashMap::new()));
71        Ok((
72            Self {
73                // baseline size to be expanded if needed when a session is added
74                v4_buf: Vec::new(),
75                v6_buf: Vec::new(),
76                sockets: sockets.clone(),
77                next_handle_id: 0,
78                send_session_states: send_session_states.clone(),
79                recv_session_states: hashbrown::HashMap::new(),
80                commands: rx,
81                shutdown: false,
82            },
83            v4_port,
84            v6_port,
85            sockets,
86            tx,
87            send_session_states,
88        ))
89    }
90
91    /// Run the main loop until shutdown.
92    pub(crate) async fn run(&mut self) {
93        loop {
94            if self.shutdown {
95                break;
96            }
97
98            if let Err(e) = self.recv_or_cmd().await {
99                warn!("Recv task error: {e}")
100            }
101        }
102    }
103
104    /// Return `Err` if it's worth logging.
105    async fn recv_or_cmd(&mut self) -> Result<(), RecvError> {
106        let send_states = &mut self.send_session_states;
107        let recv_states = &mut self.recv_session_states;
108        tokio::select! {
109            v4_res = self.sockets.ipv4.recv(&mut self.v4_buf) => {
110                let (msg, _range) = v4_res?;
111                handle_recv(msg, IcmpV4MsgType::EchoReply as u8, send_states, recv_states)?;
112            }
113            v6_res = self.sockets.ipv6.recv(&mut self.v6_buf) => {
114                let (msg, _range) = v6_res?;
115                handle_recv(msg, IcmpV6MsgType::EchoReply as u8, send_states, recv_states)?;
116            }
117            cmd_opt = self.commands.recv() => {
118                match cmd_opt {
119                    None => {
120                        // treat closing the cmd channel as shutdown
121                        self.handle_command(MultiplexerCommand::Shutdown(oneshot::channel().0)).await?
122                    }
123                    Some(cmd) => self.handle_command(cmd).await?
124                }
125            }
126        }
127
128        Ok(())
129    }
130
131    /// Must always send _something_ to `cmd`'s reply channel.
132    async fn handle_command(&mut self, cmd: MultiplexerCommand) -> Result<(), RecvError> {
133        match cmd {
134            MultiplexerCommand::Shutdown(reply) => {
135                self.shutdown = true;
136                // we'll be exiting the loop anyway but might as well clean up aggressively
137                self.send_session_states.write().unwrap().clear();
138                self.recv_session_states.clear();
139                self.commands.close();
140                reply_if_possible(reply, ())
141            }
142            MultiplexerCommand::AddSession {
143                ip,
144                id,
145                data,
146                reply,
147            } => reply_if_possible(reply, self.add_session(ip, id, data, 16)),
148            MultiplexerCommand::CloseSession {
149                session_handle,
150                reply,
151            } => {
152                handle_close_session(
153                    session_handle,
154                    &mut self.send_session_states,
155                    &mut self.recv_session_states,
156                );
157                reply_if_possible(reply, ())
158            }
159        }
160
161        Ok(())
162    }
163
164    /// Add a session.
165    ///
166    /// `channel_buf_size` need not be very large (i.e. 8) unless you plan on only occasionally
167    /// reading from the receivers. Normally there would be a task constantly waiting on all such
168    /// receivers, and since pings are not normally sent or received very fast, a small channel will
169    /// do.
170    ///
171    /// If the returned receiver is detected as dropped when trying to send to it,
172    /// the session will be closed.
173    fn add_session(
174        &mut self,
175        ip: net::IpAddr,
176        id: EchoId,
177        data: Vec<u8>,
178        channel_buf_size: usize,
179    ) -> Result<(SessionHandle, tmpsc::Receiver<ReplyTimestamp>), AddSessionError> {
180        // resize buf to be able to receive the new data size, if needed.
181        // 4 for ICMP header, 4 for ICMP Echo Reply header.
182        let buf_len = 4 + 4 + data.len();
183        match ip {
184            net::IpAddr::V4(_) => {
185                let prefix_len = if platform::ipv4_recv_prefix_ipv4_header() {
186                    // Normal IPv4 header is 20 bytes, but technically there could be options, so
187                    // use 60 = max IPv4 header len.
188                    60
189                } else {
190                    0
191                };
192
193                let buf_len = prefix_len + buf_len;
194                if self.v4_buf.len() < buf_len {
195                    self.v4_buf.resize(buf_len, 0);
196                }
197            }
198            net::IpAddr::V6(_) => {
199                if self.v6_buf.len() < buf_len {
200                    self.v6_buf.resize(buf_len, 0);
201                }
202            }
203        }
204
205        let echo_data = sync::Arc::new(SessionEchoData { id, data });
206        let key = RecvHashKey {
207            echo_data: echo_data.clone(),
208        };
209        let (tx, rx) = tmpsc::channel(channel_buf_size);
210
211        let recv_state = match self.recv_session_states.entry(key) {
212            Entry::Occupied(_) => {
213                return Err(AddSessionError::Duplicate);
214            }
215            Entry::Vacant(v) => {
216                v.insert(RecvSessionState {
217                    tx,
218                    // placeholder, replaced below
219                    session_handle: SessionHandle { id: u64::MAX },
220                })
221            }
222        };
223
224        // id/data are unique, so now we can populate send state with a unique handle id
225
226        let send_state = SendSessionState {
227            ip,
228            echo_data: echo_data.clone(),
229        };
230
231        // highly unlikely to ever wrap around on u64 but might as well be thorough
232        loop {
233            let handle = SessionHandle {
234                id: self.next_handle_id,
235            };
236            // prepare for either the next loop or the next add
237            self.next_handle_id = self.next_handle_id.wrapping_add(1);
238
239            match self.send_session_states.write().unwrap().entry(handle) {
240                Entry::Occupied(_) => {
241                    continue;
242                }
243                Entry::Vacant(v) => {
244                    v.insert(send_state);
245                    recv_state.session_handle = handle;
246                    debug!(
247                        "Added session: handle = {handle:?}, id = {id:?}, data = {}",
248                        hex::encode(&echo_data.data)
249                    );
250                    return Ok((handle, rx));
251                }
252            }
253        }
254    }
255}
256
257/// A top level function to avoid lifetime wrangling
258fn handle_recv(
259    msg: &[u8],
260    echo_reply_type: u8,
261    send_states: &mut sync::Arc<sync::RwLock<hashbrown::HashMap<SessionHandle, SendSessionState>>>,
262    recv_states: &mut hashbrown::HashMap<RecvHashKey, RecvSessionState>,
263) -> Result<(), RecvError> {
264    let decoded = if let Ok(decoded) = DecodedIcmpMsg::decode(msg) {
265        decoded
266    } else {
267        debug!("ICMP message parse failed");
268        return Ok(());
269    };
270
271    if decoded.msg_type() != echo_reply_type || decoded.msg_code() != 0 {
272        debug!(
273            "Skipping irrelevant ICMP message type {} code {}",
274            decoded.msg_type(),
275            decoded.msg_code()
276        );
277        return Ok(());
278    }
279
280    let (seq, key) = if let Some((id, seq, data)) = parse_echo_reply(decoded.body()) {
281        (seq, RefHashKey { id, data })
282    } else {
283        debug!("Couldn't parse body as Echo Reply");
284        return Ok(());
285    };
286
287    if let Some(recv_state) = recv_states.get(&key) {
288        debug!("Reply for {:?}: seq {:?}", recv_state.session_handle, seq,);
289
290        if let Err(e) = recv_state.tx.try_send(ReplyTimestamp {
291            seq,
292            received_at: time::Instant::now(),
293        }) {
294            match e {
295                TrySendError::Full(_) => {
296                    warn!("Session channel overflow");
297                }
298                TrySendError::Closed(_) => {
299                    debug!("Session channel closed; closing session");
300                    // rx has been dropped, can close the session
301                    handle_close_session(recv_state.session_handle, send_states, recv_states)
302                }
303            }
304        }
305    } else {
306        debug!("Couldn't find session for {key:?}");
307    }
308
309    Ok(())
310}
311
312fn handle_close_session(
313    session_handle: SessionHandle,
314    send_session_states: &mut sync::Arc<
315        sync::RwLock<hashbrown::HashMap<SessionHandle, SendSessionState>>,
316    >,
317    recv_session_states: &mut hashbrown::HashMap<RecvHashKey, RecvSessionState>,
318) {
319    // remove from send map
320    if let Some(send_state) = send_session_states.write().unwrap().remove(&session_handle) {
321        // we found the handle, so we can close the recv channel
322        recv_session_states.remove(&RecvHashKey {
323            echo_data: send_state.echo_data,
324        });
325    }
326}
327
328/// Reply if the channel is still open, logging if the send fails.
329fn reply_if_possible<T>(reply: oneshot::Sender<T>, val: T) {
330    if reply.send(val).is_err() {
331        debug!("Could not reply - channel closed");
332    }
333}
334
335/// All commands should have a "reply" Sender that the recv task will use to respond.
336///
337/// Since each oneshot requires a heap allocation, commands should only be used for relatively low
338/// frequency communication.
339pub(crate) enum MultiplexerCommand {
340    /// Slightly richer than simply closing the channel, as it allows us to wait until
341    /// the task reports shutdown is complete.
342    Shutdown(oneshot::Sender<()>),
343    AddSession {
344        ip: net::IpAddr,
345        id: EchoId,
346        data: Vec<u8>,
347        reply: oneshot::Sender<
348            Result<(SessionHandle, tmpsc::Receiver<ReplyTimestamp>), AddSessionError>,
349        >,
350    },
351    CloseSession {
352        session_handle: SessionHandle,
353        reply: oneshot::Sender<()>,
354    },
355}
356
357/// A handle to a session.
358///
359/// Created by [`PingMultiplexer::add_session`].
360#[derive(Clone, Copy, Hash, Debug, PartialEq, Eq)]
361pub struct SessionHandle {
362    id: u64,
363}
364
365/// A record of receiving an ICMP Echo Reply for `seq` at `timestamp`.
366#[derive(Debug, PartialEq, Eq)]
367pub struct ReplyTimestamp {
368    /// The reply sequence number
369    pub seq: EchoSeq,
370    /// The timestamp when the reply was received
371    pub received_at: time::Instant,
372}
373
374/// Task lifecycle errors applicable for any command
375#[derive(Debug, thiserror::Error)]
376pub enum LifecycleError {
377    /// The multiplexer has shut down, so it cannot respond to further commands
378    #[error("Multiplexer has shut down")]
379    Shutdown,
380}
381
382/// Errors that can occur when adding a session
383#[derive(Debug, thiserror::Error)]
384pub enum AddSessionError {
385    /// The provided session metadata (id, data) is already in use.
386    #[error("Duplicate session metadata")]
387    Duplicate,
388    /// Lifecycle error
389    #[error("Lifecycle error: {0}")]
390    Lifecycle(#[from] LifecycleError),
391}
392
393#[derive(Debug, thiserror::Error)]
394enum RecvError {
395    #[error("IO error: {0}")]
396    Io(#[from] io::Error),
397}
398
399/// Errors that can occur when sending a ping
400#[derive(Debug, thiserror::Error)]
401pub enum SendPingError {
402    /// Invalid session handle
403    #[error("Invalid session handle")]
404    InvalidSessionHandle,
405    /// IO error
406    #[error("IO error: {0}")]
407    Io(#[from] io::Error),
408    /// Recv task error
409    #[error("Task error: {0}")]
410    Lifecycle(#[from] LifecycleError),
411}
412
413/// State needed when sending an echo request for a session
414#[derive(Debug)]
415pub(crate) struct SendSessionState {
416    /// IP to send to
417    pub(crate) ip: net::IpAddr,
418
419    /// Key needed to clear the session from the recv state map.
420    ///
421    /// Arc overhead is negligible because we only make one clone ever, and only destroy
422    /// during session close.
423    pub(crate) echo_data: sync::Arc<SessionEchoData>,
424}
425
426/// Recv path state kept by the recv task for each open session
427#[derive(Debug)]
428struct RecvSessionState {
429    /// Used when closing a session via detecting a closed channel
430    session_handle: SessionHandle,
431    /// Where timestamps for replies matching the session are sent
432    tx: tokio::sync::mpsc::Sender<ReplyTimestamp>,
433}
434
435/// Owned key for recv session map.
436///
437/// Contains the data that would be provided by parsing an ICMP Echo Reply.
438#[derive(Debug, PartialEq, Eq)]
439struct RecvHashKey {
440    echo_data: sync::Arc<SessionEchoData>,
441}
442
443/// Data needed to look up a session with reasonable confidence, and therefore also the data needed
444/// to send a request for that session.
445#[derive(PartialEq, Eq)]
446pub(crate) struct SessionEchoData {
447    pub(crate) id: EchoId,
448    pub(crate) data: Vec<u8>,
449}
450
451impl fmt::Debug for SessionEchoData {
452    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
453        f.debug_struct("SessionEchoData")
454            .field("id", &self.id)
455            .field("data", &hex::encode(&self.data))
456            .finish()
457    }
458}
459
460/// Must match Hash impl for [`RefHashKey`]
461impl hash::Hash for RecvHashKey {
462    fn hash<H: hash::Hasher>(&self, state: &mut H) {
463        // prefix free: id is always two bytes
464        self.echo_data.id.hash(state);
465        self.echo_data.data.hash(state);
466    }
467}
468
469/// Reference form of [`RecvHashKey`] for map queries via [`hashbrown::Equivalent`]
470/// without having to own a copy of `data`
471#[derive(PartialEq, Eq)]
472struct RefHashKey<'a> {
473    id: EchoId,
474    data: &'a [u8],
475}
476
477/// Must match Hash impl for [`RecvHashKey`]
478#[allow(clippy::needless_lifetimes)] // not on 1.74
479impl<'a> hash::Hash for RefHashKey<'a> {
480    fn hash<H: hash::Hasher>(&self, state: &mut H) {
481        self.id.hash(state);
482        self.data.hash(state);
483    }
484}
485
486#[allow(clippy::needless_lifetimes)] // not on 1.74
487impl<'a> hashbrown::Equivalent<RecvHashKey> for RefHashKey<'a> {
488    fn equivalent(&self, key: &RecvHashKey) -> bool {
489        self.id == key.echo_data.id && self.data == key.echo_data.data
490    }
491}
492
493impl fmt::Debug for RefHashKey<'_> {
494    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
495        f.debug_struct("RefHashKey")
496            .field("id", &self.id)
497            .field("data", &hex::encode(self.data))
498            .finish()
499    }
500}
501
502#[cfg(test)]
503mod tests {
504    use super::*;
505    use hashbrown::Equivalent;
506
507    #[test]
508    fn hash_key_hash_equivalent_to_ref_hash_key() {
509        let key = RecvHashKey {
510            echo_data: SessionEchoData {
511                id: EchoId::from_be(1234),
512                data: vec![5, 6, 7, 8],
513            }
514            .into(),
515        };
516
517        let mut ref_key = RefHashKey {
518            id: key.echo_data.id,
519            data: &key.echo_data.data,
520        };
521
522        assert!(ref_key.equivalent(&key));
523
524        ref_key.id = [42_u8; 2].into();
525        assert!(!ref_key.equivalent(&key));
526    }
527}