#![warn(missing_docs, missing_debug_implementations, unreachable_pub)]
mod loom_exports;
mod queue;
use std::error;
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{self, AtomicUsize, Ordering};
use std::sync::Arc;
use std::task::Context;
use std::task::Poll;
use async_event::Event;
use diatomic_waker::primitives::DiatomicWaker;
use futures_core::Stream;
use pin_project_lite::pin_project;
use crate::queue::{PopError, PushError, Queue};
struct Inner<T> {
queue: Queue<T>,
receiver_signal: DiatomicWaker,
sender_signal: Event,
sender_count: AtomicUsize,
}
impl<T> Inner<T> {
fn new(capacity: usize, sender_count: usize) -> Self {
Self {
queue: Queue::new(capacity),
receiver_signal: DiatomicWaker::new(),
sender_signal: Event::new(),
sender_count: AtomicUsize::new(sender_count),
}
}
}
pub struct Sender<T> {
inner: Arc<Inner<T>>,
}
impl<T> Sender<T> {
pub fn try_send(&self, message: T) -> Result<(), TrySendError<T>> {
match self.inner.queue.push(message) {
Ok(()) => {
self.inner.receiver_signal.notify();
Ok(())
}
Err(PushError::Full(v)) => Err(TrySendError::Full(v)),
Err(PushError::Closed(v)) => Err(TrySendError::Closed(v)),
}
}
pub async fn send(&self, message: T) -> Result<(), SendError<T>> {
let mut message = Some(message);
self.inner
.sender_signal
.wait_until(|| {
match self.inner.queue.push(message.take().unwrap()) {
Ok(()) => Some(()),
Err(PushError::Full(m)) => {
message = Some(m);
None
}
Err(PushError::Closed(m)) => {
message = Some(m);
Some(())
}
}
})
.await;
match message {
Some(m) => Err(SendError(m)),
None => {
self.inner.receiver_signal.notify();
Ok(())
}
}
}
pub async fn send_timeout<'a, D>(
&'a self,
message: T,
deadline: D,
) -> Result<(), SendTimeoutError<T>>
where
D: Future<Output = ()> + 'a,
{
let mut message = Some(message);
let res = self
.inner
.sender_signal
.wait_until_or_timeout(
|| {
match self.inner.queue.push(message.take().unwrap()) {
Ok(()) => Some(()),
Err(PushError::Full(m)) => {
message = Some(m);
None
}
Err(PushError::Closed(m)) => {
message = Some(m);
Some(())
}
}
},
deadline,
)
.await;
match (message, res) {
(Some(m), Some(())) => Err(SendTimeoutError::Closed(m)),
(Some(m), None) => Err(SendTimeoutError::Timeout(m)),
_ => {
self.inner.receiver_signal.notify();
Ok(())
}
}
}
pub fn close(&self) {
self.inner.queue.close();
self.inner.receiver_signal.notify();
self.inner.sender_signal.notify_all();
}
pub fn is_closed(&self) -> bool {
self.inner.queue.is_closed()
}
}
impl<T> Clone for Sender<T> {
fn clone(&self) -> Self {
self.inner.sender_count.fetch_add(1, Ordering::Relaxed);
Self {
inner: self.inner.clone(),
}
}
}
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
if self.inner.sender_count.fetch_sub(1, Ordering::Release) == 1
&& !self.inner.queue.is_closed()
{
atomic::fence(Ordering::Acquire);
self.inner.queue.close();
self.inner.receiver_signal.notify();
}
}
}
impl<T> fmt::Debug for Sender<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Sender").finish_non_exhaustive()
}
}
pub struct Receiver<T> {
inner: Arc<Inner<T>>,
}
impl<T> Receiver<T> {
pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
match unsafe { self.inner.queue.pop() } {
Ok(message) => {
self.inner.sender_signal.notify_one();
Ok(message)
}
Err(PopError::Empty) => Err(TryRecvError::Empty),
Err(PopError::Closed) => Err(TryRecvError::Closed),
}
}
pub async fn recv(&mut self) -> Result<T, RecvError> {
RecvFuture { receiver: self }.await
}
pub async fn recv_timeout<D>(&mut self, deadline: D) -> Result<T, RecvTimeoutError>
where
D: Future<Output = ()>,
{
RecvTimeoutFuture {
receiver: self,
deadline,
}
.await
}
pub fn close(&self) {
if !self.inner.queue.is_closed() {
self.inner.queue.close();
self.inner.sender_signal.notify_all();
}
}
}
impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
self.inner.queue.close();
self.inner.sender_signal.notify_all();
}
}
impl<T> fmt::Debug for Receiver<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Receiver").finish_non_exhaustive()
}
}
impl<T> Stream for Receiver<T> {
type Item = T;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
unsafe {
match self.inner.queue.pop() {
Ok(message) => {
self.inner.sender_signal.notify_one();
return Poll::Ready(Some(message));
}
Err(PopError::Closed) => {
return Poll::Ready(None);
}
Err(PopError::Empty) => {}
}
self.inner.receiver_signal.register(cx.waker());
match self.inner.queue.pop() {
Ok(message) => {
self.inner.receiver_signal.unregister();
self.inner.sender_signal.notify_one();
Poll::Ready(Some(message))
}
Err(PopError::Closed) => {
self.inner.receiver_signal.unregister();
Poll::Ready(None)
}
Err(PopError::Empty) => Poll::Pending,
}
}
}
}
struct RecvFuture<'a, T> {
receiver: &'a mut Receiver<T>,
}
impl<'a, T> Future for RecvFuture<'a, T> {
type Output = Result<T, RecvError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match Pin::new(&mut self.receiver).poll_next(cx) {
Poll::Ready(Some(v)) => Poll::Ready(Ok(v)),
Poll::Ready(None) => Poll::Ready(Err(RecvError)),
Poll::Pending => Poll::Pending,
}
}
}
pin_project! {
struct RecvTimeoutFuture<'a, T, D> where D: Future<Output=()> {
receiver: &'a mut Receiver<T>,
#[pin]
deadline: D,
}
}
impl<'a, T, D> Future for RecvTimeoutFuture<'a, T, D>
where
D: Future<Output = ()>,
{
type Output = Result<T, RecvTimeoutError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let receiver = this.receiver;
let deadline = this.deadline;
match Pin::new(receiver).poll_next(cx) {
Poll::Ready(Some(v)) => Poll::Ready(Ok(v)),
Poll::Ready(None) => Poll::Ready(Err(RecvTimeoutError::Closed)),
Poll::Pending => match deadline.poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(()) => Poll::Ready(Err(RecvTimeoutError::Timeout)),
},
}
}
}
pub fn channel<T>(capacity: usize) -> (Sender<T>, Receiver<T>) {
let inner = Arc::new(Inner::new(capacity, 1));
let sender = Sender {
inner: inner.clone(),
};
let receiver = Receiver { inner };
(sender, receiver)
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum TrySendError<T> {
Full(T),
Closed(T),
}
impl<T: fmt::Debug> error::Error for TrySendError<T> {}
impl<T> fmt::Display for TrySendError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TrySendError::Full(_) => "sending into a full channel".fmt(f),
TrySendError::Closed(_) => "sending into a closed channel".fmt(f),
}
}
}
#[derive(Clone, Copy, Eq, PartialEq)]
pub struct SendError<T>(pub T);
impl<T> error::Error for SendError<T> {}
impl<T> fmt::Debug for SendError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SendError").finish_non_exhaustive()
}
}
impl<T> fmt::Display for SendError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
"sending into a closed channel".fmt(f)
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum SendTimeoutError<T> {
Timeout(T),
Closed(T),
}
impl<T: fmt::Debug> error::Error for SendTimeoutError<T> {}
impl<T> fmt::Display for SendTimeoutError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SendTimeoutError::Timeout(_) => "the deadline for sending has elapsed".fmt(f),
SendTimeoutError::Closed(_) => "sending into a closed channel".fmt(f),
}
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum TryRecvError {
Empty,
Closed,
}
impl error::Error for TryRecvError {}
impl fmt::Display for TryRecvError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TryRecvError::Empty => "receiving from an empty channel".fmt(f),
TryRecvError::Closed => "receiving from a closed channel".fmt(f),
}
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct RecvError;
impl error::Error for RecvError {}
impl fmt::Display for RecvError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
"receiving from a closed channel".fmt(f)
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum RecvTimeoutError {
Timeout,
Closed,
}
impl error::Error for RecvTimeoutError {}
impl fmt::Display for RecvTimeoutError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
RecvTimeoutError::Timeout => "the deadline for receiving has elapsed".fmt(f),
RecvTimeoutError::Closed => "receiving from a closed channel".fmt(f),
}
}
}