Skip to main content

nexus_queue/
spsc.rs

1//! Single-producer single-consumer bounded queue.
2//!
3//! A lock-free ring buffer optimized for exactly one producer thread and one
4//! consumer thread. Uses cached indices to minimize atomic operations on the
5//! hot path.
6//!
7//! # Design
8//!
9//! ```text
10//! ┌─────────────────────────────────────────────────────────────┐
11//! │ Shared (Arc):                                               │
12//! │   tail: CachePadded<AtomicUsize>   ← Producer writes        │
13//! │   head: CachePadded<AtomicUsize>   ← Consumer writes        │
14//! │   buffer: *mut T                                            │
15//! └─────────────────────────────────────────────────────────────┘
16//!
17//! ┌─────────────────────┐     ┌─────────────────────┐
18//! │ Producer:           │     │ Consumer:           │
19//! │   local_tail        │     │   local_head        │
20//! │   cached_head       │     │   cached_tail       │
21//! │   buffer (cached)   │     │   buffer (cached)   │
22//! │   mask (cached)     │     │   mask (cached)     │
23//! └─────────────────────┘     └─────────────────────┘
24//! ```
25//!
26//! Producer and consumer each cache the buffer pointer and mask locally to
27//! avoid Arc dereference on every operation. They also maintain a cached copy
28//! of the other's index, only refreshing from the atomic when the cache
29//! indicates the queue is full (producer) or empty (consumer).
30//!
31//! Head and tail are on separate cache lines (128-byte padding) to avoid false
32//! sharing between producer and consumer threads.
33//!
34//! # Example
35//!
36//! ```
37//! use nexus_queue::spsc;
38//!
39//! let (tx, rx) = spsc::ring_buffer::<u64>(1024);
40//!
41//! tx.push(42).unwrap();
42//! assert_eq!(rx.pop(), Some(42));
43//! ```
44
45use std::cell::Cell;
46use std::fmt;
47use std::mem::ManuallyDrop;
48use std::sync::Arc;
49use std::sync::atomic::{AtomicUsize, Ordering};
50
51use crossbeam_utils::CachePadded;
52
53use crate::Full;
54
55/// Creates a bounded SPSC ring buffer with the given capacity.
56///
57/// Capacity is rounded up to the next power of two.
58///
59/// # Panics
60///
61/// Panics if `capacity` is zero.
62pub fn ring_buffer<T>(capacity: usize) -> (Producer<T>, Consumer<T>) {
63    assert!(capacity > 0, "capacity must be non-zero");
64
65    let capacity = capacity
66        .checked_next_power_of_two()
67        .expect("capacity too large (must be <= usize::MAX / 2)");
68    let mask = capacity - 1;
69
70    let mut slots = ManuallyDrop::new(Vec::<T>::with_capacity(capacity));
71    let buffer = slots.as_mut_ptr();
72
73    let shared = Arc::new(Shared {
74        tail: CachePadded::new(AtomicUsize::new(0)),
75        head: CachePadded::new(AtomicUsize::new(0)),
76        buffer,
77        mask,
78    });
79
80    (
81        Producer {
82            local_tail: Cell::new(0),
83            cached_head: Cell::new(0),
84            buffer,
85            mask,
86            shared: Arc::clone(&shared),
87        },
88        Consumer {
89            local_head: Cell::new(0),
90            cached_tail: Cell::new(0),
91            buffer,
92            mask,
93            shared,
94        },
95    )
96}
97
98// repr(C): Guarantees field order. CachePadded<tail> and CachePadded<head>
99// must be at known offsets for cache line isolation to work correctly.
100#[repr(C)]
101struct Shared<T> {
102    tail: CachePadded<AtomicUsize>,
103    head: CachePadded<AtomicUsize>,
104    buffer: *mut T,
105    mask: usize,
106}
107
108// SAFETY: Shared only contains atomics and a raw pointer. The buffer is only
109// accessed through Producer (write) and Consumer (read), which are !Sync.
110// T: Send ensures the data can be transferred between threads.
111unsafe impl<T: Send> Send for Shared<T> {}
112unsafe impl<T: Send> Sync for Shared<T> {}
113
114impl<T> Drop for Shared<T> {
115    fn drop(&mut self) {
116        let head = self.head.load(Ordering::Relaxed);
117        let tail = self.tail.load(Ordering::Relaxed);
118
119        let mut i = head;
120        while i != tail {
121            // SAFETY: Slots in [head, tail) contain initialized values. We have
122            // exclusive access (drop requires &mut self, both endpoints dropped).
123            unsafe { self.buffer.add(i & self.mask).drop_in_place() };
124            i = i.wrapping_add(1);
125        }
126
127        // SAFETY: buffer was allocated by Vec::with_capacity(capacity) in ring_buffer().
128        // We pass len=0 because we already dropped all elements above.
129        unsafe {
130            let capacity = self.mask + 1;
131            let _ = Vec::from_raw_parts(self.buffer, 0, capacity);
132        }
133    }
134}
135
136/// The producer endpoint of an SPSC queue.
137///
138/// This endpoint can only push values into the queue.
139// repr(C): Hot fields (local_tail, cached_head) at struct base share cache line
140// with struct pointer. Cold field (shared Arc) pushed to end.
141#[repr(C)]
142pub struct Producer<T> {
143    local_tail: Cell<usize>,
144    cached_head: Cell<usize>,
145    buffer: *mut T,
146    mask: usize,
147    shared: Arc<Shared<T>>,
148}
149
150// SAFETY: Producer can be sent to another thread. It has exclusive write access
151// to the buffer slots and maintains the tail index. T: Send ensures the data
152// can be transferred.
153unsafe impl<T: Send> Send for Producer<T> {}
154
155impl<T> Producer<T> {
156    /// Pushes a value into the queue.
157    ///
158    /// Returns `Err(Full(value))` if the queue is full, returning ownership
159    /// of the value to the caller.
160    #[inline]
161    #[must_use = "push returns Err if full, which should be handled"]
162    pub fn push(&self, value: T) -> Result<(), Full<T>> {
163        let tail = self.local_tail.get();
164
165        if tail.wrapping_sub(self.cached_head.get()) > self.mask {
166            self.cached_head
167                .set(self.shared.head.load(Ordering::Relaxed));
168
169            std::sync::atomic::fence(Ordering::Acquire);
170            if tail.wrapping_sub(self.cached_head.get()) > self.mask {
171                return Err(Full(value));
172            }
173        }
174
175        // SAFETY: We verified tail - cached_head <= mask, so the slot is not occupied
176        // by unconsumed data. tail & mask gives a valid index within the buffer.
177        unsafe { self.buffer.add(tail & self.mask).write(value) };
178        let new_tail = tail.wrapping_add(1);
179        std::sync::atomic::fence(Ordering::Release);
180
181        self.shared.tail.store(new_tail, Ordering::Relaxed);
182        self.local_tail.set(new_tail);
183
184        Ok(())
185    }
186
187    /// Returns the capacity of the queue.
188    #[inline]
189    pub fn capacity(&self) -> usize {
190        self.mask + 1
191    }
192
193    /// Returns `true` if the consumer has been dropped.
194    #[inline]
195    pub fn is_disconnected(&self) -> bool {
196        Arc::strong_count(&self.shared) == 1
197    }
198}
199
200impl<T> fmt::Debug for Producer<T> {
201    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
202        f.debug_struct("Producer")
203            .field("capacity", &self.capacity())
204            .finish_non_exhaustive()
205    }
206}
207
208/// The consumer endpoint of an SPSC queue.
209///
210/// This endpoint can only pop values from the queue.
211// repr(C): Hot fields (local_head, cached_tail) at struct base share cache line
212// with struct pointer. Cold field (shared Arc) pushed to end.
213#[repr(C)]
214pub struct Consumer<T> {
215    local_head: Cell<usize>,
216    cached_tail: Cell<usize>,
217    buffer: *mut T,
218    mask: usize,
219    shared: Arc<Shared<T>>,
220}
221
222// SAFETY: Consumer can be sent to another thread. It has exclusive read access
223// to buffer slots and maintains the head index. T: Send ensures the data can
224// be transferred.
225unsafe impl<T: Send> Send for Consumer<T> {}
226
227impl<T> Consumer<T> {
228    /// Pops a value from the queue.
229    ///
230    /// Returns `None` if the queue is empty.
231    #[inline]
232    pub fn pop(&self) -> Option<T> {
233        let head = self.local_head.get();
234
235        if head == self.cached_tail.get() {
236            self.cached_tail
237                .set(self.shared.tail.load(Ordering::Relaxed));
238            std::sync::atomic::fence(Ordering::Acquire);
239
240            if head == self.cached_tail.get() {
241                return None;
242            }
243        }
244
245        // SAFETY: We verified head != cached_tail, so the slot contains valid data
246        // written by the producer. head & mask gives a valid index within the buffer.
247        let value = unsafe { self.buffer.add(head & self.mask).read() };
248        let new_head = head.wrapping_add(1);
249        std::sync::atomic::fence(Ordering::Release);
250
251        self.shared.head.store(new_head, Ordering::Relaxed);
252        self.local_head.set(new_head);
253
254        Some(value)
255    }
256
257    /// Returns the capacity of the queue.
258    #[inline]
259    pub fn capacity(&self) -> usize {
260        self.mask + 1
261    }
262
263    /// Returns `true` if the producer has been dropped.
264    #[inline]
265    pub fn is_disconnected(&self) -> bool {
266        Arc::strong_count(&self.shared) == 1
267    }
268}
269
270impl<T> fmt::Debug for Consumer<T> {
271    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
272        f.debug_struct("Consumer")
273            .field("capacity", &self.capacity())
274            .finish_non_exhaustive()
275    }
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281
282    // ============================================================================
283    // Basic Operations
284    // ============================================================================
285
286    #[test]
287    fn basic_push_pop() {
288        let (prod, cons) = ring_buffer::<u64>(4);
289
290        assert!(prod.push(1).is_ok());
291        assert!(prod.push(2).is_ok());
292        assert!(prod.push(3).is_ok());
293
294        assert_eq!(cons.pop(), Some(1));
295        assert_eq!(cons.pop(), Some(2));
296        assert_eq!(cons.pop(), Some(3));
297        assert_eq!(cons.pop(), None);
298    }
299
300    #[test]
301    fn empty_pop_returns_none() {
302        let (_, cons) = ring_buffer::<u64>(4);
303        assert_eq!(cons.pop(), None);
304        assert_eq!(cons.pop(), None);
305    }
306
307    #[test]
308    fn fill_then_drain() {
309        let (prod, cons) = ring_buffer::<u64>(4);
310
311        for i in 0..4 {
312            assert!(prod.push(i).is_ok());
313        }
314
315        for i in 0..4 {
316            assert_eq!(cons.pop(), Some(i));
317        }
318
319        assert_eq!(cons.pop(), None);
320    }
321
322    #[test]
323    fn push_returns_error_when_full() {
324        let (prod, _cons) = ring_buffer::<u64>(4);
325
326        assert!(prod.push(1).is_ok());
327        assert!(prod.push(2).is_ok());
328        assert!(prod.push(3).is_ok());
329        assert!(prod.push(4).is_ok());
330
331        let err = prod.push(5).unwrap_err();
332        assert_eq!(err.into_inner(), 5);
333    }
334
335    // ============================================================================
336    // Interleaved Operations
337    // ============================================================================
338
339    #[test]
340    fn interleaved_no_overwrite() {
341        let (prod, cons) = ring_buffer::<u64>(8);
342
343        for i in 0..1000 {
344            assert!(prod.push(i).is_ok());
345            assert_eq!(cons.pop(), Some(i));
346        }
347    }
348
349    #[test]
350    fn partial_fill_drain_cycles() {
351        let (prod, cons) = ring_buffer::<u64>(8);
352
353        for round in 0..100 {
354            for i in 0..4 {
355                assert!(prod.push(round * 4 + i).is_ok());
356            }
357
358            for i in 0..4 {
359                assert_eq!(cons.pop(), Some(round * 4 + i));
360            }
361        }
362    }
363
364    // ============================================================================
365    // Single Slot
366    // ============================================================================
367
368    #[test]
369    fn single_slot_bounded() {
370        let (prod, cons) = ring_buffer::<u64>(1);
371
372        assert!(prod.push(1).is_ok());
373        assert!(prod.push(2).is_err());
374
375        assert_eq!(cons.pop(), Some(1));
376        assert!(prod.push(2).is_ok());
377    }
378
379    // ============================================================================
380    // Disconnection
381    // ============================================================================
382
383    #[test]
384    fn producer_disconnected() {
385        let (prod, cons) = ring_buffer::<u64>(4);
386
387        assert!(!cons.is_disconnected());
388        drop(prod);
389        assert!(cons.is_disconnected());
390    }
391
392    #[test]
393    fn consumer_disconnected() {
394        let (prod, cons) = ring_buffer::<u64>(4);
395
396        assert!(!prod.is_disconnected());
397        drop(cons);
398        assert!(prod.is_disconnected());
399    }
400
401    // ============================================================================
402    // Drop Behavior
403    // ============================================================================
404
405    #[test]
406    fn drop_cleans_up_remaining() {
407        use std::sync::atomic::AtomicUsize;
408
409        static DROP_COUNT: AtomicUsize = AtomicUsize::new(0);
410
411        struct DropCounter;
412        impl Drop for DropCounter {
413            fn drop(&mut self) {
414                DROP_COUNT.fetch_add(1, Ordering::SeqCst);
415            }
416        }
417
418        DROP_COUNT.store(0, Ordering::SeqCst);
419
420        let (prod, cons) = ring_buffer::<DropCounter>(4);
421
422        let _ = prod.push(DropCounter);
423        let _ = prod.push(DropCounter);
424        let _ = prod.push(DropCounter);
425
426        assert_eq!(DROP_COUNT.load(Ordering::SeqCst), 0);
427
428        drop(prod);
429        drop(cons);
430
431        assert_eq!(DROP_COUNT.load(Ordering::SeqCst), 3);
432    }
433
434    // ============================================================================
435    // Cross-Thread
436    // ============================================================================
437
438    #[test]
439    fn cross_thread_bounded() {
440        use std::thread;
441
442        let (prod, cons) = ring_buffer::<u64>(64);
443
444        let producer = thread::spawn(move || {
445            for i in 0..10_000 {
446                while prod.push(i).is_err() {
447                    std::hint::spin_loop();
448                }
449            }
450        });
451
452        let consumer = thread::spawn(move || {
453            let mut received = 0u64;
454            while received < 10_000 {
455                if cons.pop().is_some() {
456                    received += 1;
457                } else {
458                    std::hint::spin_loop();
459                }
460            }
461            received
462        });
463
464        producer.join().unwrap();
465        let received = consumer.join().unwrap();
466        assert_eq!(received, 10_000);
467    }
468
469    // ============================================================================
470    // Special Types
471    // ============================================================================
472
473    #[test]
474    fn zero_sized_type() {
475        let (prod, cons) = ring_buffer::<()>(8);
476
477        let _ = prod.push(());
478        let _ = prod.push(());
479
480        assert_eq!(cons.pop(), Some(()));
481        assert_eq!(cons.pop(), Some(()));
482        assert_eq!(cons.pop(), None);
483    }
484
485    #[test]
486    fn string_type() {
487        let (prod, cons) = ring_buffer::<String>(4);
488
489        let _ = prod.push("hello".to_string());
490        let _ = prod.push("world".to_string());
491
492        assert_eq!(cons.pop(), Some("hello".to_string()));
493        assert_eq!(cons.pop(), Some("world".to_string()));
494    }
495
496    #[test]
497    #[should_panic(expected = "capacity must be non-zero")]
498    fn zero_capacity_panics() {
499        let _ = ring_buffer::<u64>(0);
500    }
501
502    #[test]
503    fn large_message_type() {
504        #[repr(C, align(64))]
505        struct LargeMessage {
506            data: [u8; 256],
507        }
508
509        let (prod, cons) = ring_buffer::<LargeMessage>(8);
510
511        let msg = LargeMessage { data: [42u8; 256] };
512        assert!(prod.push(msg).is_ok());
513
514        let received = cons.pop().unwrap();
515        assert_eq!(received.data[0], 42);
516        assert_eq!(received.data[255], 42);
517    }
518
519    #[test]
520    fn multiple_laps() {
521        let (prod, cons) = ring_buffer::<u64>(4);
522
523        // 10 full laps through 4-slot buffer
524        for i in 0..40 {
525            assert!(prod.push(i).is_ok());
526            assert_eq!(cons.pop(), Some(i));
527        }
528    }
529
530    #[test]
531    fn fifo_order_cross_thread() {
532        use std::thread;
533
534        let (prod, cons) = ring_buffer::<u64>(64);
535
536        let producer = thread::spawn(move || {
537            for i in 0..10_000u64 {
538                while prod.push(i).is_err() {
539                    std::hint::spin_loop();
540                }
541            }
542        });
543
544        let consumer = thread::spawn(move || {
545            let mut expected = 0u64;
546            while expected < 10_000 {
547                if let Some(val) = cons.pop() {
548                    assert_eq!(val, expected, "FIFO order violated");
549                    expected += 1;
550                } else {
551                    std::hint::spin_loop();
552                }
553            }
554        });
555
556        producer.join().unwrap();
557        consumer.join().unwrap();
558    }
559
560    #[test]
561    fn stress_high_volume() {
562        use std::thread;
563
564        const COUNT: u64 = 1_000_000;
565
566        let (prod, cons) = ring_buffer::<u64>(1024);
567
568        let producer = thread::spawn(move || {
569            for i in 0..COUNT {
570                while prod.push(i).is_err() {
571                    std::hint::spin_loop();
572                }
573            }
574        });
575
576        let consumer = thread::spawn(move || {
577            let mut sum = 0u64;
578            let mut received = 0u64;
579            while received < COUNT {
580                if let Some(val) = cons.pop() {
581                    sum = sum.wrapping_add(val);
582                    received += 1;
583                } else {
584                    std::hint::spin_loop();
585                }
586            }
587            sum
588        });
589
590        producer.join().unwrap();
591        let sum = consumer.join().unwrap();
592        assert_eq!(sum, COUNT * (COUNT - 1) / 2);
593    }
594
595    #[test]
596    fn capacity_rounds_to_power_of_two() {
597        let (prod, _) = ring_buffer::<u64>(100);
598        assert_eq!(prod.capacity(), 128);
599
600        let (prod, _) = ring_buffer::<u64>(1000);
601        assert_eq!(prod.capacity(), 1024);
602    }
603}