Skip to main content

tokio_thread_pool/
lib.rs

1use std::sync::Arc;
2
3use tokio::{
4    runtime::{self, Runtime},
5    sync::Semaphore,
6    task::JoinHandle,
7};
8
9/// A small wrapper around the tokio runtime supporting multithreading with max concurrency limits
10///
11/// ```rust
12/// use tokio_thread_pool::ThreadPool;
13///
14/// // Create a pool with default settings
15/// let my_pool = ThreadPool::new(
16///   None, // optional max task concurrency (usize),
17///   None, // optional max number of threads defaulting to the number of CPU cores on the system (usize),
18///   None, // an optional tokio runtime that you provide with your own custom settings (tokio::Runtime)
19/// );
20///
21/// // Create a pool with a limit on task concurrency
22/// let my_pool = ThreadPool::new(Some(10), None, None); // maximimum of ten concurrent tasks running at once
23///
24/// // Create a pool with a limit on spawned threads
25/// let my_pool = ThreadPool::new(None, Some(4), None); // maximimum of four threads for task allocation
26///
27/// // Create a pool with your own runtime provided
28/// let my_pool = ThreadPool::new(
29///   None,
30///   None,
31///   Some(tokio::runtime::Builder::new_multi_thread().build().unwrap())
32/// );
33///
34/// // Spawn async tasks
35/// let handle = my_pool.spawn(async move || {}); // return any value
36/// // Spawn sync tasks
37/// let handle = my_pool.spawn_blocking(move || {}); // return any value
38///
39/// // Get result
40/// let result = handle.await;
41/// ```
42pub struct ThreadPool {
43    pub pool: Runtime,
44    semaphore: Arc<Semaphore>,
45}
46
47impl ThreadPool {
48    /// Constructs a new ThreadPool instance
49    ///
50    /// #### Arguments
51    ///
52    /// `max_concurrency` (`Option<usize>`): optional max task concurrency
53    ///
54    /// `max_threads` (`Option<usize>`): optional max number of threads defaulting to the number of CPU cores on the system
55    ///
56    /// `pool_override` (`Option<tokio::Runtime>`): an optional tokio runtime that you provide with your own custom settings
57    ///
58    /// ```rust
59    /// // Create a pool with default settings
60    /// let my_pool = ThreadPool::new(None, None, None);
61    ///
62    /// // Create a pool with a limit on task concurrency
63    /// let my_pool = ThreadPool::new(Some(10), None, None); // maximimum of ten concurrent tasks running at once
64    ///
65    /// // Create a pool with a limit on spawned threads
66    /// let my_pool = ThreadPool::new(None, Some(4), None); // maximimum of four threads for task allocation
67    ///
68    /// // Create a pool with your own runtime provided
69    /// let my_pool = ThreadPool::new(
70    ///   None,
71    ///   None,
72    ///   Some(tokio::runtime::Builder::new_multi_thread().build().unwrap())
73    /// );
74    /// ```
75    pub fn new(
76        max_concurrency: Option<usize>,
77        max_threads: Option<usize>,
78        pool_override: Option<Runtime>,
79    ) -> ThreadPool {
80        let pool = pool_override.unwrap_or(ThreadPool::create_pool(max_threads));
81        let semaphore = Arc::new(Semaphore::new(
82            max_concurrency.unwrap_or(Semaphore::MAX_PERMITS),
83        ));
84        ThreadPool { pool, semaphore }
85    }
86
87    /// Spawns an async task and returns its `Handler<T>`
88    ///
89    /// #### Arguments
90    ///
91    /// `task` (`(Fn() -> T)`): The task to execute inside of the thread pool
92    ///
93    /// ```rust
94    /// // Create a pool with default settings
95    /// let my_pool = ThreadPool::new(None, None, None);
96    ///
97    /// let my_handle = my_pool.spawn(async move || {});
98    ///
99    /// let result = my_handle.await;
100    /// ```
101    pub fn spawn<T: Send + 'static, F: (Fn() -> T) + 'static + Send>(
102        &mut self,
103        task: F,
104    ) -> JoinHandle<T> {
105        let concurrecy = self.semaphore.clone();
106        self.pool.spawn(async move {
107            let _ticket = concurrecy.acquire().await.unwrap();
108            task()
109        })
110    }
111
112    /// Spawns a synchronous task and returns its `Handler<T>`
113    ///
114    /// #### Arguments
115    ///
116    /// `task` (`(Fn() -> T)`): The task to execute inside of the thread pool
117    ///
118    /// ```rust
119    /// // Create a pool with default settings
120    /// let my_pool = ThreadPool::new(None, None, None);
121    ///
122    /// let my_handle = my_pool.spawn_blocking(async move || {});
123    ///
124    /// let result = my_handle.await;
125    /// ```
126    pub fn spawn_blocking<T: Send + 'static, F: (Fn() -> T) + 'static + Send>(
127        &mut self,
128        task: F,
129    ) -> JoinHandle<T> {
130        self.pool.spawn_blocking(task)
131    }
132
133    fn create_pool(threads: Option<usize>) -> Runtime {
134        let mut pool = runtime::Builder::new_multi_thread();
135        pool.enable_all();
136        match threads {
137            Some(size) => pool.worker_threads(size),
138            None => &pool,
139        };
140        pool.build().unwrap()
141    }
142}