use crossbeam_channel::{Sender, bounded};
#[cfg(feature="unstable-thread-sea")]
use std::sync::Arc;
#[cfg(feature="unstable-thread-sea")]
use crossbeam_channel::Receiver;
#[cfg(feature="unstable-thread-sea")]
use std::sync::atomic::AtomicUsize;
#[cfg(feature="unstable-thread-sea")]
use std::sync::atomic::Ordering;
use std::thread;
mod unwind;
mod job;
use crate::job::{JobRef, StackJob};
#[cfg(feature="unstable-thread-sea")]
type Message = JobRef;
#[cfg(feature="unstable-thread-sea")]
type GroupMessage = Receiver<Message>;
#[cfg(feature="unstable-thread-sea")]
#[derive(Debug)]
pub struct ThreadSea {
sender: Sender<GroupMessage>,
receiver: Receiver<GroupMessage>,
thread_count: AtomicUsize,
threads_available: Arc<AtomicUsize>,
thread_id: AtomicUsize,
}
#[cfg(feature="unstable-thread-sea")]
#[derive(Debug)]
struct SeaLocalInfo {
receiver: Receiver<GroupMessage>,
threads_available: Arc<AtomicUsize>,
thread_id: usize,
}
#[cfg(feature="unstable-thread-sea")]
impl ThreadSea {
pub fn new(thread_count: usize) -> Self {
let (sender, receiver) = bounded(thread_count);
let nthreads = thread_count;
let thread_count = AtomicUsize::new(nthreads);
let threads_available = Arc::new(AtomicUsize::new(nthreads));
let thread_id = AtomicUsize::new(0);
let pool = ThreadSea { sender, receiver, threads_available, thread_count, thread_id };
for _ in 0..nthreads {
pool.add_thread();
}
pool
}
pub fn thread_count(&self) -> usize { self.thread_count.load(Ordering::Acquire) }
pub fn reserve(&self, thread_count: usize) -> ThreadPool {
let (sender, receiver) = bounded(0);
let mut nthreads = 0;
for _ in 0..thread_count {
let ret = self.sender.try_send(receiver.clone());
if ret.is_ok() { nthreads += 1; }
}
ThreadPool {
sender,
thread_count: nthreads,
}
}
fn local_info(&self) -> SeaLocalInfo {
let receiver = self.receiver.clone();
let threads_available = self.threads_available.clone();
let thread_id = self.thread_id.fetch_add(1, Ordering::Relaxed);
SeaLocalInfo { receiver, threads_available, thread_id }
}
fn add_thread(&self) {
let local = self.local_info();
std::thread::spawn(move || {
let my_local = local;
for channel in my_local.receiver {
for job in channel {
unsafe {
job.execute()
}
}
}
});
}
}
type TTreeMessage = JobRef;
#[derive(Debug)]
pub struct ThreadTree {
sender: Option<Sender<TTreeMessage>>,
child: Option<[Box<ThreadTree>; 2]>,
}
impl ThreadTree {
const BOTTOM: &'static Self = &ThreadTree::new_level0();
#[inline]
pub const fn new_level0() -> Self {
ThreadTree { sender: None, child: None }
}
pub fn new_with_level(level: usize) -> Box<Self> {
assert!(level <= 12,
"Input exceeds maximum level 12 (equivalent to 2**12 - 1 threads), got level='{}'",
level);
if level == 0 {
Box::new(Self::new_level0())
} else if level == 1 {
Box::new(ThreadTree { sender: Some(Self::add_thread()), child: None })
} else {
let fork_2 = Self::new_with_level(level - 1);
let fork_3 = Self::new_with_level(level - 1);
Box::new(ThreadTree { sender: Some(Self::add_thread()), child: Some([fork_2, fork_3])})
}
}
#[inline]
pub fn is_parallel(&self) -> bool {
self.sender.is_some()
}
#[inline]
pub fn top(&self) -> ThreadTreeCtx<'_> {
ThreadTreeCtx::from(self)
}
fn add_thread() -> Sender<TTreeMessage> {
let (sender, receiver) = bounded::<TTreeMessage>(1);
std::thread::spawn(move || {
for job in receiver {
unsafe {
job.execute()
}
}
});
sender
}
}
#[derive(Debug, Copy, Clone)]
pub struct ThreadTreeCtx<'a> {
tree: &'a ThreadTree,
_not_send_sync: *const (),
}
impl ThreadTreeCtx<'_> {
#[inline]
pub(crate) fn get(&self) -> &ThreadTree { self.tree }
#[inline]
pub(crate) fn from(tree: &ThreadTree) -> ThreadTreeCtx<'_> {
ThreadTreeCtx { tree, _not_send_sync: &() }
}
#[inline]
pub fn is_parallel(&self) -> bool {
self.get().is_parallel()
}
pub fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
where A: FnOnce(ThreadTreeCtx) -> RA + Send,
B: FnOnce(ThreadTreeCtx) -> RB + Send,
RA: Send,
RB: Send,
{
let bottom_level = ThreadTree::BOTTOM;
let self_ = self.get();
let (fork_a, fork_b) = match &self_.child {
None => (bottom_level, bottom_level),
Some([fa, fb]) => (&**fa, &**fb),
};
unsafe {
let a = move || a(ThreadTreeCtx::from(fork_a));
let b = move || b(ThreadTreeCtx::from(fork_b));
let b_job = StackJob::new(b);
let b_job_ref = JobRef::new(&b_job);
let b_runs_here = match self_.sender {
Some(ref s) => { s.send(b_job_ref).unwrap(); None }
None => Some(b_job_ref),
};
let a_result;
{
let _wait_for_b_guard = match b_runs_here {
None => Some(WaitForJobGuard::new(&b_job)),
Some(_) => None,
};
a_result = a();
if let Some(b_job_ref) = b_runs_here {
b_job_ref.execute();
}
}
(a_result, b_job.into_result())
}
}
pub fn join3l<A, RA>(&self, a: &A) -> ((RA, RA), RA)
where A: Fn(ThreadTreeCtx, usize) -> RA + Sync,
RA: Send,
{
self.join(
move |ctx| ctx.join(move |ctx| a(ctx, 0), move |ctx| a(ctx, 1)),
move |ctx| a(ctx, 2))
}
pub fn join3r<A, RA>(&self, a: &A) -> (RA, (RA, RA))
where A: Fn(ThreadTreeCtx, usize) -> RA + Sync,
RA: Send,
{
self.join(
move |ctx| a(ctx, 0),
move |ctx| ctx.join(move |ctx| a(ctx, 1), move |ctx| a(ctx, 2)))
}
pub fn join4<A, RA>(&self, a: &A) -> ((RA, RA), (RA, RA))
where A: Fn(ThreadTreeCtx, usize) -> RA + Sync,
RA: Send,
{
self.join(
move |ctx| ctx.join(move |ctx| a(ctx, 0), move |ctx| a(ctx, 1)),
move |ctx| ctx.join(move |ctx| a(ctx, 2), move |ctx| a(ctx, 3)))
}
}
#[cfg(feature="unstable-thread-sea")]
#[derive(Debug)]
pub struct ThreadPool {
sender: Sender<JobRef>,
thread_count: usize,
}
#[cfg(feature="unstable-thread-sea")]
#[derive(Debug)]
struct LocalInfo {
receiver: Receiver<JobRef>,
}
#[cfg(feature="unstable-thread-sea")]
impl ThreadPool {
pub fn new(thread_count: usize) -> Self {
let (sender, receiver) = bounded(0);
let pool = ThreadPool { sender, thread_count };
for _ in 0..thread_count {
pool.add_thread(&receiver);
}
pool
}
pub fn thread_count(&self) -> usize { self.thread_count }
fn add_thread(&self, receiver: &Receiver<JobRef>) {
let local = LocalInfo { receiver: receiver.clone() };
std::thread::spawn(move || {
let my_local = local;
for job in my_local.receiver {
unsafe {
job.execute();
}
}
});
}
pub fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
where A: FnOnce() -> RA + Send,
B: FnOnce() -> RB + Send,
RA: Send,
RB: Send,
{
unsafe {
let b_job = StackJob::new(b);
let b_job_ref = JobRef::new(&b_job);
let b_runs_here = match self.sender.try_send(b_job_ref) {
Ok(_) => None,
Err(e) => Some(e.into_inner()),
};
let a_result;
{
let _wait_for_b_guard = match b_runs_here {
None => Some(WaitForJobGuard::new(&b_job)),
Some(_) => None,
};
a_result = a();
if let Some(b_job_ref) = b_runs_here {
b_job_ref.execute();
}
}
(a_result, b_job.into_result())
}
}
pub fn recursive_fork_join<S, FS, FE, FC, R>(&self, seed: S, splitter: FS, for_each: FE, combine: FC) -> R
where FS: Fn(S) -> (S, Option<S>) + Sync,
FE: Fn(S) -> R + Sync,
FC: Fn(R, R) -> R + Sync,
R: Send,
S: Send,
{
self.recursive_join_(seed, &splitter, &for_each, &combine)
}
fn recursive_join_<S, FS, FE, FC, R>(&self, seed: S, splitter: &FS, for_each: &FE, combine: &FC) -> R
where FS: Fn(S) -> (S, Option<S>) + Sync,
FE: Fn(S) -> R + Sync,
FC: Fn(R, R) -> R + Sync,
R: Send,
S: Send,
{
match splitter(seed) {
(single, None) => for_each(single),
(first, Some(second)) => {
let (a, b) = self.join(
move || self.recursive_join_(first, splitter, for_each, combine),
move || self.recursive_join_(second, splitter, for_each, combine));
combine(a, b)
}
}
}
}
fn wait_for_job<F, R>(job: &StackJob<F, R>) {
while !job.probe() {
thread::yield_now();
}
}
struct WaitForJobGuard<'a, F, R> {
job: &'a StackJob<F, R>,
}
impl<'a, F, R> WaitForJobGuard<'a, F, R>
{
fn new(job: &'a StackJob<F, R>) -> Self {
Self { job }
}
}
impl<'a, F, R> Drop for WaitForJobGuard<'a, F, R> {
fn drop(&mut self) {
wait_for_job(self.job)
}
}
#[cfg(feature="unstable-thread-sea")]
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
#[allow(deprecated)]
fn sleep_ms(x: u32) {
std::thread::sleep_ms(x)
}
#[test]
fn it_works() {
let pool = ThreadPool::new(10);
pool.join(
|| {
println!("I'm f!");
sleep_ms(100);
pool.join(|| {
println!("f.0");
pool.join(|| {
println!("f.0.0");
sleep_ms(500);
},
|| {
println!("f.0.1");
});
},
|| {
println!("f.1");
pool.join(|| {
println!("f.1.0");
},
|| {
println!("f.1.1");
});
});
},
|| {
println!("I'm g!"); sleep_ms(100)
},
);
drop(pool);
sleep_ms(100);
}
#[test]
fn recursive() {
let pool = ThreadPool::new(50);
let ret = pool.recursive_fork_join(0..127, |x| {
let len = x.end - x.start;
let mid = x.start + len / 2;
if len > 3 {
(x.start..mid, Some(mid..x.end))
} else {
(x, None)
}
},
|value| {
println!("Thread: {:?}", value);
value.sum::<i32>()
},
|a, b| a + b);
assert_eq!(ret, (0..127).sum());
}
#[test]
#[should_panic]
fn panic_a() {
let pool = ThreadPool::new(2);
pool.join(|| panic!(), || 1 + 1);
}
#[test]
#[should_panic]
fn panic_b() {
let pool = ThreadPool::new(2);
pool.join(|| 1 + 1, || panic!());
}
#[test]
#[should_panic]
fn panic_both() {
let pool = ThreadPool::new(2);
pool.join(|| { sleep_ms(50); panic!("Panic in A") }, || panic!("Panic in B"));
}
#[test]
fn on_panic_a_wait_for_b() {
let pool = ThreadPool::new(2);
for i in 0..3 {
let start = AtomicUsize::new(0);
let finish = AtomicUsize::new(0);
let result = unwind::halt_unwinding(|| {
pool.join(
|| panic!("Panic in A"),
|| {
start.fetch_add(1, Ordering::SeqCst);
sleep_ms(50);
finish.fetch_add(1, Ordering::SeqCst);
});
});
let start_count = start.load(Ordering::SeqCst);
let finish_count = finish.load(Ordering::SeqCst);
assert_eq!(start_count, finish_count);
assert!(result.is_err());
println!("Pass {} with start: {} == finish {}", i,
start_count, finish_count);
}
}
}
#[cfg(feature="unstable-thread-sea")]
#[cfg(test)]
mod sea_tests {
use super::*;
#[allow(deprecated)]
#[test]
fn thread_count_0() {
let sea = ThreadSea::new(0);
let pool1 = sea.reserve(0);
}
#[test]
fn recursive() {
let sea = ThreadSea::new(50);
let pool1 = sea.reserve(25);
pool1.recursive_fork_join(0..127, |x| {
let len = x.end - x.start;
let mid = x.start + len / 2;
if len > 3 {
(x.start..mid, Some(mid..x.end))
} else {
(x, None)
}
},
|value| {
println!("Thread: {:?}", value);
},
|_, _| ()
);
let pool2 = sea.reserve(50);
drop(pool1);
pool2.recursive_fork_join(0..127, |x| {
let len = x.end - x.start;
let mid = x.start + len / 2;
if len > 3 {
(x.start..mid, Some(mid..x.end))
} else {
(x, None)
}
},
|value| {
println!("Thread: {:?}", value);
},
|_, _| ()
);
}
}
#[cfg(test)]
mod thread_tree_tests {
use super::*;
#[allow(deprecated)]
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use std::sync::Mutex;
use once_cell::sync::Lazy;
use std::collections::HashSet;
use std::thread;
use std::thread::ThreadId;
#[allow(deprecated)]
fn sleep_ms(x: u32) {
std::thread::sleep_ms(x)
}
#[test]
fn stub() {
let tp = ThreadTree::new_level0();
let a = AtomicUsize::new(0);
let b = AtomicUsize::new(0);
tp.top().join(|_| a.fetch_add(1, Ordering::SeqCst),
|_| b.fetch_add(1, Ordering::SeqCst));
assert_eq!(a.load(Ordering::SeqCst), 1);
assert_eq!(b.load(Ordering::SeqCst), 1);
let f = || thread::current().id();
let (aid, bid) = tp.top().join(|_| f(), |_| f());
assert_eq!(aid, bid);
assert!(!tp.top().is_parallel());
}
#[test]
fn new_level_1() {
let tp = ThreadTree::new_with_level(1);
let a = AtomicUsize::new(0);
let b = AtomicUsize::new(0);
tp.top().join(|_| a.fetch_add(1, Ordering::SeqCst),
|_| b.fetch_add(1, Ordering::SeqCst));
assert_eq!(a.load(Ordering::SeqCst), 1);
assert_eq!(b.load(Ordering::SeqCst), 1);
let f = || thread::current().id();
let (aid, bid) = tp.top().join(|_| f(), |_| f());
assert_ne!(aid, bid);
assert!(tp.top().is_parallel());
}
#[test]
fn build_level_2() {
let tp = ThreadTree::new_with_level(2);
let a = AtomicUsize::new(0);
let b = AtomicUsize::new(0);
tp.top().join(|_| a.fetch_add(1, Ordering::SeqCst),
|_| b.fetch_add(1, Ordering::SeqCst));
assert_eq!(a.load(Ordering::SeqCst), 1);
assert_eq!(b.load(Ordering::SeqCst), 1);
let f = || thread::current().id();
let ((aid, bid), (cid, did)) = tp.top().join(
|tp1| tp1.join(|_| f(), |_| f()),
|tp1| tp1.join(|_| f(), |_| f()));
assert_ne!(aid, bid);
assert_ne!(aid, cid);
assert_ne!(aid, did);
assert_ne!(bid, cid);
assert_ne!(bid, did);
assert_ne!(cid, did);
}
#[test]
fn overload_2_2() {
let global = ThreadTree::new_with_level(1);
let tp = ThreadTree::new_with_level(2);
let a = AtomicUsize::new(0);
let range = 0..100;
let work = |ctx: ThreadTreeCtx<'_>| {
let subwork = || {
for i in range.clone() {
a.fetch_add(i, Ordering::Relaxed);
sleep_ms(1);
}
};
ctx.join(|_| subwork(), |_| subwork());
};
global.top().join(
|_| tp.top().join(work, work),
|_| tp.top().join(work, work));
let sum = range.clone().sum::<usize>();
assert_eq!(sum * 4 * 2, a.load(Ordering::SeqCst));
}
#[test]
fn deep_tree() {
static THREADS: Lazy<Mutex<HashSet<ThreadId>>> = Lazy::new(|| Mutex::default());
const TREE_LEVEL: usize = 8;
const MAX_DEPTH: usize = 12;
static COUNT: AtomicUsize = AtomicUsize::new(0);
let tp = ThreadTree::new_with_level(TREE_LEVEL);
fn f(tp: ThreadTreeCtx<'_>, depth: usize) {
COUNT.fetch_add(1, Ordering::SeqCst);
THREADS.lock().unwrap().insert(thread::current().id());
if depth >= MAX_DEPTH {
return;
}
tp.join(
|ctx| {
f(ctx, depth + 1);
},
|ctx| {
f(ctx, depth + 1);
});
}
COUNT.fetch_add(2, Ordering::SeqCst);
tp.top().join(|ctx| f(ctx, 2), |ctx| f(ctx, 2));
let visited_threads = THREADS.lock().unwrap().len();
assert_eq!(visited_threads, 1 << TREE_LEVEL);
assert_eq!(COUNT.load(Ordering::SeqCst), 1 << MAX_DEPTH);
}
#[test]
#[should_panic]
fn panic_a() {
let pool = ThreadTree::new_with_level(1);
pool.top().join(|_| panic!("Panic in A"), |_| 1 + 1);
}
#[test]
#[should_panic]
fn panic_b() {
let pool = ThreadTree::new_with_level(1);
pool.top().join(|_| 1 + 1, |_| panic!());
}
#[test]
#[should_panic]
fn panic_both_in_threads() {
let pool = ThreadTree::new_with_level(1);
pool.top().join(|_| { sleep_ms(50); panic!("Panic in A") }, |_| panic!("Panic in B"));
}
#[test]
#[should_panic]
fn panic_both_bottom() {
let pool = ThreadTree::new_with_level(0);
pool.top().join(|_| { sleep_ms(50); panic!("Panic in A") }, |_| panic!("Panic in B"));
}
#[test]
fn on_panic_a_wait_for_b() {
let pool = ThreadTree::new_with_level(1);
for i in 0..3 {
let start = AtomicUsize::new(0);
let finish = AtomicUsize::new(0);
let result = unwind::halt_unwinding(|| {
pool.top().join(
|_| panic!("Panic in A"),
|_| {
start.fetch_add(1, Ordering::SeqCst);
sleep_ms(50);
finish.fetch_add(1, Ordering::SeqCst);
});
});
let start_count = start.load(Ordering::SeqCst);
let finish_count = finish.load(Ordering::SeqCst);
assert_eq!(start_count, finish_count);
assert!(result.is_err());
println!("Pass {} with start: {} == finish {}", i,
start_count, finish_count);
}
}
}