use std::ffi::c_void;
use std::io;
use std::mem::size_of;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddrV6};
use std::sync::Arc;
use std::time::Duration;
use futures::channel::oneshot;
use static_assertions::const_assert;
use windows::Win32::Foundation::{
CloseHandle, GetLastError, ERROR_HOST_UNREACHABLE, ERROR_IO_PENDING, ERROR_NETWORK_UNREACHABLE,
ERROR_PORT_UNREACHABLE, ERROR_PROTOCOL_UNREACHABLE, HANDLE,
};
use windows::Win32::NetworkManagement::IpHelper::{
Icmp6CreateFile, Icmp6ParseReplies, Icmp6SendEcho2, IcmpCloseHandle, IcmpCreateFile,
IcmpParseReplies, IcmpSendEcho2Ex, ICMPV6_ECHO_REPLY_LH as ICMPV6_ECHO_REPLY,
IP_DEST_HOST_UNREACHABLE, IP_DEST_NET_UNREACHABLE, IP_DEST_PORT_UNREACHABLE,
IP_DEST_PROT_UNREACHABLE, IP_DEST_UNREACHABLE, IP_REQ_TIMED_OUT, IP_SUCCESS, IP_TIME_EXCEEDED,
IP_TTL_EXPIRED_REASSEM, IP_TTL_EXPIRED_TRANSIT,
};
use windows::Win32::Networking::WinSock::{IN6_ADDR, SOCKADDR_IN6};
use windows::Win32::System::Threading::{
CreateEventW, RegisterWaitForSingleObject, UnregisterWaitEx, INFINITE, WT_EXECUTEINWAITTHREAD,
WT_EXECUTEONLYONCE,
};
use windows::Win32::System::IO::IO_STATUS_BLOCK;
#[cfg(target_pointer_width = "32")]
use windows::Win32::NetworkManagement::IpHelper::ICMP_ECHO_REPLY;
#[cfg(target_pointer_width = "64")]
use windows::Win32::NetworkManagement::IpHelper::ICMP_ECHO_REPLY32 as ICMP_ECHO_REPLY;
#[cfg(target_pointer_width = "32")]
use windows::Win32::NetworkManagement::IpHelper::IP_OPTION_INFORMATION;
#[cfg(target_pointer_width = "64")]
use windows::Win32::NetworkManagement::IpHelper::IP_OPTION_INFORMATION32 as IP_OPTION_INFORMATION;
use crate::{
IcmpEchoReply, IcmpEchoStatus, PING_DEFAULT_REQUEST_DATA_LENGTH, PING_DEFAULT_TIMEOUT,
PING_DEFAULT_TTL,
};
const REPLY_BUFFER_SIZE: usize = 100;
const_assert!(
size_of::<ICMP_ECHO_REPLY>()
+ PING_DEFAULT_REQUEST_DATA_LENGTH
+ 8
+ size_of::<IO_STATUS_BLOCK>()
<= REPLY_BUFFER_SIZE
);
const_assert!(
size_of::<ICMPV6_ECHO_REPLY>()
+ PING_DEFAULT_REQUEST_DATA_LENGTH
+ 8
+ size_of::<IO_STATUS_BLOCK>()
<= REPLY_BUFFER_SIZE
);
struct RequestContext {
wait_object: HANDLE,
event: HANDLE,
buffer: Box<[u8]>,
target_addr: IpAddr,
timeout: Duration,
sender: oneshot::Sender<IcmpEchoReply>,
}
impl RequestContext {
fn new(
event: HANDLE,
target_addr: IpAddr,
timeout: Duration,
sender: oneshot::Sender<IcmpEchoReply>,
) -> Self {
RequestContext {
wait_object: HANDLE::default(),
event,
buffer: vec![0u8; REPLY_BUFFER_SIZE].into_boxed_slice(),
target_addr,
timeout,
sender,
}
}
fn buffer_ptr(&mut self) -> *mut u8 {
self.buffer.as_mut_ptr()
}
fn buffer_size(&self) -> usize {
self.buffer.len()
}
}
#[derive(Clone)]
pub struct IcmpEchoRequestor {
inner: Arc<RequestorInner>,
}
struct RequestorInner {
icmp_handle: HANDLE,
target_addr: IpAddr,
source_addr: IpAddr,
ttl: u8,
timeout: Duration,
}
unsafe impl Send for RequestorInner {}
unsafe impl Sync for RequestorInner {}
impl IcmpEchoRequestor {
pub fn new(
target_addr: IpAddr,
source_addr: Option<IpAddr>,
ttl: Option<u8>,
timeout: Option<Duration>,
) -> io::Result<Self> {
match (target_addr, source_addr) {
(IpAddr::V4(_), Some(IpAddr::V6(_))) | (IpAddr::V6(_), Some(IpAddr::V4(_))) => {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Source address type does not match target address type",
));
}
_ => {}
}
let icmp_handle = match target_addr {
IpAddr::V4(_) => unsafe { IcmpCreateFile()? },
IpAddr::V6(_) => unsafe { Icmp6CreateFile()? },
};
debug_assert!(!icmp_handle.is_invalid());
let source_addr = source_addr.unwrap_or(match target_addr {
IpAddr::V4(_) => IpAddr::V4(Ipv4Addr::UNSPECIFIED),
IpAddr::V6(_) => IpAddr::V6(Ipv6Addr::UNSPECIFIED),
});
let ttl = ttl.unwrap_or(PING_DEFAULT_TTL);
let timeout = timeout.unwrap_or(PING_DEFAULT_TIMEOUT);
Ok(IcmpEchoRequestor {
inner: Arc::new(RequestorInner {
icmp_handle,
target_addr,
source_addr,
ttl,
timeout,
}),
})
}
pub async fn send(&self) -> io::Result<IcmpEchoReply> {
let (reply_tx, reply_rx) = oneshot::channel();
self.handle_send(reply_tx)?;
reply_rx
.await
.map_err(|_| io::Error::other("reply channel closed unexpectedly"))
}
fn handle_send(&self, reply_tx: oneshot::Sender<IcmpEchoReply>) -> io::Result<()> {
let event = unsafe { CreateEventW(None, false, false, None)? };
let context_raw = Box::into_raw(Box::new(RequestContext::new(
event,
self.inner.target_addr,
self.inner.timeout,
reply_tx,
)));
match self.do_send(context_raw, event) {
Ok(()) => {
unsafe {
match RegisterWaitForSingleObject(
&mut (*context_raw).wait_object,
event,
Some(wait_callback),
Some(context_raw as *const _),
INFINITE,
WT_EXECUTEINWAITTHREAD | WT_EXECUTEONLYONCE,
) {
Ok(()) => Ok(()),
Err(e) => {
let _ = CloseHandle(event);
drop(Box::from_raw(context_raw));
Err(e.into())
}
}
}
}
Err(e) => {
let status = ip_error_to_icmp_status(e);
let reply = IcmpEchoReply::new(self.inner.target_addr, status, Duration::ZERO);
unsafe {
let ctx = Box::from_raw(context_raw);
let _ = ctx.sender.send(reply);
if !ctx.event.is_invalid() {
let _ = CloseHandle(ctx.event);
}
}
Ok(())
}
}
}
fn do_send(&self, context: *mut RequestContext, event: HANDLE) -> Result<(), u32> {
let ip_option = IP_OPTION_INFORMATION {
Ttl: self.inner.ttl,
..Default::default()
};
let req_data = [0u8; PING_DEFAULT_REQUEST_DATA_LENGTH];
let error = match self.inner.target_addr {
IpAddr::V4(taddr) => {
let saddr = if let IpAddr::V4(saddr) = self.inner.source_addr {
saddr
} else {
unreachable!("source address must be IPv4 for IPv4 target");
};
unsafe {
let ctx = context.as_mut().unwrap();
IcmpSendEcho2Ex(
self.inner.icmp_handle,
Some(event),
None,
None,
u32::from(saddr).to_be(),
u32::from(taddr).to_be(),
req_data.as_ptr() as *const _,
req_data.len() as u16,
Some(&ip_option as *const _ as *const _),
ctx.buffer_ptr() as *mut _,
ctx.buffer_size() as u32,
self.inner.timeout.as_millis() as u32,
)
}
}
IpAddr::V6(taddr) => {
let saddr = if let IpAddr::V6(saddr) = self.inner.source_addr {
saddr
} else {
unreachable!("source address must be IPv6 for IPv6 target");
};
unsafe {
let ctx = context.as_mut().unwrap();
let src_saddr: SOCKADDR_IN6 = SocketAddrV6::new(saddr, 0, 0, 0).into();
let dst_saddr: SOCKADDR_IN6 = SocketAddrV6::new(taddr, 0, 0, 0).into();
Icmp6SendEcho2(
self.inner.icmp_handle,
Some(event),
None,
None,
&src_saddr,
&dst_saddr,
req_data.as_ptr() as *const _,
req_data.len() as u16,
Some(&ip_option as *const _ as *const _),
ctx.buffer_ptr() as *mut _,
ctx.buffer_size() as u32,
self.inner.timeout.as_millis() as u32,
)
}
}
};
if error == ERROR_IO_PENDING.0 {
Ok(())
} else {
let code = unsafe { GetLastError() };
if code == ERROR_IO_PENDING {
Ok(())
} else {
Err(code.0)
}
}
}
}
impl Drop for RequestorInner {
fn drop(&mut self) {
unsafe {
if !self.icmp_handle.is_invalid() {
let _ = IcmpCloseHandle(self.icmp_handle);
}
}
}
}
fn ip_error_to_icmp_status(code: u32) -> IcmpEchoStatus {
match code {
IP_SUCCESS => IcmpEchoStatus::Success,
IP_REQ_TIMED_OUT | IP_TIME_EXCEEDED | IP_TTL_EXPIRED_REASSEM | IP_TTL_EXPIRED_TRANSIT => {
IcmpEchoStatus::TimedOut
}
IP_DEST_HOST_UNREACHABLE
| IP_DEST_NET_UNREACHABLE
| IP_DEST_PORT_UNREACHABLE
| IP_DEST_PROT_UNREACHABLE
| IP_DEST_UNREACHABLE => IcmpEchoStatus::Unreachable,
code if code == ERROR_NETWORK_UNREACHABLE.0
|| code == ERROR_HOST_UNREACHABLE.0
|| code == ERROR_PROTOCOL_UNREACHABLE.0
|| code == ERROR_PORT_UNREACHABLE.0 =>
{
IcmpEchoStatus::Unreachable
}
_ => IcmpEchoStatus::Unknown,
}
}
unsafe extern "system" fn wait_callback(ptr: *mut c_void, timer_fired: bool) {
debug_assert!(!timer_fired, "Timer should not be fired here");
let context = Box::from_raw(ptr as *mut RequestContext);
let reply = match context.target_addr {
IpAddr::V4(_) => {
let ret = unsafe {
IcmpParseReplies(
context.buffer.as_ptr() as *mut _,
context.buffer.len() as u32,
)
};
if ret == 0 {
let error = unsafe { GetLastError() };
if error.0 == IP_REQ_TIMED_OUT {
IcmpEchoReply::new(
context.target_addr,
IcmpEchoStatus::TimedOut,
context.timeout,
)
} else {
IcmpEchoReply::new(context.target_addr, IcmpEchoStatus::Unknown, Duration::ZERO)
}
} else {
debug_assert_eq!(ret, 1);
let resp = (context.buffer.as_ptr() as *const ICMP_ECHO_REPLY)
.as_ref()
.unwrap();
let addr = IpAddr::V4(u32::from_be(resp.Address).into());
IcmpEchoReply::new(
addr,
ip_error_to_icmp_status(resp.Status),
Duration::from_millis(resp.RoundTripTime.into()),
)
}
}
IpAddr::V6(_) => {
let ret = unsafe {
Icmp6ParseReplies(
context.buffer.as_ptr() as *mut _,
context.buffer.len() as u32,
)
};
if ret == 0 {
let error = unsafe { GetLastError() };
if error.0 == IP_REQ_TIMED_OUT {
IcmpEchoReply::new(
context.target_addr,
IcmpEchoStatus::TimedOut,
context.timeout,
)
} else {
IcmpEchoReply::new(context.target_addr, IcmpEchoStatus::Unknown, Duration::ZERO)
}
} else {
debug_assert_eq!(ret, 1);
let resp = (context.buffer.as_ptr() as *const ICMPV6_ECHO_REPLY)
.as_ref()
.unwrap();
let mut addr_raw = IN6_ADDR::default();
addr_raw.u.Word = resp.Address.sin6_addr;
let addr = IpAddr::V6(addr_raw.into());
IcmpEchoReply::new(
addr,
ip_error_to_icmp_status(resp.Status),
Duration::from_millis(resp.RoundTripTime.into()),
)
}
}
};
let _ = context.sender.send(reply);
if !context.wait_object.is_invalid() {
let _ = UnregisterWaitEx(context.wait_object, None);
}
if !context.event.is_invalid() {
let _ = CloseHandle(context.event);
}
}