use std::{
cmp::Ordering,
collections::{BinaryHeap, binary_heap::PeekMut},
mem,
};
use bytes::{Buf, Bytes, BytesMut};
use crate::range_set::ArrayRangeSet;
#[derive(Debug, Default)]
pub(super) struct Assembler {
state: State,
data: BinaryHeap<Buffer>,
buffered: usize,
allocated: usize,
bytes_read: u64,
end: u64,
}
impl Assembler {
pub(super) fn new() -> Self {
Self::default()
}
pub(super) fn reinit(&mut self) {
let old_data = mem::take(&mut self.data);
*self = Self::default();
self.data = old_data;
self.data.clear();
}
pub(super) fn ensure_ordering(&mut self, ordered: bool) -> Result<(), IllegalOrderedRead> {
if ordered && !self.state.is_ordered() {
return Err(IllegalOrderedRead);
} else if !ordered && self.state.is_ordered() {
if !self.data.is_empty() {
self.defragment();
}
let mut recvd = ArrayRangeSet::new();
recvd.insert(0..self.bytes_read);
for chunk in &self.data {
recvd.insert(chunk.offset..chunk.offset + chunk.bytes.len() as u64);
}
self.state = State::Unordered { recvd };
}
Ok(())
}
pub(super) fn read(&mut self, max_length: usize, ordered: bool) -> Option<Chunk> {
loop {
let mut chunk = self.data.peek_mut()?;
if ordered {
if chunk.offset > self.bytes_read {
return None;
} else if (chunk.offset + chunk.bytes.len() as u64) <= self.bytes_read {
self.buffered -= chunk.bytes.len();
self.allocated -= chunk.allocation_size;
PeekMut::pop(chunk);
continue;
}
let start = (self.bytes_read - chunk.offset) as usize;
if start > 0 {
chunk.bytes.advance(start);
chunk.offset += start as u64;
self.buffered -= start;
}
}
return Some(if max_length < chunk.bytes.len() {
self.bytes_read += max_length as u64;
let offset = chunk.offset;
chunk.offset += max_length as u64;
self.buffered -= max_length;
Chunk::new(offset, chunk.bytes.split_to(max_length))
} else {
self.bytes_read += chunk.bytes.len() as u64;
self.buffered -= chunk.bytes.len();
self.allocated -= chunk.allocation_size;
let chunk = PeekMut::pop(chunk);
Chunk::new(chunk.offset, chunk.bytes)
});
}
}
fn defragment(&mut self) {
let new = BinaryHeap::with_capacity(self.data.len());
let old = mem::replace(&mut self.data, new);
let mut buffers = old.into_sorted_vec();
self.buffered = 0;
let mut fragmented_buffered = 0;
let mut offset = self.bytes_read;
for chunk in buffers.iter_mut().rev() {
chunk.try_mark_defragment(offset);
let size = chunk.bytes.len();
offset = chunk.offset + size as u64;
self.buffered += size;
if !chunk.defragmented {
fragmented_buffered += size;
}
}
self.allocated = self.buffered;
let mut buffer = BytesMut::with_capacity(fragmented_buffered);
let mut offset = self.bytes_read;
for chunk in buffers.into_iter().rev() {
if chunk.defragmented {
if !chunk.bytes.is_empty() {
self.data.push(chunk);
}
continue;
}
if chunk.offset != offset + (buffer.len() as u64) {
if !buffer.is_empty() {
self.data
.push(Buffer::new_defragmented(offset, buffer.split().freeze()));
}
offset = chunk.offset;
}
buffer.extend_from_slice(&chunk.bytes);
}
if !buffer.is_empty() {
self.data
.push(Buffer::new_defragmented(offset, buffer.split().freeze()));
}
}
pub(super) fn insert(&mut self, mut offset: u64, mut bytes: Bytes, allocation_size: usize) {
debug_assert!(
bytes.len() <= allocation_size,
"allocation_size less than bytes.len(): {:?} < {:?}",
allocation_size,
bytes.len()
);
self.end = self.end.max(offset + bytes.len() as u64);
if let State::Unordered { ref mut recvd } = self.state {
let range = offset..offset + bytes.len() as u64;
for duplicate in recvd.iter_range(range.clone()) {
if duplicate.start > offset {
let buffer = Buffer::new(
offset,
bytes.split_to((duplicate.start - offset) as usize),
allocation_size,
);
self.buffered += buffer.bytes.len();
self.allocated += buffer.allocation_size;
self.data.push(buffer);
offset = duplicate.start;
}
bytes.advance((duplicate.end - offset) as usize);
offset = duplicate.end;
}
recvd.insert(range);
} else if offset < self.bytes_read {
if (offset + bytes.len() as u64) <= self.bytes_read {
return;
} else {
let diff = self.bytes_read - offset;
offset += diff;
bytes.advance(diff as usize);
}
}
if bytes.is_empty() {
return;
}
let buffer = Buffer::new(offset, bytes, allocation_size);
self.buffered += buffer.bytes.len();
self.allocated += buffer.allocation_size;
self.data.push(buffer);
let buffered = self.buffered.min((self.end - self.bytes_read) as usize);
let over_allocation = self.allocated - buffered;
let threshold = 32768.max(buffered * 3 / 2);
if over_allocation > threshold {
self.defragment()
}
}
pub(super) fn bytes_read(&self) -> u64 {
self.bytes_read
}
pub(super) fn clear(&mut self) {
self.data.clear();
self.buffered = 0;
self.allocated = 0;
}
}
#[derive(Debug, PartialEq, Eq)]
pub struct Chunk {
pub offset: u64,
pub bytes: Bytes,
}
impl Chunk {
fn new(offset: u64, bytes: Bytes) -> Self {
Self { offset, bytes }
}
}
#[derive(Debug, Eq)]
struct Buffer {
offset: u64,
bytes: Bytes,
allocation_size: usize,
defragmented: bool,
}
impl Buffer {
fn new(offset: u64, bytes: Bytes, allocation_size: usize) -> Self {
Self {
offset,
bytes,
allocation_size,
defragmented: false,
}
}
fn new_defragmented(offset: u64, bytes: Bytes) -> Self {
let allocation_size = bytes.len();
Self {
offset,
bytes,
allocation_size,
defragmented: true,
}
}
fn try_mark_defragment(&mut self, offset: u64) {
let duplicate = offset.saturating_sub(self.offset) as usize;
self.offset = self.offset.max(offset);
if duplicate >= self.bytes.len() {
self.bytes = Bytes::new();
self.defragmented = true;
self.allocation_size = 0;
return;
}
self.bytes.advance(duplicate);
self.defragmented = self.defragmented || self.bytes.len() * 6 / 5 >= self.allocation_size;
if self.defragmented {
self.allocation_size = self.bytes.len();
}
}
}
impl Ord for Buffer {
fn cmp(&self, other: &Self) -> Ordering {
self.offset
.cmp(&other.offset)
.reverse()
.then(self.bytes.len().cmp(&other.bytes.len()))
}
}
impl PartialOrd for Buffer {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl PartialEq for Buffer {
fn eq(&self, other: &Self) -> bool {
(self.offset, self.bytes.len()) == (other.offset, other.bytes.len())
}
}
#[derive(Debug, Default)]
enum State {
#[default]
Ordered,
Unordered {
recvd: ArrayRangeSet,
},
}
impl State {
fn is_ordered(&self) -> bool {
matches!(self, Self::Ordered)
}
}
#[derive(Debug)]
pub(crate) struct IllegalOrderedRead;
#[cfg(test)]
mod test {
use super::*;
use assert_matches::assert_matches;
#[test]
fn assemble_ordered() {
let mut x = Assembler::new();
assert_matches!(next(&mut x, 32), None);
x.insert(0, Bytes::from_static(b"123"), 3);
assert_matches!(next(&mut x, 1), Some(ref y) if &y[..] == b"1");
assert_matches!(next(&mut x, 3), Some(ref y) if &y[..] == b"23");
x.insert(3, Bytes::from_static(b"456"), 3);
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"456");
x.insert(6, Bytes::from_static(b"789"), 3);
x.insert(9, Bytes::from_static(b"10"), 2);
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"789");
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"10");
assert_matches!(next(&mut x, 32), None);
}
#[test]
fn assemble_unordered() {
let mut x = Assembler::new();
x.ensure_ordering(false).unwrap();
x.insert(3, Bytes::from_static(b"456"), 3);
assert_matches!(next(&mut x, 32), None);
x.insert(0, Bytes::from_static(b"123"), 3);
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"123");
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"456");
assert_matches!(next(&mut x, 32), None);
}
#[test]
fn assemble_duplicate() {
let mut x = Assembler::new();
x.insert(0, Bytes::from_static(b"123"), 3);
x.insert(0, Bytes::from_static(b"123"), 3);
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"123");
assert_matches!(next(&mut x, 32), None);
}
#[test]
fn assemble_duplicate_compact() {
let mut x = Assembler::new();
x.insert(0, Bytes::from_static(b"123"), 3);
x.insert(0, Bytes::from_static(b"123"), 3);
x.defragment();
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"123");
assert_matches!(next(&mut x, 32), None);
}
#[test]
fn assemble_contained() {
let mut x = Assembler::new();
x.insert(0, Bytes::from_static(b"12345"), 5);
x.insert(1, Bytes::from_static(b"234"), 3);
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"12345");
assert_matches!(next(&mut x, 32), None);
}
#[test]
fn assemble_contained_compact() {
let mut x = Assembler::new();
x.insert(0, Bytes::from_static(b"12345"), 5);
x.insert(1, Bytes::from_static(b"234"), 3);
x.defragment();
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"12345");
assert_matches!(next(&mut x, 32), None);
}
#[test]
fn assemble_contains() {
let mut x = Assembler::new();
x.insert(1, Bytes::from_static(b"234"), 3);
x.insert(0, Bytes::from_static(b"12345"), 5);
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"12345");
assert_matches!(next(&mut x, 32), None);
}
#[test]
fn assemble_contains_compact() {
let mut x = Assembler::new();
x.insert(1, Bytes::from_static(b"234"), 3);
x.insert(0, Bytes::from_static(b"12345"), 5);
x.defragment();
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"12345");
assert_matches!(next(&mut x, 32), None);
}
#[test]
fn assemble_overlapping() {
let mut x = Assembler::new();
x.insert(0, Bytes::from_static(b"123"), 3);
x.insert(1, Bytes::from_static(b"234"), 3);
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"123");
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"4");
assert_matches!(next(&mut x, 32), None);
}
#[test]
fn assemble_overlapping_compact() {
let mut x = Assembler::new();
x.insert(0, Bytes::from_static(b"123"), 4);
x.insert(1, Bytes::from_static(b"234"), 4);
x.defragment();
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"1234");
assert_matches!(next(&mut x, 32), None);
}
#[test]
fn assemble_complex() {
let mut x = Assembler::new();
x.insert(0, Bytes::from_static(b"1"), 1);
x.insert(2, Bytes::from_static(b"3"), 1);
x.insert(4, Bytes::from_static(b"5"), 1);
x.insert(0, Bytes::from_static(b"123456"), 6);
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"123456");
assert_matches!(next(&mut x, 32), None);
}
#[test]
fn assemble_complex_compact() {
let mut x = Assembler::new();
x.insert(0, Bytes::from_static(b"1"), 1);
x.insert(2, Bytes::from_static(b"3"), 1);
x.insert(4, Bytes::from_static(b"5"), 1);
x.insert(0, Bytes::from_static(b"123456"), 6);
x.defragment();
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"123456");
assert_matches!(next(&mut x, 32), None);
}
#[test]
fn assemble_old() {
let mut x = Assembler::new();
x.insert(0, Bytes::from_static(b"1234"), 4);
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"1234");
x.insert(0, Bytes::from_static(b"1234"), 4);
assert_matches!(next(&mut x, 32), None);
}
#[test]
fn compact() {
let mut x = Assembler::new();
x.insert(0, Bytes::from_static(b"abc"), 4);
x.insert(3, Bytes::from_static(b"def"), 4);
x.insert(9, Bytes::from_static(b"jkl"), 4);
x.insert(12, Bytes::from_static(b"mno"), 4);
x.defragment();
assert_eq!(
next_unordered(&mut x),
Chunk::new(0, Bytes::from_static(b"abcdef"))
);
assert_eq!(
next_unordered(&mut x),
Chunk::new(9, Bytes::from_static(b"jklmno"))
);
}
#[test]
fn defrag_with_missing_prefix() {
let mut x = Assembler::new();
x.insert(3, Bytes::from_static(b"def"), 3);
x.defragment();
assert_eq!(
next_unordered(&mut x),
Chunk::new(3, Bytes::from_static(b"def"))
);
}
#[test]
fn defrag_read_chunk() {
let mut x = Assembler::new();
x.insert(3, Bytes::from_static(b"def"), 4);
x.insert(0, Bytes::from_static(b"abc"), 4);
x.insert(7, Bytes::from_static(b"hij"), 4);
x.insert(11, Bytes::from_static(b"lmn"), 4);
x.defragment();
assert_matches!(x.read(usize::MAX, true), Some(ref y) if &y.bytes[..] == b"abcdef");
x.insert(5, Bytes::from_static(b"fghijklmn"), 9);
assert_matches!(x.read(usize::MAX, true), Some(ref y) if &y.bytes[..] == b"ghijklmn");
x.insert(13, Bytes::from_static(b"nopq"), 4);
assert_matches!(x.read(usize::MAX, true), Some(ref y) if &y.bytes[..] == b"opq");
x.insert(15, Bytes::from_static(b"pqrs"), 4);
assert_matches!(x.read(usize::MAX, true), Some(ref y) if &y.bytes[..] == b"rs");
assert_matches!(x.read(usize::MAX, true), None);
}
#[test]
fn unordered_happy_path() {
let mut x = Assembler::new();
x.ensure_ordering(false).unwrap();
x.insert(0, Bytes::from_static(b"abc"), 3);
assert_eq!(
next_unordered(&mut x),
Chunk::new(0, Bytes::from_static(b"abc"))
);
assert_eq!(x.read(usize::MAX, false), None);
x.insert(3, Bytes::from_static(b"def"), 3);
assert_eq!(
next_unordered(&mut x),
Chunk::new(3, Bytes::from_static(b"def"))
);
assert_eq!(x.read(usize::MAX, false), None);
}
#[test]
fn unordered_dedup() {
let mut x = Assembler::new();
x.ensure_ordering(false).unwrap();
x.insert(3, Bytes::from_static(b"def"), 3);
assert_eq!(
next_unordered(&mut x),
Chunk::new(3, Bytes::from_static(b"def"))
);
assert_eq!(x.read(usize::MAX, false), None);
x.insert(0, Bytes::from_static(b"a"), 1);
x.insert(0, Bytes::from_static(b"abcdefghi"), 9);
x.insert(0, Bytes::from_static(b"abcd"), 4);
assert_eq!(
next_unordered(&mut x),
Chunk::new(0, Bytes::from_static(b"a"))
);
assert_eq!(
next_unordered(&mut x),
Chunk::new(1, Bytes::from_static(b"bc"))
);
assert_eq!(
next_unordered(&mut x),
Chunk::new(6, Bytes::from_static(b"ghi"))
);
assert_eq!(x.read(usize::MAX, false), None);
x.insert(8, Bytes::from_static(b"ijkl"), 4);
assert_eq!(
next_unordered(&mut x),
Chunk::new(9, Bytes::from_static(b"jkl"))
);
assert_eq!(x.read(usize::MAX, false), None);
x.insert(12, Bytes::from_static(b"mno"), 3);
assert_eq!(
next_unordered(&mut x),
Chunk::new(12, Bytes::from_static(b"mno"))
);
assert_eq!(x.read(usize::MAX, false), None);
x.insert(2, Bytes::from_static(b"cde"), 3);
assert_eq!(x.read(usize::MAX, false), None);
}
#[test]
fn chunks_dedup() {
let mut x = Assembler::new();
x.insert(3, Bytes::from_static(b"def"), 3);
assert_eq!(x.read(usize::MAX, true), None);
x.insert(0, Bytes::from_static(b"a"), 1);
x.insert(1, Bytes::from_static(b"bcdefghi"), 9);
x.insert(0, Bytes::from_static(b"abcd"), 4);
assert_eq!(
x.read(usize::MAX, true),
Some(Chunk::new(0, Bytes::from_static(b"abcd")))
);
assert_eq!(
x.read(usize::MAX, true),
Some(Chunk::new(4, Bytes::from_static(b"efghi")))
);
assert_eq!(x.read(usize::MAX, true), None);
x.insert(8, Bytes::from_static(b"ijkl"), 4);
assert_eq!(
x.read(usize::MAX, true),
Some(Chunk::new(9, Bytes::from_static(b"jkl")))
);
assert_eq!(x.read(usize::MAX, true), None);
x.insert(12, Bytes::from_static(b"mno"), 3);
assert_eq!(
x.read(usize::MAX, true),
Some(Chunk::new(12, Bytes::from_static(b"mno")))
);
assert_eq!(x.read(usize::MAX, true), None);
x.insert(2, Bytes::from_static(b"cde"), 3);
assert_eq!(x.read(usize::MAX, true), None);
}
#[test]
fn ordered_eager_discard() {
let mut x = Assembler::new();
x.insert(0, Bytes::from_static(b"abc"), 3);
assert_eq!(x.data.len(), 1);
assert_eq!(
x.read(usize::MAX, true),
Some(Chunk::new(0, Bytes::from_static(b"abc")))
);
x.insert(0, Bytes::from_static(b"ab"), 2);
assert_eq!(x.data.len(), 0);
x.insert(2, Bytes::from_static(b"cd"), 2);
assert_eq!(
x.data.peek(),
Some(&Buffer::new(3, Bytes::from_static(b"d"), 2))
);
}
#[test]
fn ordered_insert_unordered_read() {
let mut x = Assembler::new();
x.insert(0, Bytes::from_static(b"abc"), 3);
x.insert(0, Bytes::from_static(b"abc"), 3);
x.ensure_ordering(false).unwrap();
assert_eq!(
x.read(3, false),
Some(Chunk::new(0, Bytes::from_static(b"abc")))
);
assert_eq!(x.read(3, false), None);
}
#[test]
fn no_duplicate_after_mode_switch() {
let mut x = Assembler::new();
x.insert(0, Bytes::from_static(b"a"), 1);
x.insert(0, Bytes::from_static(b"a"), 1); assert_eq!(
x.read(1, true),
Some(Chunk::new(0, Bytes::from_static(b"a")))
);
x.ensure_ordering(false).unwrap();
assert_eq!(x.read(1, false), None); }
fn next_unordered(x: &mut Assembler) -> Chunk {
x.read(usize::MAX, false).unwrap()
}
fn next(x: &mut Assembler, size: usize) -> Option<Bytes> {
x.read(size, true).map(|chunk| chunk.bytes)
}
}
#[cfg(all(test, not(target_family = "wasm")))]
mod proptests {
use proptest::prelude::*;
use rand::RngExt;
use test_strategy::{Arbitrary, proptest};
use super::*;
const MAX_OFFSET: u64 = 512;
const MAX_LEN: usize = 64;
#[derive(Debug, Clone, Arbitrary)]
enum Op {
#[weight(10)]
Insert {
#[strategy(0..MAX_OFFSET)]
offset: u64,
#[strategy(1..MAX_LEN)]
len: usize,
},
#[weight(10)]
Read {
#[strategy(1..MAX_LEN)]
max_len: usize,
},
#[weight(1)]
EnsureOrdering { ordered: bool },
#[weight(1)]
Defragment,
}
struct RefState {
received: Vec<bool>,
returned: Vec<bool>,
ordered: bool,
}
fn set_range(bits: &mut [bool], start: u64, len: usize) {
for i in start..(start + len as u64).min(bits.len() as u64) {
bits[i as usize] = true;
}
}
impl RefState {
fn new() -> Self {
Self {
received: vec![false; MAX_OFFSET as usize],
returned: vec![false; MAX_OFFSET as usize],
ordered: true,
}
}
fn insert(&mut self, offset: u64, len: usize) {
set_range(&mut self.received, offset, len);
}
fn ensure_ordering(&mut self, ordered: bool) -> bool {
if ordered && !self.ordered {
return false;
}
self.ordered = ordered;
true
}
fn bytes_read(&self) -> u64 {
self.returned.iter().filter(|&&x| x).count() as u64
}
}
fn make_data() -> Vec<u8> {
use rand::SeedableRng;
let mut rng = rand::rngs::StdRng::seed_from_u64(0xDEADBEEF);
let mut data = vec![0u8; MAX_OFFSET as usize];
rng.fill(data.as_mut_slice());
data
}
fn get_slice(data: &[u8], offset: u64, len: usize) -> Bytes {
let start = offset as usize;
let end = (start + len).min(data.len());
Bytes::copy_from_slice(&data[start..end])
}
fn verify_chunk(data: &[u8], chunk: &Chunk) -> bool {
let start = chunk.offset as usize;
chunk.bytes[..] == data[start..start + chunk.bytes.len()]
}
#[proptest]
fn assembler_matches_reference(
#[strategy(proptest::collection::vec(any::<Op>(), 1..100))] ops: Vec<Op>,
) {
let data = make_data();
let mut asm = Assembler::new();
let mut reference = RefState::new();
for op in ops {
match op {
Op::Insert { offset, len } => {
let bytes = get_slice(&data, offset, len);
asm.insert(offset, bytes, len);
reference.insert(offset, len);
}
Op::Read { max_len } => {
let ordered = reference.ordered;
let actual = asm.read(max_len, ordered);
match actual {
None => {
let has_available = if ordered {
reference
.returned
.iter()
.position(|&x| !x)
.is_some_and(|pos| reference.received[pos])
} else {
reference
.received
.iter()
.zip(&reference.returned)
.any(|(&r, &ret)| r && !ret)
};
prop_assert!(
!has_available,
"read returned None but data was available"
);
}
Some(chunk) => {
prop_assert!(chunk.bytes.len() <= max_len, "chunk exceeds max_len");
prop_assert!(verify_chunk(&data, &chunk), "data corruption");
for i in 0..chunk.bytes.len() {
let offset = chunk.offset as usize + i;
prop_assert!(
reference.received[offset],
"returned unreceived byte at {offset}"
);
prop_assert!(
!reference.returned[offset],
"duplicate byte at {offset}"
);
reference.returned[offset] = true;
}
}
}
}
Op::EnsureOrdering { ordered } => {
let actual = asm.ensure_ordering(ordered).is_ok();
let expected = reference.ensure_ordering(ordered);
prop_assert_eq!(actual, expected, "ensure_ordering result mismatch");
}
Op::Defragment => {
if asm.state.is_ordered() {
asm.defragment();
}
}
}
}
prop_assert_eq!(
asm.bytes_read(),
reference.bytes_read(),
"bytes_read mismatch"
);
}
}