use core::time::Duration;
use std::io;
use common::encode::{Decode, DecodeError, Encode, EncodeError};
use common::{
ClientMessage,
Either,
RateLimitedStream,
RateLimiter,
ServerMessage,
};
use futures::{AsyncRead, AsyncWrite, SinkExt, StreamExt};
use crate::ConnectionError;
pub(crate) struct StreamReader<R> {
inner: DecodeStream<RateLimitedStream<FramedRead<R>>, ClientMessage>,
}
impl<R: AsyncRead + Unpin> StreamReader<R> {
#[inline]
pub(crate) fn new(reader: R) -> Self {
let reader = framed_read(reader);
let rate_limiter = RateLimiter::new(Duration::from_millis(100), 10);
let reader = RateLimitedStream::new(reader, rate_limiter);
let reader = decode_stream(reader);
Self { inner: reader }
}
#[inline]
pub(crate) async fn read(
&mut self,
) -> Result<ClientMessage, ConnectionError> {
let Some(maybe_msg) = self.inner.next().await else {
return Err(ConnectionError::ClientDisconnected);
};
maybe_msg.map_err(handle_read_error)
}
#[inline]
pub(crate) async fn read_other<T>(&mut self) -> Result<T, ConnectionError>
where
T: Decode<Error = DecodeError>,
{
let inner = &mut self.inner;
#[cfg(feature = "__tests")]
let inner = inner.inner_mut();
let Some(maybe_msg) = inner.with_type::<T>().next().await else {
return Err(ConnectionError::ClientDisconnected);
};
maybe_msg.map_err(handle_read_error)
}
}
pub(crate) struct StreamWriter<W> {
inner: EncodeSink<FramedWrite<W>, ServerMessage>,
}
impl<W: AsyncWrite + Unpin> StreamWriter<W> {
#[inline]
pub(crate) fn new(writer: W) -> Self {
Self { inner: encode_sink(framed_write(writer)) }
}
#[inline]
pub(crate) async fn write(
&mut self,
msg: ServerMessage,
) -> Result<(), ConnectionError> {
self.inner.send(msg).await.map_err(handle_write_error)
}
#[inline]
pub(crate) async fn write_other<T>(
&mut self,
msg: T,
) -> Result<(), ConnectionError>
where
T: Encode<Error = EncodeError>,
{
let inner = &mut self.inner;
#[cfg(feature = "__tests")]
let inner = inner.inner_mut();
inner.with_type::<T>().send(msg).await.map_err(handle_write_error)
}
}
#[inline]
fn handle_read_error(err: Either<io::Error, DecodeError>) -> ConnectionError {
match err {
Either::Left(io_err) => {
if ConnectionError::is_client_disconnected(&io_err) {
ConnectionError::ClientDisconnected
} else {
ConnectionError::StreamRead(io_err)
}
},
Either::Right(de_err) => ConnectionError::Decode(de_err),
}
}
#[inline]
fn handle_write_error(err: Either<io::Error, EncodeError>) -> ConnectionError {
match err {
Either::Left(io_err) => {
if ConnectionError::is_client_disconnected(&io_err) {
ConnectionError::ClientDisconnected
} else {
ConnectionError::StreamWrite(io_err)
}
},
Either::Right(encode) => ConnectionError::Encode(encode),
}
}
use adapters::*;
mod adapters {
#![allow(clippy::let_and_return)]
use common::{encode, ClientMessage, ServerMessage};
#[cfg(feature = "__tests")]
use common::{CallbackSink, CallbackStream};
use futures::{AsyncRead, AsyncWrite};
#[cfg(feature = "__tests")]
use crate::tests;
#[cfg(not(feature = "__tests"))]
pub(super) type EncodeSink<S, T> = encode::EncodeSink<S, T>;
#[cfg(feature = "__tests")]
pub(super) type EncodeSink<S, T> =
CallbackSink<encode::EncodeSink<S, T>, T>;
#[inline(always)]
pub(super) fn encode_sink<S>(sink: S) -> EncodeSink<S, ServerMessage> {
let sink = encode::EncodeSink::new(sink);
#[cfg(feature = "__tests")]
let sink = CallbackSink::new(sink, tests::before_send_msg);
sink
}
#[cfg(not(feature = "__tests"))]
pub(super) type DecodeStream<S, T> = encode::DecodeStream<S, T>;
#[cfg(feature = "__tests")]
pub(super) type DecodeStream<S, T> =
CallbackStream<encode::DecodeStream<S, T>, T>;
#[inline(always)]
pub(super) fn decode_stream<S>(
stream: S,
) -> DecodeStream<S, ClientMessage> {
let stream = encode::DecodeStream::new(stream);
#[cfg(feature = "__tests")]
let stream = CallbackStream::new(stream, tests::after_receive_msg);
stream
}
#[cfg(not(feature = "__tests"))]
pub(super) type FramedRead<R> = common::FramedRead<R>;
#[cfg(feature = "__tests")]
pub(super) type FramedRead<R> =
CallbackStream<common::FramedRead<R>, [u8]>;
#[inline(always)]
pub(super) fn framed_read<R: AsyncRead>(reader: R) -> FramedRead<R> {
let reader = common::FramedRead::new(reader);
#[cfg(feature = "__tests")]
let reader = CallbackStream::new(reader, tests::after_receive_bytes);
reader
}
#[cfg(not(feature = "__tests"))]
pub(super) type FramedWrite<W> = common::FramedWrite<W>;
#[cfg(feature = "__tests")]
pub(super) type FramedWrite<W> =
CallbackSink<common::FramedWrite<W>, [u8]>;
#[inline(always)]
pub(super) fn framed_write<W: AsyncWrite>(writer: W) -> FramedWrite<W> {
let writer = common::FramedWrite::new(writer);
#[cfg(feature = "__tests")]
let writer = CallbackSink::new(writer, tests::before_send_bytes);
writer
}
}