use crate::buffer::MirrorBuffer;
use core::marker::PhantomData;
use core::ptr;
use core::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
#[repr(align(64))]
pub(crate) struct PaddedAtomic(pub(crate) AtomicUsize);
impl core::ops::Deref for PaddedAtomic {
type Target = AtomicUsize;
fn deref(&self) -> &Self::Target {
&self.0
}
}
pub(crate) struct SharedState<T> {
pub(crate) buffer: MirrorBuffer,
pub(crate) head: PaddedAtomic,
pub(crate) tail: PaddedAtomic,
pub(crate) commit: PaddedAtomic,
pub(crate) capacity: usize,
pub(crate) mask: usize,
pub(crate) _marker: PhantomData<T>,
}
unsafe impl<T: Send> Send for SharedState<T> {}
unsafe impl<T: Send> Sync for SharedState<T> {}
pub struct PicoMPSC<T> {
shared: Arc<SharedState<T>>,
}
pub struct PicoMpscProducer<T> {
pub(crate) shared: Arc<SharedState<T>>,
}
pub struct PicoMpscConsumer<T> {
pub(crate) shared: Arc<SharedState<T>>,
}
impl<T> Clone for PicoMpscProducer<T> {
fn clone(&self) -> Self {
Self {
shared: self.shared.clone(),
}
}
}
impl<T> PicoMPSC<T> {
pub fn new(capacity_count: usize) -> Result<Self, String> {
let item_size = core::mem::size_of::<T>();
if item_size == 0 {
return Err("Zero sized types are not supported in PicoMPSC".into());
}
let total_bytes = capacity_count
.checked_mul(item_size)
.ok_or_else(|| "Requested capacity is too large (overflow)".to_string())?;
let buffer = MirrorBuffer::new(total_bytes)?;
let actual_capacity = buffer.size() / item_size;
let mask = if actual_capacity > 0 && (actual_capacity & (actual_capacity - 1)) == 0 {
actual_capacity - 1
} else {
0
};
let shared = Arc::new(SharedState {
buffer,
head: PaddedAtomic(AtomicUsize::new(0)),
tail: PaddedAtomic(AtomicUsize::new(0)),
commit: PaddedAtomic(AtomicUsize::new(0)),
capacity: actual_capacity,
mask,
_marker: PhantomData,
});
Ok(Self { shared })
}
pub fn split(self) -> (PicoMpscProducer<T>, PicoMpscConsumer<T>) {
(
PicoMpscProducer {
shared: self.shared.clone(),
},
PicoMpscConsumer {
shared: self.shared,
},
)
}
}
impl<T> PicoMpscProducer<T> {
#[inline]
fn wrap(&self, val: usize) -> usize {
if self.shared.mask != 0 {
val & self.shared.mask
} else {
val % self.shared.capacity
}
}
#[inline]
pub fn push(&self, item: T) -> bool {
let mut head = self.shared.head.load(Ordering::Relaxed);
let mut next_head;
loop {
let tail = self.shared.tail.load(Ordering::Acquire);
let current_len = if head >= tail {
head - tail
} else {
self.shared.capacity - (tail - head)
};
if current_len + 1 >= self.shared.capacity {
return false;
}
next_head = self.wrap(head + 1);
match self.shared.head.compare_exchange_weak(
head,
next_head,
Ordering::Acquire,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(h) => head = h,
}
}
unsafe {
let ptr = (self.shared.buffer.as_mut_ptr() as *mut T).add(head);
ptr::write(ptr, item);
}
while self.shared.commit.load(Ordering::Acquire) != head {
core::hint::spin_loop();
}
self.shared.commit.store(next_head, Ordering::Release);
true
}
}
impl<T: Copy> PicoMpscProducer<T> {
pub fn push_slice(&self, data: &[T]) -> bool {
let n = data.len();
if n == 0 {
return true;
}
let mut head = self.shared.head.load(Ordering::Relaxed);
let mut next_head;
loop {
let tail = self.shared.tail.load(Ordering::Acquire);
let current_len = if head >= tail {
head - tail
} else {
self.shared.capacity - (tail - head)
};
if current_len + n >= self.shared.capacity {
return false;
}
next_head = self.wrap(head + n);
match self.shared.head.compare_exchange_weak(
head,
next_head,
Ordering::Acquire,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(h) => head = h,
}
}
unsafe {
let dest_ptr = (self.shared.buffer.as_mut_ptr() as *mut T).add(head);
ptr::copy_nonoverlapping(data.as_ptr(), dest_ptr, n);
}
while self.shared.commit.load(Ordering::Acquire) != head {
core::hint::spin_loop();
}
self.shared.commit.store(next_head, Ordering::Release);
true
}
}
impl<T> PicoMpscConsumer<T> {
#[inline]
fn wrap(&self, val: usize) -> usize {
if self.shared.mask != 0 {
val & self.shared.mask
} else {
val % self.shared.capacity
}
}
#[inline]
pub fn pop(&self) -> Option<T> {
let tail = self.shared.tail.load(Ordering::Relaxed);
let commit = self.shared.commit.load(Ordering::Acquire);
if tail == commit {
return None;
}
let item = unsafe {
let ptr = (self.shared.buffer.as_mut_ptr() as *const T).add(tail);
ptr::read(ptr)
};
self.shared
.tail
.store(self.wrap(tail + 1), Ordering::Release);
Some(item)
}
#[inline]
pub fn len(&self) -> usize {
let commit = self.shared.commit.load(Ordering::Acquire);
let tail = self.shared.tail.load(Ordering::Relaxed);
if commit >= tail {
commit - tail
} else {
self.shared.capacity - (tail - commit)
}
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline]
pub fn readable_slice(&self) -> &[T] {
let commit = self.shared.commit.load(Ordering::Acquire);
let tail = self.shared.tail.load(Ordering::Relaxed);
let len = if commit >= tail {
commit - tail
} else {
self.shared.capacity - (tail - commit)
};
unsafe {
let ptr = (self.shared.buffer.as_mut_ptr() as *const T).add(tail);
core::slice::from_raw_parts(ptr, len)
}
}
#[inline]
pub fn advance_tail(&self, n: usize) {
let tail = self.shared.tail.load(Ordering::Relaxed);
self.shared
.tail
.store(self.wrap(tail + n), Ordering::Release);
}
}
pub fn create_mpsc<T>(
capacity_count: usize,
) -> Result<(PicoMpscProducer<T>, PicoMpscConsumer<T>), String> {
Ok(PicoMPSC::new(capacity_count)?.split())
}