use std::{
marker::PhantomData,
panic::{catch_unwind, resume_unwind, AssertUnwindSafe},
sync::{
atomic::{AtomicBool, AtomicUsize, Ordering},
Arc,
},
};
use super::{signal::Signal, utils::is_web_worker_thread, Builder, JoinInner};
pub struct Scope<'scope, 'env: 'scope> {
data: Arc<ScopeData>,
scope: PhantomData<&'scope mut &'scope ()>,
env: PhantomData<&'env mut &'env ()>,
}
pub struct ScopedJoinHandle<'scope, T>(JoinInner<'scope, T>);
pub(crate) struct ScopeData {
num_running_threads: AtomicUsize,
a_thread_panicked: AtomicBool,
signal: Signal,
}
impl ScopeData {
pub(crate) fn increment_num_running_threads(&self) {
if self.num_running_threads.fetch_add(1, Ordering::Relaxed) > usize::MAX / 2 {
self.decrement_num_running_threads(false);
panic!("too many running threads in thread scope");
}
}
pub(crate) fn decrement_num_running_threads(&self, panic: bool) {
if panic {
self.a_thread_panicked.store(true, Ordering::Relaxed);
}
if self.num_running_threads.fetch_sub(1, Ordering::Release) == 1 {
self.signal.signal();
}
}
}
pub fn scope<'env, F, T>(f: F) -> T
where
F: for<'scope> FnOnce(&'scope Scope<'scope, 'env>) -> T,
{
if !is_web_worker_thread() {
panic!("scope is not allowed on the main thread");
}
let scope = Scope {
data: Arc::new(ScopeData {
num_running_threads: AtomicUsize::new(0),
a_thread_panicked: AtomicBool::new(false),
signal: Signal::new(),
}),
env: PhantomData,
scope: PhantomData,
};
let result = catch_unwind(AssertUnwindSafe(|| f(&scope)));
while scope.data.num_running_threads.load(Ordering::Acquire) != 0 {
scope.data.signal.wait();
}
match result {
Err(e) => resume_unwind(e),
Ok(_) if scope.data.a_thread_panicked.load(Ordering::Relaxed) => {
panic!("a scoped thread panicked")
}
Ok(result) => result,
}
}
impl<'scope, 'env> Scope<'scope, 'env> {
pub fn spawn<F, T>(&'scope self, f: F) -> ScopedJoinHandle<'scope, T>
where
F: FnOnce() -> T + Send + 'scope,
T: Send + 'scope,
{
Builder::new().spawn_scoped(self, f).expect("failed to spawn thread")
}
}
impl Builder {
pub fn spawn_scoped<'scope, 'env, F, T>(
self,
scope: &'scope Scope<'scope, 'env>,
f: F,
) -> std::io::Result<ScopedJoinHandle<'scope, T>>
where
F: FnOnce() -> T + Send + 'scope,
T: Send + 'scope,
{
Ok(ScopedJoinHandle(unsafe {
self.spawn_unchecked_(f, Some(scope.data.clone()))
}?))
}
}
impl<'scope, T> ScopedJoinHandle<'scope, T> {
pub fn join(self) -> super::Result<T> {
self.0.join()
}
pub async fn join_async(self) -> super::Result<T> {
self.0.join_async().await
}
}