Skip to main content

tiny_ping/
ping.rs

1use std::{
2    mem::MaybeUninit,
3    net::{IpAddr, SocketAddr},
4    sync::atomic::{AtomicU16, Ordering},
5    time::{Duration, Instant},
6};
7
8use tokio::time::timeout;
9
10use crate::error::{Error, Result};
11use crate::icmp::{EchoReply, EchoRequest};
12use crate::socket::AsyncSocket;
13
14pub use crate::socket::SocketType;
15
16const DEFAULT_PAYLOAD_SIZE: usize = 56;
17const DEFAULT_TIMEOUT: Duration = Duration::from_secs(2);
18const TOKEN_SIZE: usize = 8;
19
20static NEXT_IDENT: AtomicU16 = AtomicU16::new(1);
21
22#[derive(Clone, Debug, Eq, PartialEq)]
23#[non_exhaustive]
24pub struct PingResult {
25    pub reply: EchoReply,
26    pub rtt: Duration,
27    pub socket_type: SocketType,
28}
29
30/// A Ping struct represents the state of one particular ping instance.
31#[derive(Debug, Clone)]
32pub struct Pinger {
33    host: IpAddr,
34    ident: u16,
35    size: usize,
36    timeout: Duration,
37    ttl: Option<u32>,
38    socket: AsyncSocket,
39}
40
41impl Pinger {
42    /// Creates a new raw-socket ping instance from `IpAddr`.
43    pub fn new(host: IpAddr) -> Result<Pinger> {
44        Self::with_socket_type(host, SocketType::Raw)
45    }
46
47    /// Creates a new ping instance using a specific socket type.
48    pub fn with_socket_type(host: IpAddr, socket_type: SocketType) -> Result<Pinger> {
49        Ok(Pinger {
50            host,
51            ident: default_ident(),
52            size: DEFAULT_PAYLOAD_SIZE,
53            timeout: DEFAULT_TIMEOUT,
54            ttl: None,
55            socket: AsyncSocket::new(host, socket_type)?,
56        })
57    }
58
59    /// Changes the socket type and recreates the underlying socket.
60    pub fn socket_type(&mut self, socket_type: SocketType) -> Result<&mut Pinger> {
61        let socket = AsyncSocket::new(self.host, socket_type)?;
62        if let Some(ttl) = self.ttl {
63            socket.set_ttl(self.host, ttl)?;
64        }
65        self.socket = socket;
66        Ok(self)
67    }
68
69    /// Returns the active socket type.
70    pub fn active_socket_type(&self) -> SocketType {
71        self.socket.socket_type()
72    }
73
74    /// Sets the value for the `SO_BINDTODEVICE` option on this socket.
75    ///
76    /// If a socket is bound to an interface, only packets received from that
77    /// particular interface are processed by the socket. Note that this only
78    /// works for some socket types, particularly `AF_INET` sockets.
79    ///
80    /// If `interface` is `None` or an empty string it removes the binding.
81    ///
82    /// This function is only available on Fuchsia and Linux.
83    #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
84    pub fn bind_device(&mut self, interface: Option<&[u8]>) -> Result<&mut Pinger> {
85        self.socket.bind_device(interface)?;
86        Ok(self)
87    }
88
89    /// Set the identification of ICMP.
90    pub fn ident(&mut self, val: u16) -> &mut Pinger {
91        self.ident = val;
92        self
93    }
94
95    /// Set the packet payload size in bytes. (default: 56)
96    pub fn size(&mut self, size: usize) -> &mut Pinger {
97        self.size = size;
98        self
99    }
100
101    /// Set the timeout of each ping. (default: 2s)
102    pub fn timeout(&mut self, timeout: Duration) -> &mut Pinger {
103        self.timeout = timeout;
104        self
105    }
106
107    /// Set the outgoing IPv4 TTL or IPv6 unicast hop limit.
108    pub fn ttl(&mut self, ttl: u32) -> Result<&mut Pinger> {
109        self.socket.set_ttl(self.host, ttl)?;
110        self.ttl = Some(ttl);
111        Ok(self)
112    }
113
114    async fn recv_reply(&self, seq_cnt: u16, payload: &[u8]) -> Result<EchoReply> {
115        let mut buffer = [MaybeUninit::new(0); 2048];
116        loop {
117            let (size, source) = self.socket.recv_from(&mut buffer).await?;
118            let buf = unsafe { assume_init(&buffer[..size]) };
119            let source = source.map(|addr| addr.ip()).unwrap_or(self.host);
120            let decoded = match self.socket.socket_type() {
121                SocketType::Raw if self.host.is_ipv6() => EchoReply::decode_raw(source, buf),
122                SocketType::Raw => EchoReply::decode_raw(self.host, buf),
123                SocketType::Dgram => EchoReply::decode_dgram(source, buf),
124            };
125
126            match decoded {
127                Ok(reply) if self.reply_matches(&reply, seq_cnt, payload) => return Ok(reply),
128                Ok(_) => continue,
129                Err(Error::InvalidPacket)
130                | Err(Error::NotEchoReply)
131                | Err(Error::NotV6EchoReply)
132                | Err(Error::OtherICMP)
133                | Err(Error::UnknownProtocol) => continue,
134                Err(e) => return Err(e),
135            }
136        }
137    }
138
139    fn reply_matches(&self, reply: &EchoReply, seq_cnt: u16, payload: &[u8]) -> bool {
140        if reply.sequence != seq_cnt {
141            return false;
142        }
143
144        if self.socket.socket_type() == SocketType::Raw && reply.identifier != self.ident {
145            return false;
146        }
147
148        payload.is_empty() || reply.payload == payload
149    }
150
151    /// Send a ping request with sequence number.
152    pub async fn ping(&self, seq_cnt: u16) -> Result<PingResult> {
153        let payload = request_payload(self.ident, seq_cnt, self.size);
154        let packet =
155            EchoRequest::new(self.host, self.ident, seq_cnt).encode_with_payload(&payload)?;
156        let sock_addr = SocketAddr::new(self.host, 0);
157
158        let sent = Instant::now();
159        let size = self.socket.send_to(&packet, &sock_addr.into()).await?;
160        if size != packet.len() {
161            return Err(Error::InvalidSize);
162        }
163
164        let reply = timeout(self.timeout, self.recv_reply(seq_cnt, &payload))
165            .await
166            .map_err(|_| Error::Timeout)??;
167
168        Ok(PingResult {
169            reply,
170            rtt: sent.elapsed(),
171            socket_type: self.socket.socket_type(),
172        })
173    }
174}
175
176fn default_ident() -> u16 {
177    let pid = std::process::id() as u16;
178    let next = NEXT_IDENT.fetch_add(1, Ordering::Relaxed);
179    pid.wrapping_add(next)
180}
181
182fn request_payload(ident: u16, seq_cnt: u16, size: usize) -> Vec<u8> {
183    let mut payload = vec![0; size];
184    let token = [
185        b't',
186        b'p',
187        (ident >> 8) as u8,
188        ident as u8,
189        (seq_cnt >> 8) as u8,
190        seq_cnt as u8,
191        (size >> 8) as u8,
192        size as u8,
193    ];
194    let len = payload.len().min(TOKEN_SIZE);
195    payload[..len].copy_from_slice(&token[..len]);
196    payload
197}
198
199/// Assume the `buf`fer to be initialised.
200///
201/// # Safety
202///
203/// `socket2` initialises exactly the number of bytes returned by `recv_from`.
204unsafe fn assume_init(buf: &[MaybeUninit<u8>]) -> &[u8] {
205    unsafe { &*(buf as *const [MaybeUninit<u8>] as *const [u8]) }
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211
212    #[test]
213    fn request_payload_respects_size() {
214        assert_eq!(request_payload(1, 2, 0), Vec::<u8>::new());
215        assert_eq!(request_payload(1, 2, 4), vec![b't', b'p', 0, 1]);
216        assert_eq!(request_payload(1, 2, 8), vec![b't', b'p', 0, 1, 0, 2, 0, 8]);
217    }
218}