1use std::cell::RefCell;
2use std::sync::{Arc, Barrier, Mutex};
3use std::sync::atomic::{AtomicBool, Ordering};
4use std::thread::{Builder, JoinHandle, LocalKey};
5use std::time::Duration;
6
7use anyhow::anyhow;
8
9use crate::blocking_queue_adapter::BlockingQueueAdapter;
10use crate::command::Command;
11use crate::queue_type::QueueType;
12use crate::shutdown_mode::ShutdownMode;
13
14struct EmptyCommand {}
15
16impl EmptyCommand {
17 pub fn new() -> EmptyCommand {
18 EmptyCommand {}
19 }
20}
21
22impl Command for EmptyCommand {
23 fn execute(&self) -> Result<(), anyhow::Error> {
24 Ok(())
25 }
26}
27
28struct RunInAllThreadsCommand {
29 f: Arc<dyn Fn() + Send + Sync>,
30 b: Arc<Barrier>,
31}
32
33impl RunInAllThreadsCommand {
34 pub fn new(f: Arc<dyn Fn() + Send + Sync>, b: Arc<Barrier>) -> RunInAllThreadsCommand {
35 RunInAllThreadsCommand {
36 f,
37 b,
38 }
39 }
40}
41
42impl Command for RunInAllThreadsCommand {
43 fn execute(&self) -> Result<(), anyhow::Error> {
44 {
45 (self.f)();
46 }
47 self.b.wait();
48 Ok(())
49 }
50}
51
52struct RunMutInAllThreadsCommand {
53 f: Arc<Mutex<dyn FnMut() + Send + Sync>>,
54 b: Arc<Barrier>,
55}
56
57impl RunMutInAllThreadsCommand {
58 pub fn new(f: Arc<Mutex<dyn FnMut() + Send + Sync>>, b: Arc<Barrier>) -> RunMutInAllThreadsCommand {
59 RunMutInAllThreadsCommand {
60 f,
61 b,
62 }
63 }
64}
65
66impl Command for RunMutInAllThreadsCommand {
67 fn execute(&self) -> Result<(), anyhow::Error> {
68 {
69 let mut f = self.f.lock().unwrap();
70 f();
71 }
72 self.b.wait();
73 Ok(())
74 }
75}
76
77pub struct ThreadPool {
95 name: String,
96 tasks: usize,
97 queue: Arc<BlockingQueueAdapter<Box<dyn Command + Send + Sync>>>,
98 threads: Vec<JoinHandle<Result<(), anyhow::Error>>>,
99 join_error_handler: fn(String, String),
100 shutdown_mode: ShutdownMode,
101 stopped: Arc<AtomicBool>,
102 expired: bool,
103}
104
105impl ThreadPool {
106 pub(crate) fn new(
107 name: String,
108 tasks: usize,
109 queue_type: QueueType,
110 queue_size: usize,
111 join_error_handler: fn(String, String),
112 shutdown_mode: ShutdownMode,
113 ) -> Result<ThreadPool, anyhow::Error> {
114 let start_barrier = Arc::new(Barrier::new(tasks + 1));
115 let stopped = Arc::new(AtomicBool::new(false));
116 let mut threads = Vec::<JoinHandle<Result<(), anyhow::Error>>>::new();
117 let queue = Arc::new(BlockingQueueAdapter::new(queue_type, queue_size));
118 for i in 0..tasks {
119 let barrier = start_barrier.clone();
120 let t = Self::create_thread(
121 &name,
122 i,
123 barrier,
124 queue.clone(),
125 stopped.clone(),
126 );
127 threads.push(t.unwrap());
128 }
129
130 start_barrier.wait();
131
132 Ok(
133 ThreadPool {
134 name,
135 tasks,
136 queue: queue.clone(),
137 threads,
138 join_error_handler,
139 shutdown_mode,
140 stopped: stopped.clone(),
141 expired: false,
142 }
143 )
144 }
145
146 pub fn tasks(&self) -> usize {
148 self.tasks
149 }
150
151 fn create_thread(
152 name: &String,
153 index: usize,
154 barrier: Arc<Barrier>,
155 queue: Arc<BlockingQueueAdapter<Box<dyn Command + Send + Sync>>>,
156 stopped: Arc<AtomicBool>,
157 ) -> Result<JoinHandle<Result<(), anyhow::Error>>, anyhow::Error> {
158 let builder = Builder::new();
159 Ok(builder
160 .name(format!("{name}-{index}"))
161 .spawn(move || {
162 barrier.wait();
163 let mut r: Result<(), anyhow::Error> = Ok(());
164 while !stopped.load(Ordering::SeqCst) {
165 let command = queue.dequeue();
166 if let Some(c) = command {
167 match c.execute() {
168 Ok(_) => {}
169 Err(e) => {
170 r = Err(e);
171 }
172 }
173 }
174 }
175 r
176 }
177 )?
178 )
179 }
180
181 pub fn in_all_threads_mut(&self, f: Arc<Mutex<dyn FnMut() + Send + Sync>>) {
190 let b = Arc::new(Barrier::new(self.tasks + 1));
191 for _i in 0..self.tasks {
192 self.submit(Box::new(RunMutInAllThreadsCommand::new(f.clone(), b.clone())));
193 }
194 b.wait();
195 }
196
197 pub fn in_all_threads(&self, f: Arc<dyn Fn() + Send + Sync>) {
206 let b = Arc::new(Barrier::new(self.tasks + 1));
207 for _i in 0..self.tasks {
208 self.submit(Box::new(RunInAllThreadsCommand::new(f.clone(), b.clone())));
209 }
210 b.wait();
211 }
212
213 pub fn set_thread_local<T>(&self, local_key: &'static LocalKey<RefCell<T>>, val: T)
217 where T: Sync + Send + Clone {
218 self.in_all_threads(
219 Arc::new(
220 move || {
221 local_key.with(
222 |value| {
223 value.replace(val.clone())
224 }
225 );
226 }
227 )
228 );
229 }
230
231 pub fn shutdown(&mut self) {
237 self.expired = true;
238 match self.shutdown_mode {
239 ShutdownMode::Immediate => {
240 self.stopped.store(true, Ordering::SeqCst);
241 }
242 ShutdownMode::CompletePending => {
243 self.queue.wait_empty(Duration::MAX);
244 self.stopped.store(true, Ordering::SeqCst);
245 }
246 }
247 for _i in 0..self.tasks {
248 self.unchecked_submit(Box::new(EmptyCommand::new()));
249 }
250 }
251
252 pub fn join(&mut self) -> Result<(), anyhow::Error> {
254 let mut join_errors = Vec::<String>::new();
255 while let Some(t) = self.threads.pop() {
256 let name = t.thread().name().unwrap_or("unnamed").to_string();
257 match t.join() {
258 Ok(r) => {
259 match r {
260 Ok(_) => {}
261 Err(e) => {
262 let message = format!("{e:?}");
263 join_errors.push(message.clone());
264 (self.join_error_handler)(name, message);
265 }
266 }
267 }
268 Err(e) => {
269 let mut message = "Unknown error".to_string();
270 if let Some(error) = e.downcast_ref::<&'static str>() {
271 message = error.to_string();
272 }
273 join_errors.push(message.clone());
274 (self.join_error_handler)(name, message);
275 }
276 }
277 }
278 if join_errors.is_empty() {
279 Ok(())
280 } else {
281 Err(anyhow!("Errors occurred while joining threads in the {} pool: {}", self.name, join_errors.join(", "))
282 )
283 }
284 }
285
286 pub fn submit(&self, command: Box<dyn Command + Send + Sync>) {
288 self.try_submit(command, Duration::MAX);
289 }
290
291 pub fn unchecked_submit(&self, command: Box<dyn Command + Send + Sync>) {
292 self.queue.enqueue(command);
293 }
294
295 pub fn try_submit(&self, command: Box<dyn Command + Send + Sync>, timeout: Duration) -> Option<Box<dyn Command + Send + Sync>> {
299 assert!(!self.expired);
300 self.queue.try_enqueue(command, timeout)
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use std::sync::atomic::{AtomicUsize, Ordering};
307
308 use crate::shutdown_mode::ShutdownMode;
309 use crate::shutdown_mode::ShutdownMode::CompletePending;
310 use crate::thread_pool_builder::ThreadPoolBuilder;
311
312 use super::*;
313
314 struct TestCommand {
315 _payload: i32,
316 execution_counter: Arc<AtomicUsize>,
317 }
318
319 impl TestCommand {
320 pub fn new(payload: i32, execution_counter: Arc<AtomicUsize>) -> TestCommand {
321 TestCommand {
322 _payload: payload,
323 execution_counter,
324 }
325 }
326 }
327
328 impl Command for TestCommand {
329 fn execute(&self) -> Result<(), anyhow::Error> {
330 self.execution_counter.fetch_add(1, Ordering::SeqCst);
331 Ok(())
332 }
333 }
334
335 #[test]
336 fn test_create() {
337 let mut thread_pool_builder = ThreadPoolBuilder::new();
338 let tp_result = thread_pool_builder
339 .with_name("t".to_string())
340 .with_tasks(4)
341 .with_queue_size(8)
342 .build();
343
344 match tp_result {
345 Ok(mut tp) => {
346 assert!(true);
347 tp.shutdown();
348 assert_eq!((), tp.join().unwrap());
349 }
350 Err(_) => {
351 assert!(false);
352 }
353 }
354 }
355
356 #[test]
357 fn test_submit() {
358 let mut thread_pool_builder = ThreadPoolBuilder::new();
359 let mut tp = thread_pool_builder
360 .with_name("t".to_string())
361 .with_tasks(4)
362 .with_queue_size(2048)
363 .build()
364 .unwrap();
365
366 let execution_counter = Arc::new(AtomicUsize::from(0));
367 for _i in 0..1024 {
368 let ec = execution_counter.clone();
369 tp.submit(Box::new(TestCommand::new(4, ec)));
370 }
371
372 tp.shutdown();
373 tp.join().expect("Failed to join thread pool");
374 assert_eq!((), tp.join().unwrap());
375 }
380
381 #[test]
382 fn test_shutdown_complete_pending() {
383 let mut thread_pool_builder = ThreadPoolBuilder::new();
384 let mut tp = thread_pool_builder
385 .with_name("t".to_string())
386 .with_tasks(4)
387 .with_queue_size(2048)
388 .with_shutdown_mode(ShutdownMode::CompletePending)
389 .build()
390 .unwrap();
391
392 let execution_counter = Arc::new(AtomicUsize::from(0));
393 for _i in 0..1024 {
394 let ec = execution_counter.clone();
395 tp.submit(Box::new(TestCommand::new(4, ec)));
396 }
397
398 tp.shutdown();
399 tp.join().expect("Failed to join thread pool");
400 assert_eq!((), tp.join().unwrap());
401 assert_eq!(execution_counter.fetch_or(0, Ordering::SeqCst), 1024);
402 }
403
404 struct PanicTestCommand {}
405
406 impl PanicTestCommand {
407 pub fn new() -> PanicTestCommand {
408 PanicTestCommand {}
409 }
410 }
411
412 impl Command for PanicTestCommand {
413 fn execute(&self) -> Result<(), anyhow::Error> {
414 Err(anyhow!("simulating error during command execution"))
415 }
416 }
417
418 #[test]
419 fn test_join_error_handler() {
420 let mut thread_pool_builder = ThreadPoolBuilder::new();
421 let mut tp = thread_pool_builder
422 .with_name("t".to_string())
423 .with_tasks(4)
424 .with_shutdown_mode(CompletePending)
425 .with_queue_size(8)
426 .with_join_error_handler(
427 |name, message| {
428 println!("Thread {name} ended with and error {message}")
429 }
430 )
431 .build()
432 .unwrap();
433
434 for _i in 0..2 {
435 tp.submit(Box::new(PanicTestCommand::new()));
436 }
437
438 tp.shutdown();
439 let r = tp.join();
440 assert!(r.is_err());
441 }
442
443 #[test]
444 #[should_panic]
445 fn test_use_after_join() {
446 let mut thread_pool_builder = ThreadPoolBuilder::new();
447 let mut tp = thread_pool_builder
448 .with_name("t".to_string())
449 .with_tasks(4)
450 .with_queue_size(2048)
451 .with_shutdown_mode(ShutdownMode::CompletePending)
452 .build()
453 .unwrap();
454
455 let execution_counter = Arc::new(AtomicUsize::from(0));
456 for _i in 0..1024 {
457 let ec = execution_counter.clone();
458 tp.submit(Box::new(TestCommand::new(4, ec)));
459 }
460
461 tp.shutdown();
462 tp.join().expect("Failed to join thread pool");
463 let execution_counter = Arc::new(AtomicUsize::from(0));
464 for _i in 0..1024 {
465 let ec = execution_counter.clone();
466 tp.submit(Box::new(TestCommand::new(4, ec)));
467 }
468 }
469}