netpulse-cli 0.1.1

A zero-config, single-binary network quality monitor with percentile stats, jitter, and MTR-style traceroute
Documentation
// src/probers/icmp.rs — Raw ICMP Echo prober
//
// This is real systems programming: we bypass the OS's ICMP handling
// and construct/parse raw ICMP packets at the byte level using socket2.
// Requires CAP_NET_RAW or running as root.

use super::{ProbeResult, Prober};
use crate::error::NetPulseError;
use async_trait::async_trait;
use socket2::{Domain, Protocol, Socket, Type};
use std::net::{IpAddr, SocketAddr};
use std::time::{Duration, Instant};
use tokio::time::timeout;

const ICMP_TIME_EXCEEDED: u8 = 11;

/// ICMP Echo Request prober using raw sockets.
pub struct IcmpProber {
    /// Probe timeout in milliseconds.
    timeout_ms: u64,
}

impl IcmpProber {
    pub fn new(timeout_ms: u64) -> Self {
        Self { timeout_ms }
    }
}

// ─── ICMP Packet Layout ───────────────────────────────────────────────────────
//
//  0               1               2               3
//  0123456789ABCDEF0123456789ABCDEF0123456789ABCDEF0123456789ABCDEF
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// |     Type (8)  |     Code (0)  |          Checksum             |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// |           Identifier          |        Sequence Number        |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// |                Payload (timestamp — 8 bytes)                  |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+

const ICMP_ECHO_REQUEST: u8 = 8;
const ICMP_ECHO_REPLY: u8 = 0;
const ICMP_HEADER_SIZE: usize = 8;
const PAYLOAD_SIZE: usize = 8; // 8-byte timestamp payload
const PACKET_SIZE: usize = ICMP_HEADER_SIZE + PAYLOAD_SIZE;

/// Build an ICMP Echo Request packet for the given identifier and sequence.
fn build_icmp_packet(id: u16, seq: u16) -> [u8; PACKET_SIZE] {
    let mut packet = [0u8; PACKET_SIZE];

    // Type: 8 (Echo Request), Code: 0
    packet[0] = ICMP_ECHO_REQUEST;
    packet[1] = 0;

    // Checksum: zero for now, calculated below
    packet[2] = 0;
    packet[3] = 0;

    // Identifier (big-endian)
    packet[4] = (id >> 8) as u8;
    packet[5] = id as u8;

    // Sequence Number (big-endian)
    packet[6] = (seq >> 8) as u8;
    packet[7] = seq as u8;

    // Payload: current time in microseconds since epoch (8 bytes, big-endian)
    let now_us = std::time::SystemTime::now()
        .duration_since(std::time::UNIX_EPOCH)
        .unwrap_or_default()
        .as_micros() as u64;
    packet[8..].copy_from_slice(&now_us.to_be_bytes());

    // Calculate and insert checksum
    let checksum = icmp_checksum(&packet);
    packet[2] = (checksum >> 8) as u8;
    packet[3] = checksum as u8;

    packet
}

/// Internet checksum (RFC 1071).
fn icmp_checksum(data: &[u8]) -> u16 {
    let mut sum: u32 = 0;
    let mut i = 0;

    while i + 1 < data.len() {
        let word = u16::from_be_bytes([data[i], data[i + 1]]);
        sum += u32::from(word);
        i += 2;
    }

    // Handle odd byte
    if i < data.len() {
        sum += u32::from(data[i]) << 8;
    }

    // Fold 32-bit sum into 16 bits
    while sum >> 16 != 0 {
        sum = (sum & 0xFFFF) + (sum >> 16);
    }

    !(sum as u16)
}

/// Validate and parse an ICMP reply (Echo Reply or Time Exceeded).
/// Returns `Option<(u64, bool)>` where bool is true if it was a Time Exceeded.
fn parse_icmp_reply(buf: &[u8], expected_id: u16, expected_seq: u16) -> Option<(u64, bool)> {
    if buf.len() < 20 + PACKET_SIZE {
        return None;
    }

    let ip_header_len = ((buf[0] & 0x0F) as usize) * 4;
    let icmp_start = ip_header_len;

    if buf.len() < icmp_start + 8 {
        return None;
    }

    let icmp_type = buf[icmp_start];

    if icmp_type == ICMP_ECHO_REPLY {
        let icmp = &buf[icmp_start..];
        if icmp.len() < PACKET_SIZE {
            return None;
        }

        let reply_id = u16::from_be_bytes([icmp[4], icmp[5]]);
        let reply_seq = u16::from_be_bytes([icmp[6], icmp[7]]);

        if reply_id != expected_id || reply_seq != expected_seq {
            return None;
        }

        let sent_us = u64::from_be_bytes(icmp[8..16].try_into().ok()?);
        let now_us = std::time::SystemTime::now()
            .duration_since(std::time::UNIX_EPOCH)
            .unwrap_or_default()
            .as_micros() as u64;

        return Some((now_us.checked_sub(sent_us)?, false));
    } else if icmp_type == ICMP_TIME_EXCEEDED {
        // ICMP Time Exceeded payload contains:
        // Original IP header (variable length, usually 20)
        // Original ICMP header (8)
        // Original ICMP Payload (first 8 bytes)
        let inner_ip_start = icmp_start + 8;
        if buf.len() < inner_ip_start + 20 {
            return None; // Not enough room for inner IP header
        }

        let inner_ip_header_len = ((buf[inner_ip_start] & 0x0F) as usize) * 4;
        let inner_icmp_start = inner_ip_start + inner_ip_header_len;

        if buf.len() < inner_icmp_start + PACKET_SIZE {
            return None; // Not enough room for inner ICMP packet
        }

        let inner_icmp = &buf[inner_icmp_start..];
        let reply_id = u16::from_be_bytes([inner_icmp[4], inner_icmp[5]]);
        let reply_seq = u16::from_be_bytes([inner_icmp[6], inner_icmp[7]]);

        if reply_id != expected_id || reply_seq != expected_seq {
            return None;
        }

        let sent_us = u64::from_be_bytes(inner_icmp[8..16].try_into().ok()?);
        let now_us = std::time::SystemTime::now()
            .duration_since(std::time::UNIX_EPOCH)
            .unwrap_or_default()
            .as_micros() as u64;

        return Some((now_us.checked_sub(sent_us)?, true));
    }

    None
}

#[async_trait]
impl Prober for IcmpProber {
    fn name(&self) -> &'static str {
        "icmp"
    }

    async fn probe(
        &self,
        target: &str,
        seq: u64,
        ttl: Option<u32>,
    ) -> Result<ProbeResult, NetPulseError> {
        // Resolve target to an IP address
        let addr: IpAddr = target.parse().map_err(|_| NetPulseError::InvalidTarget {
            target: target.to_string(),
            reason: "must be an IP address for ICMP probing".to_string(),
        })?;

        let IpAddr::V4(_) = addr else {
            return Err(NetPulseError::InvalidTarget {
                target: target.to_string(),
                reason: "only IPv4 is currently supported".to_string(),
            });
        };

        // Use process ID as the ICMP identifier (truncated to u16)
        let id = (std::process::id() & 0xFFFF) as u16;
        let seq_u16 = (seq & 0xFFFF) as u16;

        // Build raw ICMP socket
        // We run the blocking socket operations on a spawn_blocking thread
        // so we don't block the Tokio async executor.
        let packet = build_icmp_packet(id, seq_u16);
        let dest = SocketAddr::new(addr, 0);
        let timeout_ms = self.timeout_ms;

        let result = timeout(
            Duration::from_millis(timeout_ms),
            tokio::task::spawn_blocking(move || {
                let socket =
                    Socket::new(Domain::IPV4, Type::RAW, Some(Protocol::ICMPV4)).map_err(|e| {
                        if e.raw_os_error() == Some(1) {
                            NetPulseError::InsufficientPrivileges
                        } else {
                            NetPulseError::SocketError(e)
                        }
                    })?;

                if let Some(t) = ttl {
                    let _ = socket.set_ttl(t);
                }

                socket
                    .set_read_timeout(Some(Duration::from_millis(timeout_ms)))
                    .map_err(NetPulseError::SocketError)?;

                let start = Instant::now();
                socket
                    .send_to(&packet, &dest.into())
                    .map_err(NetPulseError::SocketError)?;

                // Receive loop: discard packets that don't match our probe
                let mut recv_buf = [0u8; 1024];
                loop {
                    let (n, from_addr) = socket
                        .recv_from(unsafe {
                            &mut *(&mut recv_buf as *mut [u8; 1024]
                                as *mut [std::mem::MaybeUninit<u8>; 1024])
                        })
                        .map_err(|e| {
                            if e.kind() == std::io::ErrorKind::WouldBlock
                                || e.kind() == std::io::ErrorKind::TimedOut
                            {
                                NetPulseError::Timeout { timeout_ms }
                            } else {
                                NetPulseError::SocketError(e)
                            }
                        })?;

                    if let Some((rtt_us, _time_exceeded)) =
                        parse_icmp_reply(&recv_buf[..n], id, seq_u16)
                    {
                        let _ = start; // start is still available for reference
                        let src_ip = from_addr.as_socket().map(|s| s.ip().to_string());
                        return Ok::<(u64, Option<String>), NetPulseError>((rtt_us, src_ip));
                    }
                    // Not our packet — keep waiting
                }
            }),
        )
        .await;

        match result {
            Ok(Ok(Ok((rtt_us, responder_ip)))) => {
                Ok(ProbeResult::success(target, seq, rtt_us, responder_ip))
            }
            Ok(Ok(Err(e))) => {
                // Inner probe error — treat privilege errors as fatal, others as loss
                match e {
                    NetPulseError::InsufficientPrivileges => Err(e),
                    NetPulseError::Timeout { .. } => Ok(ProbeResult::loss(target, seq)),
                    _ => Ok(ProbeResult::loss(target, seq)),
                }
            }
            Ok(Err(_join_err)) => Ok(ProbeResult::loss(target, seq)),
            Err(_timeout) => Ok(ProbeResult::loss(target, seq)),
        }
    }
}