ping_async/platform/
socket.rs

1// platform/socket.rs
2
3use std::collections::HashMap;
4use std::io;
5use std::net::{IpAddr, SocketAddr};
6use std::sync::{
7    atomic::{AtomicU16, Ordering},
8    Arc, Mutex, OnceLock,
9};
10use std::time::{Duration, SystemTime, UNIX_EPOCH};
11
12use futures::channel::oneshot;
13use rand::random;
14use socket2::{Domain, Protocol, Socket, Type};
15use tokio::{net::UdpSocket, time};
16
17use crate::{
18    icmp::IcmpPacket, IcmpEchoReply, IcmpEchoStatus, PING_DEFAULT_TIMEOUT, PING_DEFAULT_TTL,
19};
20
21type RequestRegistry = Arc<Mutex<HashMap<u16, oneshot::Sender<IcmpEchoReply>>>>;
22
23struct RouterContext {
24    target_addr: IpAddr,
25    socket: Arc<UdpSocket>,
26    registry: RequestRegistry,
27    failed: Arc<Mutex<Option<io::Error>>>,
28}
29
30/// Requestor for sending ICMP Echo Requests (ping) and receiving replies on Unix systems.
31///
32/// This implementation uses ICMP sockets with Tokio for async operations. It requires
33/// unprivileged ICMP socket support, which is available on macOS by default and on
34/// Linux when the `net.ipv4.ping_group_range` sysctl parameter is properly configured.
35///
36/// The requestor spawns a background task to handle incoming replies and is safe to
37/// clone and use across multiple threads and async tasks.
38///
39/// # Platform Requirements
40///
41/// - **macOS**: Works with unprivileged sockets out of the box
42/// - **Linux**: Requires `net.ipv4.ping_group_range` sysctl to allow unprivileged ICMP sockets
43///
44/// # Examples
45///
46/// ```rust,no_run
47/// use ping_async::IcmpEchoRequestor;
48/// use std::net::IpAddr;
49///
50/// #[tokio::main]
51/// async fn main() -> std::io::Result<()> {
52///     let target = "8.8.8.8".parse::<IpAddr>().unwrap();
53///     let pinger = IcmpEchoRequestor::new(target, None, None, None)?;
54///     
55///     let reply = pinger.send().await?;
56///     println!("Reply: {:?}", reply);
57///     
58///     Ok(())
59/// }
60/// ```
61#[derive(Clone)]
62pub struct IcmpEchoRequestor {
63    inner: Arc<RequestorInner>,
64}
65
66struct RequestorInner {
67    socket: Arc<UdpSocket>,
68    target_addr: IpAddr,
69    timeout: Duration,
70    identifier: u16,
71    sequence: AtomicU16,
72    registry: RequestRegistry,
73    router_abort: OnceLock<tokio::task::AbortHandle>,
74    router_context: RouterContext,
75}
76
77impl IcmpEchoRequestor {
78    /// Creates a new ICMP echo requestor for the specified target address.
79    ///
80    /// # Arguments
81    ///
82    /// * `target_addr` - The IP address to ping (IPv4 or IPv6)
83    /// * `source_addr` - Optional source IP address to bind to. Must match the IP version of `target_addr`
84    /// * `ttl` - Optional Time-To-Live value. Defaults to [`PING_DEFAULT_TTL`](crate::PING_DEFAULT_TTL)
85    /// * `timeout` - Optional timeout duration. Defaults to [`PING_DEFAULT_TIMEOUT`](crate::PING_DEFAULT_TIMEOUT)
86    ///
87    /// # Errors
88    ///
89    /// Returns an error if:
90    /// - The source address type doesn't match the target address type (IPv4 vs IPv6)
91    /// - ICMP socket creation fails (typically due to insufficient permissions)
92    /// - Socket configuration fails
93    ///
94    /// # Platform Requirements
95    ///
96    /// - **Linux**: Requires `net.ipv4.ping_group_range` sysctl parameter to allow unprivileged ICMP sockets.
97    ///   Check with: `sysctl net.ipv4.ping_group_range`
98    /// - **macOS**: Works with unprivileged sockets by default
99    ///
100    /// # Examples
101    ///
102    /// ```rust,no_run
103    /// use ping_async::IcmpEchoRequestor;
104    /// use std::net::IpAddr;
105    /// use std::time::Duration;
106    ///
107    /// // Basic usage with defaults
108    /// let pinger = IcmpEchoRequestor::new(
109    ///     "8.8.8.8".parse().unwrap(),
110    ///     None,
111    ///     None,
112    ///     None
113    /// )?;
114    ///
115    /// // With custom source address and timeout
116    /// let pinger = IcmpEchoRequestor::new(
117    ///     "2001:4860:4860::8888".parse().unwrap(),
118    ///     Some("::1".parse().unwrap()),
119    ///     Some(64),
120    ///     Some(Duration::from_millis(500))
121    /// )?;
122    /// # Ok::<(), std::io::Error>(())
123    /// ```
124    pub fn new(
125        target_addr: IpAddr,
126        source_addr: Option<IpAddr>,
127        ttl: Option<u8>,
128        timeout: Option<Duration>,
129    ) -> io::Result<Self> {
130        // Check if the target address matches the source address type
131        match (target_addr, source_addr) {
132            (IpAddr::V4(_), Some(IpAddr::V6(_))) | (IpAddr::V6(_), Some(IpAddr::V4(_))) => {
133                return Err(io::Error::new(
134                    io::ErrorKind::InvalidInput,
135                    "Source address type does not match target address type",
136                ));
137            }
138            _ => {}
139        }
140
141        let timeout = timeout.unwrap_or(PING_DEFAULT_TIMEOUT);
142        let identifier = random();
143        let sequence = AtomicU16::new(0);
144
145        let socket = Arc::new(create_socket(target_addr, source_addr, ttl)?);
146        let registry = Arc::new(Mutex::new(HashMap::new()));
147
148        // Create a context for the router task
149        let router_context = RouterContext {
150            target_addr,
151            socket: Arc::clone(&socket),
152            registry: Arc::clone(&registry),
153            failed: Arc::new(Mutex::new(None)),
154        };
155
156        Ok(IcmpEchoRequestor {
157            inner: Arc::new(RequestorInner {
158                socket,
159                target_addr,
160                timeout,
161                identifier,
162                sequence,
163                registry,
164                router_abort: OnceLock::new(),
165                router_context,
166            }),
167        })
168    }
169
170    /// Sends an ICMP echo request and waits for a reply.
171    ///
172    /// This method is async and will complete when either:
173    /// - An echo reply is received
174    /// - The configured timeout expires
175    /// - An error occurs
176    ///
177    /// The requestor uses lazy initialization - the background reply router task
178    /// is only spawned on the first call to `send()`. The requestor can be used
179    /// multiple times and is safe to use concurrently from multiple async tasks.
180    ///
181    /// # Returns
182    ///
183    /// Returns an [`IcmpEchoReply`](crate::IcmpEchoReply) containing:
184    /// - The destination IP address
185    /// - The status of the ping operation  
186    /// - The measured round-trip time
187    ///
188    /// # Errors
189    ///
190    /// Returns an error if:
191    /// - The socket send operation fails immediately
192    /// - The background router task has failed (typically due to permission loss)
193    /// - Internal communication channels fail unexpectedly
194    ///
195    /// Note that timeout and unreachable conditions are returned as successful
196    /// `IcmpEchoReply` with appropriate status values, not as errors.
197    ///
198    /// # Examples
199    ///
200    /// ```rust,no_run
201    /// use ping_async::{IcmpEchoRequestor, IcmpEchoStatus};
202    ///
203    /// #[tokio::main]
204    /// async fn main() -> std::io::Result<()> {
205    ///     let pinger = IcmpEchoRequestor::new(
206    ///         "8.8.8.8".parse().unwrap(),
207    ///         None, None, None
208    ///     )?;
209    ///     
210    ///     // Send multiple pings using the same requestor
211    ///     for i in 0..3 {
212    ///         let reply = pinger.send().await?;
213    ///         
214    ///         match reply.status() {
215    ///             IcmpEchoStatus::Success => {
216    ///                 println!("Ping {}: {:?}", i, reply.round_trip_time());
217    ///             }
218    ///             IcmpEchoStatus::TimedOut => {
219    ///                 println!("Ping {} timed out", i);
220    ///             }
221    ///             _ => {
222    ///                 println!("Ping {} failed: {:?}", i, reply.status());
223    ///             }
224    ///         }
225    ///     }
226    ///     
227    ///     Ok(())
228    /// }
229    /// ```
230    pub async fn send(&self) -> io::Result<IcmpEchoReply> {
231        // Check if router failed already
232        if let Some(failed) = self
233            .inner
234            .router_context
235            .failed
236            .lock()
237            .unwrap_or_else(|poisoned| poisoned.into_inner())
238            .take()
239        {
240            return Err(failed);
241        }
242
243        // lazy spawning
244        self.ensure_router_running();
245
246        let sequence = self.inner.sequence.fetch_add(1, Ordering::SeqCst);
247        let key = sequence;
248
249        // Use timestamp as our payload
250        let timestamp = SystemTime::now()
251            .duration_since(UNIX_EPOCH)
252            .map_err(|e| io::Error::other(format!("timestamp error: {e}")))?
253            .as_nanos() as u64;
254        let payload = timestamp.to_be_bytes();
255
256        let packet = IcmpPacket::new_echo_request(
257            self.inner.target_addr,
258            self.inner.identifier,
259            sequence,
260            &payload,
261        );
262
263        let target = SocketAddr::new(self.inner.target_addr, 0);
264        let reply_rx = match self.inner.socket.send_to(packet.as_bytes(), target).await {
265            Ok(_) => {
266                let (tx, rx) = oneshot::channel();
267
268                self.inner
269                    .registry
270                    .lock()
271                    .unwrap_or_else(|poisoned| poisoned.into_inner())
272                    .insert(key, tx);
273
274                rx
275            }
276            Err(e) => match e.kind() {
277                io::ErrorKind::NetworkUnreachable
278                | io::ErrorKind::NetworkDown
279                | io::ErrorKind::HostUnreachable => {
280                    return Ok(IcmpEchoReply::new(
281                        self.inner.target_addr,
282                        IcmpEchoStatus::Unreachable,
283                        Duration::ZERO,
284                    ));
285                }
286                _ => return Err(e),
287            },
288        };
289
290        let timeout = self.inner.timeout;
291        let target_addr = self.inner.target_addr;
292
293        tokio::select! {
294            result = reply_rx => {
295                match result {
296                    Ok(reply) => Ok(reply),
297                    Err(_) => {
298                        // Channel closed - router probably failed
299                        Err(io::Error::other("reply channel closed"))
300                    }
301                }
302            }
303            _ = time::sleep(timeout) => {
304                // Remove from registry on timeout
305                self.inner.registry.lock().unwrap_or_else(|poisoned| poisoned.into_inner()).remove(&key);
306
307                // Calculate RTT for timed out request
308                let now = SystemTime::now()
309                    .duration_since(UNIX_EPOCH)
310                    .map_err(|e| io::Error::other(format!("timestamp error: {e}")))?
311                    .as_nanos() as u64;
312                let rtt = Duration::from_nanos(now - timestamp);
313
314                Ok(IcmpEchoReply::new(
315                    target_addr,
316                    IcmpEchoStatus::TimedOut,
317                    rtt,
318                ))
319            }
320        }
321    }
322
323    fn ensure_router_running(&self) {
324        let target_addr = self.inner.router_context.target_addr;
325        let identifier = self.inner.identifier;
326        let socket = Arc::clone(&self.inner.router_context.socket);
327        let registry = Arc::clone(&self.inner.router_context.registry);
328        let failed = Arc::clone(&self.inner.router_context.failed);
329
330        self.inner.router_abort.get_or_init(|| {
331            let handle = tokio::spawn(reply_router_loop(
332                target_addr,
333                identifier,
334                socket,
335                registry,
336                failed,
337            ));
338            handle.abort_handle()
339        });
340    }
341}
342
343impl Drop for RequestorInner {
344    fn drop(&mut self) {
345        if let Some(abort_handle) = self.router_abort.get() {
346            abort_handle.abort();
347        }
348    }
349}
350
351async fn reply_router_loop(
352    target_addr: IpAddr,
353    identifier: u16,
354    socket: Arc<UdpSocket>,
355    registry: RequestRegistry,
356    failed: Arc<Mutex<Option<io::Error>>>,
357) {
358    loop {
359        let mut buf = vec![0u8; 1024];
360
361        match socket.recv(&mut buf).await {
362            Ok(size) => {
363                buf.truncate(size);
364
365                if let Some(reply_packet) = IcmpPacket::parse_reply(&buf, target_addr) {
366                    // Check if this is a reply to our request by comparing identifier, ignoring if not
367                    // identifier may be rewritten by some implementations (e.g., Linux on Docker Mac)
368                    if reply_packet.identifier() != identifier {
369                        continue;
370                    }
371
372                    // Use sequence number to find the waiting sender
373                    let key = reply_packet.sequence();
374                    let sender = registry
375                        .lock()
376                        .unwrap_or_else(|poisoned| poisoned.into_inner())
377                        .remove(&key);
378
379                    if let Some(sender) = sender {
380                        // Extract timestamp from payload to calculate RTT
381                        let payload = reply_packet.payload();
382
383                        let reply = if payload.len() >= 8 {
384                            let sent_timestamp = u64::from_be_bytes([
385                                payload[0], payload[1], payload[2], payload[3], payload[4],
386                                payload[5], payload[6], payload[7],
387                            ]);
388
389                            let now = SystemTime::now()
390                                .duration_since(UNIX_EPOCH)
391                                .unwrap_or_default()
392                                .as_nanos() as u64;
393                            let rtt = Duration::from_nanos(now.saturating_sub(sent_timestamp));
394
395                            IcmpEchoReply::new(target_addr, IcmpEchoStatus::Success, rtt)
396                        } else {
397                            // Report Unknown error if payload is too short
398                            IcmpEchoReply::new(target_addr, IcmpEchoStatus::Unknown, Duration::ZERO)
399                        };
400
401                        // Send reply to waiting thread
402                        let _ = sender.send(reply);
403                    }
404                }
405            }
406            Err(e) => {
407                match e.kind() {
408                    // Fatal errors - router cannot continue
409                    io::ErrorKind::PermissionDenied |        // Lost privileges
410                    io::ErrorKind::AddrNotAvailable |        // Address no longer available
411                    io::ErrorKind::ConnectionAborted |       // Socket forcibly closed
412                    io::ErrorKind::NotConnected => {         // Socket disconnected
413                        // Clear pending requests so they don't hang
414                        registry.lock().unwrap_or_else(|poisoned| poisoned.into_inner()).clear();
415
416                        // Mark the failed flag
417                        let mut failed_lock = failed.lock().unwrap_or_else(|poisoned| poisoned.into_inner());
418                        *failed_lock = Some(e);
419
420                        return;
421                    }
422
423                    // Continue with temporary network issues, etc.
424                    _ => continue,
425                }
426            }
427        }
428    }
429}
430
431fn create_socket(
432    target_addr: IpAddr,
433    source_addr: Option<IpAddr>,
434    ttl: Option<u8>,
435) -> io::Result<UdpSocket> {
436    let socket = match target_addr {
437        IpAddr::V4(_) => Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::ICMPV4))?,
438        IpAddr::V6(_) => Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::ICMPV6))?,
439    };
440    socket.set_nonblocking(true)?;
441
442    let ttl = ttl.unwrap_or(PING_DEFAULT_TTL);
443    if target_addr.is_ipv4() {
444        socket.set_ttl_v4(ttl as u32)?;
445    } else {
446        socket.set_unicast_hops_v6(ttl as u32)?;
447    }
448
449    // Bind the socket to the source address if provided
450    if let Some(source_addr) = source_addr {
451        socket.bind(&SocketAddr::new(source_addr, 0).into())?;
452    }
453
454    UdpSocket::from_std(socket.into())
455}
456
457#[cfg(test)]
458mod tests {
459    use super::*;
460    use std::io;
461
462    #[cfg(test)]
463    fn is_router_spawned(pinger: &IcmpEchoRequestor) -> bool {
464        pinger.inner.router_abort.get().is_some()
465    }
466
467    #[tokio::test]
468    async fn test_lazy_router_spawning() -> io::Result<()> {
469        // Create a requestor but don't call send() yet
470        let pinger = IcmpEchoRequestor::new("127.0.0.1".parse().unwrap(), None, None, None)?;
471
472        // Router should not be spawned yet - this is the key test for lazy initialization
473        assert!(
474            !is_router_spawned(&pinger),
475            "Router should not be spawned after new()"
476        );
477
478        // Now call send() - this should trigger lazy router spawning
479        let reply = pinger.send().await?;
480        assert_eq!(reply.destination(), "127.0.0.1".parse::<IpAddr>().unwrap());
481
482        // Verify router is now spawned
483        assert!(
484            is_router_spawned(&pinger),
485            "Router should be spawned after first send()"
486        );
487
488        // Subsequent sends should reuse the same router
489        let reply2 = pinger.send().await?;
490        assert_eq!(reply2.destination(), "127.0.0.1".parse::<IpAddr>().unwrap());
491
492        // Router should still be spawned
493        assert!(
494            is_router_spawned(&pinger),
495            "Router should remain spawned after subsequent sends"
496        );
497
498        Ok(())
499    }
500}