1#![forbid(unsafe_code)]
39#![warn(missing_docs, missing_debug_implementations, rust_2018_idioms)]
40
41use async_std::{channel, prelude::FutureExt, sync::Mutex};
42use std::any::Any;
43
44#[doc(hidden)]
45#[cfg(feature = "derive")]
46pub use async_reply_derive::*;
47
48pub fn endpoints() -> (Requester, Replyer) {
51 let (sndr, recv) = channel::bounded(10);
52 (
53 Requester { inner: sndr },
54 Replyer {
55 buffer: Mutex::default(),
56 inner: recv,
57 },
58 )
59}
60
61#[derive(Debug, Clone)]
63pub struct Requester {
64 inner: channel::Sender<Box<dyn Any + Send>>,
65}
66
67#[derive(Debug)]
69pub struct Replyer {
70 buffer: Mutex<Vec<Box<dyn Any + Send>>>,
71 inner: channel::Receiver<Box<dyn Any + Send>>,
72}
73
74#[must_use = "ReplyHandle should be used to respond to the received message"]
76#[derive(Debug)]
77pub struct ReplyHandle<T>(channel::Sender<T>);
78
79struct MessageHandle<M: Message> {
80 msg: M,
81 sndr: ReplyHandle<M::Response>,
82}
83
84pub trait Message: 'static + Send {
86 type Response: Send;
88}
89
90impl Requester {
91 pub async fn send<M>(&self, msg: M) -> Result<M::Response, Error>
93 where
94 M: Message,
95 {
96 let (sndr, recv) = channel::bounded::<M::Response>(1);
97 let sndr = ReplyHandle(sndr);
98
99 self.inner
100 .send(Box::new(MessageHandle { msg, sndr }))
101 .await?;
102
103 recv.recv().await.map_err(Error::ReplayError)
104 }
105}
106
107impl Replyer {
108 pub async fn recv<M>(&self) -> Result<(M, ReplyHandle<M::Response>), Error>
110 where
111 M: Message,
112 {
113 let is_message_type = |any: &Box<dyn Any + Send>| any.is::<MessageHandle<M>>();
114
115 loop {
116 let buffer_search_fut = async {
117 loop {
118 let mut buffer = self.buffer.lock().await;
119 let msg_index = buffer
120 .iter()
121 .enumerate()
122 .find(|(_, elem)| is_message_type(elem))
123 .map(|(index, _)| index);
124 if let Some(index) = msg_index {
125 return Ok(buffer.remove(index));
128 }
129 async_std::task::yield_now().await;
130 }
131 };
132 let channel_search_fut = async { self.inner.recv().await.map_err(Error::ReceivError) };
133
134 let msg = buffer_search_fut.race(channel_search_fut).await?;
135 if is_message_type(&msg) {
136 return Ok(msg.downcast::<MessageHandle<M>>().unwrap().into_tuple());
137 }
138 self.buffer.lock().await.push(msg);
139 }
140 }
141}
142
143impl<T> ReplyHandle<T> {
144 pub async fn respond(&self, r: T) -> Result<(), Error> {
146 Ok(self.0.send(r).await?)
147 }
148}
149
150impl<M: Message> MessageHandle<M> {
151 fn into_tuple(self) -> (M, ReplyHandle<M::Response>) {
152 (self.msg, self.sndr)
153 }
154}
155
156#[derive(Debug, derive_more::Display, derive_more::Error, derive_more::From)]
159pub enum Error {
160 SendError,
162
163 #[from(ignore)]
165 ReplayError(channel::RecvError),
166
167 ReceivError(channel::RecvError),
169}
170
171impl<T> From<channel::SendError<T>> for Error {
172 fn from(_: channel::SendError<T>) -> Self {
173 Error::SendError
178 }
179}