use crate::probe::{ProbeInfo, ProbeResponse};
use crate::socket::traits::{ProbeMode, ProbeSocket};
use crate::traceroute::TracerouteError;
use crate::TimingConfig;
use std::ffi::c_void;
use std::future::Future;
use std::mem;
use std::net::{IpAddr, Ipv4Addr};
use std::pin::Pin;
use std::ptr;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use tokio::sync::oneshot;
use windows_sys::Win32::Foundation::{
CloseHandle, GetLastError, ERROR_IO_PENDING, HANDLE, WAIT_OBJECT_0,
};
use windows_sys::Win32::NetworkManagement::IpHelper::{
IcmpCloseHandle, IcmpCreateFile, IcmpSendEcho2, ICMP_ECHO_REPLY, IP_OPTION_INFORMATION,
IP_SUCCESS,
};
const IP_REQ_TIMED_OUT: u32 = 11010;
const IP_GENERAL_FAILURE: u32 = 11050;
use windows_sys::Win32::System::Threading::{CreateEventW, WaitForSingleObject};
pub struct WindowsAsyncIcmpSocket {
icmp_handle: HANDLE,
destination_reached: Arc<Mutex<bool>>,
pending_count: Arc<Mutex<usize>>,
timing_config: TimingConfig,
}
impl WindowsAsyncIcmpSocket {
pub fn new_with_config(timing_config: TimingConfig) -> Result<Self, TracerouteError> {
let icmp_handle = unsafe { IcmpCreateFile() };
if icmp_handle.is_null() {
return Err(TracerouteError::SocketError(
"Failed to create ICMP handle".to_string(),
));
}
Ok(Self {
icmp_handle,
destination_reached: Arc::new(Mutex::new(false)),
pending_count: Arc::new(Mutex::new(0)),
timing_config,
})
}
fn process_response(
&self,
buffer: &[u8],
sequence: u16,
ttl: u8,
sent_at: Instant,
) -> Result<ProbeResponse, TracerouteError> {
if buffer.len() < mem::size_of::<ICMP_ECHO_REPLY>() {
return Err(TracerouteError::SocketError(
"Response buffer too small".to_string(),
));
}
let reply = unsafe { &*(buffer.as_ptr() as *const ICMP_ECHO_REPLY) };
let elapsed = sent_at.elapsed();
if reply.Status == IP_SUCCESS {
if buffer.len() >= mem::size_of::<ICMP_ECHO_REPLY>() + 4 {
let data_offset = mem::size_of::<ICMP_ECHO_REPLY>();
let data = &buffer[data_offset..];
if data.len() >= 4 {
let identifier = u16::from_be_bytes([data[0], data[1]]);
let recv_sequence = u16::from_be_bytes([data[2], data[3]]);
let expected_identifier = std::process::id() as u16;
if identifier != expected_identifier || recv_sequence != sequence {
return Err(TracerouteError::SocketError(format!(
"Response mismatch: expected id={}/seq={}, got id={}/seq={}",
expected_identifier, sequence, identifier, recv_sequence
)));
}
}
}
}
match reply.Status {
IP_REQ_TIMED_OUT | IP_GENERAL_FAILURE => {
return Ok(ProbeResponse {
from_addr: IpAddr::V4(Ipv4Addr::UNSPECIFIED), sequence,
ttl,
rtt: elapsed,
received_at: Instant::now(),
is_destination: false,
is_timeout: true,
});
}
_ => {}
}
let from_addr = IpAddr::V4(Ipv4Addr::new(
reply.Address as u8,
(reply.Address >> 8) as u8,
(reply.Address >> 16) as u8,
(reply.Address >> 24) as u8,
));
let is_destination = reply.Status == IP_SUCCESS;
if is_destination {
*self
.destination_reached
.lock()
.expect("Failed to acquire destination_reached lock") = true;
}
let rtt = if reply.RoundTripTime > 0 {
Duration::from_millis(reply.RoundTripTime as u64)
} else {
elapsed
};
Ok(ProbeResponse {
from_addr,
sequence,
ttl,
rtt,
received_at: Instant::now(),
is_destination,
is_timeout: false,
})
}
}
impl ProbeSocket for WindowsAsyncIcmpSocket {
fn mode(&self) -> ProbeMode {
ProbeMode::WindowsIcmp
}
fn send_probe_and_recv(
&self,
dest: IpAddr,
probe: ProbeInfo,
) -> Pin<Box<dyn Future<Output = Result<ProbeResponse, TracerouteError>> + Send + '_>> {
Box::pin(async move {
let dest_addr = match dest {
IpAddr::V4(addr) => addr,
_ => return Err(TracerouteError::Ipv6NotSupported),
};
{
let mut count = self
.pending_count
.lock()
.expect("Failed to acquire pending_count lock");
*count += 1;
}
let event = unsafe { CreateEventW(ptr::null(), 1, 0, ptr::null()) };
if event.is_null() {
let mut count = self
.pending_count
.lock()
.expect("Failed to acquire pending_count lock");
*count -= 1;
return Err(TracerouteError::SocketError(
"Failed to create event".to_string(),
));
}
let identifier = std::process::id() as u16;
let mut send_data = Vec::with_capacity(32);
send_data.extend_from_slice(&identifier.to_be_bytes());
send_data.extend_from_slice(&probe.sequence.to_be_bytes());
send_data.extend_from_slice(b"ftr-windows-padding");
send_data.resize(32, 0);
let reply_size = mem::size_of::<ICMP_ECHO_REPLY>() + send_data.len() + 8;
let reply_buffer = Box::pin(vec![0u8; reply_size]);
let reply_ptr = reply_buffer.as_ptr() as *mut c_void;
let sent_at = Instant::now();
let send_result = {
let mut options = IP_OPTION_INFORMATION {
Ttl: probe.ttl,
Tos: 0,
Flags: 0,
OptionsSize: 0,
OptionsData: ptr::null_mut(),
};
let result = unsafe {
IcmpSendEcho2(
self.icmp_handle,
event,
None, ptr::null(), u32::from_ne_bytes(dest_addr.octets()),
send_data.as_ptr() as *const c_void,
send_data.len() as u16,
&mut options as *mut IP_OPTION_INFORMATION,
reply_ptr,
reply_size as u32,
{
let user_timeout_ms =
self.timing_config.socket_read_timeout.as_millis() as u32;
let windows_timeout = user_timeout_ms
+ crate::config::timing::WINDOWS_ICMP_TIMEOUT_BUFFER_MS;
windows_timeout
.max(crate::config::timing::WINDOWS_ICMP_MIN_TOTAL_TIMEOUT_MS)
},
)
};
if result == 0 {
let error = unsafe { GetLastError() };
if error != ERROR_IO_PENDING {
Err(error)
} else {
Ok(())
}
} else {
Ok(())
}
};
if let Err(error) = send_result {
unsafe { CloseHandle(event) };
let mut count = self
.pending_count
.lock()
.expect("Failed to acquire pending_count lock");
*count -= 1;
return Err(TracerouteError::SocketError(format!(
"IcmpSendEcho2 failed: {}",
error
)));
}
let (tx, rx) = oneshot::channel();
let event_handle = event as usize; let pending_count = Arc::clone(&self.pending_count);
let wait_handle = tokio::task::spawn_blocking(move || {
let event = event_handle as HANDLE; let result = unsafe {
WaitForSingleObject(event, 0xFFFFFFFF) };
unsafe { CloseHandle(event) };
let mut count = pending_count
.lock()
.expect("Failed to acquire pending_count lock");
*count = count.saturating_sub(1);
if result == WAIT_OBJECT_0 {
tx.send(Ok(reply_buffer)).ok();
} else {
tx.send(Err(TracerouteError::SocketError(
"Event wait failed or timed out".to_string(),
)))
.ok();
}
});
let timeout_duration = self.timing_config.socket_read_timeout;
let verbose = std::env::var("FTR_VERBOSE")
.ok()
.and_then(|v| v.parse::<u8>().ok())
.unwrap_or(0);
if verbose >= 3 {
let windows_timeout_ms = {
let user_timeout_ms = self.timing_config.socket_read_timeout.as_millis() as u32;
let windows_timeout =
user_timeout_ms + crate::config::timing::WINDOWS_ICMP_TIMEOUT_BUFFER_MS;
windows_timeout.max(crate::config::timing::WINDOWS_ICMP_MIN_TOTAL_TIMEOUT_MS)
};
eprintln!(
"[TIMEOUT] Probe seq={} ttl={}: User timeout={}ms, Windows timeout={}ms",
probe.sequence,
probe.ttl,
timeout_duration.as_millis(),
windows_timeout_ms
);
}
match tokio::time::timeout(timeout_duration, rx).await {
Ok(Ok(Ok(reply_buffer))) => {
self.process_response(&reply_buffer, probe.sequence, probe.ttl, sent_at)
}
Ok(Ok(Err(e))) => {
Err(TracerouteError::SocketError(format!(
"Event wait error: {}",
e
)))
}
Ok(Err(_)) => {
Err(TracerouteError::SocketError(
"Event wait cancelled".to_string(),
))
}
Err(_) => {
drop(wait_handle);
Ok(ProbeResponse {
from_addr: IpAddr::V4(Ipv4Addr::UNSPECIFIED),
sequence: probe.sequence,
ttl: probe.ttl,
rtt: timeout_duration,
received_at: Instant::now(),
is_destination: false,
is_timeout: true,
})
}
}
})
}
fn destination_reached(&self) -> bool {
*self
.destination_reached
.lock()
.expect("Failed to acquire destination_reached lock")
}
fn pending_count(&self) -> usize {
*self
.pending_count
.lock()
.expect("Failed to acquire pending_count lock")
}
}
impl Drop for WindowsAsyncIcmpSocket {
fn drop(&mut self) {
if !self.icmp_handle.is_null() {
let pending = *self
.pending_count
.lock()
.expect("Failed to acquire pending_count lock");
if pending > 0 {
} else {
unsafe { IcmpCloseHandle(self.icmp_handle) };
}
}
}
}
unsafe impl Send for WindowsAsyncIcmpSocket {}
unsafe impl Sync for WindowsAsyncIcmpSocket {}