use crate::{
message::{BidiStreamingMsg, ClientStreamingMsg, RpcMsg, ServerStreamingMsg},
transport::ConnectionErrors,
Service, ServiceConnection,
};
use futures::{
future::BoxFuture, stream::BoxStream, FutureExt, Sink, SinkExt, Stream, StreamExt, TryFutureExt,
};
use pin_project::pin_project;
use std::{
error,
fmt::{self, Debug},
marker::PhantomData,
pin::Pin,
result,
task::{Context, Poll},
};
#[derive(Debug)]
pub struct RpcClient<S, C> {
source: C,
p: PhantomData<S>,
}
impl<S, C: Clone> Clone for RpcClient<S, C> {
fn clone(&self) -> Self {
Self {
source: self.source.clone(),
p: PhantomData,
}
}
}
#[pin_project]
#[derive(Debug)]
pub struct UpdateSink<S: Service, C: ServiceConnection<S>, T: Into<S::Req>>(
#[pin] C::SendSink,
PhantomData<T>,
);
impl<S: Service, C: ServiceConnection<S>, T: Into<S::Req>> Sink<T> for UpdateSink<S, C, T> {
type Error = C::SendError;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().0.poll_ready_unpin(cx)
}
fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
let req: S::Req = item.into();
self.project().0.start_send_unpin(req)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().0.poll_flush_unpin(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().0.poll_close_unpin(cx)
}
}
impl<S: Service, C: ServiceConnection<S>> RpcClient<S, C> {
pub fn new(source: C) -> Self {
Self {
source,
p: PhantomData,
}
}
pub fn into_inner(self) -> C {
self.source
}
pub async fn rpc<M>(&self, msg: M) -> result::Result<M::Response, RpcClientError<C>>
where
M: RpcMsg<S>,
{
let msg = msg.into();
let (mut send, mut recv) = self.source.open_bi().await.map_err(RpcClientError::Open)?;
send.send(msg).await.map_err(RpcClientError::<C>::Send)?;
let res = recv
.next()
.await
.ok_or(RpcClientError::<C>::EarlyClose)?
.map_err(RpcClientError::<C>::RecvError)?;
drop(send);
M::Response::try_from(res).map_err(|_| RpcClientError::DowncastError)
}
pub async fn server_streaming<M>(
&self,
msg: M,
) -> result::Result<
BoxStream<'static, result::Result<M::Response, StreamingResponseItemError<C>>>,
StreamingResponseError<C>,
>
where
M: ServerStreamingMsg<S>,
{
let msg = msg.into();
let (mut send, recv) = self
.source
.open_bi()
.await
.map_err(StreamingResponseError::Open)?;
send.send(msg)
.map_err(StreamingResponseError::<C>::Send)
.await?;
let recv = recv.map(move |x| match x {
Ok(x) => {
M::Response::try_from(x).map_err(|_| StreamingResponseItemError::DowncastError)
}
Err(e) => Err(StreamingResponseItemError::RecvError(e)),
});
let recv = DeferDrop(recv, send).boxed();
Ok(recv)
}
pub async fn client_streaming<M>(
&self,
msg: M,
) -> result::Result<
(
UpdateSink<S, C, M::Update>,
BoxFuture<'static, result::Result<M::Response, ClientStreamingItemError<C>>>,
),
ClientStreamingError<C>,
>
where
M: ClientStreamingMsg<S>,
{
let msg = msg.into();
let (mut send, mut recv) = self
.source
.open_bi()
.await
.map_err(ClientStreamingError::Open)?;
send.send(msg).map_err(ClientStreamingError::Send).await?;
let send = UpdateSink::<S, C, M::Update>(send, PhantomData);
let recv = async move {
let item = recv
.next()
.await
.ok_or(ClientStreamingItemError::EarlyClose)?;
match item {
Ok(x) => {
M::Response::try_from(x).map_err(|_| ClientStreamingItemError::DowncastError)
}
Err(e) => Err(ClientStreamingItemError::RecvError(e)),
}
}
.boxed();
Ok((send, recv))
}
pub async fn bidi<M>(
&self,
msg: M,
) -> result::Result<
(
UpdateSink<S, C, M::Update>,
BoxStream<'static, result::Result<M::Response, BidiItemError<C>>>,
),
BidiError<C>,
>
where
M: BidiStreamingMsg<S>,
{
let msg = msg.into();
let (mut send, recv) = self.source.open_bi().await.map_err(BidiError::Open)?;
send.send(msg).await.map_err(BidiError::<C>::Send)?;
let send = UpdateSink(send, PhantomData);
let recv = recv
.map(|x| match x {
Ok(x) => M::Response::try_from(x).map_err(|_| BidiItemError::DowncastError),
Err(e) => Err(BidiItemError::RecvError(e)),
})
.boxed();
Ok((send, recv))
}
}
impl<S: Service, C: ServiceConnection<S>> AsRef<C> for RpcClient<S, C> {
fn as_ref(&self) -> &C {
&self.source
}
}
#[derive(Debug)]
pub enum RpcClientError<C: ConnectionErrors> {
Open(C::OpenError),
Send(C::SendError),
EarlyClose,
RecvError(C::RecvError),
DowncastError,
}
impl<C: ConnectionErrors> fmt::Display for RpcClientError<C> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(self, f)
}
}
impl<C: ConnectionErrors> error::Error for RpcClientError<C> {}
#[derive(Debug)]
pub enum BidiError<C: ConnectionErrors> {
Open(C::OpenError),
Send(C::SendError),
}
impl<C: ConnectionErrors> fmt::Display for BidiError<C> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(self, f)
}
}
impl<C: ConnectionErrors> error::Error for BidiError<C> {}
#[derive(Debug)]
pub enum BidiItemError<C: ConnectionErrors> {
RecvError(C::RecvError),
DowncastError,
}
impl<C: ConnectionErrors> fmt::Display for BidiItemError<C> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(self, f)
}
}
impl<C: ConnectionErrors> error::Error for BidiItemError<C> {}
#[derive(Debug)]
pub enum ClientStreamingError<C: ConnectionErrors> {
Open(C::OpenError),
Send(C::SendError),
}
impl<C: ConnectionErrors> fmt::Display for ClientStreamingError<C> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(self, f)
}
}
impl<C: ConnectionErrors> error::Error for ClientStreamingError<C> {}
#[derive(Debug)]
pub enum ClientStreamingItemError<C: ConnectionErrors> {
EarlyClose,
RecvError(C::RecvError),
DowncastError,
}
impl<C: ConnectionErrors> fmt::Display for ClientStreamingItemError<C> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(self, f)
}
}
impl<C: ConnectionErrors> error::Error for ClientStreamingItemError<C> {}
#[derive(Debug)]
pub enum StreamingResponseError<C: ConnectionErrors> {
Open(C::OpenError),
Send(C::SendError),
}
impl<S: ConnectionErrors> fmt::Display for StreamingResponseError<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(self, f)
}
}
impl<S: ConnectionErrors> error::Error for StreamingResponseError<S> {}
#[derive(Debug)]
pub enum StreamingResponseItemError<S: ConnectionErrors> {
RecvError(S::RecvError),
DowncastError,
}
impl<S: ConnectionErrors> fmt::Display for StreamingResponseItemError<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(self, f)
}
}
impl<S: ConnectionErrors> error::Error for StreamingResponseItemError<S> {}
#[pin_project]
struct DeferDrop<S: Stream, X>(#[pin] S, X);
impl<S: Stream, X> Stream for DeferDrop<S, X> {
type Item = S::Item;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project().0.poll_next(cx)
}
}