#![allow(dead_code)]
#![allow(clippy::module_name_repetitions)]
use std::sync::atomic::{AtomicI64, Ordering};
use std::sync::{Arc, Mutex};
#[derive(Debug, PartialEq, Eq)]
pub enum StealResult<T> {
Success(T),
Empty,
Retry,
InvalidIndex,
}
struct CircularBuffer<T> {
slots: Vec<Option<T>>,
capacity: usize,
}
impl<T: Clone> CircularBuffer<T> {
fn with_capacity(cap: usize) -> Self {
let capacity = cap.next_power_of_two().max(2);
Self {
slots: (0..capacity).map(|_| None).collect(),
capacity,
}
}
#[inline]
fn mask(&self) -> usize {
self.capacity - 1
}
fn get(&self, index: i64) -> Option<T> {
let idx = (index as usize) & self.mask();
self.slots.get(idx).and_then(|s| s.clone())
}
fn put(&mut self, index: i64, val: T) {
let idx = (index as usize) & self.mask();
if let Some(slot) = self.slots.get_mut(idx) {
*slot = Some(val);
}
}
fn clear_slot(&mut self, index: i64) {
let idx = (index as usize) & self.mask();
if let Some(slot) = self.slots.get_mut(idx) {
*slot = None;
}
}
fn grow(&self, top: i64, bottom: i64) -> Self {
let new_cap = self.capacity * 2;
let mut next = Self {
slots: (0..new_cap).map(|_| None).collect(),
capacity: new_cap,
};
let mask = next.capacity - 1;
for i in top..bottom {
let old_idx = (i as usize) & self.mask();
let new_idx = (i as usize) & mask;
if let (Some(dst), Some(src)) = (next.slots.get_mut(new_idx), self.slots.get(old_idx)) {
*dst = src.clone();
}
}
next
}
}
pub struct WorkStealDeque<T: Clone + Send + 'static> {
top: AtomicI64,
bottom: AtomicI64,
buf: Mutex<CircularBuffer<T>>,
}
impl<T: Clone + Send + 'static> WorkStealDeque<T> {
#[must_use]
pub fn new(initial_cap: usize) -> Self {
Self {
top: AtomicI64::new(0),
bottom: AtomicI64::new(0),
buf: Mutex::new(CircularBuffer::with_capacity(initial_cap)),
}
}
#[must_use]
pub fn len(&self) -> usize {
let b = self.bottom.load(Ordering::Relaxed);
let t = self.top.load(Ordering::Relaxed);
let diff = b - t;
if diff > 0 {
diff as usize
} else {
0
}
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[must_use]
pub fn capacity(&self) -> usize {
self.buf.lock().map(|g| g.capacity).unwrap_or(0)
}
pub fn push(&self, item: T) -> Result<(), T> {
let b = self.bottom.load(Ordering::Relaxed);
let t = self.top.load(Ordering::Acquire);
let mut guard = self.buf.lock().map_err(|_| item.clone())?;
let size = b - t;
if size >= guard.capacity as i64 {
let grown = guard.grow(t, b);
*guard = grown;
}
guard.put(b, item);
drop(guard);
self.bottom.store(b + 1, Ordering::Release);
Ok(())
}
pub fn pop(&self) -> Option<T> {
let b = self.bottom.load(Ordering::Relaxed) - 1;
self.bottom.store(b, Ordering::Relaxed);
std::sync::atomic::fence(Ordering::SeqCst);
let t = self.top.load(Ordering::Relaxed);
let size = b - t;
if size < 0 {
self.bottom.store(b + 1, Ordering::Relaxed);
return None;
}
let guard = self.buf.lock().ok()?;
let item = guard.get(b);
if size > 0 {
return item;
}
let won = self
.top
.compare_exchange(t, t + 1, Ordering::SeqCst, Ordering::Relaxed)
.is_ok();
self.bottom.store(b + 1, Ordering::Relaxed);
if won {
item
} else {
None
}
}
pub fn steal(&self) -> StealResult<T> {
let t = self.top.load(Ordering::Acquire);
std::sync::atomic::fence(Ordering::SeqCst);
let b = self.bottom.load(Ordering::Acquire);
if t >= b {
return StealResult::Empty;
}
let guard = match self.buf.lock() {
Ok(g) => g,
Err(_) => return StealResult::Retry,
};
let item = match guard.get(t) {
Some(v) => v,
None => return StealResult::Empty,
};
drop(guard);
match self
.top
.compare_exchange(t, t + 1, Ordering::SeqCst, Ordering::Relaxed)
{
Ok(_) => StealResult::Success(item),
Err(_) => StealResult::Retry,
}
}
}
pub struct WorkStealPool<T: Clone + Send + 'static> {
deques: Vec<Arc<WorkStealDeque<T>>>,
}
impl<T: Clone + Send + 'static> WorkStealPool<T> {
#[must_use]
pub fn new(num_workers: usize) -> Self {
let deques = (0..num_workers)
.map(|_| Arc::new(WorkStealDeque::new(16)))
.collect();
Self { deques }
}
#[must_use]
pub fn num_workers(&self) -> usize {
self.deques.len()
}
pub fn push(&self, worker_id: usize, item: T) -> Result<(), T> {
match self.deques.get(worker_id) {
Some(d) => d.push(item),
None => Err(item),
}
}
pub fn pop(&self, worker_id: usize) -> Option<T> {
self.deques.get(worker_id)?.pop()
}
pub fn steal(&self, thief_id: usize, target_id: usize) -> StealResult<T> {
if thief_id >= self.deques.len() || target_id >= self.deques.len() {
return StealResult::InvalidIndex;
}
self.deques[target_id].steal()
}
#[must_use]
pub fn total_len(&self) -> usize {
self.deques.iter().map(|d| d.len()).sum()
}
#[must_use]
pub fn is_globally_empty(&self) -> bool {
self.deques.iter().all(|d| d.is_empty())
}
#[must_use]
pub fn deque(&self, worker_id: usize) -> Option<Arc<WorkStealDeque<T>>> {
self.deques.get(worker_id).cloned()
}
#[must_use]
pub fn worker_len(&self, worker_id: usize) -> usize {
self.deques.get(worker_id).map_or(0, |d| d.len())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn push_pop_lifo() {
let pool = WorkStealPool::<u32>::new(1);
pool.push(0, 1).expect("ok");
pool.push(0, 2).expect("ok");
pool.push(0, 3).expect("ok");
assert_eq!(pool.pop(0), Some(3));
assert_eq!(pool.pop(0), Some(2));
assert_eq!(pool.pop(0), Some(1));
assert_eq!(pool.pop(0), None);
}
#[test]
fn steal_fifo_order() {
let pool = WorkStealPool::<u32>::new(2);
for i in 0..4_u32 {
pool.push(0, i).expect("ok");
}
match pool.steal(1, 0) {
StealResult::Success(v) => assert_eq!(v, 0),
other => panic!("expected success, got {other:?}"),
}
match pool.steal(1, 0) {
StealResult::Success(v) => assert_eq!(v, 1),
other => panic!("expected success, got {other:?}"),
}
}
#[test]
fn steal_empty() {
let pool = WorkStealPool::<u32>::new(2);
assert_eq!(pool.steal(1, 0), StealResult::Empty);
}
#[test]
fn pop_empty() {
let pool = WorkStealPool::<u32>::new(1);
assert_eq!(pool.pop(0), None);
}
#[test]
fn total_len() {
let pool = WorkStealPool::<i32>::new(3);
pool.push(0, 10).expect("ok");
pool.push(1, 20).expect("ok");
pool.push(2, 30).expect("ok");
assert_eq!(pool.total_len(), 3);
}
#[test]
fn globally_empty() {
let pool = WorkStealPool::<i32>::new(2);
assert!(pool.is_globally_empty());
pool.push(0, 99).expect("ok");
assert!(!pool.is_globally_empty());
pool.pop(0);
assert!(pool.is_globally_empty());
}
#[test]
fn steal_invalid_index() {
let pool = WorkStealPool::<u32>::new(2);
assert_eq!(pool.steal(99, 0), StealResult::InvalidIndex);
assert_eq!(pool.steal(0, 99), StealResult::InvalidIndex);
}
#[test]
fn deque_grows_beyond_capacity() {
let pool = WorkStealPool::<u32>::new(1);
for i in 0..40_u32 {
pool.push(0, i).expect("push should succeed after grow");
}
assert_eq!(pool.worker_len(0), 40);
}
#[test]
fn num_workers() {
let pool = WorkStealPool::<u8>::new(5);
assert_eq!(pool.num_workers(), 5);
}
#[test]
fn deque_handle() {
let pool = WorkStealPool::<u32>::new(2);
pool.push(0, 42).expect("ok");
let d = pool.deque(0).expect("valid index");
assert_eq!(d.len(), 1);
assert!(pool.deque(99).is_none());
}
#[test]
fn push_invalid_worker() {
let pool = WorkStealPool::<u32>::new(2);
assert!(pool.push(99, 1).is_err());
}
#[test]
fn steal_last_item() {
let pool = WorkStealPool::<u32>::new(2);
pool.push(0, 7).expect("ok");
match pool.steal(1, 0) {
StealResult::Success(v) => assert_eq!(v, 7),
StealResult::Retry => {}
other => panic!("unexpected {other:?}"),
}
}
#[test]
fn deque_len_and_is_empty() {
let d: WorkStealDeque<u32> = WorkStealDeque::new(4);
assert!(d.is_empty());
d.push(1).expect("ok");
d.push(2).expect("ok");
assert_eq!(d.len(), 2);
d.pop();
assert_eq!(d.len(), 1);
}
#[test]
fn worker_len_invalid() {
let pool = WorkStealPool::<u32>::new(2);
assert_eq!(pool.worker_len(99), 0);
}
#[test]
fn threaded_push_steal() {
use std::sync::atomic::{AtomicUsize, Ordering};
use std::thread;
let pool = Arc::new(WorkStealPool::<u32>::new(2));
let stolen = Arc::new(AtomicUsize::new(0));
for i in 0..100_u32 {
pool.push(0, i).expect("ok");
}
let pool_clone = Arc::clone(&pool);
let stolen_clone = Arc::clone(&stolen);
let thief = thread::spawn(move || {
for _ in 0..100 {
match pool_clone.steal(1, 0) {
StealResult::Success(_) => {
stolen_clone.fetch_add(1, Ordering::Relaxed);
}
StealResult::Empty => break,
StealResult::Retry => {}
StealResult::InvalidIndex => {}
}
}
});
let mut owner_got = 0usize;
for _ in 0..50 {
if pool.pop(0).is_some() {
owner_got += 1;
}
}
thief.join().expect("thief thread panicked");
let total = owner_got + stolen.load(Ordering::Relaxed);
assert!(total <= 100, "total {total} > 100");
}
}