async-icmp 0.2.1

Async ICMP library
Documentation
//! High level "ping" (ICMP Echo / Echo Reply) API, built on top of the rest of this library.
//!
//! [`PingMultiplexer`] supports pinging multiple hosts concurrently, both IPv4 and IPv6, with flexible Echo contents.

use log::debug;
use std::{fmt, io, net, sync, time};
use tokio::sync::{mpsc as tmpsc, oneshot};

mod multiplexer_task;
#[cfg(test)]
mod tests;

use crate::{
    message::echo::{EchoId, EchoSeq, IcmpEchoRequest},
    ping::multiplexer_task::{MultiplexTask, MultiplexerCommand, SendSessionState},
    platform,
    socket::{SocketConfig, SocketPair},
    Icmpv4, Icmpv6, IpVersion,
};
pub use multiplexer_task::{
    AddSessionError, LifecycleError, ReplyTimestamp, SendPingError, SessionHandle,
};

/// A high-level ping (ICMP echo / echo reply) API.
///
/// This addresses the common case of handling one or more ping "sessions": for a given
/// IP address, ICMP Echo ID, and ICMP Echo Data, you send one or more pings (ICMP Echo).
///
/// Each session has its own channel for responses (ICMP Echo Reply).
///
/// Because this is designed to be used from many tasks simultaneously, it is safe to `clone()`
/// as it uses `Arc` internally.
///
/// # Examples
///
/// Create two ping sessions on one socket, and send one ping with each.
///
/// ```
/// use std::net;
/// use async_icmp::{
///     message::echo::EchoSeq,
///     IpVersion,
///     ping::PingMultiplexer,
///     socket::SocketConfig
/// };
///
/// async fn ping_demo() -> anyhow::Result<()> {
///     let multiplexer = PingMultiplexer::new(SocketConfig::default(), SocketConfig::default())?;
///
///     // Two sessions with distinct `data` (and possibly id, platform permitting)
///     let (handle1, mut rx1) = multiplexer
///        .add_session(
///            net::Ipv4Addr::LOCALHOST.into(),
///            multiplexer.platform_echo_id(IpVersion::V4).unwrap_or_else(rand::random),
///            rand::random::<[u8; 32]>().to_vec(),
///        )
///        .await?;
///
///     let (handle2, mut rx2) = multiplexer
///        .add_session(
///            net::Ipv4Addr::LOCALHOST.into(),
///            multiplexer.platform_echo_id(IpVersion::V4).unwrap_or_else(rand::random),
///            rand::random::<[u8; 32]>().to_vec(),
///        )
///        .await?;
///
///     // Using distinct `seq` just to show that sessions are disambiguated.
///     // Typically you would start at 0 and increment for multiple pings.
///     let seq1 = EchoSeq::from_be(3);
///     let seq2 = EchoSeq::from_be(7);
///
///     // Receiver task waiting in the background
///     let receiver = tokio::spawn(async move {
///         assert_eq!(seq1, rx1.recv().await.unwrap().seq);
///         assert_eq!(seq2, rx2.recv().await.unwrap().seq);
///     });
///
///     multiplexer.send_ping(handle1, seq1).await?;
///     multiplexer.send_ping(handle2, seq2).await?;
///
///     // Make sure receiver got expected results
///     receiver.await?;
///
///     Ok(())
/// }
///
/// # tokio_test::block_on(ping_demo()).unwrap();
/// ```
#[derive(Clone)]
pub struct PingMultiplexer {
    state: sync::Arc<MultiplexerClientState>,
}

impl PingMultiplexer {
    /// Create a new multiplexer with the provided socket configs.
    pub fn new(
        icmpv4_config: SocketConfig<Icmpv4>,
        icmpv6_config: SocketConfig<Icmpv6>,
    ) -> io::Result<Self> {
        let (mut inner, ipv4_local_port, ipv6_local_port, sockets, tx, send_state) =
            MultiplexTask::new(icmpv4_config, icmpv6_config)?;

        let handle = tokio::spawn(async move {
            inner.run().await;
        });

        Ok(Self {
            state: sync::Arc::new(MultiplexerClientState {
                commands: tx,
                ipv4_local_port,
                ipv6_local_port,
                sockets,
                task_handle: Some(handle).into(),
                send_sessions: send_state,
                req_pool: opool::Pool::new(4, ReqAllocator),
            }),
        })
    }

    /// Add a session to the multiplexer.
    ///
    /// Echo Reply messages with the provided `id` and `data` will cause a [`ReplyTimestamp`] to be
    /// sent to the returned channel receiver. Timestamps will be sent as they are received from the
    /// socket, including duplicates, etc.
    ///
    /// If the receiver is not drained fast enough, timestamps will be dropped.
    ///
    /// If the receiver is dropped, the session may be closed at some point in the future, but to
    /// reliably release resources, call [`Self::close_session`].
    ///
    /// On some platforms (for which [`platform::icmp_send_overwrite_echo_id_with_local_port`] returns true),
    /// the provided `id` will be overwritten in the kernel by the local port. On such platforms,
    /// `id` must be the local port for the relevant socket. If a different id is used, echo reply
    /// messages won't be matched to this session.
    ///
    /// See [`Self::platform_echo_id`], [`Self::ipv4_local_port`], and [`Self::ipv6_local_port`].
    pub async fn add_session(
        &self,
        ip: net::IpAddr,
        id: EchoId,
        data: Vec<u8>,
    ) -> Result<(SessionHandle, tmpsc::Receiver<ReplyTimestamp>), AddSessionError> {
        let (tx, rx) = oneshot::channel();
        self.send_cmd(
            MultiplexerCommand::AddSession {
                ip,
                id,
                data,
                reply: tx,
            },
            rx,
        )
        .await?
    }

    /// Send a ping to the IP address specified when the session was created.
    ///
    /// Returns the timestamp at which the ICMP message was passed to the socket.
    pub async fn send_ping(
        &self,
        session_handle: SessionHandle,
        seq: EchoSeq,
    ) -> Result<time::Instant, SendPingError> {
        {
            let (mut req, ip) = {
                if let Some(session_send_state) = self
                    .state
                    .send_sessions
                    .read()
                    .unwrap()
                    .get(&session_handle)
                {
                    let mut req = self.state.req_pool.get();
                    req.set_id(session_send_state.echo_data.id);
                    req.set_seq(seq);
                    req.set_data(&session_send_state.echo_data.data);
                    (req, session_send_state.ip)
                } else {
                    return Err(SendPingError::InvalidSessionHandle);
                }
            };

            self.state.sockets.send_to_either(&mut *req, ip).await?;
            debug!("Sent {session_handle:?} seq {seq:?}");
            Ok(time::Instant::now())
        }
    }

    /// Close a session, releasing any resources associated with it.
    ///
    /// If the session is open, it will be closed. It is not an error to attempt to close an already
    /// closed session.
    pub async fn close_session(&self, session_handle: SessionHandle) -> Result<(), LifecycleError> {
        let (tx, rx) = oneshot::channel();
        self.send_cmd(
            MultiplexerCommand::CloseSession {
                session_handle,
                reply: tx,
            },
            rx,
        )
        .await
    }

    /// Shutdown the multiplexer.
    ///
    /// While dropping this value will eventually shut down the background task, if you
    /// need to wait until shutdown is complete, this method provides that.
    ///
    /// It is not an error to call this multiple times.
    ///
    /// Attempts to use the multiplexer (send, etc) after shutdown will result in an error.
    pub async fn shutdown(&self) {
        let (tx, rx) = oneshot::channel();
        if let Err(e) = self.send_cmd(MultiplexerCommand::Shutdown(tx), rx).await {
            match e {
                LifecycleError::Shutdown => {
                    // we're already shutting down, so this is fine
                }
            }
        }

        // holding the lock only momentarily, not across an await point
        let handle = match self.state.task_handle.lock().unwrap().take() {
            Some(h) => h,
            None => return,
        };

        if let Err(e) = handle.await {
            debug!("Inner task exited with error: {}", e);
        };
    }

    /// Returns the local port used by the IPv4 listen socket.
    ///
    /// See [`Self::ipv6_local_port`] for the IPv6 equivalent, and [`Self::platform_echo_id`]
    /// for use as an echo id.
    pub fn ipv4_local_port(&self) -> u16 {
        self.state.ipv4_local_port
    }

    /// Returns the local port used by the IPv6 listen socket.
    ///
    /// See [`Self::ipv4_local_port`] for the IPv4 equivalent, and [`Self::platform_echo_id`]
    /// for use as an echo id.
    pub fn ipv6_local_port(&self) -> u16 {
        self.state.ipv6_local_port
    }

    /// Returns the local port used by the IPv4 or IPv6 listen socket (per `ip_version` ) as an
    /// [`EchoId`], if it required to be used as the `id` in ICMP Echo Request messages by the local
    /// platform -- that is, when [`platform::icmp_send_overwrite_echo_id_with_local_port`]
    /// returns true.
    ///
    /// This is just a possibly more convenient way to turn [`Self::ipv4_local_port`] or
    /// [`Self::ipv6_local_port`] into an [`EchoId`].
    ///
    /// # Examples
    ///
    /// Get the IPv4 local port as an EchoId, if required by the platform, otherwise use a random
    /// one.
    ///
    /// ```
    /// use async_icmp::{message::echo::EchoId, ping::PingMultiplexer, platform, IpVersion};
    ///
    /// fn get_ipv4_echo_id(multiplexer: &PingMultiplexer) -> EchoId {
    ///     multiplexer.platform_echo_id(IpVersion::V4).unwrap_or_else(rand::random)
    /// }
    /// ```
    pub fn platform_echo_id(&self, ip_version: IpVersion) -> Option<EchoId> {
        if platform::icmp_send_overwrite_echo_id_with_local_port() {
            let port = match ip_version {
                IpVersion::V4 => self.ipv4_local_port(),
                IpVersion::V6 => self.ipv6_local_port(),
            };

            Some(EchoId::from_be(port))
        } else {
            None
        }
    }

    async fn send_cmd<T>(
        &self,
        cmd: MultiplexerCommand,
        rx: oneshot::Receiver<T>,
    ) -> Result<T, LifecycleError> {
        // If the cmd receiver or reply sender are closed or dropped, treat as shutdown
        self.state
            .commands
            .send(cmd)
            .await
            .map_err(|_| LifecycleError::Shutdown)?;
        rx.await.map_err(|_| LifecycleError::Shutdown)
    }
}

impl fmt::Debug for PingMultiplexer {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "PingMultiplexer")
    }
}

/// State for being the client end of working with [`MultiplexerTask`].
struct MultiplexerClientState {
    /// Tx side to communicate with recv task
    commands: tmpsc::Sender<MultiplexerCommand>,
    /// Used for sending only.
    ///
    /// Recv task does all the receiving, with a clone of the same Arc.
    sockets: sync::Arc<SocketPair>,
    ipv4_local_port: u16,
    ipv6_local_port: u16,
    /// Handle for the recv task.
    ///
    /// `Some` if not shut down yet.
    ///
    /// Behind a lock to allow shutdown to only take &self.
    task_handle: sync::Mutex<Option<tokio::task::JoinHandle<()>>>,
    /// Contents maintained by the background task.
    ///
    /// We only use read locks to get session data when sending.
    send_sessions: sync::Arc<sync::RwLock<hashbrown::HashMap<SessionHandle, SendSessionState>>>,
    /// Object pool for ICMP requests to lower steady state allocation
    req_pool: opool::Pool<ReqAllocator, IcmpEchoRequest>,
}

struct ReqAllocator;

impl opool::PoolAllocator<IcmpEchoRequest> for ReqAllocator {
    fn allocate(&self) -> IcmpEchoRequest {
        IcmpEchoRequest::new()
    }
}