use crate::{
message::{echo::EchoId, EncodeIcmpMessage},
platform, IcmpVersion,
};
use std::{io, io::Read as _, marker, net, ops, os::fd};
use tokio::io::unix;
use winnow::{binary, combinator, Parser as _};
mod pair;
pub use pair::SocketPair;
#[cfg(test)]
mod tests;
#[derive(Debug)]
pub struct IcmpSocket<V> {
fd: unix::AsyncFd<IcmpSocketInner<V>>,
local_port: u16,
}
impl<V: IcmpVersion> IcmpSocket<V> {
pub fn new(config: SocketConfig<V>) -> io::Result<Self> {
let fd = unix::AsyncFd::new(IcmpSocketInner::new(config)?)?;
let local_port = fd
.get_ref()
.socket
.local_addr()?
.as_socket()
.map(|sa| sa.port())
.ok_or_else(|| {
io::Error::new(io::ErrorKind::Other, "Socket is not AF_INET or AF_INET6?")
})?;
Ok(Self { fd, local_port })
}
pub async fn recv<'a>(&self, buf: &'a mut [u8]) -> io::Result<(&'a [u8], ops::Range<usize>)> {
self.fd
.async_io(tokio::io::Interest::READABLE, |inner| {
(&inner.socket).read(buf)
})
.await
.and_then(|len| V::extract_icmp_from_recv_packet(&buf[..len]))
}
pub async fn send_to(
&self,
msg: &mut impl EncodeIcmpMessage<V>,
addr: V::Address,
) -> io::Result<()> {
self.fd
.async_io(tokio::io::Interest::WRITABLE, |inner| {
let buffer = msg.encode();
if V::checksum_required() {
buffer.calculate_icmpv4_checksum();
}
let socket_addr = net::SocketAddr::new(addr.into(), 0);
inner.socket.send_to(buffer.as_slice(), &socket_addr.into())
})
.await
.map(|_| ())
}
pub fn local_port(&self) -> u16 {
self.local_port
}
pub fn platform_echo_id(&self) -> Option<EchoId> {
if platform::icmp_send_overwrite_echo_id_with_local_port() {
Some(EchoId::from_be(self.local_port()))
} else {
None
}
}
}
#[derive(Debug, Clone)]
pub struct SocketConfig<V: IcmpVersion> {
pub bind_to: Option<V::SocketAddr>,
}
impl<V: IcmpVersion> Default for SocketConfig<V> {
fn default() -> Self {
Self { bind_to: None }
}
}
#[derive(Debug)]
struct IcmpSocketInner<V> {
socket: socket2::Socket,
marker: marker::PhantomData<V>,
}
impl<V: IcmpVersion> IcmpSocketInner<V> {
fn new(config: SocketConfig<V>) -> io::Result<Self> {
let socket = socket2::Socket::new(V::DOMAIN, socket2::Type::DGRAM, Some(V::PROTOCOL))?;
socket.set_nonblocking(true)?;
match config.bind_to {
None => {
if platform::socket_bind_sets_nonzero_local_port() {
socket.bind(&V::DEFAULT_BIND.into().into())?
}
}
Some(sockaddr) => socket.bind(&sockaddr.into().into())?,
}
Ok(Self {
socket,
marker: marker::PhantomData,
})
}
}
impl<V> fd::AsRawFd for IcmpSocketInner<V> {
fn as_raw_fd(&self) -> fd::RawFd {
self.socket.as_raw_fd()
}
}
pub(crate) type WinnowError<'a, C> =
winnow::error::ParseError<winnow::Located<&'a [u8]>, winnow::error::ContextError<C>>;
pub(crate) fn strip_ipv4_header(
input: &[u8],
) -> Result<(&[u8], ops::Range<usize>), WinnowError<&'static str>> {
combinator::preceded(
binary::bits::bits(
binary::length_take(
combinator::preceded(
binary::bits::pattern::<_, _, _, winnow::error::ContextError<&'static str>>(
0x04_u8, 4_usize,
)
.context("Invalid version"),
binary::bits::take(4_usize)
.verify_map(|len: usize| {
len.checked_mul(32).and_then(|prod| prod.checked_sub(8))
}),
),
),
),
combinator::rest::<_, winnow::error::ContextError<_>>.with_span(),
)
.parse(winnow::Located::new(input))
}