use super::{ProbeResult, Prober};
use crate::error::NetPulseError;
use async_trait::async_trait;
use socket2::{Domain, Protocol, Socket, Type};
use std::net::{IpAddr, SocketAddr};
use std::time::{Duration, Instant};
use tokio::time::timeout;
pub struct UdpProber {
timeout_ms: u64,
base_port: u16,
}
impl UdpProber {
pub fn new(timeout_ms: u64) -> Self {
Self {
timeout_ms,
base_port: 33434, }
}
}
fn parse_icmp_reply(buf: &[u8], expected_dest_ip: &str, expected_dest_port: u16) -> Option<bool> {
if buf.len() < 20 + 8 + 20 + 8 {
return None;
}
let ip_header_len = ((buf[0] & 0x0F) as usize) * 4;
if buf.len() < ip_header_len + 8 + 20 + 8 {
return None;
}
let icmp = &buf[ip_header_len..];
let icmp_type = icmp[0];
let icmp_code = icmp[1];
let is_time_exceeded = if icmp_type == 11 && icmp_code == 0 {
true } else if icmp_type == 3 && icmp_code == 3 {
false } else {
return None;
};
let orig_ip = &icmp[8..];
if orig_ip.len() < 20 + 8 {
return None;
}
let orig_dest_ip = format!(
"{}.{}.{}.{}",
orig_ip[16], orig_ip[17], orig_ip[18], orig_ip[19]
);
let orig_ip_len = ((orig_ip[0] & 0x0F) as usize) * 4;
if orig_ip.len() < orig_ip_len + 8 {
return None;
}
let udp = &orig_ip[orig_ip_len..];
let orig_dest_port = u16::from_be_bytes([udp[2], udp[3]]);
if orig_dest_ip == expected_dest_ip && orig_dest_port == expected_dest_port {
Some(is_time_exceeded)
} else {
None
}
}
#[async_trait]
impl Prober for UdpProber {
fn name(&self) -> &'static str {
"udp"
}
async fn probe(
&self,
target: &str,
seq: u64,
ttl: Option<u32>,
) -> Result<ProbeResult, NetPulseError> {
let addr: IpAddr = target.parse().map_err(|_| NetPulseError::InvalidTarget {
target: target.to_string(),
reason: "must be an IP address for UDP probing".to_string(),
})?;
let IpAddr::V4(_) = addr else {
return Err(NetPulseError::InvalidTarget {
target: target.to_string(),
reason: "only IPv4 is currently supported for UDP probing".to_string(),
});
};
let dest_port = self.base_port.wrapping_add((seq % 30000) as u16);
let timeout_ms = self.timeout_ms;
let target_owned = target.to_string();
let result = timeout(
Duration::from_millis(timeout_ms),
tokio::task::spawn_blocking(move || {
let icmp_sock = Socket::new(Domain::IPV4, Type::RAW, Some(Protocol::ICMPV4))
.map_err(|e| {
if e.raw_os_error() == Some(1) {
NetPulseError::InsufficientPrivileges
} else {
NetPulseError::SocketError(e)
}
})?;
icmp_sock
.set_read_timeout(Some(Duration::from_millis(timeout_ms)))
.map_err(NetPulseError::SocketError)?;
let udp_sock = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))
.map_err(NetPulseError::SocketError)?;
if let Some(t) = ttl {
let _ = udp_sock.set_ttl(t);
}
let dest = SocketAddr::new(addr, dest_port);
let start = Instant::now();
udp_sock
.send_to(b"netpulse", &dest.into())
.map_err(NetPulseError::SocketError)?;
let mut recv_buf = [0u8; 1024];
loop {
let (n, from_addr) = icmp_sock
.recv_from(unsafe {
&mut *(&mut recv_buf as *mut [u8; 1024]
as *mut [std::mem::MaybeUninit<u8>; 1024])
})
.map_err(|e| {
if e.kind() == std::io::ErrorKind::WouldBlock
|| e.kind() == std::io::ErrorKind::TimedOut
{
NetPulseError::Timeout { timeout_ms }
} else {
NetPulseError::SocketError(e)
}
})?;
if let Some(_is_time_exceeded) =
parse_icmp_reply(&recv_buf[..n], &target_owned, dest_port)
{
let rtt = start.elapsed().as_micros() as u64;
let src_ip = from_addr.as_socket().map(|s| s.ip().to_string());
return Ok::<(u64, Option<String>), NetPulseError>((rtt, src_ip));
}
}
}),
)
.await;
match result {
Ok(Ok(Ok((rtt_us, responder_ip)))) => {
Ok(ProbeResult::success(target, seq, rtt_us, responder_ip))
}
Ok(Ok(Err(NetPulseError::InsufficientPrivileges))) => {
Err(NetPulseError::InsufficientPrivileges)
}
_ => Ok(ProbeResult::loss(target, seq)),
}
}
}