async_icmp/socket/
mod.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
//! ICMP socket support.
//!
//! Sockets have an ICMP version as a type parameter, allowing precise types for
//! IP address, etc.
//!
//! If the use case demands runtime selection of IP versions ala [`net::IpAddr`], where the
//! version is determined at runtime, see [`SocketPair`].

use crate::{
    message::{echo::EchoId, EncodeIcmpMessage},
    platform, IcmpVersion,
};
use std::{io, io::Read as _, marker, net, ops, os::fd};
use tokio::io::unix;
use winnow::{binary, combinator, Parser as _};

mod pair;

pub use pair::SocketPair;

#[cfg(test)]
mod tests;

/// An ICMP socket.
///
/// Commonly this would be wrapped in an `Arc` so that it may be used from multiple tasks (e.g.
/// one for sending, one for receiving).
///
/// # Platform differences
///
/// On Linux, ICMP Echo Request messages are rewritten to use the local port as the id, and only
/// ICMP Echo Reply messages where the id = the local port will be returned from `recv()`.
/// [`IcmpSocket::local_port`] and [`IcmpSocket::platform_echo_id`] exist for such use cases.
/// See [`platform::icmp_send_overwrite_echo_id_with_local_port`].
///
/// On macOS, the kernel is less restrictive, so you can set whatever id you like. In addition,
/// other ICMP packets will also be returned from `recv()`, so additional filtering may be needed
/// depending on the use case.
#[derive(Debug)]
pub struct IcmpSocket<V> {
    fd: unix::AsyncFd<IcmpSocketInner<V>>,
    local_port: u16,
}

impl<V: IcmpVersion> IcmpSocket<V> {
    /// Create a new socket for IP version `V`.
    ///
    /// When a socket is created, it's either IPv4 (socket type `AF_INET`) or IPv6 (`AF_INET6`), which
    /// governs which type of IP address is valid to use with [`IcmpSocket::send_to`]
    /// ([`net::Ipv4Addr`] or [`net::Ipv6Addr`]).
    pub fn new(config: SocketConfig<V>) -> io::Result<Self> {
        let fd = unix::AsyncFd::new(IcmpSocketInner::new(config)?)?;
        let local_port = fd
            .get_ref()
            .socket
            .local_addr()?
            .as_socket()
            .map(|sa| sa.port())
            .ok_or_else(|| {
                io::Error::new(io::ErrorKind::Other, "Socket is not AF_INET or AF_INET6?")
            })?;
        Ok(Self { fd, local_port })
    }

    /// Write the contents of a received ICMP message into `buf`, returning a tuple containing the
    /// ICMP message and the range of indices in `buf` holding the message, in case mutable access
    /// to the slice is desired.
    ///
    /// Bytes outside the returned `range` may have been written to, and skipped during subsequent
    /// parsing.
    ///
    /// See [`crate::message::decode::DecodedIcmpMsg`] to extract basic ICMP message structure.
    pub async fn recv<'a>(&self, buf: &'a mut [u8]) -> io::Result<(&'a [u8], ops::Range<usize>)> {
        self.fd
            .async_io(tokio::io::Interest::READABLE, |inner| {
                // the read() impl is a simple wrapper around recv(2)
                (&inner.socket).read(buf)
            })
            .await
            .and_then(|len| V::extract_icmp_from_recv_packet(&buf[..len]))
    }

    /// Send `msg` to `addr`.
    ///
    /// If `msg` doesn't support the socket's IP version, an error will be returned.
    pub async fn send_to(
        &self,
        msg: &mut impl EncodeIcmpMessage<V>,
        addr: V::Address,
    ) -> io::Result<()> {
        self.fd
            .async_io(tokio::io::Interest::WRITABLE, |inner| {
                let buffer = msg.encode();

                if V::checksum_required() {
                    buffer.calculate_icmpv4_checksum();
                }

                // port is not used
                let socket_addr = net::SocketAddr::new(addr.into(), 0);
                inner.socket.send_to(buffer.as_slice(), &socket_addr.into())
            })
            .await
            .map(|_| ())
    }

    /// Returns the local port of the socket.
    ///
    /// This is useful on Linux since the local port is used as the ICMP Echo ID regardless of what
    /// is set in userspace.
    ///
    /// On macOS, the local port is always zero, but ICMP Echo ids are not tied to the local port,
    /// so it's not an issue in practice.
    ///
    /// See [`platform::icmp_send_overwrite_echo_id_with_local_port`].
    pub fn local_port(&self) -> u16 {
        self.local_port
    }

    /// Returns the local port of the socket as the `id` to be used in an ICMP Echo Request message,
    /// if the current platform is one that forces the id to match the local port.
    ///
    /// See [`platform::icmp_send_overwrite_echo_id_with_local_port`].
    ///
    /// # Examples
    ///
    /// Use the platform echo id, otherwise a random id.
    /// ```
    /// use async_icmp::{IcmpVersion, message::echo::EchoId, socket::IcmpSocket};
    /// use std::io;
    ///
    /// fn echo_id<V: IcmpVersion>(socket: &IcmpSocket<V>) -> EchoId {
    ///     socket.platform_echo_id().unwrap_or_else(rand::random)
    /// }
    /// ```
    pub fn platform_echo_id(&self) -> Option<EchoId> {
        if platform::icmp_send_overwrite_echo_id_with_local_port() {
            Some(EchoId::from_be(self.local_port()))
        } else {
            None
        }
    }
}

/// Config for creating sockets.
///
/// Most use cases can use `SocketConfig::default()`.
///
/// To avoid compatibility concerns when more fields are added, use the `..` struct update syntax
/// so that any new fields will be conveniently defaulted in existing invocations:
///
/// ```
/// use std::net;
/// use async_icmp::{Icmpv4, socket::SocketConfig};
///
/// let config: SocketConfig<Icmpv4> = SocketConfig {
///     bind_to: Some(net::SocketAddrV4::new(net::Ipv4Addr::LOCALHOST, 1234)),
///     ..SocketConfig::default()
/// };
/// ```
#[derive(Debug, Clone)]
pub struct SocketConfig<V: IcmpVersion> {
    /// The sockaddr to bind the socket to. If specified with `Some`, the socket is always bound to
    /// the address.
    ///
    /// If not specified, the behavior depends on the platform. On all supported platforms, a
    /// socket's initial state is bound to the suitable `undefined` address (`0.0.0.0:0` or `:::0`).
    ///
    /// On Linux, explicitly binding that address causes the kernel to select a local port, which is
    /// useful for ICMP Echo messages since Linux forces the echo id to be the local port.
    ///
    /// On macOS, binding that address makes no difference: an ICMP socket always has zero local
    /// port, so the bind is not performed.
    pub bind_to: Option<V::SocketAddr>,
}

impl<V: IcmpVersion> Default for SocketConfig<V> {
    fn default() -> Self {
        Self { bind_to: None }
    }
}

/// A non-public type for the necessary impls to make AsyncFd work
#[derive(Debug)]
struct IcmpSocketInner<V> {
    socket: socket2::Socket,
    marker: marker::PhantomData<V>,
}

impl<V: IcmpVersion> IcmpSocketInner<V> {
    fn new(config: SocketConfig<V>) -> io::Result<Self> {
        let socket = socket2::Socket::new(V::DOMAIN, socket2::Type::DGRAM, Some(V::PROTOCOL))?;
        socket.set_nonblocking(true)?;

        // Sockets start bound to addr=undefined, port=0 according to local_addr on a fresh socket.
        // By specifically binding to that same thing again, it forces the kernel to choose
        // a local port, so it won't magically appear later.
        match config.bind_to {
            None => {
                if platform::socket_bind_sets_nonzero_local_port() {
                    socket.bind(&V::DEFAULT_BIND.into().into())?
                }
            }
            Some(sockaddr) => socket.bind(&sockaddr.into().into())?,
        }

        Ok(Self {
            socket,
            marker: marker::PhantomData,
        })
    }
}

/// Required by [unix::AsyncFd]
impl<V> fd::AsRawFd for IcmpSocketInner<V> {
    fn as_raw_fd(&self) -> fd::RawFd {
        self.socket.as_raw_fd()
    }
}

// only used on macOS
pub(crate) type WinnowError<'a, C> =
    winnow::error::ParseError<winnow::Located<&'a [u8]>, winnow::error::ContextError<C>>;

/// Returns a result with a tuple of `(data after the ipv4 header, index range of the data)`.
///
/// The index range is useful if the caller wants to treat the data as a `&mut [u8]`.
// only used on macOS
pub(crate) fn strip_ipv4_header(
    input: &[u8],
) -> Result<(&[u8], ops::Range<usize>), WinnowError<&'static str>> {
    // discard complete ip header
    combinator::preceded(
        binary::bits::bits(
            // get and take ipv4 header len
            binary::length_take(
                // verify and discard ip version, yielding just the header length
                combinator::preceded(
                    // 4 bit version
                    binary::bits::pattern::<_, _, _, winnow::error::ContextError<&'static str>>(
                        0x04_u8, 4_usize,
                    )
                    .context("Invalid version"),
                    // 4 bit length in 32-bit words
                    binary::bits::take(4_usize)
                        // length includes the byte we just parsed
                        .verify_map(|len: usize| {
                            len.checked_mul(32).and_then(|prod| prod.checked_sub(8))
                        }),
                ),
            ),
        ),
        combinator::rest::<_, winnow::error::ContextError<_>>.with_span(),
    )
    .parse(winnow::Located::new(input))
}