mod receiver;
mod sender;
pub use self::{
receiver::{Canceled, Receiver},
sender::Sender,
};
use crate::sync::spsc::SpscInner;
use alloc::sync::Arc;
use core::{
cell::UnsafeCell,
mem::MaybeUninit,
sync::atomic::{AtomicU8, Ordering},
task::Waker,
};
#[allow(clippy::identity_op)]
const TX_WAKER_STORED: u8 = 1 << 0;
const RX_WAKER_STORED: u8 = 1 << 1;
const COMPLETE: u8 = 1 << 2;
struct Inner<T> {
state: AtomicU8,
data: UnsafeCell<Option<T>>,
rx_waker: UnsafeCell<MaybeUninit<Waker>>,
tx_waker: UnsafeCell<MaybeUninit<Waker>>,
}
#[inline]
pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
let inner = Arc::new(Inner::new());
let sender = Sender::new(Arc::clone(&inner));
let receiver = Receiver::new(inner);
(sender, receiver)
}
unsafe impl<T: Send> Send for Inner<T> {}
unsafe impl<T: Send> Sync for Inner<T> {}
impl<T> Inner<T> {
#[inline]
fn new() -> Self {
Self {
state: AtomicU8::new(0),
data: UnsafeCell::new(None),
rx_waker: UnsafeCell::new(MaybeUninit::zeroed()),
tx_waker: UnsafeCell::new(MaybeUninit::zeroed()),
}
}
}
impl<T> SpscInner<AtomicU8, u8> for Inner<T> {
const COMPLETE: u8 = COMPLETE;
const RX_WAKER_STORED: u8 = RX_WAKER_STORED;
const TX_WAKER_STORED: u8 = TX_WAKER_STORED;
const ZERO: u8 = 0;
#[inline]
fn state_load(&self, order: Ordering) -> u8 {
self.state.load(order)
}
#[inline]
fn compare_exchange_weak(
&self,
current: u8,
new: u8,
success: Ordering,
failure: Ordering,
) -> Result<u8, u8> {
self.state.compare_exchange_weak(current, new, success, failure)
}
#[inline]
unsafe fn rx_waker_mut(&self) -> &mut MaybeUninit<Waker> {
unsafe { &mut *self.rx_waker.get() }
}
#[inline]
unsafe fn tx_waker_mut(&self) -> &mut MaybeUninit<Waker> {
unsafe { &mut *self.tx_waker.get() }
}
}
#[cfg(test)]
mod tests {
use super::*;
use core::{
future::Future,
pin::Pin,
sync::atomic::AtomicUsize,
task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
};
struct Counter(AtomicUsize);
impl Counter {
fn to_waker(&'static self) -> Waker {
unsafe fn clone(counter: *const ()) -> RawWaker {
RawWaker::new(counter, &VTABLE)
}
unsafe fn wake(counter: *const ()) {
unsafe { (*(counter as *const Counter)).0.fetch_add(1, Ordering::SeqCst) };
}
static VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake, drop);
unsafe { Waker::from_raw(RawWaker::new(self as *const _ as *const (), &VTABLE)) }
}
}
#[test]
fn send_sync() {
static COUNTER: Counter = Counter(AtomicUsize::new(0));
let (tx, mut rx) = channel::<usize>();
assert_eq!(tx.send(314), Ok(()));
let waker = COUNTER.to_waker();
let mut cx = Context::from_waker(&waker);
COUNTER.0.store(0, Ordering::SeqCst);
assert_eq!(Pin::new(&mut rx).poll(&mut cx), Poll::Ready(Ok(314)));
assert_eq!(COUNTER.0.load(Ordering::SeqCst), 0);
}
#[test]
fn send_async() {
static COUNTER: Counter = Counter(AtomicUsize::new(0));
let (tx, mut rx) = channel::<usize>();
let waker = COUNTER.to_waker();
let mut cx = Context::from_waker(&waker);
COUNTER.0.store(0, Ordering::SeqCst);
assert_eq!(Pin::new(&mut rx).poll(&mut cx), Poll::Pending);
assert_eq!(tx.send(314), Ok(()));
assert_eq!(Pin::new(&mut rx).poll(&mut cx), Poll::Ready(Ok(314)));
assert_eq!(COUNTER.0.load(Ordering::SeqCst), 1);
}
}