#[cfg(feature = "async")]
mod r#async;
#[cfg(feature = "async")]
pub use r#async::{AsyncConsumer, AsyncProducer, async_spsc};
use std::cell::UnsafeCell;
use std::mem::MaybeUninit;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
#[repr(align(64))]
pub(crate) struct Padded<T>(pub(crate) T);
pub(crate) struct Ring<T> {
pub(crate) buf: Box<[UnsafeCell<MaybeUninit<T>>]>,
pub(crate) mask: usize,
pub(crate) head: Padded<AtomicUsize>,
pub(crate) flush: Padded<AtomicUsize>,
}
unsafe impl<T: Send> Send for Ring<T> {}
unsafe impl<T: Send> Sync for Ring<T> {}
impl<T> Ring<T> {
pub(crate) fn new(capacity: usize) -> Self {
assert!(capacity > 0, "capacity must be > 0");
let cap = capacity.next_power_of_two();
let buf: Vec<UnsafeCell<MaybeUninit<T>>> = (0..cap)
.map(|_| UnsafeCell::new(MaybeUninit::uninit()))
.collect();
Self {
buf: buf.into_boxed_slice(),
mask: cap - 1,
head: Padded(AtomicUsize::new(0)),
flush: Padded(AtomicUsize::new(0)),
}
}
pub(crate) fn capacity(&self) -> usize {
self.mask + 1
}
}
impl<T> Drop for Ring<T> {
fn drop(&mut self) {
let head = *self.head.0.get_mut();
let flush = *self.flush.0.get_mut();
for i in head..flush {
unsafe {
self.buf[i & self.mask].get_mut().assume_init_drop();
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FlushResult {
Flushed { count: usize, was_empty: bool },
NothingToFlush,
}
pub struct Producer<T> {
ring: Arc<Ring<T>>,
tail: usize,
cached_head: usize,
}
impl<T> Producer<T> {
pub fn ring_addr(&self) -> usize {
Arc::as_ptr(&self.ring) as usize
}
}
unsafe impl<T: Send> Send for Producer<T> {}
pub struct Consumer<T> {
ring: Arc<Ring<T>>,
head: usize,
cached_flush: usize,
}
impl<T> Consumer<T> {
pub fn ring_addr(&self) -> usize {
Arc::as_ptr(&self.ring) as usize
}
}
unsafe impl<T: Send> Send for Consumer<T> {}
pub fn spsc<T>(capacity: usize) -> (Producer<T>, Consumer<T>) {
let ring = Arc::new(Ring::new(capacity));
(
Producer {
ring: ring.clone(),
tail: 0,
cached_head: 0,
},
Consumer {
ring,
head: 0,
cached_flush: 0,
},
)
}
impl<T> Producer<T> {
#[inline]
pub fn push(&mut self, val: T) -> Result<(), T> {
if self.tail - self.cached_head >= self.ring.capacity() {
self.cached_head = self.ring.head.0.load(Ordering::Acquire);
if self.tail - self.cached_head >= self.ring.capacity() {
return Err(val);
}
}
unsafe {
(*self.ring.buf[self.tail & self.ring.mask].get()).write(val);
}
self.tail += 1;
Ok(())
}
#[inline]
pub fn flush(&mut self) -> FlushResult {
let prev_flush = self.ring.flush.0.load(Ordering::Relaxed);
if self.tail == prev_flush {
return FlushResult::NothingToFlush;
}
let count = self.tail - prev_flush;
let was_empty = prev_flush == self.cached_head;
self.ring.flush.0.store(self.tail, Ordering::Release);
FlushResult::Flushed { count, was_empty }
}
#[inline]
pub fn push_and_flush(&mut self, val: T) -> Result<FlushResult, T> {
self.push(val)?;
Ok(self.flush())
}
#[inline]
pub fn is_full(&mut self) -> bool {
if self.tail - self.cached_head >= self.ring.capacity() {
self.cached_head = self.ring.head.0.load(Ordering::Acquire);
self.tail - self.cached_head >= self.ring.capacity()
} else {
false
}
}
#[inline]
pub fn capacity(&self) -> usize {
self.ring.capacity()
}
#[inline]
pub fn len(&self) -> usize {
self.tail
.wrapping_sub(self.ring.head.0.load(Ordering::Acquire))
}
#[inline]
pub fn is_empty(&self) -> bool {
self.tail == self.ring.head.0.load(Ordering::Acquire)
}
}
impl<T> Consumer<T> {
#[inline]
pub fn pop(&mut self) -> Option<T> {
if self.head == self.cached_flush {
return None;
}
let val = unsafe { (*self.ring.buf[self.head & self.ring.mask].get()).assume_init_read() };
self.head += 1;
self.ring.head.0.store(self.head, Ordering::Release);
Some(val)
}
#[inline]
pub fn prefetch(&mut self) -> usize {
let new_flush = self.ring.flush.0.load(Ordering::Acquire);
let count = new_flush.wrapping_sub(self.cached_flush);
self.cached_flush = new_flush;
count
}
#[inline]
pub fn prefetch_and_pop(&mut self) -> Option<T> {
if self.head == self.cached_flush {
self.prefetch();
}
self.pop()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.head == self.cached_flush && self.ring.flush.0.load(Ordering::Acquire) == self.head
}
#[inline]
pub fn capacity(&self) -> usize {
self.ring.capacity()
}
#[inline]
pub fn len(&self) -> usize {
self.ring
.flush
.0
.load(Ordering::Acquire)
.wrapping_sub(self.head)
}
}
impl<T> std::fmt::Debug for Producer<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Producer")
.field("capacity", &self.capacity())
.finish_non_exhaustive()
}
}
impl<T> std::fmt::Debug for Consumer<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Consumer")
.field("capacity", &self.capacity())
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn push_pop_basic() {
let (mut p, mut c) = spsc::<u32>(4);
assert!(c.prefetch_and_pop().is_none());
p.push(1).unwrap();
p.push(2).unwrap();
assert!(c.prefetch_and_pop().is_none());
p.flush();
assert_eq!(c.prefetch_and_pop(), Some(1));
assert_eq!(c.prefetch_and_pop(), Some(2));
assert!(c.prefetch_and_pop().is_none());
}
#[test]
fn push_and_flush() {
let (mut p, mut c) = spsc::<u32>(4);
p.push_and_flush(42).unwrap();
assert_eq!(c.prefetch_and_pop(), Some(42));
}
#[test]
fn batch_prefetch() {
let (mut p, mut c) = spsc::<u32>(8);
for i in 0..5 {
p.push(i).unwrap();
}
assert_eq!(c.prefetch(), 0); p.flush();
assert_eq!(c.prefetch(), 5);
for i in 0..5 {
assert_eq!(c.pop(), Some(i));
}
assert!(c.pop().is_none());
}
#[test]
fn flush_reports_was_empty() {
let (mut p, mut c) = spsc::<u32>(4);
p.push(1).unwrap();
let r = p.flush();
assert_eq!(
r,
FlushResult::Flushed {
count: 1,
was_empty: true
}
);
p.push(2).unwrap();
let r = p.flush();
assert!(matches!(r, FlushResult::Flushed { count: 1, .. }));
c.prefetch_and_pop();
c.prefetch_and_pop();
p.push(3).unwrap();
p.push(4).unwrap();
let _ = p.push(5);
let _ = p.push(6);
let r = p.flush();
assert!(matches!(r, FlushResult::Flushed { .. }));
}
#[test]
fn full_ring() {
let (mut p, mut c) = spsc::<u32>(4);
for i in 0..4 {
p.push(i).unwrap();
}
assert!(p.push(99).is_err());
p.flush();
assert_eq!(c.prefetch_and_pop(), Some(0));
p.push(99).unwrap();
p.flush();
for i in 1..=4 {
let expected = if i < 4 { i } else { 99 };
assert_eq!(c.prefetch_and_pop(), Some(expected));
}
}
#[test]
fn wraps_around() {
let (mut p, mut c) = spsc::<u32>(2);
for round in 0..100 {
p.push(round * 2).unwrap();
p.push(round * 2 + 1).unwrap();
p.flush();
assert_eq!(c.prefetch_and_pop(), Some(round * 2));
assert_eq!(c.prefetch_and_pop(), Some(round * 2 + 1));
}
}
#[test]
fn capacity_rounds_up() {
let (p, _c) = spsc::<u8>(3);
assert_eq!(p.capacity(), 4);
let (p, _c) = spsc::<u8>(5);
assert_eq!(p.capacity(), 8);
let (p, _c) = spsc::<u8>(1);
assert_eq!(p.capacity(), 1);
}
#[test]
fn drop_remaining() {
use std::sync::atomic::AtomicUsize;
static DROPS: AtomicUsize = AtomicUsize::new(0);
#[derive(Debug)]
struct Counted;
impl Drop for Counted {
fn drop(&mut self) {
DROPS.fetch_add(1, Ordering::Relaxed);
}
}
DROPS.store(0, Ordering::Relaxed);
let (mut p, c) = spsc::<Counted>(4);
p.push(Counted).unwrap();
p.push(Counted).unwrap();
p.push(Counted).unwrap();
p.flush();
drop(p);
drop(c);
assert_eq!(DROPS.load(Ordering::Relaxed), 3);
}
#[test]
fn cross_thread() {
let (mut p, mut c) = spsc::<u64>(1024);
let n = 100_000u64;
let sender = std::thread::spawn(move || {
for i in 0..n {
while p.push(i).is_err() {
p.flush();
std::thread::yield_now();
}
p.flush();
}
});
let mut received = 0u64;
while received < n {
if c.prefetch() > 0 {
while let Some(v) = c.pop() {
assert_eq!(v, received);
received += 1;
}
} else {
std::thread::yield_now();
}
}
sender.join().unwrap();
}
}