ping_async/platform/
socket.rs

1// platform/socket.rs
2
3use std::collections::HashMap;
4use std::io;
5use std::net::{IpAddr, SocketAddr};
6use std::sync::{
7    atomic::{AtomicU16, Ordering},
8    Arc, Mutex, OnceLock,
9};
10use std::time::{Duration, SystemTime, UNIX_EPOCH};
11
12use futures::channel::oneshot;
13use socket2::{Domain, Protocol, Socket, Type};
14use tokio::{net::UdpSocket, time};
15
16use crate::{
17    icmp::IcmpPacket, IcmpEchoReply, IcmpEchoStatus, PING_DEFAULT_TIMEOUT, PING_DEFAULT_TTL,
18};
19
20type RequestRegistry = Arc<Mutex<HashMap<u16, oneshot::Sender<IcmpEchoReply>>>>;
21
22struct RouterContext {
23    target_addr: IpAddr,
24    socket: Arc<UdpSocket>,
25    registry: RequestRegistry,
26    failed: Arc<Mutex<Option<io::Error>>>,
27}
28
29/// Requestor for sending ICMP Echo Requests (ping) and receiving replies on Unix systems.
30///
31/// This implementation uses ICMP sockets with Tokio for async operations. It requires
32/// unprivileged ICMP socket support, which is available on macOS by default and on
33/// Linux when the `net.ipv4.ping_group_range` sysctl parameter is properly configured.
34///
35/// The requestor spawns a background task to handle incoming replies and is safe to
36/// clone and use across multiple threads and async tasks.
37///
38/// # Platform Requirements
39///
40/// - **macOS**: Works with unprivileged sockets out of the box
41/// - **Linux**: Requires `net.ipv4.ping_group_range` sysctl to allow unprivileged ICMP sockets
42///
43/// # Examples
44///
45/// ```rust,no_run
46/// use ping_async::IcmpEchoRequestor;
47/// use std::net::IpAddr;
48///
49/// #[tokio::main]
50/// async fn main() -> std::io::Result<()> {
51///     let target = "8.8.8.8".parse::<IpAddr>().unwrap();
52///     let pinger = IcmpEchoRequestor::new(target, None, None, None)?;
53///
54///     let reply = pinger.send().await?;
55///     println!("Reply: {:?}", reply);
56///
57///     Ok(())
58/// }
59/// ```
60#[derive(Clone)]
61pub struct IcmpEchoRequestor {
62    inner: Arc<RequestorInner>,
63}
64
65struct RequestorInner {
66    socket: Arc<UdpSocket>,
67    target_addr: IpAddr,
68    timeout: Duration,
69    identifier: u16,
70    sequence: AtomicU16,
71    registry: RequestRegistry,
72    router_abort: OnceLock<tokio::task::AbortHandle>,
73    router_context: RouterContext,
74}
75
76impl IcmpEchoRequestor {
77    /// Creates a new ICMP echo requestor for the specified target address.
78    ///
79    /// # Arguments
80    ///
81    /// * `target_addr` - The IP address to ping (IPv4 or IPv6)
82    /// * `source_addr` - Optional source IP address to bind to. Must match the IP version of `target_addr`
83    /// * `ttl` - Optional Time-To-Live value. Defaults to [`PING_DEFAULT_TTL`](crate::PING_DEFAULT_TTL)
84    /// * `timeout` - Optional timeout duration. Defaults to [`PING_DEFAULT_TIMEOUT`](crate::PING_DEFAULT_TIMEOUT)
85    ///
86    /// # Errors
87    ///
88    /// Returns an error if:
89    /// - The source address type doesn't match the target address type (IPv4 vs IPv6)
90    /// - ICMP socket creation fails (typically due to insufficient permissions)
91    /// - Socket configuration fails
92    ///
93    /// # Platform Requirements
94    ///
95    /// - **Linux**: Requires `net.ipv4.ping_group_range` sysctl parameter to allow unprivileged ICMP sockets.
96    ///   Check with: `sysctl net.ipv4.ping_group_range`
97    /// - **macOS**: Works with unprivileged sockets by default
98    ///
99    /// # Examples
100    ///
101    /// ```rust,no_run
102    /// use ping_async::IcmpEchoRequestor;
103    /// use std::net::IpAddr;
104    /// use std::time::Duration;
105    ///
106    /// // Basic usage with defaults
107    /// let pinger = IcmpEchoRequestor::new(
108    ///     "8.8.8.8".parse().unwrap(),
109    ///     None,
110    ///     None,
111    ///     None
112    /// )?;
113    ///
114    /// // With custom source address and timeout
115    /// let pinger = IcmpEchoRequestor::new(
116    ///     "2001:4860:4860::8888".parse().unwrap(),
117    ///     Some("::1".parse().unwrap()),
118    ///     Some(64),
119    ///     Some(Duration::from_millis(500))
120    /// )?;
121    /// # Ok::<(), std::io::Error>(())
122    /// ```
123    pub fn new(
124        target_addr: IpAddr,
125        source_addr: Option<IpAddr>,
126        ttl: Option<u8>,
127        timeout: Option<Duration>,
128    ) -> io::Result<Self> {
129        // Check if the target address matches the source address type
130        match (target_addr, source_addr) {
131            (IpAddr::V4(_), Some(IpAddr::V6(_))) | (IpAddr::V6(_), Some(IpAddr::V4(_))) => {
132                return Err(io::Error::new(
133                    io::ErrorKind::InvalidInput,
134                    "Source address type does not match target address type",
135                ));
136            }
137            _ => {}
138        }
139
140        let timeout = timeout.unwrap_or(PING_DEFAULT_TIMEOUT);
141        let sequence = AtomicU16::new(0);
142
143        let (socket, identifier) = create_socket(target_addr, source_addr, ttl)?;
144        let socket = Arc::new(socket);
145        let registry = Arc::new(Mutex::new(HashMap::new()));
146
147        // Create a context for the router task
148        let router_context = RouterContext {
149            target_addr,
150            socket: Arc::clone(&socket),
151            registry: Arc::clone(&registry),
152            failed: Arc::new(Mutex::new(None)),
153        };
154
155        Ok(IcmpEchoRequestor {
156            inner: Arc::new(RequestorInner {
157                socket,
158                target_addr,
159                timeout,
160                identifier,
161                sequence,
162                registry,
163                router_abort: OnceLock::new(),
164                router_context,
165            }),
166        })
167    }
168
169    /// Sends an ICMP echo request and waits for a reply.
170    ///
171    /// This method is async and will complete when either:
172    /// - An echo reply is received
173    /// - The configured timeout expires
174    /// - An error occurs
175    ///
176    /// The requestor uses lazy initialization - the background reply router task
177    /// is only spawned on the first call to `send()`. The requestor can be used
178    /// multiple times and is safe to use concurrently from multiple async tasks.
179    ///
180    /// # Returns
181    ///
182    /// Returns an [`IcmpEchoReply`](crate::IcmpEchoReply) containing:
183    /// - The destination IP address
184    /// - The status of the ping operation
185    /// - The measured round-trip time
186    ///
187    /// # Errors
188    ///
189    /// Returns an error if:
190    /// - The socket send operation fails immediately
191    /// - The background router task has failed (typically due to permission loss)
192    /// - Internal communication channels fail unexpectedly
193    ///
194    /// Note that timeout and unreachable conditions are returned as successful
195    /// `IcmpEchoReply` with appropriate status values, not as errors.
196    ///
197    /// # Examples
198    ///
199    /// ```rust,no_run
200    /// use ping_async::{IcmpEchoRequestor, IcmpEchoStatus};
201    ///
202    /// #[tokio::main]
203    /// async fn main() -> std::io::Result<()> {
204    ///     let pinger = IcmpEchoRequestor::new(
205    ///         "8.8.8.8".parse().unwrap(),
206    ///         None, None, None
207    ///     )?;
208    ///
209    ///     // Send multiple pings using the same requestor
210    ///     for i in 0..3 {
211    ///         let reply = pinger.send().await?;
212    ///
213    ///         match reply.status() {
214    ///             IcmpEchoStatus::Success => {
215    ///                 println!("Ping {}: {:?}", i, reply.round_trip_time());
216    ///             }
217    ///             IcmpEchoStatus::TimedOut => {
218    ///                 println!("Ping {} timed out", i);
219    ///             }
220    ///             _ => {
221    ///                 println!("Ping {} failed: {:?}", i, reply.status());
222    ///             }
223    ///         }
224    ///     }
225    ///
226    ///     Ok(())
227    /// }
228    /// ```
229    pub async fn send(&self) -> io::Result<IcmpEchoReply> {
230        // Check if router failed already
231        if let Some(failed) = self
232            .inner
233            .router_context
234            .failed
235            .lock()
236            .unwrap_or_else(|poisoned| poisoned.into_inner())
237            .take()
238        {
239            return Err(failed);
240        }
241
242        // lazy spawning
243        self.ensure_router_running();
244
245        let sequence = self.inner.sequence.fetch_add(1, Ordering::SeqCst);
246        let key = sequence;
247
248        // Use timestamp as our payload
249        let timestamp = SystemTime::now()
250            .duration_since(UNIX_EPOCH)
251            .map_err(|e| io::Error::other(format!("timestamp error: {e}")))?
252            .as_nanos() as u64;
253        let payload = timestamp.to_be_bytes();
254
255        let packet = IcmpPacket::new_echo_request(
256            self.inner.target_addr,
257            self.inner.identifier,
258            sequence,
259            &payload,
260        );
261
262        let target = SocketAddr::new(self.inner.target_addr, 0);
263        let reply_rx = match self.inner.socket.send_to(packet.as_bytes(), target).await {
264            Ok(_) => {
265                let (tx, rx) = oneshot::channel();
266
267                self.inner
268                    .registry
269                    .lock()
270                    .unwrap_or_else(|poisoned| poisoned.into_inner())
271                    .insert(key, tx);
272
273                rx
274            }
275            Err(e) => match e.kind() {
276                io::ErrorKind::NetworkUnreachable
277                | io::ErrorKind::NetworkDown
278                | io::ErrorKind::HostUnreachable => {
279                    return Ok(IcmpEchoReply::new(
280                        self.inner.target_addr,
281                        IcmpEchoStatus::Unreachable,
282                        Duration::ZERO,
283                    ));
284                }
285                _ => return Err(e),
286            },
287        };
288
289        let timeout = self.inner.timeout;
290        let target_addr = self.inner.target_addr;
291
292        tokio::select! {
293            result = reply_rx => {
294                match result {
295                    Ok(reply) => Ok(reply),
296                    Err(_) => {
297                        // Channel closed - router probably failed
298                        Err(io::Error::other("reply channel closed"))
299                    }
300                }
301            }
302            _ = time::sleep(timeout) => {
303                // Remove from registry on timeout
304                self.inner.registry.lock().unwrap_or_else(|poisoned| poisoned.into_inner()).remove(&key);
305
306                // Calculate RTT for timed out request
307                let now = SystemTime::now()
308                    .duration_since(UNIX_EPOCH)
309                    .map_err(|e| io::Error::other(format!("timestamp error: {e}")))?
310                    .as_nanos() as u64;
311                let rtt = Duration::from_nanos(now - timestamp);
312
313                Ok(IcmpEchoReply::new(
314                    target_addr,
315                    IcmpEchoStatus::TimedOut,
316                    rtt,
317                ))
318            }
319        }
320    }
321
322    fn ensure_router_running(&self) {
323        let target_addr = self.inner.router_context.target_addr;
324        let identifier = self.inner.identifier;
325        let socket = Arc::clone(&self.inner.router_context.socket);
326        let registry = Arc::clone(&self.inner.router_context.registry);
327        let failed = Arc::clone(&self.inner.router_context.failed);
328
329        self.inner.router_abort.get_or_init(|| {
330            let handle = tokio::spawn(reply_router_loop(
331                target_addr,
332                identifier,
333                socket,
334                registry,
335                failed,
336            ));
337            handle.abort_handle()
338        });
339    }
340}
341
342impl Drop for RequestorInner {
343    fn drop(&mut self) {
344        if let Some(abort_handle) = self.router_abort.get() {
345            abort_handle.abort();
346        }
347    }
348}
349
350async fn reply_router_loop(
351    target_addr: IpAddr,
352    identifier: u16,
353    socket: Arc<UdpSocket>,
354    registry: RequestRegistry,
355    failed: Arc<Mutex<Option<io::Error>>>,
356) {
357    loop {
358        let mut buf = vec![0u8; 1024];
359
360        match socket.recv(&mut buf).await {
361            Ok(size) => {
362                buf.truncate(size);
363
364                if let Some(reply_packet) = IcmpPacket::parse_reply(&buf, target_addr) {
365                    // Check if this is a reply to our request by comparing identifier, ignoring if not
366                    if reply_packet.identifier() != identifier {
367                        continue;
368                    }
369
370                    // Use sequence number to find the waiting sender
371                    let key = reply_packet.sequence();
372                    let sender = registry
373                        .lock()
374                        .unwrap_or_else(|poisoned| poisoned.into_inner())
375                        .remove(&key);
376
377                    if let Some(sender) = sender {
378                        // Extract timestamp from payload to calculate RTT
379                        let payload = reply_packet.payload();
380
381                        let reply = if payload.len() >= 8 {
382                            let sent_timestamp = u64::from_be_bytes([
383                                payload[0], payload[1], payload[2], payload[3], payload[4],
384                                payload[5], payload[6], payload[7],
385                            ]);
386
387                            let now = SystemTime::now()
388                                .duration_since(UNIX_EPOCH)
389                                .unwrap_or_default()
390                                .as_nanos() as u64;
391                            let rtt = Duration::from_nanos(now.saturating_sub(sent_timestamp));
392
393                            IcmpEchoReply::new(target_addr, IcmpEchoStatus::Success, rtt)
394                        } else {
395                            // Report Unknown error if payload is too short
396                            IcmpEchoReply::new(target_addr, IcmpEchoStatus::Unknown, Duration::ZERO)
397                        };
398
399                        // Send reply to waiting thread
400                        let _ = sender.send(reply);
401                    }
402                }
403            }
404            Err(e) => {
405                match e.kind() {
406                    // Fatal errors - router cannot continue
407                    io::ErrorKind::PermissionDenied |        // Lost privileges
408                    io::ErrorKind::AddrNotAvailable |        // Address no longer available
409                    io::ErrorKind::ConnectionAborted |       // Socket forcibly closed
410                    io::ErrorKind::NotConnected => {         // Socket disconnected
411                        // Clear pending requests so they don't hang
412                        registry.lock().unwrap_or_else(|poisoned| poisoned.into_inner()).clear();
413
414                        // Mark the failed flag
415                        let mut failed_lock = failed.lock().unwrap_or_else(|poisoned| poisoned.into_inner());
416                        *failed_lock = Some(e);
417
418                        return;
419                    }
420
421                    // Continue with temporary network issues, etc.
422                    _ => continue,
423                }
424            }
425        }
426    }
427}
428
429fn create_socket(
430    target_addr: IpAddr,
431    source_addr: Option<IpAddr>,
432    ttl: Option<u8>,
433) -> io::Result<(UdpSocket, u16)> {
434    let socket = match target_addr {
435        IpAddr::V4(_) => Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::ICMPV4))?,
436        IpAddr::V6(_) => Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::ICMPV6))?,
437    };
438    socket.set_nonblocking(true)?;
439
440    let ttl = ttl.unwrap_or(PING_DEFAULT_TTL);
441    if target_addr.is_ipv4() {
442        socket.set_ttl_v4(ttl as u32)?;
443    } else {
444        socket.set_unicast_hops_v6(ttl as u32)?;
445    }
446
447    // Platform-specific ICMP identifier handling
448    //
449    // macOS/BSD systems preserve the ICMP identifier field throughout the ping process.
450    // When we send an ICMP ECHO request with a specific identifier (e.g., 6789),
451    // the reply will contain the same identifier value. This allows us to use
452    // random identifiers for distinguishing between different ping sessions.
453    #[cfg(not(target_os = "linux"))]
454    let identifier = {
455        // On macOS, use random identifier and bind to source address if provided
456        if let Some(source_addr) = source_addr {
457            socket.bind(&SocketAddr::new(source_addr, 0).into())?;
458        }
459        rand::random()
460    };
461
462    // Linux systems behave differently with unprivileged ICMP sockets (SOCK_DGRAM).
463    // The Linux kernel automatically replaces the ICMP identifier field with the
464    // socket's local port number. This means:
465    // 1. Any identifier we set will be ignored and replaced by the kernel
466    // 2. ICMP replies are routed back based on the socket port, not the identifier
467    // 3. We must bind the socket to get a port assignment from the kernel
468    // 4. The port number becomes our effective identifier for matching replies
469    //
470    // This behavior ensures proper delivery of ICMP replies to the correct socket
471    // in a multi-process environment, since the kernel handles routing internally.
472    #[cfg(target_os = "linux")]
473    let identifier = {
474        // Bind with port 0 to let kernel assign a unique port number.
475        // This port will be used as the ICMP identifier by the kernel.
476        let bind_addr = source_addr.unwrap_or(match target_addr {
477            IpAddr::V4(_) => IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED),
478            IpAddr::V6(_) => IpAddr::V6(std::net::Ipv6Addr::UNSPECIFIED),
479        });
480        socket.bind(&SocketAddr::new(bind_addr, 0).into())?;
481
482        // Extract the kernel-assigned port number, which will be used as the ICMP identifier
483        let local_addr = socket.local_addr()?;
484        local_addr
485            .as_socket()
486            .ok_or(io::Error::other(
487                "Failed to get kernel-assigned ICMP identifier",
488            ))?
489            .port()
490    };
491
492    let udp_socket = UdpSocket::from_std(socket.into())?;
493    Ok((udp_socket, identifier))
494}
495
496#[cfg(test)]
497mod tests {
498    use super::*;
499    use std::io;
500
501    #[cfg(test)]
502    fn is_router_spawned(pinger: &IcmpEchoRequestor) -> bool {
503        pinger.inner.router_abort.get().is_some()
504    }
505
506    #[tokio::test]
507    async fn test_lazy_router_spawning() -> io::Result<()> {
508        // Create a requestor but don't call send() yet
509        let pinger = IcmpEchoRequestor::new("127.0.0.1".parse().unwrap(), None, None, None)?;
510
511        // Router should not be spawned yet - this is the key test for lazy initialization
512        assert!(
513            !is_router_spawned(&pinger),
514            "Router should not be spawned after new()"
515        );
516
517        // Now call send() - this should trigger lazy router spawning
518        let reply = pinger.send().await?;
519        assert_eq!(reply.destination(), "127.0.0.1".parse::<IpAddr>().unwrap());
520
521        // Verify router is now spawned
522        assert!(
523            is_router_spawned(&pinger),
524            "Router should be spawned after first send()"
525        );
526
527        // Subsequent sends should reuse the same router
528        let reply2 = pinger.send().await?;
529        assert_eq!(reply2.destination(), "127.0.0.1".parse::<IpAddr>().unwrap());
530
531        // Router should still be spawned
532        assert!(
533            is_router_spawned(&pinger),
534            "Router should remain spawned after subsequent sends"
535        );
536
537        Ok(())
538    }
539}