use std::cell::UnsafeCell;
use std::mem::MaybeUninit;
use std::sync::atomic::{AtomicUsize, Ordering};
use super::{CachePadded, DEFAULT_RING_CAPACITY, next_power_of_two};
pub struct LockFreeRing<T> {
buffer: Box<[UnsafeCell<MaybeUninit<T>>]>,
capacity: usize,
mask: usize,
head: CachePadded<AtomicUsize>,
tail: CachePadded<AtomicUsize>,
}
unsafe impl<T: Send> Send for LockFreeRing<T> {}
unsafe impl<T: Send> Sync for LockFreeRing<T> {}
impl<T> LockFreeRing<T> {
#[must_use]
pub fn new(capacity: usize) -> Self {
assert!(capacity > 0, "capacity must be > 0");
let capacity = next_power_of_two(capacity);
let mask = capacity - 1;
let buffer: Vec<UnsafeCell<MaybeUninit<T>>> = (0..capacity)
.map(|_| UnsafeCell::new(MaybeUninit::uninit()))
.collect();
Self {
buffer: buffer.into_boxed_slice(),
capacity,
mask,
head: CachePadded::new(AtomicUsize::new(0)),
tail: CachePadded::new(AtomicUsize::new(0)),
}
}
#[must_use]
pub fn with_default_capacity() -> Self {
Self::new(DEFAULT_RING_CAPACITY)
}
#[inline]
#[must_use]
pub const fn capacity(&self) -> usize {
self.capacity
}
#[inline]
#[must_use]
pub fn len(&self) -> usize {
let head = self.head.0.load(Ordering::Acquire);
let tail = self.tail.0.load(Ordering::Acquire);
head.wrapping_sub(tail)
}
#[inline]
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline]
#[must_use]
pub fn is_full(&self) -> bool {
self.len() >= self.capacity
}
#[inline]
#[must_use]
pub fn free_slots(&self) -> usize {
self.capacity - self.len()
}
#[inline]
pub fn enqueue(&self, item: T) -> Result<(), T> {
let head = self.head.0.load(Ordering::Relaxed);
let tail = self.tail.0.load(Ordering::Acquire);
if head.wrapping_sub(tail) >= self.capacity {
return Err(item);
}
let idx = head & self.mask;
unsafe {
(*self.buffer[idx].get()).write(item);
}
self.head.0.store(head.wrapping_add(1), Ordering::Release);
Ok(())
}
#[inline]
pub fn dequeue(&self) -> Option<T> {
let tail = self.tail.0.load(Ordering::Relaxed);
let head = self.head.0.load(Ordering::Acquire);
if tail == head {
return None;
}
let idx = tail & self.mask;
let item = unsafe { (*self.buffer[idx].get()).assume_init_read() };
self.tail.0.store(tail.wrapping_add(1), Ordering::Release);
Some(item)
}
pub fn enqueue_batch(&self, items: &[T]) -> usize
where
T: Copy,
{
let head = self.head.0.load(Ordering::Relaxed);
let tail = self.tail.0.load(Ordering::Acquire);
let free = self.capacity - head.wrapping_sub(tail);
let count = items.len().min(free);
if count == 0 {
return 0;
}
for (i, item) in items.iter().take(count).enumerate() {
let idx = (head + i) & self.mask;
unsafe {
(*self.buffer[idx].get()).write(*item);
}
}
self.head
.0
.store(head.wrapping_add(count), Ordering::Release);
count
}
pub fn dequeue_batch(&self, out: &mut [T]) -> usize
where
T: Copy,
{
let tail = self.tail.0.load(Ordering::Relaxed);
let head = self.head.0.load(Ordering::Acquire);
let available = head.wrapping_sub(tail);
let count = out.len().min(available);
if count == 0 {
return 0;
}
for (i, slot) in out[..count].iter_mut().enumerate() {
let idx = (tail + i) & self.mask;
*slot = unsafe { (*self.buffer[idx].get()).assume_init_read() };
}
self.tail
.0
.store(tail.wrapping_add(count), Ordering::Release);
count
}
#[inline]
pub unsafe fn peek(&self) -> Option<&T> {
let tail = self.tail.0.load(Ordering::Relaxed);
let head = self.head.0.load(Ordering::Acquire);
if tail == head {
return None;
}
let idx = tail & self.mask;
Some(unsafe { (*self.buffer[idx].get()).assume_init_ref() })
}
pub unsafe fn clear(&self) {
while self.dequeue().is_some() {}
}
}
impl<T> Drop for LockFreeRing<T> {
fn drop(&mut self) {
while self.dequeue().is_some() {}
}
}
#[allow(clippy::missing_fields_in_debug)] impl<T> std::fmt::Debug for LockFreeRing<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LockFreeRing")
.field("capacity", &self.capacity)
.field("len", &self.len())
.field("head", &self.head.0.load(Ordering::Relaxed))
.field("tail", &self.tail.0.load(Ordering::Relaxed))
.finish()
}
}
struct MpmcSlot<T> {
seq: AtomicUsize,
data: UnsafeCell<MaybeUninit<T>>,
}
pub struct MpmcRing<T> {
buffer: Box<[MpmcSlot<T>]>,
capacity: usize,
mask: usize,
head: CachePadded<AtomicUsize>,
tail: CachePadded<AtomicUsize>,
}
unsafe impl<T: Send> Send for MpmcRing<T> {}
unsafe impl<T: Send> Sync for MpmcRing<T> {}
impl<T: Copy> MpmcRing<T> {
#[must_use]
pub fn new(capacity: usize) -> Self {
assert!(capacity > 0, "capacity must be > 0");
let capacity = next_power_of_two(capacity);
let mask = capacity - 1;
let buffer: Vec<MpmcSlot<T>> = (0..capacity)
.map(|i| MpmcSlot {
seq: AtomicUsize::new(i),
data: UnsafeCell::new(MaybeUninit::uninit()),
})
.collect();
Self {
buffer: buffer.into_boxed_slice(),
capacity,
mask,
head: CachePadded::new(AtomicUsize::new(0)),
tail: CachePadded::new(AtomicUsize::new(0)),
}
}
#[inline]
#[must_use]
pub const fn capacity(&self) -> usize {
self.capacity
}
#[inline]
#[must_use]
pub fn len(&self) -> usize {
let head = self.head.0.load(Ordering::Acquire);
let tail = self.tail.0.load(Ordering::Acquire);
head.wrapping_sub(tail)
}
#[inline]
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn enqueue(&self, item: T) -> Result<(), T> {
let mut head = self.head.0.load(Ordering::Relaxed);
loop {
let slot = &self.buffer[head & self.mask];
let seq = slot.seq.load(Ordering::Acquire);
#[allow(clippy::cast_possible_wrap)]
let diff = (seq as isize).wrapping_sub(head as isize);
match diff.cmp(&0) {
std::cmp::Ordering::Equal => {
match self.head.0.compare_exchange_weak(
head,
head.wrapping_add(1),
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => {
unsafe { (*slot.data.get()).write(item) };
slot.seq.store(head.wrapping_add(1), Ordering::Release);
return Ok(());
}
Err(h) => head = h,
}
}
std::cmp::Ordering::Less => {
return Err(item);
}
std::cmp::Ordering::Greater => {
head = self.head.0.load(Ordering::Relaxed);
}
}
}
}
pub fn dequeue(&self) -> Option<T> {
let mut tail = self.tail.0.load(Ordering::Relaxed);
loop {
let slot = &self.buffer[tail & self.mask];
let seq = slot.seq.load(Ordering::Acquire);
#[allow(clippy::cast_possible_wrap)]
let diff = (seq as isize).wrapping_sub(tail.wrapping_add(1) as isize);
match diff.cmp(&0) {
std::cmp::Ordering::Equal => {
match self.tail.0.compare_exchange_weak(
tail,
tail.wrapping_add(1),
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => {
let item = unsafe { (*slot.data.get()).assume_init_read() };
slot.seq
.store(tail.wrapping_add(self.capacity), Ordering::Release);
return Some(item);
}
Err(t) => tail = t,
}
}
std::cmp::Ordering::Less => {
return None;
}
std::cmp::Ordering::Greater => {
tail = self.tail.0.load(Ordering::Relaxed);
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread;
#[test]
fn test_spsc_basic() {
let ring = LockFreeRing::<u32>::new(4);
assert!(ring.is_empty());
assert_eq!(ring.capacity(), 4);
ring.enqueue(1).unwrap();
ring.enqueue(2).unwrap();
ring.enqueue(3).unwrap();
ring.enqueue(4).unwrap();
assert!(ring.is_full());
assert!(ring.enqueue(5).is_err());
assert_eq!(ring.dequeue(), Some(1));
assert_eq!(ring.dequeue(), Some(2));
assert_eq!(ring.dequeue(), Some(3));
assert_eq!(ring.dequeue(), Some(4));
assert!(ring.is_empty());
assert_eq!(ring.dequeue(), None);
}
#[test]
fn test_spsc_batch() {
let ring = LockFreeRing::<u32>::new(8);
let items = [1, 2, 3, 4, 5];
let count = ring.enqueue_batch(&items);
assert_eq!(count, 5);
assert_eq!(ring.len(), 5);
let mut out = [0u32; 10];
let count = ring.dequeue_batch(&mut out);
assert_eq!(count, 5);
assert_eq!(&out[..5], &items);
}
#[test]
fn test_spsc_wrap() {
let ring = LockFreeRing::<u32>::new(4);
for round in 0..10 {
for i in 0..4 {
ring.enqueue(round * 4 + i).unwrap();
}
for i in 0..4 {
assert_eq!(ring.dequeue(), Some(round * 4 + i));
}
}
}
#[test]
fn test_spsc_threaded() {
let ring = Arc::new(LockFreeRing::<u64>::new(1024));
let ring_producer = Arc::clone(&ring);
let ring_consumer = Arc::clone(&ring);
let count = 100_000u64;
let producer = thread::spawn(move || {
for i in 0..count {
while ring_producer.enqueue(i).is_err() {
std::hint::spin_loop();
}
}
});
let consumer = thread::spawn(move || {
let mut received = 0u64;
let mut last = 0u64;
while received < count {
if let Some(v) = ring_consumer.dequeue() {
assert!(v >= last, "out of order: {} < {}", v, last);
last = v;
received += 1;
} else {
std::hint::spin_loop();
}
}
});
producer.join().unwrap();
consumer.join().unwrap();
}
#[test]
fn test_capacity_rounding() {
let ring = LockFreeRing::<u32>::new(3);
assert_eq!(ring.capacity(), 4);
let ring = LockFreeRing::<u32>::new(5);
assert_eq!(ring.capacity(), 8);
let ring = LockFreeRing::<u32>::new(1024);
assert_eq!(ring.capacity(), 1024);
}
#[test]
fn test_peek() {
let ring = LockFreeRing::<u32>::new(4);
unsafe {
assert!(ring.peek().is_none());
}
ring.enqueue(42).unwrap();
unsafe {
assert_eq!(ring.peek(), Some(&42));
assert_eq!(ring.peek(), Some(&42)); }
assert_eq!(ring.dequeue(), Some(42));
}
#[test]
fn test_mpmc_basic() {
let ring = MpmcRing::<u32>::new(4);
ring.enqueue(1).unwrap();
ring.enqueue(2).unwrap();
assert_eq!(ring.dequeue(), Some(1));
assert_eq!(ring.dequeue(), Some(2));
assert_eq!(ring.dequeue(), None);
}
#[test]
fn test_mpmc_stress() {
use std::sync::atomic::AtomicBool;
const PRODUCERS: usize = 4;
const CONSUMERS: usize = 4;
const ITEMS_PER_PRODUCER: usize = 10_000;
const TOTAL: usize = PRODUCERS * ITEMS_PER_PRODUCER;
let ring = Arc::new(MpmcRing::<usize>::new(256));
let producers_done = Arc::new(AtomicBool::new(false));
let mut producer_handles = Vec::new();
for p in 0..PRODUCERS {
let ring = Arc::clone(&ring);
producer_handles.push(thread::spawn(move || {
let base = p * ITEMS_PER_PRODUCER;
for i in 0..ITEMS_PER_PRODUCER {
while ring.enqueue(base + i).is_err() {
std::hint::spin_loop();
}
}
}));
}
let mut consumer_handles = Vec::new();
for _ in 0..CONSUMERS {
let ring = Arc::clone(&ring);
let done = Arc::clone(&producers_done);
consumer_handles.push(thread::spawn(move || {
let mut collected = Vec::new();
loop {
match ring.dequeue() {
Some(v) => collected.push(v),
None => {
if done.load(Ordering::Acquire) {
while let Some(v) = ring.dequeue() {
collected.push(v);
}
break;
}
std::hint::spin_loop();
}
}
}
collected
}));
}
for h in producer_handles {
h.join().unwrap();
}
producers_done.store(true, Ordering::Release);
let mut all: Vec<usize> = consumer_handles
.into_iter()
.flat_map(|h| h.join().unwrap())
.collect();
while let Some(v) = ring.dequeue() {
all.push(v);
}
all.sort_unstable();
all.dedup();
assert_eq!(
all.len(),
TOTAL,
"expected {TOTAL} unique items, got {} (duplicates or lost items)",
all.len()
);
}
}