use alloc::collections::VecDeque;
use bytes::{Buf, Bytes};
use core::{
convert::{TryFrom, TryInto},
fmt,
};
use s2n_codec::{Encoder, EncoderValue};
use s2n_quic_core::{frame::FitError, interval_set::Interval, varint::VarInt};
#[derive(Debug, Default)]
pub struct Buffer {
chunks: VecDeque<Chunk>,
head: VarInt,
pending_len: VarInt,
}
impl Buffer {
pub fn push(&mut self, data: Bytes) -> Interval<VarInt> {
debug_assert!(
self.capacity().as_u64() >= data.len() as u64,
"capacity should be checked before pushing"
);
let start = self.total_len();
let len = VarInt::try_from(data.len()).expect("cannot send more than VarInt::MAX");
self.pending_len += len;
self.chunks.push_back(Chunk { data });
self.check_integrity();
let end = start + (len - 1);
(start..=end).into()
}
#[inline]
pub fn capacity(&self) -> VarInt {
VarInt::MAX - self.total_len()
}
pub fn clear(&mut self) {
self.chunks.clear();
self.head = VarInt::from_u8(0);
self.pending_len = VarInt::from_u8(0);
self.check_integrity();
}
#[inline]
pub fn total_len(&self) -> VarInt {
self.head + self.pending_len
}
#[inline]
pub fn head(&self) -> VarInt {
self.head
}
#[inline]
pub fn enqueued_len(&self) -> VarInt {
self.pending_len
}
#[inline]
pub fn is_empty(&self) -> bool {
self.pending_len == VarInt::from_u8(0)
}
#[cfg(test)]
pub fn set_offset(&mut self, head: VarInt) {
self.head = head;
}
pub fn release(&mut self, up_to: VarInt) {
if up_to <= self.head {
return;
}
debug_assert!(
self.total_len() >= up_to,
"cannot release more than the total len"
);
while let Some(mut chunk) = self.chunks.pop_front() {
let len = VarInt::try_from(chunk.len()).unwrap();
let start = self.head;
let end = start + len;
if end <= up_to {
self.pending_len -= len;
self.head = end;
continue;
}
self.head = up_to;
let consumed = self.head - start;
self.pending_len -= consumed;
chunk.data.advance(consumed.try_into().unwrap());
self.chunks.push_front(chunk);
break;
}
self.check_integrity();
}
pub fn release_all(&mut self) {
self.chunks.clear();
self.head = self.total_len();
self.pending_len = VarInt::from_u8(0);
self.check_integrity();
}
#[inline]
pub fn viewer(&self) -> Viewer {
Viewer {
buffer: self,
offset: *self.head,
chunk_index: 0,
}
}
#[inline]
fn check_integrity(&self) {
if cfg!(debug_assertions) {
let actual: VarInt = self
.chunks
.iter()
.map(|chunk| chunk.len())
.sum::<usize>()
.try_into()
.unwrap();
assert_eq!(
actual, self.pending_len,
"actual buffer lengths should equal `pending_len`"
);
}
}
}
#[derive(Default)]
struct Chunk {
data: Bytes,
}
impl fmt::Debug for Chunk {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Chunk")
.field("len", &self.data.len())
.finish()
}
}
impl core::ops::Deref for Chunk {
type Target = Bytes;
#[inline]
fn deref(&self) -> &Self::Target {
&self.data
}
}
#[derive(Clone, Copy, Debug)]
pub struct Viewer<'a> {
buffer: &'a Buffer,
offset: u64,
chunk_index: usize,
}
impl<'a> Viewer<'a> {
#[inline]
pub fn next_view(&mut self, range: Interval<VarInt>, has_fin: bool) -> View<'a> {
View::new(
self.buffer,
range,
has_fin,
&mut self.offset,
&mut self.chunk_index,
)
}
}
#[derive(Clone, Copy, Debug)]
pub struct View<'a> {
buffer: &'a Buffer,
chunk_index: usize,
offset: usize,
len: usize,
is_fin: bool,
}
impl<'a> View<'a> {
#[inline]
fn new(
buffer: &'a Buffer,
range: Interval<VarInt>,
has_fin: bool,
stream_offset: &mut u64,
chunk_index: &mut usize,
) -> Self {
debug_assert!(
buffer.head <= range.start_inclusive(),
"range ({:?}) is referring to a chunk that has already been released: {:?}",
range,
buffer.head..buffer.total_len()
);
debug_assert!(
*stream_offset <= range.start_inclusive().as_u64(),
"viewer is trying to go backwards from offset {:?} to {:?}",
stream_offset,
range.start_inclusive()
);
debug_assert!(range.end_exclusive() <= buffer.total_len());
let mut offset = 0;
for chunk in buffer.chunks.iter().skip(*chunk_index) {
let len = chunk.len() as u64;
let start = *stream_offset;
let end = start + len;
if (start..end).contains(&range.start_inclusive()) {
offset = (range.start_inclusive().as_u64() - start) as usize;
break;
}
*stream_offset += len;
*chunk_index += 1;
}
debug_assert!(*chunk_index < buffer.chunks.len());
Self {
buffer,
chunk_index: *chunk_index,
offset,
len: range.len(),
is_fin: has_fin && range.end_inclusive() == (buffer.total_len() - 1),
}
}
#[inline]
pub fn trim_off(&mut self, amount: usize) -> Result<(), FitError> {
self.len = self.len.checked_sub(amount).ok_or(FitError)?;
self.is_fin &= amount == 0;
Ok(())
}
#[inline]
pub fn len(&self) -> VarInt {
VarInt::try_from(self.len).expect("len should always fit in a VarInt")
}
#[inline]
pub fn is_fin(&self) -> bool {
self.is_fin
}
#[inline]
pub fn iter<'iter, S: Slice<'iter>>(&'iter self) -> ViewIter<'iter, S> {
ViewIter {
view: *self,
slice: Default::default(),
}
}
}
pub struct ViewIter<'a, S: Slice<'a>> {
view: View<'a>,
slice: core::marker::PhantomData<S>,
}
impl<'a, S: Slice<'a>> Iterator for ViewIter<'a, S> {
type Item = S;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
let view = &mut self.view;
if view.len == 0 {
return None;
}
let chunk = &view.buffer.chunks[view.chunk_index];
let start = view.offset;
view.offset = 0;
let len = chunk.len() - start;
let len = view.len.min(len);
let end = start + len;
view.len -= len;
view.chunk_index += 1;
debug_assert_eq!(chunk[start..end].len(), len);
Some(S::from_chunk(chunk, start, end))
}
}
pub trait Slice<'a> {
fn from_chunk(chunk: &'a Bytes, start: usize, end: usize) -> Self;
}
impl<'a> Slice<'a> for &'a [u8] {
#[inline]
fn from_chunk(chunk: &'a Bytes, start: usize, end: usize) -> Self {
unsafe {
debug_assert!(chunk.len() >= end);
chunk.get_unchecked(start..end)
}
}
}
impl<'a> Slice<'a> for Bytes {
#[inline]
fn from_chunk(chunk: &'a Bytes, start: usize, end: usize) -> Self {
chunk.slice(start..end)
}
}
impl<'a> EncoderValue for &mut View<'a> {
#[inline]
fn encode<E: Encoder>(&self, encoder: &mut E) {
if E::SPECIALIZES_BYTES {
for chunk in self.iter::<Bytes>() {
encoder.write_bytes(chunk);
}
return;
}
encoder.write_sized(self.len, |slice| {
let mut offset = 0;
for chunk in self.iter::<&[u8]>() {
let len = chunk.len();
let end = offset + len;
unsafe {
debug_assert!(slice.len() >= end);
core::ptr::copy_nonoverlapping(
chunk.as_ptr(),
slice.get_unchecked_mut(offset),
len,
);
}
offset += len;
}
});
}
#[inline]
fn encoding_size_for_encoder<E: Encoder>(&self, _encoder: &E) -> usize {
self.len
}
}
#[cfg(test)]
mod tests {
use super::*;
fn almost_full_buffer() -> Buffer {
Buffer {
head: VarInt::MAX - 1,
..Default::default()
}
}
#[test]
fn partial_release_test() {
let mut buffer = Buffer::default();
buffer.push(Bytes::from_static(&[0, 1, 2]));
assert_eq!(buffer.total_len(), VarInt::from_u8(3));
assert_eq!(buffer.enqueued_len(), VarInt::from_u8(3));
buffer.release(VarInt::from_u8(1));
assert_eq!(buffer.total_len(), VarInt::from_u8(3));
assert_eq!(buffer.enqueued_len(), VarInt::from_u8(2));
assert_eq!(buffer.chunks.len(), 1);
assert_eq!(buffer.chunks[0][..], [1, 2]);
buffer.release(VarInt::from_u8(1));
assert_eq!(buffer.total_len(), VarInt::from_u8(3));
assert_eq!(buffer.enqueued_len(), VarInt::from_u8(2));
assert_eq!(buffer.chunks.len(), 1);
assert_eq!(buffer.chunks[0][..], [1, 2]);
}
#[test]
fn full_release_test() {
let mut buffer = Buffer::default();
buffer.push(Bytes::from_static(&[0, 1, 2]));
assert_eq!(buffer.total_len(), VarInt::from_u8(3));
assert_eq!(buffer.enqueued_len(), VarInt::from_u8(3));
buffer.release(VarInt::from_u8(3));
assert_eq!(buffer.total_len(), VarInt::from_u8(3));
assert_eq!(buffer.enqueued_len(), VarInt::from_u8(0));
assert!(buffer.chunks.is_empty());
buffer.release(VarInt::from_u8(3));
assert_eq!(buffer.total_len(), VarInt::from_u8(3));
assert_eq!(buffer.enqueued_len(), VarInt::from_u8(0));
assert!(buffer.chunks.is_empty());
}
#[test]
fn varint_max_test() {
let mut buffer = almost_full_buffer();
buffer.push(Bytes::from_static(&[0]));
buffer.release(VarInt::MAX);
assert_eq!(buffer.total_len(), VarInt::MAX);
assert_eq!(buffer.enqueued_len(), VarInt::from_u8(0));
assert!(buffer.chunks.is_empty());
}
#[test]
#[should_panic]
fn varint_overflow_test() {
let mut buffer = almost_full_buffer();
buffer.push(Bytes::from_static(&[0, 1]));
}
fn check_view(buffer: &Buffer, interval: Interval<u64>, expected: &[u8]) {
let interval = (VarInt::new(interval.start_inclusive()).unwrap()
..=VarInt::new(interval.end_inclusive()).unwrap())
.into();
let actual: Vec<u8> = View::new(buffer, interval, false, &mut buffer.head.as_u64(), &mut 0)
.iter::<&[u8]>()
.flatten()
.copied()
.collect();
assert_eq!(actual, expected);
}
fn check_viewer(viewer: &mut Viewer, interval: Interval<u64>, expected: &[u8]) {
let interval = (VarInt::new(interval.start_inclusive()).unwrap()
..=VarInt::new(interval.end_inclusive()).unwrap())
.into();
let actual: Vec<u8> = viewer
.next_view(interval, false)
.iter::<&[u8]>()
.flatten()
.copied()
.collect();
assert_eq!(actual, expected);
}
#[test]
fn view_test() {
let mut buffer = Buffer::default();
buffer.push(Bytes::from_static(&[0, 1, 2]));
buffer.push(Bytes::from_static(&[3, 4, 5]));
check_view(&buffer, (0..1).into(), &[0]);
check_view(&buffer, (0..2).into(), &[0, 1]);
check_view(&buffer, (0..3).into(), &[0, 1, 2]);
check_view(&buffer, (0..4).into(), &[0, 1, 2, 3]);
check_view(&buffer, (0..5).into(), &[0, 1, 2, 3, 4]);
check_view(&buffer, (0..6).into(), &[0, 1, 2, 3, 4, 5]);
check_view(&buffer, (1..6).into(), &[1, 2, 3, 4, 5]);
check_view(&buffer, (2..6).into(), &[2, 3, 4, 5]);
check_view(&buffer, (3..6).into(), &[3, 4, 5]);
check_view(&buffer, (4..6).into(), &[4, 5]);
check_view(&buffer, (5..6).into(), &[5]);
}
#[test]
fn viewer_test() {
let mut buffer = Buffer::default();
buffer.push(Bytes::from_static(&[0, 1, 2]));
buffer.push(Bytes::from_static(&[3, 4, 5]));
let mut viewer = buffer.viewer();
check_viewer(&mut viewer, (0..1).into(), &[0]);
check_viewer(&mut viewer, (2..4).into(), &[2, 3]);
check_viewer(&mut viewer, (5..6).into(), &[5]);
}
}