Skip to main content

nexus_queue/
mpsc.rs

1//! Multi-producer single-consumer bounded queue.
2//!
3//! A lock-free ring buffer optimized for multiple producer threads sending to
4//! one consumer thread. Uses CAS-based slot claiming with Vyukov-style turn
5//! counters for synchronization.
6//!
7//! # Design
8//!
9//! ```text
10//! ┌─────────────────────────────────────────────────────────────┐
11//! │ Shared (Arc):                                               │
12//! │   tail: CachePadded<AtomicUsize>   ← Producers CAS here     │
13//! │   head: CachePadded<AtomicUsize>   ← Consumer writes        │
14//! │   slots: *mut Slot<T>              ← Per-slot turn counters │
15//! └─────────────────────────────────────────────────────────────┘
16//!
17//! ┌─────────────────────┐     ┌─────────────────────┐
18//! │ Producer (Clone):   │     │ Consumer (!Clone):  │
19//! │   cached_head       │     │   local_head        │
20//! │   shared: Arc       │     │   shared: Arc       │
21//! └─────────────────────┘     └─────────────────────┘
22//! ```
23//!
24//! Producers compete via CAS on the tail index. After claiming a slot, the
25//! producer waits for the slot's turn counter to indicate it's writable, writes
26//! the data, then advances the turn to signal readiness.
27//!
28//! The consumer checks the turn counter to know when data is ready, reads it,
29//! then advances the turn for the next producer lap.
30//!
31//! # Turn Counter Protocol
32//!
33//! For slot at index `i` on lap `turn`:
34//! - `turn * 2`: Slot is ready for producer to write
35//! - `turn * 2 + 1`: Slot contains data, ready for consumer
36//!
37//! # Example
38//!
39//! ```
40//! use nexus_queue::mpsc;
41//! use std::thread;
42//!
43//! let (mut tx, mut rx) = mpsc::bounded::<u64>(1024);
44//!
45//! let mut tx2 = tx.clone();
46//! let h1 = thread::spawn(move || {
47//!     for i in 0..100 {
48//!         while tx.push(i).is_err() { std::hint::spin_loop(); }
49//!     }
50//! });
51//! let h2 = thread::spawn(move || {
52//!     for i in 100..200 {
53//!         while tx2.push(i).is_err() { std::hint::spin_loop(); }
54//!     }
55//! });
56//!
57//! let mut received = 0;
58//! while received < 200 {
59//!     if rx.pop().is_some() { received += 1; }
60//! }
61//!
62//! h1.join().unwrap();
63//! h2.join().unwrap();
64//! ```
65
66use std::cell::UnsafeCell;
67use std::fmt;
68use std::mem::MaybeUninit;
69use std::sync::Arc;
70use std::sync::atomic::{AtomicUsize, Ordering};
71
72use crossbeam_utils::CachePadded;
73
74use crate::Full;
75
76/// Creates a bounded MPSC queue with the given capacity.
77///
78/// Capacity is rounded up to the next power of two.
79///
80/// # Panics
81///
82/// Panics if `capacity` is zero.
83pub fn bounded<T>(capacity: usize) -> (Producer<T>, Consumer<T>) {
84    assert!(capacity > 0, "capacity must be non-zero");
85
86    let capacity = capacity.next_power_of_two();
87    let mask = capacity - 1;
88
89    // Allocate slots with turn counters initialized to 0 (ready for turn 0 producers)
90    let slots: Vec<Slot<T>> = (0..capacity)
91        .map(|_| Slot {
92            turn: AtomicUsize::new(0),
93            data: UnsafeCell::new(MaybeUninit::uninit()),
94        })
95        .collect();
96    let slots = Box::into_raw(slots.into_boxed_slice()) as *mut Slot<T>;
97
98    let shift = capacity.trailing_zeros();
99
100    let shared = Arc::new(Shared {
101        tail: CachePadded::new(AtomicUsize::new(0)),
102        head: CachePadded::new(AtomicUsize::new(0)),
103        slots,
104        capacity,
105        shift,
106        mask,
107    });
108
109    (
110        Producer {
111            cached_head: 0,
112            slots,
113            mask,
114            capacity,
115            shift,
116            shared: Arc::clone(&shared),
117        },
118        Consumer {
119            local_head: 0,
120            slots,
121            mask,
122            shift,
123            shared,
124        },
125    )
126}
127
128/// A slot in the ring buffer with turn-based synchronization.
129struct Slot<T> {
130    /// Turn counter for Vyukov-style synchronization.
131    /// - `turn * 2`: ready for producer
132    /// - `turn * 2 + 1`: ready for consumer
133    turn: AtomicUsize,
134    /// The data stored in this slot.
135    data: UnsafeCell<MaybeUninit<T>>,
136}
137
138/// Shared state between producers and the consumer.
139// repr(C): Guarantees field order for cache line layout.
140#[repr(C)]
141struct Shared<T> {
142    /// Tail index - producers CAS on this to claim slots.
143    tail: CachePadded<AtomicUsize>,
144    /// Head index - consumer publishes progress here.
145    head: CachePadded<AtomicUsize>,
146    /// Pointer to the slot array.
147    slots: *mut Slot<T>,
148    /// Actual capacity (power of two).
149    capacity: usize,
150    /// Shift for fast division by capacity (log2(capacity)).
151    shift: u32,
152    /// Mask for fast modulo (capacity - 1).
153    mask: usize,
154}
155
156// SAFETY: Shared contains atomics and raw pointers. Access is synchronized via
157// the turn counters. T: Send ensures data can move between threads.
158unsafe impl<T: Send> Send for Shared<T> {}
159unsafe impl<T: Send> Sync for Shared<T> {}
160
161impl<T> Drop for Shared<T> {
162    fn drop(&mut self) {
163        let head = self.head.load(Ordering::Relaxed);
164        let tail = self.tail.load(Ordering::Relaxed);
165
166        // Drop any remaining elements in the queue
167        let mut i = head;
168        while i != tail {
169            let slot = unsafe { &*self.slots.add(i & self.mask) };
170            let turn = i >> self.shift;
171
172            // Only drop if the slot was actually written (turn is odd = consumer-ready)
173            if slot.turn.load(Ordering::Relaxed) == turn * 2 + 1 {
174                // SAFETY: Slot contains initialized data at this turn.
175                unsafe { (*slot.data.get()).assume_init_drop() };
176            }
177            i = i.wrapping_add(1);
178        }
179
180        // SAFETY: slots was allocated via Box::into_raw from a Vec.
181        unsafe {
182            let _ = Box::from_raw(std::ptr::slice_from_raw_parts_mut(
183                self.slots,
184                self.capacity,
185            ));
186        }
187    }
188}
189
190/// The producer endpoint of an MPSC queue.
191///
192/// This endpoint can be cloned to create additional producers. Each clone
193/// maintains its own cached state for performance.
194// repr(C): Hot fields at struct base share cache line with struct pointer.
195#[repr(C)]
196pub struct Producer<T> {
197    /// Cached head for fast full-check. Only refreshed when cache indicates full.
198    cached_head: usize,
199    /// Cached slots pointer (avoids Arc deref on hot path).
200    slots: *mut Slot<T>,
201    /// Cached mask (avoids Arc deref on hot path).
202    mask: usize,
203    /// Cached capacity (avoids Arc deref on hot path).
204    capacity: usize,
205    /// Cached shift for fast division (log2(capacity)).
206    shift: u32,
207    shared: Arc<Shared<T>>,
208}
209
210impl<T> Clone for Producer<T> {
211    fn clone(&self) -> Self {
212        Producer {
213            // Fresh cache - will be populated on first push
214            cached_head: self.shared.head.load(Ordering::Relaxed),
215            slots: self.slots,
216            mask: self.mask,
217            capacity: self.capacity,
218            shift: self.shift,
219            shared: Arc::clone(&self.shared),
220        }
221    }
222}
223
224// SAFETY: Producer can be sent to another thread. Each Producer instance is
225// used by one thread (not Sync - use clone() for multiple threads).
226unsafe impl<T: Send> Send for Producer<T> {}
227
228impl<T> Producer<T> {
229    /// Pushes a value into the queue.
230    ///
231    /// Returns `Err(Full(value))` if the queue is full, returning ownership
232    /// of the value to the caller for backpressure handling.
233    ///
234    /// This method spins internally on CAS contention but returns immediately
235    /// when the queue is actually full.
236    #[inline]
237    #[must_use = "push returns Err if full, which should be handled"]
238    pub fn push(&mut self, value: T) -> Result<(), Full<T>> {
239        let mut spin_count = 0u32;
240
241        loop {
242            let tail = self.shared.tail.load(Ordering::Relaxed);
243
244            // Check against cached head (avoids atomic load most of the time)
245            if tail.wrapping_sub(self.cached_head) >= self.capacity {
246                // Cache miss: refresh from shared head
247                self.cached_head = self.shared.head.load(Ordering::Acquire);
248
249                // Re-check with fresh head - if still full, return error
250                if tail.wrapping_sub(self.cached_head) >= self.capacity {
251                    return Err(Full(value));
252                }
253            }
254
255            // SAFETY: slots pointer is valid for the lifetime of shared.
256            let slot = unsafe { &*self.slots.add(tail & self.mask) };
257            let turn = tail >> self.shift;
258            let expected_stamp = turn * 2;
259
260            // Check if slot is ready BEFORE attempting CAS (Vyukov optimization)
261            let stamp = slot.turn.load(Ordering::Acquire);
262
263            if stamp == expected_stamp {
264                // Slot is ready - try to claim it
265                if self
266                    .shared
267                    .tail
268                    .compare_exchange_weak(
269                        tail,
270                        tail.wrapping_add(1),
271                        Ordering::Relaxed,
272                        Ordering::Relaxed,
273                    )
274                    .is_ok()
275                {
276                    // SAFETY: We own this slot via successful CAS.
277                    unsafe { (*slot.data.get()).write(value) };
278
279                    // Signal ready for consumer: turn * 2 + 1
280                    slot.turn.store(turn * 2 + 1, Ordering::Release);
281
282                    return Ok(());
283                }
284            }
285
286            // CAS failed or slot not ready - exponential backoff
287            // Cap at 6 to avoid excessive spinning (1, 2, 4, 8, 16, 32, 64 iterations)
288            let spins = 1 << spin_count.min(6);
289            for _ in 0..spins {
290                std::hint::spin_loop();
291            }
292            spin_count += 1;
293        }
294    }
295
296    /// Returns the capacity of the queue.
297    #[inline]
298    pub fn capacity(&self) -> usize {
299        1 << self.shift
300    }
301
302    /// Returns `true` if the consumer has been dropped.
303    ///
304    /// With multiple producers, this returns `true` only when this is the
305    /// last handle (all other producers and the consumer are dropped).
306    #[inline]
307    pub fn is_disconnected(&self) -> bool {
308        Arc::strong_count(&self.shared) == 1
309    }
310}
311
312impl<T> fmt::Debug for Producer<T> {
313    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
314        f.debug_struct("Producer")
315            .field("capacity", &self.capacity())
316            .finish_non_exhaustive()
317    }
318}
319
320/// The consumer endpoint of an MPSC queue.
321///
322/// This endpoint cannot be cloned - only one consumer thread is allowed.
323// repr(C): Hot fields at struct base share cache line with struct pointer.
324#[repr(C)]
325pub struct Consumer<T> {
326    /// Local head index - only this thread reads/writes.
327    local_head: usize,
328    /// Cached slots pointer (avoids Arc deref on hot path).
329    slots: *mut Slot<T>,
330    /// Cached mask (avoids Arc deref on hot path).
331    mask: usize,
332    /// Cached shift for fast division (log2(capacity)).
333    shift: u32,
334    shared: Arc<Shared<T>>,
335}
336
337// SAFETY: Consumer can be sent to another thread. It has exclusive read access
338// to slots (via turn protocol) and maintains the head index.
339unsafe impl<T: Send> Send for Consumer<T> {}
340
341impl<T> Consumer<T> {
342    /// Pops a value from the queue.
343    ///
344    /// Returns `None` if the queue is empty.
345    #[inline]
346    pub fn pop(&mut self) -> Option<T> {
347        let head = self.local_head;
348        // SAFETY: slots pointer is valid for the lifetime of shared.
349        let slot = unsafe { &*self.slots.add(head & self.mask) };
350        let turn = head >> self.shift;
351
352        // Check if slot is ready (turn * 2 + 1 means producer has written)
353        if slot.turn.load(Ordering::Acquire) != turn * 2 + 1 {
354            return None;
355        }
356
357        // SAFETY: Turn counter confirms producer has written to this slot.
358        let value = unsafe { (*slot.data.get()).assume_init_read() };
359
360        // Signal slot is free for next lap: (turn + 1) * 2
361        slot.turn.store((turn + 1) * 2, Ordering::Release);
362
363        // Advance head and publish for producers' capacity check
364        self.local_head = head.wrapping_add(1);
365        self.shared.head.store(self.local_head, Ordering::Release);
366
367        Some(value)
368    }
369
370    /// Returns the capacity of the queue.
371    #[inline]
372    pub fn capacity(&self) -> usize {
373        1 << self.shift
374    }
375
376    /// Returns `true` if all producers have been dropped.
377    #[inline]
378    pub fn is_disconnected(&self) -> bool {
379        Arc::strong_count(&self.shared) == 1
380    }
381}
382
383impl<T> fmt::Debug for Consumer<T> {
384    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
385        f.debug_struct("Consumer")
386            .field("capacity", &self.capacity())
387            .finish_non_exhaustive()
388    }
389}
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394
395    // ============================================================================
396    // Basic Operations
397    // ============================================================================
398
399    #[test]
400    fn basic_push_pop() {
401        let (mut tx, mut rx) = bounded::<u64>(4);
402
403        assert!(tx.push(1).is_ok());
404        assert!(tx.push(2).is_ok());
405        assert!(tx.push(3).is_ok());
406
407        assert_eq!(rx.pop(), Some(1));
408        assert_eq!(rx.pop(), Some(2));
409        assert_eq!(rx.pop(), Some(3));
410        assert_eq!(rx.pop(), None);
411    }
412
413    #[test]
414    fn empty_pop_returns_none() {
415        let (_, mut rx) = bounded::<u64>(4);
416        assert_eq!(rx.pop(), None);
417        assert_eq!(rx.pop(), None);
418    }
419
420    #[test]
421    fn fill_then_drain() {
422        let (mut tx, mut rx) = bounded::<u64>(4);
423
424        for i in 0..4 {
425            assert!(tx.push(i).is_ok());
426        }
427
428        for i in 0..4 {
429            assert_eq!(rx.pop(), Some(i));
430        }
431
432        assert_eq!(rx.pop(), None);
433    }
434
435    #[test]
436    fn push_returns_error_when_full() {
437        let (mut tx, _rx) = bounded::<u64>(4);
438
439        assert!(tx.push(1).is_ok());
440        assert!(tx.push(2).is_ok());
441        assert!(tx.push(3).is_ok());
442        assert!(tx.push(4).is_ok());
443
444        let err = tx.push(5).unwrap_err();
445        assert_eq!(err.into_inner(), 5);
446    }
447
448    // ============================================================================
449    // Interleaved Operations
450    // ============================================================================
451
452    #[test]
453    fn interleaved_single_producer() {
454        let (mut tx, mut rx) = bounded::<u64>(8);
455
456        for i in 0..1000 {
457            assert!(tx.push(i).is_ok());
458            assert_eq!(rx.pop(), Some(i));
459        }
460    }
461
462    #[test]
463    fn partial_fill_drain_cycles() {
464        let (mut tx, mut rx) = bounded::<u64>(8);
465
466        for round in 0..100 {
467            for i in 0..4 {
468                assert!(tx.push(round * 4 + i).is_ok());
469            }
470
471            for i in 0..4 {
472                assert_eq!(rx.pop(), Some(round * 4 + i));
473            }
474        }
475    }
476
477    // ============================================================================
478    // Multiple Producers
479    // ============================================================================
480
481    #[test]
482    fn two_producers_single_consumer() {
483        use std::thread;
484
485        let (mut tx, mut rx) = bounded::<u64>(64);
486        let mut tx2 = tx.clone();
487
488        let h1 = thread::spawn(move || {
489            for i in 0..1000 {
490                while tx.push(i).is_err() {
491                    std::hint::spin_loop();
492                }
493            }
494        });
495
496        let h2 = thread::spawn(move || {
497            for i in 1000..2000 {
498                while tx2.push(i).is_err() {
499                    std::hint::spin_loop();
500                }
501            }
502        });
503
504        let mut received = Vec::new();
505        while received.len() < 2000 {
506            if let Some(val) = rx.pop() {
507                received.push(val);
508            } else {
509                std::hint::spin_loop();
510            }
511        }
512
513        h1.join().unwrap();
514        h2.join().unwrap();
515
516        // All values received (order not guaranteed across producers)
517        received.sort();
518        assert_eq!(received, (0..2000).collect::<Vec<_>>());
519    }
520
521    #[test]
522    fn four_producers_single_consumer() {
523        use std::thread;
524
525        let (tx, mut rx) = bounded::<u64>(256);
526
527        let handles: Vec<_> = (0..4)
528            .map(|p| {
529                let mut tx = tx.clone();
530                thread::spawn(move || {
531                    for i in 0..1000 {
532                        let val = p * 1000 + i;
533                        while tx.push(val).is_err() {
534                            std::hint::spin_loop();
535                        }
536                    }
537                })
538            })
539            .collect();
540
541        drop(tx); // Drop original producer
542
543        let mut received = Vec::new();
544        while received.len() < 4000 {
545            if let Some(val) = rx.pop() {
546                received.push(val);
547            } else if rx.is_disconnected() && received.len() < 4000 {
548                // Keep trying if not all received
549                std::hint::spin_loop();
550            } else {
551                std::hint::spin_loop();
552            }
553        }
554
555        for h in handles {
556            h.join().unwrap();
557        }
558
559        received.sort();
560        let expected: Vec<u64> = (0..4)
561            .flat_map(|p| (0..1000).map(move |i| p * 1000 + i))
562            .collect();
563        let mut expected_sorted = expected;
564        expected_sorted.sort();
565        assert_eq!(received, expected_sorted);
566    }
567
568    // ============================================================================
569    // Single Slot
570    // ============================================================================
571
572    #[test]
573    fn single_slot_bounded() {
574        let (mut tx, mut rx) = bounded::<u64>(1);
575
576        assert!(tx.push(1).is_ok());
577        assert!(tx.push(2).is_err());
578
579        assert_eq!(rx.pop(), Some(1));
580        assert!(tx.push(2).is_ok());
581    }
582
583    // ============================================================================
584    // Disconnection
585    // ============================================================================
586
587    #[test]
588    fn producer_disconnected() {
589        let (tx, rx) = bounded::<u64>(4);
590
591        assert!(!rx.is_disconnected());
592        drop(tx);
593        assert!(rx.is_disconnected());
594    }
595
596    #[test]
597    fn consumer_disconnected() {
598        let (tx, rx) = bounded::<u64>(4);
599
600        assert!(!tx.is_disconnected());
601        drop(rx);
602        assert!(tx.is_disconnected());
603    }
604
605    #[test]
606    fn multiple_producers_one_disconnects() {
607        let (tx1, rx) = bounded::<u64>(4);
608        let tx2 = tx1.clone();
609
610        assert!(!rx.is_disconnected());
611        drop(tx1);
612        assert!(!rx.is_disconnected()); // tx2 still alive
613        drop(tx2);
614        assert!(rx.is_disconnected());
615    }
616
617    // ============================================================================
618    // Drop Behavior
619    // ============================================================================
620
621    #[test]
622    fn drop_cleans_up_remaining() {
623        use std::sync::atomic::AtomicUsize;
624
625        static DROP_COUNT: AtomicUsize = AtomicUsize::new(0);
626
627        struct DropCounter;
628        impl Drop for DropCounter {
629            fn drop(&mut self) {
630                DROP_COUNT.fetch_add(1, Ordering::SeqCst);
631            }
632        }
633
634        DROP_COUNT.store(0, Ordering::SeqCst);
635
636        let (mut tx, rx) = bounded::<DropCounter>(4);
637
638        let _ = tx.push(DropCounter);
639        let _ = tx.push(DropCounter);
640        let _ = tx.push(DropCounter);
641
642        assert_eq!(DROP_COUNT.load(Ordering::SeqCst), 0);
643
644        drop(tx);
645        drop(rx);
646
647        assert_eq!(DROP_COUNT.load(Ordering::SeqCst), 3);
648    }
649
650    // ============================================================================
651    // Special Types
652    // ============================================================================
653
654    #[test]
655    fn zero_sized_type() {
656        let (mut tx, mut rx) = bounded::<()>(8);
657
658        let _ = tx.push(());
659        let _ = tx.push(());
660
661        assert_eq!(rx.pop(), Some(()));
662        assert_eq!(rx.pop(), Some(()));
663        assert_eq!(rx.pop(), None);
664    }
665
666    #[test]
667    fn string_type() {
668        let (mut tx, mut rx) = bounded::<String>(4);
669
670        let _ = tx.push("hello".to_string());
671        let _ = tx.push("world".to_string());
672
673        assert_eq!(rx.pop(), Some("hello".to_string()));
674        assert_eq!(rx.pop(), Some("world".to_string()));
675    }
676
677    #[test]
678    #[should_panic(expected = "capacity must be non-zero")]
679    fn zero_capacity_panics() {
680        let _ = bounded::<u64>(0);
681    }
682
683    #[test]
684    fn large_message_type() {
685        #[repr(C, align(64))]
686        struct LargeMessage {
687            data: [u8; 256],
688        }
689
690        let (mut tx, mut rx) = bounded::<LargeMessage>(8);
691
692        let msg = LargeMessage { data: [42u8; 256] };
693        assert!(tx.push(msg).is_ok());
694
695        let received = rx.pop().unwrap();
696        assert_eq!(received.data[0], 42);
697        assert_eq!(received.data[255], 42);
698    }
699
700    #[test]
701    fn multiple_laps() {
702        let (mut tx, mut rx) = bounded::<u64>(4);
703
704        // 10 full laps through 4-slot buffer
705        for i in 0..40 {
706            assert!(tx.push(i).is_ok());
707            assert_eq!(rx.pop(), Some(i));
708        }
709    }
710
711    #[test]
712    fn capacity_rounds_to_power_of_two() {
713        let (tx, _) = bounded::<u64>(100);
714        assert_eq!(tx.capacity(), 128);
715
716        let (tx, _) = bounded::<u64>(1000);
717        assert_eq!(tx.capacity(), 1024);
718    }
719
720    // ============================================================================
721    // Stress Tests
722    // ============================================================================
723
724    #[test]
725    fn stress_single_producer() {
726        use std::thread;
727
728        const COUNT: u64 = 100_000;
729
730        let (mut tx, mut rx) = bounded::<u64>(1024);
731
732        let producer = thread::spawn(move || {
733            for i in 0..COUNT {
734                while tx.push(i).is_err() {
735                    std::hint::spin_loop();
736                }
737            }
738        });
739
740        let consumer = thread::spawn(move || {
741            let mut sum = 0u64;
742            let mut received = 0u64;
743            while received < COUNT {
744                if let Some(val) = rx.pop() {
745                    sum = sum.wrapping_add(val);
746                    received += 1;
747                } else {
748                    std::hint::spin_loop();
749                }
750            }
751            sum
752        });
753
754        producer.join().unwrap();
755        let sum = consumer.join().unwrap();
756        assert_eq!(sum, COUNT * (COUNT - 1) / 2);
757    }
758
759    #[test]
760    fn stress_multiple_producers() {
761        use std::thread;
762
763        const PRODUCERS: u64 = 4;
764        const PER_PRODUCER: u64 = 25_000;
765        const TOTAL: u64 = PRODUCERS * PER_PRODUCER;
766
767        let (tx, mut rx) = bounded::<u64>(1024);
768
769        let handles: Vec<_> = (0..PRODUCERS)
770            .map(|_| {
771                let mut tx = tx.clone();
772                thread::spawn(move || {
773                    for i in 0..PER_PRODUCER {
774                        while tx.push(i).is_err() {
775                            std::hint::spin_loop();
776                        }
777                    }
778                })
779            })
780            .collect();
781
782        drop(tx);
783
784        let mut received = 0u64;
785        while received < TOTAL {
786            if rx.pop().is_some() {
787                received += 1;
788            } else {
789                std::hint::spin_loop();
790            }
791        }
792
793        for h in handles {
794            h.join().unwrap();
795        }
796
797        assert_eq!(received, TOTAL);
798    }
799}