Skip to main content

nexus_async_rt/channel/
spsc_bytes.rs

1//! Bounded cross-thread SPSC byte channel.
2//!
3//! Variable-length messages over `nexus_logbuf::spsc`. Each message is
4//! a `&[u8]` written into a claim region and committed. The consumer
5//! reads `ReadClaim` references that deref to `&[u8]`.
6//!
7//! Zero allocation on the send/recv hot path. Must be created inside
8//! [`Runtime::block_on`](crate::Runtime::block_on).
9//!
10//! ```ignore
11//! use nexus_async_rt::channel::spsc_bytes;
12//!
13//! let (mut tx, mut rx) = spsc_bytes::channel(64 * 1024);
14//!
15//! // Claim, write, commit (zero-copy)
16//! let mut claim = tx.claim(5).await?;
17//! claim.copy_from_slice(b"hello");
18//! claim.commit();
19//!
20//! // Or convenience: claim + copy + commit
21//! tx.send(b"world").await?;
22//!
23//! // Receive
24//! let msg = rx.recv().await?;
25//! assert_eq!(&*msg, b"hello");
26//! drop(msg);  // advances consumer head
27//! ```
28
29use std::sync::Arc;
30use std::sync::atomic::{AtomicBool, Ordering};
31use std::task::Poll;
32
33use std::ops::{Deref, DerefMut};
34
35use crate::cross_wake::{FallbackWaker, TaskWakerSlot, TxWakerSlot};
36
37// =============================================================================
38// Shared state
39// =============================================================================
40
41struct Inner {
42    rx_slot: TaskWakerSlot,
43    rx_fallback: FallbackWaker,
44    tx_waker: TxWakerSlot,
45    _cross_wake_owner: Arc<crate::cross_wake::CrossWakeContext>,
46    tx_alive: AtomicBool,
47    rx_closed: AtomicBool,
48}
49
50unsafe impl Send for Inner {}
51unsafe impl Sync for Inner {}
52
53impl Inner {
54    fn wake_rx(&self) {
55        if !self.rx_slot.wake() {
56            self.rx_fallback.wake();
57        }
58    }
59
60    fn has_rx_waker(&self) -> bool {
61        self.rx_slot.has_waker() || self.rx_fallback.has_waker()
62    }
63}
64
65// =============================================================================
66// Error types
67// =============================================================================
68
69// =============================================================================
70// WriteClaim wrapper — auto-notifies receiver on commit
71// =============================================================================
72
73// =============================================================================
74// ReadClaim wrapper — auto-wakes sender on drop (frees space)
75// =============================================================================
76
77/// A received message from the byte channel. Dereferences to `&[u8]`.
78///
79/// When dropped, the record region is freed (consumer head advances)
80/// and the sender is woken if it was parked on a full buffer.
81pub struct ReadClaim<'a> {
82    inner: nexus_logbuf::queue::spsc::ReadClaim<'a>,
83    notify: &'a Inner,
84}
85
86impl ReadClaim<'_> {
87    /// Payload length in bytes.
88    pub fn len(&self) -> usize {
89        self.inner.len()
90    }
91
92    /// Always false.
93    pub fn is_empty(&self) -> bool {
94        self.inner.is_empty()
95    }
96}
97
98impl Deref for ReadClaim<'_> {
99    type Target = [u8];
100    fn deref(&self) -> &[u8] {
101        &self.inner
102    }
103}
104
105impl Drop for ReadClaim<'_> {
106    fn drop(&mut self) {
107        // The inner ReadClaim drops after this impl runs (field drop order),
108        // which advances the consumer head and frees space. We wake the
109        // sender BEFORE inner drops — the sender will re-try and see space
110        // once inner's drop completes. This ordering is acceptable because
111        // the sender's try_claim will simply fail and re-park if the space
112        // isn't freed yet. On the next poll it succeeds.
113        //
114        // Alternatively we could manually drop inner first, but the
115        // timing difference is one poll cycle at worst.
116        if self.notify.tx_waker.has_waker() {
117            self.notify.tx_waker.wake();
118        }
119    }
120}
121
122// =============================================================================
123// WriteClaim wrapper — auto-notifies receiver on commit
124// =============================================================================
125
126/// A claimed write region in the byte channel. Dereferences to `&mut [u8]`.
127///
128/// Call [`.commit()`](WriteClaim::commit) to publish the record and
129/// wake the receiver. Dropping without commit writes a skip marker (abort).
130pub struct WriteClaim<'a> {
131    inner: nexus_logbuf::queue::spsc::WriteClaim<'a>,
132    notify: &'a Inner,
133}
134
135impl WriteClaim<'_> {
136    /// Commit the record, making it visible to the receiver.
137    /// Automatically wakes the receiver if it's parked.
138    pub fn commit(self) {
139        let notify = self.notify;
140        self.inner.commit();
141        if notify.has_rx_waker() {
142            notify.wake_rx();
143        }
144    }
145
146    /// Payload length in bytes.
147    pub fn len(&self) -> usize {
148        self.inner.len()
149    }
150
151    /// Always false (claims must have len > 0).
152    pub fn is_empty(&self) -> bool {
153        self.inner.is_empty()
154    }
155}
156
157impl Deref for WriteClaim<'_> {
158    type Target = [u8];
159    fn deref(&self) -> &[u8] {
160        &self.inner
161    }
162}
163
164impl DerefMut for WriteClaim<'_> {
165    fn deref_mut(&mut self) -> &mut [u8] {
166        &mut self.inner
167    }
168}
169
170// =============================================================================
171// Error types
172// =============================================================================
173
174/// Claim failed.
175///
176/// `len == 0` is not a runtime error — it's a precondition violation and
177/// panics in [`nexus_logbuf::queue::spsc::Producer::try_claim`].
178#[derive(Debug)]
179#[non_exhaustive]
180pub enum ClaimError {
181    /// Receiver was dropped.
182    Closed,
183    /// Requested length exceeds buffer capacity (can never succeed).
184    TooLarge,
185}
186
187impl std::fmt::Display for ClaimError {
188    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189        match self {
190            Self::Closed => f.write_str("byte channel closed"),
191            Self::TooLarge => f.write_str("message exceeds buffer capacity"),
192        }
193    }
194}
195
196impl std::error::Error for ClaimError {}
197
198/// Receive failed — sender dropped and buffer empty.
199#[derive(Debug)]
200pub struct RecvError;
201
202impl std::fmt::Display for RecvError {
203    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204        f.write_str("byte channel closed")
205    }
206}
207
208impl std::error::Error for RecvError {}
209
210// =============================================================================
211// channel()
212// =============================================================================
213
214/// Create a bounded cross-thread SPSC byte channel.
215///
216/// `capacity` is the ring buffer size in bytes.
217///
218/// # Panics
219///
220/// - Panics if called outside [`Runtime::block_on`](crate::Runtime::block_on).
221pub fn channel(capacity: usize) -> (Sender, Receiver) {
222    crate::context::assert_in_runtime("spsc_bytes::channel() called outside Runtime::block_on");
223
224    let cross_ctx = crate::cross_wake::cross_wake_context()
225        .expect("spsc_bytes::channel() requires runtime context");
226
227    let (producer, consumer) = nexus_logbuf::queue::spsc::new(capacity);
228    let rx_slot = TaskWakerSlot::new(Arc::as_ptr(&cross_ctx));
229
230    let inner = Arc::new(Inner {
231        rx_slot,
232        rx_fallback: FallbackWaker::new(),
233        tx_waker: TxWakerSlot::new(),
234        _cross_wake_owner: cross_ctx,
235        tx_alive: AtomicBool::new(true),
236        rx_closed: AtomicBool::new(false),
237    });
238
239    (
240        Sender {
241            producer,
242            inner: inner.clone(),
243        },
244        Receiver { consumer, inner },
245    )
246}
247
248// =============================================================================
249// Sender
250// =============================================================================
251
252/// Sending half of a bounded SPSC byte channel.
253///
254/// `Send` but not `Clone` — single producer.
255pub struct Sender {
256    producer: nexus_logbuf::queue::spsc::Producer,
257    inner: Arc<Inner>,
258}
259
260impl Sender {
261    /// Send a complete byte message. Claims space, copies, commits.
262    ///
263    /// Waits if the buffer is full. Returns `Err` if receiver dropped.
264    /// Claim `len` bytes for zero-copy writing.
265    ///
266    /// Waits if the buffer is full. Write into the returned `WriteClaim`,
267    /// then call `.commit()` to publish. Drop without commit writes a
268    /// skip marker (abort).
269    ///
270    /// Returns `Err(ClaimError::TooLarge)` immediately if `len` exceeds
271    /// the buffer capacity (can never succeed).
272    ///
273    /// # Panics
274    ///
275    /// Polling the returned future with `len == 0` panics (see
276    /// [`nexus_logbuf::queue::spsc::Producer::try_claim`]).
277    pub fn claim(&mut self, len: usize) -> ClaimFut<'_> {
278        ClaimFut { sender: self, len }
279    }
280
281    /// Try to claim without waiting.
282    ///
283    /// # Panics
284    ///
285    /// Panics if `len == 0` (see
286    /// [`nexus_logbuf::queue::spsc::Producer::try_claim`]).
287    pub fn try_claim(&mut self, len: usize) -> Result<WriteClaim<'_>, nexus_logbuf::BufferFull> {
288        let inner_claim = self.producer.try_claim(len)?;
289        Ok(WriteClaim {
290            inner: inner_claim,
291            notify: &self.inner,
292        })
293    }
294}
295
296/// Future returned by [`Sender::claim`].
297pub struct ClaimFut<'a> {
298    sender: &'a mut Sender,
299    len: usize,
300}
301
302impl<'a> Future for ClaimFut<'a> {
303    type Output = Result<WriteClaim<'a>, ClaimError>;
304
305    fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
306        let this = unsafe { &mut *std::pin::Pin::into_inner_unchecked(self) };
307        let sender: &'a mut Sender = unsafe { &mut *(this.sender as *mut Sender) };
308
309        // Precondition check before any state inspection — `len == 0` is a
310        // contract violation regardless of channel state, and the doc
311        // contract is honest only if it panics unconditionally.
312        assert!(this.len > 0, "payload length must be non-zero");
313
314        if sender.inner.rx_closed.load(Ordering::Acquire) {
315            return Poll::Ready(Err(ClaimError::Closed));
316        }
317
318        if this.len > sender.producer.capacity() {
319            return Poll::Ready(Err(ClaimError::TooLarge));
320        }
321
322        if let Ok(inner_claim) = sender.producer.try_claim(this.len) {
323            return Poll::Ready(Ok(WriteClaim {
324                inner: inner_claim,
325                notify: &sender.inner,
326            }));
327        }
328        // BufferFull — park and wake on receiver progress.
329        sender.inner.tx_waker.register(cx.waker());
330        Poll::Pending
331    }
332}
333
334unsafe impl Send for ClaimFut<'_> {}
335
336impl Drop for Sender {
337    fn drop(&mut self) {
338        self.inner.tx_alive.store(false, Ordering::Release);
339        self.inner.wake_rx();
340    }
341}
342
343unsafe impl Send for Sender {}
344
345// =============================================================================
346// Receiver
347// =============================================================================
348
349/// Receiving half of a bounded SPSC byte channel.
350///
351/// `Send` but not `Clone` — single consumer.
352pub struct Receiver {
353    consumer: nexus_logbuf::queue::spsc::Consumer,
354    inner: Arc<Inner>,
355}
356
357impl Receiver {
358    /// Receive the next message. Returns a `ReadClaim` that derefs to `&[u8]`.
359    ///
360    /// Dropping the claim advances the consumer head and wakes the sender
361    /// if it was blocked on a full buffer.
362    pub fn recv(&mut self) -> RecvFut<'_> {
363        RecvFut { receiver: self }
364    }
365
366    /// Try to receive without waiting.
367    pub fn try_recv(&mut self) -> Option<ReadClaim<'_>> {
368        let inner_claim = self.consumer.try_claim()?;
369        Some(ReadClaim {
370            inner: inner_claim,
371            notify: &self.inner,
372        })
373    }
374}
375
376/// Future returned by [`Receiver::recv`].
377pub struct RecvFut<'a> {
378    receiver: &'a mut Receiver,
379}
380
381impl Drop for RecvFut<'_> {
382    fn drop(&mut self) {
383        self.receiver.inner.rx_slot.clear();
384    }
385}
386
387impl<'a> Future for RecvFut<'a> {
388    type Output = Result<ReadClaim<'a>, RecvError>;
389
390    fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
391        // SAFETY: RecvFut is not Unpin-sensitive. We need &mut access to
392        // receiver.consumer for try_claim, and the returned ReadClaim must
393        // have lifetime 'a (tied to the Receiver, not this poll call).
394        let this = unsafe { &mut *std::pin::Pin::into_inner_unchecked(self) };
395
396        // SAFETY: Extend the reborrow lifetime to 'a. This is sound because:
397        // - RecvFut holds &'a mut Receiver, so the Receiver lives for 'a
398        // - ReadClaim borrows &mut Consumer from that Receiver
399        // - The future won't be polled again after returning Ready
400        let receiver: &'a mut Receiver = unsafe { &mut *(this.receiver as *mut Receiver) };
401
402        // Try to claim.
403        if let Some(inner_claim) = receiver.consumer.try_claim() {
404            return Poll::Ready(Ok(ReadClaim {
405                inner: inner_claim,
406                notify: &receiver.inner,
407            }));
408        }
409
410        // Empty + sender dropped → closed.
411        if !receiver.inner.tx_alive.load(Ordering::Acquire) {
412            return Poll::Ready(Err(RecvError));
413        }
414
415        // Park.
416        if !receiver.inner.rx_slot.try_register_local(cx.waker()) {
417            receiver.inner.rx_fallback.register(cx.waker());
418        }
419
420        Poll::Pending
421    }
422}
423
424unsafe impl Send for RecvFut<'_> {}
425
426impl Drop for Receiver {
427    fn drop(&mut self) {
428        self.inner.rx_closed.store(true, Ordering::Release);
429        self.inner.tx_waker.wake();
430    }
431}
432
433unsafe impl Send for Receiver {}
434
435// =============================================================================
436// Tests
437// =============================================================================
438
439#[cfg(test)]
440mod tests {
441    use super::*;
442
443    fn test_channel(capacity: usize) -> (Sender, Receiver) {
444        let poll = mio::Poll::new().unwrap();
445        let mio_waker = Arc::new(mio::Waker::new(poll.registry(), mio::Token(usize::MAX)).unwrap());
446        let cross_ctx = Arc::new(crate::cross_wake::CrossWakeContext {
447            queue: crate::cross_wake::CrossWakeQueue::new(),
448            mio_waker,
449            parked: AtomicBool::new(false),
450        });
451
452        let (producer, consumer) = nexus_logbuf::queue::spsc::new(capacity);
453        let rx_slot = TaskWakerSlot::new(Arc::as_ptr(&cross_ctx));
454
455        let inner = Arc::new(Inner {
456            rx_slot,
457            rx_fallback: FallbackWaker::new(),
458            tx_waker: TxWakerSlot::new(),
459            _cross_wake_owner: cross_ctx,
460            tx_alive: AtomicBool::new(true),
461            rx_closed: AtomicBool::new(false),
462        });
463
464        (
465            Sender {
466                producer,
467                inner: inner.clone(),
468            },
469            Receiver { consumer, inner },
470        )
471    }
472
473    fn try_send(tx: &mut Sender, data: &[u8]) {
474        let mut claim = tx.try_claim(data.len()).unwrap();
475        claim.copy_from_slice(data);
476        claim.commit(); // auto-notifies receiver
477    }
478
479    #[test]
480    fn claim_commit_recv() {
481        let (mut tx, mut rx) = test_channel(4096);
482        try_send(&mut tx, b"hello");
483        try_send(&mut tx, b"world");
484
485        let msg = rx.try_recv().unwrap();
486        assert_eq!(&*msg, b"hello");
487        drop(msg);
488
489        let msg = rx.try_recv().unwrap();
490        assert_eq!(&*msg, b"world");
491        drop(msg);
492
493        assert!(rx.try_recv().is_none());
494    }
495
496    #[test]
497    fn fifo_ordering() {
498        let (mut tx, mut rx) = test_channel(4096);
499        for i in 0u32..10 {
500            try_send(&mut tx, &i.to_le_bytes());
501        }
502        for i in 0u32..10 {
503            let msg = rx.try_recv().unwrap();
504            assert_eq!(&*msg, &i.to_le_bytes());
505        }
506    }
507
508    #[test]
509    fn sender_drop_signals_closed() {
510        let (mut tx, mut rx) = test_channel(4096);
511        try_send(&mut tx, b"last");
512        drop(tx);
513
514        let msg = rx.try_recv().unwrap();
515        assert_eq!(&*msg, b"last");
516        drop(msg);
517
518        assert!(rx.try_recv().is_none());
519    }
520
521    #[test]
522    fn variable_length_messages() {
523        let (mut tx, mut rx) = test_channel(8192);
524
525        try_send(&mut tx, b"hi");
526        try_send(&mut tx, &vec![0xABu8; 100]);
527        try_send(&mut tx, &vec![0xCDu8; 1000]);
528
529        let msg = rx.try_recv().unwrap();
530        assert_eq!(msg.len(), 2);
531        drop(msg);
532
533        let msg = rx.try_recv().unwrap();
534        assert_eq!(msg.len(), 100);
535        drop(msg);
536
537        let msg = rx.try_recv().unwrap();
538        assert_eq!(msg.len(), 1000);
539    }
540
541    #[test]
542    fn cross_thread_claim_send() {
543        let (mut tx, mut rx) = test_channel(64 * 1024);
544
545        let handle = std::thread::spawn(move || {
546            for i in 0u64..100 {
547                try_send(&mut tx, &i.to_le_bytes());
548            }
549        });
550
551        handle.join().unwrap();
552
553        for i in 0u64..100 {
554            let msg = rx.try_recv().unwrap();
555            assert_eq!(&*msg, &i.to_le_bytes());
556        }
557    }
558
559    #[test]
560    fn stress_sequential() {
561        let (mut tx, mut rx) = test_channel(4096);
562        let data = [0xFFu8; 32];
563
564        let n = if cfg!(miri) { 100 } else { 10_000 };
565        for _ in 0..n {
566            try_send(&mut tx, &data);
567            let msg = rx.try_recv().unwrap();
568            assert_eq!(msg.len(), 32);
569        }
570    }
571
572    #[test]
573    fn receiver_drop_signals_sender() {
574        let (tx, rx) = test_channel(4096);
575        drop(rx);
576        assert!(tx.inner.rx_closed.load(Ordering::Acquire));
577    }
578
579    #[test]
580    fn claim_without_commit_aborts() {
581        let (mut tx, mut rx) = test_channel(4096);
582
583        // Claim and drop without commit — skip marker.
584        let claim = tx.try_claim(10).unwrap();
585        drop(claim);
586
587        // Next claim + commit should work.
588        try_send(&mut tx, b"after_abort");
589
590        let msg = rx.try_recv().unwrap();
591        assert_eq!(&*msg, b"after_abort");
592    }
593}
594
595// =============================================================================
596// BUG-2 (#168) — cross-thread wake-path UAF regression tests
597// =============================================================================
598//
599// Tests live in `crate::cross_wake::uaf_scenarios` (one canonical body
600// per scenario, shared across all four channels). These per-channel
601// `#[test]` wrappers exist for `cargo test spsc_bytes::uaf_tests`
602// output visibility and to verify the consolidated `TaskWakerSlot`
603// works identically across channel modules.
604#[cfg(test)]
605mod uaf_tests {
606    use crate::cross_wake::uaf_scenarios as h;
607
608    #[test]
609    fn waker_slot_uaf_when_task_freed_mid_dispatch() {
610        h::waker_slot_uaf_when_task_freed_mid_dispatch();
611    }
612
613    #[test]
614    fn slot_drop_releases_ref_when_still_registered() {
615        h::slot_drop_releases_ref_when_still_registered();
616    }
617
618    #[test]
619    fn register_during_wake_does_not_leak_ref() {
620        h::register_during_wake_does_not_leak_ref();
621    }
622}