lft_rust/
threadpool.rs

1// Copyright 2022 @yucwang
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9//! A lock-free thread pool used to execute functions in parallel.
10//!
11//! Spawns a specified number of worker threads and replenishes the pool if any worker threads
12//! panic.
13//!
14//! # Examples
15//!
16//! ## Synchronized with a channel
17//!
18//! Every thread sends one message over the channel, which then is collected with the `take()`.
19//!
20//! ```
21//! use crossbeam_channel::unbounded;
22//!
23//! let n_workers = 4;
24//! let n_jobs = 8;
25//! let pool = lft_rust::lft_builder()
26//!                 .num_workers(n_workers)
27//!                 .build();
28//!
29//! let (tx, rx) = unbounded();
30//! for _ in 0..n_jobs {
31//!     let tx = tx.clone();
32//!     pool.execute(move|| {
33//!         tx.send(1).expect("channel will be there waiting for the pool");
34//!     });
35//! }
36//! drop(tx);
37//!
38//! assert_eq!(rx.iter().take(n_jobs).fold(0, |a, b| a + b), 8);
39//! ```
40//!
41//! ## Synchronized with a barrier
42//!
43//! Keep in mind, if a barrier synchronizes more jobs than you have workers in the pool,
44//! you will end up with a [deadlock](https://en.wikipedia.org/wiki/Deadlock)
45//! at the barrier which is [not considered unsafe](
46//! https://doc.rust-lang.org/reference/behavior-not-considered-unsafe.html).
47//!
48//! ```
49//! use lft_rust::{ ThreadPool };
50//! use std::sync::{Arc, Barrier};
51//! use std::sync::atomic::{AtomicUsize, Ordering};
52//!
53//! // create at least as many workers as jobs or you will deadlock yourself
54//! let n_workers = 42;
55//! let n_jobs = 23;
56//! let pool = lft_rust::lft_builder()
57//!                 .num_workers(n_workers)
58//!                 .build();
59//! let an_atomic = Arc::new(AtomicUsize::new(0));
60//!
61//! assert!(n_jobs <= n_workers, "too many jobs, will deadlock");
62//!
63//! // create a barrier that waits for all jobs plus the starter thread
64//! let barrier = Arc::new(Barrier::new(n_jobs + 1));
65//! for _ in 0..n_jobs {
66//!     let barrier = barrier.clone();
67//!     let an_atomic = an_atomic.clone();
68//!
69//!     pool.execute(move|| {
70//!         // do the heavy work
71//!         an_atomic.fetch_add(1, Ordering::Relaxed);
72//!
73//!         // then wait for the other threads
74//!         barrier.wait();
75//!     });
76//! }
77//!
78//! // wait for the threads to finish the work
79//! barrier.wait();
80//! assert_eq!(an_atomic.load(Ordering::SeqCst), n_jobs);
81//! ```
82
83use num_cpus;
84
85use crossbeam_channel::{ unbounded, Receiver, Sender };
86use log::{ trace, warn };
87
88use std::fmt;
89use std::sync::atomic::{ AtomicI8, AtomicUsize, Ordering };
90use std::sync::{ Arc, Condvar, Mutex };
91use std::{ thread, time };
92
93#[cfg(test)]
94mod test;
95
96/// Creates a new thread pool with the same number of workers as CPUs are detected.
97///
98/// # Examples
99///
100/// Create a new thread pool capable of executing at least one jobs concurrently:
101///
102/// ```
103/// let pool = threadpool::auto_config();
104/// ```
105pub fn lft_auto_config() -> ThreadPool {
106    lft_builder().build()
107}
108
109/// Initiate a new [`Builder`].
110///
111/// [`Builder`]: struct.Builder.html
112///
113/// # Examples
114///
115/// ```
116/// let builder = lft_rust::lft_builder();
117/// ```
118pub const fn lft_builder() -> Builder {
119    Builder {
120        num_workers: None,
121        max_thread_count: None,
122        worker_name: None,
123        thread_stack_size: None,
124    }
125}
126
127trait FnBox {
128    fn call_box(self: Box<Self>);
129}
130
131impl<F: FnOnce()> FnBox for F {
132    fn call_box(self: Box<F>) {
133        (*self)()
134    }
135}
136
137type Thunk<'a> = Box<dyn FnBox + Send + 'a>;
138
139struct Sentinel<'a> {
140    shared_data: &'a Arc<ThreadPoolSharedData>,
141    receiver: &'a Arc<Receiver<Thunk<'static>>>,
142    num_jobs: &'a Arc<AtomicUsize>,
143    thread_closing: &'a Arc<AtomicI8>,
144    active: bool,
145}
146
147impl<'a> Sentinel<'a> {
148    fn new(shared_data: &'a Arc<ThreadPoolSharedData>, 
149           receiver: &'a Arc<Receiver<Thunk<'static>>>,
150           num_jobs: &'a Arc<AtomicUsize>,
151           thread_closing: &'a Arc<AtomicI8>) -> Sentinel<'a> {
152        Sentinel {
153            shared_data: shared_data,
154            receiver: receiver,
155            num_jobs: num_jobs,
156            thread_closing: thread_closing,
157            active: true,
158        }
159    }
160
161    /// Cancel and destroy this sentinel.
162    fn cancel(mut self) {
163        self.active = false;
164    }
165}
166
167impl<'a> Drop for Sentinel<'a> {
168    fn drop(&mut self) {
169        if self.active {
170            self.shared_data.active_count.fetch_sub(1, Ordering::SeqCst);
171            self.thread_closing.store(3, Ordering::SeqCst);
172            if thread::panicking() {
173                self.shared_data.panic_count.fetch_add(1, Ordering::SeqCst);
174            }
175            if self.num_jobs.load(Ordering::Acquire) == 0 {
176                self.shared_data.no_work_notify_all();
177            }
178
179            self.thread_closing.store(1, Ordering::SeqCst);
180            spawn_in_pool(self.shared_data.clone(), 
181                          self.receiver.clone(), 
182                          self.num_jobs.clone(), 
183                          self.thread_closing.clone())
184        }
185    }
186}
187
188/// [`ThreadPool`] factory, which can be used in order to configure the properties of the
189/// [`ThreadPool`].
190///
191/// The three configuration options available:
192///
193/// * `num_workers`: the number of worker threads that will be spawned upon building.
194/// * `max_thread_count`: maximum number of threads that will be alive at any given moment by the built
195///   [`ThreadPool`]
196/// * `worker_name`: thread name for each of the threads spawned by the built [`ThreadPool`]
197/// * `thread_stack_size`: stack size (in bytes) for each of the threads spawned by the built
198///   [`ThreadPool`]
199///
200/// [`ThreadPool`]: struct.ThreadPool.html
201///
202/// # Examples
203///
204/// Build a [`ThreadPool`] that uses a maximum of eight threads simultaneously and each thread has
205/// a 8 MB stack size:
206///
207/// ```
208/// let pool = lft_rust::lft_builder()
209///     .num_workers(8)
210///     .thread_stack_size(8 * 1024 * 1024)
211///     .build();
212/// ```
213#[derive(Clone, Default)]
214pub struct Builder {
215    num_workers: Option<usize>,
216    max_thread_count: Option<usize>,
217    worker_name: Option<String>,
218    thread_stack_size: Option<usize>,
219}
220
221impl Builder {
222    /// Set the number of threads that will be spawned upon building. If it is not specified, it
223    /// will be the maximum number of threads that can be spawned.
224    ///
225    /// [`ThreadPool`]: struct.ThreadPool.html
226    ///
227    /// # Panics
228    ///
229    /// This method will panic if `num_workers` is 0.
230    ///
231    /// # Examples
232    ///
233    /// No more than eight threads will be alive simultaneously for this pool:
234    ///
235    /// ```
236    /// use std::thread;
237    ///
238    /// let pool = lft_rust::lft_builder()
239    ///     .num_workers(8)
240    ///     .build();
241    ///
242    /// for _ in 0..42 {
243    ///     pool.execute(|| {
244    ///         println!("Hello from a worker thread!")
245    ///     })
246    /// }
247    /// ```
248    pub fn num_workers(mut self, num_workers: usize) -> Builder {
249        assert!(num_workers > 0);
250        self.num_workers = Some(num_workers);
251        self
252    }
253
254    /// Set the maximum number of worker-threads that will be alive at any given moment by the built
255    /// [`ThreadPool`]. If not specified, defaults the number of threads to the number of CPUs.
256    ///
257    /// [`ThreadPool`]: struct.ThreadPool.html
258    ///
259    /// # Panics
260    ///
261    /// This method will panic if `max_thread_count` is 0.
262    ///
263    /// # Examples
264    ///
265    /// No more than eight threads will be alive simultaneously for this pool:
266    ///
267    /// ```
268    /// use std::thread;
269    ///
270    /// let pool = threadpool::builder()
271    ///     .max_thread_count(8)
272    ///     .build();
273    ///
274    /// for _ in 0..42 {
275    ///     pool.execute(|| {
276    ///         println!("Hello from a worker thread!")
277    ///     })
278    /// }
279    /// ```
280    pub fn max_thread_count(mut self, max_thread_count: usize) -> Builder {
281        assert!(max_thread_count > 0);
282        self.max_thread_count = Some(max_thread_count);
283        self
284    }
285
286    /// Set the thread name for each of the threads spawned by the built [`ThreadPool`]. If not
287    /// specified, threads spawned by the thread pool will be unnamed.
288    ///
289    /// [`ThreadPool`]: struct.ThreadPool.html
290    ///
291    /// # Examples
292    ///
293    /// Each thread spawned by this pool will have the name "foo":
294    ///
295    /// ```
296    /// use std::thread;
297    ///
298    /// let pool = lft_rust::lft_builder()
299    ///     .worker_name("foo")
300    ///     .build();
301    ///
302    /// for _ in 0..100 {
303    ///     pool.execute(|| {
304    ///         assert_eq!(thread::current().name(), Some("foo"));
305    ///     })
306    /// }
307    /// ```
308    pub fn worker_name<S: AsRef<str>>(mut self, name: S) -> Builder {
309        // TODO save the copy with Into<String>
310        self.worker_name = Some(name.as_ref().to_owned());
311        self
312    }
313
314    /// Set the stack size (in bytes) for each of the threads spawned by the built [`ThreadPool`].
315    /// If not specified, threads spawned by the threadpool will have a stack size [as specified in
316    /// the `std::thread` documentation][thread].
317    ///
318    /// [thread]: https://doc.rust-lang.org/nightly/std/thread/index.html#stack-size
319    /// [`ThreadPool`]: struct.ThreadPool.html
320    ///
321    /// # Examples
322    ///
323    /// Each thread spawned by this pool will have a 4 MB stack:
324    ///
325    /// ```
326    /// let pool = lft_rust::lft_builder()
327    ///     .thread_stack_size(4096 * 1024)
328    ///     .build();
329    ///
330    /// for _ in 0..100 {
331    ///     pool.execute(|| {
332    ///         println!("This thread has a 4 MB stack size!");
333    ///     })
334    /// }
335    /// ```
336    pub fn thread_stack_size(mut self, size: usize) -> Builder {
337        self.thread_stack_size = Some(size);
338        self
339    }
340
341    /// Finalize the [`Builder`] and build the [`ThreadPool`].
342    ///
343    /// [`Builder`]: struct.Builder.html
344    /// [`ThreadPool`]: struct.ThreadPool.html
345    ///
346    /// # Examples
347    ///
348    /// ```
349    /// let pool = lft_rust::lft_builder()
350    ///     .num_workers(8)
351    ///     .thread_stack_size(16*1024*1024)
352    ///     .build();
353    /// ```
354    pub fn build(self) -> ThreadPool {
355        let mut num_workers = self.num_workers.unwrap_or_else(num_cpus::get);
356        let max_thread_count = self.max_thread_count.unwrap_or_else(|| {num_workers});
357        if max_thread_count < num_workers {
358            warn!("Number of works is larger than max thread number, shrinking 
359                     the thread pool to max thread number {}.", max_thread_count);
360            num_workers = max_thread_count;
361        }
362
363        let mut num_jobs_list: Vec<Arc<AtomicUsize>> = Vec::with_capacity(max_thread_count);
364        let mut thread_closing_list: Vec<Arc<AtomicI8>> = Vec::with_capacity(max_thread_count);
365        let mut sender_list: Vec<Sender<Thunk<'static>>> = Vec::with_capacity(max_thread_count);
366        let mut receiver_list: Vec<Arc<Receiver<Thunk<'static>>>> = Vec::with_capacity(max_thread_count);
367        for i in 0..max_thread_count {
368            let (tx, rx) = unbounded::<Thunk<'static>>();
369            num_jobs_list.push(Arc::new(AtomicUsize::new(0)));
370            sender_list.push(tx);
371            receiver_list.push(Arc::new(rx));
372            if i < num_workers {
373                thread_closing_list.push(Arc::new(AtomicI8::new(1)));
374            } else {
375                thread_closing_list.push(Arc::new(AtomicI8::new(3)));
376            }
377        }
378
379        let context = Arc::new(ThreadPoolContext {
380            queued_count: num_jobs_list.clone(),
381            thread_closing: thread_closing_list.clone(),
382            senders: sender_list,
383            receivers: receiver_list.clone(),
384        });
385
386        let shared_data = Arc::new(ThreadPoolSharedData {
387            name: self.worker_name,
388            // job_receiver: Mutex::new(rx),
389            empty_condvar: Condvar::new(),
390            empty_trigger: Mutex::new(()),
391            join_generation: AtomicUsize::new(0),
392            queued_count: AtomicUsize::new(0),
393            active_count: AtomicUsize::new(0),
394            num_workers: AtomicUsize::new(num_workers),
395            max_thread_count: AtomicUsize::new(max_thread_count),
396            panic_count: AtomicUsize::new(0),
397            stack_size: self.thread_stack_size,
398        });
399
400        // Threadpool threads
401        let sleep_duration = time::Duration::from_millis(8);
402        for i in 0..max_thread_count {
403            spawn_in_pool(shared_data.clone(), 
404                          receiver_list[i].clone(), 
405                          num_jobs_list[i].clone(),
406                          thread_closing_list[i].clone());
407
408            while thread_closing_list[i].load(Ordering::SeqCst) != 0 {
409                thread::sleep(sleep_duration);
410            }
411        }
412
413        ThreadPool {
414            // jobs: tx,
415            shared_data: shared_data,
416            context: context,
417        }
418    }
419}
420
421struct ThreadPoolSharedData {
422    name: Option<String>,
423    empty_trigger: Mutex<()>,
424    empty_condvar: Condvar,
425    join_generation: AtomicUsize,
426    queued_count: AtomicUsize,
427    active_count: AtomicUsize,
428    num_workers: AtomicUsize,
429    max_thread_count: AtomicUsize,
430    panic_count: AtomicUsize,
431    stack_size: Option<usize>,
432}
433
434impl ThreadPoolSharedData {
435    fn has_work(&self) -> bool {
436        self.queued_count.load(Ordering::SeqCst) > 0 || self.active_count.load(Ordering::SeqCst) > 0
437    }
438
439    /// Notify all observers joining this pool if there is no more work to do.
440    fn no_work_notify_all(&self) {
441        if !self.has_work() {
442            *self
443                .empty_trigger
444                .lock()
445                .expect("Unable to notify all joining threads");
446            self.empty_condvar.notify_all();
447        }
448    }
449}
450
451struct ThreadPoolContext {
452    queued_count: Vec<Arc<AtomicUsize>>,
453    senders: Vec<Sender<Thunk<'static>>>,
454    receivers: Vec<Arc<Receiver<Thunk<'static>>>>,
455    thread_closing: Vec<Arc<AtomicI8>>,
456}
457
458/// Abstraction of a thread pool for basic parallelism.
459pub struct ThreadPool {
460    // How the threadpool communicates with subthreads.
461    //
462    // This is the only such Sender, so when it is dropped all subthreads will
463    // quit.
464    // jobs: Sender<Thunk<'static>>,
465    shared_data: Arc<ThreadPoolSharedData>,
466    context: Arc<ThreadPoolContext>,
467}
468
469impl ThreadPool {
470    /// Executes the function `job` on a thread in the pool.
471    ///
472    /// # Examples
473    ///
474    /// Execute four jobs on a thread pool that can run two jobs concurrently:
475    ///
476    /// ```
477    /// let pool = lft_rust::lft_auto_config();
478    /// pool.execute(|| println!("hello"));
479    /// pool.execute(|| println!("world"));
480    /// pool.execute(|| println!("foo"));
481    /// pool.execute(|| println!("bar"));
482    /// pool.join();
483    /// ```
484    pub fn execute<F>(&self, job: F)
485    where
486        F: FnOnce() + Send + 'static,
487    {
488        let max_thread_count = self.shared_data.max_thread_count.load(Ordering::Relaxed);
489
490        loop {
491            let mut target_thread_id = max_thread_count + 1;
492            let mut min_jobs_counted: usize = 0;
493            for i in 0..max_thread_count {
494                if self.context.thread_closing[i].load(Ordering::Relaxed) > 0 {
495                    // The thread is closed or about to close.
496                    continue;
497                }
498                if self.context.queued_count[i].load(Ordering::SeqCst) == 0 {
499                    target_thread_id = i;
500                    break;
501                }
502                if target_thread_id > max_thread_count 
503                    || self.context.queued_count[i].load(Ordering::SeqCst) < min_jobs_counted {
504                    target_thread_id = i;
505                    min_jobs_counted = self.context.queued_count[i].load(Ordering::Relaxed);
506                }
507            }
508
509
510            if target_thread_id < max_thread_count && 
511                self.context.thread_closing[target_thread_id].load(Ordering::SeqCst) == 0 {
512                // trace!("Target thread id: {}.", target_thread_id);
513                self.shared_data.queued_count.fetch_add(1, Ordering::SeqCst);
514                self.context.queued_count[target_thread_id].fetch_add(1, Ordering::SeqCst);
515                self.context.senders[target_thread_id]
516                    .send(Box::new(job))
517                    .expect("ThreadPool::execute unable to send job into queue.");
518                break;
519            }
520            let ten_millis = time::Duration::from_millis(10);
521            thread::sleep(ten_millis);
522        }
523    }
524
525    /// Returns the number of jobs waiting to executed in the pool.
526    ///
527    /// # Examples
528    ///
529    /// ```
530    /// use threadpool::ThreadPool;
531    /// use std::time::Duration;
532    /// use std::thread::sleep;
533    ///
534    /// let pool = lft_rust::lft_builder()
535    ///                     .num_workers(2)
536    ///                     .build();
537    /// for _ in 0..10 {
538    ///     pool.execute(|| {
539    ///         sleep(Duration::from_secs(100));
540    ///     });
541    /// }
542    ///
543    /// sleep(Duration::from_secs(1)); // wait for threads to start
544    /// assert_eq!(8, pool.queued_count());
545    /// ```
546    pub fn queued_count(&self) -> usize {
547        self.shared_data.queued_count.load(Ordering::Relaxed)
548    }
549
550    /// Returns the number of currently active worker threads.
551    ///
552    /// # Examples
553    ///
554    /// ```
555    /// use std::time::Duration;
556    /// use std::thread::sleep;
557    ///
558    /// let pool = lft_rust::lft_builder()
559    ///                     .num_workers(4)
560    ///                     .build();
561    /// for _ in 0..10 {
562    ///     pool.execute(move || {
563    ///         sleep(Duration::from_secs(100));
564    ///     });
565    /// }
566    ///
567    /// sleep(Duration::from_secs(1)); // wait for threads to start
568    /// assert_eq!(4, pool.active_count());
569    /// ```
570    pub fn active_count(&self) -> usize {
571        self.shared_data.active_count.load(Ordering::SeqCst)
572    }
573
574    /// Returns the maximum number of threads the pool will execute concurrently.
575    ///
576    /// # Examples
577    ///
578    /// ```
579    /// let pool = lft_rust::lft_builder()
580    ///                     .num_workers(4)
581    ///                     .build();
582    /// assert_eq!(4, pool.max_count());
583    ///
584    /// pool.set_num_workers(8);
585    /// assert_eq!(8, pool.max_count());
586    /// ```
587    pub fn max_count(&self) -> usize {
588        self.shared_data.max_thread_count.load(Ordering::Relaxed)
589    }
590
591    /// Returns the number of workers running in the threadpool.
592    pub fn num_workers(&self) -> usize {
593        self.shared_data.num_workers.load(Ordering::Relaxed)
594    }
595
596    /// Returns the number of panicked threads over the lifetime of the pool.
597    ///
598    /// # Examples
599    ///
600    /// ```
601    /// let pool = lft_rust::lft_builder()
602    ///                     .num_workers(4)
603    ///                     .build();
604    /// for n in 0..10 {
605    ///     pool.execute(move || {
606    ///         // simulate a panic
607    ///         if n % 2 == 0 {
608    ///             panic!()
609    ///         }
610    ///     });
611    /// }
612    /// pool.join();
613    ///
614    /// assert_eq!(5, pool.panic_count());
615    /// ```
616    pub fn panic_count(&self) -> usize {
617        self.shared_data.panic_count.load(Ordering::Relaxed)
618    }
619
620    /// Sets the number of worker-threads to use as `num_workers`.
621    /// Can be used to change the threadpool size during runtime.
622    /// Will not abort already running or waiting threads.
623    ///
624    /// # Panics
625    ///
626    /// This function will panic if `num_workers` is 0.
627    ///
628    /// # Examples
629    ///
630    /// ```
631    /// use std::time::Duration;
632    /// use std::thread::sleep;
633    ///
634    /// let  pool = threadpool::builder().num_workers(4).build();
635    /// for _ in 0..10 {
636    ///     pool.execute(move || {
637    ///         sleep(Duration::from_secs(100));
638    ///     });
639    /// }
640    ///
641    /// sleep(Duration::from_secs(1)); // wait for threads to start
642    /// assert_eq!(4, pool.active_count());
643    /// assert_eq!(6, pool.queued_count());
644    ///
645    /// // Increase thread capacity of the pool
646    /// pool.set_num_workers(8);
647    ///
648    /// sleep(Duration::from_secs(1)); // wait for new threads to start
649    /// assert_eq!(8, pool.active_count());
650    /// assert_eq!(2, pool.queued_count());
651    ///
652    /// // Decrease thread capacity of the pool
653    /// // No active threads are killed
654    /// pool.set_num_workers(4);
655    ///
656    /// assert_eq!(8, pool.active_count());
657    /// assert_eq!(2, pool.queued_count());
658    /// ```
659    // pub fn set_num_workers(&self, num_workers: usize) {
660    //     assert!(num_workers >= 1);
661    //     let prev_num_workers = self
662    //         .shared_data
663    //         .max_thread_count
664    //         .swap(num_workers, Ordering::Release);
665    //     if let Some(num_spawn) = num_workers.checked_sub(prev_num_workers) {
666    //         // Spawn new threads
667    //         for _ in 0..num_spawn {
668    //             spawn_in_pool(self.shared_data.clone());
669    //         }
670    //     }
671    // }
672
673    /// Spawn an extra worker thread. Can be used to increase the number of
674    /// work threads during runtimes. If the number of threads is already 
675    /// the maximum, it will print out an warning.
676    ///
677    /// # Examples
678    /// ```
679    /// use std::time::Duration;
680    /// use std::thread::sleep;
681    ///
682    /// let  pool = threadpool::builder().num_workers(4).build();
683    /// for _ in 0..10 {
684    ///     pool.execute(move || {
685    ///         sleep(Duration::from_secs(100));
686    ///     });
687    /// pool.spawn_extra_one_worker();
688    /// ```
689    pub fn spawn_extra_one_worker(&self) {
690        if self.shared_data.num_workers.load(Ordering::Acquire) 
691            >= self.shared_data.max_thread_count.load(Ordering::Relaxed) {
692                warn!("Max thread number exceeded.");
693                ()
694        } 
695        self.shared_data.num_workers.fetch_add(1, Ordering::SeqCst);
696
697        let mut spawn_completed = false;
698        while !spawn_completed {
699            let max_thread_count = self.shared_data.max_thread_count.load(Ordering::Relaxed);
700            for i in 0..max_thread_count {
701                if self.context.thread_closing[i].compare_exchange(3, 
702                                                                   1, 
703                                                                   Ordering::SeqCst, 
704                                                                   Ordering::Relaxed) == Ok(3) {
705                    // If one thread is closed, we will try to open this thread.
706                    spawn_in_pool(self.shared_data.clone(), 
707                                  self.context.receivers[i].clone(), 
708                                  self.context.queued_count[i].clone(),
709                                  self.context.thread_closing[i].clone());
710                    spawn_completed = true;
711                    break;
712                }
713            }
714
715            if self.shared_data.num_workers.load(Ordering::SeqCst) 
716                >= self.shared_data.max_thread_count.load(Ordering::Relaxed) {
717                    warn!("Max thread number exceeded.");
718                    break;
719            } 
720        }
721
722    }
723
724    /// Shutdown a worker thread in the threadpool. 
725    /// Can be used to increase the number of work threads during runtimes. 
726    /// If the number of threads is already 0, it will print out an warning.
727    ///
728    /// # Examples
729    /// ```
730    /// use std::time::Duration;
731    /// use std::thread::sleep;
732    ///
733    /// let  pool = threadpool::builder().num_workers(4).build();
734    /// for _ in 0..10 {
735    ///     pool.execute(move || {
736    ///         sleep(Duration::from_secs(100));
737    ///     });
738    /// pool.spawn_extra_one_worker();
739    /// ```
740    pub fn shutdown_one_worker(&self) {
741        if self.shared_data.num_workers.load(Ordering::SeqCst) <= 0 {
742            warn!("No thread to shutdown");
743            ()
744        }
745        self.shared_data.num_workers.fetch_sub(1, Ordering::SeqCst);
746
747        loop {
748            let max_thread_count = self.shared_data.max_thread_count.load(Ordering::Relaxed);
749            let mut target_thread_id = max_thread_count + 1;
750            let mut min_num_of_jobs = 0;
751
752            for i in 0..max_thread_count {
753                if self.context.thread_closing[i].load(Ordering::Relaxed) > 0 {
754                    continue;
755                }
756
757                if target_thread_id > max_thread_count || 
758                    min_num_of_jobs > self.context.queued_count[i].load(Ordering::Relaxed) {
759                        target_thread_id = i;
760                        min_num_of_jobs = self.context.queued_count[i].load(Ordering::Acquire);
761                }
762            }
763
764            // CAS to check if the target thread is still running.
765            if target_thread_id < max_thread_count && 
766                self.context.thread_closing[target_thread_id].compare_exchange(0,
767                                                                               2,
768                                                                               Ordering::SeqCst,
769                                                                               Ordering::Relaxed) == Ok(0) {
770                    trace!("Closing thread id: {}.", target_thread_id);
771                    break;
772            }
773           
774            if self.shared_data.num_workers.load(Ordering::SeqCst) <= 0 {
775                warn!("No thread to shutdown");
776                break;
777            }
778        }
779    }
780
781    /// Block the current thread until all jobs in the pool have been executed.
782    ///
783    /// Calling `join` on an empty pool will cause an immediate return.
784    /// `join` may be called from multiple threads concurrently.
785    /// A `join` is an atomic point in time. All threads joining before the join
786    /// event will exit together even if the pool is processing new jobs by the
787    /// time they get scheduled.
788    ///
789    /// Calling `join` from a thread within the pool will cause a deadlock. This
790    /// behavior is considered safe.
791    ///
792    /// **Note:** Join will not stop the worker threads. You will need to `drop`
793    /// all instances of `ThreadPool` for the worker threads to terminate.
794    ///
795    /// # Examples
796    ///
797    /// ```
798    /// use threadpool::ThreadPool;
799    /// use std::sync::Arc;
800    /// use std::sync::atomic::{AtomicUsize, Ordering};
801    ///
802    /// let pool = lft_rust::lft_auto_config();
803    /// let test_count = Arc::new(AtomicUsize::new(0));
804    ///
805    /// for _ in 0..42 {
806    ///     let test_count = test_count.clone();
807    ///     pool.execute(move || {
808    ///         test_count.fetch_add(1, Ordering::Relaxed);
809    ///     });
810    /// }
811    ///
812    /// pool.join();
813    /// assert_eq!(42, test_count.load(Ordering::Relaxed));
814    /// ```
815    pub fn join(&self) {
816        // fast path requires no mutex
817        if self.shared_data.has_work() == false {
818            return ();
819        }
820
821        let generation = self.shared_data.join_generation.load(Ordering::SeqCst);
822        let mut lock = self.shared_data.empty_trigger.lock().unwrap();
823
824        while generation == self.shared_data.join_generation.load(Ordering::Relaxed)
825            && self.shared_data.has_work()
826        {
827            lock = self.shared_data.empty_condvar.wait(lock).unwrap();
828        }
829
830        // increase generation if we are the first joining thread to come out of the loop
831        let _ = self.shared_data.join_generation.compare_exchange(
832            generation,
833            generation.wrapping_add(1),
834            Ordering::SeqCst,
835            Ordering::SeqCst,
836        );
837    }
838}
839
840impl Clone for ThreadPool {
841    /// Cloning a pool will create a new handle to the pool.
842    /// The behavior is similar to [Arc](https://doc.rust-lang.org/stable/std/sync/struct.Arc.html).
843    ///
844    /// We could for example submit jobs from multiple threads concurrently.
845    ///
846    /// ```
847    /// use std::thread;
848    /// use crossbeam_channel::unbounded;
849    ///
850    /// let pool = lft_rust::lft_builder()
851    ///                 .worker_name("clone example")
852    ///                 .num_workers(2)
853    ///                 .build();
854    ///
855    /// let results = (0..2)
856    ///     .map(|i| {
857    ///         let pool = pool.clone();
858    ///         thread::spawn(move || {
859    ///             let (tx, rx) = unbounded();
860    ///             for i in 1..12 {
861    ///                 let tx = tx.clone();
862    ///                 pool.execute(move || {
863    ///                     tx.send(i).expect("channel will be waiting");
864    ///                 });
865    ///             }
866    ///             drop(tx);
867    ///             if i == 0 {
868    ///                 rx.iter().fold(0, |accumulator, element| accumulator + element)
869    ///             } else {
870    ///                 rx.iter().fold(1, |accumulator, element| accumulator * element)
871    ///             }
872    ///         })
873    ///     })
874    ///     .map(|join_handle| join_handle.join().expect("collect results from threads"))
875    ///     .collect::<Vec<usize>>();
876    ///
877    /// assert_eq!(vec![66, 39916800], results);
878    /// ```
879    fn clone(&self) -> ThreadPool {
880        ThreadPool {
881            // jobs: self.jobs.clone(),
882            shared_data: self.shared_data.clone(),
883            context: self.context.clone(),
884        }
885    }
886}
887
888impl fmt::Debug for ThreadPool {
889    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
890        f.debug_struct("ThreadPool")
891            .field("name", &self.shared_data.name)
892            .field("queued_count", &self.queued_count())
893            .field("active_count", &self.active_count())
894            .field("max_count", &self.max_count())
895            .field("num_workers", &self.num_workers())
896            .finish()
897    }
898}
899
900impl PartialEq for ThreadPool {
901    /// Check if you are working with the same pool
902    ///
903    /// ```
904    /// let a = lft_rust::lft_auto_config();
905    /// let b = lft_rust::lft_auto_config();
906    ///
907    /// assert_eq!(a, a);
908    /// assert_eq!(b, b);
909    ///
910    /// assert_ne!(a, b);
911    /// assert_ne!(b, a);
912    /// ```
913    fn eq(&self, other: &ThreadPool) -> bool {
914        Arc::ptr_eq(&self.shared_data, &other.shared_data)
915    }
916}
917impl Eq for ThreadPool {}
918
919fn spawn_in_pool(shared_data: Arc<ThreadPoolSharedData>, 
920                 receiver: Arc<Receiver<Thunk<'static>>>,
921                 num_jobs: Arc<AtomicUsize>,
922                 thread_closing: Arc<AtomicI8>) {
923    let mut builder = thread::Builder::new();
924    if let Some(ref name) = shared_data.name {
925        builder = builder.name(name.clone());
926    }
927    if let Some(ref stack_size) = shared_data.stack_size {
928        builder = builder.stack_size(stack_size.to_owned());
929    }
930    builder
931        .spawn(move || {
932            // Will spawn a new thread on panic unless it is cancelled.
933            let sentinel = Sentinel::new(&shared_data, &receiver, &num_jobs, &thread_closing);
934            // thread_closing.swap(0, Ordering::SeqCst);
935
936            if thread_closing.compare_exchange(1, 0, Ordering::SeqCst, Ordering::Relaxed) == Ok(1) {
937                loop {
938                    // Shutdown this thread if the pool has become smaller
939                    // let thread_counter_val = shared_data.num_workers.load(Ordering::Acquire);
940                    // let max_thread_count_val = shared_data.max_thread_count.load(Ordering::Relaxed);
941                    if thread_closing.load(Ordering::SeqCst) == 2
942                        && num_jobs.load(Ordering::SeqCst) == 0 {
943                        break;
944                    }
945                    let message = {
946                        // Each thread will have a job queue, thread will fetch its own
947                        // work from its job queue.
948                        receiver.recv()
949                    };
950
951                    let job = match message {
952                        Ok(job) => job,
953                        // The ThreadPool was dropped.
954                        Err(..) => break,
955                    };
956                    // Do not allow IR around the job execution
957                    shared_data.active_count.fetch_add(1, Ordering::SeqCst);
958
959                    job.call_box();
960
961                    num_jobs.fetch_sub(1, Ordering::SeqCst);
962                    shared_data.queued_count.fetch_sub(1, Ordering::SeqCst);
963                    shared_data.active_count.fetch_sub(1, Ordering::SeqCst);
964                    if num_jobs.load(Ordering::SeqCst) == 0 {
965                        shared_data.no_work_notify_all();
966                    }
967                }
968
969                if thread_closing.compare_exchange(0, 
970                                                   3, 
971                                                   Ordering::SeqCst, 
972                                                   Ordering::Relaxed) == Ok(0) {
973                    shared_data.num_workers.fetch_sub(1, Ordering::SeqCst);
974                } else {
975                    let _ = thread_closing.compare_exchange(2, 
976                                                            3, 
977                                                            Ordering::SeqCst, 
978                                                            Ordering::Relaxed);
979                }
980            }
981
982            sentinel.cancel();
983        })
984        .unwrap();
985}