1use std::cell::{Cell, UnsafeCell};
92use std::fmt;
93use std::mem::MaybeUninit;
94use std::sync::Arc;
95use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
96
97use crossbeam_utils::CachePadded;
98
99use crate::Full;
100
101pub fn bounded<T>(capacity: usize) -> (Producer<T>, Consumer<T>) {
109 assert!(capacity > 0, "capacity must be non-zero");
110
111 let capacity = capacity.next_power_of_two();
112 let mask = capacity - 1;
113
114 let slots: Vec<Slot<T>> = (0..capacity)
116 .map(|_| Slot {
117 turn: AtomicUsize::new(0),
118 data: UnsafeCell::new(MaybeUninit::uninit()),
119 })
120 .collect();
121 let slots = Box::into_raw(slots.into_boxed_slice()) as *mut Slot<T>;
122
123 let shift = capacity.trailing_zeros();
124
125 let shared = Arc::new(Shared {
126 head: CachePadded::new(AtomicUsize::new(0)),
127 tail: CachePadded::new(AtomicUsize::new(0)),
128 producer_alive: AtomicBool::new(true),
129 slots,
130 capacity,
131 shift,
132 mask,
133 });
134
135 (
136 Producer {
137 local_tail: Cell::new(0),
138 slots,
139 mask,
140 shift,
141 shared: Arc::clone(&shared),
142 },
143 Consumer {
144 slots,
145 mask,
146 shift,
147 shared,
148 },
149 )
150}
151
152struct Slot<T> {
154 turn: AtomicUsize,
158 data: UnsafeCell<MaybeUninit<T>>,
160}
161
162#[repr(C)]
165struct Shared<T> {
166 head: CachePadded<AtomicUsize>,
168 tail: CachePadded<AtomicUsize>,
170 producer_alive: AtomicBool,
172 slots: *mut Slot<T>,
174 capacity: usize,
176 shift: u32,
178 mask: usize,
180}
181
182unsafe impl<T: Send> Send for Shared<T> {}
185unsafe impl<T: Send> Sync for Shared<T> {}
186
187impl<T> Drop for Shared<T> {
188 fn drop(&mut self) {
189 let head = self.head.load(Ordering::Relaxed);
190 let tail = self.tail.load(Ordering::Relaxed);
191
192 let mut i = head;
194 while i != tail {
195 let slot = unsafe { &*self.slots.add(i & self.mask) };
196 let turn = i >> self.shift;
197
198 if slot.turn.load(Ordering::Relaxed) == turn * 2 + 1 {
200 unsafe { (*slot.data.get()).assume_init_drop() };
202 }
203 i = i.wrapping_add(1);
204 }
205
206 unsafe {
208 let _ = Box::from_raw(std::ptr::slice_from_raw_parts_mut(
209 self.slots,
210 self.capacity,
211 ));
212 }
213 }
214}
215
216#[repr(C)]
222pub struct Producer<T> {
223 local_tail: Cell<usize>,
225 slots: *mut Slot<T>,
227 mask: usize,
229 shift: u32,
231 shared: Arc<Shared<T>>,
232}
233
234unsafe impl<T: Send> Send for Producer<T> {}
237
238impl<T> Producer<T> {
239 #[inline]
246 #[must_use = "push returns Err if full, which should be handled"]
247 pub fn push(&self, value: T) -> Result<(), Full<T>> {
248 let tail = self.local_tail.get();
249 let slot = unsafe { &*self.slots.add(tail & self.mask) };
251 let turn = tail >> self.shift;
252
253 if slot.turn.load(Ordering::Acquire) != turn * 2 {
255 return Err(Full(value));
256 }
257
258 unsafe { (*slot.data.get()).write(value) };
260
261 slot.turn.store(turn * 2 + 1, Ordering::Release);
263
264 self.local_tail.set(tail.wrapping_add(1));
265
266 Ok(())
267 }
268
269 #[inline]
271 pub fn capacity(&self) -> usize {
272 1 << self.shift
273 }
274
275 #[inline]
277 pub fn is_disconnected(&self) -> bool {
278 Arc::strong_count(&self.shared) == 1
279 }
280}
281
282impl<T> Drop for Producer<T> {
283 fn drop(&mut self) {
284 self.shared.tail.store(self.local_tail.get(), Ordering::Relaxed);
286 self.shared.producer_alive.store(false, Ordering::Release);
287 }
288}
289
290impl<T> fmt::Debug for Producer<T> {
291 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
292 f.debug_struct("Producer")
293 .field("capacity", &self.capacity())
294 .finish_non_exhaustive()
295 }
296}
297
298#[repr(C)]
304pub struct Consumer<T> {
305 slots: *mut Slot<T>,
307 mask: usize,
309 shift: u32,
311 shared: Arc<Shared<T>>,
312}
313
314impl<T> Clone for Consumer<T> {
315 fn clone(&self) -> Self {
316 Consumer {
317 slots: self.slots,
318 mask: self.mask,
319 shift: self.shift,
320 shared: Arc::clone(&self.shared),
321 }
322 }
323}
324
325unsafe impl<T: Send> Send for Consumer<T> {}
328
329impl<T> Consumer<T> {
330 #[inline]
337 pub fn pop(&self) -> Option<T> {
338 let mut spin_count = 0u32;
339
340 loop {
341 let head = self.shared.head.load(Ordering::Relaxed);
342
343 let slot = unsafe { &*self.slots.add(head & self.mask) };
345 let turn = head >> self.shift;
346
347 let stamp = slot.turn.load(Ordering::Acquire);
348
349 if stamp == turn * 2 + 1 {
350 if self
352 .shared
353 .head
354 .compare_exchange_weak(
355 head,
356 head.wrapping_add(1),
357 Ordering::Relaxed,
358 Ordering::Relaxed,
359 )
360 .is_ok()
361 {
362 let value = unsafe { (*slot.data.get()).assume_init_read() };
364
365 slot.turn.store((turn + 1) * 2, Ordering::Release);
367
368 return Some(value);
369 }
370
371 let spins = 1 << spin_count.min(6);
373 for _ in 0..spins {
374 std::hint::spin_loop();
375 }
376 spin_count += 1;
377 } else {
378 return None;
380 }
381 }
382 }
383
384 #[inline]
386 pub fn capacity(&self) -> usize {
387 1 << self.shift
388 }
389
390 #[inline]
392 pub fn is_disconnected(&self) -> bool {
393 !self.shared.producer_alive.load(Ordering::Acquire)
394 }
395}
396
397impl<T> fmt::Debug for Consumer<T> {
398 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
399 f.debug_struct("Consumer")
400 .field("capacity", &self.capacity())
401 .finish_non_exhaustive()
402 }
403}
404
405#[cfg(test)]
406mod tests {
407 use super::*;
408
409 #[test]
414 fn basic_push_pop() {
415 let (tx, rx) = bounded::<u64>(4);
416
417 assert!(tx.push(1).is_ok());
418 assert!(tx.push(2).is_ok());
419 assert!(tx.push(3).is_ok());
420
421 assert_eq!(rx.pop(), Some(1));
422 assert_eq!(rx.pop(), Some(2));
423 assert_eq!(rx.pop(), Some(3));
424 assert_eq!(rx.pop(), None);
425 }
426
427 #[test]
428 fn empty_pop_returns_none() {
429 let (_, rx) = bounded::<u64>(4);
430 assert_eq!(rx.pop(), None);
431 assert_eq!(rx.pop(), None);
432 }
433
434 #[test]
435 fn fill_then_drain() {
436 let (tx, rx) = bounded::<u64>(4);
437
438 for i in 0..4 {
439 assert!(tx.push(i).is_ok());
440 }
441
442 for i in 0..4 {
443 assert_eq!(rx.pop(), Some(i));
444 }
445
446 assert_eq!(rx.pop(), None);
447 }
448
449 #[test]
450 fn push_returns_error_when_full() {
451 let (tx, _rx) = bounded::<u64>(4);
452
453 assert!(tx.push(1).is_ok());
454 assert!(tx.push(2).is_ok());
455 assert!(tx.push(3).is_ok());
456 assert!(tx.push(4).is_ok());
457
458 let err = tx.push(5).unwrap_err();
459 assert_eq!(err.into_inner(), 5);
460 }
461
462 #[test]
467 fn interleaved_single_consumer() {
468 let (tx, rx) = bounded::<u64>(8);
469
470 for i in 0..1000 {
471 assert!(tx.push(i).is_ok());
472 assert_eq!(rx.pop(), Some(i));
473 }
474 }
475
476 #[test]
477 fn partial_fill_drain_cycles() {
478 let (tx, rx) = bounded::<u64>(8);
479
480 for round in 0..100 {
481 for i in 0..4 {
482 assert!(tx.push(round * 4 + i).is_ok());
483 }
484
485 for i in 0..4 {
486 assert_eq!(rx.pop(), Some(round * 4 + i));
487 }
488 }
489 }
490
491 #[test]
496 fn two_consumers_single_producer() {
497 use std::thread;
498
499 let (tx, rx) = bounded::<u64>(64);
500 let rx2 = rx.clone();
501
502 let rx1 = rx;
503 let h1 = thread::spawn(move || {
504 let mut received = Vec::new();
505 loop {
506 if let Some(val) = rx1.pop() {
507 received.push(val);
508 } else if rx1.is_disconnected() {
509 while let Some(val) = rx1.pop() {
510 received.push(val);
511 }
512 break;
513 } else {
514 std::hint::spin_loop();
515 }
516 }
517 received
518 });
519
520 let h2 = thread::spawn(move || {
521 let mut received = Vec::new();
522 loop {
523 if let Some(val) = rx2.pop() {
524 received.push(val);
525 } else if rx2.is_disconnected() {
526 while let Some(val) = rx2.pop() {
527 received.push(val);
528 }
529 break;
530 } else {
531 std::hint::spin_loop();
532 }
533 }
534 received
535 });
536
537 for i in 0..2000 {
538 while tx.push(i).is_err() {
539 std::hint::spin_loop();
540 }
541 }
542 drop(tx);
543
544 let mut received = h1.join().unwrap();
545 received.extend(h2.join().unwrap());
546
547 received.sort_unstable();
549 assert_eq!(received, (0..2000).collect::<Vec<_>>());
550 }
551
552 #[test]
553 fn four_consumers_single_producer() {
554 use std::thread;
555
556 let (tx, rx) = bounded::<u64>(256);
557
558 let handles: Vec<_> = (0..4)
559 .map(|_| {
560 let rx = rx.clone();
561 thread::spawn(move || {
562 let mut received = Vec::new();
563 loop {
564 if let Some(val) = rx.pop() {
565 received.push(val);
566 } else if rx.is_disconnected() {
567 while let Some(val) = rx.pop() {
568 received.push(val);
569 }
570 break;
571 } else {
572 std::hint::spin_loop();
573 }
574 }
575 received
576 })
577 })
578 .collect();
579
580 drop(rx); for i in 0..4000u64 {
583 while tx.push(i).is_err() {
584 std::hint::spin_loop();
585 }
586 }
587 drop(tx);
588
589 let mut received = Vec::new();
590 for h in handles {
591 received.extend(h.join().unwrap());
592 }
593
594 received.sort_unstable();
595 assert_eq!(received, (0..4000).collect::<Vec<_>>());
596 }
597
598 #[test]
603 fn single_slot_bounded() {
604 let (tx, rx) = bounded::<u64>(1);
605
606 assert!(tx.push(1).is_ok());
607 assert!(tx.push(2).is_err());
608
609 assert_eq!(rx.pop(), Some(1));
610 assert!(tx.push(2).is_ok());
611 }
612
613 #[test]
618 fn consumer_detects_producer_drop() {
619 let (tx, rx) = bounded::<u64>(4);
620
621 assert!(!rx.is_disconnected());
622 drop(tx);
623 assert!(rx.is_disconnected());
624 }
625
626 #[test]
627 fn producer_detects_all_consumers_drop() {
628 let (tx, rx) = bounded::<u64>(4);
629
630 assert!(!tx.is_disconnected());
631 drop(rx);
632 assert!(tx.is_disconnected());
633 }
634
635 #[test]
636 fn one_consumer_drops_others_alive() {
637 let (tx, rx) = bounded::<u64>(4);
638 let rx2 = rx.clone();
639
640 assert!(!tx.is_disconnected());
641 drop(rx);
642 assert!(!tx.is_disconnected()); assert!(!rx2.is_disconnected()); drop(rx2);
645 assert!(tx.is_disconnected());
646 }
647
648 #[test]
653 fn drop_cleans_up_remaining() {
654 use std::sync::atomic::AtomicUsize;
655
656 static DROP_COUNT: AtomicUsize = AtomicUsize::new(0);
657
658 struct DropCounter;
659 impl Drop for DropCounter {
660 fn drop(&mut self) {
661 DROP_COUNT.fetch_add(1, Ordering::SeqCst);
662 }
663 }
664
665 DROP_COUNT.store(0, Ordering::SeqCst);
666
667 let (tx, rx) = bounded::<DropCounter>(4);
668
669 let _ = tx.push(DropCounter);
670 let _ = tx.push(DropCounter);
671 let _ = tx.push(DropCounter);
672
673 assert_eq!(DROP_COUNT.load(Ordering::SeqCst), 0);
674
675 drop(tx);
676 drop(rx);
677
678 assert_eq!(DROP_COUNT.load(Ordering::SeqCst), 3);
679 }
680
681 #[test]
686 fn zero_sized_type() {
687 let (tx, rx) = bounded::<()>(8);
688
689 let _ = tx.push(());
690 let _ = tx.push(());
691
692 assert_eq!(rx.pop(), Some(()));
693 assert_eq!(rx.pop(), Some(()));
694 assert_eq!(rx.pop(), None);
695 }
696
697 #[test]
698 fn string_type() {
699 let (tx, rx) = bounded::<String>(4);
700
701 let _ = tx.push("hello".to_string());
702 let _ = tx.push("world".to_string());
703
704 assert_eq!(rx.pop(), Some("hello".to_string()));
705 assert_eq!(rx.pop(), Some("world".to_string()));
706 }
707
708 #[test]
709 #[should_panic(expected = "capacity must be non-zero")]
710 fn zero_capacity_panics() {
711 let _ = bounded::<u64>(0);
712 }
713
714 #[test]
715 fn large_message_type() {
716 #[repr(C, align(64))]
717 struct LargeMessage {
718 data: [u8; 256],
719 }
720
721 let (tx, rx) = bounded::<LargeMessage>(8);
722
723 let msg = LargeMessage { data: [42u8; 256] };
724 assert!(tx.push(msg).is_ok());
725
726 let received = rx.pop().unwrap();
727 assert_eq!(received.data[0], 42);
728 assert_eq!(received.data[255], 42);
729 }
730
731 #[test]
732 fn multiple_laps() {
733 let (tx, rx) = bounded::<u64>(4);
734
735 for i in 0..40 {
737 assert!(tx.push(i).is_ok());
738 assert_eq!(rx.pop(), Some(i));
739 }
740 }
741
742 #[test]
743 fn capacity_rounds_to_power_of_two() {
744 let (tx, _) = bounded::<u64>(100);
745 assert_eq!(tx.capacity(), 128);
746
747 let (tx, _) = bounded::<u64>(1000);
748 assert_eq!(tx.capacity(), 1024);
749 }
750
751 #[test]
756 fn stress_single_consumer() {
757 use std::thread;
758
759 const COUNT: u64 = 100_000;
760
761 let (tx, rx) = bounded::<u64>(1024);
762
763 let producer = thread::spawn(move || {
764 for i in 0..COUNT {
765 while tx.push(i).is_err() {
766 std::hint::spin_loop();
767 }
768 }
769 });
770
771 let consumer = thread::spawn(move || {
772 let mut sum = 0u64;
773 let mut received = 0u64;
774 while received < COUNT {
775 if let Some(val) = rx.pop() {
776 sum = sum.wrapping_add(val);
777 received += 1;
778 } else {
779 std::hint::spin_loop();
780 }
781 }
782 sum
783 });
784
785 producer.join().unwrap();
786 let sum = consumer.join().unwrap();
787 assert_eq!(sum, COUNT * (COUNT - 1) / 2);
788 }
789
790 #[test]
791 fn stress_multiple_consumers() {
792 use std::thread;
793
794 const CONSUMERS: usize = 4;
795 const TOTAL: u64 = 100_000;
796
797 let (tx, rx) = bounded::<u64>(1024);
798
799 let handles: Vec<_> = (0..CONSUMERS)
800 .map(|_| {
801 let rx = rx.clone();
802 thread::spawn(move || {
803 let mut received = Vec::new();
804 loop {
805 if let Some(val) = rx.pop() {
806 received.push(val);
807 } else if rx.is_disconnected() {
808 while let Some(val) = rx.pop() {
809 received.push(val);
810 }
811 break;
812 } else {
813 std::hint::spin_loop();
814 }
815 }
816 received
817 })
818 })
819 .collect();
820
821 drop(rx);
822
823 let producer = thread::spawn(move || {
824 for i in 0..TOTAL {
825 while tx.push(i).is_err() {
826 std::hint::spin_loop();
827 }
828 }
829 });
830
831 producer.join().unwrap();
832
833 let mut all_received = Vec::new();
834 for h in handles {
835 all_received.extend(h.join().unwrap());
836 }
837
838 all_received.sort_unstable();
839 assert_eq!(all_received, (0..TOTAL).collect::<Vec<_>>());
840 }
841}