use std::collections::HashMap;
use std::io;
use std::net::{IpAddr, SocketAddr};
use std::sync::{
atomic::{AtomicU16, Ordering},
Arc, Mutex, OnceLock,
};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use futures::channel::oneshot;
use socket2::{Domain, Protocol, Socket, Type};
use tokio::{net::UdpSocket, time};
use crate::{
icmp::IcmpPacket, IcmpEchoReply, IcmpEchoStatus, PING_DEFAULT_TIMEOUT, PING_DEFAULT_TTL,
};
type RequestRegistry = Arc<Mutex<HashMap<u16, oneshot::Sender<IcmpEchoReply>>>>;
struct RouterError {
kind: io::ErrorKind,
message: String,
}
impl RouterError {
fn from_io_error(e: &io::Error) -> Self {
RouterError {
kind: e.kind(),
message: e.to_string(),
}
}
fn to_io_error(&self) -> io::Error {
io::Error::new(self.kind, self.message.clone())
}
}
struct RouterContext {
target_addr: IpAddr,
socket: Arc<UdpSocket>,
registry: RequestRegistry,
failed: Arc<Mutex<Option<RouterError>>>,
}
#[derive(Clone)]
pub struct IcmpEchoRequestor {
inner: Arc<RequestorInner>,
}
struct RequestorInner {
socket: Arc<UdpSocket>,
target_addr: IpAddr,
timeout: Duration,
identifier: u16,
sequence: AtomicU16,
registry: RequestRegistry,
router_abort: OnceLock<tokio::task::AbortHandle>,
router_context: RouterContext,
}
impl IcmpEchoRequestor {
pub fn new(
target_addr: IpAddr,
source_addr: Option<IpAddr>,
ttl: Option<u8>,
timeout: Option<Duration>,
) -> io::Result<Self> {
match (target_addr, source_addr) {
(IpAddr::V4(_), Some(IpAddr::V6(_))) | (IpAddr::V6(_), Some(IpAddr::V4(_))) => {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Source address type does not match target address type",
));
}
_ => {}
}
let timeout = timeout.unwrap_or(PING_DEFAULT_TIMEOUT);
let sequence = AtomicU16::new(0);
let (socket, identifier) = create_socket(target_addr, source_addr, ttl)?;
let socket = Arc::new(socket);
let registry = Arc::new(Mutex::new(HashMap::new()));
let router_context = RouterContext {
target_addr,
socket: Arc::clone(&socket),
registry: Arc::clone(®istry),
failed: Arc::new(Mutex::new(None::<RouterError>)),
};
Ok(IcmpEchoRequestor {
inner: Arc::new(RequestorInner {
socket,
target_addr,
timeout,
identifier,
sequence,
registry,
router_abort: OnceLock::new(),
router_context,
}),
})
}
pub async fn send(&self) -> io::Result<IcmpEchoReply> {
if let Some(ref router_error) = *self
.inner
.router_context
.failed
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
{
return Err(router_error.to_io_error());
}
self.ensure_router_running();
let sequence = self.inner.sequence.fetch_add(1, Ordering::SeqCst);
let key = sequence;
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|e| io::Error::other(format!("timestamp error: {e}")))?
.as_nanos() as u64;
let payload = timestamp.to_be_bytes();
let packet = IcmpPacket::new_echo_request(
self.inner.target_addr,
self.inner.identifier,
sequence,
&payload,
);
let (tx, reply_rx) = oneshot::channel();
self.inner
.registry
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.insert(key, tx);
let target = SocketAddr::new(self.inner.target_addr, 0);
if let Err(e) = self.inner.socket.send_to(packet.as_bytes(), target).await {
self.inner
.registry
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.remove(&key);
return match e.kind() {
io::ErrorKind::NetworkUnreachable
| io::ErrorKind::NetworkDown
| io::ErrorKind::HostUnreachable => Ok(IcmpEchoReply::new(
self.inner.target_addr,
IcmpEchoStatus::Unreachable,
Duration::ZERO,
)),
_ => Err(e),
};
}
let timeout = self.inner.timeout;
let target_addr = self.inner.target_addr;
tokio::select! {
result = reply_rx => {
match result {
Ok(reply) => Ok(reply),
Err(_) => {
if let Some(ref router_error) = *self
.inner
.router_context
.failed
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
{
Err(router_error.to_io_error())
} else {
Err(io::Error::other("reply channel closed"))
}
}
}
}
_ = time::sleep(timeout) => {
self.inner.registry.lock().unwrap_or_else(|poisoned| poisoned.into_inner()).remove(&key);
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|e| io::Error::other(format!("timestamp error: {e}")))?
.as_nanos() as u64;
let rtt = Duration::from_nanos(now.saturating_sub(timestamp));
Ok(IcmpEchoReply::new(
target_addr,
IcmpEchoStatus::TimedOut,
rtt,
))
}
}
}
fn ensure_router_running(&self) {
let target_addr = self.inner.router_context.target_addr;
let identifier = self.inner.identifier;
let socket = Arc::clone(&self.inner.router_context.socket);
let registry = Arc::clone(&self.inner.router_context.registry);
let failed = Arc::clone(&self.inner.router_context.failed);
self.inner.router_abort.get_or_init(|| {
let handle = tokio::spawn(reply_router_loop(
target_addr,
identifier,
socket,
registry,
failed,
));
handle.abort_handle()
});
}
}
impl Drop for RequestorInner {
fn drop(&mut self) {
if let Some(abort_handle) = self.router_abort.get() {
abort_handle.abort();
}
}
}
async fn reply_router_loop(
target_addr: IpAddr,
identifier: u16,
socket: Arc<UdpSocket>,
registry: RequestRegistry,
failed: Arc<Mutex<Option<RouterError>>>,
) {
loop {
let mut buf = vec![0u8; 1024];
match socket.recv(&mut buf).await {
Ok(size) => {
buf.truncate(size);
if let Some(reply_packet) = IcmpPacket::parse_reply(&buf, target_addr) {
if reply_packet.identifier() != identifier {
continue;
}
let key = reply_packet.sequence();
let sender = registry
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.remove(&key);
if let Some(sender) = sender {
let payload = reply_packet.payload();
let reply = if payload.len() >= 8 {
let sent_timestamp = u64::from_be_bytes([
payload[0], payload[1], payload[2], payload[3], payload[4],
payload[5], payload[6], payload[7],
]);
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_nanos() as u64;
let rtt = Duration::from_nanos(now.saturating_sub(sent_timestamp));
IcmpEchoReply::new(target_addr, IcmpEchoStatus::Success, rtt)
} else {
IcmpEchoReply::new(target_addr, IcmpEchoStatus::Unknown, Duration::ZERO)
};
let _ = sender.send(reply);
}
} else if let Some(error_info) = IcmpPacket::parse_error_reply(&buf, target_addr) {
if error_info.identifier != identifier {
continue;
}
let sender = registry
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.remove(&error_info.sequence);
if let Some(sender) = sender {
let reply =
IcmpEchoReply::new(target_addr, error_info.status, Duration::ZERO);
let _ = sender.send(reply);
}
}
}
Err(e) => {
match e.kind() {
io::ErrorKind::PermissionDenied | io::ErrorKind::AddrNotAvailable | io::ErrorKind::ConnectionAborted | io::ErrorKind::NotConnected => { let mut failed_lock = failed.lock().unwrap_or_else(|poisoned| poisoned.into_inner());
*failed_lock = Some(RouterError::from_io_error(&e));
registry.lock().unwrap_or_else(|poisoned| poisoned.into_inner()).clear();
return;
}
_ => continue,
}
}
}
}
}
fn create_socket(
target_addr: IpAddr,
source_addr: Option<IpAddr>,
ttl: Option<u8>,
) -> io::Result<(UdpSocket, u16)> {
let socket = match target_addr {
IpAddr::V4(_) => Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::ICMPV4))?,
IpAddr::V6(_) => Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::ICMPV6))?,
};
socket.set_nonblocking(true)?;
let ttl = ttl.unwrap_or(PING_DEFAULT_TTL);
if target_addr.is_ipv4() {
socket.set_ttl_v4(ttl as u32)?;
} else {
socket.set_unicast_hops_v6(ttl as u32)?;
}
#[cfg(not(target_os = "linux"))]
let identifier = {
if let Some(source_addr) = source_addr {
socket.bind(&SocketAddr::new(source_addr, 0).into())?;
}
rand::random()
};
#[cfg(target_os = "linux")]
let identifier = {
let bind_addr = source_addr.unwrap_or(match target_addr {
IpAddr::V4(_) => IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED),
IpAddr::V6(_) => IpAddr::V6(std::net::Ipv6Addr::UNSPECIFIED),
});
socket.bind(&SocketAddr::new(bind_addr, 0).into())?;
let local_addr = socket.local_addr()?;
local_addr
.as_socket()
.ok_or(io::Error::other(
"Failed to get kernel-assigned ICMP identifier",
))?
.port()
};
let udp_socket = UdpSocket::from_std(socket.into())?;
Ok((udp_socket, identifier))
}
#[cfg(test)]
mod tests {
use super::*;
use std::io;
#[cfg(test)]
fn is_router_spawned(pinger: &IcmpEchoRequestor) -> bool {
pinger.inner.router_abort.get().is_some()
}
#[tokio::test]
async fn test_lazy_router_spawning() -> io::Result<()> {
let pinger = IcmpEchoRequestor::new("127.0.0.1".parse().unwrap(), None, None, None)?;
assert!(
!is_router_spawned(&pinger),
"Router should not be spawned after new()"
);
let reply = pinger.send().await?;
assert_eq!(reply.destination(), "127.0.0.1".parse::<IpAddr>().unwrap());
assert!(
is_router_spawned(&pinger),
"Router should be spawned after first send()"
);
let reply2 = pinger.send().await?;
assert_eq!(reply2.destination(), "127.0.0.1".parse::<IpAddr>().unwrap());
assert!(
is_router_spawned(&pinger),
"Router should remain spawned after subsequent sends"
);
Ok(())
}
}