use std::any::{Any, TypeId};
use std::cell::{Cell, RefCell};
use std::collections::{HashMap, VecDeque};
use std::fmt;
use std::future::Future;
use std::rc::{Rc, Weak};
use auralis_signal::{Memo, Signal};
use crate::executor;
use crate::Priority;
type ScopeId = u64;
type TaskId = u64;
thread_local! {
static NEXT_SCOPE_ID: Cell<ScopeId> = const { Cell::new(1) };
}
fn alloc_scope_id() -> ScopeId {
NEXT_SCOPE_ID.with(|c| {
let id = c.get();
c.set(id + 1);
id
})
}
pub struct CallbackHandle {
cleanup: Option<Box<dyn FnOnce() + 'static>>,
}
impl CallbackHandle {
pub fn new(cleanup: impl FnOnce() + 'static) -> Self {
Self {
cleanup: Some(Box::new(cleanup)),
}
}
#[must_use]
pub fn noop() -> Self {
Self { cleanup: None }
}
}
impl Drop for CallbackHandle {
fn drop(&mut self) {
if let Some(f) = self.cleanup.take() {
let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(f));
}
}
}
type ScopeRegistryEntry = (Weak<RefCell<TaskScopeInner>>, Weak<Cell<bool>>);
thread_local! {
static SCOPE_REGISTRY: RefCell<HashMap<ScopeId, ScopeRegistryEntry>> =
RefCell::new(HashMap::new());
}
fn register_scope(id: ScopeId, inner: &Rc<RefCell<TaskScopeInner>>, suspended: &Rc<Cell<bool>>) {
let _ = SCOPE_REGISTRY.try_with(|reg| {
if let Ok(mut r) = reg.try_borrow_mut() {
r.insert(id, (Rc::downgrade(inner), Rc::downgrade(suspended)));
}
});
}
fn unregister_scope(id: ScopeId) {
let _ = SCOPE_REGISTRY.try_with(|reg| {
if let Ok(mut r) = reg.try_borrow_mut() {
r.remove(&id);
}
});
}
#[must_use]
pub fn find_scope(scope_id: ScopeId) -> Option<TaskScope> {
SCOPE_REGISTRY
.try_with(|reg| {
if let Ok(r) = reg.try_borrow() {
r.get(&scope_id).and_then(|(inner_weak, suspended_weak)| {
let inner = inner_weak.upgrade()?;
let suspended = suspended_weak.upgrade()?;
let cancelled = inner.borrow().cancelled.clone();
Some(TaskScope {
inner,
cancelled,
suspended,
})
})
} else {
None
}
})
.ok()
.flatten()
}
#[cfg(feature = "debug")]
#[doc(hidden)]
#[must_use]
pub fn scope_debug_label(scope_id: ScopeId) -> Option<String> {
find_scope(scope_id).and_then(|s| s.inner.borrow().label.clone())
}
#[doc(hidden)]
pub fn clear_scope_registry() {
let _ = SCOPE_REGISTRY.try_with(|reg| {
if let Ok(mut r) = reg.try_borrow_mut() {
r.clear();
}
});
}
type ScopeSetFn = fn(Option<TaskScope>);
type ScopeGetFn = fn() -> Option<TaskScope>;
#[derive(Debug)]
pub struct ScopeStore {
pub set_fn: ScopeSetFn,
pub get_fn: ScopeGetFn,
}
use std::sync::OnceLock;
static SCOPE_STORE: OnceLock<ScopeStore> = OnceLock::new();
fn ensure_default_store() -> &'static ScopeStore {
SCOPE_STORE.get_or_init(|| ScopeStore {
set_fn: thread_local_set,
get_fn: thread_local_get,
})
}
pub fn set_scope_store(store: ScopeStore) -> Result<(), ScopeStore> {
SCOPE_STORE.set(store)
}
thread_local! {
static CURRENT_SCOPE: RefCell<Option<TaskScope>> = const { RefCell::new(None) };
}
fn thread_local_set(scope: Option<TaskScope>) {
CURRENT_SCOPE.with(|cell| {
cell.replace(scope);
});
}
fn thread_local_get() -> Option<TaskScope> {
CURRENT_SCOPE.with(|cell| cell.borrow().clone())
}
pub(crate) fn set_scope_direct(scope: Option<TaskScope>) {
let store = ensure_default_store();
(store.set_fn)(scope);
}
pub(crate) fn get_scope_direct() -> Option<TaskScope> {
let store = ensure_default_store();
(store.get_fn)()
}
#[cfg(feature = "ssr-tokio")]
pub fn init_scope_store_tokio() {
tokio::task_local! {
static TK_SCOPE: std::cell::RefCell<Option<TaskScope>>;
}
let _ = TK_SCOPE.try_with(|cell| {
cell.replace(None);
});
set_scope_store(ScopeStore {
set_fn: |s| {
let _ = TK_SCOPE.try_with(|cell| {
cell.replace(s);
});
},
get_fn: || {
TK_SCOPE
.try_with(|cell| cell.borrow().clone())
.ok()
.flatten()
},
})
.expect("init_scope_store_tokio must be called BEFORE any scope operations");
}
pub fn with_current_scope<R>(scope: &TaskScope, f: impl FnOnce() -> R) -> R {
let store = ensure_default_store();
let prev = (store.get_fn)();
(store.set_fn)(Some(scope.clone_inner()));
let result = f();
(store.set_fn)(prev);
result
}
#[must_use]
pub fn current_scope() -> Option<TaskScope> {
let store = ensure_default_store();
(store.get_fn)()
}
struct TaskScopeInner {
id: ScopeId,
task_ids: Vec<TaskId>,
children: Vec<TaskScope>,
parent: Option<Weak<RefCell<TaskScopeInner>>>,
context: RefCell<HashMap<TypeId, Rc<dyn Any>>>,
callbacks: RefCell<Vec<CallbackHandle>>,
cancelled: Rc<Cell<bool>>,
label: Option<String>,
executor: executor::ExecutorRef,
}
pub struct JoinHandle {
task_id: Option<TaskId>,
executor: executor::ExecutorRef,
}
impl JoinHandle {
pub fn cancel(&self) {
if let Some(tid) = self.task_id {
executor::cancel_task(&self.executor, tid);
}
}
#[must_use]
pub fn is_finished(&self) -> bool {
match self.task_id {
Some(tid) => executor::is_task_finished(&self.executor, tid),
None => true,
}
}
#[must_use]
pub fn task_id(&self) -> Option<TaskId> {
self.task_id
}
}
impl fmt::Debug for JoinHandle {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("JoinHandle")
.field("task_id", &self.task_id)
.finish_non_exhaustive()
}
}
#[must_use]
pub struct TaskScope {
inner: Rc<RefCell<TaskScopeInner>>,
cancelled: Rc<Cell<bool>>,
suspended: Rc<Cell<bool>>,
}
impl TaskScope {
pub fn new() -> Self {
Self::with_executor(&executor::current_executor_instance())
}
pub fn with_executor(ex: &executor::ExecutorRef) -> Self {
let cancelled = Rc::new(Cell::new(false));
let inner = Rc::new(RefCell::new(TaskScopeInner {
id: alloc_scope_id(),
task_ids: Vec::new(),
children: Vec::new(),
parent: None,
context: RefCell::new(HashMap::new()),
callbacks: RefCell::new(Vec::new()),
cancelled: Rc::clone(&cancelled),
label: None,
executor: Rc::clone(ex),
}));
let id = inner.borrow().id;
let suspended = Rc::new(Cell::new(false));
register_scope(id, &inner, &suspended);
Self {
inner,
cancelled,
suspended,
}
}
pub fn new_child(parent: &Self) -> Self {
let ex = parent.inner.borrow().executor.clone();
let cancelled = Rc::new(Cell::new(false));
let inner = Rc::new(RefCell::new(TaskScopeInner {
id: alloc_scope_id(),
task_ids: Vec::new(),
children: Vec::new(),
parent: Some(Rc::downgrade(&parent.inner)),
context: RefCell::new(HashMap::new()),
callbacks: RefCell::new(Vec::new()),
cancelled: Rc::clone(&cancelled),
label: None,
executor: ex,
}));
let id = inner.borrow().id;
let suspended = Rc::new(Cell::new(false));
register_scope(id, &inner, &suspended);
let child = Self {
inner,
cancelled,
suspended,
};
parent.inner.borrow_mut().children.push(child.clone_inner());
child
}
pub fn spawn(&self, future: impl Future<Output = ()> + 'static) -> JoinHandle {
self.spawn_with_priority(Priority::Low, future)
}
pub fn spawn_with_priority(
&self,
priority: Priority,
future: impl Future<Output = ()> + 'static,
) -> JoinHandle {
let (cancelled, ex, scope_id) = {
let inner = self.inner.borrow();
(inner.cancelled.get(), Rc::clone(&inner.executor), inner.id)
};
if cancelled {
return JoinHandle {
task_id: None,
executor: ex,
};
}
let task_id = executor::with_executor(&ex, || {
with_current_scope(self, || {
executor::spawn_scoped_on(&ex, priority, scope_id, future)
})
});
self.inner.borrow_mut().task_ids.push(task_id);
JoinHandle {
task_id: Some(task_id),
executor: ex,
}
}
pub fn watch<T: Clone + 'static>(
&self,
sig: &Signal<T>,
f: impl FnMut(&T) + 'static,
) -> JoinHandle {
let s = sig.clone();
let mut f = f;
self.spawn(async move {
loop {
s.changed().await;
f(&s.read());
}
})
}
pub fn watch_effect(&self, effect: impl Fn() + 'static) -> JoinHandle {
let memo = Memo::new(effect);
self.spawn(async move {
loop {
memo.changed().await;
#[allow(clippy::let_unit_value, clippy::ignored_unit_patterns)]
let _ = memo.read();
}
})
}
pub fn register_callback_handle(&self, handle: CallbackHandle) {
let inner = self.inner.borrow();
if inner.cancelled.get() {
return;
}
inner.callbacks.borrow_mut().push(handle);
}
pub fn on_cleanup(&self, f: impl FnOnce() + 'static) {
self.register_callback_handle(CallbackHandle::new(f));
}
pub fn provide<T: 'static>(&self, value: T) {
self.inner
.borrow()
.context
.borrow_mut()
.insert(TypeId::of::<T>(), Rc::new(value));
}
#[must_use]
pub fn consume<T: 'static>(&self) -> Option<Rc<T>> {
let mut current = Some(Rc::clone(&self.inner));
while let Some(inner) = current {
{
let inner_ref = inner.borrow();
let ctx = inner_ref.context.borrow();
if let Some(val) = ctx.get(&TypeId::of::<T>()) {
if let Ok(downcast) = val.clone().downcast::<T>() {
return Some(downcast);
}
}
}
let parent = {
let inner_ref = inner.borrow();
inner_ref.parent.as_ref().and_then(Weak::upgrade)
};
current = parent;
}
None
}
#[must_use]
#[track_caller]
pub fn expect_context<T: 'static>(&self) -> Rc<T> {
self.consume::<T>()
.unwrap_or_else(|| panic!("context not found: {}", std::any::type_name::<T>()))
}
#[must_use]
pub fn is_cancelled(&self) -> bool {
self.cancelled.get()
}
pub fn set_label(&self, label: impl Into<String>) {
self.inner.borrow_mut().label = Some(label.into());
}
#[must_use]
pub fn label(&self) -> Option<String> {
self.inner.borrow().label.clone()
}
#[cfg(feature = "debug")]
#[doc(hidden)]
#[deprecated(note = "use `set_label` instead")]
pub fn set_debug_label(&self, label: impl Into<String>) {
self.set_label(label);
}
#[cfg(test)]
#[must_use]
pub fn task_count(&self) -> usize {
self.inner.borrow().task_ids.len()
}
#[cfg(test)]
#[must_use]
pub fn child_count(&self) -> usize {
self.inner.borrow().children.len()
}
fn clone_inner(&self) -> Self {
Self {
inner: Rc::clone(&self.inner),
cancelled: Rc::clone(&self.cancelled),
suspended: Rc::clone(&self.suspended),
}
}
pub fn enter<R>(&self, f: impl FnOnce() -> R) -> R {
with_current_scope(self, f)
}
pub fn suspend(&self) {
if self.suspended.get() {
return;
}
self.suspended.set(true);
let children: Vec<TaskScope> = {
self.inner
.borrow()
.children
.iter()
.map(TaskScope::clone_inner)
.collect()
};
for child in &children {
child.suspend();
}
}
pub fn resume(&self) {
if !self.suspended.get() {
return;
}
self.suspended.set(false);
let (task_ids, children) = {
let inner = self.inner.borrow();
let tids = inner.task_ids.clone();
let children: Vec<TaskScope> =
inner.children.iter().map(TaskScope::clone_inner).collect();
(tids, children)
};
let ex = Rc::clone(&self.inner.borrow().executor);
executor::enqueue_scope_tasks_on(&ex, &task_ids);
for child in &children {
child.resume();
}
}
#[must_use]
pub fn is_suspended(&self) -> bool {
self.suspended.get()
}
}
impl Default for TaskScope {
fn default() -> Self {
Self::new()
}
}
impl Clone for TaskScope {
fn clone(&self) -> Self {
self.clone_inner()
}
}
impl Drop for TaskScope {
fn drop(&mut self) {
if Rc::strong_count(&self.inner) > 1 {
return;
}
self.cancelled.set(true);
let Ok(mut inner) = self.inner.try_borrow_mut() else {
eprintln!(
"[auralis-task] WARNING: TaskScope::drop cannot borrow inner \
(already borrowed). Tasks and callbacks in this scope will \
be cleaned up on the next executor flush. Avoid dropping \
the last TaskScope clone inside a callback."
);
return;
};
inner.callbacks.borrow_mut().clear();
let mut descendants: Vec<Rc<RefCell<TaskScopeInner>>> = Vec::new();
{
let mut queue: VecDeque<Rc<RefCell<TaskScopeInner>>> = VecDeque::new();
for child in &inner.children {
queue.push_back(Rc::clone(&child.inner));
}
while let Some(scope_rc) = queue.pop_front() {
let scope = scope_rc.borrow();
for child in &scope.children {
queue.push_back(Rc::clone(&child.inner));
}
descendants.push(Rc::clone(&scope_rc));
}
}
for scope_rc in descendants.iter().rev() {
let mut scope = scope_rc.borrow_mut();
if scope.cancelled.get() {
continue;
}
scope.cancelled.set(true);
scope.callbacks.borrow_mut().clear();
if !scope.task_ids.is_empty() {
let ex = Rc::clone(&scope.executor);
let task_ids = std::mem::take(&mut scope.task_ids);
let dropped_futures = executor::cancel_scope_tasks_on(&ex, &task_ids);
drop(dropped_futures);
}
scope.context.borrow_mut().clear();
unregister_scope(scope.id);
}
if !inner.task_ids.is_empty() {
let ex = Rc::clone(&inner.executor);
let task_ids = std::mem::take(&mut inner.task_ids);
let dropped_futures = executor::cancel_scope_tasks_on(&ex, &task_ids);
drop(dropped_futures);
}
inner.context.borrow_mut().clear();
inner.children.clear();
unregister_scope(inner.id);
}
}
#[macro_export]
macro_rules! provide_context {
($scope:expr, $value:expr) => {
$scope.provide($value)
};
}
#[macro_export]
macro_rules! consume_context {
($scope:expr, $ty:ty) => {
$scope.consume::<$ty>()
};
}
#[cfg(feature = "debug")]
#[derive(Debug, Clone, serde::Serialize)]
pub struct ScopeTreeNode {
pub id: ScopeId,
pub label: Option<String>,
pub tasks: Vec<TaskNode>,
pub children: Vec<ScopeTreeNode>,
}
#[cfg(feature = "debug")]
#[derive(Debug, Clone, serde::Serialize)]
pub struct TaskNode {
pub id: TaskId,
pub priority: &'static str,
pub queued: bool,
pub total_poll_count: u64,
pub last_poll_duration_us: u64,
}
#[cfg(feature = "debug")]
fn attach_children(
id: u64,
scope_map: &mut std::collections::HashMap<u64, ScopeTreeNode>,
child_map: &std::collections::HashMap<u64, Vec<u64>>,
) -> ScopeTreeNode {
let mut node = scope_map.remove(&id).unwrap_or(ScopeTreeNode {
id,
label: None,
tasks: Vec::new(),
children: Vec::new(),
});
if let Some(child_ids) = child_map.get(&id) {
let mut child_ids = child_ids.clone();
child_ids.sort_unstable();
for cid in child_ids {
node.children
.push(attach_children(cid, scope_map, child_map));
}
}
node
}
#[cfg(feature = "debug")]
#[must_use]
pub fn scope_tree() -> Vec<ScopeTreeNode> {
use crate::executor;
let task_snap = executor::debug_task_snapshot();
let queued: std::collections::HashSet<u64> =
executor::debug_queued_task_ids().into_iter().collect();
let timing = executor::debug_task_timing();
let mut tasks_by_scope: std::collections::HashMap<u64, Vec<TaskNode>> =
std::collections::HashMap::new();
for (tid, pri, sid) in &task_snap {
let (poll_count, last_us) = timing.get(tid).copied().unwrap_or((0, 0));
tasks_by_scope.entry(*sid).or_default().push(TaskNode {
id: *tid,
priority: match pri {
Priority::High => "H",
Priority::Low => "L",
},
queued: queued.contains(tid),
total_poll_count: poll_count,
last_poll_duration_us: last_us,
});
}
let mut scope_map: std::collections::HashMap<u64, ScopeTreeNode> =
std::collections::HashMap::new();
let _ = SCOPE_REGISTRY.try_with(|reg| {
if let Ok(r) = reg.try_borrow() {
for (&id, (inner_weak, _suspended_weak)) in r.iter() {
let Some(inner) = inner_weak.upgrade() else {
continue;
};
let b = inner.borrow();
scope_map.insert(
id,
ScopeTreeNode {
id,
label: b.label.clone(),
tasks: tasks_by_scope.remove(&id).unwrap_or_default(),
children: Vec::new(), },
);
}
}
});
let mut roots: Vec<u64> = Vec::new();
let mut child_map: std::collections::HashMap<u64, Vec<u64>> = std::collections::HashMap::new();
let _ = SCOPE_REGISTRY.try_with(|reg| {
if let Ok(r) = reg.try_borrow() {
for (&id, (inner_weak, _)) in r.iter() {
let Some(inner) = inner_weak.upgrade() else {
continue;
};
let b = inner.borrow();
let has_live_parent = b.parent.as_ref().and_then(Weak::upgrade).is_some();
if !has_live_parent {
roots.push(id);
} else if let Some(p) = b.parent.as_ref().and_then(Weak::upgrade) {
child_map.entry(p.borrow().id).or_default().push(id);
}
}
}
});
for node in scope_map.values_mut() {
node.tasks.sort_by_key(|t| t.id);
}
let mut tree = Vec::new();
roots.sort_unstable();
for rid in roots {
tree.push(attach_children(rid, &mut scope_map, &child_map));
}
let remaining: Vec<u64> = {
let mut ids: Vec<u64> = scope_map.keys().copied().collect();
ids.sort_unstable();
ids
};
for id in remaining {
tree.push(attach_children(id, &mut scope_map, &child_map));
}
tree
}
#[cfg(test)]
#[allow(clippy::items_after_statements)]
#[path = "scope_tests.rs"]
mod tests;