use std::io;
use bytes::{Bytes, BytesMut};
use tokio_util::codec::{Decoder, Encoder, LengthDelimitedCodec};
use crate::byte_order::read_network_u32;
pub mod error;
pub mod recovery;
pub use error::{CodecError, EofError, FramingError, ProtocolError};
pub use recovery::{
CodecErrorContext,
DefaultRecoveryPolicy,
RecoveryConfig,
RecoveryPolicy,
RecoveryPolicyHook,
};
pub const MIN_FRAME_LENGTH: usize = 64;
pub const MAX_FRAME_LENGTH: usize = 16 * 1024 * 1024;
pub(crate) fn clamp_frame_length(value: usize) -> usize {
value.clamp(MIN_FRAME_LENGTH, MAX_FRAME_LENGTH)
}
#[doc(hidden)]
pub mod examples;
pub trait FrameCodec: Send + Sync + Clone + 'static {
type Frame: Send + Sync + 'static;
type Decoder: Decoder<Item = Self::Frame, Error = io::Error> + Send;
type Encoder: Encoder<Self::Frame, Error = io::Error> + Send;
fn decoder(&self) -> Self::Decoder;
fn encoder(&self) -> Self::Encoder;
fn frame_payload(frame: &Self::Frame) -> &[u8];
fn frame_payload_bytes(frame: &Self::Frame) -> Bytes {
Bytes::copy_from_slice(Self::frame_payload(frame))
}
fn wrap_payload(&self, payload: Bytes) -> Self::Frame;
fn correlation_id(_frame: &Self::Frame) -> Option<u64> { None }
fn max_frame_length(&self) -> usize;
}
#[derive(Clone, Debug)]
pub struct LengthDelimitedFrameCodec {
max_frame_length: usize,
}
impl LengthDelimitedFrameCodec {
#[must_use]
pub fn new(max_frame_length: usize) -> Self {
Self {
max_frame_length: clamp_frame_length(max_frame_length),
}
}
#[must_use]
pub fn max_frame_length(&self) -> usize { self.max_frame_length }
fn new_inner_codec(&self) -> LengthDelimitedCodec {
LengthDelimitedCodec::builder()
.max_frame_length(self.max_frame_length)
.new_codec()
}
}
impl Default for LengthDelimitedFrameCodec {
fn default() -> Self {
Self {
max_frame_length: 1024,
}
}
}
pub const LENGTH_HEADER_SIZE: usize = 4;
#[doc(hidden)]
pub struct LengthDelimitedDecoder {
inner: LengthDelimitedCodec,
}
impl Decoder for LengthDelimitedDecoder {
type Item = Bytes;
type Error = io::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
self.inner.decode(src).map(|opt| opt.map(BytesMut::freeze))
}
fn decode_eof(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if src.is_empty() {
return Ok(None);
}
let eof_context = EofContext::from_buffer(src);
match self.inner.decode_eof(src) {
Ok(Some(frame)) => Ok(Some(BytesMut::freeze(frame))),
Ok(None) => Err(build_eof_error(eof_context)),
Err(e) if e.kind() == io::ErrorKind::InvalidData => Err(e),
Err(e) => {
tracing::debug!(
inner_error = %e,
"inner decoder error at EOF, converting to structured EOF error"
);
Err(build_eof_error(eof_context))
}
}
}
}
#[derive(Clone, Copy, Debug)]
struct EofContext {
bytes_received: usize,
expected: Option<usize>,
}
impl EofContext {
fn from_buffer(src: &BytesMut) -> Self {
let bytes_received = src.len();
let expected = src
.get(..LENGTH_HEADER_SIZE)
.and_then(|slice| <[u8; LENGTH_HEADER_SIZE]>::try_from(slice).ok())
.and_then(|bytes| usize::try_from(read_network_u32(bytes)).ok());
Self {
bytes_received,
expected,
}
}
}
fn build_eof_error(context: EofContext) -> io::Error {
match context.expected {
Some(expected) => {
CodecError::Eof(EofError::MidFrame {
bytes_received: context.bytes_received.saturating_sub(LENGTH_HEADER_SIZE),
expected,
})
.into()
}
None => {
CodecError::Eof(EofError::MidHeader {
bytes_received: context.bytes_received,
header_size: LENGTH_HEADER_SIZE,
})
.into()
}
}
}
#[doc(hidden)]
pub struct LengthDelimitedEncoder {
inner: LengthDelimitedCodec,
max_frame_length: usize,
}
impl Encoder<Bytes> for LengthDelimitedEncoder {
type Error = io::Error;
fn encode(&mut self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
if item.len() > self.max_frame_length {
return Err(CodecError::Framing(FramingError::OversizedFrame {
size: item.len(),
max: self.max_frame_length,
})
.into());
}
self.inner.encode(item, dst)
}
}
impl FrameCodec for LengthDelimitedFrameCodec {
type Frame = Bytes;
type Decoder = LengthDelimitedDecoder;
type Encoder = LengthDelimitedEncoder;
fn decoder(&self) -> Self::Decoder {
LengthDelimitedDecoder {
inner: self.new_inner_codec(),
}
}
fn encoder(&self) -> Self::Encoder {
LengthDelimitedEncoder {
inner: self.new_inner_codec(),
max_frame_length: self.max_frame_length,
}
}
fn frame_payload(frame: &Self::Frame) -> &[u8] { frame.as_ref() }
fn frame_payload_bytes(frame: &Self::Frame) -> Bytes { frame.clone() }
fn wrap_payload(&self, payload: Bytes) -> Self::Frame { payload }
fn max_frame_length(&self) -> usize { self.max_frame_length }
}
#[cfg(test)]
mod tests;