1use std::{
2 mem::MaybeUninit,
3 net::{IpAddr, SocketAddr},
4 sync::atomic::{AtomicU16, Ordering},
5 time::{Duration, Instant},
6};
7
8use tokio::time::timeout;
9
10use crate::error::{Error, Result};
11use crate::icmp::{EchoReply, EchoRequest};
12use crate::socket::AsyncSocket;
13
14pub use crate::socket::SocketType;
15
16const DEFAULT_PAYLOAD_SIZE: usize = 56;
17const DEFAULT_TIMEOUT: Duration = Duration::from_secs(2);
18const TOKEN_SIZE: usize = 8;
19
20static NEXT_IDENT: AtomicU16 = AtomicU16::new(1);
21
22#[derive(Clone, Debug, Eq, PartialEq)]
23#[non_exhaustive]
24pub struct PingResult {
25 pub reply: EchoReply,
26 pub rtt: Duration,
27 pub socket_type: SocketType,
28}
29
30#[derive(Debug, Clone)]
32pub struct Pinger {
33 host: IpAddr,
34 ident: u16,
35 size: usize,
36 timeout: Duration,
37 ttl: Option<u32>,
38 socket: AsyncSocket,
39}
40
41impl Pinger {
42 pub fn new(host: IpAddr) -> Result<Pinger> {
44 Self::with_socket_type(host, SocketType::Raw)
45 }
46
47 pub fn with_socket_type(host: IpAddr, socket_type: SocketType) -> Result<Pinger> {
49 Ok(Pinger {
50 host,
51 ident: default_ident(),
52 size: DEFAULT_PAYLOAD_SIZE,
53 timeout: DEFAULT_TIMEOUT,
54 ttl: None,
55 socket: AsyncSocket::new(host, socket_type)?,
56 })
57 }
58
59 pub fn socket_type(&mut self, socket_type: SocketType) -> Result<&mut Pinger> {
61 let socket = AsyncSocket::new(self.host, socket_type)?;
62 if let Some(ttl) = self.ttl {
63 socket.set_ttl(self.host, ttl)?;
64 }
65 self.socket = socket;
66 Ok(self)
67 }
68
69 pub fn active_socket_type(&self) -> SocketType {
71 self.socket.socket_type()
72 }
73
74 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
84 pub fn bind_device(&mut self, interface: Option<&[u8]>) -> Result<&mut Pinger> {
85 self.socket.bind_device(interface)?;
86 Ok(self)
87 }
88
89 pub fn ident(&mut self, val: u16) -> &mut Pinger {
91 self.ident = val;
92 self
93 }
94
95 pub fn size(&mut self, size: usize) -> &mut Pinger {
97 self.size = size;
98 self
99 }
100
101 pub fn timeout(&mut self, timeout: Duration) -> &mut Pinger {
103 self.timeout = timeout;
104 self
105 }
106
107 pub fn ttl(&mut self, ttl: u32) -> Result<&mut Pinger> {
109 self.socket.set_ttl(self.host, ttl)?;
110 self.ttl = Some(ttl);
111 Ok(self)
112 }
113
114 async fn recv_reply(&self, seq_cnt: u16, payload: &[u8]) -> Result<EchoReply> {
115 let mut buffer = [MaybeUninit::new(0); 2048];
116 loop {
117 let (size, source) = self.socket.recv_from(&mut buffer).await?;
118 let buf = unsafe { assume_init(&buffer[..size]) };
119 let source = source.map(|addr| addr.ip()).unwrap_or(self.host);
120 let decoded = match self.socket.socket_type() {
121 SocketType::Raw if self.host.is_ipv6() => EchoReply::decode_raw(source, buf),
122 SocketType::Raw => EchoReply::decode_raw(self.host, buf),
123 SocketType::Dgram => EchoReply::decode_dgram(source, buf),
124 };
125
126 match decoded {
127 Ok(reply) if self.reply_matches(&reply, seq_cnt, payload) => return Ok(reply),
128 Ok(_) => continue,
129 Err(Error::InvalidPacket)
130 | Err(Error::NotEchoReply)
131 | Err(Error::NotV6EchoReply)
132 | Err(Error::OtherICMP)
133 | Err(Error::UnknownProtocol) => continue,
134 Err(e) => return Err(e),
135 }
136 }
137 }
138
139 fn reply_matches(&self, reply: &EchoReply, seq_cnt: u16, payload: &[u8]) -> bool {
140 if reply.sequence != seq_cnt {
141 return false;
142 }
143
144 if self.socket.socket_type() == SocketType::Raw && reply.identifier != self.ident {
145 return false;
146 }
147
148 payload.is_empty() || reply.payload == payload
149 }
150
151 pub async fn ping(&self, seq_cnt: u16) -> Result<PingResult> {
153 let payload = request_payload(self.ident, seq_cnt, self.size);
154 let packet =
155 EchoRequest::new(self.host, self.ident, seq_cnt).encode_with_payload(&payload)?;
156 let sock_addr = SocketAddr::new(self.host, 0);
157
158 let sent = Instant::now();
159 let size = self.socket.send_to(&packet, &sock_addr.into()).await?;
160 if size != packet.len() {
161 return Err(Error::InvalidSize);
162 }
163
164 let reply = timeout(self.timeout, self.recv_reply(seq_cnt, &payload))
165 .await
166 .map_err(|_| Error::Timeout)??;
167
168 Ok(PingResult {
169 reply,
170 rtt: sent.elapsed(),
171 socket_type: self.socket.socket_type(),
172 })
173 }
174}
175
176fn default_ident() -> u16 {
177 let pid = std::process::id() as u16;
178 let next = NEXT_IDENT.fetch_add(1, Ordering::Relaxed);
179 pid.wrapping_add(next)
180}
181
182fn request_payload(ident: u16, seq_cnt: u16, size: usize) -> Vec<u8> {
183 let mut payload = vec![0; size];
184 let token = [
185 b't',
186 b'p',
187 (ident >> 8) as u8,
188 ident as u8,
189 (seq_cnt >> 8) as u8,
190 seq_cnt as u8,
191 (size >> 8) as u8,
192 size as u8,
193 ];
194 let len = payload.len().min(TOKEN_SIZE);
195 payload[..len].copy_from_slice(&token[..len]);
196 payload
197}
198
199unsafe fn assume_init(buf: &[MaybeUninit<u8>]) -> &[u8] {
205 unsafe { &*(buf as *const [MaybeUninit<u8>] as *const [u8]) }
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211
212 #[test]
213 fn request_payload_respects_size() {
214 assert_eq!(request_payload(1, 2, 0), Vec::<u8>::new());
215 assert_eq!(request_payload(1, 2, 4), vec![b't', b'p', 0, 1]);
216 assert_eq!(request_payload(1, 2, 8), vec![b't', b'p', 0, 1, 0, 2, 0, 8]);
217 }
218}