use alloc::vec;
use core::{
net::{IpAddr, Ipv4Addr},
time::Duration,
};
use netcore::smoltcp::{
phy::ChecksumCapabilities,
wire::{IPV4_HEADER_LEN, Icmpv4Packet, Icmpv4Repr, IpProtocol, Ipv4Packet, Ipv4Repr},
};
use crate::CreateSocket;
#[derive(Debug)]
pub enum PingError {
Timeout,
Ipv6Unsupported,
Net(netcore::Error),
}
impl core::fmt::Display for PingError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::Timeout => f.write_str("ping timed out"),
Self::Ipv6Unsupported => f.write_str("ICMPv6 ping is unsupported (IPv6 is off)"),
Self::Net(e) => write!(f, "netstack error: {e}"),
}
}
}
impl core::error::Error for PingError {
fn source(&self) -> Option<&(dyn core::error::Error + 'static)> {
match self {
Self::Net(e) => Some(e),
_ => None,
}
}
}
impl From<netcore::Error> for PingError {
fn from(e: netcore::Error) -> Self {
Self::Net(e)
}
}
const PING_PAYLOAD: &[u8] = b"ts_netstack_smoltcp ping";
#[cfg(feature = "tokio")]
static PING_IDENT_COUNTER: core::sync::atomic::AtomicU16 = core::sync::atomic::AtomicU16::new(0);
#[cfg(feature = "tokio")]
fn next_ident() -> u16 {
let counter = PING_IDENT_COUNTER.fetch_add(1, core::sync::atomic::Ordering::Relaxed);
let seed = (std::process::id() as u16) & 0xFF00;
seed ^ counter
}
#[cfg(feature = "tokio")]
pub async fn ping<C: CreateSocket + Sync>(
chan: &C,
src: Ipv4Addr,
dst: IpAddr,
timeout: Duration,
) -> Result<Duration, PingError> {
let dst = match dst {
IpAddr::V4(v4) => v4,
IpAddr::V6(_) => return Err(PingError::Ipv6Unsupported),
};
let ident: u16 = next_ident();
let seq_no: u16 = 1;
let sock = chan.raw_open(true, IpProtocol::Icmp).await?;
let request = build_echo_request(src, dst, ident, seq_no, PING_PAYLOAD);
let start = tokio::time::Instant::now();
sock.send(&request).await?;
let deadline = start + timeout;
loop {
let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
if remaining.is_zero() {
return Err(PingError::Timeout);
}
let recv = tokio::time::timeout(remaining, sock.recv_bytes()).await;
let bytes = match recv {
Err(_elapsed) => return Err(PingError::Timeout),
Ok(Ok(b)) => b,
Ok(Err(e)) => return Err(PingError::Net(e)),
};
if matches_reply(&bytes, src, dst, ident, seq_no) {
return Ok(start.elapsed());
}
}
}
fn build_echo_request(
src: Ipv4Addr,
dst: Ipv4Addr,
ident: u16,
seq_no: u16,
payload: &[u8],
) -> vec::Vec<u8> {
let checksum_caps = ChecksumCapabilities::default();
let icmp_repr = Icmpv4Repr::EchoRequest {
ident,
seq_no,
data: payload,
};
let ipv4_repr = Ipv4Repr {
src_addr: src,
dst_addr: dst,
next_header: IpProtocol::Icmp,
payload_len: icmp_repr.buffer_len(),
hop_limit: 64,
};
let total = IPV4_HEADER_LEN + icmp_repr.buffer_len();
let mut buf = vec![0u8; total];
{
let mut ip_packet = Ipv4Packet::new_unchecked(&mut buf[..]);
ipv4_repr.emit(&mut ip_packet, &checksum_caps);
}
{
let mut icmp_packet = Icmpv4Packet::new_unchecked(&mut buf[IPV4_HEADER_LEN..]);
icmp_repr.emit(&mut icmp_packet, &checksum_caps);
}
buf
}
fn matches_reply(
bytes: &[u8],
expect_src: Ipv4Addr,
expect_dst: Ipv4Addr,
ident: u16,
seq_no: u16,
) -> bool {
let checksum_caps = ChecksumCapabilities::default();
let Ok(ip_packet) = Ipv4Packet::new_checked(bytes) else {
return false;
};
let Ok(ipv4_repr) = Ipv4Repr::parse(&ip_packet, &checksum_caps) else {
return false;
};
if ipv4_repr.next_header != IpProtocol::Icmp {
return false;
}
if ipv4_repr.src_addr != expect_dst || ipv4_repr.dst_addr != expect_src {
return false;
}
let Ok(icmp_packet) = Icmpv4Packet::new_checked(ip_packet.payload()) else {
return false;
};
let Ok(icmp_repr) = Icmpv4Repr::parse(&icmp_packet, &checksum_caps) else {
return false;
};
matches!(
icmp_repr,
Icmpv4Repr::EchoReply { ident: i, seq_no: s, .. } if i == ident && s == seq_no
)
}
#[cfg(test)]
mod tests {
use super::*;
const SRC: Ipv4Addr = Ipv4Addr::new(100, 64, 0, 1);
const DST: Ipv4Addr = Ipv4Addr::new(100, 64, 0, 2);
fn build_echo_reply(from: Ipv4Addr, to: Ipv4Addr, ident: u16, seq_no: u16) -> vec::Vec<u8> {
let checksum_caps = ChecksumCapabilities::default();
let icmp_repr = Icmpv4Repr::EchoReply {
ident,
seq_no,
data: PING_PAYLOAD,
};
let ipv4_repr = Ipv4Repr {
src_addr: from,
dst_addr: to,
next_header: IpProtocol::Icmp,
payload_len: icmp_repr.buffer_len(),
hop_limit: 64,
};
let mut buf = vec![0u8; IPV4_HEADER_LEN + icmp_repr.buffer_len()];
{
let mut p = Ipv4Packet::new_unchecked(&mut buf[..]);
ipv4_repr.emit(&mut p, &checksum_caps);
}
{
let mut p = Icmpv4Packet::new_unchecked(&mut buf[IPV4_HEADER_LEN..]);
icmp_repr.emit(&mut p, &checksum_caps);
}
buf
}
#[test]
fn matches_reply_accepts_matching_ident_and_seq() {
let reply = build_echo_reply(DST, SRC, 0xABCD, 7);
assert!(matches_reply(&reply, SRC, DST, 0xABCD, 7));
}
#[test]
fn matches_reply_rejects_foreign_ident() {
let foreign = build_echo_reply(DST, SRC, 0x1111, 7);
assert!(!matches_reply(&foreign, SRC, DST, 0xABCD, 7));
}
#[test]
fn matches_reply_rejects_foreign_seq() {
let foreign = build_echo_reply(DST, SRC, 0xABCD, 99);
assert!(!matches_reply(&foreign, SRC, DST, 0xABCD, 7));
}
#[test]
fn matches_reply_rejects_non_echo_reply() {
let request = build_echo_request(DST, SRC, 0xABCD, 7, PING_PAYLOAD);
assert!(!matches_reply(&request, SRC, DST, 0xABCD, 7));
}
#[cfg(feature = "tokio")]
#[test]
fn next_ident_is_unique_for_concurrent_calls() {
let a = next_ident();
let b = next_ident();
let c = next_ident();
assert_ne!(a, b);
assert_ne!(b, c);
assert_ne!(a, c);
}
}