mill_io/
thread_pool.rs

1#[cfg(feature = "unstable-mpmc")]
2use std::sync::mpmc as channel;
3#[cfg(not(feature = "unstable-mpmc"))]
4use std::sync::mpsc as channel;
5use std::{
6    cmp::Ordering as CmpOrdering,
7    collections::BinaryHeap,
8    panic::{catch_unwind, AssertUnwindSafe},
9    sync::{
10        atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering},
11        Arc, Barrier, Condvar, Mutex,
12    },
13    thread::{Builder, JoinHandle},
14    time::Instant,
15};
16
17use crate::error::Result;
18
19pub const DEFAULT_POOL_CAPACITY: usize = 4;
20
21pub type Task = Box<dyn FnOnce() + Send + 'static>;
22
23enum WorkerMessage {
24    Task(Task),
25    Terminate,
26}
27
28pub struct ThreadPool {
29    workers: Vec<Worker>,
30    senders: Vec<channel::Sender<WorkerMessage>>,
31    next_worker: AtomicUsize,
32}
33
34impl Default for ThreadPool {
35    fn default() -> Self {
36        let default_capacity = std::thread::available_parallelism()
37            .map(|n| n.get())
38            .unwrap_or(DEFAULT_POOL_CAPACITY);
39        Self::new(default_capacity)
40    }
41}
42
43impl ThreadPool {
44    pub fn new(capacity: usize) -> Self {
45        let mut workers = Vec::with_capacity(capacity);
46        let mut senders = Vec::with_capacity(capacity);
47
48        for id in 0..capacity {
49            let (sender, receiver) = channel::channel::<WorkerMessage>();
50            workers.push(Worker::new(id, receiver));
51            senders.push(sender);
52        }
53
54        Self {
55            workers,
56            senders,
57            next_worker: AtomicUsize::new(0),
58        }
59    }
60
61    pub fn exec<F>(&self, task: F) -> Result<()>
62    where
63        F: FnOnce() + Send + 'static,
64    {
65        // Round-robin dispatch
66        let index = self.next_worker.fetch_add(1, Ordering::Relaxed) % self.senders.len();
67        Ok(self.senders[index].send(WorkerMessage::Task(Box::new(task)))?)
68    }
69
70    pub fn workers_len(&self) -> usize {
71        self.workers.len()
72    }
73}
74
75impl Drop for ThreadPool {
76    fn drop(&mut self) {
77        for sender in &self.senders {
78            let _ = sender.send(WorkerMessage::Terminate);
79        }
80        for worker in &mut self.workers {
81            if let Some(t) = worker.take_thread() {
82                t.join().unwrap();
83            }
84        }
85    }
86}
87
88struct Worker {
89    #[allow(dead_code)]
90    id: usize,
91    thread: Option<JoinHandle<()>>,
92}
93
94impl Worker {
95    pub fn new(id: usize, receiver: channel::Receiver<WorkerMessage>) -> Self {
96        let thread = Some(
97            Builder::new()
98                .name(format!("thread-pool-worker-{id}"))
99                .spawn(move || {
100                    while let Ok(message) = receiver.recv() {
101                        match message {
102                            WorkerMessage::Task(task) => task(),
103                            WorkerMessage::Terminate => break,
104                        }
105                    }
106                })
107                .expect("Couldn't create the worker thread id={id}"),
108        );
109
110        Self { id, thread }
111    }
112
113    pub fn take_thread(&mut self) -> Option<JoinHandle<()>> {
114        self.thread.take()
115    }
116}
117
118#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
119pub enum TaskPriority {
120    Low = 0,
121    Normal = 1,
122    High = 2,
123    Critical = 3,
124}
125
126#[derive(Debug, Default)]
127pub struct ComputePoolMetrics {
128    pub tasks_submitted: AtomicU64,
129    pub tasks_completed: AtomicU64,
130    pub tasks_failed: AtomicU64,
131    pub active_workers: AtomicUsize,
132    pub queue_depth_low: AtomicUsize,
133    pub queue_depth_normal: AtomicUsize,
134    pub queue_depth_high: AtomicUsize,
135    pub queue_depth_critical: AtomicUsize,
136    pub total_execution_time_ns: AtomicU64,
137}
138
139impl ComputePoolMetrics {
140    pub fn tasks_submitted(&self) -> u64 {
141        self.tasks_submitted.load(Ordering::Relaxed)
142    }
143
144    pub fn tasks_completed(&self) -> u64 {
145        self.tasks_completed.load(Ordering::Relaxed)
146    }
147
148    pub fn tasks_failed(&self) -> u64 {
149        self.tasks_failed.load(Ordering::Relaxed)
150    }
151
152    pub fn active_workers(&self) -> usize {
153        self.active_workers.load(Ordering::Relaxed)
154    }
155
156    pub fn queue_depth_low(&self) -> usize {
157        self.queue_depth_low.load(Ordering::Relaxed)
158    }
159
160    pub fn queue_depth_normal(&self) -> usize {
161        self.queue_depth_normal.load(Ordering::Relaxed)
162    }
163
164    pub fn queue_depth_high(&self) -> usize {
165        self.queue_depth_high.load(Ordering::Relaxed)
166    }
167
168    pub fn queue_depth_critical(&self) -> usize {
169        self.queue_depth_critical.load(Ordering::Relaxed)
170    }
171
172    pub fn total_execution_time_ns(&self) -> u64 {
173        self.total_execution_time_ns.load(Ordering::Relaxed)
174    }
175}
176
177struct PriorityTask {
178    task: Task,
179    priority: TaskPriority,
180    sequence: u64,
181}
182
183impl PartialEq for PriorityTask {
184    fn eq(&self, other: &Self) -> bool {
185        self.priority == other.priority && self.sequence == other.sequence
186    }
187}
188
189impl Eq for PriorityTask {}
190
191impl PartialOrd for PriorityTask {
192    fn partial_cmp(&self, other: &Self) -> Option<CmpOrdering> {
193        Some(self.cmp(other))
194    }
195}
196
197impl Ord for PriorityTask {
198    fn cmp(&self, other: &Self) -> CmpOrdering {
199        match self.priority.cmp(&other.priority) {
200            CmpOrdering::Equal => other.sequence.cmp(&self.sequence),
201            ord => ord,
202        }
203    }
204}
205
206struct ComputeSharedState {
207    queue: Mutex<BinaryHeap<PriorityTask>>,
208    condvar: Condvar,
209    shutdown: AtomicBool,
210}
211
212pub struct ComputeThreadPool {
213    workers: Vec<JoinHandle<()>>,
214    state: Arc<ComputeSharedState>,
215    sequence: AtomicU64,
216    metrics: Arc<ComputePoolMetrics>,
217}
218
219impl Default for ComputeThreadPool {
220    fn default() -> Self {
221        let default_capacity = std::thread::available_parallelism()
222            .map(|n| n.get())
223            .unwrap_or(DEFAULT_POOL_CAPACITY);
224        Self::new(default_capacity)
225    }
226}
227
228impl ComputeThreadPool {
229    pub fn new(capacity: usize) -> Self {
230        let state = Arc::new(ComputeSharedState {
231            queue: Mutex::new(BinaryHeap::new()),
232            condvar: Condvar::new(),
233            shutdown: AtomicBool::new(false),
234        });
235        let metrics = Arc::new(ComputePoolMetrics::default());
236
237        let mut workers = Vec::with_capacity(capacity);
238        // barrier to ensure all workers are started before returning
239        let barrier = Arc::new(Barrier::new(capacity + 1));
240
241        for id in 0..capacity {
242            let state_clone = Arc::clone(&state);
243            let barrier_clone = Arc::clone(&barrier);
244            let metrics_clone = Arc::clone(&metrics);
245            let thread = Builder::new()
246                .name(format!("compute-worker-{id}"))
247                .spawn(move || {
248                    // wait for all workers to be ready
249                    barrier_clone.wait();
250
251                    loop {
252                        let task = {
253                            let mut queue = state_clone.queue.lock().unwrap();
254
255                            while queue.is_empty() && !state_clone.shutdown.load(Ordering::Relaxed)
256                            {
257                                queue = state_clone.condvar.wait(queue).unwrap();
258                            }
259
260                            if state_clone.shutdown.load(Ordering::Relaxed) && queue.is_empty() {
261                                break;
262                            }
263
264                            let t = queue.pop();
265                            if let Some(ref pt) = t {
266                                match pt.priority {
267                                    TaskPriority::Low => metrics_clone
268                                        .queue_depth_low
269                                        .fetch_sub(1, Ordering::Relaxed),
270                                    TaskPriority::Normal => metrics_clone
271                                        .queue_depth_normal
272                                        .fetch_sub(1, Ordering::Relaxed),
273                                    TaskPriority::High => metrics_clone
274                                        .queue_depth_high
275                                        .fetch_sub(1, Ordering::Relaxed),
276                                    TaskPriority::Critical => metrics_clone
277                                        .queue_depth_critical
278                                        .fetch_sub(1, Ordering::Relaxed),
279                                };
280                            }
281                            t
282                        };
283
284                        if let Some(priority_task) = task {
285                            metrics_clone.active_workers.fetch_add(1, Ordering::Relaxed);
286                            let start = Instant::now();
287
288                            let result = catch_unwind(AssertUnwindSafe(|| (priority_task.task)()));
289
290                            let duration = start.elapsed();
291                            metrics_clone
292                                .total_execution_time_ns
293                                .fetch_add(duration.as_nanos() as u64, Ordering::Relaxed);
294                            metrics_clone.active_workers.fetch_sub(1, Ordering::Relaxed);
295
296                            if result.is_ok() {
297                                metrics_clone
298                                    .tasks_completed
299                                    .fetch_add(1, Ordering::Relaxed);
300                            } else {
301                                metrics_clone.tasks_failed.fetch_add(1, Ordering::Relaxed);
302                            }
303                        }
304                    }
305                })
306                .expect("Failed to create compute worker thread");
307            workers.push(thread);
308        }
309
310        // wait for all workers to start
311        barrier.wait();
312
313        Self {
314            workers,
315            state,
316            sequence: AtomicU64::new(0),
317            metrics,
318        }
319    }
320
321    pub fn spawn<F>(&self, task: F, priority: TaskPriority)
322    where
323        F: FnOnce() + Send + 'static,
324    {
325        let sequence = self.sequence.fetch_add(1, Ordering::Relaxed);
326        let priority_task = PriorityTask {
327            task: Box::new(task),
328            priority,
329            sequence,
330        };
331
332        self.metrics.tasks_submitted.fetch_add(1, Ordering::Relaxed);
333        match priority {
334            TaskPriority::Low => self.metrics.queue_depth_low.fetch_add(1, Ordering::Relaxed),
335            TaskPriority::Normal => self
336                .metrics
337                .queue_depth_normal
338                .fetch_add(1, Ordering::Relaxed),
339            TaskPriority::High => self
340                .metrics
341                .queue_depth_high
342                .fetch_add(1, Ordering::Relaxed),
343            TaskPriority::Critical => self
344                .metrics
345                .queue_depth_critical
346                .fetch_add(1, Ordering::Relaxed),
347        };
348
349        let mut queue = self.state.queue.lock().unwrap();
350        queue.push(priority_task);
351        self.state.condvar.notify_one();
352    }
353
354    pub fn metrics(&self) -> Arc<ComputePoolMetrics> {
355        self.metrics.clone()
356    }
357}
358
359impl Drop for ComputeThreadPool {
360    fn drop(&mut self) {
361        self.state.shutdown.store(true, Ordering::SeqCst);
362        self.state.condvar.notify_all();
363
364        for worker in self.workers.drain(..) {
365            let _ = worker.join();
366        }
367    }
368}
369
370#[cfg(test)]
371mod tests {
372    use std::{
373        sync::atomic::{AtomicUsize, Ordering},
374        sync::{Arc, Barrier, Mutex},
375        time::Duration,
376    };
377
378    use super::*;
379
380    #[test]
381    fn test_thread_pool_creation() {
382        let pool = ThreadPool::new(4);
383        assert_eq!(pool.workers_len(), 4);
384    }
385
386    #[test]
387    fn test_task_execution() {
388        let pool = ThreadPool::new(2);
389        let counter = Arc::new(AtomicUsize::new(0));
390        let counter_clone = counter.clone();
391
392        pool.exec(move || {
393            counter_clone.fetch_add(1, Ordering::SeqCst);
394        })
395        .unwrap();
396
397        std::thread::sleep(Duration::from_millis(100));
398        assert_eq!(counter.load(Ordering::SeqCst), 1);
399    }
400    #[test]
401    fn test_multiple_tasks() {
402        let pool = ThreadPool::new(4);
403        let counter = Arc::new(AtomicUsize::new(0));
404
405        for _ in 0..10 {
406            let counter_clone = counter.clone();
407            pool.exec(move || {
408                counter_clone.fetch_add(1, Ordering::SeqCst);
409            })
410            .unwrap();
411        }
412
413        std::thread::sleep(Duration::from_millis(200));
414        assert_eq!(counter.load(Ordering::SeqCst), 10);
415    }
416
417    #[test]
418    fn test_pool_cleanup() {
419        let counter = Arc::new(AtomicUsize::new(0));
420        {
421            let pool = ThreadPool::new(2);
422            let counter_clone = counter.clone();
423
424            pool.exec(move || {
425                std::thread::sleep(Duration::from_millis(50));
426                counter_clone.fetch_add(1, Ordering::SeqCst);
427            })
428            .unwrap();
429        }
430
431        assert_eq!(counter.load(Ordering::SeqCst), 1);
432    }
433
434    #[test]
435    fn test_compute_pool_priority() {
436        let pool = ComputeThreadPool::new(1); // Single thread to ensure order execution
437        let result = Arc::new(Mutex::new(Vec::new()));
438
439        // use a barrier to ensure the first task is running and blocking the worker
440        let barrier = Arc::new(Barrier::new(2));
441        let b_clone = barrier.clone();
442
443        let r1 = result.clone();
444        pool.spawn(
445            move || {
446                b_clone.wait(); // signal that we started
447                std::thread::sleep(Duration::from_millis(50)); // block worker
448                r1.lock().unwrap().push(1);
449            },
450            TaskPriority::Low,
451        );
452
453        // wait for Task 1 to start
454        barrier.wait();
455
456        // these should be queued while the first one runs
457        let r2 = result.clone();
458        pool.spawn(
459            move || {
460                r2.lock().unwrap().push(2);
461            },
462            TaskPriority::Low,
463        );
464
465        let r3 = result.clone();
466        pool.spawn(
467            move || {
468                r3.lock().unwrap().push(3);
469            },
470            TaskPriority::High,
471        );
472
473        let r4 = result.clone();
474        pool.spawn(
475            move || {
476                r4.lock().unwrap().push(4);
477            },
478            TaskPriority::Normal,
479        );
480
481        // wait for tasks to finish
482        std::thread::sleep(Duration::from_millis(200));
483
484        let res = result.lock().unwrap();
485        // 1 runs first (started immediately).
486        // Then 3 (High), 4 (Normal), 2 (Low).
487        assert_eq!(*res, vec![1, 3, 4, 2]);
488    }
489
490    #[test]
491    fn test_compute_pool_metrics() {
492        let pool = ComputeThreadPool::new(2);
493        let metrics = pool.metrics();
494
495        let barrier = Arc::new(Barrier::new(3)); // 2 workers + main thread
496        let barrier_clone = barrier.clone();
497
498        // Task 1: Occupy worker 1
499        pool.spawn(
500            move || {
501                barrier_clone.wait(); // wait for main thread to check metrics
502            },
503            TaskPriority::Normal,
504        );
505
506        let barrier_clone2 = barrier.clone();
507        // Task 2: Occupy worker 2
508        pool.spawn(
509            move || {
510                barrier_clone2.wait(); // wait for main thread to check metrics
511            },
512            TaskPriority::Normal,
513        );
514
515        // wait a bit for workers to pick up tasks
516        std::thread::sleep(Duration::from_millis(50));
517
518        // Task 3: Queue (Low)
519        pool.spawn(|| {}, TaskPriority::Low);
520
521        // Task 4: Queue (High)
522        pool.spawn(|| {}, TaskPriority::High);
523
524        // check intermediate metrics
525        assert_eq!(metrics.tasks_submitted(), 4);
526        // both workers should be busy
527        assert_eq!(metrics.active_workers(), 2);
528        // queued tasks
529        assert_eq!(metrics.queue_depth_low(), 1);
530        assert_eq!(metrics.queue_depth_high(), 1);
531        // running tasks are popped, so normal queue depth is 0
532        assert_eq!(metrics.queue_depth_normal(), 0);
533
534        barrier.wait();
535
536        // wait for completion
537        let start = std::time::Instant::now();
538        while metrics.tasks_completed() < 4 {
539            if start.elapsed() > Duration::from_secs(2) {
540                panic!("Timed out waiting for tasks to complete");
541            }
542            std::thread::sleep(Duration::from_millis(10));
543        }
544
545        // check final metrics
546        assert_eq!(metrics.tasks_completed(), 4);
547        assert_eq!(metrics.active_workers(), 0);
548        assert_eq!(metrics.queue_depth_low(), 0);
549        assert_eq!(metrics.queue_depth_high(), 0);
550        assert!(metrics.total_execution_time_ns() > 0);
551    }
552}