easy_threadpool/
lib.rs

1#![forbid(missing_docs)]
2// License at https://github.com/NicoElbers/easy_threadpool
3// It's MIT
4
5//! A simple thread pool to execute jobs in parallel
6//!
7//! A simple crate without dependencies which allows you to create a threadpool
8//! that has a specified amount of threads which execute given jobs. Threads don't
9//! crash when a job panics!
10//!
11//! # Examples
12//!
13//! ## Basic usage
14//!
15//! A basic use of the threadpool
16//!
17//! ```rust
18//! # use std::error::Error;
19//! # fn main() -> Result<(), Box<dyn Error>> {
20//! use easy_threadpool::ThreadPoolBuilder;
21//!
22//! fn job() {
23//!     println!("Hello world!");
24//! }
25//!
26//! let builder = ThreadPoolBuilder::with_max_threads()?;
27//! let pool = builder.build()?;
28//!
29//! for _ in 0..10 {
30//!     pool.send_job(job);
31//! }
32//!
33//! assert!(pool.wait_until_finished().is_ok());
34//! # Ok(())
35//! # }
36//! ```
37//!
38//! ## More advanced usage
39//!
40//! A slightly more advanced usage of the threadpool
41//!
42//! ```rust
43//! # use std::error::Error;
44//! # fn main() -> Result<(), Box<dyn Error>> {
45//! use easy_threadpool::ThreadPoolBuilder;
46//! use std::sync::mpsc::channel;
47//!
48//! let builder = ThreadPoolBuilder::with_max_threads()?;
49//! let pool = builder.build()?;
50//!
51//! let (tx, rx) = channel();
52//!
53//! for _ in 0..10 {
54//!     let tx = tx.clone();
55//!     pool.send_job(move || {
56//!         tx.send(1).expect("Receiver should still exist");
57//!     });
58//! }
59//!
60//! assert!(pool.wait_until_finished().is_ok());
61//!
62//! assert_eq!(rx.iter().take(10).fold(0, |a, b| a + b), 10);
63//! # Ok(())
64//! # }
65//! ```
66//!
67//! ## Dealing with panics
68//!
69//! This threadpool implementation is resistant to jobs panicing
70//!
71//! ```rust
72//! # use std::error::Error;
73//! # fn main() -> Result<(), Box<dyn Error>> {
74//! use easy_threadpool::ThreadPoolBuilder;
75//! use std::sync::mpsc::channel;
76//! use std::num::NonZeroUsize;
77//!
78//! fn panic_fn() {
79//!     panic!("Test panic");
80//! }
81//!
82//! let num = NonZeroUsize::try_from(1)?;
83//! let builder = ThreadPoolBuilder::with_thread_amount(num);
84//! let pool = builder.build()?;
85//!
86//! let (tx, rx) = channel();
87//! for _ in 0..10 {
88//!     let tx = tx.clone();
89//!     pool.send_job(move || {
90//!         tx.send(1).expect("Receiver should still exist");
91//!         panic!("Test panic");
92//!     });
93//! }
94//!
95//! assert!(pool.wait_until_finished().is_err());
96//! pool.wait_until_finished_unchecked();
97//!
98//! assert_eq!(pool.jobs_paniced(), 10);
99//! assert_eq!(rx.iter().take(10).fold(0, |a, b| a + b), 10);
100//! # Ok(())
101//! # }
102//! ```
103
104use std::{
105    error::Error,
106    fmt::{Debug, Display},
107    io,
108    num::{NonZeroUsize, TryFromIntError},
109    panic::{catch_unwind, UnwindSafe},
110    sync::{
111        atomic::{AtomicBool, AtomicUsize, Ordering},
112        mpsc::{channel, Sender},
113        Arc, Condvar, Mutex,
114    },
115    thread::{self, available_parallelism},
116};
117
118type ThreadPoolFunctionBoxed = Box<dyn FnOnce() + Send + UnwindSafe>;
119
120/// Simple error to indicate that a job has paniced in the threadpool
121#[derive(Debug)]
122pub struct JobHasPanicedError {}
123
124impl Display for JobHasPanicedError {
125    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126        write!(f, "At least one job in the threadpool has caused a panic")
127    }
128}
129
130impl Error for JobHasPanicedError {}
131
132// /// Simple error to indicate a function passed to do_until_finished has paniced
133// #[derive(Debug)]
134// pub struct DoUntilFinishedFunctionPanicedError {}
135
136// impl Display for DoUntilFinishedFunctionPanicedError {
137//     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138//         write!(f, "The function passed to do_until_finished has paniced")
139//     }
140// }
141
142// impl Error for DoUntilFinishedFunctionPanicedError {}
143
144// /// An enum to combine both errors previously defined
145// #[derive(Debug)]
146// pub enum Errors {
147//     /// Enum representation of [`JobHasPanicedError`]
148//     JobHasPanicedError,
149//     /// Enum representation of [`DoUntilFinishedFunctionPanicedError`]
150//     DoUntilFinishedFunctionPanicedError,
151// }
152
153// impl Display for Errors {
154//     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
155//         match self {
156//             Errors::DoUntilFinishedFunctionPanicedError => {
157//                 Display::fmt(&DoUntilFinishedFunctionPanicedError {}, f)
158//             }
159//             Errors::JobHasPanicedError => Display::fmt(&JobHasPanicedError {}, f),
160//         }
161//     }
162// }
163
164// impl Error for Errors {}
165
166#[derive(Debug, Default)]
167struct SharedState {
168    jobs_queued: AtomicUsize,
169    jobs_running: AtomicUsize,
170    jobs_paniced: AtomicUsize,
171    is_finished: Mutex<bool>,
172    has_paniced: AtomicBool,
173}
174
175impl Display for SharedState {
176    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
177        write!(
178            f,
179            "SharedState<jobs_queued: {}, jobs_running: {}, jobs_paniced: {}, is_finished: {}, has_paniced: {}>",
180            self.jobs_queued.load(Ordering::Relaxed),
181            self.jobs_running.load(Ordering::Relaxed),
182            self.jobs_paniced.load(Ordering::Relaxed),
183            self.is_finished.lock().expect("Shared state should never panic"),
184            self.has_paniced.load(Ordering::Relaxed)
185        )
186    }
187}
188
189impl SharedState {
190    fn new() -> Self {
191        Self {
192            jobs_running: AtomicUsize::new(0),
193            jobs_queued: AtomicUsize::new(0),
194            jobs_paniced: AtomicUsize::new(0),
195            is_finished: Mutex::new(true),
196            has_paniced: AtomicBool::new(false),
197        }
198    }
199
200    fn job_starting(&self) {
201        debug_assert!(
202            self.jobs_queued.load(Ordering::Acquire) > 0,
203            "Negative jobs queued"
204        );
205
206        self.jobs_running.fetch_add(1, Ordering::SeqCst);
207        self.jobs_queued.fetch_sub(1, Ordering::SeqCst);
208    }
209
210    fn job_finished(&self) {
211        debug_assert!(
212            self.jobs_running.load(Ordering::Acquire) > 0,
213            "Negative jobs running"
214        );
215
216        self.jobs_running.fetch_sub(1, Ordering::SeqCst);
217
218        if self.jobs_queued.load(Ordering::Acquire) == 0
219            && self.jobs_running.load(Ordering::Acquire) == 0
220        {
221            let mut is_finished = self
222                .is_finished
223                .lock()
224                .expect("Shared state should never panic");
225
226            *is_finished = true;
227        }
228    }
229
230    fn job_queued(&self) {
231        self.jobs_queued.fetch_add(1, Ordering::SeqCst);
232
233        let mut is_finished = self
234            .is_finished
235            .lock()
236            .expect("Shared state should never panic");
237
238        *is_finished = false;
239    }
240
241    fn job_paniced(&self) {
242        println!("Checking panic");
243
244        self.has_paniced.store(true, Ordering::SeqCst);
245        self.jobs_paniced.fetch_add(1, Ordering::SeqCst);
246
247        println!("Has paniced {}", self.has_paniced.load(Ordering::Acquire));
248    }
249}
250
251/// Threadpool abstraction to keep some state
252#[derive(Debug)]
253pub struct ThreadPool {
254    thread_amount: NonZeroUsize,
255    job_sender: Arc<Sender<ThreadPoolFunctionBoxed>>,
256    shared_state: Arc<SharedState>,
257    cvar: Arc<Condvar>,
258}
259
260impl Clone for ThreadPool {
261    fn clone(&self) -> Self {
262        Self {
263            thread_amount: self.thread_amount,
264            job_sender: self.job_sender.clone(),
265            shared_state: self.shared_state.clone(),
266            cvar: self.cvar.clone(),
267        }
268    }
269}
270
271impl Display for ThreadPool {
272    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
273        write!(
274            f,
275            "Threadpool< thread_amount: {}, shared_state: {}>",
276            self.thread_amount, self.shared_state
277        )
278    }
279}
280
281impl ThreadPool {
282    fn new(builder: ThreadPoolBuilder) -> io::Result<Self> {
283        let thread_amount = builder.thread_amount;
284
285        let (job_sender, job_receiver) = channel::<ThreadPoolFunctionBoxed>();
286        let job_sender = Arc::new(job_sender);
287        let shareable_job_reciever = Arc::new(Mutex::new(job_receiver));
288
289        let shared_state = Arc::new(SharedState::new());
290        let cvar = Arc::new(Condvar::new());
291
292        for thread_num in 0..thread_amount.get() {
293            let job_reciever = shareable_job_reciever.clone();
294
295            let thread_name = format!("Threadpool worker {thread_num}");
296
297            thread::Builder::new().name(thread_name).spawn(move || {
298                loop {
299                    let job = {
300                        let lock = job_reciever //
301                            .lock()
302                            .expect("Cannot get reciever");
303
304                        lock.recv()
305                    };
306
307                    // NOTE: Breaking on error ensures that all threads will stop
308                    // when the threadpool is dropped and all jobs have been executed
309                    match job {
310                        Ok(job) => job(),
311                        Err(_) => break,
312                    };
313                }
314            })?;
315        }
316
317        Ok(Self {
318            thread_amount,
319            job_sender,
320            shared_state,
321            cvar,
322        })
323    }
324
325    /// The `send_job` function takes in a function or closure without any arguments
326    /// and sends it to the threadpool to be executed. Jobs will be taken from the
327    /// job queue in order of them being sent, but that in no way guarantees they will
328    /// be executed in order.
329    ///
330    /// `job`s must implement `Send` in order to be safely sent across threads and
331    /// `UnwindSafe` to allow catching panics when executing the jobs. Both of these
332    /// traits are auto implemented.
333    ///
334    /// # Examples
335    ///
336    /// Sending a function or closure to the threadpool
337    /// ```rust
338    /// # use std::error::Error;
339    /// # fn main() -> Result<(), Box<dyn Error>> {
340    /// use easy_threadpool::ThreadPoolBuilder;
341    ///
342    /// fn job() {
343    ///     println!("Hello world from a function!");
344    /// }
345    ///
346    /// let builder = ThreadPoolBuilder::with_max_threads()?;
347    /// let pool = builder.build()?;
348    ///
349    /// pool.send_job(job);
350    ///
351    /// pool.send_job(|| println!("Hello world from a closure!"));
352    /// # Ok(())
353    /// # }
354    /// ```
355    pub fn send_job(&self, job: impl FnOnce() + Send + UnwindSafe + 'static) {
356        // NOTE: It is essential that the shared state is updated FIRST otherwise
357        // we have a race condidition that the job is transmitted and read before
358        // the shared state is updated, leading to a negative amount of jobs queued
359        self.shared_state.job_queued();
360
361        debug_assert!(self.jobs_queued() > 0, "Job didn't queue properly");
362        debug_assert!(!self.is_finished(), "Finish wasn't properly set to false");
363
364        // Pass our own state to the job. This makes it so that multiple threadpools
365        // with different states can send jobs to the same threads without getting
366        // eachothers panics for example
367        let state = self.shared_state.clone();
368        let cvar = self.cvar.clone();
369        let job_with_state = Self::job_function(Box::new(job), state, cvar);
370
371        self.job_sender
372            .send(Box::new(job_with_state))
373            .expect("The sender cannot be deallocated while the threadpool is in use")
374    }
375
376    fn job_function(
377        job: ThreadPoolFunctionBoxed,
378        state: Arc<SharedState>,
379        cvar: Arc<Condvar>,
380    ) -> impl FnOnce() + Send + 'static {
381        move || {
382            state.job_starting();
383
384            // NOTE: The use of catch_unwind means that the thread will not
385            // panic from any of the jobs it was sent. This is useful because
386            // we won't ever have to restart a thread.
387            let result = catch_unwind(job);
388
389            println!("{result:?}");
390
391            // NOTE: Do the panic check first otherwise we have a race condition
392            // where the final job panics and the wait_until_finished function
393            // doesn't detect it
394            if result.is_err() {
395                state.job_paniced();
396            }
397
398            state.job_finished();
399
400            cvar.notify_all();
401        }
402    }
403
404    /// This function will wait until all jobs have finished sending. Additionally
405    /// it will return early if any job panics.
406    ///
407    /// Be careful though, returning early DOES NOT mean that the sent jobs are
408    /// cancelled. They will remain running. Cancelling jobs that are queued is not
409    /// a feature provided by this crate as of now.
410    ///
411    /// # Errors
412    ///
413    /// This function will error if any job sent to the threadpool has errored.
414    /// This includes any errors since either the threadpool was created or since
415    /// the state was reset.
416    ///
417    /// # Examples
418    ///
419    /// ```rust
420    /// # use std::error::Error;
421    /// # fn main() -> Result<(), Box<dyn Error>> {
422    /// use easy_threadpool::ThreadPoolBuilder;
423    ///
424    /// let builder = ThreadPoolBuilder::with_max_threads()?;
425    /// let pool = builder.build()?;
426    ///
427    /// for _ in 0..10 {
428    ///     pool.send_job(|| println!("Hello world"));
429    /// }
430    ///
431    /// assert!(pool.wait_until_finished().is_ok());
432    /// assert!(pool.is_finished());
433    ///
434    /// pool.send_job(|| panic!("Test panic"));
435    ///
436    /// assert!(pool.wait_until_finished().is_err());
437    /// assert!(pool.has_paniced());
438    /// # Ok(())
439    /// # }
440    /// ```
441    pub fn wait_until_finished(&self) -> Result<(), JobHasPanicedError> {
442        let mut is_finished = self
443            .shared_state
444            .is_finished
445            .lock()
446            .expect("Shared state should never panic");
447
448        while !*is_finished && !self.has_paniced() {
449            is_finished = self
450                .cvar
451                .wait(is_finished)
452                .expect("Shared state should never panic");
453        }
454
455        println!("panic {}", self.has_paniced());
456
457        debug_assert!(
458            self.has_paniced() || self.jobs_running() == 0,
459            "wait_until_finished stopped {} jobs running and {} panics",
460            self.jobs_running(),
461            self.jobs_paniced()
462        );
463        debug_assert!(
464            self.has_paniced() || self.jobs_queued() == 0,
465            "wait_until_finished stopped while {} jobs queued and {} panics",
466            self.jobs_queued(),
467            self.jobs_paniced()
468        );
469
470        println!("WERE DONE WAITING");
471
472        match self.shared_state.has_paniced.load(Ordering::Acquire) {
473            true => Err(JobHasPanicedError {}),
474            false => Ok(()),
475        }
476    }
477
478    /// This function will wait until one job finished after calling the function.
479    /// Additionally, if the threadpool is finished this function will also return.
480    /// Additionally it will return early if any job panics.
481    ///
482    /// Be careful though, returning early DOES NOT mean that the sent jobs are
483    /// cancelled. They will remain running. Cancelling jobs that are queued is not
484    /// a feature provided by this crate as of now.
485    ///
486    /// # Errors
487    ///
488    /// This function will error if any job sent to the threadpool has errored.
489    /// This includes any errors since either the threadpool was created or since
490    /// the state was reset.
491    ///
492    /// # Examples
493    ///
494    /// ```rust
495    /// # use std::error::Error;
496    /// # fn main() -> Result<(), Box<dyn Error>> {
497    /// use easy_threadpool::ThreadPoolBuilder;
498    ///
499    /// let builder = ThreadPoolBuilder::with_max_threads()?;
500    /// let pool = builder.build()?;
501    ///
502    /// assert!(pool.wait_until_job_done().is_ok());
503    /// assert!(pool.is_finished());
504    ///
505    /// pool.send_job(|| panic!("Test panic"));
506    ///
507    /// assert!(pool.wait_until_job_done().is_err());
508    /// assert!(pool.has_paniced());
509    /// # Ok(())
510    /// # }
511    /// ```
512    pub fn wait_until_job_done(&self) -> Result<(), JobHasPanicedError> {
513        fn paniced(state: &SharedState) -> bool {
514            state.jobs_paniced.load(Ordering::Acquire) != 0
515        }
516
517        let is_finished = self
518            .shared_state
519            .is_finished
520            .lock()
521            .expect("Shared state should never panic");
522
523        if *is_finished {
524            return Ok(());
525        };
526
527        drop(self.cvar.wait(is_finished));
528
529        // Keep the guard so we don't have to drop the lock only to reaquire it
530        if paniced(&self.shared_state) {
531            Err(JobHasPanicedError {})
532        } else {
533            Ok(())
534        }
535    }
536
537    /// This function will wait until all jobs have finished sending. It will continue
538    /// waiting if a job panics in the thread pool.
539    ///
540    /// I highly doubt this has much of a performance improvement, but it's very
541    /// useful if you know that for whatever reason your jobs might panic and that
542    /// would be fine.
543    ///
544    /// # Examples
545    ///
546    /// ```rust
547    /// # use std::error::Error;
548    /// # fn main() -> Result<(), Box<dyn Error>> {
549    /// use easy_threadpool::ThreadPoolBuilder;
550    ///
551    /// let builder = ThreadPoolBuilder::with_max_threads()?;
552    /// let pool = builder.build()?;
553    ///
554    /// for _ in 0..10 {
555    ///     pool.send_job(|| println!("Hello world"));
556    /// }
557    ///
558    /// pool.wait_until_finished_unchecked();
559    /// assert!(pool.is_finished());
560    ///
561    /// pool.send_job(|| panic!("Test panic"));
562    ///
563    /// pool.wait_until_finished_unchecked();
564    /// assert!(pool.has_paniced());
565    /// # Ok(())
566    /// # }
567    /// ```
568    pub fn wait_until_finished_unchecked(&self) {
569        let mut is_finished = self
570            .shared_state
571            .is_finished
572            .lock()
573            .expect("Shared state sould never panic");
574
575        if *is_finished {
576            return;
577        }
578
579        while !*is_finished {
580            is_finished = self
581                .cvar
582                .wait(is_finished)
583                .expect("Shared state should never panic")
584        }
585
586        debug_assert!(
587            self.shared_state.jobs_running.load(Ordering::Acquire) == 0,
588            "Job still running after wait_until_finished_unchecked"
589        );
590        debug_assert!(
591            self.shared_state.jobs_queued.load(Ordering::Acquire) == 0,
592            "Job still queued after wait_until_finished_unchecked"
593        );
594    }
595
596    /// This function will wait until one job finished after calling the function.
597    /// Additionally, if the threadpool is finished this function will also return.
598    ///
599    /// Be careful though, returning early DOES NOT mean that the sent jobs are
600    /// cancelled. They will remain running. Cancelling jobs that are queued is not
601    /// a feature provided by this crate as of now.
602    ///
603    /// # Examples
604    ///
605    /// ```rust
606    /// # use std::error::Error;
607    /// # fn main() -> Result<(), Box<dyn Error>> {
608    /// use easy_threadpool::ThreadPoolBuilder;
609    ///
610    /// let builder = ThreadPoolBuilder::with_max_threads()?;
611    /// let pool = builder.build()?;
612    ///
613    /// assert!(pool.wait_until_job_done().is_ok());
614    /// assert!(pool.is_finished());
615    ///
616    /// pool.send_job(|| panic!("Test panic"));
617    ///
618    /// assert!(pool.wait_until_job_done().is_err());
619    /// assert!(pool.has_paniced());
620    /// # Ok(())
621    /// # }
622    /// ```
623    pub fn wait_until_job_done_unchecked(&self) {
624        let is_finished = self
625            .shared_state
626            .is_finished
627            .lock()
628            .expect("Shared state should never panic");
629
630        // This is guaranteed to work because jobs cannot finish without having
631        // the shared state lock, and we keep the lock until we start waiting for
632        // the condvar
633        if *is_finished {
634            return;
635        };
636
637        drop(self.cvar.wait(is_finished));
638    }
639
640    /// This function will reset the state of this instance of the threadpool.
641    ///
642    /// When resetting the state you lose all information about previously sent jobs.
643    /// If a job you previously sent panics, you will not be notified, nor can  you
644    /// wait until your previously sent jobs are done running. HOWEVER they will still
645    /// be running. Be very careful to not see this as a "stop" button.
646    ///
647    /// # Examples
648    ///
649    /// ```rust
650    /// # use std::error::Error;
651    /// # fn main() -> Result<(), Box<dyn Error>> {
652    /// use easy_threadpool::ThreadPoolBuilder;
653    ///
654    /// let builder = ThreadPoolBuilder::with_max_threads()?;
655    /// let mut pool = builder.build()?;
656    ///
657    /// pool.send_job(|| panic!("Test panic"));
658    ///
659    /// assert!(pool.wait_until_finished().is_err());
660    /// assert!(pool.has_paniced());
661    ///
662    /// pool.reset_state();
663    ///
664    /// assert!(pool.wait_until_finished().is_ok());
665    /// assert!(!pool.has_paniced());
666    /// # Ok(())
667    /// # }
668    /// ```
669    pub fn reset_state(&mut self) {
670        let cvar = Arc::new(Condvar::new());
671        let shared_state = Arc::new(SharedState::new());
672
673        self.cvar = cvar;
674        self.shared_state = shared_state;
675    }
676
677    /// This function will clone the threadpool and then reset its state. This
678    /// makes it so you can have 2 different states operate on the same threads,
679    /// effectively sharing the threads.
680    ///
681    /// Note however that there is no mechanism
682    /// to give different instances equal CPU time, jobs are executed on a first
683    /// come first server basis.
684    ///
685    /// # Examples
686    ///
687    /// ```rust
688    /// # use std::error::Error;
689    /// # fn main() -> Result<(), Box<dyn Error>> {
690    /// use easy_threadpool::ThreadPoolBuilder;
691    ///
692    /// let builder = ThreadPoolBuilder::with_max_threads()?;
693    /// let pool = builder.build()?;
694    ///
695    /// let pool_clone = pool.clone_with_new_state();
696    ///
697    /// pool.send_job(|| panic!("Test panic"));
698    ///
699    /// assert!(pool.wait_until_finished().is_err());
700    /// assert!(pool.has_paniced());
701    ///
702    /// assert!(pool_clone.wait_until_finished().is_ok());
703    /// assert!(!pool_clone.has_paniced());
704    /// # Ok(())
705    /// # }
706    /// ```
707    pub fn clone_with_new_state(&self) -> Self {
708        let mut new_pool = self.clone();
709        new_pool.reset_state();
710        new_pool
711    }
712
713    /// Returns the amount of jobs currently being ran by this instance of the
714    /// thread pool. If muliple different instances of this threadpool (see [`clone_with_new_state`])
715    /// this number might be lower than the max amount of threads, even if there
716    /// are still jobs queued
717    ///
718    /// # Examples
719    ///
720    /// ```rust
721    /// # use std::error::Error;
722    /// # fn main() -> Result<(), Box<dyn Error>> {
723    /// use easy_threadpool::ThreadPoolBuilder;
724    /// use std::{
725    ///     num::NonZeroUsize,
726    ///     sync::{Arc, Barrier},
727    /// };
728    /// let threads = 16;
729    /// let tasks = threads * 10;
730    ///
731    /// let num = NonZeroUsize::try_from(threads)?;
732    /// let pool = ThreadPoolBuilder::with_thread_amount(num).build()?;
733    ///
734    /// let b0 = Arc::new(Barrier::new(threads + 1));
735    /// let b1 = Arc::new(Barrier::new(threads + 1));
736    ///
737    /// for i in 0..tasks {
738    ///     let b0_copy = b0.clone();
739    ///     let b1_copy = b1.clone();
740    ///
741    ///     pool.send_job(move || {
742    ///         if i < threads {
743    ///             b0_copy.wait();
744    ///             b1_copy.wait();
745    ///         }
746    ///     });
747    /// }
748    ///
749    /// b0.wait();
750    /// assert_eq!(pool.jobs_running(), threads);
751    /// # b1.wait();
752    /// # Ok(())
753    /// # }
754    /// ```
755    pub fn jobs_running(&self) -> usize {
756        self.shared_state.jobs_running.load(Ordering::Acquire)
757    }
758
759    /// Returns the amount of jobs currently queued by this threadpool instance.
760    /// There might be more jobs queued that we don't know about if there are other
761    /// instances of this threadpool (see [`clone_with_new_state`]).
762    ///
763    /// # Examples
764    ///
765    /// ```rust
766    /// # use std::error::Error;
767    /// # fn main() -> Result<(), Box<dyn Error>> {
768    /// use easy_threadpool::ThreadPoolBuilder;
769    /// use std::{
770    ///     num::NonZeroUsize,
771    ///     sync::{Arc, Barrier},
772    /// };
773    /// let threads = 16;
774    /// let tasks = 100;
775    ///
776    /// let num = NonZeroUsize::try_from(threads)?;
777    /// let pool = ThreadPoolBuilder::with_thread_amount(num).build()?;
778    ///
779    /// let b0 = Arc::new(Barrier::new(threads + 1));
780    /// let b1 = Arc::new(Barrier::new(threads + 1));
781    ///
782    /// for i in 0..tasks {
783    ///     let b0_copy = b0.clone();
784    ///     let b1_copy = b1.clone();
785    ///
786    ///     pool.send_job(move || {
787    ///         if i < threads {
788    ///             b0_copy.wait();
789    ///             b1_copy.wait();
790    ///         }
791    ///     });
792    /// }
793    ///
794    /// b0.wait();
795    /// assert_eq!(pool.jobs_queued(), tasks - threads);
796    /// # b1.wait();
797    /// # Ok(())
798    /// # }
799    /// ```
800    pub fn jobs_queued(&self) -> usize {
801        self.shared_state.jobs_queued.load(Ordering::Acquire)
802    }
803
804    /// Returns the amount of jobs that were sent by this instance of the threadpool
805    /// and that paniced.
806    ///
807    /// # Examples
808    ///
809    /// ```rust
810    /// # use std::error::Error;
811    /// # fn main() -> Result<(), Box<dyn Error>> {
812    /// use easy_threadpool::ThreadPoolBuilder;
813    ///
814    /// let pool = ThreadPoolBuilder::with_max_threads()?.build()?;
815    ///
816    /// for i in 0..10 {
817    ///     pool.send_job(|| panic!("Test panic"));
818    /// }
819    ///
820    /// pool.wait_until_finished_unchecked();
821    ///
822    /// assert_eq!(pool.jobs_paniced(), 10);
823    /// # Ok(())
824    /// # }
825    /// ```
826    pub fn jobs_paniced(&self) -> usize {
827        self.shared_state.jobs_paniced.load(Ordering::Acquire)
828    }
829
830    /// Returns whether a thread has had any jobs panic at all
831    ///
832    /// # Examples
833    ///
834    /// ```rust
835    /// # use std::error::Error;
836    /// # fn main() -> Result<(), Box<dyn Error>> {
837    /// use easy_threadpool::ThreadPoolBuilder;
838    ///
839    /// let pool = ThreadPoolBuilder::with_max_threads()?.build()?;
840    ///
841    /// pool.send_job(|| panic!("Test panic"));
842    ///
843    /// pool.wait_until_finished_unchecked();
844    ///
845    /// assert!(pool.has_paniced());
846    /// # Ok(())
847    /// # }
848    /// ```
849    pub fn has_paniced(&self) -> bool {
850        self.shared_state.has_paniced.load(Ordering::Acquire)
851    }
852
853    /// Returns whether a threadpool instance has no jobs running and no jobs queued,
854    /// in other words if it's finished.
855    ///
856    /// # Examples
857    ///
858    /// ```rust
859    /// # use std::error::Error;
860    /// # fn main() -> Result<(), Box<dyn Error>> {
861    /// use easy_threadpool::ThreadPoolBuilder;
862    /// use std::{
863    ///     num::NonZeroUsize,
864    ///     sync::{Arc, Barrier},
865    /// };
866    /// let pool = ThreadPoolBuilder::with_max_threads()?.build()?;
867    ///
868    /// let b = Arc::new(Barrier::new(2));
869    ///
870    /// assert!(pool.is_finished());
871    ///
872    /// let b_clone = b.clone();
873    /// pool.send_job(move || { b_clone.wait(); });
874    ///
875    /// assert!(!pool.is_finished());
876    /// # b.wait();
877    /// # Ok(())
878    /// # }
879    /// ```
880    pub fn is_finished(&self) -> bool {
881        *self
882            .shared_state
883            .is_finished
884            .lock()
885            .expect("Shared state should never panic")
886    }
887
888    /// This function returns the amount of threads used to create the threadpool
889    ///
890    /// # Examples
891    ///
892    /// ```rust
893    /// # use std::error::Error;
894    /// # fn main() -> Result<(), Box<dyn Error>> {
895    /// use easy_threadpool::ThreadPoolBuilder;
896    /// use std::num::NonZeroUsize;
897    ///
898    /// let threads = 10;
899    ///
900    /// let num = NonZeroUsize::try_from(threads)?;
901    /// let pool = ThreadPoolBuilder::with_thread_amount(num).build()?;
902    ///
903    /// assert_eq!(pool.threads().get(), threads);
904    /// # Ok(())
905    /// # }
906    /// ```
907    pub const fn threads(&self) -> NonZeroUsize {
908        self.thread_amount
909    }
910}
911
912/// A ThreadPoolbuilder is a builder to easily create a thread pool
913pub struct ThreadPoolBuilder {
914    thread_amount: NonZeroUsize,
915    // thread_name: Option<String>,
916}
917
918impl Default for ThreadPoolBuilder {
919    fn default() -> Self {
920        Self {
921            thread_amount: NonZeroUsize::try_from(1).unwrap(),
922        }
923    }
924}
925
926impl ThreadPoolBuilder {
927    /// Initialize the amount of threads the builder will build to `thread_amount`
928    pub fn with_thread_amount(thread_amount: NonZeroUsize) -> ThreadPoolBuilder {
929        ThreadPoolBuilder { thread_amount }
930    }
931
932    /// Initialize the amount of threads the builder will build to `thread_amount`.
933    ///
934    /// # Errors
935    ///
936    /// If `thread_amount` cannot be converted to a [`std::num::NonZeroUsize`] (aka it is 0).
937    pub fn with_thread_amount_usize(
938        thread_amount: usize,
939    ) -> Result<ThreadPoolBuilder, TryFromIntError> {
940        let thread_amount = NonZeroUsize::try_from(thread_amount)?;
941        Ok(Self::with_thread_amount(thread_amount))
942    }
943
944    /// Initialize the amount of threads the builder will build to the available parallelism
945    /// as provided by [`std::thread::available_parallelism`]
946    ///
947    /// # Errors
948    ///
949    /// Taken from the available_parallelism() documentation:
950    /// This function will, but is not limited to, return errors in the following
951    /// cases:
952    ///
953    /// * If the amount of parallelism is not known for the target platform.
954    /// * If the program lacks permission to query the amount of parallelism made
955    ///   available to it.
956    ///
957    pub fn with_max_threads() -> io::Result<ThreadPoolBuilder> {
958        let max_threads = available_parallelism()?;
959        Ok(ThreadPoolBuilder {
960            thread_amount: max_threads,
961        })
962    }
963
964    // pub fn with_thread_name(thread_name: String) -> ThreadPoolBuilder {
965    //     ThreadPoolBuilder {
966    //         thread_name: Some(thread_name),
967    //         ..Default::default()
968    //     }
969    // }
970
971    /// Set the thead amount in the builder
972    pub fn set_thread_amount(mut self, thread_amount: NonZeroUsize) -> ThreadPoolBuilder {
973        self.thread_amount = thread_amount;
974        self
975    }
976
977    /// Set the thead amount in the builder from usize
978    ///
979    /// # Errors
980    ///
981    /// If `thread_amount` cannot be turned into NonZeroUsize (aka it is 0)
982    pub fn set_thread_amount_usize(
983        self,
984        thread_amount: usize,
985    ) -> Result<ThreadPoolBuilder, TryFromIntError> {
986        let thread_amount = NonZeroUsize::try_from(thread_amount)?;
987        Ok(self.set_thread_amount(thread_amount))
988    }
989
990    /// set the amount of threads the builder will build to the available parallelism
991    /// as provided by [`std::thread::available_parallelism`]
992    ///
993    /// # Errors
994    ///
995    /// Taken from the available_parallelism() documentation:
996    /// This function will, but is not limited to, return errors in the following
997    /// cases:
998    ///
999    /// * If the amount of parallelism is not known for the target platform.
1000    /// * If the program lacks permission to query the amount of parallelism made
1001    ///   available to it.
1002    ///
1003    pub fn set_max_threads(mut self) -> io::Result<ThreadPoolBuilder> {
1004        let max_threads = available_parallelism()?;
1005        self.thread_amount = max_threads;
1006        Ok(self)
1007    }
1008
1009    // pub fn set_thread_name(mut self, thread_name: String) -> ThreadPoolBuilder {
1010    //     self.thread_name = Some(thread_name);
1011    //     self
1012    // }
1013
1014    /// Build the builder into a threadpool, taking all the initialized values
1015    /// from the builder and using defaults for those not initialized.
1016    ///
1017    /// # Errors
1018    ///
1019    /// Taken from [`std::thread::Builder::spawn`]:
1020    ///
1021    /// Unlike the [`spawn`](https://doc.rust-lang.org/stable/std/thread/fn.spawn.html) free function, this method yields an
1022    /// [`io::Result`] to capture any failure to create the thread at
1023    /// the OS level.
1024    pub fn build(self) -> io::Result<ThreadPool> {
1025        ThreadPool::new(self)
1026    }
1027}
1028
1029#[cfg(test)]
1030mod test {
1031    use core::panic;
1032    use std::{
1033        num::NonZeroUsize,
1034        sync::{mpsc::channel, Arc, Barrier},
1035        thread::sleep,
1036        time::Duration,
1037    };
1038
1039    use crate::ThreadPoolBuilder;
1040
1041    #[test]
1042    // Test multiple panics on a single thread, this ensures that a thread can
1043    // handle panics
1044    fn deal_with_panics() {
1045        fn panic_fn() {
1046            panic!("Test panic");
1047        }
1048
1049        let thread_num: NonZeroUsize = 1.try_into().unwrap();
1050        let builder = ThreadPoolBuilder::with_thread_amount(thread_num);
1051
1052        let pool = builder.build().unwrap();
1053
1054        for _ in 0..10 {
1055            pool.send_job(panic_fn);
1056        }
1057
1058        assert!(
1059            pool.wait_until_finished().is_err(),
1060            "Pool didn't detect panic in wait_until_finished"
1061        );
1062
1063        assert!(
1064            pool.has_paniced(),
1065            "Pool didn't detect panic in has_paniced"
1066        );
1067        pool.wait_until_finished_unchecked();
1068
1069        assert!(
1070            pool.jobs_queued() == 0,
1071            "Incorrect amount of jobs queued after wait"
1072        );
1073        assert!(
1074            pool.jobs_running() == 0,
1075            "Incorrect amount of jobs running after wait"
1076        );
1077        assert!(
1078            pool.jobs_paniced() == 10,
1079            "Incorrect amount of jobs paniced after wait"
1080        );
1081    }
1082
1083    #[test]
1084    fn receive_value() {
1085        let (tx, rx) = channel::<u32>();
1086
1087        let func = move || {
1088            tx.send(69).unwrap();
1089        };
1090
1091        let pool = ThreadPoolBuilder::default().build().unwrap();
1092
1093        pool.send_job(func);
1094
1095        assert_eq!(rx.recv(), Ok(69), "Incorrect value received");
1096    }
1097
1098    #[test]
1099    fn test_wait() {
1100        const TASKS: usize = 1000;
1101        const THREADS: usize = 16;
1102
1103        let b0 = Arc::new(Barrier::new(THREADS + 1));
1104        let b1 = Arc::new(Barrier::new(THREADS + 1));
1105
1106        let pool = ThreadPoolBuilder::with_thread_amount_usize(THREADS)
1107            .unwrap()
1108            .build()
1109            .unwrap();
1110
1111        for i in 0..TASKS {
1112            let b0 = b0.clone();
1113            let b1 = b1.clone();
1114
1115            pool.send_job(move || {
1116                if i < THREADS {
1117                    b0.wait();
1118                    b1.wait();
1119                }
1120            });
1121        }
1122
1123        b0.wait();
1124
1125        assert_eq!(
1126            pool.jobs_running(),
1127            THREADS,
1128            "Incorrect amount of jobs running"
1129        );
1130        assert_eq!(
1131            pool.jobs_paniced(),
1132            0,
1133            "Incorrect amount of threads paniced"
1134        );
1135
1136        b1.wait();
1137
1138        assert!(
1139            pool.wait_until_finished().is_ok(),
1140            "wait_until_finished incorrectly detected a panic"
1141        );
1142
1143        assert_eq!(
1144            pool.jobs_queued(),
1145            0,
1146            "Incorrect amount of jobs queued after wait"
1147        );
1148        assert_eq!(
1149            pool.jobs_running(),
1150            0,
1151            "Incorrect amount of jobs running after wait"
1152        );
1153        assert_eq!(
1154            pool.jobs_paniced(),
1155            0,
1156            "Incorrect amount of threads paniced after wait"
1157        );
1158    }
1159
1160    #[test]
1161    fn test_wait_unchecked() {
1162        const TASKS: usize = 1000;
1163        const THREADS: usize = 16;
1164
1165        let b0 = Arc::new(Barrier::new(THREADS + 1));
1166        let b1 = Arc::new(Barrier::new(THREADS + 1));
1167
1168        let builder = ThreadPoolBuilder::with_thread_amount_usize(THREADS).unwrap();
1169        let pool = builder.build().unwrap();
1170
1171        for i in 0..TASKS {
1172            let b0 = b0.clone();
1173            let b1 = b1.clone();
1174
1175            pool.send_job(move || {
1176                if i < THREADS {
1177                    b0.wait();
1178                    b1.wait();
1179                }
1180                panic!("Test panic");
1181            });
1182        }
1183
1184        b0.wait();
1185
1186        assert_eq!(
1187            pool.jobs_running(),
1188            THREADS,
1189            "Incorrect amount of jobs running"
1190        );
1191        assert_eq!(pool.jobs_paniced(), 0);
1192
1193        b1.wait();
1194
1195        pool.wait_until_finished_unchecked();
1196
1197        assert_eq!(pool.jobs_queued(), 0);
1198        assert_eq!(pool.jobs_running(), 0);
1199        assert_eq!(pool.jobs_paniced(), TASKS);
1200    }
1201
1202    #[test]
1203    fn test_clones() {
1204        const TASKS: usize = 1000;
1205        const THREADS: usize = 16;
1206
1207        let pool = ThreadPoolBuilder::with_thread_amount_usize(THREADS)
1208            .unwrap()
1209            .build()
1210            .unwrap();
1211        let clone = pool.clone();
1212        let clone_with_new_state = pool.clone_with_new_state();
1213
1214        let b0 = Arc::new(Barrier::new(THREADS + 1));
1215        let b1 = Arc::new(Barrier::new(THREADS + 1));
1216
1217        for i in 0..TASKS {
1218            let b0_copy = b0.clone();
1219            let b1_copy = b1.clone();
1220
1221            pool.send_job(move || {
1222                if i < THREADS / 2 {
1223                    b0_copy.wait();
1224                    b1_copy.wait();
1225                }
1226            });
1227
1228            let b0_copy = b0.clone();
1229            let b1_copy = b1.clone();
1230
1231            clone_with_new_state.send_job(move || {
1232                if i < THREADS / 2 {
1233                    b0_copy.wait();
1234                    b1_copy.wait();
1235                }
1236                panic!("Test panic")
1237            });
1238        }
1239
1240        b0.wait();
1241
1242        // The /2 is guaranteed because jobs are received in order
1243        assert_eq!(
1244            pool.jobs_running(),
1245            THREADS / 2,
1246            "Incorrect amount of jobs running in pool"
1247        );
1248        assert_eq!(
1249            pool.jobs_paniced(),
1250            0,
1251            "Incorrect amount of jobs paniced in pool"
1252        );
1253
1254        // The /2 is guaranteed because jobs are received in order
1255        assert_eq!(
1256            clone_with_new_state.jobs_running(),
1257            THREADS / 2,
1258            "Incorrect amount of jobs running in clone_with_new_state"
1259        );
1260        assert_eq!(
1261            clone_with_new_state.jobs_paniced(),
1262            0,
1263            "Incorrect amount of jobs paniced in clone_with_new_state"
1264        );
1265
1266        b1.wait();
1267        assert!(
1268            clone_with_new_state.wait_until_finished().is_err(),
1269            "Clone with new state didn't detect panic"
1270        );
1271
1272        assert!(
1273            clone.wait_until_finished().is_ok(),
1274            "Pool incorrectly detected panic"
1275        );
1276
1277        assert_eq!(
1278            pool.jobs_queued(),
1279            0,
1280            "Incorrect amount of jobs queued in pool after wait"
1281        );
1282        assert_eq!(
1283            pool.jobs_running(),
1284            0,
1285            "Incorrect amount of jobs running in pool after wait"
1286        );
1287        assert_eq!(
1288            pool.jobs_paniced(),
1289            0,
1290            "Incorrect amount of jobs paniced in pool after wait"
1291        );
1292
1293        clone_with_new_state.wait_until_finished_unchecked();
1294        assert!(
1295            clone_with_new_state.wait_until_finished().is_err(),
1296            "clone_with_new_state didn't detect panics after wait"
1297        );
1298
1299        assert_eq!(
1300            clone_with_new_state.jobs_queued(),
1301            0,
1302            "Incorrect amount of jobs queued in clone_with_new_state after wait"
1303        );
1304        assert_eq!(
1305            clone_with_new_state.jobs_running(),
1306            0,
1307            "Incorrect amount of jobs running in clone_with_new_state after wait"
1308        );
1309        assert_eq!(
1310            clone_with_new_state.jobs_paniced(),
1311            TASKS,
1312            "Incorrect panics in clone"
1313        );
1314
1315        assert_eq!(
1316            pool.jobs_queued(),
1317            0,
1318            "Incorrect amount of jobs queued in pool after everything"
1319        );
1320        assert_eq!(
1321            pool.jobs_running(),
1322            0,
1323            "Incorrect amount of jobs running in pool after everything"
1324        );
1325        assert_eq!(
1326            pool.jobs_paniced(),
1327            0,
1328            "Incorrect amount of jobs paniced in pool after everything"
1329        );
1330    }
1331
1332    #[test]
1333    fn reset_state_while_running() {
1334        const TASKS: usize = 32;
1335        const THREADS: usize = 16;
1336
1337        let mut pool = ThreadPoolBuilder::with_thread_amount_usize(THREADS)
1338            .unwrap()
1339            .build()
1340            .unwrap();
1341
1342        let b0 = Arc::new(Barrier::new(THREADS + 1));
1343        let b1 = Arc::new(Barrier::new(THREADS + 1));
1344
1345        for i in 0..TASKS {
1346            let b0_copy = b0.clone();
1347            let b1_copy = b1.clone();
1348
1349            pool.send_job(move || {
1350                if i < THREADS {
1351                    b0_copy.wait();
1352                    b1_copy.wait();
1353                }
1354            });
1355        }
1356
1357        b0.wait();
1358
1359        assert_ne!(pool.jobs_queued(), 0);
1360        assert_ne!(pool.jobs_running(), 0);
1361
1362        pool.reset_state();
1363
1364        assert_eq!(pool.jobs_queued(), 0);
1365        assert_eq!(pool.jobs_running(), 0);
1366        assert_eq!(pool.jobs_paniced(), 0);
1367
1368        b1.wait();
1369        pool.wait_until_finished().expect("Nothing should panic");
1370
1371        // Give time for the jobs to execute
1372        sleep(Duration::from_secs(1));
1373
1374        assert_eq!(pool.jobs_queued(), 0);
1375        assert_eq!(pool.jobs_running(), 0);
1376        assert_eq!(pool.jobs_paniced(), 0);
1377    }
1378
1379    #[test]
1380    fn reset_panic_test() {
1381        const TASKS: usize = 32;
1382        const THREADS: usize = 16;
1383
1384        let num = NonZeroUsize::try_from(THREADS).unwrap();
1385        let mut pool = ThreadPoolBuilder::with_thread_amount(num).build().unwrap();
1386
1387        let b0 = Arc::new(Barrier::new(THREADS + 1));
1388        let b1 = Arc::new(Barrier::new(THREADS + 1));
1389
1390        for i in 0..TASKS {
1391            let b0_copy = b0.clone();
1392            let b1_copy = b1.clone();
1393
1394            pool.send_job(move || {
1395                if i < THREADS {
1396                    b0_copy.wait();
1397                    b1_copy.wait();
1398                }
1399                panic!("Test panic");
1400            });
1401        }
1402
1403        b0.wait();
1404
1405        assert_ne!(pool.jobs_queued(), 0);
1406        assert_ne!(pool.jobs_running(), 0);
1407        assert_eq!(pool.jobs_paniced(), 0);
1408
1409        pool.reset_state();
1410
1411        assert_eq!(pool.jobs_queued(), 0);
1412        assert_eq!(pool.jobs_running(), 0);
1413        assert_eq!(pool.jobs_paniced(), 0);
1414
1415        b1.wait();
1416        pool.wait_until_finished().expect("Nothing should panic");
1417
1418        // Give time for the jobs to execute
1419        sleep(Duration::from_secs(1));
1420
1421        assert_eq!(pool.jobs_queued(), 0);
1422        assert_eq!(pool.jobs_running(), 0);
1423        assert_eq!(pool.jobs_paniced(), 0);
1424    }
1425
1426    #[test]
1427    fn test_wait_until_job_done() {
1428        const THREADS: usize = 1;
1429
1430        let builder = ThreadPoolBuilder::with_thread_amount_usize(THREADS).unwrap();
1431        let pool = builder.build().unwrap();
1432
1433        assert!(pool.wait_until_job_done().is_ok());
1434
1435        pool.send_job(|| {});
1436
1437        assert!(pool.wait_until_job_done().is_ok());
1438
1439        assert_eq!(pool.jobs_queued(), 0);
1440        assert_eq!(pool.jobs_running(), 0);
1441        assert_eq!(pool.jobs_paniced(), 0);
1442
1443        pool.send_job(|| panic!("Test panic"));
1444
1445        assert!(pool.wait_until_job_done().is_err());
1446
1447        assert_eq!(pool.jobs_queued(), 0);
1448        assert_eq!(pool.jobs_running(), 0);
1449        assert_eq!(pool.jobs_paniced(), 1);
1450    }
1451
1452    #[test]
1453    fn test_wait_until_job_done_unchecked() {
1454        const THREADS: usize = 1;
1455
1456        let builder = ThreadPoolBuilder::with_thread_amount_usize(THREADS).unwrap();
1457        let pool = builder.build().unwrap();
1458
1459        // This doesn't block forever
1460        pool.wait_until_job_done_unchecked();
1461
1462        pool.send_job(|| {});
1463
1464        pool.wait_until_job_done_unchecked();
1465
1466        assert_eq!(pool.jobs_queued(), 0);
1467        assert_eq!(pool.jobs_running(), 0);
1468        assert_eq!(pool.jobs_paniced(), 0);
1469
1470        pool.send_job(|| panic!("Test panic"));
1471
1472        pool.wait_until_job_done_unchecked();
1473
1474        assert_eq!(pool.jobs_queued(), 0);
1475        assert_eq!(pool.jobs_running(), 0);
1476        assert_eq!(pool.jobs_paniced(), 1);
1477    }
1478
1479    #[test]
1480    #[allow(dead_code)]
1481    fn test_flakiness() {
1482        for _ in 0..10 {
1483            test_wait();
1484            test_wait_unchecked();
1485            deal_with_panics();
1486            receive_value();
1487            test_clones();
1488            reset_state_while_running();
1489            test_wait_until_job_done_unchecked();
1490            test_wait_until_job_done();
1491            reset_panic_test();
1492        }
1493    }
1494}