command_executor/
thread_pool.rs

1use std::cell::RefCell;
2use std::sync::{Arc, Barrier, Mutex};
3use std::sync::atomic::{AtomicBool, Ordering};
4use std::thread::{Builder, JoinHandle, LocalKey};
5use std::time::Duration;
6
7use anyhow::anyhow;
8
9use crate::blocking_queue_adapter::BlockingQueueAdapter;
10use crate::command::Command;
11use crate::queue_type::QueueType;
12use crate::shutdown_mode::ShutdownMode;
13
14struct EmptyCommand {}
15
16impl EmptyCommand {
17    pub fn new() -> EmptyCommand {
18        EmptyCommand {}
19    }
20}
21
22impl Command for EmptyCommand {
23    fn execute(&self) -> Result<(), anyhow::Error> {
24        Ok(())
25    }
26}
27
28struct RunInAllThreadsCommand {
29    f: Arc<dyn Fn() + Send + Sync>,
30    b: Arc<Barrier>,
31}
32
33impl RunInAllThreadsCommand {
34    pub fn new(f: Arc<dyn Fn() + Send + Sync>, b: Arc<Barrier>) -> RunInAllThreadsCommand {
35        RunInAllThreadsCommand {
36            f,
37            b,
38        }
39    }
40}
41
42impl Command for RunInAllThreadsCommand {
43    fn execute(&self) -> Result<(), anyhow::Error> {
44        {
45            (self.f)();
46        }
47        self.b.wait();
48        Ok(())
49    }
50}
51
52struct RunMutInAllThreadsCommand {
53    f: Arc<Mutex<dyn FnMut() + Send + Sync>>,
54    b: Arc<Barrier>,
55}
56
57impl RunMutInAllThreadsCommand {
58    pub fn new(f: Arc<Mutex<dyn FnMut() + Send + Sync>>, b: Arc<Barrier>) -> RunMutInAllThreadsCommand {
59        RunMutInAllThreadsCommand {
60            f,
61            b,
62        }
63    }
64}
65
66impl Command for RunMutInAllThreadsCommand {
67    fn execute(&self) -> Result<(), anyhow::Error> {
68        {
69            let mut f = self.f.lock().unwrap();
70            f();
71        }
72        self.b.wait();
73        Ok(())
74    }
75}
76
77/// Execute tasks concurrently while maintaining bounds on memory consumption
78///
79/// To demonstrate the use case this implementation solves let's consider a program that reads
80/// lines from a file and writes those lines to another file after some processing. The processing
81/// itself is stateless and can be done in parallel on each line, but the reading and writing must
82/// be sequential. Using this implementation we will read the input in the main thread, submit it
83/// for concurrent processing to a processing thread pool and collect it for writing in a writing thread pool
84/// with a single thread. See ./examples/read_process_write_pipeline.rs. The submission to a thread
85/// pool is done through a blocking bounded queue, so if the processing thread pool or the writing
86/// thread pool cannot keep up, their blocking queues will fill up and create a backpressure that
87/// will pause the reading. So the resulting pipeline will stabilize on a throughput commanded by the
88/// slowest stage with the memory consumption determined by sizes of queues and number of
89/// threads in each thread pool.
90///
91/// For reference see [Command Pattern](https://en.wikipedia.org/wiki/Command_pattern) and
92/// [Producer-Consumer](https://en.wikipedia.org/wiki/Producer%E2%80%93consumer_problem)
93///
94pub struct ThreadPool {
95    name: String,
96    tasks: usize,
97    queue: Arc<BlockingQueueAdapter<Box<dyn Command + Send + Sync>>>,
98    threads: Vec<JoinHandle<Result<(), anyhow::Error>>>,
99    join_error_handler: fn(String, String),
100    shutdown_mode: ShutdownMode,
101    stopped: Arc<AtomicBool>,
102    expired: bool,
103}
104
105impl ThreadPool {
106    pub(crate) fn new(
107        name: String,
108        tasks: usize,
109        queue_type: QueueType,
110        queue_size: usize,
111        join_error_handler: fn(String, String),
112        shutdown_mode: ShutdownMode,
113    ) -> Result<ThreadPool, anyhow::Error> {
114        let start_barrier = Arc::new(Barrier::new(tasks + 1));
115        let stopped = Arc::new(AtomicBool::new(false));
116        let mut threads = Vec::<JoinHandle<Result<(), anyhow::Error>>>::new();
117        let queue = Arc::new(BlockingQueueAdapter::new(queue_type, queue_size));
118        for i in 0..tasks {
119            let barrier = start_barrier.clone();
120            let t = Self::create_thread(
121                &name,
122                i,
123                barrier,
124                queue.clone(),
125                stopped.clone(),
126            );
127            threads.push(t.unwrap());
128        }
129
130        start_barrier.wait();
131
132        Ok(
133            ThreadPool {
134                name,
135                tasks,
136                queue: queue.clone(),
137                threads,
138                join_error_handler,
139                shutdown_mode,
140                stopped: stopped.clone(),
141                expired: false,
142            }
143        )
144    }
145
146    /// Get the number of concurrent threads in the thread pool
147    pub fn tasks(&self) -> usize {
148        self.tasks
149    }
150
151    fn create_thread(
152        name: &String,
153        index: usize,
154        barrier: Arc<Barrier>,
155        queue: Arc<BlockingQueueAdapter<Box<dyn Command + Send + Sync>>>,
156        stopped: Arc<AtomicBool>,
157    ) -> Result<JoinHandle<Result<(), anyhow::Error>>, anyhow::Error> {
158        let builder = Builder::new();
159        Ok(builder
160            .name(format!("{name}-{index}"))
161            .spawn(move || {
162                barrier.wait();
163                let mut r: Result<(), anyhow::Error> = Ok(());
164                while !stopped.load(Ordering::SeqCst) {
165                    let command = queue.dequeue();
166                    if let Some(c) = command {
167                        match c.execute() {
168                            Ok(_) => {}
169                            Err(e) => {
170                                r = Err(e);
171                            }
172                        }
173                    }
174                }
175                r
176            }
177            )?
178        )
179    }
180
181    /// Execute f in all threads.
182    ///
183    /// This function returns only after f had completed in all threads. Can be used to collect
184    /// data produced by the threads. See ./examples/fetch_thread_local.rs.
185    ///
186    /// Caveat: this is a [barrier](https://en.wikipedia.org/wiki/Barrier_%28computer_science%29)
187    /// function. So if one of the threads is busy with a long running task or is deadlocked, this
188    /// will halt all the threads until f can be executed.
189    pub fn in_all_threads_mut(&self, f: Arc<Mutex<dyn FnMut() + Send + Sync>>) {
190        let b = Arc::new(Barrier::new(self.tasks + 1));
191        for _i in 0..self.tasks {
192            self.submit(Box::new(RunMutInAllThreadsCommand::new(f.clone(), b.clone())));
193        }
194        b.wait();
195    }
196
197    /// Execute f in all threads.
198    ///
199    /// This function returns only after f had completed in all threads. Can be used to flush
200    /// data produced by the threads or simply execute work concurrently. See ./examples/flush_thread_local.rs.
201    ///
202    /// Caveat: this is a [barrier](https://en.wikipedia.org/wiki/Barrier_%28computer_science%29)
203    /// function. So if one of the threads is busy with a long running task or is deadlocked, this
204    /// will halt all the threads until f can be executed.
205    pub fn in_all_threads(&self, f: Arc<dyn Fn() + Send + Sync>) {
206        let b = Arc::new(Barrier::new(self.tasks + 1));
207        for _i in 0..self.tasks {
208            self.submit(Box::new(RunInAllThreadsCommand::new(f.clone(), b.clone())));
209        }
210        b.wait();
211    }
212
213    /// Initializes the `local_key` to contain `val`.
214    ///
215    /// See ./examples/thread_local.rs
216    pub fn set_thread_local<T>(&self, local_key: &'static LocalKey<RefCell<T>>, val: T)
217        where T: Sync + Send + Clone {
218        self.in_all_threads(
219            Arc::new(
220                move || {
221                    local_key.with(
222                        |value| {
223                            value.replace(val.clone())
224                        }
225                    );
226                }
227            )
228        );
229    }
230
231    /// Shut down the thread pool.
232    ///
233    /// This will shut down the thread pool according to configuration. When configured with
234    /// * [ShutdownMode::Immediate] - terminate each tread after completing the current task
235    /// * [ShutdownMode::CompletePending] - terminate after completing all pending tasks
236    pub fn shutdown(&mut self) {
237        self.expired = true;
238        match self.shutdown_mode {
239            ShutdownMode::Immediate => {
240                self.stopped.store(true, Ordering::SeqCst);
241            }
242            ShutdownMode::CompletePending => {
243                self.queue.wait_empty(Duration::MAX);
244                self.stopped.store(true, Ordering::SeqCst);
245            }
246        }
247        for _i in 0..self.tasks {
248            self.unchecked_submit(Box::new(EmptyCommand::new()));
249        }
250    }
251
252    /// Wait until all thread pool threads completed.
253    pub fn join(&mut self) -> Result<(), anyhow::Error> {
254        let mut join_errors = Vec::<String>::new();
255        while let Some(t) = self.threads.pop() {
256            let name = t.thread().name().unwrap_or("unnamed").to_string();
257            match t.join() {
258                Ok(r) => {
259                    match r {
260                        Ok(_) => {}
261                        Err(e) => {
262                            let message = format!("{e:?}");
263                            join_errors.push(message.clone());
264                            (self.join_error_handler)(name, message);
265                        }
266                    }
267                }
268                Err(e) => {
269                    let mut message = "Unknown error".to_string();
270                    if let Some(error) = e.downcast_ref::<&'static str>() {
271                        message = error.to_string();
272                    }
273                    join_errors.push(message.clone());
274                    (self.join_error_handler)(name, message);
275                }
276            }
277        }
278        if join_errors.is_empty() {
279            Ok(())
280        } else {
281            Err(anyhow!("Errors occurred while joining threads in the {} pool: {}", self.name, join_errors.join(", "))
282            )
283        }
284    }
285
286    /// Submit command for execution
287    pub fn submit(&self, command: Box<dyn Command + Send + Sync>) {
288        self.try_submit(command, Duration::MAX);
289    }
290
291    pub fn unchecked_submit(&self, command: Box<dyn Command + Send + Sync>) {
292        self.queue.enqueue(command);
293    }
294
295    /// Submit command for execution with timeout
296    ///
297    /// Returns the command on failure and None on success
298    pub fn try_submit(&self, command: Box<dyn Command + Send + Sync>, timeout: Duration) -> Option<Box<dyn Command + Send + Sync>> {
299        assert!(!self.expired);
300        self.queue.try_enqueue(command, timeout)
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use std::sync::atomic::{AtomicUsize, Ordering};
307
308    use crate::shutdown_mode::ShutdownMode;
309    use crate::shutdown_mode::ShutdownMode::CompletePending;
310    use crate::thread_pool_builder::ThreadPoolBuilder;
311
312    use super::*;
313
314    struct TestCommand {
315        _payload: i32,
316        execution_counter: Arc<AtomicUsize>,
317    }
318
319    impl TestCommand {
320        pub fn new(payload: i32, execution_counter: Arc<AtomicUsize>) -> TestCommand {
321            TestCommand {
322                _payload: payload,
323                execution_counter,
324            }
325        }
326    }
327
328    impl Command for TestCommand {
329        fn execute(&self) -> Result<(), anyhow::Error> {
330            self.execution_counter.fetch_add(1, Ordering::SeqCst);
331            Ok(())
332        }
333    }
334
335    #[test]
336    fn test_create() {
337        let mut thread_pool_builder = ThreadPoolBuilder::new();
338        let tp_result = thread_pool_builder
339            .with_name("t".to_string())
340            .with_tasks(4)
341            .with_queue_size(8)
342            .build();
343
344        match tp_result {
345            Ok(mut tp) => {
346                assert!(true);
347                tp.shutdown();
348                assert_eq!((), tp.join().unwrap());
349            }
350            Err(_) => {
351                assert!(false);
352            }
353        }
354    }
355
356    #[test]
357    fn test_submit() {
358        let mut thread_pool_builder = ThreadPoolBuilder::new();
359        let mut tp = thread_pool_builder
360            .with_name("t".to_string())
361            .with_tasks(4)
362            .with_queue_size(2048)
363            .build()
364            .unwrap();
365
366        let execution_counter = Arc::new(AtomicUsize::from(0));
367        for _i in 0..1024 {
368            let ec = execution_counter.clone();
369            tp.submit(Box::new(TestCommand::new(4, ec)));
370        }
371
372        tp.shutdown();
373        tp.join().expect("Failed to join thread pool");
374        assert_eq!((), tp.join().unwrap());
375        // accidental but usually works
376        // if fails safe to comment out the next two lines
377        // assert!(execution_counter.fetch_or(0, Ordering::SeqCst) > 0);
378        // assert!(execution_counter.fetch_or(0, Ordering::SeqCst) < 1024);
379    }
380
381    #[test]
382    fn test_shutdown_complete_pending() {
383        let mut thread_pool_builder = ThreadPoolBuilder::new();
384        let mut tp = thread_pool_builder
385            .with_name("t".to_string())
386            .with_tasks(4)
387            .with_queue_size(2048)
388            .with_shutdown_mode(ShutdownMode::CompletePending)
389            .build()
390            .unwrap();
391
392        let execution_counter = Arc::new(AtomicUsize::from(0));
393        for _i in 0..1024 {
394            let ec = execution_counter.clone();
395            tp.submit(Box::new(TestCommand::new(4, ec)));
396        }
397
398        tp.shutdown();
399        tp.join().expect("Failed to join thread pool");
400        assert_eq!((), tp.join().unwrap());
401        assert_eq!(execution_counter.fetch_or(0, Ordering::SeqCst), 1024);
402    }
403
404    struct PanicTestCommand {}
405
406    impl PanicTestCommand {
407        pub fn new() -> PanicTestCommand {
408            PanicTestCommand {}
409        }
410    }
411
412    impl Command for PanicTestCommand {
413        fn execute(&self) -> Result<(), anyhow::Error> {
414            Err(anyhow!("simulating error during command execution"))
415        }
416    }
417
418    #[test]
419    fn test_join_error_handler() {
420        let mut thread_pool_builder = ThreadPoolBuilder::new();
421        let mut tp = thread_pool_builder
422            .with_name("t".to_string())
423            .with_tasks(4)
424            .with_shutdown_mode(CompletePending)
425            .with_queue_size(8)
426            .with_join_error_handler(
427                |name, message| {
428                    println!("Thread {name} ended with and error {message}")
429                }
430            )
431            .build()
432            .unwrap();
433
434        for _i in 0..2 {
435            tp.submit(Box::new(PanicTestCommand::new()));
436        }
437
438        tp.shutdown();
439        let r = tp.join();
440        assert!(r.is_err());
441    }
442
443    #[test]
444    #[should_panic]
445    fn test_use_after_join() {
446        let mut thread_pool_builder = ThreadPoolBuilder::new();
447        let mut tp = thread_pool_builder
448            .with_name("t".to_string())
449            .with_tasks(4)
450            .with_queue_size(2048)
451            .with_shutdown_mode(ShutdownMode::CompletePending)
452            .build()
453            .unwrap();
454
455        let execution_counter = Arc::new(AtomicUsize::from(0));
456        for _i in 0..1024 {
457            let ec = execution_counter.clone();
458            tp.submit(Box::new(TestCommand::new(4, ec)));
459        }
460
461        tp.shutdown();
462        tp.join().expect("Failed to join thread pool");
463        let execution_counter = Arc::new(AtomicUsize::from(0));
464        for _i in 0..1024 {
465            let ec = execution_counter.clone();
466            tp.submit(Box::new(TestCommand::new(4, ec)));
467        }
468    }
469}