surge_ping/
client.rs

1#[cfg(unix)]
2use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
3#[cfg(windows)]
4use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket};
5
6use std::{
7    collections::HashMap,
8    io,
9    net::{IpAddr, SocketAddr},
10    sync::Arc,
11    time::Instant,
12};
13
14use parking_lot::Mutex;
15use socket2::{Domain, Protocol, Socket, Type as SockType};
16use tokio::{
17    net::UdpSocket,
18    sync::oneshot,
19    task::{self, JoinHandle},
20};
21use tracing::debug;
22
23use crate::{
24    config::Config,
25    icmp::{icmpv4::Icmpv4Packet, icmpv6::Icmpv6Packet},
26    IcmpPacket, PingIdentifier, PingSequence, Pinger, SurgeError, ICMP,
27};
28
29// Check, if the platform's socket operates with ICMP packets in a casual way
30#[macro_export]
31macro_rules! is_linux_icmp_socket {
32    ($sock_type:expr) => {
33        if ($sock_type == socket2::Type::DGRAM
34            && cfg!(not(any(target_os = "linux", target_os = "android"))))
35            || $sock_type == socket2::Type::RAW
36        {
37            false
38        } else {
39            true
40        }
41    };
42}
43
44#[derive(Clone)]
45pub struct AsyncSocket {
46    inner: Arc<UdpSocket>,
47    sock_type: SockType,
48}
49
50impl AsyncSocket {
51    pub fn new(config: &Config) -> io::Result<Self> {
52        let (sock_type, socket) = Self::create_socket(config)?;
53
54        socket.set_nonblocking(true)?;
55        if let Some(sock_addr) = &config.bind {
56            socket.bind(sock_addr)?;
57        }
58        #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
59        if let Some(interface) = &config.interface {
60            socket.bind_device(Some(interface.as_bytes()))?;
61        }
62        #[cfg(any(
63            target_os = "ios",
64            target_os = "visionos",
65            target_os = "macos",
66            target_os = "tvos",
67            target_os = "watchos",
68            target_os = "illumos",
69            target_os = "solaris",
70            target_os = "linux",
71            target_os = "android",
72        ))]
73        {
74            if config.interface_index.is_some() {
75                match config.kind {
76                    ICMP::V4 => socket.bind_device_by_index_v4(config.interface_index)?,
77                    ICMP::V6 => socket.bind_device_by_index_v6(config.interface_index)?,
78                }
79            }
80        }
81        if let Some(ttl) = config.ttl {
82            match config.kind {
83                ICMP::V4 => socket.set_ttl_v4(ttl)?,
84                ICMP::V6 => socket.set_unicast_hops_v6(ttl)?,
85            }
86        }
87        #[cfg(target_os = "freebsd")]
88        if let Some(fib) = config.fib {
89            socket.set_fib(fib)?;
90        }
91        #[cfg(windows)]
92        let socket = UdpSocket::from_std(unsafe {
93            std::net::UdpSocket::from_raw_socket(socket.into_raw_socket())
94        })?;
95        #[cfg(unix)]
96        let socket =
97            UdpSocket::from_std(unsafe { std::net::UdpSocket::from_raw_fd(socket.into_raw_fd()) })?;
98        Ok(Self {
99            inner: Arc::new(socket),
100            sock_type,
101        })
102    }
103
104    fn create_socket(config: &Config) -> io::Result<(SockType, Socket)> {
105        let (domain, proto) = match config.kind {
106            ICMP::V4 => (Domain::IPV4, Some(Protocol::ICMPV4)),
107            ICMP::V6 => (Domain::IPV6, Some(Protocol::ICMPV6)),
108        };
109
110        match Socket::new(domain, config.sock_type_hint, proto) {
111            Ok(sock) => Ok((config.sock_type_hint, sock)),
112            Err(err) => {
113                let new_type = if config.sock_type_hint == SockType::DGRAM {
114                    SockType::RAW
115                } else {
116                    SockType::DGRAM
117                };
118
119                debug!(
120                    "error opening {:?} type socket, trying {:?}: {:?}",
121                    config.sock_type_hint, new_type, err
122                );
123
124                Ok((new_type, Socket::new(domain, new_type, proto)?))
125            }
126        }
127    }
128
129    pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
130        self.inner.recv_from(buf).await
131    }
132
133    pub async fn send_to(&self, buf: &mut [u8], target: &SocketAddr) -> io::Result<usize> {
134        self.inner.send_to(buf, target).await
135    }
136
137    pub fn local_addr(&self) -> io::Result<SocketAddr> {
138        self.inner.local_addr()
139    }
140
141    pub fn get_type(&self) -> SockType {
142        self.sock_type
143    }
144
145    #[cfg(unix)]
146    pub fn get_native_sock(&self) -> RawFd {
147        self.inner.as_raw_fd()
148    }
149
150    #[cfg(windows)]
151    pub fn get_native_sock(&self) -> RawSocket {
152        self.inner.as_raw_socket()
153    }
154}
155
156#[derive(PartialEq, Eq, Hash)]
157struct ReplyToken(IpAddr, Option<PingIdentifier>, PingSequence);
158
159pub(crate) struct Reply {
160    pub timestamp: Instant,
161    pub packet: IcmpPacket,
162}
163
164#[derive(Clone, Default)]
165pub(crate) struct ReplyMap(Arc<Mutex<HashMap<ReplyToken, oneshot::Sender<Reply>>>>);
166
167impl ReplyMap {
168    /// Register to wait for a reply from host with ident and sequence number.
169    /// If there is already someone waiting for this specific reply then an
170    /// error is returned.
171    pub fn new_waiter(
172        &self,
173        host: IpAddr,
174        ident: Option<PingIdentifier>,
175        seq: PingSequence,
176    ) -> Result<oneshot::Receiver<Reply>, SurgeError> {
177        let (tx, rx) = oneshot::channel();
178        if self
179            .0
180            .lock()
181            .insert(ReplyToken(host, ident, seq), tx)
182            .is_some()
183        {
184            return Err(SurgeError::IdenticalRequests { host, ident, seq });
185        }
186        Ok(rx)
187    }
188
189    /// Remove a waiter.
190    pub(crate) fn remove(
191        &self,
192        host: IpAddr,
193        ident: Option<PingIdentifier>,
194        seq: PingSequence,
195    ) -> Option<oneshot::Sender<Reply>> {
196        self.0.lock().remove(&ReplyToken(host, ident, seq))
197    }
198}
199
200///
201/// If you want to pass the `Client` in the task, please wrap it with `Arc`: `Arc<Client>`.
202/// and can realize the simultaneous ping of multiple addresses when only one `socket` is created.
203///
204#[derive(Clone)]
205pub struct Client {
206    socket: AsyncSocket,
207    reply_map: ReplyMap,
208    recv: Arc<JoinHandle<()>>,
209}
210
211impl Drop for Client {
212    fn drop(&mut self) {
213        // The client may pass through multiple tasks, so need to judge whether the number of references is 1.
214        if Arc::strong_count(&self.recv) <= 1 {
215            self.recv.abort();
216        }
217    }
218}
219
220impl Client {
221    /// A client is generated according to the configuration. In fact, a `AsyncSocket` is wrapped inside,
222    /// and you can clone to any `task` at will.
223    pub fn new(config: &Config) -> io::Result<Self> {
224        let socket = AsyncSocket::new(config)?;
225        let reply_map = ReplyMap::default();
226        let recv = task::spawn(recv_task(socket.clone(), reply_map.clone()));
227        Ok(Self {
228            socket,
229            reply_map,
230            recv: Arc::new(recv),
231        })
232    }
233
234    /// Create a `Pinger` instance, you can make special configuration for this instance.
235    pub async fn pinger(&self, host: IpAddr, ident: PingIdentifier) -> Pinger {
236        Pinger::new(host, ident, self.socket.clone(), self.reply_map.clone())
237    }
238
239    /// Expose the underlying socket, if user wants to modify any options on it
240    pub fn get_socket(&self) -> AsyncSocket {
241        self.socket.clone()
242    }
243}
244
245async fn recv_task(socket: AsyncSocket, reply_map: ReplyMap) {
246    let mut buf = [0; 2048];
247    loop {
248        if let Ok((sz, addr)) = socket.recv_from(&mut buf).await {
249            let timestamp = Instant::now();
250            let message = &buf[..sz];
251            let local_addr = socket.local_addr().unwrap().ip();
252            let packet = {
253                let result = match addr.ip() {
254                    IpAddr::V4(src_addr) => {
255                        let local_addr_ip4 = match local_addr {
256                            IpAddr::V4(local_addr_ip4) => local_addr_ip4,
257                            _ => continue,
258                        };
259
260                        Icmpv4Packet::decode(message, socket.sock_type, src_addr, local_addr_ip4)
261                            .map(IcmpPacket::V4)
262                    }
263                    IpAddr::V6(src_addr) => {
264                        Icmpv6Packet::decode(message, src_addr).map(IcmpPacket::V6)
265                    }
266                };
267                match result {
268                    Ok(packet) => packet,
269                    Err(err) => {
270                        debug!("error decoding ICMP packet: {:?}", err);
271                        continue;
272                    }
273                }
274            };
275
276            let ident = if is_linux_icmp_socket!(socket.get_type()) {
277                None
278            } else {
279                Some(packet.get_identifier())
280            };
281
282            if let Some(waiter) = reply_map.remove(addr.ip(), ident, packet.get_sequence()) {
283                // If send fails the receiving end has closed. Nothing to do.
284                let _ = waiter.send(Reply { timestamp, packet });
285            } else {
286                debug!("no one is waiting for ICMP packet ({:?})", packet);
287            }
288        }
289    }
290}