use crate::runnable::Runnable;
use crossbeam_channel::Sender;
use crossbeam_deque::{Injector, Steal, Stealer, Worker};
use std::fmt::Debug;
use std::iter;
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc, Mutex,
};
use std::thread;
use std::time::Duration;
fn find_task<T>(
local: &Worker<T>,
global: &Injector<T>,
stealers: &Arc<Mutex<Vec<Stealer<T>>>>,
) -> Option<T>
where
T: Runnable + Send + Sync,
{
local.pop().or_else(|| {
iter::repeat_with(|| {
global
.steal_batch_and_pop(local)
.or_else(|| {
let s = stealers.clone();
let stealers = s.lock().unwrap();
stealers.iter().map(Stealer::steal).collect()
})
})
.find(|s| !s.is_retry())
.and_then(Steal::success)
})
}
pub fn worker<T>(
injector: &Injector<T>,
stealers: &Arc<Mutex<Vec<Stealer<T>>>>,
running: &Arc<AtomicBool>,
output: &Sender<Result<<T as Runnable>::Ok, <T as Runnable>::Error>>,
) where
T: Runnable + Debug + Send + Sync,
{
let worker = Worker::<T>::new_fifo();
let stealer = worker.stealer();
let mut stealers_locked = stealers.lock().unwrap();
stealers_locked.push(stealer);
drop(stealers_locked);
while running.load(Ordering::Acquire) {
if let Some(mut task) = find_task(&worker, injector, &stealers) {
match task.run() {
Ok(val) => {
let _ = output.send(Ok(val));
}
Err(e) => {
if task.should_retry(&e) {
injector.push(task);
} else {
let _ = output.send(Err(e));
}
}
}
} else {
thread::sleep(Duration::from_millis(500));
}
}
}
#[cfg(test)]
mod test {
use super::worker;
use crate::error::Err;
use crate::runnable::Runnable;
use crossbeam_channel::{select, unbounded};
use crossbeam_deque::{Injector, Stealer};
use crossbeam_utils::thread::scope;
use std::result::Result;
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc, Mutex,
};
use std::thread;
use std::time::{Duration, Instant};
const ERROR_RATE: usize = 100;
#[derive(Debug, PartialEq)]
struct Task {
value: usize,
error_count: usize,
}
impl Runnable for Task {
type Ok = ();
type Error = Err;
fn run(&mut self) -> Result<(), Err> {
if self.value % ERROR_RATE == 0 {
thread::sleep(Duration::from_millis(500));
self.error_count += 1;
Err("fail".into())
} else {
thread::sleep(Duration::from_millis(100));
Ok(())
}
}
fn should_retry(&self, _error: &Err) -> bool {
self.error_count < 2
}
fn store_result(&mut self, _result: Result<(), Err>) {}
}
#[test]
fn worker_threads() {
let injector = Injector::new();
let stealers: Arc<Mutex<Vec<Stealer<Task>>>> = Arc::new(Mutex::new(vec![]));
let running = Arc::new(AtomicBool::new(true));
let (sender, receiver) = unbounded();
let task_count = 500;
let _ = scope(|s| {
for _ in 0..64 {
let inj_ref = &injector;
let stealers_c = stealers.clone();
let running = Arc::clone(&running);
let output = sender.clone();
let _ = s.spawn(move |_| {
worker(inj_ref, &stealers_c, &running, &output);
});
}
let now = Instant::now();
for i in 0..task_count {
injector.push(Task {
value: i,
error_count: 0,
});
}
let mut count = 0;
let mut ok_count = 0;
let mut err_count = 0;
loop {
select! {
recv(receiver) -> msg => {
count += 1;
if let Ok(result) = msg {
match result {
Ok(_) => ok_count += 1,
Err(_) => err_count += 1,
}
} else {
err_count += 1;
}
}
default(Duration::from_millis(200)) => {
if count == task_count {
running.store(false, Ordering::Release);
break;
}
}
}
}
let _elapsed = now.elapsed();
assert_eq!(ok_count, task_count - (task_count / ERROR_RATE));
assert_eq!(err_count, task_count / ERROR_RATE);
});
}
}