use super::{ProbeResult, Prober};
use crate::error::NetPulseError;
use async_trait::async_trait;
use socket2::{Domain, Protocol, Socket, Type};
use std::net::{IpAddr, SocketAddr};
use std::time::{Duration, Instant};
use tokio::time::timeout;
const ICMP_TIME_EXCEEDED: u8 = 11;
pub struct IcmpProber {
timeout_ms: u64,
}
impl IcmpProber {
pub fn new(timeout_ms: u64) -> Self {
Self { timeout_ms }
}
}
const ICMP_ECHO_REQUEST: u8 = 8;
const ICMP_ECHO_REPLY: u8 = 0;
const ICMP_HEADER_SIZE: usize = 8;
const PAYLOAD_SIZE: usize = 8; const PACKET_SIZE: usize = ICMP_HEADER_SIZE + PAYLOAD_SIZE;
fn build_icmp_packet(id: u16, seq: u16) -> [u8; PACKET_SIZE] {
let mut packet = [0u8; PACKET_SIZE];
packet[0] = ICMP_ECHO_REQUEST;
packet[1] = 0;
packet[2] = 0;
packet[3] = 0;
packet[4] = (id >> 8) as u8;
packet[5] = id as u8;
packet[6] = (seq >> 8) as u8;
packet[7] = seq as u8;
let now_us = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_micros() as u64;
packet[8..].copy_from_slice(&now_us.to_be_bytes());
let checksum = icmp_checksum(&packet);
packet[2] = (checksum >> 8) as u8;
packet[3] = checksum as u8;
packet
}
fn icmp_checksum(data: &[u8]) -> u16 {
let mut sum: u32 = 0;
let mut i = 0;
while i + 1 < data.len() {
let word = u16::from_be_bytes([data[i], data[i + 1]]);
sum += u32::from(word);
i += 2;
}
if i < data.len() {
sum += u32::from(data[i]) << 8;
}
while sum >> 16 != 0 {
sum = (sum & 0xFFFF) + (sum >> 16);
}
!(sum as u16)
}
fn parse_icmp_reply(buf: &[u8], expected_id: u16, expected_seq: u16) -> Option<(u64, bool)> {
if buf.len() < 20 + PACKET_SIZE {
return None;
}
let ip_header_len = ((buf[0] & 0x0F) as usize) * 4;
let icmp_start = ip_header_len;
if buf.len() < icmp_start + 8 {
return None;
}
let icmp_type = buf[icmp_start];
if icmp_type == ICMP_ECHO_REPLY {
let icmp = &buf[icmp_start..];
if icmp.len() < PACKET_SIZE {
return None;
}
let reply_id = u16::from_be_bytes([icmp[4], icmp[5]]);
let reply_seq = u16::from_be_bytes([icmp[6], icmp[7]]);
if reply_id != expected_id || reply_seq != expected_seq {
return None;
}
let sent_us = u64::from_be_bytes(icmp[8..16].try_into().ok()?);
let now_us = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_micros() as u64;
return Some((now_us.checked_sub(sent_us)?, false));
} else if icmp_type == ICMP_TIME_EXCEEDED {
let inner_ip_start = icmp_start + 8;
if buf.len() < inner_ip_start + 20 {
return None; }
let inner_ip_header_len = ((buf[inner_ip_start] & 0x0F) as usize) * 4;
let inner_icmp_start = inner_ip_start + inner_ip_header_len;
if buf.len() < inner_icmp_start + PACKET_SIZE {
return None; }
let inner_icmp = &buf[inner_icmp_start..];
let reply_id = u16::from_be_bytes([inner_icmp[4], inner_icmp[5]]);
let reply_seq = u16::from_be_bytes([inner_icmp[6], inner_icmp[7]]);
if reply_id != expected_id || reply_seq != expected_seq {
return None;
}
let sent_us = u64::from_be_bytes(inner_icmp[8..16].try_into().ok()?);
let now_us = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_micros() as u64;
return Some((now_us.checked_sub(sent_us)?, true));
}
None
}
#[async_trait]
impl Prober for IcmpProber {
fn name(&self) -> &'static str {
"icmp"
}
async fn probe(
&self,
target: &str,
seq: u64,
ttl: Option<u32>,
) -> Result<ProbeResult, NetPulseError> {
let addr: IpAddr = target.parse().map_err(|_| NetPulseError::InvalidTarget {
target: target.to_string(),
reason: "must be an IP address for ICMP probing".to_string(),
})?;
let IpAddr::V4(_) = addr else {
return Err(NetPulseError::InvalidTarget {
target: target.to_string(),
reason: "only IPv4 is currently supported".to_string(),
});
};
let id = (std::process::id() & 0xFFFF) as u16;
let seq_u16 = (seq & 0xFFFF) as u16;
let packet = build_icmp_packet(id, seq_u16);
let dest = SocketAddr::new(addr, 0);
let timeout_ms = self.timeout_ms;
let result = timeout(
Duration::from_millis(timeout_ms),
tokio::task::spawn_blocking(move || {
let socket =
Socket::new(Domain::IPV4, Type::RAW, Some(Protocol::ICMPV4)).map_err(|e| {
if e.raw_os_error() == Some(1) {
NetPulseError::InsufficientPrivileges
} else {
NetPulseError::SocketError(e)
}
})?;
if let Some(t) = ttl {
let _ = socket.set_ttl(t);
}
socket
.set_read_timeout(Some(Duration::from_millis(timeout_ms)))
.map_err(NetPulseError::SocketError)?;
let start = Instant::now();
socket
.send_to(&packet, &dest.into())
.map_err(NetPulseError::SocketError)?;
let mut recv_buf = [0u8; 1024];
loop {
let (n, from_addr) = socket
.recv_from(unsafe {
&mut *(&mut recv_buf as *mut [u8; 1024]
as *mut [std::mem::MaybeUninit<u8>; 1024])
})
.map_err(|e| {
if e.kind() == std::io::ErrorKind::WouldBlock
|| e.kind() == std::io::ErrorKind::TimedOut
{
NetPulseError::Timeout { timeout_ms }
} else {
NetPulseError::SocketError(e)
}
})?;
if let Some((rtt_us, _time_exceeded)) =
parse_icmp_reply(&recv_buf[..n], id, seq_u16)
{
let _ = start; let src_ip = from_addr.as_socket().map(|s| s.ip().to_string());
return Ok::<(u64, Option<String>), NetPulseError>((rtt_us, src_ip));
}
}
}),
)
.await;
match result {
Ok(Ok(Ok((rtt_us, responder_ip)))) => {
Ok(ProbeResult::success(target, seq, rtt_us, responder_ip))
}
Ok(Ok(Err(e))) => {
match e {
NetPulseError::InsufficientPrivileges => Err(e),
NetPulseError::Timeout { .. } => Ok(ProbeResult::loss(target, seq)),
_ => Ok(ProbeResult::loss(target, seq)),
}
}
Ok(Err(_join_err)) => Ok(ProbeResult::loss(target, seq)),
Err(_timeout) => Ok(ProbeResult::loss(target, seq)),
}
}
}