#![cfg_attr(not(feature = "std"), no_std)]
#![warn(rust_2018_idioms)]
#![deny(missing_docs, missing_debug_implementations)]
#![deny(unsafe_op_in_unsafe_fn)]
#![warn(clippy::undocumented_unsafe_blocks, clippy::unnecessary_safety_comment)]
extern crate alloc;
use alloc::sync::Arc;
use alloc::vec::Vec;
use core::cell::Cell;
use core::fmt;
use core::marker::PhantomData;
use core::mem::{ManuallyDrop, MaybeUninit};
use core::sync::atomic::{AtomicUsize, Ordering};
#[allow(dead_code, clippy::undocumented_unsafe_blocks)]
mod cache_padded;
use cache_padded::CachePadded;
pub mod chunks;
#[allow(unused_imports)]
use chunks::WriteChunkUninit;
#[derive(Debug)]
pub struct RingBuffer<T> {
head: CachePadded<AtomicUsize>,
tail: CachePadded<AtomicUsize>,
data_ptr: *mut T,
capacity: usize,
_marker: PhantomData<T>,
}
impl<T> RingBuffer<T> {
#[allow(clippy::new_ret_no_self)]
#[must_use]
pub fn new(capacity: usize) -> (Producer<T>, Consumer<T>) {
let buffer = Arc::new(RingBuffer {
head: CachePadded::new(AtomicUsize::new(0)),
tail: CachePadded::new(AtomicUsize::new(0)),
data_ptr: ManuallyDrop::new(Vec::with_capacity(capacity)).as_mut_ptr(),
capacity,
_marker: PhantomData,
});
let p = Producer {
buffer: buffer.clone(),
cached_head: Cell::new(0),
cached_tail: Cell::new(0),
};
let c = Consumer {
buffer,
cached_head: Cell::new(0),
cached_tail: Cell::new(0),
};
(p, c)
}
pub fn capacity(&self) -> usize {
self.capacity
}
fn collapse_position(&self, pos: usize) -> usize {
debug_assert!(pos == 0 || pos < 2 * self.capacity);
if pos < self.capacity {
pos
} else {
pos - self.capacity
}
}
unsafe fn slot_ptr(&self, pos: usize) -> *mut T {
debug_assert!(pos == 0 || pos < 2 * self.capacity);
let pos = self.collapse_position(pos);
unsafe { self.data_ptr.add(pos) }
}
fn increment(&self, pos: usize, n: usize) -> usize {
debug_assert!(pos == 0 || pos < 2 * self.capacity);
debug_assert!(n <= self.capacity);
let threshold = 2 * self.capacity - n;
if pos < threshold {
pos + n
} else {
pos - threshold
}
}
fn increment1(&self, pos: usize) -> usize {
debug_assert_ne!(self.capacity, 0);
debug_assert!(pos < 2 * self.capacity);
if pos < 2 * self.capacity - 1 {
pos + 1
} else {
0
}
}
fn distance(&self, a: usize, b: usize) -> usize {
debug_assert!(a == 0 || a < 2 * self.capacity);
debug_assert!(b == 0 || b < 2 * self.capacity);
if a <= b {
b - a
} else {
2 * self.capacity - a + b
}
}
}
impl<T> Drop for RingBuffer<T> {
fn drop(&mut self) {
let mut head = self.head.load(Ordering::Relaxed);
let tail = self.tail.load(Ordering::Relaxed);
while head != tail {
unsafe { self.slot_ptr(head).drop_in_place() };
head = self.increment1(head);
}
unsafe { Vec::from_raw_parts(self.data_ptr, 0, self.capacity) };
}
}
impl<T> PartialEq for RingBuffer<T> {
fn eq(&self, other: &Self) -> bool {
core::ptr::eq(self, other)
}
}
impl<T> Eq for RingBuffer<T> {}
#[derive(Debug, PartialEq, Eq)]
pub struct Producer<T> {
buffer: Arc<RingBuffer<T>>,
cached_head: Cell<usize>,
cached_tail: Cell<usize>,
}
unsafe impl<T: Send> Send for Producer<T> {}
impl<T> Producer<T> {
pub fn push(&mut self, value: T) -> Result<(), PushError<T>> {
if let Some(tail) = self.next_tail() {
unsafe { self.buffer.slot_ptr(tail).write(value) };
let tail = self.buffer.increment1(tail);
self.buffer.tail.store(tail, Ordering::Release);
self.cached_tail.set(tail);
Ok(())
} else {
Err(PushError::Full(value))
}
}
pub fn slots(&self) -> usize {
let head = self.buffer.head.load(Ordering::Acquire);
self.cached_head.set(head);
self.buffer.capacity - self.buffer.distance(head, self.cached_tail.get())
}
pub fn cached_slots(&self) -> usize {
let head = self.cached_head.get();
let tail = self.cached_tail.get();
self.buffer.capacity - self.buffer.distance(head, tail)
}
pub fn is_full(&self) -> bool {
self.next_tail().is_none()
}
pub fn is_abandoned(&self) -> bool {
Arc::strong_count(&self.buffer) < 2
}
pub fn buffer(&self) -> &RingBuffer<T> {
&self.buffer
}
fn next_tail(&self) -> Option<usize> {
let tail = self.cached_tail.get();
if self.buffer.distance(self.cached_head.get(), tail) == self.buffer.capacity {
let head = self.buffer.head.load(Ordering::Acquire);
self.cached_head.set(head);
if self.buffer.distance(head, tail) == self.buffer.capacity {
return None;
}
}
Some(tail)
}
}
#[derive(Debug, PartialEq, Eq)]
pub struct Consumer<T> {
buffer: Arc<RingBuffer<T>>,
cached_head: Cell<usize>,
cached_tail: Cell<usize>,
}
unsafe impl<T: Send> Send for Consumer<T> {}
impl<T> Consumer<T> {
pub fn pop(&mut self) -> Result<T, PopError> {
if let Some(head) = self.next_head() {
let value = unsafe { self.buffer.slot_ptr(head).read() };
let head = self.buffer.increment1(head);
self.buffer.head.store(head, Ordering::Release);
self.cached_head.set(head);
Ok(value)
} else {
Err(PopError::Empty)
}
}
pub fn peek(&self) -> Result<&T, PeekError> {
if let Some(head) = self.next_head() {
Ok(unsafe { &*self.buffer.slot_ptr(head) })
} else {
Err(PeekError::Empty)
}
}
pub fn slots(&self) -> usize {
let tail = self.buffer.tail.load(Ordering::Acquire);
self.cached_tail.set(tail);
self.buffer.distance(self.cached_head.get(), tail)
}
pub fn cached_slots(&self) -> usize {
let head = self.cached_head.get();
let tail = self.cached_tail.get();
self.buffer.distance(head, tail)
}
pub fn is_empty(&self) -> bool {
self.next_head().is_none()
}
pub fn is_abandoned(&self) -> bool {
Arc::strong_count(&self.buffer) < 2
}
pub fn buffer(&self) -> &RingBuffer<T> {
&self.buffer
}
fn next_head(&self) -> Option<usize> {
let head = self.cached_head.get();
if head == self.cached_tail.get() {
let tail = self.buffer.tail.load(Ordering::Acquire);
self.cached_tail.set(tail);
if head == tail {
return None;
}
}
Some(head)
}
}
pub trait CopyToUninit<T: Copy> {
fn copy_to_uninit<'a>(&self, dst: &'a mut [MaybeUninit<T>]) -> &'a mut [T];
}
impl<T: Copy> CopyToUninit<T> for [T] {
fn copy_to_uninit<'a>(&self, dst: &'a mut [MaybeUninit<T>]) -> &'a mut [T] {
assert_eq!(
self.len(),
dst.len(),
"source slice length does not match destination slice length"
);
let dst_ptr = dst.as_mut_ptr().cast();
unsafe {
self.as_ptr().copy_to_nonoverlapping(dst_ptr, self.len());
core::slice::from_raw_parts_mut(dst_ptr, self.len())
}
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum PopError {
Empty,
}
#[cfg(feature = "std")]
impl std::error::Error for PopError {}
impl fmt::Display for PopError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
PopError::Empty => "empty ring buffer".fmt(f),
}
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum PeekError {
Empty,
}
#[cfg(feature = "std")]
impl std::error::Error for PeekError {}
impl fmt::Display for PeekError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
PeekError::Empty => "empty ring buffer".fmt(f),
}
}
}
#[derive(Copy, Clone, PartialEq, Eq)]
pub enum PushError<T> {
Full(T),
}
#[cfg(feature = "std")]
impl<T> std::error::Error for PushError<T> {}
impl<T> fmt::Debug for PushError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
PushError::Full(_) => f.pad("Full(_)"),
}
}
}
impl<T> fmt::Display for PushError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
PushError::Full(_) => "full ring buffer".fmt(f),
}
}
}