use core::time::Duration;
use crate::{CrafterError, Result};
pub const IP_DEFRAG_DEFAULT_MAX_DATAGRAMS: usize = 1024;
pub const IP_DEFRAG_DEFAULT_MAX_BYTES_PER_DATAGRAM: usize = 65_535;
pub const IP_DEFRAG_DEFAULT_MAX_AGE: Duration = Duration::from_secs(60);
pub const IP_FRAGMENT_MIN_MTU: usize = 28;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)]
pub enum IpDefragOverlapPolicy {
#[default]
RejectConflicting,
DropConflicting,
PassThroughConflicting,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)]
pub enum Ipv6AtomicFragmentPolicy {
PassThrough,
#[default]
Normalize,
Drop,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)]
pub enum Ipv4DontFragmentPolicy {
#[default]
Error,
PassThrough,
FragmentAnyway,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)]
pub enum Ipv4FragmentIdentificationPolicy {
#[default]
PreserveOrGenerate,
PreserveOnly,
Fixed(u16),
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)]
pub enum Ipv6FragmentIdentificationPolicy {
#[default]
Generate,
Fixed(u32),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct IpDefragConfig {
pass_non_fragments: bool,
max_datagrams: usize,
max_bytes_per_datagram: usize,
max_age: Duration,
overlap_policy: IpDefragOverlapPolicy,
ipv6_atomic_fragment_policy: Ipv6AtomicFragmentPolicy,
trace_passthrough: bool,
trace_evictions: bool,
}
impl IpDefragConfig {
pub const fn new() -> Self {
Self {
pass_non_fragments: true,
max_datagrams: IP_DEFRAG_DEFAULT_MAX_DATAGRAMS,
max_bytes_per_datagram: IP_DEFRAG_DEFAULT_MAX_BYTES_PER_DATAGRAM,
max_age: IP_DEFRAG_DEFAULT_MAX_AGE,
overlap_policy: IpDefragOverlapPolicy::RejectConflicting,
ipv6_atomic_fragment_policy: Ipv6AtomicFragmentPolicy::Normalize,
trace_passthrough: false,
trace_evictions: false,
}
}
pub const fn pass_non_fragments(mut self, pass_non_fragments: bool) -> Self {
self.pass_non_fragments = pass_non_fragments;
self
}
pub const fn emits_non_fragments(&self) -> bool {
self.pass_non_fragments
}
pub const fn max_datagrams(mut self, max_datagrams: usize) -> Self {
self.max_datagrams = max_datagrams;
self
}
pub fn try_max_datagrams(self, max_datagrams: usize) -> Result<Self> {
validate_nonzero("ip.defrag.max_datagrams", max_datagrams)?;
Ok(self.max_datagrams(max_datagrams))
}
pub const fn max_datagrams_limit(&self) -> usize {
self.max_datagrams
}
pub const fn max_bytes_per_datagram(mut self, max_bytes_per_datagram: usize) -> Self {
self.max_bytes_per_datagram = max_bytes_per_datagram;
self
}
pub fn try_max_bytes_per_datagram(self, max_bytes_per_datagram: usize) -> Result<Self> {
validate_nonzero("ip.defrag.max_bytes_per_datagram", max_bytes_per_datagram)?;
Ok(self.max_bytes_per_datagram(max_bytes_per_datagram))
}
pub const fn max_bytes_per_datagram_limit(&self) -> usize {
self.max_bytes_per_datagram
}
pub const fn max_age(mut self, max_age: Duration) -> Self {
self.max_age = max_age;
self
}
pub fn try_max_age(self, max_age: Duration) -> Result<Self> {
validate_nonzero_duration("ip.defrag.max_age", max_age)?;
Ok(self.max_age(max_age))
}
pub const fn max_age_limit(&self) -> Duration {
self.max_age
}
pub const fn overlap_policy(mut self, overlap_policy: IpDefragOverlapPolicy) -> Self {
self.overlap_policy = overlap_policy;
self
}
pub const fn configured_overlap_policy(&self) -> IpDefragOverlapPolicy {
self.overlap_policy
}
pub const fn ipv6_atomic_fragments(
mut self,
ipv6_atomic_fragment_policy: Ipv6AtomicFragmentPolicy,
) -> Self {
self.ipv6_atomic_fragment_policy = ipv6_atomic_fragment_policy;
self
}
pub const fn ipv6_atomic_fragment_policy(&self) -> Ipv6AtomicFragmentPolicy {
self.ipv6_atomic_fragment_policy
}
pub const fn trace_passthrough(mut self, trace_passthrough: bool) -> Self {
self.trace_passthrough = trace_passthrough;
self
}
pub const fn traces_passthrough(&self) -> bool {
self.trace_passthrough
}
pub const fn trace_evictions(mut self, trace_evictions: bool) -> Self {
self.trace_evictions = trace_evictions;
self
}
pub const fn traces_evictions(&self) -> bool {
self.trace_evictions
}
pub fn validate(&self) -> Result<()> {
validate_nonzero("ip.defrag.max_datagrams", self.max_datagrams)?;
validate_nonzero(
"ip.defrag.max_bytes_per_datagram",
self.max_bytes_per_datagram,
)?;
validate_nonzero_duration("ip.defrag.max_age", self.max_age)?;
Ok(())
}
}
impl Default for IpDefragConfig {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct IpFragmentConfig {
mtu: usize,
dont_fragment_policy: Ipv4DontFragmentPolicy,
ipv4_identification_policy: Ipv4FragmentIdentificationPolicy,
ipv6_identification_policy: Ipv6FragmentIdentificationPolicy,
ipv6_identification_seed: Option<u64>,
trace_passthrough: bool,
}
impl IpFragmentConfig {
pub const fn new(mtu: usize) -> Self {
Self {
mtu,
dont_fragment_policy: Ipv4DontFragmentPolicy::Error,
ipv4_identification_policy: Ipv4FragmentIdentificationPolicy::PreserveOrGenerate,
ipv6_identification_policy: Ipv6FragmentIdentificationPolicy::Generate,
ipv6_identification_seed: None,
trace_passthrough: false,
}
}
pub fn try_new(mtu: usize) -> Result<Self> {
validate_mtu(mtu)?;
Ok(Self::new(mtu))
}
pub const fn mtu(&self) -> usize {
self.mtu
}
pub const fn with_mtu(mut self, mtu: usize) -> Self {
self.mtu = mtu;
self
}
pub fn try_mtu(self, mtu: usize) -> Result<Self> {
validate_mtu(mtu)?;
Ok(self.with_mtu(mtu))
}
pub const fn honor_dont_fragment(mut self, honor_dont_fragment: bool) -> Self {
self.dont_fragment_policy = if honor_dont_fragment {
Ipv4DontFragmentPolicy::Error
} else {
Ipv4DontFragmentPolicy::FragmentAnyway
};
self
}
pub const fn honors_dont_fragment(&self) -> bool {
!matches!(
self.dont_fragment_policy,
Ipv4DontFragmentPolicy::FragmentAnyway
)
}
pub const fn dont_fragment_policy(
mut self,
dont_fragment_policy: Ipv4DontFragmentPolicy,
) -> Self {
self.dont_fragment_policy = dont_fragment_policy;
self
}
pub const fn configured_dont_fragment_policy(&self) -> Ipv4DontFragmentPolicy {
self.dont_fragment_policy
}
pub const fn ipv4_identification_policy(
mut self,
ipv4_identification_policy: Ipv4FragmentIdentificationPolicy,
) -> Self {
self.ipv4_identification_policy = ipv4_identification_policy;
self
}
pub const fn configured_ipv4_identification_policy(&self) -> Ipv4FragmentIdentificationPolicy {
self.ipv4_identification_policy
}
pub const fn ipv6_identification_policy(
mut self,
ipv6_identification_policy: Ipv6FragmentIdentificationPolicy,
) -> Self {
self.ipv6_identification_policy = ipv6_identification_policy;
self
}
pub const fn ipv6_identification(mut self, identification: u32) -> Self {
self.ipv6_identification_policy = Ipv6FragmentIdentificationPolicy::Fixed(identification);
self
}
pub const fn configured_ipv6_identification_policy(&self) -> Ipv6FragmentIdentificationPolicy {
self.ipv6_identification_policy
}
pub const fn ipv6_identification_seed(mut self, seed: u64) -> Self {
self.ipv6_identification_seed = Some(seed);
self
}
pub const fn clear_ipv6_identification_seed(mut self) -> Self {
self.ipv6_identification_seed = None;
self
}
pub const fn configured_ipv6_identification_seed(&self) -> Option<u64> {
self.ipv6_identification_seed
}
pub const fn trace_passthrough(mut self, trace_passthrough: bool) -> Self {
self.trace_passthrough = trace_passthrough;
self
}
pub const fn traces_passthrough(&self) -> bool {
self.trace_passthrough
}
pub fn validate(&self) -> Result<()> {
validate_mtu(self.mtu)
}
}
fn validate_nonzero(field: &'static str, value: usize) -> Result<()> {
if value == 0 {
return Err(CrafterError::invalid_field_value(
field,
"must be greater than zero",
));
}
Ok(())
}
fn validate_nonzero_duration(field: &'static str, value: Duration) -> Result<()> {
if value.is_zero() {
return Err(CrafterError::invalid_field_value(
field,
"must be greater than zero",
));
}
Ok(())
}
fn validate_mtu(mtu: usize) -> Result<()> {
if mtu < IP_FRAGMENT_MIN_MTU {
return Err(CrafterError::invalid_field_value(
"ip.fragment.mtu",
"must fit the minimum IPv4 header and one 8-byte fragment unit",
));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::wire::ip::{IpDefrag, IpFragment};
use crate::wire::record::{BackendKind, PacketOrigin, PacketRecord};
use crate::Raw;
fn assert_invalid_field(error: CrafterError, expected_field: &'static str) {
match error {
CrafterError::InvalidFieldValue { field, .. } => assert_eq!(field, expected_field),
other => panic!("expected InvalidFieldValue, got {other:?}"),
}
}
fn record(payload: &'static str) -> PacketRecord {
PacketRecord::new(Raw::from(payload))
.with_origin(PacketOrigin::Generated)
.with_backend(BackendKind::Memory)
}
#[test]
fn ip_defrag_config_builders_expose_bounds_and_policies() {
let config = IpDefragConfig::new()
.pass_non_fragments(false)
.max_datagrams(64)
.max_bytes_per_datagram(8192)
.max_age(Duration::from_secs(30))
.overlap_policy(IpDefragOverlapPolicy::DropConflicting)
.ipv6_atomic_fragments(Ipv6AtomicFragmentPolicy::Normalize)
.trace_passthrough(true)
.trace_evictions(true);
assert!(!config.emits_non_fragments());
assert_eq!(config.max_datagrams_limit(), 64);
assert_eq!(config.max_bytes_per_datagram_limit(), 8192);
assert_eq!(config.max_age_limit(), Duration::from_secs(30));
assert_eq!(
config.configured_overlap_policy(),
IpDefragOverlapPolicy::DropConflicting
);
assert_eq!(
config.ipv6_atomic_fragment_policy(),
Ipv6AtomicFragmentPolicy::Normalize
);
assert!(config.traces_passthrough());
assert!(config.traces_evictions());
config.validate().unwrap();
}
#[test]
fn ip_defrag_config_try_builders_reject_unbounded_settings() {
let max_datagrams_error = IpDefragConfig::new().try_max_datagrams(0).unwrap_err();
assert_invalid_field(max_datagrams_error, "ip.defrag.max_datagrams");
let max_bytes_error = IpDefragConfig::new()
.try_max_bytes_per_datagram(0)
.unwrap_err();
assert_invalid_field(max_bytes_error, "ip.defrag.max_bytes_per_datagram");
let max_age_error = IpDefragConfig::new()
.try_max_age(Duration::from_secs(0))
.unwrap_err();
assert_invalid_field(max_age_error, "ip.defrag.max_age");
}
#[test]
fn ip_fragment_config_builders_expose_mtu_df_ids_and_trace() {
let config = IpFragmentConfig::new(1500)
.with_mtu(1280)
.dont_fragment_policy(Ipv4DontFragmentPolicy::PassThrough)
.ipv4_identification_policy(Ipv4FragmentIdentificationPolicy::Fixed(0x1234))
.ipv6_identification(0xfeed_beef)
.ipv6_identification_seed(0x3100_0000_0000_0000)
.trace_passthrough(true);
assert_eq!(config.mtu(), 1280);
assert!(config.honors_dont_fragment());
assert_eq!(
config.configured_dont_fragment_policy(),
Ipv4DontFragmentPolicy::PassThrough
);
assert_eq!(
config.configured_ipv4_identification_policy(),
Ipv4FragmentIdentificationPolicy::Fixed(0x1234)
);
assert_eq!(
config.configured_ipv6_identification_policy(),
Ipv6FragmentIdentificationPolicy::Fixed(0xfeed_beef)
);
assert_eq!(
config.configured_ipv6_identification_seed(),
Some(0x3100_0000_0000_0000)
);
assert!(config.traces_passthrough());
config.validate().unwrap();
}
#[test]
fn ip_fragment_config_can_clear_ipv6_identification_seed() {
let config = IpFragmentConfig::new(1280)
.ipv6_identification_seed(31)
.clear_ipv6_identification_seed();
assert_eq!(config.configured_ipv6_identification_seed(), None);
}
#[test]
fn ip_fragment_config_keeps_legacy_df_bool_builder() {
let ignore_df = IpFragmentConfig::new(1500).honor_dont_fragment(false);
assert!(!ignore_df.honors_dont_fragment());
assert_eq!(
ignore_df.configured_dont_fragment_policy(),
Ipv4DontFragmentPolicy::FragmentAnyway
);
let honor_df = ignore_df.honor_dont_fragment(true);
assert!(honor_df.honors_dont_fragment());
assert_eq!(
honor_df.configured_dont_fragment_policy(),
Ipv4DontFragmentPolicy::Error
);
}
#[test]
fn ip_fragment_config_try_builders_reject_too_small_mtu() {
let new_error = IpFragmentConfig::try_new(IP_FRAGMENT_MIN_MTU - 1).unwrap_err();
assert_invalid_field(new_error, "ip.fragment.mtu");
let mtu_error = IpFragmentConfig::new(1500)
.try_mtu(IP_FRAGMENT_MIN_MTU - 1)
.unwrap_err();
assert_invalid_field(mtu_error, "ip.fragment.mtu");
}
#[test]
fn trace_passthrough_adds_transform_trace_to_unchanged_records() {
let mut defrag = IpDefrag::new().with_config(IpDefragConfig::new().trace_passthrough(true));
let defrag_output = defrag.defrag_record(record("defrag")).unwrap();
let defrag_traces = defrag_output.records()[0].metadata().transforms();
assert_eq!(defrag_traces.len(), 1);
assert_eq!(defrag_traces[0].name(), "ip-defrag");
assert_eq!(defrag_traces[0].note(), Some("passthrough"));
let mut fragment =
IpFragment::with_config(IpFragmentConfig::new(1280).trace_passthrough(true));
let fragment_output = fragment.fragment_record(record("fragment")).unwrap();
let fragment_traces = fragment_output.records()[0].metadata().transforms();
assert_eq!(fragment_traces.len(), 1);
assert_eq!(fragment_traces[0].name(), "ip-fragment");
assert_eq!(fragment_traces[0].note(), Some("passthrough"));
}
}