use std::{
cell::UnsafeCell,
mem::{ManuallyDrop, MaybeUninit},
sync::atomic::{AtomicUsize, Ordering::*},
};
pub(crate) const UNSELECTED: usize = usize::MAX;
struct RingSlot<T> {
sequence: AtomicUsize,
value: UnsafeCell<MaybeUninit<T>>,
}
unsafe impl<T: Send> Send for RingSlot<T> {}
unsafe impl<T: Send> Sync for RingSlot<T> {}
pub(crate) struct LockFreeBoundedRing<T> {
slots: Box<[RingSlot<T>]>,
cap: usize,
tail: AtomicUsize,
head: AtomicUsize,
}
impl<T> LockFreeBoundedRing<T> {
pub(crate) fn new(cap: usize) -> Self {
let slots: Box<[RingSlot<T>]> = (0..cap)
.map(|i| RingSlot {
sequence: AtomicUsize::new(i),
value: UnsafeCell::new(MaybeUninit::uninit()),
})
.collect();
LockFreeBoundedRing {
slots,
cap,
tail: AtomicUsize::new(0),
head: AtomicUsize::new(0),
}
}
pub(crate) fn try_push(&self, value: T) -> Result<(), T> {
if self.cap == 0 {
return Err(value);
}
let value = ManuallyDrop::new(value);
loop {
let pos = self.tail.load(Relaxed);
let slot = &self.slots[pos % self.cap];
let seq = slot.sequence.load(Acquire);
let diff = seq as isize - pos as isize;
if diff == 0 {
if self
.tail
.compare_exchange_weak(pos, pos + 1, Relaxed, Relaxed)
.is_ok()
{
unsafe {
(*slot.value.get())
.as_mut_ptr()
.copy_from_nonoverlapping(&*value as *const T, 1);
}
slot.sequence.store(pos + 1, Release);
return Ok(());
}
} else if diff < 0 {
return Err(ManuallyDrop::into_inner(value));
}
}
}
pub(crate) fn try_pop(&self) -> Option<T> {
if self.cap == 0 {
return None;
}
loop {
let pos = self.head.load(Relaxed);
let slot = &self.slots[pos % self.cap];
let seq = slot.sequence.load(Acquire);
let diff = seq as isize - (pos + 1) as isize;
if diff == 0 {
if self
.head
.compare_exchange_weak(pos, pos + 1, Relaxed, Relaxed)
.is_ok()
{
let value = unsafe { (*slot.value.get()).assume_init_read() };
slot.sequence.store(pos + self.cap, Release);
return Some(value);
}
} else if diff < 0 {
return None;
}
}
}
pub(crate) fn is_empty(&self) -> bool {
if self.cap == 0 {
return true;
}
let pos = self.head.load(Acquire);
self.slots[pos % self.cap].sequence.load(Acquire) != pos + 1
}
pub(crate) fn is_full(&self) -> bool {
if self.cap == 0 {
return true;
}
let tail = self.tail.load(Acquire);
let head = self.head.load(Acquire);
tail.wrapping_sub(head) >= self.cap
}
}
impl<T> Drop for LockFreeBoundedRing<T> {
fn drop(&mut self) {
if self.cap == 0 {
return;
}
let head = *self.head.get_mut();
let tail = *self.tail.get_mut();
for i in head..tail {
let slot = &mut self.slots[i % self.cap];
if *slot.sequence.get_mut() == i + 1 {
unsafe { (*slot.value.get()).assume_init_drop() };
}
}
}
}
#[cfg(test)]
mod tests {
use super::LockFreeBoundedRing;
#[test]
fn ring_basic_push_pop() {
let ring = LockFreeBoundedRing::new(4);
assert!(ring.is_empty());
assert!(!ring.is_full());
ring.try_push(1u32).unwrap();
ring.try_push(2).unwrap();
assert!(!ring.is_empty());
assert!(!ring.is_full());
assert_eq!(ring.try_pop(), Some(1));
assert_eq!(ring.try_pop(), Some(2));
assert_eq!(ring.try_pop(), None);
assert!(ring.is_empty());
}
#[test]
fn ring_zero_capacity() {
let ring = LockFreeBoundedRing::<i32>::new(0);
assert!(ring.is_empty());
assert!(ring.is_full());
assert!(ring.try_push(1).is_err());
assert_eq!(ring.try_pop(), None);
}
#[test]
fn ring_full_rejects_push() {
let ring = LockFreeBoundedRing::new(2);
ring.try_push(10u32).unwrap();
ring.try_push(20).unwrap();
assert!(ring.is_full());
assert!(ring.try_push(30).is_err());
assert_eq!(ring.try_pop(), Some(10));
assert!(!ring.is_full());
ring.try_push(30).unwrap();
}
#[test]
fn ring_fifo_ordering() {
let ring = LockFreeBoundedRing::new(8);
for i in 0u32..8 {
ring.try_push(i).unwrap();
}
for i in 0u32..8 {
assert_eq!(ring.try_pop(), Some(i));
}
}
#[test]
fn ring_wrap_around() {
let ring = LockFreeBoundedRing::new(4);
for i in 0u32..4 {
ring.try_push(i).unwrap();
}
for i in 0u32..4 {
assert_eq!(ring.try_pop(), Some(i));
}
for i in 4u32..8 {
ring.try_push(i).unwrap();
}
for i in 4u32..8 {
assert_eq!(ring.try_pop(), Some(i));
}
assert!(ring.is_empty());
}
#[test]
fn ring_is_full_capacity_one() {
let ring = LockFreeBoundedRing::new(1);
assert!(!ring.is_full());
ring.try_push(42u32).unwrap();
assert!(ring.is_full());
assert_eq!(ring.try_pop(), Some(42));
assert!(!ring.is_full());
ring.try_push(99).unwrap();
assert!(ring.is_full());
}
#[test]
fn ring_drop_runs_for_buffered_items() {
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
let counter = Arc::new(AtomicUsize::new(0));
#[derive(Debug)]
struct Guard(Arc<AtomicUsize>);
impl Drop for Guard {
fn drop(&mut self) {
self.0.fetch_add(1, Ordering::Relaxed);
}
}
{
let ring = LockFreeBoundedRing::new(4);
ring.try_push(Guard(Arc::clone(&counter))).unwrap();
ring.try_push(Guard(Arc::clone(&counter))).unwrap();
}
assert_eq!(counter.load(Ordering::Relaxed), 2);
}
}
#[cfg(feature = "debug-logs")]
macro_rules! log_debug {
($($arg:tt)*) => {
log::debug!($($arg)*)
};
}
#[cfg(not(feature = "debug-logs"))]
macro_rules! log_debug {
($($arg:tt)*) => {
()
};
}
#[cfg(feature = "debug-logs")]
macro_rules! log_trace {
($($arg:tt)*) => {
log::trace!($($arg)*)
};
}
#[cfg(not(feature = "debug-logs"))]
macro_rules! log_trace {
($($arg:tt)*) => {
()
};
}