request_channel/
lib.rs

1#![forbid(unsafe_code)]
2
3#[cfg(test)]
4mod tests;
5
6use futures::{
7    channel::mpsc::{unbounded, UnboundedReceiver as Receiver, UnboundedSender as Sender},
8    lock::Mutex as AsyncMutex,
9    StreamExt,
10};
11use std::{
12    collections::hash_map::HashMap,
13    mem,
14    sync::{
15        atomic::{AtomicU32, Ordering},
16        Mutex,
17    },
18};
19
20type Id = u32;
21type AtomicId = AtomicU32;
22
23struct Tx<T>(Id, T);
24struct Rx<R>(Id, Option<R>);
25
26pub struct Requester<T, R> {
27    sender: Sender<Tx<T>>,
28    receiver: AsyncMutex<Receiver<Rx<R>>>,
29    /// Buffer contains ids of all `Request`s waiting for response.
30    /// Possible values and their meaning:
31    /// + `None` - response may arrive in future.
32    /// + `Some(None)` - response will never arrive.
33    /// + `Some(Some(message))` - response arrived but hasn't been extracted by corresponding `Request`.
34    buffer: Mutex<HashMap<Id, Option<Option<R>>>>,
35    counter: AtomicId,
36}
37
38pub struct Responder<T, R> {
39    receiver: Receiver<Tx<T>>,
40    sender: Sender<Rx<R>>,
41}
42
43pub fn channel<T, R>() -> (Requester<T, R>, Responder<T, R>) {
44    let (tx_sender, tx_receiver) = unbounded::<Tx<T>>();
45    let (rx_sender, rx_receiver) = unbounded::<Rx<R>>();
46    (
47        Requester {
48            sender: tx_sender,
49            receiver: AsyncMutex::new(rx_receiver),
50            buffer: Mutex::new(HashMap::new()),
51            counter: AtomicId::new(0),
52        },
53        Responder {
54            receiver: tx_receiver,
55            sender: rx_sender,
56        },
57    )
58}
59
60/// Request handle. Used as a promise for response.
61pub struct Request<'a, R> {
62    id: Id,
63    receiver: &'a AsyncMutex<Receiver<Rx<R>>>,
64    buffer: &'a Mutex<HashMap<Id, Option<Option<R>>>>,
65}
66
67impl<T, R> Requester<T, R> {
68    /// Make request.
69    ///
70    /// This function returns:
71    /// + `Ok(request)` - request made where `request` is and object used to get response whent it's ready.
72    /// + `Err(message)` - [`Responder`] is closed, `message` is returned back.
73    pub fn request(&self, message: T) -> Result<Request<'_, R>, T> {
74        let id = self.counter.fetch_add(1, Ordering::SeqCst);
75        let mut buffer = self.buffer.lock().unwrap();
76        debug_assert!(!buffer.contains_key(&id));
77        match self.sender.unbounded_send(Tx(id, message)) {
78            Ok(()) => assert!(buffer.insert(id, None).is_none()),
79            Err(err) => return Err(err.into_inner().1),
80        }
81        Ok(Request {
82            id,
83            receiver: &self.receiver,
84            buffer: &self.buffer,
85        })
86    }
87}
88
89impl<'a, R> Request<'a, R> {
90    fn take_from_buffer(&self) -> Option<Option<R>> {
91        self.buffer
92            .lock()
93            .unwrap()
94            .get_mut(&self.id)
95            .unwrap()
96            .take()
97    }
98
99    fn put_in_buffer(&self, id: Id, message: Option<R>) {
100        if let Some(value) = self.buffer.lock().unwrap().get_mut(&id) {
101            assert!(value.replace(message).is_none());
102        }
103    }
104
105    /// Try get response without waiting.
106    ///
107    /// This function returns:
108    /// + `None` - no response yet but it may arrive in future.
109    /// + `Some(response)` - response arrived or it will never arrive (see [`Self::get_response`]).
110    pub fn try_get_response(self) -> Option<Option<R>> {
111        if let Some(value) = self.take_from_buffer() {
112            return Some(value);
113        }
114
115        let mut guard = self.receiver.try_lock()?;
116
117        // Check the buffer once more to detect insertion right before guard but after previous check.
118        if let Some(value) = self.take_from_buffer() {
119            return Some(value);
120        }
121
122        loop {
123            match guard.try_next().ok()? {
124                Some(Rx(id, message)) => {
125                    if id == self.id {
126                        return Some(message);
127                    }
128                    self.put_in_buffer(id, message);
129                }
130                None => return Some(None),
131            }
132        }
133    }
134
135    /// Wait for response and return it.
136    ///
137    /// This function returns:
138    /// + `None` - no response (due to [`Responder`] being closed or corresponding [`Response`] being ignored).
139    /// + `Some(message)` - response arrived.
140    pub async fn get_response(self) -> Option<R> {
141        if let Some(value) = self.take_from_buffer() {
142            return value;
143        }
144
145        let mut guard = self.receiver.lock().await;
146
147        // Check the buffer once more to detect insertion right before guard but after previous check.
148        if let Some(value) = self.take_from_buffer() {
149            return value;
150        }
151
152        while let Some(Rx(id, message)) = guard.next().await {
153            if id == self.id {
154                return message;
155            }
156            self.put_in_buffer(id, message);
157        }
158
159        None
160    }
161}
162
163impl<'a, R> Drop for Request<'a, R> {
164    fn drop(&mut self) {
165        self.buffer.lock().unwrap().remove(&self.id).unwrap();
166    }
167}
168
169/// Handle for responding to request.
170///
171/// When dropped the corresponding [`Request`] will be notified about request absense.
172pub struct Response<'a, R> {
173    id: Id,
174    sender: &'a mut Sender<Rx<R>>,
175}
176
177impl<T, R> Responder<T, R> {
178    /// Wait for next request.
179    ///
180    /// This function returns:
181    /// + `Some(message, response)` - request received. `message` is data being sent, `response` is an object used to respond to request.
182    /// + `None` - [`Requester`] is closed.
183    ///
184    /// *This is inherent method rather than [`Stream`](`futures::Stream`) impl because for now there is no way to put lifetime in its [`Output`](`futures::Stream::Item`).*
185    pub async fn next(&mut self) -> Option<(T, Response<'_, R>)> {
186        let Tx(id, message) = self.receiver.next().await?;
187        Some((
188            message,
189            Response {
190                id,
191                sender: &mut self.sender,
192            },
193        ))
194    }
195
196    /// Try get next request.
197    ///
198    /// This function returns:
199    /// + `Some(Some(message, response))` - request received.
200    /// + `Some(None)` - channel is closed.
201    /// + `None` - channel is not closed but no request arrived yet.
202    pub fn try_next(&mut self) -> Option<Option<(T, Response<'_, R>)>> {
203        self.receiver.try_next().ok().map(|r| {
204            r.map(|Tx(id, message)| {
205                (
206                    message,
207                    Response {
208                        id,
209                        sender: &mut self.sender,
210                    },
211                )
212            })
213        })
214    }
215}
216
217impl<'a, R> Response<'a, R> {
218    /// Send response to request.
219    pub fn respond(self, message: R) {
220        let _ = self.sender.unbounded_send(Rx(self.id, Some(message)));
221        // Suppress calling `drop`.
222        mem::forget(self);
223    }
224}
225
226impl<'a, R> Drop for Response<'a, R> {
227    fn drop(&mut self) {
228        let _ = self.sender.unbounded_send(Rx(self.id, None));
229    }
230}