bp3d_threads/thread_pool/core.rs
1// Copyright (c) 2021, BlockProject 3D
2//
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without modification,
6// are permitted provided that the following conditions are met:
7//
8// * Redistributions of source code must retain the above copyright notice,
9// this list of conditions and the following disclaimer.
10// * Redistributions in binary form must reproduce the above copyright notice,
11// this list of conditions and the following disclaimer in the documentation
12// and/or other materials provided with the distribution.
13// * Neither the name of BlockProject 3D nor the names of its contributors
14// may be used to endorse or promote products derived from this software
15// without specific prior written permission.
16//
17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
18// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
19// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
20// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
21// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
22// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
23// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
24// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
25// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
26// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
27// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28
29//! A thread pool with support for function results
30
31use crossbeam::deque::{Injector, Stealer, Worker};
32use crossbeam::queue::{ArrayQueue, SegQueue};
33use std::iter::repeat_with;
34use std::sync::Arc;
35use std::time::Duration;
36use std::vec::IntoIter;
37
38const INNER_RESULT_BUFFER: usize = 16;
39
40struct Task<'env, T: Send + 'static> {
41 func: Box<dyn FnOnce(usize) -> T + Send + 'env>,
42 id: usize,
43}
44
45struct WorkThread<'env, T: Send + 'static> {
46 id: usize,
47 worker: Worker<Task<'env, T>>,
48 task_queue: Arc<Injector<Task<'env, T>>>,
49 task_stealers: Box<[Option<Stealer<Task<'env, T>>>]>,
50 term_queue: Arc<ArrayQueue<usize>>,
51 end_queue: Arc<SegQueue<Vec<T>>>,
52}
53
54impl<'env, T: Send + 'static> WorkThread<'env, T> {
55 pub fn new(
56 id: usize,
57 task_queue: Arc<Injector<Task<'env, T>>>,
58 worker: Worker<Task<'env, T>>,
59 task_stealers: Box<[Option<Stealer<Task<'env, T>>>]>,
60 term_queue: Arc<ArrayQueue<usize>>,
61 end_queue: Arc<SegQueue<Vec<T>>>,
62 ) -> WorkThread<'env, T> {
63 WorkThread {
64 id,
65 worker,
66 task_queue,
67 task_stealers,
68 term_queue,
69 end_queue,
70 }
71 }
72
73 fn attempt_steal_task(&self) -> Option<Task<'env, T>> {
74 self.worker.pop().or_else(|| {
75 std::iter::repeat_with(|| {
76 self.task_queue
77 .steal_batch_and_pop(&self.worker)
78 .or_else(|| {
79 self.task_stealers
80 .iter()
81 .filter_map(|v| if let Some(v) = v { Some(v) } else { None })
82 .map(|v| v.steal_batch_and_pop(&self.worker))
83 .collect()
84 })
85 })
86 .find(|v| !v.is_retry())
87 .and_then(|v| v.success())
88 })
89 }
90
91 fn empty_inner_buffer(&self, mut inner: Vec<T>) -> Vec<T> {
92 if !inner.is_empty() {
93 let buffer = std::mem::replace(&mut inner, Vec::with_capacity(INNER_RESULT_BUFFER));
94 self.end_queue.push(buffer);
95 }
96 inner
97 }
98
99 fn check_empty_inner_buffer(&self, mut inner: Vec<T>) -> Vec<T> {
100 if inner.len() >= INNER_RESULT_BUFFER {
101 inner = self.empty_inner_buffer(inner);
102 }
103 inner
104 }
105
106 fn iteration(&self) {
107 let mut inner = Vec::with_capacity(INNER_RESULT_BUFFER);
108 while let Some(task) = self.attempt_steal_task() {
109 let res = (task.func)(task.id);
110 inner.push(res);
111 inner = self.check_empty_inner_buffer(inner);
112 }
113 self.empty_inner_buffer(inner);
114 }
115
116 fn main_loop(&self) {
117 self.iteration();
118 /*if self.error_flag.get() {
119 self.term_channel_in.send(self.id).unwrap();
120 return;
121 }*/
122 // Wait 100ms and give another try before shutting down to let a chance to the main thread to refill the task channel.
123 std::thread::sleep(Duration::from_millis(100));
124 self.iteration();
125 self.term_queue.push(self.id).unwrap();
126 }
127}
128
129/// Trait to access the join function of a thread handle.
130pub trait Join {
131 /// Joins this thread.
132 fn join(self) -> std::thread::Result<()>;
133}
134
135/// Trait to handle spawning generic threads.
136pub trait ThreadManager<'env> {
137 /// The type of thread handle (must have a join() function).
138 type Handle: Join;
139
140 /// Spawns a thread using this manager.
141 ///
142 /// # Arguments
143 ///
144 /// * `func`: the function to run in the thread.
145 ///
146 /// returns: Self::Handle
147 fn spawn_thread<F: FnOnce() + Send + 'env>(&self, func: F) -> Self::Handle;
148}
149
150struct Inner<'env, M: ThreadManager<'env>, T: Send + 'static> {
151 end_queue: Arc<SegQueue<Vec<T>>>,
152 threads: Box<[Option<M::Handle>]>,
153 task_stealers: Box<[Option<Stealer<Task<'env, T>>>]>,
154 term_queue: Arc<ArrayQueue<usize>>,
155 running_threads: usize,
156 n_threads: usize,
157}
158
159/// An iterator into a thread pool.
160pub struct Iter<'a, 'env, M: ThreadManager<'env>, T: Send + 'static> {
161 inner: &'a mut Inner<'env, M, T>,
162 batch: Option<IntoIter<T>>,
163 thread_id: usize,
164}
165
166impl<'a, 'env, M: ThreadManager<'env>, T: Send + 'static> Iter<'a, 'env, M, T> {
167 fn pump_next_batch(&mut self) -> Option<std::thread::Result<()>> {
168 while self.batch.is_none() {
169 if self.inner.running_threads == 0 {
170 return None;
171 }
172 if let Some(h) = self.inner.threads[self.thread_id].take() {
173 if let Err(e) = h.join() {
174 return Some(Err(e));
175 }
176 self.inner.term_queue.pop();
177 self.inner.running_threads -= 1;
178 let mut megabatch = Vec::new();
179 while let Some(batch) = self.inner.end_queue.pop() {
180 megabatch.extend(batch);
181 }
182 self.batch = Some(megabatch.into_iter());
183 return Some(Ok(()));
184 }
185 self.inner.task_stealers[self.thread_id] = None;
186 self.thread_id += 1;
187 }
188 Some(Ok(()))
189 }
190}
191
192impl<'a, 'env, M: ThreadManager<'env>, T: Send + 'static> Iter<'a, 'env, M, Vec<T>> {
193 /// Collect this iterator into a single [Vec](std::vec::Vec) when each task returns a
194 /// [Vec](std::vec::Vec).
195 pub fn to_vec(mut self) -> std::thread::Result<Vec<T>> {
196 let mut v = Vec::new();
197 for i in 0..self.inner.n_threads {
198 if let Some(h) = self.inner.threads[i].take() {
199 h.join()?;
200 self.inner.term_queue.pop();
201 self.inner.running_threads -= 1;
202 while let Some(batch) = self.inner.end_queue.pop() {
203 for r in batch {
204 v.extend(r);
205 }
206 }
207 }
208 self.inner.task_stealers[i] = None;
209 }
210 Ok(v)
211 }
212}
213
214impl<'a, 'env, M: ThreadManager<'env>, T: Send + 'static, E: Send + 'static>
215 Iter<'a, 'env, M, Result<Vec<T>, E>>
216{
217 /// Collect this iterator into a single [Result](std::result::Result) of [Vec](std::vec::Vec)
218 /// when each task returns a [Result](std::result::Result) of [Vec](std::vec::Vec).
219 pub fn to_vec(mut self) -> std::thread::Result<Result<Vec<T>, E>> {
220 let mut v = Vec::new();
221 for i in 0..self.inner.n_threads {
222 if let Some(h) = self.inner.threads[i].take() {
223 h.join()?;
224 self.inner.term_queue.pop();
225 self.inner.running_threads -= 1;
226 while let Some(batch) = self.inner.end_queue.pop() {
227 for r in batch {
228 match r {
229 Ok(items) => v.extend(items),
230 Err(e) => return Ok(Err(e)),
231 }
232 }
233 }
234 }
235 self.inner.task_stealers[i] = None;
236 }
237 Ok(Ok(v))
238 }
239}
240
241impl<'a, 'env, M: ThreadManager<'env>, T: Send + 'static> Iterator for Iter<'a, 'env, M, T> {
242 type Item = std::thread::Result<T>;
243
244 fn next(&mut self) -> Option<Self::Item> {
245 match self.pump_next_batch() {
246 None => return None,
247 Some(v) => match v {
248 Ok(_) => (),
249 Err(e) => return Some(Err(e)),
250 },
251 };
252 // SAFETY: always safe because while self.batch.is_none(). So if this is reached then
253 // batch has to be Some.
254 let batch = unsafe { self.batch.as_mut().unwrap_unchecked() };
255 let item = batch.next();
256 match item {
257 None => {
258 self.batch = None;
259 self.next()
260 }
261 Some(v) => Some(Ok(v)),
262 }
263 }
264}
265
266/// Core thread pool.
267pub struct ThreadPool<'env, M: ThreadManager<'env>, T: Send + 'static> {
268 task_queue: Arc<Injector<Task<'env, T>>>,
269 end_batch: Option<Vec<T>>,
270 inner: Inner<'env, M, T>,
271 task_id: usize,
272}
273
274impl<'env, M: ThreadManager<'env>, T: Send> ThreadPool<'env, M, T> {
275 /// Creates a new thread pool
276 ///
277 /// # Arguments
278 ///
279 /// * `n_threads`: maximum number of threads allowed to run at the same time.
280 ///
281 /// returns: ThreadPool<T, Manager>
282 ///
283 /// # Examples
284 ///
285 /// ```
286 /// use bp3d_threads::UnscopedThreadManager;
287 /// use bp3d_threads::ThreadPool;
288 /// let _: ThreadPool<UnscopedThreadManager, ()> = ThreadPool::new(4);
289 /// ```
290 pub fn new(n_threads: usize) -> Self {
291 Self {
292 task_queue: Arc::new(Injector::new()),
293 inner: Inner {
294 task_stealers: vec![None; n_threads].into_boxed_slice(),
295 end_queue: Arc::new(SegQueue::new()),
296 term_queue: Arc::new(ArrayQueue::new(n_threads)),
297 n_threads,
298 running_threads: 0,
299 threads: repeat_with(|| None)
300 .take(n_threads)
301 .collect::<Vec<Option<M::Handle>>>()
302 .into_boxed_slice(),
303 },
304 end_batch: None,
305 task_id: 0,
306 }
307 }
308
309 fn rearm_one_thread_if_possible(&mut self, manager: &M) {
310 if self.inner.running_threads < self.inner.n_threads {
311 for (i, handle) in self.inner.threads.iter_mut().enumerate() {
312 if handle.is_none() {
313 let worker = Worker::new_fifo();
314 let stealer = worker.stealer();
315 // Required due to a bug in rust: rust believes that Handle and Manager have to
316 // be Send when Task doesn't have anything to do with the Manager or the Handle!
317 let rust_hack_1 = self.task_queue.clone();
318 let rust_hack_2 = self.inner.task_stealers.clone();
319 let rust_hack_3 = self.inner.end_queue.clone();
320 let rust_hack_4 = self.inner.term_queue.clone();
321 self.inner.task_stealers[i] = Some(stealer);
322 *handle = Some(manager.spawn_thread(move || {
323 let thread = WorkThread::new(
324 i,
325 rust_hack_1,
326 worker,
327 rust_hack_2,
328 rust_hack_4,
329 rust_hack_3,
330 );
331 thread.main_loop()
332 }));
333 break;
334 }
335 }
336 self.inner.running_threads += 1;
337 }
338 }
339
340 /// Send a new task to the injector queue.
341 ///
342 /// **The task execution order is not guaranteed,
343 /// however the task index is guaranteed to be the order of the call to dispatch.**
344 ///
345 /// **If a task panics it will leave a dead thread in the corresponding slot until .wait() is called.**
346 ///
347 /// # Arguments
348 ///
349 /// * `manager`: the thread manager to spawn a new thread if needed.
350 /// * `f`: the task function to execute.
351 ///
352 /// # Examples
353 ///
354 /// ```
355 /// use bp3d_threads::UnscopedThreadManager;
356 /// use bp3d_threads::ThreadPool;
357 /// let manager = UnscopedThreadManager::new();
358 /// let mut pool: ThreadPool<UnscopedThreadManager, ()> = ThreadPool::new(4);
359 /// pool.send(&manager, |_| ());
360 /// ```
361 pub fn send<F: FnOnce(usize) -> T + Send + 'env>(&mut self, manager: &M, f: F) {
362 let task = Task {
363 func: Box::new(f),
364 id: self.task_id,
365 };
366 self.task_queue.push(task);
367 self.task_id += 1;
368 self.rearm_one_thread_if_possible(manager);
369 }
370
371 /// Schedule a new task to run.
372 ///
373 /// Returns true if the task was successfully scheduled, false otherwise.
374 ///
375 /// *NOTE: Since version 1.1.0, failure is no longer possible so this function will never return false.*
376 ///
377 /// **The task execution order is not guaranteed,
378 /// however the task index is guaranteed to be the order of the call to dispatch.**
379 ///
380 /// **If a task panics it will leave a dead thread in the corresponding slot until .join() is called.**
381 ///
382 /// # Arguments
383 ///
384 /// * `manager`: the thread manager to spawn a new thread if needed.
385 /// * `f`: the task function to execute.
386 ///
387 /// returns: bool
388 #[deprecated(since = "1.1.0", note = "Please use `send` instead")]
389 pub fn dispatch<F: FnOnce(usize) -> T + Send + 'env>(&mut self, manager: &M, f: F) -> bool {
390 self.send(manager, f);
391 true
392 }
393
394 /// Returns true if this thread pool is idle.
395 ///
396 /// **An idle thread pool does neither have running threads nor waiting tasks
397 /// but may still have waiting results to poll.**
398 pub fn is_idle(&self) -> bool {
399 self.task_queue.is_empty() && self.inner.running_threads == 0
400 }
401
402 /// Poll a result from this thread pool if any, returns None if no result is available.
403 pub fn poll(&mut self) -> Option<T> {
404 if let Some(v) = self.inner.term_queue.pop() {
405 self.inner.threads[v] = None;
406 self.inner.task_stealers[v] = None;
407 self.inner.running_threads -= 1;
408 }
409 if self.end_batch.is_none() {
410 self.end_batch = self.inner.end_queue.pop();
411 }
412 let value = match self.end_batch.as_mut() {
413 None => None,
414 Some(v) => {
415 let val = v.pop();
416 if v.is_empty() {
417 self.end_batch = None;
418 }
419 val
420 }
421 };
422 value
423 }
424
425 /// Waits for all tasks to finish execution and stops all threads while iterating over task
426 /// results.
427 ///
428 /// *Use this to periodically clean-up the thread pool, if you know that some tasks may panic.*
429 ///
430 /// **Use this function in map-reduce kind of scenarios.**
431 ///
432 /// # Errors
433 ///
434 /// Returns an error if a thread did panic.
435 pub fn reduce(&mut self) -> Iter<'_, 'env, M, T> {
436 Iter {
437 inner: &mut self.inner,
438 batch: None,
439 thread_id: 0,
440 }
441 }
442
443 /// Waits for all tasks to finish execution and stops all threads.
444 ///
445 /// *Use this to periodically clean-up the thread pool, if you know that some tasks may panic.*
446 ///
447 /// # Errors
448 ///
449 /// Returns an error if a thread did panic.
450 pub fn wait(&mut self) -> std::thread::Result<()> {
451 for i in 0..self.inner.n_threads {
452 if let Some(h) = self.inner.threads[i].take() {
453 h.join()?;
454 self.inner.term_queue.pop();
455 self.inner.running_threads -= 1;
456 }
457 self.inner.task_stealers[i] = None;
458 }
459 Ok(())
460 }
461
462 /// Waits for all tasks to finish execution and stops all threads.
463 ///
464 /// *Use this to periodically clean-up the thread pool, if you know that some tasks may panic.*
465 ///
466 /// # Errors
467 ///
468 /// Returns an error if a thread did panic.
469 #[deprecated(since = "1.1.0", note = "Please use `wait` or `reduce` instead")]
470 pub fn join(&mut self) -> std::thread::Result<()> {
471 self.wait()
472 }
473}