pub mod task;
mod work_stealing;
mod worker;
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use super::queue::TaskQueue;
use super::stealer::work_stealing_queues;
use task::{spawn_task, BoxedTask, TaskHandle};
use worker::{worker_loop, WorkerHandle};
pub enum FetchTaskMode {
GlobalQueue,
WorkStealing,
}
pub struct ThreadPool {
running: Arc<AtomicBool>,
workers: Vec<WorkerHandle>,
_fetch_task: Arc<dyn Fn(usize) -> Option<BoxedTask> + Send + Sync>,
submit_task: Arc<dyn Fn(BoxedTask) + Send + Sync>,
_work_stealing: bool,
}
impl ThreadPool {
pub fn new(num_threads: usize) -> Self {
ThreadPoolBuilder::new().num_threads(num_threads).build()
}
pub fn spawn<F, T>(&self, f: F) -> TaskHandle<T>
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
let (task, handle) = spawn_task(f);
(self.submit_task)(task);
handle
}
pub fn shutdown(mut self) {
self.running.store(false, Ordering::Release);
for worker in &mut self.workers {
worker.join();
}
}
pub fn mode(&self) -> FetchTaskMode {
if self._work_stealing {
FetchTaskMode::WorkStealing
} else {
FetchTaskMode::GlobalQueue
}
}
}
pub struct ThreadPoolBuilder {
num_threads: usize,
work_stealing: bool,
}
impl ThreadPoolBuilder {
pub fn new() -> Self {
Self {
num_threads: 4,
work_stealing: false,
}
}
pub fn num_threads(mut self, n: usize) -> Self {
self.num_threads = n;
self
}
pub fn work_stealing(mut self, enable: bool) -> Self {
self.work_stealing = enable;
self
}
pub fn build(self) -> ThreadPool {
let running = Arc::new(AtomicBool::new(true));
if self.work_stealing {
let (injector, stealers, mut workers_local) = work_stealing_queues(self.num_threads);
let running_clone = Arc::clone(&running);
let submit_task = {
let injector = Arc::clone(&injector);
Arc::new(move |task: BoxedTask| {
injector.push(task);
}) as Arc<dyn Fn(BoxedTask) + Send + Sync>
};
let stealers = Arc::new(stealers);
let mut workers = Vec::with_capacity(self.num_threads);
for i in 0..self.num_threads {
let r = Arc::clone(&running_clone);
let worker = workers_local.remove(0);
let injector = Arc::clone(&injector);
let stealers_for_thread = Arc::clone(&stealers);
let fetch_task = move || {
if let Some(task) = worker.pop() {
return Some(task);
}
match injector.steal_batch(&worker) {
crossbeam::deque::Steal::Success(_) => {
if let Some(task) = worker.pop() {
return Some(task);
}
}
crossbeam::deque::Steal::Empty | crossbeam::deque::Steal::Retry => {}
}
for st in stealers_for_thread.iter() {
match st.steal() {
crossbeam::deque::Steal::Success(t) => return Some(t),
_ => {}
}
}
None
};
let handle = std::thread::spawn(move || {
worker_loop(r, fetch_task);
});
workers.push(WorkerHandle::new(i, handle));
}
ThreadPool {
running,
workers,
_fetch_task: Arc::new(|_| None),
submit_task,
_work_stealing: true,
}
} else {
let queue = TaskQueue::new();
let running_clone = Arc::clone(&running);
let fetch_task = {
let queue_clone = queue.clone_inner();
Arc::new(move |_id: usize| queue_clone.pop())
as Arc<dyn Fn(usize) -> Option<BoxedTask> + Send + Sync>
};
let submit_task = {
let queue_clone = queue.clone_inner();
Arc::new(move |task: BoxedTask| {
queue_clone.push(task);
}) as Arc<dyn Fn(BoxedTask) + Send + Sync>
};
let mut workers = Vec::with_capacity(self.num_threads);
for i in 0..self.num_threads {
let r = Arc::clone(&running_clone);
let ft = Arc::clone(&fetch_task);
let handle = std::thread::spawn(move || {
while r.load(Ordering::Acquire) {
if let Some(task) = ft(i) {
task();
} else {
std::thread::yield_now();
}
}
});
workers.push(WorkerHandle::new(i, handle));
}
ThreadPool {
running,
workers,
_fetch_task: fetch_task,
submit_task,
_work_stealing: false,
}
}
}
}