use std::fmt;
use std::io;
use std::marker::PhantomData;
use std::mem;
use std::panic;
use std::sync::{Arc, Mutex};
use std::thread;
use sync::WaitGroup;
type SharedVec<T> = Arc<Mutex<Vec<T>>>;
type SharedOption<T> = Arc<Mutex<Option<T>>>;
pub fn scope<'env, F, R>(f: F) -> thread::Result<R>
where
F: FnOnce(&Scope<'env>) -> R,
{
let wg = WaitGroup::new();
let scope = Scope::<'env> {
handles: SharedVec::default(),
wait_group: wg.clone(),
_marker: PhantomData,
};
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| f(&scope)));
drop(scope.wait_group);
wg.wait();
let panics: Vec<_> = {
let mut handles = scope.handles.lock().unwrap();
let panics = handles
.drain(..)
.filter_map(|handle| handle.lock().unwrap().take())
.filter_map(|handle| handle.join().err())
.collect();
panics
};
match result {
Err(err) => panic::resume_unwind(err),
Ok(res) => {
if panics.is_empty() {
Ok(res)
} else {
Err(Box::new(panics))
}
}
}
}
pub struct Scope<'env> {
handles: SharedVec<SharedOption<thread::JoinHandle<()>>>,
wait_group: WaitGroup,
_marker: PhantomData<&'env mut &'env ()>,
}
unsafe impl<'env> Sync for Scope<'env> {}
impl<'env> Scope<'env> {
pub fn spawn<'scope, F, T>(&'scope self, f: F) -> ScopedJoinHandle<'scope, T>
where
F: FnOnce(&Scope<'env>) -> T,
F: Send + 'env,
T: Send + 'env,
{
self.builder().spawn(f).unwrap()
}
pub fn builder<'scope>(&'scope self) -> ScopedThreadBuilder<'scope, 'env> {
ScopedThreadBuilder {
scope: self,
builder: thread::Builder::new(),
}
}
}
impl<'env> fmt::Debug for Scope<'env> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.pad("Scope { .. }")
}
}
#[derive(Debug)]
pub struct ScopedThreadBuilder<'scope, 'env: 'scope> {
scope: &'scope Scope<'env>,
builder: thread::Builder,
}
impl<'scope, 'env> ScopedThreadBuilder<'scope, 'env> {
pub fn name(mut self, name: String) -> ScopedThreadBuilder<'scope, 'env> {
self.builder = self.builder.name(name);
self
}
pub fn stack_size(mut self, size: usize) -> ScopedThreadBuilder<'scope, 'env> {
self.builder = self.builder.stack_size(size);
self
}
pub fn spawn<F, T>(self, f: F) -> io::Result<ScopedJoinHandle<'scope, T>>
where
F: FnOnce(&Scope<'env>) -> T,
F: Send + 'env,
T: Send + 'env,
{
let result = SharedOption::default();
let (handle, thread) = {
let result = Arc::clone(&result);
let scope = Scope::<'env> {
handles: Arc::clone(&self.scope.handles),
wait_group: self.scope.wait_group.clone(),
_marker: PhantomData,
};
let handle = {
let closure = move || {
let scope: Scope<'env> = scope;
let res = f(&scope);
*result.lock().unwrap() = Some(res);
};
let mut closure = Some(closure);
let closure = move || closure.take().unwrap()();
let closure: Box<dyn FnMut() + Send + 'env> = Box::new(closure);
let closure: Box<dyn FnMut() + Send + 'static> = unsafe { mem::transmute(closure) };
let mut closure = closure;
self.builder.spawn(move || closure())?
};
let thread = handle.thread().clone();
let handle = Arc::new(Mutex::new(Some(handle)));
(handle, thread)
};
self.scope.handles.lock().unwrap().push(Arc::clone(&handle));
Ok(ScopedJoinHandle {
handle,
result,
thread,
_marker: PhantomData,
})
}
}
unsafe impl<'scope, T> Send for ScopedJoinHandle<'scope, T> {}
unsafe impl<'scope, T> Sync for ScopedJoinHandle<'scope, T> {}
pub struct ScopedJoinHandle<'scope, T> {
handle: SharedOption<thread::JoinHandle<()>>,
result: SharedOption<T>,
thread: thread::Thread,
_marker: PhantomData<&'scope ()>,
}
impl<'scope, T> ScopedJoinHandle<'scope, T> {
pub fn join(self) -> thread::Result<T> {
let handle = self.handle.lock().unwrap().take().unwrap();
handle
.join()
.map(|()| self.result.lock().unwrap().take().unwrap())
}
pub fn thread(&self) -> &thread::Thread {
&self.thread
}
}
impl<'scope, T> fmt::Debug for ScopedJoinHandle<'scope, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.pad("ScopedJoinHandle { .. }")
}
}