radiate_core/domain/sync/
thread_pool.rs1use std::{
2 fmt::Debug,
3 sync::{Arc, Mutex, OnceLock},
4};
5use std::{sync::mpsc, thread};
6
7struct FixedThreadPool {
15 inner: Arc<ThreadPool>,
16}
17
18impl FixedThreadPool {
19 pub(self) fn instance(num_workers: usize) -> &'static FixedThreadPool {
25 static INSTANCE: OnceLock<FixedThreadPool> = OnceLock::new();
26
27 INSTANCE.get_or_init(|| FixedThreadPool {
28 inner: Arc::new(ThreadPool::new(num_workers)),
29 })
30 }
31}
32
33pub fn get_thread_pool(num_workers: usize) -> Arc<ThreadPool> {
34 Arc::clone(&FixedThreadPool::instance(num_workers).inner)
35}
36
37pub struct WorkResult<T> {
41 receiver: mpsc::Receiver<T>,
42}
43
44impl<T> WorkResult<T> {
45 pub fn new(rx: mpsc::Receiver<T>) -> Self {
46 WorkResult { receiver: rx }
47 }
48 pub fn result(&self) -> T {
51 self.receiver.recv().unwrap()
52 }
53}
54
55pub struct ThreadPool {
56 sender: mpsc::Sender<Message>,
57 workers: Vec<Worker>,
58}
59
60impl ThreadPool {
61 pub fn new(size: usize) -> Self {
65 let (sender, receiver) = mpsc::channel();
66 let receiver = Arc::new(Mutex::new(receiver));
67
68 ThreadPool {
69 sender,
70 workers: (0..size)
71 .map(|id| Worker::new(id, Arc::clone(&receiver)))
72 .collect(),
73 }
74 }
75
76 pub fn num_workers(&self) -> usize {
77 self.workers.len()
78 }
79
80 pub fn is_alive(&self) -> bool {
81 self.workers.iter().any(|worker| worker.is_alive())
82 }
83
84 pub fn submit<F>(&self, f: F)
109 where
110 F: FnOnce() + Send + 'static,
111 {
112 let job = Box::new(f);
113 self.sender.send(Message::Work(job)).unwrap();
114 }
115
116 pub fn submit_with_result<F, T>(&self, f: F) -> WorkResult<T>
136 where
137 F: FnOnce() -> T + Send + 'static,
138 T: Send + 'static,
139 {
140 let (tx, rx) = mpsc::sync_channel(1);
141 let job = Box::new(move || tx.send(f()).unwrap());
142
143 self.sender.send(Message::Work(job)).unwrap();
144
145 WorkResult { receiver: rx }
146 }
147}
148
149impl Drop for ThreadPool {
152 fn drop(&mut self) {
153 for _ in self.workers.iter() {
154 self.sender.send(Message::Terminate).unwrap();
155 }
156
157 for worker in self.workers.iter_mut() {
158 if let Some(thread) = worker.thread.take() {
159 thread.join().unwrap();
160 }
161 }
162
163 assert!(!self.is_alive());
164 }
165}
166
167type Job = Box<dyn FnOnce() + Send + 'static>;
169
170enum Message {
172 Work(Job),
173 Terminate,
174}
175
176struct Worker {
178 id: usize,
179 thread: Option<thread::JoinHandle<()>>,
180}
181
182impl Worker {
183 fn new(id: usize, receiver: Arc<Mutex<mpsc::Receiver<Message>>>) -> Self {
187 Worker {
188 id,
189 thread: Some(thread::spawn(move || {
190 loop {
191 let message = receiver.lock().unwrap().recv().unwrap();
192
193 match message {
194 Message::Work(job) => job(),
195 Message::Terminate => break,
196 }
197 }
198 })),
199 }
200 }
201
202 pub fn is_alive(&self) -> bool {
205 self.thread.is_some()
206 }
207}
208
209impl Debug for Worker {
210 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
211 f.debug_struct("Worker")
212 .field("id", &self.id)
213 .field("is_alive", &self.is_alive())
214 .finish()
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use super::*;
221 use crate::WaitGroup;
222 use std::time::{Duration, Instant};
223
224 #[test]
225 fn test_thread_pool_creation() {
226 let pool = ThreadPool::new(4);
227 assert!(pool.is_alive());
228 }
229
230 #[test]
231 fn test_basic_job_execution() {
232 let pool = ThreadPool::new(4);
233 let counter = Arc::new(Mutex::new(0));
234
235 for _ in 0..8 {
236 let counter = Arc::clone(&counter);
237 pool.submit(move || {
238 let mut num = counter.lock().unwrap();
239 *num += 1;
240 });
241 }
242
243 thread::sleep(Duration::from_secs(1));
245 assert_eq!(*counter.lock().unwrap(), 8);
246 }
247
248 #[test]
249 fn test_thread_pool() {
250 let pool = ThreadPool::new(4);
251
252 for i in 0..8 {
253 pool.submit(move || {
254 let start_time = std::time::SystemTime::now();
255 println!("Job {} started.", i);
256 thread::sleep(Duration::from_secs(1));
257 println!("Job {} finished in {:?}.", i, start_time.elapsed().unwrap());
258 });
259 }
260 }
261
262 #[test]
263 fn test_job_order() {
264 let pool = ThreadPool::new(2);
265 let results = Arc::new(Mutex::new(vec![]));
266
267 for i in 0..5 {
268 let results = Arc::clone(&results);
269 pool.submit(move || {
270 results.lock().unwrap().push(i);
271 });
272 }
273
274 thread::sleep(Duration::from_secs(1));
276 let mut results = results.lock().unwrap();
277 results.sort(); assert_eq!(*results, vec![0, 1, 2, 3, 4]);
279 }
280
281 #[test]
282 fn test_thread_pool_process() {
283 let pool = ThreadPool::new(4);
284
285 let results = pool.submit_with_result(|| {
286 let start_time = std::time::SystemTime::now();
287 println!("Job started.");
288 thread::sleep(Duration::from_secs(2));
289 println!("Job finished in {:?}.", start_time.elapsed().unwrap());
290 42
291 });
292
293 let result = results.result();
294 assert_eq!(result, 42);
295 }
296
297 #[test]
298 fn test_max_concurrent_jobs() {
299 let pool = ThreadPool::new(4);
300 let (tx, rx) = mpsc::channel();
301 let num_jobs = 20;
302 let start_time = Instant::now();
303
304 for i in 0..num_jobs {
306 let tx = tx.clone();
307 pool.submit(move || {
308 thread::sleep(Duration::from_millis(100));
309 tx.send(i).unwrap();
310 });
311 }
312
313 let mut results = vec![];
315 for _ in 0..num_jobs {
316 results.push(rx.recv().unwrap());
317 }
318
319 let elapsed = start_time.elapsed();
320 assert!(elapsed < Duration::from_secs(3));
321 assert_eq!(results.len(), num_jobs);
322 assert!(results.iter().all(|&x| x < num_jobs));
323 }
324
325 #[test]
326 fn tests_thread_pool_submit_with_result_returns_correct_order() {
327 let pool = ThreadPool::new(5);
328 let num_jobs = 10;
329 let mut work_results = vec![];
330
331 for i in 0..num_jobs {
332 let work_result = pool.submit_with_result(move || {
333 thread::sleep(Duration::from_millis(50 * (num_jobs - i) as u64));
334 i * i
335 });
336 work_results.push(work_result);
337 }
338
339 for (i, work_result) in work_results.into_iter().enumerate() {
340 let result = work_result.result();
341 assert_eq!(result, i * i);
342 }
343 }
344
345 #[test]
346 fn test_wait_group() {
347 let pool = ThreadPool::new(4);
348 let wg = WaitGroup::new();
349 let num_tasks = 10;
350 let total = Arc::new(Mutex::new(0));
351
352 for _ in 0..num_tasks {
353 let guard = wg.guard();
354 let total = Arc::clone(&total);
355 pool.submit(move || {
356 thread::sleep(Duration::from_millis(100));
357 let mut num = total.lock().unwrap();
358 *num += 1;
359 drop(guard);
360 });
361 }
362
363 {
365 let total = total.lock().unwrap();
366 assert_ne!(*total, num_tasks);
367 }
368
369 let total_tasks_waited_for = wg.wait();
370
371 let total = total.lock().unwrap();
373 assert_eq!(*total, num_tasks);
374 assert_eq!(total_tasks_waited_for, num_tasks);
375 }
376
377 #[test]
378 fn test_wait_group_zero_tasks() {
379 let wg = WaitGroup::new();
380 let total_tasks_waited_for = wg.wait();
381 assert_eq!(total_tasks_waited_for, 0);
382 }
383}