use crate::error::{ReceiveError, RequestError, RespondError, SendError};
use tokio::sync::{mpsc, oneshot};
use tokio::time::{timeout, Duration};
use futures::Stream;
use std::pin::Pin;
use std::task::{Context, Poll};
pub type Payload<Req, Res> = (Req, Responder<Res>);
#[derive(Debug)]
pub struct RequestSender<Req, Res> {
request_sender: mpsc::Sender<Payload<Req, Res>>,
timeout_duration: Option<Duration>,
}
#[derive(Debug)]
pub struct RequestReceiver<Req, Res> {
request_receiver: mpsc::Receiver<Payload<Req, Res>>,
}
#[derive(Debug)]
pub struct Responder<Res> {
response_sender: oneshot::Sender<Res>,
}
#[derive(Debug)]
pub struct ResponseReceiver<Res> {
pub(crate) response_receiver: Option<oneshot::Receiver<Res>>,
pub(crate) timeout_duration: Option<Duration>,
}
impl<Req, Res> RequestSender<Req, Res> {
fn new(
request_sender: mpsc::Sender<Payload<Req, Res>>,
timeout_duration: Option<Duration>,
) -> Self {
RequestSender {
request_sender,
timeout_duration,
}
}
pub async fn send(&self, request: Req) -> Result<ResponseReceiver<Res>, SendError<Req>> {
let (response_sender, response_receiver) = oneshot::channel::<Res>();
let responder = Responder::new(response_sender);
let payload = (request, responder);
self.request_sender
.send(payload)
.await
.map_err(|payload| SendError(payload.0 .0))?;
let receiver = ResponseReceiver::new(response_receiver, self.timeout_duration);
Ok(receiver)
}
pub async fn send_receive(&self, request: Req) -> Result<Res, RequestError<Req>> {
let mut receiver = self.send(request).await?;
receiver.recv().await.map_err(|err| err.into())
}
pub fn is_closed(&self) -> bool {
self.request_sender.is_closed()
}
}
impl<Req, Res> Clone for RequestSender<Req, Res> {
fn clone(&self) -> Self {
RequestSender {
request_sender: self.request_sender.clone(),
timeout_duration: self.timeout_duration,
}
}
}
impl<Req, Res> RequestReceiver<Req, Res> {
fn new(receiver: mpsc::Receiver<Payload<Req, Res>>) -> Self {
RequestReceiver {
request_receiver: receiver,
}
}
pub async fn recv(&mut self) -> Result<Payload<Req, Res>, RequestError<Req>> {
match self.request_receiver.recv().await {
Some(payload) => Ok(payload),
None => Err(RequestError::RecvError),
}
}
pub fn close(&mut self) {
self.request_receiver.close()
}
pub fn into_stream(self) -> impl Stream<Item = Payload<Req, Res>> {
let stream: RequestReceiverStream<Req, Res> = self.into();
stream
}
}
impl<Res> ResponseReceiver<Res> {
pub(crate) fn new(
response_receiver: oneshot::Receiver<Res>,
timeout_duration: Option<Duration>,
) -> Self {
Self {
response_receiver: Some(response_receiver),
timeout_duration,
}
}
pub async fn recv(&mut self) -> Result<Res, ReceiveError> {
match self.response_receiver.take() {
Some(response_receiver) => match self.timeout_duration {
Some(duration) => match timeout(duration, response_receiver).await {
Ok(response_result) => response_result.map_err(|err| err.into()),
Err(..) => Err(ReceiveError::TimeoutError),
},
None => Ok(response_receiver.await?),
},
None => Err(ReceiveError::RecvError),
}
}
}
impl<Res> Responder<Res> {
pub(crate) fn new(response_sender: oneshot::Sender<Res>) -> Self {
Self { response_sender }
}
pub fn respond(self, response: Res) -> Result<(), RespondError<Res>> {
self.response_sender.send(response).map_err(RespondError)
}
pub fn is_closed(&self) -> bool {
self.response_sender.is_closed()
}
}
pub fn channel<Req, Res>(buffer: usize) -> (RequestSender<Req, Res>, RequestReceiver<Req, Res>) {
let (sender, receiver) = mpsc::channel::<Payload<Req, Res>>(buffer);
let request_sender = RequestSender::new(sender, None);
let request_receiver = RequestReceiver::new(receiver);
(request_sender, request_receiver)
}
pub fn channel_with_timeout<Req, Res>(
buffer: usize,
timeout_duration: Duration,
) -> (RequestSender<Req, Res>, RequestReceiver<Req, Res>) {
let (sender, receiver) = mpsc::channel::<Payload<Req, Res>>(buffer);
let request_sender = RequestSender::new(sender, Some(timeout_duration));
let request_receiver = RequestReceiver::new(receiver);
(request_sender, request_receiver)
}
#[derive(Debug)]
pub struct RequestReceiverStream<Req, Res> {
inner: RequestReceiver<Req, Res>,
}
impl<Req, Res> RequestReceiverStream<Req, Res> {
pub fn new(recv: RequestReceiver<Req, Res>) -> Self {
Self { inner: recv }
}
#[cfg(not(tarpaulin_include))]
pub fn into_inner(self) -> RequestReceiver<Req, Res> {
self.inner
}
#[cfg(not(tarpaulin_include))]
pub fn close(&mut self) {
self.inner.close()
}
}
impl<Req, Res> Stream for RequestReceiverStream<Req, Res> {
type Item = Payload<Req, Res>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.inner.request_receiver.poll_recv(cx)
}
}
impl<Req, Res> AsRef<RequestReceiver<Req, Res>> for RequestReceiverStream<Req, Res> {
#[cfg(not(tarpaulin_include))]
fn as_ref(&self) -> &RequestReceiver<Req, Res> {
&self.inner
}
}
impl<Req, Res> AsMut<RequestReceiver<Req, Res>> for RequestReceiverStream<Req, Res> {
#[cfg(not(tarpaulin_include))]
fn as_mut(&mut self) -> &mut RequestReceiver<Req, Res> {
&mut self.inner
}
}
impl<Req, Res> From<RequestReceiver<Req, Res>> for RequestReceiverStream<Req, Res> {
fn from(receiver: RequestReceiver<Req, Res>) -> Self {
RequestReceiverStream::new(receiver)
}
}