Skip to main content

nexus_logbuf/queue/
mpsc.rs

1//! Multi-producer single-consumer byte ring buffer.
2//!
3//! # Design
4//!
5//! ```text
6//! ┌─────────────────────────────────────────────────────────────────────────┐
7//! │ Shared:                                                                 │
8//! │   head: CachePadded<AtomicUsize>  ← Consumer writes, producers read     │
9//! │   tail: CachePadded<AtomicUsize>  ← Producers CAS to claim space        │
10//! │   buffer: *mut u8                                                       │
11//! │   capacity: usize                 (power of 2)                          │
12//! │   mask: usize                     (capacity - 1)                        │
13//! └─────────────────────────────────────────────────────────────────────────┘
14//!
15//! ┌─────────────────────────────────┐   ┌─────────────────────────────────┐
16//! │ Producer (cloneable):           │   │ Consumer:                       │
17//! │   cached_head: usize (local)    │   │   head: usize        (local)    │
18//! │   shared: Arc<Shared>           │   │                                 │
19//! └─────────────────────────────────┘   └─────────────────────────────────┘
20//! ```
21//!
22//! # Differences from SPSC
23//!
24//! - Tail is atomic in shared state (not local to producer)
25//! - Producers use CAS loop to claim space
26//! - Producer is `Clone` - multiple producers allowed
27//! - Synchronization: Relaxed CAS on tail, Release on len commit, Acquire on len read
28//!
29//! # Record Layout
30//!
31//! Same as SPSC - see [`crate::spsc`] for details.
32
33use std::ops::{Deref, DerefMut};
34use std::ptr;
35use std::sync::Arc;
36use std::sync::atomic::{AtomicUsize, Ordering, fence};
37
38use crossbeam_utils::CachePadded;
39
40use crate::{LEN_MASK, SKIP_BIT, TryClaimError, align8};
41
42/// Header size in bytes.
43const HEADER_SIZE: usize = 4;
44
45/// Creates a bounded MPSC byte ring buffer.
46///
47/// Capacity is rounded up to the next power of two.
48///
49/// # Panics
50///
51/// Panics if `capacity` is zero or less than 16 bytes.
52pub fn new(capacity: usize) -> (Producer, Consumer) {
53    assert!(capacity >= 16, "capacity must be at least 16 bytes");
54
55    let capacity = capacity.next_power_of_two();
56    let mask = capacity - 1;
57
58    // Allocate buffer, zero-initialized
59    let buffer = vec![0u8; capacity].into_boxed_slice();
60    let buffer_ptr = Box::into_raw(buffer) as *mut u8;
61
62    let shared = Arc::new(Shared {
63        head: CachePadded::new(AtomicUsize::new(0)),
64        tail: CachePadded::new(AtomicUsize::new(0)),
65        buffer: buffer_ptr,
66        capacity,
67        mask,
68    });
69
70    (
71        Producer {
72            cached_head: 0,
73            shared: Arc::clone(&shared),
74        },
75        Consumer { head: 0, shared },
76    )
77}
78
79struct Shared {
80    /// Consumer's read position. Updated by consumer, read by producers.
81    head: CachePadded<AtomicUsize>,
82    /// Producers' write position. CAS'd by producers.
83    tail: CachePadded<AtomicUsize>,
84    /// Buffer pointer.
85    buffer: *mut u8,
86    /// Buffer capacity (power of 2).
87    capacity: usize,
88    /// Mask for wrapping (capacity - 1).
89    mask: usize,
90}
91
92// Safety: Buffer access is synchronized through atomic head/tail.
93// Multiple producers coordinate via CAS on tail.
94// Single consumer is enforced by API (Consumer is not Clone).
95unsafe impl Send for Shared {}
96unsafe impl Sync for Shared {}
97
98impl Drop for Shared {
99    fn drop(&mut self) {
100        unsafe {
101            let _ = Box::from_raw(ptr::slice_from_raw_parts_mut(self.buffer, self.capacity));
102        }
103    }
104}
105
106// ============================================================================
107// Producer
108// ============================================================================
109
110/// Producer endpoint of the MPSC ring buffer.
111///
112/// This type is `Clone` - multiple producers can write concurrently.
113/// Use [`try_claim`](Producer::try_claim) to claim space for writing.
114#[derive(Clone)]
115pub struct Producer {
116    /// Cached head position (Rigtorp-style optimization, per-producer).
117    cached_head: usize,
118    /// Shared state.
119    shared: Arc<Shared>,
120}
121
122// Safety: Producer coordinates with other producers via atomic CAS.
123unsafe impl Send for Producer {}
124
125impl Producer {
126    /// Attempts to claim space for a record with the given payload length.
127    ///
128    /// Returns a [`WriteClaim`] that can be written to and then committed.
129    ///
130    /// # Errors
131    ///
132    /// - [`TryClaimError::ZeroLength`] if `len` is zero
133    /// - [`TryClaimError::Full`] if the buffer is full
134    ///
135    /// # Safety Contract
136    ///
137    /// `len` must not exceed `0x7FFF_FFFF` (2GB - 1). This is checked with
138    /// `debug_assert!` only.
139    #[inline]
140    pub fn try_claim(&mut self, len: usize) -> Result<WriteClaim<'_>, TryClaimError> {
141        debug_assert!(len <= LEN_MASK as usize, "payload too large");
142        if len == 0 {
143            return Err(TryClaimError::ZeroLength);
144        }
145
146        let record_size = align8(HEADER_SIZE + len);
147
148        // CAS loop to claim space
149        loop {
150            let tail = self.shared.tail.load(Ordering::Relaxed);
151
152            // Calculate used space. If cached_head is stale, used can exceed capacity.
153            // saturating_sub handles this gracefully (returns 0 if stale).
154            let used = tail.wrapping_sub(self.cached_head);
155            let available = self.shared.capacity.saturating_sub(used);
156
157            if available < record_size {
158                // Reload head from shared state
159                self.cached_head = self.shared.head.load(Ordering::Relaxed);
160                fence(Ordering::Acquire);
161
162                let used = tail.wrapping_sub(self.cached_head);
163                if used > self.shared.capacity || self.shared.capacity - used < record_size {
164                    return Err(TryClaimError::Full);
165                }
166            }
167
168            // Check if record fits before buffer end, or needs wrap
169            let offset = tail & self.shared.mask;
170            let space_to_end = self.shared.capacity - offset;
171
172            if space_to_end < record_size {
173                // Need to wrap. Check if we have space for padding + record at start.
174                let total_needed = space_to_end + record_size;
175
176                let used = tail.wrapping_sub(self.cached_head);
177                let available = self.shared.capacity.saturating_sub(used);
178
179                if available < total_needed {
180                    // Reload and recheck
181                    self.cached_head = self.shared.head.load(Ordering::Relaxed);
182                    fence(Ordering::Acquire);
183
184                    let used = tail.wrapping_sub(self.cached_head);
185                    if used > self.shared.capacity || self.shared.capacity - used < total_needed {
186                        return Err(TryClaimError::Full);
187                    }
188                }
189
190                // Try to claim the padding + record space
191                let new_tail = tail.wrapping_add(total_needed);
192                if self
193                    .shared
194                    .tail
195                    .compare_exchange_weak(tail, new_tail, Ordering::Relaxed, Ordering::Relaxed)
196                    .is_ok()
197                {
198                    // We claimed the space. Write padding skip marker.
199                    let buffer = self.shared.buffer;
200                    let padding_ptr = unsafe { buffer.add(offset) };
201                    let skip_len = space_to_end as u32 | SKIP_BIT;
202
203                    // Release fence before writing skip marker
204                    fence(Ordering::Release);
205                    unsafe {
206                        ptr::write(padding_ptr as *mut u32, skip_len);
207                    }
208
209                    return Ok(WriteClaim {
210                        shared: &self.shared,
211                        offset: 0, // Record starts at beginning after wrap
212                        len,
213                        record_size,
214                        committed: false,
215                    });
216                }
217                // CAS failed, retry
218                continue;
219            }
220
221            // Fits without wrapping
222            let new_tail = tail.wrapping_add(record_size);
223            if self
224                .shared
225                .tail
226                .compare_exchange_weak(tail, new_tail, Ordering::Relaxed, Ordering::Relaxed)
227                .is_ok()
228            {
229                return Ok(WriteClaim {
230                    shared: &self.shared,
231                    offset,
232                    len,
233                    record_size,
234                    committed: false,
235                });
236            }
237            // CAS failed, retry
238        }
239    }
240
241    /// Returns the capacity of the buffer.
242    #[inline]
243    pub fn capacity(&self) -> usize {
244        self.shared.capacity
245    }
246
247    /// Returns `true` if the consumer has been dropped.
248    #[inline]
249    pub fn is_disconnected(&self) -> bool {
250        // Consumer holds one Arc, each producer holds one.
251        // If only producers remain, consumer is gone.
252        // This is approximate - we check if we're the only holder besides other producers.
253        // A more accurate check would need a separate flag.
254        Arc::strong_count(&self.shared) == 1
255    }
256}
257
258impl std::fmt::Debug for Producer {
259    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
260        f.debug_struct("Producer")
261            .field("capacity", &self.capacity())
262            .finish_non_exhaustive()
263    }
264}
265
266// ============================================================================
267// WriteClaim
268// ============================================================================
269
270/// A claimed region for writing a record.
271///
272/// Dereferences to `&mut [u8]` for the payload region. Call [`commit`](WriteClaim::commit)
273/// when done writing to publish the record. If dropped without committing, a skip
274/// marker is written so the consumer can advance past the dead region.
275pub struct WriteClaim<'a> {
276    shared: &'a Shared,
277    offset: usize,
278    len: usize,
279    record_size: usize,
280    committed: bool,
281}
282
283impl WriteClaim<'_> {
284    /// Commits the record, making it visible to the consumer.
285    #[inline]
286    pub fn commit(mut self) {
287        self.do_commit();
288        self.committed = true;
289    }
290
291    #[inline]
292    fn do_commit(&mut self) {
293        let buffer = self.shared.buffer;
294        let len_ptr = unsafe { buffer.add(self.offset) } as *mut u32;
295
296        // Release fence: ensures payload writes are visible before len store
297        fence(Ordering::Release);
298        unsafe {
299            ptr::write(len_ptr, self.len as u32);
300        }
301    }
302
303    /// Returns the length of the payload region.
304    #[inline]
305    pub fn len(&self) -> usize {
306        self.len
307    }
308
309    /// Returns `true` if the payload is empty (always false, len must be > 0).
310    #[inline]
311    pub fn is_empty(&self) -> bool {
312        false
313    }
314}
315
316impl Deref for WriteClaim<'_> {
317    type Target = [u8];
318
319    #[inline]
320    fn deref(&self) -> &Self::Target {
321        let buffer = self.shared.buffer;
322        let payload_ptr = unsafe { buffer.add(self.offset + HEADER_SIZE) };
323        unsafe { std::slice::from_raw_parts(payload_ptr, self.len) }
324    }
325}
326
327impl DerefMut for WriteClaim<'_> {
328    #[inline]
329    fn deref_mut(&mut self) -> &mut Self::Target {
330        let buffer = self.shared.buffer;
331        let payload_ptr = unsafe { buffer.add(self.offset + HEADER_SIZE) };
332        unsafe { std::slice::from_raw_parts_mut(payload_ptr, self.len) }
333    }
334}
335
336impl Drop for WriteClaim<'_> {
337    fn drop(&mut self) {
338        if !self.committed {
339            // Write skip marker so consumer can advance past this region
340            let buffer = self.shared.buffer;
341            let len_ptr = unsafe { buffer.add(self.offset) } as *mut u32;
342            let skip_len = self.record_size as u32 | SKIP_BIT;
343
344            fence(Ordering::Release);
345            unsafe {
346                ptr::write(len_ptr, skip_len);
347            }
348        }
349    }
350}
351
352// ============================================================================
353// Consumer
354// ============================================================================
355
356/// Consumer endpoint of the MPSC ring buffer.
357///
358/// Use [`try_claim`](Consumer::try_claim) to claim the next record for reading.
359/// This type is NOT `Clone` - only one consumer is allowed.
360pub struct Consumer {
361    /// Local head position (free-running).
362    head: usize,
363    /// Shared state.
364    shared: Arc<Shared>,
365}
366
367// Safety: Consumer is only used from one thread.
368unsafe impl Send for Consumer {}
369
370impl Consumer {
371    /// Attempts to claim the next record for reading.
372    ///
373    /// Returns a [`ReadClaim`] if a record is available. The claim dereferences
374    /// to `&[u8]` for the payload. When dropped, the record region is zeroed
375    /// and the head is advanced.
376    ///
377    /// Returns `None` if no committed record is available.
378    #[inline]
379    pub fn try_claim(&mut self) -> Option<ReadClaim<'_>> {
380        let buffer = self.shared.buffer;
381
382        loop {
383            let offset = self.head & self.shared.mask;
384            let len_ptr = unsafe { buffer.add(offset) } as *const u32;
385
386            // Load len with Relaxed, then Acquire fence
387            let len_raw = unsafe { ptr::read(len_ptr) };
388            fence(Ordering::Acquire);
389
390            if len_raw == 0 {
391                // Not committed yet
392                return None;
393            }
394
395            if len_raw & SKIP_BIT != 0 {
396                // Skip marker: zero the region and advance
397                let skip_size = (len_raw & LEN_MASK) as usize;
398                unsafe {
399                    ptr::write_bytes(buffer.add(offset), 0, skip_size);
400                }
401
402                self.head = self.head.wrapping_add(skip_size);
403
404                // Release fence before updating shared head
405                fence(Ordering::Release);
406                self.shared.head.store(self.head, Ordering::Relaxed);
407
408                // Continue to check next position
409                continue;
410            }
411
412            // Valid record
413            let len = len_raw as usize;
414            let record_size = align8(HEADER_SIZE + len);
415
416            return Some(ReadClaim {
417                consumer: self,
418                offset,
419                len,
420                record_size,
421            });
422        }
423    }
424
425    /// Returns the capacity of the buffer.
426    #[inline]
427    pub fn capacity(&self) -> usize {
428        self.shared.capacity
429    }
430
431    /// Returns `true` if all producers have been dropped.
432    #[inline]
433    pub fn is_disconnected(&self) -> bool {
434        Arc::strong_count(&self.shared) == 1
435    }
436}
437
438impl std::fmt::Debug for Consumer {
439    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
440        f.debug_struct("Consumer")
441            .field("capacity", &self.capacity())
442            .finish_non_exhaustive()
443    }
444}
445
446// ============================================================================
447// ReadClaim
448// ============================================================================
449
450/// A claimed record for reading.
451///
452/// Dereferences to `&[u8]` for the payload. When dropped, the record region
453/// is zeroed and the head is advanced, freeing space for producers.
454pub struct ReadClaim<'a> {
455    consumer: &'a mut Consumer,
456    offset: usize,
457    len: usize,
458    record_size: usize,
459}
460
461impl ReadClaim<'_> {
462    /// Returns the length of the payload.
463    #[inline]
464    pub fn len(&self) -> usize {
465        self.len
466    }
467
468    /// Returns `true` if the payload is empty.
469    #[inline]
470    pub fn is_empty(&self) -> bool {
471        self.len == 0
472    }
473}
474
475impl Deref for ReadClaim<'_> {
476    type Target = [u8];
477
478    #[inline]
479    fn deref(&self) -> &Self::Target {
480        let buffer = self.consumer.shared.buffer;
481        let payload_ptr = unsafe { buffer.add(self.offset + HEADER_SIZE) };
482        unsafe { std::slice::from_raw_parts(payload_ptr, self.len) }
483    }
484}
485
486impl Drop for ReadClaim<'_> {
487    fn drop(&mut self) {
488        let buffer = self.consumer.shared.buffer;
489
490        // Zero the entire record region
491        unsafe {
492            ptr::write_bytes(buffer.add(self.offset), 0, self.record_size);
493        }
494
495        // Advance head
496        self.consumer.head = self.consumer.head.wrapping_add(self.record_size);
497
498        // Release fence before updating shared head
499        fence(Ordering::Release);
500        self.consumer
501            .shared
502            .head
503            .store(self.consumer.head, Ordering::Relaxed);
504    }
505}
506
507// ============================================================================
508// Tests
509// ============================================================================
510
511#[cfg(test)]
512mod tests {
513    use super::*;
514
515    #[test]
516    fn basic_write_read() {
517        let (mut prod, mut cons) = new(1024);
518
519        let payload = b"hello world";
520        let mut claim = prod.try_claim(payload.len()).unwrap();
521        claim.copy_from_slice(payload);
522        claim.commit();
523
524        let record = cons.try_claim().unwrap();
525        assert_eq!(&*record, payload);
526    }
527
528    #[test]
529    fn empty_returns_none() {
530        let (_, mut cons) = new(1024);
531        assert!(cons.try_claim().is_none());
532    }
533
534    #[test]
535    fn multiple_records() {
536        let (mut prod, mut cons) = new(1024);
537
538        for i in 0..10 {
539            let payload = format!("message {}", i);
540            let mut claim = prod.try_claim(payload.len()).unwrap();
541            claim.copy_from_slice(payload.as_bytes());
542            claim.commit();
543        }
544
545        for i in 0..10 {
546            let record = cons.try_claim().unwrap();
547            let expected = format!("message {}", i);
548            assert_eq!(&*record, expected.as_bytes());
549        }
550
551        assert!(cons.try_claim().is_none());
552    }
553
554    #[test]
555    fn producer_is_clone() {
556        let (prod, _cons) = new(1024);
557        let _prod2 = prod.clone();
558    }
559
560    #[test]
561    fn multiple_producers_single_consumer() {
562        use std::thread;
563
564        const PRODUCERS: usize = 4;
565        const MESSAGES_PER_PRODUCER: u64 = 10_000;
566        const TOTAL: u64 = PRODUCERS as u64 * MESSAGES_PER_PRODUCER;
567
568        let (prod, mut cons) = new(64 * 1024);
569
570        let handles: Vec<_> = (0..PRODUCERS)
571            .map(|producer_id| {
572                let mut prod = prod.clone();
573                thread::spawn(move || {
574                    for i in 0..MESSAGES_PER_PRODUCER {
575                        // Encode producer_id and sequence in payload
576                        let mut payload = [0u8; 16];
577                        payload[..8].copy_from_slice(&(producer_id as u64).to_le_bytes());
578                        payload[8..].copy_from_slice(&i.to_le_bytes());
579
580                        loop {
581                            match prod.try_claim(16) {
582                                Ok(mut claim) => {
583                                    claim.copy_from_slice(&payload);
584                                    claim.commit();
585                                    break;
586                                }
587                                Err(_) => std::hint::spin_loop(),
588                            }
589                        }
590                    }
591                })
592            })
593            .collect();
594
595        // Drop original producer
596        drop(prod);
597
598        // Consumer: track per-producer sequence
599        let consumer = thread::spawn(move || {
600            let mut received = 0u64;
601            let mut per_producer = vec![0u64; PRODUCERS];
602
603            while received < TOTAL {
604                if let Some(record) = cons.try_claim() {
605                    let producer_id = u64::from_le_bytes(record[..8].try_into().unwrap()) as usize;
606                    let seq = u64::from_le_bytes(record[8..].try_into().unwrap());
607
608                    // Each producer's messages should arrive in order
609                    assert_eq!(
610                        seq, per_producer[producer_id],
611                        "producer {} out of order",
612                        producer_id
613                    );
614                    per_producer[producer_id] += 1;
615                    received += 1;
616                } else {
617                    std::hint::spin_loop();
618                }
619            }
620
621            per_producer
622        });
623
624        for h in handles {
625            h.join().unwrap();
626        }
627
628        let per_producer = consumer.join().unwrap();
629        for (i, &count) in per_producer.iter().enumerate() {
630            assert_eq!(count, MESSAGES_PER_PRODUCER, "producer {} count", i);
631        }
632    }
633
634    #[test]
635    fn aborted_claim_creates_skip() {
636        let (mut prod, mut cons) = new(1024);
637
638        // Claim and drop without committing
639        {
640            let mut claim = prod.try_claim(10).unwrap();
641            claim.copy_from_slice(b"0123456789");
642            // drop without commit
643        }
644
645        // Write another record
646        {
647            let mut claim = prod.try_claim(5).unwrap();
648            claim.copy_from_slice(b"hello");
649            claim.commit();
650        }
651
652        // Consumer should skip the aborted record and read the committed one
653        let record = cons.try_claim().unwrap();
654        assert_eq!(&*record, b"hello");
655    }
656
657    #[test]
658    fn wrap_around() {
659        let (mut prod, mut cons) = new(64);
660
661        // Fill with messages that will cause wrap-around
662        for i in 0..20 {
663            let payload = format!("msg{:02}", i);
664            loop {
665                match prod.try_claim(payload.len()) {
666                    Ok(mut claim) => {
667                        claim.copy_from_slice(payload.as_bytes());
668                        claim.commit();
669                        break;
670                    }
671                    Err(_) => {
672                        // Drain some
673                        while cons.try_claim().is_some() {}
674                    }
675                }
676            }
677        }
678    }
679
680    #[test]
681    fn full_returns_error() {
682        let (mut prod, _cons) = new(64);
683
684        // Fill the buffer
685        let mut count = 0;
686        loop {
687            match prod.try_claim(8) {
688                Ok(mut claim) => {
689                    claim.copy_from_slice(b"12345678");
690                    claim.commit();
691                    count += 1;
692                }
693                Err(_) => break,
694            }
695        }
696
697        assert!(count > 0);
698        assert!(prod.try_claim(8).is_err());
699    }
700
701    #[test]
702    fn disconnection_detection() {
703        let (prod, cons) = new(1024);
704
705        assert!(!prod.is_disconnected());
706        assert!(!cons.is_disconnected());
707
708        drop(cons);
709        assert!(prod.is_disconnected());
710    }
711
712    #[test]
713    #[should_panic(expected = "capacity must be at least 16")]
714    fn tiny_capacity_panics() {
715        let _ = new(8);
716    }
717
718    #[test]
719    fn zero_len_returns_error() {
720        let (mut prod, _) = new(1024);
721        assert!(matches!(prod.try_claim(0), Err(TryClaimError::ZeroLength)));
722    }
723
724    #[test]
725    fn capacity_rounds_to_power_of_two() {
726        let (prod, _) = new(100);
727        assert_eq!(prod.capacity(), 128);
728
729        let (prod, _) = new(1000);
730        assert_eq!(prod.capacity(), 1024);
731    }
732
733    /// High-volume stress test with multiple producers.
734    #[test]
735    fn stress_multiple_producers() {
736        use std::thread;
737
738        const PRODUCERS: usize = 4;
739        const COUNT_PER_PRODUCER: u64 = 100_000;
740        const TOTAL: u64 = PRODUCERS as u64 * COUNT_PER_PRODUCER;
741        const BUFFER_SIZE: usize = 64 * 1024;
742
743        let (prod, mut cons) = new(BUFFER_SIZE);
744
745        let handles: Vec<_> = (0..PRODUCERS)
746            .map(|_| {
747                let mut prod = prod.clone();
748                thread::spawn(move || {
749                    for i in 0..COUNT_PER_PRODUCER {
750                        let payload = i.to_le_bytes();
751                        loop {
752                            match prod.try_claim(payload.len()) {
753                                Ok(mut claim) => {
754                                    claim.copy_from_slice(&payload);
755                                    claim.commit();
756                                    break;
757                                }
758                                Err(_) => std::hint::spin_loop(),
759                            }
760                        }
761                    }
762                })
763            })
764            .collect();
765
766        drop(prod);
767
768        let consumer = thread::spawn(move || {
769            let mut received = 0u64;
770            let mut sum = 0u64;
771            while received < TOTAL {
772                if let Some(record) = cons.try_claim() {
773                    let value = u64::from_le_bytes((*record).try_into().unwrap());
774                    sum = sum.wrapping_add(value);
775                    received += 1;
776                } else {
777                    std::hint::spin_loop();
778                }
779            }
780            (received, sum)
781        });
782
783        for h in handles {
784            h.join().unwrap();
785        }
786
787        let (received, sum) = consumer.join().unwrap();
788        assert_eq!(received, TOTAL);
789
790        // Each producer sends 0..COUNT_PER_PRODUCER
791        // Sum per producer = COUNT_PER_PRODUCER * (COUNT_PER_PRODUCER - 1) / 2
792        let expected_sum = PRODUCERS as u64 * COUNT_PER_PRODUCER * (COUNT_PER_PRODUCER - 1) / 2;
793        assert_eq!(sum, expected_sum);
794    }
795}