1use std::alloc::{Layout, alloc_zeroed, dealloc, handle_alloc_error};
44use std::cell::Cell;
45use std::ops::{Deref, DerefMut};
46use std::ptr;
47use std::sync::Arc;
48use std::sync::atomic::{AtomicUsize, Ordering, fence};
49
50use crossbeam_utils::CachePadded;
51
52use crate::{LEN_MASK, SKIP_BIT, TryClaimError, align8};
53
54const HEADER_SIZE: usize = std::mem::size_of::<usize>();
58
59pub fn new(capacity: usize) -> (Producer, Consumer) {
67 assert!(capacity >= 16, "capacity must be at least 16 bytes");
68
69 let capacity = capacity.next_power_of_two();
70 let mask = capacity - 1;
71
72 let layout = Layout::from_size_align(capacity, 8)
74 .expect("valid layout: capacity is a power of two >= 16, align is 8");
75 let buffer_ptr = unsafe { alloc_zeroed(layout) };
76 if buffer_ptr.is_null() {
77 handle_alloc_error(layout);
78 }
79
80 let shared = Arc::new(Shared {
81 head: CachePadded::new(AtomicUsize::new(0)),
82 buffer: buffer_ptr,
83 capacity,
84 mask,
85 });
86
87 (
88 Producer {
89 tail: Cell::new(0),
90 cached_head: Cell::new(0),
91 shared: Arc::clone(&shared),
92 },
93 Consumer {
94 head: Cell::new(0),
95 shared,
96 },
97 )
98}
99
100struct Shared {
101 head: CachePadded<AtomicUsize>,
103 buffer: *mut u8,
105 capacity: usize,
107 mask: usize,
109}
110
111unsafe impl Send for Shared {}
114unsafe impl Sync for Shared {}
115
116impl Drop for Shared {
117 fn drop(&mut self) {
118 let layout = Layout::from_size_align(self.capacity, 8)
120 .expect("valid layout: capacity was validated at construction");
121 unsafe { dealloc(self.buffer, layout) };
122 }
123}
124
125pub struct Producer {
133 tail: Cell<usize>,
135 cached_head: Cell<usize>,
137 shared: Arc<Shared>,
139}
140
141unsafe impl Send for Producer {}
143
144impl Producer {
145 #[inline]
159 pub fn try_claim(&mut self, len: usize) -> Result<WriteClaim<'_>, TryClaimError> {
160 debug_assert!(len <= LEN_MASK, "payload too large");
161 if len == 0 {
162 return Err(TryClaimError::ZeroLength);
163 }
164
165 let record_size = align8(HEADER_SIZE + len);
166
167 let tail = self.tail.get();
169 let available = self.shared.capacity - (tail.wrapping_sub(self.cached_head.get()));
170
171 if available < record_size {
172 self.cached_head
174 .set(self.shared.head.load(Ordering::Relaxed));
175 fence(Ordering::Acquire);
176
177 let available = self.shared.capacity - (tail.wrapping_sub(self.cached_head.get()));
178 if available < record_size {
179 return Err(TryClaimError::Full);
180 }
181 }
182
183 let offset = tail & self.shared.mask;
185 let space_to_end = self.shared.capacity - offset;
186
187 if space_to_end < record_size {
188 let total_needed = space_to_end + record_size;
190 let available = self.shared.capacity - (tail.wrapping_sub(self.cached_head.get()));
191
192 if available < total_needed {
193 self.cached_head
195 .set(self.shared.head.load(Ordering::Relaxed));
196 fence(Ordering::Acquire);
197
198 let available = self.shared.capacity - (tail.wrapping_sub(self.cached_head.get()));
199 if available < total_needed {
200 return Err(TryClaimError::Full);
201 }
202 }
203
204 let buffer = self.shared.buffer;
206 let skip_len = space_to_end | SKIP_BIT;
207 fence(Ordering::Release);
208 let len_ptr = unsafe { buffer.add(offset) }.cast::<AtomicUsize>();
209 unsafe { &*len_ptr }.store(skip_len, Ordering::Relaxed);
210
211 self.tail.set(tail.wrapping_add(space_to_end));
213 let new_offset = 0;
214
215 Ok(WriteClaim {
216 producer: self,
217 offset: new_offset,
218 len,
219 record_size,
220 committed: false,
221 })
222 } else {
223 Ok(WriteClaim {
225 producer: self,
226 offset,
227 len,
228 record_size,
229 committed: false,
230 })
231 }
232 }
233
234 #[inline]
236 pub fn capacity(&self) -> usize {
237 self.shared.capacity
238 }
239
240 #[inline]
250 pub fn is_disconnected(&self) -> bool {
251 Arc::strong_count(&self.shared) == 1
252 }
253}
254
255impl std::fmt::Debug for Producer {
256 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
257 f.debug_struct("Producer")
258 .field("capacity", &self.capacity())
259 .finish_non_exhaustive()
260 }
261}
262
263pub struct WriteClaim<'a> {
273 producer: &'a mut Producer,
274 offset: usize,
275 len: usize,
276 record_size: usize,
277 committed: bool,
278}
279
280impl WriteClaim<'_> {
281 #[inline]
283 pub fn commit(mut self) {
284 self.do_commit();
285 self.committed = true;
286 }
287
288 #[inline]
289 fn do_commit(&mut self) {
290 let buffer = self.producer.shared.buffer;
291 let len_ptr = unsafe { buffer.add(self.offset) }.cast::<AtomicUsize>();
292
293 fence(Ordering::Release);
295 unsafe { &*len_ptr }.store(self.len, Ordering::Relaxed);
296
297 self.producer
299 .tail
300 .set(self.producer.tail.get().wrapping_add(self.record_size));
301 }
302
303 #[inline]
305 pub fn len(&self) -> usize {
306 self.len
307 }
308
309 #[inline]
311 pub fn is_empty(&self) -> bool {
312 false
313 }
314}
315
316impl Deref for WriteClaim<'_> {
317 type Target = [u8];
318
319 #[inline]
320 fn deref(&self) -> &Self::Target {
321 let buffer = self.producer.shared.buffer;
322 let payload_ptr = unsafe { buffer.add(self.offset + HEADER_SIZE) };
323 unsafe { std::slice::from_raw_parts(payload_ptr, self.len) }
324 }
325}
326
327impl DerefMut for WriteClaim<'_> {
328 #[inline]
329 fn deref_mut(&mut self) -> &mut Self::Target {
330 let buffer = self.producer.shared.buffer;
331 let payload_ptr = unsafe { buffer.add(self.offset + HEADER_SIZE) };
332 unsafe { std::slice::from_raw_parts_mut(payload_ptr, self.len) }
333 }
334}
335
336impl Drop for WriteClaim<'_> {
337 fn drop(&mut self) {
338 if !self.committed {
339 let buffer = self.producer.shared.buffer;
341 let len_ptr = unsafe { buffer.add(self.offset) }.cast::<AtomicUsize>();
342 let skip_len = self.record_size | SKIP_BIT;
343
344 fence(Ordering::Release);
345 unsafe { &*len_ptr }.store(skip_len, Ordering::Relaxed);
346
347 self.producer
349 .tail
350 .set(self.producer.tail.get().wrapping_add(self.record_size));
351 }
352 }
353}
354
355pub struct Consumer {
363 head: Cell<usize>,
365 shared: Arc<Shared>,
367}
368
369unsafe impl Send for Consumer {}
371
372impl Consumer {
373 #[inline]
381 pub fn try_claim(&mut self) -> Option<ReadClaim<'_>> {
382 let buffer = self.shared.buffer;
383
384 loop {
385 let offset = self.head.get() & self.shared.mask;
386 let len_ptr = unsafe { buffer.add(offset) }.cast::<AtomicUsize>();
387
388 let len_raw = unsafe { &*len_ptr }.load(Ordering::Relaxed);
390 fence(Ordering::Acquire);
391
392 if len_raw == 0 {
393 return None;
395 }
396
397 if len_raw & SKIP_BIT != 0 {
398 let skip_size = len_raw & LEN_MASK;
400 if skip_size > HEADER_SIZE {
402 unsafe {
403 ptr::write_bytes(
404 buffer.add(offset + HEADER_SIZE),
405 0,
406 skip_size - HEADER_SIZE,
407 );
408 }
409 }
410 fence(Ordering::Release);
412 unsafe { &*len_ptr }.store(0, Ordering::Relaxed);
413
414 self.head.set(self.head.get().wrapping_add(skip_size));
415
416 fence(Ordering::Release);
418 self.shared.head.store(self.head.get(), Ordering::Relaxed);
419
420 continue;
422 }
423
424 let len = len_raw;
426 let record_size = align8(HEADER_SIZE + len);
427
428 return Some(ReadClaim {
429 consumer: self,
430 offset,
431 len,
432 record_size,
433 });
434 }
435 }
436
437 #[inline]
439 pub fn capacity(&self) -> usize {
440 self.shared.capacity
441 }
442
443 #[inline]
447 pub fn is_disconnected(&self) -> bool {
448 Arc::strong_count(&self.shared) == 1
449 }
450}
451
452impl std::fmt::Debug for Consumer {
453 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
454 f.debug_struct("Consumer")
455 .field("capacity", &self.capacity())
456 .finish_non_exhaustive()
457 }
458}
459
460pub struct ReadClaim<'a> {
469 consumer: &'a mut Consumer,
470 offset: usize,
471 len: usize,
472 record_size: usize,
473}
474
475impl ReadClaim<'_> {
476 #[inline]
478 pub fn len(&self) -> usize {
479 self.len
480 }
481
482 #[inline]
484 pub fn is_empty(&self) -> bool {
485 self.len == 0
486 }
487}
488
489impl Deref for ReadClaim<'_> {
490 type Target = [u8];
491
492 #[inline]
493 fn deref(&self) -> &Self::Target {
494 let buffer = self.consumer.shared.buffer;
495 let payload_ptr = unsafe { buffer.add(self.offset + HEADER_SIZE) };
496 unsafe { std::slice::from_raw_parts(payload_ptr, self.len) }
497 }
498}
499
500impl Drop for ReadClaim<'_> {
501 fn drop(&mut self) {
502 let buffer = self.consumer.shared.buffer;
503
504 if self.record_size > HEADER_SIZE {
506 unsafe {
507 ptr::write_bytes(
508 buffer.add(self.offset + HEADER_SIZE),
509 0,
510 self.record_size - HEADER_SIZE,
511 );
512 }
513 }
514 fence(Ordering::Release);
516 let len_ptr = unsafe { buffer.add(self.offset) }.cast::<AtomicUsize>();
517 unsafe { &*len_ptr }.store(0, Ordering::Relaxed);
518
519 let new_head = self.consumer.head.get().wrapping_add(self.record_size);
521 self.consumer.head.set(new_head);
522
523 fence(Ordering::Release);
525 self.consumer.shared.head.store(new_head, Ordering::Relaxed);
526 }
527}
528
529#[cfg(test)]
534mod tests {
535 use super::*;
536
537 #[test]
538 fn basic_write_read() {
539 let (mut prod, mut cons) = new(1024);
540
541 let payload = b"hello world";
542 let mut claim = prod.try_claim(payload.len()).unwrap();
543 claim.copy_from_slice(payload);
544 claim.commit();
545
546 let record = cons.try_claim().unwrap();
547 assert_eq!(&*record, payload);
548 }
549
550 #[test]
551 fn empty_returns_none() {
552 let (_, mut cons) = new(1024);
553 assert!(cons.try_claim().is_none());
554 }
555
556 #[test]
557 fn multiple_records() {
558 let (mut prod, mut cons) = new(1024);
559
560 for i in 0..10 {
561 let payload = format!("message {}", i);
562 let mut claim = prod.try_claim(payload.len()).unwrap();
563 claim.copy_from_slice(payload.as_bytes());
564 claim.commit();
565 }
566
567 for i in 0..10 {
568 let record = cons.try_claim().unwrap();
569 let expected = format!("message {}", i);
570 assert_eq!(&*record, expected.as_bytes());
571 }
572
573 assert!(cons.try_claim().is_none());
574 }
575
576 #[test]
577 fn aborted_claim_creates_skip() {
578 let (mut prod, mut cons) = new(1024);
579
580 {
582 let mut claim = prod.try_claim(10).unwrap();
583 claim.copy_from_slice(b"0123456789");
584 }
586
587 {
589 let mut claim = prod.try_claim(5).unwrap();
590 claim.copy_from_slice(b"hello");
591 claim.commit();
592 }
593
594 let record = cons.try_claim().unwrap();
596 assert_eq!(&*record, b"hello");
597 }
598
599 #[test]
600 fn wrap_around() {
601 let (mut prod, mut cons) = new(64);
602
603 for i in 0..20 {
605 let payload = format!("msg{:02}", i);
606 loop {
607 match prod.try_claim(payload.len()) {
608 Ok(mut claim) => {
609 claim.copy_from_slice(payload.as_bytes());
610 claim.commit();
611 break;
612 }
613 Err(_) => {
614 while cons.try_claim().is_some() {}
616 }
617 }
618 }
619 }
620 }
621
622 #[test]
623 fn full_returns_error() {
624 let (mut prod, _cons) = new(64);
625
626 let mut count = 0;
628 while let Ok(mut claim) = prod.try_claim(8) {
629 claim.copy_from_slice(b"12345678");
630 claim.commit();
631 count += 1;
632 }
633
634 assert!(count > 0);
635 assert!(prod.try_claim(8).is_err());
636 }
637
638 #[test]
639 fn cross_thread() {
640 use std::thread;
641
642 let (mut prod, mut cons) = new(4096);
643
644 let producer = thread::spawn(move || {
645 for i in 0..10_000u64 {
646 let payload = i.to_le_bytes();
647 loop {
648 match prod.try_claim(payload.len()) {
649 Ok(mut claim) => {
650 claim.copy_from_slice(&payload);
651 claim.commit();
652 break;
653 }
654 Err(_) => std::hint::spin_loop(),
655 }
656 }
657 }
658 });
659
660 let consumer = thread::spawn(move || {
661 let mut received = 0u64;
662 while received < 10_000 {
663 if let Some(record) = cons.try_claim() {
664 let value = u64::from_le_bytes((*record).try_into().unwrap());
665 assert_eq!(value, received);
666 received += 1;
667 } else {
668 std::hint::spin_loop();
669 }
670 }
671 });
672
673 producer.join().unwrap();
674 consumer.join().unwrap();
675 }
676
677 #[test]
678 fn disconnection_detection() {
679 let (prod, cons) = new(1024);
680
681 assert!(!prod.is_disconnected());
682 assert!(!cons.is_disconnected());
683
684 drop(cons);
685 assert!(prod.is_disconnected());
686 }
687
688 #[test]
689 #[should_panic(expected = "capacity must be at least 16")]
690 fn tiny_capacity_panics() {
691 let _ = new(8);
692 }
693
694 #[test]
695 fn zero_len_returns_error() {
696 let (mut prod, _) = new(1024);
697 assert!(matches!(prod.try_claim(0), Err(TryClaimError::ZeroLength)));
698 }
699
700 #[test]
701 fn capacity_rounds_to_power_of_two() {
702 let (prod, _) = new(100);
703 assert_eq!(prod.capacity(), 128);
704
705 let (prod, _) = new(1000);
706 assert_eq!(prod.capacity(), 1024);
707 }
708
709 #[test]
710 fn variable_length_records() {
711 let (mut prod, mut cons) = new(4096);
712
713 let messages = [
714 "a",
715 "hello",
716 "this is a longer message",
717 "x",
718 "medium length",
719 ];
720
721 for msg in &messages {
722 let mut claim = prod.try_claim(msg.len()).unwrap();
723 claim.copy_from_slice(msg.as_bytes());
724 claim.commit();
725 }
726
727 for msg in &messages {
728 let record = cons.try_claim().unwrap();
729 assert_eq!(&*record, msg.as_bytes());
730 }
731 }
732
733 #[test]
737 fn stress_high_volume() {
738 use std::thread;
739
740 const COUNT: u64 = 1_000_000;
741 const BUFFER_SIZE: usize = 64 * 1024; let (mut prod, mut cons) = new(BUFFER_SIZE);
744
745 let producer = thread::spawn(move || {
746 for i in 0..COUNT {
747 let len = 8 + ((i % 8) * 8) as usize;
749 let mut payload = vec![0u8; len];
750 payload[..8].copy_from_slice(&i.to_le_bytes());
752
753 loop {
754 match prod.try_claim(len) {
755 Ok(mut claim) => {
756 claim.copy_from_slice(&payload);
757 claim.commit();
758 break;
759 }
760 Err(_) => std::hint::spin_loop(),
761 }
762 }
763 }
764 });
765
766 let consumer = thread::spawn(move || {
767 let mut received = 0u64;
768 while received < COUNT {
769 if let Some(record) = cons.try_claim() {
770 let seq = u64::from_le_bytes(record[..8].try_into().unwrap());
772 assert_eq!(seq, received, "sequence mismatch at {}", received);
773
774 let expected_len = 8 + ((received % 8) * 8) as usize;
776 assert_eq!(
777 record.len(),
778 expected_len,
779 "length mismatch at {}",
780 received
781 );
782
783 received += 1;
784 } else {
785 std::hint::spin_loop();
786 }
787 }
788 received
789 });
790
791 producer.join().unwrap();
792 let received = consumer.join().unwrap();
793 assert_eq!(received, COUNT);
794 }
795
796 #[test]
798 fn stress_high_contention() {
799 use std::thread;
800
801 const COUNT: u64 = 100_000;
802 const BUFFER_SIZE: usize = 256; let (mut prod, mut cons) = new(BUFFER_SIZE);
805
806 let producer = thread::spawn(move || {
807 for i in 0..COUNT {
808 let payload = i.to_le_bytes();
809 loop {
810 match prod.try_claim(payload.len()) {
811 Ok(mut claim) => {
812 claim.copy_from_slice(&payload);
813 claim.commit();
814 break;
815 }
816 Err(_) => std::hint::spin_loop(),
817 }
818 }
819 }
820 });
821
822 let consumer = thread::spawn(move || {
823 let mut received = 0u64;
824 let mut sum = 0u64;
825 while received < COUNT {
826 if let Some(record) = cons.try_claim() {
827 let value = u64::from_le_bytes((*record).try_into().unwrap());
828 assert_eq!(value, received);
829 sum = sum.wrapping_add(value);
830 received += 1;
831 } else {
832 std::hint::spin_loop();
833 }
834 }
835 sum
836 });
837
838 producer.join().unwrap();
839 let sum = consumer.join().unwrap();
840 let expected = COUNT * (COUNT - 1) / 2;
842 assert_eq!(sum, expected);
843 }
844
845 #[test]
847 fn payload_is_word_aligned() {
848 let (mut prod, mut cons) = new(1024);
849
850 for len in [1, 3, 7, 8, 13, 64, 255] {
852 let mut claim = prod.try_claim(len).unwrap();
853 let ptr = claim.as_mut_ptr();
854 assert_eq!(
855 ptr as usize % std::mem::align_of::<usize>(),
856 0,
857 "WriteClaim payload not word-aligned for len={len}"
858 );
859 claim.commit();
860
861 let record = cons.try_claim().unwrap();
862 let ptr = record.as_ptr();
863 assert_eq!(
864 ptr as usize % std::mem::align_of::<usize>(),
865 0,
866 "ReadClaim payload not word-aligned for len={len}"
867 );
868 }
869 }
870}