blocking_threadpool/
lib.rs

1// Copyright 2014 The Rust Project Developers. See the COPYRIGHT
2// file at the top-level directory of this distribution and at
3// http://rust-lang.org/COPYRIGHT.
4//
5// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
8// option. This file may not be copied, modified, or distributed
9// except according to those terms.
10
11//! A thread pool used to execute functions in parallel.
12//!
13//! Spawns a specified number of worker threads and replenishes the pool if any worker threads
14//! panic.
15//!
16//! # Examples
17//!
18//! ## Synchronized with a channel
19//!
20//! Every thread sends one message over the channel, which then is collected with the `take()`.
21//!
22//! ```
23//! use blocking_threadpool::ThreadPool;
24//! use std::sync::mpsc::channel;
25//!
26//! let n_workers = 4;
27//! let n_jobs = 8;
28//! let pool = ThreadPool::new(n_workers);
29//!
30//! let (tx, rx) = channel();
31//! for _ in 0..n_jobs {
32//!     let tx = tx.clone();
33//!     pool.execute(move|| {
34//!         tx.send(1).expect("channel will be there waiting for the pool");
35//!     });
36//! }
37//!
38//! assert_eq!(rx.iter().take(n_jobs).fold(0, |a, b| a + b), 8);
39//! ```
40//!
41//! ## Synchronized with a barrier
42//!
43//! Keep in mind, if a barrier synchronizes more jobs than you have workers in the pool,
44//! you will end up with a [deadlock](https://en.wikipedia.org/wiki/Deadlock)
45//! at the barrier which is [not considered unsafe](
46//! https://doc.rust-lang.org/reference/behavior-not-considered-unsafe.html).
47//!
48//! ```
49//! use blocking_threadpool::ThreadPool;
50//! use std::sync::{Arc, Barrier};
51//! use std::sync::atomic::{AtomicUsize, Ordering};
52//!
53//! // create at least as many workers as jobs or you will deadlock yourself
54//! let n_workers = 42;
55//! let n_jobs = 23;
56//! let pool = ThreadPool::new(n_workers);
57//! let an_atomic = Arc::new(AtomicUsize::new(0));
58//!
59//! assert!(n_jobs <= n_workers, "too many jobs, will deadlock");
60//!
61//! // create a barrier that waits for all jobs plus the starter thread
62//! let barrier = Arc::new(Barrier::new(n_jobs + 1));
63//! for _ in 0..n_jobs {
64//!     let barrier = barrier.clone();
65//!     let an_atomic = an_atomic.clone();
66//!
67//!     pool.execute(move|| {
68//!         // do the heavy work
69//!         an_atomic.fetch_add(1, Ordering::Relaxed);
70//!
71//!         // then wait for the other threads
72//!         barrier.wait();
73//!     });
74//! }
75//!
76//! // wait for the threads to finish the work
77//! barrier.wait();
78//! assert_eq!(an_atomic.load(Ordering::SeqCst), /* n_jobs = */ 23);
79//! ```
80
81use std::fmt;
82use std::sync::atomic::{AtomicUsize, Ordering};
83use std::sync::{Arc, Condvar, Mutex};
84use std::thread;
85
86use crossbeam_channel as cbc;
87
88trait FnBox {
89    fn call_box(self: Box<Self>);
90}
91
92impl<F: FnOnce()> FnBox for F {
93    fn call_box(self: Box<F>) {
94        (*self)()
95    }
96}
97
98type Thunk<'a> = Box<dyn FnBox + Send + 'a>;
99
100struct Sentinel<'a> {
101    shared_data: &'a Arc<ThreadPoolSharedData>,
102    active: bool,
103}
104
105impl<'a> Sentinel<'a> {
106    fn new(shared_data: &'a Arc<ThreadPoolSharedData>) -> Sentinel<'a> {
107        Sentinel {
108            shared_data,
109            active: true,
110        }
111    }
112
113    /// Cancel and destroy this sentinel.
114    fn cancel(mut self) {
115        self.active = false;
116    }
117}
118
119impl<'a> Drop for Sentinel<'a> {
120    fn drop(&mut self) {
121        if self.active {
122            self.shared_data.active_count.fetch_sub(1, Ordering::SeqCst);
123            if thread::panicking() {
124                self.shared_data.panic_count.fetch_add(1, Ordering::SeqCst);
125            }
126            self.shared_data.no_work_notify_all();
127            spawn_in_pool(self.shared_data.clone())
128        }
129    }
130}
131
132/// [`ThreadPool`] factory, which can be used in order to configure the properties of the
133/// [`ThreadPool`].
134///
135/// The three configuration options available:
136///
137/// * `num_threads`: maximum number of threads that will be alive at any given moment by the built
138///   [`ThreadPool`]
139/// * `thread_name`: thread name for each of the threads spawned by the built [`ThreadPool`]
140/// * `thread_stack_size`: stack size (in bytes) for each of the threads spawned by the built
141///   [`ThreadPool`]
142///
143/// [`ThreadPool`]: struct.ThreadPool.html
144///
145/// # Examples
146///
147/// Build a [`ThreadPool`] that uses a maximum of eight threads simultaneously and each thread has
148/// a 8 MB stack size:
149///
150/// ```
151/// let pool = blocking_threadpool::Builder::new()
152///     .num_threads(8)
153///     .thread_stack_size(8_000_000)
154///     .build();
155/// ```
156#[derive(Clone, Default)]
157pub struct Builder {
158    num_threads: Option<usize>,
159    thread_name: Option<String>,
160    thread_stack_size: Option<usize>,
161    queue_len: Option<usize>,
162}
163
164impl Builder {
165    /// Initiate a new [`Builder`].
166    ///
167    /// [`Builder`]: struct.Builder.html
168    ///
169    /// # Examples
170    ///
171    /// ```
172    /// let builder = blocking_threadpool::Builder::new();
173    /// ```
174    pub fn new() -> Builder {
175        Builder {
176            num_threads: None,
177            thread_name: None,
178            thread_stack_size: None,
179            queue_len: None,
180        }
181    }
182
183    /// Set the maximum number of worker-threads that will be alive at any given moment by the built
184    /// [`ThreadPool`]. If not specified, defaults the number of threads to the number of CPUs.
185    ///
186    /// [`ThreadPool`]: struct.ThreadPool.html
187    ///
188    /// # Panics
189    ///
190    /// This method will panic if `num_threads` is 0.
191    ///
192    /// # Examples
193    ///
194    /// No more than eight threads will be alive simultaneously for this pool:
195    ///
196    /// ```
197    /// use std::thread;
198    ///
199    /// let pool = blocking_threadpool::Builder::new()
200    ///     .num_threads(8)
201    ///     .build();
202    ///
203    /// for _ in 0..100 {
204    ///     pool.execute(|| {
205    ///         println!("Hello from a worker thread!")
206    ///     })
207    /// }
208    /// ```
209    pub fn num_threads(mut self, num_threads: usize) -> Builder {
210        assert!(num_threads > 0);
211        self.num_threads = Some(num_threads);
212        self
213    }
214
215    /// Set the maximum number of pending jobs that can be queued to
216    /// the [`ThreadPool`]. Once the queue is full further calls will
217    /// block until slots become available. A `len` of 0 will always
218    /// block until a thread is available.  If not specified, defaults
219    /// to unlimited.
220    ///
221    /// [`ThreadPool`]: struct.ThreadPool.html
222    ///
223    /// # Panics
224    ///
225    /// This method will panic if `len` is less-than 0;
226    ///
227    /// # Examples
228    ///
229    /// With a single thread and a queue len of 1, the final execute
230    /// will have to wait until the first job finishes to be queued.
231    ///
232    /// ```
233    /// use std::thread;
234    /// use std::time::Duration;
235    ///
236    /// let pool = blocking_threadpool::Builder::new()
237    ///     .num_threads(1)
238    ///     .queue_len(1)
239    ///     .build();
240    ///
241    /// for _ in 0..2 {
242    ///     pool.execute(|| {
243    ///         println!("Hello from a worker thread! I'm going to rest now...");
244    ///         thread::sleep(Duration::from_secs(10));
245    ///         println!("All done!");
246    ///     })
247    /// }
248    ///
249    /// pool.execute(|| {
250    ///   println!("Hello from 10 seconds in the future!");
251    /// });
252    /// ```
253    pub fn queue_len(mut self, len: usize) -> Builder {
254        self.queue_len = Some(len);
255        self
256    }
257
258    /// Set the thread name for each of the threads spawned by the built [`ThreadPool`]. If not
259    /// specified, threads spawned by the thread pool will be unnamed.
260    ///
261    /// [`ThreadPool`]: struct.ThreadPool.html
262    ///
263    /// # Examples
264    ///
265    /// Each thread spawned by this pool will have the name "foo":
266    ///
267    /// ```
268    /// use std::thread;
269    ///
270    /// let pool = blocking_threadpool::Builder::new()
271    ///     .thread_name("foo".into())
272    ///     .build();
273    ///
274    /// for _ in 0..100 {
275    ///     pool.execute(|| {
276    ///         assert_eq!(thread::current().name(), Some("foo"));
277    ///     })
278    /// }
279    /// ```
280    pub fn thread_name(mut self, name: String) -> Builder {
281        self.thread_name = Some(name);
282        self
283    }
284
285    /// Set the stack size (in bytes) for each of the threads spawned by the built [`ThreadPool`].
286    /// If not specified, threads spawned by the threadpool will have a stack size [as specified in
287    /// the `std::thread` documentation][thread].
288    ///
289    /// [thread]: https://doc.rust-lang.org/nightly/std/thread/index.html#stack-size
290    /// [`ThreadPool`]: struct.ThreadPool.html
291    ///
292    /// # Examples
293    ///
294    /// Each thread spawned by this pool will have a 4 MB stack:
295    ///
296    /// ```
297    /// let pool = blocking_threadpool::Builder::new()
298    ///     .thread_stack_size(4_000_000)
299    ///     .build();
300    ///
301    /// for _ in 0..100 {
302    ///     pool.execute(|| {
303    ///         println!("This thread has a 4 MB stack size!");
304    ///     })
305    /// }
306    /// ```
307    pub fn thread_stack_size(mut self, size: usize) -> Builder {
308        self.thread_stack_size = Some(size);
309        self
310    }
311
312    /// Finalize the [`Builder`] and build the [`ThreadPool`].
313    ///
314    /// [`Builder`]: struct.Builder.html
315    /// [`ThreadPool`]: struct.ThreadPool.html
316    ///
317    /// # Examples
318    ///
319    /// ```
320    /// let pool = blocking_threadpool::Builder::new()
321    ///     .num_threads(8)
322    ///     .thread_stack_size(4_000_000)
323    ///     .build();
324    /// ```
325    pub fn build(self) -> ThreadPool {
326        let (tx, rx) = self.queue_len.map_or_else(
327            cbc::unbounded,
328            cbc::bounded
329        );
330
331        let num_threads = self.num_threads.unwrap_or_else(num_cpus::get);
332
333        let shared_data = Arc::new(ThreadPoolSharedData {
334            name: self.thread_name,
335            job_receiver: rx,
336            empty_condvar: Condvar::new(),
337            empty_trigger: Mutex::new(()),
338            join_generation: AtomicUsize::new(0),
339            queued_count: AtomicUsize::new(0),
340            active_count: AtomicUsize::new(0),
341            max_thread_count: AtomicUsize::new(num_threads),
342            panic_count: AtomicUsize::new(0),
343            stack_size: self.thread_stack_size,
344        });
345
346        // Threadpool threads
347        for _ in 0..num_threads {
348            spawn_in_pool(shared_data.clone());
349        }
350
351        ThreadPool {
352            jobs: tx,
353            shared_data,
354        }
355    }
356}
357
358struct ThreadPoolSharedData {
359    name: Option<String>,
360    job_receiver: cbc::Receiver<Thunk<'static>>,
361    empty_trigger: Mutex<()>,
362    empty_condvar: Condvar,
363    join_generation: AtomicUsize,
364    queued_count: AtomicUsize,
365    active_count: AtomicUsize,
366    max_thread_count: AtomicUsize,
367    panic_count: AtomicUsize,
368    stack_size: Option<usize>,
369}
370
371impl ThreadPoolSharedData {
372    fn has_work(&self) -> bool {
373        self.queued_count.load(Ordering::SeqCst) > 0 || self.active_count.load(Ordering::SeqCst) > 0
374    }
375
376    /// Notify all observers joining this pool if there is no more work to do.
377    fn no_work_notify_all(&self) {
378        if !self.has_work() {
379            let _lock = self.empty_trigger
380                .lock()
381                .expect("Unable to notify all joining threads");
382            self.empty_condvar.notify_all();
383        }
384    }
385}
386
387/// Abstraction of a thread pool for basic parallelism.
388pub struct ThreadPool {
389    // How the threadpool communicates with subthreads.
390    //
391    // This is the only such Sender, so when it is dropped all subthreads will
392    // quit.
393    jobs: cbc::Sender<Thunk<'static>>,
394    shared_data: Arc<ThreadPoolSharedData>,
395}
396
397impl ThreadPool {
398    /// Creates a new thread pool capable of executing `num_threads` number of jobs concurrently.
399    ///
400    /// # Panics
401    ///
402    /// This function will panic if `num_threads` is 0.
403    ///
404    /// # Examples
405    ///
406    /// Create a new thread pool capable of executing four jobs concurrently:
407    ///
408    /// ```
409    /// use blocking_threadpool::ThreadPool;
410    ///
411    /// let pool = ThreadPool::new(4);
412    /// ```
413    pub fn new(num_threads: usize) -> ThreadPool {
414        Builder::new().num_threads(num_threads).build()
415    }
416
417    /// Creates a new thread pool capable of executing `num_threads` number of jobs concurrently.
418    /// Each thread will have the [name][thread name] `name`.
419    ///
420    /// # Panics
421    ///
422    /// This function will panic if `num_threads` is 0.
423    ///
424    /// # Examples
425    ///
426    /// ```rust
427    /// use std::thread;
428    /// use blocking_threadpool::ThreadPool;
429    ///
430    /// let pool = ThreadPool::with_name("worker".into(), 2);
431    /// for _ in 0..2 {
432    ///     pool.execute(|| {
433    ///         assert_eq!(
434    ///             thread::current().name(),
435    ///             Some("worker")
436    ///         );
437    ///     });
438    /// }
439    /// pool.join();
440    /// ```
441    ///
442    /// [thread name]: https://doc.rust-lang.org/std/thread/struct.Thread.html#method.name
443    pub fn with_name(name: String, num_threads: usize) -> ThreadPool {
444        Builder::new()
445            .num_threads(num_threads)
446            .thread_name(name)
447            .build()
448    }
449
450    /// **Deprecated: Use [`ThreadPool::with_name`](#method.with_name)**
451    #[inline(always)]
452    #[deprecated(since = "1.4.0", note = "use blocking_threadpool::with_name")]
453    pub fn new_with_name(name: String, num_threads: usize) -> ThreadPool {
454        Self::with_name(name, num_threads)
455    }
456
457    /// Executes the function `job` on a thread in the pool.
458    ///
459    /// # Examples
460    ///
461    /// Execute four jobs on a thread pool that can run two jobs concurrently:
462    ///
463    /// ```
464    /// use blocking_threadpool::ThreadPool;
465    ///
466    /// let pool = ThreadPool::new(2);
467    /// pool.execute(|| println!("hello"));
468    /// pool.execute(|| println!("world"));
469    /// pool.execute(|| println!("foo"));
470    /// pool.execute(|| println!("bar"));
471    /// pool.join();
472    /// ```
473    pub fn execute<F>(&self, job: F)
474    where
475        F: FnOnce() + Send + 'static,
476    {
477        self.shared_data.queued_count.fetch_add(1, Ordering::SeqCst);
478        self.jobs
479            .send(Box::new(job))
480            .expect("ThreadPool::execute unable to send job into queue.");
481    }
482
483    /// Returns the number of jobs waiting to executed in the pool.
484    ///
485    /// # Examples
486    ///
487    /// ```
488    /// use blocking_threadpool::ThreadPool;
489    /// use std::time::Duration;
490    /// use std::thread::sleep;
491    ///
492    /// let pool = ThreadPool::new(2);
493    /// for _ in 0..10 {
494    ///     pool.execute(|| {
495    ///         sleep(Duration::from_secs(100));
496    ///     });
497    /// }
498    ///
499    /// sleep(Duration::from_secs(1)); // wait for threads to start
500    /// assert_eq!(8, pool.queued_count());
501    /// ```
502    pub fn queued_count(&self) -> usize {
503        self.shared_data.queued_count.load(Ordering::Relaxed)
504    }
505
506    /// Returns the number of currently active threads.
507    ///
508    /// # Examples
509    ///
510    /// ```
511    /// use blocking_threadpool::ThreadPool;
512    /// use std::time::Duration;
513    /// use std::thread::sleep;
514    ///
515    /// let pool = ThreadPool::new(4);
516    /// for _ in 0..10 {
517    ///     pool.execute(move || {
518    ///         sleep(Duration::from_secs(100));
519    ///     });
520    /// }
521    ///
522    /// sleep(Duration::from_secs(1)); // wait for threads to start
523    /// assert_eq!(4, pool.active_count());
524    /// ```
525    pub fn active_count(&self) -> usize {
526        self.shared_data.active_count.load(Ordering::SeqCst)
527    }
528
529    /// Returns the maximum number of threads the pool will execute concurrently.
530    ///
531    /// # Examples
532    ///
533    /// ```
534    /// use blocking_threadpool::ThreadPool;
535    ///
536    /// let mut pool = ThreadPool::new(4);
537    /// assert_eq!(4, pool.max_count());
538    ///
539    /// pool.set_num_threads(8);
540    /// assert_eq!(8, pool.max_count());
541    /// ```
542    pub fn max_count(&self) -> usize {
543        self.shared_data.max_thread_count.load(Ordering::Relaxed)
544    }
545
546    /// Returns the number of panicked threads over the lifetime of the pool.
547    ///
548    /// # Examples
549    ///
550    /// ```
551    /// use blocking_threadpool::ThreadPool;
552    ///
553    /// let pool = ThreadPool::new(4);
554    /// for n in 0..10 {
555    ///     pool.execute(move || {
556    ///         // simulate a panic
557    ///         if n % 2 == 0 {
558    ///             panic!()
559    ///         }
560    ///     });
561    /// }
562    /// pool.join();
563    ///
564    /// assert_eq!(5, pool.panic_count());
565    /// ```
566    pub fn panic_count(&self) -> usize {
567        self.shared_data.panic_count.load(Ordering::Relaxed)
568    }
569
570    /// **Deprecated: Use [`ThreadPool::set_num_threads`](#method.set_num_threads)**
571    #[deprecated(since = "1.3.0", note = "use blocking_threadpool::set_num_threads")]
572    pub fn set_threads(&mut self, num_threads: usize) {
573        self.set_num_threads(num_threads)
574    }
575
576    /// Sets the number of worker-threads to use as `num_threads`.
577    /// Can be used to change the threadpool size during runtime.
578    /// Will not abort already running or waiting threads.
579    ///
580    /// # Panics
581    ///
582    /// This function will panic if `num_threads` is 0.
583    ///
584    /// # Examples
585    ///
586    /// ```
587    /// use blocking_threadpool::ThreadPool;
588    /// use std::time::Duration;
589    /// use std::thread::sleep;
590    ///
591    /// let mut pool = ThreadPool::new(4);
592    /// for _ in 0..10 {
593    ///     pool.execute(move || {
594    ///         sleep(Duration::from_secs(100));
595    ///     });
596    /// }
597    ///
598    /// sleep(Duration::from_secs(1)); // wait for threads to start
599    /// assert_eq!(4, pool.active_count());
600    /// assert_eq!(6, pool.queued_count());
601    ///
602    /// // Increase thread capacity of the pool
603    /// pool.set_num_threads(8);
604    ///
605    /// sleep(Duration::from_secs(1)); // wait for new threads to start
606    /// assert_eq!(8, pool.active_count());
607    /// assert_eq!(2, pool.queued_count());
608    ///
609    /// // Decrease thread capacity of the pool
610    /// // No active threads are killed
611    /// pool.set_num_threads(4);
612    ///
613    /// assert_eq!(8, pool.active_count());
614    /// assert_eq!(2, pool.queued_count());
615    /// ```
616    pub fn set_num_threads(&mut self, num_threads: usize) {
617        assert!(num_threads >= 1);
618        let prev_num_threads = self
619            .shared_data
620            .max_thread_count
621            .swap(num_threads, Ordering::Release);
622        if let Some(num_spawn) = num_threads.checked_sub(prev_num_threads) {
623            // Spawn new threads
624            for _ in 0..num_spawn {
625                spawn_in_pool(self.shared_data.clone());
626            }
627        }
628    }
629
630    /// Block the current thread until all jobs in the pool have been executed.
631    ///
632    /// Calling `join` on an empty pool will cause an immediate return.
633    /// `join` may be called from multiple threads concurrently.
634    /// A `join` is an atomic point in time. All threads joining before the join
635    /// event will exit together even if the pool is processing new jobs by the
636    /// time they get scheduled.
637    ///
638    /// Calling `join` from a thread within the pool will cause a deadlock. This
639    /// behavior is considered safe.
640    ///
641    /// # Examples
642    ///
643    /// ```
644    /// use blocking_threadpool::ThreadPool;
645    /// use std::sync::Arc;
646    /// use std::sync::atomic::{AtomicUsize, Ordering};
647    ///
648    /// let pool = ThreadPool::new(8);
649    /// let test_count = Arc::new(AtomicUsize::new(0));
650    ///
651    /// for _ in 0..42 {
652    ///     let test_count = test_count.clone();
653    ///     pool.execute(move || {
654    ///         test_count.fetch_add(1, Ordering::Relaxed);
655    ///     });
656    /// }
657    ///
658    /// pool.join();
659    /// assert_eq!(42, test_count.load(Ordering::Relaxed));
660    /// ```
661    pub fn join(&self) {
662        // fast path requires no mutex
663        if !self.shared_data.has_work() {
664            return;
665        }
666
667        let generation = self.shared_data.join_generation.load(Ordering::SeqCst);
668        let mut lock = self.shared_data.empty_trigger.lock().unwrap();
669
670        while generation == self.shared_data.join_generation.load(Ordering::Relaxed)
671            && self.shared_data.has_work()
672        {
673            lock = self.shared_data.empty_condvar.wait(lock).unwrap();
674        }
675
676        // increase generation if we are the first thread to come out of the loop
677        let _ = self.shared_data.join_generation.compare_exchange(
678            generation,
679            generation.wrapping_add(1),
680            Ordering::SeqCst,
681            Ordering::SeqCst);
682    }
683}
684
685impl Clone for ThreadPool {
686    /// Cloning a pool will create a new handle to the pool.
687    /// The behavior is similar to [Arc](https://doc.rust-lang.org/stable/std/sync/struct.Arc.html).
688    ///
689    /// We could for example submit jobs from multiple threads concurrently.
690    ///
691    /// ```
692    /// use blocking_threadpool::ThreadPool;
693    /// use std::thread;
694    /// use std::sync::mpsc::channel;
695    ///
696    /// let pool = ThreadPool::with_name("clone example".into(), 2);
697    ///
698    /// let results = (0..2)
699    ///     .map(|i| {
700    ///         let pool = pool.clone();
701    ///         thread::spawn(move || {
702    ///             let (tx, rx) = channel();
703    ///             for i in 1..12 {
704    ///                 let tx = tx.clone();
705    ///                 pool.execute(move || {
706    ///                     tx.send(i).expect("channel will be waiting");
707    ///                 });
708    ///             }
709    ///             drop(tx);
710    ///             if i == 0 {
711    ///                 rx.iter().fold(0, |accumulator, element| accumulator + element)
712    ///             } else {
713    ///                 rx.iter().fold(1, |accumulator, element| accumulator * element)
714    ///             }
715    ///         })
716    ///     })
717    ///     .map(|join_handle| join_handle.join().expect("collect results from threads"))
718    ///     .collect::<Vec<usize>>();
719    ///
720    /// assert_eq!(vec![66, 39916800], results);
721    /// ```
722    fn clone(&self) -> ThreadPool {
723        ThreadPool {
724            jobs: self.jobs.clone(),
725            shared_data: self.shared_data.clone(),
726        }
727    }
728}
729
730/// Create a thread pool with one thread per CPU.
731/// On machines with hyperthreading,
732/// this will create one thread per hyperthread.
733impl Default for ThreadPool {
734    fn default() -> Self {
735        ThreadPool::new(num_cpus::get())
736    }
737}
738
739impl fmt::Debug for ThreadPool {
740    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
741        f.debug_struct("ThreadPool")
742            .field("name", &self.shared_data.name)
743            .field("queued_count", &self.queued_count())
744            .field("active_count", &self.active_count())
745            .field("max_count", &self.max_count())
746            .finish()
747    }
748}
749
750impl PartialEq for ThreadPool {
751    /// Check if you are working with the same pool
752    ///
753    /// ```
754    /// use blocking_threadpool::ThreadPool;
755    ///
756    /// let a = ThreadPool::new(2);
757    /// let b = ThreadPool::new(2);
758    ///
759    /// assert_eq!(a, a);
760    /// assert_eq!(b, b);
761    ///
762    /// # // TODO: change this to assert_ne in the future
763    /// assert!(a != b);
764    /// assert!(b != a);
765    /// ```
766    fn eq(&self, other: &ThreadPool) -> bool {
767        let a: &ThreadPoolSharedData = &self.shared_data;
768        let b: &ThreadPoolSharedData = &other.shared_data;
769        std::ptr::eq(a, b)
770        // with rust 1.17 and late:
771        // Arc::ptr_eq(&self.shared_data, &other.shared_data)
772    }
773}
774impl Eq for ThreadPool {}
775
776fn spawn_in_pool(shared_data: Arc<ThreadPoolSharedData>) {
777    let mut builder = thread::Builder::new();
778    if let Some(ref name) = shared_data.name {
779        builder = builder.name(name.clone());
780    }
781    if let Some(ref stack_size) = shared_data.stack_size {
782        builder = builder.stack_size(stack_size.to_owned());
783    }
784    builder
785        .spawn(move || {
786            // Will spawn a new thread on panic unless it is cancelled.
787            let sentinel = Sentinel::new(&shared_data);
788
789            loop {
790                // Shutdown this thread if the pool has become smaller
791                let thread_counter_val = shared_data.active_count.load(Ordering::Acquire);
792                let max_thread_count_val = shared_data.max_thread_count.load(Ordering::Relaxed);
793                if thread_counter_val >= max_thread_count_val {
794                    break;
795                }
796                let message = shared_data
797                    .job_receiver
798                    .recv();
799
800                let job = match message {
801                    Ok(job) => job,
802                    // The ThreadPool was dropped.
803                    Err(..) => break,
804                };
805                // Do not allow IR around the job execution
806                shared_data.active_count.fetch_add(1, Ordering::SeqCst);
807                shared_data.queued_count.fetch_sub(1, Ordering::SeqCst);
808
809                job.call_box();
810
811                shared_data.active_count.fetch_sub(1, Ordering::SeqCst);
812                shared_data.no_work_notify_all();
813            }
814
815            sentinel.cancel();
816        })
817        .unwrap();
818}
819
820#[cfg(test)]
821mod test {
822    use super::{Builder, ThreadPool};
823    use std::sync::atomic::{AtomicUsize, Ordering};
824    use std::sync::mpsc::{channel, sync_channel};
825    use std::sync::{Arc, Barrier, Mutex};
826    use std::thread::{self, sleep};
827    use std::time::Duration;
828
829    const TEST_TASKS: usize = 4;
830
831    #[test]
832    fn test_set_num_threads_increasing() {
833        let new_thread_amount = TEST_TASKS + 8;
834        let mut pool = ThreadPool::new(TEST_TASKS);
835        for _ in 0..TEST_TASKS {
836            pool.execute(move || sleep(Duration::from_secs(23)));
837        }
838        sleep(Duration::from_secs(1));
839        assert_eq!(pool.active_count(), TEST_TASKS);
840
841        pool.set_num_threads(new_thread_amount);
842
843        for _ in 0..(new_thread_amount - TEST_TASKS) {
844            pool.execute(move || sleep(Duration::from_secs(23)));
845        }
846        sleep(Duration::from_secs(1));
847        assert_eq!(pool.active_count(), new_thread_amount);
848
849        pool.join();
850    }
851
852    #[test]
853    fn test_set_num_threads_decreasing() {
854        let new_thread_amount = 2;
855        let mut pool = ThreadPool::new(TEST_TASKS);
856        for _ in 0..TEST_TASKS {
857            pool.execute(move || {
858                assert_eq!(1, 1);
859            });
860        }
861        pool.set_num_threads(new_thread_amount);
862        for _ in 0..new_thread_amount {
863            pool.execute(move || sleep(Duration::from_secs(23)));
864        }
865        sleep(Duration::from_secs(1));
866        assert_eq!(pool.active_count(), new_thread_amount);
867
868        pool.join();
869    }
870
871    #[test]
872    fn test_active_count() {
873        let pool = ThreadPool::new(TEST_TASKS);
874        for _ in 0..2 * TEST_TASKS {
875            pool.execute(move || loop {
876                sleep(Duration::from_secs(10))
877            });
878        }
879        sleep(Duration::from_secs(1));
880        let active_count = pool.active_count();
881        assert_eq!(active_count, TEST_TASKS);
882        let initialized_count = pool.max_count();
883        assert_eq!(initialized_count, TEST_TASKS);
884    }
885
886    #[test]
887    fn test_works() {
888        let pool = ThreadPool::new(TEST_TASKS);
889
890        let (tx, rx) = channel();
891        for _ in 0..TEST_TASKS {
892            let tx = tx.clone();
893            pool.execute(move || {
894                tx.send(1).unwrap();
895            });
896        }
897
898        assert_eq!(rx.iter().take(TEST_TASKS).sum::<usize>(), TEST_TASKS);
899    }
900
901    #[test]
902    #[should_panic]
903    fn test_zero_tasks_panic() {
904        ThreadPool::new(0);
905    }
906
907    #[test]
908    fn test_recovery_from_subtask_panic() {
909        let pool = ThreadPool::new(TEST_TASKS);
910
911        // Panic all the existing threads.
912        for _ in 0..TEST_TASKS {
913            pool.execute(move || panic!("Ignore this panic, it must!"));
914        }
915        pool.join();
916
917        assert_eq!(pool.panic_count(), TEST_TASKS);
918
919        // Ensure new threads were spawned to compensate.
920        let (tx, rx) = channel();
921        for _ in 0..TEST_TASKS {
922            let tx = tx.clone();
923            pool.execute(move || {
924                tx.send(1).unwrap();
925            });
926        }
927
928        assert_eq!(rx.iter().take(TEST_TASKS).sum::<usize>(), TEST_TASKS);
929    }
930
931    #[test]
932    fn test_should_not_panic_on_drop_if_subtasks_panic_after_drop() {
933        let pool = ThreadPool::new(TEST_TASKS);
934        let waiter = Arc::new(Barrier::new(TEST_TASKS + 1));
935
936        // Panic all the existing threads in a bit.
937        for _ in 0..TEST_TASKS {
938            let waiter = waiter.clone();
939            pool.execute(move || {
940                waiter.wait();
941                panic!("Ignore this panic, it should!");
942            });
943        }
944
945        drop(pool);
946
947        // Kick off the failure.
948        waiter.wait();
949    }
950
951    #[test]
952    fn test_massive_task_creation() {
953        let test_tasks = 4_200_000;
954
955        let pool = ThreadPool::new(TEST_TASKS);
956        let b0 = Arc::new(Barrier::new(TEST_TASKS + 1));
957        let b1 = Arc::new(Barrier::new(TEST_TASKS + 1));
958
959        let (tx, rx) = channel();
960
961        for i in 0..test_tasks {
962            let tx = tx.clone();
963            let (b0, b1) = (b0.clone(), b1.clone());
964
965            pool.execute(move || {
966                // Wait until the pool has been filled once.
967                if i < TEST_TASKS {
968                    b0.wait();
969                    // wait so the pool can be measured
970                    b1.wait();
971                }
972
973                tx.send(1).unwrap();
974            });
975        }
976
977        b0.wait();
978        assert_eq!(pool.active_count(), TEST_TASKS);
979        b1.wait();
980
981        assert_eq!(rx.iter().take(test_tasks).sum::<usize>(), test_tasks);
982        pool.join();
983
984        let atomic_active_count = pool.active_count();
985        assert!(
986            atomic_active_count == 0,
987            "atomic_active_count: {}",
988            atomic_active_count
989        );
990    }
991
992    #[test]
993    fn test_shrink() {
994        let test_tasks_begin = TEST_TASKS + 2;
995
996        let mut pool = ThreadPool::new(test_tasks_begin);
997        let b0 = Arc::new(Barrier::new(test_tasks_begin + 1));
998        let b1 = Arc::new(Barrier::new(test_tasks_begin + 1));
999
1000        for _ in 0..test_tasks_begin {
1001            let (b0, b1) = (b0.clone(), b1.clone());
1002            pool.execute(move || {
1003                b0.wait();
1004                b1.wait();
1005            });
1006        }
1007
1008        let b2 = Arc::new(Barrier::new(TEST_TASKS + 1));
1009        let b3 = Arc::new(Barrier::new(TEST_TASKS + 1));
1010
1011        for _ in 0..TEST_TASKS {
1012            let (b2, b3) = (b2.clone(), b3.clone());
1013            pool.execute(move || {
1014                b2.wait();
1015                b3.wait();
1016            });
1017        }
1018
1019        b0.wait();
1020        pool.set_num_threads(TEST_TASKS);
1021
1022        assert_eq!(pool.active_count(), test_tasks_begin);
1023        b1.wait();
1024
1025        b2.wait();
1026        assert_eq!(pool.active_count(), TEST_TASKS);
1027        b3.wait();
1028    }
1029
1030    #[test]
1031    fn test_name() {
1032        let name = "test";
1033        let mut pool = ThreadPool::with_name(name.to_owned(), 2);
1034        let (tx, rx) = sync_channel(0);
1035
1036        // initial thread should share the name "test"
1037        for _ in 0..2 {
1038            let tx = tx.clone();
1039            pool.execute(move || {
1040                let name = thread::current().name().unwrap().to_owned();
1041                tx.send(name).unwrap();
1042            });
1043        }
1044
1045        // new spawn thread should share the name "test" too.
1046        pool.set_num_threads(3);
1047        let tx_clone = tx.clone();
1048        pool.execute(move || {
1049            let name = thread::current().name().unwrap().to_owned();
1050            tx_clone.send(name).unwrap();
1051            panic!();
1052        });
1053
1054        // recover thread should share the name "test" too.
1055        pool.execute(move || {
1056            let name = thread::current().name().unwrap().to_owned();
1057            tx.send(name).unwrap();
1058        });
1059
1060        for thread_name in rx.iter().take(4) {
1061            assert_eq!(name, thread_name);
1062        }
1063    }
1064
1065    #[test]
1066    fn test_debug() {
1067        let pool = ThreadPool::new(4);
1068        let debug = format!("{:?}", pool);
1069        assert_eq!(
1070            debug,
1071            "ThreadPool { name: None, queued_count: 0, active_count: 0, max_count: 4 }"
1072        );
1073
1074        let pool = ThreadPool::with_name("hello".into(), 4);
1075        let debug = format!("{:?}", pool);
1076        assert_eq!(
1077            debug,
1078            "ThreadPool { name: Some(\"hello\"), queued_count: 0, active_count: 0, max_count: 4 }"
1079        );
1080
1081        let pool = ThreadPool::new(4);
1082        pool.execute(move || sleep(Duration::from_secs(5)));
1083        sleep(Duration::from_secs(1));
1084        let debug = format!("{:?}", pool);
1085        assert_eq!(
1086            debug,
1087            "ThreadPool { name: None, queued_count: 0, active_count: 1, max_count: 4 }"
1088        );
1089    }
1090
1091    #[test]
1092    fn test_repeate_join() {
1093        let pool = ThreadPool::with_name("repeate join test".into(), 8);
1094        let test_count = Arc::new(AtomicUsize::new(0));
1095
1096        for _ in 0..42 {
1097            let test_count = test_count.clone();
1098            pool.execute(move || {
1099                sleep(Duration::from_secs(2));
1100                test_count.fetch_add(1, Ordering::Release);
1101            });
1102        }
1103
1104        println!("{:?}", pool);
1105        pool.join();
1106        assert_eq!(42, test_count.load(Ordering::Acquire));
1107
1108        for _ in 0..42 {
1109            let test_count = test_count.clone();
1110            pool.execute(move || {
1111                sleep(Duration::from_secs(2));
1112                test_count.fetch_add(1, Ordering::Relaxed);
1113            });
1114        }
1115        pool.join();
1116        assert_eq!(84, test_count.load(Ordering::Relaxed));
1117    }
1118
1119    #[test]
1120    fn test_multi_join() {
1121        use std::sync::mpsc::TryRecvError::*;
1122
1123        // Toggle the following lines to debug the deadlock
1124        fn error(_s: String) {
1125            //use ::std::io::Write;
1126            //let stderr = ::std::io::stderr();
1127            //let mut stderr = stderr.lock();
1128            //stderr.write(&_s.as_bytes()).is_ok();
1129        }
1130
1131        let pool0 = ThreadPool::with_name("multi join pool0".into(), 4);
1132        let pool1 = ThreadPool::with_name("multi join pool1".into(), 4);
1133        let (tx, rx) = channel();
1134
1135        for i in 0..8 {
1136            let pool1 = pool1.clone();
1137            let pool0_ = pool0.clone();
1138            let tx = tx.clone();
1139            pool0.execute(move || {
1140                pool1.execute(move || {
1141                    error(format!("p1: {} -=- {:?}\n", i, pool0_));
1142                    pool0_.join();
1143                    error(format!("p1: send({})\n", i));
1144                    tx.send(i).expect("send i from pool1 -> main");
1145                });
1146                error(format!("p0: {}\n", i));
1147            });
1148        }
1149        drop(tx);
1150
1151        assert_eq!(rx.try_recv(), Err(Empty));
1152        error(format!("{:?}\n{:?}\n", pool0, pool1));
1153        pool0.join();
1154        error(format!("pool0.join() complete =-= {:?}", pool1));
1155        pool1.join();
1156        error("pool1.join() complete\n".into());
1157        assert_eq!(
1158            rx.iter().sum::<i32>(),
1159            1 + 2 + 3 + 4 + 5 + 6 + 7
1160        );
1161    }
1162
1163    #[test]
1164    fn test_empty_pool() {
1165        // Joining an empty pool must return imminently
1166        let pool = ThreadPool::new(4);
1167
1168        pool.join();
1169
1170        assert_eq!(0, pool.jobs.len());
1171    }
1172
1173    #[test]
1174    fn test_no_fun_or_joy() {
1175        // What happens when you keep adding jobs after a join
1176
1177        fn sleepy_function() {
1178            sleep(Duration::from_secs(6));
1179        }
1180
1181        let pool = ThreadPool::with_name("no fun or joy".into(), 8);
1182
1183        pool.execute(sleepy_function);
1184
1185        let p_t = pool.clone();
1186        thread::spawn(move || {
1187            (0..23).map(|_| p_t.execute(sleepy_function)).count();
1188        });
1189
1190        pool.join();
1191    }
1192
1193    #[test]
1194    fn test_clone() {
1195        let pool = ThreadPool::with_name("clone example".into(), 2);
1196
1197        // This batch of jobs will occupy the pool for some time
1198        for _ in 0..6 {
1199            pool.execute(move || {
1200                sleep(Duration::from_secs(2));
1201            });
1202        }
1203
1204        // The following jobs will be inserted into the pool in a random fashion
1205        let t0 = {
1206            let pool = pool.clone();
1207            thread::spawn(move || {
1208                // wait for the first batch of tasks to finish
1209                pool.join();
1210
1211                let (tx, rx) = channel();
1212                for i in 0..42 {
1213                    let tx = tx.clone();
1214                    pool.execute(move || {
1215                        tx.send(i).expect("channel will be waiting");
1216                    });
1217                }
1218                drop(tx);
1219                rx.iter()
1220                    .sum::<i32>()
1221            })
1222        };
1223        let t1 = {
1224            let pool = pool.clone();
1225            thread::spawn(move || {
1226                // wait for the first batch of tasks to finish
1227                pool.join();
1228
1229                let (tx, rx) = channel();
1230                for i in 1..12 {
1231                    let tx = tx.clone();
1232                    pool.execute(move || {
1233                        tx.send(i).expect("channel will be waiting");
1234                    });
1235                }
1236                drop(tx);
1237                rx.iter()
1238                    .product::<i32>()
1239            })
1240        };
1241
1242        assert_eq!(
1243            861,
1244            t0.join()
1245                .expect("thread 0 will return after calculating additions",)
1246        );
1247        assert_eq!(
1248            39916800,
1249            t1.join()
1250                .expect("thread 1 will return after calculating multiplications",)
1251        );
1252    }
1253
1254    #[test]
1255    fn test_sync_shared_data() {
1256        fn assert_sync<T: Sync>() {}
1257        assert_sync::<super::ThreadPoolSharedData>();
1258    }
1259
1260    #[test]
1261    fn test_send_shared_data() {
1262        fn assert_send<T: Send>() {}
1263        assert_send::<super::ThreadPoolSharedData>();
1264    }
1265
1266    #[test]
1267    fn test_send() {
1268        fn assert_send<T: Send>() {}
1269        assert_send::<ThreadPool>();
1270    }
1271
1272    #[test]
1273    fn test_cloned_eq() {
1274        let a = ThreadPool::new(2);
1275
1276        assert_eq!(a, a.clone());
1277    }
1278
1279    #[test]
1280    /// The scenario is joining threads should not be stuck once their wave
1281    /// of joins has completed. So once one thread joining on a pool has
1282    /// succeded other threads joining on the same pool must get out even if
1283    /// the thread is used for other jobs while the first group is finishing
1284    /// their join
1285    ///
1286    /// In this example this means the waiting threads will exit the join in
1287    /// groups of four because the waiter pool has four workers.
1288    fn test_join_wavesurfer() {
1289        let n_cycles = 4;
1290        let n_workers = 4;
1291        let (tx, rx) = channel();
1292        let builder = Builder::new()
1293            .num_threads(n_workers)
1294            .thread_name("join wavesurfer".into());
1295        let p_waiter = builder.clone().build();
1296        let p_clock = builder.build();
1297
1298        let barrier = Arc::new(Barrier::new(3));
1299        let wave_clock = Arc::new(AtomicUsize::new(0));
1300        let clock_thread = {
1301            let barrier = barrier.clone();
1302            let wave_clock = wave_clock.clone();
1303            thread::spawn(move || {
1304                barrier.wait();
1305                for wave_num in 0..n_cycles {
1306                    wave_clock.store(wave_num, Ordering::SeqCst);
1307                    sleep(Duration::from_secs(1));
1308                }
1309            })
1310        };
1311
1312        {
1313            let barrier = barrier.clone();
1314            p_clock.execute(move || {
1315                barrier.wait();
1316                // this sleep is for stabilisation on weaker platforms
1317                sleep(Duration::from_millis(100));
1318            });
1319        }
1320
1321        // prepare three waves of jobs
1322        for i in 0..3 * n_workers {
1323            let p_clock = p_clock.clone();
1324            let tx = tx.clone();
1325            let wave_clock = wave_clock.clone();
1326            p_waiter.execute(move || {
1327                let now = wave_clock.load(Ordering::SeqCst);
1328                p_clock.join();
1329                // submit jobs for the second wave
1330                p_clock.execute(|| sleep(Duration::from_secs(1)));
1331                let clock = wave_clock.load(Ordering::SeqCst);
1332                tx.send((now, clock, i)).unwrap();
1333            });
1334        }
1335        println!("all scheduled at {}", wave_clock.load(Ordering::SeqCst));
1336        barrier.wait();
1337
1338        p_clock.join();
1339        //p_waiter.join();
1340
1341        drop(tx);
1342        let mut hist = vec![0; n_cycles];
1343        let mut data = vec![];
1344        for (now, after, i) in rx.iter() {
1345            let mut dur = after - now;
1346            if dur >= n_cycles - 1 {
1347                dur = n_cycles - 1;
1348            }
1349            hist[dur] += 1;
1350
1351            data.push((now, after, i));
1352        }
1353        for (i, n) in hist.iter().enumerate() {
1354            println!(
1355                "\t{}: {} {}",
1356                i,
1357                n,
1358                &*(0..*n).fold("".to_owned(), |s, _| s + "*")
1359            );
1360        }
1361        assert!(data.iter().all(|&(cycle, stop, i)| if i < n_workers {
1362            cycle == stop
1363        } else {
1364            cycle < stop
1365        }));
1366
1367        clock_thread.join().unwrap();
1368    }
1369
1370    #[test]
1371    fn test_bounded_pool() {
1372        let pool = Builder::new()
1373            .num_threads(1)
1374            .queue_len(1)
1375            .build();
1376        let end = Arc::new(Barrier::new(2));
1377        let count = Arc::new(Mutex::new(0));
1378
1379        fn inc_wait(c: &Arc<Mutex<i64>>, val: i64, millis: i64) -> bool{
1380            for _ in 0..millis/10 {
1381                {
1382                    let l = c.lock().unwrap();
1383                    if *l == val {
1384                        return true;
1385                    }
1386                }
1387                sleep(Duration::from_millis(10));
1388            }
1389            false
1390        }
1391
1392        // Lock up the only thread
1393        let e1 = end.clone();
1394        let c1 = count.clone();
1395        pool.execute(move || {
1396            {
1397                let mut c = c1.lock().unwrap();
1398                *c += 1;
1399            }
1400            e1.wait();
1401        });
1402
1403        // Wait for it to be ready
1404        assert!(inc_wait(&count, 1, 1000));
1405        assert_eq!(pool.queued_count(), 0);
1406
1407        // Schedule 2nd job; sits on the queue
1408        let e2 = end.clone();
1409        let c2 = count.clone();
1410        pool.execute(move || {
1411            {
1412                let mut c = c2.lock().unwrap();
1413                *c += 1;
1414            }
1415            e2.wait();
1416        });
1417
1418        assert!(!inc_wait(&count, 2, 1000));
1419        assert_eq!(pool.queued_count(), 1);
1420
1421        // Third attempt should block
1422        let c3 = count.clone();
1423        thread::spawn(move || {
1424            pool.execute(move || {
1425            } );
1426            {
1427                let mut c = c3.lock().unwrap();
1428                *c += 1;
1429            }
1430        });
1431        assert!(!inc_wait(&count, 2, 1000));
1432
1433        end.wait();
1434    }
1435}