blocking_threadpool/lib.rs
1// Copyright 2014 The Rust Project Developers. See the COPYRIGHT
2// file at the top-level directory of this distribution and at
3// http://rust-lang.org/COPYRIGHT.
4//
5// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
8// option. This file may not be copied, modified, or distributed
9// except according to those terms.
10
11//! A thread pool used to execute functions in parallel.
12//!
13//! Spawns a specified number of worker threads and replenishes the pool if any worker threads
14//! panic.
15//!
16//! # Examples
17//!
18//! ## Synchronized with a channel
19//!
20//! Every thread sends one message over the channel, which then is collected with the `take()`.
21//!
22//! ```
23//! use blocking_threadpool::ThreadPool;
24//! use std::sync::mpsc::channel;
25//!
26//! let n_workers = 4;
27//! let n_jobs = 8;
28//! let pool = ThreadPool::new(n_workers);
29//!
30//! let (tx, rx) = channel();
31//! for _ in 0..n_jobs {
32//! let tx = tx.clone();
33//! pool.execute(move|| {
34//! tx.send(1).expect("channel will be there waiting for the pool");
35//! });
36//! }
37//!
38//! assert_eq!(rx.iter().take(n_jobs).fold(0, |a, b| a + b), 8);
39//! ```
40//!
41//! ## Synchronized with a barrier
42//!
43//! Keep in mind, if a barrier synchronizes more jobs than you have workers in the pool,
44//! you will end up with a [deadlock](https://en.wikipedia.org/wiki/Deadlock)
45//! at the barrier which is [not considered unsafe](
46//! https://doc.rust-lang.org/reference/behavior-not-considered-unsafe.html).
47//!
48//! ```
49//! use blocking_threadpool::ThreadPool;
50//! use std::sync::{Arc, Barrier};
51//! use std::sync::atomic::{AtomicUsize, Ordering};
52//!
53//! // create at least as many workers as jobs or you will deadlock yourself
54//! let n_workers = 42;
55//! let n_jobs = 23;
56//! let pool = ThreadPool::new(n_workers);
57//! let an_atomic = Arc::new(AtomicUsize::new(0));
58//!
59//! assert!(n_jobs <= n_workers, "too many jobs, will deadlock");
60//!
61//! // create a barrier that waits for all jobs plus the starter thread
62//! let barrier = Arc::new(Barrier::new(n_jobs + 1));
63//! for _ in 0..n_jobs {
64//! let barrier = barrier.clone();
65//! let an_atomic = an_atomic.clone();
66//!
67//! pool.execute(move|| {
68//! // do the heavy work
69//! an_atomic.fetch_add(1, Ordering::Relaxed);
70//!
71//! // then wait for the other threads
72//! barrier.wait();
73//! });
74//! }
75//!
76//! // wait for the threads to finish the work
77//! barrier.wait();
78//! assert_eq!(an_atomic.load(Ordering::SeqCst), /* n_jobs = */ 23);
79//! ```
80
81use std::fmt;
82use std::sync::atomic::{AtomicUsize, Ordering};
83use std::sync::{Arc, Condvar, Mutex};
84use std::thread;
85
86use crossbeam_channel as cbc;
87
88trait FnBox {
89 fn call_box(self: Box<Self>);
90}
91
92impl<F: FnOnce()> FnBox for F {
93 fn call_box(self: Box<F>) {
94 (*self)()
95 }
96}
97
98type Thunk<'a> = Box<dyn FnBox + Send + 'a>;
99
100struct Sentinel<'a> {
101 shared_data: &'a Arc<ThreadPoolSharedData>,
102 active: bool,
103}
104
105impl<'a> Sentinel<'a> {
106 fn new(shared_data: &'a Arc<ThreadPoolSharedData>) -> Sentinel<'a> {
107 Sentinel {
108 shared_data,
109 active: true,
110 }
111 }
112
113 /// Cancel and destroy this sentinel.
114 fn cancel(mut self) {
115 self.active = false;
116 }
117}
118
119impl<'a> Drop for Sentinel<'a> {
120 fn drop(&mut self) {
121 if self.active {
122 self.shared_data.active_count.fetch_sub(1, Ordering::SeqCst);
123 if thread::panicking() {
124 self.shared_data.panic_count.fetch_add(1, Ordering::SeqCst);
125 }
126 self.shared_data.no_work_notify_all();
127 spawn_in_pool(self.shared_data.clone())
128 }
129 }
130}
131
132/// [`ThreadPool`] factory, which can be used in order to configure the properties of the
133/// [`ThreadPool`].
134///
135/// The three configuration options available:
136///
137/// * `num_threads`: maximum number of threads that will be alive at any given moment by the built
138/// [`ThreadPool`]
139/// * `thread_name`: thread name for each of the threads spawned by the built [`ThreadPool`]
140/// * `thread_stack_size`: stack size (in bytes) for each of the threads spawned by the built
141/// [`ThreadPool`]
142///
143/// [`ThreadPool`]: struct.ThreadPool.html
144///
145/// # Examples
146///
147/// Build a [`ThreadPool`] that uses a maximum of eight threads simultaneously and each thread has
148/// a 8 MB stack size:
149///
150/// ```
151/// let pool = blocking_threadpool::Builder::new()
152/// .num_threads(8)
153/// .thread_stack_size(8_000_000)
154/// .build();
155/// ```
156#[derive(Clone, Default)]
157pub struct Builder {
158 num_threads: Option<usize>,
159 thread_name: Option<String>,
160 thread_stack_size: Option<usize>,
161 queue_len: Option<usize>,
162}
163
164impl Builder {
165 /// Initiate a new [`Builder`].
166 ///
167 /// [`Builder`]: struct.Builder.html
168 ///
169 /// # Examples
170 ///
171 /// ```
172 /// let builder = blocking_threadpool::Builder::new();
173 /// ```
174 pub fn new() -> Builder {
175 Builder {
176 num_threads: None,
177 thread_name: None,
178 thread_stack_size: None,
179 queue_len: None,
180 }
181 }
182
183 /// Set the maximum number of worker-threads that will be alive at any given moment by the built
184 /// [`ThreadPool`]. If not specified, defaults the number of threads to the number of CPUs.
185 ///
186 /// [`ThreadPool`]: struct.ThreadPool.html
187 ///
188 /// # Panics
189 ///
190 /// This method will panic if `num_threads` is 0.
191 ///
192 /// # Examples
193 ///
194 /// No more than eight threads will be alive simultaneously for this pool:
195 ///
196 /// ```
197 /// use std::thread;
198 ///
199 /// let pool = blocking_threadpool::Builder::new()
200 /// .num_threads(8)
201 /// .build();
202 ///
203 /// for _ in 0..100 {
204 /// pool.execute(|| {
205 /// println!("Hello from a worker thread!")
206 /// })
207 /// }
208 /// ```
209 pub fn num_threads(mut self, num_threads: usize) -> Builder {
210 assert!(num_threads > 0);
211 self.num_threads = Some(num_threads);
212 self
213 }
214
215 /// Set the maximum number of pending jobs that can be queued to
216 /// the [`ThreadPool`]. Once the queue is full further calls will
217 /// block until slots become available. A `len` of 0 will always
218 /// block until a thread is available. If not specified, defaults
219 /// to unlimited.
220 ///
221 /// [`ThreadPool`]: struct.ThreadPool.html
222 ///
223 /// # Panics
224 ///
225 /// This method will panic if `len` is less-than 0;
226 ///
227 /// # Examples
228 ///
229 /// With a single thread and a queue len of 1, the final execute
230 /// will have to wait until the first job finishes to be queued.
231 ///
232 /// ```
233 /// use std::thread;
234 /// use std::time::Duration;
235 ///
236 /// let pool = blocking_threadpool::Builder::new()
237 /// .num_threads(1)
238 /// .queue_len(1)
239 /// .build();
240 ///
241 /// for _ in 0..2 {
242 /// pool.execute(|| {
243 /// println!("Hello from a worker thread! I'm going to rest now...");
244 /// thread::sleep(Duration::from_secs(10));
245 /// println!("All done!");
246 /// })
247 /// }
248 ///
249 /// pool.execute(|| {
250 /// println!("Hello from 10 seconds in the future!");
251 /// });
252 /// ```
253 pub fn queue_len(mut self, len: usize) -> Builder {
254 self.queue_len = Some(len);
255 self
256 }
257
258 /// Set the thread name for each of the threads spawned by the built [`ThreadPool`]. If not
259 /// specified, threads spawned by the thread pool will be unnamed.
260 ///
261 /// [`ThreadPool`]: struct.ThreadPool.html
262 ///
263 /// # Examples
264 ///
265 /// Each thread spawned by this pool will have the name "foo":
266 ///
267 /// ```
268 /// use std::thread;
269 ///
270 /// let pool = blocking_threadpool::Builder::new()
271 /// .thread_name("foo".into())
272 /// .build();
273 ///
274 /// for _ in 0..100 {
275 /// pool.execute(|| {
276 /// assert_eq!(thread::current().name(), Some("foo"));
277 /// })
278 /// }
279 /// ```
280 pub fn thread_name(mut self, name: String) -> Builder {
281 self.thread_name = Some(name);
282 self
283 }
284
285 /// Set the stack size (in bytes) for each of the threads spawned by the built [`ThreadPool`].
286 /// If not specified, threads spawned by the threadpool will have a stack size [as specified in
287 /// the `std::thread` documentation][thread].
288 ///
289 /// [thread]: https://doc.rust-lang.org/nightly/std/thread/index.html#stack-size
290 /// [`ThreadPool`]: struct.ThreadPool.html
291 ///
292 /// # Examples
293 ///
294 /// Each thread spawned by this pool will have a 4 MB stack:
295 ///
296 /// ```
297 /// let pool = blocking_threadpool::Builder::new()
298 /// .thread_stack_size(4_000_000)
299 /// .build();
300 ///
301 /// for _ in 0..100 {
302 /// pool.execute(|| {
303 /// println!("This thread has a 4 MB stack size!");
304 /// })
305 /// }
306 /// ```
307 pub fn thread_stack_size(mut self, size: usize) -> Builder {
308 self.thread_stack_size = Some(size);
309 self
310 }
311
312 /// Finalize the [`Builder`] and build the [`ThreadPool`].
313 ///
314 /// [`Builder`]: struct.Builder.html
315 /// [`ThreadPool`]: struct.ThreadPool.html
316 ///
317 /// # Examples
318 ///
319 /// ```
320 /// let pool = blocking_threadpool::Builder::new()
321 /// .num_threads(8)
322 /// .thread_stack_size(4_000_000)
323 /// .build();
324 /// ```
325 pub fn build(self) -> ThreadPool {
326 let (tx, rx) = self.queue_len.map_or_else(
327 cbc::unbounded,
328 cbc::bounded
329 );
330
331 let num_threads = self.num_threads.unwrap_or_else(num_cpus::get);
332
333 let shared_data = Arc::new(ThreadPoolSharedData {
334 name: self.thread_name,
335 job_receiver: rx,
336 empty_condvar: Condvar::new(),
337 empty_trigger: Mutex::new(()),
338 join_generation: AtomicUsize::new(0),
339 queued_count: AtomicUsize::new(0),
340 active_count: AtomicUsize::new(0),
341 max_thread_count: AtomicUsize::new(num_threads),
342 panic_count: AtomicUsize::new(0),
343 stack_size: self.thread_stack_size,
344 });
345
346 // Threadpool threads
347 for _ in 0..num_threads {
348 spawn_in_pool(shared_data.clone());
349 }
350
351 ThreadPool {
352 jobs: tx,
353 shared_data,
354 }
355 }
356}
357
358struct ThreadPoolSharedData {
359 name: Option<String>,
360 job_receiver: cbc::Receiver<Thunk<'static>>,
361 empty_trigger: Mutex<()>,
362 empty_condvar: Condvar,
363 join_generation: AtomicUsize,
364 queued_count: AtomicUsize,
365 active_count: AtomicUsize,
366 max_thread_count: AtomicUsize,
367 panic_count: AtomicUsize,
368 stack_size: Option<usize>,
369}
370
371impl ThreadPoolSharedData {
372 fn has_work(&self) -> bool {
373 self.queued_count.load(Ordering::SeqCst) > 0 || self.active_count.load(Ordering::SeqCst) > 0
374 }
375
376 /// Notify all observers joining this pool if there is no more work to do.
377 fn no_work_notify_all(&self) {
378 if !self.has_work() {
379 let _lock = self.empty_trigger
380 .lock()
381 .expect("Unable to notify all joining threads");
382 self.empty_condvar.notify_all();
383 }
384 }
385}
386
387/// Abstraction of a thread pool for basic parallelism.
388pub struct ThreadPool {
389 // How the threadpool communicates with subthreads.
390 //
391 // This is the only such Sender, so when it is dropped all subthreads will
392 // quit.
393 jobs: cbc::Sender<Thunk<'static>>,
394 shared_data: Arc<ThreadPoolSharedData>,
395}
396
397impl ThreadPool {
398 /// Creates a new thread pool capable of executing `num_threads` number of jobs concurrently.
399 ///
400 /// # Panics
401 ///
402 /// This function will panic if `num_threads` is 0.
403 ///
404 /// # Examples
405 ///
406 /// Create a new thread pool capable of executing four jobs concurrently:
407 ///
408 /// ```
409 /// use blocking_threadpool::ThreadPool;
410 ///
411 /// let pool = ThreadPool::new(4);
412 /// ```
413 pub fn new(num_threads: usize) -> ThreadPool {
414 Builder::new().num_threads(num_threads).build()
415 }
416
417 /// Creates a new thread pool capable of executing `num_threads` number of jobs concurrently.
418 /// Each thread will have the [name][thread name] `name`.
419 ///
420 /// # Panics
421 ///
422 /// This function will panic if `num_threads` is 0.
423 ///
424 /// # Examples
425 ///
426 /// ```rust
427 /// use std::thread;
428 /// use blocking_threadpool::ThreadPool;
429 ///
430 /// let pool = ThreadPool::with_name("worker".into(), 2);
431 /// for _ in 0..2 {
432 /// pool.execute(|| {
433 /// assert_eq!(
434 /// thread::current().name(),
435 /// Some("worker")
436 /// );
437 /// });
438 /// }
439 /// pool.join();
440 /// ```
441 ///
442 /// [thread name]: https://doc.rust-lang.org/std/thread/struct.Thread.html#method.name
443 pub fn with_name(name: String, num_threads: usize) -> ThreadPool {
444 Builder::new()
445 .num_threads(num_threads)
446 .thread_name(name)
447 .build()
448 }
449
450 /// **Deprecated: Use [`ThreadPool::with_name`](#method.with_name)**
451 #[inline(always)]
452 #[deprecated(since = "1.4.0", note = "use blocking_threadpool::with_name")]
453 pub fn new_with_name(name: String, num_threads: usize) -> ThreadPool {
454 Self::with_name(name, num_threads)
455 }
456
457 /// Executes the function `job` on a thread in the pool.
458 ///
459 /// # Examples
460 ///
461 /// Execute four jobs on a thread pool that can run two jobs concurrently:
462 ///
463 /// ```
464 /// use blocking_threadpool::ThreadPool;
465 ///
466 /// let pool = ThreadPool::new(2);
467 /// pool.execute(|| println!("hello"));
468 /// pool.execute(|| println!("world"));
469 /// pool.execute(|| println!("foo"));
470 /// pool.execute(|| println!("bar"));
471 /// pool.join();
472 /// ```
473 pub fn execute<F>(&self, job: F)
474 where
475 F: FnOnce() + Send + 'static,
476 {
477 self.shared_data.queued_count.fetch_add(1, Ordering::SeqCst);
478 self.jobs
479 .send(Box::new(job))
480 .expect("ThreadPool::execute unable to send job into queue.");
481 }
482
483 /// Returns the number of jobs waiting to executed in the pool.
484 ///
485 /// # Examples
486 ///
487 /// ```
488 /// use blocking_threadpool::ThreadPool;
489 /// use std::time::Duration;
490 /// use std::thread::sleep;
491 ///
492 /// let pool = ThreadPool::new(2);
493 /// for _ in 0..10 {
494 /// pool.execute(|| {
495 /// sleep(Duration::from_secs(100));
496 /// });
497 /// }
498 ///
499 /// sleep(Duration::from_secs(1)); // wait for threads to start
500 /// assert_eq!(8, pool.queued_count());
501 /// ```
502 pub fn queued_count(&self) -> usize {
503 self.shared_data.queued_count.load(Ordering::Relaxed)
504 }
505
506 /// Returns the number of currently active threads.
507 ///
508 /// # Examples
509 ///
510 /// ```
511 /// use blocking_threadpool::ThreadPool;
512 /// use std::time::Duration;
513 /// use std::thread::sleep;
514 ///
515 /// let pool = ThreadPool::new(4);
516 /// for _ in 0..10 {
517 /// pool.execute(move || {
518 /// sleep(Duration::from_secs(100));
519 /// });
520 /// }
521 ///
522 /// sleep(Duration::from_secs(1)); // wait for threads to start
523 /// assert_eq!(4, pool.active_count());
524 /// ```
525 pub fn active_count(&self) -> usize {
526 self.shared_data.active_count.load(Ordering::SeqCst)
527 }
528
529 /// Returns the maximum number of threads the pool will execute concurrently.
530 ///
531 /// # Examples
532 ///
533 /// ```
534 /// use blocking_threadpool::ThreadPool;
535 ///
536 /// let mut pool = ThreadPool::new(4);
537 /// assert_eq!(4, pool.max_count());
538 ///
539 /// pool.set_num_threads(8);
540 /// assert_eq!(8, pool.max_count());
541 /// ```
542 pub fn max_count(&self) -> usize {
543 self.shared_data.max_thread_count.load(Ordering::Relaxed)
544 }
545
546 /// Returns the number of panicked threads over the lifetime of the pool.
547 ///
548 /// # Examples
549 ///
550 /// ```
551 /// use blocking_threadpool::ThreadPool;
552 ///
553 /// let pool = ThreadPool::new(4);
554 /// for n in 0..10 {
555 /// pool.execute(move || {
556 /// // simulate a panic
557 /// if n % 2 == 0 {
558 /// panic!()
559 /// }
560 /// });
561 /// }
562 /// pool.join();
563 ///
564 /// assert_eq!(5, pool.panic_count());
565 /// ```
566 pub fn panic_count(&self) -> usize {
567 self.shared_data.panic_count.load(Ordering::Relaxed)
568 }
569
570 /// **Deprecated: Use [`ThreadPool::set_num_threads`](#method.set_num_threads)**
571 #[deprecated(since = "1.3.0", note = "use blocking_threadpool::set_num_threads")]
572 pub fn set_threads(&mut self, num_threads: usize) {
573 self.set_num_threads(num_threads)
574 }
575
576 /// Sets the number of worker-threads to use as `num_threads`.
577 /// Can be used to change the threadpool size during runtime.
578 /// Will not abort already running or waiting threads.
579 ///
580 /// # Panics
581 ///
582 /// This function will panic if `num_threads` is 0.
583 ///
584 /// # Examples
585 ///
586 /// ```
587 /// use blocking_threadpool::ThreadPool;
588 /// use std::time::Duration;
589 /// use std::thread::sleep;
590 ///
591 /// let mut pool = ThreadPool::new(4);
592 /// for _ in 0..10 {
593 /// pool.execute(move || {
594 /// sleep(Duration::from_secs(100));
595 /// });
596 /// }
597 ///
598 /// sleep(Duration::from_secs(1)); // wait for threads to start
599 /// assert_eq!(4, pool.active_count());
600 /// assert_eq!(6, pool.queued_count());
601 ///
602 /// // Increase thread capacity of the pool
603 /// pool.set_num_threads(8);
604 ///
605 /// sleep(Duration::from_secs(1)); // wait for new threads to start
606 /// assert_eq!(8, pool.active_count());
607 /// assert_eq!(2, pool.queued_count());
608 ///
609 /// // Decrease thread capacity of the pool
610 /// // No active threads are killed
611 /// pool.set_num_threads(4);
612 ///
613 /// assert_eq!(8, pool.active_count());
614 /// assert_eq!(2, pool.queued_count());
615 /// ```
616 pub fn set_num_threads(&mut self, num_threads: usize) {
617 assert!(num_threads >= 1);
618 let prev_num_threads = self
619 .shared_data
620 .max_thread_count
621 .swap(num_threads, Ordering::Release);
622 if let Some(num_spawn) = num_threads.checked_sub(prev_num_threads) {
623 // Spawn new threads
624 for _ in 0..num_spawn {
625 spawn_in_pool(self.shared_data.clone());
626 }
627 }
628 }
629
630 /// Block the current thread until all jobs in the pool have been executed.
631 ///
632 /// Calling `join` on an empty pool will cause an immediate return.
633 /// `join` may be called from multiple threads concurrently.
634 /// A `join` is an atomic point in time. All threads joining before the join
635 /// event will exit together even if the pool is processing new jobs by the
636 /// time they get scheduled.
637 ///
638 /// Calling `join` from a thread within the pool will cause a deadlock. This
639 /// behavior is considered safe.
640 ///
641 /// # Examples
642 ///
643 /// ```
644 /// use blocking_threadpool::ThreadPool;
645 /// use std::sync::Arc;
646 /// use std::sync::atomic::{AtomicUsize, Ordering};
647 ///
648 /// let pool = ThreadPool::new(8);
649 /// let test_count = Arc::new(AtomicUsize::new(0));
650 ///
651 /// for _ in 0..42 {
652 /// let test_count = test_count.clone();
653 /// pool.execute(move || {
654 /// test_count.fetch_add(1, Ordering::Relaxed);
655 /// });
656 /// }
657 ///
658 /// pool.join();
659 /// assert_eq!(42, test_count.load(Ordering::Relaxed));
660 /// ```
661 pub fn join(&self) {
662 // fast path requires no mutex
663 if !self.shared_data.has_work() {
664 return;
665 }
666
667 let generation = self.shared_data.join_generation.load(Ordering::SeqCst);
668 let mut lock = self.shared_data.empty_trigger.lock().unwrap();
669
670 while generation == self.shared_data.join_generation.load(Ordering::Relaxed)
671 && self.shared_data.has_work()
672 {
673 lock = self.shared_data.empty_condvar.wait(lock).unwrap();
674 }
675
676 // increase generation if we are the first thread to come out of the loop
677 let _ = self.shared_data.join_generation.compare_exchange(
678 generation,
679 generation.wrapping_add(1),
680 Ordering::SeqCst,
681 Ordering::SeqCst);
682 }
683}
684
685impl Clone for ThreadPool {
686 /// Cloning a pool will create a new handle to the pool.
687 /// The behavior is similar to [Arc](https://doc.rust-lang.org/stable/std/sync/struct.Arc.html).
688 ///
689 /// We could for example submit jobs from multiple threads concurrently.
690 ///
691 /// ```
692 /// use blocking_threadpool::ThreadPool;
693 /// use std::thread;
694 /// use std::sync::mpsc::channel;
695 ///
696 /// let pool = ThreadPool::with_name("clone example".into(), 2);
697 ///
698 /// let results = (0..2)
699 /// .map(|i| {
700 /// let pool = pool.clone();
701 /// thread::spawn(move || {
702 /// let (tx, rx) = channel();
703 /// for i in 1..12 {
704 /// let tx = tx.clone();
705 /// pool.execute(move || {
706 /// tx.send(i).expect("channel will be waiting");
707 /// });
708 /// }
709 /// drop(tx);
710 /// if i == 0 {
711 /// rx.iter().fold(0, |accumulator, element| accumulator + element)
712 /// } else {
713 /// rx.iter().fold(1, |accumulator, element| accumulator * element)
714 /// }
715 /// })
716 /// })
717 /// .map(|join_handle| join_handle.join().expect("collect results from threads"))
718 /// .collect::<Vec<usize>>();
719 ///
720 /// assert_eq!(vec![66, 39916800], results);
721 /// ```
722 fn clone(&self) -> ThreadPool {
723 ThreadPool {
724 jobs: self.jobs.clone(),
725 shared_data: self.shared_data.clone(),
726 }
727 }
728}
729
730/// Create a thread pool with one thread per CPU.
731/// On machines with hyperthreading,
732/// this will create one thread per hyperthread.
733impl Default for ThreadPool {
734 fn default() -> Self {
735 ThreadPool::new(num_cpus::get())
736 }
737}
738
739impl fmt::Debug for ThreadPool {
740 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
741 f.debug_struct("ThreadPool")
742 .field("name", &self.shared_data.name)
743 .field("queued_count", &self.queued_count())
744 .field("active_count", &self.active_count())
745 .field("max_count", &self.max_count())
746 .finish()
747 }
748}
749
750impl PartialEq for ThreadPool {
751 /// Check if you are working with the same pool
752 ///
753 /// ```
754 /// use blocking_threadpool::ThreadPool;
755 ///
756 /// let a = ThreadPool::new(2);
757 /// let b = ThreadPool::new(2);
758 ///
759 /// assert_eq!(a, a);
760 /// assert_eq!(b, b);
761 ///
762 /// # // TODO: change this to assert_ne in the future
763 /// assert!(a != b);
764 /// assert!(b != a);
765 /// ```
766 fn eq(&self, other: &ThreadPool) -> bool {
767 let a: &ThreadPoolSharedData = &self.shared_data;
768 let b: &ThreadPoolSharedData = &other.shared_data;
769 std::ptr::eq(a, b)
770 // with rust 1.17 and late:
771 // Arc::ptr_eq(&self.shared_data, &other.shared_data)
772 }
773}
774impl Eq for ThreadPool {}
775
776fn spawn_in_pool(shared_data: Arc<ThreadPoolSharedData>) {
777 let mut builder = thread::Builder::new();
778 if let Some(ref name) = shared_data.name {
779 builder = builder.name(name.clone());
780 }
781 if let Some(ref stack_size) = shared_data.stack_size {
782 builder = builder.stack_size(stack_size.to_owned());
783 }
784 builder
785 .spawn(move || {
786 // Will spawn a new thread on panic unless it is cancelled.
787 let sentinel = Sentinel::new(&shared_data);
788
789 loop {
790 // Shutdown this thread if the pool has become smaller
791 let thread_counter_val = shared_data.active_count.load(Ordering::Acquire);
792 let max_thread_count_val = shared_data.max_thread_count.load(Ordering::Relaxed);
793 if thread_counter_val >= max_thread_count_val {
794 break;
795 }
796 let message = shared_data
797 .job_receiver
798 .recv();
799
800 let job = match message {
801 Ok(job) => job,
802 // The ThreadPool was dropped.
803 Err(..) => break,
804 };
805 // Do not allow IR around the job execution
806 shared_data.active_count.fetch_add(1, Ordering::SeqCst);
807 shared_data.queued_count.fetch_sub(1, Ordering::SeqCst);
808
809 job.call_box();
810
811 shared_data.active_count.fetch_sub(1, Ordering::SeqCst);
812 shared_data.no_work_notify_all();
813 }
814
815 sentinel.cancel();
816 })
817 .unwrap();
818}
819
820#[cfg(test)]
821mod test {
822 use super::{Builder, ThreadPool};
823 use std::sync::atomic::{AtomicUsize, Ordering};
824 use std::sync::mpsc::{channel, sync_channel};
825 use std::sync::{Arc, Barrier, Mutex};
826 use std::thread::{self, sleep};
827 use std::time::Duration;
828
829 const TEST_TASKS: usize = 4;
830
831 #[test]
832 fn test_set_num_threads_increasing() {
833 let new_thread_amount = TEST_TASKS + 8;
834 let mut pool = ThreadPool::new(TEST_TASKS);
835 for _ in 0..TEST_TASKS {
836 pool.execute(move || sleep(Duration::from_secs(23)));
837 }
838 sleep(Duration::from_secs(1));
839 assert_eq!(pool.active_count(), TEST_TASKS);
840
841 pool.set_num_threads(new_thread_amount);
842
843 for _ in 0..(new_thread_amount - TEST_TASKS) {
844 pool.execute(move || sleep(Duration::from_secs(23)));
845 }
846 sleep(Duration::from_secs(1));
847 assert_eq!(pool.active_count(), new_thread_amount);
848
849 pool.join();
850 }
851
852 #[test]
853 fn test_set_num_threads_decreasing() {
854 let new_thread_amount = 2;
855 let mut pool = ThreadPool::new(TEST_TASKS);
856 for _ in 0..TEST_TASKS {
857 pool.execute(move || {
858 assert_eq!(1, 1);
859 });
860 }
861 pool.set_num_threads(new_thread_amount);
862 for _ in 0..new_thread_amount {
863 pool.execute(move || sleep(Duration::from_secs(23)));
864 }
865 sleep(Duration::from_secs(1));
866 assert_eq!(pool.active_count(), new_thread_amount);
867
868 pool.join();
869 }
870
871 #[test]
872 fn test_active_count() {
873 let pool = ThreadPool::new(TEST_TASKS);
874 for _ in 0..2 * TEST_TASKS {
875 pool.execute(move || loop {
876 sleep(Duration::from_secs(10))
877 });
878 }
879 sleep(Duration::from_secs(1));
880 let active_count = pool.active_count();
881 assert_eq!(active_count, TEST_TASKS);
882 let initialized_count = pool.max_count();
883 assert_eq!(initialized_count, TEST_TASKS);
884 }
885
886 #[test]
887 fn test_works() {
888 let pool = ThreadPool::new(TEST_TASKS);
889
890 let (tx, rx) = channel();
891 for _ in 0..TEST_TASKS {
892 let tx = tx.clone();
893 pool.execute(move || {
894 tx.send(1).unwrap();
895 });
896 }
897
898 assert_eq!(rx.iter().take(TEST_TASKS).sum::<usize>(), TEST_TASKS);
899 }
900
901 #[test]
902 #[should_panic]
903 fn test_zero_tasks_panic() {
904 ThreadPool::new(0);
905 }
906
907 #[test]
908 fn test_recovery_from_subtask_panic() {
909 let pool = ThreadPool::new(TEST_TASKS);
910
911 // Panic all the existing threads.
912 for _ in 0..TEST_TASKS {
913 pool.execute(move || panic!("Ignore this panic, it must!"));
914 }
915 pool.join();
916
917 assert_eq!(pool.panic_count(), TEST_TASKS);
918
919 // Ensure new threads were spawned to compensate.
920 let (tx, rx) = channel();
921 for _ in 0..TEST_TASKS {
922 let tx = tx.clone();
923 pool.execute(move || {
924 tx.send(1).unwrap();
925 });
926 }
927
928 assert_eq!(rx.iter().take(TEST_TASKS).sum::<usize>(), TEST_TASKS);
929 }
930
931 #[test]
932 fn test_should_not_panic_on_drop_if_subtasks_panic_after_drop() {
933 let pool = ThreadPool::new(TEST_TASKS);
934 let waiter = Arc::new(Barrier::new(TEST_TASKS + 1));
935
936 // Panic all the existing threads in a bit.
937 for _ in 0..TEST_TASKS {
938 let waiter = waiter.clone();
939 pool.execute(move || {
940 waiter.wait();
941 panic!("Ignore this panic, it should!");
942 });
943 }
944
945 drop(pool);
946
947 // Kick off the failure.
948 waiter.wait();
949 }
950
951 #[test]
952 fn test_massive_task_creation() {
953 let test_tasks = 4_200_000;
954
955 let pool = ThreadPool::new(TEST_TASKS);
956 let b0 = Arc::new(Barrier::new(TEST_TASKS + 1));
957 let b1 = Arc::new(Barrier::new(TEST_TASKS + 1));
958
959 let (tx, rx) = channel();
960
961 for i in 0..test_tasks {
962 let tx = tx.clone();
963 let (b0, b1) = (b0.clone(), b1.clone());
964
965 pool.execute(move || {
966 // Wait until the pool has been filled once.
967 if i < TEST_TASKS {
968 b0.wait();
969 // wait so the pool can be measured
970 b1.wait();
971 }
972
973 tx.send(1).unwrap();
974 });
975 }
976
977 b0.wait();
978 assert_eq!(pool.active_count(), TEST_TASKS);
979 b1.wait();
980
981 assert_eq!(rx.iter().take(test_tasks).sum::<usize>(), test_tasks);
982 pool.join();
983
984 let atomic_active_count = pool.active_count();
985 assert!(
986 atomic_active_count == 0,
987 "atomic_active_count: {}",
988 atomic_active_count
989 );
990 }
991
992 #[test]
993 fn test_shrink() {
994 let test_tasks_begin = TEST_TASKS + 2;
995
996 let mut pool = ThreadPool::new(test_tasks_begin);
997 let b0 = Arc::new(Barrier::new(test_tasks_begin + 1));
998 let b1 = Arc::new(Barrier::new(test_tasks_begin + 1));
999
1000 for _ in 0..test_tasks_begin {
1001 let (b0, b1) = (b0.clone(), b1.clone());
1002 pool.execute(move || {
1003 b0.wait();
1004 b1.wait();
1005 });
1006 }
1007
1008 let b2 = Arc::new(Barrier::new(TEST_TASKS + 1));
1009 let b3 = Arc::new(Barrier::new(TEST_TASKS + 1));
1010
1011 for _ in 0..TEST_TASKS {
1012 let (b2, b3) = (b2.clone(), b3.clone());
1013 pool.execute(move || {
1014 b2.wait();
1015 b3.wait();
1016 });
1017 }
1018
1019 b0.wait();
1020 pool.set_num_threads(TEST_TASKS);
1021
1022 assert_eq!(pool.active_count(), test_tasks_begin);
1023 b1.wait();
1024
1025 b2.wait();
1026 assert_eq!(pool.active_count(), TEST_TASKS);
1027 b3.wait();
1028 }
1029
1030 #[test]
1031 fn test_name() {
1032 let name = "test";
1033 let mut pool = ThreadPool::with_name(name.to_owned(), 2);
1034 let (tx, rx) = sync_channel(0);
1035
1036 // initial thread should share the name "test"
1037 for _ in 0..2 {
1038 let tx = tx.clone();
1039 pool.execute(move || {
1040 let name = thread::current().name().unwrap().to_owned();
1041 tx.send(name).unwrap();
1042 });
1043 }
1044
1045 // new spawn thread should share the name "test" too.
1046 pool.set_num_threads(3);
1047 let tx_clone = tx.clone();
1048 pool.execute(move || {
1049 let name = thread::current().name().unwrap().to_owned();
1050 tx_clone.send(name).unwrap();
1051 panic!();
1052 });
1053
1054 // recover thread should share the name "test" too.
1055 pool.execute(move || {
1056 let name = thread::current().name().unwrap().to_owned();
1057 tx.send(name).unwrap();
1058 });
1059
1060 for thread_name in rx.iter().take(4) {
1061 assert_eq!(name, thread_name);
1062 }
1063 }
1064
1065 #[test]
1066 fn test_debug() {
1067 let pool = ThreadPool::new(4);
1068 let debug = format!("{:?}", pool);
1069 assert_eq!(
1070 debug,
1071 "ThreadPool { name: None, queued_count: 0, active_count: 0, max_count: 4 }"
1072 );
1073
1074 let pool = ThreadPool::with_name("hello".into(), 4);
1075 let debug = format!("{:?}", pool);
1076 assert_eq!(
1077 debug,
1078 "ThreadPool { name: Some(\"hello\"), queued_count: 0, active_count: 0, max_count: 4 }"
1079 );
1080
1081 let pool = ThreadPool::new(4);
1082 pool.execute(move || sleep(Duration::from_secs(5)));
1083 sleep(Duration::from_secs(1));
1084 let debug = format!("{:?}", pool);
1085 assert_eq!(
1086 debug,
1087 "ThreadPool { name: None, queued_count: 0, active_count: 1, max_count: 4 }"
1088 );
1089 }
1090
1091 #[test]
1092 fn test_repeate_join() {
1093 let pool = ThreadPool::with_name("repeate join test".into(), 8);
1094 let test_count = Arc::new(AtomicUsize::new(0));
1095
1096 for _ in 0..42 {
1097 let test_count = test_count.clone();
1098 pool.execute(move || {
1099 sleep(Duration::from_secs(2));
1100 test_count.fetch_add(1, Ordering::Release);
1101 });
1102 }
1103
1104 println!("{:?}", pool);
1105 pool.join();
1106 assert_eq!(42, test_count.load(Ordering::Acquire));
1107
1108 for _ in 0..42 {
1109 let test_count = test_count.clone();
1110 pool.execute(move || {
1111 sleep(Duration::from_secs(2));
1112 test_count.fetch_add(1, Ordering::Relaxed);
1113 });
1114 }
1115 pool.join();
1116 assert_eq!(84, test_count.load(Ordering::Relaxed));
1117 }
1118
1119 #[test]
1120 fn test_multi_join() {
1121 use std::sync::mpsc::TryRecvError::*;
1122
1123 // Toggle the following lines to debug the deadlock
1124 fn error(_s: String) {
1125 //use ::std::io::Write;
1126 //let stderr = ::std::io::stderr();
1127 //let mut stderr = stderr.lock();
1128 //stderr.write(&_s.as_bytes()).is_ok();
1129 }
1130
1131 let pool0 = ThreadPool::with_name("multi join pool0".into(), 4);
1132 let pool1 = ThreadPool::with_name("multi join pool1".into(), 4);
1133 let (tx, rx) = channel();
1134
1135 for i in 0..8 {
1136 let pool1 = pool1.clone();
1137 let pool0_ = pool0.clone();
1138 let tx = tx.clone();
1139 pool0.execute(move || {
1140 pool1.execute(move || {
1141 error(format!("p1: {} -=- {:?}\n", i, pool0_));
1142 pool0_.join();
1143 error(format!("p1: send({})\n", i));
1144 tx.send(i).expect("send i from pool1 -> main");
1145 });
1146 error(format!("p0: {}\n", i));
1147 });
1148 }
1149 drop(tx);
1150
1151 assert_eq!(rx.try_recv(), Err(Empty));
1152 error(format!("{:?}\n{:?}\n", pool0, pool1));
1153 pool0.join();
1154 error(format!("pool0.join() complete =-= {:?}", pool1));
1155 pool1.join();
1156 error("pool1.join() complete\n".into());
1157 assert_eq!(
1158 rx.iter().sum::<i32>(),
1159 1 + 2 + 3 + 4 + 5 + 6 + 7
1160 );
1161 }
1162
1163 #[test]
1164 fn test_empty_pool() {
1165 // Joining an empty pool must return imminently
1166 let pool = ThreadPool::new(4);
1167
1168 pool.join();
1169
1170 assert_eq!(0, pool.jobs.len());
1171 }
1172
1173 #[test]
1174 fn test_no_fun_or_joy() {
1175 // What happens when you keep adding jobs after a join
1176
1177 fn sleepy_function() {
1178 sleep(Duration::from_secs(6));
1179 }
1180
1181 let pool = ThreadPool::with_name("no fun or joy".into(), 8);
1182
1183 pool.execute(sleepy_function);
1184
1185 let p_t = pool.clone();
1186 thread::spawn(move || {
1187 (0..23).map(|_| p_t.execute(sleepy_function)).count();
1188 });
1189
1190 pool.join();
1191 }
1192
1193 #[test]
1194 fn test_clone() {
1195 let pool = ThreadPool::with_name("clone example".into(), 2);
1196
1197 // This batch of jobs will occupy the pool for some time
1198 for _ in 0..6 {
1199 pool.execute(move || {
1200 sleep(Duration::from_secs(2));
1201 });
1202 }
1203
1204 // The following jobs will be inserted into the pool in a random fashion
1205 let t0 = {
1206 let pool = pool.clone();
1207 thread::spawn(move || {
1208 // wait for the first batch of tasks to finish
1209 pool.join();
1210
1211 let (tx, rx) = channel();
1212 for i in 0..42 {
1213 let tx = tx.clone();
1214 pool.execute(move || {
1215 tx.send(i).expect("channel will be waiting");
1216 });
1217 }
1218 drop(tx);
1219 rx.iter()
1220 .sum::<i32>()
1221 })
1222 };
1223 let t1 = {
1224 let pool = pool.clone();
1225 thread::spawn(move || {
1226 // wait for the first batch of tasks to finish
1227 pool.join();
1228
1229 let (tx, rx) = channel();
1230 for i in 1..12 {
1231 let tx = tx.clone();
1232 pool.execute(move || {
1233 tx.send(i).expect("channel will be waiting");
1234 });
1235 }
1236 drop(tx);
1237 rx.iter()
1238 .product::<i32>()
1239 })
1240 };
1241
1242 assert_eq!(
1243 861,
1244 t0.join()
1245 .expect("thread 0 will return after calculating additions",)
1246 );
1247 assert_eq!(
1248 39916800,
1249 t1.join()
1250 .expect("thread 1 will return after calculating multiplications",)
1251 );
1252 }
1253
1254 #[test]
1255 fn test_sync_shared_data() {
1256 fn assert_sync<T: Sync>() {}
1257 assert_sync::<super::ThreadPoolSharedData>();
1258 }
1259
1260 #[test]
1261 fn test_send_shared_data() {
1262 fn assert_send<T: Send>() {}
1263 assert_send::<super::ThreadPoolSharedData>();
1264 }
1265
1266 #[test]
1267 fn test_send() {
1268 fn assert_send<T: Send>() {}
1269 assert_send::<ThreadPool>();
1270 }
1271
1272 #[test]
1273 fn test_cloned_eq() {
1274 let a = ThreadPool::new(2);
1275
1276 assert_eq!(a, a.clone());
1277 }
1278
1279 #[test]
1280 /// The scenario is joining threads should not be stuck once their wave
1281 /// of joins has completed. So once one thread joining on a pool has
1282 /// succeded other threads joining on the same pool must get out even if
1283 /// the thread is used for other jobs while the first group is finishing
1284 /// their join
1285 ///
1286 /// In this example this means the waiting threads will exit the join in
1287 /// groups of four because the waiter pool has four workers.
1288 fn test_join_wavesurfer() {
1289 let n_cycles = 4;
1290 let n_workers = 4;
1291 let (tx, rx) = channel();
1292 let builder = Builder::new()
1293 .num_threads(n_workers)
1294 .thread_name("join wavesurfer".into());
1295 let p_waiter = builder.clone().build();
1296 let p_clock = builder.build();
1297
1298 let barrier = Arc::new(Barrier::new(3));
1299 let wave_clock = Arc::new(AtomicUsize::new(0));
1300 let clock_thread = {
1301 let barrier = barrier.clone();
1302 let wave_clock = wave_clock.clone();
1303 thread::spawn(move || {
1304 barrier.wait();
1305 for wave_num in 0..n_cycles {
1306 wave_clock.store(wave_num, Ordering::SeqCst);
1307 sleep(Duration::from_secs(1));
1308 }
1309 })
1310 };
1311
1312 {
1313 let barrier = barrier.clone();
1314 p_clock.execute(move || {
1315 barrier.wait();
1316 // this sleep is for stabilisation on weaker platforms
1317 sleep(Duration::from_millis(100));
1318 });
1319 }
1320
1321 // prepare three waves of jobs
1322 for i in 0..3 * n_workers {
1323 let p_clock = p_clock.clone();
1324 let tx = tx.clone();
1325 let wave_clock = wave_clock.clone();
1326 p_waiter.execute(move || {
1327 let now = wave_clock.load(Ordering::SeqCst);
1328 p_clock.join();
1329 // submit jobs for the second wave
1330 p_clock.execute(|| sleep(Duration::from_secs(1)));
1331 let clock = wave_clock.load(Ordering::SeqCst);
1332 tx.send((now, clock, i)).unwrap();
1333 });
1334 }
1335 println!("all scheduled at {}", wave_clock.load(Ordering::SeqCst));
1336 barrier.wait();
1337
1338 p_clock.join();
1339 //p_waiter.join();
1340
1341 drop(tx);
1342 let mut hist = vec![0; n_cycles];
1343 let mut data = vec![];
1344 for (now, after, i) in rx.iter() {
1345 let mut dur = after - now;
1346 if dur >= n_cycles - 1 {
1347 dur = n_cycles - 1;
1348 }
1349 hist[dur] += 1;
1350
1351 data.push((now, after, i));
1352 }
1353 for (i, n) in hist.iter().enumerate() {
1354 println!(
1355 "\t{}: {} {}",
1356 i,
1357 n,
1358 &*(0..*n).fold("".to_owned(), |s, _| s + "*")
1359 );
1360 }
1361 assert!(data.iter().all(|&(cycle, stop, i)| if i < n_workers {
1362 cycle == stop
1363 } else {
1364 cycle < stop
1365 }));
1366
1367 clock_thread.join().unwrap();
1368 }
1369
1370 #[test]
1371 fn test_bounded_pool() {
1372 let pool = Builder::new()
1373 .num_threads(1)
1374 .queue_len(1)
1375 .build();
1376 let end = Arc::new(Barrier::new(2));
1377 let count = Arc::new(Mutex::new(0));
1378
1379 fn inc_wait(c: &Arc<Mutex<i64>>, val: i64, millis: i64) -> bool{
1380 for _ in 0..millis/10 {
1381 {
1382 let l = c.lock().unwrap();
1383 if *l == val {
1384 return true;
1385 }
1386 }
1387 sleep(Duration::from_millis(10));
1388 }
1389 false
1390 }
1391
1392 // Lock up the only thread
1393 let e1 = end.clone();
1394 let c1 = count.clone();
1395 pool.execute(move || {
1396 {
1397 let mut c = c1.lock().unwrap();
1398 *c += 1;
1399 }
1400 e1.wait();
1401 });
1402
1403 // Wait for it to be ready
1404 assert!(inc_wait(&count, 1, 1000));
1405 assert_eq!(pool.queued_count(), 0);
1406
1407 // Schedule 2nd job; sits on the queue
1408 let e2 = end.clone();
1409 let c2 = count.clone();
1410 pool.execute(move || {
1411 {
1412 let mut c = c2.lock().unwrap();
1413 *c += 1;
1414 }
1415 e2.wait();
1416 });
1417
1418 assert!(!inc_wait(&count, 2, 1000));
1419 assert_eq!(pool.queued_count(), 1);
1420
1421 // Third attempt should block
1422 let c3 = count.clone();
1423 thread::spawn(move || {
1424 pool.execute(move || {
1425 } );
1426 {
1427 let mut c = c3.lock().unwrap();
1428 *c += 1;
1429 }
1430 });
1431 assert!(!inc_wait(&count, 2, 1000));
1432
1433 end.wait();
1434 }
1435}