Skip to main content

tempest_rt/sync/
mpsc.rs

1//! Bounded multi-producer, single-consumer channel.
2//!
3//! Create a channel with [`bounded`]. The [`BoundedSender`] can be cloned for multiple producers;
4//! there is exactly one [`Receiver`]. Sending blocks when the channel is full, providing
5//! natural backpressure.
6
7use std::{
8    cell::RefCell,
9    collections::VecDeque,
10    future::poll_fn,
11    pin::Pin,
12    rc::Rc,
13    task::{Context, Poll, Waker},
14};
15
16use derive_more::{Display, Error};
17use futures::Stream;
18
19struct Inner<T> {
20    queue: VecDeque<T>,
21    /// Is the consumer still alive?
22    rx_alive: bool,
23    /// Stores the waker of the waiting consumer that is waiting for produced values.
24    rx_waker: Option<Waker>,
25    /// Stores any wakers of consumers that are waiting on the channel to make space.
26    tx_wakers: Vec<Waker>,
27}
28
29/// Sending half of a bounded channel. Cloneable for multiple producers.
30pub struct BoundedSender<T> {
31    inner: Rc<RefCell<Inner<T>>>,
32}
33
34impl<T> Clone for BoundedSender<T> {
35    fn clone(&self) -> Self {
36        Self {
37            inner: self.inner.clone(),
38        }
39    }
40}
41
42/// Receiving half of a bounded channel.
43pub struct Receiver<T> {
44    inner: Rc<RefCell<Inner<T>>>,
45}
46
47/// Creates a bounded channel with the given capacity.
48///
49/// Panics if `cap` is 0.
50pub fn bounded<T>(cap: usize) -> (BoundedSender<T>, Receiver<T>) {
51    assert_ne!(cap, 0, "a bounded channel with capacity 0 does not work");
52    let mut queue = VecDeque::new();
53    queue.reserve_exact(cap);
54    let inner = Rc::new(RefCell::new(Inner {
55        queue,
56        rx_alive: true,
57        rx_waker: None,
58        tx_wakers: Vec::new(),
59    }));
60    let tx = BoundedSender {
61        inner: inner.clone(),
62    };
63    let rx = Receiver { inner };
64    (tx, rx)
65}
66
67/// Error returned by [`BoundedSender::send`] when the receiver has been dropped.
68#[derive(Debug, Display, Error)]
69#[display("receiver has been dropped")]
70pub struct SendError<T>(pub T);
71
72/// Error returned by [`BoundedSender::try_send`].
73#[derive(Debug, Display, Error)]
74pub enum TrySendError<T> {
75    /// The channel is at capacity.
76    #[display("channel is full")]
77    Full(T),
78    /// The receiver has been dropped.
79    #[display("receiver has been dropped")]
80    Closed(T),
81}
82
83impl<T> BoundedSender<T> {
84    /// Sends `val`, parking until space is available.
85    ///
86    /// Returns `Err` if the receiver has been dropped.
87    pub async fn send(&mut self, val: T) -> Result<(), SendError<T>> {
88        let mut val = Some(val);
89        poll_fn(|cx| {
90            if let Some(waker) = self.inner.borrow_mut().rx_waker.take() {
91                waker.wake();
92            }
93            match self.try_send(val.take().unwrap()) {
94                Ok(()) => Poll::Ready(Ok(())),
95                Err(TrySendError::Full(v)) => {
96                    val = Some(v);
97                    self.inner.borrow_mut().tx_wakers.push(cx.waker().clone());
98                    Poll::Pending
99                }
100                Err(TrySendError::Closed(v)) => Poll::Ready(Err(SendError(v))),
101            }
102        })
103        .await
104    }
105
106    /// Sends `val` without waiting.
107    ///
108    /// Returns `Err(Full)` if the channel is at capacity, or `Err(Closed)` if the receiver
109    /// has been dropped.
110    pub fn try_send(&mut self, val: T) -> Result<(), TrySendError<T>> {
111        let mut inner = self.inner.borrow_mut();
112        if !inner.rx_alive {
113            Err(TrySendError::Closed(val))
114        } else if inner.queue.len() == inner.queue.capacity() {
115            Err(TrySendError::Full(val))
116        } else {
117            inner.queue.push_back(val);
118            Ok(())
119        }
120    }
121}
122
123/// Error returned by [`Receiver::recv`] when all senders have been dropped.
124#[derive(Debug, Display, Error, PartialEq, Eq)]
125#[display("all senders have been dropped")]
126pub struct RecvError;
127
128/// Error returned by [`Receiver::try_recv`].
129#[derive(Debug, Display, Error, PartialEq, Eq)]
130pub enum TryRecvError {
131    /// The queue is empty but senders are still alive.
132    #[display("channel is empty")]
133    Empty,
134    /// All senders have been dropped and the queue is empty.
135    #[display("all senders have been dropped")]
136    Closed,
137}
138
139impl<T> Receiver<T> {
140    pub(crate) fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> {
141        match self.try_recv() {
142            Ok(val) => {
143                for waker in self.inner.borrow_mut().tx_wakers.drain(..) {
144                    waker.wake();
145                }
146                Poll::Ready(Ok(val))
147            }
148            Err(TryRecvError::Empty) => {
149                self.inner.borrow_mut().rx_waker = Some(cx.waker().clone());
150                Poll::Pending
151            }
152            Err(TryRecvError::Closed) => Poll::Ready(Err(RecvError)),
153        }
154    }
155
156    /// Receives the next value, parking until one arrives.
157    ///
158    /// Returns `Err` if all senders have been dropped and the queue is empty.
159    pub async fn recv(&mut self) -> Result<T, RecvError> {
160        poll_fn(|cx| self.poll_recv(cx)).await
161    }
162
163    /// Receives without waiting.
164    ///
165    /// Returns `Err(Empty)` if the queue is empty, or `Err(Closed)` if all senders have been
166    /// dropped and the queue is empty.
167    pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
168        if let Some(val) = self.inner.borrow_mut().queue.pop_front() {
169            return Ok(val);
170        }
171        if Rc::strong_count(&self.inner) == 1 {
172            Err(TryRecvError::Closed)
173        } else {
174            Err(TryRecvError::Empty)
175        }
176    }
177}
178
179impl<T> Stream for Receiver<T> {
180    type Item = T;
181
182    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
183        self.poll_recv(cx).map(|r| r.ok())
184    }
185
186    fn size_hint(&self) -> (usize, Option<usize>) {
187        (self.inner.borrow().queue.len(), None)
188    }
189}
190
191impl<T> Drop for BoundedSender<T> {
192    fn drop(&mut self) {
193        // When the last sender is dropped (strong_count == 2: this sender + receiver),
194        // wake the receiver so it can observe the Closed state.
195        if Rc::strong_count(&self.inner) == 2 {
196            if let Some(waker) = self.inner.borrow_mut().rx_waker.take() {
197                waker.wake();
198            }
199        }
200    }
201}
202
203impl<T> Drop for Receiver<T> {
204    fn drop(&mut self) {
205        self.inner.borrow_mut().rx_alive = false;
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use std::{
212        pin::{Pin, pin},
213        task::Poll,
214    };
215
216    use futures::StreamExt;
217    use tempest_io::VirtualIo;
218
219    use crate::{block_on, spawn};
220
221    use super::*;
222
223    // -- try_send --
224
225    #[test]
226    fn test_try_send_one() {
227        block_on(VirtualIo::default(), async {
228            let (mut tx, _rx) = bounded(1);
229            assert!(tx.try_send(42).is_ok());
230        });
231    }
232
233    #[test]
234    fn test_try_send_exactly_full() {
235        block_on(VirtualIo::default(), async {
236            let (mut tx, _rx) = bounded(2);
237            assert!(tx.try_send(1).is_ok());
238            assert!(tx.try_send(2).is_ok());
239        });
240    }
241
242    #[test]
243    fn test_try_send_over_full() {
244        block_on(VirtualIo::default(), async {
245            let (mut tx, _rx) = bounded(1);
246            tx.try_send(1).unwrap();
247            match tx.try_send(99) {
248                Err(TrySendError::Full(v)) => assert_eq!(v, 99),
249                _ => panic!("expected Full"),
250            }
251        });
252    }
253
254    #[test]
255    fn test_try_send_closed() {
256        block_on(VirtualIo::default(), async {
257            let (mut tx, rx) = bounded::<i32>(1);
258            drop(rx);
259            match tx.try_send(99) {
260                Err(TrySendError::Closed(v)) => assert_eq!(v, 99),
261                _ => panic!("expected Closed"),
262            }
263        });
264    }
265
266    // -- send --
267
268    #[test]
269    fn test_send_one() {
270        block_on(VirtualIo::default(), async {
271            let (mut tx, _rx) = bounded(1);
272            tx.send(42).await.unwrap();
273        });
274    }
275
276    #[test]
277    fn test_send_exactly_full() {
278        block_on(VirtualIo::default(), async {
279            let (mut tx, _rx) = bounded(2);
280            tx.send(1).await.unwrap();
281            tx.send(2).await.unwrap();
282        });
283    }
284
285    #[test]
286    fn test_send_pending_when_full() {
287        block_on(VirtualIo::default(), async {
288            let (mut tx, _rx) = bounded(1);
289            tx.try_send(1).unwrap();
290
291            let waker = std::task::Waker::noop();
292            let mut cx = std::task::Context::from_waker(&waker);
293            let mut fut = pin!(tx.send(2));
294            assert!(matches!(fut.as_mut().poll(&mut cx), Poll::Pending));
295        });
296    }
297
298    #[test]
299    fn test_send_when_full_eventually_resolves() {
300        block_on(VirtualIo::default(), async {
301            let (mut tx, mut rx) = bounded(1);
302            tx.try_send(1).unwrap();
303            // this will block, until the rx frees up space
304            let mut handle = spawn(async move { tx.send(2).await.unwrap() });
305
306            // the task handle should be pending, until we make space in the rx
307            let waker = std::task::Waker::noop();
308            let mut cx = std::task::Context::from_waker(&waker);
309            assert!(matches!(Pin::new(&mut handle).poll(&mut cx), Poll::Pending));
310
311            assert_eq!(rx.recv().await, Ok(1));
312            assert_eq!(rx.recv().await, Ok(2));
313            // channel now empty and receiver dropped
314            assert!(rx.try_recv().is_err())
315        });
316    }
317
318    #[test]
319    fn test_recv_when_empty_eventually_resolves() {
320        block_on(VirtualIo::default(), async {
321            let (mut tx, mut rx) = bounded(1);
322            spawn(async move {
323                assert_eq!(rx.recv().await, Ok(1));
324                assert_eq!(rx.recv().await, Ok(2));
325            });
326
327            tx.send(1).await.unwrap();
328            assert!(matches!(tx.try_send(2), Err(TrySendError::Full(2))));
329            tx.send(2).await.unwrap();
330        });
331    }
332
333    #[test]
334    fn test_send_closed() {
335        block_on(VirtualIo::default(), async {
336            let (mut tx, rx) = bounded::<i32>(1);
337            drop(rx);
338            match tx.send(99).await {
339                Err(SendError(v)) => assert_eq!(v, 99),
340                Ok(()) => panic!("expected Err"),
341            }
342        });
343    }
344
345    // -- try_recv --
346
347    #[test]
348    fn test_try_recv_one() {
349        block_on(VirtualIo::default(), async {
350            let (mut tx, mut rx) = bounded(1);
351            tx.try_send(42).unwrap();
352            assert_eq!(rx.try_recv().unwrap(), 42);
353        });
354    }
355
356    #[test]
357    fn test_try_recv_in_order() {
358        block_on(VirtualIo::default(), async {
359            let (mut tx, mut rx) = bounded(3);
360            tx.try_send(1).unwrap();
361            tx.try_send(2).unwrap();
362            tx.try_send(3).unwrap();
363            assert_eq!(rx.try_recv().unwrap(), 1);
364            assert_eq!(rx.try_recv().unwrap(), 2);
365            assert_eq!(rx.try_recv().unwrap(), 3);
366        });
367    }
368
369    #[test]
370    fn test_try_recv_empty() {
371        block_on(VirtualIo::default(), async {
372            let (_tx, mut rx) = bounded::<i32>(1);
373            assert_eq!(rx.try_recv(), Err(TryRecvError::Empty));
374        });
375    }
376
377    #[test]
378    fn test_try_recv_closed() {
379        block_on(VirtualIo::default(), async {
380            let (tx, mut rx) = bounded::<i32>(1);
381            drop(tx);
382            assert_eq!(rx.try_recv(), Err(TryRecvError::Closed));
383        });
384    }
385
386    // -- recv --
387
388    #[test]
389    fn test_recv_one() {
390        block_on(VirtualIo::default(), async {
391            let (mut tx, mut rx) = bounded(1);
392            tx.send(42).await.unwrap();
393            assert_eq!(rx.recv().await.unwrap(), 42);
394        });
395    }
396
397    #[test]
398    fn test_recv_in_order() {
399        block_on(VirtualIo::default(), async {
400            let (mut tx, mut rx) = bounded(3);
401            tx.send(1).await.unwrap();
402            tx.send(2).await.unwrap();
403            tx.send(3).await.unwrap();
404            assert_eq!(rx.recv().await.unwrap(), 1);
405            assert_eq!(rx.recv().await.unwrap(), 2);
406            assert_eq!(rx.recv().await.unwrap(), 3);
407        });
408    }
409
410    #[test]
411    fn test_recv_pending_when_empty() {
412        block_on(VirtualIo::default(), async {
413            let (_tx, mut rx) = bounded::<i32>(1);
414
415            let waker = std::task::Waker::noop();
416            let mut cx = std::task::Context::from_waker(&waker);
417            let mut fut = pin!(rx.recv());
418            assert!(matches!(fut.as_mut().poll(&mut cx), Poll::Pending));
419        });
420    }
421
422    #[test]
423    fn test_recv_closed() {
424        block_on(VirtualIo::default(), async {
425            let (tx, mut rx) = bounded::<i32>(1);
426            drop(tx);
427            assert_eq!(rx.recv().await, Err(RecvError));
428        });
429    }
430
431    #[test]
432    fn test_recv_woken_when_last_sender_dropped() {
433        // Receiver is blocked waiting on an empty channel. Dropping the last sender
434        // must wake the receiver so it can observe Closed without deadlocking.
435        block_on(VirtualIo::default(), async {
436            let (tx, mut rx) = bounded::<i32>(1);
437            spawn(async move {
438                drop(tx);
439            });
440            assert_eq!(rx.recv().await, Err(RecvError));
441        });
442    }
443
444    #[test]
445    fn test_recv_woken_when_last_of_multiple_senders_dropped() {
446        block_on(VirtualIo::default(), async {
447            let (tx, mut rx) = bounded::<i32>(1);
448            let tx2 = tx.clone();
449            spawn(async move {
450                drop(tx);
451                drop(tx2);
452            });
453            assert_eq!(rx.recv().await, Err(RecvError));
454        });
455    }
456
457    // -- stream implementation of recv --
458
459    #[test]
460    fn test_stream_recv() {
461        const ITEMS: &[i32; 3] = &[1, 2, 3];
462        block_on(VirtualIo::default(), async {
463            let (mut tx, rx) = bounded::<i32>(1);
464            spawn(async move {
465                for &item in ITEMS {
466                    tx.send(item).await.unwrap();
467                }
468            });
469
470            let result: Vec<_> = rx.collect().await;
471            assert_eq!(result, ITEMS);
472        })
473    }
474}