radiate_core/domain/
thread_pool.rs1use std::{
2 fmt::Debug,
3 sync::{
4 Arc, Condvar, Mutex, OnceLock,
5 atomic::{AtomicUsize, Ordering},
6 },
7};
8use std::{sync::mpsc, thread};
9
10struct FixedThreadPool {
11 inner: Arc<ThreadPool>,
12}
13
14impl FixedThreadPool {
15 pub(self) fn instance(num_workers: usize) -> &'static FixedThreadPool {
17 static INSTANCE: OnceLock<FixedThreadPool> = OnceLock::new();
18
19 INSTANCE.get_or_init(|| FixedThreadPool {
20 inner: Arc::new(ThreadPool::new(num_workers)),
21 })
22 }
23}
24
25pub fn get_thread_pool(num_workers: usize) -> Arc<ThreadPool> {
26 Arc::clone(&FixedThreadPool::instance(num_workers).inner)
27}
28
29pub struct WorkResult<T> {
33 receiver: mpsc::Receiver<T>,
34}
35
36impl<T> WorkResult<T> {
37 pub fn result(&self) -> T {
40 self.receiver.recv().unwrap()
41 }
42}
43
44pub struct ThreadPool {
45 sender: mpsc::Sender<Message>,
46 workers: Vec<Worker>,
47}
48
49impl ThreadPool {
50 pub fn new(size: usize) -> Self {
54 let (sender, receiver) = mpsc::channel();
55 let receiver = Arc::new(Mutex::new(receiver));
56
57 ThreadPool {
58 sender,
59 workers: (0..size)
60 .map(|id| Worker::new(id, Arc::clone(&receiver)))
61 .collect(),
62 }
63 }
64
65 pub fn group_submit(&self, wg: &WaitGroup, f: impl FnOnce() + Send + 'static) {
66 let guard = wg.guard();
67
68 self.submit(move || {
69 f();
70 drop(guard);
71 });
72 }
73
74 pub fn submit<F>(&self, f: F)
76 where
77 F: FnOnce() + Send + 'static,
78 {
79 let job = Box::new(f);
80 self.sender.send(Message::NewJob(job)).unwrap();
81 }
82
83 pub fn submit_with_result<F, T>(&self, f: F) -> WorkResult<T>
85 where
86 F: FnOnce() -> T + Send + 'static,
87 T: Send + 'static,
88 {
89 let (tx, rx) = mpsc::sync_channel(1);
90 let job = Box::new(move || tx.send(f()).unwrap());
91
92 self.sender.send(Message::NewJob(job)).unwrap();
93 WorkResult { receiver: rx }
94 }
95
96 pub fn num_workers(&self) -> usize {
97 self.workers.len()
98 }
99
100 pub fn is_alive(&self) -> bool {
101 self.workers.iter().any(|worker| worker.is_alive())
102 }
103}
104
105impl Drop for ThreadPool {
108 fn drop(&mut self) {
109 for _ in self.workers.iter() {
110 self.sender.send(Message::Terminate).unwrap();
111 }
112
113 for worker in self.workers.iter_mut() {
114 if let Some(thread) = worker.thread.take() {
115 thread.join().unwrap();
116 }
117 }
118
119 assert!(!self.is_alive());
120 }
121}
122
123type Job = Box<dyn FnOnce() + Send + 'static>;
125
126enum Message {
128 NewJob(Job),
129 Terminate,
130}
131
132struct Worker {
134 id: usize,
135 thread: Option<thread::JoinHandle<()>>,
136}
137
138impl Worker {
139 fn new(id: usize, receiver: Arc<Mutex<mpsc::Receiver<Message>>>) -> Self {
145 Worker {
146 id,
147 thread: Some(thread::spawn(move || {
148 loop {
149 let message = receiver.lock().unwrap().recv().unwrap();
150
151 match message {
152 Message::NewJob(job) => job(),
153 Message::Terminate => break,
154 }
155 }
156 })),
157 }
158 }
159
160 pub fn is_alive(&self) -> bool {
163 self.thread.is_some()
164 }
165}
166
167impl Debug for Worker {
168 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169 f.debug_struct("Worker")
170 .field("id", &self.id)
171 .field("is_alive", &self.is_alive())
172 .finish()
173 }
174}
175
176#[derive(Clone)]
177pub struct WaitGroup {
178 inner: Arc<Inner>,
179 total_count: Arc<AtomicUsize>,
180}
181
182struct Inner {
183 counter: AtomicUsize,
184 lock: Mutex<()>,
185 cvar: Condvar,
186}
187
188pub struct WaitGuard {
189 wg: WaitGroup,
190}
191
192impl Drop for WaitGuard {
193 fn drop(&mut self) {
194 if self.wg.inner.counter.fetch_sub(1, Ordering::AcqRel) == 1 {
195 let _guard = self.wg.inner.lock.lock().unwrap();
196 self.wg.inner.cvar.notify_all();
197 }
198 }
199}
200
201impl WaitGroup {
202 pub fn new() -> Self {
203 Self {
204 inner: Arc::new(Inner {
205 counter: AtomicUsize::new(0),
206 lock: Mutex::new(()),
207 cvar: Condvar::new(),
208 }),
209 total_count: Arc::new(AtomicUsize::new(0)),
210 }
211 }
212
213 pub fn get_count(&self) -> usize {
214 self.total_count.load(Ordering::Acquire)
215 }
216
217 pub fn guard(&self) -> WaitGuard {
219 self.inner.counter.fetch_add(1, Ordering::AcqRel);
220 self.total_count.fetch_add(1, Ordering::AcqRel);
221 WaitGuard { wg: self.clone() }
222 }
223
224 pub fn wait(&self) -> usize {
226 if self.inner.counter.load(Ordering::Acquire) == 0 {
227 return 0;
228 }
229
230 let lock = self.inner.lock.lock().unwrap();
231 let _unused = self
232 .inner
233 .cvar
234 .wait_while(lock, |_| self.inner.counter.load(Ordering::Acquire) != 0);
235
236 self.get_count()
237 }
238}
239
240#[cfg(test)]
241mod tests {
242 use std::time::{Duration, Instant};
243
244 use super::*;
245
246 #[test]
247 fn test_thread_pool_creation() {
248 let pool = ThreadPool::new(4);
249 assert!(pool.is_alive());
250 }
251
252 #[test]
253 fn test_basic_job_execution() {
254 let pool = ThreadPool::new(4);
255 let counter = Arc::new(Mutex::new(0));
256
257 for _ in 0..8 {
258 let counter = Arc::clone(&counter);
259 pool.submit(move || {
260 let mut num = counter.lock().unwrap();
261 *num += 1;
262 });
263 }
264
265 thread::sleep(Duration::from_secs(1));
267 assert_eq!(*counter.lock().unwrap(), 8);
268 }
269
270 #[test]
271 fn test_thread_pool() {
272 let pool = ThreadPool::new(4);
273
274 for i in 0..8 {
275 pool.submit(move || {
276 let start_time = std::time::SystemTime::now();
277 println!("Job {} started.", i);
278 thread::sleep(Duration::from_secs(1));
279 println!("Job {} finished in {:?}.", i, start_time.elapsed().unwrap());
280 });
281 }
282 }
283
284 #[test]
285 fn test_job_order() {
286 let pool = ThreadPool::new(2);
287 let results = Arc::new(Mutex::new(vec![]));
288
289 for i in 0..5 {
290 let results = Arc::clone(&results);
291 pool.submit(move || {
292 results.lock().unwrap().push(i);
293 });
294 }
295
296 thread::sleep(Duration::from_secs(1));
298 let mut results = results.lock().unwrap();
299 results.sort(); assert_eq!(*results, vec![0, 1, 2, 3, 4]);
301 }
302
303 #[test]
304 fn test_thread_pool_process() {
305 let pool = ThreadPool::new(4);
306
307 let results = pool.submit_with_result(|| {
308 let start_time = std::time::SystemTime::now();
309 println!("Job started.");
310 thread::sleep(Duration::from_secs(2));
311 println!("Job finished in {:?}.", start_time.elapsed().unwrap());
312 42
313 });
314
315 let result = results.result();
316 assert_eq!(result, 42);
317 }
318
319 #[test]
320 fn test_max_concurrent_jobs() {
321 let pool = ThreadPool::new(4);
322 let (tx, rx) = mpsc::channel();
323 let num_jobs = 20;
324 let start_time = Instant::now();
325
326 for i in 0..num_jobs {
328 let tx = tx.clone();
329 pool.submit(move || {
330 thread::sleep(Duration::from_millis(100));
331 tx.send(i).unwrap();
332 });
333 }
334
335 let mut results = vec![];
337 for _ in 0..num_jobs {
338 results.push(rx.recv().unwrap());
339 }
340
341 let elapsed = start_time.elapsed();
342 assert!(elapsed < Duration::from_secs(3));
343 assert_eq!(results.len(), num_jobs);
344 assert!(results.iter().all(|&x| x < num_jobs));
345 }
346}