1use std::fmt;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::sync::atomic::{AtomicUsize, Ordering};
5use std::sync::Arc;
6use std::task::{Context, Poll};
7
8use futures::Sink;
9use pin_project::{pin_project, pinned_drop};
10
11use super::{Action, Reply, SendError, TrySendError, Waker};
12
13#[pin_project(PinnedDrop)]
14pub struct QueueSender<S: Waker, Item, F, R> {
15 #[pin]
16 s: S,
17 #[pin]
18 f: F,
19 num_senders: Arc<AtomicUsize>,
20 _item: PhantomData<Item>,
21 _r: PhantomData<R>,
22}
23
24unsafe impl<S: Waker, Item, F, R> Sync for QueueSender<S, Item, F, R> {}
25
26unsafe impl<S: Waker, Item, F, R> Send for QueueSender<S, Item, F, R> {}
27
28impl<S, Item, F, R> Clone for QueueSender<S, Item, F, R>
29where
30 S: Clone + Waker,
31 F: Clone,
32{
33 #[inline]
34 fn clone(&self) -> Self {
35 self.num_senders.fetch_add(1, Ordering::SeqCst);
37 Self {
39 s: self.s.clone(),
40 f: self.f.clone(),
41 num_senders: self.num_senders.clone(),
42 _item: PhantomData,
43 _r: PhantomData,
44 }
45 }
46}
47
48#[pinned_drop]
49impl<S: Waker, Item, F, R> PinnedDrop for QueueSender<S, Item, F, R> {
50 fn drop(self: Pin<&mut Self>) {
51 self.set_closed();
52 }
53}
54
55impl<S, Item, F, R> fmt::Debug for QueueSender<S, Item, F, R>
56where
57 S: fmt::Debug + Waker,
58{
59 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60 f.debug_struct("QueueSender")
61 .field("stream", &self.s)
62 .finish()
63 }
64}
65
66impl<S, Item, F, R> QueueSender<S, Item, F, R>
67where
68 S: Waker,
69{
70 #[inline]
71 fn set_closed(&self) -> usize {
72 let prev = self.num_senders.fetch_sub(1, Ordering::SeqCst);
73 if prev == 1 {
74 self.s.close_channel();
75 }
76 prev
77 }
78}
79
80impl<S, Item, F, R> QueueSender<S, Item, F, R>
81where
82 S: Waker,
83 F: Fn(&mut S, Action<Item>) -> Reply<R>,
84{
85 #[inline]
86 pub(super) fn new(s: S, f: F) -> Self {
87 Self {
88 s,
89 f,
90 num_senders: Arc::new(AtomicUsize::new(1)),
91 _item: PhantomData,
92 _r: PhantomData,
93 }
94 }
95
96 #[inline]
97 pub fn try_send(&mut self, item: Item) -> Result<R, TrySendError<Item>> {
98 if self.s.is_closed() {
99 return Err(SendError::disconnected(Some(item)));
100 }
101 if self.is_full() {
102 return Err(TrySendError::full(item));
103 }
104 let res = (self.f)(&mut self.s, Action::Send(item));
105 self.s.rx_wake();
106 if let Reply::Send(r) = res {
107 Ok(r)
108 } else {
109 unreachable!()
110 }
111 }
112
113 #[inline]
114 pub fn is_full(&mut self) -> bool {
115 match (self.f)(&mut self.s, Action::IsFull) {
116 Reply::IsFull(reply) => reply,
117 _ => unreachable!(),
118 }
119 }
120
121 #[inline]
122 pub fn is_empty(&mut self) -> bool {
123 match (self.f)(&mut self.s, Action::IsEmpty) {
124 Reply::IsEmpty(reply) => reply,
125 _ => unreachable!(),
126 }
127 }
128
129 #[inline]
130 pub fn len(&mut self) -> usize {
131 match (self.f)(&mut self.s, Action::Len) {
132 Reply::Len(reply) => reply,
133 _ => unreachable!(),
134 }
135 }
136}
137
138impl<S, Item, F, R> Sink<Item> for QueueSender<S, Item, F, R>
139where
140 S: Waker + Unpin,
141 F: Fn(&mut S, Action<Item>) -> Reply<R>,
142{
143 type Error = SendError<Item>;
144
145 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
146 if self.s.is_closed() {
147 return Poll::Ready(Err(SendError::disconnected(None)));
148 }
149 let mut this = self.project();
150 match (this.f)(&mut this.s, Action::IsFull) {
151 Reply::IsFull(true) => {
152 this.s.tx_park(cx.waker().clone());
153 Poll::Pending
154 }
155 Reply::IsFull(false) => Poll::Ready(Ok(())),
156 _ => unreachable!(),
157 }
158 }
159
160 fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
161 if self.s.is_closed() {
162 return Err(SendError::disconnected(Some(item)));
163 }
164 let mut this = self.project();
165 let _ = (this.f)(&mut this.s, Action::Send(item));
166 this.s.rx_wake();
167 Ok(())
168 }
169
170 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
171 if self.s.is_closed() {
172 return Poll::Ready(Err(SendError::disconnected(None)));
173 }
174 Poll::Ready(Ok(()))
175 }
176
177 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
178 if self.s.is_closed() {
179 return Poll::Ready(Err(SendError::disconnected(None)));
180 }
181 if self.set_closed() > 1 {
182 return Poll::Ready(Ok(()));
183 }
184
185 let mut this = self.project();
186 match (this.f)(&mut this.s, Action::IsEmpty) {
187 Reply::IsEmpty(true) => Poll::Ready(Ok(())),
188 Reply::IsEmpty(false) => {
189 this.s.tx_park(cx.waker().clone());
190 Poll::Pending
191 }
192 _ => unreachable!(),
193 }
194 }
195}
196
197impl<S: Unpin + Waker, Item, F, R> std::convert::AsMut<S> for QueueSender<S, Item, F, R> {
198 #[inline]
199 fn as_mut(&mut self) -> &mut S {
200 &mut self.s
201 }
202}
203
204impl<S: Waker, Item, F, R> std::convert::AsRef<S> for QueueSender<S, Item, F, R> {
205 #[inline]
206 fn as_ref(&self) -> &S {
207 &self.s
208 }
209}
210
211#[cfg(test)]
212use futures::task::noop_waker;
213#[cfg(test)]
214use std::collections::VecDeque;
215#[cfg(test)]
216use std::sync::atomic::AtomicBool;
217#[cfg(test)]
218use std::sync::Mutex;
219
220#[cfg(test)]
222#[derive(Clone)]
223struct TestStream {
224 queue: Arc<Mutex<VecDeque<i32>>>,
225 closed: Arc<AtomicBool>,
226}
227
228#[cfg(test)]
229impl TestStream {
230 fn new() -> Self {
231 TestStream {
232 queue: Arc::new(Mutex::new(VecDeque::new())),
233 closed: Arc::new(AtomicBool::new(false)),
234 }
235 }
236}
237
238#[cfg(test)]
239impl Waker for TestStream {
240 fn rx_wake(&self) {}
241 fn tx_park(&self, _w: std::task::Waker) {}
242 fn close_channel(&self) {
243 self.closed.store(true, Ordering::SeqCst);
244 }
245 fn is_closed(&self) -> bool {
246 self.closed.load(Ordering::SeqCst)
247 }
248}
249
250#[cfg(test)]
252fn bounded_handler(s: &mut TestStream, action: Action<i32>) -> Reply<i32> {
253 match action {
254 Action::Send(item) => {
255 s.queue.lock().unwrap().push_back(item);
256 Reply::Send(item)
257 }
258 Action::IsFull => Reply::IsFull(s.queue.lock().unwrap().len() >= 3),
259 Action::IsEmpty => Reply::IsEmpty(s.queue.lock().unwrap().is_empty()),
260 Action::Len => Reply::Len(s.queue.lock().unwrap().len()),
261 }
262}
263
264#[cfg(test)]
266fn unbounded_handler(s: &mut TestStream, action: Action<i32>) -> Reply<i32> {
267 match action {
268 Action::Send(item) => {
269 s.queue.lock().unwrap().push_back(item);
270 Reply::Send(item)
271 }
272 Action::IsFull => Reply::IsFull(false),
273 Action::IsEmpty => Reply::IsEmpty(s.queue.lock().unwrap().is_empty()),
274 Action::Len => Reply::Len(s.queue.lock().unwrap().len()),
275 }
276}
277
278#[test]
283fn try_send_ok() {
284 let mut sender = QueueSender::new(TestStream::new(), bounded_handler);
285 let r = sender.try_send(42);
286 assert!(r.is_ok());
287 assert_eq!(r.unwrap(), 42);
288}
289
290#[test]
291fn try_send_err_full() {
292 let mut sender = QueueSender::new(TestStream::new(), bounded_handler);
293 sender.try_send(1).unwrap();
294 sender.try_send(2).unwrap();
295 sender.try_send(3).unwrap();
296
297 let err = sender.try_send(4).unwrap_err();
298 assert!(err.is_full());
299 assert!(!err.is_disconnected());
300 assert_eq!(err.into_inner(), Some(4));
301}
302
303#[test]
304fn try_send_err_disconnected() {
305 let s = TestStream::new();
306 s.closed.store(true, Ordering::SeqCst);
307 let mut sender = QueueSender::new(s, unbounded_handler);
308 let err = sender.try_send(42).unwrap_err();
309 assert!(err.is_disconnected());
310 assert!(!err.is_full());
311 assert_eq!(err.into_inner(), Some(42));
312}
313
314#[test]
319fn state_methods() {
320 let mut sender = QueueSender::new(TestStream::new(), bounded_handler);
321
322 assert!(sender.is_empty());
323 assert!(!sender.is_full());
324 assert_eq!(sender.len(), 0);
325
326 sender.try_send(1).unwrap();
327 assert!(!sender.is_empty());
328 assert!(!sender.is_full());
329 assert_eq!(sender.len(), 1);
330
331 sender.try_send(2).unwrap();
332 sender.try_send(3).unwrap();
333 assert!(!sender.is_empty());
334 assert!(sender.is_full());
335 assert_eq!(sender.len(), 3);
336}
337
338#[test]
343fn send_error_full_ctor() {
344 let err = SendError::full(42i32);
345 assert!(err.is_full());
346 assert!(!err.is_disconnected());
347 assert_eq!(err.into_inner(), Some(42));
348}
349
350#[test]
351fn send_error_disconnected_ctor() {
352 let err = SendError::<i32>::disconnected(Some(99));
353 assert!(err.is_disconnected());
354 assert!(!err.is_full());
355 assert_eq!(err.into_inner(), Some(99));
356}
357
358#[test]
359fn send_error_disconnected_none() {
360 let err = SendError::<i32>::disconnected(None);
361 assert!(err.is_disconnected());
362 assert_eq!(err.into_inner(), None);
363}
364
365#[test]
366fn send_error_debug() {
367 let err = SendError::full(42i32);
368 let s = format!("{:?}", err);
369 assert!(s.contains("SendError"));
370}
371
372#[test]
373fn send_error_display_full() {
374 let err = SendError::full(42i32);
375 assert_eq!(format!("{}", err), "send failed because mpsc is full");
376}
377
378#[test]
379fn send_error_display_disconnected() {
380 let err = SendError::<i32>::disconnected(None);
381 assert_eq!(format!("{}", err), "send failed because receiver is gone");
382}
383
384#[test]
385fn send_error_clone_eq() {
386 let err1 = SendError::full(42i32);
387 let err2 = err1.clone();
388 assert_eq!(err1, err2);
389 assert!(err1.is_full());
390 assert_eq!(err2.into_inner(), Some(42));
391}
392
393#[test]
398fn sink_poll_ready_ok() {
399 let mut sender = QueueSender::new(TestStream::new(), bounded_handler);
400 let waker = noop_waker();
401 let mut cx = Context::from_waker(&waker);
402 assert_eq!(
403 Pin::new(&mut sender).poll_ready(&mut cx),
404 Poll::Ready(Ok(()))
405 );
406}
407
408#[test]
409fn sink_poll_ready_closed() {
410 let s = TestStream::new();
411 s.closed.store(true, Ordering::SeqCst);
412 let mut sender = QueueSender::new(s, unbounded_handler);
413 let waker = noop_waker();
414 let mut cx = Context::from_waker(&waker);
415 let r = Pin::new(&mut sender).poll_ready(&mut cx);
416 assert!(matches!(r, Poll::Ready(Err(ref e)) if e.is_disconnected()));
417}
418
419#[test]
420fn sink_start_send_ok() {
421 let mut sender = QueueSender::new(TestStream::new(), bounded_handler);
422 assert!(Pin::new(&mut sender).start_send(42).is_ok());
423}
424
425#[test]
426fn sink_start_send_closed() {
427 let s = TestStream::new();
428 s.closed.store(true, Ordering::SeqCst);
429 let mut sender = QueueSender::new(s, unbounded_handler);
430 let r = Pin::new(&mut sender).start_send(42);
431 assert!(r.is_err());
432 assert!(r.unwrap_err().is_disconnected());
433}
434
435#[test]
436fn sink_poll_flush_ok() {
437 let mut sender = QueueSender::new(TestStream::new(), bounded_handler);
438 let waker = noop_waker();
439 let mut cx = Context::from_waker(&waker);
440 assert_eq!(
441 Pin::new(&mut sender).poll_flush(&mut cx),
442 Poll::Ready(Ok(()))
443 );
444}
445
446#[test]
447fn sink_poll_flush_closed() {
448 let s = TestStream::new();
449 s.closed.store(true, Ordering::SeqCst);
450 let mut sender = QueueSender::new(s, unbounded_handler);
451 let waker = noop_waker();
452 let mut cx = Context::from_waker(&waker);
453 let r = Pin::new(&mut sender).poll_flush(&mut cx);
454 assert!(matches!(r, Poll::Ready(Err(ref e)) if e.is_disconnected()));
455}
456
457#[test]
458fn sink_poll_closes_single_sender() {
459 let mut sender = QueueSender::new(TestStream::new(), unbounded_handler);
460 let waker = noop_waker();
461 let mut cx = Context::from_waker(&waker);
462 assert_eq!(
464 Pin::new(&mut sender).poll_close(&mut cx),
465 Poll::Ready(Ok(()))
466 );
467}
468
469#[test]
474fn drop_last_sender_closes_channel() {
475 let s = TestStream::new();
476 let closed = s.closed.clone();
477 let sender = QueueSender::new(s, unbounded_handler);
478 drop(sender);
479 assert!(closed.load(Ordering::SeqCst));
480}
481
482#[test]
483fn drop_clone_does_not_close_immediately() {
484 let s = TestStream::new();
485 let closed = s.closed.clone();
486 let sender1 = QueueSender::new(s, unbounded_handler);
487 let sender2 = sender1.clone();
488
489 drop(sender2);
490 assert!(!closed.load(Ordering::SeqCst));
492
493 drop(sender1);
494 assert!(closed.load(Ordering::SeqCst));
496}
497
498#[test]
503fn sender_is_send_sync() {
504 fn assert_send<T: Send>(_t: &T) {}
505 fn assert_sync<T: Sync>(_t: &T) {}
506
507 let s = TestStream::new();
508 let sender = QueueSender::new(s, unbounded_handler);
509 assert_send(&sender);
510 assert_sync(&sender);
511}