use std::cell::UnsafeCell;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU8, Ordering};
use std::task::{Context, Poll, Waker};
use super::{RecvError, SendError, TryRecvError, TrySendError};
const EMPTY: u8 = 0;
const STORED: u8 = 1;
const REGISTERING: u8 = 2;
struct RxWakerSlot {
task_ptr: std::sync::atomic::AtomicPtr<u8>,
cross_ctx: *const crate::cross_wake::CrossWakeContext,
state: AtomicU8,
}
unsafe impl Send for RxWakerSlot {}
unsafe impl Sync for RxWakerSlot {}
impl RxWakerSlot {
fn new(cross_ctx: *const crate::cross_wake::CrossWakeContext) -> Self {
Self {
task_ptr: std::sync::atomic::AtomicPtr::new(std::ptr::null_mut()),
cross_ctx,
state: AtomicU8::new(EMPTY),
}
}
fn register(&self, task_ptr: *mut u8) {
let prev = self.state.swap(REGISTERING, Ordering::Acquire);
debug_assert_ne!(prev, REGISTERING);
self.task_ptr.store(task_ptr, Ordering::Relaxed);
self.state.store(STORED, Ordering::Release);
}
fn try_register_local(&self, waker: &Waker) -> bool {
crate::waker::task_ptr_from_local_waker(waker).is_some_and(|task_ptr| {
self.register(task_ptr);
true
})
}
fn wake(&self) -> bool {
if self
.state
.compare_exchange(STORED, EMPTY, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
let task_ptr = self.task_ptr.swap(std::ptr::null_mut(), Ordering::Acquire);
if !task_ptr.is_null() {
let ctx = unsafe { &*self.cross_ctx };
unsafe { crate::cross_wake::wake_task_cross_thread(task_ptr, ctx) };
return true;
}
}
false
}
fn has_waker(&self) -> bool {
self.state.load(Ordering::Acquire) == STORED
}
}
struct FallbackWaker {
state: AtomicU8,
waker: UnsafeCell<Option<Waker>>,
}
unsafe impl Send for FallbackWaker {}
unsafe impl Sync for FallbackWaker {}
impl FallbackWaker {
fn new() -> Self {
Self {
state: AtomicU8::new(EMPTY),
waker: UnsafeCell::new(None),
}
}
fn register(&self, waker: &Waker) {
let prev = self.state.swap(REGISTERING, Ordering::Acquire);
debug_assert_ne!(prev, REGISTERING);
unsafe { *self.waker.get() = Some(waker.clone()) };
self.state.store(STORED, Ordering::Release);
}
fn wake(&self) -> bool {
if self
.state
.compare_exchange(STORED, EMPTY, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
if let Some(w) = unsafe { (*self.waker.get()).take() } {
w.wake();
return true;
}
}
false
}
fn has_waker(&self) -> bool {
self.state.load(Ordering::Acquire) == STORED
}
}
impl Drop for FallbackWaker {
fn drop(&mut self) {
*self.waker.get_mut() = None;
}
}
struct TxWakerSlot {
state: AtomicU8,
waker: UnsafeCell<Option<Waker>>,
}
unsafe impl Send for TxWakerSlot {}
unsafe impl Sync for TxWakerSlot {}
impl TxWakerSlot {
fn new() -> Self {
Self {
state: AtomicU8::new(EMPTY),
waker: UnsafeCell::new(None),
}
}
fn register(&self, waker: &Waker) {
let prev = self.state.swap(REGISTERING, Ordering::Acquire);
debug_assert_ne!(prev, REGISTERING);
unsafe { *self.waker.get() = Some(waker.clone()) };
self.state.store(STORED, Ordering::Release);
}
fn wake(&self) -> bool {
if self
.state
.compare_exchange(STORED, EMPTY, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
if let Some(w) = unsafe { (*self.waker.get()).take() } {
w.wake();
return true;
}
}
false
}
fn has_waker(&self) -> bool {
self.state.load(Ordering::Acquire) == STORED
}
}
impl Drop for TxWakerSlot {
fn drop(&mut self) {
*self.waker.get_mut() = None;
}
}
struct Inner<T> {
producer: nexus_queue::spsc::Producer<T>,
consumer: nexus_queue::spsc::Consumer<T>,
rx_slot: RxWakerSlot,
rx_fallback: FallbackWaker,
tx_waker: TxWakerSlot,
_cross_wake_owner: Arc<crate::cross_wake::CrossWakeContext>,
tx_alive: AtomicBool,
rx_closed: AtomicBool,
}
unsafe impl<T: Send> Send for Inner<T> {}
unsafe impl<T: Send> Sync for Inner<T> {}
impl<T> Inner<T> {
fn wake_rx(&self) {
if !self.rx_slot.wake() {
self.rx_fallback.wake();
}
}
fn has_rx_waker(&self) -> bool {
self.rx_slot.has_waker() || self.rx_fallback.has_waker()
}
}
pub fn channel<T: Send>(capacity: usize) -> (Sender<T>, Receiver<T>) {
crate::context::assert_in_runtime("spsc::channel() called outside Runtime::block_on");
assert!(capacity > 0, "channel capacity must be > 0");
let cross_ctx = crate::cross_wake::cross_wake_context()
.expect("spsc::channel() requires runtime context for cross-thread wake");
let (producer, consumer) = nexus_queue::spsc::ring_buffer(capacity);
let rx_slot = RxWakerSlot::new(Arc::as_ptr(&cross_ctx));
let inner = Arc::new(Inner {
producer,
consumer,
rx_slot,
rx_fallback: FallbackWaker::new(),
tx_waker: TxWakerSlot::new(),
_cross_wake_owner: cross_ctx,
tx_alive: AtomicBool::new(true),
rx_closed: AtomicBool::new(false),
});
let tx = Sender {
inner: inner.clone(),
};
let rx = Receiver { inner };
(tx, rx)
}
pub struct Sender<T> {
inner: Arc<Inner<T>>,
}
impl<T: Send> Sender<T> {
pub fn send(&self, value: T) -> SendFut<'_, T> {
SendFut {
sender: self,
value: Some(value),
}
}
pub fn try_send(&self, value: T) -> Result<(), TrySendError<T>> {
if self.inner.rx_closed.load(Ordering::Acquire) {
return Err(TrySendError::Closed(value));
}
match self.inner.producer.push(value) {
Ok(()) => {
if self.inner.has_rx_waker() {
self.inner.wake_rx();
}
Ok(())
}
Err(nexus_queue::Full(value)) => Err(TrySendError::Full(value)),
}
}
}
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
self.inner.tx_alive.store(false, Ordering::Release);
self.inner.wake_rx();
}
}
unsafe impl<T: Send> Send for Sender<T> {}
pub struct SendFut<'a, T> {
sender: &'a Sender<T>,
value: Option<T>,
}
impl<T: Send> Future for SendFut<'_, T> {
type Output = Result<(), SendError<T>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = unsafe { self.get_unchecked_mut() };
let inner = &this.sender.inner;
if inner.rx_closed.load(Ordering::Acquire) {
let value = this.value.take().expect("polled after completion");
return Poll::Ready(Err(SendError(value)));
}
let value = this.value.take().expect("polled after completion");
match inner.producer.push(value) {
Ok(()) => {
if inner.has_rx_waker() {
inner.wake_rx();
}
Poll::Ready(Ok(()))
}
Err(nexus_queue::Full(value)) => {
this.value = Some(value);
inner.tx_waker.register(cx.waker());
Poll::Pending
}
}
}
}
unsafe impl<T: Send> Send for SendFut<'_, T> {}
pub struct Receiver<T> {
inner: Arc<Inner<T>>,
}
impl<T: Send> Receiver<T> {
pub fn recv(&self) -> RecvFut<'_, T> {
RecvFut { receiver: self }
}
#[allow(clippy::option_if_let_else)]
pub fn try_recv(&self) -> Result<T, TryRecvError> {
match self.inner.consumer.pop() {
Some(value) => {
if self.inner.tx_waker.has_waker() {
self.inner.tx_waker.wake();
}
Ok(value)
}
None => {
if self.inner.tx_alive.load(Ordering::Acquire) {
Err(TryRecvError::Empty)
} else {
Err(TryRecvError::Closed)
}
}
}
}
}
impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
self.inner.rx_closed.store(true, Ordering::Release);
self.inner.tx_waker.wake();
}
}
unsafe impl<T: Send> Send for Receiver<T> {}
pub struct RecvFut<'a, T> {
receiver: &'a Receiver<T>,
}
impl<T: Send> Future for RecvFut<'_, T> {
type Output = Result<T, RecvError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let inner = &self.receiver.inner;
if let Some(value) = inner.consumer.pop() {
if inner.tx_waker.has_waker() {
inner.tx_waker.wake();
}
return Poll::Ready(Ok(value));
}
if !inner.tx_alive.load(Ordering::Acquire) {
return Poll::Ready(Err(RecvError));
}
if !inner.rx_slot.try_register_local(cx.waker()) {
inner.rx_fallback.register(cx.waker());
}
if let Some(value) = inner.consumer.pop() {
if inner.tx_waker.has_waker() {
inner.tx_waker.wake();
}
return Poll::Ready(Ok(value));
}
if !inner.tx_alive.load(Ordering::Acquire) {
return Poll::Ready(Err(RecvError));
}
Poll::Pending
}
}
unsafe impl<T: Send> Send for RecvFut<'_, T> {}
#[cfg(test)]
mod tests {
use super::*;
fn test_channel<T: Send>(capacity: usize) -> (Sender<T>, Receiver<T>) {
let poll = mio::Poll::new().unwrap();
let mio_waker = Arc::new(mio::Waker::new(poll.registry(), mio::Token(usize::MAX)).unwrap());
let cross_ctx = Arc::new(crate::cross_wake::CrossWakeContext {
queue: crate::cross_wake::CrossWakeQueue::new(),
mio_waker,
parked: AtomicBool::new(false),
});
let (producer, consumer) = nexus_queue::spsc::ring_buffer(capacity);
let rx_slot = RxWakerSlot::new(Arc::as_ptr(&cross_ctx));
let inner = Arc::new(Inner {
producer,
consumer,
rx_slot,
rx_fallback: FallbackWaker::new(),
tx_waker: TxWakerSlot::new(),
_cross_wake_owner: cross_ctx,
tx_alive: AtomicBool::new(true),
rx_closed: AtomicBool::new(false),
});
(
Sender {
inner: inner.clone(),
},
Receiver { inner },
)
}
#[test]
fn send_recv_single() {
let (tx, rx) = test_channel::<u32>(4);
tx.try_send(1).unwrap();
tx.try_send(2).unwrap();
tx.try_send(3).unwrap();
assert_eq!(rx.try_recv().unwrap(), 1);
assert_eq!(rx.try_recv().unwrap(), 2);
assert_eq!(rx.try_recv().unwrap(), 3);
assert_eq!(rx.try_recv(), Err(TryRecvError::Empty));
}
#[test]
fn fifo_ordering() {
let (tx, rx) = test_channel(8);
for i in 0..8u32 {
tx.try_send(i).unwrap();
}
for i in 0..8u32 {
assert_eq!(rx.try_recv().unwrap(), i);
}
}
#[test]
fn try_send_full() {
let (tx, rx) = test_channel(2);
tx.try_send(1u32).unwrap();
tx.try_send(2).unwrap();
let err = tx.try_send(3).unwrap_err();
assert!(err.is_full());
assert_eq!(err.into_inner(), 3);
assert_eq!(rx.try_recv().unwrap(), 1);
tx.try_send(3).unwrap();
}
#[test]
fn try_recv_empty() {
let (tx, rx) = test_channel::<u32>(4);
assert_eq!(rx.try_recv(), Err(TryRecvError::Empty));
tx.try_send(1).unwrap();
assert_eq!(rx.try_recv().unwrap(), 1);
assert_eq!(rx.try_recv(), Err(TryRecvError::Empty));
}
#[test]
fn sender_drop_signals_closed() {
let (tx, rx) = test_channel::<u32>(4);
tx.try_send(42).unwrap();
drop(tx);
assert_eq!(rx.try_recv().unwrap(), 42);
assert_eq!(rx.try_recv(), Err(TryRecvError::Closed));
}
#[test]
fn receiver_drop_signals_closed() {
let (tx, rx) = test_channel::<u32>(4);
drop(rx);
let err = tx.try_send(1).unwrap_err();
assert!(err.is_closed());
}
#[test]
fn sender_is_send() {
fn assert_send<T: Send>() {}
assert_send::<Sender<u64>>();
}
#[test]
fn receiver_is_send() {
fn assert_send<T: Send>() {}
assert_send::<Receiver<u64>>();
}
#[test]
fn cross_thread_try_send() {
let (tx, rx) = test_channel::<u64>(128);
let handle = std::thread::spawn(move || {
for i in 0..100 {
tx.try_send(i).unwrap();
}
});
handle.join().unwrap();
for i in 0..100u64 {
assert_eq!(rx.try_recv().unwrap(), i);
}
}
#[test]
fn stress_sequential() {
let (tx, rx) = test_channel(64);
for i in 0..100_000u64 {
tx.try_send(i).unwrap();
assert_eq!(rx.try_recv().unwrap(), i);
}
}
#[test]
fn sender_drop_while_receiver_alive() {
let (tx, rx) = test_channel::<u32>(4);
tx.try_send(1).unwrap();
tx.try_send(2).unwrap();
drop(tx);
assert_eq!(rx.try_recv().unwrap(), 1);
assert_eq!(rx.try_recv().unwrap(), 2);
assert_eq!(rx.try_recv(), Err(TryRecvError::Closed));
drop(rx);
}
}