use std::cell::UnsafeCell;
use std::mem::MaybeUninit;
use std::sync::atomic::{AtomicUsize, Ordering};
pub struct BoundedMpscQueue<T> {
mask: usize,
slots: Box<[Slot<T>]>,
head: UnsafeCell<usize>,
tail: AtomicUsize,
}
struct Slot<T> {
seq: AtomicUsize,
value: UnsafeCell<MaybeUninit<T>>,
}
unsafe impl<T: Send> Sync for BoundedMpscQueue<T> {}
unsafe impl<T: Send> Send for BoundedMpscQueue<T> {}
impl<T> BoundedMpscQueue<T> {
pub fn new(capacity: usize) -> Self {
let cap = capacity.next_power_of_two().max(2);
let mut slots = Vec::with_capacity(cap);
for i in 0..cap {
slots.push(Slot {
seq: AtomicUsize::new(i),
value: UnsafeCell::new(MaybeUninit::uninit()),
});
}
Self {
mask: cap - 1,
slots: slots.into_boxed_slice(),
head: UnsafeCell::new(0),
tail: AtomicUsize::new(0),
}
}
pub fn capacity(&self) -> usize {
self.mask + 1
}
pub fn try_enqueue(&self, value: T) -> Result<(), T> {
let mut tail = self.tail.load(Ordering::Relaxed);
loop {
let slot = &self.slots[tail & self.mask];
let seq = slot.seq.load(Ordering::Acquire);
let diff = seq.wrapping_sub(tail) as isize;
if diff == 0 {
match self.tail.compare_exchange_weak(
tail,
tail.wrapping_add(1),
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => {
unsafe { (*slot.value.get()).write(value) };
slot.seq.store(tail.wrapping_add(1), Ordering::Release);
return Ok(());
}
Err(t) => tail = t,
}
} else if diff < 0 {
return Err(value);
} else {
tail = self.tail.load(Ordering::Relaxed);
}
}
}
pub fn try_dequeue(&mut self) -> Option<T> {
let head = unsafe { *self.head.get() };
let slot = &self.slots[head & self.mask];
let seq = slot.seq.load(Ordering::Acquire);
let diff = seq.wrapping_sub(head.wrapping_add(1)) as isize;
if diff == 0 {
let value = unsafe { (*slot.value.get()).assume_init_read() };
slot.seq
.store(head.wrapping_add(self.mask + 1), Ordering::Release);
unsafe { *self.head.get() = head.wrapping_add(1) };
Some(value)
} else {
None
}
}
pub fn len(&self) -> usize {
let head = unsafe { *self.head.get() };
let tail = self.tail.load(Ordering::Acquire);
tail.wrapping_sub(head)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl<T> Drop for BoundedMpscQueue<T> {
fn drop(&mut self) {
while self.try_dequeue().is_some() {}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::thread;
#[test]
fn enqueue_dequeue_single_item() {
let mut q: BoundedMpscQueue<u32> = BoundedMpscQueue::new(4);
assert!(q.try_enqueue(42).is_ok());
assert_eq!(q.try_dequeue(), Some(42));
assert_eq!(q.try_dequeue(), None);
}
#[test]
fn capacity_is_power_of_two() {
let q: BoundedMpscQueue<u32> = BoundedMpscQueue::new(5);
assert_eq!(q.capacity(), 8);
let q2: BoundedMpscQueue<u32> = BoundedMpscQueue::new(1);
assert_eq!(q2.capacity(), 2);
}
#[test]
fn enqueue_full_returns_value_back() {
let q: BoundedMpscQueue<u32> = BoundedMpscQueue::new(4);
for i in 0..4 {
assert!(q.try_enqueue(i).is_ok());
}
let rejected = q.try_enqueue(99);
assert_eq!(rejected, Err(99));
}
#[test]
fn fifo_order_single_producer() {
let mut q: BoundedMpscQueue<u32> = BoundedMpscQueue::new(16);
for i in 0..16 {
assert!(q.try_enqueue(i).is_ok());
}
for i in 0..16 {
assert_eq!(q.try_dequeue(), Some(i));
}
}
#[test]
fn drain_then_refill_wraps_ring() {
let mut q: BoundedMpscQueue<u32> = BoundedMpscQueue::new(4);
for round in 0..10 {
for i in 0..4 {
assert!(q.try_enqueue(round * 4 + i).is_ok());
}
for i in 0..4 {
assert_eq!(q.try_dequeue(), Some(round * 4 + i));
}
}
}
#[test]
fn multi_producer_no_lost_items() {
let producers = 4usize;
let per_producer = 5_000usize;
let q: Arc<BoundedMpscQueue<u64>> = Arc::new(BoundedMpscQueue::new(1024));
let stop = Arc::new(AtomicUsize::new(0));
let mut handles = Vec::new();
for tid in 0..producers as u64 {
let q = q.clone();
handles.push(thread::spawn(move || {
let mut i = 0u64;
while i < per_producer as u64 {
if q.try_enqueue((tid << 32) | i).is_ok() {
i += 1;
} else {
std::hint::spin_loop();
}
}
}));
}
let cq = q.clone();
let cstop = stop.clone();
let consumer = thread::spawn(move || {
let qp = Arc::as_ptr(&cq) as *mut BoundedMpscQueue<u64>;
let qm = unsafe { &mut *qp };
let mut counts = [0u64; 4];
let mut total = 0usize;
while total < producers * per_producer {
if let Some(v) = qm.try_dequeue() {
counts[(v >> 32) as usize] += 1;
total += 1;
} else {
std::hint::spin_loop();
}
}
cstop.store(1, Ordering::Relaxed);
counts
});
for h in handles {
h.join().unwrap();
}
let counts = consumer.join().unwrap();
for c in counts {
assert_eq!(c as usize, per_producer);
}
}
#[test]
fn drops_pending_items_on_destruction() {
struct DropCounted(Arc<AtomicUsize>);
impl Drop for DropCounted {
fn drop(&mut self) {
self.0.fetch_add(1, Ordering::Relaxed);
}
}
let counter = Arc::new(AtomicUsize::new(0));
{
let q: BoundedMpscQueue<DropCounted> = BoundedMpscQueue::new(4);
assert!(q.try_enqueue(DropCounted(counter.clone())).is_ok());
assert!(q.try_enqueue(DropCounted(counter.clone())).is_ok());
assert!(q.try_enqueue(DropCounted(counter.clone())).is_ok());
}
assert_eq!(counter.load(Ordering::Relaxed), 3);
}
#[test]
fn len_tracks_outstanding_items() {
let mut q: BoundedMpscQueue<u32> = BoundedMpscQueue::new(8);
assert!(q.is_empty());
q.try_enqueue(1).unwrap();
q.try_enqueue(2).unwrap();
assert_eq!(q.len(), 2);
q.try_dequeue().unwrap();
assert_eq!(q.len(), 1);
}
}