use std::fmt;
use std::io;
use std::marker::PhantomData;
use std::mem;
use std::panic;
use std::sync::{Arc, Mutex};
use std::thread;
use crate::sync::WaitGroup;
use cfg_if::cfg_if;
type SharedVec<T> = Arc<Mutex<Vec<T>>>;
type SharedOption<T> = Arc<Mutex<Option<T>>>;
pub fn scope<'env, F, R>(f: F) -> thread::Result<R>
where
F: FnOnce(&Scope<'env>) -> R,
{
let wg = WaitGroup::new();
let scope = Scope::<'env> {
handles: SharedVec::default(),
wait_group: wg.clone(),
_marker: PhantomData,
};
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| f(&scope)));
drop(scope.wait_group);
wg.wait();
let panics: Vec<_> = scope
.handles
.lock()
.unwrap()
.drain(..)
.filter_map(|handle| handle.lock().unwrap().take())
.filter_map(|handle| handle.join().err())
.collect();
match result {
Err(err) => panic::resume_unwind(err),
Ok(res) => {
if panics.is_empty() {
Ok(res)
} else {
Err(Box::new(panics))
}
}
}
}
pub struct Scope<'env> {
handles: SharedVec<SharedOption<thread::JoinHandle<()>>>,
wait_group: WaitGroup,
_marker: PhantomData<&'env mut &'env ()>,
}
unsafe impl Sync for Scope<'_> {}
impl<'env> Scope<'env> {
pub fn spawn<'scope, F, T>(&'scope self, f: F) -> ScopedJoinHandle<'scope, T>
where
F: FnOnce(&Scope<'env>) -> T,
F: Send + 'env,
T: Send + 'env,
{
self.builder()
.spawn(f)
.expect("failed to spawn scoped thread")
}
pub fn builder<'scope>(&'scope self) -> ScopedThreadBuilder<'scope, 'env> {
ScopedThreadBuilder {
scope: self,
builder: thread::Builder::new(),
}
}
}
impl fmt::Debug for Scope<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.pad("Scope { .. }")
}
}
#[derive(Debug)]
pub struct ScopedThreadBuilder<'scope, 'env> {
scope: &'scope Scope<'env>,
builder: thread::Builder,
}
impl<'scope, 'env> ScopedThreadBuilder<'scope, 'env> {
pub fn name(mut self, name: String) -> ScopedThreadBuilder<'scope, 'env> {
self.builder = self.builder.name(name);
self
}
pub fn stack_size(mut self, size: usize) -> ScopedThreadBuilder<'scope, 'env> {
self.builder = self.builder.stack_size(size);
self
}
pub fn spawn<F, T>(self, f: F) -> io::Result<ScopedJoinHandle<'scope, T>>
where
F: FnOnce(&Scope<'env>) -> T,
F: Send + 'env,
T: Send + 'env,
{
let result = SharedOption::default();
let (handle, thread) = {
let result = Arc::clone(&result);
let scope = Scope::<'env> {
handles: Arc::clone(&self.scope.handles),
wait_group: self.scope.wait_group.clone(),
_marker: PhantomData,
};
let handle = {
let closure = move || {
let scope: Scope<'env> = scope;
let res = f(&scope);
*result.lock().unwrap() = Some(res);
};
let closure: Box<dyn FnOnce() + Send + 'env> = Box::new(closure);
let closure: Box<dyn FnOnce() + Send + 'static> =
unsafe { mem::transmute(closure) };
self.builder.spawn(move || closure())?
};
let thread = handle.thread().clone();
let handle = Arc::new(Mutex::new(Some(handle)));
(handle, thread)
};
self.scope.handles.lock().unwrap().push(Arc::clone(&handle));
Ok(ScopedJoinHandle {
handle,
result,
thread,
_marker: PhantomData,
})
}
}
unsafe impl<T> Send for ScopedJoinHandle<'_, T> {}
unsafe impl<T> Sync for ScopedJoinHandle<'_, T> {}
pub struct ScopedJoinHandle<'scope, T> {
handle: SharedOption<thread::JoinHandle<()>>,
result: SharedOption<T>,
thread: thread::Thread,
_marker: PhantomData<&'scope ()>,
}
impl<T> ScopedJoinHandle<'_, T> {
pub fn join(self) -> thread::Result<T> {
let handle = self.handle.lock().unwrap().take().unwrap();
handle
.join()
.map(|()| self.result.lock().unwrap().take().unwrap())
}
pub fn thread(&self) -> &thread::Thread {
&self.thread
}
}
cfg_if! {
if #[cfg(unix)] {
use std::os::unix::thread::{JoinHandleExt, RawPthread};
impl<T> JoinHandleExt for ScopedJoinHandle<'_, T> {
fn as_pthread_t(&self) -> RawPthread {
let handle = self.handle.lock().unwrap();
handle.as_ref().unwrap().as_pthread_t()
}
fn into_pthread_t(self) -> RawPthread {
self.as_pthread_t()
}
}
} else if #[cfg(windows)] {
use std::os::windows::io::{AsRawHandle, IntoRawHandle, RawHandle};
impl<T> AsRawHandle for ScopedJoinHandle<'_, T> {
fn as_raw_handle(&self) -> RawHandle {
let handle = self.handle.lock().unwrap();
handle.as_ref().unwrap().as_raw_handle()
}
}
impl<T> IntoRawHandle for ScopedJoinHandle<'_, T> {
fn into_raw_handle(self) -> RawHandle {
self.as_raw_handle()
}
}
}
}
impl<T> fmt::Debug for ScopedJoinHandle<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.pad("ScopedJoinHandle { .. }")
}
}