1use std::alloc::{Layout, alloc_zeroed, dealloc, handle_alloc_error};
34use std::cell::Cell;
35use std::ops::{Deref, DerefMut};
36use std::ptr;
37use std::sync::Arc;
38use std::sync::atomic::{AtomicUsize, Ordering, fence};
39
40use crossbeam_utils::CachePadded;
41
42use crate::{LEN_MASK, SKIP_BIT, TryClaimError, align8};
43
44const HEADER_SIZE: usize = std::mem::size_of::<usize>();
48
49pub fn new(capacity: usize) -> (Producer, Consumer) {
57 assert!(capacity >= 16, "capacity must be at least 16 bytes");
58
59 let capacity = capacity.next_power_of_two();
60 let mask = capacity - 1;
61
62 let layout = Layout::from_size_align(capacity, 8)
64 .expect("valid layout: capacity is a power of two >= 16, align is 8");
65 let buffer_ptr = unsafe { alloc_zeroed(layout) };
66 if buffer_ptr.is_null() {
67 handle_alloc_error(layout);
68 }
69
70 let shared = Arc::new(Shared {
71 head: CachePadded::new(AtomicUsize::new(0)),
72 tail: CachePadded::new(AtomicUsize::new(0)),
73 buffer: buffer_ptr,
74 capacity,
75 mask,
76 });
77
78 (
79 Producer {
80 cached_head: Cell::new(0),
81 shared: Arc::clone(&shared),
82 },
83 Consumer {
84 head: Cell::new(0),
85 shared,
86 },
87 )
88}
89
90struct Shared {
91 head: CachePadded<AtomicUsize>,
93 tail: CachePadded<AtomicUsize>,
95 buffer: *mut u8,
97 capacity: usize,
99 mask: usize,
101}
102
103unsafe impl Send for Shared {}
107unsafe impl Sync for Shared {}
108
109impl Drop for Shared {
110 fn drop(&mut self) {
111 let layout = Layout::from_size_align(self.capacity, 8)
113 .expect("valid layout: capacity was validated at construction");
114 unsafe { dealloc(self.buffer, layout) };
115 }
116}
117
118#[derive(Clone)]
127pub struct Producer {
128 cached_head: Cell<usize>,
130 shared: Arc<Shared>,
132}
133
134unsafe impl Send for Producer {}
136
137impl Producer {
138 #[inline]
152 pub fn try_claim(&mut self, len: usize) -> Result<WriteClaim<'_>, TryClaimError> {
153 debug_assert!(len <= LEN_MASK, "payload too large");
154 if len == 0 {
155 return Err(TryClaimError::ZeroLength);
156 }
157
158 let record_size = align8(HEADER_SIZE + len);
159
160 loop {
162 let tail = self.shared.tail.load(Ordering::Relaxed);
163
164 let used = tail.wrapping_sub(self.cached_head.get());
167 let available = self.shared.capacity.saturating_sub(used);
168
169 if available < record_size {
170 self.cached_head
172 .set(self.shared.head.load(Ordering::Relaxed));
173 fence(Ordering::Acquire);
174
175 let used = tail.wrapping_sub(self.cached_head.get());
176 if used > self.shared.capacity || self.shared.capacity - used < record_size {
177 return Err(TryClaimError::Full);
178 }
179 }
180
181 let offset = tail & self.shared.mask;
183 let space_to_end = self.shared.capacity - offset;
184
185 if space_to_end < record_size {
186 let total_needed = space_to_end + record_size;
188
189 let used = tail.wrapping_sub(self.cached_head.get());
190 let available = self.shared.capacity.saturating_sub(used);
191
192 if available < total_needed {
193 self.cached_head
195 .set(self.shared.head.load(Ordering::Relaxed));
196 fence(Ordering::Acquire);
197
198 let used = tail.wrapping_sub(self.cached_head.get());
199 if used > self.shared.capacity || self.shared.capacity - used < total_needed {
200 return Err(TryClaimError::Full);
201 }
202 }
203
204 let new_tail = tail.wrapping_add(total_needed);
206 if self
207 .shared
208 .tail
209 .compare_exchange_weak(tail, new_tail, Ordering::Relaxed, Ordering::Relaxed)
210 .is_ok()
211 {
212 let buffer = self.shared.buffer;
214 let skip_len = space_to_end | SKIP_BIT;
215
216 fence(Ordering::Release);
218 let len_ptr = unsafe { buffer.add(offset) }.cast::<AtomicUsize>();
219 unsafe { &*len_ptr }.store(skip_len, Ordering::Relaxed);
220
221 return Ok(WriteClaim {
222 shared: &self.shared,
223 offset: 0, len,
225 record_size,
226 committed: false,
227 });
228 }
229 continue;
231 }
232
233 let new_tail = tail.wrapping_add(record_size);
235 if self
236 .shared
237 .tail
238 .compare_exchange_weak(tail, new_tail, Ordering::Relaxed, Ordering::Relaxed)
239 .is_ok()
240 {
241 return Ok(WriteClaim {
242 shared: &self.shared,
243 offset,
244 len,
245 record_size,
246 committed: false,
247 });
248 }
249 }
251 }
252
253 #[inline]
255 pub fn capacity(&self) -> usize {
256 self.shared.capacity
257 }
258
259 #[inline]
264 pub fn is_disconnected(&self) -> bool {
265 Arc::strong_count(&self.shared) == 1
266 }
267}
268
269impl std::fmt::Debug for Producer {
270 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
271 f.debug_struct("Producer")
272 .field("capacity", &self.capacity())
273 .finish_non_exhaustive()
274 }
275}
276
277pub struct WriteClaim<'a> {
287 shared: &'a Shared,
288 offset: usize,
289 len: usize,
290 record_size: usize,
291 committed: bool,
292}
293
294impl WriteClaim<'_> {
295 #[inline]
297 pub fn commit(mut self) {
298 self.do_commit();
299 self.committed = true;
300 }
301
302 #[inline]
303 fn do_commit(&mut self) {
304 let buffer = self.shared.buffer;
305 let len_ptr = unsafe { buffer.add(self.offset) }.cast::<AtomicUsize>();
306
307 fence(Ordering::Release);
309 unsafe { &*len_ptr }.store(self.len, Ordering::Relaxed);
310 }
311
312 #[inline]
314 pub fn len(&self) -> usize {
315 self.len
316 }
317
318 #[inline]
320 pub fn is_empty(&self) -> bool {
321 false
322 }
323}
324
325impl Deref for WriteClaim<'_> {
326 type Target = [u8];
327
328 #[inline]
329 fn deref(&self) -> &Self::Target {
330 let buffer = self.shared.buffer;
331 let payload_ptr = unsafe { buffer.add(self.offset + HEADER_SIZE) };
332 unsafe { std::slice::from_raw_parts(payload_ptr, self.len) }
333 }
334}
335
336impl DerefMut for WriteClaim<'_> {
337 #[inline]
338 fn deref_mut(&mut self) -> &mut Self::Target {
339 let buffer = self.shared.buffer;
340 let payload_ptr = unsafe { buffer.add(self.offset + HEADER_SIZE) };
341 unsafe { std::slice::from_raw_parts_mut(payload_ptr, self.len) }
342 }
343}
344
345impl Drop for WriteClaim<'_> {
346 fn drop(&mut self) {
347 if !self.committed {
348 let buffer = self.shared.buffer;
350 let len_ptr = unsafe { buffer.add(self.offset) }.cast::<AtomicUsize>();
351 let skip_len = self.record_size | SKIP_BIT;
352
353 fence(Ordering::Release);
354 unsafe { &*len_ptr }.store(skip_len, Ordering::Relaxed);
355 }
356 }
357}
358
359pub struct Consumer {
368 head: Cell<usize>,
370 shared: Arc<Shared>,
372}
373
374unsafe impl Send for Consumer {}
376
377impl Consumer {
378 #[inline]
386 pub fn try_claim(&mut self) -> Option<ReadClaim<'_>> {
387 let buffer = self.shared.buffer;
388
389 loop {
390 let offset = self.head.get() & self.shared.mask;
391 let len_ptr = unsafe { buffer.add(offset) }.cast::<AtomicUsize>();
392
393 let len_raw = unsafe { &*len_ptr }.load(Ordering::Relaxed);
395 fence(Ordering::Acquire);
396
397 if len_raw == 0 {
398 return None;
400 }
401
402 if len_raw & SKIP_BIT != 0 {
403 let skip_size = len_raw & LEN_MASK;
405 if skip_size > HEADER_SIZE {
407 unsafe {
408 ptr::write_bytes(
409 buffer.add(offset + HEADER_SIZE),
410 0,
411 skip_size - HEADER_SIZE,
412 );
413 }
414 }
415 fence(Ordering::Release);
417 unsafe { &*len_ptr }.store(0, Ordering::Relaxed);
418
419 self.head.set(self.head.get().wrapping_add(skip_size));
420
421 fence(Ordering::Release);
423 self.shared.head.store(self.head.get(), Ordering::Relaxed);
424
425 continue;
427 }
428
429 let len = len_raw;
431 let record_size = align8(HEADER_SIZE + len);
432
433 return Some(ReadClaim {
434 consumer: self,
435 offset,
436 len,
437 record_size,
438 });
439 }
440 }
441
442 #[inline]
444 pub fn capacity(&self) -> usize {
445 self.shared.capacity
446 }
447
448 #[inline]
452 pub fn is_disconnected(&self) -> bool {
453 Arc::strong_count(&self.shared) == 1
454 }
455}
456
457impl std::fmt::Debug for Consumer {
458 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
459 f.debug_struct("Consumer")
460 .field("capacity", &self.capacity())
461 .finish_non_exhaustive()
462 }
463}
464
465pub struct ReadClaim<'a> {
474 consumer: &'a mut Consumer,
475 offset: usize,
476 len: usize,
477 record_size: usize,
478}
479
480impl ReadClaim<'_> {
481 #[inline]
483 pub fn len(&self) -> usize {
484 self.len
485 }
486
487 #[inline]
489 pub fn is_empty(&self) -> bool {
490 self.len == 0
491 }
492}
493
494impl Deref for ReadClaim<'_> {
495 type Target = [u8];
496
497 #[inline]
498 fn deref(&self) -> &Self::Target {
499 let buffer = self.consumer.shared.buffer;
500 let payload_ptr = unsafe { buffer.add(self.offset + HEADER_SIZE) };
501 unsafe { std::slice::from_raw_parts(payload_ptr, self.len) }
502 }
503}
504
505impl Drop for ReadClaim<'_> {
506 fn drop(&mut self) {
507 let buffer = self.consumer.shared.buffer;
508
509 if self.record_size > HEADER_SIZE {
511 unsafe {
512 ptr::write_bytes(
513 buffer.add(self.offset + HEADER_SIZE),
514 0,
515 self.record_size - HEADER_SIZE,
516 );
517 }
518 }
519 fence(Ordering::Release);
521 let len_ptr = unsafe { buffer.add(self.offset) }.cast::<AtomicUsize>();
522 unsafe { &*len_ptr }.store(0, Ordering::Relaxed);
523
524 let new_head = self.consumer.head.get().wrapping_add(self.record_size);
526 self.consumer.head.set(new_head);
527
528 fence(Ordering::Release);
530 self.consumer.shared.head.store(new_head, Ordering::Relaxed);
531 }
532}
533
534#[cfg(test)]
539mod tests {
540 use super::*;
541
542 #[test]
543 fn basic_write_read() {
544 let (mut prod, mut cons) = new(1024);
545
546 let payload = b"hello world";
547 let mut claim = prod.try_claim(payload.len()).unwrap();
548 claim.copy_from_slice(payload);
549 claim.commit();
550
551 let record = cons.try_claim().unwrap();
552 assert_eq!(&*record, payload);
553 }
554
555 #[test]
556 fn empty_returns_none() {
557 let (_, mut cons) = new(1024);
558 assert!(cons.try_claim().is_none());
559 }
560
561 #[test]
562 fn multiple_records() {
563 let (mut prod, mut cons) = new(1024);
564
565 for i in 0..10 {
566 let payload = format!("message {}", i);
567 let mut claim = prod.try_claim(payload.len()).unwrap();
568 claim.copy_from_slice(payload.as_bytes());
569 claim.commit();
570 }
571
572 for i in 0..10 {
573 let record = cons.try_claim().unwrap();
574 let expected = format!("message {}", i);
575 assert_eq!(&*record, expected.as_bytes());
576 }
577
578 assert!(cons.try_claim().is_none());
579 }
580
581 #[test]
582 #[allow(clippy::redundant_clone)]
583 fn producer_is_clone() {
584 let (prod, _cons) = new(1024);
585 let _prod2 = prod.clone();
586 }
587
588 #[test]
589 fn multiple_producers_single_consumer() {
590 use std::thread;
591
592 const PRODUCERS: usize = 4;
593 const MESSAGES_PER_PRODUCER: u64 = 10_000;
594 const TOTAL: u64 = PRODUCERS as u64 * MESSAGES_PER_PRODUCER;
595
596 let (prod, mut cons) = new(64 * 1024);
597
598 let handles: Vec<_> = (0..PRODUCERS)
599 .map(|producer_id| {
600 let mut prod = prod.clone();
601 thread::spawn(move || {
602 for i in 0..MESSAGES_PER_PRODUCER {
603 let mut payload = [0u8; 16];
605 payload[..8].copy_from_slice(&(producer_id as u64).to_le_bytes());
606 payload[8..].copy_from_slice(&i.to_le_bytes());
607
608 loop {
609 match prod.try_claim(16) {
610 Ok(mut claim) => {
611 claim.copy_from_slice(&payload);
612 claim.commit();
613 break;
614 }
615 Err(_) => std::hint::spin_loop(),
616 }
617 }
618 }
619 })
620 })
621 .collect();
622
623 drop(prod);
625
626 let consumer = thread::spawn(move || {
628 let mut received = 0u64;
629 let mut per_producer = vec![0u64; PRODUCERS];
630
631 while received < TOTAL {
632 if let Some(record) = cons.try_claim() {
633 let producer_id = u64::from_le_bytes(record[..8].try_into().unwrap()) as usize;
634 let seq = u64::from_le_bytes(record[8..].try_into().unwrap());
635
636 assert_eq!(
638 seq, per_producer[producer_id],
639 "producer {} out of order",
640 producer_id
641 );
642 per_producer[producer_id] += 1;
643 received += 1;
644 } else {
645 std::hint::spin_loop();
646 }
647 }
648
649 per_producer
650 });
651
652 for h in handles {
653 h.join().unwrap();
654 }
655
656 let per_producer = consumer.join().unwrap();
657 for (i, &count) in per_producer.iter().enumerate() {
658 assert_eq!(count, MESSAGES_PER_PRODUCER, "producer {} count", i);
659 }
660 }
661
662 #[test]
663 fn aborted_claim_creates_skip() {
664 let (mut prod, mut cons) = new(1024);
665
666 {
668 let mut claim = prod.try_claim(10).unwrap();
669 claim.copy_from_slice(b"0123456789");
670 }
672
673 {
675 let mut claim = prod.try_claim(5).unwrap();
676 claim.copy_from_slice(b"hello");
677 claim.commit();
678 }
679
680 let record = cons.try_claim().unwrap();
682 assert_eq!(&*record, b"hello");
683 }
684
685 #[test]
686 fn wrap_around() {
687 let (mut prod, mut cons) = new(64);
688
689 for i in 0..20 {
691 let payload = format!("msg{:02}", i);
692 loop {
693 match prod.try_claim(payload.len()) {
694 Ok(mut claim) => {
695 claim.copy_from_slice(payload.as_bytes());
696 claim.commit();
697 break;
698 }
699 Err(_) => {
700 while cons.try_claim().is_some() {}
702 }
703 }
704 }
705 }
706 }
707
708 #[test]
709 fn full_returns_error() {
710 let (mut prod, _cons) = new(64);
711
712 let mut count = 0;
714 while let Ok(mut claim) = prod.try_claim(8) {
715 claim.copy_from_slice(b"12345678");
716 claim.commit();
717 count += 1;
718 }
719
720 assert!(count > 0);
721 assert!(prod.try_claim(8).is_err());
722 }
723
724 #[test]
725 fn disconnection_detection() {
726 let (prod, cons) = new(1024);
727
728 assert!(!prod.is_disconnected());
729 assert!(!cons.is_disconnected());
730
731 drop(cons);
732 assert!(prod.is_disconnected());
733 }
734
735 #[test]
736 #[should_panic(expected = "capacity must be at least 16")]
737 fn tiny_capacity_panics() {
738 let _ = new(8);
739 }
740
741 #[test]
742 fn zero_len_returns_error() {
743 let (mut prod, _) = new(1024);
744 assert!(matches!(prod.try_claim(0), Err(TryClaimError::ZeroLength)));
745 }
746
747 #[test]
748 fn capacity_rounds_to_power_of_two() {
749 let (prod, _) = new(100);
750 assert_eq!(prod.capacity(), 128);
751
752 let (prod, _) = new(1000);
753 assert_eq!(prod.capacity(), 1024);
754 }
755
756 #[test]
758 fn stress_multiple_producers() {
759 use std::thread;
760
761 const PRODUCERS: usize = 4;
762 const COUNT_PER_PRODUCER: u64 = 100_000;
763 const TOTAL: u64 = PRODUCERS as u64 * COUNT_PER_PRODUCER;
764 const BUFFER_SIZE: usize = 64 * 1024;
765
766 let (prod, mut cons) = new(BUFFER_SIZE);
767
768 let handles: Vec<_> = (0..PRODUCERS)
769 .map(|_| {
770 let mut prod = prod.clone();
771 thread::spawn(move || {
772 for i in 0..COUNT_PER_PRODUCER {
773 let payload = i.to_le_bytes();
774 loop {
775 match prod.try_claim(payload.len()) {
776 Ok(mut claim) => {
777 claim.copy_from_slice(&payload);
778 claim.commit();
779 break;
780 }
781 Err(_) => std::hint::spin_loop(),
782 }
783 }
784 }
785 })
786 })
787 .collect();
788
789 drop(prod);
790
791 let consumer = thread::spawn(move || {
792 let mut received = 0u64;
793 let mut sum = 0u64;
794 while received < TOTAL {
795 if let Some(record) = cons.try_claim() {
796 let value = u64::from_le_bytes((*record).try_into().unwrap());
797 sum = sum.wrapping_add(value);
798 received += 1;
799 } else {
800 std::hint::spin_loop();
801 }
802 }
803 (received, sum)
804 });
805
806 for h in handles {
807 h.join().unwrap();
808 }
809
810 let (received, sum) = consumer.join().unwrap();
811 assert_eq!(received, TOTAL);
812
813 let expected_sum = PRODUCERS as u64 * COUNT_PER_PRODUCER * (COUNT_PER_PRODUCER - 1) / 2;
816 assert_eq!(sum, expected_sum);
817 }
818}