use crate::ping_error::PingError;
use log::trace;
use socket2::{Domain, Protocol, SockAddr, Socket, Type};
use std::mem::MaybeUninit;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::sync::atomic::{AtomicU16, Ordering};
use std::time::Duration;
const ICMP_V4_HEADER_LENGTH: usize = 20;
fn checksum_v4(buf: &[u8]) -> u16 {
let mut sum: u32 = 0;
for chunk in buf.chunks(2) {
let word = u16::from_be_bytes([chunk[0], chunk.get(1).copied().unwrap_or(0)]);
sum = sum.wrapping_add(word as u32);
}
while (sum >> 16) != 0 {
sum = (sum & 0xffff) + (sum >> 16);
}
!(sum as u16)
}
fn build_icmp_v4_echo(id: u16, seq: u16) -> Vec<u8> {
let mut buf = vec![0; 8];
buf[0] = 8; buf[1] = 0; buf[4..6].copy_from_slice(&id.to_be_bytes());
buf[6..8].copy_from_slice(&seq.to_be_bytes());
let checksum = checksum_v4(&buf);
buf[2..4].copy_from_slice(&checksum.to_be_bytes());
buf
}
fn build_icmp_v6_echo(id: u16, seq: u16) -> Vec<u8> {
let mut buf = vec![0; 8];
buf[0] = 128; buf[1] = 0; buf[4..6].copy_from_slice(&id.to_be_bytes());
buf[6..8].copy_from_slice(&seq.to_be_bytes());
buf
}
pub struct IcmpPing {
id: u16,
seq: AtomicU16,
}
impl Clone for IcmpPing {
fn clone(&self) -> Self {
Self {
id: self.id,
seq: AtomicU16::new(self.seq.load(Ordering::Relaxed)),
}
}
}
impl IcmpPing {
pub fn new() -> Self {
Self {
id: std::process::id() as u16,
seq: 0.into(),
}
}
pub fn ping(&self, dst_ip: IpAddr, timeout: Duration) -> Result<(), PingError> {
trace!("ping {} ....", dst_ip);
let (domain, protocol, src_ip) = match dst_ip {
IpAddr::V4(_) => (
Domain::IPV4,
Protocol::ICMPV4,
IpAddr::V4(Ipv4Addr::UNSPECIFIED),
),
IpAddr::V6(_) => (
Domain::IPV6,
Protocol::ICMPV6,
IpAddr::V6(Ipv6Addr::UNSPECIFIED),
),
};
let sock = Socket::new(domain, Type::RAW, Some(protocol))?;
sock.set_read_timeout(Some(timeout))?;
sock.set_write_timeout(Some(timeout))?;
let src_addr = SockAddr::from(SocketAddr::new(src_ip, 0));
sock.bind(&src_addr)?;
let dst_addr = SockAddr::from(SocketAddr::new(dst_ip, 0));
let seq = self.seq.fetch_add(1, Ordering::Relaxed).wrapping_add(1);
let packet = match dst_ip {
IpAddr::V4(_) => build_icmp_v4_echo(self.id, seq),
IpAddr::V6(_) => build_icmp_v6_echo(self.id, seq),
};
sock.send_to(&packet, &dst_addr)?;
let mut buf: [MaybeUninit<u8>; 1024] = unsafe { MaybeUninit::uninit().assume_init() };
let (len, _) = sock.recv_from(&mut buf)?;
let reply = unsafe { std::slice::from_raw_parts(buf.as_ptr() as *const u8, len) };
match dst_ip {
IpAddr::V4(_) => {
if reply.len() != 28 {
return Err(PingError::InvalidReply(format!(
"ICMPv4 reply length expect 28 bytes but {} bytes",
reply.len()
)));
}
if &reply[(ICMP_V4_HEADER_LENGTH + 4)..(ICMP_V4_HEADER_LENGTH + 8)] != &packet[4..8]
{
return Err(PingError::InvalidReply(format!(
"ICMPv4 reply data is not correct! send-{:?} received-{:?}",
packet, reply
)));
}
}
IpAddr::V6(_) => {
if reply.len() != 8 {
return Err(PingError::InvalidReply(format!(
"ICMPv6 reply length expect 8 bytes but {} bytes",
reply.len(),
)));
}
if &reply[4..8] != &packet[4..8] {
return Err(PingError::InvalidReply(format!(
"ICMPv6 reply data is not correct! send-{:?} received-{:?}",
packet, reply
)));
}
}
}
trace!("ping {} success", dst_ip);
Ok(())
}
}