collab-server 0.0.7

Nomad's collab server
Documentation
use core::time::Duration;
use std::io;

use common::encode::{Decode, DecodeError, Encode, EncodeError};
use common::{
    ClientMessage,
    Either,
    RateLimitedStream,
    RateLimiter,
    ServerMessage,
};
use futures::{AsyncRead, AsyncWrite, SinkExt, StreamExt};

use crate::ConnectionError;

/// TODO: docs
pub(crate) struct StreamReader<R> {
    inner: DecodeStream<RateLimitedStream<FramedRead<R>>, ClientMessage>,
}

impl<R: AsyncRead + Unpin> StreamReader<R> {
    #[inline]
    pub(crate) fn new(reader: R) -> Self {
        let reader = framed_read(reader);
        let rate_limiter = RateLimiter::new(Duration::from_millis(100), 10);
        let reader = RateLimitedStream::new(reader, rate_limiter);
        let reader = decode_stream(reader);
        Self { inner: reader }
    }

    #[inline]
    pub(crate) async fn read(
        &mut self,
    ) -> Result<ClientMessage, ConnectionError> {
        let Some(maybe_msg) = self.inner.next().await else {
            return Err(ConnectionError::ClientDisconnected);
        };

        maybe_msg.map_err(handle_read_error)
    }

    #[inline]
    pub(crate) async fn read_other<T>(&mut self) -> Result<T, ConnectionError>
    where
        T: Decode<Error = DecodeError>,
    {
        let inner = &mut self.inner;

        #[cfg(feature = "__tests")]
        let inner = inner.inner_mut();

        let Some(maybe_msg) = inner.with_type::<T>().next().await else {
            return Err(ConnectionError::ClientDisconnected);
        };

        maybe_msg.map_err(handle_read_error)
    }
}

/// TODO: docs
pub(crate) struct StreamWriter<W> {
    inner: EncodeSink<FramedWrite<W>, ServerMessage>,
}

impl<W: AsyncWrite + Unpin> StreamWriter<W> {
    #[inline]
    pub(crate) fn new(writer: W) -> Self {
        Self { inner: encode_sink(framed_write(writer)) }
    }

    #[inline]
    pub(crate) async fn write(
        &mut self,
        msg: ServerMessage,
    ) -> Result<(), ConnectionError> {
        self.inner.send(msg).await.map_err(handle_write_error)
    }

    #[inline]
    pub(crate) async fn write_other<T>(
        &mut self,
        msg: T,
    ) -> Result<(), ConnectionError>
    where
        T: Encode<Error = EncodeError>,
    {
        let inner = &mut self.inner;

        #[cfg(feature = "__tests")]
        let inner = inner.inner_mut();

        inner.with_type::<T>().send(msg).await.map_err(handle_write_error)
    }
}

#[inline]
fn handle_read_error(err: Either<io::Error, DecodeError>) -> ConnectionError {
    match err {
        Either::Left(io_err) => {
            if ConnectionError::is_client_disconnected(&io_err) {
                ConnectionError::ClientDisconnected
            } else {
                ConnectionError::StreamRead(io_err)
            }
        },

        Either::Right(de_err) => ConnectionError::Decode(de_err),
    }
}

#[inline]
fn handle_write_error(err: Either<io::Error, EncodeError>) -> ConnectionError {
    match err {
        Either::Left(io_err) => {
            if ConnectionError::is_client_disconnected(&io_err) {
                ConnectionError::ClientDisconnected
            } else {
                ConnectionError::StreamWrite(io_err)
            }
        },

        Either::Right(encode) => ConnectionError::Encode(encode),
    }
}

use adapters::*;

mod adapters {
    #![allow(clippy::let_and_return)]

    use common::{encode, ClientMessage, ServerMessage};
    #[cfg(feature = "__tests")]
    use common::{CallbackSink, CallbackStream};
    use futures::{AsyncRead, AsyncWrite};

    #[cfg(feature = "__tests")]
    use crate::tests;

    #[cfg(not(feature = "__tests"))]
    pub(super) type EncodeSink<S, T> = encode::EncodeSink<S, T>;

    #[cfg(feature = "__tests")]
    pub(super) type EncodeSink<S, T> =
        CallbackSink<encode::EncodeSink<S, T>, T>;

    #[inline(always)]
    pub(super) fn encode_sink<S>(sink: S) -> EncodeSink<S, ServerMessage> {
        let sink = encode::EncodeSink::new(sink);

        #[cfg(feature = "__tests")]
        let sink = CallbackSink::new(sink, tests::before_send_msg);

        sink
    }

    #[cfg(not(feature = "__tests"))]
    pub(super) type DecodeStream<S, T> = encode::DecodeStream<S, T>;

    #[cfg(feature = "__tests")]
    pub(super) type DecodeStream<S, T> =
        CallbackStream<encode::DecodeStream<S, T>, T>;

    #[inline(always)]
    pub(super) fn decode_stream<S>(
        stream: S,
    ) -> DecodeStream<S, ClientMessage> {
        let stream = encode::DecodeStream::new(stream);

        #[cfg(feature = "__tests")]
        let stream = CallbackStream::new(stream, tests::after_receive_msg);

        stream
    }

    #[cfg(not(feature = "__tests"))]
    pub(super) type FramedRead<R> = common::FramedRead<R>;

    #[cfg(feature = "__tests")]
    pub(super) type FramedRead<R> =
        CallbackStream<common::FramedRead<R>, [u8]>;

    #[inline(always)]
    pub(super) fn framed_read<R: AsyncRead>(reader: R) -> FramedRead<R> {
        let reader = common::FramedRead::new(reader);

        #[cfg(feature = "__tests")]
        let reader = CallbackStream::new(reader, tests::after_receive_bytes);

        reader
    }

    #[cfg(not(feature = "__tests"))]
    pub(super) type FramedWrite<W> = common::FramedWrite<W>;

    #[cfg(feature = "__tests")]
    pub(super) type FramedWrite<W> =
        CallbackSink<common::FramedWrite<W>, [u8]>;

    #[inline(always)]
    pub(super) fn framed_write<W: AsyncWrite>(writer: W) -> FramedWrite<W> {
        let writer = common::FramedWrite::new(writer);

        #[cfg(feature = "__tests")]
        let writer = CallbackSink::new(writer, tests::before_send_bytes);

        writer
    }
}