use std::{
fmt,
io::{self, Cursor},
};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use tokio_util::codec::{Decoder, Encoder};
use crate::shared::channels::PROTOCOL_HEADER_LEN;
use super::{RELIABLE_FRAME_LENGTH_FIELD_LEN, RELIABLE_FRAME_TOTAL_HEADER_LEN};
#[derive(Debug, Clone, Copy)]
enum DecodeState {
Head,
Data(usize),
}
pub struct QuinnetProtocolCodecError {
_priv: (),
}
impl fmt::Debug for QuinnetProtocolCodecError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("QuinnetProtocolCodecError").finish()
}
}
impl fmt::Display for QuinnetProtocolCodecError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("frame size too big")
}
}
impl std::error::Error for QuinnetProtocolCodecError {}
#[derive(Debug)]
pub struct QuinnetProtocolCodecEncoder {
max_frame_len: usize,
raw_channel_id: u8,
}
impl QuinnetProtocolCodecEncoder {
pub fn new(raw_channel_id: u8, max_frame_len: usize) -> Self {
Self {
raw_channel_id,
max_frame_len,
}
}
}
impl Encoder<Bytes> for QuinnetProtocolCodecEncoder {
type Error = io::Error;
fn encode(&mut self, frame: Bytes, dst: &mut BytesMut) -> Result<(), io::Error> {
if frame.len() > self.max_frame_len {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
QuinnetProtocolCodecError { _priv: () },
));
}
dst.reserve(RELIABLE_FRAME_TOTAL_HEADER_LEN + frame.len());
dst.put_uint(
PROTOCOL_HEADER_LEN as u64 + frame.len() as u64,
RELIABLE_FRAME_LENGTH_FIELD_LEN,
);
dst.put_u8(self.raw_channel_id);
dst.extend_from_slice(&frame[..]);
Ok(())
}
}
#[derive(Debug)]
pub struct QuinnetProtocolCodecDecoder {
state: DecodeState,
max_frame_len: usize,
}
impl QuinnetProtocolCodecDecoder {
pub fn new(max_frame_len: usize) -> Self {
Self {
max_frame_len,
state: DecodeState::Head,
}
}
fn decode_head(&mut self, src: &mut BytesMut) -> io::Result<Option<usize>> {
if src.len() < RELIABLE_FRAME_TOTAL_HEADER_LEN {
return Ok(None);
}
let payload_length = {
let mut src = Cursor::new(&mut *src);
let payload_length = src.get_uint(RELIABLE_FRAME_LENGTH_FIELD_LEN);
if payload_length > self.max_frame_len as u64 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
QuinnetProtocolCodecError { _priv: () },
));
}
payload_length as usize
};
src.advance(RELIABLE_FRAME_LENGTH_FIELD_LEN);
src.reserve(payload_length.saturating_sub(src.len()));
Ok(Some(payload_length))
}
fn decode_data(&self, payload_length: usize, src: &mut BytesMut) -> Option<BytesMut> {
if src.len() < payload_length {
return None;
}
Some(src.split_to(payload_length))
}
}
impl Decoder for QuinnetProtocolCodecDecoder {
type Item = BytesMut;
type Error = io::Error;
fn decode(&mut self, src: &mut BytesMut) -> io::Result<Option<BytesMut>> {
let payload_length = match self.state {
DecodeState::Head => match self.decode_head(src)? {
Some(payload_length) => {
self.state = DecodeState::Data(payload_length);
payload_length
}
None => return Ok(None),
},
DecodeState::Data(payload_length) => payload_length,
};
match self.decode_data(payload_length, src) {
Some(data) => {
self.state = DecodeState::Head;
src.reserve(RELIABLE_FRAME_TOTAL_HEADER_LEN.saturating_sub(src.len()));
Ok(Some(data))
}
None => Ok(None),
}
}
}