1use std::prelude::v1::*;
2
3use std::io;
4use std::sync::{Arc, Mutex};
5use std::sync::atomic::{AtomicUsize, Ordering};
6use std::sync::mpsc;
7use std::thread;
8use std::fmt;
9
10use futures_core::*;
11use futures_core::task::{self, Wake, Waker, LocalMap};
12use futures_core::executor::{Executor, SpawnError};
13use futures_core::never::Never;
14
15use enter;
16use num_cpus;
17use unpark_mutex::UnparkMutex;
18
19pub struct ThreadPool {
27 state: Arc<PoolState>,
28}
29
30pub struct ThreadPoolBuilder {
32 pool_size: usize,
33 stack_size: usize,
34 name_prefix: Option<String>,
35 after_start: Option<Arc<Fn(usize) + Send + Sync>>,
36 before_stop: Option<Arc<Fn(usize) + Send + Sync>>,
37}
38
39trait AssertSendSync: Send + Sync {}
40impl AssertSendSync for ThreadPool {}
41
42struct PoolState {
43 tx: Mutex<mpsc::Sender<Message>>,
44 rx: Mutex<mpsc::Receiver<Message>>,
45 cnt: AtomicUsize,
46 size: usize,
47}
48
49impl fmt::Debug for ThreadPool {
50 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
51 f.debug_struct("ThreadPool")
52 .field("size", &self.state.size)
53 .finish()
54 }
55}
56
57impl fmt::Debug for ThreadPoolBuilder {
58 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
59 f.debug_struct("ThreadPoolBuilder")
60 .field("pool_size", &self.pool_size)
61 .field("name_prefix", &self.name_prefix)
62 .finish()
63 }
64}
65
66enum Message {
67 Run(Task),
68 Close,
69}
70
71impl ThreadPool {
72 pub fn new() -> Result<ThreadPool, io::Error> {
78 ThreadPoolBuilder::new().create()
79 }
80
81 pub fn builder() -> ThreadPoolBuilder {
87 ThreadPoolBuilder::new()
88 }
89
90 pub fn run<F: Future>(&mut self, f: F) -> Result<F::Item, F::Error> {
100 ::LocalPool::new().run_until(f, self)
101 }
102}
103
104impl Executor for ThreadPool {
105 fn spawn(&mut self, f: Box<Future<Item = (), Error = Never> + Send>) -> Result<(), SpawnError> {
106 let task = Task {
107 spawn: f,
108 map: LocalMap::new(),
109 wake_handle: Arc::new(WakeHandle {
110 exec: self.clone(),
111 mutex: UnparkMutex::new(),
112 }),
113 exec: self.clone(),
114 };
115 self.state.send(Message::Run(task));
116 Ok(())
117 }
118}
119
120impl PoolState {
121 fn send(&self, msg: Message) {
122 self.tx.lock().unwrap().send(msg).unwrap();
123 }
124
125 fn work(&self,
126 idx: usize,
127 after_start: Option<Arc<Fn(usize) + Send + Sync>>,
128 before_stop: Option<Arc<Fn(usize) + Send + Sync>>) {
129 let _scope = enter().unwrap();
130 after_start.map(|fun| fun(idx));
131 loop {
132 let msg = self.rx.lock().unwrap().recv().unwrap();
133 match msg {
134 Message::Run(r) => r.run(),
135 Message::Close => break,
136 }
137 }
138 before_stop.map(|fun| fun(idx));
139 }
140}
141
142impl Clone for ThreadPool {
143 fn clone(&self) -> ThreadPool {
144 self.state.cnt.fetch_add(1, Ordering::Relaxed);
145 ThreadPool { state: self.state.clone() }
146 }
147}
148
149impl Drop for ThreadPool {
150 fn drop(&mut self) {
151 if self.state.cnt.fetch_sub(1, Ordering::Relaxed) == 1 {
152 for _ in 0..self.state.size {
153 self.state.send(Message::Close);
154 }
155 }
156 }
157}
158
159impl ThreadPoolBuilder {
160 pub fn new() -> ThreadPoolBuilder {
164 ThreadPoolBuilder {
165 pool_size: num_cpus::get(),
166 stack_size: 0,
167 name_prefix: None,
168 after_start: None,
169 before_stop: None,
170 }
171 }
172
173 pub fn pool_size(&mut self, size: usize) -> &mut Self {
178 self.pool_size = size;
179 self
180 }
181
182 pub fn stack_size(&mut self, stack_size: usize) -> &mut Self {
186 self.stack_size = stack_size;
187 self
188 }
189
190 pub fn name_prefix<S: Into<String>>(&mut self, name_prefix: S) -> &mut Self {
197 self.name_prefix = Some(name_prefix.into());
198 self
199 }
200
201 pub fn after_start<F>(&mut self, f: F) -> &mut Self
211 where F: Fn(usize) + Send + Sync + 'static
212 {
213 self.after_start = Some(Arc::new(f));
214 self
215 }
216
217 pub fn before_stop<F>(&mut self, f: F) -> &mut Self
226 where F: Fn(usize) + Send + Sync + 'static
227 {
228 self.before_stop = Some(Arc::new(f));
229 self
230 }
231
232 pub fn create(&mut self) -> Result<ThreadPool, io::Error> {
238 let (tx, rx) = mpsc::channel();
239 let pool = ThreadPool {
240 state: Arc::new(PoolState {
241 tx: Mutex::new(tx),
242 rx: Mutex::new(rx),
243 cnt: AtomicUsize::new(1),
244 size: self.pool_size,
245 }),
246 };
247 assert!(self.pool_size > 0);
248
249 for counter in 0..self.pool_size {
250 let state = pool.state.clone();
251 let after_start = self.after_start.clone();
252 let before_stop = self.before_stop.clone();
253 let mut thread_builder = thread::Builder::new();
254 if let Some(ref name_prefix) = self.name_prefix {
255 thread_builder = thread_builder.name(format!("{}{}", name_prefix, counter));
256 }
257 if self.stack_size > 0 {
258 thread_builder = thread_builder.stack_size(self.stack_size);
259 }
260 thread_builder.spawn(move || state.work(counter, after_start, before_stop))?;
261 }
262 Ok(pool)
263 }
264}
265
266struct Task {
269 spawn: Box<Future<Item = (), Error = Never> + Send>,
270 map: LocalMap,
271 exec: ThreadPool,
272 wake_handle: Arc<WakeHandle>,
273}
274
275struct WakeHandle {
276 mutex: UnparkMutex<Task>,
277 exec: ThreadPool,
278}
279
280impl Task {
281 pub fn run(self) {
284 let Task { mut spawn, wake_handle, mut map, mut exec } = self;
285 let waker = Waker::from(wake_handle.clone());
286
287 unsafe {
290 wake_handle.mutex.start_poll();
291
292 loop {
293 let res = {
294 let mut cx = task::Context::new(&mut map, &waker, &mut exec);
295 spawn.poll(&mut cx)
296 };
297 match res {
298 Ok(Async::Pending) => {}
299 Ok(Async::Ready(())) => return wake_handle.mutex.complete(),
300 Err(never) => match never {},
301 }
302 let task = Task {
303 spawn,
304 map,
305 wake_handle: wake_handle.clone(),
306 exec: exec
307 };
308 match wake_handle.mutex.wait(task) {
309 Ok(()) => return, Err(r) => { spawn = r.spawn;
312 map = r.map;
313 exec = r.exec;
314 }
315 }
316 }
317 }
318 }
319}
320
321impl fmt::Debug for Task {
322 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
323 f.debug_struct("Task")
324 .field("contents", &"...")
325 .finish()
326 }
327}
328
329impl Wake for WakeHandle {
330 fn wake(arc_self: &Arc<Self>) {
331 match arc_self.mutex.notify() {
332 Ok(task) => arc_self.exec.state.send(Message::Run(task)),
333 Err(()) => {}
334 }
335 }
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341 use std::sync::mpsc;
342
343 #[test]
344 fn test_drop_after_start() {
345 let (tx, rx) = mpsc::sync_channel(2);
346 let _cpu_pool = ThreadPoolBuilder::new()
347 .pool_size(2)
348 .after_start(move |_| tx.send(1).unwrap()).create().unwrap();
349
350 let count = rx.into_iter().count();
353 assert_eq!(count, 2);
354 }
355}