use std::time::{Duration, Instant};
use tracing::{debug, info, warn};
pub const VCL_HEADER_OVERHEAD: usize = 149;
pub const ETHERNET_MTU: usize = 1500;
pub const IPV4_HEADER: usize = 20;
pub const IPV6_HEADER: usize = 40;
pub const UDP_HEADER: usize = 8;
pub const MIN_MTU: usize = 576;
pub const MAX_MTU: usize = 9000;
#[derive(Debug, Clone)]
pub struct MtuConfig {
pub start_mtu: usize,
pub min_mtu: usize,
pub max_mtu: usize,
pub step: usize,
pub probe_timeout: Duration,
pub ipv6: bool,
pub extra_overhead: usize,
}
impl Default for MtuConfig {
fn default() -> Self {
MtuConfig {
start_mtu: ETHERNET_MTU,
min_mtu: MIN_MTU,
max_mtu: ETHERNET_MTU,
step: 8,
probe_timeout: Duration::from_secs(2),
ipv6: false,
extra_overhead: 0,
}
}
}
impl MtuConfig {
pub fn ipv4_udp() -> Self {
MtuConfig::default()
}
pub fn ipv6_udp() -> Self {
MtuConfig {
ipv6: true,
..Default::default()
}
}
pub fn inside_wireguard() -> Self {
MtuConfig {
max_mtu: 1420,
start_mtu: 1420,
extra_overhead: 60,
..Default::default()
}
}
}
#[derive(Debug, Clone)]
pub struct PathMtu {
pub mtu: usize,
pub fragment_size: usize,
pub measured_at: Instant,
pub is_probed: bool,
}
impl PathMtu {
pub fn new(mtu: usize, fragment_size: usize, is_probed: bool) -> Self {
PathMtu {
mtu,
fragment_size,
measured_at: Instant::now(),
is_probed,
}
}
pub fn is_stale(&self, max_age: Duration) -> bool {
self.measured_at.elapsed() > max_age
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum MtuState {
Initial,
Probing {
low: usize,
high: usize,
current: usize,
},
Confirmed(usize),
FallbackToMin,
}
pub struct MtuNegotiator {
config: MtuConfig,
state: MtuState,
current_mtu: usize,
pending_probe: Option<(usize, Instant)>,
probe_history: Vec<(usize, bool)>,
total_probes: u64,
successful_probes: u64,
}
impl MtuNegotiator {
pub fn new(config: MtuConfig) -> Self {
let start = config.start_mtu;
let min = config.min_mtu;
let max = config.max_mtu;
MtuNegotiator {
state: MtuState::Initial,
current_mtu: start.min(max),
pending_probe: None,
probe_history: Vec::new(),
total_probes: 0,
successful_probes: 0,
config: MtuConfig { start_mtu: start, min_mtu: min, max_mtu: max, ..config },
}
}
pub fn start_discovery(&mut self) -> usize {
let low = self.config.min_mtu;
let high = self.config.max_mtu;
let current = (low + high) / 2;
self.state = MtuState::Probing { low, high, current };
self.pending_probe = Some((current, Instant::now()));
self.total_probes += 1;
info!(low, high, probe_size = current, "MTU discovery started");
current
}
pub fn record_probe(&mut self, size: usize, success: bool) -> Option<usize> {
self.probe_history.push((size, success));
self.pending_probe = None;
if success {
self.successful_probes += 1;
debug!(size, "MTU probe succeeded");
} else {
warn!(size, "MTU probe failed (packet dropped)");
}
match self.state.clone() {
MtuState::Probing { low, high, current } => {
let (new_low, new_high) = if success {
self.current_mtu = current;
(current, high)
} else {
(low, current - self.config.step)
};
if new_high <= new_low || new_high - new_low <= self.config.step {
let final_mtu = if success { current } else { self.current_mtu };
let final_mtu = final_mtu.max(self.config.min_mtu);
self.current_mtu = final_mtu;
self.state = MtuState::Confirmed(final_mtu);
info!(mtu = final_mtu, "MTU discovery complete");
return None;
}
let next = (new_low + new_high) / 2;
self.state = MtuState::Probing {
low: new_low,
high: new_high,
current: next,
};
self.pending_probe = Some((next, Instant::now()));
self.total_probes += 1;
debug!(next_probe = next, low = new_low, high = new_high, "Next MTU probe");
Some(next)
}
_ => {
if success && size > self.current_mtu {
self.current_mtu = size;
}
None
}
}
}
pub fn check_probe_timeout(&self) -> Option<usize> {
if let Some((size, sent_at)) = self.pending_probe {
if sent_at.elapsed() > self.config.probe_timeout {
return Some(size);
}
}
None
}
pub fn current_mtu(&self) -> usize {
self.current_mtu
}
pub fn state(&self) -> &MtuState {
&self.state
}
pub fn is_complete(&self) -> bool {
matches!(self.state, MtuState::Confirmed(_) | MtuState::FallbackToMin)
}
pub fn recommended_fragment_size(&self) -> usize {
let ip_header = if self.config.ipv6 { IPV6_HEADER } else { IPV4_HEADER };
let overhead = ip_header
+ UDP_HEADER
+ VCL_HEADER_OVERHEAD
+ self.config.extra_overhead;
if self.current_mtu <= overhead {
warn!(
mtu = self.current_mtu,
overhead,
"MTU smaller than overhead — using minimum fragment size"
);
return 64; }
let fragment_size = self.current_mtu - overhead;
(fragment_size / 8) * 8
}
pub fn set_mtu(&mut self, mtu: usize) {
let clamped = mtu.clamp(self.config.min_mtu, MAX_MTU);
info!(mtu = clamped, "MTU manually set");
self.current_mtu = clamped;
self.state = MtuState::Confirmed(clamped);
}
pub fn fallback_to_min(&mut self) {
warn!(min = self.config.min_mtu, "MTU falling back to minimum");
self.current_mtu = self.config.min_mtu;
self.state = MtuState::FallbackToMin;
}
pub fn path_mtu(&self) -> PathMtu {
PathMtu::new(
self.current_mtu,
self.recommended_fragment_size(),
self.successful_probes > 0,
)
}
pub fn total_probes(&self) -> u64 {
self.total_probes
}
pub fn successful_probes(&self) -> u64 {
self.successful_probes
}
pub fn probe_history(&self) -> &[(usize, bool)] {
&self.probe_history
}
}
pub fn fragment_size_for_mtu(mtu: usize, ipv6: bool, extra_overhead: usize) -> usize {
let ip_header = if ipv6 { IPV6_HEADER } else { IPV4_HEADER };
let overhead = ip_header + UDP_HEADER + VCL_HEADER_OVERHEAD + extra_overhead;
if mtu <= overhead {
return 64;
}
((mtu - overhead) / 8) * 8
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let c = MtuConfig::default();
assert_eq!(c.start_mtu, 1500);
assert_eq!(c.min_mtu, 576);
assert!(!c.ipv6);
}
#[test]
fn test_ipv6_config() {
let c = MtuConfig::ipv6_udp();
assert!(c.ipv6);
}
#[test]
fn test_wireguard_config() {
let c = MtuConfig::inside_wireguard();
assert_eq!(c.max_mtu, 1420);
assert_eq!(c.extra_overhead, 60);
}
#[test]
fn test_negotiator_new() {
let n = MtuNegotiator::new(MtuConfig::default());
assert_eq!(n.state(), &MtuState::Initial);
assert_eq!(n.current_mtu(), 1500);
assert!(!n.is_complete());
}
#[test]
fn test_start_discovery() {
let mut n = MtuNegotiator::new(MtuConfig::default());
let probe = n.start_discovery();
assert!(probe > 576 && probe < 1500);
assert!(matches!(n.state(), MtuState::Probing { .. }));
assert_eq!(n.total_probes(), 1);
}
#[test]
fn test_record_probe_success() {
let mut n = MtuNegotiator::new(MtuConfig::default());
n.start_discovery();
let next = n.record_probe(1038, true);
assert!(n.current_mtu() >= 1038);
}
#[test]
fn test_record_probe_failure() {
let mut n = MtuNegotiator::new(MtuConfig::default());
n.start_discovery();
let _ = n.record_probe(1038, false);
assert!(n.current_mtu() <= 1500);
}
#[test]
fn test_full_discovery_converges() {
let mut n = MtuNegotiator::new(MtuConfig::default());
let mut probe = n.start_discovery();
for _ in 0..20 {
let success = probe <= 1400;
match n.record_probe(probe, success) {
Some(next) => probe = next,
None => break,
}
}
assert!(n.is_complete());
assert!(n.current_mtu() <= 1400);
assert!(n.current_mtu() >= 576);
}
#[test]
fn test_recommended_fragment_size() {
let mut n = MtuNegotiator::new(MtuConfig::default());
n.set_mtu(1500);
let fs = n.recommended_fragment_size();
assert!(fs > 0);
assert!(fs < 1500);
assert_eq!(fs % 8, 0);
}
#[test]
fn test_fragment_size_for_mtu_fn() {
let fs = fragment_size_for_mtu(1500, false, 0);
assert!(fs > 0 && fs < 1500);
assert_eq!(fs % 8, 0);
let fs_v6 = fragment_size_for_mtu(1500, true, 0);
assert!(fs_v6 < fs); }
#[test]
fn test_set_mtu() {
let mut n = MtuNegotiator::new(MtuConfig::default());
n.set_mtu(1280);
assert_eq!(n.current_mtu(), 1280);
assert!(n.is_complete());
assert!(matches!(n.state(), MtuState::Confirmed(1280)));
}
#[test]
fn test_set_mtu_clamped() {
let mut n = MtuNegotiator::new(MtuConfig::default());
n.set_mtu(100); assert_eq!(n.current_mtu(), 576);
}
#[test]
fn test_fallback_to_min() {
let mut n = MtuNegotiator::new(MtuConfig::default());
n.fallback_to_min();
assert_eq!(n.current_mtu(), 576);
assert_eq!(n.state(), &MtuState::FallbackToMin);
assert!(n.is_complete());
}
#[test]
fn test_path_mtu() {
let mut n = MtuNegotiator::new(MtuConfig::default());
n.set_mtu(1400);
let pm = n.path_mtu();
assert_eq!(pm.mtu, 1400);
assert!(pm.fragment_size < 1400);
assert!(!pm.is_probed); }
#[test]
fn test_probe_history() {
let mut n = MtuNegotiator::new(MtuConfig::default());
n.start_discovery();
n.record_probe(1038, true);
assert_eq!(n.probe_history().len(), 1);
assert_eq!(n.probe_history()[0], (1038, true));
}
#[test]
fn test_check_probe_timeout_no_pending() {
let n = MtuNegotiator::new(MtuConfig::default());
assert!(n.check_probe_timeout().is_none());
}
#[test]
fn test_check_probe_timeout_not_yet() {
let mut n = MtuNegotiator::new(MtuConfig::default());
n.start_discovery();
assert!(n.check_probe_timeout().is_none());
}
#[test]
fn test_mtu_smaller_than_overhead() {
let config = MtuConfig {
start_mtu: 100,
min_mtu: 64,
max_mtu: 100,
..Default::default()
};
let mut n = MtuNegotiator::new(config);
n.set_mtu(100);
assert_eq!(n.recommended_fragment_size(), 64);
}
#[test]
fn test_extra_overhead() {
let fs1 = fragment_size_for_mtu(1500, false, 0);
let fs2 = fragment_size_for_mtu(1500, false, 60); assert!(fs2 < fs1);
}
#[test]
fn test_total_probes_counted() {
let mut n = MtuNegotiator::new(MtuConfig::default());
n.start_discovery();
assert_eq!(n.total_probes(), 1);
n.record_probe(1038, true);
assert!(n.total_probes() >= 1);
}
}