#![allow(missing_docs)]
use crate::{MAX_UDP_PAYLOAD, MtuDiscoveryConfig, frame::Crypto, packet::SpaceId};
use std::cmp;
use tracing::{debug, trace};
pub const ML_KEM_768_HANDSHAKE_OVERHEAD: u16 = 1184 + 1088; pub const ML_DSA_65_HANDSHAKE_OVERHEAD: u16 = 1952 + 3309; pub const HYBRID_HANDSHAKE_OVERHEAD: u16 =
ML_KEM_768_HANDSHAKE_OVERHEAD + ML_DSA_65_HANDSHAKE_OVERHEAD + 256;
pub const PQC_MIN_MTU: u16 = 2048;
pub const PQC_RECOMMENDED_MTU: u16 = 4096;
pub const MAX_CRYPTO_FRAME_SIZE: u16 = 1200;
#[derive(Debug, Clone)]
pub struct PqcPacketHandler {
pqc_detected: bool,
estimated_handshake_size: u32,
mtu_discovery_triggered: bool,
}
impl PqcPacketHandler {
pub fn new() -> Self {
Self {
pqc_detected: false,
estimated_handshake_size: 0,
mtu_discovery_triggered: false,
}
}
pub fn detect_pqc_handshake(&mut self, crypto_data: &[u8], space: SpaceId) -> bool {
if !matches!(space, SpaceId::Initial | SpaceId::Handshake) {
return self.pqc_detected;
}
if crypto_data.is_empty() {
return self.pqc_detected;
}
let msg_type = crypto_data[0];
if crypto_data.len() < 4 {
return self.pqc_detected;
}
if msg_type == 1 || msg_type == 2 {
if self.detect_pqc_in_extensions(crypto_data) {
debug!("Detected PQC handshake");
self.pqc_detected = true;
self.estimated_handshake_size = Self::pqc_handshake_size();
return true;
}
}
self.pqc_detected
}
fn detect_pqc_in_extensions(&self, data: &[u8]) -> bool {
data.len() > 100
}
fn pqc_handshake_size() -> u32 {
16384
}
pub fn should_trigger_mtu_discovery(&mut self) -> bool {
if self.pqc_detected && !self.mtu_discovery_triggered {
self.mtu_discovery_triggered = true;
true
} else {
false
}
}
pub fn get_pqc_mtu_config(&self) -> MtuDiscoveryConfig {
let mut config = MtuDiscoveryConfig::default();
if self.pqc_detected {
config.upper_bound(PQC_RECOMMENDED_MTU.min(MAX_UDP_PAYLOAD));
config.minimum_change = 128;
config.interval = std::time::Duration::from_millis(100);
}
config
}
pub fn calculate_crypto_frame_size(
&self,
available_space: usize,
remaining_data: usize,
) -> usize {
let max_frame_size = if self.pqc_detected {
available_space.min(MAX_CRYPTO_FRAME_SIZE as usize)
} else {
available_space.min(600)
};
cmp::min(max_frame_size, remaining_data)
}
pub fn adjust_coalescing_for_pqc(&self, current_size: usize, space: SpaceId) -> bool {
if !self.pqc_detected {
return false;
}
matches!(space, SpaceId::Initial) && current_size > 600
}
pub fn get_min_packet_size(&self, space: SpaceId) -> u16 {
if !self.pqc_detected {
return 1200; }
match space {
SpaceId::Initial => PQC_MIN_MTU,
SpaceId::Handshake => 1500, _ => 1200,
}
}
pub fn is_handshake_complete(&self, bytes_sent: u64) -> bool {
if !self.pqc_detected {
return false; }
bytes_sent >= self.estimated_handshake_size as u64
}
pub fn fragment_crypto_data(
&self,
data: &[u8],
offset: u64,
max_packet_size: usize,
) -> Vec<Crypto> {
let mut frames = Vec::new();
let mut current_offset = offset;
let mut remaining = data;
while !remaining.is_empty() {
let available_space = max_packet_size.saturating_sub(16);
let frame_size = self.calculate_crypto_frame_size(available_space, remaining.len());
let (chunk, rest) = remaining.split_at(frame_size);
frames.push(Crypto {
offset: current_offset,
data: chunk.to_vec().into(),
});
current_offset += frame_size as u64;
remaining = rest;
}
trace!(
"Fragmented {} bytes into {} CRYPTO frames",
data.len(),
frames.len()
);
frames
}
pub fn on_packet_sent(&mut self, space: SpaceId, size: u16) {
if self.pqc_detected && matches!(space, SpaceId::Initial | SpaceId::Handshake) {
trace!("PQC packet sent in {:?}: {} bytes", space, size);
}
}
pub fn reset(&mut self) {
self.pqc_detected = false;
self.estimated_handshake_size = 0;
self.mtu_discovery_triggered = false;
}
}
impl Default for PqcPacketHandler {
fn default() -> Self {
Self::new()
}
}
pub trait PqcPacketHandling {
fn pqc_packet_handler(&mut self) -> &mut PqcPacketHandler;
fn handle_pqc_detection(&mut self, crypto_data: &[u8], space: SpaceId);
fn adjust_mtu_for_pqc(&mut self);
fn get_pqc_optimal_packet_size(&self, space: SpaceId) -> u16;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pqc_packet_handler_creation() {
let handler = PqcPacketHandler::new();
assert!(!handler.pqc_detected);
assert_eq!(handler.estimated_handshake_size, 0);
assert!(!handler.mtu_discovery_triggered);
}
#[test]
fn test_mtu_discovery_trigger() {
let mut handler = PqcPacketHandler::new();
assert!(!handler.should_trigger_mtu_discovery());
handler.pqc_detected = true;
assert!(handler.should_trigger_mtu_discovery());
assert!(!handler.should_trigger_mtu_discovery());
}
#[test]
fn test_crypto_frame_size_calculation() {
let handler = PqcPacketHandler::new();
assert_eq!(handler.calculate_crypto_frame_size(1000, 2000), 600);
assert_eq!(handler.calculate_crypto_frame_size(500, 2000), 500);
assert_eq!(handler.calculate_crypto_frame_size(1000, 400), 400);
let mut handler = PqcPacketHandler::new();
handler.pqc_detected = true;
assert_eq!(handler.calculate_crypto_frame_size(1500, 2000), 1200);
assert_eq!(handler.calculate_crypto_frame_size(500, 2000), 500);
}
#[test]
fn test_min_packet_size() {
let handler = PqcPacketHandler::new();
assert_eq!(handler.get_min_packet_size(SpaceId::Initial), 1200);
assert_eq!(handler.get_min_packet_size(SpaceId::Handshake), 1200);
assert_eq!(handler.get_min_packet_size(SpaceId::Data), 1200);
let mut handler = PqcPacketHandler::new();
handler.pqc_detected = true;
assert_eq!(handler.get_min_packet_size(SpaceId::Initial), PQC_MIN_MTU);
assert_eq!(handler.get_min_packet_size(SpaceId::Handshake), 1500);
assert_eq!(handler.get_min_packet_size(SpaceId::Data), 1200);
}
#[test]
fn test_crypto_data_fragmentation() {
let handler = PqcPacketHandler::new();
let data = vec![0u8; 500];
let frames = handler.fragment_crypto_data(&data, 1000, 1200);
assert_eq!(frames.len(), 1);
assert_eq!(frames[0].offset, 1000);
assert_eq!(frames[0].data.len(), 500);
let data = vec![0u8; 3000];
let frames = handler.fragment_crypto_data(&data, 0, 700);
assert_eq!(frames.len(), 5); assert_eq!(frames[0].offset, 0);
assert_eq!(frames[0].data.len(), 600);
assert_eq!(frames[4].offset, 2400);
assert_eq!(frames[4].data.len(), 600);
}
#[test]
fn test_pqc_handshake_size() {
assert_eq!(PqcPacketHandler::pqc_handshake_size(), 16384);
}
#[test]
fn test_coalescing_adjustment() {
let handler = PqcPacketHandler::new();
assert!(!handler.adjust_coalescing_for_pqc(800, SpaceId::Initial));
assert!(!handler.adjust_coalescing_for_pqc(500, SpaceId::Initial));
let mut handler = PqcPacketHandler::new();
handler.pqc_detected = true;
assert!(handler.adjust_coalescing_for_pqc(800, SpaceId::Initial));
assert!(!handler.adjust_coalescing_for_pqc(500, SpaceId::Initial));
assert!(!handler.adjust_coalescing_for_pqc(800, SpaceId::Handshake));
}
#[test]
fn test_handshake_completion_check() {
let mut handler = PqcPacketHandler::new();
assert!(!handler.is_handshake_complete(10000));
handler.pqc_detected = true;
handler.estimated_handshake_size = 16384;
assert!(!handler.is_handshake_complete(8000));
assert!(handler.is_handshake_complete(16384));
assert!(handler.is_handshake_complete(20000));
}
#[test]
fn test_handler_reset() {
let mut handler = PqcPacketHandler::new();
handler.pqc_detected = true;
handler.estimated_handshake_size = 16384;
handler.mtu_discovery_triggered = true;
handler.reset();
assert!(!handler.pqc_detected);
assert_eq!(handler.estimated_handshake_size, 0);
assert!(!handler.mtu_discovery_triggered);
}
#[test]
fn test_pqc_mtu_config() {
let mut handler = PqcPacketHandler::new();
let config = handler.get_pqc_mtu_config();
assert_eq!(config.upper_bound, 1452);
handler.pqc_detected = true;
let config = handler.get_pqc_mtu_config();
assert_eq!(
config.upper_bound,
PQC_RECOMMENDED_MTU.min(crate::MAX_UDP_PAYLOAD)
);
assert_eq!(config.minimum_change, 128);
}
#[test]
fn test_pqc_constants() {
assert_eq!(ML_KEM_768_HANDSHAKE_OVERHEAD, 2272);
assert_eq!(ML_DSA_65_HANDSHAKE_OVERHEAD, 5261);
assert_eq!(HYBRID_HANDSHAKE_OVERHEAD, 7789);
assert_eq!(PQC_MIN_MTU, 2048);
assert_eq!(PQC_RECOMMENDED_MTU, 4096);
assert_eq!(MAX_CRYPTO_FRAME_SIZE, 1200);
}
}