1use std::ffi::c_void;
4use std::io;
5use std::mem::size_of;
6use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddrV6};
7use std::sync::Arc;
8use std::time::Duration;
9
10use futures::channel::oneshot;
11use static_assertions::const_assert;
12
13use windows::Win32::Foundation::{
14 CloseHandle, GetLastError, ERROR_HOST_UNREACHABLE, ERROR_IO_PENDING, ERROR_NETWORK_UNREACHABLE,
15 ERROR_PORT_UNREACHABLE, ERROR_PROTOCOL_UNREACHABLE, HANDLE,
16};
17use windows::Win32::NetworkManagement::IpHelper::{
18 Icmp6CreateFile, Icmp6ParseReplies, Icmp6SendEcho2, IcmpCloseHandle, IcmpCreateFile,
19 IcmpParseReplies, IcmpSendEcho2Ex, ICMPV6_ECHO_REPLY_LH as ICMPV6_ECHO_REPLY,
20 IP_DEST_HOST_UNREACHABLE, IP_DEST_NET_UNREACHABLE, IP_DEST_PORT_UNREACHABLE,
21 IP_DEST_PROT_UNREACHABLE, IP_DEST_UNREACHABLE, IP_REQ_TIMED_OUT, IP_SUCCESS, IP_TIME_EXCEEDED,
22 IP_TTL_EXPIRED_REASSEM, IP_TTL_EXPIRED_TRANSIT,
23};
24use windows::Win32::Networking::WinSock::{IN6_ADDR, SOCKADDR_IN6};
25use windows::Win32::System::Threading::{
26 CreateEventW, RegisterWaitForSingleObject, UnregisterWaitEx, INFINITE, WT_EXECUTEINWAITTHREAD,
27 WT_EXECUTEONLYONCE,
28};
29use windows::Win32::System::IO::IO_STATUS_BLOCK;
30
31#[cfg(target_pointer_width = "32")]
32use windows::Win32::NetworkManagement::IpHelper::ICMP_ECHO_REPLY;
33#[cfg(target_pointer_width = "64")]
34use windows::Win32::NetworkManagement::IpHelper::ICMP_ECHO_REPLY32 as ICMP_ECHO_REPLY;
35#[cfg(target_pointer_width = "32")]
36use windows::Win32::NetworkManagement::IpHelper::IP_OPTION_INFORMATION;
37#[cfg(target_pointer_width = "64")]
38use windows::Win32::NetworkManagement::IpHelper::IP_OPTION_INFORMATION32 as IP_OPTION_INFORMATION;
39
40use crate::{
41 IcmpEchoReply, IcmpEchoStatus, PING_DEFAULT_REQUEST_DATA_LENGTH, PING_DEFAULT_TIMEOUT,
42 PING_DEFAULT_TTL,
43};
44
45const REPLY_BUFFER_SIZE: usize = 100;
46
47const_assert!(
49 size_of::<ICMP_ECHO_REPLY>()
50 + PING_DEFAULT_REQUEST_DATA_LENGTH
51 + 8
52 + size_of::<IO_STATUS_BLOCK>()
53 <= REPLY_BUFFER_SIZE
54);
55const_assert!(
56 size_of::<ICMPV6_ECHO_REPLY>()
57 + PING_DEFAULT_REQUEST_DATA_LENGTH
58 + 8
59 + size_of::<IO_STATUS_BLOCK>()
60 <= REPLY_BUFFER_SIZE
61);
62
63struct RequestContext {
64 wait_object: HANDLE,
65 event: HANDLE,
66 buffer: Box<[u8]>,
67 target_addr: IpAddr,
68 timeout: Duration,
69 sender: oneshot::Sender<IcmpEchoReply>,
70}
71
72impl RequestContext {
73 fn new(
74 event: HANDLE,
75 target_addr: IpAddr,
76 timeout: Duration,
77 sender: oneshot::Sender<IcmpEchoReply>,
78 ) -> Self {
79 RequestContext {
80 wait_object: HANDLE::default(),
81 event,
82 buffer: vec![0u8; REPLY_BUFFER_SIZE].into_boxed_slice(),
83 target_addr,
84 timeout,
85 sender,
86 }
87 }
88
89 fn buffer_ptr(&mut self) -> *mut u8 {
90 self.buffer.as_mut_ptr()
91 }
92
93 fn buffer_size(&self) -> usize {
94 self.buffer.len()
95 }
96}
97
98#[derive(Clone)]
122pub struct IcmpEchoRequestor {
123 inner: Arc<RequestorInner>,
124}
125
126struct RequestorInner {
127 icmp_handle: HANDLE,
128 target_addr: IpAddr,
129 source_addr: IpAddr,
130 ttl: u8,
131 timeout: Duration,
132}
133
134unsafe impl Send for RequestorInner {}
136unsafe impl Sync for RequestorInner {}
137
138impl IcmpEchoRequestor {
139 pub fn new(
184 target_addr: IpAddr,
185 source_addr: Option<IpAddr>,
186 ttl: Option<u8>,
187 timeout: Option<Duration>,
188 ) -> io::Result<Self> {
189 match (target_addr, source_addr) {
191 (IpAddr::V4(_), Some(IpAddr::V6(_))) | (IpAddr::V6(_), Some(IpAddr::V4(_))) => {
192 return Err(io::Error::new(
193 io::ErrorKind::InvalidInput,
194 "Source address type does not match target address type",
195 ));
196 }
197 _ => {}
198 }
199
200 let icmp_handle = match target_addr {
201 IpAddr::V4(_) => unsafe { IcmpCreateFile()? },
202 IpAddr::V6(_) => unsafe { Icmp6CreateFile()? },
203 };
204 debug_assert!(!icmp_handle.is_invalid());
205
206 let source_addr = source_addr.unwrap_or(match target_addr {
207 IpAddr::V4(_) => IpAddr::V4(Ipv4Addr::UNSPECIFIED),
208 IpAddr::V6(_) => IpAddr::V6(Ipv6Addr::UNSPECIFIED),
209 });
210 let ttl = ttl.unwrap_or(PING_DEFAULT_TTL);
211 let timeout = timeout.unwrap_or(PING_DEFAULT_TIMEOUT);
212
213 Ok(IcmpEchoRequestor {
214 inner: Arc::new(RequestorInner {
215 icmp_handle,
216 target_addr,
217 source_addr,
218 ttl,
219 timeout,
220 }),
221 })
222 }
223
224 pub async fn send(&self) -> io::Result<IcmpEchoReply> {
280 let (reply_tx, reply_rx) = oneshot::channel();
281
282 self.handle_send(reply_tx)?;
283
284 reply_rx
285 .await
286 .map_err(|_| io::Error::other("reply channel closed unexpectedly"))
287 }
288
289 fn handle_send(&self, reply_tx: oneshot::Sender<IcmpEchoReply>) -> io::Result<()> {
290 let event = unsafe { CreateEventW(None, false, false, None)? };
292
293 let context_raw = Box::into_raw(Box::new(RequestContext::new(
295 event,
296 self.inner.target_addr,
297 self.inner.timeout,
298 reply_tx,
299 )));
300
301 match self.do_send(context_raw, event) {
303 Ok(()) => {
304 unsafe {
306 match RegisterWaitForSingleObject(
307 &mut (*context_raw).wait_object,
308 event,
309 Some(wait_callback),
310 Some(context_raw as *const _),
311 INFINITE,
312 WT_EXECUTEINWAITTHREAD | WT_EXECUTEONLYONCE,
313 ) {
314 Ok(()) => Ok(()),
315 Err(e) => {
316 let _ = CloseHandle(event);
319 drop(Box::from_raw(context_raw));
320
321 Err(e.into())
322 }
323 }
324 }
325 }
326 Err(e) => {
327 let status = ip_error_to_icmp_status(e);
329 let reply = IcmpEchoReply::new(self.inner.target_addr, status, Duration::ZERO);
330
331 unsafe {
332 let ctx = Box::from_raw(context_raw);
334 let _ = ctx.sender.send(reply);
335
336 if !ctx.event.is_invalid() {
340 let _ = CloseHandle(ctx.event);
341 }
342
343 }
345
346 Ok(())
348 }
349 }
350 }
351
352 fn do_send(&self, context: *mut RequestContext, event: HANDLE) -> Result<(), u32> {
353 let ip_option = IP_OPTION_INFORMATION {
354 Ttl: self.inner.ttl,
355 ..Default::default()
356 };
357
358 let req_data = [0u8; PING_DEFAULT_REQUEST_DATA_LENGTH];
359
360 let error = match self.inner.target_addr {
361 IpAddr::V4(taddr) => {
362 let saddr = if let IpAddr::V4(saddr) = self.inner.source_addr {
363 saddr
364 } else {
365 unreachable!("source address must be IPv4 for IPv4 target");
366 };
367
368 unsafe {
369 let ctx = context.as_mut().unwrap();
370
371 IcmpSendEcho2Ex(
372 self.inner.icmp_handle,
373 Some(event),
374 None,
375 None,
376 u32::from(saddr).to_be(),
377 u32::from(taddr).to_be(),
378 req_data.as_ptr() as *const _,
379 req_data.len() as u16,
380 Some(&ip_option as *const _ as *const _),
381 ctx.buffer_ptr() as *mut _,
382 ctx.buffer_size() as u32,
383 self.inner.timeout.as_millis() as u32,
384 )
385 }
386 }
387 IpAddr::V6(taddr) => {
388 let saddr = if let IpAddr::V6(saddr) = self.inner.source_addr {
389 saddr
390 } else {
391 unreachable!("source address must be IPv6 for IPv6 target");
392 };
393
394 unsafe {
395 let ctx = context.as_mut().unwrap();
396
397 let src_saddr: SOCKADDR_IN6 = SocketAddrV6::new(saddr, 0, 0, 0).into();
398 let dst_saddr: SOCKADDR_IN6 = SocketAddrV6::new(taddr, 0, 0, 0).into();
399
400 Icmp6SendEcho2(
401 self.inner.icmp_handle,
402 Some(event),
403 None,
404 None,
405 &src_saddr,
406 &dst_saddr,
407 req_data.as_ptr() as *const _,
408 req_data.len() as u16,
409 Some(&ip_option as *const _ as *const _),
410 ctx.buffer_ptr() as *mut _,
411 ctx.buffer_size() as u32,
412 self.inner.timeout.as_millis() as u32,
413 )
414 }
415 }
416 };
417
418 if error == ERROR_IO_PENDING.0 {
419 Ok(())
420 } else {
421 let code = unsafe { GetLastError() };
422 if code == ERROR_IO_PENDING {
423 Ok(())
424 } else {
425 Err(code.0)
426 }
427 }
428 }
429}
430
431impl Drop for RequestorInner {
432 fn drop(&mut self) {
433 unsafe {
434 if !self.icmp_handle.is_invalid() {
435 let _ = IcmpCloseHandle(self.icmp_handle);
436 }
437 }
438 }
439}
440
441fn ip_error_to_icmp_status(code: u32) -> IcmpEchoStatus {
442 match code {
443 IP_SUCCESS => IcmpEchoStatus::Success,
444 IP_REQ_TIMED_OUT | IP_TIME_EXCEEDED | IP_TTL_EXPIRED_REASSEM | IP_TTL_EXPIRED_TRANSIT => {
445 IcmpEchoStatus::TimedOut
446 }
447 IP_DEST_HOST_UNREACHABLE
448 | IP_DEST_NET_UNREACHABLE
449 | IP_DEST_PORT_UNREACHABLE
450 | IP_DEST_PROT_UNREACHABLE
451 | IP_DEST_UNREACHABLE => IcmpEchoStatus::Unreachable,
452 code if code == ERROR_NETWORK_UNREACHABLE.0
453 || code == ERROR_HOST_UNREACHABLE.0
454 || code == ERROR_PROTOCOL_UNREACHABLE.0
455 || code == ERROR_PORT_UNREACHABLE.0 =>
456 {
457 IcmpEchoStatus::Unreachable
458 }
459 _ => IcmpEchoStatus::Unknown,
460 }
461}
462
463unsafe extern "system" fn wait_callback(ptr: *mut c_void, timer_fired: bool) {
464 debug_assert!(!timer_fired, "Timer should not be fired here");
465
466 let context = Box::from_raw(ptr as *mut RequestContext);
468
469 let reply = match context.target_addr {
470 IpAddr::V4(_) => {
471 let ret = unsafe {
472 IcmpParseReplies(
473 context.buffer.as_ptr() as *mut _,
474 context.buffer.len() as u32,
475 )
476 };
477
478 if ret == 0 {
479 let error = unsafe { GetLastError() };
481 if error.0 == IP_REQ_TIMED_OUT {
482 IcmpEchoReply::new(
484 context.target_addr,
485 IcmpEchoStatus::TimedOut,
486 context.timeout,
487 )
488 } else {
489 IcmpEchoReply::new(context.target_addr, IcmpEchoStatus::Unknown, Duration::ZERO)
491 }
492 } else {
493 debug_assert_eq!(ret, 1);
494
495 let resp = (context.buffer.as_ptr() as *const ICMP_ECHO_REPLY)
496 .as_ref()
497 .unwrap();
498 let addr = IpAddr::V4(u32::from_be(resp.Address).into());
499
500 IcmpEchoReply::new(
501 addr,
502 ip_error_to_icmp_status(resp.Status),
503 Duration::from_millis(resp.RoundTripTime.into()),
504 )
505 }
506 }
507 IpAddr::V6(_) => {
508 let ret = unsafe {
509 Icmp6ParseReplies(
510 context.buffer.as_ptr() as *mut _,
511 context.buffer.len() as u32,
512 )
513 };
514
515 if ret == 0 {
516 let error = unsafe { GetLastError() };
518 if error.0 == IP_REQ_TIMED_OUT {
519 IcmpEchoReply::new(
521 context.target_addr,
522 IcmpEchoStatus::TimedOut,
523 context.timeout,
524 )
525 } else {
526 IcmpEchoReply::new(context.target_addr, IcmpEchoStatus::Unknown, Duration::ZERO)
528 }
529 } else {
530 debug_assert_eq!(ret, 1);
531
532 let resp = (context.buffer.as_ptr() as *const ICMPV6_ECHO_REPLY)
533 .as_ref()
534 .unwrap();
535 let mut addr_raw = IN6_ADDR::default();
536 addr_raw.u.Word = resp.Address.sin6_addr;
537 let addr = IpAddr::V6(addr_raw.into());
538
539 IcmpEchoReply::new(
540 addr,
541 ip_error_to_icmp_status(resp.Status),
542 Duration::from_millis(resp.RoundTripTime.into()),
543 )
544 }
545 }
546 };
547
548 let _ = context.sender.send(reply);
549
550 if !context.wait_object.is_invalid() {
552 let _ = UnregisterWaitEx(context.wait_object, None);
554 }
555 if !context.event.is_invalid() {
556 let _ = CloseHandle(context.event);
557 }
558 }