use anyhow::{Result, anyhow};
use socket2::{Socket, Domain, Type, Protocol, SockAddr};
use std::net::{IpAddr, Ipv4Addr};
use std::time::{Duration, Instant};
use std::io::{Read, Write};
use std::mem::MaybeUninit;
use crate::common::parse_hostname;
pub fn run_mtu_discover(target: &str, start_mtu: u16, max_probes: u32) -> Result<()> {
println!("ICMP Path MTU Discovery to {} (starting from {} bytes)", target, start_mtu);
let target_ip = parse_hostname(target)?;
let target_ipv4 = match target_ip {
IpAddr::V4(ip) => ip,
_ => return Err(anyhow!("IPv6 not yet supported for ICMP MTU discovery")),
};
println!("Attempting true ICMP-based Path MTU Discovery...");
println!("This method uses ICMP Echo packets with DF (Don't Fragment) flag");
println!("and listens for 'Fragmentation Needed' messages for accuracy.\n");
match run_icmp_mtu_discovery(target_ipv4, start_mtu, max_probes) {
Ok(mtu) => {
println!("\n✓ True ICMP Path MTU Discovery successful!");
println!(" Path MTU to {}: {} bytes", target, mtu);
println!(" This is the accurate maximum packet size that can reach the target");
return Ok(());
}
Err(e) => {
println!("\n⚠ ICMP-based MTU discovery failed: {}", e);
println!("Falling back to TCP-based estimation...\n");
return run_tcp_fallback_mtu_discovery(&target_ip, start_mtu, max_probes);
}
}
}
fn run_icmp_mtu_discovery(target: Ipv4Addr, start_mtu: u16, max_probes: u32) -> Result<u16> {
let send_socket = Socket::new(Domain::IPV4, Type::RAW, Some(Protocol::ICMPV4))?;
send_socket.set_write_timeout(Some(Duration::from_secs(3)))?;
let recv_socket = Socket::new(Domain::IPV4, Type::RAW, Some(Protocol::ICMPV4))?;
recv_socket.set_read_timeout(Some(Duration::from_secs(3)))?;
recv_socket.bind(&SockAddr::from(std::net::SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)))?;
let target_addr = SockAddr::from(std::net::SocketAddrV4::new(target, 0));
let mut lower_bound = 68u16; let mut upper_bound = start_mtu.min(1500); let mut last_successful_mtu = lower_bound;
let mut probe_count = 0;
println!("Starting ICMP MTU discovery using binary search...");
while lower_bound <= upper_bound && probe_count < max_probes {
let mid_mtu = (lower_bound + upper_bound) / 2;
probe_count += 1;
println!("Testing MTU: {} bytes (range: {}-{})", mid_mtu, lower_bound, upper_bound);
match test_icmp_mtu_packet(&send_socket, &recv_socket, &target_addr, mid_mtu) {
Ok(_) => {
println!(" ✓ Success - {} bytes reached target", mid_mtu);
last_successful_mtu = mid_mtu;
lower_bound = mid_mtu + 1;
}
Err(e) => {
if e.to_string().contains("fragmentation") || e.to_string().contains("Fragmentation") {
println!(" ✗ Fragmentation needed - MTU too large");
upper_bound = mid_mtu - 1;
} else {
println!(" ✗ Failed: {}", e);
upper_bound = mid_mtu - 1;
}
}
}
std::thread::sleep(Duration::from_millis(500));
}
if last_successful_mtu > 68 {
Ok(last_successful_mtu)
} else {
Err(anyhow!("ICMP MTU discovery failed - all probes failed"))
}
}
fn test_icmp_mtu_packet(
send_socket: &Socket,
recv_socket: &Socket,
target_addr: &SockAddr,
mtu: u16,
) -> Result<()> {
let mut packet = create_icmp_packet(mtu)?;
let start_time = Instant::now();
match send_socket.send_to(&packet, target_addr) {
Ok(_) => {
let mut response_buffer = [MaybeUninit::uninit(); 1500];
match recv_socket.recv(&mut response_buffer) {
Ok(_) => {
let elapsed = start_time.elapsed();
if elapsed < Duration::from_secs(3) {
return Ok(());
} else {
return Err(anyhow!("Timeout waiting for ICMP response"));
}
}
Err(e) => {
return Err(anyhow!("Failed to receive ICMP response: {}", e));
}
}
}
Err(e) => {
return Err(anyhow!("Failed to send ICMP packet: {}", e));
}
}
}
fn create_icmp_packet(mtu: u16) -> Result<Vec<u8>> {
let mut packet = Vec::new();
packet.push(8);
packet.push(0);
packet.push(0);
packet.push(0);
let pid = std::process::id() as u16;
packet.extend_from_slice(&pid.to_be_bytes());
packet.extend_from_slice(&mtu.to_be_bytes());
let header_size = 8;
let data_size = if mtu > header_size {
(mtu - header_size) as usize
} else {
0
};
for i in 0..data_size {
packet.push((i % 256) as u8);
}
let checksum = calculate_icmp_checksum(&packet);
packet[2] = (checksum >> 8) as u8;
packet[3] = (checksum & 0xFF) as u8;
Ok(packet)
}
fn calculate_icmp_checksum(packet: &[u8]) -> u16 {
let mut sum: u32 = 0;
for chunk in packet.chunks_exact(2) {
let word = u16::from_be_bytes([chunk[0], chunk[1]]) as u32;
sum += word;
}
if packet.len() % 2 == 1 {
sum += (packet[packet.len() - 1] as u32) << 8;
}
while sum >> 16 != 0 {
sum = (sum & 0xFFFF) + (sum >> 16);
}
!sum as u16
}
fn run_tcp_fallback_mtu_discovery(target_ip: &IpAddr, start_mtu: u16, max_probes: u32) -> Result<()> {
println!("Running TCP-based MTU discovery as fallback...");
let working_ports = vec![
20, 21, 22, 23, 25, 53, 80, 110, 143, 443, 465, 587, 636, 993, 995,
1433, 3306, 3389, 5432, 6379, 8080, 8443, 9000, 9090
];
let mut working_port = None;
println!("Testing connectivity on common ports...");
for &port in &working_ports {
match test_tcp_connectivity(target_ip, port, Duration::from_secs(2)) {
Ok(_) => {
println!(" ✓ Port {} is accessible", port);
working_port = Some(port);
break;
}
Err(_) => {
println!(" ✗ Port {} not accessible", port);
}
}
}
if working_port.is_none() {
return Err(anyhow!("No accessible TCP ports found for fallback MTU discovery"));
}
let mut lower_bound = 68u16;
let mut upper_bound = start_mtu.min(1500);
let mut last_successful_mtu = lower_bound;
let mut probe_count = 0;
let working_port = working_port.unwrap();
println!("\nStarting TCP-based MTU estimation on port {}...", working_port);
while lower_bound <= upper_bound && probe_count < max_probes {
let mid_mtu = (lower_bound + upper_bound) / 2;
probe_count += 1;
println!("Testing packet size: {} bytes (range: {}-{})", mid_mtu, lower_bound, upper_bound);
match test_tcp_mtu_discovery(target_ip, working_port, mid_mtu) {
Ok(_) => {
println!(" ✓ Success - {} bytes transmitted successfully", mid_mtu);
last_successful_mtu = mid_mtu;
lower_bound = mid_mtu + 1;
}
Err(e) => {
println!(" ✗ Failed: {}", e);
upper_bound = mid_mtu - 1;
}
}
std::thread::sleep(Duration::from_millis(200));
}
if last_successful_mtu > 68 {
println!("\nTCP-based MTU estimation completed:");
println!(" Estimated Path MTU: {} bytes", last_successful_mtu);
println!(" Discovery completed in {} probes", probe_count);
println!("\nNote: This is an estimation using TCP connectivity testing.");
println!("For more accurate results, true ICMP-based MTU discovery is recommended.");
} else {
return Err(anyhow!("TCP-based MTU discovery failed"));
}
Ok(())
}
fn test_tcp_connectivity(target_ip: &IpAddr, port: u16, timeout: Duration) -> Result<()> {
let socket_addr = std::net::SocketAddr::new(*target_ip, port);
match std::net::TcpStream::connect_timeout(&socket_addr, timeout) {
Ok(_) => Ok(()),
Err(e) => Err(anyhow!("TCP connection failed: {}", e))
}
}
fn test_tcp_mtu_discovery(target_ip: &IpAddr, port: u16, packet_size: u16) -> Result<()> {
let socket_addr = std::net::SocketAddr::new(*target_ip, port);
match std::net::TcpStream::connect_timeout(&socket_addr, Duration::from_secs(3)) {
Ok(mut stream) => {
stream.set_read_timeout(Some(Duration::from_secs(3))).ok();
stream.set_write_timeout(Some(Duration::from_secs(3))).ok();
let test_size = packet_size.min(4096) as usize; let test_data = vec![0u8; test_size];
match stream.write_all(&test_data) {
Ok(_) => {
let mut response_buffer = [0u8; 1024];
match stream.read(&mut response_buffer) {
Ok(_) => Ok(()), Err(_) => Ok(()), }
}
Err(e) => {
return Err(anyhow!("Failed to send {} bytes: {}", test_size, e));
}
}
}
Err(e) => {
return Err(anyhow!("Connection failed for {} bytes: {}", packet_size, e));
}
}
}