Skip to main content

nexus_queue/
mpsc.rs

1//! Multi-producer single-consumer bounded queue.
2//!
3//! A lock-free ring buffer optimized for multiple producer threads sending to
4//! one consumer thread. Uses CAS-based slot claiming with Vyukov-style turn
5//! counters for synchronization.
6//!
7//! # Design
8//!
9//! ```text
10//! ┌─────────────────────────────────────────────────────────────┐
11//! │ Shared (Arc):                                               │
12//! │   tail: CachePadded<AtomicUsize>   ← Producers CAS here     │
13//! │   head: CachePadded<AtomicUsize>   ← Consumer writes        │
14//! │   slots: *mut Slot<T>              ← Per-slot turn counters │
15//! └─────────────────────────────────────────────────────────────┘
16//!
17//! ┌─────────────────────┐     ┌─────────────────────┐
18//! │ Producer (Clone):   │     │ Consumer (!Clone):  │
19//! │   cached_head       │     │   local_head        │
20//! │   shared: Arc       │     │   shared: Arc       │
21//! └─────────────────────┘     └─────────────────────┘
22//! ```
23//!
24//! Producers compete via CAS on the tail index. After claiming a slot, the
25//! producer waits for the slot's turn counter to indicate it's writable, writes
26//! the data, then advances the turn to signal readiness.
27//!
28//! The consumer checks the turn counter to know when data is ready, reads it,
29//! then advances the turn for the next producer lap.
30//!
31//! # Turn Counter Protocol
32//!
33//! For slot at index `i` on lap `turn`:
34//! - `turn * 2`: Slot is ready for producer to write
35//! - `turn * 2 + 1`: Slot contains data, ready for consumer
36//!
37//! # Example
38//!
39//! ```
40//! use nexus_queue::mpsc;
41//! use std::thread;
42//!
43//! let (mut tx, mut rx) = mpsc::bounded::<u64>(1024);
44//!
45//! let mut tx2 = tx.clone();
46//! let h1 = thread::spawn(move || {
47//!     for i in 0..100 {
48//!         while tx.push(i).is_err() { std::hint::spin_loop(); }
49//!     }
50//! });
51//! let h2 = thread::spawn(move || {
52//!     for i in 100..200 {
53//!         while tx2.push(i).is_err() { std::hint::spin_loop(); }
54//!     }
55//! });
56//!
57//! let mut received = 0;
58//! while received < 200 {
59//!     if rx.pop().is_some() { received += 1; }
60//! }
61//!
62//! h1.join().unwrap();
63//! h2.join().unwrap();
64//! ```
65
66use std::cell::UnsafeCell;
67use std::fmt;
68use std::mem::MaybeUninit;
69use std::sync::Arc;
70use std::sync::atomic::{AtomicUsize, Ordering};
71
72use crossbeam_utils::CachePadded;
73
74use crate::Full;
75
76/// Creates a bounded MPSC queue with the given capacity.
77///
78/// Capacity is rounded up to the next power of two.
79///
80/// # Panics
81///
82/// Panics if `capacity` is zero.
83pub fn bounded<T>(capacity: usize) -> (Producer<T>, Consumer<T>) {
84    assert!(capacity > 0, "capacity must be non-zero");
85
86    let capacity = capacity.next_power_of_two();
87    let mask = capacity - 1;
88
89    // Allocate slots with turn counters initialized to 0 (ready for turn 0 producers)
90    let slots: Vec<Slot<T>> = (0..capacity)
91        .map(|_| Slot {
92            turn: AtomicUsize::new(0),
93            data: UnsafeCell::new(MaybeUninit::uninit()),
94        })
95        .collect();
96    let slots = Box::into_raw(slots.into_boxed_slice()) as *mut Slot<T>;
97
98    let shared = Arc::new(Shared {
99        tail: CachePadded::new(AtomicUsize::new(0)),
100        head: CachePadded::new(AtomicUsize::new(0)),
101        slots,
102        capacity,
103        mask,
104    });
105
106    (
107        Producer {
108            cached_head: 0,
109            slots,
110            mask,
111            capacity,
112            shared: Arc::clone(&shared),
113        },
114        Consumer {
115            local_head: 0,
116            shared,
117        },
118    )
119}
120
121/// A slot in the ring buffer with turn-based synchronization.
122struct Slot<T> {
123    /// Turn counter for Vyukov-style synchronization.
124    /// - `turn * 2`: ready for producer
125    /// - `turn * 2 + 1`: ready for consumer
126    turn: AtomicUsize,
127    /// The data stored in this slot.
128    data: UnsafeCell<MaybeUninit<T>>,
129}
130
131/// Shared state between producers and the consumer.
132// repr(C): Guarantees field order for cache line layout.
133#[repr(C)]
134struct Shared<T> {
135    /// Tail index - producers CAS on this to claim slots.
136    tail: CachePadded<AtomicUsize>,
137    /// Head index - consumer publishes progress here.
138    head: CachePadded<AtomicUsize>,
139    /// Pointer to the slot array.
140    slots: *mut Slot<T>,
141    /// Actual capacity (power of two).
142    capacity: usize,
143    /// Mask for fast modulo (capacity - 1).
144    mask: usize,
145}
146
147// SAFETY: Shared contains atomics and raw pointers. Access is synchronized via
148// the turn counters. T: Send ensures data can move between threads.
149unsafe impl<T: Send> Send for Shared<T> {}
150unsafe impl<T: Send> Sync for Shared<T> {}
151
152impl<T> Drop for Shared<T> {
153    fn drop(&mut self) {
154        let head = self.head.load(Ordering::Relaxed);
155        let tail = self.tail.load(Ordering::Relaxed);
156
157        // Drop any remaining elements in the queue
158        let mut i = head;
159        while i != tail {
160            let slot = unsafe { &*self.slots.add(i & self.mask) };
161            let turn = i / self.capacity;
162
163            // Only drop if the slot was actually written (turn is odd = consumer-ready)
164            if slot.turn.load(Ordering::Relaxed) == turn * 2 + 1 {
165                // SAFETY: Slot contains initialized data at this turn.
166                unsafe { (*slot.data.get()).assume_init_drop() };
167            }
168            i = i.wrapping_add(1);
169        }
170
171        // SAFETY: slots was allocated via Box::into_raw from a Vec.
172        unsafe {
173            let _ = Box::from_raw(std::ptr::slice_from_raw_parts_mut(
174                self.slots,
175                self.capacity,
176            ));
177        }
178    }
179}
180
181/// The producer endpoint of an MPSC queue.
182///
183/// This endpoint can be cloned to create additional producers. Each clone
184/// maintains its own cached state for performance.
185// repr(C): Hot fields at struct base share cache line with struct pointer.
186#[repr(C)]
187pub struct Producer<T> {
188    /// Cached head for fast full-check. Only refreshed when cache indicates full.
189    cached_head: usize,
190    /// Cached slots pointer (avoids Arc deref on hot path).
191    slots: *mut Slot<T>,
192    /// Cached mask (avoids Arc deref on hot path).
193    mask: usize,
194    /// Cached capacity (avoids Arc deref on hot path).
195    capacity: usize,
196    shared: Arc<Shared<T>>,
197}
198
199impl<T> Clone for Producer<T> {
200    fn clone(&self) -> Self {
201        Producer {
202            // Fresh cache - will be populated on first push
203            cached_head: self.shared.head.load(Ordering::Relaxed),
204            slots: self.slots,
205            mask: self.mask,
206            capacity: self.capacity,
207            shared: Arc::clone(&self.shared),
208        }
209    }
210}
211
212// SAFETY: Producer can be sent to another thread. Each Producer instance is
213// used by one thread (not Sync - use clone() for multiple threads).
214unsafe impl<T: Send> Send for Producer<T> {}
215
216impl<T> Producer<T> {
217    /// Pushes a value into the queue.
218    ///
219    /// Returns `Err(Full(value))` if the queue is full, returning ownership
220    /// of the value to the caller for backpressure handling.
221    #[inline]
222    #[must_use = "push returns Err if full, which should be handled"]
223    pub fn push(&mut self, value: T) -> Result<(), Full<T>> {
224        let mut backoff = Backoff::new();
225
226        loop {
227            let tail = self.shared.tail.load(Ordering::Relaxed);
228
229            // Fast path: check against cached head (avoids atomic load most of the time)
230            if tail.wrapping_sub(self.cached_head) >= self.capacity {
231                // Cache miss: refresh from shared head
232                self.cached_head = self.shared.head.load(Ordering::Acquire);
233
234                // Re-check with fresh head
235                if tail.wrapping_sub(self.cached_head) >= self.capacity {
236                    return Err(Full(value));
237                }
238            }
239
240            // Try to claim this slot via CAS
241            if self
242                .shared
243                .tail
244                .compare_exchange_weak(
245                    tail,
246                    tail.wrapping_add(1),
247                    Ordering::Relaxed,
248                    Ordering::Relaxed,
249                )
250                .is_ok()
251            {
252                // SAFETY: slots pointer is valid for the lifetime of shared.
253                let slot = unsafe { &*self.slots.add(tail & self.mask) };
254                let turn = tail / self.capacity;
255
256                // Wait for slot to be ready (should be immediate if capacity check passed)
257                // This spin is rare - only happens if consumer hasn't caught up
258                while slot.turn.load(Ordering::Acquire) != turn * 2 {
259                    std::hint::spin_loop();
260                }
261
262                // SAFETY: We own this slot via successful CAS and turn check.
263                unsafe { (*slot.data.get()).write(value) };
264
265                // Signal ready for consumer: turn * 2 + 1
266                slot.turn.store(turn * 2 + 1, Ordering::Release);
267
268                return Ok(());
269            }
270
271            // CAS failed, another producer won - back off and retry
272            backoff.spin();
273        }
274    }
275
276    /// Returns the capacity of the queue.
277    #[inline]
278    pub fn capacity(&self) -> usize {
279        self.capacity
280    }
281
282    /// Returns `true` if the consumer has been dropped.
283    #[inline]
284    pub fn is_disconnected(&self) -> bool {
285        // If only producers remain, strong_count equals number of producers
286        // We can't easily detect consumer drop with just Arc count since
287        // multiple producers share the Arc. Check if consumer's Arc is gone.
288        // Actually, both Producer and Consumer hold Arc<Shared>, so we can't
289        // distinguish. For now, this returns true when all other handles dropped.
290        Arc::strong_count(&self.shared) == 1
291    }
292}
293
294impl<T> fmt::Debug for Producer<T> {
295    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
296        f.debug_struct("Producer")
297            .field("capacity", &self.capacity())
298            .finish_non_exhaustive()
299    }
300}
301
302/// The consumer endpoint of an MPSC queue.
303///
304/// This endpoint cannot be cloned - only one consumer thread is allowed.
305pub struct Consumer<T> {
306    /// Local head index - only this thread reads/writes.
307    local_head: usize,
308    shared: Arc<Shared<T>>,
309}
310
311// SAFETY: Consumer can be sent to another thread. It has exclusive read access
312// to slots (via turn protocol) and maintains the head index.
313unsafe impl<T: Send> Send for Consumer<T> {}
314
315impl<T> Consumer<T> {
316    /// Pops a value from the queue.
317    ///
318    /// Returns `None` if the queue is empty.
319    #[inline]
320    pub fn pop(&mut self) -> Option<T> {
321        let head = self.local_head;
322        let slot = unsafe { &*self.shared.slots.add(head & self.shared.mask) };
323        let turn = head / self.shared.capacity;
324
325        // Check if slot is ready (turn * 2 + 1 means producer has written)
326        if slot.turn.load(Ordering::Acquire) != turn * 2 + 1 {
327            return None;
328        }
329
330        // SAFETY: Turn counter confirms producer has written to this slot.
331        let value = unsafe { (*slot.data.get()).assume_init_read() };
332
333        // Signal slot is free for next lap: (turn + 1) * 2
334        slot.turn.store((turn + 1) * 2, Ordering::Release);
335
336        // Advance head and publish for producers' capacity check
337        self.local_head = head.wrapping_add(1);
338        self.shared.head.store(self.local_head, Ordering::Release);
339
340        Some(value)
341    }
342
343    /// Returns the capacity of the queue.
344    #[inline]
345    pub fn capacity(&self) -> usize {
346        self.shared.capacity
347    }
348
349    /// Returns `true` if all producers have been dropped.
350    #[inline]
351    pub fn is_disconnected(&self) -> bool {
352        Arc::strong_count(&self.shared) == 1
353    }
354}
355
356impl<T> fmt::Debug for Consumer<T> {
357    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
358        f.debug_struct("Consumer")
359            .field("capacity", &self.capacity())
360            .finish_non_exhaustive()
361    }
362}
363
364/// Simple exponential backoff for CAS retry loops.
365struct Backoff {
366    step: u32,
367}
368
369impl Backoff {
370    #[inline]
371    fn new() -> Self {
372        Self { step: 0 }
373    }
374
375    #[inline]
376    fn spin(&mut self) {
377        for _ in 0..(1 << self.step.min(6)) {
378            std::hint::spin_loop();
379        }
380        self.step += 1;
381    }
382}
383
384#[cfg(test)]
385mod tests {
386    use super::*;
387
388    // ============================================================================
389    // Basic Operations
390    // ============================================================================
391
392    #[test]
393    fn basic_push_pop() {
394        let (mut tx, mut rx) = bounded::<u64>(4);
395
396        assert!(tx.push(1).is_ok());
397        assert!(tx.push(2).is_ok());
398        assert!(tx.push(3).is_ok());
399
400        assert_eq!(rx.pop(), Some(1));
401        assert_eq!(rx.pop(), Some(2));
402        assert_eq!(rx.pop(), Some(3));
403        assert_eq!(rx.pop(), None);
404    }
405
406    #[test]
407    fn empty_pop_returns_none() {
408        let (_, mut rx) = bounded::<u64>(4);
409        assert_eq!(rx.pop(), None);
410        assert_eq!(rx.pop(), None);
411    }
412
413    #[test]
414    fn fill_then_drain() {
415        let (mut tx, mut rx) = bounded::<u64>(4);
416
417        for i in 0..4 {
418            assert!(tx.push(i).is_ok());
419        }
420
421        for i in 0..4 {
422            assert_eq!(rx.pop(), Some(i));
423        }
424
425        assert_eq!(rx.pop(), None);
426    }
427
428    #[test]
429    fn push_returns_error_when_full() {
430        let (mut tx, _rx) = bounded::<u64>(4);
431
432        assert!(tx.push(1).is_ok());
433        assert!(tx.push(2).is_ok());
434        assert!(tx.push(3).is_ok());
435        assert!(tx.push(4).is_ok());
436
437        let err = tx.push(5).unwrap_err();
438        assert_eq!(err.into_inner(), 5);
439    }
440
441    // ============================================================================
442    // Interleaved Operations
443    // ============================================================================
444
445    #[test]
446    fn interleaved_single_producer() {
447        let (mut tx, mut rx) = bounded::<u64>(8);
448
449        for i in 0..1000 {
450            assert!(tx.push(i).is_ok());
451            assert_eq!(rx.pop(), Some(i));
452        }
453    }
454
455    #[test]
456    fn partial_fill_drain_cycles() {
457        let (mut tx, mut rx) = bounded::<u64>(8);
458
459        for round in 0..100 {
460            for i in 0..4 {
461                assert!(tx.push(round * 4 + i).is_ok());
462            }
463
464            for i in 0..4 {
465                assert_eq!(rx.pop(), Some(round * 4 + i));
466            }
467        }
468    }
469
470    // ============================================================================
471    // Multiple Producers
472    // ============================================================================
473
474    #[test]
475    fn two_producers_single_consumer() {
476        use std::thread;
477
478        let (mut tx, mut rx) = bounded::<u64>(64);
479        let mut tx2 = tx.clone();
480
481        let h1 = thread::spawn(move || {
482            for i in 0..1000 {
483                while tx.push(i).is_err() {
484                    std::hint::spin_loop();
485                }
486            }
487        });
488
489        let h2 = thread::spawn(move || {
490            for i in 1000..2000 {
491                while tx2.push(i).is_err() {
492                    std::hint::spin_loop();
493                }
494            }
495        });
496
497        let mut received = Vec::new();
498        while received.len() < 2000 {
499            if let Some(val) = rx.pop() {
500                received.push(val);
501            } else {
502                std::hint::spin_loop();
503            }
504        }
505
506        h1.join().unwrap();
507        h2.join().unwrap();
508
509        // All values received (order not guaranteed across producers)
510        received.sort();
511        assert_eq!(received, (0..2000).collect::<Vec<_>>());
512    }
513
514    #[test]
515    fn four_producers_single_consumer() {
516        use std::thread;
517
518        let (tx, mut rx) = bounded::<u64>(256);
519
520        let handles: Vec<_> = (0..4)
521            .map(|p| {
522                let mut tx = tx.clone();
523                thread::spawn(move || {
524                    for i in 0..1000 {
525                        let val = p * 1000 + i;
526                        while tx.push(val).is_err() {
527                            std::hint::spin_loop();
528                        }
529                    }
530                })
531            })
532            .collect();
533
534        drop(tx); // Drop original producer
535
536        let mut received = Vec::new();
537        while received.len() < 4000 {
538            if let Some(val) = rx.pop() {
539                received.push(val);
540            } else if rx.is_disconnected() && received.len() < 4000 {
541                // Keep trying if not all received
542                std::hint::spin_loop();
543            } else {
544                std::hint::spin_loop();
545            }
546        }
547
548        for h in handles {
549            h.join().unwrap();
550        }
551
552        received.sort();
553        let expected: Vec<u64> = (0..4)
554            .flat_map(|p| (0..1000).map(move |i| p * 1000 + i))
555            .collect();
556        let mut expected_sorted = expected;
557        expected_sorted.sort();
558        assert_eq!(received, expected_sorted);
559    }
560
561    // ============================================================================
562    // Single Slot
563    // ============================================================================
564
565    #[test]
566    fn single_slot_bounded() {
567        let (mut tx, mut rx) = bounded::<u64>(1);
568
569        assert!(tx.push(1).is_ok());
570        assert!(tx.push(2).is_err());
571
572        assert_eq!(rx.pop(), Some(1));
573        assert!(tx.push(2).is_ok());
574    }
575
576    // ============================================================================
577    // Disconnection
578    // ============================================================================
579
580    #[test]
581    fn producer_disconnected() {
582        let (tx, rx) = bounded::<u64>(4);
583
584        assert!(!rx.is_disconnected());
585        drop(tx);
586        assert!(rx.is_disconnected());
587    }
588
589    #[test]
590    fn consumer_disconnected() {
591        let (tx, rx) = bounded::<u64>(4);
592
593        assert!(!tx.is_disconnected());
594        drop(rx);
595        assert!(tx.is_disconnected());
596    }
597
598    #[test]
599    fn multiple_producers_one_disconnects() {
600        let (tx1, rx) = bounded::<u64>(4);
601        let tx2 = tx1.clone();
602
603        assert!(!rx.is_disconnected());
604        drop(tx1);
605        assert!(!rx.is_disconnected()); // tx2 still alive
606        drop(tx2);
607        assert!(rx.is_disconnected());
608    }
609
610    // ============================================================================
611    // Drop Behavior
612    // ============================================================================
613
614    #[test]
615    fn drop_cleans_up_remaining() {
616        use std::sync::atomic::AtomicUsize;
617
618        static DROP_COUNT: AtomicUsize = AtomicUsize::new(0);
619
620        struct DropCounter;
621        impl Drop for DropCounter {
622            fn drop(&mut self) {
623                DROP_COUNT.fetch_add(1, Ordering::SeqCst);
624            }
625        }
626
627        DROP_COUNT.store(0, Ordering::SeqCst);
628
629        let (mut tx, rx) = bounded::<DropCounter>(4);
630
631        let _ = tx.push(DropCounter);
632        let _ = tx.push(DropCounter);
633        let _ = tx.push(DropCounter);
634
635        assert_eq!(DROP_COUNT.load(Ordering::SeqCst), 0);
636
637        drop(tx);
638        drop(rx);
639
640        assert_eq!(DROP_COUNT.load(Ordering::SeqCst), 3);
641    }
642
643    // ============================================================================
644    // Special Types
645    // ============================================================================
646
647    #[test]
648    fn zero_sized_type() {
649        let (mut tx, mut rx) = bounded::<()>(8);
650
651        let _ = tx.push(());
652        let _ = tx.push(());
653
654        assert_eq!(rx.pop(), Some(()));
655        assert_eq!(rx.pop(), Some(()));
656        assert_eq!(rx.pop(), None);
657    }
658
659    #[test]
660    fn string_type() {
661        let (mut tx, mut rx) = bounded::<String>(4);
662
663        let _ = tx.push("hello".to_string());
664        let _ = tx.push("world".to_string());
665
666        assert_eq!(rx.pop(), Some("hello".to_string()));
667        assert_eq!(rx.pop(), Some("world".to_string()));
668    }
669
670    #[test]
671    #[should_panic(expected = "capacity must be non-zero")]
672    fn zero_capacity_panics() {
673        let _ = bounded::<u64>(0);
674    }
675
676    #[test]
677    fn large_message_type() {
678        #[repr(C, align(64))]
679        struct LargeMessage {
680            data: [u8; 256],
681        }
682
683        let (mut tx, mut rx) = bounded::<LargeMessage>(8);
684
685        let msg = LargeMessage { data: [42u8; 256] };
686        assert!(tx.push(msg).is_ok());
687
688        let received = rx.pop().unwrap();
689        assert_eq!(received.data[0], 42);
690        assert_eq!(received.data[255], 42);
691    }
692
693    #[test]
694    fn multiple_laps() {
695        let (mut tx, mut rx) = bounded::<u64>(4);
696
697        // 10 full laps through 4-slot buffer
698        for i in 0..40 {
699            assert!(tx.push(i).is_ok());
700            assert_eq!(rx.pop(), Some(i));
701        }
702    }
703
704    #[test]
705    fn capacity_rounds_to_power_of_two() {
706        let (tx, _) = bounded::<u64>(100);
707        assert_eq!(tx.capacity(), 128);
708
709        let (tx, _) = bounded::<u64>(1000);
710        assert_eq!(tx.capacity(), 1024);
711    }
712
713    // ============================================================================
714    // Stress Tests
715    // ============================================================================
716
717    #[test]
718    fn stress_single_producer() {
719        use std::thread;
720
721        const COUNT: u64 = 100_000;
722
723        let (mut tx, mut rx) = bounded::<u64>(1024);
724
725        let producer = thread::spawn(move || {
726            for i in 0..COUNT {
727                while tx.push(i).is_err() {
728                    std::hint::spin_loop();
729                }
730            }
731        });
732
733        let consumer = thread::spawn(move || {
734            let mut sum = 0u64;
735            let mut received = 0u64;
736            while received < COUNT {
737                if let Some(val) = rx.pop() {
738                    sum = sum.wrapping_add(val);
739                    received += 1;
740                } else {
741                    std::hint::spin_loop();
742                }
743            }
744            sum
745        });
746
747        producer.join().unwrap();
748        let sum = consumer.join().unwrap();
749        assert_eq!(sum, COUNT * (COUNT - 1) / 2);
750    }
751
752    #[test]
753    fn stress_multiple_producers() {
754        use std::thread;
755
756        const PRODUCERS: u64 = 4;
757        const PER_PRODUCER: u64 = 25_000;
758        const TOTAL: u64 = PRODUCERS * PER_PRODUCER;
759
760        let (tx, mut rx) = bounded::<u64>(1024);
761
762        let handles: Vec<_> = (0..PRODUCERS)
763            .map(|_| {
764                let mut tx = tx.clone();
765                thread::spawn(move || {
766                    for i in 0..PER_PRODUCER {
767                        while tx.push(i).is_err() {
768                            std::hint::spin_loop();
769                        }
770                    }
771                })
772            })
773            .collect();
774
775        drop(tx);
776
777        let mut received = 0u64;
778        while received < TOTAL {
779            if rx.pop().is_some() {
780                received += 1;
781            } else {
782                std::hint::spin_loop();
783            }
784        }
785
786        for h in handles {
787            h.join().unwrap();
788        }
789
790        assert_eq!(received, TOTAL);
791    }
792}