#![no_std]
use core::sync::atomic::{AtomicUsize, Ordering::Relaxed};
pub struct Buffer<T: Sized, const N: usize> {
data: core::cell::UnsafeCell<[T; N]>,
head: AtomicUsize, tail: AtomicUsize, }
pub struct Producer<'a, T: Sized, const N: usize> {
buffer: &'a Buffer<T, N>,
}
pub struct Consumer<'a, T: Sized, const N: usize> {
buffer: &'a Buffer<T, N>,
}
pub struct Region<'b, O, T: Sized> {
region: &'b mut [T], index_to_increment: &'b AtomicUsize, _owner: &'b mut O, }
impl<T: Sized, const N: usize> Buffer<T, N> {
pub const fn new() -> Self {
const {
assert!(
(N != 0) && ((N - 1) & N == 0),
"buffer size must be a power of 2"
)
};
Buffer {
data: core::cell::UnsafeCell::new(unsafe { core::mem::zeroed() }),
head: AtomicUsize::new(0),
tail: AtomicUsize::new(0),
}
}
pub fn split(&mut self) -> (Producer<'_, T, N>, Consumer<'_, T, N>) {
(Producer { buffer: self }, Consumer { buffer: self })
}
pub unsafe fn producer(&self) -> Producer<'_, T, N> {
Producer { buffer: self }
}
pub unsafe fn consumer(&self) -> Consumer<'_, T, N> {
Consumer { buffer: self }
}
#[inline(always)]
fn calc_pointers(&self, indices: [usize; 2], target_len: usize) -> (*mut T, usize, usize) {
let [start, end] = indices;
(
unsafe { (self.data.get() as *mut T).add(start & (N - 1)) },
N - (start & (N - 1)),
core::cmp::min(target_len, end.wrapping_sub(start)),
)
}
#[inline(always)]
#[allow(clippy::mut_from_ref)]
unsafe fn slice(&self, indices: [usize; 2], target_len: usize) -> &mut [T] {
let (start_ptr, wrap_len, len) = self.calc_pointers(indices, target_len);
unsafe { core::slice::from_raw_parts_mut(start_ptr, core::cmp::min(len, wrap_len)) }
}
}
impl<T: Sized, const N: usize> Default for Buffer<T, N> {
fn default() -> Self {
Self::new()
}
}
unsafe impl<T: Sized, const N: usize> Send for Buffer<T, N> {}
unsafe impl<T: Sized, const N: usize> Sync for Buffer<T, N> {}
impl<'a, T: Sized, const N: usize> Producer<'a, T, N> {
fn indices(&self) -> [usize; 2] {
[
self.buffer.tail.load(Relaxed),
self.buffer.head.load(Relaxed).wrapping_add(N),
]
}
pub fn write<'b>(&'b mut self, target_len: usize) -> Region<'b, Self, T> {
Region {
region: unsafe { self.buffer.slice(self.indices(), target_len) },
index_to_increment: &self.buffer.tail,
_owner: self,
}
}
pub fn empty_size(&self) -> usize {
let [start, end] = self.indices();
end.wrapping_sub(start)
}
}
impl<'a, T: Sized, const N: usize> Consumer<'a, T, N> {
fn indices(&self) -> [usize; 2] {
[
self.buffer.head.load(Relaxed),
self.buffer.tail.load(Relaxed),
]
}
pub fn read<'b>(&'b mut self, target_len: usize) -> Region<'b, Self, T> {
Region {
region: unsafe { self.buffer.slice(self.indices(), target_len) },
index_to_increment: &self.buffer.head,
_owner: self,
}
}
pub fn data_size(&self) -> usize {
let [start, end] = self.indices();
end.wrapping_sub(start)
}
pub fn flush(&mut self) {
self.buffer
.head
.store(self.buffer.tail.load(Relaxed), Relaxed);
}
}
impl<'b, O, T: Sized> Region<'b, O, T> {
pub fn consume(&mut self, num: usize) {
assert!(num <= self.region.len());
self.index_to_increment.fetch_add(num, Relaxed);
self.region = unsafe {
core::slice::from_raw_parts_mut(
self.region.as_mut_ptr().add(num),
self.region.len() - num,
)
}
}
pub fn partial_drop(self, num: usize) {
assert!(num <= self.region.len());
self.index_to_increment.fetch_add(num, Relaxed);
core::mem::forget(self); }
}
impl<'b, O, T: Sized> Drop for Region<'b, O, T> {
fn drop(&mut self) {
self.index_to_increment
.fetch_add(self.region.len(), Relaxed);
}
}
impl<'b, O, T: Sized> core::ops::Deref for Region<'b, O, T> {
type Target = [T];
fn deref(&self) -> &[T] {
self.region
}
}
impl<'b, O, T: Sized> core::ops::DerefMut for Region<'b, O, T> {
fn deref_mut(&mut self) -> &mut [T] {
self.region
}
}
#[test]
fn index_wraparound() {
let mut b = Buffer::<u8, 64>::new();
b.head.fetch_sub(128, Relaxed);
b.tail.fetch_sub(128, Relaxed);
let (mut p, mut c) = b.split();
for _ in 0..4 {
assert!(p.empty_size() == 64);
assert!(p.write(32).len() == 32);
assert!(p.empty_size() == 32);
assert!(p.write(usize::MAX).len() == 32);
assert!(p.empty_size() == 0);
assert!(p.write(usize::MAX).len() == 0);
assert!(c.data_size() == 64);
assert!(c.read(32).len() == 32);
assert!(c.data_size() == 32);
assert!(c.read(usize::MAX).len() == 32);
assert!(c.data_size() == 0);
assert!(c.read(usize::MAX).len() == 0);
}
assert!(b.head.load(Relaxed) == 128);
assert!(b.tail.load(Relaxed) == 128);
}