1use std::cell::UnsafeCell;
92use std::fmt;
93use std::mem::MaybeUninit;
94use std::sync::Arc;
95use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
96
97use crossbeam_utils::CachePadded;
98
99use crate::Full;
100
101pub fn bounded<T>(capacity: usize) -> (Producer<T>, Consumer<T>) {
109 assert!(capacity > 0, "capacity must be non-zero");
110
111 let capacity = capacity.next_power_of_two();
112 let mask = capacity - 1;
113
114 let slots: Vec<Slot<T>> = (0..capacity)
116 .map(|_| Slot {
117 turn: AtomicUsize::new(0),
118 data: UnsafeCell::new(MaybeUninit::uninit()),
119 })
120 .collect();
121 let slots = Box::into_raw(slots.into_boxed_slice()) as *mut Slot<T>;
122
123 let shift = capacity.trailing_zeros();
124
125 let shared = Arc::new(Shared {
126 head: CachePadded::new(AtomicUsize::new(0)),
127 tail: CachePadded::new(AtomicUsize::new(0)),
128 producer_alive: AtomicBool::new(true),
129 slots,
130 capacity,
131 shift,
132 mask,
133 });
134
135 (
136 Producer {
137 local_tail: 0,
138 slots,
139 mask,
140 shift,
141 shared: Arc::clone(&shared),
142 },
143 Consumer {
144 slots,
145 mask,
146 shift,
147 shared,
148 },
149 )
150}
151
152struct Slot<T> {
154 turn: AtomicUsize,
158 data: UnsafeCell<MaybeUninit<T>>,
160}
161
162#[repr(C)]
165struct Shared<T> {
166 head: CachePadded<AtomicUsize>,
168 tail: CachePadded<AtomicUsize>,
170 producer_alive: AtomicBool,
172 slots: *mut Slot<T>,
174 capacity: usize,
176 shift: u32,
178 mask: usize,
180}
181
182unsafe impl<T: Send> Send for Shared<T> {}
185unsafe impl<T: Send> Sync for Shared<T> {}
186
187impl<T> Drop for Shared<T> {
188 fn drop(&mut self) {
189 let head = self.head.load(Ordering::Relaxed);
190 let tail = self.tail.load(Ordering::Relaxed);
191
192 let mut i = head;
194 while i != tail {
195 let slot = unsafe { &*self.slots.add(i & self.mask) };
196 let turn = i >> self.shift;
197
198 if slot.turn.load(Ordering::Relaxed) == turn * 2 + 1 {
200 unsafe { (*slot.data.get()).assume_init_drop() };
202 }
203 i = i.wrapping_add(1);
204 }
205
206 unsafe {
208 let _ = Box::from_raw(std::ptr::slice_from_raw_parts_mut(
209 self.slots,
210 self.capacity,
211 ));
212 }
213 }
214}
215
216#[repr(C)]
222pub struct Producer<T> {
223 local_tail: usize,
225 slots: *mut Slot<T>,
227 mask: usize,
229 shift: u32,
231 shared: Arc<Shared<T>>,
232}
233
234unsafe impl<T: Send> Send for Producer<T> {}
237
238impl<T> Producer<T> {
239 #[inline]
246 #[must_use = "push returns Err if full, which should be handled"]
247 pub fn push(&mut self, value: T) -> Result<(), Full<T>> {
248 let slot = unsafe { &*self.slots.add(self.local_tail & self.mask) };
250 let turn = self.local_tail >> self.shift;
251
252 if slot.turn.load(Ordering::Acquire) != turn * 2 {
254 return Err(Full(value));
255 }
256
257 unsafe { (*slot.data.get()).write(value) };
259
260 slot.turn.store(turn * 2 + 1, Ordering::Release);
262
263 self.local_tail = self.local_tail.wrapping_add(1);
264
265 Ok(())
266 }
267
268 #[inline]
270 pub fn capacity(&self) -> usize {
271 1 << self.shift
272 }
273
274 #[inline]
276 pub fn is_disconnected(&self) -> bool {
277 Arc::strong_count(&self.shared) == 1
278 }
279}
280
281impl<T> Drop for Producer<T> {
282 fn drop(&mut self) {
283 self.shared.tail.store(self.local_tail, Ordering::Relaxed);
285 self.shared.producer_alive.store(false, Ordering::Release);
286 }
287}
288
289impl<T> fmt::Debug for Producer<T> {
290 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
291 f.debug_struct("Producer")
292 .field("capacity", &self.capacity())
293 .finish_non_exhaustive()
294 }
295}
296
297#[repr(C)]
303pub struct Consumer<T> {
304 slots: *mut Slot<T>,
306 mask: usize,
308 shift: u32,
310 shared: Arc<Shared<T>>,
311}
312
313impl<T> Clone for Consumer<T> {
314 fn clone(&self) -> Self {
315 Consumer {
316 slots: self.slots,
317 mask: self.mask,
318 shift: self.shift,
319 shared: Arc::clone(&self.shared),
320 }
321 }
322}
323
324unsafe impl<T: Send> Send for Consumer<T> {}
327
328impl<T> Consumer<T> {
329 #[inline]
336 pub fn pop(&mut self) -> Option<T> {
337 let mut spin_count = 0u32;
338
339 loop {
340 let head = self.shared.head.load(Ordering::Relaxed);
341
342 let slot = unsafe { &*self.slots.add(head & self.mask) };
344 let turn = head >> self.shift;
345
346 let stamp = slot.turn.load(Ordering::Acquire);
347
348 if stamp == turn * 2 + 1 {
349 if self
351 .shared
352 .head
353 .compare_exchange_weak(
354 head,
355 head.wrapping_add(1),
356 Ordering::Relaxed,
357 Ordering::Relaxed,
358 )
359 .is_ok()
360 {
361 let value = unsafe { (*slot.data.get()).assume_init_read() };
363
364 slot.turn.store((turn + 1) * 2, Ordering::Release);
366
367 return Some(value);
368 }
369
370 let spins = 1 << spin_count.min(6);
372 for _ in 0..spins {
373 std::hint::spin_loop();
374 }
375 spin_count += 1;
376 } else {
377 return None;
379 }
380 }
381 }
382
383 #[inline]
385 pub fn capacity(&self) -> usize {
386 1 << self.shift
387 }
388
389 #[inline]
391 pub fn is_disconnected(&self) -> bool {
392 !self.shared.producer_alive.load(Ordering::Acquire)
393 }
394}
395
396impl<T> fmt::Debug for Consumer<T> {
397 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
398 f.debug_struct("Consumer")
399 .field("capacity", &self.capacity())
400 .finish_non_exhaustive()
401 }
402}
403
404#[cfg(test)]
405mod tests {
406 use super::*;
407
408 #[test]
413 fn basic_push_pop() {
414 let (mut tx, mut rx) = bounded::<u64>(4);
415
416 assert!(tx.push(1).is_ok());
417 assert!(tx.push(2).is_ok());
418 assert!(tx.push(3).is_ok());
419
420 assert_eq!(rx.pop(), Some(1));
421 assert_eq!(rx.pop(), Some(2));
422 assert_eq!(rx.pop(), Some(3));
423 assert_eq!(rx.pop(), None);
424 }
425
426 #[test]
427 fn empty_pop_returns_none() {
428 let (_, mut rx) = bounded::<u64>(4);
429 assert_eq!(rx.pop(), None);
430 assert_eq!(rx.pop(), None);
431 }
432
433 #[test]
434 fn fill_then_drain() {
435 let (mut tx, mut rx) = bounded::<u64>(4);
436
437 for i in 0..4 {
438 assert!(tx.push(i).is_ok());
439 }
440
441 for i in 0..4 {
442 assert_eq!(rx.pop(), Some(i));
443 }
444
445 assert_eq!(rx.pop(), None);
446 }
447
448 #[test]
449 fn push_returns_error_when_full() {
450 let (mut tx, _rx) = bounded::<u64>(4);
451
452 assert!(tx.push(1).is_ok());
453 assert!(tx.push(2).is_ok());
454 assert!(tx.push(3).is_ok());
455 assert!(tx.push(4).is_ok());
456
457 let err = tx.push(5).unwrap_err();
458 assert_eq!(err.into_inner(), 5);
459 }
460
461 #[test]
466 fn interleaved_single_consumer() {
467 let (mut tx, mut rx) = bounded::<u64>(8);
468
469 for i in 0..1000 {
470 assert!(tx.push(i).is_ok());
471 assert_eq!(rx.pop(), Some(i));
472 }
473 }
474
475 #[test]
476 fn partial_fill_drain_cycles() {
477 let (mut tx, mut rx) = bounded::<u64>(8);
478
479 for round in 0..100 {
480 for i in 0..4 {
481 assert!(tx.push(round * 4 + i).is_ok());
482 }
483
484 for i in 0..4 {
485 assert_eq!(rx.pop(), Some(round * 4 + i));
486 }
487 }
488 }
489
490 #[test]
495 fn two_consumers_single_producer() {
496 use std::thread;
497
498 let (mut tx, rx) = bounded::<u64>(64);
499 let mut rx2 = rx.clone();
500
501 let mut rx1 = rx;
502 let h1 = thread::spawn(move || {
503 let mut received = Vec::new();
504 loop {
505 if let Some(val) = rx1.pop() {
506 received.push(val);
507 } else if rx1.is_disconnected() {
508 while let Some(val) = rx1.pop() {
509 received.push(val);
510 }
511 break;
512 } else {
513 std::hint::spin_loop();
514 }
515 }
516 received
517 });
518
519 let h2 = thread::spawn(move || {
520 let mut received = Vec::new();
521 loop {
522 if let Some(val) = rx2.pop() {
523 received.push(val);
524 } else if rx2.is_disconnected() {
525 while let Some(val) = rx2.pop() {
526 received.push(val);
527 }
528 break;
529 } else {
530 std::hint::spin_loop();
531 }
532 }
533 received
534 });
535
536 for i in 0..2000 {
537 while tx.push(i).is_err() {
538 std::hint::spin_loop();
539 }
540 }
541 drop(tx);
542
543 let mut received = h1.join().unwrap();
544 received.extend(h2.join().unwrap());
545
546 received.sort();
548 assert_eq!(received, (0..2000).collect::<Vec<_>>());
549 }
550
551 #[test]
552 fn four_consumers_single_producer() {
553 use std::thread;
554
555 let (mut tx, rx) = bounded::<u64>(256);
556
557 let handles: Vec<_> = (0..4)
558 .map(|_| {
559 let mut rx = rx.clone();
560 thread::spawn(move || {
561 let mut received = Vec::new();
562 loop {
563 if let Some(val) = rx.pop() {
564 received.push(val);
565 } else if rx.is_disconnected() {
566 while let Some(val) = rx.pop() {
567 received.push(val);
568 }
569 break;
570 } else {
571 std::hint::spin_loop();
572 }
573 }
574 received
575 })
576 })
577 .collect();
578
579 drop(rx); for i in 0..4000u64 {
582 while tx.push(i).is_err() {
583 std::hint::spin_loop();
584 }
585 }
586 drop(tx);
587
588 let mut received = Vec::new();
589 for h in handles {
590 received.extend(h.join().unwrap());
591 }
592
593 received.sort();
594 assert_eq!(received, (0..4000).collect::<Vec<_>>());
595 }
596
597 #[test]
602 fn single_slot_bounded() {
603 let (mut tx, mut rx) = bounded::<u64>(1);
604
605 assert!(tx.push(1).is_ok());
606 assert!(tx.push(2).is_err());
607
608 assert_eq!(rx.pop(), Some(1));
609 assert!(tx.push(2).is_ok());
610 }
611
612 #[test]
617 fn consumer_detects_producer_drop() {
618 let (tx, rx) = bounded::<u64>(4);
619
620 assert!(!rx.is_disconnected());
621 drop(tx);
622 assert!(rx.is_disconnected());
623 }
624
625 #[test]
626 fn producer_detects_all_consumers_drop() {
627 let (tx, rx) = bounded::<u64>(4);
628
629 assert!(!tx.is_disconnected());
630 drop(rx);
631 assert!(tx.is_disconnected());
632 }
633
634 #[test]
635 fn one_consumer_drops_others_alive() {
636 let (tx, rx) = bounded::<u64>(4);
637 let rx2 = rx.clone();
638
639 assert!(!tx.is_disconnected());
640 drop(rx);
641 assert!(!tx.is_disconnected()); assert!(!rx2.is_disconnected()); drop(rx2);
644 assert!(tx.is_disconnected());
645 }
646
647 #[test]
652 fn drop_cleans_up_remaining() {
653 use std::sync::atomic::AtomicUsize;
654
655 static DROP_COUNT: AtomicUsize = AtomicUsize::new(0);
656
657 struct DropCounter;
658 impl Drop for DropCounter {
659 fn drop(&mut self) {
660 DROP_COUNT.fetch_add(1, Ordering::SeqCst);
661 }
662 }
663
664 DROP_COUNT.store(0, Ordering::SeqCst);
665
666 let (mut tx, rx) = bounded::<DropCounter>(4);
667
668 let _ = tx.push(DropCounter);
669 let _ = tx.push(DropCounter);
670 let _ = tx.push(DropCounter);
671
672 assert_eq!(DROP_COUNT.load(Ordering::SeqCst), 0);
673
674 drop(tx);
675 drop(rx);
676
677 assert_eq!(DROP_COUNT.load(Ordering::SeqCst), 3);
678 }
679
680 #[test]
685 fn zero_sized_type() {
686 let (mut tx, mut rx) = bounded::<()>(8);
687
688 let _ = tx.push(());
689 let _ = tx.push(());
690
691 assert_eq!(rx.pop(), Some(()));
692 assert_eq!(rx.pop(), Some(()));
693 assert_eq!(rx.pop(), None);
694 }
695
696 #[test]
697 fn string_type() {
698 let (mut tx, mut rx) = bounded::<String>(4);
699
700 let _ = tx.push("hello".to_string());
701 let _ = tx.push("world".to_string());
702
703 assert_eq!(rx.pop(), Some("hello".to_string()));
704 assert_eq!(rx.pop(), Some("world".to_string()));
705 }
706
707 #[test]
708 #[should_panic(expected = "capacity must be non-zero")]
709 fn zero_capacity_panics() {
710 let _ = bounded::<u64>(0);
711 }
712
713 #[test]
714 fn large_message_type() {
715 #[repr(C, align(64))]
716 struct LargeMessage {
717 data: [u8; 256],
718 }
719
720 let (mut tx, mut rx) = bounded::<LargeMessage>(8);
721
722 let msg = LargeMessage { data: [42u8; 256] };
723 assert!(tx.push(msg).is_ok());
724
725 let received = rx.pop().unwrap();
726 assert_eq!(received.data[0], 42);
727 assert_eq!(received.data[255], 42);
728 }
729
730 #[test]
731 fn multiple_laps() {
732 let (mut tx, mut rx) = bounded::<u64>(4);
733
734 for i in 0..40 {
736 assert!(tx.push(i).is_ok());
737 assert_eq!(rx.pop(), Some(i));
738 }
739 }
740
741 #[test]
742 fn capacity_rounds_to_power_of_two() {
743 let (tx, _) = bounded::<u64>(100);
744 assert_eq!(tx.capacity(), 128);
745
746 let (tx, _) = bounded::<u64>(1000);
747 assert_eq!(tx.capacity(), 1024);
748 }
749
750 #[test]
755 fn stress_single_consumer() {
756 use std::thread;
757
758 const COUNT: u64 = 100_000;
759
760 let (mut tx, mut rx) = bounded::<u64>(1024);
761
762 let producer = thread::spawn(move || {
763 for i in 0..COUNT {
764 while tx.push(i).is_err() {
765 std::hint::spin_loop();
766 }
767 }
768 });
769
770 let consumer = thread::spawn(move || {
771 let mut sum = 0u64;
772 let mut received = 0u64;
773 while received < COUNT {
774 if let Some(val) = rx.pop() {
775 sum = sum.wrapping_add(val);
776 received += 1;
777 } else {
778 std::hint::spin_loop();
779 }
780 }
781 sum
782 });
783
784 producer.join().unwrap();
785 let sum = consumer.join().unwrap();
786 assert_eq!(sum, COUNT * (COUNT - 1) / 2);
787 }
788
789 #[test]
790 fn stress_multiple_consumers() {
791 use std::thread;
792
793 const CONSUMERS: usize = 4;
794 const TOTAL: u64 = 100_000;
795
796 let (mut tx, rx) = bounded::<u64>(1024);
797
798 let handles: Vec<_> = (0..CONSUMERS)
799 .map(|_| {
800 let mut rx = rx.clone();
801 thread::spawn(move || {
802 let mut received = Vec::new();
803 loop {
804 if let Some(val) = rx.pop() {
805 received.push(val);
806 } else if rx.is_disconnected() {
807 while let Some(val) = rx.pop() {
808 received.push(val);
809 }
810 break;
811 } else {
812 std::hint::spin_loop();
813 }
814 }
815 received
816 })
817 })
818 .collect();
819
820 drop(rx);
821
822 let producer = thread::spawn(move || {
823 for i in 0..TOTAL {
824 while tx.push(i).is_err() {
825 std::hint::spin_loop();
826 }
827 }
828 });
829
830 producer.join().unwrap();
831
832 let mut all_received = Vec::new();
833 for h in handles {
834 all_received.extend(h.join().unwrap());
835 }
836
837 all_received.sort();
838 assert_eq!(all_received, (0..TOTAL).collect::<Vec<_>>());
839 }
840}