Skip to main content

nexus_queue/
mpsc.rs

1//! Multi-producer single-consumer bounded queue.
2//!
3//! A lock-free ring buffer optimized for multiple producer threads sending to
4//! one consumer thread. Uses CAS-based slot claiming with Vyukov-style turn
5//! counters for synchronization.
6//!
7//! # Design
8//!
9//! ```text
10//! ┌─────────────────────────────────────────────────────────────┐
11//! │ Shared (Arc):                                               │
12//! │   tail: CachePadded<AtomicUsize>   ← Producers CAS here     │
13//! │   head: CachePadded<AtomicUsize>   ← Consumer writes        │
14//! │   slots: *mut Slot<T>              ← Per-slot turn counters │
15//! └─────────────────────────────────────────────────────────────┘
16//!
17//! ┌─────────────────────┐     ┌─────────────────────┐
18//! │ Producer (Clone):   │     │ Consumer (!Clone):  │
19//! │   cached_head       │     │   local_head        │
20//! │   shared: Arc       │     │   shared: Arc       │
21//! └─────────────────────┘     └─────────────────────┘
22//! ```
23//!
24//! Producers compete via CAS on the tail index. After claiming a slot, the
25//! producer waits for the slot's turn counter to indicate it's writable, writes
26//! the data, then advances the turn to signal readiness.
27//!
28//! The consumer checks the turn counter to know when data is ready, reads it,
29//! then advances the turn for the next producer lap.
30//!
31//! # Turn Counter Protocol
32//!
33//! For slot at index `i` on lap `turn`:
34//! - `turn * 2`: Slot is ready for producer to write
35//! - `turn * 2 + 1`: Slot contains data, ready for consumer
36//!
37//! # Example
38//!
39//! ```
40//! use nexus_queue::mpsc;
41//! use std::thread;
42//!
43//! let (tx, rx) = mpsc::bounded::<u64>(1024);
44//!
45//! let tx2 = tx.clone();
46//! let h1 = thread::spawn(move || {
47//!     for i in 0..100 {
48//!         while tx.push(i).is_err() { std::hint::spin_loop(); }
49//!     }
50//! });
51//! let h2 = thread::spawn(move || {
52//!     for i in 100..200 {
53//!         while tx2.push(i).is_err() { std::hint::spin_loop(); }
54//!     }
55//! });
56//!
57//! let mut received = 0;
58//! while received < 200 {
59//!     if rx.pop().is_some() { received += 1; }
60//! }
61//!
62//! h1.join().unwrap();
63//! h2.join().unwrap();
64//! ```
65
66use std::cell::{Cell, UnsafeCell};
67use std::fmt;
68use std::mem::MaybeUninit;
69use std::sync::Arc;
70use std::sync::atomic::{AtomicUsize, Ordering};
71
72use crossbeam_utils::CachePadded;
73
74use crate::Full;
75
76/// Creates a bounded MPSC queue with the given capacity.
77///
78/// Capacity is rounded up to the next power of two.
79///
80/// # Panics
81///
82/// Panics if `capacity` is zero.
83pub fn bounded<T>(capacity: usize) -> (Producer<T>, Consumer<T>) {
84    assert!(capacity > 0, "capacity must be non-zero");
85
86    let capacity = capacity.next_power_of_two();
87    let mask = capacity - 1;
88
89    // Allocate slots with turn counters initialized to 0 (ready for turn 0 producers)
90    let slots: Vec<Slot<T>> = (0..capacity)
91        .map(|_| Slot {
92            turn: AtomicUsize::new(0),
93            data: UnsafeCell::new(MaybeUninit::uninit()),
94        })
95        .collect();
96    let slots = Box::into_raw(slots.into_boxed_slice()) as *mut Slot<T>;
97
98    let shift = capacity.trailing_zeros();
99
100    let shared = Arc::new(Shared {
101        tail: CachePadded::new(AtomicUsize::new(0)),
102        head: CachePadded::new(AtomicUsize::new(0)),
103        slots,
104        capacity,
105        shift,
106        mask,
107    });
108
109    (
110        Producer {
111            cached_head: Cell::new(0),
112            slots,
113            mask,
114            capacity,
115            shift,
116            shared: Arc::clone(&shared),
117        },
118        Consumer {
119            local_head: Cell::new(0),
120            slots,
121            mask,
122            shift,
123            shared,
124        },
125    )
126}
127
128/// A slot in the ring buffer with turn-based synchronization.
129struct Slot<T> {
130    /// Turn counter for Vyukov-style synchronization.
131    /// - `turn * 2`: ready for producer
132    /// - `turn * 2 + 1`: ready for consumer
133    turn: AtomicUsize,
134    /// The data stored in this slot.
135    data: UnsafeCell<MaybeUninit<T>>,
136}
137
138/// Shared state between producers and the consumer.
139// repr(C): Guarantees field order for cache line layout.
140#[repr(C)]
141struct Shared<T> {
142    /// Tail index - producers CAS on this to claim slots.
143    tail: CachePadded<AtomicUsize>,
144    /// Head index - consumer publishes progress here.
145    head: CachePadded<AtomicUsize>,
146    /// Pointer to the slot array.
147    slots: *mut Slot<T>,
148    /// Actual capacity (power of two).
149    capacity: usize,
150    /// Shift for fast division by capacity (log2(capacity)).
151    shift: u32,
152    /// Mask for fast modulo (capacity - 1).
153    mask: usize,
154}
155
156// SAFETY: Shared contains atomics and raw pointers. Access is synchronized via
157// the turn counters. T: Send ensures data can move between threads.
158unsafe impl<T: Send> Send for Shared<T> {}
159unsafe impl<T: Send> Sync for Shared<T> {}
160
161impl<T> Drop for Shared<T> {
162    fn drop(&mut self) {
163        let head = self.head.load(Ordering::Relaxed);
164        let tail = self.tail.load(Ordering::Relaxed);
165
166        // Drop any remaining elements in the queue
167        let mut i = head;
168        while i != tail {
169            let slot = unsafe { &*self.slots.add(i & self.mask) };
170            let turn = i >> self.shift;
171
172            // Only drop if the slot was actually written (turn is odd = consumer-ready)
173            if slot.turn.load(Ordering::Relaxed) == turn * 2 + 1 {
174                // SAFETY: Slot contains initialized data at this turn.
175                unsafe { (*slot.data.get()).assume_init_drop() };
176            }
177            i = i.wrapping_add(1);
178        }
179
180        // SAFETY: slots was allocated via Box::into_raw from a Vec.
181        unsafe {
182            let _ = Box::from_raw(std::ptr::slice_from_raw_parts_mut(
183                self.slots,
184                self.capacity,
185            ));
186        }
187    }
188}
189
190/// The producer endpoint of an MPSC queue.
191///
192/// This endpoint can be cloned to create additional producers. Each clone
193/// maintains its own cached state for performance.
194// repr(C): Hot fields at struct base share cache line with struct pointer.
195#[repr(C)]
196pub struct Producer<T> {
197    /// Cached head for fast full-check. Only refreshed when cache indicates full.
198    cached_head: Cell<usize>,
199    /// Cached slots pointer (avoids Arc deref on hot path).
200    slots: *mut Slot<T>,
201    /// Cached mask (avoids Arc deref on hot path).
202    mask: usize,
203    /// Cached capacity (avoids Arc deref on hot path).
204    capacity: usize,
205    /// Cached shift for fast division (log2(capacity)).
206    shift: u32,
207    shared: Arc<Shared<T>>,
208}
209
210impl<T> Clone for Producer<T> {
211    fn clone(&self) -> Self {
212        Producer {
213            // Fresh cache - will be populated on first push
214            cached_head: Cell::new(self.shared.head.load(Ordering::Relaxed)),
215            slots: self.slots,
216            mask: self.mask,
217            capacity: self.capacity,
218            shift: self.shift,
219            shared: Arc::clone(&self.shared),
220        }
221    }
222}
223
224// SAFETY: Producer can be sent to another thread. Each Producer instance is
225// used by one thread (not Sync - use clone() for multiple threads).
226unsafe impl<T: Send> Send for Producer<T> {}
227
228impl<T> Producer<T> {
229    /// Pushes a value into the queue.
230    ///
231    /// Returns `Err(Full(value))` if the queue is full, returning ownership
232    /// of the value to the caller for backpressure handling.
233    ///
234    /// This method spins internally on CAS contention but returns immediately
235    /// when the queue is actually full.
236    #[inline]
237    #[must_use = "push returns Err if full, which should be handled"]
238    pub fn push(&self, value: T) -> Result<(), Full<T>> {
239        let mut spin_count = 0u32;
240
241        loop {
242            let tail = self.shared.tail.load(Ordering::Relaxed);
243
244            // Check against cached head (avoids atomic load most of the time)
245            if tail.wrapping_sub(self.cached_head.get()) >= self.capacity {
246                // Cache miss: refresh from shared head
247                self.cached_head.set(self.shared.head.load(Ordering::Acquire));
248
249                // Re-check with fresh head - if still full, return error
250                if tail.wrapping_sub(self.cached_head.get()) >= self.capacity {
251                    return Err(Full(value));
252                }
253            }
254
255            // SAFETY: slots pointer is valid for the lifetime of shared.
256            let slot = unsafe { &*self.slots.add(tail & self.mask) };
257            let turn = tail >> self.shift;
258            let expected_stamp = turn * 2;
259
260            // Check if slot is ready BEFORE attempting CAS (Vyukov optimization)
261            let stamp = slot.turn.load(Ordering::Acquire);
262
263            if stamp == expected_stamp {
264                // Slot is ready - try to claim it
265                if self
266                    .shared
267                    .tail
268                    .compare_exchange_weak(
269                        tail,
270                        tail.wrapping_add(1),
271                        Ordering::Relaxed,
272                        Ordering::Relaxed,
273                    )
274                    .is_ok()
275                {
276                    // SAFETY: We own this slot via successful CAS.
277                    unsafe { (*slot.data.get()).write(value) };
278
279                    // Signal ready for consumer: turn * 2 + 1
280                    slot.turn.store(turn * 2 + 1, Ordering::Release);
281
282                    return Ok(());
283                }
284            }
285
286            // CAS failed or slot not ready - exponential backoff
287            // Cap at 6 to avoid excessive spinning (1, 2, 4, 8, 16, 32, 64 iterations)
288            let spins = 1 << spin_count.min(6);
289            for _ in 0..spins {
290                std::hint::spin_loop();
291            }
292            spin_count += 1;
293        }
294    }
295
296    /// Returns the capacity of the queue.
297    #[inline]
298    pub fn capacity(&self) -> usize {
299        1 << self.shift
300    }
301
302    /// Returns `true` if the consumer has been dropped.
303    ///
304    /// With multiple producers, this returns `true` only when this is the
305    /// last handle (all other producers and the consumer are dropped).
306    #[inline]
307    pub fn is_disconnected(&self) -> bool {
308        Arc::strong_count(&self.shared) == 1
309    }
310}
311
312impl<T> fmt::Debug for Producer<T> {
313    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
314        f.debug_struct("Producer")
315            .field("capacity", &self.capacity())
316            .finish_non_exhaustive()
317    }
318}
319
320/// The consumer endpoint of an MPSC queue.
321///
322/// This endpoint cannot be cloned - only one consumer thread is allowed.
323// repr(C): Hot fields at struct base share cache line with struct pointer.
324#[repr(C)]
325pub struct Consumer<T> {
326    /// Local head index - only this thread reads/writes.
327    local_head: Cell<usize>,
328    /// Cached slots pointer (avoids Arc deref on hot path).
329    slots: *mut Slot<T>,
330    /// Cached mask (avoids Arc deref on hot path).
331    mask: usize,
332    /// Cached shift for fast division (log2(capacity)).
333    shift: u32,
334    shared: Arc<Shared<T>>,
335}
336
337// SAFETY: Consumer can be sent to another thread. It has exclusive read access
338// to slots (via turn protocol) and maintains the head index.
339unsafe impl<T: Send> Send for Consumer<T> {}
340
341impl<T> Consumer<T> {
342    /// Pops a value from the queue.
343    ///
344    /// Returns `None` if the queue is empty.
345    #[inline]
346    pub fn pop(&self) -> Option<T> {
347        let head = self.local_head.get();
348        // SAFETY: slots pointer is valid for the lifetime of shared.
349        let slot = unsafe { &*self.slots.add(head & self.mask) };
350        let turn = head >> self.shift;
351
352        // Check if slot is ready (turn * 2 + 1 means producer has written)
353        if slot.turn.load(Ordering::Acquire) != turn * 2 + 1 {
354            return None;
355        }
356
357        // SAFETY: Turn counter confirms producer has written to this slot.
358        let value = unsafe { (*slot.data.get()).assume_init_read() };
359
360        // Signal slot is free for next lap: (turn + 1) * 2
361        slot.turn.store((turn + 1) * 2, Ordering::Release);
362
363        // Advance head and publish for producers' capacity check
364        let new_head = head.wrapping_add(1);
365        self.local_head.set(new_head);
366        self.shared.head.store(new_head, Ordering::Release);
367
368        Some(value)
369    }
370
371    /// Returns the capacity of the queue.
372    #[inline]
373    pub fn capacity(&self) -> usize {
374        1 << self.shift
375    }
376
377    /// Returns `true` if all producers have been dropped.
378    #[inline]
379    pub fn is_disconnected(&self) -> bool {
380        Arc::strong_count(&self.shared) == 1
381    }
382}
383
384impl<T> fmt::Debug for Consumer<T> {
385    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
386        f.debug_struct("Consumer")
387            .field("capacity", &self.capacity())
388            .finish_non_exhaustive()
389    }
390}
391
392#[cfg(test)]
393mod tests {
394    use super::*;
395
396    // ============================================================================
397    // Basic Operations
398    // ============================================================================
399
400    #[test]
401    fn basic_push_pop() {
402        let (tx, rx) = bounded::<u64>(4);
403
404        assert!(tx.push(1).is_ok());
405        assert!(tx.push(2).is_ok());
406        assert!(tx.push(3).is_ok());
407
408        assert_eq!(rx.pop(), Some(1));
409        assert_eq!(rx.pop(), Some(2));
410        assert_eq!(rx.pop(), Some(3));
411        assert_eq!(rx.pop(), None);
412    }
413
414    #[test]
415    fn empty_pop_returns_none() {
416        let (_, rx) = bounded::<u64>(4);
417        assert_eq!(rx.pop(), None);
418        assert_eq!(rx.pop(), None);
419    }
420
421    #[test]
422    fn fill_then_drain() {
423        let (tx, rx) = bounded::<u64>(4);
424
425        for i in 0..4 {
426            assert!(tx.push(i).is_ok());
427        }
428
429        for i in 0..4 {
430            assert_eq!(rx.pop(), Some(i));
431        }
432
433        assert_eq!(rx.pop(), None);
434    }
435
436    #[test]
437    fn push_returns_error_when_full() {
438        let (tx, _rx) = bounded::<u64>(4);
439
440        assert!(tx.push(1).is_ok());
441        assert!(tx.push(2).is_ok());
442        assert!(tx.push(3).is_ok());
443        assert!(tx.push(4).is_ok());
444
445        let err = tx.push(5).unwrap_err();
446        assert_eq!(err.into_inner(), 5);
447    }
448
449    // ============================================================================
450    // Interleaved Operations
451    // ============================================================================
452
453    #[test]
454    fn interleaved_single_producer() {
455        let (tx, rx) = bounded::<u64>(8);
456
457        for i in 0..1000 {
458            assert!(tx.push(i).is_ok());
459            assert_eq!(rx.pop(), Some(i));
460        }
461    }
462
463    #[test]
464    fn partial_fill_drain_cycles() {
465        let (tx, rx) = bounded::<u64>(8);
466
467        for round in 0..100 {
468            for i in 0..4 {
469                assert!(tx.push(round * 4 + i).is_ok());
470            }
471
472            for i in 0..4 {
473                assert_eq!(rx.pop(), Some(round * 4 + i));
474            }
475        }
476    }
477
478    // ============================================================================
479    // Multiple Producers
480    // ============================================================================
481
482    #[test]
483    fn two_producers_single_consumer() {
484        use std::thread;
485
486        let (tx, rx) = bounded::<u64>(64);
487        let tx2 = tx.clone();
488
489        let h1 = thread::spawn(move || {
490            for i in 0..1000 {
491                while tx.push(i).is_err() {
492                    std::hint::spin_loop();
493                }
494            }
495        });
496
497        let h2 = thread::spawn(move || {
498            for i in 1000..2000 {
499                while tx2.push(i).is_err() {
500                    std::hint::spin_loop();
501                }
502            }
503        });
504
505        let mut received = Vec::new();
506        while received.len() < 2000 {
507            if let Some(val) = rx.pop() {
508                received.push(val);
509            } else {
510                std::hint::spin_loop();
511            }
512        }
513
514        h1.join().unwrap();
515        h2.join().unwrap();
516
517        // All values received (order not guaranteed across producers)
518        received.sort_unstable();
519        assert_eq!(received, (0..2000).collect::<Vec<_>>());
520    }
521
522    #[test]
523    fn four_producers_single_consumer() {
524        use std::thread;
525
526        let (tx, rx) = bounded::<u64>(256);
527
528        let handles: Vec<_> = (0..4)
529            .map(|p| {
530                let tx = tx.clone();
531                thread::spawn(move || {
532                    for i in 0..1000 {
533                        let val = p * 1000 + i;
534                        while tx.push(val).is_err() {
535                            std::hint::spin_loop();
536                        }
537                    }
538                })
539            })
540            .collect();
541
542        drop(tx); // Drop original producer
543
544        let mut received = Vec::new();
545        while received.len() < 4000 {
546            if let Some(val) = rx.pop() {
547                received.push(val);
548            } else if rx.is_disconnected() && received.len() < 4000 {
549                // Keep trying if not all received
550                std::hint::spin_loop();
551            } else {
552                std::hint::spin_loop();
553            }
554        }
555
556        for h in handles {
557            h.join().unwrap();
558        }
559
560        received.sort_unstable();
561        let expected: Vec<u64> = (0..4)
562            .flat_map(|p| (0..1000).map(move |i| p * 1000 + i))
563            .collect();
564        let mut expected_sorted = expected;
565        expected_sorted.sort_unstable();
566        assert_eq!(received, expected_sorted);
567    }
568
569    // ============================================================================
570    // Single Slot
571    // ============================================================================
572
573    #[test]
574    fn single_slot_bounded() {
575        let (tx, rx) = bounded::<u64>(1);
576
577        assert!(tx.push(1).is_ok());
578        assert!(tx.push(2).is_err());
579
580        assert_eq!(rx.pop(), Some(1));
581        assert!(tx.push(2).is_ok());
582    }
583
584    // ============================================================================
585    // Disconnection
586    // ============================================================================
587
588    #[test]
589    fn producer_disconnected() {
590        let (tx, rx) = bounded::<u64>(4);
591
592        assert!(!rx.is_disconnected());
593        drop(tx);
594        assert!(rx.is_disconnected());
595    }
596
597    #[test]
598    fn consumer_disconnected() {
599        let (tx, rx) = bounded::<u64>(4);
600
601        assert!(!tx.is_disconnected());
602        drop(rx);
603        assert!(tx.is_disconnected());
604    }
605
606    #[test]
607    fn multiple_producers_one_disconnects() {
608        let (tx1, rx) = bounded::<u64>(4);
609        let tx2 = tx1.clone();
610
611        assert!(!rx.is_disconnected());
612        drop(tx1);
613        assert!(!rx.is_disconnected()); // tx2 still alive
614        drop(tx2);
615        assert!(rx.is_disconnected());
616    }
617
618    // ============================================================================
619    // Drop Behavior
620    // ============================================================================
621
622    #[test]
623    fn drop_cleans_up_remaining() {
624        use std::sync::atomic::AtomicUsize;
625
626        static DROP_COUNT: AtomicUsize = AtomicUsize::new(0);
627
628        struct DropCounter;
629        impl Drop for DropCounter {
630            fn drop(&mut self) {
631                DROP_COUNT.fetch_add(1, Ordering::SeqCst);
632            }
633        }
634
635        DROP_COUNT.store(0, Ordering::SeqCst);
636
637        let (tx, rx) = bounded::<DropCounter>(4);
638
639        let _ = tx.push(DropCounter);
640        let _ = tx.push(DropCounter);
641        let _ = tx.push(DropCounter);
642
643        assert_eq!(DROP_COUNT.load(Ordering::SeqCst), 0);
644
645        drop(tx);
646        drop(rx);
647
648        assert_eq!(DROP_COUNT.load(Ordering::SeqCst), 3);
649    }
650
651    // ============================================================================
652    // Special Types
653    // ============================================================================
654
655    #[test]
656    fn zero_sized_type() {
657        let (tx, rx) = bounded::<()>(8);
658
659        let _ = tx.push(());
660        let _ = tx.push(());
661
662        assert_eq!(rx.pop(), Some(()));
663        assert_eq!(rx.pop(), Some(()));
664        assert_eq!(rx.pop(), None);
665    }
666
667    #[test]
668    fn string_type() {
669        let (tx, rx) = bounded::<String>(4);
670
671        let _ = tx.push("hello".to_string());
672        let _ = tx.push("world".to_string());
673
674        assert_eq!(rx.pop(), Some("hello".to_string()));
675        assert_eq!(rx.pop(), Some("world".to_string()));
676    }
677
678    #[test]
679    #[should_panic(expected = "capacity must be non-zero")]
680    fn zero_capacity_panics() {
681        let _ = bounded::<u64>(0);
682    }
683
684    #[test]
685    fn large_message_type() {
686        #[repr(C, align(64))]
687        struct LargeMessage {
688            data: [u8; 256],
689        }
690
691        let (tx, rx) = bounded::<LargeMessage>(8);
692
693        let msg = LargeMessage { data: [42u8; 256] };
694        assert!(tx.push(msg).is_ok());
695
696        let received = rx.pop().unwrap();
697        assert_eq!(received.data[0], 42);
698        assert_eq!(received.data[255], 42);
699    }
700
701    #[test]
702    fn multiple_laps() {
703        let (tx, rx) = bounded::<u64>(4);
704
705        // 10 full laps through 4-slot buffer
706        for i in 0..40 {
707            assert!(tx.push(i).is_ok());
708            assert_eq!(rx.pop(), Some(i));
709        }
710    }
711
712    #[test]
713    fn capacity_rounds_to_power_of_two() {
714        let (tx, _) = bounded::<u64>(100);
715        assert_eq!(tx.capacity(), 128);
716
717        let (tx, _) = bounded::<u64>(1000);
718        assert_eq!(tx.capacity(), 1024);
719    }
720
721    // ============================================================================
722    // Stress Tests
723    // ============================================================================
724
725    #[test]
726    fn stress_single_producer() {
727        use std::thread;
728
729        const COUNT: u64 = 100_000;
730
731        let (tx, rx) = bounded::<u64>(1024);
732
733        let producer = thread::spawn(move || {
734            for i in 0..COUNT {
735                while tx.push(i).is_err() {
736                    std::hint::spin_loop();
737                }
738            }
739        });
740
741        let consumer = thread::spawn(move || {
742            let mut sum = 0u64;
743            let mut received = 0u64;
744            while received < COUNT {
745                if let Some(val) = rx.pop() {
746                    sum = sum.wrapping_add(val);
747                    received += 1;
748                } else {
749                    std::hint::spin_loop();
750                }
751            }
752            sum
753        });
754
755        producer.join().unwrap();
756        let sum = consumer.join().unwrap();
757        assert_eq!(sum, COUNT * (COUNT - 1) / 2);
758    }
759
760    #[test]
761    fn stress_multiple_producers() {
762        use std::thread;
763
764        const PRODUCERS: u64 = 4;
765        const PER_PRODUCER: u64 = 25_000;
766        const TOTAL: u64 = PRODUCERS * PER_PRODUCER;
767
768        let (tx, rx) = bounded::<u64>(1024);
769
770        let handles: Vec<_> = (0..PRODUCERS)
771            .map(|_| {
772                let tx = tx.clone();
773                thread::spawn(move || {
774                    for i in 0..PER_PRODUCER {
775                        while tx.push(i).is_err() {
776                            std::hint::spin_loop();
777                        }
778                    }
779                })
780            })
781            .collect();
782
783        drop(tx);
784
785        let mut received = 0u64;
786        while received < TOTAL {
787            if rx.pop().is_some() {
788                received += 1;
789            } else {
790                std::hint::spin_loop();
791            }
792        }
793
794        for h in handles {
795            h.join().unwrap();
796        }
797
798        assert_eq!(received, TOTAL);
799    }
800}