bp3d_threads/thread_pool/
core.rs

1// Copyright (c) 2021, BlockProject 3D
2//
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without modification,
6// are permitted provided that the following conditions are met:
7//
8//     * Redistributions of source code must retain the above copyright notice,
9//       this list of conditions and the following disclaimer.
10//     * Redistributions in binary form must reproduce the above copyright notice,
11//       this list of conditions and the following disclaimer in the documentation
12//       and/or other materials provided with the distribution.
13//     * Neither the name of BlockProject 3D nor the names of its contributors
14//       may be used to endorse or promote products derived from this software
15//       without specific prior written permission.
16//
17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
18// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
19// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
20// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
21// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
22// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
23// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
24// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
25// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
26// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
27// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28
29//! A thread pool with support for function results
30
31use crossbeam::deque::{Injector, Stealer, Worker};
32use crossbeam::queue::{ArrayQueue, SegQueue};
33use std::iter::repeat_with;
34use std::sync::Arc;
35use std::time::Duration;
36use std::vec::IntoIter;
37
38const INNER_RESULT_BUFFER: usize = 16;
39
40struct Task<'env, T: Send + 'static> {
41    func: Box<dyn FnOnce(usize) -> T + Send + 'env>,
42    id: usize,
43}
44
45struct WorkThread<'env, T: Send + 'static> {
46    id: usize,
47    worker: Worker<Task<'env, T>>,
48    task_queue: Arc<Injector<Task<'env, T>>>,
49    task_stealers: Box<[Option<Stealer<Task<'env, T>>>]>,
50    term_queue: Arc<ArrayQueue<usize>>,
51    end_queue: Arc<SegQueue<Vec<T>>>,
52}
53
54impl<'env, T: Send + 'static> WorkThread<'env, T> {
55    pub fn new(
56        id: usize,
57        task_queue: Arc<Injector<Task<'env, T>>>,
58        worker: Worker<Task<'env, T>>,
59        task_stealers: Box<[Option<Stealer<Task<'env, T>>>]>,
60        term_queue: Arc<ArrayQueue<usize>>,
61        end_queue: Arc<SegQueue<Vec<T>>>,
62    ) -> WorkThread<'env, T> {
63        WorkThread {
64            id,
65            worker,
66            task_queue,
67            task_stealers,
68            term_queue,
69            end_queue,
70        }
71    }
72
73    fn attempt_steal_task(&self) -> Option<Task<'env, T>> {
74        self.worker.pop().or_else(|| {
75            std::iter::repeat_with(|| {
76                self.task_queue
77                    .steal_batch_and_pop(&self.worker)
78                    .or_else(|| {
79                        self.task_stealers
80                            .iter()
81                            .filter_map(|v| if let Some(v) = v { Some(v) } else { None })
82                            .map(|v| v.steal_batch_and_pop(&self.worker))
83                            .collect()
84                    })
85            })
86            .find(|v| !v.is_retry())
87            .and_then(|v| v.success())
88        })
89    }
90
91    fn empty_inner_buffer(&self, mut inner: Vec<T>) -> Vec<T> {
92        if !inner.is_empty() {
93            let buffer = std::mem::replace(&mut inner, Vec::with_capacity(INNER_RESULT_BUFFER));
94            self.end_queue.push(buffer);
95        }
96        inner
97    }
98
99    fn check_empty_inner_buffer(&self, mut inner: Vec<T>) -> Vec<T> {
100        if inner.len() >= INNER_RESULT_BUFFER {
101            inner = self.empty_inner_buffer(inner);
102        }
103        inner
104    }
105
106    fn iteration(&self) {
107        let mut inner = Vec::with_capacity(INNER_RESULT_BUFFER);
108        while let Some(task) = self.attempt_steal_task() {
109            let res = (task.func)(task.id);
110            inner.push(res);
111            inner = self.check_empty_inner_buffer(inner);
112        }
113        self.empty_inner_buffer(inner);
114    }
115
116    fn main_loop(&self) {
117        self.iteration();
118        /*if self.error_flag.get() {
119            self.term_channel_in.send(self.id).unwrap();
120            return;
121        }*/
122        // Wait 100ms and give another try before shutting down to let a chance to the main thread to refill the task channel.
123        std::thread::sleep(Duration::from_millis(100));
124        self.iteration();
125        self.term_queue.push(self.id).unwrap();
126    }
127}
128
129/// Trait to access the join function of a thread handle.
130pub trait Join {
131    /// Joins this thread.
132    fn join(self) -> std::thread::Result<()>;
133}
134
135/// Trait to handle spawning generic threads.
136pub trait ThreadManager<'env> {
137    /// The type of thread handle (must have a join() function).
138    type Handle: Join;
139
140    /// Spawns a thread using this manager.
141    ///
142    /// # Arguments
143    ///
144    /// * `func`: the function to run in the thread.
145    ///
146    /// returns: Self::Handle
147    fn spawn_thread<F: FnOnce() + Send + 'env>(&self, func: F) -> Self::Handle;
148}
149
150struct Inner<'env, M: ThreadManager<'env>, T: Send + 'static> {
151    end_queue: Arc<SegQueue<Vec<T>>>,
152    threads: Box<[Option<M::Handle>]>,
153    task_stealers: Box<[Option<Stealer<Task<'env, T>>>]>,
154    term_queue: Arc<ArrayQueue<usize>>,
155    running_threads: usize,
156    n_threads: usize,
157}
158
159/// An iterator into a thread pool.
160pub struct Iter<'a, 'env, M: ThreadManager<'env>, T: Send + 'static> {
161    inner: &'a mut Inner<'env, M, T>,
162    batch: Option<IntoIter<T>>,
163    thread_id: usize,
164}
165
166impl<'a, 'env, M: ThreadManager<'env>, T: Send + 'static> Iter<'a, 'env, M, T> {
167    fn pump_next_batch(&mut self) -> Option<std::thread::Result<()>> {
168        while self.batch.is_none() {
169            if self.inner.running_threads == 0 {
170                return None;
171            }
172            if let Some(h) = self.inner.threads[self.thread_id].take() {
173                if let Err(e) = h.join() {
174                    return Some(Err(e));
175                }
176                self.inner.term_queue.pop();
177                self.inner.running_threads -= 1;
178                let mut megabatch = Vec::new();
179                while let Some(batch) = self.inner.end_queue.pop() {
180                    megabatch.extend(batch);
181                }
182                self.batch = Some(megabatch.into_iter());
183                return Some(Ok(()));
184            }
185            self.inner.task_stealers[self.thread_id] = None;
186            self.thread_id += 1;
187        }
188        Some(Ok(()))
189    }
190}
191
192impl<'a, 'env, M: ThreadManager<'env>, T: Send + 'static> Iter<'a, 'env, M, Vec<T>> {
193    /// Collect this iterator into a single [Vec](std::vec::Vec) when each task returns a
194    /// [Vec](std::vec::Vec).
195    pub fn to_vec(mut self) -> std::thread::Result<Vec<T>> {
196        let mut v = Vec::new();
197        for i in 0..self.inner.n_threads {
198            if let Some(h) = self.inner.threads[i].take() {
199                h.join()?;
200                self.inner.term_queue.pop();
201                self.inner.running_threads -= 1;
202                while let Some(batch) = self.inner.end_queue.pop() {
203                    for r in batch {
204                        v.extend(r);
205                    }
206                }
207            }
208            self.inner.task_stealers[i] = None;
209        }
210        Ok(v)
211    }
212}
213
214impl<'a, 'env, M: ThreadManager<'env>, T: Send + 'static, E: Send + 'static>
215    Iter<'a, 'env, M, Result<Vec<T>, E>>
216{
217    /// Collect this iterator into a single [Result](std::result::Result) of [Vec](std::vec::Vec)
218    /// when each task returns a [Result](std::result::Result) of [Vec](std::vec::Vec).
219    pub fn to_vec(mut self) -> std::thread::Result<Result<Vec<T>, E>> {
220        let mut v = Vec::new();
221        for i in 0..self.inner.n_threads {
222            if let Some(h) = self.inner.threads[i].take() {
223                h.join()?;
224                self.inner.term_queue.pop();
225                self.inner.running_threads -= 1;
226                while let Some(batch) = self.inner.end_queue.pop() {
227                    for r in batch {
228                        match r {
229                            Ok(items) => v.extend(items),
230                            Err(e) => return Ok(Err(e)),
231                        }
232                    }
233                }
234            }
235            self.inner.task_stealers[i] = None;
236        }
237        Ok(Ok(v))
238    }
239}
240
241impl<'a, 'env, M: ThreadManager<'env>, T: Send + 'static> Iterator for Iter<'a, 'env, M, T> {
242    type Item = std::thread::Result<T>;
243
244    fn next(&mut self) -> Option<Self::Item> {
245        match self.pump_next_batch() {
246            None => return None,
247            Some(v) => match v {
248                Ok(_) => (),
249                Err(e) => return Some(Err(e)),
250            },
251        };
252        // SAFETY: always safe because while self.batch.is_none(). So if this is reached then
253        // batch has to be Some.
254        let batch = unsafe { self.batch.as_mut().unwrap_unchecked() };
255        let item = batch.next();
256        match item {
257            None => {
258                self.batch = None;
259                self.next()
260            }
261            Some(v) => Some(Ok(v)),
262        }
263    }
264}
265
266/// Core thread pool.
267pub struct ThreadPool<'env, M: ThreadManager<'env>, T: Send + 'static> {
268    task_queue: Arc<Injector<Task<'env, T>>>,
269    end_batch: Option<Vec<T>>,
270    inner: Inner<'env, M, T>,
271    task_id: usize,
272}
273
274impl<'env, M: ThreadManager<'env>, T: Send> ThreadPool<'env, M, T> {
275    /// Creates a new thread pool
276    ///
277    /// # Arguments
278    ///
279    /// * `n_threads`: maximum number of threads allowed to run at the same time.
280    ///
281    /// returns: ThreadPool<T, Manager>
282    ///
283    /// # Examples
284    ///
285    /// ```
286    /// use bp3d_threads::UnscopedThreadManager;
287    /// use bp3d_threads::ThreadPool;
288    /// let _: ThreadPool<UnscopedThreadManager, ()> = ThreadPool::new(4);
289    /// ```
290    pub fn new(n_threads: usize) -> Self {
291        Self {
292            task_queue: Arc::new(Injector::new()),
293            inner: Inner {
294                task_stealers: vec![None; n_threads].into_boxed_slice(),
295                end_queue: Arc::new(SegQueue::new()),
296                term_queue: Arc::new(ArrayQueue::new(n_threads)),
297                n_threads,
298                running_threads: 0,
299                threads: repeat_with(|| None)
300                    .take(n_threads)
301                    .collect::<Vec<Option<M::Handle>>>()
302                    .into_boxed_slice(),
303            },
304            end_batch: None,
305            task_id: 0,
306        }
307    }
308
309    fn rearm_one_thread_if_possible(&mut self, manager: &M) {
310        if self.inner.running_threads < self.inner.n_threads {
311            for (i, handle) in self.inner.threads.iter_mut().enumerate() {
312                if handle.is_none() {
313                    let worker = Worker::new_fifo();
314                    let stealer = worker.stealer();
315                    // Required due to a bug in rust: rust believes that Handle and Manager have to
316                    // be Send when Task doesn't have anything to do with the Manager or the Handle!
317                    let rust_hack_1 = self.task_queue.clone();
318                    let rust_hack_2 = self.inner.task_stealers.clone();
319                    let rust_hack_3 = self.inner.end_queue.clone();
320                    let rust_hack_4 = self.inner.term_queue.clone();
321                    self.inner.task_stealers[i] = Some(stealer);
322                    *handle = Some(manager.spawn_thread(move || {
323                        let thread = WorkThread::new(
324                            i,
325                            rust_hack_1,
326                            worker,
327                            rust_hack_2,
328                            rust_hack_4,
329                            rust_hack_3,
330                        );
331                        thread.main_loop()
332                    }));
333                    break;
334                }
335            }
336            self.inner.running_threads += 1;
337        }
338    }
339
340    /// Send a new task to the injector queue.
341    ///
342    /// **The task execution order is not guaranteed,
343    /// however the task index is guaranteed to be the order of the call to dispatch.**
344    ///
345    /// **If a task panics it will leave a dead thread in the corresponding slot until .wait() is called.**
346    ///
347    /// # Arguments
348    ///
349    /// * `manager`: the thread manager to spawn a new thread if needed.
350    /// * `f`: the task function to execute.
351    ///
352    /// # Examples
353    ///
354    /// ```
355    /// use bp3d_threads::UnscopedThreadManager;
356    /// use bp3d_threads::ThreadPool;
357    /// let manager = UnscopedThreadManager::new();
358    /// let mut pool: ThreadPool<UnscopedThreadManager, ()> = ThreadPool::new(4);
359    /// pool.send(&manager, |_| ());
360    /// ```
361    pub fn send<F: FnOnce(usize) -> T + Send + 'env>(&mut self, manager: &M, f: F) {
362        let task = Task {
363            func: Box::new(f),
364            id: self.task_id,
365        };
366        self.task_queue.push(task);
367        self.task_id += 1;
368        self.rearm_one_thread_if_possible(manager);
369    }
370
371    /// Schedule a new task to run.
372    ///
373    /// Returns true if the task was successfully scheduled, false otherwise.
374    ///
375    /// *NOTE: Since version 1.1.0, failure is no longer possible so this function will never return false.*
376    ///
377    /// **The task execution order is not guaranteed,
378    /// however the task index is guaranteed to be the order of the call to dispatch.**
379    ///
380    /// **If a task panics it will leave a dead thread in the corresponding slot until .join() is called.**
381    ///
382    /// # Arguments
383    ///
384    /// * `manager`: the thread manager to spawn a new thread if needed.
385    /// * `f`: the task function to execute.
386    ///
387    /// returns: bool
388    #[deprecated(since = "1.1.0", note = "Please use `send` instead")]
389    pub fn dispatch<F: FnOnce(usize) -> T + Send + 'env>(&mut self, manager: &M, f: F) -> bool {
390        self.send(manager, f);
391        true
392    }
393
394    /// Returns true if this thread pool is idle.
395    ///
396    /// **An idle thread pool does neither have running threads nor waiting tasks
397    /// but may still have waiting results to poll.**
398    pub fn is_idle(&self) -> bool {
399        self.task_queue.is_empty() && self.inner.running_threads == 0
400    }
401
402    /// Poll a result from this thread pool if any, returns None if no result is available.
403    pub fn poll(&mut self) -> Option<T> {
404        if let Some(v) = self.inner.term_queue.pop() {
405            self.inner.threads[v] = None;
406            self.inner.task_stealers[v] = None;
407            self.inner.running_threads -= 1;
408        }
409        if self.end_batch.is_none() {
410            self.end_batch = self.inner.end_queue.pop();
411        }
412        let value = match self.end_batch.as_mut() {
413            None => None,
414            Some(v) => {
415                let val = v.pop();
416                if v.is_empty() {
417                    self.end_batch = None;
418                }
419                val
420            }
421        };
422        value
423    }
424
425    /// Waits for all tasks to finish execution and stops all threads while iterating over task
426    /// results.
427    ///
428    /// *Use this to periodically clean-up the thread pool, if you know that some tasks may panic.*
429    ///
430    /// **Use this function in map-reduce kind of scenarios.**
431    ///
432    /// # Errors
433    ///
434    /// Returns an error if a thread did panic.
435    pub fn reduce(&mut self) -> Iter<'_, 'env, M, T> {
436        Iter {
437            inner: &mut self.inner,
438            batch: None,
439            thread_id: 0,
440        }
441    }
442
443    /// Waits for all tasks to finish execution and stops all threads.
444    ///
445    /// *Use this to periodically clean-up the thread pool, if you know that some tasks may panic.*
446    ///
447    /// # Errors
448    ///
449    /// Returns an error if a thread did panic.
450    pub fn wait(&mut self) -> std::thread::Result<()> {
451        for i in 0..self.inner.n_threads {
452            if let Some(h) = self.inner.threads[i].take() {
453                h.join()?;
454                self.inner.term_queue.pop();
455                self.inner.running_threads -= 1;
456            }
457            self.inner.task_stealers[i] = None;
458        }
459        Ok(())
460    }
461
462    /// Waits for all tasks to finish execution and stops all threads.
463    ///
464    /// *Use this to periodically clean-up the thread pool, if you know that some tasks may panic.*
465    ///
466    /// # Errors
467    ///
468    /// Returns an error if a thread did panic.
469    #[deprecated(since = "1.1.0", note = "Please use `wait` or `reduce` instead")]
470    pub fn join(&mut self) -> std::thread::Result<()> {
471        self.wait()
472    }
473}