use std::cell::{Cell, UnsafeCell};
use std::fmt;
use std::mem::MaybeUninit;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use crossbeam_utils::CachePadded;
use crate::Full;
#[deprecated(since = "1.3.0", note = "renamed to ring_buffer()")]
#[inline]
pub fn bounded<T>(capacity: usize) -> (Producer<T>, Consumer<T>) {
ring_buffer(capacity)
}
pub fn ring_buffer<T>(capacity: usize) -> (Producer<T>, Consumer<T>) {
assert!(capacity > 0, "capacity must be non-zero");
let capacity = capacity
.checked_next_power_of_two()
.expect("capacity too large (must be <= usize::MAX / 2)");
let mask = capacity - 1;
let slots: Vec<Slot<T>> = (0..capacity)
.map(|_| Slot {
turn: AtomicUsize::new(0),
data: UnsafeCell::new(MaybeUninit::uninit()),
})
.collect();
let slots = Box::into_raw(slots.into_boxed_slice()) as *mut Slot<T>;
let shift = capacity.trailing_zeros();
let shared = Arc::new(Shared {
tail: CachePadded::new(AtomicUsize::new(0)),
head: CachePadded::new(AtomicUsize::new(0)),
slots,
capacity,
shift,
mask,
});
(
Producer {
cached_head: Cell::new(0),
slots,
mask,
capacity,
shift,
shared: Arc::clone(&shared),
},
Consumer {
local_head: Cell::new(0),
slots,
mask,
shift,
shared,
},
)
}
struct Slot<T> {
turn: AtomicUsize,
data: UnsafeCell<MaybeUninit<T>>,
}
#[repr(C)]
struct Shared<T> {
tail: CachePadded<AtomicUsize>,
head: CachePadded<AtomicUsize>,
slots: *mut Slot<T>,
capacity: usize,
shift: u32,
mask: usize,
}
unsafe impl<T: Send> Send for Shared<T> {}
unsafe impl<T: Send> Sync for Shared<T> {}
impl<T> Drop for Shared<T> {
fn drop(&mut self) {
let head = self.head.load(Ordering::Relaxed);
let tail = self.tail.load(Ordering::Relaxed);
let mut i = head;
while i != tail {
let slot = unsafe { &*self.slots.add(i & self.mask) };
let turn = i >> self.shift;
if slot.turn.load(Ordering::Relaxed) == turn * 2 + 1 {
unsafe { (*slot.data.get()).assume_init_drop() };
}
i = i.wrapping_add(1);
}
unsafe {
let _ = Box::from_raw(std::ptr::slice_from_raw_parts_mut(
self.slots,
self.capacity,
));
}
}
}
#[repr(C)]
pub struct Producer<T> {
cached_head: Cell<usize>,
slots: *mut Slot<T>,
mask: usize,
capacity: usize,
shift: u32,
shared: Arc<Shared<T>>,
}
impl<T> Clone for Producer<T> {
fn clone(&self) -> Self {
Producer {
cached_head: Cell::new(self.shared.head.load(Ordering::Relaxed)),
slots: self.slots,
mask: self.mask,
capacity: self.capacity,
shift: self.shift,
shared: Arc::clone(&self.shared),
}
}
}
unsafe impl<T: Send> Send for Producer<T> {}
impl<T> Producer<T> {
#[inline]
#[must_use = "push returns Err if full, which should be handled"]
pub fn push(&self, value: T) -> Result<(), Full<T>> {
let mut spin_count = 0u32;
loop {
let tail = self.shared.tail.load(Ordering::Relaxed);
if tail.wrapping_sub(self.cached_head.get()) >= self.capacity {
self.cached_head
.set(self.shared.head.load(Ordering::Acquire));
if tail.wrapping_sub(self.cached_head.get()) >= self.capacity {
return Err(Full(value));
}
}
let slot = unsafe { &*self.slots.add(tail & self.mask) };
let turn = tail >> self.shift;
let expected_stamp = turn * 2;
let stamp = slot.turn.load(Ordering::Acquire);
if stamp == expected_stamp {
if self
.shared
.tail
.compare_exchange_weak(
tail,
tail.wrapping_add(1),
Ordering::Relaxed,
Ordering::Relaxed,
)
.is_ok()
{
unsafe { (*slot.data.get()).write(value) };
slot.turn.store(turn * 2 + 1, Ordering::Release);
return Ok(());
}
}
let spins = 1 << spin_count.min(6);
for _ in 0..spins {
std::hint::spin_loop();
}
spin_count += 1;
if spin_count >= 5 && self.is_disconnected() {
return Err(Full(value));
}
}
}
#[inline]
pub fn capacity(&self) -> usize {
1 << self.shift
}
#[inline]
pub fn is_disconnected(&self) -> bool {
Arc::strong_count(&self.shared) == 1
}
}
impl<T> fmt::Debug for Producer<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Producer")
.field("capacity", &self.capacity())
.finish_non_exhaustive()
}
}
#[repr(C)]
pub struct Consumer<T> {
local_head: Cell<usize>,
slots: *mut Slot<T>,
mask: usize,
shift: u32,
shared: Arc<Shared<T>>,
}
unsafe impl<T: Send> Send for Consumer<T> {}
impl<T> Consumer<T> {
#[inline]
pub fn pop(&self) -> Option<T> {
let head = self.local_head.get();
let slot = unsafe { &*self.slots.add(head & self.mask) };
let turn = head >> self.shift;
if slot.turn.load(Ordering::Acquire) != turn * 2 + 1 {
return None;
}
let value = unsafe { (*slot.data.get()).assume_init_read() };
slot.turn.store((turn + 1) * 2, Ordering::Release);
let new_head = head.wrapping_add(1);
self.local_head.set(new_head);
self.shared.head.store(new_head, Ordering::Release);
Some(value)
}
#[inline]
pub fn capacity(&self) -> usize {
1 << self.shift
}
#[inline]
pub fn is_disconnected(&self) -> bool {
Arc::strong_count(&self.shared) == 1
}
}
impl<T> fmt::Debug for Consumer<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Consumer")
.field("capacity", &self.capacity())
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn basic_push_pop() {
let (tx, rx) = ring_buffer::<u64>(4);
assert!(tx.push(1).is_ok());
assert!(tx.push(2).is_ok());
assert!(tx.push(3).is_ok());
assert_eq!(rx.pop(), Some(1));
assert_eq!(rx.pop(), Some(2));
assert_eq!(rx.pop(), Some(3));
assert_eq!(rx.pop(), None);
}
#[test]
fn empty_pop_returns_none() {
let (_, rx) = ring_buffer::<u64>(4);
assert_eq!(rx.pop(), None);
assert_eq!(rx.pop(), None);
}
#[test]
fn fill_then_drain() {
let (tx, rx) = ring_buffer::<u64>(4);
for i in 0..4 {
assert!(tx.push(i).is_ok());
}
for i in 0..4 {
assert_eq!(rx.pop(), Some(i));
}
assert_eq!(rx.pop(), None);
}
#[test]
fn push_returns_error_when_full() {
let (tx, _rx) = ring_buffer::<u64>(4);
assert!(tx.push(1).is_ok());
assert!(tx.push(2).is_ok());
assert!(tx.push(3).is_ok());
assert!(tx.push(4).is_ok());
let err = tx.push(5).unwrap_err();
assert_eq!(err.into_inner(), 5);
}
#[test]
fn interleaved_single_producer() {
let (tx, rx) = ring_buffer::<u64>(8);
for i in 0..1000 {
assert!(tx.push(i).is_ok());
assert_eq!(rx.pop(), Some(i));
}
}
#[test]
fn partial_fill_drain_cycles() {
let (tx, rx) = ring_buffer::<u64>(8);
for round in 0..100 {
for i in 0..4 {
assert!(tx.push(round * 4 + i).is_ok());
}
for i in 0..4 {
assert_eq!(rx.pop(), Some(round * 4 + i));
}
}
}
#[test]
fn two_producers_single_consumer() {
use std::thread;
let (tx, rx) = ring_buffer::<u64>(64);
let tx2 = tx.clone();
let h1 = thread::spawn(move || {
for i in 0..1000 {
while tx.push(i).is_err() {
std::hint::spin_loop();
}
}
});
let h2 = thread::spawn(move || {
for i in 1000..2000 {
while tx2.push(i).is_err() {
std::hint::spin_loop();
}
}
});
let mut received = Vec::new();
while received.len() < 2000 {
if let Some(val) = rx.pop() {
received.push(val);
} else {
std::hint::spin_loop();
}
}
h1.join().unwrap();
h2.join().unwrap();
received.sort_unstable();
assert_eq!(received, (0..2000).collect::<Vec<_>>());
}
#[test]
fn four_producers_single_consumer() {
use std::thread;
let (tx, rx) = ring_buffer::<u64>(256);
let handles: Vec<_> = (0..4)
.map(|p| {
let tx = tx.clone();
thread::spawn(move || {
for i in 0..1000 {
let val = p * 1000 + i;
while tx.push(val).is_err() {
std::hint::spin_loop();
}
}
})
})
.collect();
drop(tx);
let mut received = Vec::new();
while received.len() < 4000 {
if let Some(val) = rx.pop() {
received.push(val);
} else if rx.is_disconnected() && received.len() < 4000 {
std::hint::spin_loop();
} else {
std::hint::spin_loop();
}
}
for h in handles {
h.join().unwrap();
}
received.sort_unstable();
let expected: Vec<u64> = (0..4)
.flat_map(|p| (0..1000).map(move |i| p * 1000 + i))
.collect();
let mut expected_sorted = expected;
expected_sorted.sort_unstable();
assert_eq!(received, expected_sorted);
}
#[test]
fn single_slot_bounded() {
let (tx, rx) = ring_buffer::<u64>(1);
assert!(tx.push(1).is_ok());
assert!(tx.push(2).is_err());
assert_eq!(rx.pop(), Some(1));
assert!(tx.push(2).is_ok());
}
#[test]
fn producer_disconnected() {
let (tx, rx) = ring_buffer::<u64>(4);
assert!(!rx.is_disconnected());
drop(tx);
assert!(rx.is_disconnected());
}
#[test]
fn consumer_disconnected() {
let (tx, rx) = ring_buffer::<u64>(4);
assert!(!tx.is_disconnected());
drop(rx);
assert!(tx.is_disconnected());
}
#[test]
fn multiple_producers_one_disconnects() {
let (tx1, rx) = ring_buffer::<u64>(4);
let tx2 = tx1.clone();
assert!(!rx.is_disconnected());
drop(tx1);
assert!(!rx.is_disconnected()); drop(tx2);
assert!(rx.is_disconnected());
}
#[test]
fn drop_cleans_up_remaining() {
use std::sync::atomic::AtomicUsize;
static DROP_COUNT: AtomicUsize = AtomicUsize::new(0);
struct DropCounter;
impl Drop for DropCounter {
fn drop(&mut self) {
DROP_COUNT.fetch_add(1, Ordering::SeqCst);
}
}
DROP_COUNT.store(0, Ordering::SeqCst);
let (tx, rx) = ring_buffer::<DropCounter>(4);
let _ = tx.push(DropCounter);
let _ = tx.push(DropCounter);
let _ = tx.push(DropCounter);
assert_eq!(DROP_COUNT.load(Ordering::SeqCst), 0);
drop(tx);
drop(rx);
assert_eq!(DROP_COUNT.load(Ordering::SeqCst), 3);
}
#[test]
fn zero_sized_type() {
let (tx, rx) = ring_buffer::<()>(8);
let _ = tx.push(());
let _ = tx.push(());
assert_eq!(rx.pop(), Some(()));
assert_eq!(rx.pop(), Some(()));
assert_eq!(rx.pop(), None);
}
#[test]
fn string_type() {
let (tx, rx) = ring_buffer::<String>(4);
let _ = tx.push("hello".to_string());
let _ = tx.push("world".to_string());
assert_eq!(rx.pop(), Some("hello".to_string()));
assert_eq!(rx.pop(), Some("world".to_string()));
}
#[test]
#[should_panic(expected = "capacity must be non-zero")]
fn zero_capacity_panics() {
let _ = ring_buffer::<u64>(0);
}
#[test]
fn large_message_type() {
#[repr(C, align(64))]
struct LargeMessage {
data: [u8; 256],
}
let (tx, rx) = ring_buffer::<LargeMessage>(8);
let msg = LargeMessage { data: [42u8; 256] };
assert!(tx.push(msg).is_ok());
let received = rx.pop().unwrap();
assert_eq!(received.data[0], 42);
assert_eq!(received.data[255], 42);
}
#[test]
fn multiple_laps() {
let (tx, rx) = ring_buffer::<u64>(4);
for i in 0..40 {
assert!(tx.push(i).is_ok());
assert_eq!(rx.pop(), Some(i));
}
}
#[test]
fn capacity_rounds_to_power_of_two() {
let (tx, _) = ring_buffer::<u64>(100);
assert_eq!(tx.capacity(), 128);
let (tx, _) = ring_buffer::<u64>(1000);
assert_eq!(tx.capacity(), 1024);
}
#[test]
fn stress_single_producer() {
use std::thread;
const COUNT: u64 = 100_000;
let (tx, rx) = ring_buffer::<u64>(1024);
let producer = thread::spawn(move || {
for i in 0..COUNT {
while tx.push(i).is_err() {
std::hint::spin_loop();
}
}
});
let consumer = thread::spawn(move || {
let mut sum = 0u64;
let mut received = 0u64;
while received < COUNT {
if let Some(val) = rx.pop() {
sum = sum.wrapping_add(val);
received += 1;
} else {
std::hint::spin_loop();
}
}
sum
});
producer.join().unwrap();
let sum = consumer.join().unwrap();
assert_eq!(sum, COUNT * (COUNT - 1) / 2);
}
#[test]
fn stress_multiple_producers() {
use std::thread;
const PRODUCERS: u64 = 4;
const PER_PRODUCER: u64 = 25_000;
const TOTAL: u64 = PRODUCERS * PER_PRODUCER;
let (tx, rx) = ring_buffer::<u64>(1024);
let handles: Vec<_> = (0..PRODUCERS)
.map(|_| {
let tx = tx.clone();
thread::spawn(move || {
for i in 0..PER_PRODUCER {
while tx.push(i).is_err() {
std::hint::spin_loop();
}
}
})
})
.collect();
drop(tx);
let mut received = 0u64;
while received < TOTAL {
if rx.pop().is_some() {
received += 1;
} else {
std::hint::spin_loop();
}
}
for h in handles {
h.join().unwrap();
}
assert_eq!(received, TOTAL);
}
}