use core::cell::UnsafeCell;
use core::mem::MaybeUninit;
#[inline]
const fn is_power_of_two(n: usize) -> bool {
n != 0 && (n & (n - 1)) == 0
}
pub struct RingBuffer<T> {
slots: Box<[UnsafeCell<MaybeUninit<T>>]>,
mask: usize,
capacity: usize,
}
impl<T> RingBuffer<T> {
pub fn new(capacity: usize) -> Self {
assert!(
is_power_of_two(capacity),
"Ring buffer capacity must be a power of 2, got {}",
capacity
);
let mut slots = Vec::with_capacity(capacity);
for _ in 0..capacity {
slots.push(UnsafeCell::new(MaybeUninit::uninit()));
}
Self {
slots: slots.into_boxed_slice(),
mask: capacity - 1,
capacity,
}
}
pub fn with_factory<F>(capacity: usize, factory: F) -> Self
where
F: Fn() -> T,
{
assert!(
is_power_of_two(capacity),
"Ring buffer capacity must be a power of 2, got {}",
capacity
);
let mut slots = Vec::with_capacity(capacity);
for _ in 0..capacity {
slots.push(UnsafeCell::new(MaybeUninit::new(factory())));
}
Self {
slots: slots.into_boxed_slice(),
mask: capacity - 1,
capacity,
}
}
#[inline]
pub fn capacity(&self) -> usize {
self.capacity
}
#[inline]
pub fn mask(&self) -> usize {
self.mask
}
#[inline(always)]
pub fn index(&self, sequence: i64) -> usize {
(sequence as usize) & self.mask
}
#[inline(always)]
pub unsafe fn get_ptr_mut(&self, sequence: i64) -> *mut T {
let idx = self.index(sequence);
unsafe { (*self.slots.get_unchecked(idx).get()).as_mut_ptr() }
}
#[inline(always)]
pub unsafe fn get_ptr(&self, sequence: i64) -> *const T {
let idx = self.index(sequence);
unsafe { (*self.slots.get_unchecked(idx).get()).as_ptr() }
}
#[inline(always)]
pub unsafe fn write(&self, sequence: i64, event: T) {
let ptr = unsafe { self.get_ptr_mut(sequence) };
unsafe { ptr.write(event) };
}
#[inline(always)]
pub unsafe fn write_and_drop(&self, sequence: i64, event: T) {
let ptr = unsafe { self.get_ptr_mut(sequence) };
unsafe { core::ptr::drop_in_place(ptr) };
unsafe { ptr.write(event) };
}
#[inline(always)]
pub unsafe fn get(&self, sequence: i64) -> &T {
unsafe { &*self.get_ptr(sequence) }
}
#[inline(always)]
pub unsafe fn get_mut(&self, sequence: i64) -> &mut T {
unsafe { &mut *self.get_ptr_mut(sequence) }
}
#[inline(always)]
pub unsafe fn read(&self, sequence: i64) -> T {
unsafe { core::ptr::read(self.get_ptr(sequence)) }
}
#[inline]
pub fn free_slots(&self, producer_seq: i64, consumer_seq: i64) -> usize {
let pending = producer_seq - consumer_seq;
if pending < 0 {
self.capacity
} else {
self.capacity - pending as usize
}
}
}
unsafe impl<T: Send> Send for RingBuffer<T> {}
unsafe impl<T: Send> Sync for RingBuffer<T> {}
impl<T> Drop for RingBuffer<T> {
fn drop(&mut self) {
}
}
impl<T: core::fmt::Debug> core::fmt::Debug for RingBuffer<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("RingBuffer")
.field("capacity", &self.capacity)
.field("mask", &self.mask)
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new() {
let buffer: RingBuffer<u64> = RingBuffer::new(1024);
assert_eq!(buffer.capacity(), 1024);
assert_eq!(buffer.mask(), 1023);
}
#[test]
#[should_panic(expected = "power of 2")]
fn test_new_non_power_of_2() {
let _: RingBuffer<u64> = RingBuffer::new(100);
}
#[test]
#[should_panic(expected = "power of 2")]
fn test_new_zero() {
let _: RingBuffer<u64> = RingBuffer::new(0);
}
#[test]
fn test_with_factory() {
let buffer = RingBuffer::with_factory(8, || String::with_capacity(100));
assert_eq!(buffer.capacity(), 8);
unsafe {
let s = buffer.get_mut(0);
assert!(s.capacity() >= 100);
}
}
#[test]
fn test_index() {
let buffer: RingBuffer<u64> = RingBuffer::new(8);
assert_eq!(buffer.index(0), 0);
assert_eq!(buffer.index(1), 1);
assert_eq!(buffer.index(7), 7);
assert_eq!(buffer.index(8), 0); assert_eq!(buffer.index(9), 1);
assert_eq!(buffer.index(15), 7);
assert_eq!(buffer.index(16), 0);
}
#[test]
fn test_write_and_read() {
let buffer: RingBuffer<u64> = RingBuffer::new(8);
unsafe {
buffer.write(0, 100);
buffer.write(1, 200);
buffer.write(7, 700);
assert_eq!(*buffer.get(0), 100);
assert_eq!(*buffer.get(1), 200);
assert_eq!(*buffer.get(7), 700);
}
}
#[test]
fn test_wrap_around() {
let buffer: RingBuffer<u64> = RingBuffer::new(4);
unsafe {
for i in 0..4 {
buffer.write(i, i as u64 * 10);
}
for i in 0..4 {
assert_eq!(*buffer.get(i), i as u64 * 10);
}
buffer.write(4, 40);
assert_eq!(*buffer.get(4), 40); assert_eq!(*buffer.get(0), 40); }
}
#[test]
fn test_get_mut_zero_copy() {
let buffer = RingBuffer::with_factory(8, String::new);
unsafe {
let s = buffer.get_mut(0);
s.push_str("hello");
s.push_str(" world");
assert_eq!(buffer.get(0), "hello world");
}
}
#[test]
fn test_read_moves_ownership() {
let buffer: RingBuffer<String> = RingBuffer::new(4);
unsafe {
buffer.write(0, String::from("test"));
let owned: String = buffer.read(0);
assert_eq!(owned, "test");
}
}
#[test]
fn test_free_slots() {
let buffer: RingBuffer<u64> = RingBuffer::new(8);
assert_eq!(buffer.free_slots(-1, -1), 8);
assert_eq!(buffer.free_slots(3, -1), 4);
assert_eq!(buffer.free_slots(7, -1), 0);
assert_eq!(buffer.free_slots(7, 3), 4);
assert_eq!(buffer.free_slots(7, 6), 7);
}
#[test]
fn test_large_sequences() {
let buffer: RingBuffer<u64> = RingBuffer::new(8);
let large_seq: i64 = 1_000_000_000;
unsafe {
buffer.write(large_seq, 42);
assert_eq!(*buffer.get(large_seq), 42);
}
let expected_idx = (large_seq as usize) & 7;
assert_eq!(buffer.index(large_seq), expected_idx);
}
#[test]
fn test_debug() {
let buffer: RingBuffer<u64> = RingBuffer::new(8);
let debug = format!("{:?}", buffer);
assert!(debug.contains("RingBuffer"));
assert!(debug.contains("capacity: 8"));
}
}