use core::cell::UnsafeCell;
use core::marker::PhantomData;
use core::mem::MaybeUninit;
use core::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use crossbeam_utils::CachePadded;
struct Shared<T> {
buf: Box<[UnsafeCell<MaybeUninit<T>>]>,
capacity: usize,
head: CachePadded<AtomicUsize>,
tail: CachePadded<AtomicUsize>,
}
unsafe impl<T: Send> Send for Shared<T> {}
unsafe impl<T: Send> Sync for Shared<T> {}
pub struct Producer<T> {
shared: Arc<Shared<T>>,
_not_sync: PhantomData<*const ()>,
}
pub struct Consumer<T> {
shared: Arc<Shared<T>>,
_not_sync: PhantomData<*const ()>,
}
unsafe impl<T: Send> Send for Producer<T> {}
unsafe impl<T: Send> Send for Consumer<T> {}
impl<T> Producer<T> {
#[inline]
pub fn try_push(&self, value: T) -> Result<(), T> {
let shared = &*self.shared;
let tail = shared.tail.load(Ordering::Relaxed);
let head = shared.head.load(Ordering::Acquire);
if tail.wrapping_sub(head) == shared.capacity {
return Err(value);
}
let idx = tail % shared.capacity;
unsafe {
(*shared.buf[idx].get()).write(value);
}
shared.tail.store(tail.wrapping_add(1), Ordering::Release);
Ok(())
}
#[inline]
#[must_use]
pub fn capacity(&self) -> usize {
self.shared.capacity
}
#[inline]
#[must_use]
pub fn is_full(&self) -> bool {
let shared = &*self.shared;
let tail = shared.tail.load(Ordering::Relaxed);
let head = shared.head.load(Ordering::Acquire);
tail.wrapping_sub(head) == shared.capacity
}
}
impl<T> Consumer<T> {
#[inline]
#[must_use]
pub fn try_pop(&self) -> Option<T> {
let shared = &*self.shared;
let head = shared.head.load(Ordering::Relaxed);
let tail = shared.tail.load(Ordering::Acquire);
if head == tail {
return None;
}
let idx = head % shared.capacity;
let value = unsafe { (*shared.buf[idx].get()).assume_init_read() };
shared.head.store(head.wrapping_add(1), Ordering::Release);
Some(value)
}
#[inline]
#[must_use]
pub fn capacity(&self) -> usize {
self.shared.capacity
}
#[inline]
#[must_use]
pub fn is_empty(&self) -> bool {
let shared = &*self.shared;
let head = shared.head.load(Ordering::Relaxed);
let tail = shared.tail.load(Ordering::Acquire);
head == tail
}
}
impl<T> Drop for Shared<T> {
fn drop(&mut self) {
let head = self.head.load(Ordering::Relaxed);
let tail = self.tail.load(Ordering::Relaxed);
let mut i = head;
while i != tail {
let idx = i % self.capacity;
unsafe {
(*self.buf[idx].get()).assume_init_drop();
}
i = i.wrapping_add(1);
}
}
}
#[must_use]
pub fn channel<T>(capacity: usize) -> (Producer<T>, Consumer<T>) {
assert!(capacity > 0, "SPSC ring capacity must be non-zero");
let mut buf = Vec::with_capacity(capacity);
for _ in 0..capacity {
buf.push(UnsafeCell::new(MaybeUninit::uninit()));
}
let shared = Arc::new(Shared {
buf: buf.into_boxed_slice(),
capacity,
head: CachePadded::new(AtomicUsize::new(0)),
tail: CachePadded::new(AtomicUsize::new(0)),
});
(
Producer {
shared: Arc::clone(&shared),
_not_sync: PhantomData,
},
Consumer {
shared,
_not_sync: PhantomData,
},
)
}
#[cfg(test)]
mod tests {
use super::*;
use static_assertions::{assert_impl_all, assert_not_impl_any};
assert_impl_all!(Producer<u32>: Send);
assert_impl_all!(Consumer<u32>: Send);
assert_not_impl_any!(Producer<u32>: Sync);
assert_not_impl_any!(Consumer<u32>: Sync);
#[test]
fn new_ring_is_empty_and_not_full() {
let (p, c) = channel::<u32>(4);
assert!(c.is_empty());
assert!(!p.is_full());
assert_eq!(c.try_pop(), None);
}
#[test]
fn push_then_pop_returns_same_value() {
let (p, c) = channel::<u32>(4);
assert!(p.try_push(42).is_ok());
assert_eq!(c.try_pop(), Some(42));
}
#[test]
fn capacity_matches_constructor_argument() {
let (p, c) = channel::<u32>(8);
assert_eq!(p.capacity(), 8);
assert_eq!(c.capacity(), 8);
}
#[test]
fn push_when_full_returns_back_the_value() {
let (p, _c) = channel::<u32>(2);
assert!(p.try_push(1).is_ok());
assert!(p.try_push(2).is_ok());
assert!(p.is_full());
assert_eq!(p.try_push(3), Err(3));
}
#[test]
fn pop_when_empty_returns_none() {
let (_p, c) = channel::<u32>(2);
assert_eq!(c.try_pop(), None);
}
#[test]
fn fifo_order_is_preserved() {
let (p, c) = channel::<u32>(8);
for i in 0..8 {
assert!(p.try_push(i).is_ok());
}
for i in 0..8 {
assert_eq!(c.try_pop(), Some(i));
}
assert!(c.is_empty());
}
#[test]
fn ring_wraps_around_correctly() {
let (p, c) = channel::<u32>(3);
assert!(p.try_push(1).is_ok());
assert!(p.try_push(2).is_ok());
assert_eq!(c.try_pop(), Some(1));
assert_eq!(c.try_pop(), Some(2));
assert!(p.try_push(3).is_ok());
assert!(p.try_push(4).is_ok());
assert!(p.try_push(5).is_ok());
assert!(p.is_full());
assert_eq!(c.try_pop(), Some(3));
assert_eq!(c.try_pop(), Some(4));
assert_eq!(c.try_pop(), Some(5));
assert!(c.is_empty());
}
#[test]
fn interleaved_push_and_pop_works() {
let (p, c) = channel::<u32>(2);
for i in 0..100 {
assert!(p.try_push(i).is_ok());
assert_eq!(c.try_pop(), Some(i));
}
assert!(c.is_empty());
}
#[test]
fn dropping_ring_drops_remaining_items() {
use std::sync::atomic::{AtomicUsize, Ordering};
static DROPS: AtomicUsize = AtomicUsize::new(0);
struct Counted;
impl Drop for Counted {
fn drop(&mut self) {
DROPS.fetch_add(1, Ordering::Relaxed);
}
}
DROPS.store(0, Ordering::Relaxed);
{
let (p, _c) = channel::<Counted>(4);
assert!(p.try_push(Counted).is_ok());
assert!(p.try_push(Counted).is_ok());
assert!(p.try_push(Counted).is_ok());
}
assert_eq!(DROPS.load(Ordering::Relaxed), 3);
}
#[test]
fn popped_items_are_not_double_dropped() {
use std::sync::atomic::{AtomicUsize, Ordering};
static DROPS: AtomicUsize = AtomicUsize::new(0);
struct Counted;
impl Drop for Counted {
fn drop(&mut self) {
DROPS.fetch_add(1, Ordering::Relaxed);
}
}
DROPS.store(0, Ordering::Relaxed);
{
let (p, c) = channel::<Counted>(4);
assert!(p.try_push(Counted).is_ok());
assert!(p.try_push(Counted).is_ok());
let _popped = c.try_pop().unwrap();
}
assert_eq!(DROPS.load(Ordering::Relaxed), 2);
}
#[test]
#[should_panic(expected = "capacity must be non-zero")]
fn zero_capacity_panics() {
let _ = channel::<u32>(0);
}
#[test]
fn cross_thread_push_pop_preserves_order() {
use std::thread;
let (p, c) = channel::<u32>(64);
const N: u32 = 100_000;
let producer = thread::spawn(move || {
let mut next = 0;
while next < N {
if p.try_push(next).is_ok() {
next += 1;
} else {
std::thread::yield_now();
}
}
});
let consumer = thread::spawn(move || {
let mut expected = 0;
while expected < N {
match c.try_pop() {
Some(v) => {
assert_eq!(v, expected);
expected += 1;
}
None => std::thread::yield_now(),
}
}
});
producer.join().unwrap();
consumer.join().unwrap();
}
}