use std::collections::BTreeMap;
use std::collections::btree_map::Entry;
const MAX_STREAM_BUFFER: usize = 1024 * 1024;
#[derive(Debug, PartialEq, Eq)]
pub enum ReassemblyResult {
Contiguous,
Gap,
Buffered,
Duplicate,
Capped,
}
#[derive(Debug, Default)]
pub struct ReassemblyBuffer {
pending: Vec<u8>,
segments: BTreeMap<u32, Vec<u8>>,
base_seq: Option<u32>,
next_seq: u32,
total_buffered: usize,
fin: bool,
capped: bool,
}
fn seq_lt(a: u32, b: u32) -> bool {
(a.wrapping_sub(b) as i32) < 0
}
impl ReassemblyBuffer {
pub fn new() -> Self {
Self::default()
}
pub fn insert_segment(&mut self, seq: u32, data: &[u8]) -> ReassemblyResult {
if data.is_empty() {
return ReassemblyResult::Buffered;
}
if self.base_seq.is_none() {
self.base_seq = Some(seq);
self.next_seq = seq;
}
let mut seq = seq;
let mut data = data;
let end = seq.wrapping_add(data.len() as u32);
if !seq_lt(self.next_seq, end) {
return ReassemblyResult::Duplicate;
}
if seq_lt(seq, self.next_seq) {
let skip = self.next_seq.wrapping_sub(seq) as usize;
if skip >= data.len() {
return ReassemblyResult::Duplicate;
}
data = &data[skip..];
seq = self.next_seq;
}
if self.capped {
return ReassemblyResult::Capped;
}
if self.total_buffered + data.len() > MAX_STREAM_BUFFER {
self.capped = true;
return ReassemblyResult::Capped;
}
if seq == self.next_seq {
self.pending.extend_from_slice(data);
self.next_seq = self.next_seq.wrapping_add(data.len() as u32);
self.total_buffered += data.len();
while let Some(seg) = self.segments.remove(&self.next_seq) {
self.next_seq = self.next_seq.wrapping_add(seg.len() as u32);
self.pending.extend_from_slice(&seg);
}
ReassemblyResult::Contiguous
} else {
match self.segments.entry(seq) {
Entry::Vacant(v) => {
v.insert(data.to_vec());
self.total_buffered += data.len();
}
Entry::Occupied(_) => return ReassemblyResult::Duplicate,
}
ReassemblyResult::Gap
}
}
pub fn try_drain(&mut self) -> Option<Vec<u8>> {
if self.pending.is_empty() {
return None;
}
self.total_buffered -= self.pending.len();
Some(std::mem::take(&mut self.pending))
}
pub fn mark_fin(&mut self) {
self.fin = true;
}
pub fn is_complete(&self) -> bool {
self.fin && self.segments.is_empty()
}
pub fn buffered_len(&self) -> usize {
self.total_buffered
}
pub fn capped(&self) -> bool {
self.capped
}
}
#[derive(Debug, Clone, Default)]
pub struct StreamData {
pub client_to_server: Vec<u8>,
pub server_to_client: Vec<u8>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_in_order() {
let mut b = ReassemblyBuffer::new();
assert_eq!(
b.insert_segment(1000, b"hello"),
ReassemblyResult::Contiguous
);
assert_eq!(
b.insert_segment(1005, b"world"),
ReassemblyResult::Contiguous
);
assert_eq!(b.try_drain().unwrap(), b"helloworld");
assert!(b.try_drain().is_none());
}
#[test]
fn test_out_of_order() {
let mut b = ReassemblyBuffer::new();
assert_eq!(b.insert_segment(1000, b"AAA"), ReassemblyResult::Contiguous);
assert_eq!(b.insert_segment(1006, b"CCC"), ReassemblyResult::Gap);
assert_eq!(b.try_drain().unwrap(), b"AAA");
assert!(b.try_drain().is_none());
assert_eq!(b.insert_segment(1003, b"BBB"), ReassemblyResult::Contiguous);
assert_eq!(b.try_drain().unwrap(), b"BBBCCC");
}
#[test]
fn test_duplicate_and_overlap() {
let mut b = ReassemblyBuffer::new();
b.insert_segment(1000, b"hello");
assert_eq!(b.try_drain().unwrap(), b"hello");
assert_eq!(
b.insert_segment(1000, b"hello"),
ReassemblyResult::Duplicate
);
assert_eq!(b.insert_segment(1003, b"lop"), ReassemblyResult::Contiguous);
assert_eq!(b.try_drain().unwrap(), b"p");
}
#[test]
fn test_seq_wrap() {
let mut b = ReassemblyBuffer::new();
let base = 0xFFFF_FFFEu32; assert_eq!(b.insert_segment(base, b"AB"), ReassemblyResult::Contiguous);
assert_eq!(
b.insert_segment(base.wrapping_add(2), b"CD"),
ReassemblyResult::Contiguous
);
assert_eq!(b.try_drain().unwrap(), b"ABCD");
}
#[test]
fn test_cap() {
let mut b = ReassemblyBuffer::new();
b.insert_segment(0, b"x");
b.try_drain();
let big = vec![0u8; MAX_STREAM_BUFFER + 1];
assert_eq!(b.insert_segment(1, &big), ReassemblyResult::Capped);
assert!(b.capped());
}
#[test]
fn test_fin_completion() {
let mut b = ReassemblyBuffer::new();
b.insert_segment(1, b"hi");
b.try_drain();
assert!(!b.is_complete());
b.mark_fin();
assert!(b.is_complete());
}
}