1use std::alloc::{Layout, alloc_zeroed, dealloc, handle_alloc_error};
44use std::cell::Cell;
45use std::ops::{Deref, DerefMut};
46use std::ptr;
47use std::sync::Arc;
48use std::sync::atomic::{AtomicUsize, Ordering, fence};
49
50use crossbeam_utils::CachePadded;
51
52use crate::{LEN_MASK, SKIP_BIT, TryClaimError, align8};
53
54const HEADER_SIZE: usize = std::mem::size_of::<usize>();
58
59pub fn new(capacity: usize) -> (Producer, Consumer) {
67 assert!(capacity >= 16, "capacity must be at least 16 bytes");
68
69 let capacity = capacity.next_power_of_two();
70 let mask = capacity - 1;
71
72 let layout = Layout::from_size_align(capacity, 8).unwrap();
74 let buffer_ptr = unsafe { alloc_zeroed(layout) };
75 if buffer_ptr.is_null() {
76 handle_alloc_error(layout);
77 }
78
79 let shared = Arc::new(Shared {
80 head: CachePadded::new(AtomicUsize::new(0)),
81 buffer: buffer_ptr,
82 capacity,
83 mask,
84 });
85
86 (
87 Producer {
88 tail: Cell::new(0),
89 cached_head: Cell::new(0),
90 shared: Arc::clone(&shared),
91 },
92 Consumer {
93 head: Cell::new(0),
94 shared,
95 },
96 )
97}
98
99struct Shared {
100 head: CachePadded<AtomicUsize>,
102 buffer: *mut u8,
104 capacity: usize,
106 mask: usize,
108}
109
110unsafe impl Send for Shared {}
113unsafe impl Sync for Shared {}
114
115impl Drop for Shared {
116 fn drop(&mut self) {
117 let layout = Layout::from_size_align(self.capacity, 8).unwrap();
119 unsafe { dealloc(self.buffer, layout) };
120 }
121}
122
123pub struct Producer {
131 tail: Cell<usize>,
133 cached_head: Cell<usize>,
135 shared: Arc<Shared>,
137}
138
139unsafe impl Send for Producer {}
141
142impl Producer {
143 #[inline]
157 pub fn try_claim(&mut self, len: usize) -> Result<WriteClaim<'_>, TryClaimError> {
158 debug_assert!(len <= LEN_MASK, "payload too large");
159 if len == 0 {
160 return Err(TryClaimError::ZeroLength);
161 }
162
163 let record_size = align8(HEADER_SIZE + len);
164
165 let tail = self.tail.get();
167 let available = self.shared.capacity - (tail.wrapping_sub(self.cached_head.get()));
168
169 if available < record_size {
170 self.cached_head
172 .set(self.shared.head.load(Ordering::Relaxed));
173 fence(Ordering::Acquire);
174
175 let available = self.shared.capacity - (tail.wrapping_sub(self.cached_head.get()));
176 if available < record_size {
177 return Err(TryClaimError::Full);
178 }
179 }
180
181 let offset = tail & self.shared.mask;
183 let space_to_end = self.shared.capacity - offset;
184
185 if space_to_end < record_size {
186 let total_needed = space_to_end + record_size;
188 let available = self.shared.capacity - (tail.wrapping_sub(self.cached_head.get()));
189
190 if available < total_needed {
191 self.cached_head
193 .set(self.shared.head.load(Ordering::Relaxed));
194 fence(Ordering::Acquire);
195
196 let available = self.shared.capacity - (tail.wrapping_sub(self.cached_head.get()));
197 if available < total_needed {
198 return Err(TryClaimError::Full);
199 }
200 }
201
202 let buffer = self.shared.buffer;
204 let skip_len = space_to_end | SKIP_BIT;
205 fence(Ordering::Release);
206 let len_ptr = unsafe { buffer.add(offset) }.cast::<AtomicUsize>();
207 unsafe { &*len_ptr }.store(skip_len, Ordering::Relaxed);
208
209 self.tail.set(tail.wrapping_add(space_to_end));
211 let new_offset = 0;
212
213 Ok(WriteClaim {
214 producer: self,
215 offset: new_offset,
216 len,
217 record_size,
218 committed: false,
219 })
220 } else {
221 Ok(WriteClaim {
223 producer: self,
224 offset,
225 len,
226 record_size,
227 committed: false,
228 })
229 }
230 }
231
232 #[inline]
234 pub fn capacity(&self) -> usize {
235 self.shared.capacity
236 }
237
238 #[inline]
248 pub fn is_disconnected(&self) -> bool {
249 Arc::strong_count(&self.shared) == 1
250 }
251}
252
253impl std::fmt::Debug for Producer {
254 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
255 f.debug_struct("Producer")
256 .field("capacity", &self.capacity())
257 .finish_non_exhaustive()
258 }
259}
260
261pub struct WriteClaim<'a> {
271 producer: &'a mut Producer,
272 offset: usize,
273 len: usize,
274 record_size: usize,
275 committed: bool,
276}
277
278impl WriteClaim<'_> {
279 #[inline]
281 pub fn commit(mut self) {
282 self.do_commit();
283 self.committed = true;
284 }
285
286 #[inline]
287 fn do_commit(&mut self) {
288 let buffer = self.producer.shared.buffer;
289 let len_ptr = unsafe { buffer.add(self.offset) }.cast::<AtomicUsize>();
290
291 fence(Ordering::Release);
293 unsafe { &*len_ptr }.store(self.len, Ordering::Relaxed);
294
295 self.producer
297 .tail
298 .set(self.producer.tail.get().wrapping_add(self.record_size));
299 }
300
301 #[inline]
303 pub fn len(&self) -> usize {
304 self.len
305 }
306
307 #[inline]
309 pub fn is_empty(&self) -> bool {
310 false
311 }
312}
313
314impl Deref for WriteClaim<'_> {
315 type Target = [u8];
316
317 #[inline]
318 fn deref(&self) -> &Self::Target {
319 let buffer = self.producer.shared.buffer;
320 let payload_ptr = unsafe { buffer.add(self.offset + HEADER_SIZE) };
321 unsafe { std::slice::from_raw_parts(payload_ptr, self.len) }
322 }
323}
324
325impl DerefMut for WriteClaim<'_> {
326 #[inline]
327 fn deref_mut(&mut self) -> &mut Self::Target {
328 let buffer = self.producer.shared.buffer;
329 let payload_ptr = unsafe { buffer.add(self.offset + HEADER_SIZE) };
330 unsafe { std::slice::from_raw_parts_mut(payload_ptr, self.len) }
331 }
332}
333
334impl Drop for WriteClaim<'_> {
335 fn drop(&mut self) {
336 if !self.committed {
337 let buffer = self.producer.shared.buffer;
339 let len_ptr = unsafe { buffer.add(self.offset) }.cast::<AtomicUsize>();
340 let skip_len = self.record_size | SKIP_BIT;
341
342 fence(Ordering::Release);
343 unsafe { &*len_ptr }.store(skip_len, Ordering::Relaxed);
344
345 self.producer
347 .tail
348 .set(self.producer.tail.get().wrapping_add(self.record_size));
349 }
350 }
351}
352
353pub struct Consumer {
361 head: Cell<usize>,
363 shared: Arc<Shared>,
365}
366
367unsafe impl Send for Consumer {}
369
370impl Consumer {
371 #[inline]
379 pub fn try_claim(&mut self) -> Option<ReadClaim<'_>> {
380 let buffer = self.shared.buffer;
381
382 loop {
383 let offset = self.head.get() & self.shared.mask;
384 let len_ptr = unsafe { buffer.add(offset) }.cast::<AtomicUsize>();
385
386 let len_raw = unsafe { &*len_ptr }.load(Ordering::Relaxed);
388 fence(Ordering::Acquire);
389
390 if len_raw == 0 {
391 return None;
393 }
394
395 if len_raw & SKIP_BIT != 0 {
396 let skip_size = len_raw & LEN_MASK;
398 if skip_size > HEADER_SIZE {
400 unsafe {
401 ptr::write_bytes(
402 buffer.add(offset + HEADER_SIZE),
403 0,
404 skip_size - HEADER_SIZE,
405 );
406 }
407 }
408 fence(Ordering::Release);
410 unsafe { &*len_ptr }.store(0, Ordering::Relaxed);
411
412 self.head.set(self.head.get().wrapping_add(skip_size));
413
414 fence(Ordering::Release);
416 self.shared.head.store(self.head.get(), Ordering::Relaxed);
417
418 continue;
420 }
421
422 let len = len_raw;
424 let record_size = align8(HEADER_SIZE + len);
425
426 return Some(ReadClaim {
427 consumer: self,
428 offset,
429 len,
430 record_size,
431 });
432 }
433 }
434
435 #[inline]
437 pub fn capacity(&self) -> usize {
438 self.shared.capacity
439 }
440
441 #[inline]
445 pub fn is_disconnected(&self) -> bool {
446 Arc::strong_count(&self.shared) == 1
447 }
448}
449
450impl std::fmt::Debug for Consumer {
451 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
452 f.debug_struct("Consumer")
453 .field("capacity", &self.capacity())
454 .finish_non_exhaustive()
455 }
456}
457
458pub struct ReadClaim<'a> {
467 consumer: &'a mut Consumer,
468 offset: usize,
469 len: usize,
470 record_size: usize,
471}
472
473impl ReadClaim<'_> {
474 #[inline]
476 pub fn len(&self) -> usize {
477 self.len
478 }
479
480 #[inline]
482 pub fn is_empty(&self) -> bool {
483 self.len == 0
484 }
485}
486
487impl Deref for ReadClaim<'_> {
488 type Target = [u8];
489
490 #[inline]
491 fn deref(&self) -> &Self::Target {
492 let buffer = self.consumer.shared.buffer;
493 let payload_ptr = unsafe { buffer.add(self.offset + HEADER_SIZE) };
494 unsafe { std::slice::from_raw_parts(payload_ptr, self.len) }
495 }
496}
497
498impl Drop for ReadClaim<'_> {
499 fn drop(&mut self) {
500 let buffer = self.consumer.shared.buffer;
501
502 if self.record_size > HEADER_SIZE {
504 unsafe {
505 ptr::write_bytes(
506 buffer.add(self.offset + HEADER_SIZE),
507 0,
508 self.record_size - HEADER_SIZE,
509 );
510 }
511 }
512 fence(Ordering::Release);
514 let len_ptr = unsafe { buffer.add(self.offset) }.cast::<AtomicUsize>();
515 unsafe { &*len_ptr }.store(0, Ordering::Relaxed);
516
517 let new_head = self.consumer.head.get().wrapping_add(self.record_size);
519 self.consumer.head.set(new_head);
520
521 fence(Ordering::Release);
523 self.consumer.shared.head.store(new_head, Ordering::Relaxed);
524 }
525}
526
527#[cfg(test)]
532mod tests {
533 use super::*;
534
535 #[test]
536 fn basic_write_read() {
537 let (mut prod, mut cons) = new(1024);
538
539 let payload = b"hello world";
540 let mut claim = prod.try_claim(payload.len()).unwrap();
541 claim.copy_from_slice(payload);
542 claim.commit();
543
544 let record = cons.try_claim().unwrap();
545 assert_eq!(&*record, payload);
546 }
547
548 #[test]
549 fn empty_returns_none() {
550 let (_, mut cons) = new(1024);
551 assert!(cons.try_claim().is_none());
552 }
553
554 #[test]
555 fn multiple_records() {
556 let (mut prod, mut cons) = new(1024);
557
558 for i in 0..10 {
559 let payload = format!("message {}", i);
560 let mut claim = prod.try_claim(payload.len()).unwrap();
561 claim.copy_from_slice(payload.as_bytes());
562 claim.commit();
563 }
564
565 for i in 0..10 {
566 let record = cons.try_claim().unwrap();
567 let expected = format!("message {}", i);
568 assert_eq!(&*record, expected.as_bytes());
569 }
570
571 assert!(cons.try_claim().is_none());
572 }
573
574 #[test]
575 fn aborted_claim_creates_skip() {
576 let (mut prod, mut cons) = new(1024);
577
578 {
580 let mut claim = prod.try_claim(10).unwrap();
581 claim.copy_from_slice(b"0123456789");
582 }
584
585 {
587 let mut claim = prod.try_claim(5).unwrap();
588 claim.copy_from_slice(b"hello");
589 claim.commit();
590 }
591
592 let record = cons.try_claim().unwrap();
594 assert_eq!(&*record, b"hello");
595 }
596
597 #[test]
598 fn wrap_around() {
599 let (mut prod, mut cons) = new(64);
600
601 for i in 0..20 {
603 let payload = format!("msg{:02}", i);
604 loop {
605 match prod.try_claim(payload.len()) {
606 Ok(mut claim) => {
607 claim.copy_from_slice(payload.as_bytes());
608 claim.commit();
609 break;
610 }
611 Err(_) => {
612 while cons.try_claim().is_some() {}
614 }
615 }
616 }
617 }
618 }
619
620 #[test]
621 fn full_returns_error() {
622 let (mut prod, _cons) = new(64);
623
624 let mut count = 0;
626 while let Ok(mut claim) = prod.try_claim(8) {
627 claim.copy_from_slice(b"12345678");
628 claim.commit();
629 count += 1;
630 }
631
632 assert!(count > 0);
633 assert!(prod.try_claim(8).is_err());
634 }
635
636 #[test]
637 fn cross_thread() {
638 use std::thread;
639
640 let (mut prod, mut cons) = new(4096);
641
642 let producer = thread::spawn(move || {
643 for i in 0..10_000u64 {
644 let payload = i.to_le_bytes();
645 loop {
646 match prod.try_claim(payload.len()) {
647 Ok(mut claim) => {
648 claim.copy_from_slice(&payload);
649 claim.commit();
650 break;
651 }
652 Err(_) => std::hint::spin_loop(),
653 }
654 }
655 }
656 });
657
658 let consumer = thread::spawn(move || {
659 let mut received = 0u64;
660 while received < 10_000 {
661 if let Some(record) = cons.try_claim() {
662 let value = u64::from_le_bytes((*record).try_into().unwrap());
663 assert_eq!(value, received);
664 received += 1;
665 } else {
666 std::hint::spin_loop();
667 }
668 }
669 });
670
671 producer.join().unwrap();
672 consumer.join().unwrap();
673 }
674
675 #[test]
676 fn disconnection_detection() {
677 let (prod, cons) = new(1024);
678
679 assert!(!prod.is_disconnected());
680 assert!(!cons.is_disconnected());
681
682 drop(cons);
683 assert!(prod.is_disconnected());
684 }
685
686 #[test]
687 #[should_panic(expected = "capacity must be at least 16")]
688 fn tiny_capacity_panics() {
689 let _ = new(8);
690 }
691
692 #[test]
693 fn zero_len_returns_error() {
694 let (mut prod, _) = new(1024);
695 assert!(matches!(prod.try_claim(0), Err(TryClaimError::ZeroLength)));
696 }
697
698 #[test]
699 fn capacity_rounds_to_power_of_two() {
700 let (prod, _) = new(100);
701 assert_eq!(prod.capacity(), 128);
702
703 let (prod, _) = new(1000);
704 assert_eq!(prod.capacity(), 1024);
705 }
706
707 #[test]
708 fn variable_length_records() {
709 let (mut prod, mut cons) = new(4096);
710
711 let messages = [
712 "a",
713 "hello",
714 "this is a longer message",
715 "x",
716 "medium length",
717 ];
718
719 for msg in &messages {
720 let mut claim = prod.try_claim(msg.len()).unwrap();
721 claim.copy_from_slice(msg.as_bytes());
722 claim.commit();
723 }
724
725 for msg in &messages {
726 let record = cons.try_claim().unwrap();
727 assert_eq!(&*record, msg.as_bytes());
728 }
729 }
730
731 #[test]
735 fn stress_high_volume() {
736 use std::thread;
737
738 const COUNT: u64 = 1_000_000;
739 const BUFFER_SIZE: usize = 64 * 1024; let (mut prod, mut cons) = new(BUFFER_SIZE);
742
743 let producer = thread::spawn(move || {
744 for i in 0..COUNT {
745 let len = 8 + ((i % 8) * 8) as usize;
747 let mut payload = vec![0u8; len];
748 payload[..8].copy_from_slice(&i.to_le_bytes());
750
751 loop {
752 match prod.try_claim(len) {
753 Ok(mut claim) => {
754 claim.copy_from_slice(&payload);
755 claim.commit();
756 break;
757 }
758 Err(_) => std::hint::spin_loop(),
759 }
760 }
761 }
762 });
763
764 let consumer = thread::spawn(move || {
765 let mut received = 0u64;
766 while received < COUNT {
767 if let Some(record) = cons.try_claim() {
768 let seq = u64::from_le_bytes(record[..8].try_into().unwrap());
770 assert_eq!(seq, received, "sequence mismatch at {}", received);
771
772 let expected_len = 8 + ((received % 8) * 8) as usize;
774 assert_eq!(
775 record.len(),
776 expected_len,
777 "length mismatch at {}",
778 received
779 );
780
781 received += 1;
782 } else {
783 std::hint::spin_loop();
784 }
785 }
786 received
787 });
788
789 producer.join().unwrap();
790 let received = consumer.join().unwrap();
791 assert_eq!(received, COUNT);
792 }
793
794 #[test]
796 fn stress_high_contention() {
797 use std::thread;
798
799 const COUNT: u64 = 100_000;
800 const BUFFER_SIZE: usize = 256; let (mut prod, mut cons) = new(BUFFER_SIZE);
803
804 let producer = thread::spawn(move || {
805 for i in 0..COUNT {
806 let payload = i.to_le_bytes();
807 loop {
808 match prod.try_claim(payload.len()) {
809 Ok(mut claim) => {
810 claim.copy_from_slice(&payload);
811 claim.commit();
812 break;
813 }
814 Err(_) => std::hint::spin_loop(),
815 }
816 }
817 }
818 });
819
820 let consumer = thread::spawn(move || {
821 let mut received = 0u64;
822 let mut sum = 0u64;
823 while received < COUNT {
824 if let Some(record) = cons.try_claim() {
825 let value = u64::from_le_bytes((*record).try_into().unwrap());
826 assert_eq!(value, received);
827 sum = sum.wrapping_add(value);
828 received += 1;
829 } else {
830 std::hint::spin_loop();
831 }
832 }
833 sum
834 });
835
836 producer.join().unwrap();
837 let sum = consumer.join().unwrap();
838 let expected = COUNT * (COUNT - 1) / 2;
840 assert_eq!(sum, expected);
841 }
842
843 #[test]
845 fn payload_is_word_aligned() {
846 let (mut prod, mut cons) = new(1024);
847
848 for len in [1, 3, 7, 8, 13, 64, 255] {
850 let mut claim = prod.try_claim(len).unwrap();
851 let ptr = claim.as_mut_ptr();
852 assert_eq!(
853 ptr as usize % std::mem::align_of::<usize>(),
854 0,
855 "WriteClaim payload not word-aligned for len={len}"
856 );
857 claim.commit();
858
859 let record = cons.try_claim().unwrap();
860 let ptr = record.as_ptr();
861 assert_eq!(
862 ptr as usize % std::mem::align_of::<usize>(),
863 0,
864 "ReadClaim payload not word-aligned for len={len}"
865 );
866 }
867 }
868}