1#[cfg(unix)]
2use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
3#[cfg(windows)]
4use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket};
5
6use std::{
7 collections::HashMap,
8 io,
9 net::{IpAddr, SocketAddr},
10 sync::Arc,
11 time::Instant,
12};
13
14use parking_lot::Mutex;
15use socket2::{Domain, Protocol, Socket, Type as SockType};
16use tokio::{
17 net::UdpSocket,
18 sync::oneshot,
19 task::{self, JoinHandle},
20};
21use tracing::debug;
22
23use crate::{
24 config::Config,
25 icmp::{icmpv4::Icmpv4Packet, icmpv6::Icmpv6Packet},
26 IcmpPacket, PingIdentifier, PingSequence, Pinger, SurgeError, ICMP,
27};
28
29#[macro_export]
31macro_rules! is_linux_icmp_socket {
32 ($sock_type:expr) => {
33 if ($sock_type == socket2::Type::DGRAM
34 && cfg!(not(any(target_os = "linux", target_os = "android"))))
35 || $sock_type == socket2::Type::RAW
36 {
37 false
38 } else {
39 true
40 }
41 };
42}
43
44#[derive(Clone)]
45pub struct AsyncSocket {
46 inner: Arc<UdpSocket>,
47 sock_type: SockType,
48}
49
50impl AsyncSocket {
51 pub fn new(config: &Config) -> io::Result<Self> {
52 let (sock_type, socket) = Self::create_socket(config)?;
53
54 socket.set_nonblocking(true)?;
55 if let Some(sock_addr) = &config.bind {
56 socket.bind(sock_addr)?;
57 }
58 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
59 if let Some(interface) = &config.interface {
60 socket.bind_device(Some(interface.as_bytes()))?;
61 }
62 if let Some(ttl) = config.ttl {
63 socket.set_ttl(ttl)?;
64 }
65 #[cfg(target_os = "freebsd")]
66 if let Some(fib) = config.fib {
67 socket.set_fib(fib)?;
68 }
69 #[cfg(windows)]
70 let socket = UdpSocket::from_std(unsafe {
71 std::net::UdpSocket::from_raw_socket(socket.into_raw_socket())
72 })?;
73 #[cfg(unix)]
74 let socket =
75 UdpSocket::from_std(unsafe { std::net::UdpSocket::from_raw_fd(socket.into_raw_fd()) })?;
76 Ok(Self {
77 inner: Arc::new(socket),
78 sock_type,
79 })
80 }
81
82 fn create_socket(config: &Config) -> io::Result<(SockType, Socket)> {
83 let (domain, proto) = match config.kind {
84 ICMP::V4 => (Domain::IPV4, Some(Protocol::ICMPV4)),
85 ICMP::V6 => (Domain::IPV6, Some(Protocol::ICMPV6)),
86 };
87
88 match Socket::new(domain, config.sock_type_hint, proto) {
89 Ok(sock) => Ok((config.sock_type_hint, sock)),
90 Err(err) => {
91 let new_type = if config.sock_type_hint == SockType::DGRAM {
92 SockType::RAW
93 } else {
94 SockType::DGRAM
95 };
96
97 debug!(
98 "error opening {:?} type socket, trying {:?}: {:?}",
99 config.sock_type_hint, new_type, err
100 );
101
102 Ok((new_type, Socket::new(domain, new_type, proto)?))
103 }
104 }
105 }
106
107 pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
108 self.inner.recv_from(buf).await
109 }
110
111 pub async fn send_to(&self, buf: &mut [u8], target: &SocketAddr) -> io::Result<usize> {
112 self.inner.send_to(buf, target).await
113 }
114
115 pub fn local_addr(&self) -> io::Result<SocketAddr> {
116 self.inner.local_addr()
117 }
118
119 pub fn get_type(&self) -> SockType {
120 self.sock_type
121 }
122
123 #[cfg(unix)]
124 pub fn get_native_sock(&self) -> RawFd {
125 self.inner.as_raw_fd()
126 }
127
128 #[cfg(windows)]
129 pub fn get_native_sock(&self) -> RawSocket {
130 self.inner.as_raw_socket()
131 }
132}
133
134#[derive(PartialEq, Eq, Hash)]
135struct ReplyToken(IpAddr, Option<PingIdentifier>, PingSequence);
136
137pub(crate) struct Reply {
138 pub timestamp: Instant,
139 pub packet: IcmpPacket,
140}
141
142#[derive(Clone, Default)]
143pub(crate) struct ReplyMap(Arc<Mutex<HashMap<ReplyToken, oneshot::Sender<Reply>>>>);
144
145impl ReplyMap {
146 pub fn new_waiter(
150 &self,
151 host: IpAddr,
152 ident: Option<PingIdentifier>,
153 seq: PingSequence,
154 ) -> Result<oneshot::Receiver<Reply>, SurgeError> {
155 let (tx, rx) = oneshot::channel();
156 if self
157 .0
158 .lock()
159 .insert(ReplyToken(host, ident, seq), tx)
160 .is_some()
161 {
162 return Err(SurgeError::IdenticalRequests { host, ident, seq });
163 }
164 Ok(rx)
165 }
166
167 pub(crate) fn remove(
169 &self,
170 host: IpAddr,
171 ident: Option<PingIdentifier>,
172 seq: PingSequence,
173 ) -> Option<oneshot::Sender<Reply>> {
174 self.0.lock().remove(&ReplyToken(host, ident, seq))
175 }
176}
177
178#[derive(Clone)]
183pub struct Client {
184 socket: AsyncSocket,
185 reply_map: ReplyMap,
186 recv: Arc<JoinHandle<()>>,
187}
188
189impl Drop for Client {
190 fn drop(&mut self) {
191 if Arc::strong_count(&self.recv) <= 1 {
193 self.recv.abort();
194 }
195 }
196}
197
198impl Client {
199 pub fn new(config: &Config) -> io::Result<Self> {
202 let socket = AsyncSocket::new(config)?;
203 let reply_map = ReplyMap::default();
204 let recv = task::spawn(recv_task(socket.clone(), reply_map.clone()));
205 Ok(Self {
206 socket,
207 reply_map,
208 recv: Arc::new(recv),
209 })
210 }
211
212 pub async fn pinger(&self, host: IpAddr, ident: PingIdentifier) -> Pinger {
214 Pinger::new(host, ident, self.socket.clone(), self.reply_map.clone())
215 }
216
217 pub fn get_socket(&self) -> AsyncSocket {
219 self.socket.clone()
220 }
221}
222
223async fn recv_task(socket: AsyncSocket, reply_map: ReplyMap) {
224 let mut buf = [0; 2048];
225 loop {
226 if let Ok((sz, addr)) = socket.recv_from(&mut buf).await {
227 let timestamp = Instant::now();
228 let message = &buf[..sz];
229 let local_addr = socket.local_addr().unwrap().ip();
230 let packet = {
231 let result = match addr.ip() {
232 IpAddr::V4(src_addr) => {
233 let local_addr_ip4 = match local_addr {
234 IpAddr::V4(local_addr_ip4) => local_addr_ip4,
235 _ => continue,
236 };
237
238 Icmpv4Packet::decode(message, socket.sock_type, src_addr, local_addr_ip4)
239 .map(IcmpPacket::V4)
240 }
241 IpAddr::V6(src_addr) => {
242 Icmpv6Packet::decode(message, src_addr).map(IcmpPacket::V6)
243 }
244 };
245 match result {
246 Ok(packet) => packet,
247 Err(err) => {
248 debug!("error decoding ICMP packet: {:?}", err);
249 continue;
250 }
251 }
252 };
253
254 let ident = if is_linux_icmp_socket!(socket.get_type()) {
255 None
256 } else {
257 Some(packet.get_identifier())
258 };
259
260 if let Some(waiter) = reply_map.remove(addr.ip(), ident, packet.get_sequence()) {
261 let _ = waiter.send(Reply { timestamp, packet });
263 } else {
264 debug!("no one is waiting for ICMP packet ({:?})", packet);
265 }
266 }
267 }
268}