use std::collections::{BTreeMap, HashMap};
use std::time::{Duration, Instant};
use crate::fragment_frame::FragmentFrame;
pub const DEFAULT_REASSEMBLY_TIMEOUT: Duration = Duration::from_secs(10);
struct ReassemblyBuffer {
total_length: u16,
chunks: BTreeMap<u16, bytes::Bytes>,
created_at: Instant,
}
impl ReassemblyBuffer {
fn new(total_length: u16) -> Self {
Self {
total_length,
chunks: BTreeMap::new(),
created_at: Instant::now(),
}
}
fn insert(&mut self, offset: u16, payload: bytes::Bytes) -> bool {
use std::collections::btree_map::Entry;
match self.chunks.entry(offset) {
Entry::Vacant(e) => {
e.insert(payload);
true
}
Entry::Occupied(_) => false,
}
}
fn try_reassemble(&self) -> Option<Vec<u8>> {
let mut covered_up_to = 0usize;
let total = self.total_length as usize;
for (&offset, chunk) in &self.chunks {
let start = offset as usize;
if start > covered_up_to {
return None;
}
let end = start + chunk.len();
if end > total {
return None;
}
if end > covered_up_to {
covered_up_to = end;
}
}
if covered_up_to != total {
return None;
}
let mut buf = vec![0u8; total];
for (&offset, chunk) in &self.chunks {
let start = offset as usize;
let end = start + chunk.len();
buf[start..end].copy_from_slice(chunk);
}
Some(buf)
}
fn is_expired(&self, timeout: Duration) -> bool {
self.created_at.elapsed() >= timeout
}
}
pub struct Reassembler {
timeout: Duration,
buffers: HashMap<u32, ReassemblyBuffer>,
}
impl Reassembler {
pub fn new() -> Self {
Self {
timeout: DEFAULT_REASSEMBLY_TIMEOUT,
buffers: HashMap::new(),
}
}
pub fn with_timeout(timeout: Duration) -> Self {
Self {
timeout,
buffers: HashMap::new(),
}
}
pub fn insert(&mut self, frame: FragmentFrame) -> Option<Vec<u8>> {
let buf = self
.buffers
.entry(frame.fragment_id)
.or_insert_with(|| ReassemblyBuffer::new(frame.total_length));
if buf.total_length != frame.total_length {
return None;
}
buf.insert(frame.fragment_offset, frame.payload);
if let Some(packet) = buf.try_reassemble() {
self.buffers.remove(&frame.fragment_id);
return Some(packet);
}
None
}
pub fn expire_stale(&mut self) {
let timeout = self.timeout;
self.buffers.retain(|_, buf| !buf.is_expired(timeout));
}
pub fn buffer_count(&self) -> usize {
self.buffers.len()
}
}
impl Default for Reassembler {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests;