thread_tree/
lib.rs

1//!
2//! A hierarchical thread pool used for splitting work in a branching fashion.
3//!
4//! This thread pool is good for:
5//!
6//! - You want to split work recursively in jobs that use approximately the same time.
7//! - You want thread pool overhead to be low
8//!
9//! This is not good for:
10//!
11//! - You need work stealing
12//! - When you have jobs of uneven size
13//!
14
15// Stack jobs and job execution implementation based on rayon-core by Niko Matsakis and Josh Stone
16//
17use crossbeam_channel::{Sender, bounded};
18
19use std::thread;
20
21mod unwind;
22mod job;
23
24use crate::job::{JobRef, StackJob};
25
26// ThreadTree message on the channel (is just a job ref)
27type TTreeMessage = JobRef;
28
29/// A hierarchical thread pool used for splitting work in a branching fashion.
30///
31/// See [`ThreadTree::new_with_level()`] to create a new thread tree,
32/// and see [`ThreadTree::top()`] for a usage example.
33///
34/// The thread tree has the benefit that at each level, jobs can be sent directly to the thread
35/// that is going to execute it - that means there is no contention between waiting threads. The
36/// downside is that the structure of the thread tree is rather static.
37#[derive(Debug)]
38pub struct ThreadTree {
39    sender: Option<Sender<TTreeMessage>>,
40    child: Option<[Box<ThreadTree>; 2]>,
41}
42
43// Only three threads needed to have four leaves, see below.
44//
45//           (root)
46//      (1)            2 
47// (1.1)   1.2   (2.1)   2.2
48//
49// Leaves 1.1, 1.2, 2.1 and 2.2 but only 1.2, 2, and 2.2 are new threads - the others inherit the
50// current thread from the parent.  That means we have a fanout of four (leaves 1.1 trough 2.2)
51// using the current thread and three additional threads.
52//
53// The implementation is such that the root holds ownership of leaf 2, and the root contains a
54// channel sender that passes jobs to the node 2.  Further nodes down continue the same way
55// recursively.
56//
57// Idea for later: implement reservations of (parts of) the tree?
58// So that a 2-2 tree can be used as two separate 1-2 trees simultaneously
59
60impl ThreadTree {
61    const BOTTOM: &'static Self = &ThreadTree::new_level0();
62
63    /// Create a level 0 tree (with no parallelism)
64    #[inline]
65    pub const fn new_level0() -> Self {
66        ThreadTree { sender: None, child: None }
67    }
68
69    /// Create an n-level thread tree with 2<sup>n</sup> leaves
70    ///
71    /// Level 0 has no parallelism
72    /// Level 1 has two nodes
73    /// Level 2 has four nodes (et.c.)
74    ///
75    /// Level must be <= 12; panics on invalid input
76    pub fn new_with_level(level: usize) -> Box<Self> {
77        assert!(level <= 12,
78                "Input exceeds maximum level 12 (equivalent to 2**12 - 1 threads), got level='{}'",
79                level);
80        if level == 0 {
81            Box::new(Self::new_level0())
82        } else if level == 1 {
83            Box::new(ThreadTree { sender: Some(Self::add_thread()), child: None })
84        } else {
85            let fork_2 = Self::new_with_level(level - 1);
86            let fork_3 = Self::new_with_level(level - 1);
87            Box::new(ThreadTree { sender: Some(Self::add_thread()), child: Some([fork_2, fork_3])})
88        }
89    }
90
91    /// Return true if this is a non-dummy pool which will parallelize in join
92    #[inline]
93    pub fn is_parallel(&self) -> bool {
94        self.sender.is_some()
95    }
96
97    /// Get the top thread tree context, where we can inject tasks with join.
98    /// Each job gets a sub-context that can be used to inject tasks further down the corresponding
99    /// branch of the tree.
100    ///
101    /// **Note** to avoid deadlocks, tasks should never be injected into a tree context that
102    /// doesn't belong to the current level. To avoid this should be easy - only call .top() at the
103    /// top level.
104    ///
105    /// The following example shows using a two-level tree and using context to spawn tasks.
106    ///
107    /// ```
108    /// use thread_tree::{ThreadTree, ThreadTreeCtx};
109    ///
110    /// let tp = ThreadTree::new_with_level(2);
111    ///
112    /// fn f(index: i32, ctx: ThreadTreeCtx<'_>) -> i32 {
113    ///     // do work in subtasks here
114    ///     let (a, b) = ctx.join(move |_| index + 1, |_| index + 2);
115    ///
116    ///     return a + b;
117    /// }
118    ///
119    /// let (r0, r1) = tp.top().join(|ctx| f(0, ctx), |ctx| f(1, ctx));
120    ///
121    /// assert_eq!(r0 + r1, (0 + 1) + (0 + 2) + (1 + 1) + (1 + 2));
122    /// ```
123    #[inline]
124    pub fn top(&self) -> ThreadTreeCtx<'_> {
125        ThreadTreeCtx::from(self)
126    }
127
128    // Create a new thread that executes jobs, and return the channel sender that feeds jobs to
129    // this thread.
130    fn add_thread() -> Sender<TTreeMessage> {
131        let (sender, receiver) = bounded::<TTreeMessage>(1); // buffered, we know we have a connection
132        std::thread::spawn(move || {
133            for job in receiver {
134                unsafe {
135                    job.execute()
136                }
137            }
138        });
139        sender
140    }
141}
142
143/// A level-specific handle to the thread tree, that can be used to inject jobs.
144///
145/// See [`ThreadTree::top()`] for more information.
146#[derive(Debug, Copy, Clone)]
147pub struct ThreadTreeCtx<'a> {
148    tree: &'a ThreadTree,
149    // This handle is marked as non-Send/Sync as a help - there is nothing safety critical about it
150    // - but it helps the user to avoid deadlocks - see the top method.
151    _not_send_sync: *const (),
152}
153
154impl ThreadTreeCtx<'_> {
155    #[inline]
156    pub(crate) fn get(&self) -> &ThreadTree { self.tree }
157
158    #[inline]
159    pub(crate) fn from(tree: &ThreadTree) -> ThreadTreeCtx<'_> {
160        ThreadTreeCtx { tree, _not_send_sync: &() }
161    }
162
163    /// Return true if this level will parallelize in join (or if we are at the bottom of the tree)
164    #[inline]
165    pub fn is_parallel(&self) -> bool {
166        self.get().is_parallel()
167    }
168
169    /// Branch out and run a and b simultaneously and return their results jointly.
170    ///
171    /// Job `a` runs on the current thread while `b` runs on the sibling thread; each is passed
172    /// a lower level of the thread tree.
173    /// If the bottom of the tree is reached, where no sibling threads are available, both `a` and
174    /// `b` run on the current thread.
175    ///
176    /// If either `a` or `b` panics, the panic is propagated here. If both jobs are executing,
177    /// the panic will not propagate until after both jobs have finished.
178    /// 
179    /// Warning: You must not .join() into the same tree from nested jobs. Nested jobs must
180    /// be spawned using the context that each job receives as the first parameter.
181    pub fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
182        where A: FnOnce(ThreadTreeCtx) -> RA + Send,
183              B: FnOnce(ThreadTreeCtx) -> RB + Send,
184              RA: Send,
185              RB: Send,
186    {
187        let bottom_level = ThreadTree::BOTTOM;
188        let self_ = self.get();
189        let (fork_a, fork_b) = match &self_.child {
190            None => (bottom_level, bottom_level),
191            Some([fa, fb]) => (&**fa, &**fb),
192        };
193        //assert!(self_.sender.is_some());
194
195        unsafe {
196            let a = move || a(ThreadTreeCtx::from(fork_a));
197            let b = move || b(ThreadTreeCtx::from(fork_b));
198
199            // first send B to the sibling thread
200            let b_job = StackJob::new(b); // plant this safely on the stack
201            let b_job_ref = JobRef::new(&b_job);
202            let b_runs_here = match self_.sender {
203                Some(ref s) => { s.send(b_job_ref).unwrap(); None }
204                None => Some(b_job_ref),
205            };
206
207            let a_result;
208            {
209                // Ensure that we will later wait for B, if it is running on
210                // another thread. Both in the case of A panic or regular scope exit.
211                //
212                // If job A panics, we still cannot return until we are sure that job
213                // B is complete. This is because it may contain references into the
214                // enclosing stack frame(s).
215                let _wait_for_b_guard = match b_runs_here {
216                    None => Some(WaitForJobGuard::new(&b_job)),
217                    Some(_) => None,
218                };
219
220                // Execute task A
221                a_result = a();
222
223                if let Some(b_job_ref) = b_runs_here {
224                    b_job_ref.execute();
225                }
226                // wait for b here
227            }
228            (a_result, b_job.into_result())
229        }
230    }
231
232    /// Branch out twice and join, running three different jobs
233    ///
234    /// Branches twice on the left side and once on the right.
235    /// The closure is called with corresponding thread tree context and an index in 0..3 for the job.
236    pub fn join3l<A, RA>(&self, a: &A) -> ((RA, RA), RA)
237        where A: Fn(ThreadTreeCtx, usize) -> RA + Sync,
238              RA: Send,
239    {
240        self.join(
241            move |ctx| ctx.join(move |ctx| a(ctx, 0), move |ctx| a(ctx, 1)),
242            move |ctx| a(ctx, 2))
243    }
244
245    /// Branch out twice and join, running three different jobs
246    ///
247    /// Branches once on the right side and twice on the right.
248    /// The closure is called with corresponding thread tree context and an index in 0..3 for the job.
249    pub fn join3r<A, RA>(&self, a: &A) -> (RA, (RA, RA))
250        where A: Fn(ThreadTreeCtx, usize) -> RA + Sync,
251              RA: Send,
252    {
253        self.join(
254            move |ctx| a(ctx, 0),
255            move |ctx| ctx.join(move |ctx| a(ctx, 1), move |ctx| a(ctx, 2)))
256    }
257
258    /// Branch out twice and join, running four different jobs.
259    ///
260    /// Branches twice on each side.
261    /// The closure is called with corresponding thread tree context and an index in 0..4 for the job.
262    pub fn join4<A, RA>(&self, a: &A) -> ((RA, RA), (RA, RA))
263        where A: Fn(ThreadTreeCtx, usize) -> RA + Sync,
264              RA: Send,
265    {
266        self.join(
267            move |ctx| ctx.join(move |ctx| a(ctx, 0), move |ctx| a(ctx, 1)),
268            move |ctx| ctx.join(move |ctx| a(ctx, 2), move |ctx| a(ctx, 3)))
269    }
270}
271
272
273fn wait_for_job<F, R>(job: &StackJob<F, R>) {
274    while !job.probe() {
275        //spin_loop_hint();
276        thread::yield_now();
277    }
278}
279
280struct WaitForJobGuard<'a, F, R> {
281    job: &'a StackJob<F, R>,
282}
283
284impl<'a, F, R> WaitForJobGuard<'a, F, R>
285{
286    fn new(job: &'a StackJob<F, R>) -> Self {
287        Self { job }
288    }
289}
290
291impl<'a, F, R> Drop for WaitForJobGuard<'a, F, R> {
292    fn drop(&mut self) {
293        wait_for_job(self.job)
294    }
295}
296
297#[cfg(test)]
298mod thread_tree_tests {
299    use super::*;
300    #[allow(deprecated)]
301
302    use std::sync::atomic::AtomicUsize;
303    use std::sync::atomic::Ordering;
304    use std::sync::Mutex;
305    use once_cell::sync::Lazy;
306    use std::collections::HashSet;
307    use std::thread;
308    use std::thread::ThreadId;
309
310    #[allow(deprecated)]
311    fn sleep_ms(x: u32) {
312        std::thread::sleep_ms(x)
313    }
314
315    #[test]
316    fn stub() {
317        let tp = ThreadTree::new_level0();
318        let a = AtomicUsize::new(0);
319        let b = AtomicUsize::new(0);
320
321        tp.top().join(|_| a.fetch_add(1, Ordering::SeqCst),
322                |_| b.fetch_add(1, Ordering::SeqCst));
323        assert_eq!(a.load(Ordering::SeqCst), 1);
324        assert_eq!(b.load(Ordering::SeqCst), 1);
325
326        let f = || thread::current().id();
327        let (aid, bid) = tp.top().join(|_| f(), |_| f());
328        assert_eq!(aid, bid);
329        assert!(!tp.top().is_parallel());
330    }
331
332    #[test]
333    fn new_level_1() {
334        let tp = ThreadTree::new_with_level(1);
335        let a = AtomicUsize::new(0);
336        let b = AtomicUsize::new(0);
337
338        tp.top().join(|_| a.fetch_add(1, Ordering::SeqCst),
339                |_| b.fetch_add(1, Ordering::SeqCst));
340        assert_eq!(a.load(Ordering::SeqCst), 1);
341        assert_eq!(b.load(Ordering::SeqCst), 1);
342
343        let f = || thread::current().id();
344        let (aid, bid) = tp.top().join(|_| f(), |_| f());
345        assert_ne!(aid, bid);
346        assert!(tp.top().is_parallel());
347    }
348
349    #[test]
350    fn build_level_2() {
351        let tp = ThreadTree::new_with_level(2);
352        let a = AtomicUsize::new(0);
353        let b = AtomicUsize::new(0);
354
355        tp.top().join(|_| a.fetch_add(1, Ordering::SeqCst),
356                |_| b.fetch_add(1, Ordering::SeqCst));
357        assert_eq!(a.load(Ordering::SeqCst), 1);
358        assert_eq!(b.load(Ordering::SeqCst), 1);
359
360        let f = || thread::current().id();
361        let ((aid, bid), (cid, did)) = tp.top().join(
362            |tp1| tp1.join(|_| f(), |_| f()),
363            |tp1| tp1.join(|_| f(), |_| f()));
364        assert_ne!(aid, bid);
365        assert_ne!(aid, cid);
366        assert_ne!(aid, did);
367        assert_ne!(bid, cid);
368        assert_ne!(bid, did);
369        assert_ne!(cid, did);
370    }
371
372    #[test]
373    fn overload_2_2() {
374        let global = ThreadTree::new_with_level(1);
375        let tp = ThreadTree::new_with_level(2);
376        let a = AtomicUsize::new(0);
377
378        let range = 0..100;
379
380        let work = |ctx: ThreadTreeCtx<'_>| {
381            let subwork = || {
382                for i in range.clone() {
383                    a.fetch_add(i, Ordering::Relaxed);
384                    sleep_ms(1);
385                }
386            };
387            ctx.join(|_| subwork(), |_| subwork());
388        };
389
390        global.top().join(
391            |_| tp.top().join(work, work),
392            |_| tp.top().join(work, work));
393
394        let sum = range.clone().sum::<usize>();
395
396        assert_eq!(sum * 4 * 2, a.load(Ordering::SeqCst));
397
398    }
399
400    #[test]
401    fn deep_tree() {
402        static THREADS: Lazy<Mutex<HashSet<ThreadId>>> = Lazy::new(|| Mutex::default());
403        const TREE_LEVEL: usize = 8;
404        const MAX_DEPTH: usize = 12;
405
406        static COUNT: AtomicUsize = AtomicUsize::new(0);
407
408        let tp = ThreadTree::new_with_level(TREE_LEVEL);
409
410        fn f(tp: ThreadTreeCtx<'_>, depth: usize) {
411            COUNT.fetch_add(1, Ordering::SeqCst);
412            THREADS.lock().unwrap().insert(thread::current().id());
413            if depth >= MAX_DEPTH {
414                return;
415            }
416            tp.join(
417                |ctx| {
418                    f(ctx, depth + 1);
419                },
420                |ctx| {
421                    f(ctx, depth + 1);
422                });
423        }
424
425        COUNT.fetch_add(2, Ordering::SeqCst); // for the two invocations below.
426        tp.top().join(|ctx| f(ctx, 2), |ctx| f(ctx, 2));
427        let visited_threads = THREADS.lock().unwrap().len();
428        assert_eq!(visited_threads, 1 << TREE_LEVEL);
429        assert_eq!(COUNT.load(Ordering::SeqCst), 1 << MAX_DEPTH);
430    }
431
432    #[test]
433    #[should_panic]
434    fn panic_a() {
435        let pool = ThreadTree::new_with_level(1);
436        pool.top().join(|_| panic!("Panic in A"), |_| 1 + 1);
437    }
438
439    #[test]
440    #[should_panic]
441    fn panic_b() {
442        let pool = ThreadTree::new_with_level(1);
443        pool.top().join(|_| 1 + 1, |_| panic!());
444    }
445
446    #[test]
447    #[should_panic]
448    fn panic_both_in_threads() {
449        let pool = ThreadTree::new_with_level(1);
450        pool.top().join(|_| { sleep_ms(50); panic!("Panic in A") }, |_| panic!("Panic in B"));
451    }
452
453    #[test]
454    #[should_panic]
455    fn panic_both_bottom() {
456        let pool = ThreadTree::new_with_level(0);
457        pool.top().join(|_| { sleep_ms(50); panic!("Panic in A") }, |_| panic!("Panic in B"));
458    }
459
460    #[test]
461    fn on_panic_a_wait_for_b() {
462        let pool = ThreadTree::new_with_level(1);
463        for i in 0..3 {
464            let start = AtomicUsize::new(0);
465            let finish = AtomicUsize::new(0);
466            let result = unwind::halt_unwinding(|| {
467                pool.top().join(
468                    |_| panic!("Panic in A"),
469                    |_| {
470                        start.fetch_add(1, Ordering::SeqCst);
471                        sleep_ms(50);
472                        finish.fetch_add(1, Ordering::SeqCst);
473                    });
474            });
475            let start_count = start.load(Ordering::SeqCst);
476            let finish_count = finish.load(Ordering::SeqCst);
477            assert_eq!(start_count, finish_count);
478            assert!(result.is_err());
479            println!("Pass {} with start: {} == finish {}", i,
480                     start_count, finish_count);
481        }
482    }
483}