1use std::cell::UnsafeCell;
33use std::sync::Arc;
34use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
35use std::task::{Poll, Waker};
36
37use std::ops::{Deref, DerefMut};
38
39use crate::cross_wake::{FallbackWaker, TaskWakerSlot};
40
41struct SenderWakerNode {
46 waker: UnsafeCell<Option<Waker>>,
47 next: std::sync::atomic::AtomicPtr<SenderWakerNode>,
48 queued: AtomicBool,
49 cancelled: AtomicBool,
52}
53
54unsafe impl Send for SenderWakerNode {}
55unsafe impl Sync for SenderWakerNode {}
56
57impl SenderWakerNode {
58 fn new() -> Self {
59 Self {
60 waker: UnsafeCell::new(None),
61 next: std::sync::atomic::AtomicPtr::new(std::ptr::null_mut()),
62 queued: AtomicBool::new(false),
63 cancelled: AtomicBool::new(false),
64 }
65 }
66}
67
68struct SenderWaitList {
75 head: std::sync::atomic::AtomicPtr<SenderWakerNode>,
76}
77
78impl SenderWaitList {
79 fn new() -> Self {
80 Self {
81 head: std::sync::atomic::AtomicPtr::new(std::ptr::null_mut()),
82 }
83 }
84
85 fn push(&self, node: &Arc<SenderWakerNode>) {
90 let ptr = Arc::as_ptr(node).cast_mut();
91 std::mem::forget(Arc::clone(node));
93
94 unsafe { (*ptr).queued.store(true, Ordering::Relaxed) };
95 loop {
96 let head = self.head.load(Ordering::Acquire);
97 unsafe { (*ptr).next.store(head, Ordering::Relaxed) };
98 if self
99 .head
100 .compare_exchange_weak(head, ptr, Ordering::AcqRel, Ordering::Relaxed)
101 .is_ok()
102 {
103 break;
104 }
105 }
106 }
107
108 fn wake_one(&self) -> bool {
112 let head = self.head.swap(std::ptr::null_mut(), Ordering::AcqRel);
113 if head.is_null() {
114 return false;
115 }
116
117 let mut cursor = head;
118 let mut woken = false;
119 while !cursor.is_null() {
120 let next = unsafe { (*cursor).next.load(Ordering::Acquire) };
121 let cancelled = unsafe { (*cursor).cancelled.load(Ordering::Acquire) };
122
123 unsafe {
124 (*cursor).queued.store(false, Ordering::Release);
125 (*cursor)
126 .next
127 .store(std::ptr::null_mut(), Ordering::Relaxed);
128 }
129
130 if !cancelled && !woken {
131 let waker = unsafe { (*cursor).waker.get().read() };
132 unsafe { (*cursor).waker.get().write(None) };
133 unsafe { Arc::decrement_strong_count(cursor) };
135 if let Some(w) = waker {
136 w.wake();
137 woken = true;
138 }
139 } else if !cancelled {
140 loop {
143 let cur_head = self.head.load(Ordering::Acquire);
144 unsafe { (*cursor).next.store(cur_head, Ordering::Relaxed) };
145 unsafe { (*cursor).queued.store(true, Ordering::Relaxed) };
146 if self
147 .head
148 .compare_exchange_weak(
149 cur_head,
150 cursor,
151 Ordering::AcqRel,
152 Ordering::Relaxed,
153 )
154 .is_ok()
155 {
156 break;
157 }
158 }
159 } else {
160 unsafe { Arc::decrement_strong_count(cursor) };
162 }
163
164 cursor = next;
165 }
166
167 woken
168 }
169
170 fn has_waiters(&self) -> bool {
171 !self.head.load(Ordering::Acquire).is_null()
172 }
173
174 fn wake_all(&self) {
176 let mut node = self.head.swap(std::ptr::null_mut(), Ordering::AcqRel);
177 while !node.is_null() {
178 let next = unsafe { (*node).next.load(Ordering::Acquire) };
179 let cancelled = unsafe { (*node).cancelled.load(Ordering::Acquire) };
180 unsafe {
181 (*node).next.store(std::ptr::null_mut(), Ordering::Relaxed);
182 (*node).queued.store(false, Ordering::Release);
183 }
184 if !cancelled {
185 let waker = unsafe { (*node).waker.get().read() };
186 unsafe { (*node).waker.get().write(None) };
187 if let Some(w) = waker {
188 w.wake();
189 }
190 }
191 unsafe { Arc::decrement_strong_count(node) };
193 node = next;
194 }
195 }
196}
197
198struct Inner {
203 rx_slot: TaskWakerSlot,
204 rx_fallback: FallbackWaker,
205 tx_waiters: SenderWaitList,
206 _cross_wake_owner: Arc<crate::cross_wake::CrossWakeContext>,
207 sender_count: AtomicUsize,
208 rx_closed: AtomicBool,
209}
210
211unsafe impl Send for Inner {}
212unsafe impl Sync for Inner {}
213
214impl Inner {
215 fn wake_rx(&self) {
216 if !self.rx_slot.wake() {
217 self.rx_fallback.wake();
218 }
219 }
220
221 fn has_rx_waker(&self) -> bool {
222 self.rx_slot.has_waker() || self.rx_fallback.has_waker()
223 }
224}
225
226pub struct WriteClaim<'a> {
235 inner: nexus_logbuf::queue::mpsc::WriteClaim<'a>,
236 notify: &'a Inner,
237}
238
239impl WriteClaim<'_> {
240 pub fn commit(self) {
243 let notify = self.notify;
244 self.inner.commit();
245 if notify.has_rx_waker() {
246 notify.wake_rx();
247 }
248 }
249
250 pub fn len(&self) -> usize {
252 self.inner.len()
253 }
254
255 pub fn is_empty(&self) -> bool {
257 self.inner.is_empty()
258 }
259}
260
261impl Deref for WriteClaim<'_> {
262 type Target = [u8];
263 fn deref(&self) -> &[u8] {
264 &self.inner
265 }
266}
267
268impl DerefMut for WriteClaim<'_> {
269 fn deref_mut(&mut self) -> &mut [u8] {
270 &mut self.inner
271 }
272}
273
274pub struct ReadClaim<'a> {
283 inner: nexus_logbuf::queue::mpsc::ReadClaim<'a>,
284 notify: &'a Inner,
285}
286
287impl ReadClaim<'_> {
288 pub fn len(&self) -> usize {
290 self.inner.len()
291 }
292
293 pub fn is_empty(&self) -> bool {
295 self.inner.is_empty()
296 }
297}
298
299impl Deref for ReadClaim<'_> {
300 type Target = [u8];
301 fn deref(&self) -> &[u8] {
302 &self.inner
303 }
304}
305
306impl Drop for ReadClaim<'_> {
307 fn drop(&mut self) {
308 if self.notify.tx_waiters.has_waiters() {
315 self.notify.tx_waiters.wake_one();
316 }
317 }
318}
319
320#[derive(Debug)]
329#[non_exhaustive]
330pub enum ClaimError {
331 Closed,
333 TooLarge,
335}
336
337impl std::fmt::Display for ClaimError {
338 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
339 match self {
340 Self::Closed => f.write_str("byte channel closed"),
341 Self::TooLarge => f.write_str("message exceeds buffer capacity"),
342 }
343 }
344}
345
346impl std::error::Error for ClaimError {}
347
348#[derive(Debug)]
350pub struct RecvError;
351
352impl std::fmt::Display for RecvError {
353 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
354 f.write_str("byte channel closed")
355 }
356}
357
358impl std::error::Error for RecvError {}
359
360pub fn channel(capacity: usize) -> (Sender, Receiver) {
375 crate::context::assert_in_runtime("mpsc_bytes::channel() called outside Runtime::block_on");
376
377 let cross_ctx = crate::cross_wake::cross_wake_context()
378 .expect("mpsc_bytes::channel() requires runtime context");
379
380 let (producer, consumer) = nexus_logbuf::queue::mpsc::new(capacity);
381 let rx_slot = TaskWakerSlot::new(Arc::as_ptr(&cross_ctx));
382
383 let inner = Arc::new(Inner {
384 rx_slot,
385 rx_fallback: FallbackWaker::new(),
386 tx_waiters: SenderWaitList::new(),
387 _cross_wake_owner: cross_ctx,
388 sender_count: AtomicUsize::new(1),
389 rx_closed: AtomicBool::new(false),
390 });
391
392 (
393 Sender {
394 producer,
395 inner: inner.clone(),
396 wake_node: Arc::new(SenderWakerNode::new()),
397 },
398 Receiver { consumer, inner },
399 )
400}
401
402pub struct Sender {
410 producer: nexus_logbuf::queue::mpsc::Producer,
411 inner: Arc<Inner>,
412 wake_node: Arc<SenderWakerNode>,
415}
416
417impl Sender {
418 pub fn claim(&mut self, len: usize) -> ClaimFut<'_> {
432 ClaimFut { sender: self, len }
433 }
434
435 pub fn try_claim(&mut self, len: usize) -> Result<WriteClaim<'_>, nexus_logbuf::BufferFull> {
442 let inner_claim = self.producer.try_claim(len)?;
443 Ok(WriteClaim {
444 inner: inner_claim,
445 notify: &self.inner,
446 })
447 }
448}
449
450impl Clone for Sender {
451 fn clone(&self) -> Self {
452 self.inner.sender_count.fetch_add(1, Ordering::Relaxed);
453 Self {
454 producer: self.producer.clone(),
455 inner: self.inner.clone(),
456 wake_node: Arc::new(SenderWakerNode::new()),
457 }
458 }
459}
460
461impl Drop for Sender {
462 fn drop(&mut self) {
463 self.wake_node.cancelled.store(true, Ordering::Release);
469
470 if self.inner.sender_count.fetch_sub(1, Ordering::AcqRel) == 1 {
471 self.inner.wake_rx();
473 }
474 }
475}
476
477unsafe impl Send for Sender {}
479
480pub struct ClaimFut<'a> {
486 sender: &'a mut Sender,
487 len: usize,
488}
489
490impl<'a> Future for ClaimFut<'a> {
491 type Output = Result<WriteClaim<'a>, ClaimError>;
492
493 fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
494 let this = unsafe { &mut *std::pin::Pin::into_inner_unchecked(self) };
495 let sender: &'a mut Sender = unsafe { &mut *(this.sender as *mut Sender) };
500
501 assert!(this.len > 0, "payload length must be non-zero");
505
506 if sender.inner.rx_closed.load(Ordering::Acquire) {
507 return Poll::Ready(Err(ClaimError::Closed));
508 }
509
510 if this.len > sender.producer.capacity() {
511 return Poll::Ready(Err(ClaimError::TooLarge));
512 }
513
514 if let Ok(inner_claim) = sender.producer.try_claim(this.len) {
515 return Poll::Ready(Ok(WriteClaim {
516 inner: inner_claim,
517 notify: &sender.inner,
518 }));
519 }
520 let node = &sender.wake_node;
522 if !node.queued.load(Ordering::Acquire) {
523 unsafe { *node.waker.get() = Some(cx.waker().clone()) };
526 sender.inner.tx_waiters.push(node);
527 }
528 Poll::Pending
529 }
530}
531
532unsafe impl Send for ClaimFut<'_> {}
533
534pub struct Receiver {
542 consumer: nexus_logbuf::queue::mpsc::Consumer,
543 inner: Arc<Inner>,
544}
545
546impl Receiver {
547 pub fn recv(&mut self) -> RecvFut<'_> {
552 RecvFut { receiver: self }
553 }
554
555 pub fn try_recv(&mut self) -> Option<ReadClaim<'_>> {
557 let inner_claim = self.consumer.try_claim()?;
558 Some(ReadClaim {
559 inner: inner_claim,
560 notify: &self.inner,
561 })
562 }
563}
564
565pub struct RecvFut<'a> {
567 receiver: &'a mut Receiver,
568}
569
570impl Drop for RecvFut<'_> {
571 fn drop(&mut self) {
572 self.receiver.inner.rx_slot.clear();
573 }
574}
575
576impl<'a> Future for RecvFut<'a> {
577 type Output = Result<ReadClaim<'a>, RecvError>;
578
579 fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
580 let this = unsafe { &mut *std::pin::Pin::into_inner_unchecked(self) };
584
585 let receiver: &'a mut Receiver = unsafe { &mut *(this.receiver as *mut Receiver) };
590
591 if let Some(inner_claim) = receiver.consumer.try_claim() {
593 return Poll::Ready(Ok(ReadClaim {
594 inner: inner_claim,
595 notify: &receiver.inner,
596 }));
597 }
598
599 if receiver.inner.sender_count.load(Ordering::Acquire) == 0 {
601 return Poll::Ready(Err(RecvError));
602 }
603
604 if !receiver.inner.rx_slot.try_register_local(cx.waker()) {
606 receiver.inner.rx_fallback.register(cx.waker());
607 }
608
609 Poll::Pending
610 }
611}
612
613unsafe impl Send for RecvFut<'_> {}
614
615impl Drop for Receiver {
616 fn drop(&mut self) {
617 self.inner.rx_closed.store(true, Ordering::Release);
618 self.inner.tx_waiters.wake_all();
619 }
620}
621
622unsafe impl Send for Receiver {}
623
624#[cfg(test)]
629mod tests {
630 use super::*;
631
632 fn test_channel(capacity: usize) -> (Sender, Receiver) {
633 let poll = mio::Poll::new().unwrap();
634 let mio_waker = Arc::new(mio::Waker::new(poll.registry(), mio::Token(usize::MAX)).unwrap());
635 let cross_ctx = Arc::new(crate::cross_wake::CrossWakeContext {
636 queue: crate::cross_wake::CrossWakeQueue::new(),
637 mio_waker,
638 parked: AtomicBool::new(false),
639 });
640
641 let (producer, consumer) = nexus_logbuf::queue::mpsc::new(capacity);
642 let rx_slot = TaskWakerSlot::new(Arc::as_ptr(&cross_ctx));
643
644 let inner = Arc::new(Inner {
645 rx_slot,
646 rx_fallback: FallbackWaker::new(),
647 tx_waiters: SenderWaitList::new(),
648 _cross_wake_owner: cross_ctx,
649 sender_count: AtomicUsize::new(1),
650 rx_closed: AtomicBool::new(false),
651 });
652
653 (
654 Sender {
655 producer,
656 inner: inner.clone(),
657 wake_node: Arc::new(SenderWakerNode::new()),
658 },
659 Receiver { consumer, inner },
660 )
661 }
662
663 fn try_send(tx: &mut Sender, data: &[u8]) {
664 let mut claim = tx.try_claim(data.len()).unwrap();
665 claim.copy_from_slice(data);
666 claim.commit();
667 }
668
669 #[test]
670 fn claim_commit_recv() {
671 let (mut tx, mut rx) = test_channel(4096);
672 try_send(&mut tx, b"hello");
673 try_send(&mut tx, b"world");
674
675 let msg = rx.try_recv().unwrap();
676 assert_eq!(&*msg, b"hello");
677 drop(msg);
678
679 let msg = rx.try_recv().unwrap();
680 assert_eq!(&*msg, b"world");
681 drop(msg);
682
683 assert!(rx.try_recv().is_none());
684 }
685
686 #[test]
687 fn fifo_ordering() {
688 let (mut tx, mut rx) = test_channel(4096);
689 for i in 0u32..10 {
690 try_send(&mut tx, &i.to_le_bytes());
691 }
692 for i in 0u32..10 {
693 let msg = rx.try_recv().unwrap();
694 assert_eq!(&*msg, &i.to_le_bytes());
695 }
696 }
697
698 #[test]
699 fn sender_drop_signals_closed() {
700 let (mut tx, mut rx) = test_channel(4096);
701 try_send(&mut tx, b"last");
702 drop(tx);
703
704 let msg = rx.try_recv().unwrap();
705 assert_eq!(&*msg, b"last");
706 drop(msg);
707
708 assert!(rx.try_recv().is_none());
709 }
710
711 #[test]
712 fn receiver_drop_signals_sender() {
713 let (_tx, rx) = test_channel(4096);
714 drop(rx);
715 assert!(_tx.inner.rx_closed.load(Ordering::Acquire));
716 }
717
718 #[test]
719 fn variable_length_messages() {
720 let (mut tx, mut rx) = test_channel(8192);
721
722 try_send(&mut tx, b"hi");
723 try_send(&mut tx, &vec![0xABu8; 100]);
724 try_send(&mut tx, &vec![0xCDu8; 1000]);
725
726 let msg = rx.try_recv().unwrap();
727 assert_eq!(msg.len(), 2);
728 drop(msg);
729
730 let msg = rx.try_recv().unwrap();
731 assert_eq!(msg.len(), 100);
732 drop(msg);
733
734 let msg = rx.try_recv().unwrap();
735 assert_eq!(msg.len(), 1000);
736 }
737
738 #[test]
739 fn cross_thread_claim_send() {
740 let (mut tx, mut rx) = test_channel(64 * 1024);
741
742 let handle = std::thread::spawn(move || {
743 for i in 0u64..100 {
744 try_send(&mut tx, &i.to_le_bytes());
745 }
746 });
747
748 handle.join().unwrap();
749
750 for i in 0u64..100 {
751 let msg = rx.try_recv().unwrap();
752 assert_eq!(&*msg, &i.to_le_bytes());
753 }
754 }
755
756 #[test]
757 fn stress_sequential() {
758 let (mut tx, mut rx) = test_channel(4096);
759 let data = [0xFFu8; 32];
760
761 let n = if cfg!(miri) { 100 } else { 10_000 };
762 for _ in 0..n {
763 try_send(&mut tx, &data);
764 let msg = rx.try_recv().unwrap();
765 assert_eq!(msg.len(), 32);
766 }
767 }
768
769 #[test]
770 fn claim_without_commit_aborts() {
771 let (mut tx, mut rx) = test_channel(4096);
772
773 let claim = tx.try_claim(10).unwrap();
775 drop(claim);
776
777 try_send(&mut tx, b"after_abort");
779
780 let msg = rx.try_recv().unwrap();
781 assert_eq!(&*msg, b"after_abort");
782 }
783
784 #[test]
785 fn multiple_senders() {
786 let (mut tx1, mut rx) = test_channel(64 * 1024);
787 let mut tx2 = tx1.clone();
788
789 try_send(&mut tx1, b"from_tx1");
790 try_send(&mut tx2, b"from_tx2");
791 try_send(&mut tx1, b"tx1_again");
792
793 let msg = rx.try_recv().unwrap();
794 assert_eq!(&*msg, b"from_tx1");
795 drop(msg);
796
797 let msg = rx.try_recv().unwrap();
798 assert_eq!(&*msg, b"from_tx2");
799 drop(msg);
800
801 let msg = rx.try_recv().unwrap();
802 assert_eq!(&*msg, b"tx1_again");
803 drop(msg);
804
805 assert!(rx.try_recv().is_none());
806 }
807
808 #[test]
812 fn sender_drop_while_queued() {
813 let (mut tx1, mut rx) = test_channel(4096);
814 let tx2 = tx1.clone();
815
816 try_send(&mut tx1, b"data");
817
818 drop(tx2);
821
822 let msg = rx.try_recv().unwrap();
824 assert_eq!(&*msg, b"data");
825 drop(msg);
826
827 try_send(&mut tx1, b"more");
829 let msg = rx.try_recv().unwrap();
830 assert_eq!(&*msg, b"more");
831 }
832}
833
834#[cfg(test)]
844mod uaf_tests {
845 use crate::cross_wake::uaf_scenarios as h;
846
847 #[test]
848 fn waker_slot_uaf_when_task_freed_mid_dispatch() {
849 h::waker_slot_uaf_when_task_freed_mid_dispatch();
850 }
851
852 #[test]
853 fn slot_drop_releases_ref_when_still_registered() {
854 h::slot_drop_releases_ref_when_still_registered();
855 }
856
857 #[test]
858 fn register_during_wake_does_not_leak_ref() {
859 h::register_during_wake_does_not_leak_ref();
860 }
861}