Skip to main content

dynamo_runtime/utils/tasks/
critical.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Utilities for handling tasks.
5
6use anyhow::{Context, Result};
7use std::future::Future;
8use tokio::runtime::Handle;
9use tokio::sync::oneshot;
10use tokio::task::JoinHandle;
11use tokio_util::sync::CancellationToken;
12
13/// Type alias for a critical task handler function.
14///
15/// The handler receives a [CancellationToken] and returns a [Future] that resolves to [Result<()>].
16/// The task should monitor the cancellation token and gracefully shut down when it's cancelled.
17pub type CriticalTaskHandler<Fut> = dyn FnOnce(CancellationToken) -> Fut + Send + 'static;
18
19/// The [CriticalTaskExecutionHandle] is a handle for a critical task that is expected to
20/// complete successfully. This handle provides two cancellation mechanisms:
21///
22/// 1. **Critical Failure**: If the task returns an error or panics, the parent cancellation
23///    token is triggered immediately via a monitoring task that detects failures.
24///
25/// 2. **Graceful Shutdown**: The task can be gracefully shut down via its child token,
26///    allowing it to complete cleanly without triggering system-wide cancellation.
27///
28/// This is useful for ensuring that critical detached tasks either complete successfully
29/// or trigger appropriate shutdown procedures when they fail.
30pub struct CriticalTaskExecutionHandle {
31    monitor_task: JoinHandle<()>,
32    graceful_shutdown_token: CancellationToken,
33    result_receiver: Option<oneshot::Receiver<Result<()>>>,
34    detached: bool,
35}
36
37impl CriticalTaskExecutionHandle {
38    pub fn new<Fut>(
39        task_fn: impl FnOnce(CancellationToken) -> Fut + Send + 'static,
40        parent_token: CancellationToken,
41        description: &str,
42    ) -> Result<Self>
43    where
44        Fut: Future<Output = Result<()>> + Send + 'static,
45    {
46        Self::new_with_runtime(task_fn, parent_token, description, &Handle::try_current()?)
47    }
48
49    /// Create a new [CriticalTaskExecutionHandle] for a critical task.
50    ///
51    /// # Arguments
52    /// * `task_fn` - A function that takes a cancellation token and returns the critical task future
53    /// * `parent_token` - Token that will be cancelled if this critical task fails
54    /// * `description` - Description for logging purposes
55    /// * `runtime` - The runtime to use for the task.
56    pub fn new_with_runtime<Fut>(
57        task_fn: impl FnOnce(CancellationToken) -> Fut + Send + 'static,
58        parent_token: CancellationToken,
59        description: &str,
60        runtime: &Handle,
61    ) -> Result<Self>
62    where
63        Fut: Future<Output = Result<()>> + Send + 'static,
64    {
65        let graceful_shutdown_token = parent_token.child_token();
66        let description = description.to_string();
67        let parent_token_clone = parent_token.clone();
68
69        // Create channel for communicating results from monitor to handle
70        let (result_sender, result_receiver) = oneshot::channel();
71
72        let graceful_shutdown_token_clone = graceful_shutdown_token.clone();
73        let description_clone = description.to_string();
74        let task = runtime.spawn(async move {
75            let future = task_fn(graceful_shutdown_token_clone);
76
77            match future.await {
78                Ok(()) => {
79                    tracing::debug!(
80                        "Critical task '{}' completed successfully",
81                        description_clone
82                    );
83                    Ok(())
84                }
85                Err(e) => {
86                    tracing::error!("Critical task '{}' failed: {:#}", description_clone, e);
87                    Err(e.context(format!("Critical task '{}' failed", description_clone)))
88                }
89            }
90        });
91
92        // Spawn monitor task that immediately joins the main task and detects failures
93        let monitor_task = {
94            let main_task_handle = task;
95            let parent_token_monitor = parent_token_clone.clone();
96            let description_monitor = description.clone();
97
98            runtime.spawn(async move {
99                let result = match main_task_handle.await {
100                    Ok(task_result) => {
101                        // Task completed normally (success or error)
102                        if task_result.is_err() {
103                            // Error - trigger parent cancellation immediately
104                            parent_token_monitor.cancel();
105                        }
106                        task_result
107                    }
108                    Err(join_error) => {
109                        // Task panicked - handle immediately
110                        if join_error.is_panic() {
111                            let panic_msg = if let Ok(reason) = join_error.try_into_panic() {
112                                if let Some(s) = reason.downcast_ref::<String>() {
113                                    s.clone()
114                                } else if let Some(s) = reason.downcast_ref::<&str>() {
115                                    s.to_string()
116                                } else {
117                                    "Unknown panic".to_string()
118                                }
119                            } else {
120                                "Panic occurred but reason unavailable".to_string()
121                            };
122
123                            tracing::error!(
124                                "Critical task '{}' panicked: {}",
125                                description_monitor,
126                                panic_msg
127                            );
128                            parent_token_monitor.cancel(); // Trigger parent cancellation immediately
129                            Err(anyhow::anyhow!(
130                                "Critical task '{}' panicked: {}",
131                                description_monitor,
132                                panic_msg
133                            ))
134                        } else {
135                            parent_token_monitor.cancel();
136                            Err(anyhow::anyhow!(
137                                "Failed to join critical task '{}': {}",
138                                description_monitor,
139                                join_error
140                            ))
141                        }
142                    }
143                };
144
145                // Send result to handle (ignore if receiver dropped)
146                let _ = result_sender.send(result);
147            })
148        };
149
150        Ok(Self {
151            monitor_task,
152            graceful_shutdown_token,
153            result_receiver: Some(result_receiver),
154            detached: false,
155        })
156    }
157
158    /// Check if the task awaiting on the [Server]s background event loop has finished.
159    pub fn is_finished(&self) -> bool {
160        self.monitor_task.is_finished()
161    }
162
163    /// Check if the server's event loop has been cancelled.
164    pub fn is_cancelled(&self) -> bool {
165        self.graceful_shutdown_token.is_cancelled()
166    }
167
168    /// Gracefully cancel this critical task without triggering system-wide shutdown.
169    ///
170    /// This signals the task to stop processing and exit cleanly. The task should
171    /// monitor its cancellation token and respond appropriately.
172    ///
173    /// This will not propagate to the parent [CancellationToken] unless an error
174    /// occurs during the shutdown process.
175    pub fn cancel(&self) {
176        self.graceful_shutdown_token.cancel();
177    }
178
179    /// Join on the critical task and return its actual result.
180    ///
181    /// This will return:
182    /// - `Ok(())` if the task completed successfully or was gracefully cancelled
183    /// - `Err(...)` if the task failed or panicked, preserving the original error
184    ///
185    /// Note: Both errors and panics trigger parent cancellation immediately via the monitor task.
186    pub async fn join(mut self) -> Result<()> {
187        self.detached = true;
188
189        match self.result_receiver.take().unwrap().await {
190            Ok(task_result) => task_result,
191            Err(_) => {
192                // This should rarely happen - means monitor task was dropped/cancelled
193                Err(anyhow::anyhow!("Critical task monitor was cancelled"))
194            }
195        }
196    }
197
198    /// Detach the task. This allows the task to continue running after the handle is dropped.
199    pub fn detach(mut self) {
200        self.detached = true;
201    }
202}
203
204impl Drop for CriticalTaskExecutionHandle {
205    fn drop(&mut self) {
206        if !self.detached {
207            panic!("Critical task was not detached prior to drop!");
208        }
209    }
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215    use std::sync::Arc;
216    use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
217    use std::time::Duration;
218    use tokio::time::timeout;
219
220    #[tokio::test]
221    async fn test_successful_task_completion() {
222        // Test: A critical task that completes successfully without any issues
223        // Verifies:
224        // - Task executes and completes normally
225        // - Result is Ok(())
226        // - Parent token remains uncancelled (no critical failure)
227        // - Task execution side effects occur (work gets done)
228        let parent_token = CancellationToken::new();
229        let completed = Arc::new(AtomicBool::new(false));
230        let completed_clone = completed.clone();
231
232        let handle = CriticalTaskExecutionHandle::new(
233            |_cancel_token| async move {
234                completed_clone.store(true, Ordering::SeqCst);
235                Ok(())
236            },
237            parent_token.clone(),
238            "test-success-task",
239        )
240        .unwrap();
241
242        // Task should complete successfully
243        let result = handle.join().await;
244        assert!(result.is_ok());
245        assert!(completed.load(Ordering::SeqCst));
246        assert!(!parent_token.is_cancelled());
247    }
248
249    #[tokio::test]
250    async fn test_task_failure_cancels_parent_token() {
251        // Test: A critical task that returns an error (critical failure)
252        // Verifies:
253        // - Task error is properly propagated to caller
254        // - Parent cancellation token is triggered (critical failure behavior)
255        // - Error message is preserved and includes context
256        // - Demonstrates the "critical" aspect - failures affect the entire system
257        let parent_token = CancellationToken::new();
258
259        let handle = CriticalTaskExecutionHandle::new(
260            |_cancel_token| async move {
261                anyhow::bail!("Critical task failed!");
262            },
263            parent_token.clone(),
264            "test-failure-task",
265        )
266        .unwrap();
267
268        // Task should fail and cancel parent token
269        let result = handle.join().await;
270        assert!(result.is_err());
271        let error_msg = result.unwrap_err().to_string();
272        // Check that the error contains either the original message or the context
273        assert!(
274            error_msg.contains("Critical task failed!")
275                || error_msg.contains("Critical task 'test-failure-task' failed"),
276            "Error message should contain failure context: {}",
277            error_msg
278        );
279
280        // Give a moment for the cancellation to propagate
281        tokio::time::sleep(Duration::from_millis(10)).await;
282        assert!(parent_token.is_cancelled());
283    }
284
285    #[tokio::test]
286    async fn test_task_panic_is_caught_and_reported() {
287        // Test: A critical task that panics during execution
288        // Verifies:
289        // - Tokio's JoinHandle catches panics automatically
290        // - Panics are converted to proper Error types
291        // - System doesn't crash, panic is contained
292        // - Error message indicates a panic occurred
293        // - Parent token is cancelled (panic is treated as critical failure)
294        // - Demonstrates panic safety of the critical task system
295        let parent_token = CancellationToken::new();
296
297        let handle = CriticalTaskExecutionHandle::new(
298            |_cancel_token| async move {
299                panic!("Something went terribly wrong!");
300            },
301            parent_token.clone(),
302            "test-panic-task",
303        )
304        .unwrap();
305
306        // Panic should be caught and converted to error
307        let result = handle.join().await;
308        assert!(result.is_err());
309        let error_msg = result.unwrap_err().to_string();
310        assert!(
311            error_msg.contains("panicked") || error_msg.contains("panic"),
312            "Error message should indicate a panic occurred: {}",
313            error_msg
314        );
315
316        // Parent token should be cancelled due to panic (critical failure)
317        assert!(parent_token.is_cancelled());
318    }
319
320    #[tokio::test]
321    async fn test_graceful_shutdown_via_cancellation_token() {
322        // Test: A long-running task that responds to graceful shutdown signals
323        // Verifies:
324        // - Task can monitor its cancellation token and stop early
325        // - Graceful cancellation does NOT trigger parent token cancellation
326        // - Task can do partial work before stopping
327        // - handle.cancel() triggers the child token, not parent token
328        // - Demonstrates proper shutdown patterns for long-running tasks
329        let parent_token = CancellationToken::new();
330        let work_done = Arc::new(AtomicU32::new(0));
331        let work_done_clone = work_done.clone();
332
333        let handle = CriticalTaskExecutionHandle::new(
334            |cancel_token| async move {
335                for i in 0..100 {
336                    if cancel_token.is_cancelled() {
337                        break;
338                    }
339                    work_done_clone.store(i, Ordering::SeqCst);
340                    tokio::time::sleep(Duration::from_millis(10)).await;
341                }
342                Ok(())
343            },
344            parent_token.clone(),
345            "test-graceful-shutdown",
346        )
347        .unwrap();
348
349        // Let task do some work
350        tokio::time::sleep(Duration::from_millis(50)).await;
351
352        // Request graceful shutdown
353        handle.cancel();
354
355        // Task should complete gracefully
356        let result = handle.join().await;
357        assert!(result.is_ok());
358
359        // Task should have done some work but not all
360        let final_work = work_done.load(Ordering::SeqCst);
361        assert!(final_work > 0);
362        assert!(final_work < 99);
363
364        // Parent token should NOT be cancelled (graceful shutdown)
365        assert!(!parent_token.is_cancelled());
366    }
367
368    #[tokio::test]
369    async fn test_multiple_critical_tasks_one_failure() {
370        // Test: Multiple critical tasks sharing a parent token, one fails
371        // Verifies:
372        // - Multiple critical tasks can share the same parent cancellation token
373        // - When one critical task fails, all related tasks receive cancellation signal
374        // - Tasks can respond to cancellation and stop gracefully
375        // - System-wide shutdown behavior when critical components fail
376        // - Demonstrates coordinated shutdown of related services
377        let parent_token = CancellationToken::new();
378        let task1_completed = Arc::new(AtomicBool::new(false));
379        let task2_completed = Arc::new(AtomicBool::new(false));
380
381        let task1_completed_clone = task1_completed.clone();
382        let task2_completed_clone = task2_completed.clone();
383
384        // Start two critical tasks
385        let handle1 = CriticalTaskExecutionHandle::new(
386            |cancel_token| async move {
387                for _ in 0..50 {
388                    if cancel_token.is_cancelled() {
389                        return Ok(());
390                    }
391                    tokio::time::sleep(Duration::from_millis(10)).await;
392                }
393                task1_completed_clone.store(true, Ordering::SeqCst);
394                Ok(())
395            },
396            parent_token.clone(),
397            "long-running-task",
398        )
399        .unwrap();
400
401        let handle2 = CriticalTaskExecutionHandle::new(
402            |_cancel_token| async move {
403                tokio::time::sleep(Duration::from_millis(100)).await;
404                task2_completed_clone.store(true, Ordering::SeqCst);
405                anyhow::bail!("Task 2 failed!");
406            },
407            parent_token.clone(),
408            "failing-task",
409        )
410        .unwrap();
411
412        // Wait for task 2 to fail
413        let result2 = handle2.join().await;
414        assert!(result2.is_err());
415
416        // Parent token should be cancelled due to task 2 failure
417        assert!(parent_token.is_cancelled());
418
419        // Task 1 should complete early due to cancellation
420        let result1 = handle1.join().await;
421        assert!(result1.is_ok());
422        assert!(!task1_completed.load(Ordering::SeqCst)); // Should not have completed normally
423    }
424
425    #[tokio::test]
426    async fn test_status_checking_methods() {
427        // Test: Non-blocking status checking methods on the handle
428        // Verifies:
429        // - is_finished() accurately reports task completion status
430        // - is_cancelled() accurately reports cancellation status
431        // - Status methods work before and after cancellation
432        // - Methods are non-blocking and can be called multiple times
433        // - Demonstrates monitoring patterns for task supervision
434        let parent_token = CancellationToken::new();
435
436        let handle = CriticalTaskExecutionHandle::new(
437            |cancel_token| async move {
438                tokio::time::sleep(Duration::from_millis(100)).await;
439                if cancel_token.is_cancelled() {
440                    return Ok(());
441                }
442                tokio::time::sleep(Duration::from_millis(100)).await;
443                Ok(())
444            },
445            parent_token,
446            "status-test-task",
447        )
448        .unwrap();
449
450        // Initially task should be running
451        assert!(!handle.is_finished());
452        assert!(!handle.is_cancelled());
453
454        // Cancel the task
455        handle.cancel();
456
457        // Task should now be cancelled but may not be finished yet
458        assert!(handle.is_cancelled());
459
460        // Wait for completion
461        let result = handle.join().await;
462        assert!(result.is_ok());
463    }
464
465    #[tokio::test]
466    async fn test_task_with_select_pattern() {
467        // Test: Task using tokio::select! for cancellation-aware operations
468        // Verifies:
469        // - Tasks can use idiomatic tokio patterns with cancellation tokens
470        // - select! allows racing between work and cancellation
471        // - Cancellation interrupts work immediately, not just at check points
472        // - Demonstrates recommended pattern for responsive cancellation
473        // - Shows how to handle cancellation in the middle of async operations
474        let parent_token = CancellationToken::new();
475        let work_completed = Arc::new(AtomicBool::new(false));
476        let work_completed_clone = work_completed.clone();
477
478        let handle = CriticalTaskExecutionHandle::new(
479            |cancel_token| async move {
480                tokio::select! {
481                    _ = tokio::time::sleep(Duration::from_millis(200)) => {
482                        work_completed_clone.store(true, Ordering::SeqCst);
483                        Ok(())
484                    }
485                    _ = cancel_token.cancelled() => {
486                        // Graceful shutdown requested
487                        Ok(())
488                    }
489                }
490            },
491            parent_token,
492            "select-pattern-task",
493        )
494        .unwrap();
495
496        // Cancel after a short time
497        tokio::time::sleep(Duration::from_millis(50)).await;
498        handle.cancel();
499
500        let result = handle.join().await;
501        assert!(result.is_ok());
502        assert!(!work_completed.load(Ordering::SeqCst)); // Should not have completed the work
503    }
504
505    #[tokio::test]
506    async fn test_timeout_behavior() {
507        // Test: External timeout vs task failure distinction
508        // Verifies:
509        // - External timeouts don't trigger parent token cancellation
510        // - Tasks continue running in background even after timeout
511        // - Difference between "waiting timeout" and "task failure"
512        // - Client-side timeout vs server-side failure handling
513        // - Demonstrates that join() timeout != critical task failure
514        let parent_token = CancellationToken::new();
515
516        let handle = CriticalTaskExecutionHandle::new(
517            |_cancel_token| async move {
518                // A task that takes a long time
519                tokio::time::sleep(Duration::from_secs(10)).await;
520                Ok(())
521            },
522            parent_token,
523            "long-task",
524        )
525        .unwrap();
526
527        // Test with timeout
528        let result = timeout(Duration::from_millis(100), handle.join()).await;
529        assert!(result.is_err()); // Should timeout
530
531        // The parent token should NOT be cancelled since task didn't fail
532        // (it's still running in the background, but we timed out waiting for it)
533    }
534
535    #[tokio::test]
536    async fn test_panic_triggers_immediate_parent_cancellation() {
537        // Test: Verify that panics trigger parent cancellation immediately via monitor task
538        // Verifies:
539        // - Monitor task detects panics immediately when they occur
540        // - Parent token cancellation happens immediately, not on join()
541        // - System shutdown is triggered as soon as critical task panics
542        // - Demonstrates true "critical task" behavior with immediate failure propagation
543        let parent_token = CancellationToken::new();
544
545        let handle = CriticalTaskExecutionHandle::new(
546            |_cancel_token| async move {
547                tokio::time::sleep(Duration::from_millis(50)).await;
548                panic!("Critical failure!");
549            },
550            parent_token.clone(),
551            "immediate-panic-task",
552        )
553        .unwrap();
554
555        // Wait for the panic to be detected by monitor task
556        tokio::time::sleep(Duration::from_millis(100)).await;
557
558        // Parent token should be cancelled immediately via monitor task
559        assert!(
560            parent_token.is_cancelled(),
561            "Parent token should be cancelled immediately when critical task panics"
562        );
563        assert!(handle.join().await.is_err());
564    }
565
566    #[tokio::test]
567    async fn test_error_triggers_immediate_parent_cancellation() {
568        // Test: Verify that regular errors also trigger parent cancellation immediately
569        // Verifies:
570        // - Parent token cancellation happens immediately when task returns error
571        // - No need to call join() for critical failure detection
572        // - Both panics AND regular errors trigger immediate system shutdown
573        // - Demonstrates consistent critical failure behavior
574        let parent_token = CancellationToken::new();
575
576        let handle = CriticalTaskExecutionHandle::new(
577            |_cancel_token| async move {
578                tokio::time::sleep(Duration::from_millis(50)).await;
579                anyhow::bail!("Critical error!");
580            },
581            parent_token.clone(),
582            "immediate-error-task",
583        )
584        .unwrap();
585
586        // Don't call join() - just wait for the error to be detected
587        tokio::time::sleep(Duration::from_millis(100)).await;
588
589        // Parent token should be cancelled even though we didn't call join()
590        assert!(
591            parent_token.is_cancelled(),
592            "Parent token should be cancelled immediately when critical task errors"
593        );
594        assert!(handle.join().await.is_err());
595    }
596
597    #[tokio::test]
598    #[should_panic]
599    async fn test_task_detach() {
600        // Dropping without detaching should panic
601        let parent_token = CancellationToken::new();
602        let _handle = CriticalTaskExecutionHandle::new(
603            |_cancel_token| async move { Ok(()) },
604            parent_token,
605            "test-detach-task",
606        )
607        .unwrap();
608
609        // Dropping without detaching should panic
610    }
611}