use crossbeam_channel::{Sender, bounded};
use std::thread;
mod unwind;
mod job;
use crate::job::{JobRef, StackJob};
type TTreeMessage = JobRef;
#[derive(Debug)]
pub struct ThreadTree {
sender: Option<Sender<TTreeMessage>>,
child: Option<[Box<ThreadTree>; 2]>,
}
impl ThreadTree {
const BOTTOM: &'static Self = &ThreadTree::new_level0();
#[inline]
pub const fn new_level0() -> Self {
ThreadTree { sender: None, child: None }
}
pub fn new_with_level(level: usize) -> Box<Self> {
assert!(level <= 12,
"Input exceeds maximum level 12 (equivalent to 2**12 - 1 threads), got level='{}'",
level);
if level == 0 {
Box::new(Self::new_level0())
} else if level == 1 {
Box::new(ThreadTree { sender: Some(Self::add_thread()), child: None })
} else {
let fork_2 = Self::new_with_level(level - 1);
let fork_3 = Self::new_with_level(level - 1);
Box::new(ThreadTree { sender: Some(Self::add_thread()), child: Some([fork_2, fork_3])})
}
}
#[inline]
pub fn is_parallel(&self) -> bool {
self.sender.is_some()
}
#[inline]
pub fn top(&self) -> ThreadTreeCtx<'_> {
ThreadTreeCtx::from(self)
}
fn add_thread() -> Sender<TTreeMessage> {
let (sender, receiver) = bounded::<TTreeMessage>(1); std::thread::spawn(move || {
for job in receiver {
unsafe {
job.execute()
}
}
});
sender
}
}
#[derive(Debug, Copy, Clone)]
pub struct ThreadTreeCtx<'a> {
tree: &'a ThreadTree,
_not_send_sync: *const (),
}
impl ThreadTreeCtx<'_> {
#[inline]
pub(crate) fn get(&self) -> &ThreadTree { self.tree }
#[inline]
pub(crate) fn from(tree: &ThreadTree) -> ThreadTreeCtx<'_> {
ThreadTreeCtx { tree, _not_send_sync: &() }
}
#[inline]
pub fn is_parallel(&self) -> bool {
self.get().is_parallel()
}
pub fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
where A: FnOnce(ThreadTreeCtx) -> RA + Send,
B: FnOnce(ThreadTreeCtx) -> RB + Send,
RA: Send,
RB: Send,
{
let bottom_level = ThreadTree::BOTTOM;
let self_ = self.get();
let (fork_a, fork_b) = match &self_.child {
None => (bottom_level, bottom_level),
Some([fa, fb]) => (&**fa, &**fb),
};
unsafe {
let a = move || a(ThreadTreeCtx::from(fork_a));
let b = move || b(ThreadTreeCtx::from(fork_b));
let b_job = StackJob::new(b); let b_job_ref = JobRef::new(&b_job);
let b_runs_here = match self_.sender {
Some(ref s) => { s.send(b_job_ref).unwrap(); None }
None => Some(b_job_ref),
};
let a_result;
{
let _wait_for_b_guard = match b_runs_here {
None => Some(WaitForJobGuard::new(&b_job)),
Some(_) => None,
};
a_result = a();
if let Some(b_job_ref) = b_runs_here {
b_job_ref.execute();
}
}
(a_result, b_job.into_result())
}
}
pub fn join3l<A, RA>(&self, a: &A) -> ((RA, RA), RA)
where A: Fn(ThreadTreeCtx, usize) -> RA + Sync,
RA: Send,
{
self.join(
move |ctx| ctx.join(move |ctx| a(ctx, 0), move |ctx| a(ctx, 1)),
move |ctx| a(ctx, 2))
}
pub fn join3r<A, RA>(&self, a: &A) -> (RA, (RA, RA))
where A: Fn(ThreadTreeCtx, usize) -> RA + Sync,
RA: Send,
{
self.join(
move |ctx| a(ctx, 0),
move |ctx| ctx.join(move |ctx| a(ctx, 1), move |ctx| a(ctx, 2)))
}
pub fn join4<A, RA>(&self, a: &A) -> ((RA, RA), (RA, RA))
where A: Fn(ThreadTreeCtx, usize) -> RA + Sync,
RA: Send,
{
self.join(
move |ctx| ctx.join(move |ctx| a(ctx, 0), move |ctx| a(ctx, 1)),
move |ctx| ctx.join(move |ctx| a(ctx, 2), move |ctx| a(ctx, 3)))
}
}
fn wait_for_job<F, R>(job: &StackJob<F, R>) {
while !job.probe() {
thread::yield_now();
}
}
struct WaitForJobGuard<'a, F, R> {
job: &'a StackJob<F, R>,
}
impl<'a, F, R> WaitForJobGuard<'a, F, R>
{
fn new(job: &'a StackJob<F, R>) -> Self {
Self { job }
}
}
impl<'a, F, R> Drop for WaitForJobGuard<'a, F, R> {
fn drop(&mut self) {
wait_for_job(self.job)
}
}
#[cfg(test)]
mod thread_tree_tests {
use super::*;
#[allow(deprecated)]
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use std::sync::Mutex;
use once_cell::sync::Lazy;
use std::collections::HashSet;
use std::thread;
use std::thread::ThreadId;
#[allow(deprecated)]
fn sleep_ms(x: u32) {
std::thread::sleep_ms(x)
}
#[test]
fn stub() {
let tp = ThreadTree::new_level0();
let a = AtomicUsize::new(0);
let b = AtomicUsize::new(0);
tp.top().join(|_| a.fetch_add(1, Ordering::SeqCst),
|_| b.fetch_add(1, Ordering::SeqCst));
assert_eq!(a.load(Ordering::SeqCst), 1);
assert_eq!(b.load(Ordering::SeqCst), 1);
let f = || thread::current().id();
let (aid, bid) = tp.top().join(|_| f(), |_| f());
assert_eq!(aid, bid);
assert!(!tp.top().is_parallel());
}
#[test]
fn new_level_1() {
let tp = ThreadTree::new_with_level(1);
let a = AtomicUsize::new(0);
let b = AtomicUsize::new(0);
tp.top().join(|_| a.fetch_add(1, Ordering::SeqCst),
|_| b.fetch_add(1, Ordering::SeqCst));
assert_eq!(a.load(Ordering::SeqCst), 1);
assert_eq!(b.load(Ordering::SeqCst), 1);
let f = || thread::current().id();
let (aid, bid) = tp.top().join(|_| f(), |_| f());
assert_ne!(aid, bid);
assert!(tp.top().is_parallel());
}
#[test]
fn build_level_2() {
let tp = ThreadTree::new_with_level(2);
let a = AtomicUsize::new(0);
let b = AtomicUsize::new(0);
tp.top().join(|_| a.fetch_add(1, Ordering::SeqCst),
|_| b.fetch_add(1, Ordering::SeqCst));
assert_eq!(a.load(Ordering::SeqCst), 1);
assert_eq!(b.load(Ordering::SeqCst), 1);
let f = || thread::current().id();
let ((aid, bid), (cid, did)) = tp.top().join(
|tp1| tp1.join(|_| f(), |_| f()),
|tp1| tp1.join(|_| f(), |_| f()));
assert_ne!(aid, bid);
assert_ne!(aid, cid);
assert_ne!(aid, did);
assert_ne!(bid, cid);
assert_ne!(bid, did);
assert_ne!(cid, did);
}
#[test]
fn overload_2_2() {
let global = ThreadTree::new_with_level(1);
let tp = ThreadTree::new_with_level(2);
let a = AtomicUsize::new(0);
let range = 0..100;
let work = |ctx: ThreadTreeCtx<'_>| {
let subwork = || {
for i in range.clone() {
a.fetch_add(i, Ordering::Relaxed);
sleep_ms(1);
}
};
ctx.join(|_| subwork(), |_| subwork());
};
global.top().join(
|_| tp.top().join(work, work),
|_| tp.top().join(work, work));
let sum = range.clone().sum::<usize>();
assert_eq!(sum * 4 * 2, a.load(Ordering::SeqCst));
}
#[test]
fn deep_tree() {
static THREADS: Lazy<Mutex<HashSet<ThreadId>>> = Lazy::new(|| Mutex::default());
const TREE_LEVEL: usize = 8;
const MAX_DEPTH: usize = 12;
static COUNT: AtomicUsize = AtomicUsize::new(0);
let tp = ThreadTree::new_with_level(TREE_LEVEL);
fn f(tp: ThreadTreeCtx<'_>, depth: usize) {
COUNT.fetch_add(1, Ordering::SeqCst);
THREADS.lock().unwrap().insert(thread::current().id());
if depth >= MAX_DEPTH {
return;
}
tp.join(
|ctx| {
f(ctx, depth + 1);
},
|ctx| {
f(ctx, depth + 1);
});
}
COUNT.fetch_add(2, Ordering::SeqCst); tp.top().join(|ctx| f(ctx, 2), |ctx| f(ctx, 2));
let visited_threads = THREADS.lock().unwrap().len();
assert_eq!(visited_threads, 1 << TREE_LEVEL);
assert_eq!(COUNT.load(Ordering::SeqCst), 1 << MAX_DEPTH);
}
#[test]
#[should_panic]
fn panic_a() {
let pool = ThreadTree::new_with_level(1);
pool.top().join(|_| panic!("Panic in A"), |_| 1 + 1);
}
#[test]
#[should_panic]
fn panic_b() {
let pool = ThreadTree::new_with_level(1);
pool.top().join(|_| 1 + 1, |_| panic!());
}
#[test]
#[should_panic]
fn panic_both_in_threads() {
let pool = ThreadTree::new_with_level(1);
pool.top().join(|_| { sleep_ms(50); panic!("Panic in A") }, |_| panic!("Panic in B"));
}
#[test]
#[should_panic]
fn panic_both_bottom() {
let pool = ThreadTree::new_with_level(0);
pool.top().join(|_| { sleep_ms(50); panic!("Panic in A") }, |_| panic!("Panic in B"));
}
#[test]
fn on_panic_a_wait_for_b() {
let pool = ThreadTree::new_with_level(1);
for i in 0..3 {
let start = AtomicUsize::new(0);
let finish = AtomicUsize::new(0);
let result = unwind::halt_unwinding(|| {
pool.top().join(
|_| panic!("Panic in A"),
|_| {
start.fetch_add(1, Ordering::SeqCst);
sleep_ms(50);
finish.fetch_add(1, Ordering::SeqCst);
});
});
let start_count = start.load(Ordering::SeqCst);
let finish_count = finish.load(Ordering::SeqCst);
assert_eq!(start_count, finish_count);
assert!(result.is_err());
println!("Pass {} with start: {} == finish {}", i,
start_count, finish_count);
}
}
}