Skip to main content

aura_agent/
task_registry.rs

1//! Structured task supervision for agent background work.
2//!
3//! This module provides a root supervisor plus named task groups. Tasks are
4//! owned by a group, inherit cancellation from their ancestors, and must exit
5//! before the group is considered drained.
6
7#![allow(clippy::disallowed_types)]
8
9use std::collections::BTreeMap;
10use std::future::Future;
11use std::panic::AssertUnwindSafe;
12use std::sync::atomic::{AtomicU64, Ordering};
13use std::sync::Arc;
14use std::time::Duration;
15
16use crate::runtime::{
17    RuntimeDiagnostic, RuntimeDiagnosticKind, RuntimeDiagnosticSeverity, RuntimeDiagnosticSink,
18};
19use aura_core::effects::task::{CancellationToken, TaskSpawner};
20use aura_core::effects::PhysicalTimeEffects;
21use aura_core::{
22    execute_with_timeout_budget, OwnedShutdownToken, OwnedTaskHandle, TimeoutBudget,
23    TimeoutRunError,
24};
25use aura_effects::time::PhysicalTimeHandler;
26use futures::future::{BoxFuture, LocalBoxFuture};
27use futures::FutureExt;
28#[cfg(not(target_arch = "wasm32"))]
29use parking_lot::Mutex;
30#[cfg(target_arch = "wasm32")]
31use parking_lot::Mutex;
32use tokio::sync::watch;
33use tokio::sync::Notify;
34#[cfg(not(target_arch = "wasm32"))]
35use tokio::task::JoinHandle;
36#[cfg(target_arch = "wasm32")]
37use wasm_bindgen_futures::spawn_local;
38
39const DEFAULT_TASK_NAME: &str = "task.default";
40
41#[derive(Debug, Clone, PartialEq, Eq)]
42pub enum TaskSupervisionError {
43    Timeout {
44        group: String,
45        active_tasks: Vec<String>,
46    },
47    ForcedAbort {
48        group: String,
49        aborted_tasks: Vec<String>,
50    },
51    Cancelled {
52        group: String,
53        task: String,
54    },
55    Panicked {
56        group: String,
57        task: String,
58    },
59}
60
61impl std::fmt::Display for TaskSupervisionError {
62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63        match self {
64            Self::Timeout {
65                group,
66                active_tasks,
67            } => write!(
68                f,
69                "task group '{group}' timed out waiting for tasks: {}",
70                active_tasks.join(", ")
71            ),
72            Self::ForcedAbort {
73                group,
74                aborted_tasks,
75            } => write!(
76                f,
77                "task group '{group}' force-aborted tasks: {}",
78                aborted_tasks.join(", ")
79            ),
80            Self::Cancelled { group, task } => {
81                write!(f, "task '{task}' in group '{group}' was cancelled")
82            }
83            Self::Panicked { group, task } => {
84                write!(f, "task '{task}' in group '{group}' panicked")
85            }
86        }
87    }
88}
89
90impl std::error::Error for TaskSupervisionError {}
91
92#[derive(Debug, Clone, PartialEq, Eq)]
93enum TaskOutcome {
94    Completed,
95    Cancelled,
96    Panicked,
97}
98
99#[derive(Debug)]
100struct TaskMetadata {
101    task_name: String,
102    #[cfg(not(target_arch = "wasm32"))]
103    handle: Option<JoinHandle<()>>,
104}
105
106struct TaskGroupShared {
107    name: String,
108    next_task_id: AtomicU64,
109    shutdown_tx: watch::Sender<bool>,
110    inherited_cancellation: Option<Arc<dyn CancellationToken>>,
111    diagnostics: Option<Arc<RuntimeDiagnosticSink>>,
112    tasks: Mutex<BTreeMap<u64, TaskMetadata>>,
113    notify: Arc<Notify>,
114}
115
116#[derive(Clone)]
117pub struct TaskGroup {
118    shared: Arc<TaskGroupShared>,
119}
120
121#[derive(Clone)]
122pub struct TaskSupervisor {
123    root: TaskGroup,
124}
125
126impl TaskSupervisor {
127    pub fn new() -> Self {
128        Self {
129            root: TaskGroup::root("runtime", None),
130        }
131    }
132
133    pub fn with_diagnostics(diagnostics: Arc<RuntimeDiagnosticSink>) -> Self {
134        Self {
135            root: TaskGroup::root("runtime", Some(diagnostics)),
136        }
137    }
138
139    pub fn group(&self, name: impl Into<String>) -> TaskGroup {
140        self.root.group(name)
141    }
142
143    #[must_use = "retain or explicitly discard the owned task handle"]
144    pub fn spawn_named<F>(&self, name: impl Into<String>, fut: F) -> OwnedTaskHandle<u64>
145    where
146        F: Future<Output = ()> + Send + 'static,
147    {
148        self.root.spawn_named(name, fut)
149    }
150
151    #[must_use = "retain or explicitly discard the owned task handle"]
152    pub fn spawn_cancellable_named<F>(
153        &self,
154        name: impl Into<String>,
155        fut: F,
156    ) -> OwnedTaskHandle<u64>
157    where
158        F: Future<Output = ()> + Send + 'static,
159    {
160        self.root.spawn_cancellable_named(name, fut)
161    }
162
163    #[must_use = "retain or explicitly discard the owned task handle"]
164    pub fn spawn_local_named<F>(&self, name: impl Into<String>, fut: F) -> OwnedTaskHandle<u64>
165    where
166        F: Future<Output = ()> + 'static,
167    {
168        self.root.spawn_local_named(name, fut)
169    }
170
171    #[must_use = "retain or explicitly discard the owned task handle"]
172    pub fn spawn_local_cancellable_named<F>(
173        &self,
174        name: impl Into<String>,
175        fut: F,
176    ) -> OwnedTaskHandle<u64>
177    where
178        F: Future<Output = ()> + 'static,
179    {
180        self.root.spawn_local_cancellable_named(name, fut)
181    }
182
183    #[must_use = "retain or explicitly discard the owned task handle"]
184    pub fn spawn_interval_until_named<F, Fut>(
185        &self,
186        name: impl Into<String>,
187        time_effects: Arc<dyn PhysicalTimeEffects + Send + Sync>,
188        interval: Duration,
189        f: F,
190    ) -> OwnedTaskHandle<u64>
191    where
192        F: FnMut() -> Fut + Send + 'static,
193        Fut: Future<Output = bool> + Send + 'static,
194    {
195        self.root
196            .spawn_interval_until_named(name, time_effects, interval, f)
197    }
198
199    #[must_use = "retain or explicitly discard the owned task handle"]
200    pub fn spawn_local_interval_until_named<F, Fut>(
201        &self,
202        name: impl Into<String>,
203        time_effects: Arc<dyn PhysicalTimeEffects + Send + Sync>,
204        interval: Duration,
205        f: F,
206    ) -> OwnedTaskHandle<u64>
207    where
208        F: FnMut() -> Fut + 'static,
209        Fut: Future<Output = bool> + 'static,
210    {
211        self.root
212            .spawn_local_interval_until_named(name, time_effects, interval, f)
213    }
214
215    #[must_use = "retain or explicitly discard the owned task handle"]
216    pub fn spawn_child<F>(&self, name: impl Into<String>, fut: F) -> OwnedTaskHandle<u64>
217    where
218        F: Future<Output = ()> + Send + 'static,
219    {
220        self.spawn_named(name, fut)
221    }
222
223    #[must_use = "retain or explicitly discard the owned task handle"]
224    pub fn spawn_periodic<F, Fut>(
225        &self,
226        name: impl Into<String>,
227        time_effects: Arc<dyn PhysicalTimeEffects + Send + Sync>,
228        interval: Duration,
229        f: F,
230    ) -> OwnedTaskHandle<u64>
231    where
232        F: FnMut() -> Fut + Send + 'static,
233        Fut: Future<Output = bool> + Send + 'static,
234    {
235        self.spawn_interval_until_named(name, time_effects, interval, f)
236    }
237
238    pub fn request_cancellation(&self) {
239        self.root.request_cancellation();
240    }
241
242    pub async fn wait_for_idle(&self, timeout: Duration) -> Result<(), TaskSupervisionError> {
243        self.root.wait_for_idle(timeout).await
244    }
245
246    pub fn force_abort_remaining(&self) -> Result<(), TaskSupervisionError> {
247        self.root.force_abort_remaining()
248    }
249
250    pub fn abort_remaining(&self) -> Result<(), TaskSupervisionError> {
251        self.force_abort_remaining()
252    }
253
254    pub async fn shutdown_with_timeout(
255        &self,
256        timeout: Duration,
257    ) -> Result<(), TaskSupervisionError> {
258        self.root.shutdown_with_timeout(timeout).await
259    }
260
261    pub async fn shutdown_gracefully(&self, timeout: Duration) -> Result<(), TaskSupervisionError> {
262        self.shutdown_with_timeout(timeout).await
263    }
264
265    pub fn shutdown(&self) {
266        self.root.shutdown();
267    }
268
269    pub fn cancellation_token(&self) -> Arc<dyn CancellationToken> {
270        self.root.cancellation_token()
271    }
272
273    pub fn active_tasks(&self) -> Vec<String> {
274        self.root.active_tasks()
275    }
276}
277
278impl Default for TaskSupervisor {
279    fn default() -> Self {
280        Self::new()
281    }
282}
283
284impl Drop for TaskSupervisor {
285    fn drop(&mut self) {
286        self.shutdown();
287    }
288}
289
290impl TaskGroup {
291    fn root(name: impl Into<String>, diagnostics: Option<Arc<RuntimeDiagnosticSink>>) -> Self {
292        let (shutdown_tx, _shutdown_rx) = watch::channel(false);
293        Self {
294            shared: Arc::new(TaskGroupShared {
295                name: name.into(),
296                next_task_id: AtomicU64::new(1),
297                shutdown_tx,
298                inherited_cancellation: None,
299                diagnostics,
300                tasks: Mutex::new(BTreeMap::new()),
301                notify: Arc::new(Notify::new()),
302            }),
303        }
304    }
305
306    pub fn name(&self) -> &str {
307        &self.shared.name
308    }
309
310    pub fn group(&self, name: impl Into<String>) -> TaskGroup {
311        let name = name.into();
312        let full_name = format!("{}.{}", self.shared.name, name);
313        let (shutdown_tx, _shutdown_rx) = watch::channel(false);
314        TaskGroup {
315            shared: Arc::new(TaskGroupShared {
316                name: full_name,
317                next_task_id: AtomicU64::new(1),
318                shutdown_tx,
319                inherited_cancellation: Some(self.cancellation_token()),
320                diagnostics: self.shared.diagnostics.clone(),
321                tasks: Mutex::new(BTreeMap::new()),
322                notify: Arc::new(Notify::new()),
323            }),
324        }
325    }
326
327    #[must_use = "retain or explicitly discard the owned task handle"]
328    pub fn spawn_named<F>(&self, name: impl Into<String>, fut: F) -> OwnedTaskHandle<u64>
329    where
330        F: Future<Output = ()> + Send + 'static,
331    {
332        self.spawn_boxed(name.into(), Box::pin(fut), None)
333    }
334
335    #[must_use = "retain or explicitly discard the owned task handle"]
336    pub fn spawn_cancellable_named<F>(
337        &self,
338        name: impl Into<String>,
339        fut: F,
340    ) -> OwnedTaskHandle<u64>
341    where
342        F: Future<Output = ()> + Send + 'static,
343    {
344        self.spawn_boxed(name.into(), Box::pin(fut), None)
345    }
346
347    #[must_use = "retain or explicitly discard the owned task handle"]
348    pub fn spawn_local_named<F>(&self, name: impl Into<String>, fut: F) -> OwnedTaskHandle<u64>
349    where
350        F: Future<Output = ()> + 'static,
351    {
352        self.spawn_boxed_local(name.into(), Box::pin(fut), None)
353    }
354
355    #[must_use = "retain or explicitly discard the owned task handle"]
356    pub fn spawn_local_cancellable_named<F>(
357        &self,
358        name: impl Into<String>,
359        fut: F,
360    ) -> OwnedTaskHandle<u64>
361    where
362        F: Future<Output = ()> + 'static,
363    {
364        self.spawn_boxed_local(name.into(), Box::pin(fut), None)
365    }
366
367    #[must_use = "retain or explicitly discard the owned task handle"]
368    pub fn spawn_with_token<F>(
369        &self,
370        name: impl Into<String>,
371        fut: F,
372        token: Arc<dyn CancellationToken>,
373    ) -> OwnedTaskHandle<u64>
374    where
375        F: Future<Output = ()> + Send + 'static,
376    {
377        self.spawn_boxed(name.into(), Box::pin(fut), Some(token))
378    }
379
380    #[must_use = "retain or explicitly discard the owned task handle"]
381    pub fn spawn_child<F>(&self, name: impl Into<String>, fut: F) -> OwnedTaskHandle<u64>
382    where
383        F: Future<Output = ()> + Send + 'static,
384    {
385        self.spawn_named(name, fut)
386    }
387
388    #[must_use = "retain or explicitly discard the owned task handle"]
389    pub fn spawn_interval_until_named<F, Fut>(
390        &self,
391        name: impl Into<String>,
392        time_effects: Arc<dyn PhysicalTimeEffects + Send + Sync>,
393        interval: Duration,
394        mut f: F,
395    ) -> OwnedTaskHandle<u64>
396    where
397        F: FnMut() -> Fut + Send + 'static,
398        Fut: Future<Output = bool> + Send + 'static,
399    {
400        let interval_ms = interval.as_millis().try_into().unwrap_or(u64::MAX);
401        self.spawn_boxed(
402            name.into(),
403            Box::pin(async move {
404                loop {
405                    if !f().await {
406                        break;
407                    }
408
409                    if time_effects.sleep_ms(interval_ms).await.is_err() {
410                        break;
411                    }
412                }
413            }),
414            None,
415        )
416    }
417
418    #[must_use = "retain or explicitly discard the owned task handle"]
419    pub fn spawn_local_interval_until_named<F, Fut>(
420        &self,
421        name: impl Into<String>,
422        time_effects: Arc<dyn PhysicalTimeEffects + Send + Sync>,
423        interval: Duration,
424        mut f: F,
425    ) -> OwnedTaskHandle<u64>
426    where
427        F: FnMut() -> Fut + 'static,
428        Fut: Future<Output = bool> + 'static,
429    {
430        let interval_ms = interval.as_millis().try_into().unwrap_or(u64::MAX);
431        self.spawn_boxed_local(
432            name.into(),
433            Box::pin(async move {
434                loop {
435                    if !f().await {
436                        break;
437                    }
438
439                    if time_effects.sleep_ms(interval_ms).await.is_err() {
440                        break;
441                    }
442                }
443            }),
444            None,
445        )
446    }
447
448    #[must_use = "retain or explicitly discard the owned task handle"]
449    pub fn spawn_periodic<F, Fut>(
450        &self,
451        name: impl Into<String>,
452        time_effects: Arc<dyn PhysicalTimeEffects + Send + Sync>,
453        interval: Duration,
454        f: F,
455    ) -> OwnedTaskHandle<u64>
456    where
457        F: FnMut() -> Fut + Send + 'static,
458        Fut: Future<Output = bool> + Send + 'static,
459    {
460        self.spawn_interval_until_named(name, time_effects, interval, f)
461    }
462
463    pub fn request_cancellation(&self) {
464        let _ = self.shared.shutdown_tx.send(true);
465        tracing::debug!(
466            event = "runtime.task_group.cancel_requested",
467            task_group = %self.shared.name,
468            active_tasks = self.active_tasks().len(),
469            "Task group cancellation requested"
470        );
471        self.shared.notify.notify_waiters();
472    }
473
474    pub async fn wait_for_idle(&self, timeout: Duration) -> Result<(), TaskSupervisionError> {
475        let group_name = self.shared.name.clone();
476        let time = PhysicalTimeHandler::new();
477        let started_at = time
478            .physical_time()
479            .await
480            .map_err(|_| TaskSupervisionError::Timeout {
481                group: group_name.clone(),
482                active_tasks: self.active_tasks(),
483            })?;
484        let budget = TimeoutBudget::from_start_and_timeout(&started_at, timeout).map_err(|_| {
485            TaskSupervisionError::Timeout {
486                group: group_name.clone(),
487                active_tasks: self.active_tasks(),
488            }
489        })?;
490        let result = execute_with_timeout_budget(&time, &budget, || async {
491            loop {
492                if self.shared.tasks.lock().is_empty() {
493                    return Ok::<(), ()>(());
494                }
495                self.shared.notify.notified().await;
496            }
497        })
498        .await;
499
500        match result {
501            Ok(()) => Ok(()),
502            Err(TimeoutRunError::Timeout(_)) | Err(TimeoutRunError::Operation(_)) => {
503                Err(TaskSupervisionError::Timeout {
504                    group: group_name,
505                    active_tasks: self.active_tasks(),
506                })
507            }
508        }
509    }
510
511    pub fn force_abort_remaining(&self) -> Result<(), TaskSupervisionError> {
512        let mut tasks = self.shared.tasks.lock();
513        if tasks.is_empty() {
514            return Ok(());
515        }
516
517        let mut aborted_tasks = Vec::with_capacity(tasks.len());
518        #[cfg(not(target_arch = "wasm32"))]
519        for (_, entry) in tasks.iter() {
520            if let Some(handle) = &entry.handle {
521                handle.abort();
522            }
523            aborted_tasks.push(entry.task_name.clone());
524            emit_task_diagnostic(
525                self.shared.diagnostics.as_ref(),
526                RuntimeDiagnosticSeverity::Warn,
527                "task_supervisor",
528                format!(
529                    "force-aborted supervised task '{}' in group '{}'",
530                    entry.task_name, self.shared.name
531                ),
532            );
533            tracing::warn!(
534                event = "runtime.task.abort_forced",
535                task_group = %self.shared.name,
536                task_name = %entry.task_name,
537                "Force-aborted supervised task"
538            );
539        }
540
541        #[cfg(target_arch = "wasm32")]
542        for (_, entry) in tasks.iter() {
543            aborted_tasks.push(entry.task_name.clone());
544        }
545
546        tasks.clear();
547        self.shared.notify.notify_waiters();
548
549        Err(TaskSupervisionError::ForcedAbort {
550            group: self.shared.name.clone(),
551            aborted_tasks,
552        })
553    }
554
555    pub fn abort_remaining(&self) -> Result<(), TaskSupervisionError> {
556        self.force_abort_remaining()
557    }
558
559    pub async fn shutdown_with_timeout(
560        &self,
561        timeout: Duration,
562    ) -> Result<(), TaskSupervisionError> {
563        self.request_cancellation();
564        match self.wait_for_idle(timeout).await {
565            Ok(()) => Ok(()),
566            Err(TaskSupervisionError::Timeout { .. }) => self.force_abort_remaining(),
567            Err(other) => Err(other),
568        }
569    }
570
571    pub async fn shutdown_gracefully(&self, timeout: Duration) -> Result<(), TaskSupervisionError> {
572        self.shutdown_with_timeout(timeout).await
573    }
574
575    pub fn shutdown(&self) {
576        self.request_cancellation();
577        let _ = self.force_abort_remaining();
578    }
579
580    pub fn cancellation_token(&self) -> Arc<dyn CancellationToken> {
581        Arc::new(TaskGroupCancellationToken {
582            shutdown_rx: self.shared.shutdown_tx.subscribe(),
583            inherited: self.shared.inherited_cancellation.clone(),
584        })
585    }
586
587    pub fn active_tasks(&self) -> Vec<String> {
588        self.shared
589            .tasks
590            .lock()
591            .values()
592            .map(|task| task.task_name.clone())
593            .collect()
594    }
595
596    fn register_task(&self, task_id: u64, task_name: String) {
597        self.shared.tasks.lock().insert(
598            task_id,
599            TaskMetadata {
600                task_name,
601                #[cfg(not(target_arch = "wasm32"))]
602                handle: None,
603            },
604        );
605    }
606
607    #[cfg(not(target_arch = "wasm32"))]
608    fn attach_native_handle(&self, task_id: u64, handle: JoinHandle<()>) {
609        if let Some(metadata) = self.shared.tasks.lock().get_mut(&task_id) {
610            metadata.handle = Some(handle);
611        }
612    }
613
614    fn complete_task(&self, task_id: u64, task_name: &str, outcome: TaskOutcome) {
615        let removed = self.shared.tasks.lock().remove(&task_id);
616        if removed.is_none() {
617            return;
618        }
619
620        if matches!(outcome, TaskOutcome::Cancelled | TaskOutcome::Panicked) {
621            tracing::warn!(
622                event = "runtime.task.exit_non_success",
623                task_group = %self.shared.name,
624                task_name = %task_name,
625                outcome = ?outcome,
626                "Supervised task exited abnormally"
627            );
628        }
629
630        self.shared.notify.notify_waiters();
631    }
632
633    fn spawn_boxed(
634        &self,
635        task_name: String,
636        fut: BoxFuture<'static, ()>,
637        external_token: Option<Arc<dyn CancellationToken>>,
638    ) -> OwnedTaskHandle<u64> {
639        let task_id = self.shared.next_task_id.fetch_add(1, Ordering::Relaxed);
640        self.register_task(task_id, task_name.clone());
641        let group_name = self.shared.name.clone();
642        let mut shutdown_rx = self.shared.shutdown_tx.subscribe();
643        let inherited = self.shared.inherited_cancellation.clone();
644        let diagnostics = self.shared.diagnostics.clone();
645        let task_name_for_wrapper = task_name.clone();
646        let group = self.clone();
647
648        tracing::debug!(
649            event = "runtime.task.spawned",
650            task_group = %group_name,
651            task_name = %task_name,
652            task_id,
653            "Spawned supervised task"
654        );
655
656        #[cfg(not(target_arch = "wasm32"))]
657        let handle = tokio::spawn(async move {
658            let outcome = AssertUnwindSafe(async {
659                tokio::select! {
660                    _ = shutdown_cancelled(&mut shutdown_rx) => TaskOutcome::Cancelled,
661                    _ = inherited_cancelled(inherited.as_ref()) => TaskOutcome::Cancelled,
662                    _ = external_cancelled(external_token.as_deref()) => TaskOutcome::Cancelled,
663                    _ = fut => TaskOutcome::Completed,
664                }
665            })
666            .catch_unwind()
667            .await
668            .unwrap_or(TaskOutcome::Panicked);
669
670            emit_task_completion(
671                diagnostics.as_ref(),
672                &group_name,
673                &task_name_for_wrapper,
674                task_id,
675                &outcome,
676            );
677            group.complete_task(task_id, &task_name_for_wrapper, outcome);
678        });
679
680        #[cfg(not(target_arch = "wasm32"))]
681        self.attach_native_handle(task_id, handle);
682
683        #[cfg(target_arch = "wasm32")]
684        {
685            spawn_local(async move {
686                let outcome = AssertUnwindSafe(async {
687                    tokio::select! {
688                        _ = shutdown_cancelled(&mut shutdown_rx) => TaskOutcome::Cancelled,
689                        _ = inherited_cancelled(inherited.as_ref()) => TaskOutcome::Cancelled,
690                        _ = external_cancelled(external_token.as_deref()) => TaskOutcome::Cancelled,
691                        _ = fut => TaskOutcome::Completed,
692                    }
693                })
694                .catch_unwind()
695                .await
696                .unwrap_or(TaskOutcome::Panicked);
697
698                emit_task_completion(
699                    diagnostics.as_ref(),
700                    &group_name,
701                    &task_name_for_wrapper,
702                    task_id,
703                    &outcome,
704                );
705                group.complete_task(task_id, &task_name_for_wrapper, outcome);
706            });
707        }
708
709        OwnedTaskHandle::new(
710            task_id,
711            OwnedShutdownToken::attached(self.cancellation_token()),
712        )
713    }
714
715    fn spawn_boxed_local(
716        &self,
717        task_name: String,
718        fut: LocalBoxFuture<'static, ()>,
719        external_token: Option<Arc<dyn CancellationToken>>,
720    ) -> OwnedTaskHandle<u64> {
721        let task_id = self.shared.next_task_id.fetch_add(1, Ordering::Relaxed);
722        self.register_task(task_id, task_name.clone());
723        let mut shutdown_rx = self.shared.shutdown_tx.subscribe();
724        let inherited = self.shared.inherited_cancellation.clone();
725        let group_name = self.shared.name.clone();
726        let diagnostics = self.shared.diagnostics.clone();
727        let task_name_for_wrapper = task_name.clone();
728        let group = self.clone();
729
730        #[cfg(not(target_arch = "wasm32"))]
731        let handle = tokio::task::spawn_local(async move {
732            let outcome = AssertUnwindSafe(async {
733                tokio::select! {
734                    _ = shutdown_cancelled(&mut shutdown_rx) => TaskOutcome::Cancelled,
735                    _ = inherited_cancelled(inherited.as_ref()) => TaskOutcome::Cancelled,
736                    _ = external_cancelled(external_token.as_deref()) => TaskOutcome::Cancelled,
737                    _ = fut => TaskOutcome::Completed,
738                }
739            })
740            .catch_unwind()
741            .await
742            .unwrap_or(TaskOutcome::Panicked);
743
744            emit_task_completion(
745                diagnostics.as_ref(),
746                &group_name,
747                &task_name_for_wrapper,
748                task_id,
749                &outcome,
750            );
751            group.complete_task(task_id, &task_name_for_wrapper, outcome);
752        });
753
754        #[cfg(not(target_arch = "wasm32"))]
755        self.attach_native_handle(task_id, handle);
756
757        #[cfg(target_arch = "wasm32")]
758        {
759            spawn_local(async move {
760                let outcome = AssertUnwindSafe(async {
761                    tokio::select! {
762                        _ = shutdown_cancelled(&mut shutdown_rx) => TaskOutcome::Cancelled,
763                        _ = inherited_cancelled(inherited.as_ref()) => TaskOutcome::Cancelled,
764                        _ = external_cancelled(external_token.as_deref()) => TaskOutcome::Cancelled,
765                        _ = fut => TaskOutcome::Completed,
766                    }
767                })
768                .catch_unwind()
769                .await
770                .unwrap_or(TaskOutcome::Panicked);
771
772                emit_task_completion(
773                    diagnostics.as_ref(),
774                    &group_name,
775                    &task_name_for_wrapper,
776                    task_id,
777                    &outcome,
778                );
779                group.complete_task(task_id, &task_name_for_wrapper, outcome);
780            });
781        }
782
783        OwnedTaskHandle::new(
784            task_id,
785            OwnedShutdownToken::attached(self.cancellation_token()),
786        )
787    }
788}
789
790struct TaskGroupCancellationToken {
791    shutdown_rx: watch::Receiver<bool>,
792    inherited: Option<Arc<dyn CancellationToken>>,
793}
794
795#[async_trait::async_trait]
796impl CancellationToken for TaskGroupCancellationToken {
797    async fn cancelled(&self) {
798        if self.is_cancelled() {
799            return;
800        }
801
802        let mut shutdown_rx = self.shutdown_rx.clone();
803        match self.inherited.clone() {
804            Some(inherited) => {
805                tokio::select! {
806                    _ = shutdown_cancelled(&mut shutdown_rx) => {}
807                    _ = inherited.cancelled() => {}
808                }
809            }
810            None => {
811                shutdown_cancelled(&mut shutdown_rx).await;
812            }
813        }
814    }
815
816    fn is_cancelled(&self) -> bool {
817        *self.shutdown_rx.borrow()
818            || self
819                .inherited
820                .as_ref()
821                .map(|token| token.is_cancelled())
822                .unwrap_or(false)
823    }
824}
825
826impl TaskSpawner for TaskSupervisor {
827    fn spawn(&self, fut: BoxFuture<'static, ()>) {
828        let _ = self.spawn_named(DEFAULT_TASK_NAME, fut);
829    }
830
831    fn spawn_cancellable(&self, fut: BoxFuture<'static, ()>, token: Arc<dyn CancellationToken>) {
832        let _ = self
833            .root
834            .spawn_boxed(DEFAULT_TASK_NAME.to_string(), fut, Some(token));
835    }
836
837    fn spawn_local(&self, fut: LocalBoxFuture<'static, ()>) {
838        let _ = self
839            .root
840            .spawn_boxed_local(DEFAULT_TASK_NAME.to_string(), fut, None);
841    }
842
843    fn spawn_local_cancellable(
844        &self,
845        fut: LocalBoxFuture<'static, ()>,
846        token: Arc<dyn CancellationToken>,
847    ) {
848        let _ = self
849            .root
850            .spawn_boxed_local(DEFAULT_TASK_NAME.to_string(), fut, Some(token));
851    }
852
853    fn cancellation_token(&self) -> Arc<dyn CancellationToken> {
854        self.cancellation_token()
855    }
856}
857
858fn emit_task_completion(
859    diagnostics: Option<&Arc<RuntimeDiagnosticSink>>,
860    group: &str,
861    task_name: &str,
862    task_id: u64,
863    outcome: &TaskOutcome,
864) {
865    match outcome {
866        TaskOutcome::Completed => tracing::debug!(
867            event = "runtime.task.completed",
868            task_group = %group,
869            task_name = %task_name,
870            task_id,
871            "Supervised task completed"
872        ),
873        TaskOutcome::Cancelled => tracing::info!(
874            event = "runtime.task.cancelled",
875            task_group = %group,
876            task_name = %task_name,
877            task_id,
878            "Supervised task cancelled"
879        ),
880        TaskOutcome::Panicked => tracing::error!(
881            event = "runtime.task.panicked",
882            task_group = %group,
883            task_name = %task_name,
884            task_id,
885            "Supervised task panicked"
886        ),
887    }
888
889    if matches!(outcome, TaskOutcome::Panicked) {
890        emit_task_diagnostic(
891            diagnostics,
892            RuntimeDiagnosticSeverity::Error,
893            "task_supervisor",
894            format!("supervised task '{task_name}' in group '{group}' panicked"),
895        );
896    }
897}
898
899fn emit_task_diagnostic(
900    diagnostics: Option<&Arc<RuntimeDiagnosticSink>>,
901    severity: RuntimeDiagnosticSeverity,
902    component: &'static str,
903    message: String,
904) {
905    if let Some(diagnostics) = diagnostics {
906        diagnostics.emit(RuntimeDiagnostic {
907            severity,
908            kind: RuntimeDiagnosticKind::SupervisedTaskFailed,
909            component,
910            message,
911        });
912    }
913}
914
915async fn shutdown_cancelled(shutdown_rx: &mut watch::Receiver<bool>) {
916    loop {
917        if *shutdown_rx.borrow() {
918            return;
919        }
920        if shutdown_rx.changed().await.is_err() {
921            return;
922        }
923    }
924}
925
926async fn inherited_cancelled(token: Option<&Arc<dyn CancellationToken>>) {
927    match token {
928        Some(token) => token.cancelled().await,
929        None => futures::future::pending::<()>().await,
930    }
931}
932
933async fn external_cancelled(token: Option<&dyn CancellationToken>) {
934    match token {
935        Some(token) => token.cancelled().await,
936        None => futures::future::pending::<()>().await,
937    }
938}
939
940#[cfg(test)]
941mod tests {
942    use super::*;
943    use crate::runtime::{RuntimeDiagnosticKind, RuntimeDiagnosticSeverity};
944    use tokio::sync::oneshot;
945
946    #[tokio::test]
947    async fn shutdown_with_timeout_cancels_supervised_tasks() {
948        let supervisor = TaskSupervisor::new();
949        let (started_tx, started_rx) = oneshot::channel();
950
951        let _task_handle = supervisor.spawn_named("test.pending", async move {
952            let _ = started_tx.send(());
953            futures::future::pending::<()>().await;
954        });
955
956        started_rx.await.expect("task should start");
957        supervisor
958            .shutdown_with_timeout(Duration::from_millis(50))
959            .await
960            .expect("shutdown should cancel pending task");
961        assert!(supervisor.active_tasks().is_empty());
962    }
963
964    #[tokio::test]
965    async fn child_groups_inherit_parent_cancellation() {
966        let supervisor = TaskSupervisor::new();
967        let child = supervisor.group("child");
968        let (started_tx, started_rx) = oneshot::channel();
969
970        let _task_handle = child.spawn_named("test.pending", async move {
971            let _ = started_tx.send(());
972            futures::future::pending::<()>().await;
973        });
974
975        started_rx.await.expect("task should start");
976        supervisor.request_cancellation();
977        child
978            .wait_for_idle(Duration::from_millis(50))
979            .await
980            .expect("child tasks should stop when parent is cancelled");
981    }
982
983    #[tokio::test]
984    async fn wait_for_idle_times_out_and_force_abort_reports_tasks() {
985        let supervisor = TaskSupervisor::new();
986        let (started_tx, started_rx) = oneshot::channel();
987
988        let _task_handle = supervisor.spawn_named("test.pending", async move {
989            let _ = started_tx.send(());
990            futures::future::pending::<()>().await;
991        });
992
993        started_rx.await.expect("task should start");
994        let timeout = supervisor.wait_for_idle(Duration::from_millis(10)).await;
995        assert!(matches!(timeout, Err(TaskSupervisionError::Timeout { .. })));
996
997        let abort = supervisor.force_abort_remaining();
998        assert!(matches!(
999            abort,
1000            Err(TaskSupervisionError::ForcedAbort { .. })
1001        ));
1002        assert!(supervisor.active_tasks().is_empty());
1003    }
1004
1005    #[tokio::test]
1006    async fn force_abort_emits_runtime_diagnostic() {
1007        let diagnostics = Arc::new(RuntimeDiagnosticSink::new());
1008        let supervisor = TaskSupervisor::with_diagnostics(diagnostics.clone());
1009        let (started_tx, started_rx) = oneshot::channel();
1010
1011        let _task_handle = supervisor.spawn_named("test.pending", async move {
1012            let _ = started_tx.send(());
1013            futures::future::pending::<()>().await;
1014        });
1015
1016        started_rx.await.expect("task should start");
1017        let mut rx = diagnostics.subscribe();
1018        let abort = supervisor.force_abort_remaining();
1019        assert!(matches!(
1020            abort,
1021            Err(TaskSupervisionError::ForcedAbort { .. })
1022        ));
1023
1024        let diagnostic = rx.try_recv().expect("diagnostic emitted");
1025        assert_eq!(diagnostic.kind, RuntimeDiagnosticKind::SupervisedTaskFailed);
1026        assert_eq!(diagnostic.severity, RuntimeDiagnosticSeverity::Warn);
1027    }
1028
1029    #[test]
1030    fn loom_shutdown_race_does_not_leave_task_registered() {
1031        loom::model(|| {
1032            use loom::sync::atomic::{AtomicBool, Ordering};
1033            use loom::sync::{Arc as LoomArc, Mutex as LoomMutex};
1034            use loom::thread;
1035
1036            let active = LoomArc::new(LoomMutex::new(Vec::<u8>::new()));
1037            let cancelled = LoomArc::new(AtomicBool::new(false));
1038
1039            let register_active = LoomArc::clone(&active);
1040            let register_cancelled = LoomArc::clone(&cancelled);
1041            let register = thread::spawn(move || {
1042                {
1043                    let mut tasks = register_active.lock().unwrap();
1044                    tasks.push(1);
1045                }
1046                if register_cancelled.load(Ordering::Acquire) {
1047                    let mut tasks = register_active.lock().unwrap();
1048                    tasks.retain(|task| *task != 1);
1049                }
1050            });
1051
1052            let shutdown_active = LoomArc::clone(&active);
1053            let shutdown_cancelled = LoomArc::clone(&cancelled);
1054            let shutdown = thread::spawn(move || {
1055                shutdown_cancelled.store(true, Ordering::Release);
1056                let mut tasks = shutdown_active.lock().unwrap();
1057                tasks.retain(|task| *task != 1);
1058            });
1059
1060            register.join().expect("register thread");
1061            shutdown.join().expect("shutdown thread");
1062            assert!(
1063                active.lock().unwrap().is_empty(),
1064                "task bookkeeping should not leak active entries across shutdown races"
1065            );
1066        });
1067    }
1068
1069    #[test]
1070    fn loom_shutdown_token_propagation_reaches_child() {
1071        loom::model(|| {
1072            use loom::sync::atomic::{AtomicBool, Ordering};
1073            use loom::sync::Arc as LoomArc;
1074            use loom::thread;
1075
1076            let cancelled = LoomArc::new(AtomicBool::new(false));
1077            let child_observed = LoomArc::new(AtomicBool::new(false));
1078
1079            let child = {
1080                let cancelled = cancelled.clone();
1081                let child_observed = child_observed.clone();
1082                thread::spawn(move || {
1083                    while !cancelled.load(Ordering::Acquire) {
1084                        thread::yield_now();
1085                    }
1086                    child_observed.store(true, Ordering::Release);
1087                })
1088            };
1089
1090            let parent = {
1091                let cancelled = cancelled.clone();
1092                thread::spawn(move || {
1093                    cancelled.store(true, Ordering::Release);
1094                })
1095            };
1096
1097            parent.join().expect("parent joins");
1098            child.join().expect("child joins");
1099
1100            assert!(
1101                child_observed.load(Ordering::Acquire),
1102                "child cancellation observer must see parent-driven shutdown"
1103            );
1104        });
1105    }
1106}