async_icmp/socket/mod.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
//! ICMP sockets.
use crate::message::EncodeIcmpMessage;
use crate::IpVersion;
use std::io::Read;
use std::{io, net, ops, os::fd};
use tokio::io::unix;
use winnow::{binary, combinator, Parser as _};
// Re-exported since it's in our public API.
pub use socket2;
#[cfg(test)]
mod tests;
/// An ICMP socket.
///
/// Commonly this would be wrapped in an `Arc` so that it may be used from multiple tasks (e.g.
/// one for sending, one for receiving).
#[derive(Debug)]
pub struct IcmpSocket {
ip_version: IpVersion,
fd: unix::AsyncFd<IcmpSocketInner>,
}
impl IcmpSocket {
/// Create a new socket with the specified IP version.
///
/// When a socket is created, it's either IPv4 (socket type `AF_INET`) or IPv6 (`AF_INET6`), which
/// governs which type of IP address is valid to use with [`IcmpSocket::send_to`].
pub fn new(ip_version: IpVersion) -> io::Result<Self> {
Ok(Self {
ip_version,
fd: unix::AsyncFd::new(IcmpSocketInner::new(ip_version)?)?,
})
}
/// Write the contents of a received ICMP message into `buf`, returning a tuple containing the
/// ICMP message and the range of indices in `buf` holding the message.
///
/// Bytes outside the returned `range` may have been written to, and skipped during subsequent
/// parsing.
///
/// See [`crate::message::decode::DecodedIcmpMsg`] to extract basic ICMP message structure.
///
/// # Platform differences
///
/// Platforms differ on what messages will be exposed to userspace this way. On Linux, only
/// ICMP Echo Reply messages where the id = the local port will be visible. On macOS, the
/// kernel less restrictive, and other ICMP packets will also be visible, so additional
/// filtering may be needed depending on the use case.
pub async fn recv<'a>(&self, buf: &'a mut [u8]) -> io::Result<(&'a [u8], ops::Range<usize>)> {
self.fd
.async_io(tokio::io::Interest::READABLE, |inner| {
// the read() impl is a simple wrapper around recv(2)
(&inner.socket).read(buf)
})
.await
.and_then(|len| match self.ip_version {
IpVersion::V4 => strip_ipv4_header(&buf[..len]).map_err(|_e| {
io::Error::new(io::ErrorKind::InvalidData, "Could not strip IPv4 header")
}),
// IPV6 doesn't include headers, so we can use the length as is
IpVersion::V6 => Ok((&buf[..len], 0..len)),
})
}
/// Send `msg` to `addr`.
///
/// If `msg` doesn't support the socket's IP version, an error will be returned.
pub async fn send_to(
&self,
addr: net::IpAddr,
msg: &mut impl EncodeIcmpMessage,
) -> io::Result<()> {
self.fd
.async_io(tokio::io::Interest::WRITABLE, |inner| {
// port is not used
let socket_addr = net::SocketAddr::new(addr, 0);
inner.socket.send_to(
msg.encode_for_version(self.ip_version)
.map_err(io::Error::from)?,
&socket_addr.into(),
)
})
.await
.map(|_| ())
}
/// Access the underlying socket for any needed customization.
///
/// Don't use this unless you know what you're doing.
pub fn as_mut_socket(&mut self) -> &mut socket2::Socket {
&mut self.fd.get_mut().socket
}
/// Returns the local port of the socket.
///
/// This is useful on Linux since the local port is used as the ICMP Echo ID.
pub fn local_port(&self) -> io::Result<u16> {
let local_addr = self.fd.get_ref().socket.local_addr()?;
local_addr.as_socket().map(|sa| sa.port()).ok_or_else(|| {
io::Error::new(io::ErrorKind::Other, "Socket is not AF_INET or AF_INET6?")
})
}
/// The IP version used by the socket
pub fn ip_version(&self) -> IpVersion {
self.ip_version
}
}
/// A non-public type for the necessary impls to make AsyncFd work
#[derive(Debug)]
struct IcmpSocketInner {
socket: socket2::Socket,
}
impl IcmpSocketInner {
fn new(ip_version: IpVersion) -> io::Result<Self> {
let s = socket2::Socket::new(
match ip_version {
IpVersion::V4 => socket2::Domain::IPV4,
IpVersion::V6 => socket2::Domain::IPV6,
},
socket2::Type::DGRAM,
Some(match ip_version {
IpVersion::V4 => socket2::Protocol::ICMPV4,
IpVersion::V6 => socket2::Protocol::ICMPV6,
}),
)?;
s.set_nonblocking(true)?;
Ok(Self { socket: s })
}
}
/// Required by [unix::AsyncFd]
impl fd::AsRawFd for IcmpSocketInner {
fn as_raw_fd(&self) -> fd::RawFd {
self.socket.as_raw_fd()
}
}
type WinnowError<'a> =
winnow::error::ParseError<winnow::Located<&'a [u8]>, winnow::error::ContextError<&'static str>>;
/// Returns a result with a tuple of `(data after the ipv4 header, index range of the data)`.
///
/// The index range is useful if the caller wants to treat the data as a `&mut [u8]`.
fn strip_ipv4_header(input: &[u8]) -> Result<(&[u8], ops::Range<usize>), WinnowError> {
// discard complete ip header
combinator::preceded(
binary::bits::bits(
// get and take ipv4 header len
binary::length_take(
// verify and discard ip version, yielding just the header length
combinator::preceded(
// 4 bit version
binary::bits::pattern::<_, _, _, winnow::error::ContextError<&'static str>>(
0x04_u8, 4_usize,
)
.context("Invalid version"),
// 4 bit length in 32-bit words
binary::bits::take(4_usize)
// length includes the byte we just parsed
.verify_map(|len: usize| {
len.checked_mul(32).and_then(|prod| prod.checked_sub(8))
}),
),
),
),
combinator::rest::<_, winnow::error::ContextError<_>>.with_span(),
)
.parse(winnow::Located::new(input))
}