1use std::cell::UnsafeCell;
67use std::fmt;
68use std::mem::MaybeUninit;
69use std::sync::Arc;
70use std::sync::atomic::{AtomicUsize, Ordering};
71
72use crossbeam_utils::CachePadded;
73
74use crate::Full;
75
76pub fn bounded<T>(capacity: usize) -> (Producer<T>, Consumer<T>) {
84 assert!(capacity > 0, "capacity must be non-zero");
85
86 let capacity = capacity.next_power_of_two();
87 let mask = capacity - 1;
88
89 let slots: Vec<Slot<T>> = (0..capacity)
91 .map(|_| Slot {
92 turn: AtomicUsize::new(0),
93 data: UnsafeCell::new(MaybeUninit::uninit()),
94 })
95 .collect();
96 let slots = Box::into_raw(slots.into_boxed_slice()) as *mut Slot<T>;
97
98 let shift = capacity.trailing_zeros();
99
100 let shared = Arc::new(Shared {
101 tail: CachePadded::new(AtomicUsize::new(0)),
102 head: CachePadded::new(AtomicUsize::new(0)),
103 slots,
104 capacity,
105 shift,
106 mask,
107 });
108
109 (
110 Producer {
111 cached_head: 0,
112 slots,
113 mask,
114 capacity,
115 shift,
116 shared: Arc::clone(&shared),
117 },
118 Consumer {
119 local_head: 0,
120 slots,
121 mask,
122 shift,
123 shared,
124 },
125 )
126}
127
128struct Slot<T> {
130 turn: AtomicUsize,
134 data: UnsafeCell<MaybeUninit<T>>,
136}
137
138#[repr(C)]
141struct Shared<T> {
142 tail: CachePadded<AtomicUsize>,
144 head: CachePadded<AtomicUsize>,
146 slots: *mut Slot<T>,
148 capacity: usize,
150 shift: u32,
152 mask: usize,
154}
155
156unsafe impl<T: Send> Send for Shared<T> {}
159unsafe impl<T: Send> Sync for Shared<T> {}
160
161impl<T> Drop for Shared<T> {
162 fn drop(&mut self) {
163 let head = self.head.load(Ordering::Relaxed);
164 let tail = self.tail.load(Ordering::Relaxed);
165
166 let mut i = head;
168 while i != tail {
169 let slot = unsafe { &*self.slots.add(i & self.mask) };
170 let turn = i >> self.shift;
171
172 if slot.turn.load(Ordering::Relaxed) == turn * 2 + 1 {
174 unsafe { (*slot.data.get()).assume_init_drop() };
176 }
177 i = i.wrapping_add(1);
178 }
179
180 unsafe {
182 let _ = Box::from_raw(std::ptr::slice_from_raw_parts_mut(
183 self.slots,
184 self.capacity,
185 ));
186 }
187 }
188}
189
190#[repr(C)]
196pub struct Producer<T> {
197 cached_head: usize,
199 slots: *mut Slot<T>,
201 mask: usize,
203 capacity: usize,
205 shift: u32,
207 shared: Arc<Shared<T>>,
208}
209
210impl<T> Clone for Producer<T> {
211 fn clone(&self) -> Self {
212 Producer {
213 cached_head: self.shared.head.load(Ordering::Relaxed),
215 slots: self.slots,
216 mask: self.mask,
217 capacity: self.capacity,
218 shift: self.shift,
219 shared: Arc::clone(&self.shared),
220 }
221 }
222}
223
224unsafe impl<T: Send> Send for Producer<T> {}
227
228impl<T> Producer<T> {
229 #[inline]
237 #[must_use = "push returns Err if full, which should be handled"]
238 pub fn push(&mut self, value: T) -> Result<(), Full<T>> {
239 let mut spin_count = 0u32;
240
241 loop {
242 let tail = self.shared.tail.load(Ordering::Relaxed);
243
244 if tail.wrapping_sub(self.cached_head) >= self.capacity {
246 self.cached_head = self.shared.head.load(Ordering::Acquire);
248
249 if tail.wrapping_sub(self.cached_head) >= self.capacity {
251 return Err(Full(value));
252 }
253 }
254
255 let slot = unsafe { &*self.slots.add(tail & self.mask) };
257 let turn = tail >> self.shift;
258 let expected_stamp = turn * 2;
259
260 let stamp = slot.turn.load(Ordering::Acquire);
262
263 if stamp == expected_stamp {
264 if self
266 .shared
267 .tail
268 .compare_exchange_weak(
269 tail,
270 tail.wrapping_add(1),
271 Ordering::Relaxed,
272 Ordering::Relaxed,
273 )
274 .is_ok()
275 {
276 unsafe { (*slot.data.get()).write(value) };
278
279 slot.turn.store(turn * 2 + 1, Ordering::Release);
281
282 return Ok(());
283 }
284 }
285
286 let spins = 1 << spin_count.min(6);
289 for _ in 0..spins {
290 std::hint::spin_loop();
291 }
292 spin_count += 1;
293 }
294 }
295
296 #[inline]
298 pub fn capacity(&self) -> usize {
299 1 << self.shift
300 }
301
302 #[inline]
307 pub fn is_disconnected(&self) -> bool {
308 Arc::strong_count(&self.shared) == 1
309 }
310}
311
312impl<T> fmt::Debug for Producer<T> {
313 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
314 f.debug_struct("Producer")
315 .field("capacity", &self.capacity())
316 .finish_non_exhaustive()
317 }
318}
319
320#[repr(C)]
325pub struct Consumer<T> {
326 local_head: usize,
328 slots: *mut Slot<T>,
330 mask: usize,
332 shift: u32,
334 shared: Arc<Shared<T>>,
335}
336
337unsafe impl<T: Send> Send for Consumer<T> {}
340
341impl<T> Consumer<T> {
342 #[inline]
346 pub fn pop(&mut self) -> Option<T> {
347 let head = self.local_head;
348 let slot = unsafe { &*self.slots.add(head & self.mask) };
350 let turn = head >> self.shift;
351
352 if slot.turn.load(Ordering::Acquire) != turn * 2 + 1 {
354 return None;
355 }
356
357 let value = unsafe { (*slot.data.get()).assume_init_read() };
359
360 slot.turn.store((turn + 1) * 2, Ordering::Release);
362
363 self.local_head = head.wrapping_add(1);
365 self.shared.head.store(self.local_head, Ordering::Release);
366
367 Some(value)
368 }
369
370 #[inline]
372 pub fn capacity(&self) -> usize {
373 1 << self.shift
374 }
375
376 #[inline]
378 pub fn is_disconnected(&self) -> bool {
379 Arc::strong_count(&self.shared) == 1
380 }
381}
382
383impl<T> fmt::Debug for Consumer<T> {
384 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
385 f.debug_struct("Consumer")
386 .field("capacity", &self.capacity())
387 .finish_non_exhaustive()
388 }
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394
395 #[test]
400 fn basic_push_pop() {
401 let (mut tx, mut rx) = bounded::<u64>(4);
402
403 assert!(tx.push(1).is_ok());
404 assert!(tx.push(2).is_ok());
405 assert!(tx.push(3).is_ok());
406
407 assert_eq!(rx.pop(), Some(1));
408 assert_eq!(rx.pop(), Some(2));
409 assert_eq!(rx.pop(), Some(3));
410 assert_eq!(rx.pop(), None);
411 }
412
413 #[test]
414 fn empty_pop_returns_none() {
415 let (_, mut rx) = bounded::<u64>(4);
416 assert_eq!(rx.pop(), None);
417 assert_eq!(rx.pop(), None);
418 }
419
420 #[test]
421 fn fill_then_drain() {
422 let (mut tx, mut rx) = bounded::<u64>(4);
423
424 for i in 0..4 {
425 assert!(tx.push(i).is_ok());
426 }
427
428 for i in 0..4 {
429 assert_eq!(rx.pop(), Some(i));
430 }
431
432 assert_eq!(rx.pop(), None);
433 }
434
435 #[test]
436 fn push_returns_error_when_full() {
437 let (mut tx, _rx) = bounded::<u64>(4);
438
439 assert!(tx.push(1).is_ok());
440 assert!(tx.push(2).is_ok());
441 assert!(tx.push(3).is_ok());
442 assert!(tx.push(4).is_ok());
443
444 let err = tx.push(5).unwrap_err();
445 assert_eq!(err.into_inner(), 5);
446 }
447
448 #[test]
453 fn interleaved_single_producer() {
454 let (mut tx, mut rx) = bounded::<u64>(8);
455
456 for i in 0..1000 {
457 assert!(tx.push(i).is_ok());
458 assert_eq!(rx.pop(), Some(i));
459 }
460 }
461
462 #[test]
463 fn partial_fill_drain_cycles() {
464 let (mut tx, mut rx) = bounded::<u64>(8);
465
466 for round in 0..100 {
467 for i in 0..4 {
468 assert!(tx.push(round * 4 + i).is_ok());
469 }
470
471 for i in 0..4 {
472 assert_eq!(rx.pop(), Some(round * 4 + i));
473 }
474 }
475 }
476
477 #[test]
482 fn two_producers_single_consumer() {
483 use std::thread;
484
485 let (mut tx, mut rx) = bounded::<u64>(64);
486 let mut tx2 = tx.clone();
487
488 let h1 = thread::spawn(move || {
489 for i in 0..1000 {
490 while tx.push(i).is_err() {
491 std::hint::spin_loop();
492 }
493 }
494 });
495
496 let h2 = thread::spawn(move || {
497 for i in 1000..2000 {
498 while tx2.push(i).is_err() {
499 std::hint::spin_loop();
500 }
501 }
502 });
503
504 let mut received = Vec::new();
505 while received.len() < 2000 {
506 if let Some(val) = rx.pop() {
507 received.push(val);
508 } else {
509 std::hint::spin_loop();
510 }
511 }
512
513 h1.join().unwrap();
514 h2.join().unwrap();
515
516 received.sort();
518 assert_eq!(received, (0..2000).collect::<Vec<_>>());
519 }
520
521 #[test]
522 fn four_producers_single_consumer() {
523 use std::thread;
524
525 let (tx, mut rx) = bounded::<u64>(256);
526
527 let handles: Vec<_> = (0..4)
528 .map(|p| {
529 let mut tx = tx.clone();
530 thread::spawn(move || {
531 for i in 0..1000 {
532 let val = p * 1000 + i;
533 while tx.push(val).is_err() {
534 std::hint::spin_loop();
535 }
536 }
537 })
538 })
539 .collect();
540
541 drop(tx); let mut received = Vec::new();
544 while received.len() < 4000 {
545 if let Some(val) = rx.pop() {
546 received.push(val);
547 } else if rx.is_disconnected() && received.len() < 4000 {
548 std::hint::spin_loop();
550 } else {
551 std::hint::spin_loop();
552 }
553 }
554
555 for h in handles {
556 h.join().unwrap();
557 }
558
559 received.sort();
560 let expected: Vec<u64> = (0..4)
561 .flat_map(|p| (0..1000).map(move |i| p * 1000 + i))
562 .collect();
563 let mut expected_sorted = expected;
564 expected_sorted.sort();
565 assert_eq!(received, expected_sorted);
566 }
567
568 #[test]
573 fn single_slot_bounded() {
574 let (mut tx, mut rx) = bounded::<u64>(1);
575
576 assert!(tx.push(1).is_ok());
577 assert!(tx.push(2).is_err());
578
579 assert_eq!(rx.pop(), Some(1));
580 assert!(tx.push(2).is_ok());
581 }
582
583 #[test]
588 fn producer_disconnected() {
589 let (tx, rx) = bounded::<u64>(4);
590
591 assert!(!rx.is_disconnected());
592 drop(tx);
593 assert!(rx.is_disconnected());
594 }
595
596 #[test]
597 fn consumer_disconnected() {
598 let (tx, rx) = bounded::<u64>(4);
599
600 assert!(!tx.is_disconnected());
601 drop(rx);
602 assert!(tx.is_disconnected());
603 }
604
605 #[test]
606 fn multiple_producers_one_disconnects() {
607 let (tx1, rx) = bounded::<u64>(4);
608 let tx2 = tx1.clone();
609
610 assert!(!rx.is_disconnected());
611 drop(tx1);
612 assert!(!rx.is_disconnected()); drop(tx2);
614 assert!(rx.is_disconnected());
615 }
616
617 #[test]
622 fn drop_cleans_up_remaining() {
623 use std::sync::atomic::AtomicUsize;
624
625 static DROP_COUNT: AtomicUsize = AtomicUsize::new(0);
626
627 struct DropCounter;
628 impl Drop for DropCounter {
629 fn drop(&mut self) {
630 DROP_COUNT.fetch_add(1, Ordering::SeqCst);
631 }
632 }
633
634 DROP_COUNT.store(0, Ordering::SeqCst);
635
636 let (mut tx, rx) = bounded::<DropCounter>(4);
637
638 let _ = tx.push(DropCounter);
639 let _ = tx.push(DropCounter);
640 let _ = tx.push(DropCounter);
641
642 assert_eq!(DROP_COUNT.load(Ordering::SeqCst), 0);
643
644 drop(tx);
645 drop(rx);
646
647 assert_eq!(DROP_COUNT.load(Ordering::SeqCst), 3);
648 }
649
650 #[test]
655 fn zero_sized_type() {
656 let (mut tx, mut rx) = bounded::<()>(8);
657
658 let _ = tx.push(());
659 let _ = tx.push(());
660
661 assert_eq!(rx.pop(), Some(()));
662 assert_eq!(rx.pop(), Some(()));
663 assert_eq!(rx.pop(), None);
664 }
665
666 #[test]
667 fn string_type() {
668 let (mut tx, mut rx) = bounded::<String>(4);
669
670 let _ = tx.push("hello".to_string());
671 let _ = tx.push("world".to_string());
672
673 assert_eq!(rx.pop(), Some("hello".to_string()));
674 assert_eq!(rx.pop(), Some("world".to_string()));
675 }
676
677 #[test]
678 #[should_panic(expected = "capacity must be non-zero")]
679 fn zero_capacity_panics() {
680 let _ = bounded::<u64>(0);
681 }
682
683 #[test]
684 fn large_message_type() {
685 #[repr(C, align(64))]
686 struct LargeMessage {
687 data: [u8; 256],
688 }
689
690 let (mut tx, mut rx) = bounded::<LargeMessage>(8);
691
692 let msg = LargeMessage { data: [42u8; 256] };
693 assert!(tx.push(msg).is_ok());
694
695 let received = rx.pop().unwrap();
696 assert_eq!(received.data[0], 42);
697 assert_eq!(received.data[255], 42);
698 }
699
700 #[test]
701 fn multiple_laps() {
702 let (mut tx, mut rx) = bounded::<u64>(4);
703
704 for i in 0..40 {
706 assert!(tx.push(i).is_ok());
707 assert_eq!(rx.pop(), Some(i));
708 }
709 }
710
711 #[test]
712 fn capacity_rounds_to_power_of_two() {
713 let (tx, _) = bounded::<u64>(100);
714 assert_eq!(tx.capacity(), 128);
715
716 let (tx, _) = bounded::<u64>(1000);
717 assert_eq!(tx.capacity(), 1024);
718 }
719
720 #[test]
725 fn stress_single_producer() {
726 use std::thread;
727
728 const COUNT: u64 = 100_000;
729
730 let (mut tx, mut rx) = bounded::<u64>(1024);
731
732 let producer = thread::spawn(move || {
733 for i in 0..COUNT {
734 while tx.push(i).is_err() {
735 std::hint::spin_loop();
736 }
737 }
738 });
739
740 let consumer = thread::spawn(move || {
741 let mut sum = 0u64;
742 let mut received = 0u64;
743 while received < COUNT {
744 if let Some(val) = rx.pop() {
745 sum = sum.wrapping_add(val);
746 received += 1;
747 } else {
748 std::hint::spin_loop();
749 }
750 }
751 sum
752 });
753
754 producer.join().unwrap();
755 let sum = consumer.join().unwrap();
756 assert_eq!(sum, COUNT * (COUNT - 1) / 2);
757 }
758
759 #[test]
760 fn stress_multiple_producers() {
761 use std::thread;
762
763 const PRODUCERS: u64 = 4;
764 const PER_PRODUCER: u64 = 25_000;
765 const TOTAL: u64 = PRODUCERS * PER_PRODUCER;
766
767 let (tx, mut rx) = bounded::<u64>(1024);
768
769 let handles: Vec<_> = (0..PRODUCERS)
770 .map(|_| {
771 let mut tx = tx.clone();
772 thread::spawn(move || {
773 for i in 0..PER_PRODUCER {
774 while tx.push(i).is_err() {
775 std::hint::spin_loop();
776 }
777 }
778 })
779 })
780 .collect();
781
782 drop(tx);
783
784 let mut received = 0u64;
785 while received < TOTAL {
786 if rx.pop().is_some() {
787 received += 1;
788 } else {
789 std::hint::spin_loop();
790 }
791 }
792
793 for h in handles {
794 h.join().unwrap();
795 }
796
797 assert_eq!(received, TOTAL);
798 }
799}