1use std::alloc::{Layout, alloc_zeroed, dealloc, handle_alloc_error};
34use std::cell::Cell;
35use std::ops::{Deref, DerefMut};
36use std::ptr;
37use std::sync::Arc;
38use std::sync::atomic::{AtomicUsize, Ordering, fence};
39
40use crossbeam_utils::CachePadded;
41
42use crate::{LEN_MASK, SKIP_BIT, TryClaimError, align8};
43
44const HEADER_SIZE: usize = std::mem::size_of::<usize>();
48
49pub fn new(capacity: usize) -> (Producer, Consumer) {
57 assert!(capacity >= 16, "capacity must be at least 16 bytes");
58
59 let capacity = capacity.next_power_of_two();
60 let mask = capacity - 1;
61
62 let layout = Layout::from_size_align(capacity, 8).unwrap();
64 let buffer_ptr = unsafe { alloc_zeroed(layout) };
65 if buffer_ptr.is_null() {
66 handle_alloc_error(layout);
67 }
68
69 let shared = Arc::new(Shared {
70 head: CachePadded::new(AtomicUsize::new(0)),
71 tail: CachePadded::new(AtomicUsize::new(0)),
72 buffer: buffer_ptr,
73 capacity,
74 mask,
75 });
76
77 (
78 Producer {
79 cached_head: Cell::new(0),
80 shared: Arc::clone(&shared),
81 },
82 Consumer {
83 head: Cell::new(0),
84 shared,
85 },
86 )
87}
88
89struct Shared {
90 head: CachePadded<AtomicUsize>,
92 tail: CachePadded<AtomicUsize>,
94 buffer: *mut u8,
96 capacity: usize,
98 mask: usize,
100}
101
102unsafe impl Send for Shared {}
106unsafe impl Sync for Shared {}
107
108impl Drop for Shared {
109 fn drop(&mut self) {
110 let layout = Layout::from_size_align(self.capacity, 8).unwrap();
112 unsafe { dealloc(self.buffer, layout) };
113 }
114}
115
116#[derive(Clone)]
125pub struct Producer {
126 cached_head: Cell<usize>,
128 shared: Arc<Shared>,
130}
131
132unsafe impl Send for Producer {}
134
135impl Producer {
136 #[inline]
150 pub fn try_claim(&mut self, len: usize) -> Result<WriteClaim<'_>, TryClaimError> {
151 debug_assert!(len <= LEN_MASK, "payload too large");
152 if len == 0 {
153 return Err(TryClaimError::ZeroLength);
154 }
155
156 let record_size = align8(HEADER_SIZE + len);
157
158 loop {
160 let tail = self.shared.tail.load(Ordering::Relaxed);
161
162 let used = tail.wrapping_sub(self.cached_head.get());
165 let available = self.shared.capacity.saturating_sub(used);
166
167 if available < record_size {
168 self.cached_head
170 .set(self.shared.head.load(Ordering::Relaxed));
171 fence(Ordering::Acquire);
172
173 let used = tail.wrapping_sub(self.cached_head.get());
174 if used > self.shared.capacity || self.shared.capacity - used < record_size {
175 return Err(TryClaimError::Full);
176 }
177 }
178
179 let offset = tail & self.shared.mask;
181 let space_to_end = self.shared.capacity - offset;
182
183 if space_to_end < record_size {
184 let total_needed = space_to_end + record_size;
186
187 let used = tail.wrapping_sub(self.cached_head.get());
188 let available = self.shared.capacity.saturating_sub(used);
189
190 if available < total_needed {
191 self.cached_head
193 .set(self.shared.head.load(Ordering::Relaxed));
194 fence(Ordering::Acquire);
195
196 let used = tail.wrapping_sub(self.cached_head.get());
197 if used > self.shared.capacity || self.shared.capacity - used < total_needed {
198 return Err(TryClaimError::Full);
199 }
200 }
201
202 let new_tail = tail.wrapping_add(total_needed);
204 if self
205 .shared
206 .tail
207 .compare_exchange_weak(tail, new_tail, Ordering::Relaxed, Ordering::Relaxed)
208 .is_ok()
209 {
210 let buffer = self.shared.buffer;
212 let skip_len = space_to_end | SKIP_BIT;
213
214 fence(Ordering::Release);
216 let len_ptr = unsafe { buffer.add(offset) }.cast::<AtomicUsize>();
217 unsafe { &*len_ptr }.store(skip_len, Ordering::Relaxed);
218
219 return Ok(WriteClaim {
220 shared: &self.shared,
221 offset: 0, len,
223 record_size,
224 committed: false,
225 });
226 }
227 continue;
229 }
230
231 let new_tail = tail.wrapping_add(record_size);
233 if self
234 .shared
235 .tail
236 .compare_exchange_weak(tail, new_tail, Ordering::Relaxed, Ordering::Relaxed)
237 .is_ok()
238 {
239 return Ok(WriteClaim {
240 shared: &self.shared,
241 offset,
242 len,
243 record_size,
244 committed: false,
245 });
246 }
247 }
249 }
250
251 #[inline]
253 pub fn capacity(&self) -> usize {
254 self.shared.capacity
255 }
256
257 #[inline]
262 pub fn is_disconnected(&self) -> bool {
263 Arc::strong_count(&self.shared) == 1
264 }
265}
266
267impl std::fmt::Debug for Producer {
268 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
269 f.debug_struct("Producer")
270 .field("capacity", &self.capacity())
271 .finish_non_exhaustive()
272 }
273}
274
275pub struct WriteClaim<'a> {
285 shared: &'a Shared,
286 offset: usize,
287 len: usize,
288 record_size: usize,
289 committed: bool,
290}
291
292impl WriteClaim<'_> {
293 #[inline]
295 pub fn commit(mut self) {
296 self.do_commit();
297 self.committed = true;
298 }
299
300 #[inline]
301 fn do_commit(&mut self) {
302 let buffer = self.shared.buffer;
303 let len_ptr = unsafe { buffer.add(self.offset) }.cast::<AtomicUsize>();
304
305 fence(Ordering::Release);
307 unsafe { &*len_ptr }.store(self.len, Ordering::Relaxed);
308 }
309
310 #[inline]
312 pub fn len(&self) -> usize {
313 self.len
314 }
315
316 #[inline]
318 pub fn is_empty(&self) -> bool {
319 false
320 }
321}
322
323impl Deref for WriteClaim<'_> {
324 type Target = [u8];
325
326 #[inline]
327 fn deref(&self) -> &Self::Target {
328 let buffer = self.shared.buffer;
329 let payload_ptr = unsafe { buffer.add(self.offset + HEADER_SIZE) };
330 unsafe { std::slice::from_raw_parts(payload_ptr, self.len) }
331 }
332}
333
334impl DerefMut for WriteClaim<'_> {
335 #[inline]
336 fn deref_mut(&mut self) -> &mut Self::Target {
337 let buffer = self.shared.buffer;
338 let payload_ptr = unsafe { buffer.add(self.offset + HEADER_SIZE) };
339 unsafe { std::slice::from_raw_parts_mut(payload_ptr, self.len) }
340 }
341}
342
343impl Drop for WriteClaim<'_> {
344 fn drop(&mut self) {
345 if !self.committed {
346 let buffer = self.shared.buffer;
348 let len_ptr = unsafe { buffer.add(self.offset) }.cast::<AtomicUsize>();
349 let skip_len = self.record_size | SKIP_BIT;
350
351 fence(Ordering::Release);
352 unsafe { &*len_ptr }.store(skip_len, Ordering::Relaxed);
353 }
354 }
355}
356
357pub struct Consumer {
366 head: Cell<usize>,
368 shared: Arc<Shared>,
370}
371
372unsafe impl Send for Consumer {}
374
375impl Consumer {
376 #[inline]
384 pub fn try_claim(&mut self) -> Option<ReadClaim<'_>> {
385 let buffer = self.shared.buffer;
386
387 loop {
388 let offset = self.head.get() & self.shared.mask;
389 let len_ptr = unsafe { buffer.add(offset) }.cast::<AtomicUsize>();
390
391 let len_raw = unsafe { &*len_ptr }.load(Ordering::Relaxed);
393 fence(Ordering::Acquire);
394
395 if len_raw == 0 {
396 return None;
398 }
399
400 if len_raw & SKIP_BIT != 0 {
401 let skip_size = len_raw & LEN_MASK;
403 if skip_size > HEADER_SIZE {
405 unsafe {
406 ptr::write_bytes(
407 buffer.add(offset + HEADER_SIZE),
408 0,
409 skip_size - HEADER_SIZE,
410 );
411 }
412 }
413 fence(Ordering::Release);
415 unsafe { &*len_ptr }.store(0, Ordering::Relaxed);
416
417 self.head.set(self.head.get().wrapping_add(skip_size));
418
419 fence(Ordering::Release);
421 self.shared.head.store(self.head.get(), Ordering::Relaxed);
422
423 continue;
425 }
426
427 let len = len_raw;
429 let record_size = align8(HEADER_SIZE + len);
430
431 return Some(ReadClaim {
432 consumer: self,
433 offset,
434 len,
435 record_size,
436 });
437 }
438 }
439
440 #[inline]
442 pub fn capacity(&self) -> usize {
443 self.shared.capacity
444 }
445
446 #[inline]
450 pub fn is_disconnected(&self) -> bool {
451 Arc::strong_count(&self.shared) == 1
452 }
453}
454
455impl std::fmt::Debug for Consumer {
456 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
457 f.debug_struct("Consumer")
458 .field("capacity", &self.capacity())
459 .finish_non_exhaustive()
460 }
461}
462
463pub struct ReadClaim<'a> {
472 consumer: &'a mut Consumer,
473 offset: usize,
474 len: usize,
475 record_size: usize,
476}
477
478impl ReadClaim<'_> {
479 #[inline]
481 pub fn len(&self) -> usize {
482 self.len
483 }
484
485 #[inline]
487 pub fn is_empty(&self) -> bool {
488 self.len == 0
489 }
490}
491
492impl Deref for ReadClaim<'_> {
493 type Target = [u8];
494
495 #[inline]
496 fn deref(&self) -> &Self::Target {
497 let buffer = self.consumer.shared.buffer;
498 let payload_ptr = unsafe { buffer.add(self.offset + HEADER_SIZE) };
499 unsafe { std::slice::from_raw_parts(payload_ptr, self.len) }
500 }
501}
502
503impl Drop for ReadClaim<'_> {
504 fn drop(&mut self) {
505 let buffer = self.consumer.shared.buffer;
506
507 if self.record_size > HEADER_SIZE {
509 unsafe {
510 ptr::write_bytes(
511 buffer.add(self.offset + HEADER_SIZE),
512 0,
513 self.record_size - HEADER_SIZE,
514 );
515 }
516 }
517 fence(Ordering::Release);
519 let len_ptr = unsafe { buffer.add(self.offset) }.cast::<AtomicUsize>();
520 unsafe { &*len_ptr }.store(0, Ordering::Relaxed);
521
522 let new_head = self.consumer.head.get().wrapping_add(self.record_size);
524 self.consumer.head.set(new_head);
525
526 fence(Ordering::Release);
528 self.consumer.shared.head.store(new_head, Ordering::Relaxed);
529 }
530}
531
532#[cfg(test)]
537mod tests {
538 use super::*;
539
540 #[test]
541 fn basic_write_read() {
542 let (mut prod, mut cons) = new(1024);
543
544 let payload = b"hello world";
545 let mut claim = prod.try_claim(payload.len()).unwrap();
546 claim.copy_from_slice(payload);
547 claim.commit();
548
549 let record = cons.try_claim().unwrap();
550 assert_eq!(&*record, payload);
551 }
552
553 #[test]
554 fn empty_returns_none() {
555 let (_, mut cons) = new(1024);
556 assert!(cons.try_claim().is_none());
557 }
558
559 #[test]
560 fn multiple_records() {
561 let (mut prod, mut cons) = new(1024);
562
563 for i in 0..10 {
564 let payload = format!("message {}", i);
565 let mut claim = prod.try_claim(payload.len()).unwrap();
566 claim.copy_from_slice(payload.as_bytes());
567 claim.commit();
568 }
569
570 for i in 0..10 {
571 let record = cons.try_claim().unwrap();
572 let expected = format!("message {}", i);
573 assert_eq!(&*record, expected.as_bytes());
574 }
575
576 assert!(cons.try_claim().is_none());
577 }
578
579 #[test]
580 #[allow(clippy::redundant_clone)]
581 fn producer_is_clone() {
582 let (prod, _cons) = new(1024);
583 let _prod2 = prod.clone();
584 }
585
586 #[test]
587 fn multiple_producers_single_consumer() {
588 use std::thread;
589
590 const PRODUCERS: usize = 4;
591 const MESSAGES_PER_PRODUCER: u64 = 10_000;
592 const TOTAL: u64 = PRODUCERS as u64 * MESSAGES_PER_PRODUCER;
593
594 let (prod, mut cons) = new(64 * 1024);
595
596 let handles: Vec<_> = (0..PRODUCERS)
597 .map(|producer_id| {
598 let mut prod = prod.clone();
599 thread::spawn(move || {
600 for i in 0..MESSAGES_PER_PRODUCER {
601 let mut payload = [0u8; 16];
603 payload[..8].copy_from_slice(&(producer_id as u64).to_le_bytes());
604 payload[8..].copy_from_slice(&i.to_le_bytes());
605
606 loop {
607 match prod.try_claim(16) {
608 Ok(mut claim) => {
609 claim.copy_from_slice(&payload);
610 claim.commit();
611 break;
612 }
613 Err(_) => std::hint::spin_loop(),
614 }
615 }
616 }
617 })
618 })
619 .collect();
620
621 drop(prod);
623
624 let consumer = thread::spawn(move || {
626 let mut received = 0u64;
627 let mut per_producer = vec![0u64; PRODUCERS];
628
629 while received < TOTAL {
630 if let Some(record) = cons.try_claim() {
631 let producer_id = u64::from_le_bytes(record[..8].try_into().unwrap()) as usize;
632 let seq = u64::from_le_bytes(record[8..].try_into().unwrap());
633
634 assert_eq!(
636 seq, per_producer[producer_id],
637 "producer {} out of order",
638 producer_id
639 );
640 per_producer[producer_id] += 1;
641 received += 1;
642 } else {
643 std::hint::spin_loop();
644 }
645 }
646
647 per_producer
648 });
649
650 for h in handles {
651 h.join().unwrap();
652 }
653
654 let per_producer = consumer.join().unwrap();
655 for (i, &count) in per_producer.iter().enumerate() {
656 assert_eq!(count, MESSAGES_PER_PRODUCER, "producer {} count", i);
657 }
658 }
659
660 #[test]
661 fn aborted_claim_creates_skip() {
662 let (mut prod, mut cons) = new(1024);
663
664 {
666 let mut claim = prod.try_claim(10).unwrap();
667 claim.copy_from_slice(b"0123456789");
668 }
670
671 {
673 let mut claim = prod.try_claim(5).unwrap();
674 claim.copy_from_slice(b"hello");
675 claim.commit();
676 }
677
678 let record = cons.try_claim().unwrap();
680 assert_eq!(&*record, b"hello");
681 }
682
683 #[test]
684 fn wrap_around() {
685 let (mut prod, mut cons) = new(64);
686
687 for i in 0..20 {
689 let payload = format!("msg{:02}", i);
690 loop {
691 match prod.try_claim(payload.len()) {
692 Ok(mut claim) => {
693 claim.copy_from_slice(payload.as_bytes());
694 claim.commit();
695 break;
696 }
697 Err(_) => {
698 while cons.try_claim().is_some() {}
700 }
701 }
702 }
703 }
704 }
705
706 #[test]
707 fn full_returns_error() {
708 let (mut prod, _cons) = new(64);
709
710 let mut count = 0;
712 while let Ok(mut claim) = prod.try_claim(8) {
713 claim.copy_from_slice(b"12345678");
714 claim.commit();
715 count += 1;
716 }
717
718 assert!(count > 0);
719 assert!(prod.try_claim(8).is_err());
720 }
721
722 #[test]
723 fn disconnection_detection() {
724 let (prod, cons) = new(1024);
725
726 assert!(!prod.is_disconnected());
727 assert!(!cons.is_disconnected());
728
729 drop(cons);
730 assert!(prod.is_disconnected());
731 }
732
733 #[test]
734 #[should_panic(expected = "capacity must be at least 16")]
735 fn tiny_capacity_panics() {
736 let _ = new(8);
737 }
738
739 #[test]
740 fn zero_len_returns_error() {
741 let (mut prod, _) = new(1024);
742 assert!(matches!(prod.try_claim(0), Err(TryClaimError::ZeroLength)));
743 }
744
745 #[test]
746 fn capacity_rounds_to_power_of_two() {
747 let (prod, _) = new(100);
748 assert_eq!(prod.capacity(), 128);
749
750 let (prod, _) = new(1000);
751 assert_eq!(prod.capacity(), 1024);
752 }
753
754 #[test]
756 fn stress_multiple_producers() {
757 use std::thread;
758
759 const PRODUCERS: usize = 4;
760 const COUNT_PER_PRODUCER: u64 = 100_000;
761 const TOTAL: u64 = PRODUCERS as u64 * COUNT_PER_PRODUCER;
762 const BUFFER_SIZE: usize = 64 * 1024;
763
764 let (prod, mut cons) = new(BUFFER_SIZE);
765
766 let handles: Vec<_> = (0..PRODUCERS)
767 .map(|_| {
768 let mut prod = prod.clone();
769 thread::spawn(move || {
770 for i in 0..COUNT_PER_PRODUCER {
771 let payload = i.to_le_bytes();
772 loop {
773 match prod.try_claim(payload.len()) {
774 Ok(mut claim) => {
775 claim.copy_from_slice(&payload);
776 claim.commit();
777 break;
778 }
779 Err(_) => std::hint::spin_loop(),
780 }
781 }
782 }
783 })
784 })
785 .collect();
786
787 drop(prod);
788
789 let consumer = thread::spawn(move || {
790 let mut received = 0u64;
791 let mut sum = 0u64;
792 while received < TOTAL {
793 if let Some(record) = cons.try_claim() {
794 let value = u64::from_le_bytes((*record).try_into().unwrap());
795 sum = sum.wrapping_add(value);
796 received += 1;
797 } else {
798 std::hint::spin_loop();
799 }
800 }
801 (received, sum)
802 });
803
804 for h in handles {
805 h.join().unwrap();
806 }
807
808 let (received, sum) = consumer.join().unwrap();
809 assert_eq!(received, TOTAL);
810
811 let expected_sum = PRODUCERS as u64 * COUNT_PER_PRODUCER * (COUNT_PER_PRODUCER - 1) / 2;
814 assert_eq!(sum, expected_sum);
815 }
816}