use crate::util::CachePadded;
use crate::{Error, Result};
use core::sync::atomic::{AtomicIsize, AtomicUsize, Ordering};
use crate::metrics::MetricsCollector;
#[cfg(feature = "std")]
use std::boxed::Box;
#[cfg(feature = "std")]
use std::vec::Vec;
#[derive(Debug)]
pub struct WorkStealingDeque<T> {
buffer: CachePadded<Box<[Option<T>]>>,
capacity: usize,
mask: usize,
bottom: CachePadded<AtomicIsize>,
top: CachePadded<AtomicIsize>,
#[allow(dead_code)]
epoch: CachePadded<AtomicUsize>,
}
impl<T> Clone for WorkStealingDeque<T> {
fn clone(&self) -> Self {
let mut new_deque = WorkStealingDeque::<T>::with_capacity(self.capacity);
new_deque.capacity = self.capacity;
new_deque.mask = self.mask;
new_deque
}
}
impl<T> WorkStealingDeque<T> {
pub fn new(capacity: usize) -> Self {
Self::with_capacity(capacity)
}
fn with_capacity(capacity: usize) -> Self {
assert!(capacity > 0, "Deque capacity must be greater than 0");
let capacity = if capacity.is_power_of_two() {
capacity
} else {
capacity.next_power_of_two()
};
let mask = capacity - 1;
let mut buffer = Vec::with_capacity(capacity);
buffer.resize_with(capacity, || None);
let buffer = buffer.into_boxed_slice();
Self {
buffer: CachePadded::new(buffer),
capacity,
mask,
bottom: CachePadded::new(std::sync::atomic::AtomicIsize::new(0)),
top: CachePadded::new(std::sync::atomic::AtomicIsize::new(0)),
epoch: CachePadded::new(std::sync::atomic::AtomicUsize::new(0)),
}
}
#[inline]
pub fn push(&mut self, value: T) -> Result<()> {
let bottom = self.bottom.get().load(Ordering::Relaxed);
let top = self.top.get().load(Ordering::Acquire);
if bottom - top >= self.capacity as isize {
return Err(Error::WouldBlock);
}
let index = (bottom as usize) & self.mask;
self.buffer.inner_mut()[index] = Some(value);
self.bottom.get().store(bottom + 1, Ordering::Release);
Ok(())
}
#[inline]
pub fn pop(&mut self) -> Option<T> {
let bottom = self.bottom.get().load(Ordering::Relaxed);
if bottom == 0 {
return None;
}
self.bottom.get().store(bottom - 1, Ordering::Relaxed);
let top = self.top.get().load(Ordering::Acquire);
if top < bottom {
let index = ((bottom - 1) as usize) & self.mask;
let value = self.buffer.inner_mut()[index].take();
if top == bottom - 1 {
if self
.top
.get()
.compare_exchange(top, top + 1, Ordering::Release, Ordering::Relaxed)
.is_ok()
{
self.bottom.get().store(bottom, Ordering::Relaxed);
return value;
} else {
self.bottom.get().store(bottom, Ordering::Relaxed);
return None;
}
}
value
} else {
self.bottom.get().store(bottom, Ordering::Relaxed);
None
}
}
#[inline]
pub fn steal(&mut self) -> Option<T> {
let top = self.top.get().load(Ordering::Acquire);
let bottom = self.bottom.get().load(Ordering::Acquire);
if top >= bottom {
return None;
}
let index = (top as usize) & self.mask;
if let Some(value) = self.buffer.inner_mut()[index].take() {
if self
.top
.get()
.compare_exchange(top, top + 1, Ordering::Release, Ordering::Relaxed)
.is_ok()
{
Some(value)
} else {
self.buffer.inner_mut()[index] = Some(value);
None
}
} else {
None
}
}
#[inline]
pub fn len(&self) -> usize {
let bottom = self.bottom.get().load(Ordering::Acquire);
let top = self.top.get().load(Ordering::Acquire);
(bottom - top).max(0) as usize
}
#[inline]
pub fn is_empty(&self) -> bool {
let bottom = self.bottom.get().load(Ordering::Acquire);
let top = self.top.get().load(Ordering::Acquire);
bottom == top
}
#[inline]
pub const fn capacity(&self) -> usize {
self.capacity
}
#[inline]
pub fn try_push(&mut self, value: T) -> Result<()> {
self.push(value)
}
#[inline]
pub fn try_pop(&mut self) -> Option<T> {
self.pop()
}
#[inline]
pub fn try_steal(&mut self) -> Option<T> {
self.steal()
}
}
#[cfg(feature = "std")]
impl<T> MetricsCollector for WorkStealingDeque<T> {
fn metrics(&self) -> crate::metrics::PerformanceMetrics {
crate::metrics::PerformanceMetrics::default()
}
fn reset_metrics(&self) {
}
fn set_metrics_enabled(&self, _enabled: bool) {
}
fn is_metrics_enabled(&self) -> bool {
false }
}
#[cfg(test)]
mod tests {
use super::*;
use std::format;
#[test]
fn test_basic_operations() {
let mut deque: WorkStealingDeque<i32> = WorkStealingDeque::new(4);
assert_eq!(deque.len(), 0);
assert!(deque.is_empty());
assert_eq!(deque.pop(), None);
assert_eq!(deque.steal(), None);
assert!(deque.push(1).is_ok());
assert_eq!(deque.len(), 1);
assert!(!deque.is_empty());
assert_eq!(deque.pop(), Some(1));
assert_eq!(deque.len(), 0);
assert!(deque.is_empty());
}
#[test]
fn test_lifo_behavior() {
let mut deque: WorkStealingDeque<i32> = WorkStealingDeque::new(4);
assert!(deque.push(1).is_ok());
assert!(deque.push(2).is_ok());
assert!(deque.push(3).is_ok());
assert_eq!(deque.pop(), Some(3));
assert_eq!(deque.pop(), Some(2));
assert_eq!(deque.pop(), Some(1));
assert_eq!(deque.pop(), None);
}
#[test]
fn test_fifo_stealing() {
let mut deque: WorkStealingDeque<i32> = WorkStealingDeque::new(4);
assert!(deque.push(1).is_ok());
assert!(deque.push(2).is_ok());
assert!(deque.push(3).is_ok());
assert_eq!(deque.steal(), Some(1));
assert_eq!(deque.steal(), Some(2));
assert_eq!(deque.steal(), Some(3));
assert_eq!(deque.steal(), None);
}
#[test]
fn test_mixed_operations() {
let mut deque: WorkStealingDeque<i32> = WorkStealingDeque::new(4);
assert!(deque.push(1).is_ok());
assert!(deque.push(2).is_ok());
assert!(deque.push(3).is_ok());
assert_eq!(deque.pop(), Some(3)); assert_eq!(deque.steal(), Some(1)); assert_eq!(deque.pop(), Some(2)); assert_eq!(deque.pop(), None);
assert_eq!(deque.steal(), None);
}
#[test]
fn test_full_deque() {
let mut deque: WorkStealingDeque<i32> = WorkStealingDeque::new(2);
assert!(deque.push(1).is_ok());
assert!(deque.push(2).is_ok());
assert_eq!(deque.len(), 2);
assert!(deque.push(3).is_err());
assert_eq!(deque.pop(), Some(2));
assert!(deque.push(3).is_ok());
assert_eq!(deque.pop(), Some(3));
assert_eq!(deque.pop(), Some(1));
}
#[test]
fn test_wrap_around() {
let mut deque: WorkStealingDeque<i32> = WorkStealingDeque::new(4);
for i in 0..10 {
assert!(deque.push(i).is_ok());
assert_eq!(deque.pop(), Some(i));
}
}
#[test]
fn test_cache_alignment() {
use core::mem;
assert_eq!(mem::align_of::<WorkStealingDeque<i32>>(), 64);
}
#[test]
fn test_debug_format() {
let deque: WorkStealingDeque<i32> = WorkStealingDeque::new(4);
let debug_str = format!("{:?}", deque);
assert!(debug_str.contains("WorkStealingDeque"));
assert!(debug_str.contains("capacity"));
}
}