use crate::{
buffer::{
reader::storage::{Chunk, Infallible as _},
Error, Reader,
},
varint::VarInt,
};
use alloc::collections::{vec_deque, VecDeque};
use bytes::BytesMut;
mod duplex;
mod probe;
mod reader;
mod request;
mod slot;
mod writer;
#[cfg(test)]
mod tests;
use request::Request;
use slot::Slot;
const MIN_BUFFER_ALLOCATION_SIZE: usize = 4096;
const UNKNOWN_FINAL_SIZE: u64 = u64::MAX;
#[derive(Debug, PartialEq, Default)]
pub struct Reassembler {
slots: VecDeque<Slot>,
cursors: Cursors,
}
#[derive(Clone, Copy, Debug, PartialEq)]
struct Cursors {
start_offset: u64,
max_recv_offset: u64,
final_offset: u64,
}
impl Cursors {
#[inline]
fn final_size(&self) -> Option<u64> {
if self.final_offset == UNKNOWN_FINAL_SIZE {
None
} else {
Some(self.final_offset)
}
}
#[inline]
fn handle_reader_fin<R>(&mut self, reader: &mut R) -> Result<(), Error<R::Error>>
where
R: Reader + ?Sized,
{
let buffered_offset = reader
.current_offset()
.checked_add_usize(reader.buffered_len())
.ok_or(Error::OutOfRange)?
.as_u64();
match (reader.final_offset(), self.final_size()) {
(Some(actual), Some(expected)) => {
ensure!(actual == expected, Err(Error::InvalidFin));
}
(Some(final_offset), None) => {
let final_offset = final_offset.as_u64();
ensure!(self.max_recv_offset <= final_offset, Err(Error::InvalidFin));
self.final_offset = final_offset;
}
(None, Some(expected)) => {
ensure!(expected >= buffered_offset, Err(Error::InvalidFin));
}
(None, None) => {}
}
self.max_recv_offset = self.max_recv_offset.max(buffered_offset);
Ok(())
}
}
impl Default for Cursors {
#[inline]
fn default() -> Self {
Self {
start_offset: 0,
max_recv_offset: 0,
final_offset: UNKNOWN_FINAL_SIZE,
}
}
}
impl Reassembler {
#[inline]
pub fn new() -> Reassembler {
Self::default()
}
#[inline]
pub fn is_writing_complete(&self) -> bool {
self.final_size()
.is_some_and(|len| self.total_received_len() == len)
}
#[inline]
pub fn is_reading_complete(&self) -> bool {
self.final_size() == Some(self.cursors.start_offset)
}
#[inline]
pub fn final_size(&self) -> Option<u64> {
self.cursors.final_size()
}
#[inline]
pub fn len(&self) -> usize {
self.report().0
}
#[inline]
pub fn is_empty(&self) -> bool {
if let Some(slot) = self.slots.front() {
!slot.is_occupied(self.cursors.start_offset)
} else {
true
}
}
#[inline]
pub fn report(&self) -> (usize, usize) {
let mut bytes = 0;
let mut chunks = 0;
for chunk in self.iter() {
bytes += chunk.len();
chunks += 1;
}
(bytes, chunks)
}
#[inline]
pub fn write_at(&mut self, offset: VarInt, data: &[u8]) -> Result<(), Error> {
let mut request = Request::new(offset, data, false)?;
self.write_reader(&mut request)?;
Ok(())
}
#[inline]
pub fn write_at_fin(&mut self, offset: VarInt, data: &[u8]) -> Result<(), Error> {
let mut request = Request::new(offset, data, true)?;
self.write_reader(&mut request)?;
Ok(())
}
#[inline]
pub fn write_reader<R>(&mut self, reader: &mut R) -> Result<(), Error<R::Error>>
where
R: Reader + ?Sized,
{
reader.skip_until(self.current_offset())?;
let snapshot = self.cursors;
self.cursors.handle_reader_fin(reader)?;
if let Err(err) = self.write_reader_impl(reader) {
use core::any::TypeId;
if TypeId::of::<R::Error>() != TypeId::of::<core::convert::Infallible>() {
self.cursors = snapshot;
}
return Err(Error::ReaderError(err));
}
self.invariants();
Ok(())
}
#[inline(always)]
fn write_reader_impl<R>(&mut self, reader: &mut R) -> Result<(), R::Error>
where
R: Reader + ?Sized,
{
if reader.buffer_is_empty() {
let _chunk = reader.read_chunk(0)?;
return Ok(());
}
let mut selected = None;
for idx in (0..self.slots.len()).rev() {
let Some(slot) = self.slots.get(idx) else {
unsafe {
assume!(false);
}
};
ensure!(slot.start() <= reader.current_offset().as_u64(), continue);
selected = Some(idx);
break;
}
let idx = if let Some(idx) = selected {
idx
} else {
let mut idx = 0;
let mut slot = self.allocate_slot(reader);
let filled = slot.try_write_reader(reader, &mut true)?;
if let Some(slot) = filled {
self.slots.push_front(slot);
idx += 1;
}
self.slots.push_front(slot);
ensure!(!reader.buffer_is_empty(), Ok(()));
idx
};
self.write_reader_at(reader, idx)?;
Ok(())
}
#[inline(always)]
fn write_reader_at<R>(&mut self, reader: &mut R, mut idx: usize) -> Result<(), R::Error>
where
R: Reader + ?Sized,
{
let initial_idx = idx;
let mut filled_slot = false;
unsafe {
assume!(
!reader.buffer_is_empty(),
"the first write should always be non-empty"
);
}
while !reader.buffer_is_empty() {
let Some(slot) = self.slots.get_mut(idx) else {
unsafe {
assume!(false);
}
};
let filled = slot.try_write_reader(reader, &mut filled_slot)?;
idx += 1;
if let Some(slot) = filled {
self.insert(idx, slot);
idx += 1;
}
ensure!(!reader.buffer_is_empty(), break);
self.write_reader_with_alloc(reader, &mut idx, &mut filled_slot)?;
continue;
}
if filled_slot {
self.unsplit_range(initial_idx..idx);
}
Ok(())
}
#[inline(always)]
fn write_reader_with_alloc<R>(
&mut self,
reader: &mut R,
idx: &mut usize,
filled_slot: &mut bool,
) -> Result<(), R::Error>
where
R: Reader + ?Sized,
{
while !reader.buffer_is_empty() {
if let Some(next) = self.slots.get(*idx) {
ensure!(next.start() > reader.current_offset().as_u64(), break);
}
let mut slot = self.allocate_slot(reader);
let filled = slot.try_write_reader(reader, filled_slot)?;
self.insert(*idx, slot);
*idx += 1;
if let Some(slot) = filled {
self.insert(*idx, slot);
*idx += 1;
}
}
Ok(())
}
#[inline]
fn unsplit_range(&mut self, range: core::ops::Range<usize>) {
for idx in range.rev() {
let Some(slot) = self.slots.get(idx) else {
unsafe {
assume!(false);
}
};
ensure!(slot.is_full(), continue);
let start = slot.start();
let end = slot.end();
let Some(next) = self.slots.get(idx + 1) else {
continue;
};
ensure!(next.start() == end, continue);
let current_block = Self::align_offset(start, Self::allocation_size(start));
let next_block = Self::align_offset(next.start(), Self::allocation_size(next.start()));
ensure!(current_block == next_block, continue);
if let Some(next) = self.slots.remove(idx + 1) {
self.slots[idx].unsplit(next);
} else {
unsafe {
assume!(false, "idx + 1 was checked above");
}
}
}
}
#[inline]
pub fn skip(&mut self, len: VarInt) -> Result<(), Error> {
ensure!(len > VarInt::ZERO, Ok(()));
let new_start_offset = self
.cursors
.start_offset
.checked_add(len.as_u64())
.and_then(|v| VarInt::new(v).ok())
.ok_or(Error::OutOfRange)?;
if let Some(final_size) = self.final_size() {
ensure!(
final_size >= new_start_offset.as_u64(),
Err(Error::InvalidFin)
);
}
self.cursors.max_recv_offset = self.cursors.max_recv_offset.max(new_start_offset.as_u64());
self.cursors.start_offset = new_start_offset.as_u64();
while let Some(mut slot) = self.slots.pop_front() {
if slot.end_allocated() < new_start_offset.as_u64() {
continue;
}
slot.skip_until(new_start_offset).unwrap();
if !slot.should_drop() {
self.slots.push_front(slot);
}
break;
}
self.invariants();
Ok(())
}
#[inline]
pub fn iter(&self) -> impl Iterator<Item = &[u8]> {
Iter::new(self)
}
#[inline]
pub fn drain(&mut self) -> impl Iterator<Item = BytesMut> + '_ {
Drain { inner: self }
}
#[inline]
pub fn pop(&mut self) -> Option<BytesMut> {
self.pop_watermarked(usize::MAX)
}
#[inline]
pub fn pop_watermarked(&mut self, watermark: usize) -> Option<BytesMut> {
let Chunk::BytesMut(chunk) = self.infallible_read_chunk(watermark) else {
unsafe { assume!(false) }
};
ensure!(!chunk.is_empty(), None);
Some(chunk)
}
#[inline]
pub fn consumed_len(&self) -> u64 {
self.cursors.start_offset
}
#[inline]
pub fn total_received_len(&self) -> u64 {
let mut offset = self.cursors.start_offset;
for slot in &self.slots {
ensure!(slot.is_occupied(offset), offset);
offset = slot.end();
}
offset
}
#[inline]
pub fn reset(&mut self) {
self.slots.clear();
self.cursors = Default::default();
}
#[inline(always)]
fn insert(&mut self, idx: usize, slot: Slot) {
if self.slots.len() < idx {
debug_assert_eq!(self.slots.len() + 1, idx);
self.slots.push_back(slot);
} else {
self.slots.insert(idx, slot);
}
}
#[inline]
fn allocate_slot<R>(&mut self, reader: &R) -> Slot
where
R: Reader + ?Sized,
{
let start = reader.current_offset().as_u64();
let mut size = Self::allocation_size(start);
let mut offset = Self::align_offset(start, size);
if let Some(diff) = self.cursors.start_offset.checked_sub(offset) {
if diff > 0 {
debug_assert!(
reader.current_offset().as_u64() >= self.cursors.start_offset,
"requests should be split before allocating slots"
);
offset = self.cursors.start_offset;
size -= diff as usize;
}
}
if self.cursors.final_offset
- reader.current_offset().as_u64()
- reader.buffered_len() as u64
== 0
{
let size_candidate = (start - offset) as usize + reader.buffered_len();
if size_candidate < size {
size = size_candidate;
}
}
let buffer = BytesMut::with_capacity(size);
let end = offset + size as u64;
Slot::new(offset, end, buffer)
}
#[inline(always)]
fn align_offset(offset: u64, alignment: usize) -> u64 {
unsafe {
assume!(alignment > 0);
}
(offset / (alignment as u64)) * (alignment as u64)
}
#[inline(always)]
fn allocation_size(offset: u64) -> usize {
for pow in (2..=4).rev() {
let mult = 1 << pow;
let square = mult * mult;
let min_offset = (MIN_BUFFER_ALLOCATION_SIZE * square) as u64;
let allocation_size = MIN_BUFFER_ALLOCATION_SIZE * mult;
if offset >= min_offset {
return allocation_size;
}
}
MIN_BUFFER_ALLOCATION_SIZE
}
#[inline(always)]
fn invariants(&self) {
if cfg!(debug_assertions) {
assert_eq!(
self.total_received_len(),
self.consumed_len() + self.len() as u64
);
let (actual_len, chunks) = self.report();
assert_eq!(actual_len == 0, self.is_empty());
assert_eq!(self.iter().count(), chunks);
let mut prev_end = self.cursors.start_offset;
for (idx, slot) in self.slots.iter().enumerate() {
assert!(slot.start() >= prev_end, "{self:#?}");
assert!(!slot.should_drop(), "slot range should be non-empty");
prev_end = slot.end_allocated();
if slot.is_full() {
let start = slot.start();
let end = slot.end();
let Some(next) = self.slots.get(idx + 1) else {
continue;
};
ensure!(next.start() == end, continue);
let current_block = Self::align_offset(start, Self::allocation_size(start));
let next_block =
Self::align_offset(next.start(), Self::allocation_size(next.start()));
ensure!(current_block == next_block, continue);
panic!("unmerged slots at {idx} and {} {self:#?}", idx + 1);
}
}
}
}
}
pub struct Iter<'a> {
prev_end: u64,
inner: vec_deque::Iter<'a, Slot>,
}
impl<'a> Iter<'a> {
#[inline]
fn new(buffer: &'a Reassembler) -> Self {
Self {
prev_end: buffer.cursors.start_offset,
inner: buffer.slots.iter(),
}
}
}
impl<'a> Iterator for Iter<'a> {
type Item = &'a [u8];
#[inline]
fn next(&mut self) -> Option<Self::Item> {
let slot = self.inner.next()?;
ensure!(slot.is_occupied(self.prev_end), None);
self.prev_end = slot.end();
Some(slot.as_slice())
}
}
pub struct Drain<'a> {
inner: &'a mut Reassembler,
}
impl Iterator for Drain<'_> {
type Item = BytesMut;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
self.inner.pop()
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.inner.slots.len();
(len, Some(len))
}
}