ntex_util/channel/
mpsc.rs1use std::collections::VecDeque;
3use std::future::poll_fn;
4use std::{fmt, panic::UnwindSafe, pin::Pin, task::Context, task::Poll};
5
6use futures_core::{FusedStream, Stream};
7
8use super::cell::Cell;
9use crate::task::LocalWaker;
10
11pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
13 let shared = Cell::new(Shared {
14 has_receiver: true,
15 buffer: VecDeque::new(),
16 blocked_recv: LocalWaker::new(),
17 });
18 let sender = Sender {
19 shared: shared.clone(),
20 };
21 let receiver = Receiver { shared };
22 (sender, receiver)
23}
24
25#[derive(Debug)]
26struct Shared<T> {
27 buffer: VecDeque<T>,
28 blocked_recv: LocalWaker,
29 has_receiver: bool,
30}
31
32#[derive(Debug)]
36pub struct Sender<T> {
37 shared: Cell<Shared<T>>,
38}
39
40impl<T> Unpin for Sender<T> {}
41
42impl<T> Sender<T> {
43 pub fn send(&self, item: T) -> Result<(), SendError<T>> {
45 let shared = self.shared.get_mut();
46 if !shared.has_receiver {
47 return Err(SendError(item)); };
49 shared.buffer.push_back(item);
50 shared.blocked_recv.wake();
51 Ok(())
52 }
53
54 pub fn close(&self) {
59 let shared = self.shared.get_mut();
60 shared.has_receiver = false;
61 shared.blocked_recv.wake();
62 }
63
64 pub fn is_closed(&self) -> bool {
66 self.shared.strong_count() == 1 || !self.shared.get_ref().has_receiver
67 }
68}
69
70impl<T> Clone for Sender<T> {
71 fn clone(&self) -> Self {
72 Sender {
73 shared: self.shared.clone(),
74 }
75 }
76}
77
78impl<T> Drop for Sender<T> {
79 fn drop(&mut self) {
80 let count = self.shared.strong_count();
81 let shared = self.shared.get_mut();
82
83 if shared.has_receiver && count == 2 {
85 shared.blocked_recv.wake();
87 }
88 }
89}
90
91#[derive(Debug)]
95pub struct Receiver<T> {
96 shared: Cell<Shared<T>>,
97}
98
99impl<T> Receiver<T> {
100 pub fn sender(&self) -> Sender<T> {
102 Sender {
103 shared: self.shared.clone(),
104 }
105 }
106
107 pub fn close(&self) {
112 self.shared.get_mut().has_receiver = false;
113 }
114
115 pub fn is_closed(&self) -> bool {
117 self.shared.strong_count() == 1 || !self.shared.get_ref().has_receiver
118 }
119
120 pub async fn recv(&self) -> Option<T> {
124 poll_fn(|cx| self.poll_recv(cx)).await
125 }
126
127 pub fn poll_recv(&self, cx: &mut Context<'_>) -> Poll<Option<T>> {
131 let shared = self.shared.get_mut();
132
133 if let Some(msg) = shared.buffer.pop_front() {
134 Poll::Ready(Some(msg))
135 } else if shared.has_receiver {
136 shared.blocked_recv.register(cx.waker());
137 if self.shared.strong_count() == 1 {
138 Poll::Ready(None)
141 } else {
142 Poll::Pending
143 }
144 } else {
145 Poll::Ready(None)
146 }
147 }
148}
149
150impl<T> Unpin for Receiver<T> {}
151
152impl<T> Stream for Receiver<T> {
153 type Item = T;
154
155 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
156 self.poll_recv(cx)
157 }
158}
159
160impl<T> FusedStream for Receiver<T> {
161 fn is_terminated(&self) -> bool {
162 self.is_closed()
163 }
164}
165
166impl<T> UnwindSafe for Receiver<T> {}
167
168impl<T> Drop for Receiver<T> {
169 fn drop(&mut self) {
170 let shared = self.shared.get_mut();
171 shared.buffer.clear();
172 shared.has_receiver = false;
173 }
174}
175
176pub struct SendError<T>(T);
179
180impl<T> std::error::Error for SendError<T> {}
181
182impl<T> fmt::Debug for SendError<T> {
183 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
184 fmt.debug_tuple("SendError").field(&"...").finish()
185 }
186}
187
188impl<T> fmt::Display for SendError<T> {
189 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
190 write!(fmt, "send failed because receiver is gone")
191 }
192}
193
194impl<T> SendError<T> {
195 pub fn into_inner(self) -> T {
197 self.0
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204 use crate::{future::lazy, future::stream_recv};
205
206 #[ntex_macros::rt_test2]
207 async fn test_mpsc() {
208 let (tx, mut rx) = channel();
209 assert!(format!("{tx:?}").contains("Sender"));
210 assert!(format!("{rx:?}").contains("Receiver"));
211
212 tx.send("test").unwrap();
213 assert_eq!(stream_recv(&mut rx).await.unwrap(), "test");
214
215 let tx2 = tx.clone();
216 tx2.send("test2").unwrap();
217 assert_eq!(stream_recv(&mut rx).await.unwrap(), "test2");
218
219 assert_eq!(
220 lazy(|cx| Pin::new(&mut rx).poll_next(cx)).await,
221 Poll::Pending
222 );
223 drop(tx2);
224 assert_eq!(
225 lazy(|cx| Pin::new(&mut rx).poll_next(cx)).await,
226 Poll::Pending
227 );
228 drop(tx);
229
230 let (tx, mut rx) = channel::<String>();
231 tx.close();
232 assert_eq!(stream_recv(&mut rx).await, None);
233
234 let (tx, rx) = channel();
235 tx.send("test").unwrap();
236 drop(rx);
237 assert!(tx.send("test").is_err());
238
239 let (tx, _) = channel();
240 let tx2 = tx.clone();
241 tx.close();
242 assert!(tx.send("test").is_err());
243 assert!(tx2.send("test").is_err());
244
245 let err = SendError("test");
246 assert!(format!("{err:?}").contains("SendError"));
247 assert!(format!("{err}").contains("send failed because receiver is gone"));
248 assert_eq!(err.into_inner(), "test");
249 }
250
251 #[ntex_macros::rt_test2]
252 async fn test_close() {
253 let (tx, rx) = channel::<()>();
254 assert!(!tx.is_closed());
255 assert!(!rx.is_closed());
256 assert!(!rx.is_terminated());
257
258 tx.close();
259 assert!(tx.is_closed());
260 assert!(rx.is_closed());
261 assert!(rx.is_terminated());
262
263 let (tx, rx) = channel::<()>();
264 rx.close();
265 assert!(tx.is_closed());
266
267 let (tx, rx) = channel::<()>();
268 drop(tx);
269 assert!(rx.is_closed());
270 assert!(rx.is_terminated());
271 let _tx = rx.sender();
272 assert!(!rx.is_closed());
273 assert!(!rx.is_terminated());
274 }
275}