use std::cmp;
use std::time;
const MAX_PACKET_SIZE_IPV4: usize = 1472;
const MAX_PACKET_SIZE_IPV6: usize = 1452;
const MAX_PROBE_COUNT: u8 = 3;
#[derive(Default)]
pub(super) struct Dplpmtud {
should_probe: bool,
current_size: usize,
probe_size: Option<usize>,
probe_count: u8,
failed_size: usize,
max_pmtu: usize,
is_ipv6: bool,
}
impl Dplpmtud {
pub(super) fn new(enable: bool, mut max_pmtu: usize, is_ipv6: bool) -> Self {
if max_pmtu == crate::DEFAULT_SEND_UDP_PAYLOAD_SIZE {
max_pmtu = if is_ipv6 {
MAX_PACKET_SIZE_IPV6
} else {
MAX_PACKET_SIZE_IPV4
};
}
Self {
should_probe: enable,
current_size: crate::DEFAULT_SEND_UDP_PAYLOAD_SIZE,
probe_size: None,
max_pmtu,
is_ipv6,
..Self::default()
}
}
pub(super) fn should_probe(&self) -> bool {
self.should_probe
}
pub(super) fn get_probe_size(&mut self, peer_max_udp_payload: usize) -> usize {
if let Some(probe_size) = self.probe_size {
return probe_size;
}
let probe_size = self.cal_probe_size(peer_max_udp_payload);
self.probe_size = Some(probe_size);
probe_size
}
pub(super) fn get_current_size(&self) -> usize {
self.current_size
}
pub(super) fn on_pmtu_probe_sent(&mut self, pkt_size: usize) {
self.should_probe = false;
}
pub(super) fn on_pmtu_probe_acked(&mut self, pkt_size: usize, peer_max_udp_payload: usize) {
self.current_size = cmp::max(self.current_size, pkt_size);
self.probe_count = 0;
self.probe_size = Some(self.cal_probe_size(peer_max_udp_payload));
self.should_probe = !self.check_finish(peer_max_udp_payload);
}
pub(super) fn on_pmtu_probe_lost(&mut self, pkt_size: usize, peer_max_udp_payload: usize) {
if Some(pkt_size) != self.probe_size {
return;
}
self.probe_count += 1;
if self.probe_count < MAX_PROBE_COUNT {
self.should_probe = true;
return;
}
self.failed_size = pkt_size;
self.probe_size = Some(self.cal_probe_size(peer_max_udp_payload));
self.probe_count = 0;
self.should_probe = !self.check_finish(peer_max_udp_payload);
}
fn cal_probe_size(&self, peer_max_udp_payload: usize) -> usize {
let mtu_ceiling = self.cal_mtu_ceiling(peer_max_udp_payload);
if self.failed_size == 0 && mtu_ceiling < 1500 {
return mtu_ceiling;
}
(self.current_size + mtu_ceiling) / 2
}
fn cal_mtu_ceiling(&self, peer_max_udp_payload: usize) -> usize {
let mut mtu_ceiling = if self.failed_size > 0 {
self.failed_size
} else {
self.max_pmtu
};
if mtu_ceiling > peer_max_udp_payload {
mtu_ceiling = peer_max_udp_payload;
}
mtu_ceiling
}
fn check_finish(&self, peer_max_udp_payload: usize) -> bool {
let mtu_ceiling = self.cal_mtu_ceiling(peer_max_udp_payload);
self.current_size >= mtu_ceiling || self.current_size as f64 / mtu_ceiling as f64 >= 0.99
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn dplpmtud_default() {
let d = Dplpmtud::new(false, 1500, true);
assert_eq!(d.should_probe(), false);
assert_eq!(d.get_current_size(), crate::DEFAULT_SEND_UDP_PAYLOAD_SIZE);
let mut d = Dplpmtud::new(true, 1500, false);
let peer_max_udp_payload = 1400;
assert_eq!(d.should_probe(), true);
assert_eq!(d.get_current_size(), crate::DEFAULT_SEND_UDP_PAYLOAD_SIZE);
assert_eq!(d.get_probe_size(peer_max_udp_payload), peer_max_udp_payload);
}
#[test]
fn dplpmtud_max() {
let mut d = Dplpmtud::new(true, 1200, false);
let peer_max_udp_payload = 60000;
assert_eq!(d.should_probe(), true);
let probe_size = d.get_probe_size(peer_max_udp_payload);
d.on_pmtu_probe_sent(probe_size);
assert_eq!(d.should_probe(), false);
d.on_pmtu_probe_acked(probe_size, peer_max_udp_payload);
assert_eq!(d.get_current_size(), 1472);
assert_eq!(d.should_probe(), false);
}
#[test]
fn dplpmtud_min() {
let mut d = Dplpmtud::new(true, 1200, true);
let peer_max_udp_payload = 60000;
assert_eq!(d.should_probe(), true);
for i in 0..10 {
let probe_size = d.get_probe_size(peer_max_udp_payload);
for i in 0..MAX_PROBE_COUNT {
d.on_pmtu_probe_sent(probe_size);
d.on_pmtu_probe_lost(probe_size, peer_max_udp_payload);
}
assert_eq!(d.failed_size, probe_size);
if !d.should_probe() {
break;
}
}
assert_eq!(d.get_current_size(), 1200);
assert_eq!(d.should_probe(), false);
}
#[test]
fn dplpmtud_mid() {
let mut d = Dplpmtud::new(true, 1200, true);
let peer_max_udp_payload = 60000;
assert_eq!(d.should_probe(), true);
let pmtu = 1350;
for i in 0..10 {
let probe_size = d.get_probe_size(peer_max_udp_payload);
if probe_size > pmtu {
for i in 0..MAX_PROBE_COUNT {
d.on_pmtu_probe_sent(probe_size);
d.on_pmtu_probe_lost(probe_size, peer_max_udp_payload);
}
assert_eq!(d.failed_size, probe_size);
} else {
d.on_pmtu_probe_sent(probe_size);
d.on_pmtu_probe_acked(probe_size, peer_max_udp_payload);
assert_eq!(d.get_current_size(), probe_size);
}
if !d.should_probe() {
break;
}
}
assert_eq!(d.get_current_size(), 1349);
assert_eq!(d.should_probe(), false);
}
}