use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use uuid::Uuid;
pub const FRAGMENT_HEADER_LEN: usize = 16 + 2 + 2;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct FragmentShard {
pub fragment_id: Uuid,
pub total: u16,
pub index: u16,
pub payload: Vec<u8>,
}
#[derive(Debug, Error, PartialEq, Eq)]
pub enum FragmentError {
#[error("invalid fragment size")]
InvalidSize,
#[error("duplicate fragment {index}/{total}")]
Duplicate { index: u16, total: u16 },
#[error("unexpected total mismatch")]
TotalMismatch,
}
#[derive(Default, Debug)]
pub struct FragmentAssembler {
inflight: HashMap<Uuid, Inflight>,
order: Vec<Uuid>,
}
#[derive(Debug, Clone, Copy)]
pub struct FragmentRetention {
pub max_inflight: usize,
}
impl Default for FragmentRetention {
fn default() -> Self {
Self { max_inflight: 1024 }
}
}
#[derive(Debug)]
struct Inflight {
total: u16,
received: Vec<Option<Vec<u8>>>,
seen: u16,
}
impl FragmentAssembler {
pub fn new() -> Self {
Self {
inflight: HashMap::new(),
order: Vec::new(),
}
}
pub fn push(&mut self, shard: FragmentShard) -> Result<Option<Vec<u8>>, FragmentError> {
if shard.total == 0 {
return Err(FragmentError::InvalidSize);
}
let entry = self
.inflight
.entry(shard.fragment_id)
.or_insert_with(|| Inflight::new(shard.total));
if entry.seen == 0 {
self.order.push(shard.fragment_id);
}
if entry.total != shard.total {
return Err(FragmentError::TotalMismatch);
}
if shard.index as usize >= entry.received.len() {
return Err(FragmentError::InvalidSize);
}
if entry.received[shard.index as usize].is_some() {
return Err(FragmentError::Duplicate {
index: shard.index,
total: shard.total,
});
}
entry.received[shard.index as usize] = Some(shard.payload);
entry.seen += 1;
if entry.seen == shard.total {
let mut buf = Vec::new();
for part in entry.received.iter_mut() {
if let Some(mut chunk) = part.take() {
buf.append(&mut chunk);
}
}
self.inflight.remove(&shard.fragment_id);
self.order.retain(|id| id != &shard.fragment_id);
Ok(Some(buf))
} else {
Ok(None)
}
}
pub fn inflight(&self) -> usize {
self.inflight.len()
}
pub fn enforce_retention(&mut self, retention: FragmentRetention) {
if self.inflight.len() <= retention.max_inflight {
return;
}
while self.inflight.len() > retention.max_inflight {
if let Some(oldest) = self.order.first().cloned() {
self.inflight.remove(&oldest);
self.order.remove(0);
} else {
break;
}
}
}
}
impl Inflight {
fn new(total: u16) -> Self {
Self {
total,
received: vec![None; total as usize],
seen: 0,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn assembles_in_order() {
let id = Uuid::new_v4();
let mut assembler = FragmentAssembler::new();
let total = 3;
let shard1 = FragmentShard {
fragment_id: id,
total,
index: 0,
payload: b"hel".to_vec(),
};
assert!(assembler.push(shard1).unwrap().is_none());
let shard2 = FragmentShard {
fragment_id: id,
total,
index: 1,
payload: b"lo ".to_vec(),
};
assert!(assembler.push(shard2).unwrap().is_none());
let shard3 = FragmentShard {
fragment_id: id,
total,
index: 2,
payload: b"world".to_vec(),
};
let assembled = assembler.push(shard3).unwrap().expect("should assemble");
assert_eq!(assembled, b"hello world");
assert_eq!(assembler.inflight(), 0);
}
#[test]
fn assembles_out_of_order() {
let id = Uuid::new_v4();
let mut assembler = FragmentAssembler::new();
let total = 2;
let shard2 = FragmentShard {
fragment_id: id,
total,
index: 1,
payload: b"world".to_vec(),
};
assert!(assembler.push(shard2).unwrap().is_none());
let shard1 = FragmentShard {
fragment_id: id,
total,
index: 0,
payload: b"hello ".to_vec(),
};
let assembled = assembler.push(shard1).unwrap().expect("should assemble");
assert_eq!(assembled, b"hello world");
}
#[test]
fn rejects_duplicates() {
let id = Uuid::new_v4();
let mut assembler = FragmentAssembler::new();
let total = 2;
let shard0 = FragmentShard {
fragment_id: id,
total,
index: 0,
payload: b"abc".to_vec(),
};
assert!(assembler.push(shard0.clone()).unwrap().is_none());
let err = assembler.push(shard0).unwrap_err();
assert_eq!(err, FragmentError::Duplicate { index: 0, total: 2 });
let shard1 = FragmentShard {
fragment_id: id,
total,
index: 1,
payload: b"xyz".to_vec(),
};
let assembled = assembler.push(shard1).unwrap().expect("should assemble");
assert_eq!(assembled, b"abcxyz");
}
#[test]
fn mismatched_total_is_error() {
let id = Uuid::new_v4();
let mut assembler = FragmentAssembler::new();
let shard1 = FragmentShard {
fragment_id: id,
total: 2,
index: 0,
payload: vec![],
};
assert!(assembler.push(shard1).is_ok());
let shard2 = FragmentShard {
fragment_id: id,
total: 3,
index: 1,
payload: vec![],
};
let err = assembler.push(shard2).unwrap_err();
assert_eq!(err, FragmentError::TotalMismatch);
}
#[test]
fn retention_evicts_oldest() {
let mut assembler = FragmentAssembler::new();
let retention = FragmentRetention { max_inflight: 2 };
let ids: Vec<_> = (0..3).map(|_| Uuid::new_v4()).collect();
for id in &ids {
let shard = FragmentShard {
fragment_id: *id,
total: 2,
index: 0,
payload: b"x".to_vec(),
};
let _ = assembler.push(shard).unwrap();
}
assembler.enforce_retention(retention);
assert!(assembler.inflight.contains_key(&ids[1]));
assert!(assembler.inflight.contains_key(&ids[2]));
assert!(!assembler.inflight.contains_key(&ids[0]));
}
}