use super::core::RingBufCore;
use crate::shim::atomic::Ordering;
use alloc::vec::Vec;
use core::fmt;
#[cfg(feature = "loom")]
fn backoff() {
loom::thread::yield_now();
}
#[cfg(not(feature = "loom"))]
fn backoff() {
core::hint::spin_loop();
}
pub trait AtomicElement: Send + Sync {
type Primitive: Copy;
fn load(&self, order: Ordering) -> Self::Primitive;
fn store(&self, val: Self::Primitive, order: Ordering);
fn swap(&self, val: Self::Primitive, order: Ordering) -> Self::Primitive;
}
use crate::shim::atomic::{
AtomicBool, AtomicI8, AtomicI16, AtomicI32, AtomicI64, AtomicIsize, AtomicU8, AtomicU16,
AtomicU32, AtomicU64, AtomicUsize,
};
macro_rules! impl_atomic_element {
($atomic:ty, $primitive:ty) => {
impl AtomicElement for $atomic {
type Primitive = $primitive;
#[inline(always)]
fn load(&self, order: Ordering) -> Self::Primitive {
self.load(order)
}
#[inline(always)]
fn store(&self, val: Self::Primitive, order: Ordering) {
self.store(val, order);
}
#[inline(always)]
fn swap(&self, val: Self::Primitive, order: Ordering) -> Self::Primitive {
self.swap(val, order)
}
}
};
}
impl_atomic_element!(AtomicU8, u8);
impl_atomic_element!(AtomicU16, u16);
impl_atomic_element!(AtomicU32, u32);
impl_atomic_element!(AtomicU64, u64);
impl_atomic_element!(AtomicUsize, usize);
impl_atomic_element!(AtomicI8, i8);
impl_atomic_element!(AtomicI16, i16);
impl_atomic_element!(AtomicI32, i32);
impl_atomic_element!(AtomicI64, i64);
impl_atomic_element!(AtomicIsize, isize);
impl_atomic_element!(AtomicBool, bool);
pub trait PushDispatch<T: AtomicElement, const N: usize, const OVERWRITE: bool> {
type PushOutput;
fn push_impl(
ringbuf: &AtomicRingBuf<T, N, OVERWRITE>,
value: T::Primitive,
order: Ordering,
) -> Self::PushOutput;
}
pub struct PushMarker<const OVERWRITE: bool>;
impl<T: AtomicElement, const N: usize> PushDispatch<T, N, true> for PushMarker<true> {
type PushOutput = Option<T::Primitive>;
#[inline]
fn push_impl(
ringbuf: &AtomicRingBuf<T, N, true>,
value: T::Primitive,
order: Ordering,
) -> Self::PushOutput {
let write = ringbuf.core.write_idx().fetch_add(1, Ordering::Relaxed);
let read = ringbuf.core.read_idx().load(Ordering::Acquire);
if write.wrapping_sub(read) >= ringbuf.core.capacity() {
ringbuf
.core
.read_idx()
.compare_exchange(
read,
read.wrapping_add(1),
Ordering::Release,
Ordering::Relaxed,
)
.ok();
let index = write & ringbuf.core.mask();
let old_value = unsafe {
let slot = ringbuf.core.peek_at(index);
slot.swap(value, order)
};
loop {
let commit = ringbuf.write_commit.load(Ordering::Acquire);
if commit == write {
ringbuf
.write_commit
.store(write.wrapping_add(1), Ordering::Release);
return Some(old_value);
}
backoff();
}
} else {
let index = write & ringbuf.core.mask();
unsafe {
let slot = ringbuf.core.peek_at(index);
slot.store(value, order);
}
loop {
let commit = ringbuf.write_commit.load(Ordering::Acquire);
if commit == write {
ringbuf
.write_commit
.store(write.wrapping_add(1), Ordering::Release);
return None;
}
backoff();
}
}
}
}
impl<T: AtomicElement, const N: usize> PushDispatch<T, N, false> for PushMarker<false> {
type PushOutput = Result<(), T::Primitive>;
#[inline]
fn push_impl(
ringbuf: &AtomicRingBuf<T, N, false>,
value: T::Primitive,
order: Ordering,
) -> Self::PushOutput {
loop {
let read = ringbuf.core.read_idx().load(Ordering::Acquire);
let write = ringbuf.core.write_idx().load(Ordering::Relaxed);
if write.wrapping_sub(read) >= ringbuf.core.capacity() {
return Err(value);
}
if ringbuf
.core
.write_idx()
.compare_exchange(
write,
write.wrapping_add(1),
Ordering::Relaxed,
Ordering::Relaxed,
)
.is_ok()
{
let index = write & ringbuf.core.mask();
unsafe {
let slot = ringbuf.core.peek_at(index);
slot.store(value, order);
}
loop {
let commit = ringbuf.write_commit.load(Ordering::Acquire);
if commit == write {
ringbuf
.write_commit
.store(write.wrapping_add(1), Ordering::Release);
return Ok(());
}
backoff();
}
}
backoff();
}
}
}
pub struct AtomicRingBuf<T: AtomicElement, const N: usize, const OVERWRITE: bool = true> {
core: RingBufCore<T, N>,
write_commit: AtomicUsize,
}
impl<T: AtomicElement, const N: usize, const OVERWRITE: bool> AtomicRingBuf<T, N, OVERWRITE> {
#[inline]
pub fn new(capacity: usize) -> Self
where
T: Default,
{
let uninit = Self::new_uninit(capacity);
unsafe {
for i in 0..uninit.core.capacity() {
uninit.core.write_at(i, T::default());
}
}
uninit
}
#[inline]
pub fn new_uninit(capacity: usize) -> Self {
Self {
core: RingBufCore::new(capacity),
write_commit: AtomicUsize::new(0),
}
}
#[inline]
pub fn capacity(&self) -> usize {
self.core.capacity()
}
#[inline]
pub fn len(&self) -> usize {
let commit = self.write_commit.load(Ordering::Acquire);
let read = self.core.read_idx().load(Ordering::Acquire);
commit.wrapping_sub(read).min(self.core.capacity())
}
#[inline]
pub fn is_empty(&self) -> bool {
let commit = self.write_commit.load(Ordering::Acquire);
let read = self.core.read_idx().load(Ordering::Acquire);
commit == read
}
#[inline]
pub fn is_full(&self) -> bool {
self.core.is_full()
}
#[inline(always)]
pub fn push(
&self,
value: T::Primitive,
order: Ordering,
) -> <PushMarker<OVERWRITE> as PushDispatch<T, N, OVERWRITE>>::PushOutput
where
PushMarker<OVERWRITE>: PushDispatch<T, N, OVERWRITE>,
{
PushMarker::<OVERWRITE>::push_impl(self, value, order)
}
#[inline]
pub fn pop(&self, order: Ordering) -> Option<T::Primitive> {
let read = self.core.read_idx().load(Ordering::Relaxed);
let commit = self.write_commit.load(Ordering::Acquire);
if read == commit {
return None;
}
let index = read & self.core.mask();
let value = unsafe {
let slot = self.core.peek_at(index);
slot.load(order)
};
self.core
.read_idx()
.store(read.wrapping_add(1), Ordering::Release);
Some(value)
}
#[inline]
pub fn peek(&self, order: Ordering) -> Option<T::Primitive> {
let read = self.core.read_idx().load(Ordering::Acquire);
let commit = self.write_commit.load(Ordering::Acquire);
if read == commit {
return None;
}
let index = read & self.core.mask();
unsafe {
let slot = self.core.peek_at(index);
Some(slot.load(order))
}
}
#[inline]
pub unsafe fn get_unchecked(&self, offset: usize) -> &T {
let read = self.core.read_idx().load(Ordering::Acquire);
let index = read.wrapping_add(offset) & self.core.mask();
unsafe { self.core.peek_at(index) }
}
#[inline]
pub fn clear(&self) {
let commit = self.write_commit.load(Ordering::Acquire);
self.core.read_idx().store(commit, Ordering::Release);
}
#[inline]
pub fn read_all(&self, order: Ordering) -> Vec<T::Primitive> {
let read = self.core.read_idx().load(Ordering::Acquire);
let commit = self.write_commit.load(Ordering::Acquire);
let len = commit.wrapping_sub(read).min(self.core.capacity());
let mut values = Vec::with_capacity(len);
for i in 0..len {
let index = read.wrapping_add(i) & self.core.mask();
let value = unsafe {
let slot = self.core.peek_at(index);
slot.load(order)
};
values.push(value);
}
values
}
#[inline]
pub fn iter(&self) -> AtomicIter<'_, T, N, OVERWRITE> {
let read = self.core.read_idx().load(Ordering::Acquire);
let commit = self.write_commit.load(Ordering::Acquire);
let len = commit.wrapping_sub(read).min(self.core.capacity());
AtomicIter {
ringbuf: self,
start: read,
remaining: len,
}
}
}
pub struct AtomicIter<'a, T: AtomicElement, const N: usize, const OVERWRITE: bool> {
ringbuf: &'a AtomicRingBuf<T, N, OVERWRITE>,
start: usize,
remaining: usize,
}
impl<'a, T: AtomicElement, const N: usize, const OVERWRITE: bool> Iterator
for AtomicIter<'a, T, N, OVERWRITE>
{
type Item = &'a T;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
if self.remaining == 0 {
return None;
}
let index = self.start & self.ringbuf.core.mask();
let element = unsafe { self.ringbuf.core.peek_at(index) };
self.start = self.start.wrapping_add(1);
self.remaining -= 1;
Some(element)
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
(self.remaining, Some(self.remaining))
}
}
impl<'a, T: AtomicElement, const N: usize, const OVERWRITE: bool> ExactSizeIterator
for AtomicIter<'a, T, N, OVERWRITE>
{
#[inline]
fn len(&self) -> usize {
self.remaining
}
}
pub trait AtomicNumeric: AtomicElement {
fn fetch_add(&self, val: Self::Primitive, order: Ordering) -> Self::Primitive;
fn fetch_sub(&self, val: Self::Primitive, order: Ordering) -> Self::Primitive;
}
macro_rules! impl_atomic_numeric {
($atomic:ty, $primitive:ty) => {
impl AtomicNumeric for $atomic {
#[inline]
fn fetch_add(&self, val: Self::Primitive, order: Ordering) -> Self::Primitive {
self.fetch_add(val, order)
}
#[inline]
fn fetch_sub(&self, val: Self::Primitive, order: Ordering) -> Self::Primitive {
self.fetch_sub(val, order)
}
}
};
}
impl_atomic_numeric!(AtomicU8, u8);
impl_atomic_numeric!(AtomicU16, u16);
impl_atomic_numeric!(AtomicU32, u32);
impl_atomic_numeric!(AtomicU64, u64);
impl_atomic_numeric!(AtomicUsize, usize);
impl_atomic_numeric!(AtomicI8, i8);
impl_atomic_numeric!(AtomicI16, i16);
impl_atomic_numeric!(AtomicI32, i32);
impl_atomic_numeric!(AtomicI64, i64);
impl_atomic_numeric!(AtomicIsize, isize);
impl<T: AtomicElement + AtomicNumeric, const N: usize, const OVERWRITE: bool>
AtomicRingBuf<T, N, OVERWRITE>
{
#[inline]
pub unsafe fn fetch_add_at(
&self,
offset: usize,
val: T::Primitive,
order: Ordering,
) -> T::Primitive {
let element = unsafe { self.get_unchecked(offset) };
element.fetch_add(val, order)
}
#[inline]
pub unsafe fn fetch_sub_at(
&self,
offset: usize,
val: T::Primitive,
order: Ordering,
) -> T::Primitive {
let element = unsafe { self.get_unchecked(offset) };
element.fetch_sub(val, order)
}
}
impl<T: AtomicElement, const N: usize, const OVERWRITE: bool> fmt::Debug
for AtomicRingBuf<T, N, OVERWRITE>
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AtomicRingBuf")
.field("capacity", &self.core.capacity())
.field("len", &self.core.len())
.field("is_empty", &self.core.is_empty())
.field("is_full", &self.core.is_full())
.field("overwrite_mode", &OVERWRITE)
.finish()
}
}
unsafe impl<T: AtomicElement, const N: usize, const OVERWRITE: bool> Send
for AtomicRingBuf<T, N, OVERWRITE>
{
}
unsafe impl<T: AtomicElement, const N: usize, const OVERWRITE: bool> Sync
for AtomicRingBuf<T, N, OVERWRITE>
{
}
#[cfg(all(test, not(feature = "loom")))]
mod tests {
use super::*;
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64};
#[test]
fn test_basic_push_pop() {
let buf: AtomicRingBuf<AtomicU64, 32> = AtomicRingBuf::new(4);
buf.push(1, Ordering::Relaxed);
buf.push(2, Ordering::Relaxed);
buf.push(3, Ordering::Relaxed);
assert_eq!(buf.len(), 3);
assert_eq!(buf.pop(Ordering::Relaxed), Some(1));
assert_eq!(buf.pop(Ordering::Relaxed), Some(2));
assert_eq!(buf.pop(Ordering::Relaxed), Some(3));
assert_eq!(buf.pop(Ordering::Relaxed), None);
}
#[test]
fn test_basic_overwrite_mode() {
let buf: AtomicRingBuf<AtomicU64, 32, true> = AtomicRingBuf::new(2);
assert_eq!(buf.push(1, Ordering::Relaxed), None);
assert_eq!(buf.push(2, Ordering::Relaxed), None);
assert_eq!(buf.push(3, Ordering::Relaxed), Some(1));
assert_eq!(buf.len(), 2);
assert_eq!(buf.pop(Ordering::Relaxed), Some(2));
assert_eq!(buf.pop(Ordering::Relaxed), Some(3));
}
#[test]
fn test_basic_non_overwrite_mode() {
let buf: AtomicRingBuf<AtomicU64, 32, false> = AtomicRingBuf::new(2);
assert_eq!(buf.push(1, Ordering::Relaxed), Ok(()));
assert_eq!(buf.push(2, Ordering::Relaxed), Ok(()));
assert_eq!(buf.push(3, Ordering::Relaxed), Err(3));
assert_eq!(buf.len(), 2);
assert!(buf.is_full());
}
#[test]
fn test_basic_peek() {
let buf: AtomicRingBuf<AtomicU64, 32> = AtomicRingBuf::new(4);
assert_eq!(buf.peek(Ordering::Relaxed), None);
buf.push(42, Ordering::Relaxed);
assert_eq!(buf.peek(Ordering::Relaxed), Some(42));
assert_eq!(buf.len(), 1);
buf.push(99, Ordering::Relaxed);
assert_eq!(buf.peek(Ordering::Relaxed), Some(42));
buf.pop(Ordering::Relaxed);
assert_eq!(buf.peek(Ordering::Relaxed), Some(99));
}
#[test]
fn test_basic_clear() {
let buf: AtomicRingBuf<AtomicU64, 32> = AtomicRingBuf::new(8);
buf.push(1, Ordering::Relaxed);
buf.push(2, Ordering::Relaxed);
buf.push(3, Ordering::Relaxed);
assert_eq!(buf.len(), 3);
buf.clear();
assert_eq!(buf.len(), 0);
assert!(buf.is_empty());
buf.push(10, Ordering::Relaxed);
assert_eq!(buf.pop(Ordering::Relaxed), Some(10));
}
#[test]
fn test_basic_capacity() {
let buf: AtomicRingBuf<AtomicU64, 32> = AtomicRingBuf::new(8);
assert_eq!(buf.capacity(), 8);
let buf: AtomicRingBuf<AtomicU64, 32> = AtomicRingBuf::new(5);
assert_eq!(buf.capacity(), 8);
}
#[test]
fn test_basic_is_empty_full() {
let buf: AtomicRingBuf<AtomicU64, 32> = AtomicRingBuf::new(2);
assert!(buf.is_empty());
assert!(!buf.is_full());
buf.push(1, Ordering::Relaxed);
assert!(!buf.is_empty());
assert!(!buf.is_full());
buf.push(2, Ordering::Relaxed);
assert!(!buf.is_empty());
assert!(buf.is_full());
buf.pop(Ordering::Relaxed);
assert!(!buf.is_empty());
assert!(!buf.is_full());
buf.pop(Ordering::Relaxed);
assert!(buf.is_empty());
assert!(!buf.is_full());
}
#[test]
fn test_basic_read_all() {
let buf: AtomicRingBuf<AtomicU64, 32> = AtomicRingBuf::new(8);
buf.push(1, Ordering::Relaxed);
buf.push(2, Ordering::Relaxed);
buf.push(3, Ordering::Relaxed);
let values = buf.read_all(Ordering::Acquire);
assert_eq!(values, vec![1, 2, 3]);
assert_eq!(buf.len(), 3);
}
#[test]
fn test_basic_iter() {
let buf: AtomicRingBuf<AtomicU64, 32> = AtomicRingBuf::new(8);
buf.push(10, Ordering::Relaxed);
buf.push(20, Ordering::Relaxed);
buf.push(30, Ordering::Relaxed);
let values: Vec<u64> = buf
.iter()
.map(|atom| atom.load(Ordering::Acquire))
.collect();
assert_eq!(values, vec![10, 20, 30]);
}
#[test]
fn test_basic_atomic_u32() {
let buf: AtomicRingBuf<AtomicU32, 32> = AtomicRingBuf::new(4);
buf.push(100u32, Ordering::Relaxed);
buf.push(200u32, Ordering::Relaxed);
assert_eq!(buf.pop(Ordering::Relaxed), Some(100u32));
assert_eq!(buf.pop(Ordering::Relaxed), Some(200u32));
}
#[test]
fn test_basic_atomic_bool() {
let buf: AtomicRingBuf<AtomicBool, 32> = AtomicRingBuf::new(4);
buf.push(true, Ordering::Relaxed);
buf.push(false, Ordering::Relaxed);
buf.push(true, Ordering::Relaxed);
assert_eq!(buf.pop(Ordering::Relaxed), Some(true));
assert_eq!(buf.pop(Ordering::Relaxed), Some(false));
assert_eq!(buf.pop(Ordering::Relaxed), Some(true));
}
#[test]
fn test_basic_fetch_add_at() {
let buf: AtomicRingBuf<AtomicU64, 32> = AtomicRingBuf::new(8);
buf.push(10, Ordering::Relaxed);
buf.push(20, Ordering::Relaxed);
buf.push(30, Ordering::Relaxed);
let old = unsafe { buf.fetch_add_at(0, 5, Ordering::Relaxed) };
assert_eq!(old, 10);
assert_eq!(buf.peek(Ordering::Acquire).unwrap(), 15);
let old = unsafe { buf.fetch_add_at(1, 100, Ordering::Relaxed) };
assert_eq!(old, 20);
}
#[test]
fn test_basic_fetch_sub_at() {
let buf: AtomicRingBuf<AtomicU64, 32> = AtomicRingBuf::new(8);
buf.push(100, Ordering::Relaxed);
buf.push(200, Ordering::Relaxed);
let old = unsafe { buf.fetch_sub_at(0, 10, Ordering::Relaxed) };
assert_eq!(old, 100);
assert_eq!(buf.peek(Ordering::Acquire).unwrap(), 90);
}
}