#![allow(dead_code)]
use core::cell::UnsafeCell;
use core::marker::PhantomData;
use core::mem::{align_of, size_of, MaybeUninit};
use core::sync::atomic::{AtomicUsize, Ordering};
#[repr(C)]
pub struct MPMC<const S: usize, T> {
prod_head: AtomicUsize,
prod_tail: AtomicUsize,
cons_head: AtomicUsize,
cons_tail: AtomicUsize,
size: usize,
ring: [UnsafeCell<MaybeUninit<T>>; S],
}
pub struct MPMCRef<'a, T>(*const MPMC<0, T>, PhantomData<&'a T>);
impl<const S: usize, T> MPMC<S, T> {
pub const fn new() -> Self {
if !S.is_power_of_two() || S < 2 {
panic!("size must be a power of 2 (and be greater or equal 2)");
}
Self {
prod_head: AtomicUsize::new(0),
prod_tail: AtomicUsize::new(0),
cons_head: AtomicUsize::new(0),
cons_tail: AtomicUsize::new(0),
size: S,
#[allow(clippy::uninit_assumed_init)]
ring: unsafe { MaybeUninit::uninit().assume_init() },
}
}
#[inline]
pub fn as_ptr(&self) -> MPMCRef<'_, T> {
unsafe { MPMCRef::from_ptr(self) }
}
#[inline]
pub fn capacity(&self) -> usize {
S - 1
}
#[inline]
pub fn len(&self) -> usize {
self.as_ptr().len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.as_ptr().is_empty()
}
#[inline]
pub fn is_full(&self) -> bool {
self.as_ptr().is_full()
}
#[inline]
pub fn push(&self, val: T) -> Result<&UnsafeCell<T>, T> {
self.as_ptr().push(val)
}
#[inline]
pub fn pop(&self) -> Option<T> {
self.as_ptr().pop()
}
#[inline]
pub fn peek(&self) -> Option<&UnsafeCell<T>> {
self.as_ptr().peek()
}
const fn ring_offset() -> usize {
let offset = size_of::<AtomicUsize>() * 4 + size_of::<usize>();
let align = align_of::<[UnsafeCell<MaybeUninit<T>>; 2]>();
(offset + align - 1) & !(align - 1)
}
}
impl<'a, T> MPMCRef<'a, T> {
pub unsafe fn from_ptr<const S: usize>(ptr: *const MPMC<S, T>) -> MPMCRef<'a, T> {
MPMCRef(ptr as *const _ as *const MPMC<0, T>, PhantomData)
}
#[inline]
fn ring(&self) -> &[UnsafeCell<MaybeUninit<T>>] {
let ring_ptr = unsafe { self.0.cast::<u8>().add(MPMC::<0, T>::ring_offset()) }
as *const UnsafeCell<MaybeUninit<T>>;
unsafe { core::slice::from_raw_parts(ring_ptr, self.size) }
}
#[inline]
pub fn capacity(&self) -> usize {
self.size - 1
}
#[inline]
pub fn len(&self) -> usize {
let prod_tail = self.prod_tail.load(Ordering::Acquire);
let cons_tail = self.cons_tail.load(Ordering::Acquire);
(prod_tail - cons_tail) & self.capacity()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline]
pub fn is_full(&self) -> bool {
self.len() == self.capacity()
}
pub fn push(&self, val: T) -> Result<&'a UnsafeCell<T>, T> {
let mut head = self.prod_head.load(Ordering::Acquire);
loop {
let tail = self.cons_tail.load(Ordering::Acquire);
if self.capacity().wrapping_add(tail.wrapping_sub(head)) == 0 {
return Err(val);
}
match self.prod_head.compare_exchange_weak(
head,
head.wrapping_add(1),
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => break,
Err(new_head) => head = new_head,
}
}
let slot = &self.ring()[head & (self.capacity())];
unsafe { (*slot.get()).write(val) };
loop {
match self.prod_tail.compare_exchange_weak(
head,
head.wrapping_add(1),
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => break,
Err(_) => {
core::hint::spin_loop();
}
}
}
Ok(unsafe { core::intrinsics::transmute(slot) })
}
pub fn pop(&self) -> Option<T> {
let mut data: MaybeUninit<T> = MaybeUninit::uninit();
let mut head = self.cons_head.load(Ordering::Acquire);
loop {
let tail = self.prod_tail.load(Ordering::Acquire);
if head == tail {
return None;
}
match self.cons_head.compare_exchange_weak(
head,
head.wrapping_add(1),
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => break,
Err(new_head) => head = new_head,
}
}
let slot = &self.ring()[head & self.capacity()];
data.write(unsafe { (*slot.get()).assume_init_read() });
loop {
match self.cons_tail.compare_exchange_weak(
head,
head.wrapping_add(1),
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => break,
Err(_) => {
core::hint::spin_loop();
}
}
}
Some(unsafe { data.assume_init() })
}
#[inline]
pub fn peek(&self) -> Option<&'a UnsafeCell<T>> {
let prod_tail = self.prod_tail.load(Ordering::Acquire);
let cons_tail = self.cons_tail.load(Ordering::Acquire);
if prod_tail == cons_tail {
return None;
}
let slot = &self.ring()[cons_tail & self.capacity()];
Some(unsafe { core::intrinsics::transmute(slot) })
}
}
impl<const S: usize, T> Default for MPMC<S, T> {
fn default() -> Self {
Self::new()
}
}
impl<const S: usize, T> core::fmt::Debug for MPMC<S, T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_fmt(format_args!(
"MPMC {{ capacity: {capacity}, len: {len} }}",
capacity = self.capacity(),
len = self.len()
))
}
}
unsafe impl<const S: usize, T> Sync for MPMC<S, T> {}
unsafe impl<const S: usize, T> Send for MPMC<S, T> {}
impl<'a, T> core::ops::Deref for MPMCRef<'a, T> {
type Target = MPMC<0, T>;
fn deref(&self) -> &Self::Target {
unsafe { &*self.0 }
}
}
impl<'a, T> Clone for MPMCRef<'a, T> {
fn clone(&self) -> Self {
MPMCRef(self.0, PhantomData)
}
}
#[cfg(test)]
mod tests {
use super::MPMC;
#[test]
fn validate_ring_field_offset() {
assert_eq!(
MPMC::<0, usize>::ring_offset(),
core::mem::offset_of!(MPMC<0, usize>, ring)
);
assert_eq!(
MPMC::<0, &dyn core::any::Any>::ring_offset(),
core::mem::offset_of!(MPMC<0, &dyn core::any::Any>, ring)
);
}
}