easy_threadpool/lib.rs
1#![forbid(missing_docs)]
2// License at https://github.com/NicoElbers/easy_threadpool
3// It's MIT
4
5//! A simple thread pool to execute jobs in parallel
6//!
7//! A simple crate without dependencies which allows you to create a threadpool
8//! that has a specified amount of threads which execute given jobs. Threads don't
9//! crash when a job panics!
10//!
11//! # Examples
12//!
13//! ## Basic usage
14//!
15//! A basic use of the threadpool
16//!
17//! ```rust
18//! # use std::error::Error;
19//! # fn main() -> Result<(), Box<dyn Error>> {
20//! use easy_threadpool::ThreadPoolBuilder;
21//!
22//! fn job() {
23//! println!("Hello world!");
24//! }
25//!
26//! let builder = ThreadPoolBuilder::with_max_threads()?;
27//! let pool = builder.build()?;
28//!
29//! for _ in 0..10 {
30//! pool.send_job(job);
31//! }
32//!
33//! assert!(pool.wait_until_finished().is_ok());
34//! # Ok(())
35//! # }
36//! ```
37//!
38//! ## More advanced usage
39//!
40//! A slightly more advanced usage of the threadpool
41//!
42//! ```rust
43//! # use std::error::Error;
44//! # fn main() -> Result<(), Box<dyn Error>> {
45//! use easy_threadpool::ThreadPoolBuilder;
46//! use std::sync::mpsc::channel;
47//!
48//! let builder = ThreadPoolBuilder::with_max_threads()?;
49//! let pool = builder.build()?;
50//!
51//! let (tx, rx) = channel();
52//!
53//! for _ in 0..10 {
54//! let tx = tx.clone();
55//! pool.send_job(move || {
56//! tx.send(1).expect("Receiver should still exist");
57//! });
58//! }
59//!
60//! assert!(pool.wait_until_finished().is_ok());
61//!
62//! assert_eq!(rx.iter().take(10).fold(0, |a, b| a + b), 10);
63//! # Ok(())
64//! # }
65//! ```
66//!
67//! ## Dealing with panics
68//!
69//! This threadpool implementation is resistant to jobs panicing
70//!
71//! ```rust
72//! # use std::error::Error;
73//! # fn main() -> Result<(), Box<dyn Error>> {
74//! use easy_threadpool::ThreadPoolBuilder;
75//! use std::sync::mpsc::channel;
76//! use std::num::NonZeroUsize;
77//!
78//! fn panic_fn() {
79//! panic!("Test panic");
80//! }
81//!
82//! let num = NonZeroUsize::try_from(1)?;
83//! let builder = ThreadPoolBuilder::with_thread_amount(num);
84//! let pool = builder.build()?;
85//!
86//! let (tx, rx) = channel();
87//! for _ in 0..10 {
88//! let tx = tx.clone();
89//! pool.send_job(move || {
90//! tx.send(1).expect("Receiver should still exist");
91//! panic!("Test panic");
92//! });
93//! }
94//!
95//! assert!(pool.wait_until_finished().is_err());
96//! pool.wait_until_finished_unchecked();
97//!
98//! assert_eq!(pool.jobs_paniced(), 10);
99//! assert_eq!(rx.iter().take(10).fold(0, |a, b| a + b), 10);
100//! # Ok(())
101//! # }
102//! ```
103
104use std::{
105 error::Error,
106 fmt::{Debug, Display},
107 io,
108 num::{NonZeroUsize, TryFromIntError},
109 panic::{catch_unwind, UnwindSafe},
110 sync::{
111 atomic::{AtomicBool, AtomicUsize, Ordering},
112 mpsc::{channel, Sender},
113 Arc, Condvar, Mutex,
114 },
115 thread::{self, available_parallelism},
116};
117
118type ThreadPoolFunctionBoxed = Box<dyn FnOnce() + Send + UnwindSafe>;
119
120/// Simple error to indicate that a job has paniced in the threadpool
121#[derive(Debug)]
122pub struct JobHasPanicedError {}
123
124impl Display for JobHasPanicedError {
125 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126 write!(f, "At least one job in the threadpool has caused a panic")
127 }
128}
129
130impl Error for JobHasPanicedError {}
131
132// /// Simple error to indicate a function passed to do_until_finished has paniced
133// #[derive(Debug)]
134// pub struct DoUntilFinishedFunctionPanicedError {}
135
136// impl Display for DoUntilFinishedFunctionPanicedError {
137// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138// write!(f, "The function passed to do_until_finished has paniced")
139// }
140// }
141
142// impl Error for DoUntilFinishedFunctionPanicedError {}
143
144// /// An enum to combine both errors previously defined
145// #[derive(Debug)]
146// pub enum Errors {
147// /// Enum representation of [`JobHasPanicedError`]
148// JobHasPanicedError,
149// /// Enum representation of [`DoUntilFinishedFunctionPanicedError`]
150// DoUntilFinishedFunctionPanicedError,
151// }
152
153// impl Display for Errors {
154// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
155// match self {
156// Errors::DoUntilFinishedFunctionPanicedError => {
157// Display::fmt(&DoUntilFinishedFunctionPanicedError {}, f)
158// }
159// Errors::JobHasPanicedError => Display::fmt(&JobHasPanicedError {}, f),
160// }
161// }
162// }
163
164// impl Error for Errors {}
165
166#[derive(Debug, Default)]
167struct SharedState {
168 jobs_queued: AtomicUsize,
169 jobs_running: AtomicUsize,
170 jobs_paniced: AtomicUsize,
171 is_finished: Mutex<bool>,
172 has_paniced: AtomicBool,
173}
174
175impl Display for SharedState {
176 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
177 write!(
178 f,
179 "SharedState<jobs_queued: {}, jobs_running: {}, jobs_paniced: {}, is_finished: {}, has_paniced: {}>",
180 self.jobs_queued.load(Ordering::Relaxed),
181 self.jobs_running.load(Ordering::Relaxed),
182 self.jobs_paniced.load(Ordering::Relaxed),
183 self.is_finished.lock().expect("Shared state should never panic"),
184 self.has_paniced.load(Ordering::Relaxed)
185 )
186 }
187}
188
189impl SharedState {
190 fn new() -> Self {
191 Self {
192 jobs_running: AtomicUsize::new(0),
193 jobs_queued: AtomicUsize::new(0),
194 jobs_paniced: AtomicUsize::new(0),
195 is_finished: Mutex::new(true),
196 has_paniced: AtomicBool::new(false),
197 }
198 }
199
200 fn job_starting(&self) {
201 debug_assert!(
202 self.jobs_queued.load(Ordering::Acquire) > 0,
203 "Negative jobs queued"
204 );
205
206 self.jobs_running.fetch_add(1, Ordering::SeqCst);
207 self.jobs_queued.fetch_sub(1, Ordering::SeqCst);
208 }
209
210 fn job_finished(&self) {
211 debug_assert!(
212 self.jobs_running.load(Ordering::Acquire) > 0,
213 "Negative jobs running"
214 );
215
216 self.jobs_running.fetch_sub(1, Ordering::SeqCst);
217
218 if self.jobs_queued.load(Ordering::Acquire) == 0
219 && self.jobs_running.load(Ordering::Acquire) == 0
220 {
221 let mut is_finished = self
222 .is_finished
223 .lock()
224 .expect("Shared state should never panic");
225
226 *is_finished = true;
227 }
228 }
229
230 fn job_queued(&self) {
231 self.jobs_queued.fetch_add(1, Ordering::SeqCst);
232
233 let mut is_finished = self
234 .is_finished
235 .lock()
236 .expect("Shared state should never panic");
237
238 *is_finished = false;
239 }
240
241 fn job_paniced(&self) {
242 println!("Checking panic");
243
244 self.has_paniced.store(true, Ordering::SeqCst);
245 self.jobs_paniced.fetch_add(1, Ordering::SeqCst);
246
247 println!("Has paniced {}", self.has_paniced.load(Ordering::Acquire));
248 }
249}
250
251/// Threadpool abstraction to keep some state
252#[derive(Debug)]
253pub struct ThreadPool {
254 thread_amount: NonZeroUsize,
255 job_sender: Arc<Sender<ThreadPoolFunctionBoxed>>,
256 shared_state: Arc<SharedState>,
257 cvar: Arc<Condvar>,
258}
259
260impl Clone for ThreadPool {
261 fn clone(&self) -> Self {
262 Self {
263 thread_amount: self.thread_amount,
264 job_sender: self.job_sender.clone(),
265 shared_state: self.shared_state.clone(),
266 cvar: self.cvar.clone(),
267 }
268 }
269}
270
271impl Display for ThreadPool {
272 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
273 write!(
274 f,
275 "Threadpool< thread_amount: {}, shared_state: {}>",
276 self.thread_amount, self.shared_state
277 )
278 }
279}
280
281impl ThreadPool {
282 fn new(builder: ThreadPoolBuilder) -> io::Result<Self> {
283 let thread_amount = builder.thread_amount;
284
285 let (job_sender, job_receiver) = channel::<ThreadPoolFunctionBoxed>();
286 let job_sender = Arc::new(job_sender);
287 let shareable_job_reciever = Arc::new(Mutex::new(job_receiver));
288
289 let shared_state = Arc::new(SharedState::new());
290 let cvar = Arc::new(Condvar::new());
291
292 for thread_num in 0..thread_amount.get() {
293 let job_reciever = shareable_job_reciever.clone();
294
295 let thread_name = format!("Threadpool worker {thread_num}");
296
297 thread::Builder::new().name(thread_name).spawn(move || {
298 loop {
299 let job = {
300 let lock = job_reciever //
301 .lock()
302 .expect("Cannot get reciever");
303
304 lock.recv()
305 };
306
307 // NOTE: Breaking on error ensures that all threads will stop
308 // when the threadpool is dropped and all jobs have been executed
309 match job {
310 Ok(job) => job(),
311 Err(_) => break,
312 };
313 }
314 })?;
315 }
316
317 Ok(Self {
318 thread_amount,
319 job_sender,
320 shared_state,
321 cvar,
322 })
323 }
324
325 /// The `send_job` function takes in a function or closure without any arguments
326 /// and sends it to the threadpool to be executed. Jobs will be taken from the
327 /// job queue in order of them being sent, but that in no way guarantees they will
328 /// be executed in order.
329 ///
330 /// `job`s must implement `Send` in order to be safely sent across threads and
331 /// `UnwindSafe` to allow catching panics when executing the jobs. Both of these
332 /// traits are auto implemented.
333 ///
334 /// # Examples
335 ///
336 /// Sending a function or closure to the threadpool
337 /// ```rust
338 /// # use std::error::Error;
339 /// # fn main() -> Result<(), Box<dyn Error>> {
340 /// use easy_threadpool::ThreadPoolBuilder;
341 ///
342 /// fn job() {
343 /// println!("Hello world from a function!");
344 /// }
345 ///
346 /// let builder = ThreadPoolBuilder::with_max_threads()?;
347 /// let pool = builder.build()?;
348 ///
349 /// pool.send_job(job);
350 ///
351 /// pool.send_job(|| println!("Hello world from a closure!"));
352 /// # Ok(())
353 /// # }
354 /// ```
355 pub fn send_job(&self, job: impl FnOnce() + Send + UnwindSafe + 'static) {
356 // NOTE: It is essential that the shared state is updated FIRST otherwise
357 // we have a race condidition that the job is transmitted and read before
358 // the shared state is updated, leading to a negative amount of jobs queued
359 self.shared_state.job_queued();
360
361 debug_assert!(self.jobs_queued() > 0, "Job didn't queue properly");
362 debug_assert!(!self.is_finished(), "Finish wasn't properly set to false");
363
364 // Pass our own state to the job. This makes it so that multiple threadpools
365 // with different states can send jobs to the same threads without getting
366 // eachothers panics for example
367 let state = self.shared_state.clone();
368 let cvar = self.cvar.clone();
369 let job_with_state = Self::job_function(Box::new(job), state, cvar);
370
371 self.job_sender
372 .send(Box::new(job_with_state))
373 .expect("The sender cannot be deallocated while the threadpool is in use")
374 }
375
376 fn job_function(
377 job: ThreadPoolFunctionBoxed,
378 state: Arc<SharedState>,
379 cvar: Arc<Condvar>,
380 ) -> impl FnOnce() + Send + 'static {
381 move || {
382 state.job_starting();
383
384 // NOTE: The use of catch_unwind means that the thread will not
385 // panic from any of the jobs it was sent. This is useful because
386 // we won't ever have to restart a thread.
387 let result = catch_unwind(job);
388
389 println!("{result:?}");
390
391 // NOTE: Do the panic check first otherwise we have a race condition
392 // where the final job panics and the wait_until_finished function
393 // doesn't detect it
394 if result.is_err() {
395 state.job_paniced();
396 }
397
398 state.job_finished();
399
400 cvar.notify_all();
401 }
402 }
403
404 /// This function will wait until all jobs have finished sending. Additionally
405 /// it will return early if any job panics.
406 ///
407 /// Be careful though, returning early DOES NOT mean that the sent jobs are
408 /// cancelled. They will remain running. Cancelling jobs that are queued is not
409 /// a feature provided by this crate as of now.
410 ///
411 /// # Errors
412 ///
413 /// This function will error if any job sent to the threadpool has errored.
414 /// This includes any errors since either the threadpool was created or since
415 /// the state was reset.
416 ///
417 /// # Examples
418 ///
419 /// ```rust
420 /// # use std::error::Error;
421 /// # fn main() -> Result<(), Box<dyn Error>> {
422 /// use easy_threadpool::ThreadPoolBuilder;
423 ///
424 /// let builder = ThreadPoolBuilder::with_max_threads()?;
425 /// let pool = builder.build()?;
426 ///
427 /// for _ in 0..10 {
428 /// pool.send_job(|| println!("Hello world"));
429 /// }
430 ///
431 /// assert!(pool.wait_until_finished().is_ok());
432 /// assert!(pool.is_finished());
433 ///
434 /// pool.send_job(|| panic!("Test panic"));
435 ///
436 /// assert!(pool.wait_until_finished().is_err());
437 /// assert!(pool.has_paniced());
438 /// # Ok(())
439 /// # }
440 /// ```
441 pub fn wait_until_finished(&self) -> Result<(), JobHasPanicedError> {
442 let mut is_finished = self
443 .shared_state
444 .is_finished
445 .lock()
446 .expect("Shared state should never panic");
447
448 while !*is_finished && !self.has_paniced() {
449 is_finished = self
450 .cvar
451 .wait(is_finished)
452 .expect("Shared state should never panic");
453 }
454
455 println!("panic {}", self.has_paniced());
456
457 debug_assert!(
458 self.has_paniced() || self.jobs_running() == 0,
459 "wait_until_finished stopped {} jobs running and {} panics",
460 self.jobs_running(),
461 self.jobs_paniced()
462 );
463 debug_assert!(
464 self.has_paniced() || self.jobs_queued() == 0,
465 "wait_until_finished stopped while {} jobs queued and {} panics",
466 self.jobs_queued(),
467 self.jobs_paniced()
468 );
469
470 println!("WERE DONE WAITING");
471
472 match self.shared_state.has_paniced.load(Ordering::Acquire) {
473 true => Err(JobHasPanicedError {}),
474 false => Ok(()),
475 }
476 }
477
478 /// This function will wait until one job finished after calling the function.
479 /// Additionally, if the threadpool is finished this function will also return.
480 /// Additionally it will return early if any job panics.
481 ///
482 /// Be careful though, returning early DOES NOT mean that the sent jobs are
483 /// cancelled. They will remain running. Cancelling jobs that are queued is not
484 /// a feature provided by this crate as of now.
485 ///
486 /// # Errors
487 ///
488 /// This function will error if any job sent to the threadpool has errored.
489 /// This includes any errors since either the threadpool was created or since
490 /// the state was reset.
491 ///
492 /// # Examples
493 ///
494 /// ```rust
495 /// # use std::error::Error;
496 /// # fn main() -> Result<(), Box<dyn Error>> {
497 /// use easy_threadpool::ThreadPoolBuilder;
498 ///
499 /// let builder = ThreadPoolBuilder::with_max_threads()?;
500 /// let pool = builder.build()?;
501 ///
502 /// assert!(pool.wait_until_job_done().is_ok());
503 /// assert!(pool.is_finished());
504 ///
505 /// pool.send_job(|| panic!("Test panic"));
506 ///
507 /// assert!(pool.wait_until_job_done().is_err());
508 /// assert!(pool.has_paniced());
509 /// # Ok(())
510 /// # }
511 /// ```
512 pub fn wait_until_job_done(&self) -> Result<(), JobHasPanicedError> {
513 fn paniced(state: &SharedState) -> bool {
514 state.jobs_paniced.load(Ordering::Acquire) != 0
515 }
516
517 let is_finished = self
518 .shared_state
519 .is_finished
520 .lock()
521 .expect("Shared state should never panic");
522
523 if *is_finished {
524 return Ok(());
525 };
526
527 drop(self.cvar.wait(is_finished));
528
529 // Keep the guard so we don't have to drop the lock only to reaquire it
530 if paniced(&self.shared_state) {
531 Err(JobHasPanicedError {})
532 } else {
533 Ok(())
534 }
535 }
536
537 /// This function will wait until all jobs have finished sending. It will continue
538 /// waiting if a job panics in the thread pool.
539 ///
540 /// I highly doubt this has much of a performance improvement, but it's very
541 /// useful if you know that for whatever reason your jobs might panic and that
542 /// would be fine.
543 ///
544 /// # Examples
545 ///
546 /// ```rust
547 /// # use std::error::Error;
548 /// # fn main() -> Result<(), Box<dyn Error>> {
549 /// use easy_threadpool::ThreadPoolBuilder;
550 ///
551 /// let builder = ThreadPoolBuilder::with_max_threads()?;
552 /// let pool = builder.build()?;
553 ///
554 /// for _ in 0..10 {
555 /// pool.send_job(|| println!("Hello world"));
556 /// }
557 ///
558 /// pool.wait_until_finished_unchecked();
559 /// assert!(pool.is_finished());
560 ///
561 /// pool.send_job(|| panic!("Test panic"));
562 ///
563 /// pool.wait_until_finished_unchecked();
564 /// assert!(pool.has_paniced());
565 /// # Ok(())
566 /// # }
567 /// ```
568 pub fn wait_until_finished_unchecked(&self) {
569 let mut is_finished = self
570 .shared_state
571 .is_finished
572 .lock()
573 .expect("Shared state sould never panic");
574
575 if *is_finished {
576 return;
577 }
578
579 while !*is_finished {
580 is_finished = self
581 .cvar
582 .wait(is_finished)
583 .expect("Shared state should never panic")
584 }
585
586 debug_assert!(
587 self.shared_state.jobs_running.load(Ordering::Acquire) == 0,
588 "Job still running after wait_until_finished_unchecked"
589 );
590 debug_assert!(
591 self.shared_state.jobs_queued.load(Ordering::Acquire) == 0,
592 "Job still queued after wait_until_finished_unchecked"
593 );
594 }
595
596 /// This function will wait until one job finished after calling the function.
597 /// Additionally, if the threadpool is finished this function will also return.
598 ///
599 /// Be careful though, returning early DOES NOT mean that the sent jobs are
600 /// cancelled. They will remain running. Cancelling jobs that are queued is not
601 /// a feature provided by this crate as of now.
602 ///
603 /// # Examples
604 ///
605 /// ```rust
606 /// # use std::error::Error;
607 /// # fn main() -> Result<(), Box<dyn Error>> {
608 /// use easy_threadpool::ThreadPoolBuilder;
609 ///
610 /// let builder = ThreadPoolBuilder::with_max_threads()?;
611 /// let pool = builder.build()?;
612 ///
613 /// assert!(pool.wait_until_job_done().is_ok());
614 /// assert!(pool.is_finished());
615 ///
616 /// pool.send_job(|| panic!("Test panic"));
617 ///
618 /// assert!(pool.wait_until_job_done().is_err());
619 /// assert!(pool.has_paniced());
620 /// # Ok(())
621 /// # }
622 /// ```
623 pub fn wait_until_job_done_unchecked(&self) {
624 let is_finished = self
625 .shared_state
626 .is_finished
627 .lock()
628 .expect("Shared state should never panic");
629
630 // This is guaranteed to work because jobs cannot finish without having
631 // the shared state lock, and we keep the lock until we start waiting for
632 // the condvar
633 if *is_finished {
634 return;
635 };
636
637 drop(self.cvar.wait(is_finished));
638 }
639
640 /// This function will reset the state of this instance of the threadpool.
641 ///
642 /// When resetting the state you lose all information about previously sent jobs.
643 /// If a job you previously sent panics, you will not be notified, nor can you
644 /// wait until your previously sent jobs are done running. HOWEVER they will still
645 /// be running. Be very careful to not see this as a "stop" button.
646 ///
647 /// # Examples
648 ///
649 /// ```rust
650 /// # use std::error::Error;
651 /// # fn main() -> Result<(), Box<dyn Error>> {
652 /// use easy_threadpool::ThreadPoolBuilder;
653 ///
654 /// let builder = ThreadPoolBuilder::with_max_threads()?;
655 /// let mut pool = builder.build()?;
656 ///
657 /// pool.send_job(|| panic!("Test panic"));
658 ///
659 /// assert!(pool.wait_until_finished().is_err());
660 /// assert!(pool.has_paniced());
661 ///
662 /// pool.reset_state();
663 ///
664 /// assert!(pool.wait_until_finished().is_ok());
665 /// assert!(!pool.has_paniced());
666 /// # Ok(())
667 /// # }
668 /// ```
669 pub fn reset_state(&mut self) {
670 let cvar = Arc::new(Condvar::new());
671 let shared_state = Arc::new(SharedState::new());
672
673 self.cvar = cvar;
674 self.shared_state = shared_state;
675 }
676
677 /// This function will clone the threadpool and then reset its state. This
678 /// makes it so you can have 2 different states operate on the same threads,
679 /// effectively sharing the threads.
680 ///
681 /// Note however that there is no mechanism
682 /// to give different instances equal CPU time, jobs are executed on a first
683 /// come first server basis.
684 ///
685 /// # Examples
686 ///
687 /// ```rust
688 /// # use std::error::Error;
689 /// # fn main() -> Result<(), Box<dyn Error>> {
690 /// use easy_threadpool::ThreadPoolBuilder;
691 ///
692 /// let builder = ThreadPoolBuilder::with_max_threads()?;
693 /// let pool = builder.build()?;
694 ///
695 /// let pool_clone = pool.clone_with_new_state();
696 ///
697 /// pool.send_job(|| panic!("Test panic"));
698 ///
699 /// assert!(pool.wait_until_finished().is_err());
700 /// assert!(pool.has_paniced());
701 ///
702 /// assert!(pool_clone.wait_until_finished().is_ok());
703 /// assert!(!pool_clone.has_paniced());
704 /// # Ok(())
705 /// # }
706 /// ```
707 pub fn clone_with_new_state(&self) -> Self {
708 let mut new_pool = self.clone();
709 new_pool.reset_state();
710 new_pool
711 }
712
713 /// Returns the amount of jobs currently being ran by this instance of the
714 /// thread pool. If muliple different instances of this threadpool (see [`clone_with_new_state`])
715 /// this number might be lower than the max amount of threads, even if there
716 /// are still jobs queued
717 ///
718 /// # Examples
719 ///
720 /// ```rust
721 /// # use std::error::Error;
722 /// # fn main() -> Result<(), Box<dyn Error>> {
723 /// use easy_threadpool::ThreadPoolBuilder;
724 /// use std::{
725 /// num::NonZeroUsize,
726 /// sync::{Arc, Barrier},
727 /// };
728 /// let threads = 16;
729 /// let tasks = threads * 10;
730 ///
731 /// let num = NonZeroUsize::try_from(threads)?;
732 /// let pool = ThreadPoolBuilder::with_thread_amount(num).build()?;
733 ///
734 /// let b0 = Arc::new(Barrier::new(threads + 1));
735 /// let b1 = Arc::new(Barrier::new(threads + 1));
736 ///
737 /// for i in 0..tasks {
738 /// let b0_copy = b0.clone();
739 /// let b1_copy = b1.clone();
740 ///
741 /// pool.send_job(move || {
742 /// if i < threads {
743 /// b0_copy.wait();
744 /// b1_copy.wait();
745 /// }
746 /// });
747 /// }
748 ///
749 /// b0.wait();
750 /// assert_eq!(pool.jobs_running(), threads);
751 /// # b1.wait();
752 /// # Ok(())
753 /// # }
754 /// ```
755 pub fn jobs_running(&self) -> usize {
756 self.shared_state.jobs_running.load(Ordering::Acquire)
757 }
758
759 /// Returns the amount of jobs currently queued by this threadpool instance.
760 /// There might be more jobs queued that we don't know about if there are other
761 /// instances of this threadpool (see [`clone_with_new_state`]).
762 ///
763 /// # Examples
764 ///
765 /// ```rust
766 /// # use std::error::Error;
767 /// # fn main() -> Result<(), Box<dyn Error>> {
768 /// use easy_threadpool::ThreadPoolBuilder;
769 /// use std::{
770 /// num::NonZeroUsize,
771 /// sync::{Arc, Barrier},
772 /// };
773 /// let threads = 16;
774 /// let tasks = 100;
775 ///
776 /// let num = NonZeroUsize::try_from(threads)?;
777 /// let pool = ThreadPoolBuilder::with_thread_amount(num).build()?;
778 ///
779 /// let b0 = Arc::new(Barrier::new(threads + 1));
780 /// let b1 = Arc::new(Barrier::new(threads + 1));
781 ///
782 /// for i in 0..tasks {
783 /// let b0_copy = b0.clone();
784 /// let b1_copy = b1.clone();
785 ///
786 /// pool.send_job(move || {
787 /// if i < threads {
788 /// b0_copy.wait();
789 /// b1_copy.wait();
790 /// }
791 /// });
792 /// }
793 ///
794 /// b0.wait();
795 /// assert_eq!(pool.jobs_queued(), tasks - threads);
796 /// # b1.wait();
797 /// # Ok(())
798 /// # }
799 /// ```
800 pub fn jobs_queued(&self) -> usize {
801 self.shared_state.jobs_queued.load(Ordering::Acquire)
802 }
803
804 /// Returns the amount of jobs that were sent by this instance of the threadpool
805 /// and that paniced.
806 ///
807 /// # Examples
808 ///
809 /// ```rust
810 /// # use std::error::Error;
811 /// # fn main() -> Result<(), Box<dyn Error>> {
812 /// use easy_threadpool::ThreadPoolBuilder;
813 ///
814 /// let pool = ThreadPoolBuilder::with_max_threads()?.build()?;
815 ///
816 /// for i in 0..10 {
817 /// pool.send_job(|| panic!("Test panic"));
818 /// }
819 ///
820 /// pool.wait_until_finished_unchecked();
821 ///
822 /// assert_eq!(pool.jobs_paniced(), 10);
823 /// # Ok(())
824 /// # }
825 /// ```
826 pub fn jobs_paniced(&self) -> usize {
827 self.shared_state.jobs_paniced.load(Ordering::Acquire)
828 }
829
830 /// Returns whether a thread has had any jobs panic at all
831 ///
832 /// # Examples
833 ///
834 /// ```rust
835 /// # use std::error::Error;
836 /// # fn main() -> Result<(), Box<dyn Error>> {
837 /// use easy_threadpool::ThreadPoolBuilder;
838 ///
839 /// let pool = ThreadPoolBuilder::with_max_threads()?.build()?;
840 ///
841 /// pool.send_job(|| panic!("Test panic"));
842 ///
843 /// pool.wait_until_finished_unchecked();
844 ///
845 /// assert!(pool.has_paniced());
846 /// # Ok(())
847 /// # }
848 /// ```
849 pub fn has_paniced(&self) -> bool {
850 self.shared_state.has_paniced.load(Ordering::Acquire)
851 }
852
853 /// Returns whether a threadpool instance has no jobs running and no jobs queued,
854 /// in other words if it's finished.
855 ///
856 /// # Examples
857 ///
858 /// ```rust
859 /// # use std::error::Error;
860 /// # fn main() -> Result<(), Box<dyn Error>> {
861 /// use easy_threadpool::ThreadPoolBuilder;
862 /// use std::{
863 /// num::NonZeroUsize,
864 /// sync::{Arc, Barrier},
865 /// };
866 /// let pool = ThreadPoolBuilder::with_max_threads()?.build()?;
867 ///
868 /// let b = Arc::new(Barrier::new(2));
869 ///
870 /// assert!(pool.is_finished());
871 ///
872 /// let b_clone = b.clone();
873 /// pool.send_job(move || { b_clone.wait(); });
874 ///
875 /// assert!(!pool.is_finished());
876 /// # b.wait();
877 /// # Ok(())
878 /// # }
879 /// ```
880 pub fn is_finished(&self) -> bool {
881 *self
882 .shared_state
883 .is_finished
884 .lock()
885 .expect("Shared state should never panic")
886 }
887
888 /// This function returns the amount of threads used to create the threadpool
889 ///
890 /// # Examples
891 ///
892 /// ```rust
893 /// # use std::error::Error;
894 /// # fn main() -> Result<(), Box<dyn Error>> {
895 /// use easy_threadpool::ThreadPoolBuilder;
896 /// use std::num::NonZeroUsize;
897 ///
898 /// let threads = 10;
899 ///
900 /// let num = NonZeroUsize::try_from(threads)?;
901 /// let pool = ThreadPoolBuilder::with_thread_amount(num).build()?;
902 ///
903 /// assert_eq!(pool.threads().get(), threads);
904 /// # Ok(())
905 /// # }
906 /// ```
907 pub const fn threads(&self) -> NonZeroUsize {
908 self.thread_amount
909 }
910}
911
912/// A ThreadPoolbuilder is a builder to easily create a thread pool
913pub struct ThreadPoolBuilder {
914 thread_amount: NonZeroUsize,
915 // thread_name: Option<String>,
916}
917
918impl Default for ThreadPoolBuilder {
919 fn default() -> Self {
920 Self {
921 thread_amount: NonZeroUsize::try_from(1).unwrap(),
922 }
923 }
924}
925
926impl ThreadPoolBuilder {
927 /// Initialize the amount of threads the builder will build to `thread_amount`
928 pub fn with_thread_amount(thread_amount: NonZeroUsize) -> ThreadPoolBuilder {
929 ThreadPoolBuilder { thread_amount }
930 }
931
932 /// Initialize the amount of threads the builder will build to `thread_amount`.
933 ///
934 /// # Errors
935 ///
936 /// If `thread_amount` cannot be converted to a [`std::num::NonZeroUsize`] (aka it is 0).
937 pub fn with_thread_amount_usize(
938 thread_amount: usize,
939 ) -> Result<ThreadPoolBuilder, TryFromIntError> {
940 let thread_amount = NonZeroUsize::try_from(thread_amount)?;
941 Ok(Self::with_thread_amount(thread_amount))
942 }
943
944 /// Initialize the amount of threads the builder will build to the available parallelism
945 /// as provided by [`std::thread::available_parallelism`]
946 ///
947 /// # Errors
948 ///
949 /// Taken from the available_parallelism() documentation:
950 /// This function will, but is not limited to, return errors in the following
951 /// cases:
952 ///
953 /// * If the amount of parallelism is not known for the target platform.
954 /// * If the program lacks permission to query the amount of parallelism made
955 /// available to it.
956 ///
957 pub fn with_max_threads() -> io::Result<ThreadPoolBuilder> {
958 let max_threads = available_parallelism()?;
959 Ok(ThreadPoolBuilder {
960 thread_amount: max_threads,
961 })
962 }
963
964 // pub fn with_thread_name(thread_name: String) -> ThreadPoolBuilder {
965 // ThreadPoolBuilder {
966 // thread_name: Some(thread_name),
967 // ..Default::default()
968 // }
969 // }
970
971 /// Set the thead amount in the builder
972 pub fn set_thread_amount(mut self, thread_amount: NonZeroUsize) -> ThreadPoolBuilder {
973 self.thread_amount = thread_amount;
974 self
975 }
976
977 /// Set the thead amount in the builder from usize
978 ///
979 /// # Errors
980 ///
981 /// If `thread_amount` cannot be turned into NonZeroUsize (aka it is 0)
982 pub fn set_thread_amount_usize(
983 self,
984 thread_amount: usize,
985 ) -> Result<ThreadPoolBuilder, TryFromIntError> {
986 let thread_amount = NonZeroUsize::try_from(thread_amount)?;
987 Ok(self.set_thread_amount(thread_amount))
988 }
989
990 /// set the amount of threads the builder will build to the available parallelism
991 /// as provided by [`std::thread::available_parallelism`]
992 ///
993 /// # Errors
994 ///
995 /// Taken from the available_parallelism() documentation:
996 /// This function will, but is not limited to, return errors in the following
997 /// cases:
998 ///
999 /// * If the amount of parallelism is not known for the target platform.
1000 /// * If the program lacks permission to query the amount of parallelism made
1001 /// available to it.
1002 ///
1003 pub fn set_max_threads(mut self) -> io::Result<ThreadPoolBuilder> {
1004 let max_threads = available_parallelism()?;
1005 self.thread_amount = max_threads;
1006 Ok(self)
1007 }
1008
1009 // pub fn set_thread_name(mut self, thread_name: String) -> ThreadPoolBuilder {
1010 // self.thread_name = Some(thread_name);
1011 // self
1012 // }
1013
1014 /// Build the builder into a threadpool, taking all the initialized values
1015 /// from the builder and using defaults for those not initialized.
1016 ///
1017 /// # Errors
1018 ///
1019 /// Taken from [`std::thread::Builder::spawn`]:
1020 ///
1021 /// Unlike the [`spawn`](https://doc.rust-lang.org/stable/std/thread/fn.spawn.html) free function, this method yields an
1022 /// [`io::Result`] to capture any failure to create the thread at
1023 /// the OS level.
1024 pub fn build(self) -> io::Result<ThreadPool> {
1025 ThreadPool::new(self)
1026 }
1027}
1028
1029#[cfg(test)]
1030mod test {
1031 use core::panic;
1032 use std::{
1033 num::NonZeroUsize,
1034 sync::{mpsc::channel, Arc, Barrier},
1035 thread::sleep,
1036 time::Duration,
1037 };
1038
1039 use crate::ThreadPoolBuilder;
1040
1041 #[test]
1042 // Test multiple panics on a single thread, this ensures that a thread can
1043 // handle panics
1044 fn deal_with_panics() {
1045 fn panic_fn() {
1046 panic!("Test panic");
1047 }
1048
1049 let thread_num: NonZeroUsize = 1.try_into().unwrap();
1050 let builder = ThreadPoolBuilder::with_thread_amount(thread_num);
1051
1052 let pool = builder.build().unwrap();
1053
1054 for _ in 0..10 {
1055 pool.send_job(panic_fn);
1056 }
1057
1058 assert!(
1059 pool.wait_until_finished().is_err(),
1060 "Pool didn't detect panic in wait_until_finished"
1061 );
1062
1063 assert!(
1064 pool.has_paniced(),
1065 "Pool didn't detect panic in has_paniced"
1066 );
1067 pool.wait_until_finished_unchecked();
1068
1069 assert!(
1070 pool.jobs_queued() == 0,
1071 "Incorrect amount of jobs queued after wait"
1072 );
1073 assert!(
1074 pool.jobs_running() == 0,
1075 "Incorrect amount of jobs running after wait"
1076 );
1077 assert!(
1078 pool.jobs_paniced() == 10,
1079 "Incorrect amount of jobs paniced after wait"
1080 );
1081 }
1082
1083 #[test]
1084 fn receive_value() {
1085 let (tx, rx) = channel::<u32>();
1086
1087 let func = move || {
1088 tx.send(69).unwrap();
1089 };
1090
1091 let pool = ThreadPoolBuilder::default().build().unwrap();
1092
1093 pool.send_job(func);
1094
1095 assert_eq!(rx.recv(), Ok(69), "Incorrect value received");
1096 }
1097
1098 #[test]
1099 fn test_wait() {
1100 const TASKS: usize = 1000;
1101 const THREADS: usize = 16;
1102
1103 let b0 = Arc::new(Barrier::new(THREADS + 1));
1104 let b1 = Arc::new(Barrier::new(THREADS + 1));
1105
1106 let pool = ThreadPoolBuilder::with_thread_amount_usize(THREADS)
1107 .unwrap()
1108 .build()
1109 .unwrap();
1110
1111 for i in 0..TASKS {
1112 let b0 = b0.clone();
1113 let b1 = b1.clone();
1114
1115 pool.send_job(move || {
1116 if i < THREADS {
1117 b0.wait();
1118 b1.wait();
1119 }
1120 });
1121 }
1122
1123 b0.wait();
1124
1125 assert_eq!(
1126 pool.jobs_running(),
1127 THREADS,
1128 "Incorrect amount of jobs running"
1129 );
1130 assert_eq!(
1131 pool.jobs_paniced(),
1132 0,
1133 "Incorrect amount of threads paniced"
1134 );
1135
1136 b1.wait();
1137
1138 assert!(
1139 pool.wait_until_finished().is_ok(),
1140 "wait_until_finished incorrectly detected a panic"
1141 );
1142
1143 assert_eq!(
1144 pool.jobs_queued(),
1145 0,
1146 "Incorrect amount of jobs queued after wait"
1147 );
1148 assert_eq!(
1149 pool.jobs_running(),
1150 0,
1151 "Incorrect amount of jobs running after wait"
1152 );
1153 assert_eq!(
1154 pool.jobs_paniced(),
1155 0,
1156 "Incorrect amount of threads paniced after wait"
1157 );
1158 }
1159
1160 #[test]
1161 fn test_wait_unchecked() {
1162 const TASKS: usize = 1000;
1163 const THREADS: usize = 16;
1164
1165 let b0 = Arc::new(Barrier::new(THREADS + 1));
1166 let b1 = Arc::new(Barrier::new(THREADS + 1));
1167
1168 let builder = ThreadPoolBuilder::with_thread_amount_usize(THREADS).unwrap();
1169 let pool = builder.build().unwrap();
1170
1171 for i in 0..TASKS {
1172 let b0 = b0.clone();
1173 let b1 = b1.clone();
1174
1175 pool.send_job(move || {
1176 if i < THREADS {
1177 b0.wait();
1178 b1.wait();
1179 }
1180 panic!("Test panic");
1181 });
1182 }
1183
1184 b0.wait();
1185
1186 assert_eq!(
1187 pool.jobs_running(),
1188 THREADS,
1189 "Incorrect amount of jobs running"
1190 );
1191 assert_eq!(pool.jobs_paniced(), 0);
1192
1193 b1.wait();
1194
1195 pool.wait_until_finished_unchecked();
1196
1197 assert_eq!(pool.jobs_queued(), 0);
1198 assert_eq!(pool.jobs_running(), 0);
1199 assert_eq!(pool.jobs_paniced(), TASKS);
1200 }
1201
1202 #[test]
1203 fn test_clones() {
1204 const TASKS: usize = 1000;
1205 const THREADS: usize = 16;
1206
1207 let pool = ThreadPoolBuilder::with_thread_amount_usize(THREADS)
1208 .unwrap()
1209 .build()
1210 .unwrap();
1211 let clone = pool.clone();
1212 let clone_with_new_state = pool.clone_with_new_state();
1213
1214 let b0 = Arc::new(Barrier::new(THREADS + 1));
1215 let b1 = Arc::new(Barrier::new(THREADS + 1));
1216
1217 for i in 0..TASKS {
1218 let b0_copy = b0.clone();
1219 let b1_copy = b1.clone();
1220
1221 pool.send_job(move || {
1222 if i < THREADS / 2 {
1223 b0_copy.wait();
1224 b1_copy.wait();
1225 }
1226 });
1227
1228 let b0_copy = b0.clone();
1229 let b1_copy = b1.clone();
1230
1231 clone_with_new_state.send_job(move || {
1232 if i < THREADS / 2 {
1233 b0_copy.wait();
1234 b1_copy.wait();
1235 }
1236 panic!("Test panic")
1237 });
1238 }
1239
1240 b0.wait();
1241
1242 // The /2 is guaranteed because jobs are received in order
1243 assert_eq!(
1244 pool.jobs_running(),
1245 THREADS / 2,
1246 "Incorrect amount of jobs running in pool"
1247 );
1248 assert_eq!(
1249 pool.jobs_paniced(),
1250 0,
1251 "Incorrect amount of jobs paniced in pool"
1252 );
1253
1254 // The /2 is guaranteed because jobs are received in order
1255 assert_eq!(
1256 clone_with_new_state.jobs_running(),
1257 THREADS / 2,
1258 "Incorrect amount of jobs running in clone_with_new_state"
1259 );
1260 assert_eq!(
1261 clone_with_new_state.jobs_paniced(),
1262 0,
1263 "Incorrect amount of jobs paniced in clone_with_new_state"
1264 );
1265
1266 b1.wait();
1267 assert!(
1268 clone_with_new_state.wait_until_finished().is_err(),
1269 "Clone with new state didn't detect panic"
1270 );
1271
1272 assert!(
1273 clone.wait_until_finished().is_ok(),
1274 "Pool incorrectly detected panic"
1275 );
1276
1277 assert_eq!(
1278 pool.jobs_queued(),
1279 0,
1280 "Incorrect amount of jobs queued in pool after wait"
1281 );
1282 assert_eq!(
1283 pool.jobs_running(),
1284 0,
1285 "Incorrect amount of jobs running in pool after wait"
1286 );
1287 assert_eq!(
1288 pool.jobs_paniced(),
1289 0,
1290 "Incorrect amount of jobs paniced in pool after wait"
1291 );
1292
1293 clone_with_new_state.wait_until_finished_unchecked();
1294 assert!(
1295 clone_with_new_state.wait_until_finished().is_err(),
1296 "clone_with_new_state didn't detect panics after wait"
1297 );
1298
1299 assert_eq!(
1300 clone_with_new_state.jobs_queued(),
1301 0,
1302 "Incorrect amount of jobs queued in clone_with_new_state after wait"
1303 );
1304 assert_eq!(
1305 clone_with_new_state.jobs_running(),
1306 0,
1307 "Incorrect amount of jobs running in clone_with_new_state after wait"
1308 );
1309 assert_eq!(
1310 clone_with_new_state.jobs_paniced(),
1311 TASKS,
1312 "Incorrect panics in clone"
1313 );
1314
1315 assert_eq!(
1316 pool.jobs_queued(),
1317 0,
1318 "Incorrect amount of jobs queued in pool after everything"
1319 );
1320 assert_eq!(
1321 pool.jobs_running(),
1322 0,
1323 "Incorrect amount of jobs running in pool after everything"
1324 );
1325 assert_eq!(
1326 pool.jobs_paniced(),
1327 0,
1328 "Incorrect amount of jobs paniced in pool after everything"
1329 );
1330 }
1331
1332 #[test]
1333 fn reset_state_while_running() {
1334 const TASKS: usize = 32;
1335 const THREADS: usize = 16;
1336
1337 let mut pool = ThreadPoolBuilder::with_thread_amount_usize(THREADS)
1338 .unwrap()
1339 .build()
1340 .unwrap();
1341
1342 let b0 = Arc::new(Barrier::new(THREADS + 1));
1343 let b1 = Arc::new(Barrier::new(THREADS + 1));
1344
1345 for i in 0..TASKS {
1346 let b0_copy = b0.clone();
1347 let b1_copy = b1.clone();
1348
1349 pool.send_job(move || {
1350 if i < THREADS {
1351 b0_copy.wait();
1352 b1_copy.wait();
1353 }
1354 });
1355 }
1356
1357 b0.wait();
1358
1359 assert_ne!(pool.jobs_queued(), 0);
1360 assert_ne!(pool.jobs_running(), 0);
1361
1362 pool.reset_state();
1363
1364 assert_eq!(pool.jobs_queued(), 0);
1365 assert_eq!(pool.jobs_running(), 0);
1366 assert_eq!(pool.jobs_paniced(), 0);
1367
1368 b1.wait();
1369 pool.wait_until_finished().expect("Nothing should panic");
1370
1371 // Give time for the jobs to execute
1372 sleep(Duration::from_secs(1));
1373
1374 assert_eq!(pool.jobs_queued(), 0);
1375 assert_eq!(pool.jobs_running(), 0);
1376 assert_eq!(pool.jobs_paniced(), 0);
1377 }
1378
1379 #[test]
1380 fn reset_panic_test() {
1381 const TASKS: usize = 32;
1382 const THREADS: usize = 16;
1383
1384 let num = NonZeroUsize::try_from(THREADS).unwrap();
1385 let mut pool = ThreadPoolBuilder::with_thread_amount(num).build().unwrap();
1386
1387 let b0 = Arc::new(Barrier::new(THREADS + 1));
1388 let b1 = Arc::new(Barrier::new(THREADS + 1));
1389
1390 for i in 0..TASKS {
1391 let b0_copy = b0.clone();
1392 let b1_copy = b1.clone();
1393
1394 pool.send_job(move || {
1395 if i < THREADS {
1396 b0_copy.wait();
1397 b1_copy.wait();
1398 }
1399 panic!("Test panic");
1400 });
1401 }
1402
1403 b0.wait();
1404
1405 assert_ne!(pool.jobs_queued(), 0);
1406 assert_ne!(pool.jobs_running(), 0);
1407 assert_eq!(pool.jobs_paniced(), 0);
1408
1409 pool.reset_state();
1410
1411 assert_eq!(pool.jobs_queued(), 0);
1412 assert_eq!(pool.jobs_running(), 0);
1413 assert_eq!(pool.jobs_paniced(), 0);
1414
1415 b1.wait();
1416 pool.wait_until_finished().expect("Nothing should panic");
1417
1418 // Give time for the jobs to execute
1419 sleep(Duration::from_secs(1));
1420
1421 assert_eq!(pool.jobs_queued(), 0);
1422 assert_eq!(pool.jobs_running(), 0);
1423 assert_eq!(pool.jobs_paniced(), 0);
1424 }
1425
1426 #[test]
1427 fn test_wait_until_job_done() {
1428 const THREADS: usize = 1;
1429
1430 let builder = ThreadPoolBuilder::with_thread_amount_usize(THREADS).unwrap();
1431 let pool = builder.build().unwrap();
1432
1433 assert!(pool.wait_until_job_done().is_ok());
1434
1435 pool.send_job(|| {});
1436
1437 assert!(pool.wait_until_job_done().is_ok());
1438
1439 assert_eq!(pool.jobs_queued(), 0);
1440 assert_eq!(pool.jobs_running(), 0);
1441 assert_eq!(pool.jobs_paniced(), 0);
1442
1443 pool.send_job(|| panic!("Test panic"));
1444
1445 assert!(pool.wait_until_job_done().is_err());
1446
1447 assert_eq!(pool.jobs_queued(), 0);
1448 assert_eq!(pool.jobs_running(), 0);
1449 assert_eq!(pool.jobs_paniced(), 1);
1450 }
1451
1452 #[test]
1453 fn test_wait_until_job_done_unchecked() {
1454 const THREADS: usize = 1;
1455
1456 let builder = ThreadPoolBuilder::with_thread_amount_usize(THREADS).unwrap();
1457 let pool = builder.build().unwrap();
1458
1459 // This doesn't block forever
1460 pool.wait_until_job_done_unchecked();
1461
1462 pool.send_job(|| {});
1463
1464 pool.wait_until_job_done_unchecked();
1465
1466 assert_eq!(pool.jobs_queued(), 0);
1467 assert_eq!(pool.jobs_running(), 0);
1468 assert_eq!(pool.jobs_paniced(), 0);
1469
1470 pool.send_job(|| panic!("Test panic"));
1471
1472 pool.wait_until_job_done_unchecked();
1473
1474 assert_eq!(pool.jobs_queued(), 0);
1475 assert_eq!(pool.jobs_running(), 0);
1476 assert_eq!(pool.jobs_paniced(), 1);
1477 }
1478
1479 #[test]
1480 #[allow(dead_code)]
1481 fn test_flakiness() {
1482 for _ in 0..10 {
1483 test_wait();
1484 test_wait_unchecked();
1485 deal_with_panics();
1486 receive_value();
1487 test_clones();
1488 reset_state_while_running();
1489 test_wait_until_job_done_unchecked();
1490 test_wait_until_job_done();
1491 reset_panic_test();
1492 }
1493 }
1494}