Skip to main content

mill_io/
thread_pool.rs

1#[cfg(feature = "unstable-mpmc")]
2use std::sync::mpmc as channel;
3#[cfg(not(feature = "unstable-mpmc"))]
4use std::sync::mpsc as channel;
5
6use parking_lot::{Condvar, Mutex};
7use std::{
8    cmp::Ordering as CmpOrdering,
9    collections::BinaryHeap,
10    panic::{catch_unwind, AssertUnwindSafe},
11    sync::{
12        atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering},
13        Arc, Barrier,
14    },
15    thread::{Builder, JoinHandle},
16    time::Instant,
17};
18
19use crate::error::Result;
20
21pub const DEFAULT_POOL_CAPACITY: usize = 4;
22
23pub type Task = Box<dyn FnOnce() + Send + 'static>;
24
25enum WorkerMessage {
26    Task(Task),
27    Terminate,
28}
29
30pub struct ThreadPool {
31    workers: Vec<Worker>,
32    senders: Vec<channel::Sender<WorkerMessage>>,
33    next_worker: AtomicUsize,
34}
35
36impl Default for ThreadPool {
37    fn default() -> Self {
38        let default_capacity = std::thread::available_parallelism()
39            .map(|n| n.get())
40            .unwrap_or(DEFAULT_POOL_CAPACITY);
41        Self::new(default_capacity)
42    }
43}
44
45impl ThreadPool {
46    pub fn new(capacity: usize) -> Self {
47        let mut workers = Vec::with_capacity(capacity);
48        let mut senders = Vec::with_capacity(capacity);
49
50        for id in 0..capacity {
51            let (sender, receiver) = channel::channel::<WorkerMessage>();
52            workers.push(Worker::new(id, receiver));
53            senders.push(sender);
54        }
55
56        Self {
57            workers,
58            senders,
59            next_worker: AtomicUsize::new(0),
60        }
61    }
62
63    pub fn exec<F>(&self, task: F) -> Result<()>
64    where
65        F: FnOnce() + Send + 'static,
66    {
67        // Round-robin dispatch
68        let index = self.next_worker.fetch_add(1, Ordering::Relaxed) % self.senders.len();
69        Ok(self.senders[index].send(WorkerMessage::Task(Box::new(task)))?)
70    }
71
72    pub fn workers_len(&self) -> usize {
73        self.workers.len()
74    }
75}
76
77impl Drop for ThreadPool {
78    fn drop(&mut self) {
79        for sender in &self.senders {
80            let _ = sender.send(WorkerMessage::Terminate);
81        }
82        for worker in &mut self.workers {
83            if let Some(t) = worker.take_thread() {
84                t.join().unwrap();
85            }
86        }
87    }
88}
89
90struct Worker {
91    #[allow(dead_code)]
92    id: usize,
93    thread: Option<JoinHandle<()>>,
94}
95
96impl Worker {
97    pub fn new(id: usize, receiver: channel::Receiver<WorkerMessage>) -> Self {
98        let thread = Some(
99            Builder::new()
100                .name(format!("thread-pool-worker-{id}"))
101                .spawn(move || {
102                    while let Ok(message) = receiver.recv() {
103                        match message {
104                            WorkerMessage::Task(task) => task(),
105                            WorkerMessage::Terminate => break,
106                        }
107                    }
108                })
109                .expect("Couldn't create the worker thread id={id}"),
110        );
111
112        Self { id, thread }
113    }
114
115    pub fn take_thread(&mut self) -> Option<JoinHandle<()>> {
116        self.thread.take()
117    }
118}
119
120#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
121pub enum TaskPriority {
122    Low = 0,
123    Normal = 1,
124    High = 2,
125    Critical = 3,
126}
127
128#[derive(Debug, Default)]
129pub struct ComputePoolMetrics {
130    pub tasks_submitted: AtomicU64,
131    pub tasks_completed: AtomicU64,
132    pub tasks_failed: AtomicU64,
133    pub active_workers: AtomicUsize,
134    pub queue_depth_low: AtomicUsize,
135    pub queue_depth_normal: AtomicUsize,
136    pub queue_depth_high: AtomicUsize,
137    pub queue_depth_critical: AtomicUsize,
138    pub total_execution_time_ns: AtomicU64,
139}
140
141impl ComputePoolMetrics {
142    pub fn tasks_submitted(&self) -> u64 {
143        self.tasks_submitted.load(Ordering::Relaxed)
144    }
145
146    pub fn tasks_completed(&self) -> u64 {
147        self.tasks_completed.load(Ordering::Relaxed)
148    }
149
150    pub fn tasks_failed(&self) -> u64 {
151        self.tasks_failed.load(Ordering::Relaxed)
152    }
153
154    pub fn active_workers(&self) -> usize {
155        self.active_workers.load(Ordering::Relaxed)
156    }
157
158    pub fn queue_depth_low(&self) -> usize {
159        self.queue_depth_low.load(Ordering::Relaxed)
160    }
161
162    pub fn queue_depth_normal(&self) -> usize {
163        self.queue_depth_normal.load(Ordering::Relaxed)
164    }
165
166    pub fn queue_depth_high(&self) -> usize {
167        self.queue_depth_high.load(Ordering::Relaxed)
168    }
169
170    pub fn queue_depth_critical(&self) -> usize {
171        self.queue_depth_critical.load(Ordering::Relaxed)
172    }
173
174    pub fn total_execution_time_ns(&self) -> u64 {
175        self.total_execution_time_ns.load(Ordering::Relaxed)
176    }
177}
178
179struct PriorityTask {
180    task: Task,
181    priority: TaskPriority,
182    sequence: u64,
183}
184
185impl PartialEq for PriorityTask {
186    fn eq(&self, other: &Self) -> bool {
187        self.priority == other.priority && self.sequence == other.sequence
188    }
189}
190
191impl Eq for PriorityTask {}
192
193impl PartialOrd for PriorityTask {
194    fn partial_cmp(&self, other: &Self) -> Option<CmpOrdering> {
195        Some(self.cmp(other))
196    }
197}
198
199impl Ord for PriorityTask {
200    fn cmp(&self, other: &Self) -> CmpOrdering {
201        match self.priority.cmp(&other.priority) {
202            CmpOrdering::Equal => other.sequence.cmp(&self.sequence),
203            ord => ord,
204        }
205    }
206}
207
208struct ComputeSharedState {
209    queue: Mutex<BinaryHeap<PriorityTask>>,
210    condvar: Condvar,
211    shutdown: AtomicBool,
212}
213
214pub struct ComputeThreadPool {
215    workers: Vec<JoinHandle<()>>,
216    state: Arc<ComputeSharedState>,
217    sequence: AtomicU64,
218    metrics: Arc<ComputePoolMetrics>,
219}
220
221impl Default for ComputeThreadPool {
222    fn default() -> Self {
223        let default_capacity = std::thread::available_parallelism()
224            .map(|n| n.get())
225            .unwrap_or(DEFAULT_POOL_CAPACITY);
226        Self::new(default_capacity)
227    }
228}
229
230impl ComputeThreadPool {
231    pub fn new(capacity: usize) -> Self {
232        let state = Arc::new(ComputeSharedState {
233            queue: Mutex::new(BinaryHeap::new()),
234            condvar: Condvar::new(),
235            shutdown: AtomicBool::new(false),
236        });
237        let metrics = Arc::new(ComputePoolMetrics::default());
238
239        let mut workers = Vec::with_capacity(capacity);
240        // barrier to ensure all workers are started before returning
241        let barrier = Arc::new(Barrier::new(capacity + 1));
242
243        for id in 0..capacity {
244            let state_clone = Arc::clone(&state);
245            let barrier_clone = Arc::clone(&barrier);
246            let metrics_clone = Arc::clone(&metrics);
247            let thread = Builder::new()
248                .name(format!("compute-worker-{id}"))
249                .spawn(move || {
250                    // wait for all workers to be ready
251                    barrier_clone.wait();
252
253                    loop {
254                        let task = {
255                            let mut queue = state_clone.queue.lock();
256
257                            while queue.is_empty() && !state_clone.shutdown.load(Ordering::Relaxed)
258                            {
259                                state_clone.condvar.wait(&mut queue);
260                            }
261
262                            if state_clone.shutdown.load(Ordering::Relaxed) && queue.is_empty() {
263                                break;
264                            }
265
266                            let t = queue.pop();
267                            if let Some(ref pt) = t {
268                                match pt.priority {
269                                    TaskPriority::Low => metrics_clone
270                                        .queue_depth_low
271                                        .fetch_sub(1, Ordering::Relaxed),
272                                    TaskPriority::Normal => metrics_clone
273                                        .queue_depth_normal
274                                        .fetch_sub(1, Ordering::Relaxed),
275                                    TaskPriority::High => metrics_clone
276                                        .queue_depth_high
277                                        .fetch_sub(1, Ordering::Relaxed),
278                                    TaskPriority::Critical => metrics_clone
279                                        .queue_depth_critical
280                                        .fetch_sub(1, Ordering::Relaxed),
281                                };
282                            }
283                            t
284                        };
285
286                        if let Some(priority_task) = task {
287                            metrics_clone.active_workers.fetch_add(1, Ordering::Relaxed);
288                            let start = Instant::now();
289
290                            let result = catch_unwind(AssertUnwindSafe(|| (priority_task.task)()));
291
292                            let duration = start.elapsed();
293                            metrics_clone
294                                .total_execution_time_ns
295                                .fetch_add(duration.as_nanos() as u64, Ordering::Relaxed);
296                            metrics_clone.active_workers.fetch_sub(1, Ordering::Relaxed);
297
298                            if result.is_ok() {
299                                metrics_clone
300                                    .tasks_completed
301                                    .fetch_add(1, Ordering::Relaxed);
302                            } else {
303                                metrics_clone.tasks_failed.fetch_add(1, Ordering::Relaxed);
304                            }
305                        }
306                    }
307                })
308                .expect("Failed to create compute worker thread");
309            workers.push(thread);
310        }
311
312        // wait for all workers to start
313        barrier.wait();
314
315        Self {
316            workers,
317            state,
318            sequence: AtomicU64::new(0),
319            metrics,
320        }
321    }
322
323    pub fn spawn<F>(&self, task: F, priority: TaskPriority)
324    where
325        F: FnOnce() + Send + 'static,
326    {
327        let sequence = self.sequence.fetch_add(1, Ordering::Relaxed);
328        let priority_task = PriorityTask {
329            task: Box::new(task),
330            priority,
331            sequence,
332        };
333
334        self.metrics.tasks_submitted.fetch_add(1, Ordering::Relaxed);
335        match priority {
336            TaskPriority::Low => self.metrics.queue_depth_low.fetch_add(1, Ordering::Relaxed),
337            TaskPriority::Normal => self
338                .metrics
339                .queue_depth_normal
340                .fetch_add(1, Ordering::Relaxed),
341            TaskPriority::High => self
342                .metrics
343                .queue_depth_high
344                .fetch_add(1, Ordering::Relaxed),
345            TaskPriority::Critical => self
346                .metrics
347                .queue_depth_critical
348                .fetch_add(1, Ordering::Relaxed),
349        };
350
351        let mut queue = self.state.queue.lock();
352        queue.push(priority_task);
353        self.state.condvar.notify_one();
354    }
355
356    pub fn metrics(&self) -> Arc<ComputePoolMetrics> {
357        self.metrics.clone()
358    }
359}
360
361impl Drop for ComputeThreadPool {
362    fn drop(&mut self) {
363        self.state.shutdown.store(true, Ordering::SeqCst);
364        self.state.condvar.notify_all();
365
366        for worker in self.workers.drain(..) {
367            let _ = worker.join();
368        }
369    }
370}
371
372#[cfg(test)]
373mod tests {
374    use std::{
375        sync::atomic::{AtomicUsize, Ordering},
376        sync::{Arc, Barrier, Mutex},
377        time::Duration,
378    };
379
380    use super::*;
381
382    #[test]
383    fn test_thread_pool_creation() {
384        let pool = ThreadPool::new(4);
385        assert_eq!(pool.workers_len(), 4);
386    }
387
388    #[test]
389    fn test_task_execution() {
390        let pool = ThreadPool::new(2);
391        let counter = Arc::new(AtomicUsize::new(0));
392        let counter_clone = counter.clone();
393
394        pool.exec(move || {
395            counter_clone.fetch_add(1, Ordering::SeqCst);
396        })
397        .unwrap();
398
399        std::thread::sleep(Duration::from_millis(100));
400        assert_eq!(counter.load(Ordering::SeqCst), 1);
401    }
402    #[test]
403    fn test_multiple_tasks() {
404        let pool = ThreadPool::new(4);
405        let counter = Arc::new(AtomicUsize::new(0));
406
407        for _ in 0..10 {
408            let counter_clone = counter.clone();
409            pool.exec(move || {
410                counter_clone.fetch_add(1, Ordering::SeqCst);
411            })
412            .unwrap();
413        }
414
415        std::thread::sleep(Duration::from_millis(200));
416        assert_eq!(counter.load(Ordering::SeqCst), 10);
417    }
418
419    #[test]
420    fn test_pool_cleanup() {
421        let counter = Arc::new(AtomicUsize::new(0));
422        {
423            let pool = ThreadPool::new(2);
424            let counter_clone = counter.clone();
425
426            pool.exec(move || {
427                std::thread::sleep(Duration::from_millis(50));
428                counter_clone.fetch_add(1, Ordering::SeqCst);
429            })
430            .unwrap();
431        }
432
433        assert_eq!(counter.load(Ordering::SeqCst), 1);
434    }
435
436    #[test]
437    fn test_compute_pool_priority() {
438        let pool = ComputeThreadPool::new(1); // Single thread to ensure order execution
439        let result = Arc::new(Mutex::new(Vec::new()));
440
441        // use a barrier to ensure the first task is running and blocking the worker
442        let barrier = Arc::new(Barrier::new(2));
443        let b_clone = barrier.clone();
444
445        let r1 = result.clone();
446        pool.spawn(
447            move || {
448                b_clone.wait(); // signal that we started
449                std::thread::sleep(Duration::from_millis(50)); // block worker
450                r1.lock().unwrap().push(1);
451            },
452            TaskPriority::Low,
453        );
454
455        // wait for Task 1 to start
456        barrier.wait();
457
458        // these should be queued while the first one runs
459        let r2 = result.clone();
460        pool.spawn(
461            move || {
462                r2.lock().unwrap().push(2);
463            },
464            TaskPriority::Low,
465        );
466
467        let r3 = result.clone();
468        pool.spawn(
469            move || {
470                r3.lock().unwrap().push(3);
471            },
472            TaskPriority::High,
473        );
474
475        let r4 = result.clone();
476        pool.spawn(
477            move || {
478                r4.lock().unwrap().push(4);
479            },
480            TaskPriority::Normal,
481        );
482
483        // wait for tasks to finish
484        std::thread::sleep(Duration::from_millis(200));
485
486        let res = result.lock().unwrap();
487        // 1 runs first (started immediately).
488        // Then 3 (High), 4 (Normal), 2 (Low).
489        assert_eq!(*res, vec![1, 3, 4, 2]);
490    }
491
492    #[test]
493    fn test_compute_pool_metrics() {
494        let pool = ComputeThreadPool::new(2);
495        let metrics = pool.metrics();
496
497        let barrier = Arc::new(Barrier::new(3)); // 2 workers + main thread
498        let barrier_clone = barrier.clone();
499
500        // Task 1: Occupy worker 1
501        pool.spawn(
502            move || {
503                barrier_clone.wait(); // wait for main thread to check metrics
504            },
505            TaskPriority::Normal,
506        );
507
508        let barrier_clone2 = barrier.clone();
509        // Task 2: Occupy worker 2
510        pool.spawn(
511            move || {
512                barrier_clone2.wait(); // wait for main thread to check metrics
513            },
514            TaskPriority::Normal,
515        );
516
517        // wait a bit for workers to pick up tasks
518        std::thread::sleep(Duration::from_millis(50));
519
520        // Task 3: Queue (Low)
521        pool.spawn(|| {}, TaskPriority::Low);
522
523        // Task 4: Queue (High)
524        pool.spawn(|| {}, TaskPriority::High);
525
526        // check intermediate metrics
527        assert_eq!(metrics.tasks_submitted(), 4);
528        // both workers should be busy
529        assert_eq!(metrics.active_workers(), 2);
530        // queued tasks
531        assert_eq!(metrics.queue_depth_low(), 1);
532        assert_eq!(metrics.queue_depth_high(), 1);
533        // running tasks are popped, so normal queue depth is 0
534        assert_eq!(metrics.queue_depth_normal(), 0);
535
536        barrier.wait();
537
538        // wait for completion
539        let start = std::time::Instant::now();
540        while metrics.tasks_completed() < 4 {
541            if start.elapsed() > Duration::from_secs(2) {
542                panic!("Timed out waiting for tasks to complete");
543            }
544            std::thread::sleep(Duration::from_millis(10));
545        }
546
547        // check final metrics
548        assert_eq!(metrics.tasks_completed(), 4);
549        assert_eq!(metrics.active_workers(), 0);
550        assert_eq!(metrics.queue_depth_low(), 0);
551        assert_eq!(metrics.queue_depth_high(), 0);
552        assert!(metrics.total_execution_time_ns() > 0);
553    }
554}