use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use crossbeam::deque::{Steal, Stealer, Worker};
use crossbeam::sync::{Parker, Unparker};
pub struct Pool<T: Send> {
stealers: Vec<Stealer<T>>,
unparkers: Vec<Unparker>,
parked_count: AtomicUsize,
dir_pending: AtomicUsize,
shutdown: AtomicBool,
}
impl<T: Send> Pool<T> {
pub fn new(num_workers: usize) -> (Self, Vec<Worker<T>>, Vec<Parker>) {
let n = num_workers.max(1);
let mut workers = Vec::with_capacity(n);
for _ in 0..n {
workers.push(Worker::<T>::new_lifo());
}
let stealers: Vec<Stealer<T>> = workers.iter().map(|w| w.stealer()).collect();
let mut parkers = Vec::with_capacity(n);
let mut unparkers = Vec::with_capacity(n);
for _ in 0..n {
let p = Parker::new();
unparkers.push(p.unparker().clone());
parkers.push(p);
}
let pool = Pool {
stealers,
unparkers,
parked_count: AtomicUsize::new(0),
dir_pending: AtomicUsize::new(0),
shutdown: AtomicBool::new(false),
};
(pool, workers, parkers)
}
#[inline]
pub fn add_dirs(&self, n: usize) {
if n > 0 {
self.dir_pending.fetch_add(n, Ordering::Release);
}
}
#[inline]
pub fn sub_dirs(&self, n: usize) {
if n > 0 {
self.dir_pending.fetch_sub(n, Ordering::AcqRel);
}
}
#[inline]
pub fn maybe_unpark(&self) {
if self.parked_count.load(Ordering::Acquire) > 0 {
for u in &self.unparkers {
u.unpark();
}
}
}
#[inline]
pub fn try_steal(&self, my_id: usize) -> Option<T> {
let n = self.stealers.len();
for k in 1..n {
let idx = (my_id + k) % n;
loop {
match self.stealers[idx].steal() {
Steal::Success(t) => return Some(t),
Steal::Empty => break,
Steal::Retry => continue,
}
}
}
None
}
fn signal_shutdown(&self) {
self.shutdown.store(true, Ordering::Release);
for u in &self.unparkers {
u.unpark();
}
}
}
pub fn worker_loop<T, F>(id: usize, local: &Worker<T>, pool: &Pool<T>, parker: &Parker, mut walk: F)
where
T: Send,
F: FnMut(T, &Worker<T>),
{
'outer: loop {
if let Some(t) = local.pop() {
walk(t, local);
continue;
}
if let Some(t) = pool.try_steal(id) {
walk(t, local);
continue 'outer;
}
pool.parked_count.fetch_add(1, Ordering::SeqCst);
if let Some(t) = pool.try_steal(id) {
pool.parked_count.fetch_sub(1, Ordering::Relaxed);
walk(t, local);
continue;
}
if pool.dir_pending.load(Ordering::Acquire) == 0 {
pool.parked_count.fetch_sub(1, Ordering::Relaxed);
pool.signal_shutdown();
return;
}
parker.park();
pool.parked_count.fetch_sub(1, Ordering::Relaxed);
if pool.shutdown.load(Ordering::Acquire) {
return;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread;
#[test]
fn pool_drains_with_one_worker() {
let (pool, workers, parkers) = Pool::<u32>::new(1);
let pool = Arc::new(pool);
pool.add_dirs(3);
workers[0].push(1);
workers[0].push(2);
workers[0].push(3);
let mut workers = workers.into_iter();
let mut parkers = parkers.into_iter();
let local = workers.next().unwrap();
let parker = parkers.next().unwrap();
let counter = Arc::new(AtomicUsize::new(0));
let pool_c = Arc::clone(&pool);
let counter_c = Arc::clone(&counter);
thread::spawn(move || {
worker_loop(0, &local, &pool_c, &parker, |_v, _w| {
counter_c.fetch_add(1, Ordering::Relaxed);
pool_c.sub_dirs(1);
});
})
.join()
.unwrap();
assert_eq!(counter.load(Ordering::Relaxed), 3);
}
#[test]
fn pool_drains_with_stealing() {
let (pool, workers, parkers) = Pool::<u32>::new(4);
let pool = Arc::new(pool);
pool.add_dirs(40);
for v in 0..40 {
workers[0].push(v);
}
let counter = Arc::new(AtomicUsize::new(0));
let mut handles = Vec::new();
for (id, (local, parker)) in workers.into_iter().zip(parkers.into_iter()).enumerate() {
let pool_c = Arc::clone(&pool);
let counter_c = Arc::clone(&counter);
handles.push(thread::spawn(move || {
worker_loop(id, &local, &pool_c, &parker, |_v, _w| {
counter_c.fetch_add(1, Ordering::Relaxed);
pool_c.sub_dirs(1);
});
}));
}
for h in handles {
h.join().unwrap();
}
assert_eq!(counter.load(Ordering::Relaxed), 40);
}
}