Skip to main content

futures_executor/
thread_pool.rs

1use crate::enter;
2use crate::unpark_mutex::UnparkMutex;
3use futures_core::future::Future;
4use futures_core::task::{Context, Poll};
5use futures_task::{waker_ref, ArcWake};
6use futures_task::{FutureObj, Spawn, SpawnError};
7use futures_util::future::FutureExt;
8use std::boxed::Box;
9use std::fmt;
10use std::format;
11use std::io;
12use std::string::String;
13use std::sync::atomic::{AtomicUsize, Ordering};
14use std::sync::mpsc;
15use std::sync::{Arc, Mutex};
16use std::thread;
17
18/// A general-purpose thread pool for scheduling tasks that poll futures to
19/// completion.
20///
21/// The thread pool multiplexes any number of tasks onto a fixed number of
22/// worker threads.
23///
24/// This type is a clonable handle to the threadpool itself.
25/// Cloning it will only create a new reference, not a new threadpool.
26///
27/// This type is only available when the `thread-pool` feature of this
28/// library is activated.
29#[cfg_attr(docsrs, doc(cfg(feature = "thread-pool")))]
30pub struct ThreadPool {
31    state: Arc<PoolState>,
32}
33
34/// Thread pool configuration object.
35///
36/// This type is only available when the `thread-pool` feature of this
37/// library is activated.
38#[cfg_attr(docsrs, doc(cfg(feature = "thread-pool")))]
39pub struct ThreadPoolBuilder {
40    pool_size: usize,
41    stack_size: usize,
42    name_prefix: Option<String>,
43    after_start: Option<Arc<dyn Fn(usize) + Send + Sync>>,
44    before_stop: Option<Arc<dyn Fn(usize) + Send + Sync>>,
45}
46
47#[allow(dead_code)]
48trait AssertSendSync: Send + Sync {}
49impl AssertSendSync for ThreadPool {}
50
51struct PoolState {
52    tx: Mutex<mpsc::Sender<Message>>,
53    rx: Mutex<mpsc::Receiver<Message>>,
54    cnt: AtomicUsize,
55    size: usize,
56}
57
58impl fmt::Debug for ThreadPool {
59    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60        f.debug_struct("ThreadPool").field("size", &self.state.size).finish()
61    }
62}
63
64impl fmt::Debug for ThreadPoolBuilder {
65    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66        f.debug_struct("ThreadPoolBuilder")
67            .field("pool_size", &self.pool_size)
68            .field("name_prefix", &self.name_prefix)
69            .finish()
70    }
71}
72
73enum Message {
74    Run(Task),
75    Close,
76}
77
78impl ThreadPool {
79    /// Creates a new thread pool with the default configuration.
80    ///
81    /// See documentation for the methods in
82    /// [`ThreadPoolBuilder`](ThreadPoolBuilder) for details on the default
83    /// configuration.
84    pub fn new() -> Result<Self, io::Error> {
85        ThreadPoolBuilder::new().create()
86    }
87
88    /// Create a default thread pool configuration, which can then be customized.
89    ///
90    /// See documentation for the methods in
91    /// [`ThreadPoolBuilder`](ThreadPoolBuilder) for details on the default
92    /// configuration.
93    pub fn builder() -> ThreadPoolBuilder {
94        ThreadPoolBuilder::new()
95    }
96
97    /// Spawns a future that will be run to completion.
98    ///
99    /// > **Note**: This method is similar to `Spawn::spawn_obj`, except that
100    /// >           it is guaranteed to always succeed.
101    pub fn spawn_obj_ok(&self, future: FutureObj<'static, ()>) {
102        let task = Task {
103            future,
104            wake_handle: Arc::new(WakeHandle { exec: self.clone(), mutex: UnparkMutex::new() }),
105            exec: self.clone(),
106        };
107        self.state.send(Message::Run(task));
108    }
109
110    /// Spawns a task that polls the given future with output `()` to
111    /// completion.
112    ///
113    /// ```
114    /// # {
115    /// use futures::executor::ThreadPool;
116    ///
117    /// let pool = ThreadPool::new().unwrap();
118    ///
119    /// let future = async { /* ... */ };
120    /// pool.spawn_ok(future);
121    /// # }
122    /// # std::thread::sleep(std::time::Duration::from_millis(500)); // wait for background threads closed: https://github.com/rust-lang/miri/issues/1371
123    /// ```
124    ///
125    /// > **Note**: This method is similar to `SpawnExt::spawn`, except that
126    /// >           it is guaranteed to always succeed.
127    pub fn spawn_ok<Fut>(&self, future: Fut)
128    where
129        Fut: Future<Output = ()> + Send + 'static,
130    {
131        self.spawn_obj_ok(FutureObj::new(Box::new(future)))
132    }
133}
134
135impl Spawn for ThreadPool {
136    fn spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), SpawnError> {
137        self.spawn_obj_ok(future);
138        Ok(())
139    }
140}
141
142impl PoolState {
143    fn send(&self, msg: Message) {
144        self.tx.lock().unwrap().send(msg).unwrap();
145    }
146
147    fn work(
148        &self,
149        idx: usize,
150        after_start: Option<Arc<dyn Fn(usize) + Send + Sync>>,
151        before_stop: Option<Arc<dyn Fn(usize) + Send + Sync>>,
152    ) {
153        let _scope = enter().unwrap();
154        if let Some(after_start) = after_start {
155            after_start(idx);
156        }
157        loop {
158            let msg = self.rx.lock().unwrap().recv().unwrap();
159            match msg {
160                Message::Run(task) => task.run(),
161                Message::Close => break,
162            }
163        }
164        if let Some(before_stop) = before_stop {
165            before_stop(idx);
166        }
167    }
168}
169
170impl Clone for ThreadPool {
171    fn clone(&self) -> Self {
172        self.state.cnt.fetch_add(1, Ordering::Relaxed);
173        Self { state: self.state.clone() }
174    }
175}
176
177impl Drop for ThreadPool {
178    fn drop(&mut self) {
179        if self.state.cnt.fetch_sub(1, Ordering::Relaxed) == 1 {
180            for _ in 0..self.state.size {
181                self.state.send(Message::Close);
182            }
183        }
184    }
185}
186
187impl ThreadPoolBuilder {
188    /// Create a default thread pool configuration.
189    ///
190    /// See the other methods on this type for details on the defaults.
191    pub fn new() -> Self {
192        let pool_size = thread::available_parallelism().map_or(1, |p| p.get());
193        Self { pool_size, stack_size: 0, name_prefix: None, after_start: None, before_stop: None }
194    }
195
196    /// Set size of a future ThreadPool
197    ///
198    /// The size of a thread pool is the number of worker threads spawned. By
199    /// default, this is equal to the number of CPU cores.
200    ///
201    /// # Panics
202    ///
203    /// Panics if `pool_size == 0`.
204    pub fn pool_size(&mut self, size: usize) -> &mut Self {
205        assert!(size > 0);
206        self.pool_size = size;
207        self
208    }
209
210    /// Set stack size of threads in the pool, in bytes.
211    ///
212    /// By default, worker threads use Rust's standard stack size.
213    pub fn stack_size(&mut self, stack_size: usize) -> &mut Self {
214        self.stack_size = stack_size;
215        self
216    }
217
218    /// Set thread name prefix of a future ThreadPool.
219    ///
220    /// Thread name prefix is used for generating thread names. For example, if prefix is
221    /// `my-pool-`, then threads in the pool will get names like `my-pool-1` etc.
222    ///
223    /// By default, worker threads are assigned Rust's standard thread name.
224    pub fn name_prefix<S: Into<String>>(&mut self, name_prefix: S) -> &mut Self {
225        self.name_prefix = Some(name_prefix.into());
226        self
227    }
228
229    /// Execute the closure `f` immediately after each worker thread is started,
230    /// but before running any tasks on it.
231    ///
232    /// This hook is intended for bookkeeping and monitoring.
233    /// The closure `f` will be dropped after the `builder` is dropped
234    /// and all worker threads in the pool have executed it.
235    ///
236    /// The closure provided will receive an index corresponding to the worker
237    /// thread it's running on.
238    pub fn after_start<F>(&mut self, f: F) -> &mut Self
239    where
240        F: Fn(usize) + Send + Sync + 'static,
241    {
242        self.after_start = Some(Arc::new(f));
243        self
244    }
245
246    /// Execute closure `f` just prior to shutting down each worker thread.
247    ///
248    /// This hook is intended for bookkeeping and monitoring.
249    /// The closure `f` will be dropped after the `builder` is dropped
250    /// and all threads in the pool have executed it.
251    ///
252    /// The closure provided will receive an index corresponding to the worker
253    /// thread it's running on.
254    pub fn before_stop<F>(&mut self, f: F) -> &mut Self
255    where
256        F: Fn(usize) + Send + Sync + 'static,
257    {
258        self.before_stop = Some(Arc::new(f));
259        self
260    }
261
262    /// Create a [`ThreadPool`](ThreadPool) with the given configuration.
263    pub fn create(&mut self) -> Result<ThreadPool, io::Error> {
264        let (tx, rx) = mpsc::channel();
265        let pool = ThreadPool {
266            state: Arc::new(PoolState {
267                tx: Mutex::new(tx),
268                rx: Mutex::new(rx),
269                cnt: AtomicUsize::new(1),
270                size: self.pool_size,
271            }),
272        };
273
274        for counter in 0..self.pool_size {
275            let state = pool.state.clone();
276            let after_start = self.after_start.clone();
277            let before_stop = self.before_stop.clone();
278            let mut thread_builder = thread::Builder::new();
279            if let Some(ref name_prefix) = self.name_prefix {
280                thread_builder = thread_builder.name(format!("{name_prefix}{counter}"));
281            }
282            if self.stack_size > 0 {
283                thread_builder = thread_builder.stack_size(self.stack_size);
284            }
285            thread_builder.spawn(move || state.work(counter, after_start, before_stop))?;
286        }
287        Ok(pool)
288    }
289}
290
291impl Default for ThreadPoolBuilder {
292    fn default() -> Self {
293        Self::new()
294    }
295}
296
297/// A task responsible for polling a future to completion.
298struct Task {
299    future: FutureObj<'static, ()>,
300    exec: ThreadPool,
301    wake_handle: Arc<WakeHandle>,
302}
303
304struct WakeHandle {
305    mutex: UnparkMutex<Task>,
306    exec: ThreadPool,
307}
308
309impl Task {
310    /// Actually run the task (invoking `poll` on the future) on the current
311    /// thread.
312    fn run(self) {
313        let Self { mut future, wake_handle, mut exec } = self;
314        let waker = waker_ref(&wake_handle);
315        let mut cx = Context::from_waker(&waker);
316
317        // Safety: The ownership of this `Task` object is evidence that
318        // we are in the `POLLING`/`REPOLL` state for the mutex.
319        unsafe {
320            wake_handle.mutex.start_poll();
321
322            loop {
323                let res = future.poll_unpin(&mut cx);
324                match res {
325                    Poll::Pending => {}
326                    Poll::Ready(()) => return wake_handle.mutex.complete(),
327                }
328                let task = Self { future, wake_handle: wake_handle.clone(), exec };
329                match wake_handle.mutex.wait(task) {
330                    Ok(()) => return, // we've waited
331                    Err(task) => {
332                        // someone's notified us
333                        future = task.future;
334                        exec = task.exec;
335                    }
336                }
337            }
338        }
339    }
340}
341
342impl fmt::Debug for Task {
343    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
344        f.debug_struct("Task").field("contents", &"...").finish()
345    }
346}
347
348impl ArcWake for WakeHandle {
349    fn wake_by_ref(arc_self: &Arc<Self>) {
350        if let Ok(task) = arc_self.mutex.notify() {
351            arc_self.exec.state.send(Message::Run(task))
352        }
353    }
354}
355
356#[cfg(test)]
357mod tests {
358    use super::*;
359
360    #[test]
361    fn test_drop_after_start() {
362        {
363            let (tx, rx) = mpsc::sync_channel(2);
364            let _cpu_pool = ThreadPoolBuilder::new()
365                .pool_size(2)
366                .after_start(move |_| tx.send(1).unwrap())
367                .create()
368                .unwrap();
369
370            // After ThreadPoolBuilder is deconstructed, the tx should be dropped
371            // so that we can use rx as an iterator.
372            let count = rx.into_iter().count();
373            assert_eq!(count, 2);
374        }
375        std::thread::sleep(std::time::Duration::from_millis(500)); // wait for background threads closed: https://github.com/rust-lang/miri/issues/1371
376    }
377}