use crossbeam_utils::CachePadded;
use std::cell::UnsafeCell;
use std::mem::MaybeUninit;
#[cfg(any(test, debug_assertions))]
use std::sync::atomic::AtomicBool;
use std::sync::atomic::{AtomicU64, Ordering};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) struct BufferFullError;
impl std::fmt::Display for BufferFullError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "ring buffer is full")
}
}
impl std::error::Error for BufferFullError {}
pub(crate) struct RingBuffer<T> {
buffer: Box<[UnsafeCell<MaybeUninit<T>>]>,
capacity: usize,
mask: usize,
head: CachePadded<AtomicU64>,
tail: CachePadded<AtomicU64>,
#[cfg(any(test, debug_assertions))]
producer_in_progress: AtomicBool,
#[cfg(any(test, debug_assertions))]
consumer_in_progress: AtomicBool,
}
#[cfg(any(test, debug_assertions))]
struct InProgressGuard<'a> {
flag: &'a AtomicBool,
}
#[cfg(any(test, debug_assertions))]
impl<'a> InProgressGuard<'a> {
#[inline]
fn enter(flag: &'a AtomicBool, label: &'static str) -> Self {
let was_in = flag.swap(true, Ordering::Acquire);
assert!(
!was_in,
"SPSC violation: {label} called concurrently with another \
{label} on the same RingBuffer (the SPSC contract requires \
at most one in-flight call per side at a time — typically \
upheld by the outer Shard mutex)",
);
Self { flag }
}
}
#[cfg(any(test, debug_assertions))]
impl<'a> Drop for InProgressGuard<'a> {
#[inline]
fn drop(&mut self) {
self.flag.store(false, Ordering::Release);
}
}
unsafe impl<T: Send> Send for RingBuffer<T> {}
unsafe impl<T: Send> Sync for RingBuffer<T> {}
impl<T> RingBuffer<T> {
pub fn new(capacity: usize) -> Self {
assert!(capacity.is_power_of_two(), "capacity must be a power of 2");
assert!(capacity >= 2, "capacity must be at least 2");
let buffer: Vec<UnsafeCell<MaybeUninit<T>>> = (0..capacity)
.map(|_| UnsafeCell::new(MaybeUninit::uninit()))
.collect();
Self {
buffer: buffer.into_boxed_slice(),
capacity,
mask: capacity - 1,
head: CachePadded::new(AtomicU64::new(0)),
tail: CachePadded::new(AtomicU64::new(0)),
#[cfg(any(test, debug_assertions))]
producer_in_progress: AtomicBool::new(false),
#[cfg(any(test, debug_assertions))]
consumer_in_progress: AtomicBool::new(false),
}
}
#[inline]
pub fn try_push(&self, value: T) -> Result<(), BufferFullError> {
#[cfg(any(test, debug_assertions))]
let _spsc_guard = InProgressGuard::enter(&self.producer_in_progress, "try_push");
let head = self.head.load(Ordering::Relaxed);
let tail = self.tail.load(Ordering::Acquire);
let len = head.wrapping_sub(tail);
if len >= (self.capacity as u64) - 1 {
return Err(BufferFullError);
}
let index = (head & self.mask as u64) as usize;
unsafe {
(*self.buffer[index].get()).write(value);
}
self.head.store(head.wrapping_add(1), Ordering::Release);
Ok(())
}
#[inline]
pub(crate) fn evict_oldest(&self) -> Option<T> {
#[cfg(any(test, debug_assertions))]
let _spsc_guard = InProgressGuard::enter(&self.producer_in_progress, "evict_oldest");
let tail = self.tail.load(Ordering::Relaxed);
let head = self.head.load(Ordering::Acquire);
if tail == head {
return None;
}
let index = (tail & self.mask as u64) as usize;
let value = unsafe { (*self.buffer[index].get()).assume_init_read() };
self.tail.store(tail.wrapping_add(1), Ordering::Release);
Some(value)
}
#[inline]
pub fn try_pop(&self) -> Option<T> {
#[cfg(any(test, debug_assertions))]
let _spsc_guard = InProgressGuard::enter(&self.consumer_in_progress, "try_pop");
let tail = self.tail.load(Ordering::Relaxed);
let head = self.head.load(Ordering::Acquire);
if tail == head {
return None;
}
let index = (tail & self.mask as u64) as usize;
let value = unsafe { (*self.buffer[index].get()).assume_init_read() };
self.tail.store(tail.wrapping_add(1), Ordering::Release);
Some(value)
}
#[inline]
pub fn pop_batch(&self, max: usize) -> Vec<T> {
#[cfg(any(test, debug_assertions))]
let _spsc_guard = InProgressGuard::enter(&self.consumer_in_progress, "pop_batch");
let tail = self.tail.load(Ordering::Relaxed);
let head = self.head.load(Ordering::Acquire);
let available = head.wrapping_sub(tail);
let count = available.min(max as u64) as usize;
if count == 0 {
return Vec::new();
}
let mut result = Vec::with_capacity(count);
for i in 0..count {
let index = (tail.wrapping_add(i as u64) & self.mask as u64) as usize;
let value = unsafe { (*self.buffer[index].get()).assume_init_read() };
result.push(value);
}
self.tail
.store(tail.wrapping_add(count as u64), Ordering::Release);
result
}
#[inline]
pub fn pop_batch_into(&self, dst: &mut Vec<T>, max: usize) -> usize {
#[cfg(any(test, debug_assertions))]
let _spsc_guard = InProgressGuard::enter(&self.consumer_in_progress, "pop_batch_into");
let tail = self.tail.load(Ordering::Relaxed);
let head = self.head.load(Ordering::Acquire);
let available = head.wrapping_sub(tail);
let count = available.min(max as u64) as usize;
if count == 0 {
return 0;
}
dst.reserve(count);
for i in 0..count {
let index = (tail.wrapping_add(i as u64) & self.mask as u64) as usize;
let value = unsafe { (*self.buffer[index].get()).assume_init_read() };
dst.push(value);
}
self.tail
.store(tail.wrapping_add(count as u64), Ordering::Release);
count
}
#[inline]
pub fn len(&self) -> usize {
let head = self.head.load(Ordering::Acquire);
let tail = self.tail.load(Ordering::Acquire);
head.wrapping_sub(tail) as usize
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline]
pub fn is_full(&self) -> bool {
self.len() >= self.capacity - 1
}
#[cfg(test)]
#[inline]
fn capacity(&self) -> usize {
self.capacity
}
#[cfg(test)]
#[inline]
fn free_slots(&self) -> usize {
self.capacity - 1 - self.len()
}
}
impl<T> Drop for RingBuffer<T> {
fn drop(&mut self) {
#[cfg(any(test, debug_assertions))]
{
self.producer_in_progress.store(false, Ordering::Release);
self.consumer_in_progress.store(false, Ordering::Release);
}
while self.try_pop().is_some() {}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_push_pop() {
let buf = RingBuffer::new(4);
assert!(buf.is_empty());
assert_eq!(buf.len(), 0);
buf.try_push(1).unwrap();
buf.try_push(2).unwrap();
buf.try_push(3).unwrap();
assert_eq!(buf.len(), 3);
assert!(buf.is_full());
assert!(buf.try_push(4).is_err());
assert_eq!(buf.try_pop(), Some(1));
assert_eq!(buf.try_pop(), Some(2));
assert_eq!(buf.try_pop(), Some(3));
assert_eq!(buf.try_pop(), None);
assert!(buf.is_empty());
}
#[test]
fn test_pop_batch() {
let buf = RingBuffer::new(8);
for i in 0..5 {
buf.try_push(i).unwrap();
}
let batch = buf.pop_batch(3);
assert_eq!(batch, vec![0, 1, 2]);
let batch = buf.pop_batch(10); assert_eq!(batch, vec![3, 4]);
assert!(buf.is_empty());
}
#[test]
fn test_pop_batch_into() {
let buf = RingBuffer::new(8);
for i in 0..5 {
buf.try_push(i).unwrap();
}
let mut dst = vec![999u32];
let drained = buf.pop_batch_into(&mut dst, 3);
assert_eq!(drained, 3);
assert_eq!(dst, vec![999, 0, 1, 2]);
dst.clear();
let drained = buf.pop_batch_into(&mut dst, 10);
assert_eq!(drained, 2);
assert_eq!(dst, vec![3, 4]);
assert!(buf.is_empty());
dst.clear();
let drained = buf.pop_batch_into(&mut dst, 100);
assert_eq!(drained, 0);
assert!(dst.is_empty());
}
#[test]
fn test_pop_batch_into_scratch_reuse_across_wraparound() {
let buf = RingBuffer::new(4);
let mut scratch: Vec<u32> = Vec::with_capacity(2);
let mut seen: Vec<u32> = Vec::new();
for round in 0..10u32 {
for i in 0..3 {
buf.try_push(round * 3 + i).unwrap();
}
let drained = buf.pop_batch_into(&mut scratch, 3);
assert_eq!(drained, 3);
seen.append(&mut scratch); }
let expected: Vec<u32> = (0..30).collect();
assert_eq!(seen, expected);
}
#[test]
fn test_wraparound() {
let buf = RingBuffer::new(4);
for round in 0..10 {
for i in 0..3 {
buf.try_push(round * 3 + i).unwrap();
}
for i in 0..3 {
assert_eq!(buf.try_pop(), Some(round * 3 + i));
}
}
}
#[test]
fn test_concurrent_spsc() {
use std::sync::Arc;
use std::thread;
let buf = Arc::new(RingBuffer::new(1024));
let buf_producer = buf.clone();
let buf_consumer = buf.clone();
let count = 100_000;
let producer = thread::spawn(move || {
for i in 0..count {
while buf_producer.try_push(i).is_err() {
std::hint::spin_loop();
}
}
});
let consumer = thread::spawn(move || {
let mut received = Vec::with_capacity(count);
while received.len() < count {
if let Some(val) = buf_consumer.try_pop() {
received.push(val);
} else {
std::hint::spin_loop();
}
}
received
});
producer.join().unwrap();
let received = consumer.join().unwrap();
assert_eq!(received.len(), count);
for (i, &val) in received.iter().enumerate() {
assert_eq!(val, i, "mismatch at index {}", i);
}
}
#[test]
#[should_panic(expected = "power of 2")]
fn test_non_power_of_two_capacity() {
let _ = RingBuffer::<i32>::new(5);
}
#[test]
fn test_drop() {
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
let drop_count = Arc::new(AtomicUsize::new(0));
struct DropCounter(Arc<AtomicUsize>);
impl Drop for DropCounter {
fn drop(&mut self) {
self.0.fetch_add(1, Ordering::SeqCst);
}
}
{
let buf = RingBuffer::new(8);
for _ in 0..5 {
buf.try_push(DropCounter(drop_count.clone())).unwrap();
}
}
assert_eq!(drop_count.load(Ordering::SeqCst), 5);
}
#[test]
fn test_buffer_full_error_display() {
let err = BufferFullError;
assert_eq!(format!("{}", err), "ring buffer is full");
}
#[test]
fn test_buffer_full_error_debug() {
let err = BufferFullError;
assert!(format!("{:?}", err).contains("BufferFullError"));
}
#[test]
fn test_buffer_full_error_is_error() {
let err: &dyn std::error::Error = &BufferFullError;
assert!(err.to_string().contains("full"));
}
#[test]
fn test_capacity_and_free_slots() {
let buf = RingBuffer::new(8);
assert_eq!(buf.capacity(), 8);
assert_eq!(buf.free_slots(), 7);
buf.try_push(1).unwrap();
assert_eq!(buf.free_slots(), 6);
buf.try_push(2).unwrap();
buf.try_push(3).unwrap();
assert_eq!(buf.free_slots(), 4);
}
#[test]
fn test_is_full() {
let buf = RingBuffer::new(4);
assert!(!buf.is_full());
buf.try_push(1).unwrap();
buf.try_push(2).unwrap();
assert!(!buf.is_full());
buf.try_push(3).unwrap();
assert!(buf.is_full());
}
#[test]
fn test_pop_batch_empty() {
let buf: RingBuffer<i32> = RingBuffer::new(8);
let batch = buf.pop_batch(10);
assert!(batch.is_empty());
}
#[test]
#[should_panic(expected = "at least 2")]
fn test_capacity_too_small() {
let _ = RingBuffer::<i32>::new(1);
}
#[test]
fn test_push_pop_at_exact_capacity() {
let buf = RingBuffer::new(4);
buf.try_push(1).unwrap();
buf.try_push(2).unwrap();
buf.try_push(3).unwrap();
assert!(buf.is_full());
assert!(buf.try_push(4).is_err());
assert_eq!(buf.try_pop(), Some(1));
buf.try_push(4).unwrap();
assert!(buf.is_full());
assert_eq!(buf.try_pop(), Some(2));
assert_eq!(buf.try_pop(), Some(3));
assert_eq!(buf.try_pop(), Some(4));
assert!(buf.is_empty());
}
#[test]
fn test_push_pop_boundary_stress() {
let buf = RingBuffer::new(4);
for round in 0..100 {
for i in 0..3 {
buf.try_push(round * 3 + i)
.unwrap_or_else(|_| panic!("push failed at round {} item {}", round, i));
}
assert!(buf.is_full());
assert!(buf.try_push(999).is_err());
for i in 0..3 {
assert_eq!(buf.try_pop(), Some(round * 3 + i));
}
assert!(buf.is_empty());
}
}
#[test]
fn ring_buffer_cursors_are_u64_on_every_target() {
let buf: RingBuffer<u32> = RingBuffer::new(4);
let head_val: u64 = buf.head.load(Ordering::Relaxed);
let tail_val: u64 = buf.tail.load(Ordering::Relaxed);
assert_eq!(head_val, 0);
assert_eq!(tail_val, 0);
let len: u64 = head_val.wrapping_sub(tail_val);
assert_eq!(len, 0);
}
#[test]
fn sequential_cross_thread_push_is_allowed() {
use std::sync::Arc;
use std::thread;
let buf = Arc::new(RingBuffer::new(1024));
buf.try_push(1).unwrap();
let buf2 = buf.clone();
let result = thread::spawn(move || buf2.try_push(2).unwrap()).join();
assert!(
result.is_ok(),
"sequential cross-thread push must be allowed (the SPSC \
contract is about non-concurrency, not thread identity — \
tokio task migration must not trip the tripwire)",
);
}
#[test]
fn sequential_cross_thread_pop_is_allowed() {
use std::sync::Arc;
use std::thread;
let buf = Arc::new(RingBuffer::new(1024));
buf.try_push(1).unwrap();
buf.try_push(2).unwrap();
let _ = buf.try_pop();
let buf2 = buf.clone();
let result = thread::spawn(move || buf2.try_pop()).join();
assert!(
result.is_ok(),
"sequential cross-thread pop must be allowed",
);
}
#[test]
fn sequential_cross_thread_evict_oldest_is_allowed() {
use std::sync::Arc;
use std::thread;
let buf = Arc::new(RingBuffer::new(4));
buf.try_push(1).unwrap();
let buf2 = buf.clone();
let result = thread::spawn(move || buf2.evict_oldest()).join();
assert!(
result.is_ok(),
"sequential cross-thread evict_oldest must be allowed",
);
}
#[test]
fn concurrent_producer_panics_via_simulated_in_progress_flag() {
let buf = RingBuffer::<i32>::new(8);
buf.producer_in_progress.store(true, Ordering::Release);
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
buf.try_push(1).unwrap();
}));
assert!(
result.is_err(),
"try_push must panic when a producer is already in-progress (real SPSC violation)",
);
buf.producer_in_progress.store(false, Ordering::Release);
}
#[test]
fn concurrent_consumer_panics_via_simulated_in_progress_flag() {
let buf = RingBuffer::<i32>::new(8);
buf.try_push(1).unwrap();
buf.consumer_in_progress.store(true, Ordering::Release);
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let _ = buf.try_pop();
}));
assert!(
result.is_err(),
"try_pop must panic when a consumer is already in-progress",
);
buf.consumer_in_progress.store(false, Ordering::Release);
}
#[test]
fn concurrent_evict_oldest_panics_via_simulated_in_progress_flag() {
let buf = RingBuffer::<i32>::new(4);
buf.try_push(1).unwrap();
buf.producer_in_progress.store(true, Ordering::Release);
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let _ = buf.evict_oldest();
}));
assert!(
result.is_err(),
"evict_oldest must panic when a producer is already in-progress",
);
buf.producer_in_progress.store(false, Ordering::Release);
}
#[test]
fn guard_releases_flag_on_early_return() {
let buf = RingBuffer::<i32>::new(4);
let mut scratch = Vec::new();
let popped = buf.pop_batch_into(&mut scratch, 8);
assert_eq!(popped, 0);
assert!(
!buf.consumer_in_progress.load(Ordering::Acquire),
"in-progress flag must be cleared on early return",
);
buf.try_push(42).unwrap();
assert_eq!(buf.pop_batch_into(&mut scratch, 8), 1);
}
}