1use crate::{
2 TaskFnPointer, TaskParamPointer, WorkItem, future::WorkFuture, queue::BatchQueue,
3 uniform_tasks_to_pointers, worker::Worker,
4};
5use std::sync::Arc;
6
7pub struct ThreadPool {
8 workers: Vec<Worker>,
9 queue: Arc<BatchQueue>,
10}
11
12impl ThreadPool {
13 pub fn new(worker_count: usize) -> Self {
14 assert!(worker_count > 0, "Must have at least one worker");
15
16 let queue = Arc::new(BatchQueue::new());
17
18 let workers: Vec<Worker> = (0..worker_count)
19 .map(|id| Worker::new(id, queue.clone()))
20 .collect();
21
22 ThreadPool { workers, queue }
23 }
24
25 pub fn submit_raw_task(&self, task_fn: TaskFnPointer, params: TaskParamPointer) -> WorkFuture {
26 self.queue.push_single_task(task_fn, params)
27 }
28
29 pub fn submit_raw_task_batch(&self, tasks: &[WorkItem]) -> WorkFuture {
30 self.queue.push_task_batch(tasks)
31 }
32
33 pub fn submit_task<T>(&self, task_fn: TaskFnPointer, params: &T) -> WorkFuture {
34 let params_ptr = params as *const T as TaskParamPointer;
35 self.submit_raw_task(task_fn, params_ptr)
36 }
37
38 pub fn submit_batch_uniform<T>(&self, task_fn: TaskFnPointer, params_vec: &[T]) -> WorkFuture {
39 let tasks = uniform_tasks_to_pointers(task_fn, params_vec);
40 self.submit_raw_task_batch(&tasks)
41 }
42
43 pub fn worker_count(&self) -> usize {
44 self.workers.len()
45 }
46
47 pub fn total_pending(&self) -> usize {
48 self.queue.len()
49 }
50}
51
52impl Drop for ThreadPool {
53 fn drop(&mut self) {
54 self.queue.shutdown();
55
56 let workers = std::mem::take(&mut self.workers);
57 for worker in workers {
58 let _ = worker.handle.join();
59 }
60 }
61}