1use crossbeam_channel::{Sender, bounded};
18
19use std::thread;
20
21mod unwind;
22mod job;
23
24use crate::job::{JobRef, StackJob};
25
26type TTreeMessage = JobRef;
28
29#[derive(Debug)]
38pub struct ThreadTree {
39 sender: Option<Sender<TTreeMessage>>,
40 child: Option<[Box<ThreadTree>; 2]>,
41}
42
43impl ThreadTree {
61 const BOTTOM: &'static Self = &ThreadTree::new_level0();
62
63 #[inline]
65 pub const fn new_level0() -> Self {
66 ThreadTree { sender: None, child: None }
67 }
68
69 pub fn new_with_level(level: usize) -> Box<Self> {
77 assert!(level <= 12,
78 "Input exceeds maximum level 12 (equivalent to 2**12 - 1 threads), got level='{}'",
79 level);
80 if level == 0 {
81 Box::new(Self::new_level0())
82 } else if level == 1 {
83 Box::new(ThreadTree { sender: Some(Self::add_thread()), child: None })
84 } else {
85 let fork_2 = Self::new_with_level(level - 1);
86 let fork_3 = Self::new_with_level(level - 1);
87 Box::new(ThreadTree { sender: Some(Self::add_thread()), child: Some([fork_2, fork_3])})
88 }
89 }
90
91 #[inline]
93 pub fn is_parallel(&self) -> bool {
94 self.sender.is_some()
95 }
96
97 #[inline]
124 pub fn top(&self) -> ThreadTreeCtx<'_> {
125 ThreadTreeCtx::from(self)
126 }
127
128 fn add_thread() -> Sender<TTreeMessage> {
131 let (sender, receiver) = bounded::<TTreeMessage>(1); std::thread::spawn(move || {
133 for job in receiver {
134 unsafe {
135 job.execute()
136 }
137 }
138 });
139 sender
140 }
141}
142
143#[derive(Debug, Copy, Clone)]
147pub struct ThreadTreeCtx<'a> {
148 tree: &'a ThreadTree,
149 _not_send_sync: *const (),
152}
153
154impl ThreadTreeCtx<'_> {
155 #[inline]
156 pub(crate) fn get(&self) -> &ThreadTree { self.tree }
157
158 #[inline]
159 pub(crate) fn from(tree: &ThreadTree) -> ThreadTreeCtx<'_> {
160 ThreadTreeCtx { tree, _not_send_sync: &() }
161 }
162
163 #[inline]
165 pub fn is_parallel(&self) -> bool {
166 self.get().is_parallel()
167 }
168
169 pub fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
182 where A: FnOnce(ThreadTreeCtx) -> RA + Send,
183 B: FnOnce(ThreadTreeCtx) -> RB + Send,
184 RA: Send,
185 RB: Send,
186 {
187 let bottom_level = ThreadTree::BOTTOM;
188 let self_ = self.get();
189 let (fork_a, fork_b) = match &self_.child {
190 None => (bottom_level, bottom_level),
191 Some([fa, fb]) => (&**fa, &**fb),
192 };
193 unsafe {
196 let a = move || a(ThreadTreeCtx::from(fork_a));
197 let b = move || b(ThreadTreeCtx::from(fork_b));
198
199 let b_job = StackJob::new(b); let b_job_ref = JobRef::new(&b_job);
202 let b_runs_here = match self_.sender {
203 Some(ref s) => { s.send(b_job_ref).unwrap(); None }
204 None => Some(b_job_ref),
205 };
206
207 let a_result;
208 {
209 let _wait_for_b_guard = match b_runs_here {
216 None => Some(WaitForJobGuard::new(&b_job)),
217 Some(_) => None,
218 };
219
220 a_result = a();
222
223 if let Some(b_job_ref) = b_runs_here {
224 b_job_ref.execute();
225 }
226 }
228 (a_result, b_job.into_result())
229 }
230 }
231
232 pub fn join3l<A, RA>(&self, a: &A) -> ((RA, RA), RA)
237 where A: Fn(ThreadTreeCtx, usize) -> RA + Sync,
238 RA: Send,
239 {
240 self.join(
241 move |ctx| ctx.join(move |ctx| a(ctx, 0), move |ctx| a(ctx, 1)),
242 move |ctx| a(ctx, 2))
243 }
244
245 pub fn join3r<A, RA>(&self, a: &A) -> (RA, (RA, RA))
250 where A: Fn(ThreadTreeCtx, usize) -> RA + Sync,
251 RA: Send,
252 {
253 self.join(
254 move |ctx| a(ctx, 0),
255 move |ctx| ctx.join(move |ctx| a(ctx, 1), move |ctx| a(ctx, 2)))
256 }
257
258 pub fn join4<A, RA>(&self, a: &A) -> ((RA, RA), (RA, RA))
263 where A: Fn(ThreadTreeCtx, usize) -> RA + Sync,
264 RA: Send,
265 {
266 self.join(
267 move |ctx| ctx.join(move |ctx| a(ctx, 0), move |ctx| a(ctx, 1)),
268 move |ctx| ctx.join(move |ctx| a(ctx, 2), move |ctx| a(ctx, 3)))
269 }
270}
271
272
273fn wait_for_job<F, R>(job: &StackJob<F, R>) {
274 while !job.probe() {
275 thread::yield_now();
277 }
278}
279
280struct WaitForJobGuard<'a, F, R> {
281 job: &'a StackJob<F, R>,
282}
283
284impl<'a, F, R> WaitForJobGuard<'a, F, R>
285{
286 fn new(job: &'a StackJob<F, R>) -> Self {
287 Self { job }
288 }
289}
290
291impl<'a, F, R> Drop for WaitForJobGuard<'a, F, R> {
292 fn drop(&mut self) {
293 wait_for_job(self.job)
294 }
295}
296
297#[cfg(test)]
298mod thread_tree_tests {
299 use super::*;
300 #[allow(deprecated)]
301
302 use std::sync::atomic::AtomicUsize;
303 use std::sync::atomic::Ordering;
304 use std::sync::Mutex;
305 use once_cell::sync::Lazy;
306 use std::collections::HashSet;
307 use std::thread;
308 use std::thread::ThreadId;
309
310 #[allow(deprecated)]
311 fn sleep_ms(x: u32) {
312 std::thread::sleep_ms(x)
313 }
314
315 #[test]
316 fn stub() {
317 let tp = ThreadTree::new_level0();
318 let a = AtomicUsize::new(0);
319 let b = AtomicUsize::new(0);
320
321 tp.top().join(|_| a.fetch_add(1, Ordering::SeqCst),
322 |_| b.fetch_add(1, Ordering::SeqCst));
323 assert_eq!(a.load(Ordering::SeqCst), 1);
324 assert_eq!(b.load(Ordering::SeqCst), 1);
325
326 let f = || thread::current().id();
327 let (aid, bid) = tp.top().join(|_| f(), |_| f());
328 assert_eq!(aid, bid);
329 assert!(!tp.top().is_parallel());
330 }
331
332 #[test]
333 fn new_level_1() {
334 let tp = ThreadTree::new_with_level(1);
335 let a = AtomicUsize::new(0);
336 let b = AtomicUsize::new(0);
337
338 tp.top().join(|_| a.fetch_add(1, Ordering::SeqCst),
339 |_| b.fetch_add(1, Ordering::SeqCst));
340 assert_eq!(a.load(Ordering::SeqCst), 1);
341 assert_eq!(b.load(Ordering::SeqCst), 1);
342
343 let f = || thread::current().id();
344 let (aid, bid) = tp.top().join(|_| f(), |_| f());
345 assert_ne!(aid, bid);
346 assert!(tp.top().is_parallel());
347 }
348
349 #[test]
350 fn build_level_2() {
351 let tp = ThreadTree::new_with_level(2);
352 let a = AtomicUsize::new(0);
353 let b = AtomicUsize::new(0);
354
355 tp.top().join(|_| a.fetch_add(1, Ordering::SeqCst),
356 |_| b.fetch_add(1, Ordering::SeqCst));
357 assert_eq!(a.load(Ordering::SeqCst), 1);
358 assert_eq!(b.load(Ordering::SeqCst), 1);
359
360 let f = || thread::current().id();
361 let ((aid, bid), (cid, did)) = tp.top().join(
362 |tp1| tp1.join(|_| f(), |_| f()),
363 |tp1| tp1.join(|_| f(), |_| f()));
364 assert_ne!(aid, bid);
365 assert_ne!(aid, cid);
366 assert_ne!(aid, did);
367 assert_ne!(bid, cid);
368 assert_ne!(bid, did);
369 assert_ne!(cid, did);
370 }
371
372 #[test]
373 fn overload_2_2() {
374 let global = ThreadTree::new_with_level(1);
375 let tp = ThreadTree::new_with_level(2);
376 let a = AtomicUsize::new(0);
377
378 let range = 0..100;
379
380 let work = |ctx: ThreadTreeCtx<'_>| {
381 let subwork = || {
382 for i in range.clone() {
383 a.fetch_add(i, Ordering::Relaxed);
384 sleep_ms(1);
385 }
386 };
387 ctx.join(|_| subwork(), |_| subwork());
388 };
389
390 global.top().join(
391 |_| tp.top().join(work, work),
392 |_| tp.top().join(work, work));
393
394 let sum = range.clone().sum::<usize>();
395
396 assert_eq!(sum * 4 * 2, a.load(Ordering::SeqCst));
397
398 }
399
400 #[test]
401 fn deep_tree() {
402 static THREADS: Lazy<Mutex<HashSet<ThreadId>>> = Lazy::new(|| Mutex::default());
403 const TREE_LEVEL: usize = 8;
404 const MAX_DEPTH: usize = 12;
405
406 static COUNT: AtomicUsize = AtomicUsize::new(0);
407
408 let tp = ThreadTree::new_with_level(TREE_LEVEL);
409
410 fn f(tp: ThreadTreeCtx<'_>, depth: usize) {
411 COUNT.fetch_add(1, Ordering::SeqCst);
412 THREADS.lock().unwrap().insert(thread::current().id());
413 if depth >= MAX_DEPTH {
414 return;
415 }
416 tp.join(
417 |ctx| {
418 f(ctx, depth + 1);
419 },
420 |ctx| {
421 f(ctx, depth + 1);
422 });
423 }
424
425 COUNT.fetch_add(2, Ordering::SeqCst); tp.top().join(|ctx| f(ctx, 2), |ctx| f(ctx, 2));
427 let visited_threads = THREADS.lock().unwrap().len();
428 assert_eq!(visited_threads, 1 << TREE_LEVEL);
429 assert_eq!(COUNT.load(Ordering::SeqCst), 1 << MAX_DEPTH);
430 }
431
432 #[test]
433 #[should_panic]
434 fn panic_a() {
435 let pool = ThreadTree::new_with_level(1);
436 pool.top().join(|_| panic!("Panic in A"), |_| 1 + 1);
437 }
438
439 #[test]
440 #[should_panic]
441 fn panic_b() {
442 let pool = ThreadTree::new_with_level(1);
443 pool.top().join(|_| 1 + 1, |_| panic!());
444 }
445
446 #[test]
447 #[should_panic]
448 fn panic_both_in_threads() {
449 let pool = ThreadTree::new_with_level(1);
450 pool.top().join(|_| { sleep_ms(50); panic!("Panic in A") }, |_| panic!("Panic in B"));
451 }
452
453 #[test]
454 #[should_panic]
455 fn panic_both_bottom() {
456 let pool = ThreadTree::new_with_level(0);
457 pool.top().join(|_| { sleep_ms(50); panic!("Panic in A") }, |_| panic!("Panic in B"));
458 }
459
460 #[test]
461 fn on_panic_a_wait_for_b() {
462 let pool = ThreadTree::new_with_level(1);
463 for i in 0..3 {
464 let start = AtomicUsize::new(0);
465 let finish = AtomicUsize::new(0);
466 let result = unwind::halt_unwinding(|| {
467 pool.top().join(
468 |_| panic!("Panic in A"),
469 |_| {
470 start.fetch_add(1, Ordering::SeqCst);
471 sleep_ms(50);
472 finish.fetch_add(1, Ordering::SeqCst);
473 });
474 });
475 let start_count = start.load(Ordering::SeqCst);
476 let finish_count = finish.load(Ordering::SeqCst);
477 assert_eq!(start_count, finish_count);
478 assert!(result.is_err());
479 println!("Pass {} with start: {} == finish {}", i,
480 start_count, finish_count);
481 }
482 }
483}