use std::fmt;
use std::io;
use std::marker::PhantomData;
use std::mem;
use std::panic;
use std::sync::{Arc, Mutex};
use std::thread;
use crate::wait_group::WaitGroup;
use cfg_if::cfg_if;
use panic::catch_unwind;
type SharedVec<T> = Arc<Mutex<Vec<T>>>;
type SharedOption<T> = Arc<Mutex<Option<T>>>;
pub struct ThreadGroup<'env, R> {
handles: SharedVec<SharedOption<thread::JoinHandle<()>>>,
wait_group: WaitGroup,
res: Option<thread::Result<R>>,
_marker: PhantomData<&'env mut &'env ()>,
}
impl<'env, R> ThreadGroup<'env, R> {
pub fn new<F>(mut f: F) -> Self
where
F: FnMut(&Self) -> R,
{
let mut group = ThreadGroup::<'env> {
handles: SharedVec::default(),
wait_group: WaitGroup::new(),
res: None,
_marker: PhantomData,
};
let res = panic::catch_unwind(panic::AssertUnwindSafe(|| f(&group)));
group.res = Some(res);
group
}
pub fn join(self) -> thread::Result<R> {
if let Some(Err(err)) = self.res {
return Err(err);
}
if let Some(Err(err)) = self.res {
return Err(err);
}
if let Err(id) = self.wait_group.wait() {
for handle in self.handles.lock().unwrap().iter() {
let handle = handle.lock().unwrap().take().unwrap();
if id == handle.thread().id() {
let err = handle.join().unwrap_err();
return Err(err);
}
}
};
let panics: Vec<_> = self
.handles
.lock()
.unwrap()
.drain(..)
.filter_map(|handle| handle.lock().unwrap().take())
.filter_map(|handle| handle.join().err())
.collect();
match self.res.unwrap() {
Err(err) => panic::resume_unwind(err),
Ok(res) => {
if panics.is_empty() {
Ok(res)
} else {
Err(Box::new(panics))
}
}
}
}
}
unsafe impl<R> Sync for ThreadGroup<'_, R> {}
impl<'env, R: 'env + Send> ThreadGroup<'env, R> {
pub fn spawn<'group, F, T>(&'group self, f: F)
where
F: FnOnce(&ThreadGroup<'env, R>) -> T,
F: Send + 'env,
T: Send + 'env,
{
self.builder()
.spawn(f)
.expect("failed to spawn thread in group")
}
pub fn builder<'group>(&'group self) -> ThreadGroupBuilder<'group, 'env, R> {
ThreadGroupBuilder {
group: self,
builder: thread::Builder::new(),
}
}
}
impl<R> fmt::Debug for ThreadGroup<'_, R> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.pad("ThreadGroup { .. }")
}
}
#[derive(Debug)]
pub struct ThreadGroupBuilder<'group, 'env, R> {
group: &'group ThreadGroup<'env, R>,
builder: thread::Builder,
}
impl<'group, 'env, R: 'env + Send> ThreadGroupBuilder<'group, 'env, R> {
pub fn name(mut self, name: String) -> Self {
self.builder = self.builder.name(name);
self
}
pub fn stack_size(mut self, size: usize) -> Self {
self.builder = self.builder.stack_size(size);
self
}
pub fn spawn<F, T>(self, f: F) -> io::Result<()>
where
F: FnOnce(&ThreadGroup<'env, R>) -> T,
F: Send + 'env,
T: Send + 'env,
{
let result = SharedOption::default();
let result = Arc::clone(&result);
let group = ThreadGroup::<'env> {
handles: Arc::clone(&self.group.handles),
wait_group: self.group.wait_group.clone(),
res: None,
_marker: PhantomData,
};
let closure = move || {
let group: ThreadGroup<'env, R> = group;
match catch_unwind(panic::AssertUnwindSafe(|| f(&group))) {
Ok(res) => *result.lock().unwrap() = Some(res),
Err(err) => {
group.wait_group.set_panic_id(thread::current().id());
panic::resume_unwind(err);
}
};
};
let closure: Box<dyn FnOnce() + Send + 'env> = Box::new(closure);
let closure: Box<dyn FnOnce() + Send + 'static> = unsafe { mem::transmute(closure) };
let handle = self.builder.spawn(move || closure())?;
let handle = Arc::new(Mutex::new(Some(handle)));
self.group.handles.lock().unwrap().push(handle);
Ok(())
}
}
unsafe impl<T> Send for JoinHandle<'_, T> {}
unsafe impl<T> Sync for JoinHandle<'_, T> {}
pub struct JoinHandle<'group, T> {
handle: SharedOption<thread::JoinHandle<()>>,
result: SharedOption<T>,
thread: thread::Thread,
_marker: PhantomData<&'group ()>,
}
impl<T> JoinHandle<'_, T> {
pub fn join(self) -> thread::Result<T> {
let handle = self.handle.lock().unwrap().take().unwrap();
handle
.join()
.map(|()| self.result.lock().unwrap().take().unwrap())
}
pub fn thread(&self) -> &thread::Thread {
&self.thread
}
}
cfg_if! {
if #[cfg(unix)] {
use std::os::unix::thread::{JoinHandleExt, RawPthread};
impl<T> JoinHandleExt for ScopedJoinHandle<'_, T> {
fn as_pthread_t(&self) -> RawPthread {
let handle = self.handle.lock().unwrap();
handle.as_ref().unwrap().as_pthread_t()
}
fn into_pthread_t(self) -> RawPthread {
self.as_pthread_t()
}
}
} else if #[cfg(windows)] {
use std::os::windows::io::{AsRawHandle, IntoRawHandle, RawHandle};
impl<T> AsRawHandle for JoinHandle<'_, T> {
fn as_raw_handle(&self) -> RawHandle {
let handle = self.handle.lock().unwrap();
handle.as_ref().unwrap().as_raw_handle()
}
}
impl<T> IntoRawHandle for JoinHandle<'_, T> {
fn into_raw_handle(self) -> RawHandle {
self.as_raw_handle()
}
}
}
}
impl<T> fmt::Debug for JoinHandle<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.pad("JoinHandle { .. }")
}
}