pinned/
mpsc.rs

1//! A multi-producer, single-receiver channel.
2//!
3//! This is an asynchronous, `!Send` version of `std::sync::mpsc`. Currently only the unbounded
4//! variant is implemented.
5//!
6//! [`UnboundedReceiver`] implements [`Stream`] and allows asynchronous tasks to read values out of
7//! the channel. The `UnboundedReceiver` Stream will suspend and wait for available values if the
8//! current queue is empty. [`UnboundedSender`] implements [`Sink`] and allows messages to be sent
9//! to the corresponding `UnboundedReceiver`. The `UnboundedReceiver` also implements a
10//! [`send_now`](UnboundedSender::send_now) method to send a value synchronously.
11
12use std::collections::VecDeque;
13use std::marker::PhantomData;
14use std::pin::Pin;
15use std::rc::Rc;
16use std::task::{Context, Poll, Waker};
17
18use futures::sink::Sink;
19use futures::stream::{FusedStream, Stream};
20use thiserror::Error;
21
22use crate::cell::UnsafeCell;
23
24/// Error returned by [`try_next`](UnboundedReceiver::try_next).
25#[derive(Error, Debug)]
26#[error("queue is empty")]
27pub struct TryRecvError {
28    _marker: PhantomData<()>,
29}
30
31/// Error returned by [`send_now`](UnboundedSender::send_now).
32#[derive(Error, Debug)]
33#[error("failed to send")]
34pub struct SendError<T> {
35    /// The send value.
36    pub inner: T,
37}
38
39/// Error returned by [`UnboundedSender`] when used as a [`Sink`](futures::sink::Sink).
40#[derive(Error, Debug)]
41#[error("failed to send")]
42pub struct TrySendError {
43    _marker: PhantomData<()>,
44}
45
46#[derive(Debug)]
47struct Inner<T> {
48    rx_waker: Option<Waker>,
49    closed: bool,
50    sender_ctr: usize,
51    items: VecDeque<T>,
52
53    // This type is not send or sync.
54    _marker: PhantomData<Rc<()>>,
55}
56
57impl<T> Inner<T> {
58    fn close_impl(&mut self) {
59        self.closed = true;
60
61        if let Some(ref m) = self.rx_waker {
62            m.wake_by_ref();
63        }
64    }
65
66    #[inline]
67    fn try_next_impl(&mut self) -> Result<Option<T>, TryRecvError> {
68        match (self.items.pop_front(), self.closed) {
69            (Some(m), _) => Ok(Some(m)),
70            (None, true) => Ok(None),
71            (None, false) => Err(TryRecvError {
72                _marker: PhantomData,
73            }),
74        }
75    }
76
77    #[inline]
78    fn poll_next_impl(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
79        match (self.items.pop_front(), self.closed) {
80            (Some(m), _) => Poll::Ready(Some(m)),
81            (None, false) => {
82                self.rx_waker = Some(cx.waker().clone());
83                Poll::Pending
84            }
85            (None, true) => Poll::Ready(None),
86        }
87    }
88
89    #[inline]
90    fn is_terminated_impl(&self) -> bool {
91        self.items.is_empty() && self.closed
92    }
93
94    #[inline]
95    fn send_impl(&mut self, item: T) -> Result<(), SendError<T>> {
96        if self.closed {
97            return Err(SendError { inner: item });
98        }
99
100        self.items.push_back(item);
101
102        if let Some(ref m) = self.rx_waker {
103            m.wake_by_ref();
104        }
105
106        Ok(())
107    }
108
109    #[inline]
110    fn pre_clone_sender_impl(&mut self) {
111        self.sender_ctr += 1;
112    }
113
114    #[inline]
115    fn drop_sender_impl(&mut self) {
116        let sender_ctr = {
117            self.sender_ctr -= 1;
118            self.sender_ctr
119        };
120
121        if sender_ctr == 0 {
122            self.close_impl();
123        }
124    }
125}
126
127/// The receiver of an unbounded mpsc channel.
128///
129/// This is created by the [`unbounded`] function.
130#[derive(Debug)]
131pub struct UnboundedReceiver<T> {
132    inner: Rc<UnsafeCell<Inner<T>>>,
133}
134
135impl<T> UnboundedReceiver<T> {
136    /// Try to read the next value from the channel.
137    ///
138    /// This function will return:
139    /// - `Ok(Some(T))` if a value is ready.
140    /// - `Ok(None)` if the channel has become closed.
141    /// - `Err(TryRecvError)` if the channel is not closed and the channel is empty.
142    pub fn try_next(&self) -> Result<Option<T>, TryRecvError> {
143        // SAFETY:
144        //
145        // We can acquire a mutable reference without checking as:
146        //
147        // - This type is !Sync and !Send.
148        // - This function is not used by any other functions and hence uniquely owns the
149        // mutable reference.
150        // - The mutable reference is dropped at the end of this function.
151        unsafe { self.inner.with_mut(|inner| inner.try_next_impl()) }
152    }
153
154    /// Closes the receiver of the channel without dropping it.
155    ///
156    /// This prevents any further messages from being sent on the channel while still enabling the
157    /// receiver to drain messages in the buffer.
158    pub fn close(&self) {
159        // SAFETY:
160        //
161        // We can acquire a mutable reference without checking as:
162        //
163        // - This type is !Sync and !Send.
164        // - This function is not used by any other functions and hence uniquely owns the
165        // mutable reference.
166        // - The mutable reference is dropped at the end of this function.
167
168        unsafe { self.inner.with_mut(|inner| inner.close_impl()) }
169    }
170}
171
172impl<T> Stream for UnboundedReceiver<T> {
173    type Item = T;
174
175    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
176        // SAFETY:
177        //
178        // We can acquire a mutable reference without checking as:
179        //
180        // - This type is !Sync and !Send.
181        // - This function is not used by any other functions and hence uniquely owns the
182        // mutable reference.
183        // - The mutable reference is dropped at the end of this function.
184        unsafe { self.inner.with_mut(|inner| inner.poll_next_impl(cx)) }
185    }
186}
187
188impl<T> FusedStream for UnboundedReceiver<T> {
189    fn is_terminated(&self) -> bool {
190        // SAFETY:
191        //
192        // We can acquire a mutable reference without checking as:
193        //
194        // - This type is !Sync and !Send.
195        // - This function is not used by any other functions and hence uniquely owns the
196        // mutable reference.
197        // - The mutable reference is dropped at the end of this function.
198        unsafe { self.inner.with(|inner| inner.is_terminated_impl()) }
199    }
200}
201
202impl<T> Drop for UnboundedReceiver<T> {
203    fn drop(&mut self) {
204        // SAFETY:
205        //
206        // We can acquire a mutable reference without checking as:
207        //
208        // - This type is !Sync and !Send.
209        // - This function is not used by any other functions and hence uniquely owns the
210        // mutable reference.
211        // - The mutable reference is dropped at the end of this function.
212        unsafe { self.inner.with_mut(|inner| inner.close_impl()) }
213    }
214}
215
216/// The sender of an unbounded mpsc channel.
217///
218/// This value is created by the [`unbounded`] function.
219#[derive(Debug)]
220pub struct UnboundedSender<T> {
221    inner: Rc<UnsafeCell<Inner<T>>>,
222}
223
224impl<T> UnboundedSender<T> {
225    /// Sends a value to the unbounded receiver.
226    ///
227    /// This is an unbounded sender, so this function differs from
228    /// [`SinkExt::send`](futures::sink::SinkExt::send) by ensuring the return type reflects
229    /// that the channel is always ready to receive messages.
230    pub fn send_now(&self, item: T) -> Result<(), SendError<T>> {
231        // SAFETY:
232        //
233        // We can acquire a mutable reference without checking as:
234        //
235        // - This type is !Sync and !Send.
236        // - This function is not used by any function that have already acquired a mutable
237        // reference.
238        // - The mutable reference is dropped at the end of this function.
239
240        unsafe { self.inner.with_mut(move |inner| inner.send_impl(item)) }
241    }
242
243    /// Closes the channel.
244    ///
245    /// Every sender (dropped or not) is considered closed when this method is called.
246    pub fn close_now(&self) {
247        // SAFETY:
248        //
249        // We can acquire a mutable reference without checking as:
250        //
251        // - This type is !Sync and !Send.
252        // - This function is not used by any function that have already acquired a mutable
253        // reference.
254        // - The mutable reference is dropped at the end of this function.
255        unsafe { self.inner.with_mut(|inner| inner.close_impl()) }
256    }
257}
258
259impl<T> Clone for UnboundedSender<T> {
260    fn clone(&self) -> Self {
261        // SAFETY:
262        //
263        // We can acquire a mutable reference without checking as:
264        //
265        // - This type is !Sync and !Send.
266        // - This function is not used by any other functions and hence uniquely owns the
267        // mutable reference.
268        // - The mutable reference is dropped at the end of this function.
269        unsafe { self.inner.with_mut(|inner| inner.pre_clone_sender_impl()) }
270
271        Self {
272            inner: self.inner.clone(),
273        }
274    }
275}
276
277impl<T> Drop for UnboundedSender<T> {
278    fn drop(&mut self) {
279        // SAFETY:
280        //
281        // We can acquire a mutable reference without checking as:
282        //
283        // - This type is !Sync and !Send.
284        // - This function is not used by any other functions and hence uniquely owns the
285        // mutable reference.
286        // - The mutable reference is dropped at the end of this function.
287        unsafe { self.inner.with_mut(|inner| inner.drop_sender_impl()) }
288    }
289}
290
291impl<T> Sink<T> for &'_ UnboundedSender<T> {
292    type Error = TrySendError;
293
294    fn start_send(self: std::pin::Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
295        self.send_now(item).map_err(|_| TrySendError {
296            _marker: PhantomData,
297        })
298    }
299
300    fn poll_ready(
301        self: std::pin::Pin<&mut Self>,
302        _cx: &mut std::task::Context<'_>,
303    ) -> Poll<Result<(), Self::Error>> {
304        let closed = unsafe { self.inner.with(|inner| inner.closed) };
305
306        match closed {
307            false => Poll::Ready(Ok(())),
308            true => Poll::Ready(Err(TrySendError {
309                _marker: PhantomData,
310            })),
311        }
312    }
313
314    fn poll_flush(
315        self: std::pin::Pin<&mut Self>,
316        _cx: &mut std::task::Context<'_>,
317    ) -> Poll<Result<(), Self::Error>> {
318        Poll::Ready(Ok(()))
319    }
320
321    fn poll_close(
322        self: std::pin::Pin<&mut Self>,
323        _cx: &mut std::task::Context<'_>,
324    ) -> Poll<Result<(), Self::Error>> {
325        self.close_now();
326
327        Poll::Ready(Ok(()))
328    }
329}
330
331/// Creates an unbounded channel.
332///
333/// The `send` method on Senders created by this function will always succeed and return immediately
334/// as long as the channel is open.
335///
336/// # Note
337///
338/// This channel has an infinite buffer and can run out of memory if the channel is not actively
339/// drained.
340pub fn unbounded<T>() -> (UnboundedSender<T>, UnboundedReceiver<T>) {
341    let inner = Rc::new(UnsafeCell::new(Inner {
342        rx_waker: None,
343        closed: false,
344
345        sender_ctr: 1,
346        items: VecDeque::new(),
347        _marker: PhantomData,
348    }));
349
350    (
351        UnboundedSender {
352            inner: inner.clone(),
353        },
354        UnboundedReceiver { inner },
355    )
356}
357
358#[cfg(test)]
359mod tests {
360    use std::time::Duration;
361
362    use futures::sink::SinkExt;
363    use futures::stream::StreamExt;
364    use tokio::task::{spawn_local, LocalSet};
365    use tokio::test;
366    use tokio::time::sleep;
367
368    use super::*;
369
370    #[test]
371    async fn mpsc_works() {
372        let local_set = LocalSet::new();
373
374        local_set
375            .run_until(async {
376                let (tx, mut rx) = unbounded::<usize>();
377
378                spawn_local(async move {
379                    for i in 0..10 {
380                        (&tx).send(i).await.expect("failed to send.");
381                        sleep(Duration::from_millis(1)).await;
382                    }
383                });
384
385                for i in 0..10 {
386                    let received = rx.next().await.expect("failed to receive");
387
388                    assert_eq!(i, received);
389                }
390
391                assert_eq!(rx.next().await, None);
392            })
393            .await;
394    }
395
396    #[test]
397    async fn mpsc_drops_receiver() {
398        let (tx, rx) = unbounded::<usize>();
399        drop(rx);
400
401        (&tx).send(0).await.expect_err("should fail to send.");
402    }
403
404    #[test]
405    async fn mpsc_multi_sender() {
406        let local_set = LocalSet::new();
407
408        local_set
409            .run_until(async {
410                let (tx, mut rx) = unbounded::<usize>();
411
412                spawn_local(async move {
413                    let tx2 = tx.clone();
414
415                    for i in 0..10 {
416                        if i % 2 == 0 {
417                            (&tx).send(i).await.expect("failed to send.");
418                        } else {
419                            (&tx2).send(i).await.expect("failed to send.");
420                        }
421
422                        sleep(Duration::from_millis(1)).await;
423                    }
424
425                    drop(tx2);
426
427                    for i in 10..20 {
428                        (&tx).send(i).await.expect("failed to send.");
429
430                        sleep(Duration::from_millis(1)).await;
431                    }
432                });
433
434                for i in 0..20 {
435                    let received = rx.next().await.expect("failed to receive");
436
437                    assert_eq!(i, received);
438                }
439
440                assert_eq!(rx.next().await, None);
441            })
442            .await;
443    }
444
445    #[test]
446    async fn mpsc_drops_sender() {
447        let (tx, mut rx) = unbounded::<usize>();
448        drop(tx);
449
450        assert_eq!(rx.next().await, None);
451    }
452}