jlizard_simple_threadpool/
threadpool.rs1use crate::common::Job;
9use crate::worker::Worker;
10use std::error::Error;
11
12#[cfg(feature = "log")]
13use log::debug;
14
15use std::fmt::{Display, Formatter};
16use std::sync::mpsc::Sender;
17use std::sync::{Arc, Mutex, mpsc};
18use std::thread;
19
20pub struct ThreadPool {
21 workers: Vec<Worker>,
22 sender: Option<Sender<Job>>,
23 num_threads: u8,
24}
25
26impl ThreadPool {
27 pub(crate) fn new(pool_size: u8) -> Self {
32 if pool_size == 0 {
33 Self::default()
34 } else if pool_size == 1 {
35 Self {
36 workers: Vec::new(),
37 sender: None,
38 num_threads: pool_size,
39 }
40 } else {
41 let (sender, receiver) = mpsc::channel::<Job>();
42
43 let mut workers = Vec::with_capacity(pool_size as usize);
44
45 let receiver = Arc::new(Mutex::new(receiver));
46
47 for id in 1..=pool_size {
48 workers.push(Worker::new(id, Arc::clone(&receiver)));
49 }
50
51 Self {
52 workers,
53 sender: Some(sender),
54 num_threads: pool_size,
55 }
56 }
57 }
58
59 pub fn execute<F>(&self, f: F) -> Result<(), Box<dyn Error>>
65 where
66 F: FnOnce() + Send + 'static,
67 {
68 if self.is_single_threaded() {
69 f();
70 Ok(())
71 } else {
72 self.sender
73 .as_ref()
74 .unwrap()
75 .send(Box::new(f))
76 .map_err(|e| e.into())
77 }
78 }
79
80 pub fn is_single_threaded(&self) -> bool {
87 self.sender.is_none() && self.workers.is_empty()
88 }
89}
90
91impl Drop for ThreadPool {
92 fn drop(&mut self) {
93 drop(self.sender.take());
95 for worker in &mut self.workers {
98 #[cfg(feature = "log")]
99 {
100 debug!("Shutting down worker {}", worker.id);
101 }
102 worker.thread.take().unwrap().join().unwrap();
103 }
104 }
105}
106
107impl Default for ThreadPool {
108 fn default() -> Self {
109 let max_threads = thread::available_parallelism().map(|e| e.get()).expect("Unable to find any threads to run with. Possible system-side restrictions or limitations");
110
111 ThreadPool::new(u8::try_from(max_threads).unwrap_or(u8::MAX))
113 }
114}
115
116impl Display for ThreadPool {
117 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
118 if self.is_single_threaded() {
119 write!(
120 f,
121 "Concurrency Disabled: running all jobs sequentially in main thread. A user override forced this through an VEX2PDF_MAX_JOBS or the --max-jobs cli argument"
122 )
123 } else {
124 write!(
125 f,
126 "Concurrency Enabled: running with {} jobs",
127 self.num_threads
128 )
129 }
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136 use std::sync::{Arc, Mutex};
137 use std::time::Duration;
138
139 #[test]
140 fn test_threadpool_creation_modes() {
141 let pool_default = ThreadPool::new(0);
143 assert!(pool_default.num_threads > 0);
144 assert!(!pool_default.is_single_threaded());
145
146 let pool_single = ThreadPool::new(1);
148 assert_eq!(pool_single.num_threads, 1);
149 assert!(pool_single.is_single_threaded());
150 assert!(pool_single.workers.is_empty());
151 assert!(pool_single.sender.is_none());
152
153 let pool_multi = ThreadPool::new(4);
155 assert_eq!(pool_multi.num_threads, 4);
156 assert!(!pool_multi.is_single_threaded());
157 assert_eq!(pool_multi.workers.len(), 4);
158 assert!(pool_multi.sender.is_some());
159 }
160
161 #[test]
162 fn test_single_threaded_execution() {
163 let pool = ThreadPool::new(1);
164 let counter = Arc::new(Mutex::new(0));
165 let counter_clone = Arc::clone(&counter);
166
167 pool.execute(move || {
169 let mut num = counter_clone.lock().unwrap();
170 *num += 1;
171 })
172 .expect("Failed to execute job");
173
174 let value = *counter.lock().unwrap();
176 assert_eq!(value, 1);
177 }
178
179 #[test]
180 fn test_multi_threaded_execution() {
181 let pool = ThreadPool::new(2);
182 let results = Arc::new(Mutex::new(Vec::new()));
183
184 for i in 0..5 {
186 let results_clone = Arc::clone(&results);
187 pool.execute(move || {
188 std::thread::sleep(Duration::from_millis(10));
189 results_clone.lock().unwrap().push(i);
190 })
191 .expect("Failed to execute job");
192 }
193
194 drop(pool);
196
197 let final_results = results.lock().unwrap();
199 assert_eq!(final_results.len(), 5);
200 for i in 0..5 {
202 assert!(final_results.contains(&i));
203 }
204 }
205
206 #[test]
207 fn test_get_num_threads() {
208 let pool1 = ThreadPool::new(1);
209 assert_eq!(pool1.num_threads, 1);
210
211 let pool4 = ThreadPool::new(4);
212 assert_eq!(pool4.num_threads, 4);
213
214 let pool_default = ThreadPool::default();
215 assert!(pool_default.num_threads > 0);
216 }
217
218 #[test]
219 fn test_is_single_threaded() {
220 let pool_single = ThreadPool::new(1);
221 assert!(pool_single.is_single_threaded());
222
223 let pool_multi = ThreadPool::new(2);
224 assert!(!pool_multi.is_single_threaded());
225
226 let pool_default = ThreadPool::default();
227 assert!(!pool_default.is_single_threaded());
228 }
229
230 #[test]
231 fn test_pool_graceful_shutdown() {
232 let pool = ThreadPool::new(3);
233 let completed = Arc::new(Mutex::new(0));
234
235 for _ in 0..10 {
237 let completed_clone = Arc::clone(&completed);
238 pool.execute(move || {
239 std::thread::sleep(Duration::from_millis(20));
240 *completed_clone.lock().unwrap() += 1;
241 })
242 .expect("Failed to execute job");
243 }
244
245 drop(pool);
247
248 assert_eq!(*completed.lock().unwrap(), 10);
250 }
251}