1#![forbid(unsafe_code)]
2
3#[cfg(test)]
4mod tests;
5
6use futures::{
7 channel::mpsc::{unbounded, UnboundedReceiver as Receiver, UnboundedSender as Sender},
8 lock::Mutex as AsyncMutex,
9 StreamExt,
10};
11use std::{
12 collections::hash_map::HashMap,
13 mem,
14 sync::{
15 atomic::{AtomicU32, Ordering},
16 Mutex,
17 },
18};
19
20type Id = u32;
21type AtomicId = AtomicU32;
22
23struct Tx<T>(Id, T);
24struct Rx<R>(Id, Option<R>);
25
26pub struct Requester<T, R> {
27 sender: Sender<Tx<T>>,
28 receiver: AsyncMutex<Receiver<Rx<R>>>,
29 buffer: Mutex<HashMap<Id, Option<Option<R>>>>,
35 counter: AtomicId,
36}
37
38pub struct Responder<T, R> {
39 receiver: Receiver<Tx<T>>,
40 sender: Sender<Rx<R>>,
41}
42
43pub fn channel<T, R>() -> (Requester<T, R>, Responder<T, R>) {
44 let (tx_sender, tx_receiver) = unbounded::<Tx<T>>();
45 let (rx_sender, rx_receiver) = unbounded::<Rx<R>>();
46 (
47 Requester {
48 sender: tx_sender,
49 receiver: AsyncMutex::new(rx_receiver),
50 buffer: Mutex::new(HashMap::new()),
51 counter: AtomicId::new(0),
52 },
53 Responder {
54 receiver: tx_receiver,
55 sender: rx_sender,
56 },
57 )
58}
59
60pub struct Request<'a, R> {
62 id: Id,
63 receiver: &'a AsyncMutex<Receiver<Rx<R>>>,
64 buffer: &'a Mutex<HashMap<Id, Option<Option<R>>>>,
65}
66
67impl<T, R> Requester<T, R> {
68 pub fn request(&self, message: T) -> Result<Request<'_, R>, T> {
74 let id = self.counter.fetch_add(1, Ordering::SeqCst);
75 let mut buffer = self.buffer.lock().unwrap();
76 debug_assert!(!buffer.contains_key(&id));
77 match self.sender.unbounded_send(Tx(id, message)) {
78 Ok(()) => assert!(buffer.insert(id, None).is_none()),
79 Err(err) => return Err(err.into_inner().1),
80 }
81 Ok(Request {
82 id,
83 receiver: &self.receiver,
84 buffer: &self.buffer,
85 })
86 }
87}
88
89impl<'a, R> Request<'a, R> {
90 fn take_from_buffer(&self) -> Option<Option<R>> {
91 self.buffer
92 .lock()
93 .unwrap()
94 .get_mut(&self.id)
95 .unwrap()
96 .take()
97 }
98
99 fn put_in_buffer(&self, id: Id, message: Option<R>) {
100 if let Some(value) = self.buffer.lock().unwrap().get_mut(&id) {
101 assert!(value.replace(message).is_none());
102 }
103 }
104
105 pub fn try_get_response(self) -> Option<Option<R>> {
111 if let Some(value) = self.take_from_buffer() {
112 return Some(value);
113 }
114
115 let mut guard = self.receiver.try_lock()?;
116
117 if let Some(value) = self.take_from_buffer() {
119 return Some(value);
120 }
121
122 loop {
123 match guard.try_next().ok()? {
124 Some(Rx(id, message)) => {
125 if id == self.id {
126 return Some(message);
127 }
128 self.put_in_buffer(id, message);
129 }
130 None => return Some(None),
131 }
132 }
133 }
134
135 pub async fn get_response(self) -> Option<R> {
141 if let Some(value) = self.take_from_buffer() {
142 return value;
143 }
144
145 let mut guard = self.receiver.lock().await;
146
147 if let Some(value) = self.take_from_buffer() {
149 return value;
150 }
151
152 while let Some(Rx(id, message)) = guard.next().await {
153 if id == self.id {
154 return message;
155 }
156 self.put_in_buffer(id, message);
157 }
158
159 None
160 }
161}
162
163impl<'a, R> Drop for Request<'a, R> {
164 fn drop(&mut self) {
165 self.buffer.lock().unwrap().remove(&self.id).unwrap();
166 }
167}
168
169pub struct Response<'a, R> {
173 id: Id,
174 sender: &'a mut Sender<Rx<R>>,
175}
176
177impl<T, R> Responder<T, R> {
178 pub async fn next(&mut self) -> Option<(T, Response<'_, R>)> {
186 let Tx(id, message) = self.receiver.next().await?;
187 Some((
188 message,
189 Response {
190 id,
191 sender: &mut self.sender,
192 },
193 ))
194 }
195
196 pub fn try_next(&mut self) -> Option<Option<(T, Response<'_, R>)>> {
203 self.receiver.try_next().ok().map(|r| {
204 r.map(|Tx(id, message)| {
205 (
206 message,
207 Response {
208 id,
209 sender: &mut self.sender,
210 },
211 )
212 })
213 })
214 }
215}
216
217impl<'a, R> Response<'a, R> {
218 pub fn respond(self, message: R) {
220 let _ = self.sender.unbounded_send(Rx(self.id, Some(message)));
221 mem::forget(self);
223 }
224}
225
226impl<'a, R> Drop for Response<'a, R> {
227 fn drop(&mut self) {
228 let _ = self.sender.unbounded_send(Rx(self.id, None));
229 }
230}