Skip to main content

ping_async/platform/
socket.rs

1// platform/socket.rs
2
3use std::collections::HashMap;
4use std::io;
5use std::net::{IpAddr, SocketAddr};
6use std::sync::{
7    atomic::{AtomicU16, Ordering},
8    Arc, Mutex, OnceLock,
9};
10use std::time::{Duration, SystemTime, UNIX_EPOCH};
11
12use futures::channel::oneshot;
13use socket2::{Domain, Protocol, Socket, Type};
14use tokio::{net::UdpSocket, time};
15
16use crate::{
17    icmp::IcmpPacket, IcmpEchoReply, IcmpEchoStatus, PING_DEFAULT_TIMEOUT, PING_DEFAULT_TTL,
18};
19
20type RequestRegistry = Arc<Mutex<HashMap<u16, oneshot::Sender<IcmpEchoReply>>>>;
21
22/// Persistent record of a fatal router error, stored so that all subsequent
23/// `send()` calls fail fast with the same error information.
24struct RouterError {
25    kind: io::ErrorKind,
26    message: String,
27}
28
29impl RouterError {
30    fn from_io_error(e: &io::Error) -> Self {
31        RouterError {
32            kind: e.kind(),
33            message: e.to_string(),
34        }
35    }
36
37    fn to_io_error(&self) -> io::Error {
38        io::Error::new(self.kind, self.message.clone())
39    }
40}
41
42struct RouterContext {
43    target_addr: IpAddr,
44    socket: Arc<UdpSocket>,
45    registry: RequestRegistry,
46    failed: Arc<Mutex<Option<RouterError>>>,
47}
48
49/// Requestor for sending ICMP Echo Requests (ping) and receiving replies on Unix systems.
50///
51/// This implementation uses ICMP sockets with Tokio for async operations. It requires
52/// unprivileged ICMP socket support, which is available on macOS by default and on
53/// Linux when the `net.ipv4.ping_group_range` sysctl parameter is properly configured.
54///
55/// The requestor spawns a background task to handle incoming replies and is safe to
56/// clone and use across multiple threads and async tasks.
57///
58/// # Platform Requirements
59///
60/// - **macOS**: Works with unprivileged sockets out of the box
61/// - **Linux**: Requires `net.ipv4.ping_group_range` sysctl to allow unprivileged ICMP sockets
62///
63/// # Examples
64///
65/// ```rust,no_run
66/// use ping_async::IcmpEchoRequestor;
67/// use std::net::IpAddr;
68///
69/// #[tokio::main]
70/// async fn main() -> std::io::Result<()> {
71///     let target = "8.8.8.8".parse::<IpAddr>().unwrap();
72///     let pinger = IcmpEchoRequestor::new(target, None, None, None)?;
73///
74///     let reply = pinger.send().await?;
75///     println!("Reply: {:?}", reply);
76///
77///     Ok(())
78/// }
79/// ```
80#[derive(Clone)]
81pub struct IcmpEchoRequestor {
82    inner: Arc<RequestorInner>,
83}
84
85struct RequestorInner {
86    socket: Arc<UdpSocket>,
87    target_addr: IpAddr,
88    timeout: Duration,
89    identifier: u16,
90    sequence: AtomicU16,
91    registry: RequestRegistry,
92    router_abort: OnceLock<tokio::task::AbortHandle>,
93    router_context: RouterContext,
94}
95
96impl IcmpEchoRequestor {
97    /// Creates a new ICMP echo requestor for the specified target address.
98    ///
99    /// # Arguments
100    ///
101    /// * `target_addr` - The IP address to ping (IPv4 or IPv6)
102    /// * `source_addr` - Optional source IP address to bind to. Must match the IP version of `target_addr`
103    /// * `ttl` - Optional Time-To-Live value. Defaults to [`PING_DEFAULT_TTL`](crate::PING_DEFAULT_TTL)
104    /// * `timeout` - Optional timeout duration. Defaults to [`PING_DEFAULT_TIMEOUT`](crate::PING_DEFAULT_TIMEOUT)
105    ///
106    /// # Errors
107    ///
108    /// Returns an error if:
109    /// - The source address type doesn't match the target address type (IPv4 vs IPv6)
110    /// - ICMP socket creation fails (typically due to insufficient permissions)
111    /// - Socket configuration fails
112    ///
113    /// # Platform Requirements
114    ///
115    /// - **Linux**: Requires `net.ipv4.ping_group_range` sysctl parameter to allow unprivileged ICMP sockets.
116    ///   Check with: `sysctl net.ipv4.ping_group_range`
117    /// - **macOS**: Works with unprivileged sockets by default
118    ///
119    /// # Examples
120    ///
121    /// ```rust,no_run
122    /// use ping_async::IcmpEchoRequestor;
123    /// use std::net::IpAddr;
124    /// use std::time::Duration;
125    ///
126    /// // Basic usage with defaults
127    /// let pinger = IcmpEchoRequestor::new(
128    ///     "8.8.8.8".parse().unwrap(),
129    ///     None,
130    ///     None,
131    ///     None
132    /// )?;
133    ///
134    /// // With custom source address and timeout
135    /// let pinger = IcmpEchoRequestor::new(
136    ///     "2001:4860:4860::8888".parse().unwrap(),
137    ///     Some("::1".parse().unwrap()),
138    ///     Some(64),
139    ///     Some(Duration::from_millis(500))
140    /// )?;
141    /// # Ok::<(), std::io::Error>(())
142    /// ```
143    pub fn new(
144        target_addr: IpAddr,
145        source_addr: Option<IpAddr>,
146        ttl: Option<u8>,
147        timeout: Option<Duration>,
148    ) -> io::Result<Self> {
149        // Check if the target address matches the source address type
150        match (target_addr, source_addr) {
151            (IpAddr::V4(_), Some(IpAddr::V6(_))) | (IpAddr::V6(_), Some(IpAddr::V4(_))) => {
152                return Err(io::Error::new(
153                    io::ErrorKind::InvalidInput,
154                    "Source address type does not match target address type",
155                ));
156            }
157            _ => {}
158        }
159
160        let timeout = timeout.unwrap_or(PING_DEFAULT_TIMEOUT);
161        let sequence = AtomicU16::new(0);
162
163        let (socket, identifier) = create_socket(target_addr, source_addr, ttl)?;
164        let socket = Arc::new(socket);
165        let registry = Arc::new(Mutex::new(HashMap::new()));
166
167        // Create a context for the router task
168        let router_context = RouterContext {
169            target_addr,
170            socket: Arc::clone(&socket),
171            registry: Arc::clone(&registry),
172            failed: Arc::new(Mutex::new(None::<RouterError>)),
173        };
174
175        Ok(IcmpEchoRequestor {
176            inner: Arc::new(RequestorInner {
177                socket,
178                target_addr,
179                timeout,
180                identifier,
181                sequence,
182                registry,
183                router_abort: OnceLock::new(),
184                router_context,
185            }),
186        })
187    }
188
189    /// Sends an ICMP echo request and waits for a reply.
190    ///
191    /// This method is async and will complete when either:
192    /// - An echo reply is received
193    /// - The configured timeout expires
194    /// - An error occurs
195    ///
196    /// The requestor uses lazy initialization - the background reply router task
197    /// is only spawned on the first call to `send()`. The requestor can be used
198    /// multiple times and is safe to use concurrently from multiple async tasks.
199    ///
200    /// # Returns
201    ///
202    /// Returns an [`IcmpEchoReply`](crate::IcmpEchoReply) containing:
203    /// - The destination IP address
204    /// - The status of the ping operation
205    /// - The measured round-trip time
206    ///
207    /// # Errors
208    ///
209    /// Returns an error if:
210    /// - The socket send operation fails immediately
211    /// - The background router task has failed (typically due to permission loss)
212    /// - Internal communication channels fail unexpectedly
213    ///
214    /// Note that timeout and unreachable conditions are returned as successful
215    /// `IcmpEchoReply` with appropriate status values, not as errors.
216    ///
217    /// # Examples
218    ///
219    /// ```rust,no_run
220    /// use ping_async::{IcmpEchoRequestor, IcmpEchoStatus};
221    ///
222    /// #[tokio::main]
223    /// async fn main() -> std::io::Result<()> {
224    ///     let pinger = IcmpEchoRequestor::new(
225    ///         "8.8.8.8".parse().unwrap(),
226    ///         None, None, None
227    ///     )?;
228    ///
229    ///     // Send multiple pings using the same requestor
230    ///     for i in 0..3 {
231    ///         let reply = pinger.send().await?;
232    ///
233    ///         match reply.status() {
234    ///             IcmpEchoStatus::Success => {
235    ///                 println!("Ping {}: {:?}", i, reply.round_trip_time());
236    ///             }
237    ///             IcmpEchoStatus::TimedOut => {
238    ///                 println!("Ping {} timed out", i);
239    ///             }
240    ///             _ => {
241    ///                 println!("Ping {} failed: {:?}", i, reply.status());
242    ///             }
243    ///         }
244    ///     }
245    ///
246    ///     Ok(())
247    /// }
248    /// ```
249    pub async fn send(&self) -> io::Result<IcmpEchoReply> {
250        // Check if router failed already — error is persistent, not consumed
251        if let Some(ref router_error) = *self
252            .inner
253            .router_context
254            .failed
255            .lock()
256            .unwrap_or_else(|poisoned| poisoned.into_inner())
257        {
258            return Err(router_error.to_io_error());
259        }
260
261        // lazy spawning
262        self.ensure_router_running();
263
264        let sequence = self.inner.sequence.fetch_add(1, Ordering::SeqCst);
265        let key = sequence;
266
267        // Use timestamp as our payload
268        let timestamp = SystemTime::now()
269            .duration_since(UNIX_EPOCH)
270            .map_err(|e| io::Error::other(format!("timestamp error: {e}")))?
271            .as_nanos() as u64;
272        let payload = timestamp.to_be_bytes();
273
274        let packet = IcmpPacket::new_echo_request(
275            self.inner.target_addr,
276            self.inner.identifier,
277            sequence,
278            &payload,
279        );
280
281        // Register in the registry BEFORE sending so fast replies (e.g. loopback)
282        // are not dropped by the router due to a missing entry.
283        let (tx, reply_rx) = oneshot::channel();
284        self.inner
285            .registry
286            .lock()
287            .unwrap_or_else(|poisoned| poisoned.into_inner())
288            .insert(key, tx);
289
290        let target = SocketAddr::new(self.inner.target_addr, 0);
291        if let Err(e) = self.inner.socket.send_to(packet.as_bytes(), target).await {
292            // Send failed — remove the registry entry we just inserted
293            self.inner
294                .registry
295                .lock()
296                .unwrap_or_else(|poisoned| poisoned.into_inner())
297                .remove(&key);
298
299            return match e.kind() {
300                io::ErrorKind::NetworkUnreachable
301                | io::ErrorKind::NetworkDown
302                | io::ErrorKind::HostUnreachable => Ok(IcmpEchoReply::new(
303                    self.inner.target_addr,
304                    IcmpEchoStatus::Unreachable,
305                    Duration::ZERO,
306                )),
307                _ => Err(e),
308            };
309        }
310
311        let timeout = self.inner.timeout;
312        let target_addr = self.inner.target_addr;
313
314        tokio::select! {
315            result = reply_rx => {
316                match result {
317                    Ok(reply) => Ok(reply),
318                    Err(_) => {
319                        // Channel closed — check if router failed for a consistent error
320                        if let Some(ref router_error) = *self
321                            .inner
322                            .router_context
323                            .failed
324                            .lock()
325                            .unwrap_or_else(|poisoned| poisoned.into_inner())
326                        {
327                            Err(router_error.to_io_error())
328                        } else {
329                            Err(io::Error::other("reply channel closed"))
330                        }
331                    }
332                }
333            }
334            _ = time::sleep(timeout) => {
335                // Remove from registry on timeout
336                self.inner.registry.lock().unwrap_or_else(|poisoned| poisoned.into_inner()).remove(&key);
337
338                // Calculate RTT for timed out request
339                let now = SystemTime::now()
340                    .duration_since(UNIX_EPOCH)
341                    .map_err(|e| io::Error::other(format!("timestamp error: {e}")))?
342                    .as_nanos() as u64;
343                let rtt = Duration::from_nanos(now.saturating_sub(timestamp));
344
345                Ok(IcmpEchoReply::new(
346                    target_addr,
347                    IcmpEchoStatus::TimedOut,
348                    rtt,
349                ))
350            }
351        }
352    }
353
354    fn ensure_router_running(&self) {
355        let target_addr = self.inner.router_context.target_addr;
356        let identifier = self.inner.identifier;
357        let socket = Arc::clone(&self.inner.router_context.socket);
358        let registry = Arc::clone(&self.inner.router_context.registry);
359        let failed = Arc::clone(&self.inner.router_context.failed);
360
361        self.inner.router_abort.get_or_init(|| {
362            let handle = tokio::spawn(reply_router_loop(
363                target_addr,
364                identifier,
365                socket,
366                registry,
367                failed,
368            ));
369            handle.abort_handle()
370        });
371    }
372}
373
374impl Drop for RequestorInner {
375    fn drop(&mut self) {
376        if let Some(abort_handle) = self.router_abort.get() {
377            abort_handle.abort();
378        }
379    }
380}
381
382async fn reply_router_loop(
383    target_addr: IpAddr,
384    identifier: u16,
385    socket: Arc<UdpSocket>,
386    registry: RequestRegistry,
387    failed: Arc<Mutex<Option<RouterError>>>,
388) {
389    loop {
390        let mut buf = vec![0u8; 1024];
391
392        match socket.recv(&mut buf).await {
393            Ok(size) => {
394                buf.truncate(size);
395
396                if let Some(reply_packet) = IcmpPacket::parse_reply(&buf, target_addr) {
397                    // Check if this is a reply to our request by comparing identifier, ignoring if not
398                    if reply_packet.identifier() != identifier {
399                        continue;
400                    }
401
402                    // Use sequence number to find the waiting sender
403                    let key = reply_packet.sequence();
404                    let sender = registry
405                        .lock()
406                        .unwrap_or_else(|poisoned| poisoned.into_inner())
407                        .remove(&key);
408
409                    if let Some(sender) = sender {
410                        // Extract timestamp from payload to calculate RTT
411                        let payload = reply_packet.payload();
412
413                        let reply = if payload.len() >= 8 {
414                            let sent_timestamp = u64::from_be_bytes([
415                                payload[0], payload[1], payload[2], payload[3], payload[4],
416                                payload[5], payload[6], payload[7],
417                            ]);
418
419                            let now = SystemTime::now()
420                                .duration_since(UNIX_EPOCH)
421                                .unwrap_or_default()
422                                .as_nanos() as u64;
423                            let rtt = Duration::from_nanos(now.saturating_sub(sent_timestamp));
424
425                            IcmpEchoReply::new(target_addr, IcmpEchoStatus::Success, rtt)
426                        } else {
427                            // Report Unknown error if payload is too short
428                            IcmpEchoReply::new(target_addr, IcmpEchoStatus::Unknown, Duration::ZERO)
429                        };
430
431                        // Send reply to waiting thread
432                        let _ = sender.send(reply);
433                    }
434                } else if let Some(error_info) = IcmpPacket::parse_error_reply(&buf, target_addr) {
435                    // ICMP error message (Dest Unreachable, Time Exceeded) with
436                    // embedded echo request — match by identifier and sequence.
437                    if error_info.identifier != identifier {
438                        continue;
439                    }
440
441                    let sender = registry
442                        .lock()
443                        .unwrap_or_else(|poisoned| poisoned.into_inner())
444                        .remove(&error_info.sequence);
445
446                    if let Some(sender) = sender {
447                        let reply =
448                            IcmpEchoReply::new(target_addr, error_info.status, Duration::ZERO);
449                        let _ = sender.send(reply);
450                    }
451                }
452            }
453            Err(e) => {
454                match e.kind() {
455                    // Fatal errors - router cannot continue
456                    io::ErrorKind::PermissionDenied |        // Lost privileges
457                    io::ErrorKind::AddrNotAvailable |        // Address no longer available
458                    io::ErrorKind::ConnectionAborted |       // Socket forcibly closed
459                    io::ErrorKind::NotConnected => {         // Socket disconnected
460                        // Store the error persistently so all future send() calls fail fast
461                        let mut failed_lock = failed.lock().unwrap_or_else(|poisoned| poisoned.into_inner());
462                        *failed_lock = Some(RouterError::from_io_error(&e));
463
464                        // Drain the registry — dropping senders closes channels, which
465                        // causes in-flight send() calls to see Canceled and check `failed`
466                        // for a consistent error. Entries already removed by timeout are
467                        // simply absent.
468                        registry.lock().unwrap_or_else(|poisoned| poisoned.into_inner()).clear();
469
470                        return;
471                    }
472
473                    // Continue with temporary network issues, etc.
474                    _ => continue,
475                }
476            }
477        }
478    }
479}
480
481fn create_socket(
482    target_addr: IpAddr,
483    source_addr: Option<IpAddr>,
484    ttl: Option<u8>,
485) -> io::Result<(UdpSocket, u16)> {
486    let socket = match target_addr {
487        IpAddr::V4(_) => Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::ICMPV4))?,
488        IpAddr::V6(_) => Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::ICMPV6))?,
489    };
490    socket.set_nonblocking(true)?;
491
492    let ttl = ttl.unwrap_or(PING_DEFAULT_TTL);
493    if target_addr.is_ipv4() {
494        socket.set_ttl_v4(ttl as u32)?;
495    } else {
496        socket.set_unicast_hops_v6(ttl as u32)?;
497    }
498
499    // Platform-specific ICMP identifier handling
500    //
501    // macOS/BSD systems preserve the ICMP identifier field throughout the ping process.
502    // When we send an ICMP ECHO request with a specific identifier (e.g., 6789),
503    // the reply will contain the same identifier value. This allows us to use
504    // random identifiers for distinguishing between different ping sessions.
505    #[cfg(not(target_os = "linux"))]
506    let identifier = {
507        // On macOS, use random identifier and bind to source address if provided
508        if let Some(source_addr) = source_addr {
509            socket.bind(&SocketAddr::new(source_addr, 0).into())?;
510        }
511        rand::random()
512    };
513
514    // Linux systems behave differently with unprivileged ICMP sockets (SOCK_DGRAM).
515    // The Linux kernel automatically replaces the ICMP identifier field with the
516    // socket's local port number. This means:
517    // 1. Any identifier we set will be ignored and replaced by the kernel
518    // 2. ICMP replies are routed back based on the socket port, not the identifier
519    // 3. We must bind the socket to get a port assignment from the kernel
520    // 4. The port number becomes our effective identifier for matching replies
521    //
522    // This behavior ensures proper delivery of ICMP replies to the correct socket
523    // in a multi-process environment, since the kernel handles routing internally.
524    #[cfg(target_os = "linux")]
525    let identifier = {
526        // Bind with port 0 to let kernel assign a unique port number.
527        // This port will be used as the ICMP identifier by the kernel.
528        let bind_addr = source_addr.unwrap_or(match target_addr {
529            IpAddr::V4(_) => IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED),
530            IpAddr::V6(_) => IpAddr::V6(std::net::Ipv6Addr::UNSPECIFIED),
531        });
532        socket.bind(&SocketAddr::new(bind_addr, 0).into())?;
533
534        // Extract the kernel-assigned port number, which will be used as the ICMP identifier
535        let local_addr = socket.local_addr()?;
536        local_addr
537            .as_socket()
538            .ok_or(io::Error::other(
539                "Failed to get kernel-assigned ICMP identifier",
540            ))?
541            .port()
542    };
543
544    let udp_socket = UdpSocket::from_std(socket.into())?;
545    Ok((udp_socket, identifier))
546}
547
548#[cfg(test)]
549mod tests {
550    use super::*;
551    use std::io;
552
553    #[cfg(test)]
554    fn is_router_spawned(pinger: &IcmpEchoRequestor) -> bool {
555        pinger.inner.router_abort.get().is_some()
556    }
557
558    #[tokio::test]
559    async fn test_lazy_router_spawning() -> io::Result<()> {
560        // Create a requestor but don't call send() yet
561        let pinger = IcmpEchoRequestor::new("127.0.0.1".parse().unwrap(), None, None, None)?;
562
563        // Router should not be spawned yet - this is the key test for lazy initialization
564        assert!(
565            !is_router_spawned(&pinger),
566            "Router should not be spawned after new()"
567        );
568
569        // Now call send() - this should trigger lazy router spawning
570        let reply = pinger.send().await?;
571        assert_eq!(reply.destination(), "127.0.0.1".parse::<IpAddr>().unwrap());
572
573        // Verify router is now spawned
574        assert!(
575            is_router_spawned(&pinger),
576            "Router should be spawned after first send()"
577        );
578
579        // Subsequent sends should reuse the same router
580        let reply2 = pinger.send().await?;
581        assert_eq!(reply2.destination(), "127.0.0.1".parse::<IpAddr>().unwrap());
582
583        // Router should still be spawned
584        assert!(
585            is_router_spawned(&pinger),
586            "Router should remain spawned after subsequent sends"
587        );
588
589        Ok(())
590    }
591}