1use std::sync::Arc;
30use std::sync::atomic::{AtomicBool, Ordering};
31use std::task::Poll;
32
33use std::ops::{Deref, DerefMut};
34
35use crate::cross_wake::{FallbackWaker, TaskWakerSlot, TxWakerSlot};
36
37struct Inner {
42 rx_slot: TaskWakerSlot,
43 rx_fallback: FallbackWaker,
44 tx_waker: TxWakerSlot,
45 _cross_wake_owner: Arc<crate::cross_wake::CrossWakeContext>,
46 tx_alive: AtomicBool,
47 rx_closed: AtomicBool,
48}
49
50unsafe impl Send for Inner {}
51unsafe impl Sync for Inner {}
52
53impl Inner {
54 fn wake_rx(&self) {
55 if !self.rx_slot.wake() {
56 self.rx_fallback.wake();
57 }
58 }
59
60 fn has_rx_waker(&self) -> bool {
61 self.rx_slot.has_waker() || self.rx_fallback.has_waker()
62 }
63}
64
65pub struct ReadClaim<'a> {
82 inner: nexus_logbuf::queue::spsc::ReadClaim<'a>,
83 notify: &'a Inner,
84}
85
86impl ReadClaim<'_> {
87 pub fn len(&self) -> usize {
89 self.inner.len()
90 }
91
92 pub fn is_empty(&self) -> bool {
94 self.inner.is_empty()
95 }
96}
97
98impl Deref for ReadClaim<'_> {
99 type Target = [u8];
100 fn deref(&self) -> &[u8] {
101 &self.inner
102 }
103}
104
105impl Drop for ReadClaim<'_> {
106 fn drop(&mut self) {
107 if self.notify.tx_waker.has_waker() {
117 self.notify.tx_waker.wake();
118 }
119 }
120}
121
122pub struct WriteClaim<'a> {
131 inner: nexus_logbuf::queue::spsc::WriteClaim<'a>,
132 notify: &'a Inner,
133}
134
135impl WriteClaim<'_> {
136 pub fn commit(self) {
139 let notify = self.notify;
140 self.inner.commit();
141 if notify.has_rx_waker() {
142 notify.wake_rx();
143 }
144 }
145
146 pub fn len(&self) -> usize {
148 self.inner.len()
149 }
150
151 pub fn is_empty(&self) -> bool {
153 self.inner.is_empty()
154 }
155}
156
157impl Deref for WriteClaim<'_> {
158 type Target = [u8];
159 fn deref(&self) -> &[u8] {
160 &self.inner
161 }
162}
163
164impl DerefMut for WriteClaim<'_> {
165 fn deref_mut(&mut self) -> &mut [u8] {
166 &mut self.inner
167 }
168}
169
170#[derive(Debug)]
179#[non_exhaustive]
180pub enum ClaimError {
181 Closed,
183 TooLarge,
185}
186
187impl std::fmt::Display for ClaimError {
188 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189 match self {
190 Self::Closed => f.write_str("byte channel closed"),
191 Self::TooLarge => f.write_str("message exceeds buffer capacity"),
192 }
193 }
194}
195
196impl std::error::Error for ClaimError {}
197
198#[derive(Debug)]
200pub struct RecvError;
201
202impl std::fmt::Display for RecvError {
203 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204 f.write_str("byte channel closed")
205 }
206}
207
208impl std::error::Error for RecvError {}
209
210pub fn channel(capacity: usize) -> (Sender, Receiver) {
222 crate::context::assert_in_runtime("spsc_bytes::channel() called outside Runtime::block_on");
223
224 let cross_ctx = crate::cross_wake::cross_wake_context()
225 .expect("spsc_bytes::channel() requires runtime context");
226
227 let (producer, consumer) = nexus_logbuf::queue::spsc::new(capacity);
228 let rx_slot = TaskWakerSlot::new(Arc::as_ptr(&cross_ctx));
229
230 let inner = Arc::new(Inner {
231 rx_slot,
232 rx_fallback: FallbackWaker::new(),
233 tx_waker: TxWakerSlot::new(),
234 _cross_wake_owner: cross_ctx,
235 tx_alive: AtomicBool::new(true),
236 rx_closed: AtomicBool::new(false),
237 });
238
239 (
240 Sender {
241 producer,
242 inner: inner.clone(),
243 },
244 Receiver { consumer, inner },
245 )
246}
247
248pub struct Sender {
256 producer: nexus_logbuf::queue::spsc::Producer,
257 inner: Arc<Inner>,
258}
259
260impl Sender {
261 pub fn claim(&mut self, len: usize) -> ClaimFut<'_> {
278 ClaimFut { sender: self, len }
279 }
280
281 pub fn try_claim(&mut self, len: usize) -> Result<WriteClaim<'_>, nexus_logbuf::BufferFull> {
288 let inner_claim = self.producer.try_claim(len)?;
289 Ok(WriteClaim {
290 inner: inner_claim,
291 notify: &self.inner,
292 })
293 }
294}
295
296pub struct ClaimFut<'a> {
298 sender: &'a mut Sender,
299 len: usize,
300}
301
302impl<'a> Future for ClaimFut<'a> {
303 type Output = Result<WriteClaim<'a>, ClaimError>;
304
305 fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
306 let this = unsafe { &mut *std::pin::Pin::into_inner_unchecked(self) };
307 let sender: &'a mut Sender = unsafe { &mut *(this.sender as *mut Sender) };
308
309 assert!(this.len > 0, "payload length must be non-zero");
313
314 if sender.inner.rx_closed.load(Ordering::Acquire) {
315 return Poll::Ready(Err(ClaimError::Closed));
316 }
317
318 if this.len > sender.producer.capacity() {
319 return Poll::Ready(Err(ClaimError::TooLarge));
320 }
321
322 if let Ok(inner_claim) = sender.producer.try_claim(this.len) {
323 return Poll::Ready(Ok(WriteClaim {
324 inner: inner_claim,
325 notify: &sender.inner,
326 }));
327 }
328 sender.inner.tx_waker.register(cx.waker());
330 Poll::Pending
331 }
332}
333
334unsafe impl Send for ClaimFut<'_> {}
335
336impl Drop for Sender {
337 fn drop(&mut self) {
338 self.inner.tx_alive.store(false, Ordering::Release);
339 self.inner.wake_rx();
340 }
341}
342
343unsafe impl Send for Sender {}
344
345pub struct Receiver {
353 consumer: nexus_logbuf::queue::spsc::Consumer,
354 inner: Arc<Inner>,
355}
356
357impl Receiver {
358 pub fn recv(&mut self) -> RecvFut<'_> {
363 RecvFut { receiver: self }
364 }
365
366 pub fn try_recv(&mut self) -> Option<ReadClaim<'_>> {
368 let inner_claim = self.consumer.try_claim()?;
369 Some(ReadClaim {
370 inner: inner_claim,
371 notify: &self.inner,
372 })
373 }
374}
375
376pub struct RecvFut<'a> {
378 receiver: &'a mut Receiver,
379}
380
381impl Drop for RecvFut<'_> {
382 fn drop(&mut self) {
383 self.receiver.inner.rx_slot.clear();
384 }
385}
386
387impl<'a> Future for RecvFut<'a> {
388 type Output = Result<ReadClaim<'a>, RecvError>;
389
390 fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
391 let this = unsafe { &mut *std::pin::Pin::into_inner_unchecked(self) };
395
396 let receiver: &'a mut Receiver = unsafe { &mut *(this.receiver as *mut Receiver) };
401
402 if let Some(inner_claim) = receiver.consumer.try_claim() {
404 return Poll::Ready(Ok(ReadClaim {
405 inner: inner_claim,
406 notify: &receiver.inner,
407 }));
408 }
409
410 if !receiver.inner.tx_alive.load(Ordering::Acquire) {
412 return Poll::Ready(Err(RecvError));
413 }
414
415 if !receiver.inner.rx_slot.try_register_local(cx.waker()) {
417 receiver.inner.rx_fallback.register(cx.waker());
418 }
419
420 Poll::Pending
421 }
422}
423
424unsafe impl Send for RecvFut<'_> {}
425
426impl Drop for Receiver {
427 fn drop(&mut self) {
428 self.inner.rx_closed.store(true, Ordering::Release);
429 self.inner.tx_waker.wake();
430 }
431}
432
433unsafe impl Send for Receiver {}
434
435#[cfg(test)]
440mod tests {
441 use super::*;
442
443 fn test_channel(capacity: usize) -> (Sender, Receiver) {
444 let poll = mio::Poll::new().unwrap();
445 let mio_waker = Arc::new(mio::Waker::new(poll.registry(), mio::Token(usize::MAX)).unwrap());
446 let cross_ctx = Arc::new(crate::cross_wake::CrossWakeContext {
447 queue: crate::cross_wake::CrossWakeQueue::new(),
448 mio_waker,
449 parked: AtomicBool::new(false),
450 });
451
452 let (producer, consumer) = nexus_logbuf::queue::spsc::new(capacity);
453 let rx_slot = TaskWakerSlot::new(Arc::as_ptr(&cross_ctx));
454
455 let inner = Arc::new(Inner {
456 rx_slot,
457 rx_fallback: FallbackWaker::new(),
458 tx_waker: TxWakerSlot::new(),
459 _cross_wake_owner: cross_ctx,
460 tx_alive: AtomicBool::new(true),
461 rx_closed: AtomicBool::new(false),
462 });
463
464 (
465 Sender {
466 producer,
467 inner: inner.clone(),
468 },
469 Receiver { consumer, inner },
470 )
471 }
472
473 fn try_send(tx: &mut Sender, data: &[u8]) {
474 let mut claim = tx.try_claim(data.len()).unwrap();
475 claim.copy_from_slice(data);
476 claim.commit(); }
478
479 #[test]
480 fn claim_commit_recv() {
481 let (mut tx, mut rx) = test_channel(4096);
482 try_send(&mut tx, b"hello");
483 try_send(&mut tx, b"world");
484
485 let msg = rx.try_recv().unwrap();
486 assert_eq!(&*msg, b"hello");
487 drop(msg);
488
489 let msg = rx.try_recv().unwrap();
490 assert_eq!(&*msg, b"world");
491 drop(msg);
492
493 assert!(rx.try_recv().is_none());
494 }
495
496 #[test]
497 fn fifo_ordering() {
498 let (mut tx, mut rx) = test_channel(4096);
499 for i in 0u32..10 {
500 try_send(&mut tx, &i.to_le_bytes());
501 }
502 for i in 0u32..10 {
503 let msg = rx.try_recv().unwrap();
504 assert_eq!(&*msg, &i.to_le_bytes());
505 }
506 }
507
508 #[test]
509 fn sender_drop_signals_closed() {
510 let (mut tx, mut rx) = test_channel(4096);
511 try_send(&mut tx, b"last");
512 drop(tx);
513
514 let msg = rx.try_recv().unwrap();
515 assert_eq!(&*msg, b"last");
516 drop(msg);
517
518 assert!(rx.try_recv().is_none());
519 }
520
521 #[test]
522 fn variable_length_messages() {
523 let (mut tx, mut rx) = test_channel(8192);
524
525 try_send(&mut tx, b"hi");
526 try_send(&mut tx, &vec![0xABu8; 100]);
527 try_send(&mut tx, &vec![0xCDu8; 1000]);
528
529 let msg = rx.try_recv().unwrap();
530 assert_eq!(msg.len(), 2);
531 drop(msg);
532
533 let msg = rx.try_recv().unwrap();
534 assert_eq!(msg.len(), 100);
535 drop(msg);
536
537 let msg = rx.try_recv().unwrap();
538 assert_eq!(msg.len(), 1000);
539 }
540
541 #[test]
542 fn cross_thread_claim_send() {
543 let (mut tx, mut rx) = test_channel(64 * 1024);
544
545 let handle = std::thread::spawn(move || {
546 for i in 0u64..100 {
547 try_send(&mut tx, &i.to_le_bytes());
548 }
549 });
550
551 handle.join().unwrap();
552
553 for i in 0u64..100 {
554 let msg = rx.try_recv().unwrap();
555 assert_eq!(&*msg, &i.to_le_bytes());
556 }
557 }
558
559 #[test]
560 fn stress_sequential() {
561 let (mut tx, mut rx) = test_channel(4096);
562 let data = [0xFFu8; 32];
563
564 let n = if cfg!(miri) { 100 } else { 10_000 };
565 for _ in 0..n {
566 try_send(&mut tx, &data);
567 let msg = rx.try_recv().unwrap();
568 assert_eq!(msg.len(), 32);
569 }
570 }
571
572 #[test]
573 fn receiver_drop_signals_sender() {
574 let (tx, rx) = test_channel(4096);
575 drop(rx);
576 assert!(tx.inner.rx_closed.load(Ordering::Acquire));
577 }
578
579 #[test]
580 fn claim_without_commit_aborts() {
581 let (mut tx, mut rx) = test_channel(4096);
582
583 let claim = tx.try_claim(10).unwrap();
585 drop(claim);
586
587 try_send(&mut tx, b"after_abort");
589
590 let msg = rx.try_recv().unwrap();
591 assert_eq!(&*msg, b"after_abort");
592 }
593}
594
595#[cfg(test)]
605mod uaf_tests {
606 use crate::cross_wake::uaf_scenarios as h;
607
608 #[test]
609 fn waker_slot_uaf_when_task_freed_mid_dispatch() {
610 h::waker_slot_uaf_when_task_freed_mid_dispatch();
611 }
612
613 #[test]
614 fn slot_drop_releases_ref_when_still_registered() {
615 h::slot_drop_releases_ref_when_still_registered();
616 }
617
618 #[test]
619 fn register_during_wake_does_not_leak_ref() {
620 h::register_during_wake_does_not_leak_ref();
621 }
622}