use core::marker::PhantomData;
use core::mem;
use core::pin::Pin;
use core::task::Context;
use std::error::Error as StdError;
use futures::task::Poll;
use futures::{Sink, Stream};
use pin_project_lite::pin_project;
use serde::de::DeserializeOwned;
use serde::ser::Serialize;
use crate::compress::{
compress_from,
decompress_into,
CompressError,
DecompressError,
};
use crate::serde::{
deserialize,
serialize_into,
DeserializeError,
SerializeError,
};
use crate::Either;
pub trait ShouldCompress {
fn should_compress(&self) -> bool;
}
impl ShouldCompress for () {
#[inline(always)]
fn should_compress(&self) -> bool {
false
}
}
impl<T, E> ShouldCompress for Result<T, E>
where
T: ShouldCompress,
E: ShouldCompress,
{
#[inline]
fn should_compress(&self) -> bool {
match self {
Ok(value) => value.should_compress(),
Err(err) => err.should_compress(),
}
}
}
pub trait Encode: ShouldCompress {
type Error: StdError;
fn encode(
&self,
buf: &mut Vec<u8>,
other_buf: &mut Vec<u8>,
) -> Result<(), Self::Error>;
}
pub trait Decode: Encode + Sized {
type Error: StdError;
fn decode(
buf: &[u8],
other_buf: &mut Vec<u8>,
) -> Result<Self, <Self as Decode>::Error>;
}
impl<T: Serialize + ShouldCompress> Encode for T {
type Error = EncodeError;
#[inline]
fn encode(
&self,
buf: &mut Vec<u8>,
aux: &mut Vec<u8>,
) -> Result<(), Self::Error> {
if self.should_compress() {
buf.push(1);
aux.clear();
serialize_into(self, aux).map_err(EncodeError::serialize)?;
compress_from(aux, buf).map_err(EncodeError::compress)?;
} else {
buf.push(0);
serialize_into(self, buf).map_err(EncodeError::serialize)?;
}
Ok(())
}
}
impl<T: Encode + DeserializeOwned> Decode for T {
type Error = DecodeError;
#[inline]
fn decode(
buf: &[u8],
aux: &mut Vec<u8>,
) -> Result<Self, <Self as Decode>::Error> {
let (&first, buf) =
buf.split_first().ok_or_else(DecodeError::empty)?;
let is_compressed = match first {
0 => false,
1 => true,
_ => return Err(DecodeError::invalid_byte(first)),
};
if is_compressed {
aux.clear();
decompress_into(buf, aux).map_err(DecodeError::decompress)?;
deserialize(aux).map_err(DecodeError::deserialize)
} else {
deserialize(buf).map_err(DecodeError::deserialize)
}
}
}
pin_project! {
pub struct DecodeStream<S, T> {
buf: Vec<u8>,
#[pin]
inner: S,
_phantom: PhantomData<T>,
}
}
impl<S, T> DecodeStream<S, T> {
#[inline]
pub fn new(stream: S) -> Self {
Self { buf: Vec::new(), inner: stream, _phantom: PhantomData }
}
#[inline]
pub fn with_type<U>(&mut self) -> &mut DecodeStream<S, U> {
unsafe { mem::transmute(self) }
}
}
impl<S, T, Bytes, InnerError> Stream for DecodeStream<S, T>
where
S: Stream<Item = Result<Bytes, InnerError>>,
T: Decode,
Bytes: AsRef<[u8]>,
{
type Item = Result<T, Either<InnerError, <T as Decode>::Error>>;
#[inline(always)]
fn poll_next(
self: Pin<&mut Self>,
ctx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
let this = self.project();
let bytes = match this.inner.poll_next(ctx) {
Poll::Ready(Some(Ok(bytes))) => bytes,
Poll::Ready(Some(Err(err))) => {
return Poll::Ready(Some(Err(Either::Left(err))))
},
Poll::Ready(None) => return Poll::Ready(None),
Poll::Pending => return Poll::Pending,
};
let result =
T::decode(bytes.as_ref(), this.buf).map_err(Either::Right);
Poll::Ready(Some(result))
}
}
pin_project! {
pub struct EncodeSink<S, T> {
buf: Vec<u8>,
#[pin]
inner: S,
_phantom: PhantomData<T>,
}
}
impl<S, T> EncodeSink<S, T> {
#[inline]
pub fn new(sink: S) -> Self {
Self { buf: Vec::new(), inner: sink, _phantom: PhantomData }
}
#[inline]
pub fn with_type<U>(&mut self) -> &mut EncodeSink<S, U> {
unsafe { mem::transmute(self) }
}
}
impl<S, T> Sink<T> for EncodeSink<S, T>
where
T: Encode,
S: Sink<Vec<u8>>,
{
type Error = Either<S::Error, T::Error>;
#[inline(always)]
fn poll_ready(
self: Pin<&mut Self>,
ctx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_ready(ctx).map_err(Either::Left)
}
#[inline(always)]
fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
let this = self.project();
let mut buf = Vec::new();
item.encode(&mut buf, this.buf).map_err(Either::Right)?;
this.inner.start_send(buf).map_err(Either::Left)
}
#[inline(always)]
fn poll_flush(
self: Pin<&mut Self>,
ctx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_flush(ctx).map_err(Either::Left)
}
#[inline(always)]
fn poll_close(
self: Pin<&mut Self>,
ctx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_close(ctx).map_err(Either::Left)
}
}
#[derive(Debug)]
pub struct EncodeError {
kind: EncodeErrorKind,
}
impl EncodeError {
#[inline]
pub(crate) fn compress(err: CompressError) -> Self {
Self { kind: EncodeErrorKind::Compress(err) }
}
#[inline]
pub(crate) fn serialize(err: SerializeError) -> Self {
Self { kind: EncodeErrorKind::Serialize(err) }
}
}
impl core::fmt::Display for EncodeError {
#[inline]
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let err: &dyn core::fmt::Display = match &self.kind {
EncodeErrorKind::Compress(err) => err,
EncodeErrorKind::Serialize(err) => err,
};
write!(f, "Encoding failed: {err}")
}
}
impl StdError for EncodeError {}
#[derive(Debug)]
enum EncodeErrorKind {
Compress(CompressError),
Serialize(SerializeError),
}
#[derive(Debug)]
pub struct DecodeError {
kind: DecodeErrorKind,
}
impl DecodeError {
#[inline]
pub(crate) fn empty() -> Self {
Self { kind: DecodeErrorKind::EmptyBuffer }
}
#[inline]
pub(crate) fn decompress(err: DecompressError) -> Self {
Self { kind: DecodeErrorKind::Decompress(err) }
}
#[inline]
pub(crate) fn deserialize(err: DeserializeError) -> Self {
Self { kind: DecodeErrorKind::Deserialize(err) }
}
#[inline]
pub(crate) fn invalid_byte(byte: u8) -> Self {
Self {
kind: DecodeErrorKind::InvalidFirstByte(InvalidFirstByte(byte)),
}
}
}
#[derive(Debug)]
enum DecodeErrorKind {
Decompress(DecompressError),
Deserialize(DeserializeError),
EmptyBuffer,
InvalidFirstByte(InvalidFirstByte),
}
impl core::fmt::Display for DecodeError {
#[inline]
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let err: &dyn core::fmt::Display = match &self.kind {
DecodeErrorKind::Decompress(err) => err,
DecodeErrorKind::Deserialize(err) => err,
DecodeErrorKind::EmptyBuffer => &"buffer is empty",
DecodeErrorKind::InvalidFirstByte(err) => err,
};
write!(f, "Decoding failed: {err}")
}
}
#[derive(Debug)]
struct InvalidFirstByte(u8);
impl core::fmt::Display for InvalidFirstByte {
#[inline]
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "invalid first byte {}, expected 0 or 1", self.0)
}
}
impl StdError for DecodeError {}