1use std::any::Any;
3use std::collections::VecDeque;
4use std::error::Error;
5use std::fmt;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9use futures::{Sink, Stream};
10
11use crate::cell::Cell;
12use crate::task::LocalWaker;
13
14pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
16 let shared = Cell::new(Shared {
17 has_receiver: true,
18 buffer: VecDeque::new(),
19 blocked_recv: LocalWaker::new(),
20 });
21 let sender = Sender {
22 shared: shared.clone(),
23 };
24 let receiver = Receiver { shared };
25 (sender, receiver)
26}
27
28#[derive(Debug)]
29struct Shared<T> {
30 buffer: VecDeque<T>,
31 blocked_recv: LocalWaker,
32 has_receiver: bool,
33}
34
35#[derive(Debug)]
39pub struct Sender<T> {
40 shared: Cell<Shared<T>>,
41}
42
43impl<T> Unpin for Sender<T> {}
44
45impl<T> Sender<T> {
46 pub fn send(&self, item: T) -> Result<(), SendError<T>> {
48 let shared = unsafe { self.shared.get_mut_unsafe() };
49 if !shared.has_receiver {
50 return Err(SendError(item)); };
52 shared.buffer.push_back(item);
53 shared.blocked_recv.wake();
54 Ok(())
55 }
56
57 pub fn close(&mut self) {
62 self.shared.get_mut().has_receiver = false;
63 }
64}
65
66impl<T> Clone for Sender<T> {
67 fn clone(&self) -> Self {
68 Sender {
69 shared: self.shared.clone(),
70 }
71 }
72}
73
74impl<T> Sink<T> for Sender<T> {
75 type Error = SendError<T>;
76
77 fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
78 Poll::Ready(Ok(()))
79 }
80
81 fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), SendError<T>> {
82 self.send(item)
83 }
84
85 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), SendError<T>>> {
86 Poll::Ready(Ok(()))
87 }
88
89 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
90 Poll::Ready(Ok(()))
91 }
92}
93
94impl<T> Drop for Sender<T> {
95 fn drop(&mut self) {
96 let count = self.shared.strong_count();
97 let shared = self.shared.get_mut();
98
99 if shared.has_receiver && count == 2 {
101 shared.blocked_recv.wake();
103 }
104 }
105}
106
107#[derive(Debug)]
111pub struct Receiver<T> {
112 shared: Cell<Shared<T>>,
113}
114
115impl<T> Receiver<T> {
116 pub fn sender(&self) -> Sender<T> {
118 Sender {
119 shared: self.shared.clone(),
120 }
121 }
122}
123
124impl<T> Unpin for Receiver<T> {}
125
126impl<T> Stream for Receiver<T> {
127 type Item = T;
128
129 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
130 if self.shared.strong_count() == 1 {
131 Poll::Ready(self.shared.get_mut().buffer.pop_front())
134 } else if let Some(msg) = self.shared.get_mut().buffer.pop_front() {
135 Poll::Ready(Some(msg))
136 } else {
137 self.shared.get_mut().blocked_recv.register(cx.waker());
138 Poll::Pending
139 }
140 }
141}
142
143impl<T> Drop for Receiver<T> {
144 fn drop(&mut self) {
145 let shared = self.shared.get_mut();
146 shared.buffer.clear();
147 shared.has_receiver = false;
148 }
149}
150
151pub struct SendError<T>(T);
154
155impl<T> fmt::Debug for SendError<T> {
156 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
157 fmt.debug_tuple("SendError").field(&"...").finish()
158 }
159}
160
161impl<T> fmt::Display for SendError<T> {
162 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
163 write!(fmt, "send failed because receiver is gone")
164 }
165}
166
167impl<T: Any> Error for SendError<T> {
168 fn description(&self) -> &str {
169 "send failed because receiver is gone"
170 }
171}
172
173impl<T> SendError<T> {
174 pub fn into_inner(self) -> T {
176 self.0
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183 use futures::future::lazy;
184 use futures::{Stream, StreamExt};
185
186 #[scrappy_rt::test]
187 async fn test_mpsc() {
188 let (tx, mut rx) = channel();
189 tx.send("test").unwrap();
190 assert_eq!(rx.next().await.unwrap(), "test");
191
192 let tx2 = tx.clone();
193 tx2.send("test2").unwrap();
194 assert_eq!(rx.next().await.unwrap(), "test2");
195
196 assert_eq!(
197 lazy(|cx| Pin::new(&mut rx).poll_next(cx)).await,
198 Poll::Pending
199 );
200 drop(tx2);
201 assert_eq!(
202 lazy(|cx| Pin::new(&mut rx).poll_next(cx)).await,
203 Poll::Pending
204 );
205 drop(tx);
206 assert_eq!(rx.next().await, None);
207
208 let (tx, rx) = channel();
209 tx.send("test").unwrap();
210 drop(rx);
211 assert!(tx.send("test").is_err());
212
213 let (mut tx, _) = channel();
214 let tx2 = tx.clone();
215 tx.close();
216 assert!(tx.send("test").is_err());
217 assert!(tx2.send("test").is_err());
218 }
219}