1#![forbid(unsafe_code)]
33#![warn(missing_docs, missing_debug_implementations, rust_2018_idioms)]
34
35mod awaitable_atomics;
36
37use awaitable_atomics::AwaitableAtomicCounterAndBit;
38use std::{
39 collections::BinaryHeap,
40 convert::TryInto,
41 error, fmt,
42 iter::Peekable,
43 sync::{
44 atomic::{AtomicUsize, Ordering},
45 Arc, Mutex,
46 },
47};
48
49pub fn bounded<I, P>(cap: u64) -> (Sender<I, P>, Receiver<I, P>)
68where
69 P: Ord,
70{
71 if cap == 0 {
72 panic!("cap must be positive");
73 }
74
75 let channel = Arc::new(PriorityQueueChannel {
76 heap: Mutex::new(BinaryHeap::new()),
77 len_and_closed: AwaitableAtomicCounterAndBit::new(0),
78 cap,
79 sender_count: AtomicUsize::new(1),
80 receiver_count: AtomicUsize::new(1),
81 });
82 let s = Sender {
83 channel: channel.clone(),
84 };
85 let r = Receiver { channel };
86 (s, r)
87}
88
89pub fn unbounded<I, P>() -> (Sender<I, P>, Receiver<I, P>)
106where
107 P: Ord,
108{
109 bounded(u64::MAX)
110}
111
112#[derive(Debug)]
113struct PriorityQueueChannel<I, P>
114where
115 P: Ord,
116{
117 heap: Mutex<BinaryHeap<Item<I, P>>>,
119
120 len_and_closed: AwaitableAtomicCounterAndBit,
123
124 cap: u64,
126
127 sender_count: AtomicUsize,
128 receiver_count: AtomicUsize,
129}
130
131#[derive(Debug)]
132pub struct Sender<I, P>
134where
135 P: Ord,
136{
137 channel: Arc<PriorityQueueChannel<I, P>>,
138}
139
140#[derive(Debug)]
141pub struct Receiver<I, P>
143where
144 P: Ord,
145{
146 channel: Arc<PriorityQueueChannel<I, P>>,
147}
148
149impl<I, P> Drop for Sender<I, P>
150where
151 P: Ord,
152{
153 fn drop(&mut self) {
154 if self.channel.sender_count.fetch_sub(1, Ordering::AcqRel) == 1 {
156 self.channel.close();
157 }
158 }
159}
160
161impl<I, P> Drop for Receiver<I, P>
162where
163 P: Ord,
164{
165 fn drop(&mut self) {
166 if self.channel.receiver_count.fetch_sub(1, Ordering::AcqRel) == 1 {
168 self.channel.close();
169 }
170 }
171}
172
173impl<I, P> Clone for Sender<I, P>
174where
175 P: Ord,
176{
177 fn clone(&self) -> Sender<I, P> {
178 let count = self.channel.sender_count.fetch_add(1, Ordering::Relaxed);
179
180 if count > usize::MAX / 2 {
182 panic!("bailing due to possible overflow");
183 }
184
185 Sender {
186 channel: self.channel.clone(),
187 }
188 }
189}
190
191impl<I, P> Clone for Receiver<I, P>
192where
193 P: Ord,
194{
195 fn clone(&self) -> Receiver<I, P> {
196 let count = self.channel.receiver_count.fetch_add(1, Ordering::Relaxed);
197
198 if count > usize::MAX / 2 {
200 panic!("bailing due to possible overflow");
201 }
202
203 Receiver {
204 channel: self.channel.clone(),
205 }
206 }
207}
208
209impl<I, P> PriorityQueueChannel<I, P>
210where
211 P: Ord,
212{
213 fn close(&self) -> bool {
218 let was_closed = self.len_and_closed.set_bit();
219 !was_closed
220 }
221
222 fn is_closed(&self) -> bool {
224 self.len_and_closed.load().0
225 }
226
227 fn is_empty(&self) -> bool {
229 self.len() == 0
230 }
231
232 fn is_full(&self) -> bool {
234 self.cap > 0 && self.len() == self.cap
235 }
236
237 fn len(&self) -> u64 {
239 self.len_and_closed.load().1
240 }
241
242 fn len_and_closed(&self) -> (bool, u64) {
243 self.len_and_closed.load()
244 }
245}
246
247impl<T, P> Sender<T, P>
248where
249 P: Ord,
250{
251 pub fn try_send(&self, msg: T, priority: P) -> Result<(), TrySendError<(T, P)>> {
256 self.try_sendv(std::iter::once((msg, priority)).peekable())
257 .map_err(|e| match e {
258 TrySendError::Closed(mut value) => TrySendError::Closed(value.next().expect("foo")),
259 TrySendError::Full(mut value) => TrySendError::Full(value.next().expect("foo")),
260 })
261 }
262
263 pub fn try_sendv<I>(&self, msgs: Peekable<I>) -> Result<(), TrySendError<Peekable<I>>>
271 where
272 I: Iterator<Item = (T, P)>,
273 {
274 let mut msgs = msgs;
275 let (is_closed, len) = self.channel.len_and_closed();
276 if is_closed {
277 return Err(TrySendError::Closed(msgs));
278 }
279 if len > self.channel.cap {
280 panic!("size of channel is larger than capacity. this must indicate a bug");
281 }
282
283 match len == self.channel.cap {
284 true => Err(TrySendError::Full(msgs)),
285 false => {
286 let mut heap = self
291 .channel
292 .heap
293 .lock()
294 .expect("task panicked while holding lock");
295 let mut n = 0;
296 loop {
297 if heap.len().try_into().unwrap_or(u64::MAX) < self.channel.cap {
298 if let Some((msg, priority)) = msgs.next() {
299 heap.push(Item { msg, priority });
300 n += 1;
301 } else {
302 break;
303 }
304 } else {
305 self.channel.len_and_closed.incr(n);
306 return match msgs.peek() {
307 Some(_) => Err(TrySendError::Full(msgs)),
308 None => Ok(()),
309 };
310 }
311 }
312 self.channel.len_and_closed.incr(n);
313 Ok(())
314 }
315 }
316 }
317
318 pub async fn send(&self, msg: T, priority: P) -> Result<(), SendError<(T, P)>> {
325 let mut msg2 = msg;
326 let mut priority2 = priority;
327 loop {
328 let decr_listener = self.channel.len_and_closed.listen_decr();
329 match self.try_send(msg2, priority2) {
330 Ok(_) => {
331 return Ok(());
332 }
333 Err(TrySendError::Full((msg, priority))) => {
334 msg2 = msg;
335 priority2 = priority;
336 decr_listener.await;
337 }
338 Err(TrySendError::Closed((msg, priority))) => {
339 return Err(SendError((msg, priority)));
340 }
341 }
342 }
343 }
344
345 pub async fn sendv<I>(&self, msgs: Peekable<I>) -> Result<(), SendError<Peekable<I>>>
351 where
352 I: Iterator<Item = (T, P)>,
353 {
354 let mut msgs2 = msgs;
355 loop {
356 let decr_listener = self.channel.len_and_closed.listen_decr();
357 match self.try_sendv(msgs2) {
358 Ok(_) => {
359 return Ok(());
360 }
361 Err(TrySendError::Full(msgs)) => {
362 msgs2 = msgs;
363 decr_listener.await;
364 }
365 Err(TrySendError::Closed(msgs)) => {
366 return Err(SendError(msgs));
367 }
368 }
369 }
370 }
371
372 pub fn close(&self) -> bool {
377 self.channel.close()
378 }
379
380 pub fn is_closed(&self) -> bool {
382 self.channel.is_closed()
383 }
384
385 pub fn is_empty(&self) -> bool {
387 self.channel.is_empty()
388 }
389
390 pub fn is_full(&self) -> bool {
392 self.channel.is_full()
393 }
394
395 pub fn len(&self) -> u64 {
397 self.channel.len()
398 }
399
400 pub fn capacity(&self) -> Option<u64> {
402 match self.channel.cap {
403 u64::MAX => None,
404 c => Some(c),
405 }
406 }
407
408 pub fn receiver_count(&self) -> usize {
410 self.channel.receiver_count.load(Ordering::SeqCst)
411 }
412
413 pub fn sender_count(&self) -> usize {
415 self.channel.sender_count.load(Ordering::SeqCst)
416 }
417}
418
419impl<I, P> Receiver<I, P>
420where
421 P: Ord,
422{
423 pub fn try_recv(&self) -> Result<(I, P), TryRecvError> {
428 match (self.channel.is_empty(), self.channel.is_closed()) {
429 (true, true) => Err(TryRecvError::Closed),
430 (true, false) => Err(TryRecvError::Empty),
431 (false, _) => {
432 let mut heap = self
434 .channel
435 .heap
436 .lock()
437 .expect("task panicked while holding lock");
438 let item = heap.pop();
439 match item {
440 Some(item) => {
441 self.channel.len_and_closed.decr();
442 Ok((item.msg, item.priority))
443 }
444 None => Err(TryRecvError::Empty),
445 }
446 }
447 }
448 }
449
450 pub async fn recv(&self) -> Result<(I, P), RecvError> {
457 loop {
458 let incr_listener = self.channel.len_and_closed.listen_incr();
459 match self.try_recv() {
460 Ok(item) => {
461 return Ok(item);
462 }
463 Err(TryRecvError::Closed) => {
464 return Err(RecvError);
465 }
466 Err(TryRecvError::Empty) => {
467 incr_listener.await;
468 }
469 }
470 }
471 }
472
473 pub fn close(&self) -> bool {
478 self.channel.close()
479 }
480
481 pub fn is_closed(&self) -> bool {
483 self.channel.is_closed()
484 }
485
486 pub fn is_empty(&self) -> bool {
488 self.channel.is_empty()
489 }
490
491 pub fn is_full(&self) -> bool {
493 self.channel.is_full()
494 }
495
496 pub fn len(&self) -> u64 {
498 self.channel.len()
499 }
500
501 pub fn capacity(&self) -> Option<u64> {
503 match self.channel.cap {
504 u64::MAX => None,
505 c => Some(c),
506 }
507 }
508
509 pub fn receiver_count(&self) -> usize {
511 self.channel.receiver_count.load(Ordering::SeqCst)
512 }
513
514 pub fn sender_count(&self) -> usize {
516 self.channel.sender_count.load(Ordering::SeqCst)
517 }
518}
519
520#[derive(Debug)]
522struct Item<I, P>
523where
524 P: Eq + Ord,
525{
526 msg: I,
527 priority: P,
528}
529
530impl<I, P> Ord for Item<I, P>
531where
532 P: Eq + Ord,
533{
534 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
535 self.priority.cmp(&other.priority)
536 }
537}
538
539impl<I, P> PartialOrd for Item<I, P>
540where
541 P: Eq + Ord,
542{
543 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
544 Some(self.cmp(other))
545 }
546}
547
548impl<I, P: std::cmp::Eq> PartialEq for Item<I, P>
549where
550 P: Eq + Ord,
551{
552 fn eq(&self, other: &Self) -> bool {
553 self.priority == other.priority
554 }
555}
556
557impl<I, P> Eq for Item<I, P> where P: Eq + Ord {}
558
559#[derive(PartialEq, Eq, Clone, Copy)]
563pub struct SendError<T>(pub T);
564
565impl<T> SendError<T> {
566 pub fn into_inner(self) -> T {
568 self.0
569 }
570}
571
572impl<T> error::Error for SendError<T> {}
573
574impl<T> fmt::Debug for SendError<T> {
575 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
576 write!(f, "SendError(..)")
577 }
578}
579
580impl<T> fmt::Display for SendError<T> {
581 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
582 write!(f, "sending into a closed channel")
583 }
584}
585
586#[derive(PartialEq, Eq, Clone, Copy, Debug)]
590pub struct RecvError;
591
592impl error::Error for RecvError {}
593
594impl fmt::Display for RecvError {
595 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
596 write!(f, "receiving from an empty and closed channel")
597 }
598}
599
600#[derive(PartialEq, Eq, Clone, Copy)]
602pub enum TrySendError<T> {
603 Full(T),
605
606 Closed(T),
608}
609
610impl<T> TrySendError<T> {
611 pub fn into_inner(self) -> T {
613 match self {
614 TrySendError::Full(t) => t,
615 TrySendError::Closed(t) => t,
616 }
617 }
618
619 pub fn is_full(&self) -> bool {
621 match self {
622 TrySendError::Full(_) => true,
623 TrySendError::Closed(_) => false,
624 }
625 }
626
627 pub fn is_closed(&self) -> bool {
629 match self {
630 TrySendError::Full(_) => false,
631 TrySendError::Closed(_) => true,
632 }
633 }
634}
635
636impl<T> error::Error for TrySendError<T> {}
637
638impl<T> fmt::Debug for TrySendError<T> {
639 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
640 match *self {
641 TrySendError::Full(..) => write!(f, "Full(..)"),
642 TrySendError::Closed(..) => write!(f, "Closed(..)"),
643 }
644 }
645}
646
647impl<T> fmt::Display for TrySendError<T> {
648 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
649 match *self {
650 TrySendError::Full(..) => write!(f, "sending into a full channel"),
651 TrySendError::Closed(..) => write!(f, "sending into a closed channel"),
652 }
653 }
654}
655
656#[derive(PartialEq, Eq, Clone, Copy, Debug)]
658pub enum TryRecvError {
659 Empty,
661
662 Closed,
664}
665
666impl TryRecvError {
667 pub fn is_empty(&self) -> bool {
669 match self {
670 TryRecvError::Empty => true,
671 TryRecvError::Closed => false,
672 }
673 }
674
675 pub fn is_closed(&self) -> bool {
677 match self {
678 TryRecvError::Empty => false,
679 TryRecvError::Closed => true,
680 }
681 }
682}
683
684impl error::Error for TryRecvError {}
685
686impl fmt::Display for TryRecvError {
687 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
688 match *self {
689 TryRecvError::Empty => write!(f, "receiving from an empty channel"),
690 TryRecvError::Closed => write!(f, "receiving from an empty and closed channel"),
691 }
692 }
693}