jlizard_simple_threadpool/
threadpool.rs

1//! Thread pool implementation for concurrent PDF generation.
2//!
3//! Supports single-threaded mode (`max_jobs=1`) for debugging and sequential processing,
4//! or multi-threaded mode for parallel processing of multiple BOM files.
5//!
6//! When `max_jobs` is 0 or not set, the pool uses all available CPU cores for maximum parallelism.
7
8use crate::common::Job;
9use crate::worker::Worker;
10use std::error::Error;
11
12#[cfg(feature = "log")]
13use log::debug;
14
15use std::fmt::{Display, Formatter};
16use std::sync::atomic::{AtomicBool, Ordering};
17use std::sync::mpsc::Sender;
18use std::sync::{Arc, Mutex, mpsc};
19use std::thread;
20
21pub struct ThreadPool {
22    workers: Vec<Worker>,
23    sender: Option<Sender<Job>>,
24    num_threads: u8,
25    kill_signal: Arc<AtomicBool>,
26}
27
28impl ThreadPool {
29    /// Creates a new thread pool with the following behavior constraints:
30    /// - pool_size is `0`: runs in multithreaded default mode using maximum parallelism
31    /// - pool_size is `1`: runs in single-threaded mode (all jobs are run in the main thread)
32    /// - pool_size is `1<N<=255`: runs in multithreaded mode with `N` jobs
33    pub fn new(pool_size: u8) -> Self {
34        if pool_size == 0 {
35            Self::default()
36        } else if pool_size == 1 {
37            Self {
38                workers: Vec::new(),
39                sender: None,
40                num_threads: pool_size,
41                kill_signal: Arc::new(AtomicBool::new(false)),
42            }
43        } else {
44            let (sender, receiver) = mpsc::channel::<Job>();
45
46            let mut workers = Vec::with_capacity(pool_size as usize);
47
48            let receiver = Arc::new(Mutex::new(receiver));
49            let kill_signal = Arc::new(AtomicBool::new(false));
50
51            for id in 1..=pool_size {
52                workers.push(Worker::new(
53                    id,
54                    Arc::clone(&receiver),
55                    Arc::clone(&kill_signal),
56                ));
57            }
58
59            Self {
60                workers,
61                sender: Some(sender),
62                num_threads: pool_size,
63                kill_signal,
64            }
65        }
66    }
67
68    /// Executes a job on the thread pool.
69    ///
70    /// # Behavior
71    /// - **Single-threaded mode** (`max_jobs=1`): Job executes synchronously in the calling thread
72    /// - **Multi-threaded mode**: Job is queued and executed asynchronously by worker threads
73    pub fn execute<F>(&self, f: F) -> Result<(), Box<dyn Error>>
74    where
75        F: FnOnce() + Send + 'static,
76    {
77        if self.is_single_threaded() {
78            f();
79            Ok(())
80        } else {
81            self.sender
82                .as_ref()
83                .unwrap()
84                .send(Box::new(f))
85                .map_err(|e| e.into())
86        }
87    }
88
89    /// Returns `true` if running in single-threaded mode.
90    ///
91    /// Single-threaded mode is active when `max_jobs=1`, resulting in:
92    /// - No worker threads spawned
93    /// - No message passing channel created
94    /// - All jobs executed synchronously in the main thread
95    pub fn is_single_threaded(&self) -> bool {
96        self.sender.is_none() && self.workers.is_empty()
97    }
98
99    /// Signals all worker threads to stop processing after completing their current jobs.
100    ///
101    /// This method sets the kill signal which workers check periodically.
102    /// Workers will complete their current job before stopping.
103    /// The pool's Drop implementation will wait for all workers to finish.
104    pub fn signal_stop(&self) {
105        self.kill_signal.store(true, Ordering::Relaxed);
106    }
107
108    /// Returns a clone of the kill signal Arc that can be passed into jobs.
109    ///
110    /// This allows jobs to signal the thread pool to stop from within the job itself.
111    /// Useful for scenarios like finding a hash collision where one worker needs to
112    /// stop all other workers.
113    ///
114    /// # Example
115    /// ```no_run
116    /// use jlizard_simple_threadpool::threadpool::ThreadPool;
117    /// use std::sync::atomic::Ordering;
118    ///
119    /// let pool = ThreadPool::new(4);
120    /// let kill_signal = pool.get_kill_signal();
121    ///
122    /// pool.execute(move || {
123    ///     // Do some work...
124    ///     if /* condition met */ true {
125    ///         // Signal all workers to stop
126    ///         kill_signal.store(true, Ordering::Relaxed);
127    ///     }
128    /// }).expect("Failed to execute");
129    /// ```
130    pub fn get_kill_signal(&self) -> Arc<AtomicBool> {
131        Arc::clone(&self.kill_signal)
132    }
133}
134
135impl Drop for ThreadPool {
136    fn drop(&mut self) {
137        // Drop the sender first which causes receivers to error out gracefully
138        // This allows pending jobs in the queue to complete
139        drop(self.sender.take());
140
141        #[cfg(feature = "log")]
142        {
143            debug!("Waiting for workers to finish");
144        }
145
146        // Workers will exit naturally when:
147        // 1. The channel is closed (sender dropped) and queue is empty, OR
148        // 2. The kill_signal was set (via signal_stop())
149        for worker in &mut self.workers {
150            #[cfg(feature = "log")]
151            {
152                debug!("Shutting down worker {}", worker.id);
153            }
154            worker.thread.take().unwrap().join().unwrap();
155        }
156
157        #[cfg(feature = "log")]
158        {
159            debug!("All workers stopped");
160        }
161    }
162}
163
164impl Default for ThreadPool {
165    fn default() -> Self {
166        let max_threads = thread::available_parallelism().map(|e| e.get()).expect("Unable to find any threads to run with. Possible system-side restrictions or limitations");
167
168        // saturate to u8::MAX if number of threads is larger than what u8 can hold
169        ThreadPool::new(u8::try_from(max_threads).unwrap_or(u8::MAX))
170    }
171}
172
173impl Display for ThreadPool {
174    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
175        if self.is_single_threaded() {
176            write!(
177                f,
178                "Concurrency Disabled: running all jobs sequentially in main thread. A user override forced this through an VEX2PDF_MAX_JOBS or the --max-jobs cli argument"
179            )
180        } else {
181            write!(
182                f,
183                "Concurrency Enabled: running with {} jobs",
184                self.num_threads
185            )
186        }
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193    use std::sync::{Arc, Mutex};
194    use std::time::Duration;
195
196    #[test]
197    fn test_threadpool_creation_modes() {
198        // Test pool with size 0 (default - max parallelism)
199        let pool_default = ThreadPool::new(0);
200        assert!(pool_default.num_threads > 0);
201        assert!(!pool_default.is_single_threaded());
202
203        // Test pool with size 1 (single-threaded)
204        let pool_single = ThreadPool::new(1);
205        assert_eq!(pool_single.num_threads, 1);
206        assert!(pool_single.is_single_threaded());
207        assert!(pool_single.workers.is_empty());
208        assert!(pool_single.sender.is_none());
209
210        // Test pool with size 4 (multi-threaded)
211        let pool_multi = ThreadPool::new(4);
212        assert_eq!(pool_multi.num_threads, 4);
213        assert!(!pool_multi.is_single_threaded());
214        assert_eq!(pool_multi.workers.len(), 4);
215        assert!(pool_multi.sender.is_some());
216    }
217
218    #[test]
219    fn test_single_threaded_execution() {
220        let pool = ThreadPool::new(1);
221        let counter = Arc::new(Mutex::new(0));
222        let counter_clone = Arc::clone(&counter);
223
224        // Execute job synchronously
225        pool.execute(move || {
226            let mut num = counter_clone.lock().unwrap();
227            *num += 1;
228        })
229        .expect("Failed to execute job");
230
231        // In single-threaded mode, job executes immediately
232        let value = *counter.lock().unwrap();
233        assert_eq!(value, 1);
234    }
235
236    #[test]
237    fn test_multi_threaded_execution() {
238        let pool = ThreadPool::new(2);
239        let results = Arc::new(Mutex::new(Vec::new()));
240
241        // Execute multiple jobs
242        for i in 0..5 {
243            let results_clone = Arc::clone(&results);
244            pool.execute(move || {
245                std::thread::sleep(Duration::from_millis(10));
246                results_clone.lock().unwrap().push(i);
247            })
248            .expect("Failed to execute job");
249        }
250
251        // Drop pool to wait for all jobs to complete
252        drop(pool);
253
254        // Verify all jobs completed
255        let final_results = results.lock().unwrap();
256        assert_eq!(final_results.len(), 5);
257        // Results may be in any order due to concurrency
258        for i in 0..5 {
259            assert!(final_results.contains(&i));
260        }
261    }
262
263    #[test]
264    fn test_get_num_threads() {
265        let pool1 = ThreadPool::new(1);
266        assert_eq!(pool1.num_threads, 1);
267
268        let pool4 = ThreadPool::new(4);
269        assert_eq!(pool4.num_threads, 4);
270
271        let pool_default = ThreadPool::default();
272        assert!(pool_default.num_threads > 0);
273    }
274
275    #[test]
276    fn test_is_single_threaded() {
277        let pool_single = ThreadPool::new(1);
278        assert!(pool_single.is_single_threaded());
279
280        let pool_multi = ThreadPool::new(2);
281        assert!(!pool_multi.is_single_threaded());
282
283        let pool_default = ThreadPool::default();
284        assert!(!pool_default.is_single_threaded());
285    }
286
287    #[test]
288    fn test_pool_graceful_shutdown() {
289        let pool = ThreadPool::new(3);
290        let completed = Arc::new(Mutex::new(0));
291
292        // Execute several jobs
293        for _ in 0..10 {
294            let completed_clone = Arc::clone(&completed);
295            pool.execute(move || {
296                std::thread::sleep(Duration::from_millis(20));
297                *completed_clone.lock().unwrap() += 1;
298            })
299            .expect("Failed to execute job");
300        }
301
302        // Drop pool - should wait for all jobs to complete
303        drop(pool);
304
305        // All jobs should have completed
306        assert_eq!(*completed.lock().unwrap(), 10);
307    }
308
309    #[test]
310    fn test_signal_stop_method() {
311        let pool = ThreadPool::new(4);
312        let completed = Arc::new(Mutex::new(0));
313
314        // Execute several quick jobs
315        for _ in 0..5 {
316            let completed_clone = Arc::clone(&completed);
317            pool.execute(move || {
318                std::thread::sleep(Duration::from_millis(10));
319                *completed_clone.lock().unwrap() += 1;
320            })
321            .expect("Failed to execute job");
322        }
323
324        // Signal stop
325        pool.signal_stop();
326
327        // Drop pool to wait for shutdown
328        drop(pool);
329
330        // At least some jobs should have completed before stop
331        let count = *completed.lock().unwrap();
332        assert!(count >= 1 && count <= 5);
333    }
334
335    #[test]
336    fn test_get_kill_signal() {
337        let pool = ThreadPool::new(2);
338        let kill_signal = pool.get_kill_signal();
339
340        // Verify we can read the signal
341        assert!(!kill_signal.load(std::sync::atomic::Ordering::Relaxed));
342
343        // Signal stop through the cloned kill signal
344        kill_signal.store(true, std::sync::atomic::Ordering::Relaxed);
345
346        // Pool should also see it
347        drop(pool);
348    }
349
350    #[test]
351    fn test_job_signals_stop_to_other_workers() {
352        use std::sync::atomic::Ordering;
353
354        let pool = Arc::new(ThreadPool::new(4));
355        let completed = Arc::new(Mutex::new(Vec::new()));
356        let collision_found = Arc::new(AtomicBool::new(false));
357
358        // Submit jobs where job 2 will "find a collision" early
359        for i in 0..20 {
360            let pool_clone = Arc::clone(&pool);
361            let completed_clone = Arc::clone(&completed);
362            let collision_found_clone = Arc::clone(&collision_found);
363            let kill_signal = pool.get_kill_signal();
364
365            pool.execute(move || {
366                // Check if already stopped before starting work
367                if kill_signal.load(Ordering::Relaxed) {
368                    return;
369                }
370
371                std::thread::sleep(Duration::from_millis(10));
372
373                // Job 2 simulates finding a collision (early job to ensure it runs)
374                if i == 2 {
375                    collision_found_clone.store(true, Ordering::Relaxed);
376                    pool_clone.signal_stop();
377                    completed_clone.lock().unwrap().push(i);
378                } else {
379                    // Only complete if collision not yet found
380                    if !collision_found_clone.load(Ordering::Relaxed) {
381                        completed_clone.lock().unwrap().push(i);
382                    }
383                }
384            })
385            .expect("Failed to execute job");
386        }
387
388        // Give jobs time to execute
389        std::thread::sleep(Duration::from_millis(150));
390
391        // Wait for pool to finish
392        drop(pool);
393
394        // Verify collision was found
395        assert!(collision_found.load(Ordering::Relaxed));
396
397        // Not all jobs should have completed (some were stopped)
398        let completed_jobs = completed.lock().unwrap();
399        assert!(completed_jobs.len() < 20);
400        assert!(completed_jobs.contains(&2)); // The collision job completed
401    }
402
403    #[test]
404    fn test_workers_complete_current_job_before_stopping() {
405        use std::sync::atomic::Ordering;
406
407        let pool = ThreadPool::new(2);
408        let job_started = Arc::new(AtomicBool::new(false));
409        let job_completed = Arc::new(AtomicBool::new(false));
410
411        let job_started_clone = Arc::clone(&job_started);
412        let job_completed_clone = Arc::clone(&job_completed);
413
414        // Start a long job
415        pool.execute(move || {
416            job_started_clone.store(true, Ordering::Relaxed);
417            std::thread::sleep(Duration::from_millis(100));
418            job_completed_clone.store(true, Ordering::Relaxed);
419        })
420        .expect("Failed to execute job");
421
422        // Wait for job to start
423        std::thread::sleep(Duration::from_millis(50));
424        assert!(job_started.load(Ordering::Relaxed));
425
426        // Signal stop while job is running
427        pool.signal_stop();
428
429        // Drop pool to wait for completion
430        drop(pool);
431
432        // Job should have completed before worker stopped
433        assert!(job_completed.load(Ordering::Relaxed));
434    }
435
436    #[test]
437    fn test_no_new_jobs_after_signal_stop() {
438        use std::sync::atomic::Ordering;
439
440        let pool = ThreadPool::new(3);
441        let executed = Arc::new(AtomicBool::new(false));
442        let executed_clone = Arc::clone(&executed);
443
444        // Signal stop immediately
445        pool.signal_stop();
446
447        // Try to execute a job
448        pool.execute(move || {
449            executed_clone.store(true, Ordering::Relaxed);
450        })
451        .expect("Failed to execute job");
452
453        // Give some time for potential execution
454        std::thread::sleep(Duration::from_millis(200));
455
456        // Job might or might not execute depending on timing
457        // This is expected behavior - jobs in queue may still execute
458        // The important thing is workers stop checking for new jobs
459
460        drop(pool);
461        // Test passes if we don't hang
462    }
463
464    #[test]
465    fn test_kill_signal_in_single_threaded_mode() {
466        let pool = ThreadPool::new(1);
467        assert!(pool.is_single_threaded());
468
469        // Get kill signal - should exist even in single-threaded mode
470        let kill_signal = pool.get_kill_signal();
471        assert!(!kill_signal.load(std::sync::atomic::Ordering::Relaxed));
472
473        // Signal stop
474        pool.signal_stop();
475        assert!(kill_signal.load(std::sync::atomic::Ordering::Relaxed));
476
477        // Drop should work fine
478        drop(pool);
479    }
480}