use async_trait::async_trait;
use std::{error::Error, fmt};
use crate::{
chmux::ChMuxError,
connect::ConnectError,
rch::base::{RecvError, SendError},
};
#[cfg(feature = "default-codec-set")]
use crate::{connect::Connect, rch::base, RemoteSend};
#[cfg(feature = "default-codec-set")]
use futures::Future;
#[cfg_attr(docsrs, doc(cfg(feature = "rch")))]
#[derive(Debug, Clone)]
pub enum ProvideError<TransportSinkError, TransportStreamError> {
ChMux(ChMuxError<TransportSinkError, TransportStreamError>),
Connect(ConnectError<TransportSinkError, TransportStreamError>),
Send(SendError<()>),
}
impl<TransportSinkError, TransportStreamError> From<ChMuxError<TransportSinkError, TransportStreamError>>
for ProvideError<TransportSinkError, TransportStreamError>
{
fn from(err: ChMuxError<TransportSinkError, TransportStreamError>) -> Self {
Self::ChMux(err)
}
}
impl<TransportSinkError, TransportStreamError> From<ConnectError<TransportSinkError, TransportStreamError>>
for ProvideError<TransportSinkError, TransportStreamError>
{
fn from(err: ConnectError<TransportSinkError, TransportStreamError>) -> Self {
Self::Connect(err)
}
}
impl<T, TransportSinkError, TransportStreamError> From<SendError<T>>
for ProvideError<TransportSinkError, TransportStreamError>
{
fn from(err: SendError<T>) -> Self {
Self::Send(err.without_item())
}
}
impl<TransportSinkError, TransportStreamError> fmt::Display
for ProvideError<TransportSinkError, TransportStreamError>
where
TransportSinkError: fmt::Display,
TransportStreamError: fmt::Display,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::ChMux(err) => write!(f, "chmux error: {err}"),
Self::Connect(err) => write!(f, "connect error: {err}"),
Self::Send(err) => write!(f, "send error: {err}"),
}
}
}
impl<TransportSinkError, TransportStreamError> Error for ProvideError<TransportSinkError, TransportStreamError>
where
TransportSinkError: fmt::Debug + fmt::Display,
TransportStreamError: fmt::Debug + fmt::Display,
{
}
#[cfg_attr(docsrs, doc(cfg(feature = "rch")))]
#[derive(Debug, Clone)]
pub enum ConsumeError<TransportSinkError, TransportStreamError> {
ChMux(ChMuxError<TransportSinkError, TransportStreamError>),
Connect(ConnectError<TransportSinkError, TransportStreamError>),
Recv(RecvError),
NoValueReceived,
}
impl<TransportSinkError, TransportStreamError> From<ChMuxError<TransportSinkError, TransportStreamError>>
for ConsumeError<TransportSinkError, TransportStreamError>
{
fn from(err: ChMuxError<TransportSinkError, TransportStreamError>) -> Self {
Self::ChMux(err)
}
}
impl<TransportSinkError, TransportStreamError> From<ConnectError<TransportSinkError, TransportStreamError>>
for ConsumeError<TransportSinkError, TransportStreamError>
{
fn from(err: ConnectError<TransportSinkError, TransportStreamError>) -> Self {
Self::Connect(err)
}
}
impl<TransportSinkError, TransportStreamError> From<RecvError>
for ConsumeError<TransportSinkError, TransportStreamError>
{
fn from(err: RecvError) -> Self {
Self::Recv(err)
}
}
impl<TransportSinkError, TransportStreamError> fmt::Display
for ConsumeError<TransportSinkError, TransportStreamError>
where
TransportSinkError: fmt::Display,
TransportStreamError: fmt::Display,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::ChMux(err) => write!(f, "chmux error: {err}"),
Self::Connect(err) => write!(f, "connect error: {err}"),
Self::Recv(err) => write!(f, "receive error: {err}"),
Self::NoValueReceived => write!(f, "no value was received for consumption"),
}
}
}
impl<TransportSinkError, TransportStreamError> Error for ConsumeError<TransportSinkError, TransportStreamError>
where
TransportSinkError: fmt::Debug + fmt::Display,
TransportStreamError: fmt::Debug + fmt::Display,
{
}
#[cfg_attr(docsrs, doc(cfg(feature = "rch")))]
#[async_trait]
pub trait ConnectExt<T, TransportSinkError, TransportStreamError> {
async fn provide(self, value: T) -> Result<(), ProvideError<TransportSinkError, TransportStreamError>>;
async fn consume(self) -> Result<T, ConsumeError<TransportSinkError, TransportStreamError>>;
}
#[async_trait]
#[cfg(feature = "default-codec-set")]
impl<TransportSinkError, TransportStreamError, T, ConnectFuture>
ConnectExt<T, TransportSinkError, TransportStreamError> for ConnectFuture
where
T: RemoteSend,
TransportSinkError: Send + Error + 'static,
TransportStreamError: Send + Error + 'static,
ConnectFuture: Future<
Output = Result<
(
Connect<'static, TransportSinkError, TransportStreamError>,
base::Sender<T, crate::codec::Default>,
base::Receiver<T, crate::codec::Default>,
),
ConnectError<TransportSinkError, TransportStreamError>,
>,
> + Send,
{
async fn provide(self, value: T) -> Result<(), ProvideError<TransportSinkError, TransportStreamError>> {
let (mut conn, mut tx, _) = self.await?;
tokio::select! {
biased;
res = &mut conn => res?,
res = tx.send(value) => res?,
}
tokio::spawn(async move {
if let Err(err) = conn.await {
tracing::warn!(%err, "connection failed");
}
});
Ok(())
}
async fn consume(self) -> Result<T, ConsumeError<TransportSinkError, TransportStreamError>> {
let (mut conn, _, mut rx) = self.await?;
let value = tokio::select! {
biased;
res = &mut conn => {
res?;
return Err(ConsumeError::NoValueReceived);
},
res = rx.recv() => {
match res? {
Some(value) => value,
None => return Err(ConsumeError::NoValueReceived),
}
}
};
tokio::spawn(async move {
if let Err(err) = conn.await {
tracing::warn!(%err, "connection failed");
}
});
Ok(value)
}
}