use futures::{Future, FutureExt};
use std::{
pin::Pin,
task::{Context, Poll},
};
use tokio::sync::oneshot;
use crate::all::*;
pub fn new_request<T>() -> (Tx<T>, Rx<T>) {
let (tx, rx) = oneshot::channel();
(Tx(tx), Rx(rx))
}
#[derive(Debug)]
pub struct Tx<M>(pub(super) oneshot::Sender<M>);
impl<M> Tx<M> {
pub fn send(self, msg: M) -> Result<(), TxError<M>> {
self.0.send(msg).map_err(|msg| TxError(msg))
}
pub fn is_closed(&self) -> bool {
self.0.is_closed()
}
pub async fn closed(&mut self) {
self.0.closed().await
}
}
impl<M, R> MessageDerive<M> for Tx<R> {
type Payload = (M, Rx<R>);
type Returned = Tx<R>;
fn create(msg: M) -> ((M, Rx<R>), Tx<R>) {
let (tx, rx) = new_request();
((msg, rx), tx)
}
fn cancel(sent: (M, Rx<R>), _returned: Tx<R>) -> M {
sent.0
}
}
#[derive(Debug)]
pub struct Rx<M>(pub(super) oneshot::Receiver<M>);
impl<M> Rx<M> {
pub fn try_recv(&mut self) -> Result<M, TryRxError> {
self.0.try_recv().map_err(|e| e.into())
}
pub fn recv_blocking(self) -> Result<M, RxError> {
self.0.blocking_recv().map_err(|e| e.into())
}
pub fn close(&mut self) {
self.0.close()
}
}
impl<M, R> MessageDerive<M> for Rx<R> {
type Payload = (M, Tx<R>);
type Returned = Rx<R>;
fn create(msg: M) -> ((M, Tx<R>), Rx<R>) {
let (tx, rx) = new_request();
((msg, tx), rx)
}
fn cancel(sent: (M, Tx<R>), _returned: Rx<R>) -> M {
sent.0
}
}
impl<M> Unpin for Rx<M> {}
impl<M> Future for Rx<M> {
type Output = Result<M, RxError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.0.poll_unpin(cx).map_err(|e| e.into())
}
}
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash, thiserror::Error)]
#[error("Failed to receive from Rx because it is closed.")]
pub struct RxError;
impl From<oneshot::error::RecvError> for RxError {
fn from(_: oneshot::error::RecvError) -> Self {
Self
}
}
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash, thiserror::Error)]
pub enum TryRxError {
#[error("Closed")]
Closed,
#[error("Empty")]
Empty,
}
impl From<oneshot::error::TryRecvError> for TryRxError {
fn from(e: oneshot::error::TryRecvError) -> Self {
match e {
oneshot::error::TryRecvError::Empty => Self::Empty,
oneshot::error::TryRecvError::Closed => Self::Closed,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash, thiserror::Error)]
#[error("Failed to send to Tx because it is closed.")]
pub struct TxError<M>(pub M);