use std::{cell::Cell, cmp, error::Error as StdError, fmt, io::Cursor};
use ntex_bytes::{Buf, BufMut, Bytes, BytesMut};
use ntex_codec::{Decoder, Encoder};
#[derive(Debug, Clone, Copy)]
pub(super) struct Builder {
max_frame_len: usize,
length_field_len: usize,
length_adjustment: isize,
num_skip: Option<usize>,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub(super) enum LengthDelimitedCodecError {
MaxSize,
Adjusted,
}
#[derive(Debug, Clone)]
pub(super) struct LengthDelimitedCodec {
builder: Builder,
state: Cell<DecodeState>,
}
#[derive(Debug, Clone, Copy, Default)]
enum DecodeState {
#[default]
Head,
Data(usize),
}
impl LengthDelimitedCodec {
pub(super) fn new() -> Self {
Self {
builder: Builder::new(),
state: Cell::new(DecodeState::Head),
}
}
pub(super) fn max_frame_length(&self) -> usize {
self.builder.max_frame_len
}
pub(super) fn set_max_frame_length(&mut self, val: usize) {
self.builder.max_frame_length(val);
}
fn decode_head(&self, src: &mut BytesMut) -> Result<Option<usize>, LengthDelimitedCodecError> {
let head_len = self.builder.num_head_bytes();
let field_len = self.builder.length_field_len;
if src.len() < head_len {
return Ok(None);
}
let n = {
let mut src = Cursor::new(&mut *src);
let n = src.get_uint(field_len);
if n > self.builder.max_frame_len as u64 {
return Err(LengthDelimitedCodecError::MaxSize);
}
let n = n as usize;
let n = if self.builder.length_adjustment < 0 {
n.checked_sub(-self.builder.length_adjustment as usize)
} else {
n.checked_add(self.builder.length_adjustment as usize)
};
n.ok_or(LengthDelimitedCodecError::Adjusted)?
};
let num_skip = self.builder.get_num_skip();
if num_skip > 0 {
src.advance_to(num_skip);
}
Ok(Some(n))
}
fn decode_data(n: usize, src: &mut BytesMut) -> Option<Bytes> {
if src.len() < n {
return None;
}
Some(src.split_to(n))
}
}
impl Decoder for LengthDelimitedCodec {
type Item = Bytes;
type Error = LengthDelimitedCodecError;
fn decode(&self, src: &mut BytesMut) -> Result<Option<Bytes>, LengthDelimitedCodecError> {
let n = match self.state.get() {
DecodeState::Head => match self.decode_head(src)? {
Some(n) => {
self.state.set(DecodeState::Data(n));
n
}
None => return Ok(None),
},
DecodeState::Data(n) => n,
};
match Self::decode_data(n, src) {
Some(data) => {
self.state.set(DecodeState::Head);
Ok(Some(data))
}
None => Ok(None),
}
}
}
impl Encoder for LengthDelimitedCodec {
type Item = Bytes;
type Error = LengthDelimitedCodecError;
fn encode(&self, data: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
let n = data.len();
if n > self.builder.max_frame_len {
return Err(LengthDelimitedCodecError::MaxSize);
}
let n = if self.builder.length_adjustment < 0 {
n.checked_add(-self.builder.length_adjustment as usize)
} else {
n.checked_sub(self.builder.length_adjustment as usize)
};
let n = n.ok_or(LengthDelimitedCodecError::Adjusted)?;
dst.reserve(self.builder.length_field_len + n);
dst.put_uint(n as u64, self.builder.length_field_len);
dst.extend_from_slice(&data[..]);
Ok(())
}
}
impl Default for LengthDelimitedCodec {
fn default() -> Self {
Self::new()
}
}
impl Builder {
pub(super) fn new() -> Builder {
Builder {
max_frame_len: 8 * 1_024 * 1_024,
length_field_len: 4,
length_adjustment: 0,
num_skip: None,
}
}
pub(super) fn max_frame_length(&mut self, val: usize) -> &mut Self {
self.max_frame_len = val;
self
}
pub(super) fn length_field_length(&mut self, val: usize) -> &mut Self {
assert!(val > 0 && val <= 8, "invalid length field length");
self.length_field_len = val;
self
}
pub(super) fn length_adjustment(&mut self, val: isize) -> &mut Self {
self.length_adjustment = val;
self
}
pub(super) fn num_skip(&mut self, val: usize) -> &mut Self {
self.num_skip = Some(val);
self
}
pub(super) fn new_codec(&self) -> LengthDelimitedCodec {
LengthDelimitedCodec {
builder: *self,
state: Cell::new(DecodeState::Head),
}
}
fn num_head_bytes(&self) -> usize {
cmp::max(self.length_field_len, self.num_skip.unwrap_or(0))
}
fn get_num_skip(&self) -> usize {
self.num_skip.unwrap_or(self.length_field_len)
}
}
impl Default for Builder {
fn default() -> Self {
Self::new()
}
}
impl fmt::Display for LengthDelimitedCodecError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("frame size too big")
}
}
impl StdError for LengthDelimitedCodecError {}