#![doc = include_str!("../README.md")]
#![deny(missing_docs)]
#![allow(
clippy::doc_markdown,
clippy::missing_errors_doc,
clippy::missing_panics_doc,
clippy::must_use_candidate,
clippy::cast_possible_truncation,
clippy::cast_possible_wrap,
clippy::cast_sign_loss
)]
pub mod addr;
mod net;
pub(crate) mod time;
use std::{
mem::MaybeUninit,
net::{Ipv4Addr, Ipv6Addr, SocketAddrV6},
sync::{
atomic::{AtomicU16, Ordering},
LazyLock,
},
time::Duration,
};
pub use net::IcmpSocket;
use socket2::{MaybeUninitSlice, SockAddr};
use tokio::time::timeout;
use crate::{addr::ToIpAddr, net::MsgHdrMut};
const IP_HEADER_SIZE: usize = 20;
const ICMP_HEADER_SIZE: usize = 8;
const ICMP_ECHO_REQUEST: u8 = 8;
const ICMP_ECHO_REPLY: u8 = 0;
const ICMP6_ECHO_REQUEST: u8 = 128;
const ICMP6_ECHO_REPLY: u8 = 129;
static REQ_ID: LazyLock<AtomicU16> = LazyLock::new(|| {
let pid = u64::from(std::process::id());
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_or(0, |d| u64::from(d.subsec_nanos()));
#[allow(clippy::cast_possible_truncation)]
AtomicU16::new(((pid ^ nanos ^ (nanos >> 16)) & 0xffff) as u16)
});
#[derive(Clone, Copy, Debug)]
pub struct PingStats {
pub packets_tx: u32,
pub packets_rx: u32,
pub rtt_min: Duration,
pub rtt_avg: Duration,
pub rtt_max: Duration,
pub rtt_std_dev: Duration,
}
#[derive(Clone, Copy, Debug)]
pub struct IcmpEchoReply {
pub src_addr: Ipv4Addr,
pub len: usize,
pub seq: u16,
pub ttl: u8,
pub rtt: Duration,
}
pub async fn ping<A: ToIpAddr>(
src: A,
dest: A,
count: u32,
interval: Duration,
size: u16,
) -> std::io::Result<PingStats> {
use std::net::IpAddr;
let dest = dest.to_ip_addr().await?;
let ts_len = time::Timestamp::len();
if (size as usize) <= ts_len {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("size must be greater than {ts_len} (timestamp bytes)"),
));
}
let payload = generate_payload(size as usize - ts_len);
let tout = Duration::from_secs(5);
let socket = IcmpSocket::bind(src).await?;
socket.connect(dest).await?;
let mut packets_rx: u32 = 0;
let mut rtts: Vec<Duration> = Vec::with_capacity(count as usize);
for seq in 1..=count {
let result = match dest {
IpAddr::V4(_) => send_icmp_echo_v4(&socket, &payload, seq as u16, tout)
.await
.map(|r| r.rtt),
IpAddr::V6(_) => send_icmp_echo_v6(&socket, &payload, seq as u16, tout)
.await
.map(|r| r.rtt),
};
if let Ok(rtt) = result {
packets_rx += 1;
rtts.push(rtt);
}
if seq < count {
tokio::time::sleep(interval).await;
}
}
let packets_tx = count;
let mut stats = compute_rtt_stats(&rtts);
stats.packets_tx = packets_tx;
stats.packets_rx = packets_rx;
Ok(stats)
}
fn compute_rtt_stats(rtts: &[Duration]) -> PingStats {
let (rtt_min, rtt_avg, rtt_max, rtt_std_dev) = if rtts.is_empty() {
(
Duration::ZERO,
Duration::ZERO,
Duration::ZERO,
Duration::ZERO,
)
} else {
let min = *rtts.iter().min().unwrap();
let max = *rtts.iter().max().unwrap();
let avg_nanos = rtts.iter().map(|d| d.as_nanos() as u64).sum::<u64>() / rtts.len() as u64;
let avg = Duration::from_nanos(avg_nanos);
let variance = rtts
.iter()
.map(|d| {
let diff = d.as_nanos() as i64 - avg_nanos as i64;
(diff * diff) as u64
})
.sum::<u64>()
/ rtts.len() as u64;
let std_dev = Duration::from_nanos(variance.isqrt());
(min, avg, max, std_dev)
};
PingStats {
packets_tx: 0,
packets_rx: 0,
rtt_min,
rtt_avg,
rtt_max,
rtt_std_dev,
}
}
#[derive(Clone, Copy, Debug)]
pub struct IcmpV6EchoReply {
pub src_addr: Ipv6Addr,
pub len: usize,
pub seq: u16,
pub hlim: u8,
pub rtt: Duration,
}
pub async fn send_icmp_echo_v4(
socket: &IcmpSocket,
payload: &[u8],
seq: u16,
tout: Duration,
) -> std::io::Result<IcmpEchoReply> {
let mut buf: Vec<u8> = Vec::with_capacity(
IP_HEADER_SIZE + ICMP_HEADER_SIZE + time::Timestamp::len() + payload.len(),
);
let req_id = REQ_ID.fetch_add(1, Ordering::Relaxed);
add_icmp_header(&mut buf, ICMP_ECHO_REQUEST, req_id, seq);
let sent_ts_bytes = time::Timestamp::now().as_bytes();
buf.extend_from_slice(&sent_ts_bytes);
buf.extend_from_slice(payload);
let checksum = calculate_checksum(&buf);
buf[2] = (checksum >> 8) as u8;
buf[3] = (checksum & 0xff) as u8;
socket.send(&buf).await?;
let overall = timeout(tout, async {
loop {
buf.clear();
let received = socket.recv(buf.spare_capacity_mut()).await?;
unsafe { buf.set_len(received) };
if received < IP_HEADER_SIZE + ICMP_HEADER_SIZE + time::Timestamp::len() {
continue;
}
let msg_type = buf[IP_HEADER_SIZE];
if msg_type != ICMP_ECHO_REPLY {
continue;
}
let reply_id = u16::from_be_bytes([buf[IP_HEADER_SIZE + 4], buf[IP_HEADER_SIZE + 5]]);
if req_id != reply_id {
continue;
}
let reply_seq = u16::from_be_bytes([buf[IP_HEADER_SIZE + 6], buf[IP_HEADER_SIZE + 7]]);
if reply_seq != seq {
continue;
}
let ts_start = IP_HEADER_SIZE + ICMP_HEADER_SIZE;
let ts_end = ts_start + time::Timestamp::len();
if buf[ts_start..ts_end] != sent_ts_bytes {
continue;
}
let now = time::Timestamp::now();
let src_addr = Ipv4Addr::new(
buf[IP_HEADER_SIZE - 8],
buf[IP_HEADER_SIZE - 7],
buf[IP_HEADER_SIZE - 6],
buf[IP_HEADER_SIZE - 5],
);
let reply_ttl = buf[8];
let reply_ts =
time::Timestamp::from(<[u8; 8]>::try_from(&buf[ts_start..ts_end]).unwrap());
let rtt = now - reply_ts;
return Ok(IcmpEchoReply {
src_addr,
len: received - IP_HEADER_SIZE,
seq: reply_seq,
ttl: reply_ttl,
rtt,
});
}
});
match overall.await {
Ok(result) => result,
Err(_) => Err(std::io::Error::new(
std::io::ErrorKind::TimedOut,
"timed out",
)),
}
}
pub async fn send_icmp_echo_v6(
socket: &IcmpSocket,
payload: &[u8],
seq: u16,
tout: Duration,
) -> std::io::Result<IcmpV6EchoReply> {
let mut buf: Vec<u8> =
Vec::with_capacity(ICMP_HEADER_SIZE + time::Timestamp::len() + payload.len());
let req_id = REQ_ID.fetch_add(1, Ordering::Relaxed);
add_icmp_header(&mut buf, ICMP6_ECHO_REQUEST, req_id, seq);
let sent_ts_bytes = time::Timestamp::now().as_bytes();
buf.extend_from_slice(&sent_ts_bytes);
buf.extend_from_slice(payload);
socket.send(&buf).await?;
let mut from: SockAddr = SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0u16, 0, 0).into();
let mut control_storage: [MaybeUninit<u64>; 8] = [MaybeUninit::uninit(); 8];
let overall = timeout(tout, async {
loop {
buf.clear();
let (received, flags, reply_hlim_opt) = {
let bufs = &mut [MaybeUninitSlice::new(buf.spare_capacity_mut())];
let control_bytes: &mut [MaybeUninit<u8>] = unsafe {
std::slice::from_raw_parts_mut(
control_storage.as_mut_ptr().cast::<MaybeUninit<u8>>(),
std::mem::size_of_val(&control_storage),
)
};
let mut msg = MsgHdrMut::new()
.with_addr(&mut from)
.with_control(control_bytes)
.with_buffers(bufs);
let received = socket.recvmsg(&mut msg).await?;
let flags = msg.flags();
let hlim = decode_hlim(&msg);
(received, flags, hlim)
};
unsafe { buf.set_len(received) };
if flags & libc::MSG_CTRUNC != 0 {
return Err(std::io::Error::other(
"recvmsg control buffer truncated (MSG_CTRUNC)",
));
}
if received < ICMP_HEADER_SIZE + time::Timestamp::len() {
continue;
}
let msg_type = buf[0];
if msg_type != ICMP6_ECHO_REPLY {
continue;
}
let reply_id = u16::from_be_bytes([buf[4], buf[5]]);
if req_id != reply_id {
continue;
}
let reply_seq = u16::from_be_bytes([buf[6], buf[7]]);
if reply_seq != seq {
continue;
}
let ts_end = ICMP_HEADER_SIZE + time::Timestamp::len();
if buf[ICMP_HEADER_SIZE..ts_end] != sent_ts_bytes {
continue;
}
let now = time::Timestamp::now();
let src_addr = from.as_socket_ipv6().map(|s| *s.ip()).ok_or_else(|| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
"recvmsg returned no source address",
)
})?;
let reply_hlim = reply_hlim_opt.ok_or_else(|| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
"reply missing IPV6_HOPLIMIT control message",
)
})?;
let reply_ts =
time::Timestamp::from(<[u8; 8]>::try_from(&buf[ICMP_HEADER_SIZE..ts_end]).unwrap());
let rtt = now - reply_ts;
return Ok(IcmpV6EchoReply {
src_addr,
len: received,
seq: reply_seq,
hlim: reply_hlim,
rtt,
});
}
});
match overall.await {
Ok(result) => result,
Err(_) => Err(std::io::Error::new(
std::io::ErrorKind::TimedOut,
"timed out",
)),
}
}
#[allow(clippy::cast_possible_truncation)]
pub fn generate_payload(size: usize) -> Vec<u8> {
(0..size).map(|i| (i % 256) as u8).collect()
}
fn add_icmp_header(buf: &mut Vec<u8>, typ: u8, id: u16, seq: u16) {
buf.push(typ);
buf.push(0);
buf.push(0);
buf.push(0);
#[cfg(target_endian = "big")]
{
buf.push((id & 0xff) as u8);
buf.push((id >> 8) as u8);
}
#[cfg(not(target_endian = "big"))]
{
buf.push((id >> 8) as u8);
buf.push((id & 0xff) as u8);
}
#[cfg(target_endian = "big")]
{
buf.push((seq & 0xff) as u8);
buf.push((seq >> 8) as u8);
}
#[cfg(not(target_endian = "big"))]
{
buf.push((seq >> 8) as u8);
buf.push((seq & 0xff) as u8);
}
}
fn calculate_checksum(data: &[u8]) -> u16 {
let mut sum: u32 = 0;
let mut i = 0;
while i < data.len() - 1 {
let word = u32::from(u16::from_be_bytes([data[i], data[i + 1]]));
sum += word;
i += 2;
}
if data.len() % 2 == 1 {
sum += u32::from(data[data.len() - 1]) << 8;
}
while sum >> 16 != 0 {
sum = (sum & 0xffff) + (sum >> 16);
}
#[allow(clippy::cast_possible_truncation)]
{
!sum as u16
}
}
fn decode_hlim(msg: &MsgHdrMut<'_, '_, '_>) -> Option<u8> {
let hdr = msg.as_msghdr();
let want_len = unsafe { libc::CMSG_LEN(size_of::<libc::c_int>() as u32) } as usize;
let mut p = unsafe { libc::CMSG_FIRSTHDR(hdr) };
while !p.is_null() {
let h = unsafe { &*p };
if h.cmsg_level == libc::IPPROTO_IPV6
&& h.cmsg_type == libc::IPV6_HOPLIMIT
&& h.cmsg_len as usize >= want_len
{
let mut value = MaybeUninit::<libc::c_int>::uninit();
let hlim = unsafe {
std::ptr::copy_nonoverlapping(
libc::CMSG_DATA(p),
value.as_mut_ptr().cast::<u8>(),
size_of::<libc::c_int>(),
);
value.assume_init()
};
return u8::try_from(hlim).ok();
}
p = unsafe { libc::CMSG_NXTHDR(hdr, p) };
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn add_icmp_header_writes_8_bytes() {
let mut buf = Vec::new();
add_icmp_header(&mut buf, 8, 0, 0);
assert_eq!(buf.len(), 8);
}
#[test]
fn add_icmp_header_type_field() {
let mut buf = Vec::new();
add_icmp_header(&mut buf, 0x08, 0, 0);
assert_eq!(buf[0], 0x08);
}
#[test]
fn add_icmp_header_code_is_zero() {
let mut buf = Vec::new();
add_icmp_header(&mut buf, 8, 0xffff, 0xffff);
assert_eq!(buf[1], 0);
}
#[test]
fn add_icmp_header_id_big_endian() {
let mut buf = Vec::new();
add_icmp_header(&mut buf, 8, 0x1234, 0);
assert_eq!(buf[4], 0x12);
assert_eq!(buf[5], 0x34);
}
#[test]
fn add_icmp_header_seq_big_endian() {
let mut buf = Vec::new();
add_icmp_header(&mut buf, 8, 0, 1);
assert_eq!(buf[6], 0);
assert_eq!(buf[7], 1);
}
#[test]
fn add_icmp_header_appends_to_existing_content() {
let mut buf = vec![0xde, 0xad];
add_icmp_header(&mut buf, 8, 0, 0);
assert_eq!(buf.len(), 10);
assert_eq!(&buf[..2], &[0xde, 0xad]);
}
#[test]
fn test_checksum() {
let data = vec![0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00];
let checksum = calculate_checksum(&data);
assert_eq!(checksum, 0xf7ff);
let data = vec![0x00, 0x01, 0x02];
let checksum = calculate_checksum(&data);
assert_eq!(checksum, 0xfdfe);
}
#[test]
fn test_compute_rtt_stats_empty() {
let stats = compute_rtt_stats(&[]);
assert_eq!(stats.rtt_min, Duration::ZERO);
assert_eq!(stats.rtt_avg, Duration::ZERO);
assert_eq!(stats.rtt_max, Duration::ZERO);
assert_eq!(stats.rtt_std_dev, Duration::ZERO);
}
#[test]
fn test_compute_rtt_stats_single() {
let rtts = vec![Duration::from_millis(10)];
let stats = compute_rtt_stats(&rtts);
assert_eq!(stats.rtt_min, Duration::from_millis(10));
assert_eq!(stats.rtt_avg, Duration::from_millis(10));
assert_eq!(stats.rtt_max, Duration::from_millis(10));
assert_eq!(stats.rtt_std_dev, Duration::ZERO);
}
#[test]
fn test_compute_rtt_stats_multiple() {
let rtts = vec![
Duration::from_millis(10),
Duration::from_millis(20),
Duration::from_millis(30),
];
let stats = compute_rtt_stats(&rtts);
assert_eq!(stats.rtt_min, Duration::from_millis(10));
assert_eq!(stats.rtt_max, Duration::from_millis(30));
assert_eq!(stats.rtt_avg, Duration::from_millis(20));
let expected_std_dev_nanos: u64 = {
let avg_ns: u64 = 20_000_000;
let variance = [10_000_000u64, 20_000_000u64, 30_000_000u64]
.iter()
.map(|&d| {
let diff = d as i64 - avg_ns as i64;
(diff * diff) as u64
})
.sum::<u64>()
/ 3;
variance.isqrt()
};
assert_eq!(
stats.rtt_std_dev,
Duration::from_nanos(expected_std_dev_nanos)
);
}
#[test]
fn test_compute_rtt_stats_identical() {
let rtts = vec![
Duration::from_millis(5),
Duration::from_millis(5),
Duration::from_millis(5),
];
let stats = compute_rtt_stats(&rtts);
assert_eq!(stats.rtt_min, Duration::from_millis(5));
assert_eq!(stats.rtt_avg, Duration::from_millis(5));
assert_eq!(stats.rtt_max, Duration::from_millis(5));
assert_eq!(stats.rtt_std_dev, Duration::ZERO);
}
#[tokio::test]
async fn test_send_icmp_echo_v4() {
let sock = IcmpSocket::bind(Ipv4Addr::UNSPECIFIED).await.unwrap();
sock.connect("127.0.0.1").await.unwrap();
let payload = generate_payload(48);
let reply = send_icmp_echo_v4(&sock, &payload, 1, Duration::from_secs(5))
.await
.unwrap();
assert_eq!(reply.src_addr, Ipv4Addr::LOCALHOST);
assert_eq!(reply.len, 64);
assert_eq!(reply.seq, 1);
assert!(reply.ttl > 0);
assert!(reply.rtt > Duration::ZERO);
}
#[tokio::test]
async fn test_send_icmp_echo_v6() {
let sock = IcmpSocket::bind(Ipv6Addr::UNSPECIFIED).await.unwrap();
sock.connect("::1").await.unwrap();
let payload = [];
let reply = send_icmp_echo_v6(&sock, &payload, 1, Duration::from_secs(5))
.await
.unwrap();
assert_eq!(reply.src_addr, Ipv6Addr::LOCALHOST);
assert_eq!(reply.len, 16);
assert_eq!(reply.seq, 1);
assert!(reply.hlim > 0);
assert!(reply.rtt > Duration::ZERO);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn test_send_icmp_echo_v6_send() {
let reply = tokio::task::spawn(async {
let sock = IcmpSocket::bind(Ipv6Addr::UNSPECIFIED).await.unwrap();
sock.connect("::1").await.unwrap();
let payload = [];
let reply = send_icmp_echo_v6(&sock, &payload, 1, Duration::from_secs(5))
.await
.unwrap();
reply
})
.await
.unwrap();
assert_eq!(reply.src_addr, Ipv6Addr::LOCALHOST);
assert_eq!(reply.len, 16);
assert_eq!(reply.seq, 1);
assert!(reply.hlim > 0);
assert!(reply.rtt > Duration::ZERO);
}
}