dynamo_runtime/utils/tasks/critical.rs
1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16//! Utilities for handling tasks.
17
18use anyhow::{Context, Result};
19use std::future::Future;
20use tokio::runtime::Handle;
21use tokio::sync::oneshot;
22use tokio::task::JoinHandle;
23use tokio_util::sync::CancellationToken;
24
25/// Type alias for a critical task handler function.
26///
27/// The handler receives a [CancellationToken] and returns a [Future] that resolves to [Result<()>].
28/// The task should monitor the cancellation token and gracefully shut down when it's cancelled.
29pub type CriticalTaskHandler<Fut> = dyn FnOnce(CancellationToken) -> Fut + Send + 'static;
30
31/// The [CriticalTaskExecutionHandle] is a handle for a critical task that is expected to
32/// complete successfully. This handle provides two cancellation mechanisms:
33///
34/// 1. **Critical Failure**: If the task returns an error or panics, the parent cancellation
35/// token is triggered immediately via a monitoring task that detects failures.
36///
37/// 2. **Graceful Shutdown**: The task can be gracefully shut down via its child token,
38/// allowing it to complete cleanly without triggering system-wide cancellation.
39///
40/// This is useful for ensuring that critical detached tasks either complete successfully
41/// or trigger appropriate shutdown procedures when they fail.
42pub struct CriticalTaskExecutionHandle {
43 monitor_task: JoinHandle<()>,
44 graceful_shutdown_token: CancellationToken,
45 result_receiver: Option<oneshot::Receiver<Result<()>>>,
46 detached: bool,
47}
48
49impl CriticalTaskExecutionHandle {
50 pub fn new<Fut>(
51 task_fn: impl FnOnce(CancellationToken) -> Fut + Send + 'static,
52 parent_token: CancellationToken,
53 description: &str,
54 ) -> Result<Self>
55 where
56 Fut: Future<Output = Result<()>> + Send + 'static,
57 {
58 Self::new_with_runtime(task_fn, parent_token, description, &Handle::try_current()?)
59 }
60
61 /// Create a new [CriticalTaskExecutionHandle] for a critical task.
62 ///
63 /// # Arguments
64 /// * `task_fn` - A function that takes a cancellation token and returns the critical task future
65 /// * `parent_token` - Token that will be cancelled if this critical task fails
66 /// * `description` - Description for logging purposes
67 /// * `runtime` - The runtime to use for the task.
68 pub fn new_with_runtime<Fut>(
69 task_fn: impl FnOnce(CancellationToken) -> Fut + Send + 'static,
70 parent_token: CancellationToken,
71 description: &str,
72 runtime: &Handle,
73 ) -> Result<Self>
74 where
75 Fut: Future<Output = Result<()>> + Send + 'static,
76 {
77 let graceful_shutdown_token = parent_token.child_token();
78 let description = description.to_string();
79 let parent_token_clone = parent_token.clone();
80
81 // Create channel for communicating results from monitor to handle
82 let (result_sender, result_receiver) = oneshot::channel();
83
84 let graceful_shutdown_token_clone = graceful_shutdown_token.clone();
85 let description_clone = description.to_string();
86 let task = runtime.spawn(async move {
87 let future = task_fn(graceful_shutdown_token_clone);
88
89 match future.await {
90 Ok(()) => {
91 tracing::debug!(
92 "Critical task '{}' completed successfully",
93 description_clone
94 );
95 Ok(())
96 }
97 Err(e) => {
98 tracing::error!("Critical task '{}' failed: {:#}", description_clone, e);
99 Err(e.context(format!("Critical task '{}' failed", description_clone)))
100 }
101 }
102 });
103
104 // Spawn monitor task that immediately joins the main task and detects failures
105 let monitor_task = {
106 let main_task_handle = task;
107 let parent_token_monitor = parent_token_clone.clone();
108 let description_monitor = description.clone();
109
110 runtime.spawn(async move {
111 let result = match main_task_handle.await {
112 Ok(task_result) => {
113 // Task completed normally (success or error)
114 if task_result.is_err() {
115 // Error - trigger parent cancellation immediately
116 parent_token_monitor.cancel();
117 }
118 task_result
119 }
120 Err(join_error) => {
121 // Task panicked - handle immediately
122 if join_error.is_panic() {
123 let panic_msg = if let Ok(reason) = join_error.try_into_panic() {
124 if let Some(s) = reason.downcast_ref::<String>() {
125 s.clone()
126 } else if let Some(s) = reason.downcast_ref::<&str>() {
127 s.to_string()
128 } else {
129 "Unknown panic".to_string()
130 }
131 } else {
132 "Panic occurred but reason unavailable".to_string()
133 };
134
135 tracing::error!(
136 "Critical task '{}' panicked: {}",
137 description_monitor,
138 panic_msg
139 );
140 parent_token_monitor.cancel(); // Trigger parent cancellation immediately
141 Err(anyhow::anyhow!(
142 "Critical task '{}' panicked: {}",
143 description_monitor,
144 panic_msg
145 ))
146 } else {
147 parent_token_monitor.cancel();
148 Err(anyhow::anyhow!(
149 "Failed to join critical task '{}': {}",
150 description_monitor,
151 join_error
152 ))
153 }
154 }
155 };
156
157 // Send result to handle (ignore if receiver dropped)
158 let _ = result_sender.send(result);
159 })
160 };
161
162 Ok(Self {
163 monitor_task,
164 graceful_shutdown_token,
165 result_receiver: Some(result_receiver),
166 detached: false,
167 })
168 }
169
170 /// Check if the task awaiting on the [Server]s background event loop has finished.
171 pub fn is_finished(&self) -> bool {
172 self.monitor_task.is_finished()
173 }
174
175 /// Check if the server's event loop has been cancelled.
176 pub fn is_cancelled(&self) -> bool {
177 self.graceful_shutdown_token.is_cancelled()
178 }
179
180 /// Gracefully cancel this critical task without triggering system-wide shutdown.
181 ///
182 /// This signals the task to stop processing and exit cleanly. The task should
183 /// monitor its cancellation token and respond appropriately.
184 ///
185 /// This will not propagate to the parent [CancellationToken] unless an error
186 /// occurs during the shutdown process.
187 pub fn cancel(&self) {
188 self.graceful_shutdown_token.cancel();
189 }
190
191 /// Join on the critical task and return its actual result.
192 ///
193 /// This will return:
194 /// - `Ok(())` if the task completed successfully or was gracefully cancelled
195 /// - `Err(...)` if the task failed or panicked, preserving the original error
196 ///
197 /// Note: Both errors and panics trigger parent cancellation immediately via the monitor task.
198 pub async fn join(mut self) -> Result<()> {
199 self.detached = true;
200
201 match self.result_receiver.take().unwrap().await {
202 Ok(task_result) => task_result,
203 Err(_) => {
204 // This should rarely happen - means monitor task was dropped/cancelled
205 Err(anyhow::anyhow!("Critical task monitor was cancelled"))
206 }
207 }
208 }
209
210 /// Detach the task. This allows the task to continue running after the handle is dropped.
211 pub fn detach(mut self) {
212 self.detached = true;
213 }
214}
215
216impl Drop for CriticalTaskExecutionHandle {
217 fn drop(&mut self) {
218 if !self.detached {
219 panic!("Critical task was not detached prior to drop!");
220 }
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227 use std::sync::Arc;
228 use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
229 use std::time::Duration;
230 use tokio::time::timeout;
231
232 #[tokio::test]
233 async fn test_successful_task_completion() {
234 // Test: A critical task that completes successfully without any issues
235 // Verifies:
236 // - Task executes and completes normally
237 // - Result is Ok(())
238 // - Parent token remains uncancelled (no critical failure)
239 // - Task execution side effects occur (work gets done)
240 let parent_token = CancellationToken::new();
241 let completed = Arc::new(AtomicBool::new(false));
242 let completed_clone = completed.clone();
243
244 let handle = CriticalTaskExecutionHandle::new(
245 |_cancel_token| async move {
246 completed_clone.store(true, Ordering::SeqCst);
247 Ok(())
248 },
249 parent_token.clone(),
250 "test-success-task",
251 )
252 .unwrap();
253
254 // Task should complete successfully
255 let result = handle.join().await;
256 assert!(result.is_ok());
257 assert!(completed.load(Ordering::SeqCst));
258 assert!(!parent_token.is_cancelled());
259 }
260
261 #[tokio::test]
262 async fn test_task_failure_cancels_parent_token() {
263 // Test: A critical task that returns an error (critical failure)
264 // Verifies:
265 // - Task error is properly propagated to caller
266 // - Parent cancellation token is triggered (critical failure behavior)
267 // - Error message is preserved and includes context
268 // - Demonstrates the "critical" aspect - failures affect the entire system
269 let parent_token = CancellationToken::new();
270
271 let handle = CriticalTaskExecutionHandle::new(
272 |_cancel_token| async move {
273 anyhow::bail!("Critical task failed!");
274 },
275 parent_token.clone(),
276 "test-failure-task",
277 )
278 .unwrap();
279
280 // Task should fail and cancel parent token
281 let result = handle.join().await;
282 assert!(result.is_err());
283 let error_msg = result.unwrap_err().to_string();
284 // Check that the error contains either the original message or the context
285 assert!(
286 error_msg.contains("Critical task failed!")
287 || error_msg.contains("Critical task 'test-failure-task' failed"),
288 "Error message should contain failure context: {}",
289 error_msg
290 );
291
292 // Give a moment for the cancellation to propagate
293 tokio::time::sleep(Duration::from_millis(10)).await;
294 assert!(parent_token.is_cancelled());
295 }
296
297 #[tokio::test]
298 async fn test_task_panic_is_caught_and_reported() {
299 // Test: A critical task that panics during execution
300 // Verifies:
301 // - Tokio's JoinHandle catches panics automatically
302 // - Panics are converted to proper Error types
303 // - System doesn't crash, panic is contained
304 // - Error message indicates a panic occurred
305 // - Parent token is cancelled (panic is treated as critical failure)
306 // - Demonstrates panic safety of the critical task system
307 let parent_token = CancellationToken::new();
308
309 let handle = CriticalTaskExecutionHandle::new(
310 |_cancel_token| async move {
311 panic!("Something went terribly wrong!");
312 },
313 parent_token.clone(),
314 "test-panic-task",
315 )
316 .unwrap();
317
318 // Panic should be caught and converted to error
319 let result = handle.join().await;
320 assert!(result.is_err());
321 let error_msg = result.unwrap_err().to_string();
322 assert!(
323 error_msg.contains("panicked") || error_msg.contains("panic"),
324 "Error message should indicate a panic occurred: {}",
325 error_msg
326 );
327
328 // Parent token should be cancelled due to panic (critical failure)
329 assert!(parent_token.is_cancelled());
330 }
331
332 #[tokio::test]
333 async fn test_graceful_shutdown_via_cancellation_token() {
334 // Test: A long-running task that responds to graceful shutdown signals
335 // Verifies:
336 // - Task can monitor its cancellation token and stop early
337 // - Graceful cancellation does NOT trigger parent token cancellation
338 // - Task can do partial work before stopping
339 // - handle.cancel() triggers the child token, not parent token
340 // - Demonstrates proper shutdown patterns for long-running tasks
341 let parent_token = CancellationToken::new();
342 let work_done = Arc::new(AtomicU32::new(0));
343 let work_done_clone = work_done.clone();
344
345 let handle = CriticalTaskExecutionHandle::new(
346 |cancel_token| async move {
347 for i in 0..100 {
348 if cancel_token.is_cancelled() {
349 break;
350 }
351 work_done_clone.store(i, Ordering::SeqCst);
352 tokio::time::sleep(Duration::from_millis(10)).await;
353 }
354 Ok(())
355 },
356 parent_token.clone(),
357 "test-graceful-shutdown",
358 )
359 .unwrap();
360
361 // Let task do some work
362 tokio::time::sleep(Duration::from_millis(50)).await;
363
364 // Request graceful shutdown
365 handle.cancel();
366
367 // Task should complete gracefully
368 let result = handle.join().await;
369 assert!(result.is_ok());
370
371 // Task should have done some work but not all
372 let final_work = work_done.load(Ordering::SeqCst);
373 assert!(final_work > 0);
374 assert!(final_work < 99);
375
376 // Parent token should NOT be cancelled (graceful shutdown)
377 assert!(!parent_token.is_cancelled());
378 }
379
380 #[tokio::test]
381 async fn test_multiple_critical_tasks_one_failure() {
382 // Test: Multiple critical tasks sharing a parent token, one fails
383 // Verifies:
384 // - Multiple critical tasks can share the same parent cancellation token
385 // - When one critical task fails, all related tasks receive cancellation signal
386 // - Tasks can respond to cancellation and stop gracefully
387 // - System-wide shutdown behavior when critical components fail
388 // - Demonstrates coordinated shutdown of related services
389 let parent_token = CancellationToken::new();
390 let task1_completed = Arc::new(AtomicBool::new(false));
391 let task2_completed = Arc::new(AtomicBool::new(false));
392
393 let task1_completed_clone = task1_completed.clone();
394 let task2_completed_clone = task2_completed.clone();
395
396 // Start two critical tasks
397 let handle1 = CriticalTaskExecutionHandle::new(
398 |cancel_token| async move {
399 for _ in 0..50 {
400 if cancel_token.is_cancelled() {
401 return Ok(());
402 }
403 tokio::time::sleep(Duration::from_millis(10)).await;
404 }
405 task1_completed_clone.store(true, Ordering::SeqCst);
406 Ok(())
407 },
408 parent_token.clone(),
409 "long-running-task",
410 )
411 .unwrap();
412
413 let handle2 = CriticalTaskExecutionHandle::new(
414 |_cancel_token| async move {
415 tokio::time::sleep(Duration::from_millis(100)).await;
416 task2_completed_clone.store(true, Ordering::SeqCst);
417 anyhow::bail!("Task 2 failed!");
418 },
419 parent_token.clone(),
420 "failing-task",
421 )
422 .unwrap();
423
424 // Wait for task 2 to fail
425 let result2 = handle2.join().await;
426 assert!(result2.is_err());
427
428 // Parent token should be cancelled due to task 2 failure
429 assert!(parent_token.is_cancelled());
430
431 // Task 1 should complete early due to cancellation
432 let result1 = handle1.join().await;
433 assert!(result1.is_ok());
434 assert!(!task1_completed.load(Ordering::SeqCst)); // Should not have completed normally
435 }
436
437 #[tokio::test]
438 async fn test_status_checking_methods() {
439 // Test: Non-blocking status checking methods on the handle
440 // Verifies:
441 // - is_finished() accurately reports task completion status
442 // - is_cancelled() accurately reports cancellation status
443 // - Status methods work before and after cancellation
444 // - Methods are non-blocking and can be called multiple times
445 // - Demonstrates monitoring patterns for task supervision
446 let parent_token = CancellationToken::new();
447
448 let handle = CriticalTaskExecutionHandle::new(
449 |cancel_token| async move {
450 tokio::time::sleep(Duration::from_millis(100)).await;
451 if cancel_token.is_cancelled() {
452 return Ok(());
453 }
454 tokio::time::sleep(Duration::from_millis(100)).await;
455 Ok(())
456 },
457 parent_token,
458 "status-test-task",
459 )
460 .unwrap();
461
462 // Initially task should be running
463 assert!(!handle.is_finished());
464 assert!(!handle.is_cancelled());
465
466 // Cancel the task
467 handle.cancel();
468
469 // Task should now be cancelled but may not be finished yet
470 assert!(handle.is_cancelled());
471
472 // Wait for completion
473 let result = handle.join().await;
474 assert!(result.is_ok());
475 }
476
477 #[tokio::test]
478 async fn test_task_with_select_pattern() {
479 // Test: Task using tokio::select! for cancellation-aware operations
480 // Verifies:
481 // - Tasks can use idiomatic tokio patterns with cancellation tokens
482 // - select! allows racing between work and cancellation
483 // - Cancellation interrupts work immediately, not just at check points
484 // - Demonstrates recommended pattern for responsive cancellation
485 // - Shows how to handle cancellation in the middle of async operations
486 let parent_token = CancellationToken::new();
487 let work_completed = Arc::new(AtomicBool::new(false));
488 let work_completed_clone = work_completed.clone();
489
490 let handle = CriticalTaskExecutionHandle::new(
491 |cancel_token| async move {
492 tokio::select! {
493 _ = tokio::time::sleep(Duration::from_millis(200)) => {
494 work_completed_clone.store(true, Ordering::SeqCst);
495 Ok(())
496 }
497 _ = cancel_token.cancelled() => {
498 // Graceful shutdown requested
499 Ok(())
500 }
501 }
502 },
503 parent_token,
504 "select-pattern-task",
505 )
506 .unwrap();
507
508 // Cancel after a short time
509 tokio::time::sleep(Duration::from_millis(50)).await;
510 handle.cancel();
511
512 let result = handle.join().await;
513 assert!(result.is_ok());
514 assert!(!work_completed.load(Ordering::SeqCst)); // Should not have completed the work
515 }
516
517 #[tokio::test]
518 async fn test_timeout_behavior() {
519 // Test: External timeout vs task failure distinction
520 // Verifies:
521 // - External timeouts don't trigger parent token cancellation
522 // - Tasks continue running in background even after timeout
523 // - Difference between "waiting timeout" and "task failure"
524 // - Client-side timeout vs server-side failure handling
525 // - Demonstrates that join() timeout != critical task failure
526 let parent_token = CancellationToken::new();
527
528 let handle = CriticalTaskExecutionHandle::new(
529 |_cancel_token| async move {
530 // A task that takes a long time
531 tokio::time::sleep(Duration::from_secs(10)).await;
532 Ok(())
533 },
534 parent_token,
535 "long-task",
536 )
537 .unwrap();
538
539 // Test with timeout
540 let result = timeout(Duration::from_millis(100), handle.join()).await;
541 assert!(result.is_err()); // Should timeout
542
543 // The parent token should NOT be cancelled since task didn't fail
544 // (it's still running in the background, but we timed out waiting for it)
545 }
546
547 #[tokio::test]
548 async fn test_panic_triggers_immediate_parent_cancellation() {
549 // Test: Verify that panics trigger parent cancellation immediately via monitor task
550 // Verifies:
551 // - Monitor task detects panics immediately when they occur
552 // - Parent token cancellation happens immediately, not on join()
553 // - System shutdown is triggered as soon as critical task panics
554 // - Demonstrates true "critical task" behavior with immediate failure propagation
555 let parent_token = CancellationToken::new();
556
557 let handle = CriticalTaskExecutionHandle::new(
558 |_cancel_token| async move {
559 tokio::time::sleep(Duration::from_millis(50)).await;
560 panic!("Critical failure!");
561 },
562 parent_token.clone(),
563 "immediate-panic-task",
564 )
565 .unwrap();
566
567 // Wait for the panic to be detected by monitor task
568 tokio::time::sleep(Duration::from_millis(100)).await;
569
570 // Parent token should be cancelled immediately via monitor task
571 assert!(
572 parent_token.is_cancelled(),
573 "Parent token should be cancelled immediately when critical task panics"
574 );
575 assert!(handle.join().await.is_err());
576 }
577
578 #[tokio::test]
579 async fn test_error_triggers_immediate_parent_cancellation() {
580 // Test: Verify that regular errors also trigger parent cancellation immediately
581 // Verifies:
582 // - Parent token cancellation happens immediately when task returns error
583 // - No need to call join() for critical failure detection
584 // - Both panics AND regular errors trigger immediate system shutdown
585 // - Demonstrates consistent critical failure behavior
586 let parent_token = CancellationToken::new();
587
588 let handle = CriticalTaskExecutionHandle::new(
589 |_cancel_token| async move {
590 tokio::time::sleep(Duration::from_millis(50)).await;
591 anyhow::bail!("Critical error!");
592 },
593 parent_token.clone(),
594 "immediate-error-task",
595 )
596 .unwrap();
597
598 // Don't call join() - just wait for the error to be detected
599 tokio::time::sleep(Duration::from_millis(100)).await;
600
601 // Parent token should be cancelled even though we didn't call join()
602 assert!(
603 parent_token.is_cancelled(),
604 "Parent token should be cancelled immediately when critical task errors"
605 );
606 assert!(handle.join().await.is_err());
607 }
608
609 #[tokio::test]
610 #[should_panic]
611 async fn test_task_detach() {
612 // Dropping without detaching should panic
613 let parent_token = CancellationToken::new();
614 let _handle = CriticalTaskExecutionHandle::new(
615 |_cancel_token| async move { Ok(()) },
616 parent_token,
617 "test-detach-task",
618 )
619 .unwrap();
620
621 // Dropping without detaching should panic
622 }
623}