1#![allow(clippy::disallowed_types)]
8
9use std::collections::BTreeMap;
10use std::future::Future;
11use std::panic::AssertUnwindSafe;
12use std::sync::atomic::{AtomicU64, Ordering};
13use std::sync::Arc;
14use std::time::Duration;
15
16use crate::runtime::{
17 RuntimeDiagnostic, RuntimeDiagnosticKind, RuntimeDiagnosticSeverity, RuntimeDiagnosticSink,
18};
19use aura_core::effects::task::{CancellationToken, TaskSpawner};
20use aura_core::effects::PhysicalTimeEffects;
21use aura_core::{
22 execute_with_timeout_budget, OwnedShutdownToken, OwnedTaskHandle, TimeoutBudget,
23 TimeoutRunError,
24};
25use aura_effects::time::PhysicalTimeHandler;
26use futures::future::{BoxFuture, LocalBoxFuture};
27use futures::FutureExt;
28#[cfg(not(target_arch = "wasm32"))]
29use parking_lot::Mutex;
30#[cfg(target_arch = "wasm32")]
31use parking_lot::Mutex;
32use tokio::sync::watch;
33use tokio::sync::Notify;
34#[cfg(not(target_arch = "wasm32"))]
35use tokio::task::JoinHandle;
36#[cfg(target_arch = "wasm32")]
37use wasm_bindgen_futures::spawn_local;
38
39const DEFAULT_TASK_NAME: &str = "task.default";
40
41#[derive(Debug, Clone, PartialEq, Eq)]
42pub enum TaskSupervisionError {
43 Timeout {
44 group: String,
45 active_tasks: Vec<String>,
46 },
47 ForcedAbort {
48 group: String,
49 aborted_tasks: Vec<String>,
50 },
51 Cancelled {
52 group: String,
53 task: String,
54 },
55 Panicked {
56 group: String,
57 task: String,
58 },
59}
60
61impl std::fmt::Display for TaskSupervisionError {
62 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63 match self {
64 Self::Timeout {
65 group,
66 active_tasks,
67 } => write!(
68 f,
69 "task group '{group}' timed out waiting for tasks: {}",
70 active_tasks.join(", ")
71 ),
72 Self::ForcedAbort {
73 group,
74 aborted_tasks,
75 } => write!(
76 f,
77 "task group '{group}' force-aborted tasks: {}",
78 aborted_tasks.join(", ")
79 ),
80 Self::Cancelled { group, task } => {
81 write!(f, "task '{task}' in group '{group}' was cancelled")
82 }
83 Self::Panicked { group, task } => {
84 write!(f, "task '{task}' in group '{group}' panicked")
85 }
86 }
87 }
88}
89
90impl std::error::Error for TaskSupervisionError {}
91
92#[derive(Debug, Clone, PartialEq, Eq)]
93enum TaskOutcome {
94 Completed,
95 Cancelled,
96 Panicked,
97}
98
99#[derive(Debug)]
100struct TaskMetadata {
101 task_name: String,
102 #[cfg(not(target_arch = "wasm32"))]
103 handle: Option<JoinHandle<()>>,
104}
105
106struct TaskGroupShared {
107 name: String,
108 next_task_id: AtomicU64,
109 shutdown_tx: watch::Sender<bool>,
110 inherited_cancellation: Option<Arc<dyn CancellationToken>>,
111 diagnostics: Option<Arc<RuntimeDiagnosticSink>>,
112 tasks: Mutex<BTreeMap<u64, TaskMetadata>>,
113 notify: Arc<Notify>,
114}
115
116#[derive(Clone)]
117pub struct TaskGroup {
118 shared: Arc<TaskGroupShared>,
119}
120
121#[derive(Clone)]
122pub struct TaskSupervisor {
123 root: TaskGroup,
124}
125
126impl TaskSupervisor {
127 pub fn new() -> Self {
128 Self {
129 root: TaskGroup::root("runtime", None),
130 }
131 }
132
133 pub fn with_diagnostics(diagnostics: Arc<RuntimeDiagnosticSink>) -> Self {
134 Self {
135 root: TaskGroup::root("runtime", Some(diagnostics)),
136 }
137 }
138
139 pub fn group(&self, name: impl Into<String>) -> TaskGroup {
140 self.root.group(name)
141 }
142
143 #[must_use = "retain or explicitly discard the owned task handle"]
144 pub fn spawn_named<F>(&self, name: impl Into<String>, fut: F) -> OwnedTaskHandle<u64>
145 where
146 F: Future<Output = ()> + Send + 'static,
147 {
148 self.root.spawn_named(name, fut)
149 }
150
151 #[must_use = "retain or explicitly discard the owned task handle"]
152 pub fn spawn_cancellable_named<F>(
153 &self,
154 name: impl Into<String>,
155 fut: F,
156 ) -> OwnedTaskHandle<u64>
157 where
158 F: Future<Output = ()> + Send + 'static,
159 {
160 self.root.spawn_cancellable_named(name, fut)
161 }
162
163 #[must_use = "retain or explicitly discard the owned task handle"]
164 pub fn spawn_local_named<F>(&self, name: impl Into<String>, fut: F) -> OwnedTaskHandle<u64>
165 where
166 F: Future<Output = ()> + 'static,
167 {
168 self.root.spawn_local_named(name, fut)
169 }
170
171 #[must_use = "retain or explicitly discard the owned task handle"]
172 pub fn spawn_local_cancellable_named<F>(
173 &self,
174 name: impl Into<String>,
175 fut: F,
176 ) -> OwnedTaskHandle<u64>
177 where
178 F: Future<Output = ()> + 'static,
179 {
180 self.root.spawn_local_cancellable_named(name, fut)
181 }
182
183 #[must_use = "retain or explicitly discard the owned task handle"]
184 pub fn spawn_interval_until_named<F, Fut>(
185 &self,
186 name: impl Into<String>,
187 time_effects: Arc<dyn PhysicalTimeEffects + Send + Sync>,
188 interval: Duration,
189 f: F,
190 ) -> OwnedTaskHandle<u64>
191 where
192 F: FnMut() -> Fut + Send + 'static,
193 Fut: Future<Output = bool> + Send + 'static,
194 {
195 self.root
196 .spawn_interval_until_named(name, time_effects, interval, f)
197 }
198
199 #[must_use = "retain or explicitly discard the owned task handle"]
200 pub fn spawn_local_interval_until_named<F, Fut>(
201 &self,
202 name: impl Into<String>,
203 time_effects: Arc<dyn PhysicalTimeEffects + Send + Sync>,
204 interval: Duration,
205 f: F,
206 ) -> OwnedTaskHandle<u64>
207 where
208 F: FnMut() -> Fut + 'static,
209 Fut: Future<Output = bool> + 'static,
210 {
211 self.root
212 .spawn_local_interval_until_named(name, time_effects, interval, f)
213 }
214
215 #[must_use = "retain or explicitly discard the owned task handle"]
216 pub fn spawn_child<F>(&self, name: impl Into<String>, fut: F) -> OwnedTaskHandle<u64>
217 where
218 F: Future<Output = ()> + Send + 'static,
219 {
220 self.spawn_named(name, fut)
221 }
222
223 #[must_use = "retain or explicitly discard the owned task handle"]
224 pub fn spawn_periodic<F, Fut>(
225 &self,
226 name: impl Into<String>,
227 time_effects: Arc<dyn PhysicalTimeEffects + Send + Sync>,
228 interval: Duration,
229 f: F,
230 ) -> OwnedTaskHandle<u64>
231 where
232 F: FnMut() -> Fut + Send + 'static,
233 Fut: Future<Output = bool> + Send + 'static,
234 {
235 self.spawn_interval_until_named(name, time_effects, interval, f)
236 }
237
238 pub fn request_cancellation(&self) {
239 self.root.request_cancellation();
240 }
241
242 pub async fn wait_for_idle(&self, timeout: Duration) -> Result<(), TaskSupervisionError> {
243 self.root.wait_for_idle(timeout).await
244 }
245
246 pub fn force_abort_remaining(&self) -> Result<(), TaskSupervisionError> {
247 self.root.force_abort_remaining()
248 }
249
250 pub fn abort_remaining(&self) -> Result<(), TaskSupervisionError> {
251 self.force_abort_remaining()
252 }
253
254 pub async fn shutdown_with_timeout(
255 &self,
256 timeout: Duration,
257 ) -> Result<(), TaskSupervisionError> {
258 self.root.shutdown_with_timeout(timeout).await
259 }
260
261 pub async fn shutdown_gracefully(&self, timeout: Duration) -> Result<(), TaskSupervisionError> {
262 self.shutdown_with_timeout(timeout).await
263 }
264
265 pub fn shutdown(&self) {
266 self.root.shutdown();
267 }
268
269 pub fn cancellation_token(&self) -> Arc<dyn CancellationToken> {
270 self.root.cancellation_token()
271 }
272
273 pub fn active_tasks(&self) -> Vec<String> {
274 self.root.active_tasks()
275 }
276}
277
278impl Default for TaskSupervisor {
279 fn default() -> Self {
280 Self::new()
281 }
282}
283
284impl Drop for TaskSupervisor {
285 fn drop(&mut self) {
286 self.shutdown();
287 }
288}
289
290impl TaskGroup {
291 fn root(name: impl Into<String>, diagnostics: Option<Arc<RuntimeDiagnosticSink>>) -> Self {
292 let (shutdown_tx, _shutdown_rx) = watch::channel(false);
293 Self {
294 shared: Arc::new(TaskGroupShared {
295 name: name.into(),
296 next_task_id: AtomicU64::new(1),
297 shutdown_tx,
298 inherited_cancellation: None,
299 diagnostics,
300 tasks: Mutex::new(BTreeMap::new()),
301 notify: Arc::new(Notify::new()),
302 }),
303 }
304 }
305
306 pub fn name(&self) -> &str {
307 &self.shared.name
308 }
309
310 pub fn group(&self, name: impl Into<String>) -> TaskGroup {
311 let name = name.into();
312 let full_name = format!("{}.{}", self.shared.name, name);
313 let (shutdown_tx, _shutdown_rx) = watch::channel(false);
314 TaskGroup {
315 shared: Arc::new(TaskGroupShared {
316 name: full_name,
317 next_task_id: AtomicU64::new(1),
318 shutdown_tx,
319 inherited_cancellation: Some(self.cancellation_token()),
320 diagnostics: self.shared.diagnostics.clone(),
321 tasks: Mutex::new(BTreeMap::new()),
322 notify: Arc::new(Notify::new()),
323 }),
324 }
325 }
326
327 #[must_use = "retain or explicitly discard the owned task handle"]
328 pub fn spawn_named<F>(&self, name: impl Into<String>, fut: F) -> OwnedTaskHandle<u64>
329 where
330 F: Future<Output = ()> + Send + 'static,
331 {
332 self.spawn_boxed(name.into(), Box::pin(fut), None)
333 }
334
335 #[must_use = "retain or explicitly discard the owned task handle"]
336 pub fn spawn_cancellable_named<F>(
337 &self,
338 name: impl Into<String>,
339 fut: F,
340 ) -> OwnedTaskHandle<u64>
341 where
342 F: Future<Output = ()> + Send + 'static,
343 {
344 self.spawn_boxed(name.into(), Box::pin(fut), None)
345 }
346
347 #[must_use = "retain or explicitly discard the owned task handle"]
348 pub fn spawn_local_named<F>(&self, name: impl Into<String>, fut: F) -> OwnedTaskHandle<u64>
349 where
350 F: Future<Output = ()> + 'static,
351 {
352 self.spawn_boxed_local(name.into(), Box::pin(fut), None)
353 }
354
355 #[must_use = "retain or explicitly discard the owned task handle"]
356 pub fn spawn_local_cancellable_named<F>(
357 &self,
358 name: impl Into<String>,
359 fut: F,
360 ) -> OwnedTaskHandle<u64>
361 where
362 F: Future<Output = ()> + 'static,
363 {
364 self.spawn_boxed_local(name.into(), Box::pin(fut), None)
365 }
366
367 #[must_use = "retain or explicitly discard the owned task handle"]
368 pub fn spawn_with_token<F>(
369 &self,
370 name: impl Into<String>,
371 fut: F,
372 token: Arc<dyn CancellationToken>,
373 ) -> OwnedTaskHandle<u64>
374 where
375 F: Future<Output = ()> + Send + 'static,
376 {
377 self.spawn_boxed(name.into(), Box::pin(fut), Some(token))
378 }
379
380 #[must_use = "retain or explicitly discard the owned task handle"]
381 pub fn spawn_child<F>(&self, name: impl Into<String>, fut: F) -> OwnedTaskHandle<u64>
382 where
383 F: Future<Output = ()> + Send + 'static,
384 {
385 self.spawn_named(name, fut)
386 }
387
388 #[must_use = "retain or explicitly discard the owned task handle"]
389 pub fn spawn_interval_until_named<F, Fut>(
390 &self,
391 name: impl Into<String>,
392 time_effects: Arc<dyn PhysicalTimeEffects + Send + Sync>,
393 interval: Duration,
394 mut f: F,
395 ) -> OwnedTaskHandle<u64>
396 where
397 F: FnMut() -> Fut + Send + 'static,
398 Fut: Future<Output = bool> + Send + 'static,
399 {
400 let interval_ms = interval.as_millis().try_into().unwrap_or(u64::MAX);
401 self.spawn_boxed(
402 name.into(),
403 Box::pin(async move {
404 loop {
405 if !f().await {
406 break;
407 }
408
409 if time_effects.sleep_ms(interval_ms).await.is_err() {
410 break;
411 }
412 }
413 }),
414 None,
415 )
416 }
417
418 #[must_use = "retain or explicitly discard the owned task handle"]
419 pub fn spawn_local_interval_until_named<F, Fut>(
420 &self,
421 name: impl Into<String>,
422 time_effects: Arc<dyn PhysicalTimeEffects + Send + Sync>,
423 interval: Duration,
424 mut f: F,
425 ) -> OwnedTaskHandle<u64>
426 where
427 F: FnMut() -> Fut + 'static,
428 Fut: Future<Output = bool> + 'static,
429 {
430 let interval_ms = interval.as_millis().try_into().unwrap_or(u64::MAX);
431 self.spawn_boxed_local(
432 name.into(),
433 Box::pin(async move {
434 loop {
435 if !f().await {
436 break;
437 }
438
439 if time_effects.sleep_ms(interval_ms).await.is_err() {
440 break;
441 }
442 }
443 }),
444 None,
445 )
446 }
447
448 #[must_use = "retain or explicitly discard the owned task handle"]
449 pub fn spawn_periodic<F, Fut>(
450 &self,
451 name: impl Into<String>,
452 time_effects: Arc<dyn PhysicalTimeEffects + Send + Sync>,
453 interval: Duration,
454 f: F,
455 ) -> OwnedTaskHandle<u64>
456 where
457 F: FnMut() -> Fut + Send + 'static,
458 Fut: Future<Output = bool> + Send + 'static,
459 {
460 self.spawn_interval_until_named(name, time_effects, interval, f)
461 }
462
463 pub fn request_cancellation(&self) {
464 let _ = self.shared.shutdown_tx.send(true);
465 tracing::debug!(
466 event = "runtime.task_group.cancel_requested",
467 task_group = %self.shared.name,
468 active_tasks = self.active_tasks().len(),
469 "Task group cancellation requested"
470 );
471 self.shared.notify.notify_waiters();
472 }
473
474 pub async fn wait_for_idle(&self, timeout: Duration) -> Result<(), TaskSupervisionError> {
475 let group_name = self.shared.name.clone();
476 let time = PhysicalTimeHandler::new();
477 let started_at = time
478 .physical_time()
479 .await
480 .map_err(|_| TaskSupervisionError::Timeout {
481 group: group_name.clone(),
482 active_tasks: self.active_tasks(),
483 })?;
484 let budget = TimeoutBudget::from_start_and_timeout(&started_at, timeout).map_err(|_| {
485 TaskSupervisionError::Timeout {
486 group: group_name.clone(),
487 active_tasks: self.active_tasks(),
488 }
489 })?;
490 let result = execute_with_timeout_budget(&time, &budget, || async {
491 loop {
492 if self.shared.tasks.lock().is_empty() {
493 return Ok::<(), ()>(());
494 }
495 self.shared.notify.notified().await;
496 }
497 })
498 .await;
499
500 match result {
501 Ok(()) => Ok(()),
502 Err(TimeoutRunError::Timeout(_)) | Err(TimeoutRunError::Operation(_)) => {
503 Err(TaskSupervisionError::Timeout {
504 group: group_name,
505 active_tasks: self.active_tasks(),
506 })
507 }
508 }
509 }
510
511 pub fn force_abort_remaining(&self) -> Result<(), TaskSupervisionError> {
512 let mut tasks = self.shared.tasks.lock();
513 if tasks.is_empty() {
514 return Ok(());
515 }
516
517 let mut aborted_tasks = Vec::with_capacity(tasks.len());
518 #[cfg(not(target_arch = "wasm32"))]
519 for (_, entry) in tasks.iter() {
520 if let Some(handle) = &entry.handle {
521 handle.abort();
522 }
523 aborted_tasks.push(entry.task_name.clone());
524 emit_task_diagnostic(
525 self.shared.diagnostics.as_ref(),
526 RuntimeDiagnosticSeverity::Warn,
527 "task_supervisor",
528 format!(
529 "force-aborted supervised task '{}' in group '{}'",
530 entry.task_name, self.shared.name
531 ),
532 );
533 tracing::warn!(
534 event = "runtime.task.abort_forced",
535 task_group = %self.shared.name,
536 task_name = %entry.task_name,
537 "Force-aborted supervised task"
538 );
539 }
540
541 #[cfg(target_arch = "wasm32")]
542 for (_, entry) in tasks.iter() {
543 aborted_tasks.push(entry.task_name.clone());
544 }
545
546 tasks.clear();
547 self.shared.notify.notify_waiters();
548
549 Err(TaskSupervisionError::ForcedAbort {
550 group: self.shared.name.clone(),
551 aborted_tasks,
552 })
553 }
554
555 pub fn abort_remaining(&self) -> Result<(), TaskSupervisionError> {
556 self.force_abort_remaining()
557 }
558
559 pub async fn shutdown_with_timeout(
560 &self,
561 timeout: Duration,
562 ) -> Result<(), TaskSupervisionError> {
563 self.request_cancellation();
564 match self.wait_for_idle(timeout).await {
565 Ok(()) => Ok(()),
566 Err(TaskSupervisionError::Timeout { .. }) => self.force_abort_remaining(),
567 Err(other) => Err(other),
568 }
569 }
570
571 pub async fn shutdown_gracefully(&self, timeout: Duration) -> Result<(), TaskSupervisionError> {
572 self.shutdown_with_timeout(timeout).await
573 }
574
575 pub fn shutdown(&self) {
576 self.request_cancellation();
577 let _ = self.force_abort_remaining();
578 }
579
580 pub fn cancellation_token(&self) -> Arc<dyn CancellationToken> {
581 Arc::new(TaskGroupCancellationToken {
582 shutdown_rx: self.shared.shutdown_tx.subscribe(),
583 inherited: self.shared.inherited_cancellation.clone(),
584 })
585 }
586
587 pub fn active_tasks(&self) -> Vec<String> {
588 self.shared
589 .tasks
590 .lock()
591 .values()
592 .map(|task| task.task_name.clone())
593 .collect()
594 }
595
596 fn register_task(&self, task_id: u64, task_name: String) {
597 self.shared.tasks.lock().insert(
598 task_id,
599 TaskMetadata {
600 task_name,
601 #[cfg(not(target_arch = "wasm32"))]
602 handle: None,
603 },
604 );
605 }
606
607 #[cfg(not(target_arch = "wasm32"))]
608 fn attach_native_handle(&self, task_id: u64, handle: JoinHandle<()>) {
609 if let Some(metadata) = self.shared.tasks.lock().get_mut(&task_id) {
610 metadata.handle = Some(handle);
611 }
612 }
613
614 fn complete_task(&self, task_id: u64, task_name: &str, outcome: TaskOutcome) {
615 let removed = self.shared.tasks.lock().remove(&task_id);
616 if removed.is_none() {
617 return;
618 }
619
620 if matches!(outcome, TaskOutcome::Cancelled | TaskOutcome::Panicked) {
621 tracing::warn!(
622 event = "runtime.task.exit_non_success",
623 task_group = %self.shared.name,
624 task_name = %task_name,
625 outcome = ?outcome,
626 "Supervised task exited abnormally"
627 );
628 }
629
630 self.shared.notify.notify_waiters();
631 }
632
633 fn spawn_boxed(
634 &self,
635 task_name: String,
636 fut: BoxFuture<'static, ()>,
637 external_token: Option<Arc<dyn CancellationToken>>,
638 ) -> OwnedTaskHandle<u64> {
639 let task_id = self.shared.next_task_id.fetch_add(1, Ordering::Relaxed);
640 self.register_task(task_id, task_name.clone());
641 let group_name = self.shared.name.clone();
642 let mut shutdown_rx = self.shared.shutdown_tx.subscribe();
643 let inherited = self.shared.inherited_cancellation.clone();
644 let diagnostics = self.shared.diagnostics.clone();
645 let task_name_for_wrapper = task_name.clone();
646 let group = self.clone();
647
648 tracing::debug!(
649 event = "runtime.task.spawned",
650 task_group = %group_name,
651 task_name = %task_name,
652 task_id,
653 "Spawned supervised task"
654 );
655
656 #[cfg(not(target_arch = "wasm32"))]
657 let handle = tokio::spawn(async move {
658 let outcome = AssertUnwindSafe(async {
659 tokio::select! {
660 _ = shutdown_cancelled(&mut shutdown_rx) => TaskOutcome::Cancelled,
661 _ = inherited_cancelled(inherited.as_ref()) => TaskOutcome::Cancelled,
662 _ = external_cancelled(external_token.as_deref()) => TaskOutcome::Cancelled,
663 _ = fut => TaskOutcome::Completed,
664 }
665 })
666 .catch_unwind()
667 .await
668 .unwrap_or(TaskOutcome::Panicked);
669
670 emit_task_completion(
671 diagnostics.as_ref(),
672 &group_name,
673 &task_name_for_wrapper,
674 task_id,
675 &outcome,
676 );
677 group.complete_task(task_id, &task_name_for_wrapper, outcome);
678 });
679
680 #[cfg(not(target_arch = "wasm32"))]
681 self.attach_native_handle(task_id, handle);
682
683 #[cfg(target_arch = "wasm32")]
684 {
685 spawn_local(async move {
686 let outcome = AssertUnwindSafe(async {
687 tokio::select! {
688 _ = shutdown_cancelled(&mut shutdown_rx) => TaskOutcome::Cancelled,
689 _ = inherited_cancelled(inherited.as_ref()) => TaskOutcome::Cancelled,
690 _ = external_cancelled(external_token.as_deref()) => TaskOutcome::Cancelled,
691 _ = fut => TaskOutcome::Completed,
692 }
693 })
694 .catch_unwind()
695 .await
696 .unwrap_or(TaskOutcome::Panicked);
697
698 emit_task_completion(
699 diagnostics.as_ref(),
700 &group_name,
701 &task_name_for_wrapper,
702 task_id,
703 &outcome,
704 );
705 group.complete_task(task_id, &task_name_for_wrapper, outcome);
706 });
707 }
708
709 OwnedTaskHandle::new(
710 task_id,
711 OwnedShutdownToken::attached(self.cancellation_token()),
712 )
713 }
714
715 fn spawn_boxed_local(
716 &self,
717 task_name: String,
718 fut: LocalBoxFuture<'static, ()>,
719 external_token: Option<Arc<dyn CancellationToken>>,
720 ) -> OwnedTaskHandle<u64> {
721 let task_id = self.shared.next_task_id.fetch_add(1, Ordering::Relaxed);
722 self.register_task(task_id, task_name.clone());
723 let mut shutdown_rx = self.shared.shutdown_tx.subscribe();
724 let inherited = self.shared.inherited_cancellation.clone();
725 let group_name = self.shared.name.clone();
726 let diagnostics = self.shared.diagnostics.clone();
727 let task_name_for_wrapper = task_name.clone();
728 let group = self.clone();
729
730 #[cfg(not(target_arch = "wasm32"))]
731 let handle = tokio::task::spawn_local(async move {
732 let outcome = AssertUnwindSafe(async {
733 tokio::select! {
734 _ = shutdown_cancelled(&mut shutdown_rx) => TaskOutcome::Cancelled,
735 _ = inherited_cancelled(inherited.as_ref()) => TaskOutcome::Cancelled,
736 _ = external_cancelled(external_token.as_deref()) => TaskOutcome::Cancelled,
737 _ = fut => TaskOutcome::Completed,
738 }
739 })
740 .catch_unwind()
741 .await
742 .unwrap_or(TaskOutcome::Panicked);
743
744 emit_task_completion(
745 diagnostics.as_ref(),
746 &group_name,
747 &task_name_for_wrapper,
748 task_id,
749 &outcome,
750 );
751 group.complete_task(task_id, &task_name_for_wrapper, outcome);
752 });
753
754 #[cfg(not(target_arch = "wasm32"))]
755 self.attach_native_handle(task_id, handle);
756
757 #[cfg(target_arch = "wasm32")]
758 {
759 spawn_local(async move {
760 let outcome = AssertUnwindSafe(async {
761 tokio::select! {
762 _ = shutdown_cancelled(&mut shutdown_rx) => TaskOutcome::Cancelled,
763 _ = inherited_cancelled(inherited.as_ref()) => TaskOutcome::Cancelled,
764 _ = external_cancelled(external_token.as_deref()) => TaskOutcome::Cancelled,
765 _ = fut => TaskOutcome::Completed,
766 }
767 })
768 .catch_unwind()
769 .await
770 .unwrap_or(TaskOutcome::Panicked);
771
772 emit_task_completion(
773 diagnostics.as_ref(),
774 &group_name,
775 &task_name_for_wrapper,
776 task_id,
777 &outcome,
778 );
779 group.complete_task(task_id, &task_name_for_wrapper, outcome);
780 });
781 }
782
783 OwnedTaskHandle::new(
784 task_id,
785 OwnedShutdownToken::attached(self.cancellation_token()),
786 )
787 }
788}
789
790struct TaskGroupCancellationToken {
791 shutdown_rx: watch::Receiver<bool>,
792 inherited: Option<Arc<dyn CancellationToken>>,
793}
794
795#[async_trait::async_trait]
796impl CancellationToken for TaskGroupCancellationToken {
797 async fn cancelled(&self) {
798 if self.is_cancelled() {
799 return;
800 }
801
802 let mut shutdown_rx = self.shutdown_rx.clone();
803 match self.inherited.clone() {
804 Some(inherited) => {
805 tokio::select! {
806 _ = shutdown_cancelled(&mut shutdown_rx) => {}
807 _ = inherited.cancelled() => {}
808 }
809 }
810 None => {
811 shutdown_cancelled(&mut shutdown_rx).await;
812 }
813 }
814 }
815
816 fn is_cancelled(&self) -> bool {
817 *self.shutdown_rx.borrow()
818 || self
819 .inherited
820 .as_ref()
821 .map(|token| token.is_cancelled())
822 .unwrap_or(false)
823 }
824}
825
826impl TaskSpawner for TaskSupervisor {
827 fn spawn(&self, fut: BoxFuture<'static, ()>) {
828 let _ = self.spawn_named(DEFAULT_TASK_NAME, fut);
829 }
830
831 fn spawn_cancellable(&self, fut: BoxFuture<'static, ()>, token: Arc<dyn CancellationToken>) {
832 let _ = self
833 .root
834 .spawn_boxed(DEFAULT_TASK_NAME.to_string(), fut, Some(token));
835 }
836
837 fn spawn_local(&self, fut: LocalBoxFuture<'static, ()>) {
838 let _ = self
839 .root
840 .spawn_boxed_local(DEFAULT_TASK_NAME.to_string(), fut, None);
841 }
842
843 fn spawn_local_cancellable(
844 &self,
845 fut: LocalBoxFuture<'static, ()>,
846 token: Arc<dyn CancellationToken>,
847 ) {
848 let _ = self
849 .root
850 .spawn_boxed_local(DEFAULT_TASK_NAME.to_string(), fut, Some(token));
851 }
852
853 fn cancellation_token(&self) -> Arc<dyn CancellationToken> {
854 self.cancellation_token()
855 }
856}
857
858fn emit_task_completion(
859 diagnostics: Option<&Arc<RuntimeDiagnosticSink>>,
860 group: &str,
861 task_name: &str,
862 task_id: u64,
863 outcome: &TaskOutcome,
864) {
865 match outcome {
866 TaskOutcome::Completed => tracing::debug!(
867 event = "runtime.task.completed",
868 task_group = %group,
869 task_name = %task_name,
870 task_id,
871 "Supervised task completed"
872 ),
873 TaskOutcome::Cancelled => tracing::info!(
874 event = "runtime.task.cancelled",
875 task_group = %group,
876 task_name = %task_name,
877 task_id,
878 "Supervised task cancelled"
879 ),
880 TaskOutcome::Panicked => tracing::error!(
881 event = "runtime.task.panicked",
882 task_group = %group,
883 task_name = %task_name,
884 task_id,
885 "Supervised task panicked"
886 ),
887 }
888
889 if matches!(outcome, TaskOutcome::Panicked) {
890 emit_task_diagnostic(
891 diagnostics,
892 RuntimeDiagnosticSeverity::Error,
893 "task_supervisor",
894 format!("supervised task '{task_name}' in group '{group}' panicked"),
895 );
896 }
897}
898
899fn emit_task_diagnostic(
900 diagnostics: Option<&Arc<RuntimeDiagnosticSink>>,
901 severity: RuntimeDiagnosticSeverity,
902 component: &'static str,
903 message: String,
904) {
905 if let Some(diagnostics) = diagnostics {
906 diagnostics.emit(RuntimeDiagnostic {
907 severity,
908 kind: RuntimeDiagnosticKind::SupervisedTaskFailed,
909 component,
910 message,
911 });
912 }
913}
914
915async fn shutdown_cancelled(shutdown_rx: &mut watch::Receiver<bool>) {
916 loop {
917 if *shutdown_rx.borrow() {
918 return;
919 }
920 if shutdown_rx.changed().await.is_err() {
921 return;
922 }
923 }
924}
925
926async fn inherited_cancelled(token: Option<&Arc<dyn CancellationToken>>) {
927 match token {
928 Some(token) => token.cancelled().await,
929 None => futures::future::pending::<()>().await,
930 }
931}
932
933async fn external_cancelled(token: Option<&dyn CancellationToken>) {
934 match token {
935 Some(token) => token.cancelled().await,
936 None => futures::future::pending::<()>().await,
937 }
938}
939
940#[cfg(test)]
941mod tests {
942 use super::*;
943 use crate::runtime::{RuntimeDiagnosticKind, RuntimeDiagnosticSeverity};
944 use tokio::sync::oneshot;
945
946 #[tokio::test]
947 async fn shutdown_with_timeout_cancels_supervised_tasks() {
948 let supervisor = TaskSupervisor::new();
949 let (started_tx, started_rx) = oneshot::channel();
950
951 let _task_handle = supervisor.spawn_named("test.pending", async move {
952 let _ = started_tx.send(());
953 futures::future::pending::<()>().await;
954 });
955
956 started_rx.await.expect("task should start");
957 supervisor
958 .shutdown_with_timeout(Duration::from_millis(50))
959 .await
960 .expect("shutdown should cancel pending task");
961 assert!(supervisor.active_tasks().is_empty());
962 }
963
964 #[tokio::test]
965 async fn child_groups_inherit_parent_cancellation() {
966 let supervisor = TaskSupervisor::new();
967 let child = supervisor.group("child");
968 let (started_tx, started_rx) = oneshot::channel();
969
970 let _task_handle = child.spawn_named("test.pending", async move {
971 let _ = started_tx.send(());
972 futures::future::pending::<()>().await;
973 });
974
975 started_rx.await.expect("task should start");
976 supervisor.request_cancellation();
977 child
978 .wait_for_idle(Duration::from_millis(50))
979 .await
980 .expect("child tasks should stop when parent is cancelled");
981 }
982
983 #[tokio::test]
984 async fn wait_for_idle_times_out_and_force_abort_reports_tasks() {
985 let supervisor = TaskSupervisor::new();
986 let (started_tx, started_rx) = oneshot::channel();
987
988 let _task_handle = supervisor.spawn_named("test.pending", async move {
989 let _ = started_tx.send(());
990 futures::future::pending::<()>().await;
991 });
992
993 started_rx.await.expect("task should start");
994 let timeout = supervisor.wait_for_idle(Duration::from_millis(10)).await;
995 assert!(matches!(timeout, Err(TaskSupervisionError::Timeout { .. })));
996
997 let abort = supervisor.force_abort_remaining();
998 assert!(matches!(
999 abort,
1000 Err(TaskSupervisionError::ForcedAbort { .. })
1001 ));
1002 assert!(supervisor.active_tasks().is_empty());
1003 }
1004
1005 #[tokio::test]
1006 async fn force_abort_emits_runtime_diagnostic() {
1007 let diagnostics = Arc::new(RuntimeDiagnosticSink::new());
1008 let supervisor = TaskSupervisor::with_diagnostics(diagnostics.clone());
1009 let (started_tx, started_rx) = oneshot::channel();
1010
1011 let _task_handle = supervisor.spawn_named("test.pending", async move {
1012 let _ = started_tx.send(());
1013 futures::future::pending::<()>().await;
1014 });
1015
1016 started_rx.await.expect("task should start");
1017 let mut rx = diagnostics.subscribe();
1018 let abort = supervisor.force_abort_remaining();
1019 assert!(matches!(
1020 abort,
1021 Err(TaskSupervisionError::ForcedAbort { .. })
1022 ));
1023
1024 let diagnostic = rx.try_recv().expect("diagnostic emitted");
1025 assert_eq!(diagnostic.kind, RuntimeDiagnosticKind::SupervisedTaskFailed);
1026 assert_eq!(diagnostic.severity, RuntimeDiagnosticSeverity::Warn);
1027 }
1028
1029 #[test]
1030 fn loom_shutdown_race_does_not_leave_task_registered() {
1031 loom::model(|| {
1032 use loom::sync::atomic::{AtomicBool, Ordering};
1033 use loom::sync::{Arc as LoomArc, Mutex as LoomMutex};
1034 use loom::thread;
1035
1036 let active = LoomArc::new(LoomMutex::new(Vec::<u8>::new()));
1037 let cancelled = LoomArc::new(AtomicBool::new(false));
1038
1039 let register_active = LoomArc::clone(&active);
1040 let register_cancelled = LoomArc::clone(&cancelled);
1041 let register = thread::spawn(move || {
1042 {
1043 let mut tasks = register_active.lock().unwrap();
1044 tasks.push(1);
1045 }
1046 if register_cancelled.load(Ordering::Acquire) {
1047 let mut tasks = register_active.lock().unwrap();
1048 tasks.retain(|task| *task != 1);
1049 }
1050 });
1051
1052 let shutdown_active = LoomArc::clone(&active);
1053 let shutdown_cancelled = LoomArc::clone(&cancelled);
1054 let shutdown = thread::spawn(move || {
1055 shutdown_cancelled.store(true, Ordering::Release);
1056 let mut tasks = shutdown_active.lock().unwrap();
1057 tasks.retain(|task| *task != 1);
1058 });
1059
1060 register.join().expect("register thread");
1061 shutdown.join().expect("shutdown thread");
1062 assert!(
1063 active.lock().unwrap().is_empty(),
1064 "task bookkeeping should not leak active entries across shutdown races"
1065 );
1066 });
1067 }
1068
1069 #[test]
1070 fn loom_shutdown_token_propagation_reaches_child() {
1071 loom::model(|| {
1072 use loom::sync::atomic::{AtomicBool, Ordering};
1073 use loom::sync::Arc as LoomArc;
1074 use loom::thread;
1075
1076 let cancelled = LoomArc::new(AtomicBool::new(false));
1077 let child_observed = LoomArc::new(AtomicBool::new(false));
1078
1079 let child = {
1080 let cancelled = cancelled.clone();
1081 let child_observed = child_observed.clone();
1082 thread::spawn(move || {
1083 while !cancelled.load(Ordering::Acquire) {
1084 thread::yield_now();
1085 }
1086 child_observed.store(true, Ordering::Release);
1087 })
1088 };
1089
1090 let parent = {
1091 let cancelled = cancelled.clone();
1092 thread::spawn(move || {
1093 cancelled.store(true, Ordering::Release);
1094 })
1095 };
1096
1097 parent.join().expect("parent joins");
1098 child.join().expect("child joins");
1099
1100 assert!(
1101 child_observed.load(Ordering::Acquire),
1102 "child cancellation observer must see parent-driven shutdown"
1103 );
1104 });
1105 }
1106}