1use std::alloc::{Layout, alloc_zeroed, dealloc, handle_alloc_error};
34use std::ops::{Deref, DerefMut};
35use std::ptr;
36use std::sync::Arc;
37use std::sync::atomic::{AtomicUsize, Ordering, fence};
38
39use crossbeam_utils::CachePadded;
40
41use crate::{LEN_MASK, SKIP_BIT, TryClaimError, align8};
42
43const HEADER_SIZE: usize = std::mem::size_of::<usize>();
47
48pub fn new(capacity: usize) -> (Producer, Consumer) {
56 assert!(capacity >= 16, "capacity must be at least 16 bytes");
57
58 let capacity = capacity.next_power_of_two();
59 let mask = capacity - 1;
60
61 let layout = Layout::from_size_align(capacity, 8).unwrap();
63 let buffer_ptr = unsafe { alloc_zeroed(layout) };
64 if buffer_ptr.is_null() {
65 handle_alloc_error(layout);
66 }
67
68 let shared = Arc::new(Shared {
69 head: CachePadded::new(AtomicUsize::new(0)),
70 tail: CachePadded::new(AtomicUsize::new(0)),
71 buffer: buffer_ptr,
72 capacity,
73 mask,
74 });
75
76 (
77 Producer {
78 cached_head: 0,
79 shared: Arc::clone(&shared),
80 },
81 Consumer { head: 0, shared },
82 )
83}
84
85struct Shared {
86 head: CachePadded<AtomicUsize>,
88 tail: CachePadded<AtomicUsize>,
90 buffer: *mut u8,
92 capacity: usize,
94 mask: usize,
96}
97
98unsafe impl Send for Shared {}
102unsafe impl Sync for Shared {}
103
104impl Drop for Shared {
105 fn drop(&mut self) {
106 let layout = Layout::from_size_align(self.capacity, 8).unwrap();
108 unsafe { dealloc(self.buffer, layout) };
109 }
110}
111
112#[derive(Clone)]
121pub struct Producer {
122 cached_head: usize,
124 shared: Arc<Shared>,
126}
127
128unsafe impl Send for Producer {}
130
131impl Producer {
132 #[inline]
146 pub fn try_claim(&mut self, len: usize) -> Result<WriteClaim<'_>, TryClaimError> {
147 debug_assert!(len <= LEN_MASK, "payload too large");
148 if len == 0 {
149 return Err(TryClaimError::ZeroLength);
150 }
151
152 let record_size = align8(HEADER_SIZE + len);
153
154 loop {
156 let tail = self.shared.tail.load(Ordering::Relaxed);
157
158 let used = tail.wrapping_sub(self.cached_head);
161 let available = self.shared.capacity.saturating_sub(used);
162
163 if available < record_size {
164 self.cached_head = self.shared.head.load(Ordering::Relaxed);
166 fence(Ordering::Acquire);
167
168 let used = tail.wrapping_sub(self.cached_head);
169 if used > self.shared.capacity || self.shared.capacity - used < record_size {
170 return Err(TryClaimError::Full);
171 }
172 }
173
174 let offset = tail & self.shared.mask;
176 let space_to_end = self.shared.capacity - offset;
177
178 if space_to_end < record_size {
179 let total_needed = space_to_end + record_size;
181
182 let used = tail.wrapping_sub(self.cached_head);
183 let available = self.shared.capacity.saturating_sub(used);
184
185 if available < total_needed {
186 self.cached_head = self.shared.head.load(Ordering::Relaxed);
188 fence(Ordering::Acquire);
189
190 let used = tail.wrapping_sub(self.cached_head);
191 if used > self.shared.capacity || self.shared.capacity - used < total_needed {
192 return Err(TryClaimError::Full);
193 }
194 }
195
196 let new_tail = tail.wrapping_add(total_needed);
198 if self
199 .shared
200 .tail
201 .compare_exchange_weak(tail, new_tail, Ordering::Relaxed, Ordering::Relaxed)
202 .is_ok()
203 {
204 let buffer = self.shared.buffer;
206 let skip_len = space_to_end | SKIP_BIT;
207
208 fence(Ordering::Release);
210 let len_ptr = unsafe { buffer.add(offset) }.cast::<AtomicUsize>();
211 unsafe { &*len_ptr }.store(skip_len, Ordering::Relaxed);
212
213 return Ok(WriteClaim {
214 shared: &self.shared,
215 offset: 0, len,
217 record_size,
218 committed: false,
219 });
220 }
221 continue;
223 }
224
225 let new_tail = tail.wrapping_add(record_size);
227 if self
228 .shared
229 .tail
230 .compare_exchange_weak(tail, new_tail, Ordering::Relaxed, Ordering::Relaxed)
231 .is_ok()
232 {
233 return Ok(WriteClaim {
234 shared: &self.shared,
235 offset,
236 len,
237 record_size,
238 committed: false,
239 });
240 }
241 }
243 }
244
245 #[inline]
247 pub fn capacity(&self) -> usize {
248 self.shared.capacity
249 }
250
251 #[inline]
253 pub fn is_disconnected(&self) -> bool {
254 Arc::strong_count(&self.shared) == 1
259 }
260}
261
262impl std::fmt::Debug for Producer {
263 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
264 f.debug_struct("Producer")
265 .field("capacity", &self.capacity())
266 .finish_non_exhaustive()
267 }
268}
269
270pub struct WriteClaim<'a> {
280 shared: &'a Shared,
281 offset: usize,
282 len: usize,
283 record_size: usize,
284 committed: bool,
285}
286
287impl WriteClaim<'_> {
288 #[inline]
290 pub fn commit(mut self) {
291 self.do_commit();
292 self.committed = true;
293 }
294
295 #[inline]
296 fn do_commit(&mut self) {
297 let buffer = self.shared.buffer;
298 let len_ptr = unsafe { buffer.add(self.offset) }.cast::<AtomicUsize>();
299
300 fence(Ordering::Release);
302 unsafe { &*len_ptr }.store(self.len, Ordering::Relaxed);
303 }
304
305 #[inline]
307 pub fn len(&self) -> usize {
308 self.len
309 }
310
311 #[inline]
313 pub fn is_empty(&self) -> bool {
314 false
315 }
316}
317
318impl Deref for WriteClaim<'_> {
319 type Target = [u8];
320
321 #[inline]
322 fn deref(&self) -> &Self::Target {
323 let buffer = self.shared.buffer;
324 let payload_ptr = unsafe { buffer.add(self.offset + HEADER_SIZE) };
325 unsafe { std::slice::from_raw_parts(payload_ptr, self.len) }
326 }
327}
328
329impl DerefMut for WriteClaim<'_> {
330 #[inline]
331 fn deref_mut(&mut self) -> &mut Self::Target {
332 let buffer = self.shared.buffer;
333 let payload_ptr = unsafe { buffer.add(self.offset + HEADER_SIZE) };
334 unsafe { std::slice::from_raw_parts_mut(payload_ptr, self.len) }
335 }
336}
337
338impl Drop for WriteClaim<'_> {
339 fn drop(&mut self) {
340 if !self.committed {
341 let buffer = self.shared.buffer;
343 let len_ptr = unsafe { buffer.add(self.offset) }.cast::<AtomicUsize>();
344 let skip_len = self.record_size | SKIP_BIT;
345
346 fence(Ordering::Release);
347 unsafe { &*len_ptr }.store(skip_len, Ordering::Relaxed);
348 }
349 }
350}
351
352pub struct Consumer {
361 head: usize,
363 shared: Arc<Shared>,
365}
366
367unsafe impl Send for Consumer {}
369
370impl Consumer {
371 #[inline]
379 pub fn try_claim(&mut self) -> Option<ReadClaim<'_>> {
380 let buffer = self.shared.buffer;
381
382 loop {
383 let offset = self.head & self.shared.mask;
384 let len_ptr = unsafe { buffer.add(offset) }.cast::<AtomicUsize>();
385
386 let len_raw = unsafe { &*len_ptr }.load(Ordering::Relaxed);
388 fence(Ordering::Acquire);
389
390 if len_raw == 0 {
391 return None;
393 }
394
395 if len_raw & SKIP_BIT != 0 {
396 let skip_size = len_raw & LEN_MASK;
398 if skip_size > HEADER_SIZE {
400 unsafe {
401 ptr::write_bytes(buffer.add(offset + HEADER_SIZE), 0, skip_size - HEADER_SIZE);
402 }
403 }
404 fence(Ordering::Release);
406 unsafe { &*len_ptr }.store(0, Ordering::Relaxed);
407
408 self.head = self.head.wrapping_add(skip_size);
409
410 fence(Ordering::Release);
412 self.shared.head.store(self.head, Ordering::Relaxed);
413
414 continue;
416 }
417
418 let len = len_raw;
420 let record_size = align8(HEADER_SIZE + len);
421
422 return Some(ReadClaim {
423 consumer: self,
424 offset,
425 len,
426 record_size,
427 });
428 }
429 }
430
431 #[inline]
433 pub fn capacity(&self) -> usize {
434 self.shared.capacity
435 }
436
437 #[inline]
439 pub fn is_disconnected(&self) -> bool {
440 Arc::strong_count(&self.shared) == 1
441 }
442}
443
444impl std::fmt::Debug for Consumer {
445 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
446 f.debug_struct("Consumer")
447 .field("capacity", &self.capacity())
448 .finish_non_exhaustive()
449 }
450}
451
452pub struct ReadClaim<'a> {
461 consumer: &'a mut Consumer,
462 offset: usize,
463 len: usize,
464 record_size: usize,
465}
466
467impl ReadClaim<'_> {
468 #[inline]
470 pub fn len(&self) -> usize {
471 self.len
472 }
473
474 #[inline]
476 pub fn is_empty(&self) -> bool {
477 self.len == 0
478 }
479}
480
481impl Deref for ReadClaim<'_> {
482 type Target = [u8];
483
484 #[inline]
485 fn deref(&self) -> &Self::Target {
486 let buffer = self.consumer.shared.buffer;
487 let payload_ptr = unsafe { buffer.add(self.offset + HEADER_SIZE) };
488 unsafe { std::slice::from_raw_parts(payload_ptr, self.len) }
489 }
490}
491
492impl Drop for ReadClaim<'_> {
493 fn drop(&mut self) {
494 let buffer = self.consumer.shared.buffer;
495
496 if self.record_size > HEADER_SIZE {
498 unsafe {
499 ptr::write_bytes(buffer.add(self.offset + HEADER_SIZE), 0, self.record_size - HEADER_SIZE);
500 }
501 }
502 fence(Ordering::Release);
504 let len_ptr = unsafe { buffer.add(self.offset) }.cast::<AtomicUsize>();
505 unsafe { &*len_ptr }.store(0, Ordering::Relaxed);
506
507 self.consumer.head = self.consumer.head.wrapping_add(self.record_size);
509
510 fence(Ordering::Release);
512 self.consumer
513 .shared
514 .head
515 .store(self.consumer.head, Ordering::Relaxed);
516 }
517}
518
519#[cfg(test)]
524mod tests {
525 use super::*;
526
527 #[test]
528 fn basic_write_read() {
529 let (mut prod, mut cons) = new(1024);
530
531 let payload = b"hello world";
532 let mut claim = prod.try_claim(payload.len()).unwrap();
533 claim.copy_from_slice(payload);
534 claim.commit();
535
536 let record = cons.try_claim().unwrap();
537 assert_eq!(&*record, payload);
538 }
539
540 #[test]
541 fn empty_returns_none() {
542 let (_, mut cons) = new(1024);
543 assert!(cons.try_claim().is_none());
544 }
545
546 #[test]
547 fn multiple_records() {
548 let (mut prod, mut cons) = new(1024);
549
550 for i in 0..10 {
551 let payload = format!("message {}", i);
552 let mut claim = prod.try_claim(payload.len()).unwrap();
553 claim.copy_from_slice(payload.as_bytes());
554 claim.commit();
555 }
556
557 for i in 0..10 {
558 let record = cons.try_claim().unwrap();
559 let expected = format!("message {}", i);
560 assert_eq!(&*record, expected.as_bytes());
561 }
562
563 assert!(cons.try_claim().is_none());
564 }
565
566 #[test]
567 fn producer_is_clone() {
568 let (prod, _cons) = new(1024);
569 let _prod2 = prod.clone();
570 }
571
572 #[test]
573 fn multiple_producers_single_consumer() {
574 use std::thread;
575
576 const PRODUCERS: usize = 4;
577 const MESSAGES_PER_PRODUCER: u64 = 10_000;
578 const TOTAL: u64 = PRODUCERS as u64 * MESSAGES_PER_PRODUCER;
579
580 let (prod, mut cons) = new(64 * 1024);
581
582 let handles: Vec<_> = (0..PRODUCERS)
583 .map(|producer_id| {
584 let mut prod = prod.clone();
585 thread::spawn(move || {
586 for i in 0..MESSAGES_PER_PRODUCER {
587 let mut payload = [0u8; 16];
589 payload[..8].copy_from_slice(&(producer_id as u64).to_le_bytes());
590 payload[8..].copy_from_slice(&i.to_le_bytes());
591
592 loop {
593 match prod.try_claim(16) {
594 Ok(mut claim) => {
595 claim.copy_from_slice(&payload);
596 claim.commit();
597 break;
598 }
599 Err(_) => std::hint::spin_loop(),
600 }
601 }
602 }
603 })
604 })
605 .collect();
606
607 drop(prod);
609
610 let consumer = thread::spawn(move || {
612 let mut received = 0u64;
613 let mut per_producer = vec![0u64; PRODUCERS];
614
615 while received < TOTAL {
616 if let Some(record) = cons.try_claim() {
617 let producer_id = u64::from_le_bytes(record[..8].try_into().unwrap()) as usize;
618 let seq = u64::from_le_bytes(record[8..].try_into().unwrap());
619
620 assert_eq!(
622 seq, per_producer[producer_id],
623 "producer {} out of order",
624 producer_id
625 );
626 per_producer[producer_id] += 1;
627 received += 1;
628 } else {
629 std::hint::spin_loop();
630 }
631 }
632
633 per_producer
634 });
635
636 for h in handles {
637 h.join().unwrap();
638 }
639
640 let per_producer = consumer.join().unwrap();
641 for (i, &count) in per_producer.iter().enumerate() {
642 assert_eq!(count, MESSAGES_PER_PRODUCER, "producer {} count", i);
643 }
644 }
645
646 #[test]
647 fn aborted_claim_creates_skip() {
648 let (mut prod, mut cons) = new(1024);
649
650 {
652 let mut claim = prod.try_claim(10).unwrap();
653 claim.copy_from_slice(b"0123456789");
654 }
656
657 {
659 let mut claim = prod.try_claim(5).unwrap();
660 claim.copy_from_slice(b"hello");
661 claim.commit();
662 }
663
664 let record = cons.try_claim().unwrap();
666 assert_eq!(&*record, b"hello");
667 }
668
669 #[test]
670 fn wrap_around() {
671 let (mut prod, mut cons) = new(64);
672
673 for i in 0..20 {
675 let payload = format!("msg{:02}", i);
676 loop {
677 match prod.try_claim(payload.len()) {
678 Ok(mut claim) => {
679 claim.copy_from_slice(payload.as_bytes());
680 claim.commit();
681 break;
682 }
683 Err(_) => {
684 while cons.try_claim().is_some() {}
686 }
687 }
688 }
689 }
690 }
691
692 #[test]
693 fn full_returns_error() {
694 let (mut prod, _cons) = new(64);
695
696 let mut count = 0;
698 loop {
699 match prod.try_claim(8) {
700 Ok(mut claim) => {
701 claim.copy_from_slice(b"12345678");
702 claim.commit();
703 count += 1;
704 }
705 Err(_) => break,
706 }
707 }
708
709 assert!(count > 0);
710 assert!(prod.try_claim(8).is_err());
711 }
712
713 #[test]
714 fn disconnection_detection() {
715 let (prod, cons) = new(1024);
716
717 assert!(!prod.is_disconnected());
718 assert!(!cons.is_disconnected());
719
720 drop(cons);
721 assert!(prod.is_disconnected());
722 }
723
724 #[test]
725 #[should_panic(expected = "capacity must be at least 16")]
726 fn tiny_capacity_panics() {
727 let _ = new(8);
728 }
729
730 #[test]
731 fn zero_len_returns_error() {
732 let (mut prod, _) = new(1024);
733 assert!(matches!(prod.try_claim(0), Err(TryClaimError::ZeroLength)));
734 }
735
736 #[test]
737 fn capacity_rounds_to_power_of_two() {
738 let (prod, _) = new(100);
739 assert_eq!(prod.capacity(), 128);
740
741 let (prod, _) = new(1000);
742 assert_eq!(prod.capacity(), 1024);
743 }
744
745 #[test]
747 fn stress_multiple_producers() {
748 use std::thread;
749
750 const PRODUCERS: usize = 4;
751 const COUNT_PER_PRODUCER: u64 = 100_000;
752 const TOTAL: u64 = PRODUCERS as u64 * COUNT_PER_PRODUCER;
753 const BUFFER_SIZE: usize = 64 * 1024;
754
755 let (prod, mut cons) = new(BUFFER_SIZE);
756
757 let handles: Vec<_> = (0..PRODUCERS)
758 .map(|_| {
759 let mut prod = prod.clone();
760 thread::spawn(move || {
761 for i in 0..COUNT_PER_PRODUCER {
762 let payload = i.to_le_bytes();
763 loop {
764 match prod.try_claim(payload.len()) {
765 Ok(mut claim) => {
766 claim.copy_from_slice(&payload);
767 claim.commit();
768 break;
769 }
770 Err(_) => std::hint::spin_loop(),
771 }
772 }
773 }
774 })
775 })
776 .collect();
777
778 drop(prod);
779
780 let consumer = thread::spawn(move || {
781 let mut received = 0u64;
782 let mut sum = 0u64;
783 while received < TOTAL {
784 if let Some(record) = cons.try_claim() {
785 let value = u64::from_le_bytes((*record).try_into().unwrap());
786 sum = sum.wrapping_add(value);
787 received += 1;
788 } else {
789 std::hint::spin_loop();
790 }
791 }
792 (received, sum)
793 });
794
795 for h in handles {
796 h.join().unwrap();
797 }
798
799 let (received, sum) = consumer.join().unwrap();
800 assert_eq!(received, TOTAL);
801
802 let expected_sum = PRODUCERS as u64 * COUNT_PER_PRODUCER * (COUNT_PER_PRODUCER - 1) / 2;
805 assert_eq!(sum, expected_sum);
806 }
807}