use crossbeam_deque::{Injector, Steal, Stealer, Worker};
use std::sync::{Arc, Mutex};
struct Inner<T> {
injector: Injector<T>,
stealers: Vec<Stealer<T>>,
workers: Vec<Mutex<Worker<T>>>,
len: std::sync::atomic::AtomicIsize,
}
#[derive(Clone)]
pub struct WorkQueue<T: Send + 'static> {
inner: Arc<Inner<T>>,
}
impl<T: Send + 'static> WorkQueue<T> {
#[must_use]
pub fn new(workers: usize) -> Self {
let num = workers.max(1);
let injector = Injector::new();
let mut worker_deques = Vec::with_capacity(num);
let mut stealers = Vec::with_capacity(num);
for _ in 0..num {
let w: Worker<T> = Worker::new_fifo();
stealers.push(w.stealer());
worker_deques.push(Mutex::new(w));
}
Self {
inner: Arc::new(Inner {
injector,
stealers,
workers: worker_deques,
len: std::sync::atomic::AtomicIsize::new(0),
}),
}
}
pub fn push(&self, task: T) {
self.inner.injector.push(task);
self.inner
.len
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
pub fn steal(&self) -> Option<T> {
if let Ok(guard) = self.inner.workers[0].lock() {
loop {
match self.inner.injector.steal_batch_and_pop(&guard) {
Steal::Success(v) => {
self.inner
.len
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
return Some(v);
}
Steal::Retry => continue,
Steal::Empty => break,
}
}
}
for w_mutex in &self.inner.workers {
if let Ok(guard) = w_mutex.lock() {
if let Some(item) = guard.pop() {
self.inner
.len
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
return Some(item);
}
}
}
for stealer in &self.inner.stealers {
loop {
match stealer.steal() {
Steal::Success(v) => {
self.inner
.len
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
return Some(v);
}
Steal::Retry => continue,
Steal::Empty => break,
}
}
}
None
}
#[must_use]
pub fn len(&self) -> usize {
let v = self.inner.len.load(std::sync::atomic::Ordering::Relaxed);
v.max(0) as usize
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::thread;
#[test]
fn push_and_steal_basic() {
let wq = WorkQueue::<u32>::new(1);
wq.push(10_u32);
wq.push(20_u32);
let a = wq.steal();
let b = wq.steal();
assert!(a.is_some());
assert!(b.is_some());
assert_eq!(wq.len(), 0);
}
#[test]
fn steal_empty_returns_none() {
let wq = WorkQueue::<u32>::new(2);
assert!(wq.steal().is_none());
}
#[test]
fn len_tracks_count() {
let wq = WorkQueue::<u32>::new(2);
assert_eq!(wq.len(), 0);
wq.push(1_u32);
assert_eq!(wq.len(), 1);
wq.push(2_u32);
assert_eq!(wq.len(), 2);
wq.steal();
assert_eq!(wq.len(), 1);
}
#[test]
fn is_empty_basic() {
let wq = WorkQueue::<u32>::new(2);
assert!(wq.is_empty());
wq.push(1_u32);
assert!(!wq.is_empty());
}
#[test]
fn clone_shares_state() {
let wq = WorkQueue::<u32>::new(2);
let wq2 = wq.clone();
wq.push(99_u32);
let stolen = wq2.steal();
assert_eq!(stolen, Some(99_u32));
}
#[test]
fn threaded_stress_10000_tasks() {
const TASKS: u32 = 10_000;
const WORKERS: usize = 4;
let wq = WorkQueue::<u32>::new(WORKERS);
for i in 0..TASKS {
wq.push(i);
}
let stolen_count = Arc::new(AtomicUsize::new(0));
let mut handles = Vec::with_capacity(WORKERS);
for _ in 0..WORKERS {
let wq_clone = wq.clone();
let count = Arc::clone(&stolen_count);
handles.push(thread::spawn(move || {
let mut local = 0usize;
let mut empty_streak = 0usize;
loop {
match wq_clone.steal() {
Some(_) => {
local += 1;
empty_streak = 0;
}
None => {
empty_streak += 1;
if empty_streak > 200 {
break;
}
std::hint::spin_loop();
}
}
}
count.fetch_add(local, Ordering::Relaxed);
}));
}
for h in handles {
h.join().expect("worker thread panicked");
}
let total = stolen_count.load(Ordering::Relaxed);
assert_eq!(
total, TASKS as usize,
"expected all {TASKS} tasks to be consumed, got {total}"
);
}
#[test]
fn multi_producer_multi_consumer() {
const PER_PRODUCER: usize = 1_000;
const PRODUCERS: usize = 4;
const CONSUMERS: usize = 4;
const TOTAL: usize = PER_PRODUCER * PRODUCERS;
let wq = WorkQueue::<usize>::new(CONSUMERS);
let consumed = Arc::new(AtomicUsize::new(0));
let mut handles = Vec::new();
for p in 0..PRODUCERS {
let wq_p = wq.clone();
handles.push(thread::spawn(move || {
for i in 0..PER_PRODUCER {
wq_p.push(p * PER_PRODUCER + i);
}
}));
}
for h in handles {
h.join().expect("producer panicked");
}
let mut handles = Vec::new();
for _ in 0..CONSUMERS {
let wq_c = wq.clone();
let cnt = Arc::clone(&consumed);
handles.push(thread::spawn(move || {
let mut miss = 0;
loop {
match wq_c.steal() {
Some(_) => {
cnt.fetch_add(1, Ordering::Relaxed);
miss = 0;
}
None => {
miss += 1;
if miss > 500 {
break;
}
}
}
}
}));
}
for h in handles {
h.join().expect("consumer panicked");
}
assert_eq!(consumed.load(Ordering::Relaxed), TOTAL);
}
}