jlizard_simple_threadpool/
threadpool.rs1use crate::common::Job;
9use crate::worker::Worker;
10use std::error::Error;
11
12#[cfg(feature = "log")]
13use log::debug;
14
15use std::fmt::{Display, Formatter};
16use std::sync::atomic::{AtomicBool, Ordering};
17use std::sync::mpsc::Sender;
18use std::sync::{Arc, Mutex, mpsc};
19use std::thread;
20
21pub struct ThreadPool {
22 workers: Vec<Worker>,
23 sender: Option<Sender<Job>>,
24 num_threads: u8,
25 kill_signal: Arc<AtomicBool>,
26}
27
28impl ThreadPool {
29 pub fn new(pool_size: u8) -> Self {
34 if pool_size == 0 {
35 Self::default()
36 } else if pool_size == 1 {
37 Self {
38 workers: Vec::new(),
39 sender: None,
40 num_threads: pool_size,
41 kill_signal: Arc::new(AtomicBool::new(false)),
42 }
43 } else {
44 let (sender, receiver) = mpsc::channel::<Job>();
45
46 let mut workers = Vec::with_capacity(pool_size as usize);
47
48 let receiver = Arc::new(Mutex::new(receiver));
49 let kill_signal = Arc::new(AtomicBool::new(false));
50
51 for id in 1..=pool_size {
52 workers.push(Worker::new(
53 id,
54 Arc::clone(&receiver),
55 Arc::clone(&kill_signal),
56 ));
57 }
58
59 Self {
60 workers,
61 sender: Some(sender),
62 num_threads: pool_size,
63 kill_signal,
64 }
65 }
66 }
67
68 pub fn execute<F>(&self, f: F) -> Result<(), Box<dyn Error>>
74 where
75 F: FnOnce() + Send + 'static,
76 {
77 if self.is_single_threaded() {
78 f();
79 Ok(())
80 } else {
81 self.sender
82 .as_ref()
83 .unwrap()
84 .send(Box::new(f))
85 .map_err(|e| e.into())
86 }
87 }
88
89 pub fn is_single_threaded(&self) -> bool {
96 self.sender.is_none() && self.workers.is_empty()
97 }
98
99 pub fn signal_stop(&self) {
105 self.kill_signal.store(true, Ordering::Relaxed);
106 }
107
108 pub fn get_kill_signal(&self) -> Arc<AtomicBool> {
131 Arc::clone(&self.kill_signal)
132 }
133}
134
135impl Drop for ThreadPool {
136 fn drop(&mut self) {
137 drop(self.sender.take());
140
141 #[cfg(feature = "log")]
142 {
143 debug!("Waiting for workers to finish");
144 }
145
146 for worker in &mut self.workers {
150 #[cfg(feature = "log")]
151 {
152 debug!("Shutting down worker {}", worker.id);
153 }
154 worker.thread.take().unwrap().join().unwrap();
155 }
156
157 #[cfg(feature = "log")]
158 {
159 debug!("All workers stopped");
160 }
161 }
162}
163
164impl Default for ThreadPool {
165 fn default() -> Self {
166 let max_threads = thread::available_parallelism().map(|e| e.get()).expect("Unable to find any threads to run with. Possible system-side restrictions or limitations");
167
168 ThreadPool::new(u8::try_from(max_threads).unwrap_or(u8::MAX))
170 }
171}
172
173impl Display for ThreadPool {
174 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
175 if self.is_single_threaded() {
176 write!(
177 f,
178 "Concurrency Disabled: running all jobs sequentially in main thread. A user override forced this through an VEX2PDF_MAX_JOBS or the --max-jobs cli argument"
179 )
180 } else {
181 write!(
182 f,
183 "Concurrency Enabled: running with {} jobs",
184 self.num_threads
185 )
186 }
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193 use std::sync::{Arc, Mutex};
194 use std::time::Duration;
195
196 #[test]
197 fn test_threadpool_creation_modes() {
198 let pool_default = ThreadPool::new(0);
200 assert!(pool_default.num_threads > 0);
201 assert!(!pool_default.is_single_threaded());
202
203 let pool_single = ThreadPool::new(1);
205 assert_eq!(pool_single.num_threads, 1);
206 assert!(pool_single.is_single_threaded());
207 assert!(pool_single.workers.is_empty());
208 assert!(pool_single.sender.is_none());
209
210 let pool_multi = ThreadPool::new(4);
212 assert_eq!(pool_multi.num_threads, 4);
213 assert!(!pool_multi.is_single_threaded());
214 assert_eq!(pool_multi.workers.len(), 4);
215 assert!(pool_multi.sender.is_some());
216 }
217
218 #[test]
219 fn test_single_threaded_execution() {
220 let pool = ThreadPool::new(1);
221 let counter = Arc::new(Mutex::new(0));
222 let counter_clone = Arc::clone(&counter);
223
224 pool.execute(move || {
226 let mut num = counter_clone.lock().unwrap();
227 *num += 1;
228 })
229 .expect("Failed to execute job");
230
231 let value = *counter.lock().unwrap();
233 assert_eq!(value, 1);
234 }
235
236 #[test]
237 fn test_multi_threaded_execution() {
238 let pool = ThreadPool::new(2);
239 let results = Arc::new(Mutex::new(Vec::new()));
240
241 for i in 0..5 {
243 let results_clone = Arc::clone(&results);
244 pool.execute(move || {
245 std::thread::sleep(Duration::from_millis(10));
246 results_clone.lock().unwrap().push(i);
247 })
248 .expect("Failed to execute job");
249 }
250
251 drop(pool);
253
254 let final_results = results.lock().unwrap();
256 assert_eq!(final_results.len(), 5);
257 for i in 0..5 {
259 assert!(final_results.contains(&i));
260 }
261 }
262
263 #[test]
264 fn test_get_num_threads() {
265 let pool1 = ThreadPool::new(1);
266 assert_eq!(pool1.num_threads, 1);
267
268 let pool4 = ThreadPool::new(4);
269 assert_eq!(pool4.num_threads, 4);
270
271 let pool_default = ThreadPool::default();
272 assert!(pool_default.num_threads > 0);
273 }
274
275 #[test]
276 fn test_is_single_threaded() {
277 let pool_single = ThreadPool::new(1);
278 assert!(pool_single.is_single_threaded());
279
280 let pool_multi = ThreadPool::new(2);
281 assert!(!pool_multi.is_single_threaded());
282
283 let pool_default = ThreadPool::default();
284 assert!(!pool_default.is_single_threaded());
285 }
286
287 #[test]
288 fn test_pool_graceful_shutdown() {
289 let pool = ThreadPool::new(3);
290 let completed = Arc::new(Mutex::new(0));
291
292 for _ in 0..10 {
294 let completed_clone = Arc::clone(&completed);
295 pool.execute(move || {
296 std::thread::sleep(Duration::from_millis(20));
297 *completed_clone.lock().unwrap() += 1;
298 })
299 .expect("Failed to execute job");
300 }
301
302 drop(pool);
304
305 assert_eq!(*completed.lock().unwrap(), 10);
307 }
308
309 #[test]
310 fn test_signal_stop_method() {
311 let pool = ThreadPool::new(4);
312 let completed = Arc::new(Mutex::new(0));
313
314 for _ in 0..5 {
316 let completed_clone = Arc::clone(&completed);
317 pool.execute(move || {
318 std::thread::sleep(Duration::from_millis(10));
319 *completed_clone.lock().unwrap() += 1;
320 })
321 .expect("Failed to execute job");
322 }
323
324 pool.signal_stop();
326
327 drop(pool);
329
330 let count = *completed.lock().unwrap();
332 assert!(count >= 1 && count <= 5);
333 }
334
335 #[test]
336 fn test_get_kill_signal() {
337 let pool = ThreadPool::new(2);
338 let kill_signal = pool.get_kill_signal();
339
340 assert!(!kill_signal.load(std::sync::atomic::Ordering::Relaxed));
342
343 kill_signal.store(true, std::sync::atomic::Ordering::Relaxed);
345
346 drop(pool);
348 }
349
350 #[test]
351 fn test_job_signals_stop_to_other_workers() {
352 use std::sync::atomic::Ordering;
353
354 let pool = Arc::new(ThreadPool::new(4));
355 let completed = Arc::new(Mutex::new(Vec::new()));
356 let collision_found = Arc::new(AtomicBool::new(false));
357
358 for i in 0..20 {
360 let pool_clone = Arc::clone(&pool);
361 let completed_clone = Arc::clone(&completed);
362 let collision_found_clone = Arc::clone(&collision_found);
363 let kill_signal = pool.get_kill_signal();
364
365 pool.execute(move || {
366 if kill_signal.load(Ordering::Relaxed) {
368 return;
369 }
370
371 std::thread::sleep(Duration::from_millis(10));
372
373 if i == 2 {
375 collision_found_clone.store(true, Ordering::Relaxed);
376 pool_clone.signal_stop();
377 completed_clone.lock().unwrap().push(i);
378 } else {
379 if !collision_found_clone.load(Ordering::Relaxed) {
381 completed_clone.lock().unwrap().push(i);
382 }
383 }
384 })
385 .expect("Failed to execute job");
386 }
387
388 std::thread::sleep(Duration::from_millis(150));
390
391 drop(pool);
393
394 assert!(collision_found.load(Ordering::Relaxed));
396
397 let completed_jobs = completed.lock().unwrap();
399 assert!(completed_jobs.len() < 20);
400 assert!(completed_jobs.contains(&2)); }
402
403 #[test]
404 fn test_workers_complete_current_job_before_stopping() {
405 use std::sync::atomic::Ordering;
406
407 let pool = ThreadPool::new(2);
408 let job_started = Arc::new(AtomicBool::new(false));
409 let job_completed = Arc::new(AtomicBool::new(false));
410
411 let job_started_clone = Arc::clone(&job_started);
412 let job_completed_clone = Arc::clone(&job_completed);
413
414 pool.execute(move || {
416 job_started_clone.store(true, Ordering::Relaxed);
417 std::thread::sleep(Duration::from_millis(100));
418 job_completed_clone.store(true, Ordering::Relaxed);
419 })
420 .expect("Failed to execute job");
421
422 std::thread::sleep(Duration::from_millis(50));
424 assert!(job_started.load(Ordering::Relaxed));
425
426 pool.signal_stop();
428
429 drop(pool);
431
432 assert!(job_completed.load(Ordering::Relaxed));
434 }
435
436 #[test]
437 fn test_no_new_jobs_after_signal_stop() {
438 use std::sync::atomic::Ordering;
439
440 let pool = ThreadPool::new(3);
441 let executed = Arc::new(AtomicBool::new(false));
442 let executed_clone = Arc::clone(&executed);
443
444 pool.signal_stop();
446
447 pool.execute(move || {
449 executed_clone.store(true, Ordering::Relaxed);
450 })
451 .expect("Failed to execute job");
452
453 std::thread::sleep(Duration::from_millis(200));
455
456 drop(pool);
461 }
463
464 #[test]
465 fn test_kill_signal_in_single_threaded_mode() {
466 let pool = ThreadPool::new(1);
467 assert!(pool.is_single_threaded());
468
469 let kill_signal = pool.get_kill_signal();
471 assert!(!kill_signal.load(std::sync::atomic::Ordering::Relaxed));
472
473 pool.signal_stop();
475 assert!(kill_signal.load(std::sync::atomic::Ordering::Relaxed));
476
477 drop(pool);
479 }
480}