#[cfg(unix)]
use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
#[cfg(windows)]
use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket};
use std::{
collections::HashMap,
io,
net::{IpAddr, SocketAddr},
sync::Arc,
time::Instant,
};
use parking_lot::Mutex;
use socket2::{Domain, Protocol, Socket, Type as SockType};
use tokio::{
net::UdpSocket,
sync::oneshot,
task::{self, JoinHandle},
};
use tracing::debug;
use crate::{
config::Config,
icmp::{icmpv4::Icmpv4Packet, icmpv6::Icmpv6Packet},
IcmpPacket, PingIdentifier, PingSequence, Pinger, SurgeError, ICMP,
};
#[macro_export]
macro_rules! is_linux_icmp_socket {
($sock_type:expr) => {
if ($sock_type == socket2::Type::DGRAM
&& cfg!(not(any(target_os = "linux", target_os = "android"))))
|| $sock_type == socket2::Type::RAW
{
false
} else {
true
}
};
}
#[derive(Clone)]
pub struct AsyncSocket {
inner: Arc<UdpSocket>,
sock_type: SockType,
}
impl AsyncSocket {
pub fn new(config: &Config) -> io::Result<Self> {
let (sock_type, socket) = Self::create_socket(config)?;
socket.set_nonblocking(true)?;
if let Some(sock_addr) = &config.bind {
socket.bind(sock_addr)?;
}
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
if let Some(interface) = &config.interface {
socket.bind_device(Some(interface.as_bytes()))?;
}
if let Some(ttl) = config.ttl {
socket.set_ttl(ttl)?;
}
#[cfg(target_os = "freebsd")]
if let Some(fib) = config.fib {
socket.set_fib(fib)?;
}
#[cfg(windows)]
let socket = UdpSocket::from_std(unsafe {
std::net::UdpSocket::from_raw_socket(socket.into_raw_socket())
})?;
#[cfg(unix)]
let socket =
UdpSocket::from_std(unsafe { std::net::UdpSocket::from_raw_fd(socket.into_raw_fd()) })?;
Ok(Self {
inner: Arc::new(socket),
sock_type,
})
}
fn create_socket(config: &Config) -> io::Result<(SockType, Socket)> {
let (domain, proto) = match config.kind {
ICMP::V4 => (Domain::IPV4, Some(Protocol::ICMPV4)),
ICMP::V6 => (Domain::IPV6, Some(Protocol::ICMPV6)),
};
match Socket::new(domain, config.sock_type_hint, proto) {
Ok(sock) => Ok((config.sock_type_hint, sock)),
Err(err) => {
let new_type = if config.sock_type_hint == SockType::DGRAM {
SockType::RAW
} else {
SockType::DGRAM
};
debug!(
"error opening {:?} type socket, trying {:?}: {:?}",
config.sock_type_hint, new_type, err
);
Ok((new_type, Socket::new(domain, new_type, proto)?))
}
}
}
pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
self.inner.recv_from(buf).await
}
pub async fn send_to(&self, buf: &mut [u8], target: &SocketAddr) -> io::Result<usize> {
self.inner.send_to(buf, target).await
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.inner.local_addr()
}
pub fn get_type(&self) -> SockType {
self.sock_type
}
#[cfg(unix)]
pub fn get_native_sock(&self) -> RawFd {
self.inner.as_raw_fd()
}
#[cfg(windows)]
pub fn get_native_sock(&self) -> RawSocket {
self.inner.as_raw_socket()
}
}
#[derive(PartialEq, Eq, Hash)]
struct ReplyToken(IpAddr, Option<PingIdentifier>, PingSequence);
pub(crate) struct Reply {
pub timestamp: Instant,
pub packet: IcmpPacket,
}
#[derive(Clone, Default)]
pub(crate) struct ReplyMap(Arc<Mutex<HashMap<ReplyToken, oneshot::Sender<Reply>>>>);
impl ReplyMap {
pub fn new_waiter(
&self,
host: IpAddr,
ident: Option<PingIdentifier>,
seq: PingSequence,
) -> Result<oneshot::Receiver<Reply>, SurgeError> {
let (tx, rx) = oneshot::channel();
if self
.0
.lock()
.insert(ReplyToken(host, ident, seq), tx)
.is_some()
{
return Err(SurgeError::IdenticalRequests { host, ident, seq });
}
Ok(rx)
}
pub(crate) fn remove(
&self,
host: IpAddr,
ident: Option<PingIdentifier>,
seq: PingSequence,
) -> Option<oneshot::Sender<Reply>> {
self.0.lock().remove(&ReplyToken(host, ident, seq))
}
}
#[derive(Clone)]
pub struct Client {
socket: AsyncSocket,
reply_map: ReplyMap,
recv: Arc<JoinHandle<()>>,
}
impl Drop for Client {
fn drop(&mut self) {
if Arc::strong_count(&self.recv) <= 1 {
self.recv.abort();
}
}
}
impl Client {
pub fn new(config: &Config) -> io::Result<Self> {
let socket = AsyncSocket::new(config)?;
let reply_map = ReplyMap::default();
let recv = task::spawn(recv_task(socket.clone(), reply_map.clone()));
Ok(Self {
socket,
reply_map,
recv: Arc::new(recv),
})
}
pub async fn pinger(&self, host: IpAddr, ident: PingIdentifier) -> Pinger {
Pinger::new(host, ident, self.socket.clone(), self.reply_map.clone())
}
pub fn get_socket(&self) -> AsyncSocket {
self.socket.clone()
}
}
async fn recv_task(socket: AsyncSocket, reply_map: ReplyMap) {
let mut buf = [0; 2048];
loop {
if let Ok((sz, addr)) = socket.recv_from(&mut buf).await {
let timestamp = Instant::now();
let message = &buf[..sz];
let local_addr = socket.local_addr().unwrap().ip();
let packet = {
let result = match addr.ip() {
IpAddr::V4(src_addr) => {
let local_addr_ip4 = match local_addr {
IpAddr::V4(local_addr_ip4) => local_addr_ip4,
_ => continue,
};
Icmpv4Packet::decode(message, socket.sock_type, src_addr, local_addr_ip4)
.map(IcmpPacket::V4)
}
IpAddr::V6(src_addr) => {
Icmpv6Packet::decode(message, src_addr).map(IcmpPacket::V6)
}
};
match result {
Ok(packet) => packet,
Err(err) => {
debug!("error decoding ICMP packet: {:?}", err);
continue;
}
}
};
let ident = if is_linux_icmp_socket!(socket.get_type()) {
None
} else {
Some(packet.get_identifier())
};
if let Some(waiter) = reply_map.remove(addr.ip(), ident, packet.get_sequence()) {
let _ = waiter.send(Reply { timestamp, packet });
} else {
debug!("no one is waiting for ICMP packet ({:?})", packet);
}
}
}
}