async_icmp/socket/
mod.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
//! ICMP sockets.

use crate::message::EncodeIcmpMessage;
use crate::IpVersion;
use std::io::Read;
use std::{io, net, ops, os::fd};
use tokio::io::unix;
use winnow::{binary, combinator, Parser as _};

// Re-exported since it's in our public API.
pub use socket2;

#[cfg(test)]
mod tests;

/// An ICMP socket.
///
/// Commonly this would be wrapped in an `Arc` so that it may be used from multiple tasks (e.g.
/// one for sending, one for receiving).
#[derive(Debug)]
pub struct IcmpSocket {
    ip_version: IpVersion,
    fd: unix::AsyncFd<IcmpSocketInner>,
}

impl IcmpSocket {
    /// Create a new socket with the specified IP version.
    ///
    /// When a socket is created, it's either IPv4 (socket type `AF_INET`) or IPv6 (`AF_INET6`), which
    /// governs which type of IP address is valid to use with [`IcmpSocket::send_to`].
    pub fn new(ip_version: IpVersion) -> io::Result<Self> {
        Ok(Self {
            ip_version,
            fd: unix::AsyncFd::new(IcmpSocketInner::new(ip_version)?)?,
        })
    }

    /// Write the contents of a received ICMP message into `buf`, returning a tuple containing the
    /// ICMP message and the range of indices in `buf` holding the message.
    ///
    /// Bytes outside the returned `range` may have been written to, and skipped during subsequent
    /// parsing.
    ///
    /// See [`crate::message::decode::DecodedIcmpMsg`] to extract basic ICMP message structure.
    ///
    /// # Platform differences
    ///
    /// Platforms differ on what messages will be exposed to userspace this way. On Linux, only
    /// ICMP Echo Reply messages where the id = the local port will be visible. On macOS, the
    /// kernel less restrictive, and other ICMP packets will also be visible, so additional
    /// filtering may be needed depending on the use case.
    pub async fn recv<'a>(&self, buf: &'a mut [u8]) -> io::Result<(&'a [u8], ops::Range<usize>)> {
        self.fd
            .async_io(tokio::io::Interest::READABLE, |inner| {
                // the read() impl is a simple wrapper around recv(2)
                (&inner.socket).read(buf)
            })
            .await
            .and_then(|len| match self.ip_version {
                IpVersion::V4 => strip_ipv4_header(&buf[..len]).map_err(|_e| {
                    io::Error::new(io::ErrorKind::InvalidData, "Could not strip IPv4 header")
                }),
                // IPV6 doesn't include headers, so we can use the length as is
                IpVersion::V6 => Ok((&buf[..len], 0..len)),
            })
    }

    /// Send `msg` to `addr`.
    ///
    /// If `msg` doesn't support the socket's IP version, an error will be returned.
    pub async fn send_to(
        &self,
        addr: net::IpAddr,
        msg: &mut impl EncodeIcmpMessage,
    ) -> io::Result<()> {
        self.fd
            .async_io(tokio::io::Interest::WRITABLE, |inner| {
                // port is not used
                let socket_addr = net::SocketAddr::new(addr, 0);
                inner.socket.send_to(
                    msg.encode_for_version(self.ip_version)
                        .map_err(io::Error::from)?,
                    &socket_addr.into(),
                )
            })
            .await
            .map(|_| ())
    }

    /// Access the underlying socket for any needed customization.
    ///
    /// Don't use this unless you know what you're doing.
    pub fn as_mut_socket(&mut self) -> &mut socket2::Socket {
        &mut self.fd.get_mut().socket
    }

    /// Returns the local port of the socket.
    ///
    /// This is useful on Linux since the local port is used as the ICMP Echo ID.
    pub fn local_port(&self) -> io::Result<u16> {
        let local_addr = self.fd.get_ref().socket.local_addr()?;
        local_addr.as_socket().map(|sa| sa.port()).ok_or_else(|| {
            io::Error::new(io::ErrorKind::Other, "Socket is not AF_INET or AF_INET6?")
        })
    }

    /// The IP version used by the socket
    pub fn ip_version(&self) -> IpVersion {
        self.ip_version
    }
}

/// A non-public type for the necessary impls to make AsyncFd work
#[derive(Debug)]
struct IcmpSocketInner {
    socket: socket2::Socket,
}

impl IcmpSocketInner {
    fn new(ip_version: IpVersion) -> io::Result<Self> {
        let s = socket2::Socket::new(
            match ip_version {
                IpVersion::V4 => socket2::Domain::IPV4,
                IpVersion::V6 => socket2::Domain::IPV6,
            },
            socket2::Type::DGRAM,
            Some(match ip_version {
                IpVersion::V4 => socket2::Protocol::ICMPV4,
                IpVersion::V6 => socket2::Protocol::ICMPV6,
            }),
        )?;
        s.set_nonblocking(true)?;

        Ok(Self { socket: s })
    }
}

/// Required by [unix::AsyncFd]
impl fd::AsRawFd for IcmpSocketInner {
    fn as_raw_fd(&self) -> fd::RawFd {
        self.socket.as_raw_fd()
    }
}

type WinnowError<'a> =
    winnow::error::ParseError<winnow::Located<&'a [u8]>, winnow::error::ContextError<&'static str>>;

/// Returns a result with a tuple of `(data after the ipv4 header, index range of the data)`.
///
/// The index range is useful if the caller wants to treat the data as a `&mut [u8]`.
fn strip_ipv4_header(input: &[u8]) -> Result<(&[u8], ops::Range<usize>), WinnowError> {
    // discard complete ip header
    combinator::preceded(
        binary::bits::bits(
            // get and take ipv4 header len
            binary::length_take(
                // verify and discard ip version, yielding just the header length
                combinator::preceded(
                    // 4 bit version
                    binary::bits::pattern::<_, _, _, winnow::error::ContextError<&'static str>>(
                        0x04_u8, 4_usize,
                    )
                    .context("Invalid version"),
                    // 4 bit length in 32-bit words
                    binary::bits::take(4_usize)
                        // length includes the byte we just parsed
                        .verify_map(|len: usize| {
                            len.checked_mul(32).and_then(|prod| prod.checked_sub(8))
                        }),
                ),
            ),
        ),
        combinator::rest::<_, winnow::error::ContextError<_>>.with_span(),
    )
    .parse(winnow::Located::new(input))
}