async_ping/
lib.rs

1pub use icmp_client;
2pub use icmp_packet;
3
4use core::time::Duration;
5use std::{
6    collections::HashMap,
7    io::{Error as IoError, ErrorKind as IoErrorKind},
8    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
9    sync::Arc,
10    time::Instant,
11};
12
13use icmp_client::{AsyncClient, AsyncClientWithConfigError, Config as ClientConfig};
14use icmp_packet::{
15    icmpv4::ParseError as Icmpv4ParseError, icmpv6::ParseError as Icmpv6ParseError, Icmp, Icmpv4,
16    Icmpv6, PayloadLengthDelimitedEchoRequest,
17};
18use tokio::sync::{
19    mpsc::{self, Sender},
20    Mutex,
21};
22use tracing::{event, Level};
23
24//
25type V4RecvFromMap =
26    Arc<Mutex<HashMap<SocketAddr, Sender<(Result<Icmpv4, Icmpv4ParseError>, Instant)>>>>;
27type V6RecvFromMap =
28    Arc<Mutex<HashMap<SocketAddr, Sender<(Result<Icmpv6, Icmpv6ParseError>, Instant)>>>>;
29
30//
31pub struct PingClient<C>
32where
33    C: AsyncClient,
34{
35    v4_client: Option<Arc<C>>,
36    v6_client: Option<Arc<C>>,
37    v4_recv_from_map: V4RecvFromMap,
38    v6_recv_from_map: V6RecvFromMap,
39}
40
41impl<C> core::fmt::Debug for PingClient<C>
42where
43    C: AsyncClient,
44{
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        f.debug_struct("PingClient").finish()
47    }
48}
49
50impl<C> Clone for PingClient<C>
51where
52    C: AsyncClient,
53{
54    fn clone(&self) -> Self {
55        Self {
56            v4_client: self.v4_client.clone(),
57            v6_client: self.v6_client.clone(),
58            v4_recv_from_map: self.v4_recv_from_map.clone(),
59            v6_recv_from_map: self.v6_recv_from_map.clone(),
60        }
61    }
62}
63
64impl<C> PingClient<C>
65where
66    C: AsyncClient,
67{
68    pub fn new(
69        v4_client_config: Option<ClientConfig>,
70        v6_client_config: Option<ClientConfig>,
71    ) -> Result<Self, AsyncClientWithConfigError> {
72        let v4_client = if let Some(mut v4_client_config) = v4_client_config {
73            if v4_client_config.is_ipv6() {
74                return Err(IoError::new(IoErrorKind::Other, "v4_client_config invalid").into());
75            }
76            if v4_client_config.bind.is_none() {
77                v4_client_config.bind =
78                    Some(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0).into());
79            }
80
81            Some(Arc::new(C::with_config(&v4_client_config)?))
82        } else {
83            None
84        };
85
86        let v6_client = if let Some(mut v6_client_config) = v6_client_config {
87            if !v6_client_config.is_ipv6() {
88                return Err(IoError::new(IoErrorKind::Other, "v4_client_config invalid").into());
89            }
90            if v6_client_config.bind.is_none() {
91                v6_client_config.bind =
92                    Some(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0), 0, 0, 0).into());
93            }
94
95            Some(Arc::new(C::with_config(&v6_client_config)?))
96        } else {
97            None
98        };
99
100        let v4_recv_from_map = Arc::new(Mutex::new(HashMap::new()));
101        let v6_recv_from_map = Arc::new(Mutex::new(HashMap::new()));
102
103        Ok(Self {
104            v4_client,
105            v6_client,
106            v4_recv_from_map,
107            v6_recv_from_map,
108        })
109    }
110
111    // TODO, Support with spawn and without spawn.
112    pub async fn handle_v4_recv_from(&self) {
113        let v4_client = match self.v4_client.as_ref() {
114            Some(x) => x,
115            None => return,
116        };
117
118        let mut buf = [0; 2048];
119        let bytes_present_map: Arc<Mutex<HashMap<SocketAddr, Vec<u8>>>> =
120            Arc::new(Mutex::new(HashMap::new()));
121
122        loop {
123            match v4_client.recv_from(&mut buf).await {
124                Ok((n, addr)) => {
125                    let instant_end = Instant::now();
126                    let bytes_read = buf[..n].to_owned();
127
128                    let v4_recv_from_map = self.v4_recv_from_map.clone();
129                    let bytes_present_map = bytes_present_map.clone();
130
131                    tokio::spawn(async move {
132                        let bytes = if let Some(mut bytes_present) =
133                            bytes_present_map.lock().await.remove(&addr)
134                        {
135                            bytes_present.extend_from_slice(&bytes_read);
136                            bytes_present
137                        } else {
138                            bytes_read
139                        };
140
141                        match Icmpv4::parse_from_packet_bytes(&bytes) {
142                            Ok(Some(icmpv4)) => {
143                                if let Some(tx) = v4_recv_from_map.lock().await.remove(&addr) {
144                                    if let Err(err) = tx.try_send((Ok(icmpv4), instant_end)) {
145                                        event!(
146                                            Level::ERROR,
147                                            "tx.send failed, err:{err} addr:{addr}"
148                                        );
149                                    }
150                                } else {
151                                    event!(
152                                        Level::WARN,
153                                        "v4_recv_from_map.remove None, addr:{addr}"
154                                    );
155                                }
156                            }
157                            Ok(None) => {
158                                bytes_present_map.lock().await.insert(addr, bytes);
159                            }
160                            Err(err) => {
161                                if let Some(tx) = v4_recv_from_map.lock().await.remove(&addr) {
162                                    if let Err(err) = tx.try_send((Err(err), instant_end)) {
163                                        event!(
164                                            Level::ERROR,
165                                            "tx.send failed, err:{err} addr:{addr}"
166                                        );
167                                    }
168                                } else {
169                                    event!(
170                                        Level::WARN,
171                                        "v4_recv_from_map.remove None, addr:{addr}"
172                                    );
173                                }
174                            }
175                        }
176                    });
177                }
178                Err(err) => {
179                    event!(Level::ERROR, "v4_client.recv_from failed, err:{err}");
180                }
181            }
182        }
183    }
184
185    pub async fn handle_v6_recv_from(&self) {
186        let v6_client = match self.v6_client.as_ref() {
187            Some(x) => x,
188            None => return,
189        };
190
191        let mut buf = [0; 2048];
192        let bytes_present_map: Arc<Mutex<HashMap<SocketAddr, Vec<u8>>>> =
193            Arc::new(Mutex::new(HashMap::new()));
194
195        loop {
196            match v6_client.recv_from(&mut buf).await {
197                Ok((n, addr)) => {
198                    let instant_end = Instant::now();
199                    let bytes_read = buf[..n].to_owned();
200
201                    let v6_recv_from_map = self.v6_recv_from_map.clone();
202                    let bytes_present_map = bytes_present_map.clone();
203
204                    tokio::spawn(async move {
205                        let bytes = if let Some(mut bytes_present) =
206                            bytes_present_map.lock().await.remove(&addr)
207                        {
208                            bytes_present.extend_from_slice(&bytes_read);
209                            bytes_present
210                        } else {
211                            bytes_read
212                        };
213
214                        match Icmpv6::parse_from_packet_bytes(&bytes) {
215                            Ok(Some(icmpv6)) => {
216                                if let Some(tx) = v6_recv_from_map.lock().await.remove(&addr) {
217                                    if let Err(err) = tx.try_send((Ok(icmpv6), instant_end)) {
218                                        event!(
219                                            Level::ERROR,
220                                            "tx.send failed, err:{err} addr:{addr}"
221                                        );
222                                    }
223                                } else {
224                                    event!(
225                                        Level::WARN,
226                                        "v6_recv_from_map.remove None, addr:{addr}"
227                                    );
228                                }
229                            }
230                            Ok(None) => {
231                                bytes_present_map.lock().await.insert(addr, bytes);
232                            }
233                            Err(err) => {
234                                if let Some(tx) = v6_recv_from_map.lock().await.remove(&addr) {
235                                    if let Err(err) = tx.try_send((Err(err), instant_end)) {
236                                        event!(
237                                            Level::ERROR,
238                                            "tx.send failed, err:{err} addr:{addr}"
239                                        );
240                                    }
241                                } else {
242                                    event!(
243                                        Level::WARN,
244                                        "v6_recv_from_map.remove None, addr:{addr}"
245                                    );
246                                }
247                            }
248                        }
249                    });
250                }
251                Err(err) => {
252                    event!(Level::ERROR, "v6_client.recv_from failed, err:{err}");
253                }
254            }
255        }
256    }
257
258    pub async fn ping(
259        &self,
260        ip: IpAddr,
261        identifier: Option<u16>,
262        sequence_number: Option<u16>,
263        payload: impl AsRef<[u8]>,
264        timeout_dur: Duration,
265    ) -> Result<(Icmp, Duration), PingError> {
266        //
267        let echo_request = PayloadLengthDelimitedEchoRequest::new(
268            identifier.map(Into::into),
269            sequence_number.map(Into::into),
270            payload,
271        );
272        let echo_request_bytes = match ip {
273            IpAddr::V4(_) => echo_request.render_v4_packet_bytes(),
274            IpAddr::V6(_) => echo_request.render_v6_packet_bytes(),
275        };
276
277        //
278        let rx = match ip {
279            IpAddr::V4(_) => {
280                let (tx, rx) = mpsc::channel(1);
281
282                self.v4_recv_from_map
283                    .lock()
284                    .await
285                    .insert((ip, 0).into(), tx);
286
287                Ok(rx)
288            }
289            IpAddr::V6(_) => {
290                let (tx, rx) = mpsc::channel(1);
291
292                self.v6_recv_from_map
293                    .lock()
294                    .await
295                    .insert((ip, 0).into(), tx);
296
297                Err(rx)
298            }
299        };
300
301        //
302        let client = match ip {
303            IpAddr::V4(_) => self.v4_client.as_ref().ok_or(PingError::NoV4Client)?,
304            IpAddr::V6(_) => self.v6_client.as_ref().ok_or(PingError::NoV6Client)?,
305        };
306
307        let instant_begin = Instant::now();
308
309        {
310            let mut n_write = 0;
311            while !echo_request_bytes[n_write..].is_empty() {
312                let n = client
313                    .send_to(&echo_request_bytes[n_write..], (ip, 0))
314                    .await
315                    .map_err(PingError::Send)?;
316                n_write += n;
317
318                if n == 0 {
319                    return Err(PingError::Send(IoErrorKind::WriteZero.into()));
320                }
321            }
322        }
323
324        //
325        match rx {
326            Ok(mut rx) => {
327                match tokio::time::timeout(
328                    tokio::time::Duration::from_millis(timeout_dur.as_millis() as u64),
329                    rx.recv(),
330                )
331                .await
332                {
333                    Ok(Some((Ok(icmpv4), instant_end))) => Ok((
334                        Icmp::V4(icmpv4),
335                        instant_end
336                            .checked_duration_since(instant_begin)
337                            .unwrap_or(instant_begin.elapsed()),
338                    )),
339                    Ok(Some((Err(err), _))) => Err(PingError::Icmpv4ParseError(err)),
340                    Ok(None) => Err(PingError::Unknown("rx.recv None".to_string())),
341                    Err(_) => Err(PingError::RecvTimedOut),
342                }
343            }
344            Err(mut rx) => {
345                match tokio::time::timeout(
346                    tokio::time::Duration::from_millis(timeout_dur.as_millis() as u64),
347                    rx.recv(),
348                )
349                .await
350                {
351                    Ok(Some((Ok(icmpv6), instant_end))) => Ok((
352                        Icmp::V6(icmpv6),
353                        instant_end
354                            .checked_duration_since(instant_begin)
355                            .unwrap_or(instant_begin.elapsed()),
356                    )),
357                    Ok(Some((Err(err), _))) => Err(PingError::Icmpv6ParseError(err)),
358                    Ok(None) => Err(PingError::Unknown("rx.recv None".to_string())),
359                    Err(_) => Err(PingError::RecvTimedOut),
360                }
361            }
362        }
363    }
364
365    pub async fn ping_v4(
366        &self,
367        ip: Ipv4Addr,
368        identifier: Option<u16>,
369        sequence_number: Option<u16>,
370        payload: impl AsRef<[u8]>,
371        timeout_dur: Duration,
372    ) -> Result<(Icmpv4, Duration), PingError> {
373        let (icmp, dur) = self
374            .ping(ip.into(), identifier, sequence_number, payload, timeout_dur)
375            .await?;
376        match icmp {
377            Icmp::V4(icmp) => Ok((icmp, dur)),
378            Icmp::V6(_) => Err(PingError::Unknown("unreachable".to_string())),
379        }
380    }
381
382    pub async fn ping_v6(
383        &self,
384        ip: Ipv6Addr,
385        identifier: Option<u16>,
386        sequence_number: Option<u16>,
387        payload: impl AsRef<[u8]>,
388        timeout_dur: Duration,
389    ) -> Result<(Icmpv6, Duration), PingError> {
390        let (icmp, dur) = self
391            .ping(ip.into(), identifier, sequence_number, payload, timeout_dur)
392            .await?;
393        match icmp {
394            Icmp::V4(_) => Err(PingError::Unknown("unreachable".to_string())),
395            Icmp::V6(icmp) => Ok((icmp, dur)),
396        }
397    }
398}
399
400//
401#[derive(Debug)]
402pub enum PingError {
403    NoV4Client,
404    NoV6Client,
405    Send(IoError),
406    Icmpv4ParseError(Icmpv4ParseError),
407    Icmpv6ParseError(Icmpv6ParseError),
408    RecvTimedOut,
409    Unknown(String),
410}
411impl core::fmt::Display for PingError {
412    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
413        write!(f, "{self:?}")
414    }
415}
416impl std::error::Error for PingError {}
417
418#[cfg(test)]
419mod tests {
420    use super::*;
421
422    #[tokio::test]
423    async fn test_ping_with_ipv4() -> Result<(), Box<dyn std::error::Error>> {
424        let client =
425            PingClient::<icmp_client::impl_tokio::Client>::new(Some(ClientConfig::new()), None)?;
426
427        {
428            let client = client.clone();
429            tokio::spawn(async move {
430                client.handle_v4_recv_from().await;
431            });
432        }
433
434        {
435            match client
436                .ping(
437                    "127.0.0.1".parse().expect("Never"),
438                    None,
439                    None,
440                    vec![0; 32],
441                    Duration::from_secs(2),
442                )
443                .await
444            {
445                Ok((icmp, dur)) => {
446                    println!("{dur:?} {icmp:?}");
447                }
448                Err(err) => panic!("{err}"),
449            }
450        }
451
452        Ok(())
453    }
454
455    #[tokio::test]
456    async fn test_ping_with_ipv6() -> Result<(), Box<dyn std::error::Error>> {
457        let client = match PingClient::<icmp_client::impl_tokio::Client>::new(
458            None,
459            Some(ClientConfig::with_ipv6()),
460        ) {
461            Ok(x) => x,
462            Err(err) => {
463                if matches!(
464                    err,
465                    AsyncClientWithConfigError::IcmpV6ProtocolNotSupported(_)
466                ) {
467                    let info = os_info::get();
468                    if info.os_type() == os_info::Type::CentOS
469                        && matches!(info.version(), os_info::Version::Semantic(7, 0, 0))
470                    {
471                        eprintln!("CentOS 7 doesn't support IcmpV6");
472                        return Ok(());
473                    } else {
474                        panic!("{err:?}")
475                    }
476                } else {
477                    panic!("{err:?}")
478                }
479            }
480        };
481
482        {
483            let client = client.clone();
484            tokio::spawn(async move {
485                client.handle_v6_recv_from().await;
486            });
487        }
488
489        {
490            match client
491                .ping(
492                    "::1".parse().expect("Never"),
493                    None,
494                    None,
495                    vec![0; 32],
496                    Duration::from_secs(2),
497                )
498                .await
499            {
500                Ok((icmp, dur)) => {
501                    println!("{dur:?} {icmp:?}");
502                }
503                Err(err) => panic!("{err}"),
504            }
505        }
506
507        Ok(())
508    }
509}