use borsh::{BorshDeserialize, BorshSerialize};
use std::collections::HashMap;
use std::time::{Duration, Instant};
const MAX_UDP_PAYLOAD: usize = 1200;
pub const MAX_REASSEMBLED_LEN: usize = 256 * 1024;
pub const MAX_TOTAL_CHUNKS: u16 = (MAX_REASSEMBLED_LEN / MAX_UDP_PAYLOAD + 1) as u16;
pub const MAX_CONCURRENT_ASSEMBLIES: usize = 256;
#[derive(BorshSerialize, BorshDeserialize, Debug, Clone)]
pub struct CryptoFrame {
pub session_id: [u8; 16], pub packet_id: u32,
pub chunk_index: u16,
pub total_chunks: u16,
pub payload: Vec<u8>,
}
pub struct FragmentAssembler {
assemblies: HashMap<([u8; 16], u32), AssemblyState>,
}
struct AssemblyState {
chunks: HashMap<u16, Vec<u8>>,
total_chunks: u16,
last_update: Instant,
}
impl Default for FragmentAssembler {
fn default() -> Self {
Self::new()
}
}
impl FragmentAssembler {
pub fn new() -> Self {
Self {
assemblies: HashMap::new(),
}
}
pub fn process_chunk(&mut self, frame: CryptoFrame) -> Option<Vec<u8>> {
if frame.total_chunks == 0
|| frame.total_chunks > MAX_TOTAL_CHUNKS
|| frame.chunk_index >= frame.total_chunks
|| frame.payload.len() > MAX_UDP_PAYLOAD
{
return None;
}
let key = (frame.session_id, frame.packet_id);
if !self.assemblies.contains_key(&key) && self.assemblies.len() >= MAX_CONCURRENT_ASSEMBLIES
{
self.evict_stalest();
}
let is_complete = {
let state = self.assemblies.entry(key).or_insert_with(|| AssemblyState {
chunks: HashMap::new(),
total_chunks: frame.total_chunks,
last_update: Instant::now(),
});
state.last_update = Instant::now();
state.chunks.insert(frame.chunk_index, frame.payload);
state.chunks.len() == state.total_chunks as usize
};
if is_complete {
#[allow(clippy::unwrap_used, clippy::disallowed_methods)]
let state = self.assemblies.remove(&key).unwrap();
let mut total_size = 0;
for i in 0..state.total_chunks {
if let Some(chunk) = state.chunks.get(&i) {
total_size += chunk.len();
} else {
return None;
}
}
let mut packet = Vec::with_capacity(total_size);
for i in 0..state.total_chunks {
#[allow(clippy::unwrap_used, clippy::disallowed_methods)]
packet.extend_from_slice(state.chunks.get(&i).unwrap());
}
return Some(packet);
}
None
}
fn evict_stalest(&mut self) {
if let Some((&stalest_key, _)) = self
.assemblies
.iter()
.min_by_key(|(_, state)| state.last_update)
{
self.assemblies.remove(&stalest_key);
}
}
pub fn len(&self) -> usize {
self.assemblies.len()
}
pub fn is_empty(&self) -> bool {
self.assemblies.is_empty()
}
pub fn get_nacks_and_evict(&mut self) -> Vec<([u8; 16], u32, Vec<u16>)> {
let now = Instant::now();
let mut nacks = Vec::new();
let mut to_remove = Vec::new();
for (key, state) in self.assemblies.iter() {
let elapsed = now.duration_since(state.last_update);
if elapsed > Duration::from_millis(5000) {
to_remove.push(*key);
} else if elapsed > Duration::from_millis(50) {
let mut missing = Vec::new();
for i in 0..state.total_chunks {
if !state.chunks.contains_key(&i) {
missing.push(i);
}
}
if !missing.is_empty() {
nacks.push((key.0, key.1, missing));
}
}
}
for k in to_remove {
self.assemblies.remove(&k);
}
nacks
}
}
pub fn fragment_payload(session_id: [u8; 16], packet_id: u32, payload: &[u8]) -> Vec<CryptoFrame> {
let mut frames = Vec::new();
let chunks = payload.chunks(MAX_UDP_PAYLOAD);
let total_chunks = chunks.len() as u16;
for (i, chunk) in chunks.enumerate() {
frames.push(CryptoFrame {
session_id,
packet_id,
chunk_index: i as u16,
total_chunks,
payload: chunk.to_vec(),
});
}
frames
}
#[cfg(test)]
mod tests {
use super::*;
fn frame(packet_id: u32, idx: u16, total: u16, payload_len: usize) -> CryptoFrame {
CryptoFrame {
session_id: [0u8; 16],
packet_id,
chunk_index: idx,
total_chunks: total,
payload: vec![0xABu8; payload_len],
}
}
#[test]
fn fragment_reassemble_round_trip() {
let payload: Vec<u8> = (0..3000u32).map(|i| i as u8).collect();
let frames = fragment_payload([1u8; 16], 42, &payload);
assert!(frames.len() > 1, "3000 bytes must fragment");
let mut asm = FragmentAssembler::new();
let mut out = None;
for f in frames {
if let Some(p) = asm.process_chunk(f) {
out = Some(p);
}
}
assert_eq!(out.as_deref(), Some(payload.as_slice()));
assert!(asm.is_empty(), "completed assembly is removed");
}
#[test]
fn rejects_zero_total_chunks() {
let mut asm = FragmentAssembler::new();
assert!(asm.process_chunk(frame(1, 0, 0, 10)).is_none());
assert!(asm.is_empty(), "malformed frame must not open an assembly");
}
#[test]
fn rejects_out_of_range_chunk_index() {
let mut asm = FragmentAssembler::new();
assert!(asm.process_chunk(frame(1, 2, 2, 10)).is_none());
assert!(asm.is_empty());
}
#[test]
fn rejects_excessive_total_chunks() {
let mut asm = FragmentAssembler::new();
assert!(asm
.process_chunk(frame(1, 0, MAX_TOTAL_CHUNKS.saturating_add(1), 10))
.is_none());
assert!(asm.is_empty());
}
#[test]
fn rejects_oversized_fragment_payload() {
let mut asm = FragmentAssembler::new();
assert!(asm
.process_chunk(frame(1, 0, 4, MAX_UDP_PAYLOAD + 1))
.is_none());
assert!(asm.is_empty());
}
#[test]
fn caps_concurrent_assemblies() {
let mut asm = FragmentAssembler::new();
for packet_id in 0..(MAX_CONCURRENT_ASSEMBLIES as u32 * 4) {
assert!(asm.process_chunk(frame(packet_id, 0, 4, 10)).is_none());
assert!(
asm.len() <= MAX_CONCURRENT_ASSEMBLIES,
"assembly table exceeded its cap: {}",
asm.len()
);
}
assert_eq!(asm.len(), MAX_CONCURRENT_ASSEMBLIES);
}
}