1pub use icmp_client;
2pub use icmp_packet;
3
4use core::time::Duration;
5use std::{
6 collections::HashMap,
7 io::{Error as IoError, ErrorKind as IoErrorKind},
8 net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
9 sync::Arc,
10 time::Instant,
11};
12
13use icmp_client::{AsyncClient, AsyncClientWithConfigError, Config as ClientConfig};
14use icmp_packet::{
15 icmpv4::ParseError as Icmpv4ParseError, icmpv6::ParseError as Icmpv6ParseError, Icmp, Icmpv4,
16 Icmpv6, PayloadLengthDelimitedEchoRequest,
17};
18use tokio::sync::{
19 mpsc::{self, Sender},
20 Mutex,
21};
22use tracing::{event, Level};
23
24type V4RecvFromMap =
26 Arc<Mutex<HashMap<SocketAddr, Sender<(Result<Icmpv4, Icmpv4ParseError>, Instant)>>>>;
27type V6RecvFromMap =
28 Arc<Mutex<HashMap<SocketAddr, Sender<(Result<Icmpv6, Icmpv6ParseError>, Instant)>>>>;
29
30pub struct PingClient<C>
32where
33 C: AsyncClient,
34{
35 v4_client: Option<Arc<C>>,
36 v6_client: Option<Arc<C>>,
37 v4_recv_from_map: V4RecvFromMap,
38 v6_recv_from_map: V6RecvFromMap,
39}
40
41impl<C> core::fmt::Debug for PingClient<C>
42where
43 C: AsyncClient,
44{
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 f.debug_struct("PingClient").finish()
47 }
48}
49
50impl<C> Clone for PingClient<C>
51where
52 C: AsyncClient,
53{
54 fn clone(&self) -> Self {
55 Self {
56 v4_client: self.v4_client.clone(),
57 v6_client: self.v6_client.clone(),
58 v4_recv_from_map: self.v4_recv_from_map.clone(),
59 v6_recv_from_map: self.v6_recv_from_map.clone(),
60 }
61 }
62}
63
64impl<C> PingClient<C>
65where
66 C: AsyncClient,
67{
68 pub fn new(
69 v4_client_config: Option<ClientConfig>,
70 v6_client_config: Option<ClientConfig>,
71 ) -> Result<Self, AsyncClientWithConfigError> {
72 let v4_client = if let Some(mut v4_client_config) = v4_client_config {
73 if v4_client_config.is_ipv6() {
74 return Err(IoError::new(IoErrorKind::Other, "v4_client_config invalid").into());
75 }
76 if v4_client_config.bind.is_none() {
77 v4_client_config.bind =
78 Some(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0).into());
79 }
80
81 Some(Arc::new(C::with_config(&v4_client_config)?))
82 } else {
83 None
84 };
85
86 let v6_client = if let Some(mut v6_client_config) = v6_client_config {
87 if !v6_client_config.is_ipv6() {
88 return Err(IoError::new(IoErrorKind::Other, "v4_client_config invalid").into());
89 }
90 if v6_client_config.bind.is_none() {
91 v6_client_config.bind =
92 Some(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0), 0, 0, 0).into());
93 }
94
95 Some(Arc::new(C::with_config(&v6_client_config)?))
96 } else {
97 None
98 };
99
100 let v4_recv_from_map = Arc::new(Mutex::new(HashMap::new()));
101 let v6_recv_from_map = Arc::new(Mutex::new(HashMap::new()));
102
103 Ok(Self {
104 v4_client,
105 v6_client,
106 v4_recv_from_map,
107 v6_recv_from_map,
108 })
109 }
110
111 pub async fn handle_v4_recv_from(&self) {
113 let v4_client = match self.v4_client.as_ref() {
114 Some(x) => x,
115 None => return,
116 };
117
118 let mut buf = [0; 2048];
119 let bytes_present_map: Arc<Mutex<HashMap<SocketAddr, Vec<u8>>>> =
120 Arc::new(Mutex::new(HashMap::new()));
121
122 loop {
123 match v4_client.recv_from(&mut buf).await {
124 Ok((n, addr)) => {
125 let instant_end = Instant::now();
126 let bytes_read = buf[..n].to_owned();
127
128 let v4_recv_from_map = self.v4_recv_from_map.clone();
129 let bytes_present_map = bytes_present_map.clone();
130
131 tokio::spawn(async move {
132 let bytes = if let Some(mut bytes_present) =
133 bytes_present_map.lock().await.remove(&addr)
134 {
135 bytes_present.extend_from_slice(&bytes_read);
136 bytes_present
137 } else {
138 bytes_read
139 };
140
141 match Icmpv4::parse_from_packet_bytes(&bytes) {
142 Ok(Some(icmpv4)) => {
143 if let Some(tx) = v4_recv_from_map.lock().await.remove(&addr) {
144 if let Err(err) = tx.try_send((Ok(icmpv4), instant_end)) {
145 event!(
146 Level::ERROR,
147 "tx.send failed, err:{err} addr:{addr}"
148 );
149 }
150 } else {
151 event!(
152 Level::WARN,
153 "v4_recv_from_map.remove None, addr:{addr}"
154 );
155 }
156 }
157 Ok(None) => {
158 bytes_present_map.lock().await.insert(addr, bytes);
159 }
160 Err(err) => {
161 if let Some(tx) = v4_recv_from_map.lock().await.remove(&addr) {
162 if let Err(err) = tx.try_send((Err(err), instant_end)) {
163 event!(
164 Level::ERROR,
165 "tx.send failed, err:{err} addr:{addr}"
166 );
167 }
168 } else {
169 event!(
170 Level::WARN,
171 "v4_recv_from_map.remove None, addr:{addr}"
172 );
173 }
174 }
175 }
176 });
177 }
178 Err(err) => {
179 event!(Level::ERROR, "v4_client.recv_from failed, err:{err}");
180 }
181 }
182 }
183 }
184
185 pub async fn handle_v6_recv_from(&self) {
186 let v6_client = match self.v6_client.as_ref() {
187 Some(x) => x,
188 None => return,
189 };
190
191 let mut buf = [0; 2048];
192 let bytes_present_map: Arc<Mutex<HashMap<SocketAddr, Vec<u8>>>> =
193 Arc::new(Mutex::new(HashMap::new()));
194
195 loop {
196 match v6_client.recv_from(&mut buf).await {
197 Ok((n, addr)) => {
198 let instant_end = Instant::now();
199 let bytes_read = buf[..n].to_owned();
200
201 let v6_recv_from_map = self.v6_recv_from_map.clone();
202 let bytes_present_map = bytes_present_map.clone();
203
204 tokio::spawn(async move {
205 let bytes = if let Some(mut bytes_present) =
206 bytes_present_map.lock().await.remove(&addr)
207 {
208 bytes_present.extend_from_slice(&bytes_read);
209 bytes_present
210 } else {
211 bytes_read
212 };
213
214 match Icmpv6::parse_from_packet_bytes(&bytes) {
215 Ok(Some(icmpv6)) => {
216 if let Some(tx) = v6_recv_from_map.lock().await.remove(&addr) {
217 if let Err(err) = tx.try_send((Ok(icmpv6), instant_end)) {
218 event!(
219 Level::ERROR,
220 "tx.send failed, err:{err} addr:{addr}"
221 );
222 }
223 } else {
224 event!(
225 Level::WARN,
226 "v6_recv_from_map.remove None, addr:{addr}"
227 );
228 }
229 }
230 Ok(None) => {
231 bytes_present_map.lock().await.insert(addr, bytes);
232 }
233 Err(err) => {
234 if let Some(tx) = v6_recv_from_map.lock().await.remove(&addr) {
235 if let Err(err) = tx.try_send((Err(err), instant_end)) {
236 event!(
237 Level::ERROR,
238 "tx.send failed, err:{err} addr:{addr}"
239 );
240 }
241 } else {
242 event!(
243 Level::WARN,
244 "v6_recv_from_map.remove None, addr:{addr}"
245 );
246 }
247 }
248 }
249 });
250 }
251 Err(err) => {
252 event!(Level::ERROR, "v6_client.recv_from failed, err:{err}");
253 }
254 }
255 }
256 }
257
258 pub async fn ping(
259 &self,
260 ip: IpAddr,
261 identifier: Option<u16>,
262 sequence_number: Option<u16>,
263 payload: impl AsRef<[u8]>,
264 timeout_dur: Duration,
265 ) -> Result<(Icmp, Duration), PingError> {
266 let echo_request = PayloadLengthDelimitedEchoRequest::new(
268 identifier.map(Into::into),
269 sequence_number.map(Into::into),
270 payload,
271 );
272 let echo_request_bytes = match ip {
273 IpAddr::V4(_) => echo_request.render_v4_packet_bytes(),
274 IpAddr::V6(_) => echo_request.render_v6_packet_bytes(),
275 };
276
277 let rx = match ip {
279 IpAddr::V4(_) => {
280 let (tx, rx) = mpsc::channel(1);
281
282 self.v4_recv_from_map
283 .lock()
284 .await
285 .insert((ip, 0).into(), tx);
286
287 Ok(rx)
288 }
289 IpAddr::V6(_) => {
290 let (tx, rx) = mpsc::channel(1);
291
292 self.v6_recv_from_map
293 .lock()
294 .await
295 .insert((ip, 0).into(), tx);
296
297 Err(rx)
298 }
299 };
300
301 let client = match ip {
303 IpAddr::V4(_) => self.v4_client.as_ref().ok_or(PingError::NoV4Client)?,
304 IpAddr::V6(_) => self.v6_client.as_ref().ok_or(PingError::NoV6Client)?,
305 };
306
307 let instant_begin = Instant::now();
308
309 {
310 let mut n_write = 0;
311 while !echo_request_bytes[n_write..].is_empty() {
312 let n = client
313 .send_to(&echo_request_bytes[n_write..], (ip, 0))
314 .await
315 .map_err(PingError::Send)?;
316 n_write += n;
317
318 if n == 0 {
319 return Err(PingError::Send(IoErrorKind::WriteZero.into()));
320 }
321 }
322 }
323
324 match rx {
326 Ok(mut rx) => {
327 match tokio::time::timeout(
328 tokio::time::Duration::from_millis(timeout_dur.as_millis() as u64),
329 rx.recv(),
330 )
331 .await
332 {
333 Ok(Some((Ok(icmpv4), instant_end))) => Ok((
334 Icmp::V4(icmpv4),
335 instant_end
336 .checked_duration_since(instant_begin)
337 .unwrap_or(instant_begin.elapsed()),
338 )),
339 Ok(Some((Err(err), _))) => Err(PingError::Icmpv4ParseError(err)),
340 Ok(None) => Err(PingError::Unknown("rx.recv None".to_string())),
341 Err(_) => Err(PingError::RecvTimedOut),
342 }
343 }
344 Err(mut rx) => {
345 match tokio::time::timeout(
346 tokio::time::Duration::from_millis(timeout_dur.as_millis() as u64),
347 rx.recv(),
348 )
349 .await
350 {
351 Ok(Some((Ok(icmpv6), instant_end))) => Ok((
352 Icmp::V6(icmpv6),
353 instant_end
354 .checked_duration_since(instant_begin)
355 .unwrap_or(instant_begin.elapsed()),
356 )),
357 Ok(Some((Err(err), _))) => Err(PingError::Icmpv6ParseError(err)),
358 Ok(None) => Err(PingError::Unknown("rx.recv None".to_string())),
359 Err(_) => Err(PingError::RecvTimedOut),
360 }
361 }
362 }
363 }
364
365 pub async fn ping_v4(
366 &self,
367 ip: Ipv4Addr,
368 identifier: Option<u16>,
369 sequence_number: Option<u16>,
370 payload: impl AsRef<[u8]>,
371 timeout_dur: Duration,
372 ) -> Result<(Icmpv4, Duration), PingError> {
373 let (icmp, dur) = self
374 .ping(ip.into(), identifier, sequence_number, payload, timeout_dur)
375 .await?;
376 match icmp {
377 Icmp::V4(icmp) => Ok((icmp, dur)),
378 Icmp::V6(_) => Err(PingError::Unknown("unreachable".to_string())),
379 }
380 }
381
382 pub async fn ping_v6(
383 &self,
384 ip: Ipv6Addr,
385 identifier: Option<u16>,
386 sequence_number: Option<u16>,
387 payload: impl AsRef<[u8]>,
388 timeout_dur: Duration,
389 ) -> Result<(Icmpv6, Duration), PingError> {
390 let (icmp, dur) = self
391 .ping(ip.into(), identifier, sequence_number, payload, timeout_dur)
392 .await?;
393 match icmp {
394 Icmp::V4(_) => Err(PingError::Unknown("unreachable".to_string())),
395 Icmp::V6(icmp) => Ok((icmp, dur)),
396 }
397 }
398}
399
400#[derive(Debug)]
402pub enum PingError {
403 NoV4Client,
404 NoV6Client,
405 Send(IoError),
406 Icmpv4ParseError(Icmpv4ParseError),
407 Icmpv6ParseError(Icmpv6ParseError),
408 RecvTimedOut,
409 Unknown(String),
410}
411impl core::fmt::Display for PingError {
412 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
413 write!(f, "{self:?}")
414 }
415}
416impl std::error::Error for PingError {}
417
418#[cfg(test)]
419mod tests {
420 use super::*;
421
422 #[tokio::test]
423 async fn test_ping_with_ipv4() -> Result<(), Box<dyn std::error::Error>> {
424 let client =
425 PingClient::<icmp_client::impl_tokio::Client>::new(Some(ClientConfig::new()), None)?;
426
427 {
428 let client = client.clone();
429 tokio::spawn(async move {
430 client.handle_v4_recv_from().await;
431 });
432 }
433
434 {
435 match client
436 .ping(
437 "127.0.0.1".parse().expect("Never"),
438 None,
439 None,
440 vec![0; 32],
441 Duration::from_secs(2),
442 )
443 .await
444 {
445 Ok((icmp, dur)) => {
446 println!("{dur:?} {icmp:?}");
447 }
448 Err(err) => panic!("{err}"),
449 }
450 }
451
452 Ok(())
453 }
454
455 #[tokio::test]
456 async fn test_ping_with_ipv6() -> Result<(), Box<dyn std::error::Error>> {
457 let client = match PingClient::<icmp_client::impl_tokio::Client>::new(
458 None,
459 Some(ClientConfig::with_ipv6()),
460 ) {
461 Ok(x) => x,
462 Err(err) => {
463 if matches!(
464 err,
465 AsyncClientWithConfigError::IcmpV6ProtocolNotSupported(_)
466 ) {
467 let info = os_info::get();
468 if info.os_type() == os_info::Type::CentOS
469 && matches!(info.version(), os_info::Version::Semantic(7, 0, 0))
470 {
471 eprintln!("CentOS 7 doesn't support IcmpV6");
472 return Ok(());
473 } else {
474 panic!("{err:?}")
475 }
476 } else {
477 panic!("{err:?}")
478 }
479 }
480 };
481
482 {
483 let client = client.clone();
484 tokio::spawn(async move {
485 client.handle_v6_recv_from().await;
486 });
487 }
488
489 {
490 match client
491 .ping(
492 "::1".parse().expect("Never"),
493 None,
494 None,
495 vec![0; 32],
496 Duration::from_secs(2),
497 )
498 .await
499 {
500 Ok((icmp, dur)) => {
501 println!("{dur:?} {icmp:?}");
502 }
503 Err(err) => panic!("{err}"),
504 }
505 }
506
507 Ok(())
508 }
509}