use std::collections::{HashMap, hash_map};
use std::time::{Duration, Instant};
use crate::events::FragmentedEvent;
use crate::types::ClientId;
#[derive(Debug, Clone)]
struct FragmentBuffer {
start_time: Instant,
total_fragments: u16,
fragments: Vec<Option<Vec<u8>>>,
count: u16,
}
impl FragmentBuffer {
fn new(total_fragments: u16) -> Option<Self> {
if total_fragments == 0 || total_fragments > crate::MAX_TOTAL_FRAGMENTS {
return None;
}
Some(Self {
start_time: Instant::now(),
total_fragments,
fragments: vec![None; total_fragments as usize],
count: 0,
})
}
fn add(&mut self, index: u16, payload: Vec<u8>) -> Option<Vec<u8>> {
let idx = index as usize;
if idx >= self.fragments.len() {
return None;
}
if self.fragments[idx].is_none() {
self.fragments[idx] = Some(payload);
self.count += 1;
}
if self.count == self.total_fragments {
let mut full_payload = Vec::new();
for frag in self.fragments.drain(..) {
full_payload.extend(frag.unwrap());
}
Some(full_payload)
} else {
None
}
}
fn is_stale(&self, timeout: Duration) -> bool {
self.start_time.elapsed() > timeout
}
}
#[derive(Debug, Default, Clone)]
pub struct Reassembler {
buffers: HashMap<(ClientId, u32), FragmentBuffer>,
timeout: Duration,
}
impl Reassembler {
#[must_use]
pub fn new() -> Self {
Self {
buffers: HashMap::new(),
timeout: Duration::from_secs(5),
}
}
#[must_use]
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub fn add(&mut self, client_id: ClientId, event: FragmentedEvent) -> Option<Vec<u8>> {
if event.total_fragments == 0 || event.total_fragments > crate::MAX_TOTAL_FRAGMENTS {
tracing::warn!(
"Rejecting fragment with invalid total_fragments: {}",
event.total_fragments
);
return None;
}
let key = (client_id, event.message_id);
let buffer = match self.buffers.entry(key) {
hash_map::Entry::Occupied(e) => e.into_mut(),
hash_map::Entry::Vacant(e) => match FragmentBuffer::new(event.total_fragments) {
Some(buf) => e.insert(buf),
None => return None,
},
};
if buffer.total_fragments != event.total_fragments {
tracing::warn!(
"Fragment mismatch for message_id {}: expected {}, got {}",
event.message_id,
buffer.total_fragments,
event.total_fragments
);
return None;
}
let result = buffer.add(event.fragment_index, event.payload);
if result.is_some() {
self.buffers.remove(&key);
}
result
}
pub fn cleanup(&mut self) {
self.buffers
.retain(|_, buffer| !buffer.is_stale(self.timeout));
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_reassembly_ordered() {
let mut reassembler = Reassembler::new();
let cid = ClientId(1);
let mid = 100;
let f1 = FragmentedEvent {
message_id: mid,
fragment_index: 0,
total_fragments: 2,
payload: vec![1, 2],
};
let f2 = FragmentedEvent {
message_id: mid,
fragment_index: 1,
total_fragments: 2,
payload: vec![3, 4],
};
assert!(reassembler.add(cid, f1).is_none());
let result = reassembler.add(cid, f2).unwrap();
assert_eq!(result, vec![1, 2, 3, 4]);
}
#[test]
fn test_reassembly_out_of_order() {
let mut reassembler = Reassembler::new();
let cid = ClientId(1);
let mid = 101;
let f1 = FragmentedEvent {
message_id: mid,
fragment_index: 0,
total_fragments: 3,
payload: vec![1],
};
let f2 = FragmentedEvent {
message_id: mid,
fragment_index: 1,
total_fragments: 3,
payload: vec![2],
};
let f3 = FragmentedEvent {
message_id: mid,
fragment_index: 2,
total_fragments: 3,
payload: vec![3],
};
assert!(reassembler.add(cid, f3).is_none());
assert!(reassembler.add(cid, f1).is_none());
let result = reassembler.add(cid, f2).unwrap();
assert_eq!(result, vec![1, 2, 3]);
}
#[test]
fn test_cleanup() {
let mut reassembler = Reassembler::new().with_timeout(Duration::from_millis(10));
let cid = ClientId(1);
let mid = 102;
reassembler.add(
cid,
FragmentedEvent {
message_id: mid,
fragment_index: 0,
total_fragments: 2,
payload: vec![1],
},
);
std::thread::sleep(Duration::from_millis(20));
reassembler.cleanup();
assert!(reassembler.buffers.is_empty());
}
}