use crate::shim::sync::Arc;
use core::fmt;
use core::future::Future;
use core::pin::Pin;
use core::task::{Context, Poll};
use crate::atomic_waker::AtomicWaker;
pub mod error {
use core::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct RecvError;
impl fmt::Display for RecvError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "channel closed")
}
}
impl core::error::Error for RecvError {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TryRecvError {
Empty,
Closed,
}
impl fmt::Display for TryRecvError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TryRecvError::Empty => write!(f, "channel empty"),
TryRecvError::Closed => write!(f, "channel closed"),
}
}
}
impl core::error::Error for TryRecvError {}
}
pub use self::error::RecvError;
pub use self::error::TryRecvError;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TakeResult<T> {
Ready(T),
Pending,
Closed,
}
impl<T> TakeResult<T> {
#[inline]
pub fn ok(self) -> Option<T> {
match self {
TakeResult::Ready(v) => Some(v),
_ => None,
}
}
#[inline]
pub fn is_closed(&self) -> bool {
matches!(self, TakeResult::Closed)
}
}
pub trait OneshotStorage: Send + Sync + Sized {
type Value: Send;
fn new() -> Self;
fn store(&self, value: Self::Value);
fn try_take(&self) -> TakeResult<Self::Value>;
fn is_sender_dropped(&self) -> bool;
fn mark_sender_dropped(&self);
fn is_receiver_closed(&self) -> bool;
fn mark_receiver_closed(&self);
}
pub struct Inner<S: OneshotStorage> {
pub(crate) waker: AtomicWaker,
pub(crate) storage: S,
}
impl<S: OneshotStorage> Inner<S> {
#[inline]
pub fn new() -> Arc<Self> {
Arc::new(Self {
waker: AtomicWaker::new(),
storage: S::new(),
})
}
#[inline]
pub fn send(&self, value: S::Value) {
self.storage.store(value);
self.waker.wake();
}
#[inline]
pub fn try_recv(&self) -> TakeResult<S::Value> {
self.storage.try_take()
}
#[inline]
pub fn register_waker(&self, waker: &core::task::Waker) {
self.waker.register(waker);
}
#[inline]
pub fn is_sender_dropped(&self) -> bool {
self.storage.is_sender_dropped()
}
}
impl<S: OneshotStorage> fmt::Debug for Inner<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Inner").finish_non_exhaustive()
}
}
pub struct Sender<S: OneshotStorage> {
pub(crate) inner: Arc<Inner<S>>,
}
impl<S: OneshotStorage> fmt::Debug for Sender<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Sender").finish_non_exhaustive()
}
}
impl<S: OneshotStorage> Sender<S> {
#[inline]
pub fn new() -> (Self, Receiver<S>) {
let inner = Inner::new();
let sender = Sender {
inner: inner.clone(),
};
let receiver = Receiver { inner };
(sender, receiver)
}
#[inline]
pub fn send(self, value: S::Value) -> Result<(), S::Value> {
if self.is_closed() {
return Err(value);
}
self.send_unchecked(value);
Ok(())
}
#[inline]
pub fn send_unchecked(self, value: S::Value) {
self.inner.send(value);
core::mem::forget(self);
}
#[inline]
pub fn is_closed(&self) -> bool {
self.inner.storage.is_receiver_closed() || Arc::strong_count(&self.inner) == 1
}
}
impl<S: OneshotStorage> Drop for Sender<S> {
fn drop(&mut self) {
self.inner.storage.mark_sender_dropped();
self.inner.waker.wake();
}
}
pub struct Receiver<S: OneshotStorage> {
pub(crate) inner: Arc<Inner<S>>,
}
impl<S: OneshotStorage> fmt::Debug for Receiver<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Receiver").finish_non_exhaustive()
}
}
impl<S: OneshotStorage> Unpin for Receiver<S> {}
impl<S: OneshotStorage> Receiver<S> {
#[inline]
pub async fn wait(self) -> Result<S::Value, RecvError> {
self.await
}
#[inline]
pub fn close(&mut self) {
self.inner.storage.mark_receiver_closed();
}
#[inline]
#[cfg(any(feature = "std", feature = "loom", test))]
pub fn blocking_recv(self) -> Result<S::Value, RecvError> {
use crate::shim::atomic::{AtomicBool, Ordering};
use core::task::{RawWaker, RawWakerVTable, Waker};
match self.inner.storage.try_take() {
TakeResult::Ready(value) => return Ok(value),
TakeResult::Closed => return Err(RecvError),
TakeResult::Pending => {}
}
struct ThreadParker {
thread: crate::shim::thread::Thread,
notified: AtomicBool,
}
const VTABLE: RawWakerVTable = RawWakerVTable::new(
|ptr| unsafe {
Arc::increment_strong_count(ptr as *const ThreadParker);
RawWaker::new(ptr, &VTABLE)
},
|ptr| unsafe {
let parker = Arc::from_raw(ptr as *const ThreadParker);
parker.notified.store(true, Ordering::Release);
parker.thread.unpark();
},
|ptr| unsafe {
let parker = &*(ptr as *const ThreadParker);
parker.notified.store(true, Ordering::Release);
parker.thread.unpark();
},
|ptr| unsafe {
Arc::decrement_strong_count(ptr as *const ThreadParker);
},
);
let parker = Arc::new(ThreadParker {
thread: crate::shim::thread::current(),
notified: AtomicBool::new(false),
});
let raw_waker = RawWaker::new(Arc::into_raw(parker.clone()) as *const (), &VTABLE);
let waker = unsafe { Waker::from_raw(raw_waker) };
self.inner.register_waker(&waker);
loop {
match self.inner.storage.try_take() {
TakeResult::Ready(value) => return Ok(value),
TakeResult::Closed => return Err(RecvError),
TakeResult::Pending => {}
}
if Arc::strong_count(&self.inner) == 1 && self.inner.is_sender_dropped() {
return Err(RecvError);
}
if !parker.notified.swap(false, Ordering::Acquire) {
crate::shim::thread::park();
}
}
}
}
impl<S: OneshotStorage> Future for Receiver<S> {
type Output = Result<S::Value, RecvError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
match this.inner.try_recv() {
TakeResult::Ready(value) => return Poll::Ready(Ok(value)),
TakeResult::Closed => return Poll::Ready(Err(RecvError)),
TakeResult::Pending => {}
}
this.inner.register_waker(cx.waker());
match this.inner.try_recv() {
TakeResult::Ready(value) => return Poll::Ready(Ok(value)),
TakeResult::Closed => return Poll::Ready(Err(RecvError)),
TakeResult::Pending => {}
}
if Arc::strong_count(&this.inner) == 1 && this.inner.is_sender_dropped() {
return Poll::Ready(Err(RecvError));
}
Poll::Pending
}
}
#[inline]
pub fn channel<S: OneshotStorage>() -> (Sender<S>, Receiver<S>) {
Sender::new()
}