#![doc(test(
no_crate_inject,
attr(
deny(warnings, rust_2018_idioms),
allow(dead_code, unused_assignments, unused_variables)
)
))]
#![deny(missing_docs, missing_debug_implementations, rust_2018_idioms)]
use std::fmt;
use std::io;
use std::marker::PhantomData;
use std::mem;
use std::panic;
use std::sync::{Arc, Mutex};
use std::thread;
use cfg_if::cfg_if;
use crossbeam_utils::sync::WaitGroup;
use std::sync::atomic::{AtomicBool, Ordering};
type SharedVec<T> = Arc<Mutex<Vec<T>>>;
type SharedOption<T> = Arc<Mutex<Option<T>>>;
pub fn scope<'env, F, R>(f: F) -> thread::Result<R>
where
F: FnOnce(&Context<'env>) -> R,
{
let wg = WaitGroup::new();
let ctx = Context::<'env> {
done: Arc::new(AtomicBool::new(false)),
handles: SharedVec::default(),
wait_group: wg.clone(),
_marker: PhantomData,
};
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| f(&ctx)));
if result.is_err() {
ctx.cancel();
}
drop(ctx.wait_group);
wg.wait();
let panics: Vec<_> = ctx
.handles
.lock()
.unwrap()
.drain(..)
.filter_map(|handle| handle.lock().unwrap().take())
.filter_map(|handle| handle.join().err())
.collect();
match result {
Err(err) => panic::resume_unwind(err),
Ok(res) => {
if panics.is_empty() {
Ok(res)
} else {
Err(Box::new(panics))
}
}
}
}
pub struct Context<'env> {
done: Arc<AtomicBool>,
handles: SharedVec<SharedOption<thread::JoinHandle<()>>>,
wait_group: WaitGroup,
_marker: PhantomData<&'env mut &'env ()>,
}
unsafe impl Sync for Context<'_> {}
impl<'env> Context<'env> {
pub fn done(&self) -> bool {
self.done.load(Ordering::Relaxed)
}
pub fn cancel(&self) {
self.done.store(true, Ordering::Relaxed)
}
pub fn active(&self) -> bool {
!self.done()
}
pub fn spawn<'scope, F, T>(&'scope self, f: F) -> ContextJoinHandle<'scope, T>
where
F: FnOnce(&Context<'env>) -> T,
F: Send + 'env,
T: Send + 'env,
{
self.builder()
.spawn(|ctx| {
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| f(ctx)));
if let Err(e) = result {
ctx.cancel();
panic::resume_unwind(e)
}
result.unwrap()
})
.expect("failed to spawn scoped thread")
}
pub fn builder<'scope>(&'scope self) -> ContextThreadBuilder<'scope, 'env> {
ContextThreadBuilder {
scope: self,
builder: thread::Builder::new(),
}
}
}
impl fmt::Debug for Context<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.pad("Scope { .. }")
}
}
#[derive(Debug)]
pub struct ContextThreadBuilder<'scope, 'env> {
scope: &'scope Context<'env>,
builder: thread::Builder,
}
impl<'scope, 'env> ContextThreadBuilder<'scope, 'env> {
pub fn name(mut self, name: String) -> ContextThreadBuilder<'scope, 'env> {
self.builder = self.builder.name(name);
self
}
pub fn stack_size(mut self, size: usize) -> ContextThreadBuilder<'scope, 'env> {
self.builder = self.builder.stack_size(size);
self
}
pub fn spawn<F, T>(self, f: F) -> io::Result<ContextJoinHandle<'scope, T>>
where
F: FnOnce(&Context<'env>) -> T,
F: Send + 'env,
T: Send + 'env,
{
let result = SharedOption::default();
let (handle, thread) = {
let result = Arc::clone(&result);
let ctx = Context::<'env> {
done: self.scope.done.clone(),
handles: Arc::clone(&self.scope.handles),
wait_group: self.scope.wait_group.clone(),
_marker: PhantomData,
};
let handle = {
let closure = move || {
let scope: Context<'env> = ctx;
let res = f(&scope);
*result.lock().unwrap() = Some(res);
};
let closure: Box<dyn FnOnce() + Send + 'env> = Box::new(closure);
let closure: Box<dyn FnOnce() + Send + 'static> =
unsafe { mem::transmute(closure) };
self.builder.spawn(move || closure())?
};
let thread = handle.thread().clone();
let handle = Arc::new(Mutex::new(Some(handle)));
(handle, thread)
};
self.scope.handles.lock().unwrap().push(Arc::clone(&handle));
Ok(ContextJoinHandle {
handle,
result,
thread,
_marker: PhantomData,
})
}
}
unsafe impl<T> Send for ContextJoinHandle<'_, T> {}
unsafe impl<T> Sync for ContextJoinHandle<'_, T> {}
pub struct ContextJoinHandle<'scope, T> {
handle: SharedOption<thread::JoinHandle<()>>,
result: SharedOption<T>,
thread: thread::Thread,
_marker: PhantomData<&'scope ()>,
}
impl<T> ContextJoinHandle<'_, T> {
pub fn join(self) -> thread::Result<T> {
let handle = self.handle.lock().unwrap().take().unwrap();
handle
.join()
.map(|()| self.result.lock().unwrap().take().unwrap())
}
pub fn thread(&self) -> &thread::Thread {
&self.thread
}
}
cfg_if! {
if #[cfg(unix)] {
use std::os::unix::thread::{JoinHandleExt, RawPthread};
impl<T> JoinHandleExt for ContextJoinHandle<'_, T> {
fn as_pthread_t(&self) -> RawPthread {
let handle = self.handle.lock().unwrap();
handle.as_ref().unwrap().as_pthread_t()
}
fn into_pthread_t(self) -> RawPthread {
self.as_pthread_t()
}
}
} else if #[cfg(windows)] {
use std::os::windows::io::{AsRawHandle, IntoRawHandle, RawHandle};
impl<T> AsRawHandle for ContextJoinHandle<'_, T> {
fn as_raw_handle(&self) -> RawHandle {
let handle = self.handle.lock().unwrap();
handle.as_ref().unwrap().as_raw_handle()
}
}
#[cfg(windows)]
impl<T> IntoRawHandle for ContextJoinHandle<'_, T> {
fn into_raw_handle(self) -> RawHandle {
self.as_raw_handle()
}
}
}
}
impl<T> fmt::Debug for ContextJoinHandle<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.pad(&format!(
"ScopedJoinHandle {{ name: {:?} }}",
self.thread.name()
))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cancellation_nested() {
scope(|ctx| {
ctx.spawn(|ctx| while !ctx.done() {});
ctx.spawn(|ctx| {
while ctx.active() {
ctx.spawn(|ctx| ctx.cancel());
}
});
})
.unwrap()
}
#[test]
#[should_panic]
fn test_panic_cancellation() {
scope(|ctx| {
ctx.spawn(|_| panic!());
assert!(ctx.active())
})
.unwrap()
}
}