base_threadpool/
lib.rs

1#![forbid(unsafe_code)]
2
3//! Threadpool provides a way to manage and execute tasks concurrently using a fixed number of worker threads.
4//! It allows you to submit tasks that will be executed by one of the available worker threads,
5//! providing an efficient way to parallelize work across multiple threads.
6//!
7//! Maintaining a pool of threads over creating a new thread for each task has the benefit that
8//! thread creation and destruction overhead is restricted to the initial creation of the pool.
9//!
10//! # Examples
11//!
12//! ```
13//! use base_threadpool::{ThreadPool, ThreadPoolBuilder};
14//! use std::sync::{Arc, Mutex};
15//!
16//! let thread_pool = ThreadPoolBuilder::default().build();
17//! let value = Arc::new(Mutex::new(0));
18//!
19//! (0..4).for_each(move |_| {
20//!     let value = Arc::clone(&value);
21//!     thread_pool.execute(move || {
22//!         let mut ir = 0;
23//!         (0..100_000_000).for_each(|_| {
24//!             ir += 1;
25//!         });
26//!
27//!         let mut lock = value.lock().unwrap();
28//!         *lock += ir;
29//!     });
30//! });
31//! ```
32use std::{
33    num::NonZero,
34    sync::{mpsc, Arc, Mutex},
35    thread::{self, JoinHandle},
36};
37
38/// `ThreadPool` provides a way to manage and execute tasks concurrently using a fixed number of worker threads.
39///
40/// It allows you to submit tasks that will be executed by one of the available worker threads,
41/// providing an efficient way to parallelize work across multiple threads.
42#[derive(Debug)]
43pub struct ThreadPool {
44    workers: Vec<ThreadWorker>,
45    producer: Option<mpsc::Sender<ThreadJob>>,
46}
47
48impl ThreadPool {
49    /// Discovery method for [`ThreadPoolBuilder`].
50    ///
51    /// Returns a default [`ThreadPoolBuilder`] for constructing a new [`ThreadPool`].
52    ///
53    /// This method provides a convenient way to start building a `ThreadPool` with default settings,
54    /// which can then be customized as needed.
55    ///
56    /// # Examples
57    ///
58    /// ```
59    /// use base_threadpool::ThreadPool;
60    ///
61    /// let pool = ThreadPool::builder().build();
62    /// ```
63    pub fn builder() -> ThreadPoolBuilder {
64        ThreadPoolBuilder::default()
65    }
66
67    /// Schedules a task to be executed by the thread pool.
68    ///
69    /// Panics if the thread pool has been shut down or if there's an error sending the job.
70    ///
71    /// # Examples
72    ///
73    /// ```
74    /// use base_threadpool::ThreadPool;
75    /// use std::sync::{Arc, Mutex};
76    ///
77    /// // Create a thread pool with 4 worker threads
78    /// let pool = ThreadPool::builder().num_threads(4).build();
79    ///
80    /// // Create a list of items to process
81    /// let items = vec!["apple", "banana", "cherry", "date", "elderberry"];
82    /// let processed_items = Arc::new(Mutex::new(Vec::new()));
83    ///
84    /// // Process each item concurrently
85    /// for item in items {
86    ///     let processed_items = Arc::clone(&processed_items);
87    ///     pool.execute(move || {
88    ///         // Simulate some processing time
89    ///         std::thread::sleep(std::time::Duration::from_millis(100));
90    ///
91    ///         // Process the item (in this case, convert to uppercase)
92    ///         let processed = item.to_uppercase();
93    ///
94    ///         // Store the processed item
95    ///         processed_items.lock().unwrap().push(processed);
96    ///     });
97    /// }
98    /// ```
99    #[inline]
100    pub fn execute<F>(&self, f: F)
101    where
102        F: FnOnce() + Send + 'static,
103    {
104        let job = Box::new(f);
105
106        self.producer
107            .as_ref()
108            .expect("err acquiring sender ref")
109            .send(ThreadJob::Run(job))
110            .expect("send error")
111    }
112
113    /// Waits for all worker threads in the pool to finish their current tasks and then shuts down the pool.
114    ///
115    /// This function will block until all workers have completed.
116    ///
117    /// # Examples
118    ///
119    /// ```
120    /// use base_threadpool::ThreadPool;
121    /// use std::sync::{
122    ///     atomic::{AtomicU16, Ordering},
123    ///     Arc,
124    /// };
125    ///
126    /// let mut pool = ThreadPool::builder().build();
127    /// let counter = Arc::new(AtomicU16::new(0));
128    ///
129    /// (0..100).for_each(|_| {
130    ///     let counter = Arc::clone(&counter);
131    ///     pool.execute(move || {
132    ///         let _ = counter.fetch_add(1, Ordering::SeqCst);
133    ///     });
134    /// });
135    ///
136    /// pool.join();
137    /// assert_eq!(counter.load(Ordering::SeqCst), 100);
138    /// ```
139    pub fn join(&mut self) {
140        (0..self.workers.len()).for_each(|_| {
141            self.producer
142                .as_ref()
143                .unwrap()
144                .send(ThreadJob::Stop)
145                .unwrap();
146        });
147
148        // make sure that the channel gets closed once the thread pool is disposed
149        drop(self.producer.take());
150
151        self.workers.iter_mut().for_each(|worker| {
152            if let Some(thread) = worker.thread.take() {
153                thread.join().unwrap();
154            }
155        });
156    }
157
158    /// Provides information about the level of concurrency available in the thread pool.
159    ///
160    /// Returns the number of worker threads in the pool.
161    ///
162    /// # Examples
163    ///
164    /// ```
165    /// use base_threadpool::ThreadPool;
166    ///
167    /// let pool = ThreadPool::builder().num_threads(4).build();
168    /// assert_eq!(pool.num_threads(), 4);
169    /// ```
170    #[doc(alias = "available_parallelism")]
171    #[doc(alias = "available_concurrency")]
172    #[doc(alias = "available_workers")]
173    #[doc(alias = "available_threads")]
174    pub fn num_threads(&self) -> usize {
175        self.workers.len()
176    }
177}
178
179impl Drop for ThreadPool {
180    fn drop(&mut self) {
181        if self.producer.is_some() {
182            // finish up the work that has already been picked up
183            self.join();
184        }
185    }
186}
187
188/// A builder for configuring and creating a [`ThreadPool`].
189///
190/// This builder allows you to set various parameters for the thread pool,
191/// such as the number of threads, stack size, and a name prefix for the threads.
192///
193/// # Examples
194///
195/// Creating a thread pool with default settings:
196///
197/// ```
198/// use base_threadpool::ThreadPoolBuilder;
199///
200/// let pool = ThreadPoolBuilder::default().build();
201/// ```
202///
203/// Creating a customized thread pool:
204///
205/// ```
206/// use base_threadpool::ThreadPoolBuilder;
207///
208/// let pool = ThreadPoolBuilder::default()
209///     .num_threads(4)
210///     .stack_size(3 * 1024 * 1024)
211///     .name_prefix("worker".to_string())
212///     .build();
213/// ```
214#[derive(Debug)]
215pub struct ThreadPoolBuilder {
216    num_threads: NonZero<usize>,
217    stack_size: Option<usize>,
218    name_prefix: Option<String>,
219}
220
221/// Default parameters for [`ThreadPoolBuilder`]
222/// The Default number of threads available for the [`ThreadPool`] is [`std::thread::available_parallelism`].
223impl Default for ThreadPoolBuilder {
224    fn default() -> ThreadPoolBuilder {
225        ThreadPoolBuilder {
226            num_threads: thread::available_parallelism().unwrap(),
227            stack_size: Option::default(),
228            name_prefix: Option::default(),
229        }
230    }
231}
232
233impl ThreadPoolBuilder {
234    /// Constructs a new instance of [`ThreadPoolBuilder`] with specified parameters.
235    ///
236    /// # Arguments
237    ///
238    /// * `num_threads` - The number of threads in the pool. Must be greater than 0.
239    /// * `stack_size` - The stack size for each thread in bytes.
240    /// * `name_prefix` - A prefix for naming the threads in the pool.
241    ///
242    /// # Panics
243    ///
244    /// Panics if `num_threads` is 0.
245    // # Examples
246    ///
247    /// ```
248    /// use base_threadpool::ThreadPoolBuilder;
249    ///
250    /// let builder = ThreadPoolBuilder::new(4, 2 * 1024 * 1024, "custom-worker".to_string());
251    /// let pool = builder.build();
252    /// ```
253    pub fn new(num_threads: usize, stack_size: usize, name_prefix: String) -> ThreadPoolBuilder {
254        assert!(num_threads > 0);
255
256        ThreadPoolBuilder {
257            num_threads: NonZero::new(num_threads).unwrap(),
258            stack_size: Some(stack_size),
259            name_prefix: Some(name_prefix),
260        }
261    }
262
263    /// Builds and returns a new [`ThreadPool`] instance based on the current configuration.
264    ///
265    /// # Examples
266    ///
267    /// ```
268    /// use base_threadpool::ThreadPoolBuilder;
269    ///
270    /// let pool = ThreadPoolBuilder::default().num_threads(2).build();
271    /// ```
272    pub fn build(&self) -> ThreadPool {
273        let (producer, consumer) = mpsc::channel();
274        let consumer = Arc::new(Mutex::new(consumer));
275
276        let mut workers = Vec::with_capacity(self.num_threads.into());
277        (0..self.num_threads.into()).for_each(|id| {
278            let consumer = Arc::clone(&consumer);
279            let mut builder = thread::Builder::new();
280
281            if let Some(stack_size) = self.stack_size {
282                builder = builder.stack_size(stack_size);
283            }
284
285            if let Some(prefix) = &self.name_prefix {
286                builder = builder.name(format!("{}-{}", prefix, id));
287            }
288
289            let worker = ThreadWorker::new(id, consumer, builder);
290            workers.push(worker);
291        });
292
293        ThreadPool {
294            workers,
295            producer: Some(producer),
296        }
297    }
298
299    /// Sets the number of threads for the thread pool.
300    ///
301    /// # Panics
302    ///
303    /// Panics if `num_threads` is 0.
304    ///
305    /// # Examples
306    ///
307    /// ```
308    /// use base_threadpool::ThreadPoolBuilder;
309    ///
310    /// let builder = ThreadPoolBuilder::default().num_threads(8);
311    /// ```
312    pub fn num_threads(mut self, num_threads: usize) -> ThreadPoolBuilder {
313        assert!(num_threads > 0);
314
315        self.num_threads = NonZero::new(num_threads).unwrap();
316        self
317    }
318
319    /// Sets the stack size, in bytes, for each thread in the pool.
320    ///
321    /// # Examples
322    ///
323    /// ```
324    /// use base_threadpool::ThreadPoolBuilder;
325    ///
326    /// let builder = ThreadPoolBuilder::default().stack_size(4 * 1024 * 1024);
327    /// ```
328    pub fn stack_size(mut self, stack_size: usize) -> ThreadPoolBuilder {
329        self.stack_size = Some(stack_size);
330        self
331    }
332
333    /// Sets the name prefix for threads in the pool.
334    ///
335    /// # Examples
336    ///
337    /// ```
338    /// use base_threadpool::ThreadPoolBuilder;
339    ///
340    /// let builder = ThreadPoolBuilder::default().name_prefix("my-worker".to_string());
341    /// ```
342    pub fn name_prefix(mut self, name_prefix: String) -> ThreadPoolBuilder {
343        self.name_prefix = Some(name_prefix);
344        self
345    }
346}
347
348enum ThreadJob {
349    Stop,
350    Run(Box<dyn FnOnce() + Send + 'static>),
351}
352
353#[derive(Debug)]
354struct ThreadWorker {
355    id: usize,
356    thread: Option<JoinHandle<()>>,
357}
358impl ThreadWorker {
359    fn new(
360        id: usize,
361        consumer: Arc<Mutex<mpsc::Receiver<ThreadJob>>>,
362        builder: thread::Builder,
363    ) -> ThreadWorker {
364        let thread = builder
365            .spawn(move || loop {
366                let job = consumer.lock().unwrap().recv().unwrap();
367
368                match job {
369                    ThreadJob::Run(job) => job(),
370                    ThreadJob::Stop => break,
371                };
372            })
373            .unwrap();
374
375        ThreadWorker {
376            id,
377            thread: Some(thread),
378        }
379    }
380}
381
382impl std::fmt::Display for ThreadWorker {
383    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
384        write!(f, "[{}]", self.id)
385    }
386}
387
388#[cfg(test)]
389mod tests {
390    use super::*;
391    use std::{
392        sync::{
393            atomic::{AtomicBool, Ordering},
394            Arc, Mutex,
395        },
396        thread, time,
397    };
398
399    mod helpers {
400        use super::*;
401
402        const TARGET: usize = 500_000_000;
403
404        pub fn get_sequential_speed() -> time::Duration {
405            let mut value = 0;
406            let start = time::Instant::now();
407
408            (0..TARGET).for_each(|_| {
409                value += 1;
410            });
411
412            start.elapsed()
413        }
414
415        pub fn get_parallel_speed() -> time::Duration {
416            let mut pool = ThreadPoolBuilder::default().build();
417            let num_threads = pool.num_threads();
418            let value = Arc::new(Mutex::new(0));
419            let start = time::Instant::now();
420
421            assert!(num_threads > 0);
422
423            (0..num_threads).for_each(|_| {
424                let value = Arc::clone(&value);
425                let mut ir = 0;
426                pool.execute(move || {
427                    (0..TARGET / num_threads).for_each(|_| {
428                        ir += 1;
429                    });
430
431                    let mut value = value.lock().unwrap();
432                    *value += ir;
433                });
434            });
435
436            pool.join();
437            start.elapsed()
438        }
439    }
440
441    #[test]
442    fn construct_pool() {
443        let mut pool = ThreadPoolBuilder::default().build();
444
445        let p = Arc::new(Mutex::new(5));
446        let v = Arc::clone(&p);
447        pool.execute(move || {
448            let mut lock = v.lock().unwrap();
449            *lock += 1;
450
451            thread::sleep(time::Duration::from_secs(5));
452        });
453
454        pool.execute(|| {
455            thread::sleep(time::Duration::from_secs(10));
456        });
457
458        pool.join();
459        assert_eq!(*p.lock().unwrap(), 6);
460    }
461
462    #[test]
463    fn test_sequential_vs_parallel_speed() {
464        let sequential = helpers::get_sequential_speed();
465        let parallel = helpers::get_parallel_speed();
466
467        println!("sequential speed: {sequential:#?}\nparallel speed: {parallel:#?}");
468        assert!(sequential > parallel);
469        assert!(sequential > parallel / 2);
470    }
471
472    #[test]
473    fn test_join_disposal() {
474        use std::sync::atomic::{AtomicBool, Ordering};
475
476        let mut pool = ThreadPoolBuilder::default().num_threads(2).build();
477        let task_completed = Arc::new(AtomicBool::new(false));
478        let task_completed_clone = Arc::clone(&task_completed);
479
480        pool.execute(move || {
481            thread::sleep(time::Duration::from_millis(2500));
482            task_completed_clone.store(true, Ordering::SeqCst);
483        });
484        pool.execute(|| {
485            thread::sleep(time::Duration::from_secs(1));
486        });
487        pool.join();
488
489        assert!(
490            task_completed.load(Ordering::SeqCst),
491            "task not completed before shutdown"
492        );
493        assert!(
494            pool.producer.is_none(),
495            "producer isn't none after pool join"
496        );
497    }
498
499    #[test]
500    fn test_setup_builder_default() {
501        let pool = ThreadPoolBuilder::default();
502
503        assert_eq!(pool.num_threads, thread::available_parallelism().unwrap());
504        assert_eq!(pool.name_prefix, None);
505        assert_eq!(pool.stack_size, None);
506    }
507
508    #[test]
509    fn test_setup_builder_new() {
510        let pool = ThreadPoolBuilder::new(1, 5 * 1024, "PrivatePool".to_string());
511
512        assert_eq!(pool.num_threads, NonZero::new(1).unwrap());
513        assert_eq!(pool.stack_size, Some(5 * 1024));
514        assert_eq!(pool.name_prefix, Some("PrivatePool".to_string()));
515    }
516
517    #[test]
518    fn test_setup_builder_num_threads() {
519        let pool = ThreadPoolBuilder::default().num_threads(4).build();
520
521        assert_eq!(pool.num_threads(), 4);
522    }
523
524    #[test]
525    fn test_setup_builder_prefix_name() {
526        let pool = ThreadPoolBuilder::default().name_prefix("DarkPrivatisedPool".to_string());
527
528        assert_eq!(pool.name_prefix, Some("DarkPrivatisedPool".to_string()));
529    }
530
531    #[test]
532    fn test_setup_builder_stack_size() {
533        let pool = ThreadPoolBuilder::default().stack_size(5 * 1024 * 1024);
534
535        assert_eq!(pool.stack_size, Some(5 * 1024 * 1024));
536    }
537
538    #[test]
539    #[should_panic(expected = "err acquiring sender ref")]
540    fn test_execute_after_join_panics() {
541        let mut pool = ThreadPoolBuilder::default().num_threads(2).build();
542
543        pool.join();
544        pool.execute(|| {
545            println!("shouldn't execute");
546        });
547    }
548}