collab-common 0.0.7

Code shared by collab's client and server
Documentation
//! TODO: docs

use core::marker::PhantomData;
use core::mem;
use core::pin::Pin;
use core::task::Context;
use std::error::Error as StdError;

use futures::task::Poll;
use futures::{Sink, Stream};
use pin_project_lite::pin_project;
use serde::de::DeserializeOwned;
use serde::ser::Serialize;

use crate::compress::{
    compress_from,
    decompress_into,
    CompressError,
    DecompressError,
};
use crate::serde::{
    deserialize,
    serialize_into,
    DeserializeError,
    SerializeError,
};
use crate::Either;

/// TODO: doc
pub trait ShouldCompress {
    /// TODO: doc
    fn should_compress(&self) -> bool;
}

impl ShouldCompress for () {
    #[inline(always)]
    fn should_compress(&self) -> bool {
        false
    }
}

impl<T, E> ShouldCompress for Result<T, E>
where
    T: ShouldCompress,
    E: ShouldCompress,
{
    #[inline]
    fn should_compress(&self) -> bool {
        match self {
            Ok(value) => value.should_compress(),
            Err(err) => err.should_compress(),
        }
    }
}

/// TODO: doc
pub trait Encode: ShouldCompress {
    /// TODO: doc
    type Error: StdError;

    /// TODO: doc
    fn encode(
        &self,
        buf: &mut Vec<u8>,
        other_buf: &mut Vec<u8>,
    ) -> Result<(), Self::Error>;
}

/// TODO: doc
pub trait Decode: Encode + Sized {
    /// TODO: doc
    type Error: StdError;

    /// TODO: doc
    fn decode(
        buf: &[u8],
        other_buf: &mut Vec<u8>,
    ) -> Result<Self, <Self as Decode>::Error>;
}

impl<T: Serialize + ShouldCompress> Encode for T {
    type Error = EncodeError;

    #[inline]
    fn encode(
        &self,
        buf: &mut Vec<u8>,
        aux: &mut Vec<u8>,
    ) -> Result<(), Self::Error> {
        if self.should_compress() {
            buf.push(1);
            aux.clear();
            serialize_into(self, aux).map_err(EncodeError::serialize)?;
            compress_from(aux, buf).map_err(EncodeError::compress)?;
        } else {
            buf.push(0);
            serialize_into(self, buf).map_err(EncodeError::serialize)?;
        }

        Ok(())
    }
}

impl<T: Encode + DeserializeOwned> Decode for T {
    type Error = DecodeError;

    #[inline]
    fn decode(
        buf: &[u8],
        aux: &mut Vec<u8>,
    ) -> Result<Self, <Self as Decode>::Error> {
        let (&first, buf) =
            buf.split_first().ok_or_else(DecodeError::empty)?;

        let is_compressed = match first {
            0 => false,
            1 => true,
            _ => return Err(DecodeError::invalid_byte(first)),
        };

        if is_compressed {
            aux.clear();
            decompress_into(buf, aux).map_err(DecodeError::decompress)?;
            deserialize(aux).map_err(DecodeError::deserialize)
        } else {
            deserialize(buf).map_err(DecodeError::deserialize)
        }
    }
}

pin_project! {
    /// A [`Stream`] adapter that deserializes items from a stream of bytes.
    pub struct DecodeStream<S, T> {
        buf: Vec<u8>,
        #[pin]
        inner: S,
        _phantom: PhantomData<T>,
    }
}

impl<S, T> DecodeStream<S, T> {
    /// TODO: docs
    #[inline]
    pub fn new(stream: S) -> Self {
        Self { buf: Vec::new(), inner: stream, _phantom: PhantomData }
    }

    /// TODO: docs
    #[inline]
    pub fn with_type<U>(&mut self) -> &mut DecodeStream<S, U> {
        unsafe { mem::transmute(self) }
    }
}

impl<S, T, Bytes, InnerError> Stream for DecodeStream<S, T>
where
    S: Stream<Item = Result<Bytes, InnerError>>,
    T: Decode,
    Bytes: AsRef<[u8]>,
{
    type Item = Result<T, Either<InnerError, <T as Decode>::Error>>;

    #[inline(always)]
    fn poll_next(
        self: Pin<&mut Self>,
        ctx: &mut Context<'_>,
    ) -> Poll<Option<Self::Item>> {
        let this = self.project();

        let bytes = match this.inner.poll_next(ctx) {
            Poll::Ready(Some(Ok(bytes))) => bytes,
            Poll::Ready(Some(Err(err))) => {
                return Poll::Ready(Some(Err(Either::Left(err))))
            },
            Poll::Ready(None) => return Poll::Ready(None),
            Poll::Pending => return Poll::Pending,
        };

        let result =
            T::decode(bytes.as_ref(), this.buf).map_err(Either::Right);

        Poll::Ready(Some(result))
    }
}

pin_project! {
    /// A [`Sink`] adapter that serializes items before sending them to the
    /// underlying sink.
    pub struct EncodeSink<S, T> {
        buf: Vec<u8>,
        #[pin]
        inner: S,
        _phantom: PhantomData<T>,
    }
}

impl<S, T> EncodeSink<S, T> {
    /// TODO: docs
    #[inline]
    pub fn new(sink: S) -> Self {
        Self { buf: Vec::new(), inner: sink, _phantom: PhantomData }
    }

    /// TODO: docs
    #[inline]
    pub fn with_type<U>(&mut self) -> &mut EncodeSink<S, U> {
        unsafe { mem::transmute(self) }
    }
}

impl<S, T> Sink<T> for EncodeSink<S, T>
where
    T: Encode,
    S: Sink<Vec<u8>>,
{
    type Error = Either<S::Error, T::Error>;

    #[inline(always)]
    fn poll_ready(
        self: Pin<&mut Self>,
        ctx: &mut Context<'_>,
    ) -> Poll<Result<(), Self::Error>> {
        self.project().inner.poll_ready(ctx).map_err(Either::Left)
    }

    #[inline(always)]
    fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
        let this = self.project();
        let mut buf = Vec::new();
        item.encode(&mut buf, this.buf).map_err(Either::Right)?;
        this.inner.start_send(buf).map_err(Either::Left)
    }

    #[inline(always)]
    fn poll_flush(
        self: Pin<&mut Self>,
        ctx: &mut Context<'_>,
    ) -> Poll<Result<(), Self::Error>> {
        self.project().inner.poll_flush(ctx).map_err(Either::Left)
    }

    #[inline(always)]
    fn poll_close(
        self: Pin<&mut Self>,
        ctx: &mut Context<'_>,
    ) -> Poll<Result<(), Self::Error>> {
        self.project().inner.poll_close(ctx).map_err(Either::Left)
    }
}

/// TODO: docs
#[derive(Debug)]
pub struct EncodeError {
    kind: EncodeErrorKind,
}

impl EncodeError {
    #[inline]
    pub(crate) fn compress(err: CompressError) -> Self {
        Self { kind: EncodeErrorKind::Compress(err) }
    }

    #[inline]
    pub(crate) fn serialize(err: SerializeError) -> Self {
        Self { kind: EncodeErrorKind::Serialize(err) }
    }
}

impl core::fmt::Display for EncodeError {
    #[inline]
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
        let err: &dyn core::fmt::Display = match &self.kind {
            EncodeErrorKind::Compress(err) => err,
            EncodeErrorKind::Serialize(err) => err,
        };

        write!(f, "Encoding failed: {err}")
    }
}

impl StdError for EncodeError {}

#[derive(Debug)]
enum EncodeErrorKind {
    Compress(CompressError),
    Serialize(SerializeError),
}

/// TODO: docs
#[derive(Debug)]
pub struct DecodeError {
    kind: DecodeErrorKind,
}

impl DecodeError {
    #[inline]
    pub(crate) fn empty() -> Self {
        Self { kind: DecodeErrorKind::EmptyBuffer }
    }

    #[inline]
    pub(crate) fn decompress(err: DecompressError) -> Self {
        Self { kind: DecodeErrorKind::Decompress(err) }
    }

    #[inline]
    pub(crate) fn deserialize(err: DeserializeError) -> Self {
        Self { kind: DecodeErrorKind::Deserialize(err) }
    }

    #[inline]
    pub(crate) fn invalid_byte(byte: u8) -> Self {
        Self {
            kind: DecodeErrorKind::InvalidFirstByte(InvalidFirstByte(byte)),
        }
    }
}

#[derive(Debug)]
enum DecodeErrorKind {
    Decompress(DecompressError),
    Deserialize(DeserializeError),
    EmptyBuffer,
    InvalidFirstByte(InvalidFirstByte),
}

impl core::fmt::Display for DecodeError {
    #[inline]
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
        let err: &dyn core::fmt::Display = match &self.kind {
            DecodeErrorKind::Decompress(err) => err,
            DecodeErrorKind::Deserialize(err) => err,
            DecodeErrorKind::EmptyBuffer => &"buffer is empty",
            DecodeErrorKind::InvalidFirstByte(err) => err,
        };

        write!(f, "Decoding failed: {err}")
    }
}

#[derive(Debug)]
struct InvalidFirstByte(u8);

impl core::fmt::Display for InvalidFirstByte {
    #[inline]
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
        write!(f, "invalid first byte {}, expected 0 or 1", self.0)
    }
}

impl StdError for DecodeError {}