use crate::{
frame::{FitError, Tag},
varint::VarInt,
};
use core::{convert::TryFrom, mem::size_of};
use s2n_codec::{
decoder_parameterized_value, DecoderBuffer, DecoderBufferMut, Encoder, EncoderValue,
};
macro_rules! stream_tag {
() => {
0x08u8..=0x0fu8
};
}
const STREAM_TAG: u8 = 0x08;
const OFF_BIT: u8 = 0x04;
const LEN_BIT: u8 = 0x02;
const FIN_BIT: u8 = 0x01;
#[derive(Debug, PartialEq, Eq)]
pub struct Stream<Data> {
pub stream_id: VarInt,
pub offset: VarInt,
pub is_last_frame: bool,
pub is_fin: bool,
pub data: Data,
}
pub type StreamRef<'a> = Stream<&'a [u8]>;
pub type StreamMut<'a> = Stream<&'a mut [u8]>;
impl<Data> Stream<Data> {
#[inline]
pub fn tag(&self) -> u8 {
let mut tag: u8 = STREAM_TAG;
if *self.offset != 0 {
tag |= OFF_BIT;
}
if !self.is_last_frame {
tag |= LEN_BIT;
}
if self.is_fin {
tag |= FIN_BIT;
}
tag
}
#[inline]
pub fn map_data<F: FnOnce(Data) -> Out, Out>(self, map: F) -> Stream<Out> {
Stream {
stream_id: self.stream_id,
offset: self.offset,
is_last_frame: self.is_last_frame,
is_fin: self.is_fin,
data: map(self.data),
}
}
}
impl<Data: EncoderValue> Stream<Data> {
#[inline]
pub fn try_fit(&mut self, capacity: usize) -> Result<usize, FitError> {
let mut fixed_len = 0;
fixed_len += size_of::<Tag>();
fixed_len += self.stream_id.encoding_size();
if self.offset != 0u64 {
fixed_len += self.offset.encoding_size();
}
let remaining_capacity = capacity.checked_sub(fixed_len).ok_or(FitError)?;
let data_len = self.data.encoding_size();
let max_data_len = remaining_capacity.min(data_len);
if max_data_len == remaining_capacity {
self.is_last_frame = true;
return Ok(max_data_len);
}
self.is_last_frame = false;
let len_prefix_size = VarInt::try_from(max_data_len)
.map_err(|_| FitError)?
.encoding_size();
let prefixed_data_len = remaining_capacity
.checked_sub(len_prefix_size)
.ok_or(FitError)?;
let data_len = prefixed_data_len.min(data_len);
Ok(data_len)
}
}
decoder_parameterized_value!(
impl<'a, Data> Stream<Data> {
fn decode(tag: Tag, buffer: Buffer) -> Result<Self> {
let has_offset = tag & OFF_BIT == OFF_BIT;
let is_last_frame = tag & LEN_BIT != LEN_BIT;
let is_fin = tag & FIN_BIT == FIN_BIT;
let (stream_id, buffer) = buffer.decode()?;
let (offset, buffer) = if has_offset {
buffer.decode()?
} else {
(Default::default(), buffer)
};
let (data, buffer) = if !is_last_frame {
let (data, buffer) = buffer.decode_with_len_prefix::<VarInt, Data>()?;
(data, buffer)
} else {
let len = buffer.len();
let (data, buffer) = buffer.decode_slice(len)?;
let (data, remaining) = data.decode()?;
remaining.ensure_empty()?;
(data, buffer)
};
let frame = Stream {
stream_id,
offset,
is_last_frame,
is_fin,
data,
};
Ok((frame, buffer))
}
}
);
impl<Data: EncoderValue> EncoderValue for Stream<Data> {
#[inline]
fn encode<E: Encoder>(&self, buffer: &mut E) {
buffer.encode(&self.tag());
buffer.encode(&self.stream_id);
if *self.offset != 0 {
buffer.encode(&self.offset);
}
if self.is_last_frame {
buffer.encode(&self.data);
} else {
buffer.encode_with_len_prefix::<VarInt, _>(&self.data);
}
}
#[inline]
fn encoding_size_for_encoder<E: Encoder>(&self, encoder: &E) -> usize {
let mut len = 0;
len += size_of::<Tag>();
len += self.stream_id.encoding_size();
if *self.offset != 0 {
len += self.offset.encoding_size();
}
let data_len = self.data.encoding_size_for_encoder(encoder);
len += data_len;
if !self.is_last_frame {
len += VarInt::try_from(data_len).unwrap().encoding_size();
}
if cfg!(debug_assertions) {
use s2n_codec::EncoderLenEstimator;
let mut estimator = EncoderLenEstimator::new(encoder.remaining_capacity());
self.encode(&mut estimator);
assert_eq!(estimator.len(), len);
}
len
}
}
impl<'a> From<Stream<DecoderBuffer<'a>>> for StreamRef<'a> {
#[inline]
fn from(s: Stream<DecoderBuffer<'a>>) -> Self {
s.map_data(|data| data.into_less_safe_slice())
}
}
impl<'a> From<Stream<DecoderBufferMut<'a>>> for StreamRef<'a> {
#[inline]
fn from(s: Stream<DecoderBufferMut<'a>>) -> Self {
s.map_data(|data| &*data.into_less_safe_slice())
}
}
impl<'a> From<Stream<DecoderBufferMut<'a>>> for StreamMut<'a> {
#[inline]
fn from(s: Stream<DecoderBufferMut<'a>>) -> Self {
s.map_data(|data| data.into_less_safe_slice())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::frame::Padding;
use bolero::check;
use core::convert::TryInto;
fn model(stream_id: VarInt, offset: VarInt, length: VarInt, capacity: usize) {
let length = if let Ok(length) = VarInt::try_into(length) {
length
} else {
return;
};
let mut frame = Stream {
stream_id,
offset,
is_last_frame: false,
is_fin: false,
data: Padding { length },
};
if let Ok(new_length) = frame.try_fit(capacity) {
frame.data = Padding { length: new_length };
assert!(
frame.encoding_size() <= capacity,
"the encoding_size should not exceed capacity {frame:#?}"
);
if new_length < length {
let mut min = capacity;
if !frame.is_last_frame {
min -= VarInt::try_from(new_length).unwrap().encoding_size();
}
let max = capacity;
assert!(
(min..=max).contains(&frame.encoding_size()),
"encoding_size ({}) should match capacity ({capacity}) {frame:#?}",
frame.encoding_size(),
);
}
if frame.is_last_frame {
assert_eq!(
frame.encoding_size(),
capacity,
"should only be the last frame if == capacity {frame:#?}"
);
}
} else {
assert!(
frame.encoding_size() > capacity,
"rejection should only occur when encoding size > capacity {frame:#?}"
);
}
}
#[test]
#[cfg_attr(kani, kani::proof, kani::unwind(1), kani::solver(kissat))]
fn try_fit_test() {
check!()
.with_type()
.cloned()
.for_each(|(stream_id, offset, length, capacity)| {
model(stream_id, offset, length, capacity);
});
}
}