radiate_core/domain/
thread_pool.rs

1use std::{
2    fmt::Debug,
3    sync::{
4        Arc, Condvar, Mutex, OnceLock,
5        atomic::{AtomicUsize, Ordering},
6    },
7};
8use std::{sync::mpsc, thread};
9
10struct FixedThreadPool {
11    inner: Arc<ThreadPool>,
12}
13
14impl FixedThreadPool {
15    /// Returns the global instance of the registry.
16    pub(self) fn instance(num_workers: usize) -> &'static FixedThreadPool {
17        static INSTANCE: OnceLock<FixedThreadPool> = OnceLock::new();
18
19        INSTANCE.get_or_init(|| FixedThreadPool {
20            inner: Arc::new(ThreadPool::new(num_workers)),
21        })
22    }
23}
24
25pub fn get_thread_pool(num_workers: usize) -> Arc<ThreadPool> {
26    Arc::clone(&FixedThreadPool::instance(num_workers).inner)
27}
28
29/// [WorkResult] is a simple wrapper around a `Receiver` that allows the user to get
30/// the result of a job that was executed in the thread pool. It kinda acts like
31/// a `Future` in a synchronous way.
32pub struct WorkResult<T> {
33    receiver: mpsc::Receiver<T>,
34}
35
36impl<T> WorkResult<T> {
37    /// Get the result of the job.
38    /// **Note**: This method will block until the result is available.
39    pub fn result(&self) -> T {
40        self.receiver.recv().unwrap()
41    }
42}
43
44pub struct ThreadPool {
45    sender: mpsc::Sender<Message>,
46    workers: Vec<Worker>,
47}
48
49impl ThreadPool {
50    /// Basic thread pool implementation.
51    ///
52    /// Create a new ThreadPool with the given size.
53    pub fn new(size: usize) -> Self {
54        let (sender, receiver) = mpsc::channel();
55        let receiver = Arc::new(Mutex::new(receiver));
56
57        ThreadPool {
58            sender,
59            workers: (0..size)
60                .map(|id| Worker::new(id, Arc::clone(&receiver)))
61                .collect(),
62        }
63    }
64
65    pub fn group_submit(&self, wg: &WaitGroup, f: impl FnOnce() + Send + 'static) {
66        let guard = wg.guard();
67
68        self.submit(move || {
69            f();
70            drop(guard);
71        });
72    }
73
74    /// Execute a job in the thread pool. This is a 'fire and forget' method.
75    pub fn submit<F>(&self, f: F)
76    where
77        F: FnOnce() + Send + 'static,
78    {
79        let job = Box::new(f);
80        self.sender.send(Message::NewJob(job)).unwrap();
81    }
82
83    /// Execute a job in the thread pool and return a `WorkResult` that can be used to get the result of the job.
84    pub fn submit_with_result<F, T>(&self, f: F) -> WorkResult<T>
85    where
86        F: FnOnce() -> T + Send + 'static,
87        T: Send + 'static,
88    {
89        let (tx, rx) = mpsc::sync_channel(1);
90        let job = Box::new(move || tx.send(f()).unwrap());
91
92        self.sender.send(Message::NewJob(job)).unwrap();
93        WorkResult { receiver: rx }
94    }
95
96    pub fn num_workers(&self) -> usize {
97        self.workers.len()
98    }
99
100    pub fn is_alive(&self) -> bool {
101        self.workers.iter().any(|worker| worker.is_alive())
102    }
103}
104
105/// Drop implementation for ThreadPool. This will terminate all workers when the ThreadPool is dropped.
106/// We need to make sure that all workers are terminated before the ThreadPool is dropped.
107impl Drop for ThreadPool {
108    fn drop(&mut self) {
109        for _ in self.workers.iter() {
110            self.sender.send(Message::Terminate).unwrap();
111        }
112
113        for worker in self.workers.iter_mut() {
114            if let Some(thread) = worker.thread.take() {
115                thread.join().unwrap();
116            }
117        }
118
119        assert!(!self.is_alive());
120    }
121}
122
123/// Job type that can be executed in the thread pool.
124type Job = Box<dyn FnOnce() + Send + 'static>;
125
126/// Message type that can be sent to the worker threads.
127enum Message {
128    NewJob(Job),
129    Terminate,
130}
131
132/// Worker struct that listens for incoming `Message`s and executes the `Job`s or terminates.
133struct Worker {
134    id: usize,
135    thread: Option<thread::JoinHandle<()>>,
136}
137
138impl Worker {
139    /// Create a new Worker.
140    ///
141    /// The Worker will listen for incoming jobs on the given receiver.
142    /// When a job is received, it will be executed in a new thread and the
143    /// mutex will release allowing another job to be received from a different worker.
144    fn new(id: usize, receiver: Arc<Mutex<mpsc::Receiver<Message>>>) -> Self {
145        Worker {
146            id,
147            thread: Some(thread::spawn(move || {
148                loop {
149                    let message = receiver.lock().unwrap().recv().unwrap();
150
151                    match message {
152                        Message::NewJob(job) => job(),
153                        Message::Terminate => break,
154                    }
155                }
156            })),
157        }
158    }
159
160    /// Simple check if the worker is alive. The thread is 'taken' when the worker is dropped.
161    /// So if the thread is 'None' the worker is no longer alive.
162    pub fn is_alive(&self) -> bool {
163        self.thread.is_some()
164    }
165}
166
167impl Debug for Worker {
168    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169        f.debug_struct("Worker")
170            .field("id", &self.id)
171            .field("is_alive", &self.is_alive())
172            .finish()
173    }
174}
175
176#[derive(Clone)]
177pub struct WaitGroup {
178    inner: Arc<Inner>,
179    total_count: Arc<AtomicUsize>,
180}
181
182struct Inner {
183    counter: AtomicUsize,
184    lock: Mutex<()>,
185    cvar: Condvar,
186}
187
188pub struct WaitGuard {
189    wg: WaitGroup,
190}
191
192impl Drop for WaitGuard {
193    fn drop(&mut self) {
194        if self.wg.inner.counter.fetch_sub(1, Ordering::AcqRel) == 1 {
195            let _guard = self.wg.inner.lock.lock().unwrap();
196            self.wg.inner.cvar.notify_all();
197        }
198    }
199}
200
201impl WaitGroup {
202    pub fn new() -> Self {
203        Self {
204            inner: Arc::new(Inner {
205                counter: AtomicUsize::new(0),
206                lock: Mutex::new(()),
207                cvar: Condvar::new(),
208            }),
209            total_count: Arc::new(AtomicUsize::new(0)),
210        }
211    }
212
213    pub fn get_count(&self) -> usize {
214        self.total_count.load(Ordering::Acquire)
215    }
216
217    /// Adds one to the counter and returns a scoped guard that will decrement when dropped.
218    pub fn guard(&self) -> WaitGuard {
219        self.inner.counter.fetch_add(1, Ordering::AcqRel);
220        self.total_count.fetch_add(1, Ordering::AcqRel);
221        WaitGuard { wg: self.clone() }
222    }
223
224    /// Waits until the counter reaches zero.
225    pub fn wait(&self) -> usize {
226        if self.inner.counter.load(Ordering::Acquire) == 0 {
227            return 0;
228        }
229
230        let lock = self.inner.lock.lock().unwrap();
231        let _unused = self
232            .inner
233            .cvar
234            .wait_while(lock, |_| self.inner.counter.load(Ordering::Acquire) != 0);
235
236        self.get_count()
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use std::time::{Duration, Instant};
243
244    use super::*;
245
246    #[test]
247    fn test_thread_pool_creation() {
248        let pool = ThreadPool::new(4);
249        assert!(pool.is_alive());
250    }
251
252    #[test]
253    fn test_basic_job_execution() {
254        let pool = ThreadPool::new(4);
255        let counter = Arc::new(Mutex::new(0));
256
257        for _ in 0..8 {
258            let counter = Arc::clone(&counter);
259            pool.submit(move || {
260                let mut num = counter.lock().unwrap();
261                *num += 1;
262            });
263        }
264
265        // Give threads some time to finish processing
266        thread::sleep(Duration::from_secs(1));
267        assert_eq!(*counter.lock().unwrap(), 8);
268    }
269
270    #[test]
271    fn test_thread_pool() {
272        let pool = ThreadPool::new(4);
273
274        for i in 0..8 {
275            pool.submit(move || {
276                let start_time = std::time::SystemTime::now();
277                println!("Job {} started.", i);
278                thread::sleep(Duration::from_secs(1));
279                println!("Job {} finished in {:?}.", i, start_time.elapsed().unwrap());
280            });
281        }
282    }
283
284    #[test]
285    fn test_job_order() {
286        let pool = ThreadPool::new(2);
287        let results = Arc::new(Mutex::new(vec![]));
288
289        for i in 0..5 {
290            let results = Arc::clone(&results);
291            pool.submit(move || {
292                results.lock().unwrap().push(i);
293            });
294        }
295
296        // Give threads some time to finish processing
297        thread::sleep(Duration::from_secs(1));
298        let mut results = results.lock().unwrap();
299        results.sort(); // Order may not be guaranteed
300        assert_eq!(*results, vec![0, 1, 2, 3, 4]);
301    }
302
303    #[test]
304    fn test_thread_pool_process() {
305        let pool = ThreadPool::new(4);
306
307        let results = pool.submit_with_result(|| {
308            let start_time = std::time::SystemTime::now();
309            println!("Job started.");
310            thread::sleep(Duration::from_secs(2));
311            println!("Job finished in {:?}.", start_time.elapsed().unwrap());
312            42
313        });
314
315        let result = results.result();
316        assert_eq!(result, 42);
317    }
318
319    #[test]
320    fn test_max_concurrent_jobs() {
321        let pool = ThreadPool::new(4);
322        let (tx, rx) = mpsc::channel();
323        let num_jobs = 20;
324        let start_time = Instant::now();
325
326        // Submit 20 jobs
327        for i in 0..num_jobs {
328            let tx = tx.clone();
329            pool.submit(move || {
330                thread::sleep(Duration::from_millis(100));
331                tx.send(i).unwrap();
332            });
333        }
334
335        // Wait for all jobs to finish
336        let mut results = vec![];
337        for _ in 0..num_jobs {
338            results.push(rx.recv().unwrap());
339        }
340
341        let elapsed = start_time.elapsed();
342        assert!(elapsed < Duration::from_secs(3));
343        assert_eq!(results.len(), num_jobs);
344        assert!(results.iter().all(|&x| x < num_jobs));
345    }
346}