use std::{
error,
fmt::{self, Debug},
result,
};
use futures_lite::{Future, StreamExt};
use futures_util::{FutureExt, SinkExt};
use crate::{
message::{InteractionPattern, Msg},
server::{race2, RpcChannel, RpcServerError},
transport::{ConnectionErrors, StreamTypes},
Connector, RpcClient, Service,
};
#[derive(Debug, Clone, Copy)]
pub struct Rpc;
impl InteractionPattern for Rpc {}
pub trait RpcMsg<S: Service>: Msg<S, Pattern = Rpc> {
type Response: Into<S::Res> + TryFrom<S::Res> + Send + 'static;
}
impl<T: RpcMsg<S>, S: Service> Msg<S> for T {
type Pattern = Rpc;
}
#[derive(Debug)]
pub enum Error<C: ConnectionErrors> {
Open(C::OpenError),
Send(C::SendError),
EarlyClose,
RecvError(C::RecvError),
DowncastError,
}
impl<C: ConnectionErrors> fmt::Display for Error<C> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(self, f)
}
}
impl<C: ConnectionErrors> error::Error for Error<C> {}
impl<S, C> RpcClient<S, C>
where
S: Service,
C: Connector<S>,
{
pub async fn rpc<M>(&self, msg: M) -> result::Result<M::Response, Error<C>>
where
M: RpcMsg<S>,
{
let msg = msg.into();
let (mut send, mut recv) = self.source.open().await.map_err(Error::Open)?;
send.send(msg).await.map_err(Error::<C>::Send)?;
let res = recv
.next()
.await
.ok_or(Error::<C>::EarlyClose)?
.map_err(Error::<C>::RecvError)?;
drop(send);
M::Response::try_from(res).map_err(|_| Error::DowncastError)
}
}
impl<S, C> RpcChannel<S, C>
where
S: Service,
C: StreamTypes<In = S::Req, Out = S::Res>,
{
pub async fn rpc<M, F, Fut, T>(
self,
req: M,
target: T,
f: F,
) -> result::Result<(), RpcServerError<C>>
where
M: RpcMsg<S>,
F: FnOnce(T, M) -> Fut,
Fut: Future<Output = M::Response>,
T: Send + 'static,
{
let Self {
mut send, mut recv, ..
} = self;
let cancel = recv
.next()
.map(|_| RpcServerError::UnexpectedUpdateMessage::<C>);
race2(cancel.map(Err), async move {
let res = f(target, req).await;
let res = res.into();
send.send(res).await.map_err(RpcServerError::SendError)
})
.await
}
pub async fn rpc_map_err<M, F, Fut, T, R, E1, E2>(
self,
req: M,
target: T,
f: F,
) -> result::Result<(), RpcServerError<C>>
where
M: RpcMsg<S, Response = result::Result<R, E2>>,
F: FnOnce(T, M) -> Fut,
Fut: Future<Output = result::Result<R, E1>>,
E2: From<E1>,
T: Send + 'static,
{
let fut = |target: T, msg: M| async move {
let res: Result<R, E1> = f(target, msg).await;
let res: Result<R, E2> = res.map_err(E2::from);
res
};
self.rpc(req, target, fut).await
}
}