futures_executor/
thread_pool.rs1use crate::enter;
2use crate::unpark_mutex::UnparkMutex;
3use futures_core::future::Future;
4use futures_core::task::{Context, Poll};
5use futures_task::{waker_ref, ArcWake};
6use futures_task::{FutureObj, Spawn, SpawnError};
7use futures_util::future::FutureExt;
8use std::boxed::Box;
9use std::fmt;
10use std::format;
11use std::io;
12use std::string::String;
13use std::sync::atomic::{AtomicUsize, Ordering};
14use std::sync::mpsc;
15use std::sync::{Arc, Mutex};
16use std::thread;
17
18#[cfg_attr(docsrs, doc(cfg(feature = "thread-pool")))]
30pub struct ThreadPool {
31 state: Arc<PoolState>,
32}
33
34#[cfg_attr(docsrs, doc(cfg(feature = "thread-pool")))]
39pub struct ThreadPoolBuilder {
40 pool_size: usize,
41 stack_size: usize,
42 name_prefix: Option<String>,
43 after_start: Option<Arc<dyn Fn(usize) + Send + Sync>>,
44 before_stop: Option<Arc<dyn Fn(usize) + Send + Sync>>,
45}
46
47#[allow(dead_code)]
48trait AssertSendSync: Send + Sync {}
49impl AssertSendSync for ThreadPool {}
50
51struct PoolState {
52 tx: Mutex<mpsc::Sender<Message>>,
53 rx: Mutex<mpsc::Receiver<Message>>,
54 cnt: AtomicUsize,
55 size: usize,
56}
57
58impl fmt::Debug for ThreadPool {
59 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60 f.debug_struct("ThreadPool").field("size", &self.state.size).finish()
61 }
62}
63
64impl fmt::Debug for ThreadPoolBuilder {
65 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66 f.debug_struct("ThreadPoolBuilder")
67 .field("pool_size", &self.pool_size)
68 .field("name_prefix", &self.name_prefix)
69 .finish()
70 }
71}
72
73enum Message {
74 Run(Task),
75 Close,
76}
77
78impl ThreadPool {
79 pub fn new() -> Result<Self, io::Error> {
85 ThreadPoolBuilder::new().create()
86 }
87
88 pub fn builder() -> ThreadPoolBuilder {
94 ThreadPoolBuilder::new()
95 }
96
97 pub fn spawn_obj_ok(&self, future: FutureObj<'static, ()>) {
102 let task = Task {
103 future,
104 wake_handle: Arc::new(WakeHandle { exec: self.clone(), mutex: UnparkMutex::new() }),
105 exec: self.clone(),
106 };
107 self.state.send(Message::Run(task));
108 }
109
110 pub fn spawn_ok<Fut>(&self, future: Fut)
128 where
129 Fut: Future<Output = ()> + Send + 'static,
130 {
131 self.spawn_obj_ok(FutureObj::new(Box::new(future)))
132 }
133}
134
135impl Spawn for ThreadPool {
136 fn spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), SpawnError> {
137 self.spawn_obj_ok(future);
138 Ok(())
139 }
140}
141
142impl PoolState {
143 fn send(&self, msg: Message) {
144 self.tx.lock().unwrap().send(msg).unwrap();
145 }
146
147 fn work(
148 &self,
149 idx: usize,
150 after_start: Option<Arc<dyn Fn(usize) + Send + Sync>>,
151 before_stop: Option<Arc<dyn Fn(usize) + Send + Sync>>,
152 ) {
153 let _scope = enter().unwrap();
154 if let Some(after_start) = after_start {
155 after_start(idx);
156 }
157 loop {
158 let msg = self.rx.lock().unwrap().recv().unwrap();
159 match msg {
160 Message::Run(task) => task.run(),
161 Message::Close => break,
162 }
163 }
164 if let Some(before_stop) = before_stop {
165 before_stop(idx);
166 }
167 }
168}
169
170impl Clone for ThreadPool {
171 fn clone(&self) -> Self {
172 self.state.cnt.fetch_add(1, Ordering::Relaxed);
173 Self { state: self.state.clone() }
174 }
175}
176
177impl Drop for ThreadPool {
178 fn drop(&mut self) {
179 if self.state.cnt.fetch_sub(1, Ordering::Relaxed) == 1 {
180 for _ in 0..self.state.size {
181 self.state.send(Message::Close);
182 }
183 }
184 }
185}
186
187impl ThreadPoolBuilder {
188 pub fn new() -> Self {
192 let pool_size = thread::available_parallelism().map_or(1, |p| p.get());
193 Self { pool_size, stack_size: 0, name_prefix: None, after_start: None, before_stop: None }
194 }
195
196 pub fn pool_size(&mut self, size: usize) -> &mut Self {
205 assert!(size > 0);
206 self.pool_size = size;
207 self
208 }
209
210 pub fn stack_size(&mut self, stack_size: usize) -> &mut Self {
214 self.stack_size = stack_size;
215 self
216 }
217
218 pub fn name_prefix<S: Into<String>>(&mut self, name_prefix: S) -> &mut Self {
225 self.name_prefix = Some(name_prefix.into());
226 self
227 }
228
229 pub fn after_start<F>(&mut self, f: F) -> &mut Self
239 where
240 F: Fn(usize) + Send + Sync + 'static,
241 {
242 self.after_start = Some(Arc::new(f));
243 self
244 }
245
246 pub fn before_stop<F>(&mut self, f: F) -> &mut Self
255 where
256 F: Fn(usize) + Send + Sync + 'static,
257 {
258 self.before_stop = Some(Arc::new(f));
259 self
260 }
261
262 pub fn create(&mut self) -> Result<ThreadPool, io::Error> {
264 let (tx, rx) = mpsc::channel();
265 let pool = ThreadPool {
266 state: Arc::new(PoolState {
267 tx: Mutex::new(tx),
268 rx: Mutex::new(rx),
269 cnt: AtomicUsize::new(1),
270 size: self.pool_size,
271 }),
272 };
273
274 for counter in 0..self.pool_size {
275 let state = pool.state.clone();
276 let after_start = self.after_start.clone();
277 let before_stop = self.before_stop.clone();
278 let mut thread_builder = thread::Builder::new();
279 if let Some(ref name_prefix) = self.name_prefix {
280 thread_builder = thread_builder.name(format!("{name_prefix}{counter}"));
281 }
282 if self.stack_size > 0 {
283 thread_builder = thread_builder.stack_size(self.stack_size);
284 }
285 thread_builder.spawn(move || state.work(counter, after_start, before_stop))?;
286 }
287 Ok(pool)
288 }
289}
290
291impl Default for ThreadPoolBuilder {
292 fn default() -> Self {
293 Self::new()
294 }
295}
296
297struct Task {
299 future: FutureObj<'static, ()>,
300 exec: ThreadPool,
301 wake_handle: Arc<WakeHandle>,
302}
303
304struct WakeHandle {
305 mutex: UnparkMutex<Task>,
306 exec: ThreadPool,
307}
308
309impl Task {
310 fn run(self) {
313 let Self { mut future, wake_handle, mut exec } = self;
314 let waker = waker_ref(&wake_handle);
315 let mut cx = Context::from_waker(&waker);
316
317 unsafe {
320 wake_handle.mutex.start_poll();
321
322 loop {
323 let res = future.poll_unpin(&mut cx);
324 match res {
325 Poll::Pending => {}
326 Poll::Ready(()) => return wake_handle.mutex.complete(),
327 }
328 let task = Self { future, wake_handle: wake_handle.clone(), exec };
329 match wake_handle.mutex.wait(task) {
330 Ok(()) => return, Err(task) => {
332 future = task.future;
334 exec = task.exec;
335 }
336 }
337 }
338 }
339 }
340}
341
342impl fmt::Debug for Task {
343 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
344 f.debug_struct("Task").field("contents", &"...").finish()
345 }
346}
347
348impl ArcWake for WakeHandle {
349 fn wake_by_ref(arc_self: &Arc<Self>) {
350 if let Ok(task) = arc_self.mutex.notify() {
351 arc_self.exec.state.send(Message::Run(task))
352 }
353 }
354}
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359
360 #[test]
361 fn test_drop_after_start() {
362 {
363 let (tx, rx) = mpsc::sync_channel(2);
364 let _cpu_pool = ThreadPoolBuilder::new()
365 .pool_size(2)
366 .after_start(move |_| tx.send(1).unwrap())
367 .create()
368 .unwrap();
369
370 let count = rx.into_iter().count();
373 assert_eq!(count, 2);
374 }
375 std::thread::sleep(std::time::Duration::from_millis(500)); }
377}