#![forbid(unsafe_code)]
use bytes::{Buf, BufMut, BytesMut};
#[cfg(test)]
mod test_utils;
#[cfg(test)]
mod test;
const DEFAULT_SENTINEL: u8 = 0x00;
const DEFAULT_MAX_LEN: usize = 0;
const DEFAULT_DECODE_BUFFER_CAPACITY: usize = 8 * 1024;
const MAX_RUN: usize = 254;
const fn max_encoded_len(input_len: usize) -> usize {
let overhead = if input_len == 0 {
1
} else {
(input_len + 253) / 254
};
input_len + overhead + 1
}
const fn decode_buffer_cap(max_len: usize) -> usize {
if max_len == 0 {
DEFAULT_DECODE_BUFFER_CAPACITY
} else {
max_len
}
}
#[inline(always)]
fn encode_len<const SENTINEL: u8>(len: usize) -> u8 {
debug_assert!(len <= MAX_RUN);
#[allow(clippy::collapsible_else_if)]
if SENTINEL == 0 {
len.wrapping_add(1) as u8
} else if SENTINEL == 255 {
assert!(SENTINEL as usize > MAX_RUN);
len as u8
} else {
if len >= SENTINEL as usize {
len.wrapping_add(1) as u8
} else {
len as u8
}
}
}
#[inline(always)]
fn decode_len<const SENTINEL: u8>(code: u8) -> Option<usize> {
let len = if SENTINEL == 0 {
usize::from(code).checked_sub(1)
} else if SENTINEL == 255 {
if code == SENTINEL {
None
} else {
Some(usize::from(code))
}
} else {
use std::cmp::Ordering;
match code.cmp(&SENTINEL) {
Ordering::Equal => None,
Ordering::Less => Some(usize::from(code)),
Ordering::Greater => Some(usize::from(code).wrapping_sub(1)),
}
};
if let Some(len) = len {
debug_assert!(len <= MAX_RUN);
};
len
}
#[inline(always)]
fn encode<const SENTINEL: u8, const MAX_LEN: usize>(
input: &[u8],
output: &mut BytesMut,
) {
output.reserve(max_encoded_len(input.len()));
if MAX_LEN != 0 {
debug_assert!(input.len() <= MAX_LEN);
}
if MAX_LEN != 0 && MAX_LEN <= MAX_RUN {
for run in input.split(|&b| b == SENTINEL) {
output.put_u8(encode_len::<SENTINEL>(run.len()));
output.put_slice(run);
}
} else {
let mut prev_run_was_maximal = false;
for mut run in input.split(|&b| b == SENTINEL) {
if prev_run_was_maximal {
output.put_u8(encode_len::<SENTINEL>(0));
}
loop {
let chunk_len = usize::min(run.len(), MAX_RUN);
let (chunk, new_run) = run.split_at(chunk_len);
output.put_u8(encode_len::<SENTINEL>(chunk_len));
output.put_slice(chunk);
run = new_run;
prev_run_was_maximal = chunk_len == MAX_RUN;
if run.is_empty() {
break;
}
}
}
}
output.put_u8(SENTINEL);
}
#[derive(Default, Debug)]
pub struct Encoder<
const SENTINEL: u8 = DEFAULT_SENTINEL,
const MAX_LEN: usize = DEFAULT_MAX_LEN,
>;
impl<const SENTINEL: u8, const MAX_LEN: usize> Encoder<SENTINEL, MAX_LEN> {
pub fn new() -> Self {
Self
}
}
impl<const SENTINEL: u8, const MAX_LEN: usize, T: AsRef<[u8]>>
tokio_util::codec::Encoder<T> for Encoder<SENTINEL, MAX_LEN>
{
type Error = std::io::Error;
#[inline(always)]
fn encode(
&mut self,
item: T,
dst: &mut BytesMut,
) -> Result<(), Self::Error> {
let bytes = item.as_ref();
assert!(MAX_LEN == 0 || bytes.len() <= MAX_LEN);
encode::<SENTINEL, MAX_LEN>(bytes, dst);
assert_eq!(dst.last(), Some(&SENTINEL));
Ok(())
}
}
#[derive(Debug)]
enum DecoderReadResult {
NeedMoreData,
Frame(BytesMut),
UnexpectedSentinel,
FrameOverflow,
}
#[derive(Debug)]
struct DecoderReadingState {
next_chunk_offset: usize,
output: BytesMut,
chunk_overflow: bool,
}
impl DecoderReadingState {
#[inline(always)]
fn new<const MAX_LEN: usize>(offset: usize) -> Self {
let mut this = Self {
next_chunk_offset: 0,
output: BytesMut::with_capacity(decode_buffer_cap(MAX_LEN)),
chunk_overflow: false,
};
this.update(offset);
this
}
#[inline(always)]
fn update(&mut self, offset: usize) {
self.next_chunk_offset = offset;
self.chunk_overflow = offset == MAX_RUN;
}
#[inline(always)]
fn read<const SENTINEL: u8, const MAX_LEN: usize>(
&mut self,
src: &mut BytesMut,
) -> DecoderReadResult {
loop {
if src.is_empty() {
return DecoderReadResult::NeedMoreData;
}
if self.next_chunk_offset > 0 {
let len = usize::min(self.next_chunk_offset, src.len());
if MAX_LEN != 0 && self.output.len() + len > MAX_LEN {
return DecoderReadResult::FrameOverflow;
}
self.next_chunk_offset -= len;
let chunk = src.split_to(len);
if chunk.contains(&SENTINEL) {
return DecoderReadResult::UnexpectedSentinel;
}
self.output.put(chunk);
if src.is_empty() {
return DecoderReadResult::NeedMoreData;
}
}
debug_assert!(self.next_chunk_offset == 0);
debug_assert!(!src.is_empty());
if let Some(offset) = decode_len::<SENTINEL>(src.get_u8()) {
if !self.chunk_overflow {
if MAX_LEN != 0 && self.output.len() == MAX_LEN {
return DecoderReadResult::FrameOverflow;
}
self.output.put_u8(SENTINEL);
}
self.update(offset);
} else {
let capacity = decode_buffer_cap(MAX_LEN);
let new_output = BytesMut::with_capacity(capacity);
let frame = std::mem::replace(&mut self.output, new_output);
return DecoderReadResult::Frame(frame);
}
}
}
}
#[derive(Debug)]
enum DecoderState {
Initial,
Reading(DecoderReadingState),
Lost,
}
#[derive(Debug)]
pub struct Decoder<
const SENTINEL: u8 = DEFAULT_SENTINEL,
const MAX_LEN: usize = DEFAULT_MAX_LEN,
> {
state: DecoderState,
}
impl<const SENTINEL: u8, const MAX_LEN: usize> Decoder<SENTINEL, MAX_LEN> {
pub fn new() -> Self {
Self {
state: DecoderState::Initial,
}
}
}
impl<const SENTINEL: u8, const MAX_LEN: usize> Default
for Decoder<SENTINEL, MAX_LEN>
{
fn default() -> Self {
Self::new()
}
}
#[derive(thiserror::Error, Debug)]
pub enum DecodeError {
#[error(transparent)]
Io(#[from] std::io::Error),
#[error("missing frame")]
MissingFrame,
#[error("unexpected sentinel")]
UnexpectedSentinel,
#[error("frame overflow")]
FrameOverflow,
}
impl<const SENTINEL: u8, const MAX_LEN: usize> tokio_util::codec::Decoder
for Decoder<SENTINEL, MAX_LEN>
{
type Item = BytesMut;
type Error = DecodeError;
fn decode(
&mut self,
src: &mut BytesMut,
) -> Result<Option<BytesMut>, Self::Error> {
loop {
if matches!(self.state, DecoderState::Initial) {
src.reserve(max_encoded_len(decode_buffer_cap(MAX_LEN)));
if src.is_empty() {
return Ok(None);
} else if let Some(offset) =
decode_len::<SENTINEL>(src.get_u8())
{
let read_state =
DecoderReadingState::new::<MAX_LEN>(offset);
self.state = DecoderState::Reading(read_state);
} else {
return Err(DecodeError::MissingFrame);
}
}
match &mut self.state {
DecoderState::Initial => unreachable!(),
DecoderState::Reading(state) => {
match state.read::<SENTINEL, MAX_LEN>(src) {
DecoderReadResult::NeedMoreData => return Ok(None),
DecoderReadResult::Frame(frame) => {
self.state = DecoderState::Initial;
return Ok(Some(frame));
}
DecoderReadResult::UnexpectedSentinel => {
self.state = DecoderState::Initial;
return Err(DecodeError::UnexpectedSentinel);
}
DecoderReadResult::FrameOverflow => {
self.state = DecoderState::Lost;
return Err(DecodeError::FrameOverflow);
}
}
}
DecoderState::Lost => {
if let Some(index) =
src.iter().position(|byte| *byte == SENTINEL)
{
let _ = src.split_to(index + 1);
let total_capacity =
max_encoded_len(decode_buffer_cap(MAX_LEN));
src.reserve(total_capacity.saturating_sub(src.len()));
self.state = DecoderState::Initial;
} else {
src.clear();
return Ok(None);
}
}
}
}
}
}
#[derive(Debug)]
pub struct Codec<
const SENTINEL_ENCODE: u8 = DEFAULT_SENTINEL,
const SENTINEL_DECODE: u8 = DEFAULT_SENTINEL,
const MAX_LEN_ENCODE: usize = DEFAULT_MAX_LEN,
const MAX_LEN_DECODE: usize = DEFAULT_MAX_LEN,
> {
encoder: Encoder<SENTINEL_ENCODE, MAX_LEN_ENCODE>,
decoder: Decoder<SENTINEL_DECODE, MAX_LEN_DECODE>,
}
impl<
const SENTINEL_ENCODE: u8,
const SENTINEL_DECODE: u8,
const MAX_LEN_ENCODE: usize,
const MAX_LEN_DECODE: usize,
> Codec<SENTINEL_ENCODE, SENTINEL_DECODE, MAX_LEN_ENCODE, MAX_LEN_DECODE>
{
pub fn new() -> Self {
Self {
encoder: Encoder::new(),
decoder: Decoder::new(),
}
}
}
impl<
const SENTINEL_ENCODE: u8,
const SENTINEL_DECODE: u8,
const MAX_LEN_ENCODE: usize,
const MAX_LEN_DECODE: usize,
> Default
for Codec<SENTINEL_ENCODE, SENTINEL_DECODE, MAX_LEN_ENCODE, MAX_LEN_DECODE>
{
fn default() -> Self {
Self::new()
}
}
impl<
const SENTINEL_ENCODE: u8,
const SENTINEL_DECODE: u8,
const MAX_LEN_ENCODE: usize,
const MAX_LEN_DECODE: usize,
T: AsRef<[u8]>,
> tokio_util::codec::Encoder<T>
for Codec<SENTINEL_ENCODE, SENTINEL_DECODE, MAX_LEN_ENCODE, MAX_LEN_DECODE>
{
type Error = std::io::Error;
fn encode(
&mut self,
item: T,
dst: &mut BytesMut,
) -> Result<(), Self::Error> {
self.encoder.encode(item, dst)
}
}
impl<
const SENTINEL_ENCODE: u8,
const SENTINEL_DECODE: u8,
const MAX_LEN_ENCODE: usize,
const MAX_LEN_DECODE: usize,
> tokio_util::codec::Decoder
for Codec<SENTINEL_ENCODE, SENTINEL_DECODE, MAX_LEN_ENCODE, MAX_LEN_DECODE>
{
type Item = BytesMut;
type Error = DecodeError;
fn decode(
&mut self,
src: &mut BytesMut,
) -> Result<Option<BytesMut>, Self::Error> {
self.decoder.decode(src)
}
}