futures_executor/
thread_pool.rs

1use std::prelude::v1::*;
2
3use std::io;
4use std::sync::{Arc, Mutex};
5use std::sync::atomic::{AtomicUsize, Ordering};
6use std::sync::mpsc;
7use std::thread;
8use std::fmt;
9
10use futures_core::*;
11use futures_core::task::{self, Wake, Waker, LocalMap};
12use futures_core::executor::{Executor, SpawnError};
13use futures_core::never::Never;
14
15use enter;
16use num_cpus;
17use unpark_mutex::UnparkMutex;
18
19/// A general-purpose thread pool for scheduling asynchronous tasks.
20///
21/// The thread pool multiplexes any number of tasks onto a fixed number of
22/// worker threads.
23///
24/// This type is a clonable handle to the threadpool itself.
25/// Cloning it will only create a new reference, not a new threadpool.
26pub struct ThreadPool {
27    state: Arc<PoolState>,
28}
29
30/// Thread pool configuration object.
31pub struct ThreadPoolBuilder {
32    pool_size: usize,
33    stack_size: usize,
34    name_prefix: Option<String>,
35    after_start: Option<Arc<Fn(usize) + Send + Sync>>,
36    before_stop: Option<Arc<Fn(usize) + Send + Sync>>,
37}
38
39trait AssertSendSync: Send + Sync {}
40impl AssertSendSync for ThreadPool {}
41
42struct PoolState {
43    tx: Mutex<mpsc::Sender<Message>>,
44    rx: Mutex<mpsc::Receiver<Message>>,
45    cnt: AtomicUsize,
46    size: usize,
47}
48
49impl fmt::Debug for ThreadPool {
50    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
51        f.debug_struct("ThreadPool")
52            .field("size", &self.state.size)
53            .finish()
54    }
55}
56
57impl fmt::Debug for ThreadPoolBuilder {
58    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
59        f.debug_struct("ThreadPoolBuilder")
60            .field("pool_size", &self.pool_size)
61            .field("name_prefix", &self.name_prefix)
62            .finish()
63    }
64}
65
66enum Message {
67    Run(Task),
68    Close,
69}
70
71impl ThreadPool {
72    /// Creates a new thread pool with the default configuration.
73    ///
74    /// See documentation for the methods in
75    /// [`ThreadPoolBuilder`](::ThreadPoolBuilder) for details on the default
76    /// configuration.
77    pub fn new() -> Result<ThreadPool, io::Error> {
78        ThreadPoolBuilder::new().create()
79    }
80
81    /// Create a default thread pool configuration, which can then be customized.
82    ///
83    /// See documentation for the methods in
84    /// [`ThreadPoolBuilder`](::ThreadPoolBuilder) for details on the default
85    /// configuration.
86    pub fn builder() -> ThreadPoolBuilder {
87        ThreadPoolBuilder::new()
88    }
89
90    /// Runs the given future with this thread pool as the default executor for
91    /// spawning tasks.
92    ///
93    /// **This function will block the calling thread** until the given future
94    /// is complete. While executing that future, any tasks spawned onto the
95    /// default executor will be routed to this thread pool.
96    ///
97    /// Note that the function will return when the provided future completes,
98    /// even if some of the tasks it spawned are still running.
99    pub fn run<F: Future>(&mut self, f: F) -> Result<F::Item, F::Error> {
100        ::LocalPool::new().run_until(f, self)
101    }
102}
103
104impl Executor for ThreadPool {
105    fn spawn(&mut self, f: Box<Future<Item = (), Error = Never> + Send>) -> Result<(), SpawnError> {
106        let task = Task {
107            spawn: f,
108            map: LocalMap::new(),
109            wake_handle: Arc::new(WakeHandle {
110                exec: self.clone(),
111                mutex: UnparkMutex::new(),
112            }),
113            exec: self.clone(),
114        };
115        self.state.send(Message::Run(task));
116        Ok(())
117    }
118}
119
120impl PoolState {
121    fn send(&self, msg: Message) {
122        self.tx.lock().unwrap().send(msg).unwrap();
123    }
124
125    fn work(&self,
126            idx: usize,
127            after_start: Option<Arc<Fn(usize) + Send + Sync>>,
128            before_stop: Option<Arc<Fn(usize) + Send + Sync>>) {
129        let _scope = enter().unwrap();
130        after_start.map(|fun| fun(idx));
131        loop {
132            let msg = self.rx.lock().unwrap().recv().unwrap();
133            match msg {
134                Message::Run(r) => r.run(),
135                Message::Close => break,
136            }
137        }
138        before_stop.map(|fun| fun(idx));
139    }
140}
141
142impl Clone for ThreadPool {
143    fn clone(&self) -> ThreadPool {
144        self.state.cnt.fetch_add(1, Ordering::Relaxed);
145        ThreadPool { state: self.state.clone() }
146    }
147}
148
149impl Drop for ThreadPool {
150    fn drop(&mut self) {
151        if self.state.cnt.fetch_sub(1, Ordering::Relaxed) == 1 {
152            for _ in 0..self.state.size {
153                self.state.send(Message::Close);
154            }
155        }
156    }
157}
158
159impl ThreadPoolBuilder {
160    /// Create a default thread pool configuration.
161    ///
162    /// See the other methods on this type for details on the defaults.
163    pub fn new() -> ThreadPoolBuilder {
164        ThreadPoolBuilder {
165            pool_size: num_cpus::get(),
166            stack_size: 0,
167            name_prefix: None,
168            after_start: None,
169            before_stop: None,
170        }
171    }
172
173    /// Set size of a future ThreadPool
174    ///
175    /// The size of a thread pool is the number of worker threads spawned.  By
176    /// default, this is equal to the number of CPU cores.
177    pub fn pool_size(&mut self, size: usize) -> &mut Self {
178        self.pool_size = size;
179        self
180    }
181
182    /// Set stack size of threads in the pool.
183    ///
184    /// By default, worker threads use Rust's standard stack size.
185    pub fn stack_size(&mut self, stack_size: usize) -> &mut Self {
186        self.stack_size = stack_size;
187        self
188    }
189
190    /// Set thread name prefix of a future ThreadPool.
191    ///
192    /// Thread name prefix is used for generating thread names. For example, if prefix is
193    /// `my-pool-`, then threads in the pool will get names like `my-pool-1` etc.
194    ///
195    /// By default, worker threads are assigned Rust's standard thread name.
196    pub fn name_prefix<S: Into<String>>(&mut self, name_prefix: S) -> &mut Self {
197        self.name_prefix = Some(name_prefix.into());
198        self
199    }
200
201    /// Execute the closure `f` immediately after each worker thread is started,
202    /// but before running any tasks on it.
203    ///
204    /// This hook is intended for bookkeeping and monitoring.
205    /// The closure `f` will be dropped after the `builder` is dropped
206    /// and all worker threads in the pool have executed it.
207    ///
208    /// The closure provided will receive an index corresponding to the worker
209    /// thread it's running on.
210    pub fn after_start<F>(&mut self, f: F) -> &mut Self
211        where F: Fn(usize) + Send + Sync + 'static
212    {
213        self.after_start = Some(Arc::new(f));
214        self
215    }
216
217    /// Execute closure `f` just prior to shutting down each worker thread.
218    ///
219    /// This hook is intended for bookkeeping and monitoring.
220    /// The closure `f` will be dropped after the `builder` is droppped
221    /// and all threads in the pool have executed it.
222    ///
223    /// The closure provided will receive an index corresponding to the worker
224    /// thread it's running on.
225    pub fn before_stop<F>(&mut self, f: F) -> &mut Self
226        where F: Fn(usize) + Send + Sync + 'static
227    {
228        self.before_stop = Some(Arc::new(f));
229        self
230    }
231
232    /// Create a [`ThreadPool`](::ThreadPool) with the given configuration.
233    ///
234    /// # Panics
235    ///
236    /// Panics if `pool_size == 0`.
237    pub fn create(&mut self) -> Result<ThreadPool, io::Error> {
238        let (tx, rx) = mpsc::channel();
239        let pool = ThreadPool {
240            state: Arc::new(PoolState {
241                tx: Mutex::new(tx),
242                rx: Mutex::new(rx),
243                cnt: AtomicUsize::new(1),
244                size: self.pool_size,
245            }),
246        };
247        assert!(self.pool_size > 0);
248
249        for counter in 0..self.pool_size {
250            let state = pool.state.clone();
251            let after_start = self.after_start.clone();
252            let before_stop = self.before_stop.clone();
253            let mut thread_builder = thread::Builder::new();
254            if let Some(ref name_prefix) = self.name_prefix {
255                thread_builder = thread_builder.name(format!("{}{}", name_prefix, counter));
256            }
257            if self.stack_size > 0 {
258                thread_builder = thread_builder.stack_size(self.stack_size);
259            }
260            thread_builder.spawn(move || state.work(counter, after_start, before_stop))?;
261        }
262        Ok(pool)
263    }
264}
265
266/// Units of work submitted to an `Executor`, currently only created
267/// internally.
268struct Task {
269    spawn: Box<Future<Item = (), Error = Never> + Send>,
270    map: LocalMap,
271    exec: ThreadPool,
272    wake_handle: Arc<WakeHandle>,
273}
274
275struct WakeHandle {
276    mutex: UnparkMutex<Task>,
277    exec: ThreadPool,
278}
279
280impl Task {
281    /// Actually run the task (invoking `poll` on its future) on the current
282    /// thread.
283    pub fn run(self) {
284        let Task { mut spawn, wake_handle, mut map, mut exec } = self;
285        let waker = Waker::from(wake_handle.clone());
286
287        // SAFETY: the ownership of this `Task` object is evidence that
288        // we are in the `POLLING`/`REPOLL` state for the mutex.
289        unsafe {
290            wake_handle.mutex.start_poll();
291
292            loop {
293                let res = {
294                    let mut cx = task::Context::new(&mut map, &waker, &mut exec);
295                    spawn.poll(&mut cx)
296                };
297                match res {
298                    Ok(Async::Pending) => {}
299                    Ok(Async::Ready(())) => return wake_handle.mutex.complete(),
300                    Err(never) => match never {},
301                }
302                let task = Task {
303                    spawn,
304                    map,
305                    wake_handle: wake_handle.clone(),
306                    exec: exec
307                };
308                match wake_handle.mutex.wait(task) {
309                    Ok(()) => return,            // we've waited
310                    Err(r) => { // someone's notified us
311                        spawn = r.spawn;
312                        map = r.map;
313                        exec = r.exec;
314                    }
315                }
316            }
317        }
318    }
319}
320
321impl fmt::Debug for Task {
322    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
323        f.debug_struct("Task")
324            .field("contents", &"...")
325            .finish()
326    }
327}
328
329impl Wake for WakeHandle {
330    fn wake(arc_self: &Arc<Self>) {
331        match arc_self.mutex.notify() {
332            Ok(task) => arc_self.exec.state.send(Message::Run(task)),
333            Err(()) => {}
334        }
335    }
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341    use std::sync::mpsc;
342
343    #[test]
344    fn test_drop_after_start() {
345        let (tx, rx) = mpsc::sync_channel(2);
346        let _cpu_pool = ThreadPoolBuilder::new()
347            .pool_size(2)
348            .after_start(move |_| tx.send(1).unwrap()).create().unwrap();
349
350        // After ThreadPoolBuilder is deconstructed, the tx should be droped
351        // so that we can use rx as an iterator.
352        let count = rx.into_iter().count();
353        assert_eq!(count, 2);
354    }
355}