dynamo_runtime/utils/tasks/critical.rs
1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Utilities for handling tasks.
5
6use anyhow::{Context, Result};
7use std::future::Future;
8use tokio::runtime::Handle;
9use tokio::sync::oneshot;
10use tokio::task::JoinHandle;
11use tokio_util::sync::CancellationToken;
12
13/// Type alias for a critical task handler function.
14///
15/// The handler receives a [CancellationToken] and returns a [Future] that resolves to [Result<()>].
16/// The task should monitor the cancellation token and gracefully shut down when it's cancelled.
17pub type CriticalTaskHandler<Fut> = dyn FnOnce(CancellationToken) -> Fut + Send + 'static;
18
19/// The [CriticalTaskExecutionHandle] is a handle for a critical task that is expected to
20/// complete successfully. This handle provides two cancellation mechanisms:
21///
22/// 1. **Critical Failure**: If the task returns an error or panics, the parent cancellation
23/// token is triggered immediately via a monitoring task that detects failures.
24///
25/// 2. **Graceful Shutdown**: The task can be gracefully shut down via its child token,
26/// allowing it to complete cleanly without triggering system-wide cancellation.
27///
28/// This is useful for ensuring that critical detached tasks either complete successfully
29/// or trigger appropriate shutdown procedures when they fail.
30pub struct CriticalTaskExecutionHandle {
31 monitor_task: JoinHandle<()>,
32 graceful_shutdown_token: CancellationToken,
33 result_receiver: Option<oneshot::Receiver<Result<()>>>,
34 detached: bool,
35}
36
37impl CriticalTaskExecutionHandle {
38 pub fn new<Fut>(
39 task_fn: impl FnOnce(CancellationToken) -> Fut + Send + 'static,
40 parent_token: CancellationToken,
41 description: &str,
42 ) -> Result<Self>
43 where
44 Fut: Future<Output = Result<()>> + Send + 'static,
45 {
46 Self::new_with_runtime(task_fn, parent_token, description, &Handle::try_current()?)
47 }
48
49 /// Create a new [CriticalTaskExecutionHandle] for a critical task.
50 ///
51 /// # Arguments
52 /// * `task_fn` - A function that takes a cancellation token and returns the critical task future
53 /// * `parent_token` - Token that will be cancelled if this critical task fails
54 /// * `description` - Description for logging purposes
55 /// * `runtime` - The runtime to use for the task.
56 pub fn new_with_runtime<Fut>(
57 task_fn: impl FnOnce(CancellationToken) -> Fut + Send + 'static,
58 parent_token: CancellationToken,
59 description: &str,
60 runtime: &Handle,
61 ) -> Result<Self>
62 where
63 Fut: Future<Output = Result<()>> + Send + 'static,
64 {
65 let graceful_shutdown_token = parent_token.child_token();
66 let description = description.to_string();
67 let parent_token_clone = parent_token.clone();
68
69 // Create channel for communicating results from monitor to handle
70 let (result_sender, result_receiver) = oneshot::channel();
71
72 let graceful_shutdown_token_clone = graceful_shutdown_token.clone();
73 let description_clone = description.to_string();
74 let task = runtime.spawn(async move {
75 let future = task_fn(graceful_shutdown_token_clone);
76
77 match future.await {
78 Ok(()) => {
79 tracing::debug!(
80 "Critical task '{}' completed successfully",
81 description_clone
82 );
83 Ok(())
84 }
85 Err(e) => {
86 tracing::error!("Critical task '{}' failed: {:#}", description_clone, e);
87 Err(e.context(format!("Critical task '{}' failed", description_clone)))
88 }
89 }
90 });
91
92 // Spawn monitor task that immediately joins the main task and detects failures
93 let monitor_task = {
94 let main_task_handle = task;
95 let parent_token_monitor = parent_token_clone.clone();
96 let description_monitor = description.clone();
97
98 runtime.spawn(async move {
99 let result = match main_task_handle.await {
100 Ok(task_result) => {
101 // Task completed normally (success or error)
102 if task_result.is_err() {
103 // Error - trigger parent cancellation immediately
104 parent_token_monitor.cancel();
105 }
106 task_result
107 }
108 Err(join_error) => {
109 // Task panicked - handle immediately
110 if join_error.is_panic() {
111 let panic_msg = if let Ok(reason) = join_error.try_into_panic() {
112 if let Some(s) = reason.downcast_ref::<String>() {
113 s.clone()
114 } else if let Some(s) = reason.downcast_ref::<&str>() {
115 s.to_string()
116 } else {
117 "Unknown panic".to_string()
118 }
119 } else {
120 "Panic occurred but reason unavailable".to_string()
121 };
122
123 tracing::error!(
124 "Critical task '{}' panicked: {}",
125 description_monitor,
126 panic_msg
127 );
128 parent_token_monitor.cancel(); // Trigger parent cancellation immediately
129 Err(anyhow::anyhow!(
130 "Critical task '{}' panicked: {}",
131 description_monitor,
132 panic_msg
133 ))
134 } else {
135 parent_token_monitor.cancel();
136 Err(anyhow::anyhow!(
137 "Failed to join critical task '{}': {}",
138 description_monitor,
139 join_error
140 ))
141 }
142 }
143 };
144
145 // Send result to handle (ignore if receiver dropped)
146 let _ = result_sender.send(result);
147 })
148 };
149
150 Ok(Self {
151 monitor_task,
152 graceful_shutdown_token,
153 result_receiver: Some(result_receiver),
154 detached: false,
155 })
156 }
157
158 /// Check if the task awaiting on the [Server]s background event loop has finished.
159 pub fn is_finished(&self) -> bool {
160 self.monitor_task.is_finished()
161 }
162
163 /// Check if the server's event loop has been cancelled.
164 pub fn is_cancelled(&self) -> bool {
165 self.graceful_shutdown_token.is_cancelled()
166 }
167
168 /// Gracefully cancel this critical task without triggering system-wide shutdown.
169 ///
170 /// This signals the task to stop processing and exit cleanly. The task should
171 /// monitor its cancellation token and respond appropriately.
172 ///
173 /// This will not propagate to the parent [CancellationToken] unless an error
174 /// occurs during the shutdown process.
175 pub fn cancel(&self) {
176 self.graceful_shutdown_token.cancel();
177 }
178
179 /// Join on the critical task and return its actual result.
180 ///
181 /// This will return:
182 /// - `Ok(())` if the task completed successfully or was gracefully cancelled
183 /// - `Err(...)` if the task failed or panicked, preserving the original error
184 ///
185 /// Note: Both errors and panics trigger parent cancellation immediately via the monitor task.
186 pub async fn join(mut self) -> Result<()> {
187 self.detached = true;
188
189 match self.result_receiver.take().unwrap().await {
190 Ok(task_result) => task_result,
191 Err(_) => {
192 // This should rarely happen - means monitor task was dropped/cancelled
193 Err(anyhow::anyhow!("Critical task monitor was cancelled"))
194 }
195 }
196 }
197
198 /// Detach the task. This allows the task to continue running after the handle is dropped.
199 pub fn detach(mut self) {
200 self.detached = true;
201 }
202}
203
204impl Drop for CriticalTaskExecutionHandle {
205 fn drop(&mut self) {
206 if !self.detached {
207 panic!("Critical task was not detached prior to drop!");
208 }
209 }
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215 use std::sync::Arc;
216 use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
217 use std::time::Duration;
218 use tokio::time::timeout;
219
220 #[tokio::test]
221 async fn test_successful_task_completion() {
222 // Test: A critical task that completes successfully without any issues
223 // Verifies:
224 // - Task executes and completes normally
225 // - Result is Ok(())
226 // - Parent token remains uncancelled (no critical failure)
227 // - Task execution side effects occur (work gets done)
228 let parent_token = CancellationToken::new();
229 let completed = Arc::new(AtomicBool::new(false));
230 let completed_clone = completed.clone();
231
232 let handle = CriticalTaskExecutionHandle::new(
233 |_cancel_token| async move {
234 completed_clone.store(true, Ordering::SeqCst);
235 Ok(())
236 },
237 parent_token.clone(),
238 "test-success-task",
239 )
240 .unwrap();
241
242 // Task should complete successfully
243 let result = handle.join().await;
244 assert!(result.is_ok());
245 assert!(completed.load(Ordering::SeqCst));
246 assert!(!parent_token.is_cancelled());
247 }
248
249 #[tokio::test]
250 async fn test_task_failure_cancels_parent_token() {
251 // Test: A critical task that returns an error (critical failure)
252 // Verifies:
253 // - Task error is properly propagated to caller
254 // - Parent cancellation token is triggered (critical failure behavior)
255 // - Error message is preserved and includes context
256 // - Demonstrates the "critical" aspect - failures affect the entire system
257 let parent_token = CancellationToken::new();
258
259 let handle = CriticalTaskExecutionHandle::new(
260 |_cancel_token| async move {
261 anyhow::bail!("Critical task failed!");
262 },
263 parent_token.clone(),
264 "test-failure-task",
265 )
266 .unwrap();
267
268 // Task should fail and cancel parent token
269 let result = handle.join().await;
270 assert!(result.is_err());
271 let error_msg = result.unwrap_err().to_string();
272 // Check that the error contains either the original message or the context
273 assert!(
274 error_msg.contains("Critical task failed!")
275 || error_msg.contains("Critical task 'test-failure-task' failed"),
276 "Error message should contain failure context: {}",
277 error_msg
278 );
279
280 // Give a moment for the cancellation to propagate
281 tokio::time::sleep(Duration::from_millis(10)).await;
282 assert!(parent_token.is_cancelled());
283 }
284
285 #[tokio::test]
286 async fn test_task_panic_is_caught_and_reported() {
287 // Test: A critical task that panics during execution
288 // Verifies:
289 // - Tokio's JoinHandle catches panics automatically
290 // - Panics are converted to proper Error types
291 // - System doesn't crash, panic is contained
292 // - Error message indicates a panic occurred
293 // - Parent token is cancelled (panic is treated as critical failure)
294 // - Demonstrates panic safety of the critical task system
295 let parent_token = CancellationToken::new();
296
297 let handle = CriticalTaskExecutionHandle::new(
298 |_cancel_token| async move {
299 panic!("Something went terribly wrong!");
300 },
301 parent_token.clone(),
302 "test-panic-task",
303 )
304 .unwrap();
305
306 // Panic should be caught and converted to error
307 let result = handle.join().await;
308 assert!(result.is_err());
309 let error_msg = result.unwrap_err().to_string();
310 assert!(
311 error_msg.contains("panicked") || error_msg.contains("panic"),
312 "Error message should indicate a panic occurred: {}",
313 error_msg
314 );
315
316 // Parent token should be cancelled due to panic (critical failure)
317 assert!(parent_token.is_cancelled());
318 }
319
320 #[tokio::test]
321 async fn test_graceful_shutdown_via_cancellation_token() {
322 // Test: A long-running task that responds to graceful shutdown signals
323 // Verifies:
324 // - Task can monitor its cancellation token and stop early
325 // - Graceful cancellation does NOT trigger parent token cancellation
326 // - Task can do partial work before stopping
327 // - handle.cancel() triggers the child token, not parent token
328 // - Demonstrates proper shutdown patterns for long-running tasks
329 let parent_token = CancellationToken::new();
330 let work_done = Arc::new(AtomicU32::new(0));
331 let work_done_clone = work_done.clone();
332
333 let handle = CriticalTaskExecutionHandle::new(
334 |cancel_token| async move {
335 for i in 0..100 {
336 if cancel_token.is_cancelled() {
337 break;
338 }
339 work_done_clone.store(i, Ordering::SeqCst);
340 tokio::time::sleep(Duration::from_millis(10)).await;
341 }
342 Ok(())
343 },
344 parent_token.clone(),
345 "test-graceful-shutdown",
346 )
347 .unwrap();
348
349 // Let task do some work
350 tokio::time::sleep(Duration::from_millis(50)).await;
351
352 // Request graceful shutdown
353 handle.cancel();
354
355 // Task should complete gracefully
356 let result = handle.join().await;
357 assert!(result.is_ok());
358
359 // Task should have done some work but not all
360 let final_work = work_done.load(Ordering::SeqCst);
361 assert!(final_work > 0);
362 assert!(final_work < 99);
363
364 // Parent token should NOT be cancelled (graceful shutdown)
365 assert!(!parent_token.is_cancelled());
366 }
367
368 #[tokio::test]
369 async fn test_multiple_critical_tasks_one_failure() {
370 // Test: Multiple critical tasks sharing a parent token, one fails
371 // Verifies:
372 // - Multiple critical tasks can share the same parent cancellation token
373 // - When one critical task fails, all related tasks receive cancellation signal
374 // - Tasks can respond to cancellation and stop gracefully
375 // - System-wide shutdown behavior when critical components fail
376 // - Demonstrates coordinated shutdown of related services
377 let parent_token = CancellationToken::new();
378 let task1_completed = Arc::new(AtomicBool::new(false));
379 let task2_completed = Arc::new(AtomicBool::new(false));
380
381 let task1_completed_clone = task1_completed.clone();
382 let task2_completed_clone = task2_completed.clone();
383
384 // Start two critical tasks
385 let handle1 = CriticalTaskExecutionHandle::new(
386 |cancel_token| async move {
387 for _ in 0..50 {
388 if cancel_token.is_cancelled() {
389 return Ok(());
390 }
391 tokio::time::sleep(Duration::from_millis(10)).await;
392 }
393 task1_completed_clone.store(true, Ordering::SeqCst);
394 Ok(())
395 },
396 parent_token.clone(),
397 "long-running-task",
398 )
399 .unwrap();
400
401 let handle2 = CriticalTaskExecutionHandle::new(
402 |_cancel_token| async move {
403 tokio::time::sleep(Duration::from_millis(100)).await;
404 task2_completed_clone.store(true, Ordering::SeqCst);
405 anyhow::bail!("Task 2 failed!");
406 },
407 parent_token.clone(),
408 "failing-task",
409 )
410 .unwrap();
411
412 // Wait for task 2 to fail
413 let result2 = handle2.join().await;
414 assert!(result2.is_err());
415
416 // Parent token should be cancelled due to task 2 failure
417 assert!(parent_token.is_cancelled());
418
419 // Task 1 should complete early due to cancellation
420 let result1 = handle1.join().await;
421 assert!(result1.is_ok());
422 assert!(!task1_completed.load(Ordering::SeqCst)); // Should not have completed normally
423 }
424
425 #[tokio::test]
426 async fn test_status_checking_methods() {
427 // Test: Non-blocking status checking methods on the handle
428 // Verifies:
429 // - is_finished() accurately reports task completion status
430 // - is_cancelled() accurately reports cancellation status
431 // - Status methods work before and after cancellation
432 // - Methods are non-blocking and can be called multiple times
433 // - Demonstrates monitoring patterns for task supervision
434 let parent_token = CancellationToken::new();
435
436 let handle = CriticalTaskExecutionHandle::new(
437 |cancel_token| async move {
438 tokio::time::sleep(Duration::from_millis(100)).await;
439 if cancel_token.is_cancelled() {
440 return Ok(());
441 }
442 tokio::time::sleep(Duration::from_millis(100)).await;
443 Ok(())
444 },
445 parent_token,
446 "status-test-task",
447 )
448 .unwrap();
449
450 // Initially task should be running
451 assert!(!handle.is_finished());
452 assert!(!handle.is_cancelled());
453
454 // Cancel the task
455 handle.cancel();
456
457 // Task should now be cancelled but may not be finished yet
458 assert!(handle.is_cancelled());
459
460 // Wait for completion
461 let result = handle.join().await;
462 assert!(result.is_ok());
463 }
464
465 #[tokio::test]
466 async fn test_task_with_select_pattern() {
467 // Test: Task using tokio::select! for cancellation-aware operations
468 // Verifies:
469 // - Tasks can use idiomatic tokio patterns with cancellation tokens
470 // - select! allows racing between work and cancellation
471 // - Cancellation interrupts work immediately, not just at check points
472 // - Demonstrates recommended pattern for responsive cancellation
473 // - Shows how to handle cancellation in the middle of async operations
474 let parent_token = CancellationToken::new();
475 let work_completed = Arc::new(AtomicBool::new(false));
476 let work_completed_clone = work_completed.clone();
477
478 let handle = CriticalTaskExecutionHandle::new(
479 |cancel_token| async move {
480 tokio::select! {
481 _ = tokio::time::sleep(Duration::from_millis(200)) => {
482 work_completed_clone.store(true, Ordering::SeqCst);
483 Ok(())
484 }
485 _ = cancel_token.cancelled() => {
486 // Graceful shutdown requested
487 Ok(())
488 }
489 }
490 },
491 parent_token,
492 "select-pattern-task",
493 )
494 .unwrap();
495
496 // Cancel after a short time
497 tokio::time::sleep(Duration::from_millis(50)).await;
498 handle.cancel();
499
500 let result = handle.join().await;
501 assert!(result.is_ok());
502 assert!(!work_completed.load(Ordering::SeqCst)); // Should not have completed the work
503 }
504
505 #[tokio::test]
506 async fn test_timeout_behavior() {
507 // Test: External timeout vs task failure distinction
508 // Verifies:
509 // - External timeouts don't trigger parent token cancellation
510 // - Tasks continue running in background even after timeout
511 // - Difference between "waiting timeout" and "task failure"
512 // - Client-side timeout vs server-side failure handling
513 // - Demonstrates that join() timeout != critical task failure
514 let parent_token = CancellationToken::new();
515
516 let handle = CriticalTaskExecutionHandle::new(
517 |_cancel_token| async move {
518 // A task that takes a long time
519 tokio::time::sleep(Duration::from_secs(10)).await;
520 Ok(())
521 },
522 parent_token,
523 "long-task",
524 )
525 .unwrap();
526
527 // Test with timeout
528 let result = timeout(Duration::from_millis(100), handle.join()).await;
529 assert!(result.is_err()); // Should timeout
530
531 // The parent token should NOT be cancelled since task didn't fail
532 // (it's still running in the background, but we timed out waiting for it)
533 }
534
535 #[tokio::test]
536 async fn test_panic_triggers_immediate_parent_cancellation() {
537 // Test: Verify that panics trigger parent cancellation immediately via monitor task
538 // Verifies:
539 // - Monitor task detects panics immediately when they occur
540 // - Parent token cancellation happens immediately, not on join()
541 // - System shutdown is triggered as soon as critical task panics
542 // - Demonstrates true "critical task" behavior with immediate failure propagation
543 let parent_token = CancellationToken::new();
544
545 let handle = CriticalTaskExecutionHandle::new(
546 |_cancel_token| async move {
547 tokio::time::sleep(Duration::from_millis(50)).await;
548 panic!("Critical failure!");
549 },
550 parent_token.clone(),
551 "immediate-panic-task",
552 )
553 .unwrap();
554
555 // Wait for the panic to be detected by monitor task
556 tokio::time::sleep(Duration::from_millis(100)).await;
557
558 // Parent token should be cancelled immediately via monitor task
559 assert!(
560 parent_token.is_cancelled(),
561 "Parent token should be cancelled immediately when critical task panics"
562 );
563 assert!(handle.join().await.is_err());
564 }
565
566 #[tokio::test]
567 async fn test_error_triggers_immediate_parent_cancellation() {
568 // Test: Verify that regular errors also trigger parent cancellation immediately
569 // Verifies:
570 // - Parent token cancellation happens immediately when task returns error
571 // - No need to call join() for critical failure detection
572 // - Both panics AND regular errors trigger immediate system shutdown
573 // - Demonstrates consistent critical failure behavior
574 let parent_token = CancellationToken::new();
575
576 let handle = CriticalTaskExecutionHandle::new(
577 |_cancel_token| async move {
578 tokio::time::sleep(Duration::from_millis(50)).await;
579 anyhow::bail!("Critical error!");
580 },
581 parent_token.clone(),
582 "immediate-error-task",
583 )
584 .unwrap();
585
586 // Don't call join() - just wait for the error to be detected
587 tokio::time::sleep(Duration::from_millis(100)).await;
588
589 // Parent token should be cancelled even though we didn't call join()
590 assert!(
591 parent_token.is_cancelled(),
592 "Parent token should be cancelled immediately when critical task errors"
593 );
594 assert!(handle.join().await.is_err());
595 }
596
597 #[tokio::test]
598 #[should_panic]
599 async fn test_task_detach() {
600 // Dropping without detaching should panic
601 let parent_token = CancellationToken::new();
602 let _handle = CriticalTaskExecutionHandle::new(
603 |_cancel_token| async move { Ok(()) },
604 parent_token,
605 "test-detach-task",
606 )
607 .unwrap();
608
609 // Dropping without detaching should panic
610 }
611}