radiate_core/domain/
thread_pool.rs

1use std::{
2    fmt::Debug,
3    sync::{
4        Arc, Condvar, Mutex,
5        atomic::{AtomicUsize, Ordering},
6    },
7};
8use std::{sync::mpsc, thread};
9
10/// `WorkResult` is a simple wrapper around a `Receiver` that allows the user to get
11/// the result of a job that was executed in the thread pool. It kinda acts like
12/// a `Future` in a synchronous way.
13pub struct WorkResult<T> {
14    receiver: mpsc::Receiver<T>,
15}
16
17impl<T> WorkResult<T> {
18    /// Get the result of the job.
19    /// **Note**: This method will block until the result is available.
20    pub fn result(&self) -> T {
21        self.receiver.recv().unwrap()
22    }
23}
24
25pub struct ThreadPool {
26    sender: mpsc::Sender<Message>,
27    workers: Vec<Worker>,
28}
29
30impl ThreadPool {
31    /// Basic thread pool implementation.
32    ///
33    /// Create a new ThreadPool with the given size.
34    pub fn new(size: usize) -> Self {
35        let (sender, receiver) = mpsc::channel();
36        let receiver = Arc::new(Mutex::new(receiver));
37
38        ThreadPool {
39            sender,
40            workers: (0..size)
41                .map(|id| Worker::new(id, Arc::clone(&receiver)))
42                .collect(),
43        }
44    }
45
46    pub fn group_submit(&self, wg: &WaitGroup, f: impl FnOnce() + Send + 'static) {
47        let guard = wg.guard();
48
49        self.submit(move || {
50            f();
51            drop(guard);
52        });
53    }
54
55    /// Execute a job in the thread pool. This is a 'fire and forget' method.
56    pub fn submit<F>(&self, f: F)
57    where
58        F: FnOnce() + Send + 'static,
59    {
60        let job = Box::new(f);
61        self.sender.send(Message::NewJob(job)).unwrap();
62    }
63
64    /// Execute a job in the thread pool and return a `WorkResult` that can be used to get the result of the job.
65    pub fn submit_with_result<F, T>(&self, f: F) -> WorkResult<T>
66    where
67        F: FnOnce() -> T + Send + 'static,
68        T: Send + 'static,
69    {
70        let (tx, rx) = mpsc::sync_channel(1);
71        let job = Box::new(move || tx.send(f()).unwrap());
72
73        self.sender.send(Message::NewJob(job)).unwrap();
74        WorkResult { receiver: rx }
75    }
76
77    pub fn num_workers(&self) -> usize {
78        self.workers.len()
79    }
80
81    pub fn is_alive(&self) -> bool {
82        self.workers.iter().any(|worker| worker.is_alive())
83    }
84}
85
86/// Drop implementation for ThreadPool. This will terminate all workers when the ThreadPool is dropped.
87/// We need to make sure that all workers are terminated before the ThreadPool is dropped.
88impl Drop for ThreadPool {
89    fn drop(&mut self) {
90        for _ in self.workers.iter() {
91            self.sender.send(Message::Terminate).unwrap();
92        }
93
94        for worker in self.workers.iter_mut() {
95            if let Some(thread) = worker.thread.take() {
96                thread.join().unwrap();
97            }
98        }
99
100        assert!(!self.is_alive());
101    }
102}
103
104/// Job type that can be executed in the thread pool.
105type Job = Box<dyn FnOnce() + Send + 'static>;
106
107/// Message type that can be sent to the worker threads.
108enum Message {
109    NewJob(Job),
110    Terminate,
111}
112
113/// Worker struct that listens for incoming `Message`s and executes the `Job`s or terminates.
114struct Worker {
115    id: usize,
116    thread: Option<thread::JoinHandle<()>>,
117}
118
119impl Worker {
120    /// Create a new Worker.
121    ///
122    /// The Worker will listen for incoming jobs on the given receiver.
123    /// When a job is received, it will be executed in a new thread and the
124    /// mutex will release allowing another job to be received from a different worker.
125    fn new(id: usize, receiver: Arc<Mutex<mpsc::Receiver<Message>>>) -> Self {
126        Worker {
127            id,
128            thread: Some(thread::spawn(move || {
129                loop {
130                    let message = receiver.lock().unwrap().recv().unwrap();
131
132                    match message {
133                        Message::NewJob(job) => job(),
134                        Message::Terminate => break,
135                    }
136                }
137            })),
138        }
139    }
140
141    /// Simple check if the worker is alive. The thread is 'taken' when the worker is dropped.
142    /// So if the thread is 'None' the worker is no longer alive.
143    pub fn is_alive(&self) -> bool {
144        self.thread.is_some()
145    }
146}
147
148impl Debug for Worker {
149    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150        f.debug_struct("Worker")
151            .field("id", &self.id)
152            .field("is_alive", &self.is_alive())
153            .finish()
154    }
155}
156
157#[derive(Clone)]
158pub struct WaitGroup {
159    inner: Arc<Inner>,
160    total_count: Arc<AtomicUsize>,
161}
162
163struct Inner {
164    counter: AtomicUsize,
165    lock: Mutex<()>,
166    cvar: Condvar,
167}
168
169pub struct WaitGuard {
170    wg: WaitGroup,
171}
172
173impl Drop for WaitGuard {
174    fn drop(&mut self) {
175        if self.wg.inner.counter.fetch_sub(1, Ordering::AcqRel) == 1 {
176            let _guard = self.wg.inner.lock.lock().unwrap();
177            self.wg.inner.cvar.notify_all();
178        }
179    }
180}
181
182impl WaitGroup {
183    pub fn new() -> Self {
184        Self {
185            inner: Arc::new(Inner {
186                counter: AtomicUsize::new(0),
187                lock: Mutex::new(()),
188                cvar: Condvar::new(),
189            }),
190            total_count: Arc::new(AtomicUsize::new(0)),
191        }
192    }
193
194    pub fn get_count(&self) -> usize {
195        self.total_count.load(Ordering::Acquire)
196    }
197
198    /// Adds one to the counter and returns a scoped guard that will decrement when dropped.
199    pub fn guard(&self) -> WaitGuard {
200        self.inner.counter.fetch_add(1, Ordering::AcqRel);
201        self.total_count.fetch_add(1, Ordering::AcqRel);
202        WaitGuard { wg: self.clone() }
203    }
204
205    /// Waits until the counter reaches zero.
206    pub fn wait(&self) -> usize {
207        if self.inner.counter.load(Ordering::Acquire) == 0 {
208            return 0;
209        }
210
211        let lock = self.inner.lock.lock().unwrap();
212        let _unused = self
213            .inner
214            .cvar
215            .wait_while(lock, |_| self.inner.counter.load(Ordering::Acquire) != 0);
216
217        self.get_count()
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use std::time::{Duration, Instant};
224
225    use super::*;
226
227    #[test]
228    fn test_thread_pool_creation() {
229        let pool = ThreadPool::new(4);
230        assert!(pool.is_alive());
231    }
232
233    #[test]
234    fn test_basic_job_execution() {
235        let pool = ThreadPool::new(4);
236        let counter = Arc::new(Mutex::new(0));
237
238        for _ in 0..8 {
239            let counter = Arc::clone(&counter);
240            pool.submit(move || {
241                let mut num = counter.lock().unwrap();
242                *num += 1;
243            });
244        }
245
246        // Give threads some time to finish processing
247        thread::sleep(Duration::from_secs(1));
248        assert_eq!(*counter.lock().unwrap(), 8);
249    }
250
251    #[test]
252    fn test_thread_pool() {
253        let pool = ThreadPool::new(4);
254
255        for i in 0..8 {
256            pool.submit(move || {
257                let start_time = std::time::SystemTime::now();
258                println!("Job {} started.", i);
259                thread::sleep(Duration::from_secs(1));
260                println!("Job {} finished in {:?}.", i, start_time.elapsed().unwrap());
261            });
262        }
263    }
264
265    #[test]
266    fn test_job_order() {
267        let pool = ThreadPool::new(2);
268        let results = Arc::new(Mutex::new(vec![]));
269
270        for i in 0..5 {
271            let results = Arc::clone(&results);
272            pool.submit(move || {
273                results.lock().unwrap().push(i);
274            });
275        }
276
277        // Give threads some time to finish processing
278        thread::sleep(Duration::from_secs(1));
279        let mut results = results.lock().unwrap();
280        results.sort(); // Order may not be guaranteed
281        assert_eq!(*results, vec![0, 1, 2, 3, 4]);
282    }
283
284    #[test]
285    fn test_thread_pool_process() {
286        let pool = ThreadPool::new(4);
287
288        let results = pool.submit_with_result(|| {
289            let start_time = std::time::SystemTime::now();
290            println!("Job started.");
291            thread::sleep(Duration::from_secs(2));
292            println!("Job finished in {:?}.", start_time.elapsed().unwrap());
293            42
294        });
295
296        let result = results.result();
297        assert_eq!(result, 42);
298    }
299
300    #[test]
301    fn test_max_concurrent_jobs() {
302        let pool = ThreadPool::new(4);
303        let (tx, rx) = mpsc::channel();
304        let num_jobs = 20;
305        let start_time = Instant::now();
306
307        // Submit 20 jobs
308        for i in 0..num_jobs {
309            let tx = tx.clone();
310            pool.submit(move || {
311                thread::sleep(Duration::from_millis(100));
312                tx.send(i).unwrap();
313            });
314        }
315
316        // Wait for all jobs to finish
317        let mut results = vec![];
318        for _ in 0..num_jobs {
319            results.push(rx.recv().unwrap());
320        }
321
322        let elapsed = start_time.elapsed();
323        assert!(elapsed < Duration::from_secs(3));
324        assert_eq!(results.len(), num_jobs);
325        assert!(results.iter().all(|&x| x < num_jobs));
326    }
327}