use super::codec::{FragmentInfo, Frame};
use bytes::Bytes;
use std::collections::HashMap;
pub struct FragmentAssembler {
pending: HashMap<u64, PendingPacket>,
}
struct PendingPacket {
total: u16,
fragments: HashMap<u16, Frame>,
}
impl FragmentAssembler {
pub fn new() -> Self {
Self {
pending: HashMap::new(),
}
}
pub fn submit(&mut self, frame: Frame) -> Option<Vec<u8>> {
let info = frame.fragment?;
let FragmentInfo {
sequence,
index,
total,
} = info;
if total == 0 || index >= total {
return None;
}
let entry = self.pending.entry(sequence).or_insert(PendingPacket {
total,
fragments: HashMap::new(),
});
if entry.total != total {
tracing::debug!(
sequence,
expected_total = entry.total,
got_total = total,
"fragment total mismatch for sequence; dropping reassembly"
);
return None;
}
entry.fragments.insert(index, frame);
if entry.fragments.len() != usize::from(entry.total) {
return None;
}
let packet = self.pending.remove(&sequence)?;
let mut indices: Vec<u16> = packet.fragments.keys().copied().collect();
indices.sort_unstable();
let mut out = Vec::new();
for i in indices {
let frag = packet.fragments.get(&i)?;
out.extend_from_slice(frag.payload.as_ref());
}
Some(out)
}
pub fn gc(&mut self) {
self.pending.shrink_to_fit();
}
}
impl Default for FragmentAssembler {
fn default() -> Self {
Self::new()
}
}
pub fn frame_with_fragment(
frame_id: u64,
routing_mask: u32,
payload: Bytes,
mac: [u8; 16],
fragment: FragmentInfo,
) -> Frame {
Frame {
frame_id,
routing_mask,
payload,
mac,
fragment: Some(fragment),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::frame::codec::WIRE_MAC_LEN;
#[test]
fn reassembles_out_of_order() {
let mut asm = FragmentAssembler::new();
let seq = 7u64;
let mac = [0u8; WIRE_MAC_LEN];
let f1 = frame_with_fragment(
1,
0,
Bytes::from_static(b"hel"),
mac,
FragmentInfo {
sequence: seq,
index: 0,
total: 2,
},
);
let f2 = frame_with_fragment(
2,
0,
Bytes::from_static(b"lo"),
mac,
FragmentInfo {
sequence: seq,
index: 1,
total: 2,
},
);
assert!(asm.submit(f2).is_none());
let merged = asm.submit(f1).expect("complete");
assert_eq!(merged, b"hello");
}
#[test]
fn rejects_total_mismatch() {
let mut asm = FragmentAssembler::new();
let mac = [0u8; WIRE_MAC_LEN];
let a = frame_with_fragment(
1,
0,
Bytes::from_static(b"a"),
mac,
FragmentInfo {
sequence: 1,
index: 0,
total: 2,
},
);
let b = frame_with_fragment(
2,
0,
Bytes::from_static(b"b"),
mac,
FragmentInfo {
sequence: 1,
index: 1,
total: 3,
},
);
assert!(asm.submit(a).is_none());
assert!(asm.submit(b).is_none());
}
}