1use std::alloc::{Layout, alloc_zeroed, dealloc, handle_alloc_error};
44use std::ops::{Deref, DerefMut};
45use std::ptr;
46use std::sync::Arc;
47use std::sync::atomic::{AtomicUsize, Ordering, fence};
48
49use crossbeam_utils::CachePadded;
50
51use crate::{LEN_MASK, SKIP_BIT, TryClaimError, align8};
52
53const HEADER_SIZE: usize = std::mem::size_of::<usize>();
57
58pub fn new(capacity: usize) -> (Producer, Consumer) {
66 assert!(capacity >= 16, "capacity must be at least 16 bytes");
67
68 let capacity = capacity.next_power_of_two();
69 let mask = capacity - 1;
70
71 let layout = Layout::from_size_align(capacity, 8).unwrap();
73 let buffer_ptr = unsafe { alloc_zeroed(layout) };
74 if buffer_ptr.is_null() {
75 handle_alloc_error(layout);
76 }
77
78 let shared = Arc::new(Shared {
79 head: CachePadded::new(AtomicUsize::new(0)),
80 buffer: buffer_ptr,
81 capacity,
82 mask,
83 });
84
85 (
86 Producer {
87 tail: 0,
88 cached_head: 0,
89 shared: Arc::clone(&shared),
90 },
91 Consumer { head: 0, shared },
92 )
93}
94
95struct Shared {
96 head: CachePadded<AtomicUsize>,
98 buffer: *mut u8,
100 capacity: usize,
102 mask: usize,
104}
105
106unsafe impl Send for Shared {}
109unsafe impl Sync for Shared {}
110
111impl Drop for Shared {
112 fn drop(&mut self) {
113 let layout = Layout::from_size_align(self.capacity, 8).unwrap();
115 unsafe { dealloc(self.buffer, layout) };
116 }
117}
118
119pub struct Producer {
127 tail: usize,
129 cached_head: usize,
131 shared: Arc<Shared>,
133}
134
135unsafe impl Send for Producer {}
137
138impl Producer {
139 #[inline]
153 pub fn try_claim(&mut self, len: usize) -> Result<WriteClaim<'_>, TryClaimError> {
154 debug_assert!(len <= LEN_MASK, "payload too large");
155 if len == 0 {
156 return Err(TryClaimError::ZeroLength);
157 }
158
159 let record_size = align8(HEADER_SIZE + len);
160
161 let tail = self.tail;
163 let available = self.shared.capacity - (tail.wrapping_sub(self.cached_head));
164
165 if available < record_size {
166 self.cached_head = self.shared.head.load(Ordering::Relaxed);
168 fence(Ordering::Acquire);
169
170 let available = self.shared.capacity - (tail.wrapping_sub(self.cached_head));
171 if available < record_size {
172 return Err(TryClaimError::Full);
173 }
174 }
175
176 let offset = tail & self.shared.mask;
178 let space_to_end = self.shared.capacity - offset;
179
180 if space_to_end < record_size {
181 let total_needed = space_to_end + record_size;
183 let available = self.shared.capacity - (tail.wrapping_sub(self.cached_head));
184
185 if available < total_needed {
186 self.cached_head = self.shared.head.load(Ordering::Relaxed);
188 fence(Ordering::Acquire);
189
190 let available = self.shared.capacity - (tail.wrapping_sub(self.cached_head));
191 if available < total_needed {
192 return Err(TryClaimError::Full);
193 }
194 }
195
196 let buffer = self.shared.buffer;
198 let skip_len = space_to_end | SKIP_BIT;
199 fence(Ordering::Release);
200 let len_ptr = unsafe { buffer.add(offset) }.cast::<AtomicUsize>();
201 unsafe { &*len_ptr }.store(skip_len, Ordering::Relaxed);
202
203 self.tail = tail.wrapping_add(space_to_end);
205 let new_offset = 0;
206
207 Ok(WriteClaim {
208 producer: self,
209 offset: new_offset,
210 len,
211 record_size,
212 committed: false,
213 })
214 } else {
215 Ok(WriteClaim {
217 producer: self,
218 offset,
219 len,
220 record_size,
221 committed: false,
222 })
223 }
224 }
225
226 #[inline]
228 pub fn capacity(&self) -> usize {
229 self.shared.capacity
230 }
231
232 #[inline]
234 pub fn is_disconnected(&self) -> bool {
235 Arc::strong_count(&self.shared) == 1
236 }
237}
238
239impl std::fmt::Debug for Producer {
240 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
241 f.debug_struct("Producer")
242 .field("capacity", &self.capacity())
243 .finish_non_exhaustive()
244 }
245}
246
247pub struct WriteClaim<'a> {
257 producer: &'a mut Producer,
258 offset: usize,
259 len: usize,
260 record_size: usize,
261 committed: bool,
262}
263
264impl WriteClaim<'_> {
265 #[inline]
267 pub fn commit(mut self) {
268 self.do_commit();
269 self.committed = true;
270 }
271
272 #[inline]
273 fn do_commit(&mut self) {
274 let buffer = self.producer.shared.buffer;
275 let len_ptr = unsafe { buffer.add(self.offset) }.cast::<AtomicUsize>();
276
277 fence(Ordering::Release);
279 unsafe { &*len_ptr }.store(self.len, Ordering::Relaxed);
280
281 self.producer.tail = self.producer.tail.wrapping_add(self.record_size);
283 }
284
285 #[inline]
287 pub fn len(&self) -> usize {
288 self.len
289 }
290
291 #[inline]
293 pub fn is_empty(&self) -> bool {
294 false
295 }
296}
297
298impl Deref for WriteClaim<'_> {
299 type Target = [u8];
300
301 #[inline]
302 fn deref(&self) -> &Self::Target {
303 let buffer = self.producer.shared.buffer;
304 let payload_ptr = unsafe { buffer.add(self.offset + HEADER_SIZE) };
305 unsafe { std::slice::from_raw_parts(payload_ptr, self.len) }
306 }
307}
308
309impl DerefMut for WriteClaim<'_> {
310 #[inline]
311 fn deref_mut(&mut self) -> &mut Self::Target {
312 let buffer = self.producer.shared.buffer;
313 let payload_ptr = unsafe { buffer.add(self.offset + HEADER_SIZE) };
314 unsafe { std::slice::from_raw_parts_mut(payload_ptr, self.len) }
315 }
316}
317
318impl Drop for WriteClaim<'_> {
319 fn drop(&mut self) {
320 if !self.committed {
321 let buffer = self.producer.shared.buffer;
323 let len_ptr = unsafe { buffer.add(self.offset) }.cast::<AtomicUsize>();
324 let skip_len = self.record_size | SKIP_BIT;
325
326 fence(Ordering::Release);
327 unsafe { &*len_ptr }.store(skip_len, Ordering::Relaxed);
328
329 self.producer.tail = self.producer.tail.wrapping_add(self.record_size);
331 }
332 }
333}
334
335pub struct Consumer {
343 head: usize,
345 shared: Arc<Shared>,
347}
348
349unsafe impl Send for Consumer {}
351
352impl Consumer {
353 #[inline]
361 pub fn try_claim(&mut self) -> Option<ReadClaim<'_>> {
362 let buffer = self.shared.buffer;
363
364 loop {
365 let offset = self.head & self.shared.mask;
366 let len_ptr = unsafe { buffer.add(offset) }.cast::<AtomicUsize>();
367
368 let len_raw = unsafe { &*len_ptr }.load(Ordering::Relaxed);
370 fence(Ordering::Acquire);
371
372 if len_raw == 0 {
373 return None;
375 }
376
377 if len_raw & SKIP_BIT != 0 {
378 let skip_size = len_raw & LEN_MASK;
380 if skip_size > HEADER_SIZE {
382 unsafe {
383 ptr::write_bytes(buffer.add(offset + HEADER_SIZE), 0, skip_size - HEADER_SIZE);
384 }
385 }
386 fence(Ordering::Release);
388 unsafe { &*len_ptr }.store(0, Ordering::Relaxed);
389
390 self.head = self.head.wrapping_add(skip_size);
391
392 fence(Ordering::Release);
394 self.shared.head.store(self.head, Ordering::Relaxed);
395
396 continue;
398 }
399
400 let len = len_raw;
402 let record_size = align8(HEADER_SIZE + len);
403
404 return Some(ReadClaim {
405 consumer: self,
406 offset,
407 len,
408 record_size,
409 });
410 }
411 }
412
413 #[inline]
415 pub fn capacity(&self) -> usize {
416 self.shared.capacity
417 }
418
419 #[inline]
421 pub fn is_disconnected(&self) -> bool {
422 Arc::strong_count(&self.shared) == 1
423 }
424}
425
426impl std::fmt::Debug for Consumer {
427 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
428 f.debug_struct("Consumer")
429 .field("capacity", &self.capacity())
430 .finish_non_exhaustive()
431 }
432}
433
434pub struct ReadClaim<'a> {
443 consumer: &'a mut Consumer,
444 offset: usize,
445 len: usize,
446 record_size: usize,
447}
448
449impl ReadClaim<'_> {
450 #[inline]
452 pub fn len(&self) -> usize {
453 self.len
454 }
455
456 #[inline]
458 pub fn is_empty(&self) -> bool {
459 self.len == 0
460 }
461}
462
463impl Deref for ReadClaim<'_> {
464 type Target = [u8];
465
466 #[inline]
467 fn deref(&self) -> &Self::Target {
468 let buffer = self.consumer.shared.buffer;
469 let payload_ptr = unsafe { buffer.add(self.offset + HEADER_SIZE) };
470 unsafe { std::slice::from_raw_parts(payload_ptr, self.len) }
471 }
472}
473
474impl Drop for ReadClaim<'_> {
475 fn drop(&mut self) {
476 let buffer = self.consumer.shared.buffer;
477
478 if self.record_size > HEADER_SIZE {
480 unsafe {
481 ptr::write_bytes(buffer.add(self.offset + HEADER_SIZE), 0, self.record_size - HEADER_SIZE);
482 }
483 }
484 fence(Ordering::Release);
486 let len_ptr = unsafe { buffer.add(self.offset) }.cast::<AtomicUsize>();
487 unsafe { &*len_ptr }.store(0, Ordering::Relaxed);
488
489 self.consumer.head = self.consumer.head.wrapping_add(self.record_size);
491
492 fence(Ordering::Release);
494 self.consumer
495 .shared
496 .head
497 .store(self.consumer.head, Ordering::Relaxed);
498 }
499}
500
501#[cfg(test)]
506mod tests {
507 use super::*;
508
509 #[test]
510 fn basic_write_read() {
511 let (mut prod, mut cons) = new(1024);
512
513 let payload = b"hello world";
514 let mut claim = prod.try_claim(payload.len()).unwrap();
515 claim.copy_from_slice(payload);
516 claim.commit();
517
518 let record = cons.try_claim().unwrap();
519 assert_eq!(&*record, payload);
520 }
521
522 #[test]
523 fn empty_returns_none() {
524 let (_, mut cons) = new(1024);
525 assert!(cons.try_claim().is_none());
526 }
527
528 #[test]
529 fn multiple_records() {
530 let (mut prod, mut cons) = new(1024);
531
532 for i in 0..10 {
533 let payload = format!("message {}", i);
534 let mut claim = prod.try_claim(payload.len()).unwrap();
535 claim.copy_from_slice(payload.as_bytes());
536 claim.commit();
537 }
538
539 for i in 0..10 {
540 let record = cons.try_claim().unwrap();
541 let expected = format!("message {}", i);
542 assert_eq!(&*record, expected.as_bytes());
543 }
544
545 assert!(cons.try_claim().is_none());
546 }
547
548 #[test]
549 fn aborted_claim_creates_skip() {
550 let (mut prod, mut cons) = new(1024);
551
552 {
554 let mut claim = prod.try_claim(10).unwrap();
555 claim.copy_from_slice(b"0123456789");
556 }
558
559 {
561 let mut claim = prod.try_claim(5).unwrap();
562 claim.copy_from_slice(b"hello");
563 claim.commit();
564 }
565
566 let record = cons.try_claim().unwrap();
568 assert_eq!(&*record, b"hello");
569 }
570
571 #[test]
572 fn wrap_around() {
573 let (mut prod, mut cons) = new(64);
574
575 for i in 0..20 {
577 let payload = format!("msg{:02}", i);
578 loop {
579 match prod.try_claim(payload.len()) {
580 Ok(mut claim) => {
581 claim.copy_from_slice(payload.as_bytes());
582 claim.commit();
583 break;
584 }
585 Err(_) => {
586 while cons.try_claim().is_some() {}
588 }
589 }
590 }
591 }
592 }
593
594 #[test]
595 fn full_returns_error() {
596 let (mut prod, _cons) = new(64);
597
598 let mut count = 0;
600 loop {
601 match prod.try_claim(8) {
602 Ok(mut claim) => {
603 claim.copy_from_slice(b"12345678");
604 claim.commit();
605 count += 1;
606 }
607 Err(_) => break,
608 }
609 }
610
611 assert!(count > 0);
612 assert!(prod.try_claim(8).is_err());
613 }
614
615 #[test]
616 fn cross_thread() {
617 use std::thread;
618
619 let (mut prod, mut cons) = new(4096);
620
621 let producer = thread::spawn(move || {
622 for i in 0..10_000u64 {
623 let payload = i.to_le_bytes();
624 loop {
625 match prod.try_claim(payload.len()) {
626 Ok(mut claim) => {
627 claim.copy_from_slice(&payload);
628 claim.commit();
629 break;
630 }
631 Err(_) => std::hint::spin_loop(),
632 }
633 }
634 }
635 });
636
637 let consumer = thread::spawn(move || {
638 let mut received = 0u64;
639 while received < 10_000 {
640 if let Some(record) = cons.try_claim() {
641 let value = u64::from_le_bytes((*record).try_into().unwrap());
642 assert_eq!(value, received);
643 received += 1;
644 } else {
645 std::hint::spin_loop();
646 }
647 }
648 });
649
650 producer.join().unwrap();
651 consumer.join().unwrap();
652 }
653
654 #[test]
655 fn disconnection_detection() {
656 let (prod, cons) = new(1024);
657
658 assert!(!prod.is_disconnected());
659 assert!(!cons.is_disconnected());
660
661 drop(cons);
662 assert!(prod.is_disconnected());
663 }
664
665 #[test]
666 #[should_panic(expected = "capacity must be at least 16")]
667 fn tiny_capacity_panics() {
668 let _ = new(8);
669 }
670
671 #[test]
672 fn zero_len_returns_error() {
673 let (mut prod, _) = new(1024);
674 assert!(matches!(prod.try_claim(0), Err(TryClaimError::ZeroLength)));
675 }
676
677 #[test]
678 fn capacity_rounds_to_power_of_two() {
679 let (prod, _) = new(100);
680 assert_eq!(prod.capacity(), 128);
681
682 let (prod, _) = new(1000);
683 assert_eq!(prod.capacity(), 1024);
684 }
685
686 #[test]
687 fn variable_length_records() {
688 let (mut prod, mut cons) = new(4096);
689
690 let messages = [
691 "a",
692 "hello",
693 "this is a longer message",
694 "x",
695 "medium length",
696 ];
697
698 for msg in &messages {
699 let mut claim = prod.try_claim(msg.len()).unwrap();
700 claim.copy_from_slice(msg.as_bytes());
701 claim.commit();
702 }
703
704 for msg in &messages {
705 let record = cons.try_claim().unwrap();
706 assert_eq!(&*record, msg.as_bytes());
707 }
708 }
709
710 #[test]
714 fn stress_high_volume() {
715 use std::thread;
716
717 const COUNT: u64 = 1_000_000;
718 const BUFFER_SIZE: usize = 64 * 1024; let (mut prod, mut cons) = new(BUFFER_SIZE);
721
722 let producer = thread::spawn(move || {
723 for i in 0..COUNT {
724 let len = 8 + ((i % 8) * 8) as usize;
726 let mut payload = vec![0u8; len];
727 payload[..8].copy_from_slice(&i.to_le_bytes());
729
730 loop {
731 match prod.try_claim(len) {
732 Ok(mut claim) => {
733 claim.copy_from_slice(&payload);
734 claim.commit();
735 break;
736 }
737 Err(_) => std::hint::spin_loop(),
738 }
739 }
740 }
741 });
742
743 let consumer = thread::spawn(move || {
744 let mut received = 0u64;
745 while received < COUNT {
746 if let Some(record) = cons.try_claim() {
747 let seq = u64::from_le_bytes(record[..8].try_into().unwrap());
749 assert_eq!(seq, received, "sequence mismatch at {}", received);
750
751 let expected_len = 8 + ((received % 8) * 8) as usize;
753 assert_eq!(
754 record.len(),
755 expected_len,
756 "length mismatch at {}",
757 received
758 );
759
760 received += 1;
761 } else {
762 std::hint::spin_loop();
763 }
764 }
765 received
766 });
767
768 producer.join().unwrap();
769 let received = consumer.join().unwrap();
770 assert_eq!(received, COUNT);
771 }
772
773 #[test]
775 fn stress_high_contention() {
776 use std::thread;
777
778 const COUNT: u64 = 100_000;
779 const BUFFER_SIZE: usize = 256; let (mut prod, mut cons) = new(BUFFER_SIZE);
782
783 let producer = thread::spawn(move || {
784 for i in 0..COUNT {
785 let payload = i.to_le_bytes();
786 loop {
787 match prod.try_claim(payload.len()) {
788 Ok(mut claim) => {
789 claim.copy_from_slice(&payload);
790 claim.commit();
791 break;
792 }
793 Err(_) => std::hint::spin_loop(),
794 }
795 }
796 }
797 });
798
799 let consumer = thread::spawn(move || {
800 let mut received = 0u64;
801 let mut sum = 0u64;
802 while received < COUNT {
803 if let Some(record) = cons.try_claim() {
804 let value = u64::from_le_bytes((*record).try_into().unwrap());
805 assert_eq!(value, received);
806 sum = sum.wrapping_add(value);
807 received += 1;
808 } else {
809 std::hint::spin_loop();
810 }
811 }
812 sum
813 });
814
815 producer.join().unwrap();
816 let sum = consumer.join().unwrap();
817 let expected = COUNT * (COUNT - 1) / 2;
819 assert_eq!(sum, expected);
820 }
821
822 #[test]
824 fn payload_is_word_aligned() {
825 let (mut prod, mut cons) = new(1024);
826
827 for len in [1, 3, 7, 8, 13, 64, 255] {
829 let mut claim = prod.try_claim(len).unwrap();
830 let ptr = claim.as_mut_ptr();
831 assert_eq!(
832 ptr as usize % std::mem::align_of::<usize>(),
833 0,
834 "WriteClaim payload not word-aligned for len={len}"
835 );
836 claim.commit();
837
838 let record = cons.try_claim().unwrap();
839 let ptr = record.as_ptr();
840 assert_eq!(
841 ptr as usize % std::mem::align_of::<usize>(),
842 0,
843 "ReadClaim payload not word-aligned for len={len}"
844 );
845 }
846 }
847}