avila_parallel/
thread_pool.rs

1//! Thread pool implementation
2
3use std::sync::{Arc, Mutex, Condvar};
4use std::sync::mpsc::{channel, Sender, Receiver};
5use std::thread;
6
7type Job = Box<dyn FnOnce() + Send + 'static>;
8
9/// A thread pool for executing tasks concurrently
10pub struct ThreadPool {
11    workers: Vec<Worker>,
12    sender: Sender<Message>,
13    active: Arc<Mutex<usize>>,
14    condvar: Arc<Condvar>,
15}
16
17enum Message {
18    NewJob(Job),
19    Terminate,
20}
21
22impl ThreadPool {
23    /// Create a new thread pool with the specified number of threads
24    pub fn new(size: usize) -> Self {
25        assert!(size > 0);
26
27        let (sender, receiver) = channel();
28        let receiver = Arc::new(Mutex::new(receiver));
29        let active = Arc::new(Mutex::new(0));
30        let condvar = Arc::new(Condvar::new());
31
32        let mut workers = Vec::with_capacity(size);
33
34        for id in 0..size {
35            workers.push(Worker::new(
36                id,
37                Arc::clone(&receiver),
38                Arc::clone(&active),
39                Arc::clone(&condvar),
40            ));
41        }
42
43        ThreadPool {
44            workers,
45            sender,
46            active,
47            condvar,
48        }
49    }
50
51    /// Execute a job on the thread pool
52    pub fn execute<F>(&self, f: F)
53    where
54        F: FnOnce() + Send + 'static,
55    {
56        let job = Box::new(f);
57        self.sender.send(Message::NewJob(job)).unwrap();
58    }
59
60    /// Wait for all jobs to complete
61    pub fn wait(&self) {
62        let mut active = self.active.lock().unwrap();
63        while *active > 0 {
64            active = self.condvar.wait(active).unwrap();
65        }
66    }
67
68    /// Get number of worker threads
69    pub fn size(&self) -> usize {
70        self.workers.len()
71    }
72}
73
74impl Drop for ThreadPool {
75    fn drop(&mut self) {
76        for _ in &self.workers {
77            self.sender.send(Message::Terminate).unwrap();
78        }
79
80        for worker in &mut self.workers {
81            if let Some(thread) = worker.thread.take() {
82                thread.join().unwrap();
83            }
84        }
85    }
86}
87
88struct Worker {
89    #[allow(dead_code)]
90    id: usize,
91    thread: Option<thread::JoinHandle<()>>,
92}
93
94impl Worker {
95    fn new(
96        id: usize,
97        receiver: Arc<Mutex<Receiver<Message>>>,
98        active: Arc<Mutex<usize>>,
99        condvar: Arc<Condvar>,
100    ) -> Worker {
101        let thread = thread::spawn(move || loop {
102            let message = receiver.lock().unwrap().recv().unwrap();
103
104            match message {
105                Message::NewJob(job) => {
106                    *active.lock().unwrap() += 1;
107                    job();
108                    let mut count = active.lock().unwrap();
109                    *count -= 1;
110                    condvar.notify_all();
111                }
112                Message::Terminate => {
113                    break;
114                }
115            }
116        });
117
118        Worker {
119            id,
120            thread: Some(thread),
121        }
122    }
123}
124
125/// Get the number of available CPU cores
126pub fn num_cpus() -> usize {
127    thread::available_parallelism()
128        .map(|n| n.get())
129        .unwrap_or(1)
130}
131
132/// Global thread pool (lazy static pattern)
133static mut GLOBAL_POOL: Option<ThreadPool> = None;
134static INIT: std::sync::Once = std::sync::Once::new();
135
136/// Get or initialize the global thread pool
137#[allow(static_mut_refs)]
138pub fn global_pool() -> &'static ThreadPool {
139    unsafe {
140        INIT.call_once(|| {
141            GLOBAL_POOL = Some(ThreadPool::new(num_cpus()));
142        });
143        GLOBAL_POOL.as_ref().unwrap()
144    }
145}