base_threadpool/lib.rs
1#![forbid(unsafe_code)]
2
3//! Threadpool provides a way to manage and execute tasks concurrently using a fixed number of worker threads.
4//! It allows you to submit tasks that will be executed by one of the available worker threads,
5//! providing an efficient way to parallelize work across multiple threads.
6//!
7//! Maintaining a pool of threads over creating a new thread for each task has the benefit that
8//! thread creation and destruction overhead is restricted to the initial creation of the pool.
9//!
10//! # Examples
11//!
12//! ```
13//! use base_threadpool::{ThreadPool, ThreadPoolBuilder};
14//! use std::sync::{Arc, Mutex};
15//!
16//! let thread_pool = ThreadPoolBuilder::default().build();
17//! let value = Arc::new(Mutex::new(0));
18//!
19//! (0..4).for_each(move |_| {
20//! let value = Arc::clone(&value);
21//! thread_pool.execute(move || {
22//! let mut ir = 0;
23//! (0..100_000_000).for_each(|_| {
24//! ir += 1;
25//! });
26//!
27//! let mut lock = value.lock().unwrap();
28//! *lock += ir;
29//! });
30//! });
31//! ```
32use std::{
33 num::NonZero,
34 sync::{mpsc, Arc, Mutex},
35 thread::{self, JoinHandle},
36};
37
38/// `ThreadPool` provides a way to manage and execute tasks concurrently using a fixed number of worker threads.
39///
40/// It allows you to submit tasks that will be executed by one of the available worker threads,
41/// providing an efficient way to parallelize work across multiple threads.
42#[derive(Debug)]
43pub struct ThreadPool {
44 workers: Vec<ThreadWorker>,
45 producer: Option<mpsc::Sender<ThreadJob>>,
46}
47
48impl ThreadPool {
49 /// Discovery method for [`ThreadPoolBuilder`].
50 ///
51 /// Returns a default [`ThreadPoolBuilder`] for constructing a new [`ThreadPool`].
52 ///
53 /// This method provides a convenient way to start building a `ThreadPool` with default settings,
54 /// which can then be customized as needed.
55 ///
56 /// # Examples
57 ///
58 /// ```
59 /// use base_threadpool::ThreadPool;
60 ///
61 /// let pool = ThreadPool::builder().build();
62 /// ```
63 pub fn builder() -> ThreadPoolBuilder {
64 ThreadPoolBuilder::default()
65 }
66
67 /// Schedules a task to be executed by the thread pool.
68 ///
69 /// Panics if the thread pool has been shut down or if there's an error sending the job.
70 ///
71 /// # Examples
72 ///
73 /// ```
74 /// use base_threadpool::ThreadPool;
75 /// use std::sync::{Arc, Mutex};
76 ///
77 /// // Create a thread pool with 4 worker threads
78 /// let pool = ThreadPool::builder().num_threads(4).build();
79 ///
80 /// // Create a list of items to process
81 /// let items = vec!["apple", "banana", "cherry", "date", "elderberry"];
82 /// let processed_items = Arc::new(Mutex::new(Vec::new()));
83 ///
84 /// // Process each item concurrently
85 /// for item in items {
86 /// let processed_items = Arc::clone(&processed_items);
87 /// pool.execute(move || {
88 /// // Simulate some processing time
89 /// std::thread::sleep(std::time::Duration::from_millis(100));
90 ///
91 /// // Process the item (in this case, convert to uppercase)
92 /// let processed = item.to_uppercase();
93 ///
94 /// // Store the processed item
95 /// processed_items.lock().unwrap().push(processed);
96 /// });
97 /// }
98 /// ```
99 #[inline]
100 pub fn execute<F>(&self, f: F)
101 where
102 F: FnOnce() + Send + 'static,
103 {
104 let job = Box::new(f);
105
106 self.producer
107 .as_ref()
108 .expect("err acquiring sender ref")
109 .send(ThreadJob::Run(job))
110 .expect("send error")
111 }
112
113 /// Waits for all worker threads in the pool to finish their current tasks and then shuts down the pool.
114 ///
115 /// This function will block until all workers have completed.
116 ///
117 /// # Examples
118 ///
119 /// ```
120 /// use base_threadpool::ThreadPool;
121 /// use std::sync::{
122 /// atomic::{AtomicU16, Ordering},
123 /// Arc,
124 /// };
125 ///
126 /// let mut pool = ThreadPool::builder().build();
127 /// let counter = Arc::new(AtomicU16::new(0));
128 ///
129 /// (0..100).for_each(|_| {
130 /// let counter = Arc::clone(&counter);
131 /// pool.execute(move || {
132 /// let _ = counter.fetch_add(1, Ordering::SeqCst);
133 /// });
134 /// });
135 ///
136 /// pool.join();
137 /// assert_eq!(counter.load(Ordering::SeqCst), 100);
138 /// ```
139 pub fn join(&mut self) {
140 (0..self.workers.len()).for_each(|_| {
141 self.producer
142 .as_ref()
143 .unwrap()
144 .send(ThreadJob::Stop)
145 .unwrap();
146 });
147
148 // make sure that the channel gets closed once the thread pool is disposed
149 drop(self.producer.take());
150
151 self.workers.iter_mut().for_each(|worker| {
152 if let Some(thread) = worker.thread.take() {
153 thread.join().unwrap();
154 }
155 });
156 }
157
158 /// Provides information about the level of concurrency available in the thread pool.
159 ///
160 /// Returns the number of worker threads in the pool.
161 ///
162 /// # Examples
163 ///
164 /// ```
165 /// use base_threadpool::ThreadPool;
166 ///
167 /// let pool = ThreadPool::builder().num_threads(4).build();
168 /// assert_eq!(pool.num_threads(), 4);
169 /// ```
170 #[doc(alias = "available_parallelism")]
171 #[doc(alias = "available_concurrency")]
172 #[doc(alias = "available_workers")]
173 #[doc(alias = "available_threads")]
174 pub fn num_threads(&self) -> usize {
175 self.workers.len()
176 }
177}
178
179impl Drop for ThreadPool {
180 fn drop(&mut self) {
181 if self.producer.is_some() {
182 // finish up the work that has already been picked up
183 self.join();
184 }
185 }
186}
187
188/// A builder for configuring and creating a [`ThreadPool`].
189///
190/// This builder allows you to set various parameters for the thread pool,
191/// such as the number of threads, stack size, and a name prefix for the threads.
192///
193/// # Examples
194///
195/// Creating a thread pool with default settings:
196///
197/// ```
198/// use base_threadpool::ThreadPoolBuilder;
199///
200/// let pool = ThreadPoolBuilder::default().build();
201/// ```
202///
203/// Creating a customized thread pool:
204///
205/// ```
206/// use base_threadpool::ThreadPoolBuilder;
207///
208/// let pool = ThreadPoolBuilder::default()
209/// .num_threads(4)
210/// .stack_size(3 * 1024 * 1024)
211/// .name_prefix("worker".to_string())
212/// .build();
213/// ```
214#[derive(Debug)]
215pub struct ThreadPoolBuilder {
216 num_threads: NonZero<usize>,
217 stack_size: Option<usize>,
218 name_prefix: Option<String>,
219}
220
221/// Default parameters for [`ThreadPoolBuilder`]
222/// The Default number of threads available for the [`ThreadPool`] is [`std::thread::available_parallelism`].
223impl Default for ThreadPoolBuilder {
224 fn default() -> ThreadPoolBuilder {
225 ThreadPoolBuilder {
226 num_threads: thread::available_parallelism().unwrap(),
227 stack_size: Option::default(),
228 name_prefix: Option::default(),
229 }
230 }
231}
232
233impl ThreadPoolBuilder {
234 /// Constructs a new instance of [`ThreadPoolBuilder`] with specified parameters.
235 ///
236 /// # Arguments
237 ///
238 /// * `num_threads` - The number of threads in the pool. Must be greater than 0.
239 /// * `stack_size` - The stack size for each thread in bytes.
240 /// * `name_prefix` - A prefix for naming the threads in the pool.
241 ///
242 /// # Panics
243 ///
244 /// Panics if `num_threads` is 0.
245 // # Examples
246 ///
247 /// ```
248 /// use base_threadpool::ThreadPoolBuilder;
249 ///
250 /// let builder = ThreadPoolBuilder::new(4, 2 * 1024 * 1024, "custom-worker".to_string());
251 /// let pool = builder.build();
252 /// ```
253 pub fn new(num_threads: usize, stack_size: usize, name_prefix: String) -> ThreadPoolBuilder {
254 assert!(num_threads > 0);
255
256 ThreadPoolBuilder {
257 num_threads: NonZero::new(num_threads).unwrap(),
258 stack_size: Some(stack_size),
259 name_prefix: Some(name_prefix),
260 }
261 }
262
263 /// Builds and returns a new [`ThreadPool`] instance based on the current configuration.
264 ///
265 /// # Examples
266 ///
267 /// ```
268 /// use base_threadpool::ThreadPoolBuilder;
269 ///
270 /// let pool = ThreadPoolBuilder::default().num_threads(2).build();
271 /// ```
272 pub fn build(&self) -> ThreadPool {
273 let (producer, consumer) = mpsc::channel();
274 let consumer = Arc::new(Mutex::new(consumer));
275
276 let mut workers = Vec::with_capacity(self.num_threads.into());
277 (0..self.num_threads.into()).for_each(|id| {
278 let consumer = Arc::clone(&consumer);
279 let mut builder = thread::Builder::new();
280
281 if let Some(stack_size) = self.stack_size {
282 builder = builder.stack_size(stack_size);
283 }
284
285 if let Some(prefix) = &self.name_prefix {
286 builder = builder.name(format!("{}-{}", prefix, id));
287 }
288
289 let worker = ThreadWorker::new(id, consumer, builder);
290 workers.push(worker);
291 });
292
293 ThreadPool {
294 workers,
295 producer: Some(producer),
296 }
297 }
298
299 /// Sets the number of threads for the thread pool.
300 ///
301 /// # Panics
302 ///
303 /// Panics if `num_threads` is 0.
304 ///
305 /// # Examples
306 ///
307 /// ```
308 /// use base_threadpool::ThreadPoolBuilder;
309 ///
310 /// let builder = ThreadPoolBuilder::default().num_threads(8);
311 /// ```
312 pub fn num_threads(mut self, num_threads: usize) -> ThreadPoolBuilder {
313 assert!(num_threads > 0);
314
315 self.num_threads = NonZero::new(num_threads).unwrap();
316 self
317 }
318
319 /// Sets the stack size, in bytes, for each thread in the pool.
320 ///
321 /// # Examples
322 ///
323 /// ```
324 /// use base_threadpool::ThreadPoolBuilder;
325 ///
326 /// let builder = ThreadPoolBuilder::default().stack_size(4 * 1024 * 1024);
327 /// ```
328 pub fn stack_size(mut self, stack_size: usize) -> ThreadPoolBuilder {
329 self.stack_size = Some(stack_size);
330 self
331 }
332
333 /// Sets the name prefix for threads in the pool.
334 ///
335 /// # Examples
336 ///
337 /// ```
338 /// use base_threadpool::ThreadPoolBuilder;
339 ///
340 /// let builder = ThreadPoolBuilder::default().name_prefix("my-worker".to_string());
341 /// ```
342 pub fn name_prefix(mut self, name_prefix: String) -> ThreadPoolBuilder {
343 self.name_prefix = Some(name_prefix);
344 self
345 }
346}
347
348enum ThreadJob {
349 Stop,
350 Run(Box<dyn FnOnce() + Send + 'static>),
351}
352
353#[derive(Debug)]
354struct ThreadWorker {
355 id: usize,
356 thread: Option<JoinHandle<()>>,
357}
358impl ThreadWorker {
359 fn new(
360 id: usize,
361 consumer: Arc<Mutex<mpsc::Receiver<ThreadJob>>>,
362 builder: thread::Builder,
363 ) -> ThreadWorker {
364 let thread = builder
365 .spawn(move || loop {
366 let job = consumer.lock().unwrap().recv().unwrap();
367
368 match job {
369 ThreadJob::Run(job) => job(),
370 ThreadJob::Stop => break,
371 };
372 })
373 .unwrap();
374
375 ThreadWorker {
376 id,
377 thread: Some(thread),
378 }
379 }
380}
381
382impl std::fmt::Display for ThreadWorker {
383 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
384 write!(f, "[{}]", self.id)
385 }
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391 use std::{
392 sync::{
393 atomic::{AtomicBool, Ordering},
394 Arc, Mutex,
395 },
396 thread, time,
397 };
398
399 mod helpers {
400 use super::*;
401
402 const TARGET: usize = 500_000_000;
403
404 pub fn get_sequential_speed() -> time::Duration {
405 let mut value = 0;
406 let start = time::Instant::now();
407
408 (0..TARGET).for_each(|_| {
409 value += 1;
410 });
411
412 start.elapsed()
413 }
414
415 pub fn get_parallel_speed() -> time::Duration {
416 let mut pool = ThreadPoolBuilder::default().build();
417 let num_threads = pool.num_threads();
418 let value = Arc::new(Mutex::new(0));
419 let start = time::Instant::now();
420
421 assert!(num_threads > 0);
422
423 (0..num_threads).for_each(|_| {
424 let value = Arc::clone(&value);
425 let mut ir = 0;
426 pool.execute(move || {
427 (0..TARGET / num_threads).for_each(|_| {
428 ir += 1;
429 });
430
431 let mut value = value.lock().unwrap();
432 *value += ir;
433 });
434 });
435
436 pool.join();
437 start.elapsed()
438 }
439 }
440
441 #[test]
442 fn construct_pool() {
443 let mut pool = ThreadPoolBuilder::default().build();
444
445 let p = Arc::new(Mutex::new(5));
446 let v = Arc::clone(&p);
447 pool.execute(move || {
448 let mut lock = v.lock().unwrap();
449 *lock += 1;
450
451 thread::sleep(time::Duration::from_secs(5));
452 });
453
454 pool.execute(|| {
455 thread::sleep(time::Duration::from_secs(10));
456 });
457
458 pool.join();
459 assert_eq!(*p.lock().unwrap(), 6);
460 }
461
462 #[test]
463 fn test_sequential_vs_parallel_speed() {
464 let sequential = helpers::get_sequential_speed();
465 let parallel = helpers::get_parallel_speed();
466
467 println!("sequential speed: {sequential:#?}\nparallel speed: {parallel:#?}");
468 assert!(sequential > parallel);
469 assert!(sequential > parallel / 2);
470 }
471
472 #[test]
473 fn test_join_disposal() {
474 use std::sync::atomic::{AtomicBool, Ordering};
475
476 let mut pool = ThreadPoolBuilder::default().num_threads(2).build();
477 let task_completed = Arc::new(AtomicBool::new(false));
478 let task_completed_clone = Arc::clone(&task_completed);
479
480 pool.execute(move || {
481 thread::sleep(time::Duration::from_millis(2500));
482 task_completed_clone.store(true, Ordering::SeqCst);
483 });
484 pool.execute(|| {
485 thread::sleep(time::Duration::from_secs(1));
486 });
487 pool.join();
488
489 assert!(
490 task_completed.load(Ordering::SeqCst),
491 "task not completed before shutdown"
492 );
493 assert!(
494 pool.producer.is_none(),
495 "producer isn't none after pool join"
496 );
497 }
498
499 #[test]
500 fn test_setup_builder_default() {
501 let pool = ThreadPoolBuilder::default();
502
503 assert_eq!(pool.num_threads, thread::available_parallelism().unwrap());
504 assert_eq!(pool.name_prefix, None);
505 assert_eq!(pool.stack_size, None);
506 }
507
508 #[test]
509 fn test_setup_builder_new() {
510 let pool = ThreadPoolBuilder::new(1, 5 * 1024, "PrivatePool".to_string());
511
512 assert_eq!(pool.num_threads, NonZero::new(1).unwrap());
513 assert_eq!(pool.stack_size, Some(5 * 1024));
514 assert_eq!(pool.name_prefix, Some("PrivatePool".to_string()));
515 }
516
517 #[test]
518 fn test_setup_builder_num_threads() {
519 let pool = ThreadPoolBuilder::default().num_threads(4).build();
520
521 assert_eq!(pool.num_threads(), 4);
522 }
523
524 #[test]
525 fn test_setup_builder_prefix_name() {
526 let pool = ThreadPoolBuilder::default().name_prefix("DarkPrivatisedPool".to_string());
527
528 assert_eq!(pool.name_prefix, Some("DarkPrivatisedPool".to_string()));
529 }
530
531 #[test]
532 fn test_setup_builder_stack_size() {
533 let pool = ThreadPoolBuilder::default().stack_size(5 * 1024 * 1024);
534
535 assert_eq!(pool.stack_size, Some(5 * 1024 * 1024));
536 }
537
538 #[test]
539 #[should_panic(expected = "err acquiring sender ref")]
540 fn test_execute_after_join_panics() {
541 let mut pool = ThreadPoolBuilder::default().num_threads(2).build();
542
543 pool.join();
544 pool.execute(|| {
545 println!("shouldn't execute");
546 });
547 }
548}