use std::cell::UnsafeCell;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicPtr, AtomicU8, AtomicUsize, Ordering};
use std::task::{Context, Poll, Waker};
use super::{RecvError, SendError, TryRecvError, TrySendError};
struct RxWakerSlot {
task_ptr: AtomicPtr<u8>,
cross_ctx: *const crate::cross_wake::CrossWakeContext,
state: AtomicU8,
}
const EMPTY: u8 = 0;
const STORED: u8 = 1;
const REGISTERING: u8 = 2;
unsafe impl Send for RxWakerSlot {}
unsafe impl Sync for RxWakerSlot {}
impl RxWakerSlot {
fn new(cross_ctx: *const crate::cross_wake::CrossWakeContext) -> Self {
Self {
task_ptr: AtomicPtr::new(std::ptr::null_mut()),
cross_ctx,
state: AtomicU8::new(EMPTY),
}
}
fn register(&self, task_ptr: *mut u8) {
let prev = self.state.swap(REGISTERING, Ordering::Acquire);
debug_assert_ne!(prev, REGISTERING, "concurrent register on RxWakerSlot");
self.task_ptr.store(task_ptr, Ordering::Relaxed);
self.state.store(STORED, Ordering::Release);
}
fn try_register_local(&self, waker: &Waker) -> bool {
crate::waker::task_ptr_from_local_waker(waker).is_some_and(|task_ptr| {
self.register(task_ptr);
true
})
}
fn wake(&self) -> bool {
if self
.state
.compare_exchange(STORED, EMPTY, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
let task_ptr = self.task_ptr.swap(std::ptr::null_mut(), Ordering::Acquire);
if !task_ptr.is_null() {
let ctx = unsafe { &*self.cross_ctx };
unsafe { crate::cross_wake::wake_task_cross_thread(task_ptr, ctx) };
return true;
}
}
false
}
fn has_waker(&self) -> bool {
self.state.load(Ordering::Acquire) == STORED
}
}
struct SenderWakerNode {
waker: UnsafeCell<Option<Waker>>,
next: AtomicPtr<SenderWakerNode>,
queued: AtomicBool,
cancelled: AtomicBool,
}
unsafe impl Send for SenderWakerNode {}
unsafe impl Sync for SenderWakerNode {}
impl SenderWakerNode {
fn new() -> Self {
Self {
waker: UnsafeCell::new(None),
next: AtomicPtr::new(std::ptr::null_mut()),
queued: AtomicBool::new(false),
cancelled: AtomicBool::new(false),
}
}
}
struct SenderWaitList {
head: AtomicPtr<SenderWakerNode>,
}
impl SenderWaitList {
fn new() -> Self {
Self {
head: AtomicPtr::new(std::ptr::null_mut()),
}
}
fn push(&self, node: &Arc<SenderWakerNode>) {
let ptr = Arc::as_ptr(node).cast_mut();
std::mem::forget(Arc::clone(node));
unsafe { (*ptr).queued.store(true, Ordering::Relaxed) };
loop {
let head = self.head.load(Ordering::Acquire);
unsafe { (*ptr).next.store(head, Ordering::Relaxed) };
if self
.head
.compare_exchange_weak(head, ptr, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
break;
}
}
}
fn wake_one(&self) -> bool {
let head = self.head.swap(std::ptr::null_mut(), Ordering::AcqRel);
if head.is_null() {
return false;
}
let mut cursor = head;
let mut woken = false;
while !cursor.is_null() {
let next = unsafe { (*cursor).next.load(Ordering::Acquire) };
let cancelled = unsafe { (*cursor).cancelled.load(Ordering::Acquire) };
unsafe {
(*cursor).queued.store(false, Ordering::Release);
(*cursor)
.next
.store(std::ptr::null_mut(), Ordering::Relaxed);
}
if !cancelled && !woken {
let waker = unsafe { (*cursor).waker.get().read() };
unsafe { (*cursor).waker.get().write(None) };
unsafe { Arc::decrement_strong_count(cursor) };
if let Some(w) = waker {
w.wake();
woken = true;
}
} else if !cancelled {
loop {
let cur_head = self.head.load(Ordering::Acquire);
unsafe { (*cursor).next.store(cur_head, Ordering::Relaxed) };
unsafe { (*cursor).queued.store(true, Ordering::Relaxed) };
if self
.head
.compare_exchange_weak(
cur_head,
cursor,
Ordering::AcqRel,
Ordering::Relaxed,
)
.is_ok()
{
break;
}
}
} else {
unsafe { Arc::decrement_strong_count(cursor) };
}
cursor = next;
}
woken
}
fn has_waiters(&self) -> bool {
!self.head.load(Ordering::Acquire).is_null()
}
fn wake_all(&self) {
let mut node = self.head.swap(std::ptr::null_mut(), Ordering::AcqRel);
while !node.is_null() {
let next = unsafe { (*node).next.load(Ordering::Acquire) };
let cancelled = unsafe { (*node).cancelled.load(Ordering::Acquire) };
unsafe {
(*node).next.store(std::ptr::null_mut(), Ordering::Relaxed);
(*node).queued.store(false, Ordering::Release);
}
if !cancelled {
let waker = unsafe { (*node).waker.get().read() };
unsafe { (*node).waker.get().write(None) };
if let Some(w) = waker {
w.wake();
}
}
unsafe { Arc::decrement_strong_count(node) };
node = next;
}
}
}
struct FallbackWaker {
state: AtomicU8,
waker: UnsafeCell<Option<Waker>>,
}
unsafe impl Send for FallbackWaker {}
unsafe impl Sync for FallbackWaker {}
impl FallbackWaker {
fn new() -> Self {
Self {
state: AtomicU8::new(EMPTY),
waker: UnsafeCell::new(None),
}
}
fn register(&self, waker: &Waker) {
let prev = self.state.swap(REGISTERING, Ordering::Acquire);
debug_assert_ne!(prev, REGISTERING);
unsafe { *self.waker.get() = Some(waker.clone()) };
self.state.store(STORED, Ordering::Release);
}
fn wake(&self) -> bool {
if self
.state
.compare_exchange(STORED, EMPTY, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
let waker = unsafe { (*self.waker.get()).take() };
if let Some(w) = waker {
w.wake();
return true;
}
}
false
}
fn has_waker(&self) -> bool {
self.state.load(Ordering::Acquire) == STORED
}
}
impl Drop for FallbackWaker {
fn drop(&mut self) {
*self.waker.get_mut() = None;
}
}
struct Inner<T> {
producer: nexus_queue::mpsc::Producer<T>,
consumer: nexus_queue::mpsc::Consumer<T>,
rx_slot: RxWakerSlot,
rx_fallback: FallbackWaker,
tx_waiters: SenderWaitList,
_cross_wake_owner: Arc<crate::cross_wake::CrossWakeContext>,
sender_count: AtomicUsize,
rx_closed: AtomicBool,
}
unsafe impl<T: Send> Send for Inner<T> {}
unsafe impl<T: Send> Sync for Inner<T> {}
impl<T> Inner<T> {
fn wake_rx(&self) {
if !self.rx_slot.wake() {
self.rx_fallback.wake();
}
}
fn has_rx_waker(&self) -> bool {
self.rx_slot.has_waker() || self.rx_fallback.has_waker()
}
}
pub fn channel<T: Send>(capacity: usize) -> (Sender<T>, Receiver<T>) {
crate::context::assert_in_runtime("mpsc::channel() called outside Runtime::block_on");
assert!(capacity > 0, "channel capacity must be > 0");
let cross_ctx = crate::cross_wake::cross_wake_context()
.expect("mpsc::channel() requires runtime context for cross-thread wake");
let (producer, consumer) = nexus_queue::mpsc::ring_buffer(capacity);
let rx_slot = RxWakerSlot::new(Arc::as_ptr(&cross_ctx));
let inner = Arc::new(Inner {
producer,
consumer,
rx_slot,
rx_fallback: FallbackWaker::new(),
tx_waiters: SenderWaitList::new(),
_cross_wake_owner: cross_ctx,
sender_count: AtomicUsize::new(1),
rx_closed: AtomicBool::new(false),
});
let tx = Sender {
inner: inner.clone(),
wake_node: Arc::new(SenderWakerNode::new()),
};
let rx = Receiver { inner };
(tx, rx)
}
pub struct Sender<T> {
inner: Arc<Inner<T>>,
wake_node: Arc<SenderWakerNode>,
}
impl<T: Send> Sender<T> {
pub fn send(&self, value: T) -> SendFut<'_, T> {
SendFut {
sender: self,
value: Some(value),
}
}
pub fn try_send(&self, value: T) -> Result<(), TrySendError<T>> {
if self.inner.rx_closed.load(Ordering::Acquire) {
return Err(TrySendError::Closed(value));
}
match self.inner.producer.push(value) {
Ok(()) => {
if self.inner.has_rx_waker() {
self.inner.wake_rx();
}
Ok(())
}
Err(nexus_queue::Full(value)) => Err(TrySendError::Full(value)),
}
}
}
impl<T> Clone for Sender<T> {
fn clone(&self) -> Self {
self.inner.sender_count.fetch_add(1, Ordering::Relaxed);
Self {
inner: self.inner.clone(),
wake_node: Arc::new(SenderWakerNode::new()),
}
}
}
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
self.wake_node.cancelled.store(true, Ordering::Release);
if self.inner.sender_count.fetch_sub(1, Ordering::AcqRel) == 1 {
self.inner.wake_rx();
}
}
}
unsafe impl<T: Send> Send for Sender<T> {}
unsafe impl<T: Send> Sync for Sender<T> {}
pub struct SendFut<'a, T> {
sender: &'a Sender<T>,
value: Option<T>,
}
impl<T: Send> Future for SendFut<'_, T> {
type Output = Result<(), SendError<T>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = unsafe { self.get_unchecked_mut() };
let inner = &this.sender.inner;
if inner.rx_closed.load(Ordering::Acquire) {
let value = this.value.take().expect("polled after completion");
return Poll::Ready(Err(SendError(value)));
}
let value = this.value.take().expect("polled after completion");
match inner.producer.push(value) {
Ok(()) => {
if inner.has_rx_waker() {
inner.wake_rx();
}
Poll::Ready(Ok(()))
}
Err(nexus_queue::Full(value)) => {
this.value = Some(value);
let node = &this.sender.wake_node;
if !node.queued.load(Ordering::Acquire) {
unsafe { *node.waker.get() = Some(cx.waker().clone()) };
inner.tx_waiters.push(node);
}
Poll::Pending
}
}
}
}
unsafe impl<T: Send> Send for SendFut<'_, T> {}
pub struct Receiver<T> {
inner: Arc<Inner<T>>,
}
impl<T: Send> Receiver<T> {
pub fn recv(&self) -> RecvFut<'_, T> {
RecvFut { receiver: self }
}
#[allow(clippy::option_if_let_else)]
pub fn try_recv(&self) -> Result<T, TryRecvError> {
match self.inner.consumer.pop() {
Some(value) => {
if self.inner.tx_waiters.has_waiters() {
self.inner.tx_waiters.wake_one();
}
Ok(value)
}
None => {
if self.inner.sender_count.load(Ordering::Acquire) == 0 {
Err(TryRecvError::Closed)
} else {
Err(TryRecvError::Empty)
}
}
}
}
}
impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
self.inner.rx_closed.store(true, Ordering::Release);
self.inner.tx_waiters.wake_all();
}
}
unsafe impl<T: Send> Send for Receiver<T> {}
pub struct RecvFut<'a, T> {
receiver: &'a Receiver<T>,
}
impl<T: Send> Future for RecvFut<'_, T> {
type Output = Result<T, RecvError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let inner = &self.receiver.inner;
if let Some(value) = inner.consumer.pop() {
if inner.tx_waiters.has_waiters() {
inner.tx_waiters.wake_one();
}
return Poll::Ready(Ok(value));
}
if inner.sender_count.load(Ordering::Acquire) == 0 {
return Poll::Ready(Err(RecvError));
}
if !inner.rx_slot.try_register_local(cx.waker()) {
inner.rx_fallback.register(cx.waker());
}
if let Some(value) = inner.consumer.pop() {
if inner.tx_waiters.has_waiters() {
inner.tx_waiters.wake_one();
}
return Poll::Ready(Ok(value));
}
if inner.sender_count.load(Ordering::Acquire) == 0 {
return Poll::Ready(Err(RecvError));
}
Poll::Pending
}
}
unsafe impl<T: Send> Send for RecvFut<'_, T> {}
#[cfg(test)]
mod tests {
use super::*;
fn test_channel<T: Send>(capacity: usize) -> (Sender<T>, Receiver<T>) {
let poll = mio::Poll::new().unwrap();
let mio_waker = Arc::new(mio::Waker::new(poll.registry(), mio::Token(usize::MAX)).unwrap());
let cross_ctx = Arc::new(crate::cross_wake::CrossWakeContext {
queue: crate::cross_wake::CrossWakeQueue::new(),
mio_waker,
parked: AtomicBool::new(false),
});
let (producer, consumer) = nexus_queue::mpsc::ring_buffer(capacity);
let rx_slot = RxWakerSlot::new(Arc::as_ptr(&cross_ctx));
let inner = Arc::new(Inner {
producer,
consumer,
rx_slot,
rx_fallback: FallbackWaker::new(),
tx_waiters: SenderWaitList::new(),
_cross_wake_owner: cross_ctx,
sender_count: AtomicUsize::new(1),
rx_closed: AtomicBool::new(false),
});
(
Sender {
inner: inner.clone(),
wake_node: Arc::new(SenderWakerNode::new()),
},
Receiver { inner },
)
}
#[test]
fn send_recv_single() {
let (tx, rx) = test_channel::<u32>(4);
tx.try_send(1).unwrap();
tx.try_send(2).unwrap();
tx.try_send(3).unwrap();
assert_eq!(rx.try_recv().unwrap(), 1);
assert_eq!(rx.try_recv().unwrap(), 2);
assert_eq!(rx.try_recv().unwrap(), 3);
assert_eq!(rx.try_recv(), Err(TryRecvError::Empty));
}
#[test]
fn fifo_ordering() {
let (tx, rx) = test_channel(8);
for i in 0..8u32 {
tx.try_send(i).unwrap();
}
for i in 0..8u32 {
assert_eq!(rx.try_recv().unwrap(), i);
}
}
#[test]
fn try_send_full() {
let (tx, rx) = test_channel(2);
tx.try_send(1u32).unwrap();
tx.try_send(2).unwrap();
let err = tx.try_send(3).unwrap_err();
assert!(err.is_full());
assert_eq!(err.into_inner(), 3);
assert_eq!(rx.try_recv().unwrap(), 1);
tx.try_send(3).unwrap();
}
#[test]
fn try_recv_empty() {
let (tx, rx) = test_channel::<u32>(4);
assert_eq!(rx.try_recv(), Err(TryRecvError::Empty));
tx.try_send(1).unwrap();
assert_eq!(rx.try_recv().unwrap(), 1);
assert_eq!(rx.try_recv(), Err(TryRecvError::Empty));
}
#[test]
fn sender_drop_signals_closed() {
let (tx, rx) = test_channel::<u32>(4);
tx.try_send(42).unwrap();
drop(tx);
assert_eq!(rx.try_recv().unwrap(), 42);
assert_eq!(rx.try_recv(), Err(TryRecvError::Closed));
}
#[test]
fn receiver_drop_signals_closed() {
let (tx, rx) = test_channel::<u32>(4);
drop(rx);
let err = tx.try_send(1).unwrap_err();
assert!(err.is_closed());
}
#[test]
fn multiple_senders() {
let (tx1, rx) = test_channel(8);
let tx2 = tx1.clone();
tx1.try_send(1u32).unwrap();
tx2.try_send(2).unwrap();
tx1.try_send(3).unwrap();
assert_eq!(rx.try_recv().unwrap(), 1);
assert_eq!(rx.try_recv().unwrap(), 2);
assert_eq!(rx.try_recv().unwrap(), 3);
}
#[test]
fn sender_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<Sender<u64>>();
}
#[test]
fn receiver_is_send() {
fn assert_send<T: Send>() {}
assert_send::<Receiver<u64>>();
}
#[test]
#[cfg_attr(miri, ignore)]
fn cross_thread_try_send() {
let (tx, rx) = test_channel::<u64>(128);
let handle = std::thread::spawn(move || {
for i in 0..100 {
tx.try_send(i).unwrap();
}
});
handle.join().unwrap();
for i in 0..100u64 {
assert_eq!(rx.try_recv().unwrap(), i);
}
}
#[test]
#[cfg_attr(miri, ignore)]
fn cross_thread_multiple_producers() {
let (tx, rx) = test_channel::<u64>(512);
let handles: Vec<_> = (0..4u64)
.map(|id| {
let tx = tx.clone();
std::thread::spawn(move || {
for i in 0..100 {
tx.try_send(id * 1000 + i).unwrap();
}
})
})
.collect();
drop(tx);
for h in handles {
h.join().unwrap();
}
let mut received = Vec::new();
while let Ok(v) = rx.try_recv() {
received.push(v);
}
assert_eq!(received.len(), 400);
}
#[test]
fn stress_sequential() {
let (tx, rx) = test_channel(64);
for i in 0..100_000u64 {
tx.try_send(i).unwrap();
assert_eq!(rx.try_recv().unwrap(), i);
}
}
#[test]
fn sender_drop_while_queued_in_waiter_list() {
let (tx1, rx) = test_channel::<u32>(1);
let tx2 = tx1.clone();
tx1.try_send(1).unwrap();
drop(tx2);
assert_eq!(rx.try_recv().unwrap(), 1);
tx1.try_send(2).unwrap();
assert_eq!(rx.try_recv().unwrap(), 2);
}
#[test]
fn multiple_senders_dropped_then_receiver_dropped() {
let (tx1, rx) = test_channel::<u32>(1);
let tx2 = tx1.clone();
let tx3 = tx1.clone();
tx1.try_send(1).unwrap();
drop(tx1);
drop(tx2);
drop(tx3);
assert_eq!(rx.try_recv().unwrap(), 1);
assert_eq!(rx.try_recv(), Err(TryRecvError::Closed));
drop(rx);
}
#[test]
#[cfg_attr(miri, ignore)]
fn cross_thread_sender_drop() {
let (tx, rx) = test_channel::<u64>(128);
let handles: Vec<_> = (0..8)
.map(|_| {
let tx = tx.clone();
std::thread::spawn(move || {
for i in 0..100 {
let _ = tx.try_send(i);
}
})
})
.collect();
drop(tx);
for h in handles {
h.join().unwrap();
}
while rx.try_recv().is_ok() {}
assert_eq!(rx.try_recv(), Err(TryRecvError::Closed));
}
}