Skip to main content

radiate_core/domain/sync/
thread_pool.rs

1use std::{
2    fmt::Debug,
3    sync::{Arc, Mutex, OnceLock},
4};
5use std::{sync::mpsc, thread};
6
7/// A fixed-size thread pool implementation. This thread pool will create a fixed number of worker threads
8/// that will be reused for executing jobs. This is useful for limiting the number of concurrent threads
9/// in the application.
10///
11/// The thread pool within the `FixedThreadPool` is created only once and will be reused for the lifetime of the program.
12/// Meaning that the first time you request a thread pool with a specific number of workers, that number will be used.
13/// Subsequent requests with different numbers will be ignored.
14struct FixedThreadPool {
15    inner: Arc<ThreadPool>,
16}
17
18impl FixedThreadPool {
19    /// Returns the global instance of the threadpool.
20    ///
21    /// This thread pool is fixed in size and will be created only once. This means that
22    /// the first time you call this method with a specific number of workers, that number will be used
23    /// for the lifetime of the program. Subsequent calls with different numbers will be ignored.
24    pub(self) fn instance(num_workers: usize) -> &'static FixedThreadPool {
25        static INSTANCE: OnceLock<FixedThreadPool> = OnceLock::new();
26
27        INSTANCE.get_or_init(|| FixedThreadPool {
28            inner: Arc::new(ThreadPool::new(num_workers)),
29        })
30    }
31}
32
33pub fn get_thread_pool(num_workers: usize) -> Arc<ThreadPool> {
34    Arc::clone(&FixedThreadPool::instance(num_workers).inner)
35}
36
37/// [WorkResult] is a simple wrapper around a `Receiver` that allows the user to get
38/// the result of a job that was executed in the thread pool. It kinda acts like
39/// a `Future` in a synchronous way.
40pub struct WorkResult<T> {
41    receiver: mpsc::Receiver<T>,
42}
43
44impl<T> WorkResult<T> {
45    pub fn new(rx: mpsc::Receiver<T>) -> Self {
46        WorkResult { receiver: rx }
47    }
48    /// Get the result of the job.
49    /// **Note**: This method will block until the result is available.
50    pub fn result(&self) -> T {
51        self.receiver.recv().unwrap()
52    }
53}
54
55pub struct ThreadPool {
56    sender: mpsc::Sender<Message>,
57    workers: Vec<Worker>,
58}
59
60impl ThreadPool {
61    /// Basic thread pool implementation.
62    ///
63    /// Create a new ThreadPool with the given size.
64    pub fn new(size: usize) -> Self {
65        let (sender, receiver) = mpsc::channel();
66        let receiver = Arc::new(Mutex::new(receiver));
67
68        ThreadPool {
69            sender,
70            workers: (0..size)
71                .map(|id| Worker::new(id, Arc::clone(&receiver)))
72                .collect(),
73        }
74    }
75
76    pub fn num_workers(&self) -> usize {
77        self.workers.len()
78    }
79
80    pub fn is_alive(&self) -> bool {
81        self.workers.iter().any(|worker| worker.is_alive())
82    }
83
84    /// Execute a job in the thread pool. This method does not return anything
85    /// and as such can be thought of as a 'fire-and-forget' job submission.
86    ///
87    /// # Example
88    /// ```rust,ignore
89    /// use radiate_core::domain::thread_pool::ThreadPool;
90    /// use std::sync::{Arc, Mutex};
91    ///
92    /// let pool = ThreadPool::new(4);
93    /// let counter = Arc::new(Mutex::new(0));
94    ///
95    /// for _ in 0..8 {
96    ///     let counter = Arc::clone(&counter);
97    ///     pool.submit(move || {
98    ///         let mut num = counter.lock().unwrap();
99    ///         *num += 1;
100    ///     });
101    /// }
102    ///
103    /// // Drop the pool to join all threads
104    /// drop(pool);
105    ///
106    /// assert_eq!(*counter.lock().unwrap(), 8);
107    /// ```
108    pub fn submit<F>(&self, f: F)
109    where
110        F: FnOnce() + Send + 'static,
111    {
112        let job = Box::new(f);
113        self.sender.send(Message::Work(job)).unwrap();
114    }
115
116    /// Execute a job in the thread pool and return a [WorkResult]
117    /// that can be used to get the result of the job. This method
118    /// is similar to a 'future' in that it allows the user to get
119    /// the result of the job at a later time. It should be noted that the [WorkResult]
120    /// will block when calling `result()` until the job is complete.
121    ///
122    /// # Example
123    /// ```rust,ignore
124    /// use radiate_core::domain::thread_pool::ThreadPool;
125    ///
126    /// let pool = ThreadPool::new(4);
127    /// let work_result = pool.submit_with_result(|| 10 + 32);
128    ///
129    /// // Drop the pool to join all threads
130    /// drop(pool);
131    ///
132    /// let result = work_result.result();
133    /// assert_eq!(result, 42);
134    /// ```
135    pub fn submit_with_result<F, T>(&self, f: F) -> WorkResult<T>
136    where
137        F: FnOnce() -> T + Send + 'static,
138        T: Send + 'static,
139    {
140        let (tx, rx) = mpsc::sync_channel(1);
141        let job = Box::new(move || tx.send(f()).unwrap());
142
143        self.sender.send(Message::Work(job)).unwrap();
144
145        WorkResult { receiver: rx }
146    }
147}
148
149/// Drop implementation for ThreadPool. This will terminate all workers when the ThreadPool is dropped.
150/// We need to make sure that all workers are terminated before the ThreadPool is dropped.
151impl Drop for ThreadPool {
152    fn drop(&mut self) {
153        for _ in self.workers.iter() {
154            self.sender.send(Message::Terminate).unwrap();
155        }
156
157        for worker in self.workers.iter_mut() {
158            if let Some(thread) = worker.thread.take() {
159                thread.join().unwrap();
160            }
161        }
162
163        assert!(!self.is_alive());
164    }
165}
166
167/// Job type that can be executed in the thread pool.
168type Job = Box<dyn FnOnce() + Send + 'static>;
169
170/// Message type that can be sent to the worker threads.
171enum Message {
172    Work(Job),
173    Terminate,
174}
175
176/// Worker struct that listens for incoming `Message`s and executes the `Job`s or terminates.
177struct Worker {
178    id: usize,
179    thread: Option<thread::JoinHandle<()>>,
180}
181
182impl Worker {
183    /// Create a new Worker.
184    ///
185    /// Runs jobs on a long-lived worker thread that pulls tasks from the queue.
186    fn new(id: usize, receiver: Arc<Mutex<mpsc::Receiver<Message>>>) -> Self {
187        Worker {
188            id,
189            thread: Some(thread::spawn(move || {
190                loop {
191                    let message = receiver.lock().unwrap().recv().unwrap();
192
193                    match message {
194                        Message::Work(job) => job(),
195                        Message::Terminate => break,
196                    }
197                }
198            })),
199        }
200    }
201
202    /// Simple check if the worker is alive. The thread is 'taken' when the worker is dropped.
203    /// So if the thread is 'None' the worker is no longer alive.
204    pub fn is_alive(&self) -> bool {
205        self.thread.is_some()
206    }
207}
208
209impl Debug for Worker {
210    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
211        f.debug_struct("Worker")
212            .field("id", &self.id)
213            .field("is_alive", &self.is_alive())
214            .finish()
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221    use crate::WaitGroup;
222    use std::time::{Duration, Instant};
223
224    #[test]
225    fn test_thread_pool_creation() {
226        let pool = ThreadPool::new(4);
227        assert!(pool.is_alive());
228    }
229
230    #[test]
231    fn test_basic_job_execution() {
232        let pool = ThreadPool::new(4);
233        let counter = Arc::new(Mutex::new(0));
234
235        for _ in 0..8 {
236            let counter = Arc::clone(&counter);
237            pool.submit(move || {
238                let mut num = counter.lock().unwrap();
239                *num += 1;
240            });
241        }
242
243        // Give threads some time to finish processing
244        thread::sleep(Duration::from_secs(1));
245        assert_eq!(*counter.lock().unwrap(), 8);
246    }
247
248    #[test]
249    fn test_thread_pool() {
250        let pool = ThreadPool::new(4);
251
252        for i in 0..8 {
253            pool.submit(move || {
254                let start_time = std::time::SystemTime::now();
255                println!("Job {} started.", i);
256                thread::sleep(Duration::from_secs(1));
257                println!("Job {} finished in {:?}.", i, start_time.elapsed().unwrap());
258            });
259        }
260    }
261
262    #[test]
263    fn test_job_order() {
264        let pool = ThreadPool::new(2);
265        let results = Arc::new(Mutex::new(vec![]));
266
267        for i in 0..5 {
268            let results = Arc::clone(&results);
269            pool.submit(move || {
270                results.lock().unwrap().push(i);
271            });
272        }
273
274        // Give threads some time to finish processing
275        thread::sleep(Duration::from_secs(1));
276        let mut results = results.lock().unwrap();
277        results.sort(); // Order may not be guaranteed
278        assert_eq!(*results, vec![0, 1, 2, 3, 4]);
279    }
280
281    #[test]
282    fn test_thread_pool_process() {
283        let pool = ThreadPool::new(4);
284
285        let results = pool.submit_with_result(|| {
286            let start_time = std::time::SystemTime::now();
287            println!("Job started.");
288            thread::sleep(Duration::from_secs(2));
289            println!("Job finished in {:?}.", start_time.elapsed().unwrap());
290            42
291        });
292
293        let result = results.result();
294        assert_eq!(result, 42);
295    }
296
297    #[test]
298    fn test_max_concurrent_jobs() {
299        let pool = ThreadPool::new(4);
300        let (tx, rx) = mpsc::channel();
301        let num_jobs = 20;
302        let start_time = Instant::now();
303
304        // Submit 20 jobs
305        for i in 0..num_jobs {
306            let tx = tx.clone();
307            pool.submit(move || {
308                thread::sleep(Duration::from_millis(100));
309                tx.send(i).unwrap();
310            });
311        }
312
313        // Wait for all jobs to finish
314        let mut results = vec![];
315        for _ in 0..num_jobs {
316            results.push(rx.recv().unwrap());
317        }
318
319        let elapsed = start_time.elapsed();
320        assert!(elapsed < Duration::from_secs(3));
321        assert_eq!(results.len(), num_jobs);
322        assert!(results.iter().all(|&x| x < num_jobs));
323    }
324
325    #[test]
326    fn tests_thread_pool_submit_with_result_returns_correct_order() {
327        let pool = ThreadPool::new(5);
328        let num_jobs = 10;
329        let mut work_results = vec![];
330
331        for i in 0..num_jobs {
332            let work_result = pool.submit_with_result(move || {
333                thread::sleep(Duration::from_millis(50 * (num_jobs - i) as u64));
334                i * i
335            });
336            work_results.push(work_result);
337        }
338
339        for (i, work_result) in work_results.into_iter().enumerate() {
340            let result = work_result.result();
341            assert_eq!(result, i * i);
342        }
343    }
344
345    #[test]
346    fn test_wait_group() {
347        let pool = ThreadPool::new(4);
348        let wg = WaitGroup::new();
349        let num_tasks = 10;
350        let total = Arc::new(Mutex::new(0));
351
352        for _ in 0..num_tasks {
353            let guard = wg.guard();
354            let total = Arc::clone(&total);
355            pool.submit(move || {
356                thread::sleep(Duration::from_millis(100));
357                let mut num = total.lock().unwrap();
358                *num += 1;
359                drop(guard);
360            });
361        }
362
363        // Not all tasks should be done yet - so the total should be less than num_tasks
364        {
365            let total = total.lock().unwrap();
366            assert_ne!(*total, num_tasks);
367        }
368
369        let total_tasks_waited_for = wg.wait();
370
371        // Now all tasks should be done - so the total should equal num_tasks
372        let total = total.lock().unwrap();
373        assert_eq!(*total, num_tasks);
374        assert_eq!(total_tasks_waited_for, num_tasks);
375    }
376
377    #[test]
378    fn test_wait_group_zero_tasks() {
379        let wg = WaitGroup::new();
380        let total_tasks_waited_for = wg.wait();
381        assert_eq!(total_tasks_waited_for, 0);
382    }
383}