use std::cell::UnsafeCell;
use std::mem::MaybeUninit;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
pub struct SpscQueue<T> {
buffer: Box<[UnsafeCell<MaybeUninit<T>>]>,
capacity: usize,
mask: usize,
head: AtomicUsize,
tail: AtomicUsize,
}
unsafe impl<T: Send> Send for SpscQueue<T> {}
unsafe impl<T: Send> Sync for SpscQueue<T> {}
impl<T> SpscQueue<T> {
pub fn new(capacity: usize) -> Self {
assert!(capacity > 0, "capacity must be greater than 0");
let capacity = capacity.next_power_of_two();
let mask = capacity - 1;
let buffer: Vec<UnsafeCell<MaybeUninit<T>>> = (0..capacity)
.map(|_| UnsafeCell::new(MaybeUninit::uninit()))
.collect();
Self {
buffer: buffer.into_boxed_slice(),
capacity,
mask,
head: AtomicUsize::new(0),
tail: AtomicUsize::new(0),
}
}
#[inline]
pub fn push(&self, item: T) -> Result<(), T> {
let tail = self.tail.load(Ordering::Relaxed);
let head = self.head.load(Ordering::Acquire);
if tail.wrapping_sub(head) >= self.capacity {
return Err(item);
}
let idx = tail & self.mask;
unsafe {
(*self.buffer[idx].get()).write(item);
}
self.tail.store(tail.wrapping_add(1), Ordering::Release);
Ok(())
}
#[inline]
pub fn push_blocking(&self, mut item: T) {
loop {
match self.push(item) {
Ok(()) => return,
Err(returned) => {
item = returned;
std::hint::spin_loop();
}
}
}
}
#[inline]
pub fn pop(&self) -> Option<T> {
let head = self.head.load(Ordering::Relaxed);
let tail = self.tail.load(Ordering::Acquire);
if head == tail {
return None;
}
let idx = head & self.mask;
let item = unsafe { (*self.buffer[idx].get()).assume_init_read() };
self.head.store(head.wrapping_add(1), Ordering::Release);
Some(item)
}
#[inline]
pub fn is_empty(&self) -> bool {
let head = self.head.load(Ordering::Acquire);
let tail = self.tail.load(Ordering::Acquire);
head == tail
}
#[inline]
pub fn is_full(&self) -> bool {
let head = self.head.load(Ordering::Acquire);
let tail = self.tail.load(Ordering::Acquire);
tail.wrapping_sub(head) >= self.capacity
}
#[inline]
pub fn len(&self) -> usize {
let head = self.head.load(Ordering::Acquire);
let tail = self.tail.load(Ordering::Acquire);
tail.wrapping_sub(head)
}
#[inline]
pub fn capacity(&self) -> usize {
self.capacity
}
}
impl<T> Drop for SpscQueue<T> {
fn drop(&mut self) {
while self.pop().is_some() {}
}
}
pub struct SpscProducer<T> {
queue: Arc<SpscQueue<T>>,
}
impl<T> SpscProducer<T> {
#[inline]
pub fn push(&self, item: T) {
self.queue.push_blocking(item);
}
#[inline]
pub fn try_push(&self, item: T) -> Result<(), T> {
self.queue.push(item)
}
#[inline]
pub fn is_full(&self) -> bool {
self.queue.is_full()
}
#[inline]
pub fn len(&self) -> usize {
self.queue.len()
}
#[inline]
pub fn capacity(&self) -> usize {
self.queue.capacity()
}
}
pub struct SpscConsumer<T> {
queue: Arc<SpscQueue<T>>,
}
impl<T> SpscConsumer<T> {
#[inline]
pub fn pop(&self) -> Option<T> {
self.queue.pop()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.queue.is_empty()
}
#[inline]
pub fn len(&self) -> usize {
self.queue.len()
}
#[inline]
pub fn capacity(&self) -> usize {
self.queue.capacity()
}
}
pub fn spsc_channel<T>(capacity: usize) -> (SpscProducer<T>, SpscConsumer<T>) {
let queue = Arc::new(SpscQueue::new(capacity));
(
SpscProducer {
queue: queue.clone(),
},
SpscConsumer { queue },
)
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn test_spsc_basic() {
let queue = SpscQueue::<i32>::new(4);
assert!(queue.is_empty());
assert!(!queue.is_full());
queue.push(1).unwrap();
queue.push(2).unwrap();
assert_eq!(queue.len(), 2);
assert_eq!(queue.pop(), Some(1));
assert_eq!(queue.pop(), Some(2));
assert_eq!(queue.pop(), None);
}
#[test]
fn test_spsc_full() {
let queue = SpscQueue::<i32>::new(2);
queue.push(1).unwrap();
queue.push(2).unwrap();
assert!(queue.is_full());
assert!(queue.push(3).is_err());
queue.pop();
queue.push(3).unwrap();
}
#[test]
fn test_spsc_threaded() {
let (producer, consumer) = spsc_channel::<i32>(16);
let producer_thread = thread::spawn(move || {
for i in 0..1000 {
producer.push(i);
}
});
let consumer_thread = thread::spawn(move || {
let mut sum = 0i64;
let mut count = 0;
while count < 1000 {
if let Some(v) = consumer.pop() {
sum += v as i64;
count += 1;
} else {
thread::yield_now();
}
}
sum
});
producer_thread.join().unwrap();
let sum = consumer_thread.join().unwrap();
assert_eq!(sum, 499500);
}
#[test]
fn test_capacity_power_of_two() {
let queue = SpscQueue::<i32>::new(5);
assert_eq!(queue.capacity(), 8);
let queue = SpscQueue::<i32>::new(8);
assert_eq!(queue.capacity(), 8); }
#[test]
fn test_spsc_wrap_around() {
let queue = SpscQueue::<i32>::new(4);
for round in 0..10 {
for i in 0..4 {
queue.push(round * 4 + i).unwrap();
}
assert!(queue.is_full());
for i in 0..4 {
assert_eq!(queue.pop(), Some(round * 4 + i));
}
assert!(queue.is_empty());
}
}
#[test]
fn test_spsc_producer_consumer_handles() {
let (producer, consumer) = spsc_channel::<String>(8);
producer.push("hello".to_string());
producer.push("world".to_string());
assert_eq!(producer.len(), 2);
assert_eq!(consumer.len(), 2);
assert_eq!(consumer.pop(), Some("hello".to_string()));
assert_eq!(consumer.pop(), Some("world".to_string()));
assert!(consumer.is_empty());
}
#[test]
fn test_spsc_drop_remaining() {
use std::sync::atomic::{AtomicUsize, Ordering};
static DROP_COUNT: AtomicUsize = AtomicUsize::new(0);
#[derive(Debug)]
struct DropCounter;
impl Drop for DropCounter {
fn drop(&mut self) {
DROP_COUNT.fetch_add(1, Ordering::SeqCst);
}
}
DROP_COUNT.store(0, Ordering::SeqCst);
{
let queue = SpscQueue::<DropCounter>::new(4);
queue.push(DropCounter).unwrap();
queue.push(DropCounter).unwrap();
queue.push(DropCounter).unwrap();
}
assert_eq!(DROP_COUNT.load(Ordering::SeqCst), 3);
}
#[test]
fn test_spsc_high_throughput() {
let (producer, consumer) = spsc_channel::<u64>(1024);
let items = 100_000u64;
let producer_thread = thread::spawn(move || {
for i in 0..items {
producer.push(i);
}
});
let consumer_thread = thread::spawn(move || {
let mut received = 0u64;
let mut expected = 0u64;
while received < items {
if let Some(v) = consumer.pop() {
assert_eq!(v, expected, "Items received out of order");
expected += 1;
received += 1;
} else {
thread::yield_now();
}
}
received
});
producer_thread.join().unwrap();
let received = consumer_thread.join().unwrap();
assert_eq!(received, items);
}
#[test]
fn test_try_push() {
let (producer, _consumer) = spsc_channel::<i32>(2);
assert!(producer.try_push(1).is_ok());
assert!(producer.try_push(2).is_ok());
assert!(producer.try_push(3).is_err()); }
#[test]
#[should_panic(expected = "capacity must be greater than 0")]
fn test_zero_capacity_panics() {
let _queue = SpscQueue::<i32>::new(0);
}
}