mpmc_async/
lib.rs

1//! A multi-producer, multi-consumer async channel with reservations.
2//!
3//! Example usage:
4//!
5//! ```rust
6//! tokio_test::block_on(async {
7//!     let (tx1, rx1) = mpmc_async::channel(2);
8//!
9//!     let task = tokio::spawn(async move {
10//!       let rx2 = rx1.clone();
11//!       assert_eq!(rx1.recv().await.unwrap(), 2);
12//!       assert_eq!(rx2.recv().await.unwrap(), 1);
13//!     });
14//!
15//!     let tx2 = tx1.clone();
16//!     let permit = tx1.reserve().await.unwrap();
17//!     tx2.send(1).await.unwrap();
18//!     permit.send(2);
19//!
20//!     task.await.unwrap();
21//! });
22//! ```
23//!
24//! A more complex example with multiple sender and receiver tasks:
25//!
26//! ```rust
27//! use std::collections::BTreeSet;
28//! use std::ops::DerefMut;
29//! use std::sync::{Arc, Mutex};
30//!
31//! tokio_test::block_on(async {
32//!     let (tx, rx) = mpmc_async::channel(1);
33//!
34//!     let num_workers = 10;
35//!     let count = 10;
36//!     let mut tasks = Vec::with_capacity(num_workers);
37//!
38//!     for i in 0..num_workers {
39//!         let mut tx = tx.clone();
40//!         let task = tokio::spawn(async move {
41//!             for j in 0..count {
42//!                 let val = i * count + j;
43//!                 tx.reserve().await.expect("no error").send(val);
44//!             }
45//!         });
46//!         tasks.push(task);
47//!     }
48//!
49//!     let total = count * num_workers;
50//!     let values = Arc::new(Mutex::new(BTreeSet::new()));
51//!
52//!     for _ in 0..num_workers {
53//!         let values = values.clone();
54//!         let rx = rx.clone();
55//!         let task = tokio::spawn(async move {
56//!             for _ in 0..count {
57//!                 let val = rx.recv().await.expect("Failed to recv");
58//!                 values.lock().unwrap().insert(val);
59//!             }
60//!         });
61//!         tasks.push(task);
62//!     }
63//!
64//!     for task in tasks {
65//!         task.await.expect("failed to join task");
66//!     }
67//!
68//!     let exp = (0..total).collect::<Vec<_>>();
69//!     let got = std::mem::take(values.lock().unwrap().deref_mut())
70//!         .into_iter()
71//!         .collect::<Vec<_>>();
72//!     assert_eq!(exp, got);
73//! });
74//! ```
75mod linked_list;
76mod queue;
77mod state;
78
79use std::fmt::{Debug, Display};
80use std::future::Future;
81use std::ops::DerefMut;
82use std::pin::Pin;
83use std::task::{Context, Poll, Waker};
84
85use self::linked_list::NodeRef;
86use self::queue::Spot;
87use self::state::SendWaker;
88use self::state::State;
89
90/// Occurs when all receivers have been dropped.
91#[derive(PartialEq, Eq, Clone, Copy, Debug)]
92pub struct SendError<T>(pub T);
93
94impl<T> Display for SendError<T> {
95    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96        f.write_str("Send failed: disconnected")
97    }
98}
99
100/// TrySendError occurs when the channel is empty or all receivers have been dropped.
101#[derive(PartialEq, Eq, Clone, Copy, Debug)]
102pub enum TrySendError<T> {
103    /// Channel full
104    Full(T),
105    /// Disconnected
106    Disconnected(T),
107}
108
109impl<T> Display for TrySendError<T> {
110    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
111        match self {
112            TrySendError::Full(_) => f.write_str("Try send failed: full"),
113            TrySendError::Disconnected(_) => f.write_str("Try send failed: disconnected"),
114        }
115    }
116}
117
118/// Occurs when all receivers have been dropped.
119#[derive(PartialEq, Eq, Clone, Copy, Debug)]
120pub struct ReserveError;
121
122impl Display for ReserveError {
123    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124        f.write_str("Reserve failed: disconnected")
125    }
126}
127
128/// Occurs when the channel is full, or all receivers have been dropped.
129#[derive(PartialEq, Eq, Clone, Copy, Debug)]
130pub enum TryReserveError {
131    /// Channel full
132    Full,
133    /// Disconnected
134    Disconnected,
135}
136
137impl Display for TryReserveError {
138    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
139        match self {
140            TryReserveError::Full => f.write_str("Try send failed: full"),
141            TryReserveError::Disconnected => f.write_str("Try send failed: disconnected"),
142        }
143    }
144}
145
146/// Occurs when all senders have been dropped.
147#[derive(PartialEq, Eq, Clone, Copy, Debug)]
148pub struct RecvError;
149
150impl Display for RecvError {
151    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152        f.write_str("Recv failed: disconnected")
153    }
154}
155
156/// Occurs when channel is empty or all senders have been dropped.
157#[derive(PartialEq, Eq, Clone, Copy, Debug)]
158pub enum TryRecvError {
159    /// Channel is empty
160    Empty,
161    /// Disconnected
162    Disconnected,
163}
164
165impl Display for TryRecvError {
166    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
167        match self {
168            TryRecvError::Empty => f.write_str("Try recv failed: empty"),
169            TryRecvError::Disconnected => f.write_str("Try recv failed: disconnected"),
170        }
171    }
172}
173
174/// Creates a new bounded channel. When `cap` is 0 it will be increased to 1.
175pub fn channel<T>(mut cap: usize) -> (Sender<T>, Receiver<T>) {
176    if cap == 0 {
177        cap = 1
178    }
179
180    let state = State::new(cap);
181
182    (state.new_sender(), state.new_receiver())
183}
184
185/// Receives messages sent by [Sender].
186pub struct Receiver<T> {
187    state: State<T>,
188}
189
190/// Cloning creates new a instance with the shared state  and increases the internal reference
191/// counter. It is guaranteed that a single message will be distributed to exacly one receiver
192/// future awaited after calling `recv()`.
193impl<T> Clone for Receiver<T> {
194    fn clone(&self) -> Self {
195        self.state.new_receiver()
196    }
197}
198
199impl<T> Receiver<T> {
200    fn new(state: State<T>) -> Self {
201        Self { state }
202    }
203
204    /// Disconnects all receivers from senders. The receivers will still receive all buffered
205    /// or reserved messages before returning an error, allowing a graceful teardown.
206    pub fn close_all(&self) {
207        self.state.close_all_receivers();
208    }
209
210    /// Waits until there's a message to be read and returns. Returns an error when there are no
211    /// more messages in the queue and all [Sender]s have been dropped.
212    pub async fn recv(&self) -> Result<T, RecvError> {
213        let recv = RecvFuture::new(self);
214        recv.await
215    }
216
217    /// Checks if there's a message to be read and returns immediately. Returns an error when the
218    /// channel is disconnected or empty.
219    pub fn try_recv(&self) -> Result<T, TryRecvError> {
220        self.state.try_recv()
221    }
222
223    pub async fn recv_many(&self, vec: &mut Vec<T>, count: usize) -> Result<usize, RecvError> {
224        let recv = RecvManyFuture::new(self, vec, count);
225        recv.await
226    }
227
228    pub fn try_recv_many(&self, vec: &mut Vec<T>, count: usize) -> Result<usize, TryRecvError> {
229        self.state.try_recv_many(vec, count)
230    }
231}
232
233/// The last reciever that's dropped will mark the channel as disconnected.
234impl<T> Drop for Receiver<T> {
235    fn drop(&mut self) {
236        self.state.drop_receiver();
237    }
238}
239
240/// Producers messages to be read by [Receiver]s.
241pub struct Sender<T> {
242    state: State<T>,
243}
244
245impl<T> Debug for Sender<T> {
246    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
247        f.write_str("Sender")
248    }
249}
250
251/// Cloning creates new a instance with the shared state  and increases the internal reference
252/// counter.
253impl<T> Clone for Sender<T> {
254    fn clone(&self) -> Self {
255        self.state.new_sender()
256    }
257}
258
259impl<T> Sender<T> {
260    fn new(state: State<T>) -> Sender<T> {
261        Self { state }
262    }
263
264    /// Waits until the value is sent or returns an when all receivers have been dropped.
265    pub async fn send(&self, value: T) -> Result<(), SendError<T>> {
266        SendFuture::new(&self.state, value).await
267    }
268
269    /// Sends without blocking or returns an when all receivers have been dropped.
270    pub fn try_send(&self, value: T) -> Result<(), TrySendError<T>> {
271        self.state.try_send(value)
272    }
273
274    /// Waits until a permit is reserved or returns an when all receivers have been dropped.
275    pub async fn reserve(&self) -> Result<Permit<'_, T>, ReserveError> {
276        let reservation = ReserveFuture::new(&self.state, 1).await?;
277        Ok(Permit::new(self, reservation))
278    }
279
280    /// Reserves a permit or returns an when all receivers have been dropped.
281    pub fn try_reserve(&self) -> Result<Permit<'_, T>, TryReserveError> {
282        let reservation = self.state.try_reserve(1)?;
283        Ok(Permit::new(self, reservation))
284    }
285
286    /// Waits until multiple Permits in the queue are reserved.
287    pub async fn reserve_many(&self, count: usize) -> Result<PermitIterator<'_, T>, ReserveError> {
288        let reservation = ReserveFuture::new(&self.state, count).await?;
289        Ok(PermitIterator::new(self, reservation, count))
290    }
291
292    /// Reserves multiple Permits in the queue, or errors out when there's no room.
293    pub fn try_reserve_many(&self, count: usize) -> Result<PermitIterator<'_, T>, TryReserveError> {
294        let reservation = self.state.try_reserve(count)?;
295        Ok(PermitIterator::new(self, reservation, count))
296    }
297
298    /// Like [Sender::reserve], but takes ownership of `Sender` until sending is done.
299    pub async fn reserve_owned(self) -> Result<OwnedPermit<T>, ReserveError> {
300        let reservation = ReserveFuture::new(&self.state, 1).await?;
301        Ok(OwnedPermit::new(self, reservation))
302    }
303
304    /// Reserves a permit or returns an when all receivers have been dropped.
305    pub fn try_reserve_owned(self) -> Result<OwnedPermit<T>, TrySendError<Self>> {
306        let reservation = match self.state.try_reserve(1) {
307            Ok(reservation) => reservation,
308            Err(TryReserveError::Disconnected) => return Err(TrySendError::Disconnected(self)),
309            Err(TryReserveError::Full) => return Err(TrySendError::Full(self)),
310        };
311        Ok(OwnedPermit::new(self, reservation))
312    }
313}
314
315/// The last sender that's dropped will mark the channel as disconnected.
316impl<T> Drop for Sender<T> {
317    fn drop(&mut self) {
318        self.state.drop_sender();
319    }
320}
321
322/// Permit holds a spot in the internal buffer so the message can be sent without awaiting.
323pub struct Permit<'a, T> {
324    sender: &'a Sender<T>,
325    reservation: Option<NodeRef<Spot<T>>>,
326}
327
328impl<'a, T> Permit<'a, T> {
329    fn new(sender: &'a Sender<T>, reservation: NodeRef<Spot<T>>) -> Self {
330        Self {
331            sender,
332            reservation: Some(reservation),
333        }
334    }
335
336    /// Writes a message to the internal buffer.
337    pub fn send(mut self, value: T) {
338        let reservation = self.reservation.take().expect("reservation");
339        self.sender.state.send_with_permit(reservation, value);
340    }
341}
342
343impl<'a, T> Drop for Permit<'a, T> {
344    fn drop(&mut self) {
345        if let Some(reservation) = self.reservation.take() {
346            self.sender.state.drop_permit(reservation, 1);
347        }
348    }
349}
350
351pub struct PermitIterator<'a, T> {
352    sender: &'a Sender<T>,
353    reservation: Option<NodeRef<Spot<T>>>,
354    count: usize,
355}
356
357impl<'a, T> PermitIterator<'a, T> {
358    pub fn new(sender: &'a Sender<T>, reservation: NodeRef<Spot<T>>, count: usize) -> Self {
359        Self {
360            sender,
361            reservation: Some(reservation),
362            count,
363        }
364    }
365}
366
367impl<'a, T> Iterator for PermitIterator<'a, T> {
368    type Item = Permit<'a, T>;
369
370    fn next(&mut self) -> Option<Self::Item> {
371        if self.count == 0 {
372            None
373        } else {
374            self.count -= 1;
375
376            let reservation = if self.count == 0 {
377                self.reservation.take().expect("reservation")
378            } else {
379                self.reservation.clone().expect("reservation")
380            };
381
382            Some(Permit::new(self.sender, reservation))
383        }
384    }
385}
386
387impl<'a, T> Drop for PermitIterator<'a, T> {
388    fn drop(&mut self) {
389        if let Some(reservation) = self.reservation.take() {
390            self.sender.state.drop_permit(reservation, self.count);
391        }
392    }
393}
394
395pub struct OwnedPermit<T> {
396    sender_and_reservation: Option<(Sender<T>, NodeRef<Spot<T>>)>,
397}
398
399impl<T> OwnedPermit<T> {
400    fn new(sender: Sender<T>, reservation: NodeRef<Spot<T>>) -> Self {
401        Self {
402            sender_and_reservation: Some((sender, reservation)),
403        }
404    }
405
406    /// Writes a message to the internal buffer.
407    pub fn send(mut self, value: T) -> Sender<T> {
408        let (sender, reservation) = self
409            .sender_and_reservation
410            .take()
411            .expect("sender and reservation");
412
413        sender.state.send_with_permit(reservation, value);
414
415        sender
416    }
417
418    pub fn release(mut self) -> Sender<T> {
419        let (sender, reservation) = self
420            .sender_and_reservation
421            .take()
422            .expect("sender and reservation");
423
424        sender.state.drop_permit(reservation, 1);
425
426        sender
427    }
428}
429
430impl<T> Drop for OwnedPermit<T> {
431    fn drop(&mut self) {
432        // if we haven't called send or release:
433        if let Some((sender, reservation)) = self.sender_and_reservation.take() {
434            sender.state.drop_permit(reservation, 1);
435        }
436    }
437}
438
439struct SendFuture<'a, T> {
440    state: &'a State<T>,
441    value: Option<T>,
442    waiting: Option<NodeRef<SendWaker>>,
443}
444
445impl<'a, T> Unpin for SendFuture<'a, T> {}
446
447impl<'a, T> SendFuture<'a, T> {
448    fn new(state: &'a State<T>, value: T) -> Self {
449        Self {
450            state,
451            value: Some(value),
452            waiting: None,
453        }
454    }
455}
456
457impl<'a, T> Future for SendFuture<'a, T> {
458    type Output = Result<(), SendError<T>>;
459
460    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
461        let this = self.deref_mut();
462
463        this.state.send(&mut this.value, cx, &mut this.waiting)
464    }
465}
466
467impl<'a, T> Drop for SendFuture<'a, T> {
468    fn drop(&mut self) {
469        self.state
470            .drop_send_future(&mut self.value, &mut self.waiting)
471    }
472}
473
474struct ReserveFuture<'a, T> {
475    state: &'a State<T>,
476    count: usize,
477    waiting: Option<NodeRef<SendWaker>>,
478}
479
480impl<'a, T> ReserveFuture<'a, T> {
481    fn new(state: &'a State<T>, count: usize) -> Self {
482        Self {
483            state,
484            count,
485            waiting: None,
486        }
487    }
488}
489
490impl<'a, T> Future for ReserveFuture<'a, T> {
491    type Output = Result<NodeRef<Spot<T>>, ReserveError>;
492
493    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
494        let this = self.deref_mut();
495
496        this.state.reserve(this.count, cx, &mut this.waiting)
497    }
498}
499
500impl<'a, T> Drop for ReserveFuture<'a, T> {
501    fn drop(&mut self) {
502        self.state.drop_reserve_future(&mut self.waiting)
503    }
504}
505
506struct RecvFuture<'a, T> {
507    receiver: &'a Receiver<T>,
508    waker_ref: Option<NodeRef<Waker>>,
509    has_received: bool,
510}
511
512impl<'a, T> RecvFuture<'a, T> {
513    fn new(receiver: &'a Receiver<T>) -> Self {
514        Self {
515            receiver,
516            waker_ref: None,
517            has_received: false,
518        }
519    }
520}
521
522impl<'a, T> Unpin for RecvFuture<'a, T> {}
523
524impl<'a, T> Future for RecvFuture<'a, T> {
525    type Output = Result<T, RecvError>;
526
527    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
528        let this = self.deref_mut();
529        this.receiver
530            .state
531            .recv(cx, &mut this.has_received, &mut this.waker_ref)
532    }
533}
534
535impl<'a, T> Drop for RecvFuture<'a, T> {
536    fn drop(&mut self) {
537        self.receiver
538            .state
539            .drop_recv_future(self.has_received, &mut self.waker_ref);
540    }
541}
542
543struct RecvManyFuture<'a, T> {
544    receiver: &'a Receiver<T>,
545    vec: &'a mut Vec<T>,
546    count: usize,
547    waker_ref: Option<NodeRef<Waker>>,
548    has_received: bool,
549}
550
551impl<'a, T> RecvManyFuture<'a, T> {
552    fn new(receiver: &'a Receiver<T>, vec: &'a mut Vec<T>, count: usize) -> Self {
553        Self {
554            receiver,
555            vec,
556            count,
557            waker_ref: None,
558            has_received: false,
559        }
560    }
561}
562
563impl<'a, T> Unpin for RecvManyFuture<'a, T> {}
564
565impl<'a, T> Future for RecvManyFuture<'a, T> {
566    type Output = Result<usize, RecvError>;
567
568    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
569        let this = self.deref_mut();
570        this.receiver.state.recv_many(
571            cx,
572            &mut this.has_received,
573            &mut this.waker_ref,
574            this.vec,
575            this.count,
576        )
577    }
578}
579
580impl<'a, T> Drop for RecvManyFuture<'a, T> {
581    fn drop(&mut self) {
582        self.receiver
583            .state
584            .drop_recv_future(self.has_received, &mut self.waker_ref);
585    }
586}
587
588#[cfg(test)]
589mod testing {
590    use std::collections::BTreeSet;
591    use std::sync::{Arc, Mutex};
592    use std::time::Duration;
593
594    use super::*;
595
596    #[tokio::test]
597    async fn send_receive() {
598        let (tx, rx) = channel(1);
599        tx.send(1).await.expect("no error");
600        let res = rx.recv().await.expect("no error");
601        assert_eq!(res, 1);
602    }
603
604    #[tokio::test]
605    async fn mpsc() {
606        let (tx, rx) = channel(1);
607
608        let num_workers = 10;
609        let count = 10;
610        let mut tasks = Vec::with_capacity(num_workers);
611
612        for i in 0..num_workers {
613            let tx = tx.clone();
614            let task = tokio::spawn(async move {
615                for j in 0..count {
616                    let val = i * count + j;
617                    tx.send(val).await.expect("Failed to send");
618                }
619            });
620            tasks.push(task);
621        }
622
623        let total = count * num_workers;
624        let mut values = BTreeSet::new();
625
626        for _ in 0..total {
627            let value = rx.recv().await.expect("no error");
628            values.insert(value);
629        }
630
631        let exp = (0..total).collect::<Vec<_>>();
632        let got = values.into_iter().collect::<Vec<_>>();
633        assert_eq!(exp, got);
634
635        for task in tasks {
636            task.await.expect("failed to join task");
637        }
638    }
639
640    async fn run_tasks<F, Fut>(send: F)
641    where
642        Fut: Future<Output = Sender<usize>> + Send,
643        F: Send + Sync + 'static + Copy,
644        F: Fn(Sender<usize>, usize) -> Fut,
645    {
646        let (tx, rx) = channel(1);
647
648        let num_workers = 10;
649        let count = 10;
650        let mut tasks = Vec::with_capacity(num_workers);
651
652        for i in 0..num_workers {
653            let mut tx = tx.clone();
654            let task = tokio::spawn(async move {
655                for j in 0..count {
656                    let val = i * count + j;
657                    tx = send(tx, val).await;
658                }
659            });
660            tasks.push(task);
661        }
662
663        let total = count * num_workers;
664        let values = Arc::new(Mutex::new(BTreeSet::new()));
665
666        for _ in 0..num_workers {
667            let values = values.clone();
668            let rx = rx.clone();
669            let task = tokio::spawn(async move {
670                for _ in 0..count {
671                    let val = rx.recv().await.expect("Failed to recv");
672                    values.lock().unwrap().insert(val);
673                }
674            });
675            tasks.push(task);
676        }
677
678        for task in tasks {
679            task.await.expect("failed to join task");
680        }
681
682        let exp = (0..total).collect::<Vec<_>>();
683        let got = std::mem::take(values.lock().unwrap().deref_mut())
684            .into_iter()
685            .collect::<Vec<_>>();
686        assert_eq!(exp, got);
687    }
688
689    #[tokio::test]
690    async fn mpmc_multiple_tasks() {
691        run_tasks(|tx, value| async move {
692            tx.send(value).await.expect("Failed to send");
693            tx
694        })
695        .await;
696    }
697
698    #[tokio::test]
699    async fn mpmc_reserve() {
700        run_tasks(|tx, value| async move {
701            tx.reserve().await.expect("Failed to send").send(value);
702            tx
703        })
704        .await;
705    }
706
707    #[tokio::test]
708    async fn mpmc_try_reserve() {
709        run_tasks(|tx, value| async move {
710            loop {
711                match tx.try_reserve() {
712                    Ok(permit) => {
713                        permit.send(value);
714                    }
715                    Err(_err) => {
716                        tokio::time::sleep(Duration::ZERO).await;
717                        continue;
718                    }
719                };
720
721                return tx;
722            }
723        })
724        .await;
725    }
726
727    #[tokio::test]
728    async fn send_errors() {
729        let (tx, rx) = channel::<i32>(2);
730        assert_eq!(tx.send(1).await, Ok(()));
731        assert_eq!(tx.send(2).await, Ok(()));
732        let task = tokio::spawn({
733            let tx = tx.clone();
734            async move { tx.send(3).await }
735        });
736        drop(rx);
737        assert_eq!(tx.send(4).await, Err(SendError(4)));
738        assert_eq!(task.await.expect("panic"), Err(SendError(3)));
739    }
740
741    #[test]
742    fn try_send_errors() {
743        let (tx, rx) = channel::<i32>(2);
744        assert_eq!(tx.try_send(1), Ok(()));
745        assert_eq!(tx.try_send(2), Ok(()));
746        assert_eq!(tx.try_send(3), Err(TrySendError::Full(3)));
747        assert_eq!(tx.try_send(4), Err(TrySendError::Full(4)));
748        drop(rx);
749        assert_eq!(tx.try_send(5), Err(TrySendError::Disconnected(5)));
750    }
751
752    #[tokio::test]
753    async fn reserve_errors() {
754        let (tx, rx) = channel::<i32>(2);
755        tx.reserve().await.expect("reserved 1");
756        tx.reserve().await.expect("reserved 2");
757        let task = tokio::spawn({
758            let tx = tx.clone();
759            async move {
760                assert!(matches!(tx.reserve().await, Err(ReserveError)));
761            }
762        });
763        drop(rx);
764        assert!(matches!(tx.reserve().await, Err(ReserveError)));
765        task.await.expect("no panic");
766    }
767
768    #[test]
769    fn try_reserve_errors() {
770        let (tx, rx) = channel::<i32>(2);
771        let _res1 = tx.try_reserve().expect("reserved 1");
772        let _res2 = tx.try_reserve().expect("reserved 2");
773        assert!(matches!(tx.try_reserve(), Err(TryReserveError::Full)));
774        assert!(matches!(tx.try_reserve(), Err(TryReserveError::Full)));
775        drop(rx);
776        assert!(matches!(
777            tx.try_reserve(),
778            Err(TryReserveError::Disconnected)
779        ));
780    }
781
782    #[tokio::test]
783    async fn recv_future_awoken_but_unused() {
784        let (tx, rx) = channel::<i32>(1);
785        let mut recv = Box::pin(rx.recv());
786        let rx2 = rx.clone();
787        // Try receiving from rx2, but don't drop it yet.
788        tokio::select! {
789            biased;
790            _ = &mut recv => {
791                panic!("unexpected recv");
792            }
793            _ = ReadyFuture {} => {}
794        }
795        let task = tokio::spawn(async move { rx2.recv().await });
796        // Yield the current task so task above can be started.
797        tokio::time::sleep(Duration::ZERO).await;
798        tx.try_send(1).expect("sent");
799        // It would hang without the drop, since the recv would be awoken, but we'd never await for
800        // it. This is the main flaw of this design where only a single future is awoken at the
801        // time. Alternatively, we could wake all of them at once, but this would most likely
802        // result in performance degradation due to lock contention.
803        drop(recv);
804        let res = task.await.expect("no panic").expect("received");
805        assert_eq!(res, 1);
806    }
807
808    #[tokio::test]
809    async fn try_reserve_unused_permit_and_send() {
810        let (tx, rx) = channel::<i32>(1);
811        let permit = tx.try_reserve().expect("reserved");
812        let task = tokio::spawn({
813            let tx = tx.clone();
814            async move { tx.send(1).await }
815        });
816        drop(permit);
817        task.await.expect("no panic").expect("sent");
818        assert_eq!(rx.try_recv().expect("recv"), 1);
819    }
820
821    #[tokio::test]
822    async fn try_reserve_unused_permit_and_other_permit() {
823        let (tx, rx) = channel::<i32>(1);
824        let permit = tx.try_reserve().expect("reserved");
825        let task = tokio::spawn({
826            let tx = tx.clone();
827            async move { tx.reserve().await.expect("reserved").send(1) }
828        });
829        drop(permit);
830        task.await.expect("no panic");
831        assert_eq!(rx.try_recv().expect("recv"), 1);
832    }
833
834    #[tokio::test]
835    async fn receiver_close_all() {
836        let (tx, rx1) = channel::<i32>(3);
837        let rx2 = rx1.clone();
838        let permit1 = tx.reserve().await.unwrap();
839        let permit2 = tx.reserve().await.unwrap();
840        tx.send(1).await.unwrap();
841        rx1.close_all();
842        assert_no_recv(&rx1).await;
843        assert_no_recv(&rx2).await;
844        permit1.send(2);
845        permit2.send(3);
846        assert_eq!(rx1.recv().await.unwrap(), 2);
847        assert_eq!(rx2.try_recv().unwrap(), 3);
848        assert_eq!(rx1.recv().await.unwrap(), 1);
849        assert_eq!(rx1.recv().await, Err(RecvError));
850        assert_eq!(rx2.recv().await, Err(RecvError));
851        assert!(matches!(tx.send(3).await, Err(SendError(3))));
852    }
853
854    #[tokio::test]
855    async fn receiver_close_all_permit_drop() {
856        let (tx, rx) = channel::<i32>(3);
857        let permit = tx.reserve().await.unwrap();
858        rx.close_all();
859        assert_no_recv(&rx).await;
860        drop(permit);
861        assert_eq!(rx.recv().await, Err(RecvError));
862    }
863
864    #[tokio::test]
865    async fn reserve_owned() {
866        let (tx, rx) = channel::<usize>(4);
867        let tx = tx.reserve_owned().await.unwrap().send(1);
868        let tx = tx.reserve_owned().await.unwrap().send(2);
869        let tx = tx.try_reserve_owned().unwrap().send(3);
870        let tx = tx.try_reserve_owned().unwrap().release();
871        let tx = tx.try_reserve_owned().unwrap().send(4);
872        assert!(matches!(
873            tx.clone().try_reserve_owned(),
874            Err(TrySendError::Full(_))
875        ));
876        for i in 1..=4 {
877            assert_eq!(rx.try_recv().unwrap(), i);
878        }
879        drop(rx);
880        assert!(matches!(
881            tx.clone().reserve_owned().await,
882            Err(ReserveError)
883        ));
884        assert!(matches!(
885            tx.try_reserve_owned(),
886            Err(TrySendError::Disconnected(_))
887        ));
888    }
889
890    #[tokio::test]
891    async fn reserve_many() {
892        let (tx, rx) = channel::<usize>(10);
893        let p1 = tx.reserve_many(5).await.unwrap();
894        let p2 = tx.try_reserve_many(5).unwrap();
895        assert!(matches!(tx.try_send(11), Err(TrySendError::Full(11))));
896        for (i, p) in p2.enumerate() {
897            p.send(i + 5);
898        }
899        for (i, p) in p1.enumerate() {
900            p.send(i);
901        }
902        for i in 0..10 {
903            assert_eq!(rx.try_recv().unwrap(), i);
904        }
905    }
906
907    #[tokio::test]
908    async fn reserve_many_drop() {
909        let (tx, _rx) = channel::<usize>(2);
910        let it = tx.reserve_many(2).await.unwrap();
911        drop(it);
912        tx.try_send(1).unwrap();
913        tx.try_send(2).unwrap();
914        assert!(matches!(tx.try_send(3), Err(TrySendError::Full(3))));
915    }
916
917    #[tokio::test]
918    async fn reserve_many_drop_halfway() {
919        let (tx, _rx) = channel::<usize>(4);
920        let mut it = tx.reserve_many(4).await.unwrap();
921        it.next().unwrap().send(1);
922        it.next().unwrap().send(2);
923        drop(it);
924        tx.try_send(3).unwrap();
925        tx.try_send(4).unwrap();
926        assert!(matches!(tx.try_send(5), Err(TrySendError::Full(5))));
927    }
928
929    async fn assert_no_recv<T>(rx: &Receiver<T>)
930    where
931        T: Debug,
932    {
933        tokio::select! {
934            result = rx.recv() => {
935                panic!("unexpected recv: {result:?}");
936            },
937            _ = tokio::time::sleep(std::time::Duration::ZERO) => {},
938        }
939    }
940
941    struct ReadyFuture {}
942
943    impl Future for ReadyFuture {
944        type Output = ();
945
946        fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
947            Poll::Ready(())
948        }
949    }
950}