1use std::alloc::{Layout, alloc_zeroed, dealloc, handle_alloc_error};
34use std::cell::Cell;
35use std::ops::{Deref, DerefMut};
36use std::ptr;
37use std::sync::Arc;
38use std::sync::atomic::{AtomicUsize, Ordering, fence};
39
40use crossbeam_utils::CachePadded;
41
42use crate::{LEN_MASK, SKIP_BIT, TryClaimError, align8};
43
44const HEADER_SIZE: usize = std::mem::size_of::<usize>();
48
49pub fn new(capacity: usize) -> (Producer, Consumer) {
57 assert!(capacity >= 16, "capacity must be at least 16 bytes");
58
59 let capacity = capacity.next_power_of_two();
60 let mask = capacity - 1;
61
62 let layout = Layout::from_size_align(capacity, 8).unwrap();
64 let buffer_ptr = unsafe { alloc_zeroed(layout) };
65 if buffer_ptr.is_null() {
66 handle_alloc_error(layout);
67 }
68
69 let shared = Arc::new(Shared {
70 head: CachePadded::new(AtomicUsize::new(0)),
71 tail: CachePadded::new(AtomicUsize::new(0)),
72 buffer: buffer_ptr,
73 capacity,
74 mask,
75 });
76
77 (
78 Producer {
79 cached_head: Cell::new(0),
80 shared: Arc::clone(&shared),
81 },
82 Consumer {
83 head: Cell::new(0),
84 shared,
85 },
86 )
87}
88
89struct Shared {
90 head: CachePadded<AtomicUsize>,
92 tail: CachePadded<AtomicUsize>,
94 buffer: *mut u8,
96 capacity: usize,
98 mask: usize,
100}
101
102unsafe impl Send for Shared {}
106unsafe impl Sync for Shared {}
107
108impl Drop for Shared {
109 fn drop(&mut self) {
110 let layout = Layout::from_size_align(self.capacity, 8).unwrap();
112 unsafe { dealloc(self.buffer, layout) };
113 }
114}
115
116#[derive(Clone)]
125pub struct Producer {
126 cached_head: Cell<usize>,
128 shared: Arc<Shared>,
130}
131
132unsafe impl Send for Producer {}
134
135impl Producer {
136 #[inline]
150 pub fn try_claim(&mut self, len: usize) -> Result<WriteClaim<'_>, TryClaimError> {
151 debug_assert!(len <= LEN_MASK, "payload too large");
152 if len == 0 {
153 return Err(TryClaimError::ZeroLength);
154 }
155
156 let record_size = align8(HEADER_SIZE + len);
157
158 loop {
160 let tail = self.shared.tail.load(Ordering::Relaxed);
161
162 let used = tail.wrapping_sub(self.cached_head.get());
165 let available = self.shared.capacity.saturating_sub(used);
166
167 if available < record_size {
168 self.cached_head.set(self.shared.head.load(Ordering::Relaxed));
170 fence(Ordering::Acquire);
171
172 let used = tail.wrapping_sub(self.cached_head.get());
173 if used > self.shared.capacity || self.shared.capacity - used < record_size {
174 return Err(TryClaimError::Full);
175 }
176 }
177
178 let offset = tail & self.shared.mask;
180 let space_to_end = self.shared.capacity - offset;
181
182 if space_to_end < record_size {
183 let total_needed = space_to_end + record_size;
185
186 let used = tail.wrapping_sub(self.cached_head.get());
187 let available = self.shared.capacity.saturating_sub(used);
188
189 if available < total_needed {
190 self.cached_head.set(self.shared.head.load(Ordering::Relaxed));
192 fence(Ordering::Acquire);
193
194 let used = tail.wrapping_sub(self.cached_head.get());
195 if used > self.shared.capacity || self.shared.capacity - used < total_needed {
196 return Err(TryClaimError::Full);
197 }
198 }
199
200 let new_tail = tail.wrapping_add(total_needed);
202 if self
203 .shared
204 .tail
205 .compare_exchange_weak(tail, new_tail, Ordering::Relaxed, Ordering::Relaxed)
206 .is_ok()
207 {
208 let buffer = self.shared.buffer;
210 let skip_len = space_to_end | SKIP_BIT;
211
212 fence(Ordering::Release);
214 let len_ptr = unsafe { buffer.add(offset) }.cast::<AtomicUsize>();
215 unsafe { &*len_ptr }.store(skip_len, Ordering::Relaxed);
216
217 return Ok(WriteClaim {
218 shared: &self.shared,
219 offset: 0, len,
221 record_size,
222 committed: false,
223 });
224 }
225 continue;
227 }
228
229 let new_tail = tail.wrapping_add(record_size);
231 if self
232 .shared
233 .tail
234 .compare_exchange_weak(tail, new_tail, Ordering::Relaxed, Ordering::Relaxed)
235 .is_ok()
236 {
237 return Ok(WriteClaim {
238 shared: &self.shared,
239 offset,
240 len,
241 record_size,
242 committed: false,
243 });
244 }
245 }
247 }
248
249 #[inline]
251 pub fn capacity(&self) -> usize {
252 self.shared.capacity
253 }
254
255 #[inline]
257 pub fn is_disconnected(&self) -> bool {
258 Arc::strong_count(&self.shared) == 1
263 }
264}
265
266impl std::fmt::Debug for Producer {
267 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
268 f.debug_struct("Producer")
269 .field("capacity", &self.capacity())
270 .finish_non_exhaustive()
271 }
272}
273
274pub struct WriteClaim<'a> {
284 shared: &'a Shared,
285 offset: usize,
286 len: usize,
287 record_size: usize,
288 committed: bool,
289}
290
291impl WriteClaim<'_> {
292 #[inline]
294 pub fn commit(mut self) {
295 self.do_commit();
296 self.committed = true;
297 }
298
299 #[inline]
300 fn do_commit(&mut self) {
301 let buffer = self.shared.buffer;
302 let len_ptr = unsafe { buffer.add(self.offset) }.cast::<AtomicUsize>();
303
304 fence(Ordering::Release);
306 unsafe { &*len_ptr }.store(self.len, Ordering::Relaxed);
307 }
308
309 #[inline]
311 pub fn len(&self) -> usize {
312 self.len
313 }
314
315 #[inline]
317 pub fn is_empty(&self) -> bool {
318 false
319 }
320}
321
322impl Deref for WriteClaim<'_> {
323 type Target = [u8];
324
325 #[inline]
326 fn deref(&self) -> &Self::Target {
327 let buffer = self.shared.buffer;
328 let payload_ptr = unsafe { buffer.add(self.offset + HEADER_SIZE) };
329 unsafe { std::slice::from_raw_parts(payload_ptr, self.len) }
330 }
331}
332
333impl DerefMut for WriteClaim<'_> {
334 #[inline]
335 fn deref_mut(&mut self) -> &mut Self::Target {
336 let buffer = self.shared.buffer;
337 let payload_ptr = unsafe { buffer.add(self.offset + HEADER_SIZE) };
338 unsafe { std::slice::from_raw_parts_mut(payload_ptr, self.len) }
339 }
340}
341
342impl Drop for WriteClaim<'_> {
343 fn drop(&mut self) {
344 if !self.committed {
345 let buffer = self.shared.buffer;
347 let len_ptr = unsafe { buffer.add(self.offset) }.cast::<AtomicUsize>();
348 let skip_len = self.record_size | SKIP_BIT;
349
350 fence(Ordering::Release);
351 unsafe { &*len_ptr }.store(skip_len, Ordering::Relaxed);
352 }
353 }
354}
355
356pub struct Consumer {
365 head: Cell<usize>,
367 shared: Arc<Shared>,
369}
370
371unsafe impl Send for Consumer {}
373
374impl Consumer {
375 #[inline]
383 pub fn try_claim(&mut self) -> Option<ReadClaim<'_>> {
384 let buffer = self.shared.buffer;
385
386 loop {
387 let offset = self.head.get() & self.shared.mask;
388 let len_ptr = unsafe { buffer.add(offset) }.cast::<AtomicUsize>();
389
390 let len_raw = unsafe { &*len_ptr }.load(Ordering::Relaxed);
392 fence(Ordering::Acquire);
393
394 if len_raw == 0 {
395 return None;
397 }
398
399 if len_raw & SKIP_BIT != 0 {
400 let skip_size = len_raw & LEN_MASK;
402 if skip_size > HEADER_SIZE {
404 unsafe {
405 ptr::write_bytes(
406 buffer.add(offset + HEADER_SIZE),
407 0,
408 skip_size - HEADER_SIZE,
409 );
410 }
411 }
412 fence(Ordering::Release);
414 unsafe { &*len_ptr }.store(0, Ordering::Relaxed);
415
416 self.head.set(self.head.get().wrapping_add(skip_size));
417
418 fence(Ordering::Release);
420 self.shared.head.store(self.head.get(), Ordering::Relaxed);
421
422 continue;
424 }
425
426 let len = len_raw;
428 let record_size = align8(HEADER_SIZE + len);
429
430 return Some(ReadClaim {
431 consumer: self,
432 offset,
433 len,
434 record_size,
435 });
436 }
437 }
438
439 #[inline]
441 pub fn capacity(&self) -> usize {
442 self.shared.capacity
443 }
444
445 #[inline]
447 pub fn is_disconnected(&self) -> bool {
448 Arc::strong_count(&self.shared) == 1
449 }
450}
451
452impl std::fmt::Debug for Consumer {
453 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
454 f.debug_struct("Consumer")
455 .field("capacity", &self.capacity())
456 .finish_non_exhaustive()
457 }
458}
459
460pub struct ReadClaim<'a> {
469 consumer: &'a mut Consumer,
470 offset: usize,
471 len: usize,
472 record_size: usize,
473}
474
475impl ReadClaim<'_> {
476 #[inline]
478 pub fn len(&self) -> usize {
479 self.len
480 }
481
482 #[inline]
484 pub fn is_empty(&self) -> bool {
485 self.len == 0
486 }
487}
488
489impl Deref for ReadClaim<'_> {
490 type Target = [u8];
491
492 #[inline]
493 fn deref(&self) -> &Self::Target {
494 let buffer = self.consumer.shared.buffer;
495 let payload_ptr = unsafe { buffer.add(self.offset + HEADER_SIZE) };
496 unsafe { std::slice::from_raw_parts(payload_ptr, self.len) }
497 }
498}
499
500impl Drop for ReadClaim<'_> {
501 fn drop(&mut self) {
502 let buffer = self.consumer.shared.buffer;
503
504 if self.record_size > HEADER_SIZE {
506 unsafe {
507 ptr::write_bytes(
508 buffer.add(self.offset + HEADER_SIZE),
509 0,
510 self.record_size - HEADER_SIZE,
511 );
512 }
513 }
514 fence(Ordering::Release);
516 let len_ptr = unsafe { buffer.add(self.offset) }.cast::<AtomicUsize>();
517 unsafe { &*len_ptr }.store(0, Ordering::Relaxed);
518
519 let new_head = self.consumer.head.get().wrapping_add(self.record_size);
521 self.consumer.head.set(new_head);
522
523 fence(Ordering::Release);
525 self.consumer.shared.head.store(new_head, Ordering::Relaxed);
526 }
527}
528
529#[cfg(test)]
534mod tests {
535 use super::*;
536
537 #[test]
538 fn basic_write_read() {
539 let (mut prod, mut cons) = new(1024);
540
541 let payload = b"hello world";
542 let mut claim = prod.try_claim(payload.len()).unwrap();
543 claim.copy_from_slice(payload);
544 claim.commit();
545
546 let record = cons.try_claim().unwrap();
547 assert_eq!(&*record, payload);
548 }
549
550 #[test]
551 fn empty_returns_none() {
552 let (_, mut cons) = new(1024);
553 assert!(cons.try_claim().is_none());
554 }
555
556 #[test]
557 fn multiple_records() {
558 let (mut prod, mut cons) = new(1024);
559
560 for i in 0..10 {
561 let payload = format!("message {}", i);
562 let mut claim = prod.try_claim(payload.len()).unwrap();
563 claim.copy_from_slice(payload.as_bytes());
564 claim.commit();
565 }
566
567 for i in 0..10 {
568 let record = cons.try_claim().unwrap();
569 let expected = format!("message {}", i);
570 assert_eq!(&*record, expected.as_bytes());
571 }
572
573 assert!(cons.try_claim().is_none());
574 }
575
576 #[test]
577 #[allow(clippy::redundant_clone)]
578 fn producer_is_clone() {
579 let (prod, _cons) = new(1024);
580 let _prod2 = prod.clone();
581 }
582
583 #[test]
584 fn multiple_producers_single_consumer() {
585 use std::thread;
586
587 const PRODUCERS: usize = 4;
588 const MESSAGES_PER_PRODUCER: u64 = 10_000;
589 const TOTAL: u64 = PRODUCERS as u64 * MESSAGES_PER_PRODUCER;
590
591 let (prod, mut cons) = new(64 * 1024);
592
593 let handles: Vec<_> = (0..PRODUCERS)
594 .map(|producer_id| {
595 let mut prod = prod.clone();
596 thread::spawn(move || {
597 for i in 0..MESSAGES_PER_PRODUCER {
598 let mut payload = [0u8; 16];
600 payload[..8].copy_from_slice(&(producer_id as u64).to_le_bytes());
601 payload[8..].copy_from_slice(&i.to_le_bytes());
602
603 loop {
604 match prod.try_claim(16) {
605 Ok(mut claim) => {
606 claim.copy_from_slice(&payload);
607 claim.commit();
608 break;
609 }
610 Err(_) => std::hint::spin_loop(),
611 }
612 }
613 }
614 })
615 })
616 .collect();
617
618 drop(prod);
620
621 let consumer = thread::spawn(move || {
623 let mut received = 0u64;
624 let mut per_producer = vec![0u64; PRODUCERS];
625
626 while received < TOTAL {
627 if let Some(record) = cons.try_claim() {
628 let producer_id = u64::from_le_bytes(record[..8].try_into().unwrap()) as usize;
629 let seq = u64::from_le_bytes(record[8..].try_into().unwrap());
630
631 assert_eq!(
633 seq, per_producer[producer_id],
634 "producer {} out of order",
635 producer_id
636 );
637 per_producer[producer_id] += 1;
638 received += 1;
639 } else {
640 std::hint::spin_loop();
641 }
642 }
643
644 per_producer
645 });
646
647 for h in handles {
648 h.join().unwrap();
649 }
650
651 let per_producer = consumer.join().unwrap();
652 for (i, &count) in per_producer.iter().enumerate() {
653 assert_eq!(count, MESSAGES_PER_PRODUCER, "producer {} count", i);
654 }
655 }
656
657 #[test]
658 fn aborted_claim_creates_skip() {
659 let (mut prod, mut cons) = new(1024);
660
661 {
663 let mut claim = prod.try_claim(10).unwrap();
664 claim.copy_from_slice(b"0123456789");
665 }
667
668 {
670 let mut claim = prod.try_claim(5).unwrap();
671 claim.copy_from_slice(b"hello");
672 claim.commit();
673 }
674
675 let record = cons.try_claim().unwrap();
677 assert_eq!(&*record, b"hello");
678 }
679
680 #[test]
681 fn wrap_around() {
682 let (mut prod, mut cons) = new(64);
683
684 for i in 0..20 {
686 let payload = format!("msg{:02}", i);
687 loop {
688 match prod.try_claim(payload.len()) {
689 Ok(mut claim) => {
690 claim.copy_from_slice(payload.as_bytes());
691 claim.commit();
692 break;
693 }
694 Err(_) => {
695 while cons.try_claim().is_some() {}
697 }
698 }
699 }
700 }
701 }
702
703 #[test]
704 fn full_returns_error() {
705 let (mut prod, _cons) = new(64);
706
707 let mut count = 0;
709 while let Ok(mut claim) = prod.try_claim(8) {
710 claim.copy_from_slice(b"12345678");
711 claim.commit();
712 count += 1;
713 }
714
715 assert!(count > 0);
716 assert!(prod.try_claim(8).is_err());
717 }
718
719 #[test]
720 fn disconnection_detection() {
721 let (prod, cons) = new(1024);
722
723 assert!(!prod.is_disconnected());
724 assert!(!cons.is_disconnected());
725
726 drop(cons);
727 assert!(prod.is_disconnected());
728 }
729
730 #[test]
731 #[should_panic(expected = "capacity must be at least 16")]
732 fn tiny_capacity_panics() {
733 let _ = new(8);
734 }
735
736 #[test]
737 fn zero_len_returns_error() {
738 let (mut prod, _) = new(1024);
739 assert!(matches!(prod.try_claim(0), Err(TryClaimError::ZeroLength)));
740 }
741
742 #[test]
743 fn capacity_rounds_to_power_of_two() {
744 let (prod, _) = new(100);
745 assert_eq!(prod.capacity(), 128);
746
747 let (prod, _) = new(1000);
748 assert_eq!(prod.capacity(), 1024);
749 }
750
751 #[test]
753 fn stress_multiple_producers() {
754 use std::thread;
755
756 const PRODUCERS: usize = 4;
757 const COUNT_PER_PRODUCER: u64 = 100_000;
758 const TOTAL: u64 = PRODUCERS as u64 * COUNT_PER_PRODUCER;
759 const BUFFER_SIZE: usize = 64 * 1024;
760
761 let (prod, mut cons) = new(BUFFER_SIZE);
762
763 let handles: Vec<_> = (0..PRODUCERS)
764 .map(|_| {
765 let mut prod = prod.clone();
766 thread::spawn(move || {
767 for i in 0..COUNT_PER_PRODUCER {
768 let payload = i.to_le_bytes();
769 loop {
770 match prod.try_claim(payload.len()) {
771 Ok(mut claim) => {
772 claim.copy_from_slice(&payload);
773 claim.commit();
774 break;
775 }
776 Err(_) => std::hint::spin_loop(),
777 }
778 }
779 }
780 })
781 })
782 .collect();
783
784 drop(prod);
785
786 let consumer = thread::spawn(move || {
787 let mut received = 0u64;
788 let mut sum = 0u64;
789 while received < TOTAL {
790 if let Some(record) = cons.try_claim() {
791 let value = u64::from_le_bytes((*record).try_into().unwrap());
792 sum = sum.wrapping_add(value);
793 received += 1;
794 } else {
795 std::hint::spin_loop();
796 }
797 }
798 (received, sum)
799 });
800
801 for h in handles {
802 h.join().unwrap();
803 }
804
805 let (received, sum) = consumer.join().unwrap();
806 assert_eq!(received, TOTAL);
807
808 let expected_sum = PRODUCERS as u64 * COUNT_PER_PRODUCER * (COUNT_PER_PRODUCER - 1) / 2;
811 assert_eq!(sum, expected_sum);
812 }
813}