1use std::cell::{Cell, UnsafeCell};
67use std::fmt;
68use std::mem::MaybeUninit;
69use std::sync::Arc;
70use std::sync::atomic::{AtomicUsize, Ordering};
71
72use crossbeam_utils::CachePadded;
73
74use crate::Full;
75
76pub fn bounded<T>(capacity: usize) -> (Producer<T>, Consumer<T>) {
84 assert!(capacity > 0, "capacity must be non-zero");
85
86 let capacity = capacity.next_power_of_two();
87 let mask = capacity - 1;
88
89 let slots: Vec<Slot<T>> = (0..capacity)
91 .map(|_| Slot {
92 turn: AtomicUsize::new(0),
93 data: UnsafeCell::new(MaybeUninit::uninit()),
94 })
95 .collect();
96 let slots = Box::into_raw(slots.into_boxed_slice()) as *mut Slot<T>;
97
98 let shift = capacity.trailing_zeros();
99
100 let shared = Arc::new(Shared {
101 tail: CachePadded::new(AtomicUsize::new(0)),
102 head: CachePadded::new(AtomicUsize::new(0)),
103 slots,
104 capacity,
105 shift,
106 mask,
107 });
108
109 (
110 Producer {
111 cached_head: Cell::new(0),
112 slots,
113 mask,
114 capacity,
115 shift,
116 shared: Arc::clone(&shared),
117 },
118 Consumer {
119 local_head: Cell::new(0),
120 slots,
121 mask,
122 shift,
123 shared,
124 },
125 )
126}
127
128struct Slot<T> {
130 turn: AtomicUsize,
134 data: UnsafeCell<MaybeUninit<T>>,
136}
137
138#[repr(C)]
141struct Shared<T> {
142 tail: CachePadded<AtomicUsize>,
144 head: CachePadded<AtomicUsize>,
146 slots: *mut Slot<T>,
148 capacity: usize,
150 shift: u32,
152 mask: usize,
154}
155
156unsafe impl<T: Send> Send for Shared<T> {}
159unsafe impl<T: Send> Sync for Shared<T> {}
160
161impl<T> Drop for Shared<T> {
162 fn drop(&mut self) {
163 let head = self.head.load(Ordering::Relaxed);
164 let tail = self.tail.load(Ordering::Relaxed);
165
166 let mut i = head;
168 while i != tail {
169 let slot = unsafe { &*self.slots.add(i & self.mask) };
170 let turn = i >> self.shift;
171
172 if slot.turn.load(Ordering::Relaxed) == turn * 2 + 1 {
174 unsafe { (*slot.data.get()).assume_init_drop() };
176 }
177 i = i.wrapping_add(1);
178 }
179
180 unsafe {
182 let _ = Box::from_raw(std::ptr::slice_from_raw_parts_mut(
183 self.slots,
184 self.capacity,
185 ));
186 }
187 }
188}
189
190#[repr(C)]
196pub struct Producer<T> {
197 cached_head: Cell<usize>,
199 slots: *mut Slot<T>,
201 mask: usize,
203 capacity: usize,
205 shift: u32,
207 shared: Arc<Shared<T>>,
208}
209
210impl<T> Clone for Producer<T> {
211 fn clone(&self) -> Self {
212 Producer {
213 cached_head: Cell::new(self.shared.head.load(Ordering::Relaxed)),
215 slots: self.slots,
216 mask: self.mask,
217 capacity: self.capacity,
218 shift: self.shift,
219 shared: Arc::clone(&self.shared),
220 }
221 }
222}
223
224unsafe impl<T: Send> Send for Producer<T> {}
227
228impl<T> Producer<T> {
229 #[inline]
237 #[must_use = "push returns Err if full, which should be handled"]
238 pub fn push(&self, value: T) -> Result<(), Full<T>> {
239 let mut spin_count = 0u32;
240
241 loop {
242 let tail = self.shared.tail.load(Ordering::Relaxed);
243
244 if tail.wrapping_sub(self.cached_head.get()) >= self.capacity {
246 self.cached_head.set(self.shared.head.load(Ordering::Acquire));
248
249 if tail.wrapping_sub(self.cached_head.get()) >= self.capacity {
251 return Err(Full(value));
252 }
253 }
254
255 let slot = unsafe { &*self.slots.add(tail & self.mask) };
257 let turn = tail >> self.shift;
258 let expected_stamp = turn * 2;
259
260 let stamp = slot.turn.load(Ordering::Acquire);
262
263 if stamp == expected_stamp {
264 if self
266 .shared
267 .tail
268 .compare_exchange_weak(
269 tail,
270 tail.wrapping_add(1),
271 Ordering::Relaxed,
272 Ordering::Relaxed,
273 )
274 .is_ok()
275 {
276 unsafe { (*slot.data.get()).write(value) };
278
279 slot.turn.store(turn * 2 + 1, Ordering::Release);
281
282 return Ok(());
283 }
284 }
285
286 let spins = 1 << spin_count.min(6);
289 for _ in 0..spins {
290 std::hint::spin_loop();
291 }
292 spin_count += 1;
293 }
294 }
295
296 #[inline]
298 pub fn capacity(&self) -> usize {
299 1 << self.shift
300 }
301
302 #[inline]
307 pub fn is_disconnected(&self) -> bool {
308 Arc::strong_count(&self.shared) == 1
309 }
310}
311
312impl<T> fmt::Debug for Producer<T> {
313 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
314 f.debug_struct("Producer")
315 .field("capacity", &self.capacity())
316 .finish_non_exhaustive()
317 }
318}
319
320#[repr(C)]
325pub struct Consumer<T> {
326 local_head: Cell<usize>,
328 slots: *mut Slot<T>,
330 mask: usize,
332 shift: u32,
334 shared: Arc<Shared<T>>,
335}
336
337unsafe impl<T: Send> Send for Consumer<T> {}
340
341impl<T> Consumer<T> {
342 #[inline]
346 pub fn pop(&self) -> Option<T> {
347 let head = self.local_head.get();
348 let slot = unsafe { &*self.slots.add(head & self.mask) };
350 let turn = head >> self.shift;
351
352 if slot.turn.load(Ordering::Acquire) != turn * 2 + 1 {
354 return None;
355 }
356
357 let value = unsafe { (*slot.data.get()).assume_init_read() };
359
360 slot.turn.store((turn + 1) * 2, Ordering::Release);
362
363 let new_head = head.wrapping_add(1);
365 self.local_head.set(new_head);
366 self.shared.head.store(new_head, Ordering::Release);
367
368 Some(value)
369 }
370
371 #[inline]
373 pub fn capacity(&self) -> usize {
374 1 << self.shift
375 }
376
377 #[inline]
379 pub fn is_disconnected(&self) -> bool {
380 Arc::strong_count(&self.shared) == 1
381 }
382}
383
384impl<T> fmt::Debug for Consumer<T> {
385 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
386 f.debug_struct("Consumer")
387 .field("capacity", &self.capacity())
388 .finish_non_exhaustive()
389 }
390}
391
392#[cfg(test)]
393mod tests {
394 use super::*;
395
396 #[test]
401 fn basic_push_pop() {
402 let (tx, rx) = bounded::<u64>(4);
403
404 assert!(tx.push(1).is_ok());
405 assert!(tx.push(2).is_ok());
406 assert!(tx.push(3).is_ok());
407
408 assert_eq!(rx.pop(), Some(1));
409 assert_eq!(rx.pop(), Some(2));
410 assert_eq!(rx.pop(), Some(3));
411 assert_eq!(rx.pop(), None);
412 }
413
414 #[test]
415 fn empty_pop_returns_none() {
416 let (_, rx) = bounded::<u64>(4);
417 assert_eq!(rx.pop(), None);
418 assert_eq!(rx.pop(), None);
419 }
420
421 #[test]
422 fn fill_then_drain() {
423 let (tx, rx) = bounded::<u64>(4);
424
425 for i in 0..4 {
426 assert!(tx.push(i).is_ok());
427 }
428
429 for i in 0..4 {
430 assert_eq!(rx.pop(), Some(i));
431 }
432
433 assert_eq!(rx.pop(), None);
434 }
435
436 #[test]
437 fn push_returns_error_when_full() {
438 let (tx, _rx) = bounded::<u64>(4);
439
440 assert!(tx.push(1).is_ok());
441 assert!(tx.push(2).is_ok());
442 assert!(tx.push(3).is_ok());
443 assert!(tx.push(4).is_ok());
444
445 let err = tx.push(5).unwrap_err();
446 assert_eq!(err.into_inner(), 5);
447 }
448
449 #[test]
454 fn interleaved_single_producer() {
455 let (tx, rx) = bounded::<u64>(8);
456
457 for i in 0..1000 {
458 assert!(tx.push(i).is_ok());
459 assert_eq!(rx.pop(), Some(i));
460 }
461 }
462
463 #[test]
464 fn partial_fill_drain_cycles() {
465 let (tx, rx) = bounded::<u64>(8);
466
467 for round in 0..100 {
468 for i in 0..4 {
469 assert!(tx.push(round * 4 + i).is_ok());
470 }
471
472 for i in 0..4 {
473 assert_eq!(rx.pop(), Some(round * 4 + i));
474 }
475 }
476 }
477
478 #[test]
483 fn two_producers_single_consumer() {
484 use std::thread;
485
486 let (tx, rx) = bounded::<u64>(64);
487 let tx2 = tx.clone();
488
489 let h1 = thread::spawn(move || {
490 for i in 0..1000 {
491 while tx.push(i).is_err() {
492 std::hint::spin_loop();
493 }
494 }
495 });
496
497 let h2 = thread::spawn(move || {
498 for i in 1000..2000 {
499 while tx2.push(i).is_err() {
500 std::hint::spin_loop();
501 }
502 }
503 });
504
505 let mut received = Vec::new();
506 while received.len() < 2000 {
507 if let Some(val) = rx.pop() {
508 received.push(val);
509 } else {
510 std::hint::spin_loop();
511 }
512 }
513
514 h1.join().unwrap();
515 h2.join().unwrap();
516
517 received.sort_unstable();
519 assert_eq!(received, (0..2000).collect::<Vec<_>>());
520 }
521
522 #[test]
523 fn four_producers_single_consumer() {
524 use std::thread;
525
526 let (tx, rx) = bounded::<u64>(256);
527
528 let handles: Vec<_> = (0..4)
529 .map(|p| {
530 let tx = tx.clone();
531 thread::spawn(move || {
532 for i in 0..1000 {
533 let val = p * 1000 + i;
534 while tx.push(val).is_err() {
535 std::hint::spin_loop();
536 }
537 }
538 })
539 })
540 .collect();
541
542 drop(tx); let mut received = Vec::new();
545 while received.len() < 4000 {
546 if let Some(val) = rx.pop() {
547 received.push(val);
548 } else if rx.is_disconnected() && received.len() < 4000 {
549 std::hint::spin_loop();
551 } else {
552 std::hint::spin_loop();
553 }
554 }
555
556 for h in handles {
557 h.join().unwrap();
558 }
559
560 received.sort_unstable();
561 let expected: Vec<u64> = (0..4)
562 .flat_map(|p| (0..1000).map(move |i| p * 1000 + i))
563 .collect();
564 let mut expected_sorted = expected;
565 expected_sorted.sort_unstable();
566 assert_eq!(received, expected_sorted);
567 }
568
569 #[test]
574 fn single_slot_bounded() {
575 let (tx, rx) = bounded::<u64>(1);
576
577 assert!(tx.push(1).is_ok());
578 assert!(tx.push(2).is_err());
579
580 assert_eq!(rx.pop(), Some(1));
581 assert!(tx.push(2).is_ok());
582 }
583
584 #[test]
589 fn producer_disconnected() {
590 let (tx, rx) = bounded::<u64>(4);
591
592 assert!(!rx.is_disconnected());
593 drop(tx);
594 assert!(rx.is_disconnected());
595 }
596
597 #[test]
598 fn consumer_disconnected() {
599 let (tx, rx) = bounded::<u64>(4);
600
601 assert!(!tx.is_disconnected());
602 drop(rx);
603 assert!(tx.is_disconnected());
604 }
605
606 #[test]
607 fn multiple_producers_one_disconnects() {
608 let (tx1, rx) = bounded::<u64>(4);
609 let tx2 = tx1.clone();
610
611 assert!(!rx.is_disconnected());
612 drop(tx1);
613 assert!(!rx.is_disconnected()); drop(tx2);
615 assert!(rx.is_disconnected());
616 }
617
618 #[test]
623 fn drop_cleans_up_remaining() {
624 use std::sync::atomic::AtomicUsize;
625
626 static DROP_COUNT: AtomicUsize = AtomicUsize::new(0);
627
628 struct DropCounter;
629 impl Drop for DropCounter {
630 fn drop(&mut self) {
631 DROP_COUNT.fetch_add(1, Ordering::SeqCst);
632 }
633 }
634
635 DROP_COUNT.store(0, Ordering::SeqCst);
636
637 let (tx, rx) = bounded::<DropCounter>(4);
638
639 let _ = tx.push(DropCounter);
640 let _ = tx.push(DropCounter);
641 let _ = tx.push(DropCounter);
642
643 assert_eq!(DROP_COUNT.load(Ordering::SeqCst), 0);
644
645 drop(tx);
646 drop(rx);
647
648 assert_eq!(DROP_COUNT.load(Ordering::SeqCst), 3);
649 }
650
651 #[test]
656 fn zero_sized_type() {
657 let (tx, rx) = bounded::<()>(8);
658
659 let _ = tx.push(());
660 let _ = tx.push(());
661
662 assert_eq!(rx.pop(), Some(()));
663 assert_eq!(rx.pop(), Some(()));
664 assert_eq!(rx.pop(), None);
665 }
666
667 #[test]
668 fn string_type() {
669 let (tx, rx) = bounded::<String>(4);
670
671 let _ = tx.push("hello".to_string());
672 let _ = tx.push("world".to_string());
673
674 assert_eq!(rx.pop(), Some("hello".to_string()));
675 assert_eq!(rx.pop(), Some("world".to_string()));
676 }
677
678 #[test]
679 #[should_panic(expected = "capacity must be non-zero")]
680 fn zero_capacity_panics() {
681 let _ = bounded::<u64>(0);
682 }
683
684 #[test]
685 fn large_message_type() {
686 #[repr(C, align(64))]
687 struct LargeMessage {
688 data: [u8; 256],
689 }
690
691 let (tx, rx) = bounded::<LargeMessage>(8);
692
693 let msg = LargeMessage { data: [42u8; 256] };
694 assert!(tx.push(msg).is_ok());
695
696 let received = rx.pop().unwrap();
697 assert_eq!(received.data[0], 42);
698 assert_eq!(received.data[255], 42);
699 }
700
701 #[test]
702 fn multiple_laps() {
703 let (tx, rx) = bounded::<u64>(4);
704
705 for i in 0..40 {
707 assert!(tx.push(i).is_ok());
708 assert_eq!(rx.pop(), Some(i));
709 }
710 }
711
712 #[test]
713 fn capacity_rounds_to_power_of_two() {
714 let (tx, _) = bounded::<u64>(100);
715 assert_eq!(tx.capacity(), 128);
716
717 let (tx, _) = bounded::<u64>(1000);
718 assert_eq!(tx.capacity(), 1024);
719 }
720
721 #[test]
726 fn stress_single_producer() {
727 use std::thread;
728
729 const COUNT: u64 = 100_000;
730
731 let (tx, rx) = bounded::<u64>(1024);
732
733 let producer = thread::spawn(move || {
734 for i in 0..COUNT {
735 while tx.push(i).is_err() {
736 std::hint::spin_loop();
737 }
738 }
739 });
740
741 let consumer = thread::spawn(move || {
742 let mut sum = 0u64;
743 let mut received = 0u64;
744 while received < COUNT {
745 if let Some(val) = rx.pop() {
746 sum = sum.wrapping_add(val);
747 received += 1;
748 } else {
749 std::hint::spin_loop();
750 }
751 }
752 sum
753 });
754
755 producer.join().unwrap();
756 let sum = consumer.join().unwrap();
757 assert_eq!(sum, COUNT * (COUNT - 1) / 2);
758 }
759
760 #[test]
761 fn stress_multiple_producers() {
762 use std::thread;
763
764 const PRODUCERS: u64 = 4;
765 const PER_PRODUCER: u64 = 25_000;
766 const TOTAL: u64 = PRODUCERS * PER_PRODUCER;
767
768 let (tx, rx) = bounded::<u64>(1024);
769
770 let handles: Vec<_> = (0..PRODUCERS)
771 .map(|_| {
772 let tx = tx.clone();
773 thread::spawn(move || {
774 for i in 0..PER_PRODUCER {
775 while tx.push(i).is_err() {
776 std::hint::spin_loop();
777 }
778 }
779 })
780 })
781 .collect();
782
783 drop(tx);
784
785 let mut received = 0u64;
786 while received < TOTAL {
787 if rx.pop().is_some() {
788 received += 1;
789 } else {
790 std::hint::spin_loop();
791 }
792 }
793
794 for h in handles {
795 h.join().unwrap();
796 }
797
798 assert_eq!(received, TOTAL);
799 }
800}