use crate::{
counter::{Counter, Saturating},
event::{self, builder::MtuUpdatedCause, IntoEvent},
frame, inet,
packet::number::PacketNumber,
path::{self, mtu},
recovery::{congestion_controller, CongestionController},
time::{timer, Timer, Timestamp},
transmission,
};
use core::{
fmt,
fmt::{Display, Formatter},
num::NonZeroU16,
time::Duration,
};
use s2n_codec::EncoderValue;
#[cfg(test)]
mod tests;
#[cfg(any(test, feature = "testing"))]
pub mod testing {
use super::*;
use crate::inet::{IpV4Address, SocketAddressV4};
pub fn new_controller(max_mtu: u16) -> Controller {
let ip = IpV4Address::new([127, 0, 0, 1]);
let addr = inet::SocketAddress::IpV4(SocketAddressV4::new(ip, 443));
Controller::new(
Config {
max_mtu: max_mtu.try_into().unwrap(),
..Default::default()
},
&addr,
)
}
pub fn test_controller(mtu: u16, probed_size: u16) -> Controller {
let mut controller = new_controller(u16::MAX);
controller.plpmtu = mtu;
controller.probed_size = probed_size;
controller
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
enum State {
EarlySearchRequested,
Disabled,
SearchRequested,
Searching(PacketNumber, Timestamp),
SearchComplete,
}
impl State {
fn is_early_search_requested(&self) -> bool {
matches!(self, State::EarlySearchRequested)
}
fn is_disabled(&self) -> bool {
matches!(self, State::Disabled)
}
fn is_search_complete(&self) -> bool {
matches!(self, State::SearchComplete)
}
}
const MAX_PROBES: u8 = 3;
const ETHERNET_MTU: u16 = 1500;
const PROBE_THRESHOLD: u16 = 20;
const BLACK_HOLE_THRESHOLD: u8 = 3;
const BLACK_HOLE_COOL_OFF_DURATION: Duration = Duration::from_secs(60);
const PMTU_RAISE_TIMER_DURATION: Duration = Duration::from_secs(600);
pub const MINIMUM_MAX_DATAGRAM_SIZE: u16 = 1200;
const UDP_HEADER_LEN: u16 = 8;
const IPV4_MIN_HEADER_LEN: u16 = 20;
const IPV6_MIN_HEADER_LEN: u16 = 40;
const fn const_min(a: u16, b: u16) -> u16 {
if a < b {
a
} else {
b
}
}
const MINIMUM_MTU: u16 = MINIMUM_MAX_DATAGRAM_SIZE
+ UDP_HEADER_LEN
+ const_min(IPV4_MIN_HEADER_LEN, IPV6_MIN_HEADER_LEN);
macro_rules! impl_mtu {
($name:ident, $default:expr) => {
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct $name(NonZeroU16);
impl $name {
pub const MIN: Self = Self(unsafe { NonZeroU16::new_unchecked(MINIMUM_MTU) });
#[inline]
pub fn max_datagram_size(&self, peer_socket_address: &inet::SocketAddress) -> u16 {
let min_ip_header_len = match peer_socket_address {
inet::SocketAddress::IpV4(_) => IPV4_MIN_HEADER_LEN,
inet::SocketAddress::IpV6(_) => IPV6_MIN_HEADER_LEN,
};
(u16::from(*self) - UDP_HEADER_LEN - min_ip_header_len)
.max(MINIMUM_MAX_DATAGRAM_SIZE)
}
}
impl Default for $name {
#[inline]
fn default() -> Self {
$default
}
}
impl TryFrom<u16> for $name {
type Error = MtuError;
fn try_from(value: u16) -> Result<Self, Self::Error> {
if value < MINIMUM_MTU {
return Err(MtuError);
}
Ok($name(value.try_into().expect(
"Value must be greater than zero according to the check above",
)))
}
}
impl From<$name> for usize {
#[inline]
fn from(value: $name) -> Self {
value.0.get() as usize
}
}
impl From<$name> for u16 {
#[inline]
fn from(value: $name) -> Self {
value.0.get()
}
}
};
}
const DEFAULT_MAX_MTU: MaxMtu = MaxMtu(NonZeroU16::new(1500).unwrap());
const DEFAULT_BASE_MTU: BaseMtu = BaseMtu(NonZeroU16::new(MINIMUM_MTU).unwrap());
const DEFAULT_INITIAL_MTU: InitialMtu = InitialMtu(NonZeroU16::new(MINIMUM_MTU).unwrap());
impl_mtu!(MaxMtu, DEFAULT_MAX_MTU);
impl_mtu!(InitialMtu, DEFAULT_INITIAL_MTU);
impl_mtu!(BaseMtu, DEFAULT_BASE_MTU);
#[derive(Debug, Eq, PartialEq)]
pub struct MtuError;
impl Display for MtuError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(
f,
"MTU must have {} <= base_mtu (default: {}) <= initial_mtu (default: {}) <= max_mtu (default: {})",
MINIMUM_MTU, DEFAULT_BASE_MTU.0, DEFAULT_INITIAL_MTU.0, DEFAULT_MAX_MTU.0
)
}
}
impl core::error::Error for MtuError {}
#[non_exhaustive]
pub struct PathInfo<'a> {
pub remote_address: event::api::SocketAddress<'a>,
}
impl<'a> PathInfo<'a> {
#[inline]
#[doc(hidden)]
pub fn new(remote_address: &'a inet::SocketAddress) -> Self {
PathInfo {
remote_address: remote_address.into_event(),
}
}
}
#[derive(Debug)]
pub struct Manager<E: mtu::Endpoint> {
provider: E,
endpoint_mtu_config: Config,
}
impl<E: mtu::Endpoint> Manager<E> {
pub fn new(provider: E) -> Self {
Manager {
provider,
endpoint_mtu_config: Default::default(),
}
}
pub fn config(&mut self, remote_address: &inet::SocketAddress) -> Result<Config, MtuError> {
let info = mtu::PathInfo::new(remote_address);
if let Some(conn_config) = self.provider.on_path(&info, self.endpoint_mtu_config) {
ensure!(conn_config.is_valid(), Err(MtuError));
ensure!(
u16::from(conn_config.max_mtu) <= u16::from(self.endpoint_mtu_config.max_mtu()),
Err(MtuError)
);
Ok(conn_config)
} else {
Ok(self.endpoint_mtu_config)
}
}
pub fn set_endpoint_config(&mut self, config: Config) {
self.endpoint_mtu_config = config;
}
pub fn endpoint_config(&self) -> &Config {
&self.endpoint_mtu_config
}
}
pub trait Endpoint: 'static + Send {
fn on_path(&mut self, info: &mtu::PathInfo, endpoint_mtu_config: Config)
-> Option<mtu::Config>;
}
#[derive(Debug, Default)]
pub struct Inherit {}
impl Endpoint for Inherit {
fn on_path(
&mut self,
_info: &mtu::PathInfo,
_endpoint_mtu_config: Config,
) -> Option<mtu::Config> {
None
}
}
#[derive(Copy, Clone, Debug, Default)]
pub struct Config {
initial_mtu: InitialMtu,
base_mtu: BaseMtu,
max_mtu: MaxMtu,
}
impl Endpoint for Config {
fn on_path(
&mut self,
_info: &mtu::PathInfo,
_endpoint_mtu_config: Config,
) -> Option<mtu::Config> {
Some(*self)
}
}
impl Config {
pub const MIN: Self = Self {
initial_mtu: InitialMtu::MIN,
base_mtu: BaseMtu::MIN,
max_mtu: MaxMtu::MIN,
};
pub fn builder() -> Builder {
Builder::default()
}
pub fn initial_mtu(&self) -> InitialMtu {
self.initial_mtu
}
pub fn base_mtu(&self) -> BaseMtu {
self.base_mtu
}
pub fn max_mtu(&self) -> MaxMtu {
self.max_mtu
}
#[inline]
pub fn is_valid(&self) -> bool {
self.base_mtu.0 <= self.initial_mtu.0 && self.initial_mtu.0 <= self.max_mtu.0
}
}
#[derive(Debug, Default)]
pub struct Builder {
initial_mtu: Option<InitialMtu>,
base_mtu: Option<BaseMtu>,
max_mtu: Option<MaxMtu>,
}
impl Builder {
pub fn with_initial_mtu(mut self, initial_mtu: u16) -> Result<Self, MtuError> {
if let Some(base_mtu) = self.base_mtu {
ensure!(initial_mtu >= base_mtu.0.get(), Err(MtuError));
}
if let Some(max_mtu) = self.max_mtu {
ensure!(initial_mtu <= max_mtu.0.get(), Err(MtuError));
}
self.initial_mtu = Some(initial_mtu.try_into()?);
Ok(self)
}
pub fn with_base_mtu(mut self, base_mtu: u16) -> Result<Self, MtuError> {
if let Some(initial_mtu) = self.initial_mtu {
ensure!(initial_mtu.0.get() >= base_mtu, Err(MtuError));
}
if let Some(max_mtu) = self.max_mtu {
ensure!(base_mtu <= max_mtu.0.get(), Err(MtuError));
}
self.base_mtu = Some(base_mtu.try_into()?);
Ok(self)
}
pub fn with_max_mtu(mut self, max_mtu: u16) -> Result<Self, MtuError> {
if let Some(initial_mtu) = self.initial_mtu {
ensure!(initial_mtu.0.get() <= max_mtu, Err(MtuError));
}
if let Some(base_mtu) = self.base_mtu {
ensure!(base_mtu.0.get() <= max_mtu, Err(MtuError));
}
self.max_mtu = Some(max_mtu.try_into()?);
Ok(self)
}
pub fn build(self) -> Result<Config, MtuError> {
let base_mtu = self.base_mtu.unwrap_or_default();
let max_mtu = self.max_mtu.unwrap_or_default();
let mut initial_mtu = self.initial_mtu.unwrap_or_default();
if self.initial_mtu.is_none() {
initial_mtu = initial_mtu
.0
.max(base_mtu.0)
.min(max_mtu.0)
.get()
.try_into()?
};
let config = Config {
initial_mtu,
max_mtu,
base_mtu,
};
ensure!(config.is_valid(), Err(MtuError));
Ok(config)
}
}
#[derive(Eq, PartialEq, Debug)]
pub enum MtuResult {
NoChange,
MtuUpdated(u16),
}
#[derive(Clone, Debug)]
pub struct Controller {
state: State,
base_plpmtu: u16,
plpmtu: u16,
max_udp_payload: u16,
probed_size: u16,
max_probe_size: u16,
probe_count: u8,
black_hole_counter: Counter<u8, Saturating>,
largest_acked_mtu_sized_packet: Option<PacketNumber>,
pmtu_raise_timer: Timer,
needs_to_send_completion: bool,
mtu_probing_complete_support: bool,
}
impl Controller {
#[inline]
pub fn new(config: Config, peer_socket_address: &inet::SocketAddress) -> Self {
debug_assert!(config.is_valid(), "Invalid MTU configuration {config:?}");
let base_plpmtu = config.base_mtu.max_datagram_size(peer_socket_address);
let max_udp_payload = config.max_mtu.max_datagram_size(peer_socket_address);
let plpmtu = config.initial_mtu.max_datagram_size(peer_socket_address);
let initial_probed_size = if u16::from(config.initial_mtu) > ETHERNET_MTU - PROBE_THRESHOLD
{
Self::next_probe_size(plpmtu, max_udp_payload)
} else {
let min_ip_header_len = match peer_socket_address {
inet::SocketAddress::IpV4(_) => IPV4_MIN_HEADER_LEN,
inet::SocketAddress::IpV6(_) => IPV6_MIN_HEADER_LEN,
};
ETHERNET_MTU - UDP_HEADER_LEN - min_ip_header_len
}
.min(max_udp_payload);
let state = if plpmtu > base_plpmtu {
State::EarlySearchRequested
} else if initial_probed_size - base_plpmtu < PROBE_THRESHOLD {
State::SearchComplete
} else {
State::Disabled
};
Self {
state,
base_plpmtu,
plpmtu,
probed_size: initial_probed_size,
max_udp_payload,
max_probe_size: max_udp_payload,
probe_count: 0,
black_hole_counter: Default::default(),
largest_acked_mtu_sized_packet: None,
pmtu_raise_timer: Timer::default(),
needs_to_send_completion: false,
mtu_probing_complete_support: false,
}
}
#[inline]
pub fn enable_mtu_probing_complete_support(&mut self) {
self.mtu_probing_complete_support = true;
if self.state.is_search_complete() {
self.needs_to_send_completion = true;
}
}
#[inline]
pub fn enable(&mut self) {
ensure!(self.state.is_disabled() || self.state.is_early_search_requested());
self.request_new_search(None);
}
#[inline]
pub fn on_timeout(&mut self, now: Timestamp) {
ensure!(self.pmtu_raise_timer.poll_expiration(now).is_ready());
self.request_new_search(None);
}
#[inline]
pub fn on_packet_ack<CC: CongestionController, Pub: event::ConnectionPublisher>(
&mut self,
packet_number: PacketNumber,
sent_bytes: u16,
congestion_controller: &mut CC,
path_id: path::Id,
publisher: &mut Pub,
) -> MtuResult {
if self.state.is_early_search_requested() && sent_bytes > self.base_plpmtu {
if self.is_next_probe_size_above_threshold() {
self.state = State::Disabled;
} else {
self.set_search_complete();
}
publisher.on_mtu_updated(event::builder::MtuUpdated {
path_id: path_id.into_event(),
mtu: self.plpmtu,
cause: MtuUpdatedCause::InitialMtuPacketAcknowledged,
search_complete: self.state.is_search_complete(),
});
}
ensure!(self.state != State::Disabled, MtuResult::NoChange);
ensure!(
packet_number.space().is_application_data(),
MtuResult::NoChange
);
if sent_bytes >= self.plpmtu
&& self
.largest_acked_mtu_sized_packet
.is_none_or(|pn| packet_number > pn)
{
self.black_hole_counter = Default::default();
self.largest_acked_mtu_sized_packet = Some(packet_number);
}
if let State::Searching(probe_packet_number, transmit_time) = self.state {
if packet_number == probe_packet_number {
self.plpmtu = self.probed_size;
congestion_controller.on_mtu_update(
self.plpmtu,
&mut congestion_controller::PathPublisher::new(publisher, path_id),
);
self.update_probed_size();
self.request_new_search(Some(transmit_time));
publisher.on_mtu_updated(event::builder::MtuUpdated {
path_id: path_id.into_event(),
mtu: self.plpmtu,
cause: MtuUpdatedCause::ProbeAcknowledged,
search_complete: self.state.is_search_complete(),
});
return MtuResult::MtuUpdated(self.plpmtu);
}
}
MtuResult::NoChange
}
#[inline]
pub fn on_packet_loss<CC: CongestionController, Pub: event::ConnectionPublisher>(
&mut self,
packet_number: PacketNumber,
lost_bytes: u16,
new_loss_burst: bool,
now: Timestamp,
congestion_controller: &mut CC,
path_id: path::Id,
publisher: &mut Pub,
) -> MtuResult {
ensure!(
self.state.is_early_search_requested() || packet_number.space().is_application_data(),
MtuResult::NoChange
);
match &self.state {
State::Disabled => {}
State::EarlySearchRequested => {
self.plpmtu = self.base_plpmtu;
congestion_controller.on_mtu_update(
self.plpmtu,
&mut congestion_controller::PathPublisher::new(publisher, path_id),
);
if self.is_next_probe_size_above_threshold() {
self.state = State::Disabled;
} else {
self.set_search_complete();
}
publisher.on_mtu_updated(event::builder::MtuUpdated {
path_id: path_id.into_event(),
mtu: self.plpmtu,
cause: MtuUpdatedCause::InitialMtuPacketLost,
search_complete: self.state.is_search_complete(),
});
return MtuResult::MtuUpdated(self.plpmtu);
}
State::Searching(probe_pn, _) if *probe_pn == packet_number => {
if self.probe_count == MAX_PROBES {
self.max_probe_size = self.probed_size;
self.update_probed_size();
self.request_new_search(None);
if self.is_search_completed() {
publisher.on_mtu_updated(event::builder::MtuUpdated {
path_id: path_id.into_event(),
mtu: self.plpmtu,
cause: MtuUpdatedCause::LargerProbesLost,
search_complete: true,
})
}
} else {
self.state = State::SearchRequested
}
}
State::Searching(_, _) | State::SearchComplete | State::SearchRequested => {
if (self.base_plpmtu + 1..=self.plpmtu).contains(&lost_bytes)
&& self
.largest_acked_mtu_sized_packet
.is_none_or(|pn| packet_number > pn)
&& new_loss_burst
{
self.black_hole_counter += 1;
}
if self.black_hole_counter > BLACK_HOLE_THRESHOLD {
return self.on_black_hole_detected(
now,
congestion_controller,
path_id,
publisher,
);
}
}
}
MtuResult::NoChange
}
#[inline]
pub fn max_datagram_size(&self) -> usize {
self.plpmtu as usize
}
#[inline]
pub fn probed_sized(&self) -> usize {
self.probed_size as usize
}
pub fn is_search_completed(&self) -> bool {
self.state.is_search_complete()
}
#[inline]
fn update_probed_size(&mut self) {
self.probed_size = Self::next_probe_size(self.plpmtu, self.max_probe_size);
}
#[inline]
fn next_probe_size(current: u16, max: u16) -> u16 {
current + ((max - current) / 2)
}
#[inline]
fn is_next_probe_size_above_threshold(&self) -> bool {
self.probed_size - self.plpmtu >= PROBE_THRESHOLD
}
#[inline]
fn set_search_complete(&mut self) {
self.state = State::SearchComplete;
if self.mtu_probing_complete_support {
self.needs_to_send_completion = true;
}
}
#[inline]
fn request_new_search(&mut self, last_probe_time: Option<Timestamp>) {
if self.is_next_probe_size_above_threshold() {
self.probe_count = 0;
self.state = State::SearchRequested;
} else {
self.set_search_complete();
if let Some(last_probe_time) = last_probe_time {
self.arm_pmtu_raise_timer(last_probe_time + PMTU_RAISE_TIMER_DURATION);
}
}
}
#[inline]
fn on_black_hole_detected<CC: CongestionController, Pub: event::ConnectionPublisher>(
&mut self,
now: Timestamp,
congestion_controller: &mut CC,
path_id: path::Id,
publisher: &mut Pub,
) -> MtuResult {
self.black_hole_counter = Default::default();
self.largest_acked_mtu_sized_packet = None;
self.plpmtu = self.base_plpmtu;
congestion_controller.on_mtu_update(
self.plpmtu,
&mut congestion_controller::PathPublisher::new(publisher, path_id),
);
self.set_search_complete();
self.arm_pmtu_raise_timer(now + BLACK_HOLE_COOL_OFF_DURATION);
publisher.on_mtu_updated(event::builder::MtuUpdated {
path_id: path_id.into_event(),
mtu: self.plpmtu,
cause: MtuUpdatedCause::Blackhole,
search_complete: self.state.is_search_complete(),
});
MtuResult::MtuUpdated(self.plpmtu)
}
#[inline]
fn arm_pmtu_raise_timer(&mut self, timestamp: Timestamp) {
self.max_probe_size = self.max_udp_payload;
self.update_probed_size();
if self.is_next_probe_size_above_threshold() {
self.pmtu_raise_timer.set(timestamp);
}
}
}
impl timer::Provider for Controller {
#[inline]
fn timers<Q: timer::Query>(&self, query: &mut Q) -> timer::Result {
self.pmtu_raise_timer.timers(query)?;
Ok(())
}
}
impl Controller {
#[inline]
pub fn probe_needed(&self) -> bool {
self.state == State::SearchRequested
}
#[inline]
pub fn on_transmit_probe<W: transmission::Writer>(&mut self, context: &mut W) {
ensure!(context.transmission_mode().is_mtu_probing());
ensure!(self.state == State::SearchRequested);
let probe_payload_size =
self.probed_size as usize - context.header_len() - context.tag_len();
if context.remaining_capacity() < probe_payload_size {
self.set_search_complete();
return;
}
context.write_frame(&frame::Ping);
let padding_size = probe_payload_size - frame::Ping.encoding_size();
if let Some(packet_number) = context.write_frame(&frame::Padding {
length: padding_size,
}) {
self.probe_count += 1;
self.state = State::Searching(packet_number, context.current_time());
}
}
}
impl transmission::Provider for Controller {
#[inline]
fn on_transmit<W: transmission::Writer>(&mut self, context: &mut W) {
ensure!(!context.transmission_mode().is_mtu_probing());
if self.needs_to_send_completion {
let frame = frame::MtuProbingComplete::new(self.plpmtu);
if context.write_frame(&frame).is_some() {
self.needs_to_send_completion = false;
}
}
}
}
impl transmission::interest::Provider for Controller {
#[inline]
fn transmission_interest<Q: transmission::interest::Query>(
&self,
query: &mut Q,
) -> transmission::interest::Result {
if self.needs_to_send_completion {
query.on_new_data()?;
}
Ok(())
}
}