jtp/
thread_pool.rs

1use crate::{
2    task::{Task, TaskFn, TaskListeners},
3    worker::Worker,
4    ThreadPoolBuilder,
5};
6
7use crossbeam_channel::{bounded, Receiver, Sender, TrySendError};
8
9use std::{
10    sync::{atomic::AtomicUsize, Arc, Mutex},
11    thread,
12    time::Duration,
13};
14
15/// A function that used to create a custom thread.
16pub type ThreadFactory = dyn Fn() -> thread::Builder + Send + Sync + 'static;
17
18type TPResult<T> = Result<T, TPError>;
19
20/// An error returned from the [`ThreadPool::execute`].
21///
22/// [`ThreadPool::execute`]: crate::ThreadPool::execute
23#[derive(Debug)]
24pub enum TPError {
25    /// The task could not be executed because the task is rejected
26    /// by [`RejectedTaskHandler::Abort`] when the thread pool and the
27    /// task channel are both full.
28    Abort,
29
30    /// The task could not be executed because the thread pool is closed.
31    Closed,
32}
33
34impl std::error::Error for TPError {}
35
36impl std::fmt::Display for TPError {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        match &self {
39            TPError::Abort => writeln!(f, "task abortion error."),
40            TPError::Closed => writeln!(f, "the thread pool is closwd."),
41        }
42    }
43}
44
45/// If a task is rejected, the task will be handled by this.
46#[derive(Clone)]
47pub enum RejectedTaskHandler {
48    /// Returns [`TPError::Abort`].
49    Abort,
50
51    /// Nothing to do, just return [`Ok`].
52    Discard,
53
54    /// Immediately run the rejected task in the caller thread.
55    CallerRuns,
56}
57
58pub(crate) struct ThreadPoolSharedData {
59    pub(crate) sender: Mutex<Option<Sender<Task>>>,
60    pub(crate) core_workers: Mutex<Option<Vec<Worker>>>,
61    pub(crate) workers: Mutex<Option<Vec<Worker>>>,
62    pub(crate) next_task_id: AtomicUsize,
63}
64
65impl ThreadPoolSharedData {
66    pub(crate) fn num_of_core_workers(&self) -> usize {
67        self.core_workers
68            .lock()
69            .unwrap()
70            .as_ref()
71            .map_or(0, Vec::len)
72    }
73
74    pub(crate) fn num_of_active_workers(&self) -> usize {
75        self.workers.lock().unwrap().as_ref().map_or(0, |x| {
76            x.iter().filter(|worker| !worker.is_finished()).count()
77        })
78    }
79}
80
81/// A `ThreadPool` consists of a collection of reusable threads and
82/// a bounded channel that is used to transfer and hold submitted
83/// tasks.
84///
85/// # Bounded Channel(Queue)
86///
87/// A bounded channel is a concurrent structure that can help to
88/// transfer messages across multiple threads. It holds a finite
89/// number of elements, which is useful to prevent resource exhaustion.
90///
91/// A bounded channel consists of two sides: `Sender` and `Receiver`.
92///
93/// ## Sender Side
94///
95/// A sender is used to send a message into a channel. In thread pool,
96/// we use the sender to transfer submitted tasks to avalible worker
97/// threads.
98///
99/// ## Receiver Side
100///
101/// A receiver is used to fetch messages from the channel. There are
102/// limited worker threads in a thread pool, for each of which it
103/// contains a receiver that is used to fetch tasks from the channel.
104///
105/// # Worker Thread
106///
107/// We use a special structure, `Worker`, to represent a thread that
108/// is always receiving tasks and executing them. In this library,
109/// there are two kinds of the worker:
110/// 1. Core worker: The worker thread never be terminated except the
111/// associated thread pool is closed and the channel is empty.
112/// 2. Non-core worker: A thread in this worker can be idle if no task
113/// is received for a certain period of time(`keep_alive_time`).
114///
115/// This thread pool will store core workers and non-core workers in
116/// two vectors. When you execute a task, it creates a core worker to
117/// process the task if the core worker vector is not full, otherwise
118/// the task will be sent to the task channel. If the channel buffer
119/// is full, it attempts to find an idle worker thread or creates a
120/// new non-core worker thread to process the task.
121///
122/// Worker threads will keep fetching tasks from the channel and
123/// executing them until the queue is empty and the sender is dropped
124/// or no tasks were received for a long time (only non-core thread).
125///
126/// # Rejected Task
127///
128/// New tasks will be rejected when the channel is full and the number
129/// of worker threads in the thread pool reaches a certain number(`max_pool_size`). A rejected task will be handled by the [`RejectedTaskHandler`].
130/// You can set the handler for rejected tasks when you build a thread
131/// pool with [`ThreadPoolBuilder`].
132#[derive(Clone)]
133pub struct ThreadPool {
134    pub(crate) reciver: Receiver<Task>,
135    pub(crate) share: Arc<ThreadPoolSharedData>,
136
137    pub(crate) core_pool_size: usize,
138    pub(crate) max_pool_size: usize,
139    pub(crate) keep_alive_time: Duration,
140    pub(crate) rejected_task_handler: RejectedTaskHandler,
141    pub(crate) task_lisenters: Arc<TaskListeners>,
142    pub(crate) thread_factory: Arc<ThreadFactory>,
143}
144
145impl ThreadPool {
146    /// Builds a thread pool from a configration(builder).
147    ///
148    /// This assumes arguments of the builder are valid.
149    pub(crate) fn from_builder(builder: ThreadPoolBuilder) -> Self {
150        let (sender, reciver) = bounded(builder.channel_capacity);
151        Self {
152            reciver,
153            share: Arc::new(ThreadPoolSharedData {
154                sender: Mutex::new(Some(sender)),
155                core_workers: Mutex::new(Some(Vec::default())),
156                workers: Mutex::new(Some(Vec::default())),
157                next_task_id: AtomicUsize::new(0),
158            }),
159            core_pool_size: builder.core_pool_size,
160            max_pool_size: builder.max_pool_size,
161            keep_alive_time: builder.keep_alive_time,
162            rejected_task_handler: builder.rejected_task_handler,
163            task_lisenters: Arc::new(builder.task_lisenters),
164            thread_factory: builder.thread_factory,
165        }
166    }
167
168    /// Executes the given task in the future.
169    ///
170    /// If the task queue is full and no worker thread can be
171    /// allocated to execute the task, the task will be handled by the
172    /// setted [`RejectedTaskHandler`].
173    ///
174    /// # Errors
175    ///
176    /// 1. [`Abort`]: The task was handled by the [`RejectedTaskHandler::Abort`].
177    ///
178    /// 2. [`Closed`]: The channel was closed.
179    ///
180    /// [`Abort`]: crate::TPError::Abort
181    /// [`Closed`]: crate::TPError::Closed
182    pub fn execute<F>(&self, task_fn: F) -> Result<(), TPError>
183    where
184        F: FnOnce() + Send + 'static,
185    {
186        if self.is_closed() {
187            return Err(TPError::Closed);
188        }
189
190        let task = self.create_task(Box::new(task_fn));
191        let mut core_workers = self.share.core_workers.lock().unwrap();
192        if let Some(core_workers) = core_workers.as_mut() {
193            // Add a new worker to the core thread pool.
194            if core_workers.len() < self.core_pool_size {
195                let worker = self.create_worker(task, true);
196                core_workers.push(worker);
197                return Ok(());
198            }
199        }
200        // Release lock.
201        drop(core_workers);
202        self.send_task(task)
203    }
204
205    /// Counts all active worker threads and returns it.
206    #[must_use]
207    pub fn active_count(&self) -> usize {
208        self.share.num_of_active_workers() + self.share.num_of_core_workers()
209    }
210
211    /// Closes a thread pool.
212    ///
213    /// A closed thread pool will not accept any tasks, but will still
214    /// process tasks in the channel(queue).
215    ///
216    /// # Examples
217    ///
218    /// ```
219    /// use jtp::ThreadPoolBuilder;
220    /// let thread_pool = ThreadPoolBuilder::default()
221    ///     .build();
222    ///
223    /// thread_pool.shutdown();
224    ///
225    /// assert!(thread_pool.execute(|| {
226    ///     println!("Hello");
227    /// }).is_err());
228    /// ```
229    pub fn shutdown(&self) {
230        self.share.sender.lock().unwrap().take();
231    }
232
233    /// Returns `true` if the thread pool is closed.
234    #[must_use]
235    pub fn is_closed(&self) -> bool {
236        self.share.sender.lock().unwrap().is_none()
237    }
238
239    /// Waits for all worker threads to finish. Note that is worker
240    /// threads instead of tasks.
241    ///
242    /// If this is called in a worker thread, then the worker thread
243    /// will not be joined.
244    ///
245    /// Note that this function will close the thread pool because
246    /// if the thread pool is not closed, worker threads are never
247    /// be terminated.
248    ///
249    /// # Errors
250    /// An error is returned if a thread panics.
251    ///
252    /// # Examples
253    ///
254    /// ```
255    /// use jtp::ThreadPoolBuilder;
256    /// use std::sync::atomic::{AtomicUsize, Ordering};
257    /// use std::sync::Arc;
258    ///
259    /// let mut thread_pool = ThreadPoolBuilder::default()
260    ///     .build();
261    ///
262    /// let sum = Arc::new(AtomicUsize::new(0));
263    /// for _ in 0..10 {
264    ///     let sum = sum.clone();
265    ///     thread_pool.execute(move || {
266    ///         // Increase `sum`.
267    ///         sum.fetch_add(1, Ordering::SeqCst);
268    ///     });
269    /// }
270    ///
271    /// // Block current thread until all worker threads are finished.
272    /// thread_pool.wait().unwrap();
273    /// assert_eq!(10, sum.load(Ordering::Relaxed));
274    /// ```
275    pub fn wait(&self) -> std::thread::Result<()> {
276        self.shutdown();
277        Self::wait_workers(self.share.core_workers.lock().unwrap().take())?;
278        Self::wait_workers(self.share.workers.lock().unwrap().take())
279    }
280
281    fn wait_workers(workers: Option<Vec<Worker>>) -> std::thread::Result<()> {
282        if let Some(workers) = workers {
283            for worker in workers {
284                worker.join()?;
285            }
286        }
287        Ok(())
288    }
289
290    fn create_task(&self, task_fn: TaskFn) -> Task {
291        let id = self
292            .share
293            .next_task_id
294            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
295        Task::create(id, task_fn, self.task_lisenters.clone())
296    }
297
298    fn create_worker(&self, task: Task, is_core: bool) -> Worker {
299        Worker::new(
300            is_core,
301            self.keep_alive_time,
302            self.thread_factory.clone(),
303            self.reciver.clone(),
304            task,
305        )
306    }
307
308    fn send_task(&self, task: Task) -> TPResult<()> {
309        let sender = self.share.sender.lock().unwrap();
310        if sender.is_none() {
311            return Err(TPError::Closed);
312        }
313
314        if let Err(err) = sender.as_ref().unwrap().try_send(task) {
315            // Release lock.
316            drop(sender);
317            return match err {
318                TrySendError::Full(task) => self.process_task_if_channel_full(task),
319                TrySendError::Disconnected(_) => Err(TPError::Closed),
320            };
321        }
322        Ok(())
323    }
324
325    fn process_task_if_channel_full(&self, task: Task) -> TPResult<()> {
326        let mut workers = self.share.workers.lock().unwrap();
327        if workers.is_none() {
328            // the `wait` function will take workers and close thread
329            // pool but the task is accepted.
330            return self.reject(task);
331        }
332
333        let non_core_workers = workers.as_mut().unwrap();
334        // Attempt to find an idle worker.
335        let idle_worker = non_core_workers
336            .iter_mut()
337            .find(|worker| worker.is_finished());
338        if let Some(idle_worker) = idle_worker {
339            idle_worker.restart(task);
340            return Ok(());
341        }
342
343        if non_core_workers.len() < self.max_pool_size - self.core_pool_size {
344            let worker = self.create_worker(task, false);
345            non_core_workers.push(worker);
346            Ok(())
347        } else {
348            // Release lock.
349            drop(workers);
350            // Reject the task if there is no place in the channel and
351            // the size of the thread pool reachs max.
352            self.reject(task)
353        }
354    }
355
356    fn reject(&self, task: Task) -> Result<(), TPError> {
357        match &self.rejected_task_handler {
358            RejectedTaskHandler::Abort => Err(TPError::Abort),
359            RejectedTaskHandler::CallerRuns => {
360                task.run();
361                Ok(())
362            }
363            RejectedTaskHandler::Discard => Ok(()),
364        }
365    }
366}
367
368#[cfg(test)]
369mod tests {
370
371    use crate::{RejectedTaskHandler, ThreadPoolBuilder};
372    use std::{
373        collections::HashSet,
374        sync::{
375            atomic::{AtomicUsize, Ordering},
376            Arc, Mutex,
377        },
378        thread,
379        time::Duration,
380    };
381
382    #[test]
383    fn test_execute_in_multiple_threads() {
384        let thread_pool = ThreadPoolBuilder::default()
385            .core_pool_size(4)
386            .max_pool_size(10)
387            .channel_capacity(100)
388            .keep_alive_time(Duration::from_secs(100))
389            .build();
390
391        let sum = Arc::new(AtomicUsize::new(0));
392        let mut handles = Vec::new();
393        for _ in 0..10 {
394            let sum = sum.clone();
395            let thread_pool = thread_pool.clone();
396            handles.push(thread::spawn(move || {
397                for _ in 0..10 {
398                    let sum = sum.clone();
399                    thread_pool
400                        .execute(move || {
401                            sum.fetch_add(1, Ordering::SeqCst);
402                        })
403                        .ok();
404                }
405            }));
406        }
407
408        for handle in handles {
409            handle.join().unwrap();
410        }
411
412        // check shared data.
413        assert!(thread_pool.share.sender.lock().unwrap().is_some());
414        assert_eq!(4, thread_pool.share.num_of_core_workers());
415        assert!(thread_pool.share.num_of_active_workers() <= 6);
416
417        thread_pool.wait().unwrap();
418        assert_eq!(100, sum.load(Ordering::Relaxed));
419    }
420
421    #[test]
422    fn test_shutdown_in_multiple_threads() {
423        let thread_pool = ThreadPoolBuilder::default().build();
424        let counter = Arc::new(AtomicUsize::new(0));
425        let mut handles = Vec::new();
426        for _ in 0..100 {
427            let thread_pool = thread_pool.clone();
428            let counter = counter.clone();
429            handles.push(thread::spawn(move || {
430                if thread_pool.is_closed() {
431                    counter.fetch_add(1, Ordering::SeqCst);
432                    assert!(thread_pool.execute(|| ()).is_err());
433                }
434                thread_pool.shutdown();
435                assert!(thread_pool.execute(|| ()).is_err());
436            }));
437        }
438
439        for handle in handles {
440            handle.join().unwrap();
441        }
442
443        assert_eq!(99, counter.load(Ordering::Relaxed));
444    }
445
446    #[test]
447    fn test_lisenters() {
448        let map0 = Arc::new(Mutex::new(HashSet::new()));
449        let map1 = map0.clone();
450        let map2 = map0.clone();
451
452        let thread_pool = ThreadPoolBuilder::default()
453            .lisenter_before_execute(move |id| {
454                let mut map = map0.lock().unwrap();
455                map.insert(id);
456            })
457            .lisenter_after_execute(move |id| {
458                assert!(map1.lock().unwrap().contains(&id));
459            })
460            .channel_capacity(50)
461            .build();
462
463        for _ in 0..50 {
464            thread_pool
465                .execute(|| {
466                    thread::sleep(Duration::from_millis(20));
467                })
468                .unwrap();
469        }
470        thread_pool.shutdown();
471        thread_pool.wait().unwrap();
472        assert_eq!(50, map2.lock().unwrap().len());
473    }
474
475    #[test]
476    fn test_thread_factory() {
477        let thread_pool = ThreadPoolBuilder::new()
478            .thread_factory_fn(|| thread::Builder::new().name("test".into()))
479            .core_pool_size(2)
480            .max_pool_size(5)
481            .channel_capacity(5)
482            .rejected_handler(RejectedTaskHandler::Discard)
483            .build();
484
485        for _ in 0..20 {
486            thread_pool
487                .execute(|| thread::sleep(Duration::from_millis(20)))
488                .unwrap();
489        }
490
491        let workers = thread_pool.share.core_workers.lock().unwrap();
492        assert!(workers.as_ref().unwrap().len() == 2);
493        for core_worker in workers.as_ref().unwrap() {
494            assert_eq!(Some("test"), core_worker.handle.thread().name());
495        }
496
497        let workers = thread_pool.share.workers.lock().unwrap();
498        assert!(workers.as_ref().unwrap().len() == 3);
499        for worker in workers.as_ref().unwrap() {
500            assert_eq!(Some("test"), worker.handle.thread().name());
501        }
502    }
503}