pub mod modes;
pub(crate) mod task;
mod worker;
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use super::queue::TaskQueue;
use super::stealer::work_stealing_queues;
use crate::{
metrics::MetricsCollector, priority_stealer::prioritized_work_stealing_queues,
queue::PriorityQueue,
};
use modes::{
GlobalQueueMode, NonPriorityTaskFetchMode, PriorityGlobalQueueMode, PriorityTaskFetchMode,
PriorityWorkStealingMode, TaskFetchMode, WorkStealingMode,
};
use task::{spawn_task, BoxedTask, Priority, PriorityTask, TaskHandle};
use worker::{priority_worker_loop, worker_loop, WorkerHandle};
pub enum SubmitTaskType {
Function(Arc<dyn Fn(BoxedTask) + Send + Sync>),
Struct(Arc<dyn Fn(PriorityTask) + Send + Sync>),
}
pub enum FetchTaskType {
Function(Arc<dyn Fn(usize) -> Option<BoxedTask> + Send + Sync>),
Struct(Arc<dyn Fn(usize) -> Option<PriorityTask> + Send + Sync>),
}
pub struct ThreadPool<M: TaskFetchMode> {
running: Arc<AtomicBool>,
workers: Vec<WorkerHandle>,
_fetch_task: FetchTaskType,
submit_task: SubmitTaskType,
_work_stealing: bool,
metrics_collector: Option<Arc<dyn MetricsCollector>>,
mode: M,
}
impl<M: NonPriorityTaskFetchMode> ThreadPool<M> {
pub fn spawn<F, T>(&self, f: F) -> TaskHandle<T>
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
let (task, handle) = spawn_task(f);
if let SubmitTaskType::Function(submit) = &self.submit_task {
(submit)(task);
self.metrics_collector
.as_ref()
.map(|m| m.on_task_submitted());
}
handle
}
}
impl<M: PriorityTaskFetchMode> ThreadPool<M> {
pub fn spawn_with_priority<F, T>(&self, f: F, priority: usize) -> TaskHandle<T>
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
let (task, handle) = spawn_task(f);
if let SubmitTaskType::Struct(submit) = &self.submit_task {
let task = PriorityTask::new(task, Priority(priority));
(submit)(task);
self.metrics_collector
.as_ref()
.map(|m| m.on_task_submitted());
}
handle
}
}
impl<M: TaskFetchMode> ThreadPool<M> {
pub fn shutdown(mut self) {
self.running.store(false, Ordering::Release);
for worker in &mut self.workers {
worker.join();
}
}
pub fn mode(&self) -> &str {
self.mode.mode()
}
}
pub trait IBuilderState {}
pub struct DefaultModeState;
impl IBuilderState for DefaultModeState {}
pub struct WorkStealingState;
impl IBuilderState for WorkStealingState {}
pub struct PriorityState;
impl IBuilderState for PriorityState {}
pub struct PriorityWorkStealingState;
impl IBuilderState for PriorityWorkStealingState {}
pub struct ThreadPoolBuilder<S: IBuilderState = DefaultModeState> {
num_threads: usize,
metrics_collector: Option<Arc<dyn MetricsCollector>>,
_state: std::marker::PhantomData<S>,
}
impl<S: IBuilderState> ThreadPoolBuilder<S> {
pub fn with_metrics_collector(mut self, collector: Arc<dyn MetricsCollector>) -> Self {
self.metrics_collector = Some(collector);
self
}
pub fn num_threads(mut self, n: usize) -> Self {
self.num_threads = n;
self
}
}
impl ThreadPoolBuilder<DefaultModeState> {
pub fn new() -> Self {
Self {
num_threads: 4,
metrics_collector: None,
_state: std::marker::PhantomData,
}
}
pub fn set_work_stealing(self) -> ThreadPoolBuilder<WorkStealingState> {
ThreadPoolBuilder {
num_threads: self.num_threads,
metrics_collector: self.metrics_collector,
_state: std::marker::PhantomData,
}
}
pub fn enable_priority(self) -> ThreadPoolBuilder<PriorityState> {
ThreadPoolBuilder {
num_threads: self.num_threads,
metrics_collector: self.metrics_collector,
_state: std::marker::PhantomData,
}
}
pub fn build(self) -> ThreadPool<GlobalQueueMode> {
let running = Arc::new(AtomicBool::new(true));
let queue = TaskQueue::new();
let fetch_task = {
let queue_clone = queue.clone_inner();
Arc::new(move |_id: usize| queue_clone.pop())
as Arc<dyn Fn(usize) -> Option<BoxedTask> + Send + Sync>
};
let submit_task = {
let queue_clone = queue.clone_inner();
Arc::new(move |task: BoxedTask| {
queue_clone.push(task);
}) as Arc<dyn Fn(BoxedTask) + Send + Sync>
};
let mut workers = Vec::with_capacity(self.num_threads);
let running_clone = Arc::clone(&running);
let metrics_collector = self.metrics_collector.clone();
for i in 0..self.num_threads {
let r = Arc::clone(&running_clone);
let metrics_collector_clone = self.metrics_collector.clone();
let ft = Arc::clone(&fetch_task);
let handle = std::thread::spawn(move || {
while r.load(Ordering::Acquire) {
if let Some(task) = ft(i) {
metrics_collector_clone
.as_ref()
.map(|m| m.on_task_started());
task();
metrics_collector_clone
.as_ref()
.map(|m| m.on_task_completed());
} else {
std::thread::yield_now();
}
}
metrics_collector_clone
.as_ref()
.map(|m| m.on_worker_stopped());
});
workers.push(WorkerHandle::new(i, handle));
metrics_collector.as_ref().map(|m| m.on_worker_started());
}
ThreadPool::<GlobalQueueMode> {
running,
workers,
_fetch_task: FetchTaskType::Function(fetch_task),
submit_task: SubmitTaskType::Function(submit_task),
_work_stealing: false,
metrics_collector: self.metrics_collector,
mode: GlobalQueueMode,
}
}
}
impl ThreadPoolBuilder<WorkStealingState> {
pub fn enable_priority(self) -> ThreadPoolBuilder<PriorityWorkStealingState> {
ThreadPoolBuilder {
num_threads: self.num_threads,
metrics_collector: self.metrics_collector,
_state: std::marker::PhantomData,
}
}
pub fn build(self) -> ThreadPool<WorkStealingMode> {
let running = Arc::new(AtomicBool::new(true));
let metrics_collector_clone = self.metrics_collector.clone();
let (injector, stealers, mut workers_local) =
work_stealing_queues(self.num_threads, metrics_collector_clone);
let submit_task = {
let injector = Arc::clone(&injector);
Arc::new(move |task: BoxedTask| {
injector.push(task);
}) as Arc<dyn Fn(BoxedTask) + Send + Sync>
};
let stealers = Arc::new(stealers);
let mut workers = Vec::with_capacity(self.num_threads);
let running_clone = Arc::clone(&running);
for i in 0..self.num_threads {
let r = Arc::clone(&running_clone);
let metrics_collector_clone = self.metrics_collector.clone();
let worker = workers_local.remove(0);
let injector = Arc::clone(&injector);
let stealers_for_thread = Arc::clone(&stealers);
let fetch_task = move || {
if let Some(task) = worker.pop() {
return Some(task);
}
match injector.steal_batch(&worker) {
crossbeam::deque::Steal::Success(_) => {
if let Some(task) = worker.pop() {
return Some(task);
}
}
crossbeam::deque::Steal::Empty | crossbeam::deque::Steal::Retry => {}
}
for st in stealers_for_thread.iter() {
match st.steal() {
crossbeam::deque::Steal::Success(t) => return Some(t),
_ => {}
}
}
None
};
let handle = std::thread::spawn(move || {
worker_loop(r, fetch_task, metrics_collector_clone);
});
workers.push(WorkerHandle::new(i, handle));
}
ThreadPool::<WorkStealingMode> {
running,
workers,
_fetch_task: FetchTaskType::Function(Arc::new(|_| None)),
submit_task: SubmitTaskType::Function(submit_task),
_work_stealing: true,
metrics_collector: self.metrics_collector,
mode: WorkStealingMode,
}
}
}
impl ThreadPoolBuilder<PriorityState> {
pub fn build(self) -> ThreadPool<PriorityGlobalQueueMode> {
let running = Arc::new(AtomicBool::new(true));
let queue = PriorityQueue::new();
let fetch_task = {
let queue_clone = queue.clone();
Arc::new(move |_: usize| queue_clone.pop()) as Arc<dyn Fn(usize) -> Option<PriorityTask> + Send + Sync>
};
let submit_task = {
let queue_clone = queue.clone();
Arc::new(move |task: PriorityTask| {
queue_clone.push(task);
}) as Arc<dyn Fn(PriorityTask) + Send + Sync>
};
let mut workers = Vec::with_capacity(self.num_threads);
let running_clone = Arc::clone(&running);
let metrics_collector = self.metrics_collector.clone();
for i in 0..self.num_threads {
let r = Arc::clone(&running_clone);
let metrics_collector_clone = self.metrics_collector.clone();
let ft = Arc::clone(&fetch_task);
let handle = std::thread::spawn(move || {
while r.load(Ordering::Acquire) {
if let Some(pt) = ft(i) {
metrics_collector_clone
.as_ref()
.map(|m| m.on_task_started());
(pt.task)();
metrics_collector_clone
.as_ref()
.map(|m| m.on_task_completed());
} else {
std::thread::yield_now();
}
}
metrics_collector_clone
.as_ref()
.map(|m| m.on_worker_stopped());
});
workers.push(WorkerHandle::new(i, handle));
metrics_collector.as_ref().map(|m| m.on_worker_started());
}
ThreadPool::<PriorityGlobalQueueMode> {
running,
workers,
_fetch_task: FetchTaskType::Struct(fetch_task),
submit_task: SubmitTaskType::Struct(submit_task),
_work_stealing: false,
metrics_collector: self.metrics_collector,
mode: PriorityGlobalQueueMode,
}
}
}
impl ThreadPoolBuilder<PriorityWorkStealingState> {
pub fn build(self) -> ThreadPool<PriorityWorkStealingMode> {
let running = Arc::new(AtomicBool::new(true));
let metrics_collector_clone = self.metrics_collector.clone();
let (injector, stealers, mut workers_local) =
prioritized_work_stealing_queues(self.num_threads, metrics_collector_clone);
let submit_task = {
let injector = Arc::clone(&injector);
Arc::new(move |task: PriorityTask| {
injector.push(task);
}) as Arc<dyn Fn(PriorityTask) + Send + Sync>
};
let stealers = Arc::new(stealers);
let mut workers = Vec::with_capacity(self.num_threads);
let running_clone = Arc::clone(&running);
for i in 0..self.num_threads {
let r = Arc::clone(&running_clone);
let metrics_collector_clone = self.metrics_collector.clone();
let worker = workers_local.remove(0);
let injector = Arc::clone(&injector);
let stealers_for_thread = Arc::clone(&stealers);
let fetch_task = move || {
if let Some(task) = worker.pop() {
return Some(task);
}
match injector.steal_batch(&worker) {
crate::priority_stealer::Steal::Success(_) => {
if let Some(task) = worker.pop() {
return Some(task);
}
}
crate::priority_stealer::Steal::Empty
| crate::priority_stealer::Steal::Retry => {}
}
for st in stealers_for_thread.iter() {
match st.steal() {
crate::priority_stealer::Steal::Success(t) => return Some(t),
_ => {}
}
}
None
};
let handle = std::thread::spawn(move || {
priority_worker_loop(r, fetch_task, metrics_collector_clone);
});
workers.push(WorkerHandle::new(i, handle));
}
ThreadPool::<PriorityWorkStealingMode> {
running,
workers,
_fetch_task: FetchTaskType::Struct(Arc::new(|_| None)),
submit_task: SubmitTaskType::Struct(submit_task),
_work_stealing: true,
metrics_collector: self.metrics_collector,
mode: PriorityWorkStealingMode,
}
}
}