use crate::util::CachePadded;
use crate::{Error, Result};
use core::sync::atomic::{AtomicPtr, AtomicU64, AtomicUsize, Ordering};
#[cfg(feature = "unstable")]
use crossbeam_epoch::{self as epoch, Atomic, Owned};
#[cfg(feature = "std")]
use std::boxed::Box;
#[cfg(feature = "std")]
use std::vec::Vec;
#[cfg(feature = "std")]
use std::time::Duration;
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::_mm_prefetch;
#[derive(Debug)]
pub struct MpmcQueue<T> {
buffer: CachePadded<Box<[CachePadded<AtomicPtr<T>>]>>,
capacity: usize,
mask: usize,
head: CachePadded<AtomicUsize>,
tail: CachePadded<AtomicUsize>,
generation: CachePadded<AtomicU64>,
}
#[cfg(feature = "unstable")]
#[repr(align(64))] struct Node<T> {
value: T,
next: Atomic<Node<T>>,
}
#[cfg(feature = "unstable")]
#[derive(Debug)]
pub struct UnboundedMpmcQueue<T> {
head: Atomic<Node<T>>,
tail: Atomic<Node<T>>,
size: AtomicUsize,
}
impl<T> Clone for MpmcQueue<T> {
fn clone(&self) -> Self {
let mut new_queue: MpmcQueue<T> = MpmcQueue::with_capacity(self.capacity);
new_queue.capacity = self.capacity;
new_queue.mask = self.mask;
new_queue
}
}
#[cfg(feature = "unstable")]
impl<T> Clone for UnboundedMpmcQueue<T> {
fn clone(&self) -> Self {
Self::new()
}
}
impl<T> MpmcQueue<T> {
pub fn new(capacity: usize) -> Self {
Self::with_capacity(capacity)
}
fn with_capacity(capacity: usize) -> Self {
assert!(capacity > 0, "Queue capacity must be greater than 0");
let capacity = if capacity.is_power_of_two() {
capacity
} else {
capacity.next_power_of_two()
};
let mask = capacity - 1;
let buffer: Vec<CachePadded<AtomicPtr<T>>> = (0..capacity)
.map(|_| CachePadded::new(AtomicPtr::new(core::ptr::null_mut())))
.collect();
Self {
buffer: CachePadded::new(buffer.into_boxed_slice()),
capacity,
mask,
head: CachePadded::new(AtomicUsize::new(0)),
tail: CachePadded::new(AtomicUsize::new(0)),
generation: CachePadded::new(AtomicU64::new(0)),
}
}
#[inline]
pub fn push(&self, value: T) -> Result<()> {
let head = self.head.load(Ordering::Acquire);
let tail = self.tail.load(Ordering::Relaxed);
if (tail.wrapping_add(1) & self.mask) == head {
return Err(Error::CapacityExceeded);
}
let boxed = Box::into_raw(Box::new(value));
let index = tail & self.mask;
#[cfg(target_arch = "x86_64")]
unsafe {
_mm_prefetch((&self.buffer.inner()[index] as *const _ as *const i8), 0);
}
self.buffer.inner()[index].store(boxed, Ordering::Release);
self.tail.store(tail.wrapping_add(1), Ordering::Release);
self.generation.fetch_add(1, Ordering::Relaxed);
Ok(())
}
#[inline]
pub fn pop(&self) -> Option<T> {
let tail = self.tail.load(Ordering::Acquire);
let head = self.head.load(Ordering::Relaxed);
if head == tail {
return None;
}
let index = head & self.mask;
#[cfg(target_arch = "x86_64")]
unsafe {
_mm_prefetch((&self.buffer.inner()[index] as *const _ as *const i8), 0);
}
let ptr = self.buffer.inner()[index].load(Ordering::Acquire);
if ptr.is_null() {
return None;
}
self.head.store(head.wrapping_add(1), Ordering::Release);
self.buffer.inner()[index].store(core::ptr::null_mut(), Ordering::Relaxed);
let value = unsafe { Box::from_raw(ptr) };
self.generation.fetch_add(1, Ordering::Relaxed);
Some(*value)
}
#[inline]
pub fn len(&self) -> usize {
let head = self.head.load(Ordering::Acquire);
let tail = self.tail.load(Ordering::Acquire);
if tail >= head {
tail - head
} else {
self.capacity - (head - tail)
}
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline]
pub const fn capacity(&self) -> usize {
self.capacity
}
#[inline]
pub fn try_push(&self, value: T) -> Result<()> {
self.push(value)
}
#[inline]
pub fn try_pop(&self) -> Option<T> {
self.pop()
}
pub fn push_batch<I>(&self, values: I) -> usize
where
I: IntoIterator<Item = T>,
{
let mut pushed = 0;
for value in values {
match self.push(value) {
Ok(()) => pushed += 1,
Err(_) => break, }
}
pushed
}
pub fn pop_batch(&self, max_values: usize) -> Vec<T> {
let mut values = Vec::with_capacity(max_values);
for _ in 0..max_values {
match self.pop() {
Some(value) => values.push(value),
None => break, }
}
values
}
#[cfg(feature = "std")]
pub fn push_with_timeout(&self, value: T, timeout: Duration) -> Result<()>
where
T: Clone,
{
let start = std::time::Instant::now();
let mut backoff = Duration::from_nanos(100);
while start.elapsed() < timeout {
match self.push(value.clone()) {
Ok(()) => return Ok(()),
Err(Error::CapacityExceeded) => {
let elapsed = start.elapsed();
let remaining = timeout - elapsed;
if backoff > remaining {
break;
}
std::thread::sleep(backoff);
backoff = std::cmp::min(backoff * 2, Duration::from_millis(1));
}
Err(e) => return Err(e),
}
}
Err(Error::Timeout)
}
#[cfg(feature = "std")]
pub fn pop_with_timeout(&self, timeout: Duration) -> Option<T> {
let start = std::time::Instant::now();
let mut backoff = Duration::from_nanos(50);
while start.elapsed() < timeout {
if let Some(value) = self.pop() {
return Some(value);
}
let elapsed = start.elapsed();
let remaining = timeout - elapsed;
if backoff > remaining {
break;
}
std::thread::sleep(backoff);
backoff = std::cmp::min(backoff * 2, Duration::from_millis(1));
}
None
}
pub fn metrics(&self) -> QueueMetrics {
QueueMetrics {
capacity: self.capacity,
current_len: self.len(),
is_empty: self.is_empty(),
utilization_ratio: self.len() as f64 / self.capacity as f64,
}
}
}
#[derive(Debug, Clone)]
pub struct QueueMetrics {
pub capacity: usize,
pub current_len: usize,
pub is_empty: bool,
pub utilization_ratio: f64,
}
#[cfg(feature = "unstable")]
impl<T> UnboundedMpmcQueue<T> {
pub fn new() -> Self {
let guard = &epoch::pin();
let dummy_shared = Owned::new(Node {
value: unsafe { core::mem::MaybeUninit::uninit().assume_init() },
next: Atomic::null(),
})
.into_shared(guard);
Self {
head: Atomic::from(dummy_shared),
tail: Atomic::from(dummy_shared),
size: AtomicUsize::new(0),
}
}
#[inline]
pub fn push(&self, value: T) -> Result<()> {
let guard = &epoch::pin();
let new_node = Owned::new(Node {
value,
next: Atomic::null(),
});
let tail = self.tail.load(Ordering::Acquire, guard);
unsafe {
let tail_ref = tail.deref();
tail_ref.next.store(new_node, Ordering::Release);
}
let new_node_shared = unsafe { tail.deref().next.load(Ordering::Acquire, guard) };
self.tail.store(new_node_shared, Ordering::Release);
self.size.fetch_add(1, Ordering::Relaxed);
Ok(())
}
#[inline]
pub fn pop(&self) -> Option<T> {
let guard = &epoch::pin();
let head = self.head.load(Ordering::Acquire, guard);
let next = unsafe { head.deref().next.load(Ordering::Acquire, guard) };
if next.is_null() {
return None;
}
let value = unsafe {
let next_ref = next.deref();
std::ptr::read(&next_ref.value)
};
self.head.store(next, Ordering::Release);
self.size.fetch_sub(1, Ordering::Relaxed);
unsafe {
guard.defer_destroy(head);
}
Some(value)
}
#[inline]
pub fn len(&self) -> usize {
self.size.load(Ordering::Relaxed)
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline]
pub fn try_push(&self, value: T) -> Result<()> {
self.push(value)
}
#[inline]
pub fn try_pop(&self) -> Option<T> {
self.pop()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread;
use std::time::Duration;
use std::{format, vec};
#[test]
fn test_bounded_queue_basic_operations() {
let queue: MpmcQueue<i32> = MpmcQueue::new(4);
assert_eq!(queue.len(), 0);
assert!(queue.is_empty());
assert_eq!(queue.pop(), None);
assert!(queue.push(1).is_ok());
assert_eq!(queue.len(), 1);
assert!(!queue.is_empty());
assert_eq!(queue.pop(), Some(1));
assert_eq!(queue.len(), 0);
assert!(queue.is_empty());
}
#[test]
fn test_bounded_queue_capacity_rounding() {
let queue: MpmcQueue<i32> = MpmcQueue::new(5);
assert_eq!(queue.capacity(), 8);
let queue: MpmcQueue<i32> = MpmcQueue::new(16);
assert_eq!(queue.capacity(), 16); }
#[test]
fn test_bounded_queue_full_behavior() {
let queue: MpmcQueue<i32> = MpmcQueue::new(2);
assert!(queue.push(1).is_ok());
assert!(queue.push(2).is_ok());
assert_eq!(queue.len(), 2);
assert!(queue.push(3).is_err());
assert_eq!(queue.pop(), Some(1));
assert!(queue.push(3).is_ok());
assert_eq!(queue.pop(), Some(2));
assert_eq!(queue.pop(), Some(3));
assert_eq!(queue.pop(), None);
}
#[test]
fn test_bounded_queue_wrap_around() {
let queue: MpmcQueue<i32> = MpmcQueue::new(4);
for i in 0..10 {
assert!(queue.push(i).is_ok());
assert_eq!(queue.pop(), Some(i));
}
assert_eq!(queue.len(), 0);
assert!(queue.is_empty());
}
#[test]
fn test_bounded_queue_fifo_ordering() {
let queue: MpmcQueue<i32> = MpmcQueue::new(10);
for i in 0..5 {
assert!(queue.push(i).is_ok());
}
for i in 0..5 {
assert_eq!(queue.pop(), Some(i));
}
assert_eq!(queue.pop(), None);
}
#[test]
fn test_bounded_queue_concurrent_access() {
let queue = Arc::new(MpmcQueue::new(1000));
let num_producers = 4;
let num_consumers = 4;
let items_per_producer = 1000;
let mut handles = vec![];
for producer_id in 0..num_producers {
let queue = Arc::clone(&queue);
let handle = thread::spawn(move || {
for i in 0..items_per_producer {
let value = producer_id * items_per_producer + i;
while queue.push(value).is_err() {
thread::yield_now();
}
}
});
handles.push(handle);
}
for _ in 0..num_consumers {
let queue = Arc::clone(&queue);
let handle = thread::spawn(move || {
let mut received = 0;
while received < items_per_producer * num_producers / num_consumers {
if queue.pop().is_some() {
received += 1;
} else {
thread::yield_now();
}
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
assert!(queue.is_empty());
}
#[cfg(feature = "unstable")]
#[test]
fn test_unbounded_queue_basic_operations() {
let queue: UnboundedMpmcQueue<i32> = UnboundedMpmcQueue::new();
assert_eq!(queue.len(), 0);
assert!(queue.is_empty());
assert_eq!(queue.pop(), None);
assert!(queue.push(1).is_ok());
assert_eq!(queue.len(), 1);
assert!(!queue.is_empty());
assert_eq!(queue.pop(), Some(1));
assert_eq!(queue.len(), 0);
assert!(queue.is_empty());
}
#[cfg(feature = "unstable")]
#[test]
fn test_unbounded_queue_fifo_ordering() {
let queue: UnboundedMpmcQueue<i32> = UnboundedMpmcQueue::new();
for i in 0..10 {
assert!(queue.push(i).is_ok());
}
for i in 0..10 {
assert_eq!(queue.pop(), Some(i));
}
assert_eq!(queue.pop(), None);
}
#[cfg(feature = "unstable")]
#[test]
fn test_unbounded_queue_concurrent_access() {
let queue = Arc::new(UnboundedMpmcQueue::new());
let num_producers = 4;
let num_consumers = 4;
let items_per_producer = 1000;
let mut handles = vec![];
for producer_id in 0..num_producers {
let queue = Arc::clone(&queue);
let handle = thread::spawn(move || {
for i in 0..items_per_producer {
let value = producer_id * items_per_producer + i;
queue.push(value).unwrap();
}
});
handles.push(handle);
}
for _ in 0..num_consumers {
let queue = Arc::clone(&queue);
let handle = thread::spawn(move || {
let mut received = 0;
while received < items_per_producer * num_producers / num_consumers {
if queue.pop().is_some() {
received += 1;
} else {
thread::yield_now();
}
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
use std::sync::atomic::{AtomicUsize, Ordering};
static DROP_COUNT: AtomicUsize = AtomicUsize::new(0);
#[derive(Debug, PartialEq, Eq, Hash)]
struct DropTracker {
id: usize,
}
impl Drop for DropTracker {
fn drop(&mut self) {
DROP_COUNT.fetch_add(1, Ordering::Relaxed);
}
}
let queue: UnboundedMpmcQueue<DropTracker> = UnboundedMpmcQueue::new();
for i in 0..100 {
queue.push(DropTracker { id: i }).unwrap();
}
for _ in 0..50 {
queue.pop();
}
drop(queue);
let dropped_items = DROP_COUNT.load(Ordering::Relaxed);
assert_eq!(dropped_items, 100);
}
#[test]
fn test_bounded_queue_stress() {
let queue = Arc::new(MpmcQueue::new(1000));
let num_threads = 8;
let operations_per_thread = 10000;
let mut handles = vec![];
for thread_id in 0..num_threads {
let queue = Arc::clone(&queue);
let handle = thread::spawn(move || {
for i in 0..operations_per_thread {
let value = thread_id * operations_per_thread + i;
match i % 3 {
0 => {
while queue.push(value).is_err() {
thread::yield_now();
}
}
1 => {
let _ = queue.pop();
}
2 => {
if queue.push(value).is_ok() {
let _ = queue.pop();
}
}
_ => unreachable!(),
}
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
while queue.pop().is_some() {}
assert!(queue.is_empty());
}
#[cfg(feature = "unstable")]
#[test]
fn test_unbounded_queue_stress() {
let queue = Arc::new(UnboundedMpmcQueue::new());
let num_threads = 8;
let operations_per_thread = 10000;
let mut handles = vec![];
for thread_id in 0..num_threads {
let queue = Arc::clone(&queue);
let handle = thread::spawn(move || {
for i in 0..operations_per_thread {
let value = thread_id * operations_per_thread + i;
match i % 3 {
0 => {
queue.push(value).unwrap();
}
1 => {
let _ = queue.pop();
}
2 => {
queue.push(value).unwrap();
let _ = queue.pop();
}
_ => unreachable!(),
}
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
while queue.pop().is_some() {}
assert!(queue.is_empty());
}
#[test]
fn test_bounded_queue_properties() {
let queue: MpmcQueue<i32> = MpmcQueue::new(10);
let mut pushed = 0;
let mut popped = 0;
for i in 0..100 {
if queue.push(i).is_ok() {
pushed += 1;
}
if i % 2 == 0 && queue.pop().is_some() {
popped += 1;
}
assert_eq!(queue.len(), pushed - popped);
}
}
#[cfg(feature = "unstable")]
#[test]
fn test_unbounded_queue_properties() {
let queue: UnboundedMpmcQueue<i32> = UnboundedMpmcQueue::new();
let mut pushed = 0;
let mut popped = 0;
for i in 0..100 {
queue.push(i).unwrap();
pushed += 1;
if i % 2 == 0 && queue.pop().is_some() {
popped += 1;
}
assert_eq!(queue.len(), pushed - popped);
}
}
#[test]
fn test_cache_alignment() {
use core::mem;
assert_eq!(mem::align_of::<MpmcQueue<i32>>(), 64);
#[cfg(feature = "unstable")]
{
assert_eq!(mem::align_of::<UnboundedMpmcQueue<i32>>(), 64);
assert_eq!(mem::align_of::<Node<i32>>(), 64);
}
}
#[test]
fn test_debug_format() {
let bounded: MpmcQueue<i32> = MpmcQueue::new(10);
let debug_str = format!("{:?}", bounded);
assert!(debug_str.contains("MpmcQueue"));
#[cfg(feature = "unstable")]
{
let unbounded: UnboundedMpmcQueue<i32> = UnboundedMpmcQueue::new();
let debug_str = format!("{:?}", unbounded);
assert!(debug_str.contains("UnboundedMpmcQueue"));
}
}
}
pub use MpmcQueue as BoundedMpmcQueue;