radiate_core/domain/
thread_pool.rs1use std::{
2 fmt::Debug,
3 sync::{
4 Arc, Condvar, Mutex,
5 atomic::{AtomicUsize, Ordering},
6 },
7};
8use std::{sync::mpsc, thread};
9
10pub struct WorkResult<T> {
14 receiver: mpsc::Receiver<T>,
15}
16
17impl<T> WorkResult<T> {
18 pub fn result(&self) -> T {
21 self.receiver.recv().unwrap()
22 }
23}
24
25pub struct ThreadPool {
26 sender: mpsc::Sender<Message>,
27 workers: Vec<Worker>,
28}
29
30impl ThreadPool {
31 pub fn new(size: usize) -> Self {
35 let (sender, receiver) = mpsc::channel();
36 let receiver = Arc::new(Mutex::new(receiver));
37
38 ThreadPool {
39 sender,
40 workers: (0..size)
41 .map(|id| Worker::new(id, Arc::clone(&receiver)))
42 .collect(),
43 }
44 }
45
46 pub fn group_submit(&self, wg: &WaitGroup, f: impl FnOnce() + Send + 'static) {
47 let guard = wg.guard();
48
49 self.submit(move || {
50 f();
51 drop(guard);
52 });
53 }
54
55 pub fn submit<F>(&self, f: F)
57 where
58 F: FnOnce() + Send + 'static,
59 {
60 let job = Box::new(f);
61 self.sender.send(Message::NewJob(job)).unwrap();
62 }
63
64 pub fn submit_with_result<F, T>(&self, f: F) -> WorkResult<T>
66 where
67 F: FnOnce() -> T + Send + 'static,
68 T: Send + 'static,
69 {
70 let (tx, rx) = mpsc::sync_channel(1);
71 let job = Box::new(move || tx.send(f()).unwrap());
72
73 self.sender.send(Message::NewJob(job)).unwrap();
74 WorkResult { receiver: rx }
75 }
76
77 pub fn num_workers(&self) -> usize {
78 self.workers.len()
79 }
80
81 pub fn is_alive(&self) -> bool {
82 self.workers.iter().any(|worker| worker.is_alive())
83 }
84}
85
86impl Drop for ThreadPool {
89 fn drop(&mut self) {
90 for _ in self.workers.iter() {
91 self.sender.send(Message::Terminate).unwrap();
92 }
93
94 for worker in self.workers.iter_mut() {
95 if let Some(thread) = worker.thread.take() {
96 thread.join().unwrap();
97 }
98 }
99
100 assert!(!self.is_alive());
101 }
102}
103
104type Job = Box<dyn FnOnce() + Send + 'static>;
106
107enum Message {
109 NewJob(Job),
110 Terminate,
111}
112
113struct Worker {
115 id: usize,
116 thread: Option<thread::JoinHandle<()>>,
117}
118
119impl Worker {
120 fn new(id: usize, receiver: Arc<Mutex<mpsc::Receiver<Message>>>) -> Self {
126 Worker {
127 id,
128 thread: Some(thread::spawn(move || {
129 loop {
130 let message = receiver.lock().unwrap().recv().unwrap();
131
132 match message {
133 Message::NewJob(job) => job(),
134 Message::Terminate => break,
135 }
136 }
137 })),
138 }
139 }
140
141 pub fn is_alive(&self) -> bool {
144 self.thread.is_some()
145 }
146}
147
148impl Debug for Worker {
149 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150 f.debug_struct("Worker")
151 .field("id", &self.id)
152 .field("is_alive", &self.is_alive())
153 .finish()
154 }
155}
156
157#[derive(Clone)]
158pub struct WaitGroup {
159 inner: Arc<Inner>,
160 total_count: Arc<AtomicUsize>,
161}
162
163struct Inner {
164 counter: AtomicUsize,
165 lock: Mutex<()>,
166 cvar: Condvar,
167}
168
169pub struct WaitGuard {
170 wg: WaitGroup,
171}
172
173impl Drop for WaitGuard {
174 fn drop(&mut self) {
175 if self.wg.inner.counter.fetch_sub(1, Ordering::AcqRel) == 1 {
176 let _guard = self.wg.inner.lock.lock().unwrap();
177 self.wg.inner.cvar.notify_all();
178 }
179 }
180}
181
182impl WaitGroup {
183 pub fn new() -> Self {
184 Self {
185 inner: Arc::new(Inner {
186 counter: AtomicUsize::new(0),
187 lock: Mutex::new(()),
188 cvar: Condvar::new(),
189 }),
190 total_count: Arc::new(AtomicUsize::new(0)),
191 }
192 }
193
194 pub fn get_count(&self) -> usize {
195 self.total_count.load(Ordering::Acquire)
196 }
197
198 pub fn guard(&self) -> WaitGuard {
200 self.inner.counter.fetch_add(1, Ordering::AcqRel);
201 self.total_count.fetch_add(1, Ordering::AcqRel);
202 WaitGuard { wg: self.clone() }
203 }
204
205 pub fn wait(&self) -> usize {
207 if self.inner.counter.load(Ordering::Acquire) == 0 {
208 return 0;
209 }
210
211 let lock = self.inner.lock.lock().unwrap();
212 let _unused = self
213 .inner
214 .cvar
215 .wait_while(lock, |_| self.inner.counter.load(Ordering::Acquire) != 0);
216
217 self.get_count()
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use std::time::{Duration, Instant};
224
225 use super::*;
226
227 #[test]
228 fn test_thread_pool_creation() {
229 let pool = ThreadPool::new(4);
230 assert!(pool.is_alive());
231 }
232
233 #[test]
234 fn test_basic_job_execution() {
235 let pool = ThreadPool::new(4);
236 let counter = Arc::new(Mutex::new(0));
237
238 for _ in 0..8 {
239 let counter = Arc::clone(&counter);
240 pool.submit(move || {
241 let mut num = counter.lock().unwrap();
242 *num += 1;
243 });
244 }
245
246 thread::sleep(Duration::from_secs(1));
248 assert_eq!(*counter.lock().unwrap(), 8);
249 }
250
251 #[test]
252 fn test_thread_pool() {
253 let pool = ThreadPool::new(4);
254
255 for i in 0..8 {
256 pool.submit(move || {
257 let start_time = std::time::SystemTime::now();
258 println!("Job {} started.", i);
259 thread::sleep(Duration::from_secs(1));
260 println!("Job {} finished in {:?}.", i, start_time.elapsed().unwrap());
261 });
262 }
263 }
264
265 #[test]
266 fn test_job_order() {
267 let pool = ThreadPool::new(2);
268 let results = Arc::new(Mutex::new(vec![]));
269
270 for i in 0..5 {
271 let results = Arc::clone(&results);
272 pool.submit(move || {
273 results.lock().unwrap().push(i);
274 });
275 }
276
277 thread::sleep(Duration::from_secs(1));
279 let mut results = results.lock().unwrap();
280 results.sort(); assert_eq!(*results, vec![0, 1, 2, 3, 4]);
282 }
283
284 #[test]
285 fn test_thread_pool_process() {
286 let pool = ThreadPool::new(4);
287
288 let results = pool.submit_with_result(|| {
289 let start_time = std::time::SystemTime::now();
290 println!("Job started.");
291 thread::sleep(Duration::from_secs(2));
292 println!("Job finished in {:?}.", start_time.elapsed().unwrap());
293 42
294 });
295
296 let result = results.result();
297 assert_eq!(result, 42);
298 }
299
300 #[test]
301 fn test_max_concurrent_jobs() {
302 let pool = ThreadPool::new(4);
303 let (tx, rx) = mpsc::channel();
304 let num_jobs = 20;
305 let start_time = Instant::now();
306
307 for i in 0..num_jobs {
309 let tx = tx.clone();
310 pool.submit(move || {
311 thread::sleep(Duration::from_millis(100));
312 tx.send(i).unwrap();
313 });
314 }
315
316 let mut results = vec![];
318 for _ in 0..num_jobs {
319 results.push(rx.recv().unwrap());
320 }
321
322 let elapsed = start_time.elapsed();
323 assert!(elapsed < Duration::from_secs(3));
324 assert_eq!(results.len(), num_jobs);
325 assert!(results.iter().all(|&x| x < num_jobs));
326 }
327}