Skip to main content

tiny_ping/
ping.rs

1use std::{
2    mem::MaybeUninit,
3    net::{IpAddr, SocketAddr},
4    sync::atomic::{AtomicU16, Ordering},
5    time::{Duration, Instant},
6};
7
8use tokio::time::timeout;
9
10use crate::error::{Error, Result};
11use crate::icmp::{EchoReply, EchoRequest};
12use crate::socket::AsyncSocket;
13
14pub use crate::socket::SocketType;
15
16const DEFAULT_PAYLOAD_SIZE: usize = 56;
17const DEFAULT_TIMEOUT: Duration = Duration::from_secs(2);
18const TOKEN_SIZE: usize = 8;
19
20static NEXT_IDENT: AtomicU16 = AtomicU16::new(1);
21
22#[derive(Clone, Debug, Eq, PartialEq)]
23#[non_exhaustive]
24pub struct PingResult {
25    pub reply: EchoReply,
26    pub rtt: Duration,
27    pub socket_type: SocketType,
28}
29
30/// A Ping struct represents the state of one particular ping instance.
31#[derive(Debug, Clone)]
32pub struct Pinger {
33    target: SocketAddr,
34    ident: u16,
35    size: usize,
36    timeout: Duration,
37    ttl: Option<u32>,
38    socket: AsyncSocket,
39}
40
41impl Pinger {
42    /// Creates a new raw-socket ping instance from `IpAddr`.
43    pub fn new(host: IpAddr) -> Result<Pinger> {
44        Self::with_socket_type(host, SocketType::Raw)
45    }
46
47    /// Creates a new ping instance using a specific socket type.
48    pub fn with_socket_type(host: IpAddr, socket_type: SocketType) -> Result<Pinger> {
49        Self::with_socket_addr(SocketAddr::new(host, 0), socket_type)
50    }
51
52    /// Creates a new ping instance using a specific socket address and socket type.
53    ///
54    /// The port is ignored. For IPv6, callers can use this to provide a
55    /// `SocketAddrV6` scope ID, for example when targeting link-local multicast.
56    pub fn with_socket_addr(target: SocketAddr, socket_type: SocketType) -> Result<Pinger> {
57        Ok(Pinger {
58            target,
59            ident: default_ident(),
60            size: DEFAULT_PAYLOAD_SIZE,
61            timeout: DEFAULT_TIMEOUT,
62            ttl: None,
63            socket: AsyncSocket::new(target.ip(), socket_type)?,
64        })
65    }
66
67    /// Changes the socket type and recreates the underlying socket.
68    pub fn socket_type(&mut self, socket_type: SocketType) -> Result<&mut Pinger> {
69        let socket = AsyncSocket::new(self.target.ip(), socket_type)?;
70        if let Some(ttl) = self.ttl {
71            socket.set_ttl(self.target.ip(), ttl)?;
72        }
73        self.socket = socket;
74        Ok(self)
75    }
76
77    /// Returns the active socket type.
78    pub fn active_socket_type(&self) -> SocketType {
79        self.socket.socket_type()
80    }
81
82    /// Sets the value for the `SO_BINDTODEVICE` option on this socket.
83    ///
84    /// If a socket is bound to an interface, only packets received from that
85    /// particular interface are processed by the socket. Note that this only
86    /// works for some socket types, particularly `AF_INET` sockets.
87    ///
88    /// If `interface` is `None` or an empty string it removes the binding.
89    ///
90    /// This function is only available on Fuchsia and Linux.
91    #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
92    pub fn bind_device(&mut self, interface: Option<&[u8]>) -> Result<&mut Pinger> {
93        self.socket.bind_device(interface)?;
94        Ok(self)
95    }
96
97    /// Set the identification of ICMP.
98    pub fn ident(&mut self, val: u16) -> &mut Pinger {
99        self.ident = val;
100        self
101    }
102
103    /// Set the packet payload size in bytes. (default: 56)
104    pub fn size(&mut self, size: usize) -> &mut Pinger {
105        self.size = size;
106        self
107    }
108
109    /// Set the timeout of each ping. (default: 2s)
110    pub fn timeout(&mut self, timeout: Duration) -> &mut Pinger {
111        self.timeout = timeout;
112        self
113    }
114
115    /// Set the outgoing IPv4 TTL or IPv6 unicast hop limit.
116    pub fn ttl(&mut self, ttl: u32) -> Result<&mut Pinger> {
117        self.socket.set_ttl(self.target.ip(), ttl)?;
118        self.ttl = Some(ttl);
119        Ok(self)
120    }
121
122    async fn recv_reply(&self, seq_cnt: u16, payload: &[u8]) -> Result<EchoReply> {
123        let mut buffer = [MaybeUninit::new(0); 2048];
124        loop {
125            let (size, source) = self.socket.recv_from(&mut buffer).await?;
126            let buf = unsafe { assume_init(&buffer[..size]) };
127            let source = source.map(|addr| addr.ip()).unwrap_or(self.target.ip());
128            let decoded = match self.socket.socket_type() {
129                SocketType::Raw if self.target.ip().is_ipv6() => EchoReply::decode_raw(source, buf),
130                SocketType::Raw => EchoReply::decode_raw(self.target.ip(), buf),
131                SocketType::Dgram => EchoReply::decode_dgram(source, buf),
132            };
133
134            match decoded {
135                Ok(reply) if self.reply_matches(&reply, seq_cnt, payload) => return Ok(reply),
136                Ok(_) => continue,
137                Err(Error::InvalidPacket)
138                | Err(Error::NotEchoReply)
139                | Err(Error::NotV6EchoReply)
140                | Err(Error::OtherICMP)
141                | Err(Error::UnknownProtocol) => continue,
142                Err(e) => return Err(e),
143            }
144        }
145    }
146
147    fn reply_matches(&self, reply: &EchoReply, seq_cnt: u16, payload: &[u8]) -> bool {
148        if reply.sequence != seq_cnt {
149            return false;
150        }
151
152        if self.socket.socket_type() == SocketType::Raw && reply.identifier != self.ident {
153            return false;
154        }
155
156        payload.is_empty() || reply.payload == payload
157    }
158
159    async fn send_request(&self, seq_cnt: u16, payload: &[u8]) -> Result<Instant> {
160        let packet =
161            EchoRequest::new(self.target.ip(), self.ident, seq_cnt).encode_with_payload(payload)?;
162
163        let sent = Instant::now();
164        let size = self.socket.send_to(&packet, &self.target.into()).await?;
165        if size != packet.len() {
166            return Err(Error::InvalidSize);
167        }
168
169        Ok(sent)
170    }
171
172    /// Send a ping request with sequence number.
173    pub async fn ping(&self, seq_cnt: u16) -> Result<PingResult> {
174        let payload = request_payload(self.ident, seq_cnt, self.size);
175        let sent = self.send_request(seq_cnt, &payload).await?;
176
177        let reply = timeout(self.timeout, self.recv_reply(seq_cnt, &payload))
178            .await
179            .map_err(|_| Error::Timeout)??;
180
181        Ok(PingResult {
182            reply,
183            rtt: sent.elapsed(),
184            socket_type: self.socket.socket_type(),
185        })
186    }
187
188    /// Send one ping request and collect all matching replies until timeout.
189    ///
190    /// This is useful for multicast targets where more than one host can reply
191    /// to the same echo request. Unlike [`Pinger::ping`], a timeout after the
192    /// request is sent is not an error; it ends collection and returns the
193    /// replies seen so far.
194    pub async fn ping_replies(&self, seq_cnt: u16) -> Result<Vec<PingResult>> {
195        let payload = request_payload(self.ident, seq_cnt, self.size);
196        let sent = self.send_request(seq_cnt, &payload).await?;
197        let deadline = sent + self.timeout;
198        let mut replies = Vec::new();
199
200        while let Some(remaining) = deadline.checked_duration_since(Instant::now()) {
201            let reply = match timeout(remaining, self.recv_reply(seq_cnt, &payload)).await {
202                Ok(reply) => reply?,
203                Err(_) => break,
204            };
205
206            replies.push(PingResult {
207                reply,
208                rtt: sent.elapsed(),
209                socket_type: self.socket.socket_type(),
210            });
211        }
212
213        Ok(replies)
214    }
215}
216
217fn default_ident() -> u16 {
218    let pid = std::process::id() as u16;
219    let next = NEXT_IDENT.fetch_add(1, Ordering::Relaxed);
220    pid.wrapping_add(next)
221}
222
223fn request_payload(ident: u16, seq_cnt: u16, size: usize) -> Vec<u8> {
224    let mut payload = vec![0; size];
225    let token = [
226        b't',
227        b'p',
228        (ident >> 8) as u8,
229        ident as u8,
230        (seq_cnt >> 8) as u8,
231        seq_cnt as u8,
232        (size >> 8) as u8,
233        size as u8,
234    ];
235    let len = payload.len().min(TOKEN_SIZE);
236    payload[..len].copy_from_slice(&token[..len]);
237    payload
238}
239
240/// Assume the `buf`fer to be initialised.
241///
242/// # Safety
243///
244/// `socket2` initialises exactly the number of bytes returned by `recv_from`.
245unsafe fn assume_init(buf: &[MaybeUninit<u8>]) -> &[u8] {
246    unsafe { &*(buf as *const [MaybeUninit<u8>] as *const [u8]) }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252
253    #[test]
254    fn request_payload_respects_size() {
255        assert_eq!(request_payload(1, 2, 0), Vec::<u8>::new());
256        assert_eq!(request_payload(1, 2, 4), vec![b't', b'p', 0, 1]);
257        assert_eq!(request_payload(1, 2, 8), vec![b't', b'p', 0, 1, 0, 2, 0, 8]);
258    }
259}