1use std::alloc::{Layout, alloc_zeroed, dealloc, handle_alloc_error};
44use std::cell::Cell;
45use std::ops::{Deref, DerefMut};
46use std::ptr;
47use std::sync::Arc;
48use std::sync::atomic::{AtomicUsize, Ordering, fence};
49
50use crossbeam_utils::CachePadded;
51
52use crate::{LEN_MASK, SKIP_BIT, TryClaimError, align8};
53
54const HEADER_SIZE: usize = std::mem::size_of::<usize>();
58
59pub fn new(capacity: usize) -> (Producer, Consumer) {
67 assert!(capacity >= 16, "capacity must be at least 16 bytes");
68
69 let capacity = capacity.next_power_of_two();
70 let mask = capacity - 1;
71
72 let layout = Layout::from_size_align(capacity, 8).unwrap();
74 let buffer_ptr = unsafe { alloc_zeroed(layout) };
75 if buffer_ptr.is_null() {
76 handle_alloc_error(layout);
77 }
78
79 let shared = Arc::new(Shared {
80 head: CachePadded::new(AtomicUsize::new(0)),
81 buffer: buffer_ptr,
82 capacity,
83 mask,
84 });
85
86 (
87 Producer {
88 tail: Cell::new(0),
89 cached_head: Cell::new(0),
90 shared: Arc::clone(&shared),
91 },
92 Consumer {
93 head: Cell::new(0),
94 shared,
95 },
96 )
97}
98
99struct Shared {
100 head: CachePadded<AtomicUsize>,
102 buffer: *mut u8,
104 capacity: usize,
106 mask: usize,
108}
109
110unsafe impl Send for Shared {}
113unsafe impl Sync for Shared {}
114
115impl Drop for Shared {
116 fn drop(&mut self) {
117 let layout = Layout::from_size_align(self.capacity, 8).unwrap();
119 unsafe { dealloc(self.buffer, layout) };
120 }
121}
122
123pub struct Producer {
131 tail: Cell<usize>,
133 cached_head: Cell<usize>,
135 shared: Arc<Shared>,
137}
138
139unsafe impl Send for Producer {}
141
142impl Producer {
143 #[inline]
157 pub fn try_claim(&mut self, len: usize) -> Result<WriteClaim<'_>, TryClaimError> {
158 debug_assert!(len <= LEN_MASK, "payload too large");
159 if len == 0 {
160 return Err(TryClaimError::ZeroLength);
161 }
162
163 let record_size = align8(HEADER_SIZE + len);
164
165 let tail = self.tail.get();
167 let available = self.shared.capacity - (tail.wrapping_sub(self.cached_head.get()));
168
169 if available < record_size {
170 self.cached_head.set(self.shared.head.load(Ordering::Relaxed));
172 fence(Ordering::Acquire);
173
174 let available = self.shared.capacity - (tail.wrapping_sub(self.cached_head.get()));
175 if available < record_size {
176 return Err(TryClaimError::Full);
177 }
178 }
179
180 let offset = tail & self.shared.mask;
182 let space_to_end = self.shared.capacity - offset;
183
184 if space_to_end < record_size {
185 let total_needed = space_to_end + record_size;
187 let available = self.shared.capacity - (tail.wrapping_sub(self.cached_head.get()));
188
189 if available < total_needed {
190 self.cached_head.set(self.shared.head.load(Ordering::Relaxed));
192 fence(Ordering::Acquire);
193
194 let available = self.shared.capacity - (tail.wrapping_sub(self.cached_head.get()));
195 if available < total_needed {
196 return Err(TryClaimError::Full);
197 }
198 }
199
200 let buffer = self.shared.buffer;
202 let skip_len = space_to_end | SKIP_BIT;
203 fence(Ordering::Release);
204 let len_ptr = unsafe { buffer.add(offset) }.cast::<AtomicUsize>();
205 unsafe { &*len_ptr }.store(skip_len, Ordering::Relaxed);
206
207 self.tail.set(tail.wrapping_add(space_to_end));
209 let new_offset = 0;
210
211 Ok(WriteClaim {
212 producer: self,
213 offset: new_offset,
214 len,
215 record_size,
216 committed: false,
217 })
218 } else {
219 Ok(WriteClaim {
221 producer: self,
222 offset,
223 len,
224 record_size,
225 committed: false,
226 })
227 }
228 }
229
230 #[inline]
232 pub fn capacity(&self) -> usize {
233 self.shared.capacity
234 }
235
236 #[inline]
238 pub fn is_disconnected(&self) -> bool {
239 Arc::strong_count(&self.shared) == 1
240 }
241}
242
243impl std::fmt::Debug for Producer {
244 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
245 f.debug_struct("Producer")
246 .field("capacity", &self.capacity())
247 .finish_non_exhaustive()
248 }
249}
250
251pub struct WriteClaim<'a> {
261 producer: &'a mut Producer,
262 offset: usize,
263 len: usize,
264 record_size: usize,
265 committed: bool,
266}
267
268impl WriteClaim<'_> {
269 #[inline]
271 pub fn commit(mut self) {
272 self.do_commit();
273 self.committed = true;
274 }
275
276 #[inline]
277 fn do_commit(&mut self) {
278 let buffer = self.producer.shared.buffer;
279 let len_ptr = unsafe { buffer.add(self.offset) }.cast::<AtomicUsize>();
280
281 fence(Ordering::Release);
283 unsafe { &*len_ptr }.store(self.len, Ordering::Relaxed);
284
285 self.producer
287 .tail
288 .set(self.producer.tail.get().wrapping_add(self.record_size));
289 }
290
291 #[inline]
293 pub fn len(&self) -> usize {
294 self.len
295 }
296
297 #[inline]
299 pub fn is_empty(&self) -> bool {
300 false
301 }
302}
303
304impl Deref for WriteClaim<'_> {
305 type Target = [u8];
306
307 #[inline]
308 fn deref(&self) -> &Self::Target {
309 let buffer = self.producer.shared.buffer;
310 let payload_ptr = unsafe { buffer.add(self.offset + HEADER_SIZE) };
311 unsafe { std::slice::from_raw_parts(payload_ptr, self.len) }
312 }
313}
314
315impl DerefMut for WriteClaim<'_> {
316 #[inline]
317 fn deref_mut(&mut self) -> &mut Self::Target {
318 let buffer = self.producer.shared.buffer;
319 let payload_ptr = unsafe { buffer.add(self.offset + HEADER_SIZE) };
320 unsafe { std::slice::from_raw_parts_mut(payload_ptr, self.len) }
321 }
322}
323
324impl Drop for WriteClaim<'_> {
325 fn drop(&mut self) {
326 if !self.committed {
327 let buffer = self.producer.shared.buffer;
329 let len_ptr = unsafe { buffer.add(self.offset) }.cast::<AtomicUsize>();
330 let skip_len = self.record_size | SKIP_BIT;
331
332 fence(Ordering::Release);
333 unsafe { &*len_ptr }.store(skip_len, Ordering::Relaxed);
334
335 self.producer
337 .tail
338 .set(self.producer.tail.get().wrapping_add(self.record_size));
339 }
340 }
341}
342
343pub struct Consumer {
351 head: Cell<usize>,
353 shared: Arc<Shared>,
355}
356
357unsafe impl Send for Consumer {}
359
360impl Consumer {
361 #[inline]
369 pub fn try_claim(&mut self) -> Option<ReadClaim<'_>> {
370 let buffer = self.shared.buffer;
371
372 loop {
373 let offset = self.head.get() & self.shared.mask;
374 let len_ptr = unsafe { buffer.add(offset) }.cast::<AtomicUsize>();
375
376 let len_raw = unsafe { &*len_ptr }.load(Ordering::Relaxed);
378 fence(Ordering::Acquire);
379
380 if len_raw == 0 {
381 return None;
383 }
384
385 if len_raw & SKIP_BIT != 0 {
386 let skip_size = len_raw & LEN_MASK;
388 if skip_size > HEADER_SIZE {
390 unsafe {
391 ptr::write_bytes(
392 buffer.add(offset + HEADER_SIZE),
393 0,
394 skip_size - HEADER_SIZE,
395 );
396 }
397 }
398 fence(Ordering::Release);
400 unsafe { &*len_ptr }.store(0, Ordering::Relaxed);
401
402 self.head.set(self.head.get().wrapping_add(skip_size));
403
404 fence(Ordering::Release);
406 self.shared.head.store(self.head.get(), Ordering::Relaxed);
407
408 continue;
410 }
411
412 let len = len_raw;
414 let record_size = align8(HEADER_SIZE + len);
415
416 return Some(ReadClaim {
417 consumer: self,
418 offset,
419 len,
420 record_size,
421 });
422 }
423 }
424
425 #[inline]
427 pub fn capacity(&self) -> usize {
428 self.shared.capacity
429 }
430
431 #[inline]
433 pub fn is_disconnected(&self) -> bool {
434 Arc::strong_count(&self.shared) == 1
435 }
436}
437
438impl std::fmt::Debug for Consumer {
439 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
440 f.debug_struct("Consumer")
441 .field("capacity", &self.capacity())
442 .finish_non_exhaustive()
443 }
444}
445
446pub struct ReadClaim<'a> {
455 consumer: &'a mut Consumer,
456 offset: usize,
457 len: usize,
458 record_size: usize,
459}
460
461impl ReadClaim<'_> {
462 #[inline]
464 pub fn len(&self) -> usize {
465 self.len
466 }
467
468 #[inline]
470 pub fn is_empty(&self) -> bool {
471 self.len == 0
472 }
473}
474
475impl Deref for ReadClaim<'_> {
476 type Target = [u8];
477
478 #[inline]
479 fn deref(&self) -> &Self::Target {
480 let buffer = self.consumer.shared.buffer;
481 let payload_ptr = unsafe { buffer.add(self.offset + HEADER_SIZE) };
482 unsafe { std::slice::from_raw_parts(payload_ptr, self.len) }
483 }
484}
485
486impl Drop for ReadClaim<'_> {
487 fn drop(&mut self) {
488 let buffer = self.consumer.shared.buffer;
489
490 if self.record_size > HEADER_SIZE {
492 unsafe {
493 ptr::write_bytes(
494 buffer.add(self.offset + HEADER_SIZE),
495 0,
496 self.record_size - HEADER_SIZE,
497 );
498 }
499 }
500 fence(Ordering::Release);
502 let len_ptr = unsafe { buffer.add(self.offset) }.cast::<AtomicUsize>();
503 unsafe { &*len_ptr }.store(0, Ordering::Relaxed);
504
505 let new_head = self.consumer.head.get().wrapping_add(self.record_size);
507 self.consumer.head.set(new_head);
508
509 fence(Ordering::Release);
511 self.consumer.shared.head.store(new_head, Ordering::Relaxed);
512 }
513}
514
515#[cfg(test)]
520mod tests {
521 use super::*;
522
523 #[test]
524 fn basic_write_read() {
525 let (mut prod, mut cons) = new(1024);
526
527 let payload = b"hello world";
528 let mut claim = prod.try_claim(payload.len()).unwrap();
529 claim.copy_from_slice(payload);
530 claim.commit();
531
532 let record = cons.try_claim().unwrap();
533 assert_eq!(&*record, payload);
534 }
535
536 #[test]
537 fn empty_returns_none() {
538 let (_, mut cons) = new(1024);
539 assert!(cons.try_claim().is_none());
540 }
541
542 #[test]
543 fn multiple_records() {
544 let (mut prod, mut cons) = new(1024);
545
546 for i in 0..10 {
547 let payload = format!("message {}", i);
548 let mut claim = prod.try_claim(payload.len()).unwrap();
549 claim.copy_from_slice(payload.as_bytes());
550 claim.commit();
551 }
552
553 for i in 0..10 {
554 let record = cons.try_claim().unwrap();
555 let expected = format!("message {}", i);
556 assert_eq!(&*record, expected.as_bytes());
557 }
558
559 assert!(cons.try_claim().is_none());
560 }
561
562 #[test]
563 fn aborted_claim_creates_skip() {
564 let (mut prod, mut cons) = new(1024);
565
566 {
568 let mut claim = prod.try_claim(10).unwrap();
569 claim.copy_from_slice(b"0123456789");
570 }
572
573 {
575 let mut claim = prod.try_claim(5).unwrap();
576 claim.copy_from_slice(b"hello");
577 claim.commit();
578 }
579
580 let record = cons.try_claim().unwrap();
582 assert_eq!(&*record, b"hello");
583 }
584
585 #[test]
586 fn wrap_around() {
587 let (mut prod, mut cons) = new(64);
588
589 for i in 0..20 {
591 let payload = format!("msg{:02}", i);
592 loop {
593 match prod.try_claim(payload.len()) {
594 Ok(mut claim) => {
595 claim.copy_from_slice(payload.as_bytes());
596 claim.commit();
597 break;
598 }
599 Err(_) => {
600 while cons.try_claim().is_some() {}
602 }
603 }
604 }
605 }
606 }
607
608 #[test]
609 fn full_returns_error() {
610 let (mut prod, _cons) = new(64);
611
612 let mut count = 0;
614 while let Ok(mut claim) = prod.try_claim(8) {
615 claim.copy_from_slice(b"12345678");
616 claim.commit();
617 count += 1;
618 }
619
620 assert!(count > 0);
621 assert!(prod.try_claim(8).is_err());
622 }
623
624 #[test]
625 fn cross_thread() {
626 use std::thread;
627
628 let (mut prod, mut cons) = new(4096);
629
630 let producer = thread::spawn(move || {
631 for i in 0..10_000u64 {
632 let payload = i.to_le_bytes();
633 loop {
634 match prod.try_claim(payload.len()) {
635 Ok(mut claim) => {
636 claim.copy_from_slice(&payload);
637 claim.commit();
638 break;
639 }
640 Err(_) => std::hint::spin_loop(),
641 }
642 }
643 }
644 });
645
646 let consumer = thread::spawn(move || {
647 let mut received = 0u64;
648 while received < 10_000 {
649 if let Some(record) = cons.try_claim() {
650 let value = u64::from_le_bytes((*record).try_into().unwrap());
651 assert_eq!(value, received);
652 received += 1;
653 } else {
654 std::hint::spin_loop();
655 }
656 }
657 });
658
659 producer.join().unwrap();
660 consumer.join().unwrap();
661 }
662
663 #[test]
664 fn disconnection_detection() {
665 let (prod, cons) = new(1024);
666
667 assert!(!prod.is_disconnected());
668 assert!(!cons.is_disconnected());
669
670 drop(cons);
671 assert!(prod.is_disconnected());
672 }
673
674 #[test]
675 #[should_panic(expected = "capacity must be at least 16")]
676 fn tiny_capacity_panics() {
677 let _ = new(8);
678 }
679
680 #[test]
681 fn zero_len_returns_error() {
682 let (mut prod, _) = new(1024);
683 assert!(matches!(prod.try_claim(0), Err(TryClaimError::ZeroLength)));
684 }
685
686 #[test]
687 fn capacity_rounds_to_power_of_two() {
688 let (prod, _) = new(100);
689 assert_eq!(prod.capacity(), 128);
690
691 let (prod, _) = new(1000);
692 assert_eq!(prod.capacity(), 1024);
693 }
694
695 #[test]
696 fn variable_length_records() {
697 let (mut prod, mut cons) = new(4096);
698
699 let messages = [
700 "a",
701 "hello",
702 "this is a longer message",
703 "x",
704 "medium length",
705 ];
706
707 for msg in &messages {
708 let mut claim = prod.try_claim(msg.len()).unwrap();
709 claim.copy_from_slice(msg.as_bytes());
710 claim.commit();
711 }
712
713 for msg in &messages {
714 let record = cons.try_claim().unwrap();
715 assert_eq!(&*record, msg.as_bytes());
716 }
717 }
718
719 #[test]
723 fn stress_high_volume() {
724 use std::thread;
725
726 const COUNT: u64 = 1_000_000;
727 const BUFFER_SIZE: usize = 64 * 1024; let (mut prod, mut cons) = new(BUFFER_SIZE);
730
731 let producer = thread::spawn(move || {
732 for i in 0..COUNT {
733 let len = 8 + ((i % 8) * 8) as usize;
735 let mut payload = vec![0u8; len];
736 payload[..8].copy_from_slice(&i.to_le_bytes());
738
739 loop {
740 match prod.try_claim(len) {
741 Ok(mut claim) => {
742 claim.copy_from_slice(&payload);
743 claim.commit();
744 break;
745 }
746 Err(_) => std::hint::spin_loop(),
747 }
748 }
749 }
750 });
751
752 let consumer = thread::spawn(move || {
753 let mut received = 0u64;
754 while received < COUNT {
755 if let Some(record) = cons.try_claim() {
756 let seq = u64::from_le_bytes(record[..8].try_into().unwrap());
758 assert_eq!(seq, received, "sequence mismatch at {}", received);
759
760 let expected_len = 8 + ((received % 8) * 8) as usize;
762 assert_eq!(
763 record.len(),
764 expected_len,
765 "length mismatch at {}",
766 received
767 );
768
769 received += 1;
770 } else {
771 std::hint::spin_loop();
772 }
773 }
774 received
775 });
776
777 producer.join().unwrap();
778 let received = consumer.join().unwrap();
779 assert_eq!(received, COUNT);
780 }
781
782 #[test]
784 fn stress_high_contention() {
785 use std::thread;
786
787 const COUNT: u64 = 100_000;
788 const BUFFER_SIZE: usize = 256; let (mut prod, mut cons) = new(BUFFER_SIZE);
791
792 let producer = thread::spawn(move || {
793 for i in 0..COUNT {
794 let payload = i.to_le_bytes();
795 loop {
796 match prod.try_claim(payload.len()) {
797 Ok(mut claim) => {
798 claim.copy_from_slice(&payload);
799 claim.commit();
800 break;
801 }
802 Err(_) => std::hint::spin_loop(),
803 }
804 }
805 }
806 });
807
808 let consumer = thread::spawn(move || {
809 let mut received = 0u64;
810 let mut sum = 0u64;
811 while received < COUNT {
812 if let Some(record) = cons.try_claim() {
813 let value = u64::from_le_bytes((*record).try_into().unwrap());
814 assert_eq!(value, received);
815 sum = sum.wrapping_add(value);
816 received += 1;
817 } else {
818 std::hint::spin_loop();
819 }
820 }
821 sum
822 });
823
824 producer.join().unwrap();
825 let sum = consumer.join().unwrap();
826 let expected = COUNT * (COUNT - 1) / 2;
828 assert_eq!(sum, expected);
829 }
830
831 #[test]
833 fn payload_is_word_aligned() {
834 let (mut prod, mut cons) = new(1024);
835
836 for len in [1, 3, 7, 8, 13, 64, 255] {
838 let mut claim = prod.try_claim(len).unwrap();
839 let ptr = claim.as_mut_ptr();
840 assert_eq!(
841 ptr as usize % std::mem::align_of::<usize>(),
842 0,
843 "WriteClaim payload not word-aligned for len={len}"
844 );
845 claim.commit();
846
847 let record = cons.try_claim().unwrap();
848 let ptr = record.as_ptr();
849 assert_eq!(
850 ptr as usize % std::mem::align_of::<usize>(),
851 0,
852 "ReadClaim payload not word-aligned for len={len}"
853 );
854 }
855 }
856}