use tokio::sync::{mpsc, oneshot};
use crate::error::{CallError, TransportResult};
use crate::transport::{ClientTransport, ServerTransport};
pub struct TokioService<Req, Resp> {
tx: mpsc::Sender<(Req, oneshot::Sender<Resp>)>,
rx: Option<mpsc::Receiver<(Req, oneshot::Sender<Resp>)>>,
}
impl<Req, Resp> TokioService<Req, Resp> {
pub fn new(channel_depth: usize) -> Self {
let (tx, rx) = mpsc::channel(channel_depth);
Self { tx, rx: Some(rx) }
}
pub fn client(&self) -> TokioClient<Req, Resp> {
TokioClient {
tx: self.tx.clone(),
}
}
pub fn server(&mut self) -> TokioServer<Req, Resp> {
TokioServer {
rx: self.rx.take().expect("server() called more than once"),
}
}
}
#[derive(Clone)]
pub struct TokioClient<Req, Resp> {
tx: mpsc::Sender<(Req, oneshot::Sender<Resp>)>,
}
pub struct TokioServer<Req, Resp> {
rx: mpsc::Receiver<(Req, oneshot::Sender<Resp>)>,
}
#[derive(Debug)]
pub enum TokioLocalError {
ChannelClosed,
}
impl core::fmt::Display for TokioLocalError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
TokioLocalError::ChannelClosed => write!(f, "channel closed"),
}
}
}
impl<T> TransportResult<T> for TokioLocalError {
type Output = Result<T, CallError<TokioLocalError>>;
fn into_output(result: Result<T, Self>) -> Self::Output {
result.map_err(CallError::Transport)
}
}
impl<Req, Resp> ClientTransport<Req, Resp> for TokioClient<Req, Resp>
where
Req: Send + 'static,
Resp: Send + 'static,
{
type Error = TokioLocalError;
async fn call(&self, req: Req) -> Result<Resp, Self::Error> {
let (tx, rx) = oneshot::channel();
self.tx
.send((req, tx))
.await
.map_err(|_| TokioLocalError::ChannelClosed)?;
rx.await.map_err(|_| TokioLocalError::ChannelClosed)
}
}
impl<Req, Resp> ServerTransport<Req, Resp> for TokioServer<Req, Resp>
where
Req: Send + 'static,
Resp: Send + 'static,
{
type Error = TokioLocalError;
type ReplyToken = oneshot::Sender<Resp>;
async fn recv(&mut self) -> Result<(Req, Self::ReplyToken), Self::Error> {
self.rx.recv().await.ok_or(TokioLocalError::ChannelClosed)
}
async fn reply(&self, token: Self::ReplyToken, resp: Resp) -> Result<(), Self::Error> {
token.send(resp).map_err(|_| TokioLocalError::ChannelClosed)
}
}