1use std::cell::UnsafeCell;
67use std::fmt;
68use std::mem::MaybeUninit;
69use std::sync::Arc;
70use std::sync::atomic::{AtomicUsize, Ordering};
71
72use crossbeam_utils::CachePadded;
73
74use crate::Full;
75
76pub fn bounded<T>(capacity: usize) -> (Producer<T>, Consumer<T>) {
84 assert!(capacity > 0, "capacity must be non-zero");
85
86 let capacity = capacity.next_power_of_two();
87 let mask = capacity - 1;
88
89 let slots: Vec<Slot<T>> = (0..capacity)
91 .map(|_| Slot {
92 turn: AtomicUsize::new(0),
93 data: UnsafeCell::new(MaybeUninit::uninit()),
94 })
95 .collect();
96 let slots = Box::into_raw(slots.into_boxed_slice()) as *mut Slot<T>;
97
98 let shared = Arc::new(Shared {
99 tail: CachePadded::new(AtomicUsize::new(0)),
100 head: CachePadded::new(AtomicUsize::new(0)),
101 slots,
102 capacity,
103 mask,
104 });
105
106 (
107 Producer {
108 cached_head: 0,
109 slots,
110 mask,
111 capacity,
112 shared: Arc::clone(&shared),
113 },
114 Consumer {
115 local_head: 0,
116 shared,
117 },
118 )
119}
120
121struct Slot<T> {
123 turn: AtomicUsize,
127 data: UnsafeCell<MaybeUninit<T>>,
129}
130
131#[repr(C)]
134struct Shared<T> {
135 tail: CachePadded<AtomicUsize>,
137 head: CachePadded<AtomicUsize>,
139 slots: *mut Slot<T>,
141 capacity: usize,
143 mask: usize,
145}
146
147unsafe impl<T: Send> Send for Shared<T> {}
150unsafe impl<T: Send> Sync for Shared<T> {}
151
152impl<T> Drop for Shared<T> {
153 fn drop(&mut self) {
154 let head = self.head.load(Ordering::Relaxed);
155 let tail = self.tail.load(Ordering::Relaxed);
156
157 let mut i = head;
159 while i != tail {
160 let slot = unsafe { &*self.slots.add(i & self.mask) };
161 let turn = i / self.capacity;
162
163 if slot.turn.load(Ordering::Relaxed) == turn * 2 + 1 {
165 unsafe { (*slot.data.get()).assume_init_drop() };
167 }
168 i = i.wrapping_add(1);
169 }
170
171 unsafe {
173 let _ = Box::from_raw(std::ptr::slice_from_raw_parts_mut(
174 self.slots,
175 self.capacity,
176 ));
177 }
178 }
179}
180
181#[repr(C)]
187pub struct Producer<T> {
188 cached_head: usize,
190 slots: *mut Slot<T>,
192 mask: usize,
194 capacity: usize,
196 shared: Arc<Shared<T>>,
197}
198
199impl<T> Clone for Producer<T> {
200 fn clone(&self) -> Self {
201 Producer {
202 cached_head: self.shared.head.load(Ordering::Relaxed),
204 slots: self.slots,
205 mask: self.mask,
206 capacity: self.capacity,
207 shared: Arc::clone(&self.shared),
208 }
209 }
210}
211
212unsafe impl<T: Send> Send for Producer<T> {}
215
216impl<T> Producer<T> {
217 #[inline]
222 #[must_use = "push returns Err if full, which should be handled"]
223 pub fn push(&mut self, value: T) -> Result<(), Full<T>> {
224 let mut backoff = Backoff::new();
225
226 loop {
227 let tail = self.shared.tail.load(Ordering::Relaxed);
228
229 if tail.wrapping_sub(self.cached_head) >= self.capacity {
231 self.cached_head = self.shared.head.load(Ordering::Acquire);
233
234 if tail.wrapping_sub(self.cached_head) >= self.capacity {
236 return Err(Full(value));
237 }
238 }
239
240 if self
242 .shared
243 .tail
244 .compare_exchange_weak(
245 tail,
246 tail.wrapping_add(1),
247 Ordering::Relaxed,
248 Ordering::Relaxed,
249 )
250 .is_ok()
251 {
252 let slot = unsafe { &*self.slots.add(tail & self.mask) };
254 let turn = tail / self.capacity;
255
256 while slot.turn.load(Ordering::Acquire) != turn * 2 {
259 std::hint::spin_loop();
260 }
261
262 unsafe { (*slot.data.get()).write(value) };
264
265 slot.turn.store(turn * 2 + 1, Ordering::Release);
267
268 return Ok(());
269 }
270
271 backoff.spin();
273 }
274 }
275
276 #[inline]
278 pub fn capacity(&self) -> usize {
279 self.capacity
280 }
281
282 #[inline]
284 pub fn is_disconnected(&self) -> bool {
285 Arc::strong_count(&self.shared) == 1
291 }
292}
293
294impl<T> fmt::Debug for Producer<T> {
295 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
296 f.debug_struct("Producer")
297 .field("capacity", &self.capacity())
298 .finish_non_exhaustive()
299 }
300}
301
302pub struct Consumer<T> {
306 local_head: usize,
308 shared: Arc<Shared<T>>,
309}
310
311unsafe impl<T: Send> Send for Consumer<T> {}
314
315impl<T> Consumer<T> {
316 #[inline]
320 pub fn pop(&mut self) -> Option<T> {
321 let head = self.local_head;
322 let slot = unsafe { &*self.shared.slots.add(head & self.shared.mask) };
323 let turn = head / self.shared.capacity;
324
325 if slot.turn.load(Ordering::Acquire) != turn * 2 + 1 {
327 return None;
328 }
329
330 let value = unsafe { (*slot.data.get()).assume_init_read() };
332
333 slot.turn.store((turn + 1) * 2, Ordering::Release);
335
336 self.local_head = head.wrapping_add(1);
338 self.shared.head.store(self.local_head, Ordering::Release);
339
340 Some(value)
341 }
342
343 #[inline]
345 pub fn capacity(&self) -> usize {
346 self.shared.capacity
347 }
348
349 #[inline]
351 pub fn is_disconnected(&self) -> bool {
352 Arc::strong_count(&self.shared) == 1
353 }
354}
355
356impl<T> fmt::Debug for Consumer<T> {
357 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
358 f.debug_struct("Consumer")
359 .field("capacity", &self.capacity())
360 .finish_non_exhaustive()
361 }
362}
363
364struct Backoff {
366 step: u32,
367}
368
369impl Backoff {
370 #[inline]
371 fn new() -> Self {
372 Self { step: 0 }
373 }
374
375 #[inline]
376 fn spin(&mut self) {
377 for _ in 0..(1 << self.step.min(6)) {
378 std::hint::spin_loop();
379 }
380 self.step += 1;
381 }
382}
383
384#[cfg(test)]
385mod tests {
386 use super::*;
387
388 #[test]
393 fn basic_push_pop() {
394 let (mut tx, mut rx) = bounded::<u64>(4);
395
396 assert!(tx.push(1).is_ok());
397 assert!(tx.push(2).is_ok());
398 assert!(tx.push(3).is_ok());
399
400 assert_eq!(rx.pop(), Some(1));
401 assert_eq!(rx.pop(), Some(2));
402 assert_eq!(rx.pop(), Some(3));
403 assert_eq!(rx.pop(), None);
404 }
405
406 #[test]
407 fn empty_pop_returns_none() {
408 let (_, mut rx) = bounded::<u64>(4);
409 assert_eq!(rx.pop(), None);
410 assert_eq!(rx.pop(), None);
411 }
412
413 #[test]
414 fn fill_then_drain() {
415 let (mut tx, mut rx) = bounded::<u64>(4);
416
417 for i in 0..4 {
418 assert!(tx.push(i).is_ok());
419 }
420
421 for i in 0..4 {
422 assert_eq!(rx.pop(), Some(i));
423 }
424
425 assert_eq!(rx.pop(), None);
426 }
427
428 #[test]
429 fn push_returns_error_when_full() {
430 let (mut tx, _rx) = bounded::<u64>(4);
431
432 assert!(tx.push(1).is_ok());
433 assert!(tx.push(2).is_ok());
434 assert!(tx.push(3).is_ok());
435 assert!(tx.push(4).is_ok());
436
437 let err = tx.push(5).unwrap_err();
438 assert_eq!(err.into_inner(), 5);
439 }
440
441 #[test]
446 fn interleaved_single_producer() {
447 let (mut tx, mut rx) = bounded::<u64>(8);
448
449 for i in 0..1000 {
450 assert!(tx.push(i).is_ok());
451 assert_eq!(rx.pop(), Some(i));
452 }
453 }
454
455 #[test]
456 fn partial_fill_drain_cycles() {
457 let (mut tx, mut rx) = bounded::<u64>(8);
458
459 for round in 0..100 {
460 for i in 0..4 {
461 assert!(tx.push(round * 4 + i).is_ok());
462 }
463
464 for i in 0..4 {
465 assert_eq!(rx.pop(), Some(round * 4 + i));
466 }
467 }
468 }
469
470 #[test]
475 fn two_producers_single_consumer() {
476 use std::thread;
477
478 let (mut tx, mut rx) = bounded::<u64>(64);
479 let mut tx2 = tx.clone();
480
481 let h1 = thread::spawn(move || {
482 for i in 0..1000 {
483 while tx.push(i).is_err() {
484 std::hint::spin_loop();
485 }
486 }
487 });
488
489 let h2 = thread::spawn(move || {
490 for i in 1000..2000 {
491 while tx2.push(i).is_err() {
492 std::hint::spin_loop();
493 }
494 }
495 });
496
497 let mut received = Vec::new();
498 while received.len() < 2000 {
499 if let Some(val) = rx.pop() {
500 received.push(val);
501 } else {
502 std::hint::spin_loop();
503 }
504 }
505
506 h1.join().unwrap();
507 h2.join().unwrap();
508
509 received.sort();
511 assert_eq!(received, (0..2000).collect::<Vec<_>>());
512 }
513
514 #[test]
515 fn four_producers_single_consumer() {
516 use std::thread;
517
518 let (tx, mut rx) = bounded::<u64>(256);
519
520 let handles: Vec<_> = (0..4)
521 .map(|p| {
522 let mut tx = tx.clone();
523 thread::spawn(move || {
524 for i in 0..1000 {
525 let val = p * 1000 + i;
526 while tx.push(val).is_err() {
527 std::hint::spin_loop();
528 }
529 }
530 })
531 })
532 .collect();
533
534 drop(tx); let mut received = Vec::new();
537 while received.len() < 4000 {
538 if let Some(val) = rx.pop() {
539 received.push(val);
540 } else if rx.is_disconnected() && received.len() < 4000 {
541 std::hint::spin_loop();
543 } else {
544 std::hint::spin_loop();
545 }
546 }
547
548 for h in handles {
549 h.join().unwrap();
550 }
551
552 received.sort();
553 let expected: Vec<u64> = (0..4)
554 .flat_map(|p| (0..1000).map(move |i| p * 1000 + i))
555 .collect();
556 let mut expected_sorted = expected;
557 expected_sorted.sort();
558 assert_eq!(received, expected_sorted);
559 }
560
561 #[test]
566 fn single_slot_bounded() {
567 let (mut tx, mut rx) = bounded::<u64>(1);
568
569 assert!(tx.push(1).is_ok());
570 assert!(tx.push(2).is_err());
571
572 assert_eq!(rx.pop(), Some(1));
573 assert!(tx.push(2).is_ok());
574 }
575
576 #[test]
581 fn producer_disconnected() {
582 let (tx, rx) = bounded::<u64>(4);
583
584 assert!(!rx.is_disconnected());
585 drop(tx);
586 assert!(rx.is_disconnected());
587 }
588
589 #[test]
590 fn consumer_disconnected() {
591 let (tx, rx) = bounded::<u64>(4);
592
593 assert!(!tx.is_disconnected());
594 drop(rx);
595 assert!(tx.is_disconnected());
596 }
597
598 #[test]
599 fn multiple_producers_one_disconnects() {
600 let (tx1, rx) = bounded::<u64>(4);
601 let tx2 = tx1.clone();
602
603 assert!(!rx.is_disconnected());
604 drop(tx1);
605 assert!(!rx.is_disconnected()); drop(tx2);
607 assert!(rx.is_disconnected());
608 }
609
610 #[test]
615 fn drop_cleans_up_remaining() {
616 use std::sync::atomic::AtomicUsize;
617
618 static DROP_COUNT: AtomicUsize = AtomicUsize::new(0);
619
620 struct DropCounter;
621 impl Drop for DropCounter {
622 fn drop(&mut self) {
623 DROP_COUNT.fetch_add(1, Ordering::SeqCst);
624 }
625 }
626
627 DROP_COUNT.store(0, Ordering::SeqCst);
628
629 let (mut tx, rx) = bounded::<DropCounter>(4);
630
631 let _ = tx.push(DropCounter);
632 let _ = tx.push(DropCounter);
633 let _ = tx.push(DropCounter);
634
635 assert_eq!(DROP_COUNT.load(Ordering::SeqCst), 0);
636
637 drop(tx);
638 drop(rx);
639
640 assert_eq!(DROP_COUNT.load(Ordering::SeqCst), 3);
641 }
642
643 #[test]
648 fn zero_sized_type() {
649 let (mut tx, mut rx) = bounded::<()>(8);
650
651 let _ = tx.push(());
652 let _ = tx.push(());
653
654 assert_eq!(rx.pop(), Some(()));
655 assert_eq!(rx.pop(), Some(()));
656 assert_eq!(rx.pop(), None);
657 }
658
659 #[test]
660 fn string_type() {
661 let (mut tx, mut rx) = bounded::<String>(4);
662
663 let _ = tx.push("hello".to_string());
664 let _ = tx.push("world".to_string());
665
666 assert_eq!(rx.pop(), Some("hello".to_string()));
667 assert_eq!(rx.pop(), Some("world".to_string()));
668 }
669
670 #[test]
671 #[should_panic(expected = "capacity must be non-zero")]
672 fn zero_capacity_panics() {
673 let _ = bounded::<u64>(0);
674 }
675
676 #[test]
677 fn large_message_type() {
678 #[repr(C, align(64))]
679 struct LargeMessage {
680 data: [u8; 256],
681 }
682
683 let (mut tx, mut rx) = bounded::<LargeMessage>(8);
684
685 let msg = LargeMessage { data: [42u8; 256] };
686 assert!(tx.push(msg).is_ok());
687
688 let received = rx.pop().unwrap();
689 assert_eq!(received.data[0], 42);
690 assert_eq!(received.data[255], 42);
691 }
692
693 #[test]
694 fn multiple_laps() {
695 let (mut tx, mut rx) = bounded::<u64>(4);
696
697 for i in 0..40 {
699 assert!(tx.push(i).is_ok());
700 assert_eq!(rx.pop(), Some(i));
701 }
702 }
703
704 #[test]
705 fn capacity_rounds_to_power_of_two() {
706 let (tx, _) = bounded::<u64>(100);
707 assert_eq!(tx.capacity(), 128);
708
709 let (tx, _) = bounded::<u64>(1000);
710 assert_eq!(tx.capacity(), 1024);
711 }
712
713 #[test]
718 fn stress_single_producer() {
719 use std::thread;
720
721 const COUNT: u64 = 100_000;
722
723 let (mut tx, mut rx) = bounded::<u64>(1024);
724
725 let producer = thread::spawn(move || {
726 for i in 0..COUNT {
727 while tx.push(i).is_err() {
728 std::hint::spin_loop();
729 }
730 }
731 });
732
733 let consumer = thread::spawn(move || {
734 let mut sum = 0u64;
735 let mut received = 0u64;
736 while received < COUNT {
737 if let Some(val) = rx.pop() {
738 sum = sum.wrapping_add(val);
739 received += 1;
740 } else {
741 std::hint::spin_loop();
742 }
743 }
744 sum
745 });
746
747 producer.join().unwrap();
748 let sum = consumer.join().unwrap();
749 assert_eq!(sum, COUNT * (COUNT - 1) / 2);
750 }
751
752 #[test]
753 fn stress_multiple_producers() {
754 use std::thread;
755
756 const PRODUCERS: u64 = 4;
757 const PER_PRODUCER: u64 = 25_000;
758 const TOTAL: u64 = PRODUCERS * PER_PRODUCER;
759
760 let (tx, mut rx) = bounded::<u64>(1024);
761
762 let handles: Vec<_> = (0..PRODUCERS)
763 .map(|_| {
764 let mut tx = tx.clone();
765 thread::spawn(move || {
766 for i in 0..PER_PRODUCER {
767 while tx.push(i).is_err() {
768 std::hint::spin_loop();
769 }
770 }
771 })
772 })
773 .collect();
774
775 drop(tx);
776
777 let mut received = 0u64;
778 while received < TOTAL {
779 if rx.pop().is_some() {
780 received += 1;
781 } else {
782 std::hint::spin_loop();
783 }
784 }
785
786 for h in handles {
787 h.join().unwrap();
788 }
789
790 assert_eq!(received, TOTAL);
791 }
792}