use super::{current, park, Builder, JoinInner, Result, Thread};
use crate::fmt;
use crate::io;
use crate::marker::PhantomData;
use crate::panic::{catch_unwind, resume_unwind, AssertUnwindSafe};
use crate::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use crate::sync::Arc;
#[stable(feature = "scoped_threads", since = "1.63.0")]
pub struct Scope<'scope, 'env: 'scope> {
data: Arc<ScopeData>,
scope: PhantomData<&'scope mut &'scope ()>,
env: PhantomData<&'env mut &'env ()>,
}
#[stable(feature = "scoped_threads", since = "1.63.0")]
pub struct ScopedJoinHandle<'scope, T>(JoinInner<'scope, T>);
pub(super) struct ScopeData {
num_running_threads: AtomicUsize,
a_thread_panicked: AtomicBool,
main_thread: Thread,
}
impl ScopeData {
pub(super) fn increment_num_running_threads(&self) {
if self.num_running_threads.fetch_add(1, Ordering::Relaxed) > usize::MAX / 2 {
self.decrement_num_running_threads(false);
panic!("too many running threads in thread scope");
}
}
pub(super) fn decrement_num_running_threads(&self, panic: bool) {
if panic {
self.a_thread_panicked.store(true, Ordering::Relaxed);
}
if self.num_running_threads.fetch_sub(1, Ordering::Release) == 1 {
self.main_thread.unpark();
}
}
}
#[track_caller]
#[stable(feature = "scoped_threads", since = "1.63.0")]
pub fn scope<'env, F, T>(f: F) -> T
where
F: for<'scope> FnOnce(&'scope Scope<'scope, 'env>) -> T,
{
let scope = Scope {
data: Arc::new(ScopeData {
num_running_threads: AtomicUsize::new(0),
main_thread: current(),
a_thread_panicked: AtomicBool::new(false),
}),
env: PhantomData,
scope: PhantomData,
};
let result = catch_unwind(AssertUnwindSafe(|| f(&scope)));
while scope.data.num_running_threads.load(Ordering::Acquire) != 0 {
park();
}
match result {
Err(e) => resume_unwind(e),
Ok(_) if scope.data.a_thread_panicked.load(Ordering::Relaxed) => {
panic!("a scoped thread panicked")
}
Ok(result) => result,
}
}
impl<'scope, 'env> Scope<'scope, 'env> {
#[stable(feature = "scoped_threads", since = "1.63.0")]
pub fn spawn<F, T>(&'scope self, f: F) -> ScopedJoinHandle<'scope, T>
where
F: FnOnce() -> T + Send + 'scope,
T: Send + 'scope,
{
Builder::new().spawn_scoped(self, f).expect("failed to spawn thread")
}
}
impl Builder {
#[stable(feature = "scoped_threads", since = "1.63.0")]
pub fn spawn_scoped<'scope, 'env, F, T>(
self,
scope: &'scope Scope<'scope, 'env>,
f: F,
) -> io::Result<ScopedJoinHandle<'scope, T>>
where
F: FnOnce() -> T + Send + 'scope,
T: Send + 'scope,
{
Ok(ScopedJoinHandle(unsafe { self.spawn_unchecked_(f, Some(scope.data.clone())) }?))
}
}
impl<'scope, T> ScopedJoinHandle<'scope, T> {
#[must_use]
#[stable(feature = "scoped_threads", since = "1.63.0")]
pub fn thread(&self) -> &Thread {
&self.0.thread
}
#[stable(feature = "scoped_threads", since = "1.63.0")]
pub fn join(self) -> Result<T> {
self.0.join()
}
#[stable(feature = "scoped_threads", since = "1.63.0")]
pub fn is_finished(&self) -> bool {
Arc::strong_count(&self.0.packet) == 1
}
}
#[stable(feature = "scoped_threads", since = "1.63.0")]
impl fmt::Debug for Scope<'_, '_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Scope")
.field("num_running_threads", &self.data.num_running_threads.load(Ordering::Relaxed))
.field("a_thread_panicked", &self.data.a_thread_panicked.load(Ordering::Relaxed))
.field("main_thread", &self.data.main_thread)
.finish_non_exhaustive()
}
}
#[stable(feature = "scoped_threads", since = "1.63.0")]
impl<'scope, T> fmt::Debug for ScopedJoinHandle<'scope, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ScopedJoinHandle").finish_non_exhaustive()
}
}