ntex_util/channel/
mpsc.rs

1//! A multi-producer, single-consumer, futures-aware, FIFO queue.
2use std::collections::VecDeque;
3use std::future::poll_fn;
4use std::{fmt, panic::UnwindSafe, pin::Pin, task::Context, task::Poll};
5
6use futures_core::{FusedStream, Stream};
7
8use super::cell::Cell;
9use crate::task::LocalWaker;
10
11/// Creates a unbounded in-memory channel with buffered storage.
12pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
13    let shared = Cell::new(Shared {
14        has_receiver: true,
15        buffer: VecDeque::new(),
16        blocked_recv: LocalWaker::new(),
17    });
18    let sender = Sender {
19        shared: shared.clone(),
20    };
21    let receiver = Receiver { shared };
22    (sender, receiver)
23}
24
25#[derive(Debug)]
26struct Shared<T> {
27    buffer: VecDeque<T>,
28    blocked_recv: LocalWaker,
29    has_receiver: bool,
30}
31
32/// The transmission end of a channel.
33///
34/// This is created by the `channel` function.
35#[derive(Debug)]
36pub struct Sender<T> {
37    shared: Cell<Shared<T>>,
38}
39
40impl<T> Unpin for Sender<T> {}
41
42impl<T> Sender<T> {
43    /// Sends the provided message along this channel.
44    pub fn send(&self, item: T) -> Result<(), SendError<T>> {
45        let shared = self.shared.get_mut();
46        if !shared.has_receiver {
47            return Err(SendError(item)); // receiver was dropped
48        };
49        shared.buffer.push_back(item);
50        shared.blocked_recv.wake();
51        Ok(())
52    }
53
54    /// Closes the sender half
55    ///
56    /// This prevents any further messages from being sent on the channel while
57    /// still enabling the receiver to drain messages that are buffered.
58    pub fn close(&self) {
59        let shared = self.shared.get_mut();
60        shared.has_receiver = false;
61        shared.blocked_recv.wake();
62    }
63
64    /// Returns whether this channel is closed without needing a context.
65    pub fn is_closed(&self) -> bool {
66        self.shared.strong_count() == 1 || !self.shared.get_ref().has_receiver
67    }
68}
69
70impl<T> Clone for Sender<T> {
71    fn clone(&self) -> Self {
72        Sender {
73            shared: self.shared.clone(),
74        }
75    }
76}
77
78impl<T> Drop for Sender<T> {
79    fn drop(&mut self) {
80        let count = self.shared.strong_count();
81        let shared = self.shared.get_mut();
82
83        // check is last sender is about to drop
84        if shared.has_receiver && count == 2 {
85            // Wake up receiver as its stream has ended
86            shared.blocked_recv.wake();
87        }
88    }
89}
90
91/// The receiving end of a channel which implements the `Stream` trait.
92///
93/// This is created by the `channel` function.
94#[derive(Debug)]
95pub struct Receiver<T> {
96    shared: Cell<Shared<T>>,
97}
98
99impl<T> Receiver<T> {
100    /// Create a Sender
101    pub fn sender(&self) -> Sender<T> {
102        Sender {
103            shared: self.shared.clone(),
104        }
105    }
106
107    /// Closes the receiving half of a channel, without dropping it.
108    ///
109    /// This prevents any further messages from being sent on the channel
110    /// while still enabling the receiver to drain messages that are buffered.
111    pub fn close(&self) {
112        self.shared.get_mut().has_receiver = false;
113    }
114
115    /// Returns whether this channel is closed without needing a context.
116    pub fn is_closed(&self) -> bool {
117        self.shared.strong_count() == 1 || !self.shared.get_ref().has_receiver
118    }
119
120    /// Attempt to pull out the next value of this receiver, registering
121    /// the current task for wakeup if the value is not yet available,
122    /// and returning None if the stream is exhausted.
123    pub async fn recv(&self) -> Option<T> {
124        poll_fn(|cx| self.poll_recv(cx)).await
125    }
126
127    /// Attempt to pull out the next value of this receiver, registering
128    /// the current task for wakeup if the value is not yet available,
129    /// and returning None if the stream is exhausted.
130    pub fn poll_recv(&self, cx: &mut Context<'_>) -> Poll<Option<T>> {
131        let shared = self.shared.get_mut();
132
133        if let Some(msg) = shared.buffer.pop_front() {
134            Poll::Ready(Some(msg))
135        } else if shared.has_receiver {
136            shared.blocked_recv.register(cx.waker());
137            if self.shared.strong_count() == 1 {
138                // All senders have been dropped, so drain the buffer and end the
139                // stream.
140                Poll::Ready(None)
141            } else {
142                Poll::Pending
143            }
144        } else {
145            Poll::Ready(None)
146        }
147    }
148}
149
150impl<T> Unpin for Receiver<T> {}
151
152impl<T> Stream for Receiver<T> {
153    type Item = T;
154
155    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
156        self.poll_recv(cx)
157    }
158}
159
160impl<T> FusedStream for Receiver<T> {
161    fn is_terminated(&self) -> bool {
162        self.is_closed()
163    }
164}
165
166impl<T> UnwindSafe for Receiver<T> {}
167
168impl<T> Drop for Receiver<T> {
169    fn drop(&mut self) {
170        let shared = self.shared.get_mut();
171        shared.buffer.clear();
172        shared.has_receiver = false;
173    }
174}
175
176/// Error type for sending, used when the receiving end of a channel is
177/// dropped
178pub struct SendError<T>(T);
179
180impl<T> std::error::Error for SendError<T> {}
181
182impl<T> fmt::Debug for SendError<T> {
183    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
184        fmt.debug_tuple("SendError").field(&"...").finish()
185    }
186}
187
188impl<T> fmt::Display for SendError<T> {
189    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
190        write!(fmt, "send failed because receiver is gone")
191    }
192}
193
194impl<T> SendError<T> {
195    /// Returns the message that was attempted to be sent but failed.
196    pub fn into_inner(self) -> T {
197        self.0
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204    use crate::{future::lazy, future::stream_recv};
205
206    #[ntex_macros::rt_test2]
207    async fn test_mpsc() {
208        let (tx, mut rx) = channel();
209        assert!(format!("{tx:?}").contains("Sender"));
210        assert!(format!("{rx:?}").contains("Receiver"));
211
212        tx.send("test").unwrap();
213        assert_eq!(stream_recv(&mut rx).await.unwrap(), "test");
214
215        let tx2 = tx.clone();
216        tx2.send("test2").unwrap();
217        assert_eq!(stream_recv(&mut rx).await.unwrap(), "test2");
218
219        assert_eq!(
220            lazy(|cx| Pin::new(&mut rx).poll_next(cx)).await,
221            Poll::Pending
222        );
223        drop(tx2);
224        assert_eq!(
225            lazy(|cx| Pin::new(&mut rx).poll_next(cx)).await,
226            Poll::Pending
227        );
228        drop(tx);
229
230        let (tx, mut rx) = channel::<String>();
231        tx.close();
232        assert_eq!(stream_recv(&mut rx).await, None);
233
234        let (tx, rx) = channel();
235        tx.send("test").unwrap();
236        drop(rx);
237        assert!(tx.send("test").is_err());
238
239        let (tx, _) = channel();
240        let tx2 = tx.clone();
241        tx.close();
242        assert!(tx.send("test").is_err());
243        assert!(tx2.send("test").is_err());
244
245        let err = SendError("test");
246        assert!(format!("{err:?}").contains("SendError"));
247        assert!(format!("{err}").contains("send failed because receiver is gone"));
248        assert_eq!(err.into_inner(), "test");
249    }
250
251    #[ntex_macros::rt_test2]
252    async fn test_close() {
253        let (tx, rx) = channel::<()>();
254        assert!(!tx.is_closed());
255        assert!(!rx.is_closed());
256        assert!(!rx.is_terminated());
257
258        tx.close();
259        assert!(tx.is_closed());
260        assert!(rx.is_closed());
261        assert!(rx.is_terminated());
262
263        let (tx, rx) = channel::<()>();
264        rx.close();
265        assert!(tx.is_closed());
266
267        let (tx, rx) = channel::<()>();
268        drop(tx);
269        assert!(rx.is_closed());
270        assert!(rx.is_terminated());
271        let _tx = rx.sender();
272        assert!(!rx.is_closed());
273        assert!(!rx.is_terminated());
274    }
275}