#![forbid(unsafe_code)]
#![warn(missing_docs, missing_debug_implementations, rust_2018_idioms)]
mod awaitable_atomics;
use awaitable_atomics::AwaitableAtomicCounterAndBit;
use std::{
collections::BinaryHeap,
convert::TryInto,
error, fmt,
iter::Peekable,
sync::{
atomic::{AtomicUsize, Ordering},
Arc, Mutex,
},
};
pub fn bounded<I, P>(cap: u64) -> (Sender<I, P>, Receiver<I, P>)
where
P: Ord,
{
if cap == 0 {
panic!("cap must be positive");
}
let channel = Arc::new(PriorityQueueChannel {
heap: Mutex::new(BinaryHeap::new()),
len_and_closed: AwaitableAtomicCounterAndBit::new(0),
cap,
sender_count: AtomicUsize::new(1),
receiver_count: AtomicUsize::new(1),
});
let s = Sender {
channel: channel.clone(),
};
let r = Receiver { channel };
(s, r)
}
pub fn unbounded<I, P>() -> (Sender<I, P>, Receiver<I, P>)
where
P: Ord,
{
bounded(u64::MAX)
}
#[derive(Debug)]
struct PriorityQueueChannel<I, P>
where
P: Ord,
{
heap: Mutex<BinaryHeap<Item<I, P>>>,
len_and_closed: AwaitableAtomicCounterAndBit,
cap: u64,
sender_count: AtomicUsize,
receiver_count: AtomicUsize,
}
#[derive(Debug)]
pub struct Sender<I, P>
where
P: Ord,
{
channel: Arc<PriorityQueueChannel<I, P>>,
}
#[derive(Debug)]
pub struct Receiver<I, P>
where
P: Ord,
{
channel: Arc<PriorityQueueChannel<I, P>>,
}
impl<I, P> Drop for Sender<I, P>
where
P: Ord,
{
fn drop(&mut self) {
if self.channel.sender_count.fetch_sub(1, Ordering::AcqRel) == 1 {
self.channel.close();
}
}
}
impl<I, P> Drop for Receiver<I, P>
where
P: Ord,
{
fn drop(&mut self) {
if self.channel.receiver_count.fetch_sub(1, Ordering::AcqRel) == 1 {
self.channel.close();
}
}
}
impl<I, P> Clone for Sender<I, P>
where
P: Ord,
{
fn clone(&self) -> Sender<I, P> {
let count = self.channel.sender_count.fetch_add(1, Ordering::Relaxed);
if count > usize::MAX / 2 {
panic!("bailing due to possible overflow");
}
Sender {
channel: self.channel.clone(),
}
}
}
impl<I, P> Clone for Receiver<I, P>
where
P: Ord,
{
fn clone(&self) -> Receiver<I, P> {
let count = self.channel.receiver_count.fetch_add(1, Ordering::Relaxed);
if count > usize::MAX / 2 {
panic!("bailing due to possible overflow");
}
Receiver {
channel: self.channel.clone(),
}
}
}
impl<I, P> PriorityQueueChannel<I, P>
where
P: Ord,
{
fn close(&self) -> bool {
let was_closed = self.len_and_closed.set_bit();
!was_closed
}
fn is_closed(&self) -> bool {
self.len_and_closed.load().0
}
fn is_empty(&self) -> bool {
self.len() == 0
}
fn is_full(&self) -> bool {
self.cap > 0 && self.len() == self.cap
}
fn len(&self) -> u64 {
self.len_and_closed.load().1
}
fn len_and_closed(&self) -> (bool, u64) {
self.len_and_closed.load()
}
}
impl<T, P> Sender<T, P>
where
P: Ord,
{
pub fn try_send(&self, msg: T, priority: P) -> Result<(), TrySendError<(T, P)>> {
self.try_sendv(std::iter::once((msg, priority)).peekable())
.map_err(|e| match e {
TrySendError::Closed(mut value) => TrySendError::Closed(value.next().expect("foo")),
TrySendError::Full(mut value) => TrySendError::Full(value.next().expect("foo")),
})
}
pub fn try_sendv<I>(&self, msgs: Peekable<I>) -> Result<(), TrySendError<Peekable<I>>>
where
I: Iterator<Item = (T, P)>,
{
let mut msgs = msgs;
let (is_closed, len) = self.channel.len_and_closed();
if is_closed {
return Err(TrySendError::Closed(msgs));
}
if len > self.channel.cap {
panic!("size of channel is larger than capacity. this must indicate a bug");
}
match len == self.channel.cap {
true => Err(TrySendError::Full(msgs)),
false => {
let mut heap = self
.channel
.heap
.lock()
.expect("task panicked while holding lock");
let mut n = 0;
loop {
if heap.len().try_into().unwrap_or(u64::MAX) < self.channel.cap {
if let Some((msg, priority)) = msgs.next() {
heap.push(Item { msg, priority });
n += 1;
} else {
break;
}
} else {
self.channel.len_and_closed.incr(n);
return match msgs.peek() {
Some(_) => Err(TrySendError::Full(msgs)),
None => Ok(()),
};
}
}
self.channel.len_and_closed.incr(n);
Ok(())
}
}
}
pub async fn send(&self, msg: T, priority: P) -> Result<(), SendError<(T, P)>> {
let mut msg2 = msg;
let mut priority2 = priority;
loop {
let decr_listener = self.channel.len_and_closed.listen_decr();
match self.try_send(msg2, priority2) {
Ok(_) => {
return Ok(());
}
Err(TrySendError::Full((msg, priority))) => {
msg2 = msg;
priority2 = priority;
decr_listener.await;
}
Err(TrySendError::Closed((msg, priority))) => {
return Err(SendError((msg, priority)));
}
}
}
}
pub async fn sendv<I>(&self, msgs: Peekable<I>) -> Result<(), SendError<Peekable<I>>>
where
I: Iterator<Item = (T, P)>,
{
let mut msgs2 = msgs;
loop {
let decr_listener = self.channel.len_and_closed.listen_decr();
match self.try_sendv(msgs2) {
Ok(_) => {
return Ok(());
}
Err(TrySendError::Full(msgs)) => {
msgs2 = msgs;
decr_listener.await;
}
Err(TrySendError::Closed(msgs)) => {
return Err(SendError(msgs));
}
}
}
}
pub fn close(&self) -> bool {
self.channel.close()
}
pub fn is_closed(&self) -> bool {
self.channel.is_closed()
}
pub fn is_empty(&self) -> bool {
self.channel.is_empty()
}
pub fn is_full(&self) -> bool {
self.channel.is_full()
}
pub fn len(&self) -> u64 {
self.channel.len()
}
pub fn capacity(&self) -> Option<u64> {
match self.channel.cap {
u64::MAX => None,
c => Some(c),
}
}
pub fn receiver_count(&self) -> usize {
self.channel.receiver_count.load(Ordering::SeqCst)
}
pub fn sender_count(&self) -> usize {
self.channel.sender_count.load(Ordering::SeqCst)
}
}
impl<I, P> Receiver<I, P>
where
P: Ord,
{
pub fn try_recv(&self) -> Result<(I, P), TryRecvError> {
match (self.channel.is_empty(), self.channel.is_closed()) {
(true, true) => Err(TryRecvError::Closed),
(true, false) => Err(TryRecvError::Empty),
(false, _) => {
let mut heap = self
.channel
.heap
.lock()
.expect("task panicked while holding lock");
let item = heap.pop();
match item {
Some(item) => {
self.channel.len_and_closed.decr();
Ok((item.msg, item.priority))
}
None => Err(TryRecvError::Empty),
}
}
}
}
pub async fn recv(&self) -> Result<(I, P), RecvError> {
loop {
let incr_listener = self.channel.len_and_closed.listen_incr();
match self.try_recv() {
Ok(item) => {
return Ok(item);
}
Err(TryRecvError::Closed) => {
return Err(RecvError);
}
Err(TryRecvError::Empty) => {
incr_listener.await;
}
}
}
}
pub fn close(&self) -> bool {
self.channel.close()
}
pub fn is_closed(&self) -> bool {
self.channel.is_closed()
}
pub fn is_empty(&self) -> bool {
self.channel.is_empty()
}
pub fn is_full(&self) -> bool {
self.channel.is_full()
}
pub fn len(&self) -> u64 {
self.channel.len()
}
pub fn capacity(&self) -> Option<u64> {
match self.channel.cap {
u64::MAX => None,
c => Some(c),
}
}
pub fn receiver_count(&self) -> usize {
self.channel.receiver_count.load(Ordering::SeqCst)
}
pub fn sender_count(&self) -> usize {
self.channel.sender_count.load(Ordering::SeqCst)
}
}
#[derive(Debug)]
struct Item<I, P>
where
P: Eq + Ord,
{
msg: I,
priority: P,
}
impl<I, P> Ord for Item<I, P>
where
P: Eq + Ord,
{
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.priority.cmp(&other.priority)
}
}
impl<I, P> PartialOrd for Item<I, P>
where
P: Eq + Ord,
{
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl<I, P: std::cmp::Eq> PartialEq for Item<I, P>
where
P: Eq + Ord,
{
fn eq(&self, other: &Self) -> bool {
self.priority == other.priority
}
}
impl<I, P> Eq for Item<I, P> where P: Eq + Ord {}
#[derive(PartialEq, Eq, Clone, Copy)]
pub struct SendError<T>(pub T);
impl<T> SendError<T> {
pub fn into_inner(self) -> T {
self.0
}
}
impl<T> error::Error for SendError<T> {}
impl<T> fmt::Debug for SendError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "SendError(..)")
}
}
impl<T> fmt::Display for SendError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "sending into a closed channel")
}
}
#[derive(PartialEq, Eq, Clone, Copy, Debug)]
pub struct RecvError;
impl error::Error for RecvError {}
impl fmt::Display for RecvError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "receiving from an empty and closed channel")
}
}
#[derive(PartialEq, Eq, Clone, Copy)]
pub enum TrySendError<T> {
Full(T),
Closed(T),
}
impl<T> TrySendError<T> {
pub fn into_inner(self) -> T {
match self {
TrySendError::Full(t) => t,
TrySendError::Closed(t) => t,
}
}
pub fn is_full(&self) -> bool {
match self {
TrySendError::Full(_) => true,
TrySendError::Closed(_) => false,
}
}
pub fn is_closed(&self) -> bool {
match self {
TrySendError::Full(_) => false,
TrySendError::Closed(_) => true,
}
}
}
impl<T> error::Error for TrySendError<T> {}
impl<T> fmt::Debug for TrySendError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
TrySendError::Full(..) => write!(f, "Full(..)"),
TrySendError::Closed(..) => write!(f, "Closed(..)"),
}
}
}
impl<T> fmt::Display for TrySendError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
TrySendError::Full(..) => write!(f, "sending into a full channel"),
TrySendError::Closed(..) => write!(f, "sending into a closed channel"),
}
}
}
#[derive(PartialEq, Eq, Clone, Copy, Debug)]
pub enum TryRecvError {
Empty,
Closed,
}
impl TryRecvError {
pub fn is_empty(&self) -> bool {
match self {
TryRecvError::Empty => true,
TryRecvError::Closed => false,
}
}
pub fn is_closed(&self) -> bool {
match self {
TryRecvError::Empty => false,
TryRecvError::Closed => true,
}
}
}
impl error::Error for TryRecvError {}
impl fmt::Display for TryRecvError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
TryRecvError::Empty => write!(f, "receiving from an empty channel"),
TryRecvError::Closed => write!(f, "receiving from an empty and closed channel"),
}
}
}