1#[cfg(unix)]
2use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
3#[cfg(windows)]
4use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket};
5
6use std::{
7 collections::HashMap,
8 io,
9 net::{IpAddr, SocketAddr},
10 sync::Arc,
11 time::Instant,
12};
13
14use parking_lot::Mutex;
15use socket2::{Domain, Protocol, Socket, Type as SockType};
16use tokio::{
17 net::UdpSocket,
18 sync::oneshot,
19 task::{self, JoinHandle},
20};
21use tracing::debug;
22
23use crate::{
24 config::Config,
25 icmp::{icmpv4::Icmpv4Packet, icmpv6::Icmpv6Packet},
26 IcmpPacket, PingIdentifier, PingSequence, Pinger, SurgeError, ICMP,
27};
28
29#[macro_export]
31macro_rules! is_linux_icmp_socket {
32 ($sock_type:expr) => {
33 if ($sock_type == socket2::Type::DGRAM
34 && cfg!(not(any(target_os = "linux", target_os = "android"))))
35 || $sock_type == socket2::Type::RAW
36 {
37 false
38 } else {
39 true
40 }
41 };
42}
43
44#[derive(Clone)]
45pub struct AsyncSocket {
46 inner: Arc<UdpSocket>,
47 sock_type: SockType,
48}
49
50impl AsyncSocket {
51 pub fn new(config: &Config) -> io::Result<Self> {
52 let (sock_type, socket) = Self::create_socket(config)?;
53
54 socket.set_nonblocking(true)?;
55 if let Some(sock_addr) = &config.bind {
56 socket.bind(sock_addr)?;
57 }
58 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
59 if let Some(interface) = &config.interface {
60 socket.bind_device(Some(interface.as_bytes()))?;
61 }
62 #[cfg(any(
63 target_os = "ios",
64 target_os = "visionos",
65 target_os = "macos",
66 target_os = "tvos",
67 target_os = "watchos",
68 target_os = "illumos",
69 target_os = "solaris",
70 target_os = "linux",
71 target_os = "android",
72 ))]
73 {
74 if config.interface_index.is_some() {
75 match config.kind {
76 ICMP::V4 => socket.bind_device_by_index_v4(config.interface_index)?,
77 ICMP::V6 => socket.bind_device_by_index_v6(config.interface_index)?,
78 }
79 }
80 }
81 if let Some(ttl) = config.ttl {
82 match config.kind {
83 ICMP::V4 => socket.set_ttl_v4(ttl)?,
84 ICMP::V6 => socket.set_unicast_hops_v6(ttl)?,
85 }
86 }
87 #[cfg(target_os = "freebsd")]
88 if let Some(fib) = config.fib {
89 socket.set_fib(fib)?;
90 }
91 #[cfg(windows)]
92 let socket = UdpSocket::from_std(unsafe {
93 std::net::UdpSocket::from_raw_socket(socket.into_raw_socket())
94 })?;
95 #[cfg(unix)]
96 let socket =
97 UdpSocket::from_std(unsafe { std::net::UdpSocket::from_raw_fd(socket.into_raw_fd()) })?;
98 Ok(Self {
99 inner: Arc::new(socket),
100 sock_type,
101 })
102 }
103
104 fn create_socket(config: &Config) -> io::Result<(SockType, Socket)> {
105 let (domain, proto) = match config.kind {
106 ICMP::V4 => (Domain::IPV4, Some(Protocol::ICMPV4)),
107 ICMP::V6 => (Domain::IPV6, Some(Protocol::ICMPV6)),
108 };
109
110 match Socket::new(domain, config.sock_type_hint, proto) {
111 Ok(sock) => Ok((config.sock_type_hint, sock)),
112 Err(err) => {
113 let new_type = if config.sock_type_hint == SockType::DGRAM {
114 SockType::RAW
115 } else {
116 SockType::DGRAM
117 };
118
119 debug!(
120 "error opening {:?} type socket, trying {:?}: {:?}",
121 config.sock_type_hint, new_type, err
122 );
123
124 Ok((new_type, Socket::new(domain, new_type, proto)?))
125 }
126 }
127 }
128
129 pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
130 self.inner.recv_from(buf).await
131 }
132
133 pub async fn send_to(&self, buf: &mut [u8], target: &SocketAddr) -> io::Result<usize> {
134 self.inner.send_to(buf, target).await
135 }
136
137 pub fn local_addr(&self) -> io::Result<SocketAddr> {
138 self.inner.local_addr()
139 }
140
141 pub fn get_type(&self) -> SockType {
142 self.sock_type
143 }
144
145 #[cfg(unix)]
146 pub fn get_native_sock(&self) -> RawFd {
147 self.inner.as_raw_fd()
148 }
149
150 #[cfg(windows)]
151 pub fn get_native_sock(&self) -> RawSocket {
152 self.inner.as_raw_socket()
153 }
154}
155
156#[derive(PartialEq, Eq, Hash)]
157struct ReplyToken(IpAddr, Option<PingIdentifier>, PingSequence);
158
159pub(crate) struct Reply {
160 pub timestamp: Instant,
161 pub packet: IcmpPacket,
162}
163
164#[derive(Clone, Default)]
165pub(crate) struct ReplyMap(Arc<Mutex<HashMap<ReplyToken, oneshot::Sender<Reply>>>>);
166
167impl ReplyMap {
168 pub fn new_waiter(
172 &self,
173 host: IpAddr,
174 ident: Option<PingIdentifier>,
175 seq: PingSequence,
176 ) -> Result<oneshot::Receiver<Reply>, SurgeError> {
177 let (tx, rx) = oneshot::channel();
178 if self
179 .0
180 .lock()
181 .insert(ReplyToken(host, ident, seq), tx)
182 .is_some()
183 {
184 return Err(SurgeError::IdenticalRequests { host, ident, seq });
185 }
186 Ok(rx)
187 }
188
189 pub(crate) fn remove(
191 &self,
192 host: IpAddr,
193 ident: Option<PingIdentifier>,
194 seq: PingSequence,
195 ) -> Option<oneshot::Sender<Reply>> {
196 self.0.lock().remove(&ReplyToken(host, ident, seq))
197 }
198}
199
200#[derive(Clone)]
205pub struct Client {
206 socket: AsyncSocket,
207 reply_map: ReplyMap,
208 recv: Arc<JoinHandle<()>>,
209}
210
211impl Drop for Client {
212 fn drop(&mut self) {
213 if Arc::strong_count(&self.recv) <= 1 {
215 self.recv.abort();
216 }
217 }
218}
219
220impl Client {
221 pub fn new(config: &Config) -> io::Result<Self> {
224 let socket = AsyncSocket::new(config)?;
225 let reply_map = ReplyMap::default();
226 let recv = task::spawn(recv_task(socket.clone(), reply_map.clone()));
227 Ok(Self {
228 socket,
229 reply_map,
230 recv: Arc::new(recv),
231 })
232 }
233
234 pub async fn pinger(&self, host: IpAddr, ident: PingIdentifier) -> Pinger {
236 Pinger::new(host, ident, self.socket.clone(), self.reply_map.clone())
237 }
238
239 pub fn get_socket(&self) -> AsyncSocket {
241 self.socket.clone()
242 }
243}
244
245async fn recv_task(socket: AsyncSocket, reply_map: ReplyMap) {
246 let mut buf = [0; 2048];
247 loop {
248 if let Ok((sz, addr)) = socket.recv_from(&mut buf).await {
249 let timestamp = Instant::now();
250 let message = &buf[..sz];
251 let local_addr = socket.local_addr().unwrap().ip();
252 let packet = {
253 let result = match addr.ip() {
254 IpAddr::V4(src_addr) => {
255 let local_addr_ip4 = match local_addr {
256 IpAddr::V4(local_addr_ip4) => local_addr_ip4,
257 _ => continue,
258 };
259
260 Icmpv4Packet::decode(message, socket.sock_type, src_addr, local_addr_ip4)
261 .map(IcmpPacket::V4)
262 }
263 IpAddr::V6(src_addr) => {
264 Icmpv6Packet::decode(message, src_addr).map(IcmpPacket::V6)
265 }
266 };
267 match result {
268 Ok(packet) => packet,
269 Err(err) => {
270 debug!("error decoding ICMP packet: {:?}", err);
271 continue;
272 }
273 }
274 };
275
276 let ident = if is_linux_icmp_socket!(socket.get_type()) {
277 None
278 } else {
279 Some(packet.get_identifier())
280 };
281
282 if let Some(waiter) = reply_map.remove(addr.ip(), ident, packet.get_sequence()) {
283 let _ = waiter.send(Reply { timestamp, packet });
285 } else {
286 debug!("no one is waiting for ICMP packet ({:?})", packet);
287 }
288 }
289 }
290}