use std::cell::UnsafeCell;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU8, AtomicUsize, Ordering};
use std::task::{Poll, Waker};
use std::ops::{Deref, DerefMut};
const EMPTY: u8 = 0;
const STORED: u8 = 1;
const REGISTERING: u8 = 2;
struct RxWakerSlot {
task_ptr: std::sync::atomic::AtomicPtr<u8>,
cross_ctx: *const crate::cross_wake::CrossWakeContext,
state: AtomicU8,
}
unsafe impl Send for RxWakerSlot {}
unsafe impl Sync for RxWakerSlot {}
impl RxWakerSlot {
fn new(cross_ctx: *const crate::cross_wake::CrossWakeContext) -> Self {
Self {
task_ptr: std::sync::atomic::AtomicPtr::new(std::ptr::null_mut()),
cross_ctx,
state: AtomicU8::new(EMPTY),
}
}
fn try_register_local(&self, waker: &Waker) -> bool {
crate::waker::task_ptr_from_local_waker(waker).is_some_and(|task_ptr| {
let prev = self.state.swap(REGISTERING, Ordering::Acquire);
debug_assert_ne!(prev, REGISTERING);
self.task_ptr.store(task_ptr, Ordering::Relaxed);
self.state.store(STORED, Ordering::Release);
true
})
}
fn wake(&self) -> bool {
if self
.state
.compare_exchange(STORED, EMPTY, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
let task_ptr = self.task_ptr.swap(std::ptr::null_mut(), Ordering::Acquire);
if !task_ptr.is_null() {
let ctx = unsafe { &*self.cross_ctx };
unsafe { crate::cross_wake::wake_task_cross_thread(task_ptr, ctx) };
return true;
}
}
false
}
fn has_waker(&self) -> bool {
self.state.load(Ordering::Acquire) == STORED
}
}
struct FallbackWaker {
state: AtomicU8,
waker: UnsafeCell<Option<Waker>>,
}
unsafe impl Send for FallbackWaker {}
unsafe impl Sync for FallbackWaker {}
impl FallbackWaker {
fn new() -> Self {
Self {
state: AtomicU8::new(EMPTY),
waker: UnsafeCell::new(None),
}
}
fn register(&self, waker: &Waker) {
let prev = self.state.swap(REGISTERING, Ordering::Acquire);
debug_assert_ne!(prev, REGISTERING);
unsafe { *self.waker.get() = Some(waker.clone()) };
self.state.store(STORED, Ordering::Release);
}
fn wake(&self) -> bool {
if self
.state
.compare_exchange(STORED, EMPTY, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
if let Some(w) = unsafe { (*self.waker.get()).take() } {
w.wake();
return true;
}
}
false
}
fn has_waker(&self) -> bool {
self.state.load(Ordering::Acquire) == STORED
}
}
impl Drop for FallbackWaker {
fn drop(&mut self) {
*self.waker.get_mut() = None;
}
}
struct SenderWakerNode {
waker: UnsafeCell<Option<Waker>>,
next: std::sync::atomic::AtomicPtr<SenderWakerNode>,
queued: AtomicBool,
cancelled: AtomicBool,
}
unsafe impl Send for SenderWakerNode {}
unsafe impl Sync for SenderWakerNode {}
impl SenderWakerNode {
fn new() -> Self {
Self {
waker: UnsafeCell::new(None),
next: std::sync::atomic::AtomicPtr::new(std::ptr::null_mut()),
queued: AtomicBool::new(false),
cancelled: AtomicBool::new(false),
}
}
}
struct SenderWaitList {
head: std::sync::atomic::AtomicPtr<SenderWakerNode>,
}
impl SenderWaitList {
fn new() -> Self {
Self {
head: std::sync::atomic::AtomicPtr::new(std::ptr::null_mut()),
}
}
fn push(&self, node: &Arc<SenderWakerNode>) {
let ptr = Arc::as_ptr(node).cast_mut();
std::mem::forget(Arc::clone(node));
unsafe { (*ptr).queued.store(true, Ordering::Relaxed) };
loop {
let head = self.head.load(Ordering::Acquire);
unsafe { (*ptr).next.store(head, Ordering::Relaxed) };
if self
.head
.compare_exchange_weak(head, ptr, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
break;
}
}
}
fn wake_one(&self) -> bool {
let head = self.head.swap(std::ptr::null_mut(), Ordering::AcqRel);
if head.is_null() {
return false;
}
let mut cursor = head;
let mut woken = false;
while !cursor.is_null() {
let next = unsafe { (*cursor).next.load(Ordering::Acquire) };
let cancelled = unsafe { (*cursor).cancelled.load(Ordering::Acquire) };
unsafe {
(*cursor).queued.store(false, Ordering::Release);
(*cursor)
.next
.store(std::ptr::null_mut(), Ordering::Relaxed);
}
if !cancelled && !woken {
let waker = unsafe { (*cursor).waker.get().read() };
unsafe { (*cursor).waker.get().write(None) };
unsafe { Arc::decrement_strong_count(cursor) };
if let Some(w) = waker {
w.wake();
woken = true;
}
} else if !cancelled {
loop {
let cur_head = self.head.load(Ordering::Acquire);
unsafe { (*cursor).next.store(cur_head, Ordering::Relaxed) };
unsafe { (*cursor).queued.store(true, Ordering::Relaxed) };
if self
.head
.compare_exchange_weak(
cur_head,
cursor,
Ordering::AcqRel,
Ordering::Relaxed,
)
.is_ok()
{
break;
}
}
} else {
unsafe { Arc::decrement_strong_count(cursor) };
}
cursor = next;
}
woken
}
fn has_waiters(&self) -> bool {
!self.head.load(Ordering::Acquire).is_null()
}
fn wake_all(&self) {
let mut node = self.head.swap(std::ptr::null_mut(), Ordering::AcqRel);
while !node.is_null() {
let next = unsafe { (*node).next.load(Ordering::Acquire) };
let cancelled = unsafe { (*node).cancelled.load(Ordering::Acquire) };
unsafe {
(*node).next.store(std::ptr::null_mut(), Ordering::Relaxed);
(*node).queued.store(false, Ordering::Release);
}
if !cancelled {
let waker = unsafe { (*node).waker.get().read() };
unsafe { (*node).waker.get().write(None) };
if let Some(w) = waker {
w.wake();
}
}
unsafe { Arc::decrement_strong_count(node) };
node = next;
}
}
}
struct Inner {
rx_slot: RxWakerSlot,
rx_fallback: FallbackWaker,
tx_waiters: SenderWaitList,
_cross_wake_owner: Arc<crate::cross_wake::CrossWakeContext>,
sender_count: AtomicUsize,
rx_closed: AtomicBool,
}
unsafe impl Send for Inner {}
unsafe impl Sync for Inner {}
impl Inner {
fn wake_rx(&self) {
if !self.rx_slot.wake() {
self.rx_fallback.wake();
}
}
fn has_rx_waker(&self) -> bool {
self.rx_slot.has_waker() || self.rx_fallback.has_waker()
}
}
pub struct WriteClaim<'a> {
inner: nexus_logbuf::queue::mpsc::WriteClaim<'a>,
notify: &'a Inner,
}
impl WriteClaim<'_> {
pub fn commit(self) {
let notify = self.notify;
self.inner.commit();
if notify.has_rx_waker() {
notify.wake_rx();
}
}
pub fn len(&self) -> usize {
self.inner.len()
}
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
}
impl Deref for WriteClaim<'_> {
type Target = [u8];
fn deref(&self) -> &[u8] {
&self.inner
}
}
impl DerefMut for WriteClaim<'_> {
fn deref_mut(&mut self) -> &mut [u8] {
&mut self.inner
}
}
pub struct ReadClaim<'a> {
inner: nexus_logbuf::queue::mpsc::ReadClaim<'a>,
notify: &'a Inner,
}
impl ReadClaim<'_> {
pub fn len(&self) -> usize {
self.inner.len()
}
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
}
impl Deref for ReadClaim<'_> {
type Target = [u8];
fn deref(&self) -> &[u8] {
&self.inner
}
}
impl Drop for ReadClaim<'_> {
fn drop(&mut self) {
if self.notify.tx_waiters.has_waiters() {
self.notify.tx_waiters.wake_one();
}
}
}
#[derive(Debug)]
pub enum ClaimError {
Closed,
TooLarge,
}
impl std::fmt::Display for ClaimError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Closed => f.write_str("byte channel closed"),
Self::TooLarge => f.write_str("message exceeds buffer capacity"),
}
}
}
impl std::error::Error for ClaimError {}
#[derive(Debug)]
pub struct RecvError;
impl std::fmt::Display for RecvError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("byte channel closed")
}
}
impl std::error::Error for RecvError {}
pub fn channel(capacity: usize) -> (Sender, Receiver) {
crate::context::assert_in_runtime("mpsc_bytes::channel() called outside Runtime::block_on");
let cross_ctx = crate::cross_wake::cross_wake_context()
.expect("mpsc_bytes::channel() requires runtime context");
let (producer, consumer) = nexus_logbuf::queue::mpsc::new(capacity);
let rx_slot = RxWakerSlot::new(Arc::as_ptr(&cross_ctx));
let inner = Arc::new(Inner {
rx_slot,
rx_fallback: FallbackWaker::new(),
tx_waiters: SenderWaitList::new(),
_cross_wake_owner: cross_ctx,
sender_count: AtomicUsize::new(1),
rx_closed: AtomicBool::new(false),
});
(
Sender {
producer,
inner: inner.clone(),
wake_node: Arc::new(SenderWakerNode::new()),
},
Receiver { consumer, inner },
)
}
pub struct Sender {
producer: nexus_logbuf::queue::mpsc::Producer,
inner: Arc<Inner>,
wake_node: Arc<SenderWakerNode>,
}
impl Sender {
pub fn claim(&mut self, len: usize) -> ClaimFut<'_> {
ClaimFut { sender: self, len }
}
pub fn try_claim(&mut self, len: usize) -> Result<WriteClaim<'_>, nexus_logbuf::TryClaimError> {
let inner_claim = self.producer.try_claim(len)?;
Ok(WriteClaim {
inner: inner_claim,
notify: &self.inner,
})
}
}
impl Clone for Sender {
fn clone(&self) -> Self {
self.inner.sender_count.fetch_add(1, Ordering::Relaxed);
Self {
producer: self.producer.clone(),
inner: self.inner.clone(),
wake_node: Arc::new(SenderWakerNode::new()),
}
}
}
impl Drop for Sender {
fn drop(&mut self) {
self.wake_node.cancelled.store(true, Ordering::Release);
if self.inner.sender_count.fetch_sub(1, Ordering::AcqRel) == 1 {
self.inner.wake_rx();
}
}
}
unsafe impl Send for Sender {}
pub struct ClaimFut<'a> {
sender: &'a mut Sender,
len: usize,
}
impl<'a> Future for ClaimFut<'a> {
type Output = Result<WriteClaim<'a>, ClaimError>;
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
let this = unsafe { &mut *std::pin::Pin::into_inner_unchecked(self) };
let sender: &'a mut Sender = unsafe { &mut *(this.sender as *mut Sender) };
if sender.inner.rx_closed.load(Ordering::Acquire) {
return Poll::Ready(Err(ClaimError::Closed));
}
if this.len > sender.producer.capacity() {
return Poll::Ready(Err(ClaimError::TooLarge));
}
match sender.producer.try_claim(this.len) {
Ok(inner_claim) => Poll::Ready(Ok(WriteClaim {
inner: inner_claim,
notify: &sender.inner,
})),
Err(nexus_logbuf::TryClaimError::Full | nexus_logbuf::TryClaimError::ZeroLength) => {
let node = &sender.wake_node;
if !node.queued.load(Ordering::Acquire) {
unsafe { *node.waker.get() = Some(cx.waker().clone()) };
sender.inner.tx_waiters.push(node);
}
Poll::Pending
}
}
}
}
unsafe impl Send for ClaimFut<'_> {}
pub struct Receiver {
consumer: nexus_logbuf::queue::mpsc::Consumer,
inner: Arc<Inner>,
}
impl Receiver {
pub fn recv(&mut self) -> RecvFut<'_> {
RecvFut { receiver: self }
}
pub fn try_recv(&mut self) -> Option<ReadClaim<'_>> {
let inner_claim = self.consumer.try_claim()?;
Some(ReadClaim {
inner: inner_claim,
notify: &self.inner,
})
}
}
pub struct RecvFut<'a> {
receiver: &'a mut Receiver,
}
impl<'a> Future for RecvFut<'a> {
type Output = Result<ReadClaim<'a>, RecvError>;
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
let this = unsafe { &mut *std::pin::Pin::into_inner_unchecked(self) };
let receiver: &'a mut Receiver = unsafe { &mut *(this.receiver as *mut Receiver) };
if let Some(inner_claim) = receiver.consumer.try_claim() {
return Poll::Ready(Ok(ReadClaim {
inner: inner_claim,
notify: &receiver.inner,
}));
}
if receiver.inner.sender_count.load(Ordering::Acquire) == 0 {
return Poll::Ready(Err(RecvError));
}
if !receiver.inner.rx_slot.try_register_local(cx.waker()) {
receiver.inner.rx_fallback.register(cx.waker());
}
Poll::Pending
}
}
unsafe impl Send for RecvFut<'_> {}
impl Drop for Receiver {
fn drop(&mut self) {
self.inner.rx_closed.store(true, Ordering::Release);
self.inner.tx_waiters.wake_all();
}
}
unsafe impl Send for Receiver {}
#[cfg(test)]
mod tests {
use super::*;
fn test_channel(capacity: usize) -> (Sender, Receiver) {
let poll = mio::Poll::new().unwrap();
let mio_waker = Arc::new(mio::Waker::new(poll.registry(), mio::Token(usize::MAX)).unwrap());
let cross_ctx = Arc::new(crate::cross_wake::CrossWakeContext {
queue: crate::cross_wake::CrossWakeQueue::new(),
mio_waker,
parked: AtomicBool::new(false),
});
let (producer, consumer) = nexus_logbuf::queue::mpsc::new(capacity);
let rx_slot = RxWakerSlot::new(Arc::as_ptr(&cross_ctx));
let inner = Arc::new(Inner {
rx_slot,
rx_fallback: FallbackWaker::new(),
tx_waiters: SenderWaitList::new(),
_cross_wake_owner: cross_ctx,
sender_count: AtomicUsize::new(1),
rx_closed: AtomicBool::new(false),
});
(
Sender {
producer,
inner: inner.clone(),
wake_node: Arc::new(SenderWakerNode::new()),
},
Receiver { consumer, inner },
)
}
fn try_send(tx: &mut Sender, data: &[u8]) {
let mut claim = tx.try_claim(data.len()).unwrap();
claim.copy_from_slice(data);
claim.commit();
}
#[test]
fn claim_commit_recv() {
let (mut tx, mut rx) = test_channel(4096);
try_send(&mut tx, b"hello");
try_send(&mut tx, b"world");
let msg = rx.try_recv().unwrap();
assert_eq!(&*msg, b"hello");
drop(msg);
let msg = rx.try_recv().unwrap();
assert_eq!(&*msg, b"world");
drop(msg);
assert!(rx.try_recv().is_none());
}
#[test]
fn fifo_ordering() {
let (mut tx, mut rx) = test_channel(4096);
for i in 0u32..10 {
try_send(&mut tx, &i.to_le_bytes());
}
for i in 0u32..10 {
let msg = rx.try_recv().unwrap();
assert_eq!(&*msg, &i.to_le_bytes());
}
}
#[test]
fn sender_drop_signals_closed() {
let (mut tx, mut rx) = test_channel(4096);
try_send(&mut tx, b"last");
drop(tx);
let msg = rx.try_recv().unwrap();
assert_eq!(&*msg, b"last");
drop(msg);
assert!(rx.try_recv().is_none());
}
#[test]
fn receiver_drop_signals_sender() {
let (_tx, rx) = test_channel(4096);
drop(rx);
assert!(_tx.inner.rx_closed.load(Ordering::Acquire));
}
#[test]
fn variable_length_messages() {
let (mut tx, mut rx) = test_channel(8192);
try_send(&mut tx, b"hi");
try_send(&mut tx, &vec![0xABu8; 100]);
try_send(&mut tx, &vec![0xCDu8; 1000]);
let msg = rx.try_recv().unwrap();
assert_eq!(msg.len(), 2);
drop(msg);
let msg = rx.try_recv().unwrap();
assert_eq!(msg.len(), 100);
drop(msg);
let msg = rx.try_recv().unwrap();
assert_eq!(msg.len(), 1000);
}
#[test]
fn cross_thread_claim_send() {
let (mut tx, mut rx) = test_channel(64 * 1024);
let handle = std::thread::spawn(move || {
for i in 0u64..100 {
try_send(&mut tx, &i.to_le_bytes());
}
});
handle.join().unwrap();
for i in 0u64..100 {
let msg = rx.try_recv().unwrap();
assert_eq!(&*msg, &i.to_le_bytes());
}
}
#[test]
fn stress_sequential() {
let (mut tx, mut rx) = test_channel(4096);
let data = [0xFFu8; 32];
for _ in 0..10_000 {
try_send(&mut tx, &data);
let msg = rx.try_recv().unwrap();
assert_eq!(msg.len(), 32);
}
}
#[test]
fn claim_without_commit_aborts() {
let (mut tx, mut rx) = test_channel(4096);
let claim = tx.try_claim(10).unwrap();
drop(claim);
try_send(&mut tx, b"after_abort");
let msg = rx.try_recv().unwrap();
assert_eq!(&*msg, b"after_abort");
}
#[test]
fn multiple_senders() {
let (mut tx1, mut rx) = test_channel(64 * 1024);
let mut tx2 = tx1.clone();
try_send(&mut tx1, b"from_tx1");
try_send(&mut tx2, b"from_tx2");
try_send(&mut tx1, b"tx1_again");
let msg = rx.try_recv().unwrap();
assert_eq!(&*msg, b"from_tx1");
drop(msg);
let msg = rx.try_recv().unwrap();
assert_eq!(&*msg, b"from_tx2");
drop(msg);
let msg = rx.try_recv().unwrap();
assert_eq!(&*msg, b"tx1_again");
drop(msg);
assert!(rx.try_recv().is_none());
}
#[test]
fn sender_drop_while_queued() {
let (mut tx1, mut rx) = test_channel(4096);
let tx2 = tx1.clone();
try_send(&mut tx1, b"data");
drop(tx2);
let msg = rx.try_recv().unwrap();
assert_eq!(&*msg, b"data");
drop(msg);
try_send(&mut tx1, b"more");
let msg = rx.try_recv().unwrap();
assert_eq!(&*msg, b"more");
}
}