1use std::ops::{Deref, DerefMut};
42use std::ptr;
43use std::sync::Arc;
44use std::sync::atomic::{AtomicUsize, Ordering, fence};
45
46use crossbeam_utils::CachePadded;
47
48use crate::{LEN_MASK, SKIP_BIT, TryClaimError, align8};
49
50const HEADER_SIZE: usize = 4;
52
53pub fn new(capacity: usize) -> (Producer, Consumer) {
61 assert!(capacity >= 16, "capacity must be at least 16 bytes");
62
63 let capacity = capacity.next_power_of_two();
64 let mask = capacity - 1;
65
66 let buffer = vec![0u8; capacity].into_boxed_slice();
68 let buffer_ptr = Box::into_raw(buffer) as *mut u8;
69
70 let shared = Arc::new(Shared {
71 head: CachePadded::new(AtomicUsize::new(0)),
72 buffer: buffer_ptr,
73 capacity,
74 mask,
75 });
76
77 (
78 Producer {
79 tail: 0,
80 cached_head: 0,
81 shared: Arc::clone(&shared),
82 },
83 Consumer { head: 0, shared },
84 )
85}
86
87struct Shared {
88 head: CachePadded<AtomicUsize>,
90 buffer: *mut u8,
92 capacity: usize,
94 mask: usize,
96}
97
98unsafe impl Send for Shared {}
101unsafe impl Sync for Shared {}
102
103impl Drop for Shared {
104 fn drop(&mut self) {
105 unsafe {
106 let _ = Box::from_raw(ptr::slice_from_raw_parts_mut(self.buffer, self.capacity));
107 }
108 }
109}
110
111pub struct Producer {
119 tail: usize,
121 cached_head: usize,
123 shared: Arc<Shared>,
125}
126
127unsafe impl Send for Producer {}
129
130impl Producer {
131 #[inline]
145 pub fn try_claim(&mut self, len: usize) -> Result<WriteClaim<'_>, TryClaimError> {
146 debug_assert!(len <= LEN_MASK as usize, "payload too large");
147 if len == 0 {
148 return Err(TryClaimError::ZeroLength);
149 }
150
151 let record_size = align8(HEADER_SIZE + len);
152
153 let tail = self.tail;
155 let available = self.shared.capacity - (tail.wrapping_sub(self.cached_head));
156
157 if available < record_size {
158 self.cached_head = self.shared.head.load(Ordering::Relaxed);
160 fence(Ordering::Acquire);
161
162 let available = self.shared.capacity - (tail.wrapping_sub(self.cached_head));
163 if available < record_size {
164 return Err(TryClaimError::Full);
165 }
166 }
167
168 let offset = tail & self.shared.mask;
170 let space_to_end = self.shared.capacity - offset;
171
172 if space_to_end < record_size {
173 let total_needed = space_to_end + record_size;
175 let available = self.shared.capacity - (tail.wrapping_sub(self.cached_head));
176
177 if available < total_needed {
178 self.cached_head = self.shared.head.load(Ordering::Relaxed);
180 fence(Ordering::Acquire);
181
182 let available = self.shared.capacity - (tail.wrapping_sub(self.cached_head));
183 if available < total_needed {
184 return Err(TryClaimError::Full);
185 }
186 }
187
188 let buffer = self.shared.buffer;
190 let padding_ptr = unsafe { buffer.add(offset) };
191 let skip_len = space_to_end as u32 | SKIP_BIT;
192 unsafe {
193 ptr::write(padding_ptr as *mut u32, skip_len);
194 }
195
196 self.tail = tail.wrapping_add(space_to_end);
198 let new_offset = 0;
199
200 Ok(WriteClaim {
201 producer: self,
202 offset: new_offset,
203 len,
204 record_size,
205 committed: false,
206 })
207 } else {
208 Ok(WriteClaim {
210 producer: self,
211 offset,
212 len,
213 record_size,
214 committed: false,
215 })
216 }
217 }
218
219 #[inline]
221 pub fn capacity(&self) -> usize {
222 self.shared.capacity
223 }
224
225 #[inline]
227 pub fn is_disconnected(&self) -> bool {
228 Arc::strong_count(&self.shared) == 1
229 }
230}
231
232impl std::fmt::Debug for Producer {
233 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
234 f.debug_struct("Producer")
235 .field("capacity", &self.capacity())
236 .finish_non_exhaustive()
237 }
238}
239
240pub struct WriteClaim<'a> {
250 producer: &'a mut Producer,
251 offset: usize,
252 len: usize,
253 record_size: usize,
254 committed: bool,
255}
256
257impl WriteClaim<'_> {
258 #[inline]
260 pub fn commit(mut self) {
261 self.do_commit();
262 self.committed = true;
263 }
264
265 #[inline]
266 fn do_commit(&mut self) {
267 let buffer = self.producer.shared.buffer;
268 let len_ptr = unsafe { buffer.add(self.offset) } as *mut u32;
269
270 fence(Ordering::Release);
272 unsafe {
273 ptr::write(len_ptr, self.len as u32);
274 }
275
276 self.producer.tail = self.producer.tail.wrapping_add(self.record_size);
278 }
279
280 #[inline]
282 pub fn len(&self) -> usize {
283 self.len
284 }
285
286 #[inline]
288 pub fn is_empty(&self) -> bool {
289 false
290 }
291}
292
293impl Deref for WriteClaim<'_> {
294 type Target = [u8];
295
296 #[inline]
297 fn deref(&self) -> &Self::Target {
298 let buffer = self.producer.shared.buffer;
299 let payload_ptr = unsafe { buffer.add(self.offset + HEADER_SIZE) };
300 unsafe { std::slice::from_raw_parts(payload_ptr, self.len) }
301 }
302}
303
304impl DerefMut for WriteClaim<'_> {
305 #[inline]
306 fn deref_mut(&mut self) -> &mut Self::Target {
307 let buffer = self.producer.shared.buffer;
308 let payload_ptr = unsafe { buffer.add(self.offset + HEADER_SIZE) };
309 unsafe { std::slice::from_raw_parts_mut(payload_ptr, self.len) }
310 }
311}
312
313impl Drop for WriteClaim<'_> {
314 fn drop(&mut self) {
315 if !self.committed {
316 let buffer = self.producer.shared.buffer;
318 let len_ptr = unsafe { buffer.add(self.offset) } as *mut u32;
319 let skip_len = self.record_size as u32 | SKIP_BIT;
320
321 fence(Ordering::Release);
322 unsafe {
323 ptr::write(len_ptr, skip_len);
324 }
325
326 self.producer.tail = self.producer.tail.wrapping_add(self.record_size);
328 }
329 }
330}
331
332pub struct Consumer {
340 head: usize,
342 shared: Arc<Shared>,
344}
345
346unsafe impl Send for Consumer {}
348
349impl Consumer {
350 #[inline]
358 pub fn try_claim(&mut self) -> Option<ReadClaim<'_>> {
359 let buffer = self.shared.buffer;
360
361 loop {
362 let offset = self.head & self.shared.mask;
363 let len_ptr = unsafe { buffer.add(offset) } as *const u32;
364
365 let len_raw = unsafe { ptr::read(len_ptr) };
367 fence(Ordering::Acquire);
368
369 if len_raw == 0 {
370 return None;
372 }
373
374 if len_raw & SKIP_BIT != 0 {
375 let skip_size = (len_raw & LEN_MASK) as usize;
377 unsafe {
378 ptr::write_bytes(buffer.add(offset), 0, skip_size);
379 }
380
381 self.head = self.head.wrapping_add(skip_size);
382
383 fence(Ordering::Release);
385 self.shared.head.store(self.head, Ordering::Relaxed);
386
387 continue;
389 }
390
391 let len = len_raw as usize;
393 let record_size = align8(HEADER_SIZE + len);
394
395 return Some(ReadClaim {
396 consumer: self,
397 offset,
398 len,
399 record_size,
400 });
401 }
402 }
403
404 #[inline]
406 pub fn capacity(&self) -> usize {
407 self.shared.capacity
408 }
409
410 #[inline]
412 pub fn is_disconnected(&self) -> bool {
413 Arc::strong_count(&self.shared) == 1
414 }
415}
416
417impl std::fmt::Debug for Consumer {
418 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
419 f.debug_struct("Consumer")
420 .field("capacity", &self.capacity())
421 .finish_non_exhaustive()
422 }
423}
424
425pub struct ReadClaim<'a> {
434 consumer: &'a mut Consumer,
435 offset: usize,
436 len: usize,
437 record_size: usize,
438}
439
440impl ReadClaim<'_> {
441 #[inline]
443 pub fn len(&self) -> usize {
444 self.len
445 }
446
447 #[inline]
449 pub fn is_empty(&self) -> bool {
450 self.len == 0
451 }
452}
453
454impl Deref for ReadClaim<'_> {
455 type Target = [u8];
456
457 #[inline]
458 fn deref(&self) -> &Self::Target {
459 let buffer = self.consumer.shared.buffer;
460 let payload_ptr = unsafe { buffer.add(self.offset + HEADER_SIZE) };
461 unsafe { std::slice::from_raw_parts(payload_ptr, self.len) }
462 }
463}
464
465impl Drop for ReadClaim<'_> {
466 fn drop(&mut self) {
467 let buffer = self.consumer.shared.buffer;
468
469 unsafe {
471 ptr::write_bytes(buffer.add(self.offset), 0, self.record_size);
472 }
473
474 self.consumer.head = self.consumer.head.wrapping_add(self.record_size);
476
477 fence(Ordering::Release);
479 self.consumer
480 .shared
481 .head
482 .store(self.consumer.head, Ordering::Relaxed);
483 }
484}
485
486#[cfg(test)]
491mod tests {
492 use super::*;
493
494 #[test]
495 fn basic_write_read() {
496 let (mut prod, mut cons) = new(1024);
497
498 let payload = b"hello world";
499 let mut claim = prod.try_claim(payload.len()).unwrap();
500 claim.copy_from_slice(payload);
501 claim.commit();
502
503 let record = cons.try_claim().unwrap();
504 assert_eq!(&*record, payload);
505 }
506
507 #[test]
508 fn empty_returns_none() {
509 let (_, mut cons) = new(1024);
510 assert!(cons.try_claim().is_none());
511 }
512
513 #[test]
514 fn multiple_records() {
515 let (mut prod, mut cons) = new(1024);
516
517 for i in 0..10 {
518 let payload = format!("message {}", i);
519 let mut claim = prod.try_claim(payload.len()).unwrap();
520 claim.copy_from_slice(payload.as_bytes());
521 claim.commit();
522 }
523
524 for i in 0..10 {
525 let record = cons.try_claim().unwrap();
526 let expected = format!("message {}", i);
527 assert_eq!(&*record, expected.as_bytes());
528 }
529
530 assert!(cons.try_claim().is_none());
531 }
532
533 #[test]
534 fn aborted_claim_creates_skip() {
535 let (mut prod, mut cons) = new(1024);
536
537 {
539 let mut claim = prod.try_claim(10).unwrap();
540 claim.copy_from_slice(b"0123456789");
541 }
543
544 {
546 let mut claim = prod.try_claim(5).unwrap();
547 claim.copy_from_slice(b"hello");
548 claim.commit();
549 }
550
551 let record = cons.try_claim().unwrap();
553 assert_eq!(&*record, b"hello");
554 }
555
556 #[test]
557 fn wrap_around() {
558 let (mut prod, mut cons) = new(64);
559
560 for i in 0..20 {
562 let payload = format!("msg{:02}", i);
563 loop {
564 match prod.try_claim(payload.len()) {
565 Ok(mut claim) => {
566 claim.copy_from_slice(payload.as_bytes());
567 claim.commit();
568 break;
569 }
570 Err(_) => {
571 while cons.try_claim().is_some() {}
573 }
574 }
575 }
576 }
577 }
578
579 #[test]
580 fn full_returns_error() {
581 let (mut prod, _cons) = new(64);
582
583 let mut count = 0;
585 loop {
586 match prod.try_claim(8) {
587 Ok(mut claim) => {
588 claim.copy_from_slice(b"12345678");
589 claim.commit();
590 count += 1;
591 }
592 Err(_) => break,
593 }
594 }
595
596 assert!(count > 0);
597 assert!(prod.try_claim(8).is_err());
598 }
599
600 #[test]
601 fn cross_thread() {
602 use std::thread;
603
604 let (mut prod, mut cons) = new(4096);
605
606 let producer = thread::spawn(move || {
607 for i in 0..10_000u64 {
608 let payload = i.to_le_bytes();
609 loop {
610 match prod.try_claim(payload.len()) {
611 Ok(mut claim) => {
612 claim.copy_from_slice(&payload);
613 claim.commit();
614 break;
615 }
616 Err(_) => std::hint::spin_loop(),
617 }
618 }
619 }
620 });
621
622 let consumer = thread::spawn(move || {
623 let mut received = 0u64;
624 while received < 10_000 {
625 if let Some(record) = cons.try_claim() {
626 let value = u64::from_le_bytes((*record).try_into().unwrap());
627 assert_eq!(value, received);
628 received += 1;
629 } else {
630 std::hint::spin_loop();
631 }
632 }
633 });
634
635 producer.join().unwrap();
636 consumer.join().unwrap();
637 }
638
639 #[test]
640 fn disconnection_detection() {
641 let (prod, cons) = new(1024);
642
643 assert!(!prod.is_disconnected());
644 assert!(!cons.is_disconnected());
645
646 drop(cons);
647 assert!(prod.is_disconnected());
648 }
649
650 #[test]
651 #[should_panic(expected = "capacity must be at least 16")]
652 fn tiny_capacity_panics() {
653 let _ = new(8);
654 }
655
656 #[test]
657 fn zero_len_returns_error() {
658 let (mut prod, _) = new(1024);
659 assert!(matches!(prod.try_claim(0), Err(TryClaimError::ZeroLength)));
660 }
661
662 #[test]
663 fn capacity_rounds_to_power_of_two() {
664 let (prod, _) = new(100);
665 assert_eq!(prod.capacity(), 128);
666
667 let (prod, _) = new(1000);
668 assert_eq!(prod.capacity(), 1024);
669 }
670
671 #[test]
672 fn variable_length_records() {
673 let (mut prod, mut cons) = new(4096);
674
675 let messages = [
676 "a",
677 "hello",
678 "this is a longer message",
679 "x",
680 "medium length",
681 ];
682
683 for msg in &messages {
684 let mut claim = prod.try_claim(msg.len()).unwrap();
685 claim.copy_from_slice(msg.as_bytes());
686 claim.commit();
687 }
688
689 for msg in &messages {
690 let record = cons.try_claim().unwrap();
691 assert_eq!(&*record, msg.as_bytes());
692 }
693 }
694
695 #[test]
699 fn stress_high_volume() {
700 use std::thread;
701
702 const COUNT: u64 = 1_000_000;
703 const BUFFER_SIZE: usize = 64 * 1024; let (mut prod, mut cons) = new(BUFFER_SIZE);
706
707 let producer = thread::spawn(move || {
708 for i in 0..COUNT {
709 let len = 8 + ((i % 8) * 8) as usize;
711 let mut payload = vec![0u8; len];
712 payload[..8].copy_from_slice(&i.to_le_bytes());
714
715 loop {
716 match prod.try_claim(len) {
717 Ok(mut claim) => {
718 claim.copy_from_slice(&payload);
719 claim.commit();
720 break;
721 }
722 Err(_) => std::hint::spin_loop(),
723 }
724 }
725 }
726 });
727
728 let consumer = thread::spawn(move || {
729 let mut received = 0u64;
730 while received < COUNT {
731 if let Some(record) = cons.try_claim() {
732 let seq = u64::from_le_bytes(record[..8].try_into().unwrap());
734 assert_eq!(seq, received, "sequence mismatch at {}", received);
735
736 let expected_len = 8 + ((received % 8) * 8) as usize;
738 assert_eq!(
739 record.len(),
740 expected_len,
741 "length mismatch at {}",
742 received
743 );
744
745 received += 1;
746 } else {
747 std::hint::spin_loop();
748 }
749 }
750 received
751 });
752
753 producer.join().unwrap();
754 let received = consumer.join().unwrap();
755 assert_eq!(received, COUNT);
756 }
757
758 #[test]
760 fn stress_high_contention() {
761 use std::thread;
762
763 const COUNT: u64 = 100_000;
764 const BUFFER_SIZE: usize = 256; let (mut prod, mut cons) = new(BUFFER_SIZE);
767
768 let producer = thread::spawn(move || {
769 for i in 0..COUNT {
770 let payload = i.to_le_bytes();
771 loop {
772 match prod.try_claim(payload.len()) {
773 Ok(mut claim) => {
774 claim.copy_from_slice(&payload);
775 claim.commit();
776 break;
777 }
778 Err(_) => std::hint::spin_loop(),
779 }
780 }
781 }
782 });
783
784 let consumer = thread::spawn(move || {
785 let mut received = 0u64;
786 let mut sum = 0u64;
787 while received < COUNT {
788 if let Some(record) = cons.try_claim() {
789 let value = u64::from_le_bytes((*record).try_into().unwrap());
790 assert_eq!(value, received);
791 sum = sum.wrapping_add(value);
792 received += 1;
793 } else {
794 std::hint::spin_loop();
795 }
796 }
797 sum
798 });
799
800 producer.join().unwrap();
801 let sum = consumer.join().unwrap();
802 let expected = COUNT * (COUNT - 1) / 2;
804 assert_eq!(sum, expected);
805 }
806}