async_priority_channel/
lib.rs

1//! An async channel where pending messages are delivered in order of priority.
2//!
3//! There are two kinds of channels:
4//!
5//! 1. [Bounded][`bounded()`] channel with limited capacity.
6//! 2. [Unbounded][`unbounded()`] channel with unlimited capacity.
7//!
8//! A channel has the [`Sender`] and [`Receiver`] side. Both sides are cloneable and can be shared
9//! among multiple threads. When [sending][`Sender::send()`], you pass in a message and its
10//! priority. When [receiving][`Receiver::recv()`], you'll get back the pending message with the
11//! highest priotiy.
12//!
13//! When all [`Sender`]s or all [`Receiver`]s are dropped, the channel becomes closed. When a
14//! channel is closed, no more messages can be sent, but remaining messages can still be received.
15//!
16//! The channel can also be closed manually by calling [`Sender::close()`] or
17//! [`Receiver::close()`]. The API and much of the documentation is based on  [async_channel](https://docs.rs/async-channel/1.6.1/async_channel/).
18//!
19//! # Examples
20//!
21//! ```
22//! # futures_lite::future::block_on(async {
23//! let (s, r) = async_priority_channel::unbounded();
24//!
25//! assert_eq!(s.send("Foo", 0).await, Ok(()));
26//! assert_eq!(s.send("Bar", 2).await, Ok(()));
27//! assert_eq!(s.send("Baz", 1).await, Ok(()));
28//! assert_eq!(r.recv().await, Ok(("Bar", 2)));
29//! # });
30//! ```
31
32#![forbid(unsafe_code)]
33#![warn(missing_docs, missing_debug_implementations, rust_2018_idioms)]
34
35mod awaitable_atomics;
36
37use awaitable_atomics::AwaitableAtomicCounterAndBit;
38use std::{
39    collections::BinaryHeap,
40    convert::TryInto,
41    error, fmt,
42    iter::Peekable,
43    sync::{
44        atomic::{AtomicUsize, Ordering},
45        Arc, Mutex,
46    },
47};
48
49/// Creates a bounded channel.
50///
51/// The created channel has space to hold at most `cap` messages at a time.
52///
53/// # Panics
54///
55/// Capacity must be a positive number. If `cap` is zero, this function will panic.
56///
57/// # Examples
58///
59/// ```
60/// # futures_lite::future::block_on(async {
61/// let (s, r) = async_priority_channel::bounded(1);
62///
63/// assert_eq!(s.send("Foo", 0).await, Ok(()));
64/// assert_eq!(r.recv().await, Ok(("Foo", 0)));
65/// # });
66/// ```
67pub fn bounded<I, P>(cap: u64) -> (Sender<I, P>, Receiver<I, P>)
68where
69    P: Ord,
70{
71    if cap == 0 {
72        panic!("cap must be positive");
73    }
74
75    let channel = Arc::new(PriorityQueueChannel {
76        heap: Mutex::new(BinaryHeap::new()),
77        len_and_closed: AwaitableAtomicCounterAndBit::new(0),
78        cap,
79        sender_count: AtomicUsize::new(1),
80        receiver_count: AtomicUsize::new(1),
81    });
82    let s = Sender {
83        channel: channel.clone(),
84    };
85    let r = Receiver { channel };
86    (s, r)
87}
88
89/// Creates an unbounded channel.
90///
91/// The created channel can hold an unlimited number of messages.
92///
93/// # Examples
94///
95/// ```
96/// # futures_lite::future::block_on(async {
97/// let (s, r) = async_priority_channel::unbounded();
98///
99/// assert_eq!(s.send("Foo", 0).await, Ok(()));
100/// assert_eq!(s.send("Bar", 2).await, Ok(()));
101/// assert_eq!(s.send("Baz", 1).await, Ok(()));
102/// assert_eq!(r.recv().await, Ok(("Bar", 2)));
103/// # });
104/// ```
105pub fn unbounded<I, P>() -> (Sender<I, P>, Receiver<I, P>)
106where
107    P: Ord,
108{
109    bounded(u64::MAX)
110}
111
112#[derive(Debug)]
113struct PriorityQueueChannel<I, P>
114where
115    P: Ord,
116{
117    // the data that needs to be maintained under a mutex
118    heap: Mutex<BinaryHeap<Item<I, P>>>,
119
120    // number of items in the channel, and is the channel closed,
121    // all accessible without holding the mutex?
122    len_and_closed: AwaitableAtomicCounterAndBit,
123
124    // capacity = 0 means unbounded, otherwise the bound.
125    cap: u64,
126
127    sender_count: AtomicUsize,
128    receiver_count: AtomicUsize,
129}
130
131#[derive(Debug)]
132/// Send side of the channel. Can be cloned.
133pub struct Sender<I, P>
134where
135    P: Ord,
136{
137    channel: Arc<PriorityQueueChannel<I, P>>,
138}
139
140#[derive(Debug)]
141/// Receive side of the channel. Can be cloned.
142pub struct Receiver<I, P>
143where
144    P: Ord,
145{
146    channel: Arc<PriorityQueueChannel<I, P>>,
147}
148
149impl<I, P> Drop for Sender<I, P>
150where
151    P: Ord,
152{
153    fn drop(&mut self) {
154        // Decrement the sender count and close the channel if it drops down to zero.
155        if self.channel.sender_count.fetch_sub(1, Ordering::AcqRel) == 1 {
156            self.channel.close();
157        }
158    }
159}
160
161impl<I, P> Drop for Receiver<I, P>
162where
163    P: Ord,
164{
165    fn drop(&mut self) {
166        // Decrement the receiver count and close the channel if it drops down to zero.
167        if self.channel.receiver_count.fetch_sub(1, Ordering::AcqRel) == 1 {
168            self.channel.close();
169        }
170    }
171}
172
173impl<I, P> Clone for Sender<I, P>
174where
175    P: Ord,
176{
177    fn clone(&self) -> Sender<I, P> {
178        let count = self.channel.sender_count.fetch_add(1, Ordering::Relaxed);
179
180        // Make sure the count never overflows, even if lots of sender clones are leaked.
181        if count > usize::MAX / 2 {
182            panic!("bailing due to possible overflow");
183        }
184
185        Sender {
186            channel: self.channel.clone(),
187        }
188    }
189}
190
191impl<I, P> Clone for Receiver<I, P>
192where
193    P: Ord,
194{
195    fn clone(&self) -> Receiver<I, P> {
196        let count = self.channel.receiver_count.fetch_add(1, Ordering::Relaxed);
197
198        // Make sure the count never overflows, even if lots of sender clones are leaked.
199        if count > usize::MAX / 2 {
200            panic!("bailing due to possible overflow");
201        }
202
203        Receiver {
204            channel: self.channel.clone(),
205        }
206    }
207}
208
209impl<I, P> PriorityQueueChannel<I, P>
210where
211    P: Ord,
212{
213    /// Closes the channel and notifies all blocked operations.
214    ///
215    /// Returns `true` if this call has closed the channel and it was not closed already.
216    ///
217    fn close(&self) -> bool {
218        let was_closed = self.len_and_closed.set_bit();
219        !was_closed
220    }
221
222    // Return `true` if the channel is closed
223    fn is_closed(&self) -> bool {
224        self.len_and_closed.load().0
225    }
226
227    /// Return `true` if the channel is empty
228    fn is_empty(&self) -> bool {
229        self.len() == 0
230    }
231
232    /// Return `true` if the channel is full
233    fn is_full(&self) -> bool {
234        self.cap > 0 && self.len() == self.cap
235    }
236
237    /// Returns the number of messages in the channel.
238    fn len(&self) -> u64 {
239        self.len_and_closed.load().1
240    }
241
242    fn len_and_closed(&self) -> (bool, u64) {
243        self.len_and_closed.load()
244    }
245}
246
247impl<T, P> Sender<T, P>
248where
249    P: Ord,
250{
251    /// Attempts to send a message into the channel.
252    ///
253    /// If the channel is full or closed, this method returns an error.
254    ///
255    pub fn try_send(&self, msg: T, priority: P) -> Result<(), TrySendError<(T, P)>> {
256        self.try_sendv(std::iter::once((msg, priority)).peekable())
257            .map_err(|e| match e {
258                TrySendError::Closed(mut value) => TrySendError::Closed(value.next().expect("foo")),
259                TrySendError::Full(mut value) => TrySendError::Full(value.next().expect("foo")),
260            })
261    }
262
263    /// Attempts to send multiple messages into the channel.
264    ///
265    /// If the channel is closed, this method returns an error.
266    ///
267    /// If the channel is full or nearly full, this method inserts as many messages
268    /// as it can into the channel and then returns an error containing the
269    /// remaining unsent messages.
270    pub fn try_sendv<I>(&self, msgs: Peekable<I>) -> Result<(), TrySendError<Peekable<I>>>
271    where
272        I: Iterator<Item = (T, P)>,
273    {
274        let mut msgs = msgs;
275        let (is_closed, len) = self.channel.len_and_closed();
276        if is_closed {
277            return Err(TrySendError::Closed(msgs));
278        }
279        if len > self.channel.cap {
280            panic!("size of channel is larger than capacity. this must indicate a bug");
281        }
282
283        match len == self.channel.cap {
284            true => Err(TrySendError::Full(msgs)),
285            false => {
286                // we're below capacity according to the atomic len() field.
287                // but it's possible that two threads will get here at the same time
288                // because we haven't acquired the lock yet, so lets acquire the lock
289                // and only let one through
290                let mut heap = self
291                    .channel
292                    .heap
293                    .lock()
294                    .expect("task panicked while holding lock");
295                let mut n = 0;
296                loop {
297                    if heap.len().try_into().unwrap_or(u64::MAX) < self.channel.cap {
298                        if let Some((msg, priority)) = msgs.next() {
299                            heap.push(Item { msg, priority });
300                            n += 1;
301                        } else {
302                            break;
303                        }
304                    } else {
305                        self.channel.len_and_closed.incr(n);
306                        return match msgs.peek() {
307                            Some(_) => Err(TrySendError::Full(msgs)),
308                            None => Ok(()),
309                        };
310                    }
311                }
312                self.channel.len_and_closed.incr(n);
313                Ok(())
314            }
315        }
316    }
317
318    /// Sends a message into the channel.
319    ///
320    /// If the channel is full, this method waits until there is space for a message.
321    ///
322    /// If the channel is closed, this method returns an error.
323    ///
324    pub async fn send(&self, msg: T, priority: P) -> Result<(), SendError<(T, P)>> {
325        let mut msg2 = msg;
326        let mut priority2 = priority;
327        loop {
328            let decr_listener = self.channel.len_and_closed.listen_decr();
329            match self.try_send(msg2, priority2) {
330                Ok(_) => {
331                    return Ok(());
332                }
333                Err(TrySendError::Full((msg, priority))) => {
334                    msg2 = msg;
335                    priority2 = priority;
336                    decr_listener.await;
337                }
338                Err(TrySendError::Closed((msg, priority))) => {
339                    return Err(SendError((msg, priority)));
340                }
341            }
342        }
343    }
344
345    /// Send multiple messages into the channel
346    ///
347    /// If the channel is full, this method waits until there is space.
348    ///
349    /// If the channel is closed, this method returns an error.
350    pub async fn sendv<I>(&self, msgs: Peekable<I>) -> Result<(), SendError<Peekable<I>>>
351    where
352        I: Iterator<Item = (T, P)>,
353    {
354        let mut msgs2 = msgs;
355        loop {
356            let decr_listener = self.channel.len_and_closed.listen_decr();
357            match self.try_sendv(msgs2) {
358                Ok(_) => {
359                    return Ok(());
360                }
361                Err(TrySendError::Full(msgs)) => {
362                    msgs2 = msgs;
363                    decr_listener.await;
364                }
365                Err(TrySendError::Closed(msgs)) => {
366                    return Err(SendError(msgs));
367                }
368            }
369        }
370    }
371
372    /// Closes the channel and notifies all blocked operations.
373    ///
374    /// Returns `true` if this call has closed the channel and it was not closed already.
375    ///
376    pub fn close(&self) -> bool {
377        self.channel.close()
378    }
379
380    /// Returns `true` if the channel is closed
381    pub fn is_closed(&self) -> bool {
382        self.channel.is_closed()
383    }
384
385    /// Return `true` if the channel is empty
386    pub fn is_empty(&self) -> bool {
387        self.channel.is_empty()
388    }
389
390    /// Return `true` if the channel is full
391    pub fn is_full(&self) -> bool {
392        self.channel.is_full()
393    }
394
395    /// Returns the number of messages in the channel.
396    pub fn len(&self) -> u64 {
397        self.channel.len()
398    }
399
400    /// Returns the channel capacity if it's bounded.
401    pub fn capacity(&self) -> Option<u64> {
402        match self.channel.cap {
403            u64::MAX => None,
404            c => Some(c),
405        }
406    }
407
408    /// Returns the number of receivers for the channel.
409    pub fn receiver_count(&self) -> usize {
410        self.channel.receiver_count.load(Ordering::SeqCst)
411    }
412
413    /// Returns the number of senders for the channel.
414    pub fn sender_count(&self) -> usize {
415        self.channel.sender_count.load(Ordering::SeqCst)
416    }
417}
418
419impl<I, P> Receiver<I, P>
420where
421    P: Ord,
422{
423    /// Attempts to receive a message from the channel.
424    ///
425    /// If the channel is empty or closed, this method returns an error.
426    ///
427    pub fn try_recv(&self) -> Result<(I, P), TryRecvError> {
428        match (self.channel.is_empty(), self.channel.is_closed()) {
429            (true, true) => Err(TryRecvError::Closed),
430            (true, false) => Err(TryRecvError::Empty),
431            (false, _) => {
432                // channel contains items and is either open or closed
433                let mut heap = self
434                    .channel
435                    .heap
436                    .lock()
437                    .expect("task panicked while holding lock");
438                let item = heap.pop();
439                match item {
440                    Some(item) => {
441                        self.channel.len_and_closed.decr();
442                        Ok((item.msg, item.priority))
443                    }
444                    None => Err(TryRecvError::Empty),
445                }
446            }
447        }
448    }
449
450    /// Receives a message from the channel.
451    ///
452    /// If the channel is empty, this method waits until there is a message.
453    ///
454    /// If the channel is closed, this method receives a message or returns an error if there are
455    /// no more messages.
456    pub async fn recv(&self) -> Result<(I, P), RecvError> {
457        loop {
458            let incr_listener = self.channel.len_and_closed.listen_incr();
459            match self.try_recv() {
460                Ok(item) => {
461                    return Ok(item);
462                }
463                Err(TryRecvError::Closed) => {
464                    return Err(RecvError);
465                }
466                Err(TryRecvError::Empty) => {
467                    incr_listener.await;
468                }
469            }
470        }
471    }
472
473    /// Closes the channel and notifies all blocked operations.
474    ///
475    /// Returns `true` if this call has closed the channel and it was not closed already.
476    ///
477    pub fn close(&self) -> bool {
478        self.channel.close()
479    }
480
481    /// Returns whether the channel is closed
482    pub fn is_closed(&self) -> bool {
483        self.channel.is_closed()
484    }
485
486    /// Return `true` if the channel is empty
487    pub fn is_empty(&self) -> bool {
488        self.channel.is_empty()
489    }
490
491    /// Return `true` if the channel is full
492    pub fn is_full(&self) -> bool {
493        self.channel.is_full()
494    }
495
496    /// Returns the number of messages in the channel.
497    pub fn len(&self) -> u64 {
498        self.channel.len()
499    }
500
501    /// Returns the channel capacity if it's bounded.
502    pub fn capacity(&self) -> Option<u64> {
503        match self.channel.cap {
504            u64::MAX => None,
505            c => Some(c),
506        }
507    }
508
509    /// Returns the number of receivers for the channel.
510    pub fn receiver_count(&self) -> usize {
511        self.channel.receiver_count.load(Ordering::SeqCst)
512    }
513
514    /// Returns the number of senders for the channel.
515    pub fn sender_count(&self) -> usize {
516        self.channel.sender_count.load(Ordering::SeqCst)
517    }
518}
519
520/// Private 2-tuple that sorts only by the `[priority]`
521#[derive(Debug)]
522struct Item<I, P>
523where
524    P: Eq + Ord,
525{
526    msg: I,
527    priority: P,
528}
529
530impl<I, P> Ord for Item<I, P>
531where
532    P: Eq + Ord,
533{
534    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
535        self.priority.cmp(&other.priority)
536    }
537}
538
539impl<I, P> PartialOrd for Item<I, P>
540where
541    P: Eq + Ord,
542{
543    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
544        Some(self.cmp(other))
545    }
546}
547
548impl<I, P: std::cmp::Eq> PartialEq for Item<I, P>
549where
550    P: Eq + Ord,
551{
552    fn eq(&self, other: &Self) -> bool {
553        self.priority == other.priority
554    }
555}
556
557impl<I, P> Eq for Item<I, P> where P: Eq + Ord {}
558
559/// An error returned from [`Sender::send()`].
560///
561/// Received because the channel is closed.
562#[derive(PartialEq, Eq, Clone, Copy)]
563pub struct SendError<T>(pub T);
564
565impl<T> SendError<T> {
566    /// Unwraps the message that couldn't be sent.
567    pub fn into_inner(self) -> T {
568        self.0
569    }
570}
571
572impl<T> error::Error for SendError<T> {}
573
574impl<T> fmt::Debug for SendError<T> {
575    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
576        write!(f, "SendError(..)")
577    }
578}
579
580impl<T> fmt::Display for SendError<T> {
581    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
582        write!(f, "sending into a closed channel")
583    }
584}
585
586/// An error returned from [`Receiver::recv()`].
587///
588/// Received because the channel is empty and closed.
589#[derive(PartialEq, Eq, Clone, Copy, Debug)]
590pub struct RecvError;
591
592impl error::Error for RecvError {}
593
594impl fmt::Display for RecvError {
595    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
596        write!(f, "receiving from an empty and closed channel")
597    }
598}
599
600/// An error returned from [`Sender::try_send()`].
601#[derive(PartialEq, Eq, Clone, Copy)]
602pub enum TrySendError<T> {
603    /// The channel is full but not closed.
604    Full(T),
605
606    /// The channel is closed.
607    Closed(T),
608}
609
610impl<T> TrySendError<T> {
611    /// Unwraps the message that couldn't be sent.
612    pub fn into_inner(self) -> T {
613        match self {
614            TrySendError::Full(t) => t,
615            TrySendError::Closed(t) => t,
616        }
617    }
618
619    /// Returns `true` if the channel is full but not closed.
620    pub fn is_full(&self) -> bool {
621        match self {
622            TrySendError::Full(_) => true,
623            TrySendError::Closed(_) => false,
624        }
625    }
626
627    /// Returns `true` if the channel is closed.
628    pub fn is_closed(&self) -> bool {
629        match self {
630            TrySendError::Full(_) => false,
631            TrySendError::Closed(_) => true,
632        }
633    }
634}
635
636impl<T> error::Error for TrySendError<T> {}
637
638impl<T> fmt::Debug for TrySendError<T> {
639    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
640        match *self {
641            TrySendError::Full(..) => write!(f, "Full(..)"),
642            TrySendError::Closed(..) => write!(f, "Closed(..)"),
643        }
644    }
645}
646
647impl<T> fmt::Display for TrySendError<T> {
648    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
649        match *self {
650            TrySendError::Full(..) => write!(f, "sending into a full channel"),
651            TrySendError::Closed(..) => write!(f, "sending into a closed channel"),
652        }
653    }
654}
655
656/// An error returned from [`Receiver::try_recv()`].
657#[derive(PartialEq, Eq, Clone, Copy, Debug)]
658pub enum TryRecvError {
659    /// The channel is empty but not closed.
660    Empty,
661
662    /// The channel is empty and closed.
663    Closed,
664}
665
666impl TryRecvError {
667    /// Returns `true` if the channel is empty but not closed.
668    pub fn is_empty(&self) -> bool {
669        match self {
670            TryRecvError::Empty => true,
671            TryRecvError::Closed => false,
672        }
673    }
674
675    /// Returns `true` if the channel is empty and closed.
676    pub fn is_closed(&self) -> bool {
677        match self {
678            TryRecvError::Empty => false,
679            TryRecvError::Closed => true,
680        }
681    }
682}
683
684impl error::Error for TryRecvError {}
685
686impl fmt::Display for TryRecvError {
687    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
688        match *self {
689            TryRecvError::Empty => write!(f, "receiving from an empty channel"),
690            TryRecvError::Closed => write!(f, "receiving from an empty and closed channel"),
691        }
692    }
693}