network_toolset 0.1.0

A comprehensive network diagnostic toolset implemented in Rust
Documentation
use anyhow::{Result, anyhow};
use socket2::{Socket, Domain, Type, Protocol, SockAddr};
use std::net::{IpAddr, Ipv4Addr};
use std::time::{Duration, Instant};
use std::io::{Read, Write};
use std::mem::MaybeUninit;

use crate::common::parse_hostname;

pub fn run_mtu_discover(target: &str, start_mtu: u16, max_probes: u32) -> Result<()> {
    println!("ICMP Path MTU Discovery to {} (starting from {} bytes)", target, start_mtu);

    let target_ip = parse_hostname(target)?;
    let target_ipv4 = match target_ip {
        IpAddr::V4(ip) => ip,
        _ => return Err(anyhow!("IPv6 not yet supported for ICMP MTU discovery")),
    };

    println!("Attempting true ICMP-based Path MTU Discovery...");
    println!("This method uses ICMP Echo packets with DF (Don't Fragment) flag");
    println!("and listens for 'Fragmentation Needed' messages for accuracy.\n");

    // Try ICMP-based discovery first
    match run_icmp_mtu_discovery(target_ipv4, start_mtu, max_probes) {
        Ok(mtu) => {
            println!("\n✓ True ICMP Path MTU Discovery successful!");
            println!("  Path MTU to {}: {} bytes", target, mtu);
            println!("  This is the accurate maximum packet size that can reach the target");
            return Ok(());
        }
        Err(e) => {
            println!("\n⚠ ICMP-based MTU discovery failed: {}", e);
            println!("Falling back to TCP-based estimation...\n");

            // Fall back to TCP-based method
            return run_tcp_fallback_mtu_discovery(&target_ip, start_mtu, max_probes);
        }
    }
}

fn run_icmp_mtu_discovery(target: Ipv4Addr, start_mtu: u16, max_probes: u32) -> Result<u16> {
    // Create raw ICMP socket for sending
    let send_socket = Socket::new(Domain::IPV4, Type::RAW, Some(Protocol::ICMPV4))?;
    send_socket.set_write_timeout(Some(Duration::from_secs(3)))?;

    // Create raw ICMP socket for receiving
    let recv_socket = Socket::new(Domain::IPV4, Type::RAW, Some(Protocol::ICMPV4))?;
    recv_socket.set_read_timeout(Some(Duration::from_secs(3)))?;

    // Bind receive socket to all interfaces
    recv_socket.bind(&SockAddr::from(std::net::SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)))?;

    let target_addr = SockAddr::from(std::net::SocketAddrV4::new(target, 0));

    // Binary search for MTU
    let mut lower_bound = 68u16; // Minimum IPv4 MTU
    let mut upper_bound = start_mtu.min(1500); // Standard Ethernet MTU
    let mut last_successful_mtu = lower_bound;
    let mut probe_count = 0;

    println!("Starting ICMP MTU discovery using binary search...");

    while lower_bound <= upper_bound && probe_count < max_probes {
        let mid_mtu = (lower_bound + upper_bound) / 2;
        probe_count += 1;

        println!("Testing MTU: {} bytes (range: {}-{})", mid_mtu, lower_bound, upper_bound);

        match test_icmp_mtu_packet(&send_socket, &recv_socket, &target_addr, mid_mtu) {
            Ok(_) => {
                // Packet succeeded - this MTU works
                println!("  ✓ Success - {} bytes reached target", mid_mtu);
                last_successful_mtu = mid_mtu;
                lower_bound = mid_mtu + 1;
            }
            Err(e) => {
                // Packet failed - check if it's fragmentation needed
                if e.to_string().contains("fragmentation") || e.to_string().contains("Fragmentation") {
                    println!("  ✗ Fragmentation needed - MTU too large");
                    upper_bound = mid_mtu - 1;
                } else {
                    println!("  ✗ Failed: {}", e);
                    upper_bound = mid_mtu - 1;
                }
            }
        }

        std::thread::sleep(Duration::from_millis(500));
    }

    if last_successful_mtu > 68 {
        Ok(last_successful_mtu)
    } else {
        Err(anyhow!("ICMP MTU discovery failed - all probes failed"))
    }
}

fn test_icmp_mtu_packet(
    send_socket: &Socket,
    recv_socket: &Socket,
    target_addr: &SockAddr,
    mtu: u16,
) -> Result<()> {
    // Create ICMP Echo Request packet
    let mut packet = create_icmp_packet(mtu)?;

    let start_time = Instant::now();

    // Send the packet
    match send_socket.send_to(&packet, target_addr) {
        Ok(_) => {
            // Wait for response (Echo Reply or Fragmentation Needed)
            let mut response_buffer = [MaybeUninit::uninit(); 1500];

            match recv_socket.recv(&mut response_buffer) {
                Ok(_) => {
                    let elapsed = start_time.elapsed();
                    if elapsed < Duration::from_secs(3) {
                        // Got some response - packet reached target
                        return Ok(());
                    } else {
                        return Err(anyhow!("Timeout waiting for ICMP response"));
                    }
                }
                Err(e) => {
                    return Err(anyhow!("Failed to receive ICMP response: {}", e));
                }
            }
        }
        Err(e) => {
            return Err(anyhow!("Failed to send ICMP packet: {}", e));
        }
    }
}

fn create_icmp_packet(mtu: u16) -> Result<Vec<u8>> {
    let mut packet = Vec::new();

    // ICMP Header (8 bytes for IPv4)
    // Type: 8 (Echo Request)
    packet.push(8);
    // Code: 0
    packet.push(0);
    // Checksum: placeholder for now
    packet.push(0);
    packet.push(0);
    // Identifier: use process ID
    let pid = std::process::id() as u16;
    packet.extend_from_slice(&pid.to_be_bytes());
    // Sequence number: use MTU as identifier
    packet.extend_from_slice(&mtu.to_be_bytes());

    // Data payload - fill remaining space
    let header_size = 8;
    let data_size = if mtu > header_size {
        (mtu - header_size) as usize
    } else {
        0
    };

    for i in 0..data_size {
        packet.push((i % 256) as u8);
    }

    // Calculate and set checksum
    let checksum = calculate_icmp_checksum(&packet);
    packet[2] = (checksum >> 8) as u8;
    packet[3] = (checksum & 0xFF) as u8;

    Ok(packet)
}

fn calculate_icmp_checksum(packet: &[u8]) -> u16 {
    let mut sum: u32 = 0;

    // Process 16-bit words
    for chunk in packet.chunks_exact(2) {
        let word = u16::from_be_bytes([chunk[0], chunk[1]]) as u32;
        sum += word;
    }

    // Handle odd byte
    if packet.len() % 2 == 1 {
        sum += (packet[packet.len() - 1] as u32) << 8;
    }

    // Add carry bits
    while sum >> 16 != 0 {
        sum = (sum & 0xFFFF) + (sum >> 16);
    }

    // One's complement
    !sum as u16
}

fn run_tcp_fallback_mtu_discovery(target_ip: &IpAddr, start_mtu: u16, max_probes: u32) -> Result<()> {
    println!("Running TCP-based MTU discovery as fallback...");

    // Use the first available port for testing
    let working_ports = vec![
        20, 21, 22, 23, 25, 53, 80, 110, 143, 443, 465, 587, 636, 993, 995,
        1433, 3306, 3389, 5432, 6379, 8080, 8443, 9000, 9090
    ];
    let mut working_port = None;

    println!("Testing connectivity on common ports...");
    for &port in &working_ports {
        match test_tcp_connectivity(target_ip, port, Duration::from_secs(2)) {
            Ok(_) => {
                println!("  ✓ Port {} is accessible", port);
                working_port = Some(port);
                break;
            }
            Err(_) => {
                println!("  ✗ Port {} not accessible", port);
            }
        }
    }

    if working_port.is_none() {
        return Err(anyhow!("No accessible TCP ports found for fallback MTU discovery"));
    }

    // Binary search for MTU using TCP
    let mut lower_bound = 68u16;
    let mut upper_bound = start_mtu.min(1500);
    let mut last_successful_mtu = lower_bound;
    let mut probe_count = 0;

    let working_port = working_port.unwrap();
    println!("\nStarting TCP-based MTU estimation on port {}...", working_port);

    while lower_bound <= upper_bound && probe_count < max_probes {
        let mid_mtu = (lower_bound + upper_bound) / 2;
        probe_count += 1;

        println!("Testing packet size: {} bytes (range: {}-{})", mid_mtu, lower_bound, upper_bound);

        match test_tcp_mtu_discovery(target_ip, working_port, mid_mtu) {
            Ok(_) => {
                println!("  ✓ Success - {} bytes transmitted successfully", mid_mtu);
                last_successful_mtu = mid_mtu;
                lower_bound = mid_mtu + 1;
            }
            Err(e) => {
                println!("  ✗ Failed: {}", e);
                upper_bound = mid_mtu - 1;
            }
        }

        std::thread::sleep(Duration::from_millis(200));
    }

    if last_successful_mtu > 68 {
        println!("\nTCP-based MTU estimation completed:");
        println!("  Estimated Path MTU: {} bytes", last_successful_mtu);
        println!("  Discovery completed in {} probes", probe_count);
        println!("\nNote: This is an estimation using TCP connectivity testing.");
        println!("For more accurate results, true ICMP-based MTU discovery is recommended.");
    } else {
        return Err(anyhow!("TCP-based MTU discovery failed"));
    }

    Ok(())
}

fn test_tcp_connectivity(target_ip: &IpAddr, port: u16, timeout: Duration) -> Result<()> {
    let socket_addr = std::net::SocketAddr::new(*target_ip, port);

    match std::net::TcpStream::connect_timeout(&socket_addr, timeout) {
        Ok(_) => Ok(()),
        Err(e) => Err(anyhow!("TCP connection failed: {}", e))
    }
}

fn test_tcp_mtu_discovery(target_ip: &IpAddr, port: u16, packet_size: u16) -> Result<()> {
    let socket_addr = std::net::SocketAddr::new(*target_ip, port);

    // Create new socket for each test (socket reuse issues)
    match std::net::TcpStream::connect_timeout(&socket_addr, Duration::from_secs(3)) {
        Ok(mut stream) => {
            // Set socket options for better MTU simulation
            stream.set_read_timeout(Some(Duration::from_secs(3))).ok();
            stream.set_write_timeout(Some(Duration::from_secs(3))).ok();

            // Create test data of specified size (but limit to reasonable size)
            let test_size = packet_size.min(4096) as usize; // Cap at 4KB to avoid memory issues
            let test_data = vec![0u8; test_size];

            // Try to send the data
            match stream.write_all(&test_data) {
                Ok(_) => {
                    // Try to read some response (some services will respond)
                    let mut response_buffer = [0u8; 1024];
                    match stream.read(&mut response_buffer) {
                        Ok(_) => Ok(()), // Got some response
                        Err(_) => Ok(()), // No response but connection worked
                    }
                }
                Err(e) => {
                    return Err(anyhow!("Failed to send {} bytes: {}", test_size, e));
                }
            }
        }
        Err(e) => {
            return Err(anyhow!("Connection failed for {} bytes: {}", packet_size, e));
        }
    }
}