1use std::ops::{Deref, DerefMut};
34use std::ptr;
35use std::sync::Arc;
36use std::sync::atomic::{AtomicUsize, Ordering, fence};
37
38use crossbeam_utils::CachePadded;
39
40use crate::{LEN_MASK, SKIP_BIT, TryClaimError, align8};
41
42const HEADER_SIZE: usize = 4;
44
45pub fn new(capacity: usize) -> (Producer, Consumer) {
53 assert!(capacity >= 16, "capacity must be at least 16 bytes");
54
55 let capacity = capacity.next_power_of_two();
56 let mask = capacity - 1;
57
58 let buffer = vec![0u8; capacity].into_boxed_slice();
60 let buffer_ptr = Box::into_raw(buffer) as *mut u8;
61
62 let shared = Arc::new(Shared {
63 head: CachePadded::new(AtomicUsize::new(0)),
64 tail: CachePadded::new(AtomicUsize::new(0)),
65 buffer: buffer_ptr,
66 capacity,
67 mask,
68 });
69
70 (
71 Producer {
72 cached_head: 0,
73 shared: Arc::clone(&shared),
74 },
75 Consumer { head: 0, shared },
76 )
77}
78
79struct Shared {
80 head: CachePadded<AtomicUsize>,
82 tail: CachePadded<AtomicUsize>,
84 buffer: *mut u8,
86 capacity: usize,
88 mask: usize,
90}
91
92unsafe impl Send for Shared {}
96unsafe impl Sync for Shared {}
97
98impl Drop for Shared {
99 fn drop(&mut self) {
100 unsafe {
101 let _ = Box::from_raw(ptr::slice_from_raw_parts_mut(self.buffer, self.capacity));
102 }
103 }
104}
105
106#[derive(Clone)]
115pub struct Producer {
116 cached_head: usize,
118 shared: Arc<Shared>,
120}
121
122unsafe impl Send for Producer {}
124
125impl Producer {
126 #[inline]
140 pub fn try_claim(&mut self, len: usize) -> Result<WriteClaim<'_>, TryClaimError> {
141 debug_assert!(len <= LEN_MASK as usize, "payload too large");
142 if len == 0 {
143 return Err(TryClaimError::ZeroLength);
144 }
145
146 let record_size = align8(HEADER_SIZE + len);
147
148 loop {
150 let tail = self.shared.tail.load(Ordering::Relaxed);
151
152 let used = tail.wrapping_sub(self.cached_head);
155 let available = self.shared.capacity.saturating_sub(used);
156
157 if available < record_size {
158 self.cached_head = self.shared.head.load(Ordering::Relaxed);
160 fence(Ordering::Acquire);
161
162 let used = tail.wrapping_sub(self.cached_head);
163 if used > self.shared.capacity || self.shared.capacity - used < record_size {
164 return Err(TryClaimError::Full);
165 }
166 }
167
168 let offset = tail & self.shared.mask;
170 let space_to_end = self.shared.capacity - offset;
171
172 if space_to_end < record_size {
173 let total_needed = space_to_end + record_size;
175
176 let used = tail.wrapping_sub(self.cached_head);
177 let available = self.shared.capacity.saturating_sub(used);
178
179 if available < total_needed {
180 self.cached_head = self.shared.head.load(Ordering::Relaxed);
182 fence(Ordering::Acquire);
183
184 let used = tail.wrapping_sub(self.cached_head);
185 if used > self.shared.capacity || self.shared.capacity - used < total_needed {
186 return Err(TryClaimError::Full);
187 }
188 }
189
190 let new_tail = tail.wrapping_add(total_needed);
192 if self
193 .shared
194 .tail
195 .compare_exchange_weak(tail, new_tail, Ordering::Relaxed, Ordering::Relaxed)
196 .is_ok()
197 {
198 let buffer = self.shared.buffer;
200 let padding_ptr = unsafe { buffer.add(offset) };
201 let skip_len = space_to_end as u32 | SKIP_BIT;
202
203 fence(Ordering::Release);
205 unsafe {
206 ptr::write(padding_ptr as *mut u32, skip_len);
207 }
208
209 return Ok(WriteClaim {
210 shared: &self.shared,
211 offset: 0, len,
213 record_size,
214 committed: false,
215 });
216 }
217 continue;
219 }
220
221 let new_tail = tail.wrapping_add(record_size);
223 if self
224 .shared
225 .tail
226 .compare_exchange_weak(tail, new_tail, Ordering::Relaxed, Ordering::Relaxed)
227 .is_ok()
228 {
229 return Ok(WriteClaim {
230 shared: &self.shared,
231 offset,
232 len,
233 record_size,
234 committed: false,
235 });
236 }
237 }
239 }
240
241 #[inline]
243 pub fn capacity(&self) -> usize {
244 self.shared.capacity
245 }
246
247 #[inline]
249 pub fn is_disconnected(&self) -> bool {
250 Arc::strong_count(&self.shared) == 1
255 }
256}
257
258impl std::fmt::Debug for Producer {
259 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
260 f.debug_struct("Producer")
261 .field("capacity", &self.capacity())
262 .finish_non_exhaustive()
263 }
264}
265
266pub struct WriteClaim<'a> {
276 shared: &'a Shared,
277 offset: usize,
278 len: usize,
279 record_size: usize,
280 committed: bool,
281}
282
283impl WriteClaim<'_> {
284 #[inline]
286 pub fn commit(mut self) {
287 self.do_commit();
288 self.committed = true;
289 }
290
291 #[inline]
292 fn do_commit(&mut self) {
293 let buffer = self.shared.buffer;
294 let len_ptr = unsafe { buffer.add(self.offset) } as *mut u32;
295
296 fence(Ordering::Release);
298 unsafe {
299 ptr::write(len_ptr, self.len as u32);
300 }
301 }
302
303 #[inline]
305 pub fn len(&self) -> usize {
306 self.len
307 }
308
309 #[inline]
311 pub fn is_empty(&self) -> bool {
312 false
313 }
314}
315
316impl Deref for WriteClaim<'_> {
317 type Target = [u8];
318
319 #[inline]
320 fn deref(&self) -> &Self::Target {
321 let buffer = self.shared.buffer;
322 let payload_ptr = unsafe { buffer.add(self.offset + HEADER_SIZE) };
323 unsafe { std::slice::from_raw_parts(payload_ptr, self.len) }
324 }
325}
326
327impl DerefMut for WriteClaim<'_> {
328 #[inline]
329 fn deref_mut(&mut self) -> &mut Self::Target {
330 let buffer = self.shared.buffer;
331 let payload_ptr = unsafe { buffer.add(self.offset + HEADER_SIZE) };
332 unsafe { std::slice::from_raw_parts_mut(payload_ptr, self.len) }
333 }
334}
335
336impl Drop for WriteClaim<'_> {
337 fn drop(&mut self) {
338 if !self.committed {
339 let buffer = self.shared.buffer;
341 let len_ptr = unsafe { buffer.add(self.offset) } as *mut u32;
342 let skip_len = self.record_size as u32 | SKIP_BIT;
343
344 fence(Ordering::Release);
345 unsafe {
346 ptr::write(len_ptr, skip_len);
347 }
348 }
349 }
350}
351
352pub struct Consumer {
361 head: usize,
363 shared: Arc<Shared>,
365}
366
367unsafe impl Send for Consumer {}
369
370impl Consumer {
371 #[inline]
379 pub fn try_claim(&mut self) -> Option<ReadClaim<'_>> {
380 let buffer = self.shared.buffer;
381
382 loop {
383 let offset = self.head & self.shared.mask;
384 let len_ptr = unsafe { buffer.add(offset) } as *const u32;
385
386 let len_raw = unsafe { ptr::read(len_ptr) };
388 fence(Ordering::Acquire);
389
390 if len_raw == 0 {
391 return None;
393 }
394
395 if len_raw & SKIP_BIT != 0 {
396 let skip_size = (len_raw & LEN_MASK) as usize;
398 unsafe {
399 ptr::write_bytes(buffer.add(offset), 0, skip_size);
400 }
401
402 self.head = self.head.wrapping_add(skip_size);
403
404 fence(Ordering::Release);
406 self.shared.head.store(self.head, Ordering::Relaxed);
407
408 continue;
410 }
411
412 let len = len_raw as usize;
414 let record_size = align8(HEADER_SIZE + len);
415
416 return Some(ReadClaim {
417 consumer: self,
418 offset,
419 len,
420 record_size,
421 });
422 }
423 }
424
425 #[inline]
427 pub fn capacity(&self) -> usize {
428 self.shared.capacity
429 }
430
431 #[inline]
433 pub fn is_disconnected(&self) -> bool {
434 Arc::strong_count(&self.shared) == 1
435 }
436}
437
438impl std::fmt::Debug for Consumer {
439 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
440 f.debug_struct("Consumer")
441 .field("capacity", &self.capacity())
442 .finish_non_exhaustive()
443 }
444}
445
446pub struct ReadClaim<'a> {
455 consumer: &'a mut Consumer,
456 offset: usize,
457 len: usize,
458 record_size: usize,
459}
460
461impl ReadClaim<'_> {
462 #[inline]
464 pub fn len(&self) -> usize {
465 self.len
466 }
467
468 #[inline]
470 pub fn is_empty(&self) -> bool {
471 self.len == 0
472 }
473}
474
475impl Deref for ReadClaim<'_> {
476 type Target = [u8];
477
478 #[inline]
479 fn deref(&self) -> &Self::Target {
480 let buffer = self.consumer.shared.buffer;
481 let payload_ptr = unsafe { buffer.add(self.offset + HEADER_SIZE) };
482 unsafe { std::slice::from_raw_parts(payload_ptr, self.len) }
483 }
484}
485
486impl Drop for ReadClaim<'_> {
487 fn drop(&mut self) {
488 let buffer = self.consumer.shared.buffer;
489
490 unsafe {
492 ptr::write_bytes(buffer.add(self.offset), 0, self.record_size);
493 }
494
495 self.consumer.head = self.consumer.head.wrapping_add(self.record_size);
497
498 fence(Ordering::Release);
500 self.consumer
501 .shared
502 .head
503 .store(self.consumer.head, Ordering::Relaxed);
504 }
505}
506
507#[cfg(test)]
512mod tests {
513 use super::*;
514
515 #[test]
516 fn basic_write_read() {
517 let (mut prod, mut cons) = new(1024);
518
519 let payload = b"hello world";
520 let mut claim = prod.try_claim(payload.len()).unwrap();
521 claim.copy_from_slice(payload);
522 claim.commit();
523
524 let record = cons.try_claim().unwrap();
525 assert_eq!(&*record, payload);
526 }
527
528 #[test]
529 fn empty_returns_none() {
530 let (_, mut cons) = new(1024);
531 assert!(cons.try_claim().is_none());
532 }
533
534 #[test]
535 fn multiple_records() {
536 let (mut prod, mut cons) = new(1024);
537
538 for i in 0..10 {
539 let payload = format!("message {}", i);
540 let mut claim = prod.try_claim(payload.len()).unwrap();
541 claim.copy_from_slice(payload.as_bytes());
542 claim.commit();
543 }
544
545 for i in 0..10 {
546 let record = cons.try_claim().unwrap();
547 let expected = format!("message {}", i);
548 assert_eq!(&*record, expected.as_bytes());
549 }
550
551 assert!(cons.try_claim().is_none());
552 }
553
554 #[test]
555 fn producer_is_clone() {
556 let (prod, _cons) = new(1024);
557 let _prod2 = prod.clone();
558 }
559
560 #[test]
561 fn multiple_producers_single_consumer() {
562 use std::thread;
563
564 const PRODUCERS: usize = 4;
565 const MESSAGES_PER_PRODUCER: u64 = 10_000;
566 const TOTAL: u64 = PRODUCERS as u64 * MESSAGES_PER_PRODUCER;
567
568 let (prod, mut cons) = new(64 * 1024);
569
570 let handles: Vec<_> = (0..PRODUCERS)
571 .map(|producer_id| {
572 let mut prod = prod.clone();
573 thread::spawn(move || {
574 for i in 0..MESSAGES_PER_PRODUCER {
575 let mut payload = [0u8; 16];
577 payload[..8].copy_from_slice(&(producer_id as u64).to_le_bytes());
578 payload[8..].copy_from_slice(&i.to_le_bytes());
579
580 loop {
581 match prod.try_claim(16) {
582 Ok(mut claim) => {
583 claim.copy_from_slice(&payload);
584 claim.commit();
585 break;
586 }
587 Err(_) => std::hint::spin_loop(),
588 }
589 }
590 }
591 })
592 })
593 .collect();
594
595 drop(prod);
597
598 let consumer = thread::spawn(move || {
600 let mut received = 0u64;
601 let mut per_producer = vec![0u64; PRODUCERS];
602
603 while received < TOTAL {
604 if let Some(record) = cons.try_claim() {
605 let producer_id = u64::from_le_bytes(record[..8].try_into().unwrap()) as usize;
606 let seq = u64::from_le_bytes(record[8..].try_into().unwrap());
607
608 assert_eq!(
610 seq, per_producer[producer_id],
611 "producer {} out of order",
612 producer_id
613 );
614 per_producer[producer_id] += 1;
615 received += 1;
616 } else {
617 std::hint::spin_loop();
618 }
619 }
620
621 per_producer
622 });
623
624 for h in handles {
625 h.join().unwrap();
626 }
627
628 let per_producer = consumer.join().unwrap();
629 for (i, &count) in per_producer.iter().enumerate() {
630 assert_eq!(count, MESSAGES_PER_PRODUCER, "producer {} count", i);
631 }
632 }
633
634 #[test]
635 fn aborted_claim_creates_skip() {
636 let (mut prod, mut cons) = new(1024);
637
638 {
640 let mut claim = prod.try_claim(10).unwrap();
641 claim.copy_from_slice(b"0123456789");
642 }
644
645 {
647 let mut claim = prod.try_claim(5).unwrap();
648 claim.copy_from_slice(b"hello");
649 claim.commit();
650 }
651
652 let record = cons.try_claim().unwrap();
654 assert_eq!(&*record, b"hello");
655 }
656
657 #[test]
658 fn wrap_around() {
659 let (mut prod, mut cons) = new(64);
660
661 for i in 0..20 {
663 let payload = format!("msg{:02}", i);
664 loop {
665 match prod.try_claim(payload.len()) {
666 Ok(mut claim) => {
667 claim.copy_from_slice(payload.as_bytes());
668 claim.commit();
669 break;
670 }
671 Err(_) => {
672 while cons.try_claim().is_some() {}
674 }
675 }
676 }
677 }
678 }
679
680 #[test]
681 fn full_returns_error() {
682 let (mut prod, _cons) = new(64);
683
684 let mut count = 0;
686 loop {
687 match prod.try_claim(8) {
688 Ok(mut claim) => {
689 claim.copy_from_slice(b"12345678");
690 claim.commit();
691 count += 1;
692 }
693 Err(_) => break,
694 }
695 }
696
697 assert!(count > 0);
698 assert!(prod.try_claim(8).is_err());
699 }
700
701 #[test]
702 fn disconnection_detection() {
703 let (prod, cons) = new(1024);
704
705 assert!(!prod.is_disconnected());
706 assert!(!cons.is_disconnected());
707
708 drop(cons);
709 assert!(prod.is_disconnected());
710 }
711
712 #[test]
713 #[should_panic(expected = "capacity must be at least 16")]
714 fn tiny_capacity_panics() {
715 let _ = new(8);
716 }
717
718 #[test]
719 fn zero_len_returns_error() {
720 let (mut prod, _) = new(1024);
721 assert!(matches!(prod.try_claim(0), Err(TryClaimError::ZeroLength)));
722 }
723
724 #[test]
725 fn capacity_rounds_to_power_of_two() {
726 let (prod, _) = new(100);
727 assert_eq!(prod.capacity(), 128);
728
729 let (prod, _) = new(1000);
730 assert_eq!(prod.capacity(), 1024);
731 }
732
733 #[test]
735 fn stress_multiple_producers() {
736 use std::thread;
737
738 const PRODUCERS: usize = 4;
739 const COUNT_PER_PRODUCER: u64 = 100_000;
740 const TOTAL: u64 = PRODUCERS as u64 * COUNT_PER_PRODUCER;
741 const BUFFER_SIZE: usize = 64 * 1024;
742
743 let (prod, mut cons) = new(BUFFER_SIZE);
744
745 let handles: Vec<_> = (0..PRODUCERS)
746 .map(|_| {
747 let mut prod = prod.clone();
748 thread::spawn(move || {
749 for i in 0..COUNT_PER_PRODUCER {
750 let payload = i.to_le_bytes();
751 loop {
752 match prod.try_claim(payload.len()) {
753 Ok(mut claim) => {
754 claim.copy_from_slice(&payload);
755 claim.commit();
756 break;
757 }
758 Err(_) => std::hint::spin_loop(),
759 }
760 }
761 }
762 })
763 })
764 .collect();
765
766 drop(prod);
767
768 let consumer = thread::spawn(move || {
769 let mut received = 0u64;
770 let mut sum = 0u64;
771 while received < TOTAL {
772 if let Some(record) = cons.try_claim() {
773 let value = u64::from_le_bytes((*record).try_into().unwrap());
774 sum = sum.wrapping_add(value);
775 received += 1;
776 } else {
777 std::hint::spin_loop();
778 }
779 }
780 (received, sum)
781 });
782
783 for h in handles {
784 h.join().unwrap();
785 }
786
787 let (received, sum) = consumer.join().unwrap();
788 assert_eq!(received, TOTAL);
789
790 let expected_sum = PRODUCERS as u64 * COUNT_PER_PRODUCER * (COUNT_PER_PRODUCER - 1) / 2;
793 assert_eq!(sum, expected_sum);
794 }
795}