ping_async/platform/
socket.rs

1// platform/socket.rs
2#![cfg(any(target_os = "macos", target_os = "linux"))]
3
4use std::io;
5use std::mem;
6use std::net::{IpAddr, SocketAddr, SocketAddrV4, SocketAddrV6};
7use std::sync::Arc;
8use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
9
10use byteorder::NetworkEndian;
11use futures::channel::mpsc::UnboundedSender;
12use ippacket::{Bytes, IcmpHeader, IcmpType4, IcmpType6};
13use rand::random;
14use socket2::{Domain, Protocol, Socket, Type};
15use tokio::{net::UdpSocket, task, time};
16
17use crate::{IcmpEchoReply, IcmpEchoStatus, PING_DEFAULT_TIMEOUT, PING_DEFAULT_TTL};
18
19pub struct IcmpEchoRequestor {
20    socket: Arc<UdpSocket>,
21    target_addr: IpAddr,
22    timeout: Duration,
23    identifier: u16,
24    sequence: u16,
25    reply_tx: UnboundedSender<IcmpEchoReply>,
26}
27
28impl IcmpEchoRequestor {
29    pub fn new(
30        reply_tx: UnboundedSender<IcmpEchoReply>,
31        target_addr: IpAddr,
32        source_addr: Option<IpAddr>,
33        ttl: Option<u8>,
34        timeout: Option<Duration>,
35    ) -> io::Result<Self> {
36        let socket = match target_addr {
37            IpAddr::V4(_) => Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::ICMPV4))?,
38            IpAddr::V6(_) => Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::ICMPV6))?,
39        };
40        socket.set_nonblocking(true)?;
41
42        let ttl = ttl.unwrap_or(PING_DEFAULT_TTL);
43        let timeout = timeout.unwrap_or(PING_DEFAULT_TIMEOUT);
44
45        if target_addr.is_ipv4() {
46            socket.set_ttl(ttl as u32)?;
47        } else {
48            socket.set_unicast_hops_v6(ttl as u32)?;
49        }
50
51        // bind the source address if provided
52        if let Some(source_addr) = source_addr {
53            match (target_addr, source_addr) {
54                (IpAddr::V4(_), IpAddr::V4(ip)) => {
55                    socket.bind(&SocketAddrV4::new(ip, 0).into())?;
56                }
57                (IpAddr::V6(_), IpAddr::V6(ip)) => {
58                    socket.bind(&SocketAddrV6::new(ip, 0, 0, 0).into())?;
59                }
60                _ => {
61                    return Err(io::Error::new(
62                        io::ErrorKind::InvalidInput,
63                        "source and target address must be the same IP version",
64                    ));
65                }
66            }
67        }
68
69        // connect to the target address
70        socket.connect(&SocketAddr::new(target_addr, 0).into())?;
71
72        Ok(IcmpEchoRequestor {
73            socket: Arc::new(UdpSocket::from_std(socket.into())?),
74            target_addr,
75            timeout,
76            identifier: random(),
77            sequence: 0,
78            reply_tx,
79        })
80    }
81
82    pub async fn send(&self) -> io::Result<()> {
83        let payload = vec![0u8; IcmpHeader::len() + mem::size_of::<u128>()];
84        let byte = Bytes::new(payload.into_boxed_slice());
85        let packet = byte.clone();
86
87        let (mut header, mut data) = IcmpHeader::with_bytes(byte)?;
88        if self.target_addr.is_ipv4() {
89            header.set_icmp_type(IcmpType4::EchoRequest.value());
90        } else {
91            header.set_icmp_type(IcmpType6::EchoRequest.value());
92        }
93        header.set_icmp_code(0);
94        header.set_id(self.identifier);
95        header.set_seq(self.sequence);
96
97        let socket_clone = Arc::clone(&self.socket);
98        let tx_clone = self.reply_tx.clone();
99        let target_clone = self.target_addr.clone();
100
101        let mut tick = time::interval(self.timeout);
102        // approximately 0ms have elapsed. The first tick above completes immediately.
103        tick.tick().await;
104
105        let now = SystemTime::now()
106            .duration_since(UNIX_EPOCH)
107            .map_err(|e| {
108                io::Error::new(
109                    io::ErrorKind::Other,
110                    format!("failed to get timestamp: {}", e),
111                )
112            })?
113            .as_nanos();
114
115        data.write_u128::<NetworkEndian>(0, now).unwrap();
116        if self.target_addr.is_ipv4() {
117            header.calculate_checksum(data.pair_iter());
118        }
119
120        let beginning = Instant::now();
121        self.socket.send(&packet.as_slice()).await?;
122
123        task::spawn(async move {
124            tokio::select! {
125                _ = tick.tick() => {
126                    let _ = tx_clone.unbounded_send(IcmpEchoReply::new(
127                        target_clone,
128                        IcmpEchoStatus::TimedOut,
129                        beginning.elapsed(),
130                    ));
131                    return;
132                }
133                header = IcmpEchoRequestor::recv_loop(socket_clone, target_clone) => {
134                    match header {
135                        Ok((_, data)) => {
136                            // we don't test identifier and sequence number here
137                            IcmpEchoRequestor::parse_icmp_data(
138                                data,
139                                tx_clone,
140                                target_clone,
141                                beginning,
142                            );
143                        }
144                        Err(e) => {
145                            log::debug!("error upon recving ICMP packet: {}", e);
146                            let _ = tx_clone.unbounded_send(IcmpEchoReply::new(
147                                target_clone,
148                                IcmpEchoStatus::Unknown,
149                                beginning.elapsed(),
150                            ));
151                        }
152                    }
153                }
154            }
155        });
156
157        Ok(())
158    }
159
160    async fn recv_loop(socket: Arc<UdpSocket>, target: IpAddr) -> io::Result<(IcmpHeader, Bytes)> {
161        loop {
162            let mut buf = vec![0u8; 1024];
163
164            let size = socket.recv(&mut buf).await?;
165            let payload: Box<[u8]>;
166            if target.is_ipv4() {
167                // skip the IP header for icmp
168                payload = Vec::from(&buf[20..size]).into_boxed_slice();
169            } else {
170                payload = Vec::from(&buf[..size]).into_boxed_slice();
171            }
172
173            let (header, data) = IcmpHeader::with_bytes(Bytes::new(payload))?;
174            match (target, header.icmp_type()) {
175                (IpAddr::V4(_), x) if x == IcmpType4::EchoReply.value() => {
176                    return Ok((header, data))
177                }
178                (IpAddr::V6(_), x) if x == IcmpType6::EchoReply.value() => {
179                    return Ok((header, data))
180                }
181                _ => continue, // ignore the ECHO_REQUEST packet when ping ::1 on macOS
182            }
183        }
184    }
185
186    fn parse_icmp_data(
187        data: Bytes,
188        tx: UnboundedSender<IcmpEchoReply>,
189        target: IpAddr,
190        beginning: Instant,
191    ) {
192        if data.len() != mem::size_of::<u128>() {
193            let _ = tx.unbounded_send(IcmpEchoReply::new(
194                target,
195                IcmpEchoStatus::Unknown,
196                beginning.elapsed(),
197            ));
198            return;
199        }
200
201        let sent = data.read_u128::<NetworkEndian>(0).unwrap();
202        match SystemTime::now().duration_since(UNIX_EPOCH) {
203            Ok(now) => {
204                let rtt = now.as_nanos() - sent;
205                let _ = tx.unbounded_send(IcmpEchoReply::new(
206                    target,
207                    IcmpEchoStatus::Success,
208                    Duration::from_nanos(rtt as u64),
209                ));
210            }
211            Err(e) => {
212                log::debug!("failed to get system time: {}", e);
213
214                let _ = tx.unbounded_send(IcmpEchoReply::new(
215                    target,
216                    IcmpEchoStatus::Unknown,
217                    beginning.elapsed(),
218                ));
219            }
220        }
221    }
222}