#![no_std]
extern crate alloc;
use alloc::sync::Arc;
use core::array;
use core::cell::UnsafeCell;
use core::mem::MaybeUninit;
use core::sync::atomic::{AtomicUsize, Ordering};
#[derive(Debug)]
pub struct RingBuffer<T, const N: usize> {
buffer: [UnsafeCell<MaybeUninit<T>>; N],
head: AtomicUsize,
tail: AtomicUsize,
}
#[derive(Debug)]
struct Observer<T, const N: usize> {
rb: Arc<RingBuffer<T, N>>,
cached_head: usize,
cached_tail: usize,
}
#[derive(Debug)]
pub struct Producer<T, const N: usize> {
obs: Observer<T, N>,
}
#[derive(Debug)]
pub struct Consumer<T, const N: usize> {
obs: Observer<T, N>,
}
impl<T: core::fmt::Display, const N: usize> core::fmt::Display for RingBuffer<T, N> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let head = self.head.load(Ordering::Acquire);
let tail = self.tail.load(Ordering::Acquire);
let mut list = f.debug_list();
let mut i = tail;
while i != head {
unsafe {
let v = &*self.buffer[self.mask(i)].get();
list.entry(&format_args!("{:?}", v));
}
i = self.next_index(i);
}
list.finish()
}
}
unsafe impl<T: Send, const N: usize> Send for RingBuffer<T, N> {}
unsafe impl<T: Send, const N: usize> Sync for RingBuffer<T, N> {}
impl<T, const N: usize> RingBuffer<T, N> {
pub fn new() -> Self {
let init = |_| UnsafeCell::new(MaybeUninit::uninit());
Self {
buffer: array::from_fn(init),
head: AtomicUsize::new(0),
tail: AtomicUsize::new(0),
}
}
pub fn split(self) -> (Producer<T, N>, Consumer<T, N>) {
let buffer = Arc::new(self);
let producer = Producer::new(Observer::new(Arc::clone(&buffer)));
let consumer = Consumer::new(Observer::new(Arc::clone(&buffer)));
(producer, consumer)
}
#[inline]
pub fn size(&self) -> usize {
self.load_head().wrapping_sub(self.load_tail())
}
#[inline]
pub const fn capacity(&self) -> usize {
N
}
#[inline]
const fn mask(&self, i: usize) -> usize {
i % N
}
#[inline]
const fn next_index(&self, i: usize) -> usize {
i + 1
}
#[inline]
fn load_head(&self) -> usize {
self.head.load(Ordering::Acquire)
}
#[inline]
fn load_tail(&self) -> usize {
self.tail.load(Ordering::Acquire)
}
#[inline]
fn store_head(&self, value: usize) {
self.head.store(value, Ordering::Release);
}
#[inline]
fn store_tail(&self, value: usize) {
self.tail.store(value, Ordering::Release);
}
}
impl<T, const N: usize> Observer<T, N> {
fn new(rb: Arc<RingBuffer<T, N>>) -> Self {
let head = rb.load_head();
let tail = rb.load_tail();
Self {
rb,
cached_head: head,
cached_tail: tail,
}
}
#[inline]
fn sync_head(&mut self) {
self.cached_head = self.rb.load_head();
}
#[inline]
fn sync_tail(&mut self) {
self.cached_tail = self.rb.load_tail();
}
#[inline]
fn len(&self) -> usize {
self.cached_head - self.cached_tail
}
#[inline]
fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline]
fn is_full(&self) -> bool {
self.len() == N
}
}
impl<T, const N: usize> Producer<T, N> {
fn new(obs: Observer<T, N>) -> Self {
Self { obs }
}
pub fn push(&mut self, value: T) -> bool {
if self.obs.is_full() {
self.obs.sync_tail();
if self.obs.is_full() {
return false;
}
}
let head = self.obs.cached_head;
let next = self.obs.rb.next_index(head);
unsafe {
*self.obs.rb.buffer[self.obs.rb.mask(head)].get() = MaybeUninit::new(value);
self.obs.rb.store_head(next);
self.obs.cached_head = next;
true
}
}
}
impl<T, const N: usize> Consumer<T, N> {
fn new(obs: Observer<T, N>) -> Self {
Self { obs }
}
pub fn pop(&mut self) -> Option<T> {
if self.obs.is_empty() {
self.obs.sync_head();
if self.obs.is_empty() {
return None;
}
}
let tail = self.obs.cached_tail;
let next = self.obs.rb.next_index(tail);
unsafe {
let item = (*self.obs.rb.buffer[self.obs.rb.mask(tail)].get()).assume_init_read();
*self.obs.rb.buffer[self.obs.rb.mask(tail)].get() = MaybeUninit::uninit();
self.obs.rb.store_tail(next);
self.obs.cached_tail = next;
Some(item)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use ringbuf::StaticRb;
use ringbuf::traits::{Consumer, Producer, Split};
use std::thread;
use std::time::Instant;
#[test]
fn test_split() {
let rb: RingBuffer<i32, 3> = RingBuffer::new();
let (mut tx, mut rx) = rb.split();
assert!(tx.push(1));
assert!(tx.push(2));
assert!(tx.push(3));
assert_eq!(rx.pop(), Some(1));
assert_eq!(rx.pop(), Some(2));
assert_eq!(rx.pop(), Some(3));
assert_eq!(rx.pop(), None);
}
#[test]
fn spsc_threaded() {
let n = 20;
let (mut tx, mut rx) = RingBuffer::<u64, 20>::new().split();
let t1 = thread::spawn(move || {
for i in 0..n {
while !tx.push(i) {}
}
});
let t2 = thread::spawn(move || {
let mut expected = 0;
while expected < n {
if let Some(v) = rx.pop() {
assert_eq!(v, expected);
expected += 1;
}
}
});
t1.join().unwrap();
t2.join().unwrap();
}
#[test]
fn time_1_million_writes() {
const ITERS: u64 = 1_000_000;
const CAP: usize = 1_000;
let (mut tx, mut rx) = RingBuffer::<u64, CAP>::new().split();
let t1 = thread::spawn(move || {
for i in 0..ITERS {
while !tx.push(i) {}
}
});
let t2 = thread::spawn(move || {
let mut expected = 0;
while expected < ITERS {
if let Some(v) = rx.pop() {
if v != expected {
panic!("bad value: {v} != {expected}");
}
expected += 1;
}
}
});
t1.join().unwrap();
t2.join().unwrap();
}
#[test]
fn benchmark_against_ringbuf() {
const ITERS: u64 = 5_000_000;
const CAP: usize = 1024;
let (mut tx, mut rx) = RingBuffer::<u64, CAP>::new().split();
let start = Instant::now();
let t1 = thread::spawn(move || {
for i in 0..ITERS {
while !tx.push(i) {}
}
});
let t2 = thread::spawn(move || {
let mut expected = 0;
while expected < ITERS {
if let Some(v) = rx.pop() {
if v != expected {
panic!("bad value: {v} != {expected}");
}
expected += 1;
}
}
});
t1.join().unwrap();
t2.join().unwrap();
let ours = start.elapsed();
let rb = StaticRb::<u64, CAP>::default();
let (mut tx, mut rx) = rb.split();
let start = Instant::now();
let t1 = thread::spawn(move || {
for i in 0..ITERS {
while tx.try_push(i).is_err() {}
}
});
let t2 = thread::spawn(move || {
let mut expected = 0;
while expected < ITERS {
if let Some(v) = rx.try_pop() {
if v != expected {
panic!("bad value: {v} != {expected}");
}
expected += 1;
}
}
});
t1.join().unwrap();
t2.join().unwrap();
let ringbuf = start.elapsed();
println!();
println!("================= SPSC BENCH =================");
println!("iters: {ITERS}");
println!("capacity: {CAP}");
println!("---------------------------------------------");
println!("Your RingBuffer : {:?}", ours);
println!("ringbuf HeapRb : {:?}", ringbuf);
let ratio = ringbuf.as_secs_f64() / ours.as_secs_f64();
println!("Speed ratio (ringbuf / yours): {:.2}x", ratio);
println!("=============================================");
}
#[test]
fn test_custom_type() {
use std::fmt::Display;
struct OptionContract {
strike: u32,
maturity: u32,
}
impl Display for OptionContract {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"OptionContract(strike: {}, maturity: {})",
self.strike, self.maturity
)
}
}
impl OptionContract {
fn new(strike: u32, maturity: u32) -> Self {
Self { strike, maturity }
}
}
let n_contracts = 3;
let (mut tx, mut rx) = RingBuffer::<OptionContract, 5>::new().split();
for i in 0..n_contracts {
assert!(tx.push(OptionContract::new(100 + i, 2025 + i)));
}
for i in 0..n_contracts {
let contract = rx.pop().unwrap();
assert_eq!(contract.strike, 100 + i);
assert_eq!(contract.maturity, 2025 + i);
}
}
#[test]
fn time_reads_and_writes() {
const OPERATIONS: u32 = 100_000;
const CAPACITY: usize = 100_000;
let (mut tx, mut rx) = RingBuffer::<u32, CAPACITY>::new().split();
let start = Instant::now();
let t1 = thread::spawn(move || {
for i in 0..OPERATIONS {
while !tx.push(i) {}
}
});
t1.join().unwrap();
println!("Time for {OPERATIONS} writes: {:?}", start.elapsed());
let start = Instant::now();
let t2 = thread::spawn(move || {
for expected in 0..OPERATIONS {
loop {
if let Some(v) = rx.pop() {
assert_eq!(v, expected);
break;
}
}
}
});
t2.join().unwrap();
println!("Time for {OPERATIONS} reads: {:?}", start.elapsed());
}
}