dynamo_runtime/utils/tasks/
critical.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16//! Utilities for handling tasks.
17
18use anyhow::{Context, Result};
19use std::future::Future;
20use tokio::runtime::Handle;
21use tokio::sync::oneshot;
22use tokio::task::JoinHandle;
23use tokio_util::sync::CancellationToken;
24
25/// Type alias for a critical task handler function.
26///
27/// The handler receives a [CancellationToken] and returns a [Future] that resolves to [Result<()>].
28/// The task should monitor the cancellation token and gracefully shut down when it's cancelled.
29pub type CriticalTaskHandler<Fut> = dyn FnOnce(CancellationToken) -> Fut + Send + 'static;
30
31/// The [CriticalTaskExecutionHandle] is a handle for a critical task that is expected to
32/// complete successfully. This handle provides two cancellation mechanisms:
33///
34/// 1. **Critical Failure**: If the task returns an error or panics, the parent cancellation
35///    token is triggered immediately via a monitoring task that detects failures.
36///
37/// 2. **Graceful Shutdown**: The task can be gracefully shut down via its child token,
38///    allowing it to complete cleanly without triggering system-wide cancellation.
39///
40/// This is useful for ensuring that critical detached tasks either complete successfully
41/// or trigger appropriate shutdown procedures when they fail.
42pub struct CriticalTaskExecutionHandle {
43    monitor_task: JoinHandle<()>,
44    graceful_shutdown_token: CancellationToken,
45    result_receiver: Option<oneshot::Receiver<Result<()>>>,
46    detached: bool,
47}
48
49impl CriticalTaskExecutionHandle {
50    pub fn new<Fut>(
51        task_fn: impl FnOnce(CancellationToken) -> Fut + Send + 'static,
52        parent_token: CancellationToken,
53        description: &str,
54    ) -> Result<Self>
55    where
56        Fut: Future<Output = Result<()>> + Send + 'static,
57    {
58        Self::new_with_runtime(task_fn, parent_token, description, &Handle::try_current()?)
59    }
60
61    /// Create a new [CriticalTaskExecutionHandle] for a critical task.
62    ///
63    /// # Arguments
64    /// * `task_fn` - A function that takes a cancellation token and returns the critical task future
65    /// * `parent_token` - Token that will be cancelled if this critical task fails
66    /// * `description` - Description for logging purposes
67    /// * `runtime` - The runtime to use for the task.
68    pub fn new_with_runtime<Fut>(
69        task_fn: impl FnOnce(CancellationToken) -> Fut + Send + 'static,
70        parent_token: CancellationToken,
71        description: &str,
72        runtime: &Handle,
73    ) -> Result<Self>
74    where
75        Fut: Future<Output = Result<()>> + Send + 'static,
76    {
77        let graceful_shutdown_token = parent_token.child_token();
78        let description = description.to_string();
79        let parent_token_clone = parent_token.clone();
80
81        // Create channel for communicating results from monitor to handle
82        let (result_sender, result_receiver) = oneshot::channel();
83
84        let graceful_shutdown_token_clone = graceful_shutdown_token.clone();
85        let description_clone = description.to_string();
86        let task = runtime.spawn(async move {
87            let future = task_fn(graceful_shutdown_token_clone);
88
89            match future.await {
90                Ok(()) => {
91                    tracing::debug!(
92                        "Critical task '{}' completed successfully",
93                        description_clone
94                    );
95                    Ok(())
96                }
97                Err(e) => {
98                    tracing::error!("Critical task '{}' failed: {:#}", description_clone, e);
99                    Err(e.context(format!("Critical task '{}' failed", description_clone)))
100                }
101            }
102        });
103
104        // Spawn monitor task that immediately joins the main task and detects failures
105        let monitor_task = {
106            let main_task_handle = task;
107            let parent_token_monitor = parent_token_clone.clone();
108            let description_monitor = description.clone();
109
110            runtime.spawn(async move {
111                let result = match main_task_handle.await {
112                    Ok(task_result) => {
113                        // Task completed normally (success or error)
114                        if task_result.is_err() {
115                            // Error - trigger parent cancellation immediately
116                            parent_token_monitor.cancel();
117                        }
118                        task_result
119                    }
120                    Err(join_error) => {
121                        // Task panicked - handle immediately
122                        if join_error.is_panic() {
123                            let panic_msg = if let Ok(reason) = join_error.try_into_panic() {
124                                if let Some(s) = reason.downcast_ref::<String>() {
125                                    s.clone()
126                                } else if let Some(s) = reason.downcast_ref::<&str>() {
127                                    s.to_string()
128                                } else {
129                                    "Unknown panic".to_string()
130                                }
131                            } else {
132                                "Panic occurred but reason unavailable".to_string()
133                            };
134
135                            tracing::error!(
136                                "Critical task '{}' panicked: {}",
137                                description_monitor,
138                                panic_msg
139                            );
140                            parent_token_monitor.cancel(); // Trigger parent cancellation immediately
141                            Err(anyhow::anyhow!(
142                                "Critical task '{}' panicked: {}",
143                                description_monitor,
144                                panic_msg
145                            ))
146                        } else {
147                            parent_token_monitor.cancel();
148                            Err(anyhow::anyhow!(
149                                "Failed to join critical task '{}': {}",
150                                description_monitor,
151                                join_error
152                            ))
153                        }
154                    }
155                };
156
157                // Send result to handle (ignore if receiver dropped)
158                let _ = result_sender.send(result);
159            })
160        };
161
162        Ok(Self {
163            monitor_task,
164            graceful_shutdown_token,
165            result_receiver: Some(result_receiver),
166            detached: false,
167        })
168    }
169
170    /// Check if the task awaiting on the [Server]s background event loop has finished.
171    pub fn is_finished(&self) -> bool {
172        self.monitor_task.is_finished()
173    }
174
175    /// Check if the server's event loop has been cancelled.
176    pub fn is_cancelled(&self) -> bool {
177        self.graceful_shutdown_token.is_cancelled()
178    }
179
180    /// Gracefully cancel this critical task without triggering system-wide shutdown.
181    ///
182    /// This signals the task to stop processing and exit cleanly. The task should
183    /// monitor its cancellation token and respond appropriately.
184    ///
185    /// This will not propagate to the parent [CancellationToken] unless an error
186    /// occurs during the shutdown process.
187    pub fn cancel(&self) {
188        self.graceful_shutdown_token.cancel();
189    }
190
191    /// Join on the critical task and return its actual result.
192    ///
193    /// This will return:
194    /// - `Ok(())` if the task completed successfully or was gracefully cancelled
195    /// - `Err(...)` if the task failed or panicked, preserving the original error
196    ///
197    /// Note: Both errors and panics trigger parent cancellation immediately via the monitor task.
198    pub async fn join(mut self) -> Result<()> {
199        self.detached = true;
200
201        match self.result_receiver.take().unwrap().await {
202            Ok(task_result) => task_result,
203            Err(_) => {
204                // This should rarely happen - means monitor task was dropped/cancelled
205                Err(anyhow::anyhow!("Critical task monitor was cancelled"))
206            }
207        }
208    }
209
210    /// Detach the task. This allows the task to continue running after the handle is dropped.
211    pub fn detach(mut self) {
212        self.detached = true;
213    }
214}
215
216impl Drop for CriticalTaskExecutionHandle {
217    fn drop(&mut self) {
218        if !self.detached {
219            panic!("Critical task was not detached prior to drop!");
220        }
221    }
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227    use std::sync::Arc;
228    use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
229    use std::time::Duration;
230    use tokio::time::timeout;
231
232    #[tokio::test]
233    async fn test_successful_task_completion() {
234        // Test: A critical task that completes successfully without any issues
235        // Verifies:
236        // - Task executes and completes normally
237        // - Result is Ok(())
238        // - Parent token remains uncancelled (no critical failure)
239        // - Task execution side effects occur (work gets done)
240        let parent_token = CancellationToken::new();
241        let completed = Arc::new(AtomicBool::new(false));
242        let completed_clone = completed.clone();
243
244        let handle = CriticalTaskExecutionHandle::new(
245            |_cancel_token| async move {
246                completed_clone.store(true, Ordering::SeqCst);
247                Ok(())
248            },
249            parent_token.clone(),
250            "test-success-task",
251        )
252        .unwrap();
253
254        // Task should complete successfully
255        let result = handle.join().await;
256        assert!(result.is_ok());
257        assert!(completed.load(Ordering::SeqCst));
258        assert!(!parent_token.is_cancelled());
259    }
260
261    #[tokio::test]
262    async fn test_task_failure_cancels_parent_token() {
263        // Test: A critical task that returns an error (critical failure)
264        // Verifies:
265        // - Task error is properly propagated to caller
266        // - Parent cancellation token is triggered (critical failure behavior)
267        // - Error message is preserved and includes context
268        // - Demonstrates the "critical" aspect - failures affect the entire system
269        let parent_token = CancellationToken::new();
270
271        let handle = CriticalTaskExecutionHandle::new(
272            |_cancel_token| async move {
273                anyhow::bail!("Critical task failed!");
274            },
275            parent_token.clone(),
276            "test-failure-task",
277        )
278        .unwrap();
279
280        // Task should fail and cancel parent token
281        let result = handle.join().await;
282        assert!(result.is_err());
283        let error_msg = result.unwrap_err().to_string();
284        // Check that the error contains either the original message or the context
285        assert!(
286            error_msg.contains("Critical task failed!")
287                || error_msg.contains("Critical task 'test-failure-task' failed"),
288            "Error message should contain failure context: {}",
289            error_msg
290        );
291
292        // Give a moment for the cancellation to propagate
293        tokio::time::sleep(Duration::from_millis(10)).await;
294        assert!(parent_token.is_cancelled());
295    }
296
297    #[tokio::test]
298    async fn test_task_panic_is_caught_and_reported() {
299        // Test: A critical task that panics during execution
300        // Verifies:
301        // - Tokio's JoinHandle catches panics automatically
302        // - Panics are converted to proper Error types
303        // - System doesn't crash, panic is contained
304        // - Error message indicates a panic occurred
305        // - Parent token is cancelled (panic is treated as critical failure)
306        // - Demonstrates panic safety of the critical task system
307        let parent_token = CancellationToken::new();
308
309        let handle = CriticalTaskExecutionHandle::new(
310            |_cancel_token| async move {
311                panic!("Something went terribly wrong!");
312            },
313            parent_token.clone(),
314            "test-panic-task",
315        )
316        .unwrap();
317
318        // Panic should be caught and converted to error
319        let result = handle.join().await;
320        assert!(result.is_err());
321        let error_msg = result.unwrap_err().to_string();
322        assert!(
323            error_msg.contains("panicked") || error_msg.contains("panic"),
324            "Error message should indicate a panic occurred: {}",
325            error_msg
326        );
327
328        // Parent token should be cancelled due to panic (critical failure)
329        assert!(parent_token.is_cancelled());
330    }
331
332    #[tokio::test]
333    async fn test_graceful_shutdown_via_cancellation_token() {
334        // Test: A long-running task that responds to graceful shutdown signals
335        // Verifies:
336        // - Task can monitor its cancellation token and stop early
337        // - Graceful cancellation does NOT trigger parent token cancellation
338        // - Task can do partial work before stopping
339        // - handle.cancel() triggers the child token, not parent token
340        // - Demonstrates proper shutdown patterns for long-running tasks
341        let parent_token = CancellationToken::new();
342        let work_done = Arc::new(AtomicU32::new(0));
343        let work_done_clone = work_done.clone();
344
345        let handle = CriticalTaskExecutionHandle::new(
346            |cancel_token| async move {
347                for i in 0..100 {
348                    if cancel_token.is_cancelled() {
349                        break;
350                    }
351                    work_done_clone.store(i, Ordering::SeqCst);
352                    tokio::time::sleep(Duration::from_millis(10)).await;
353                }
354                Ok(())
355            },
356            parent_token.clone(),
357            "test-graceful-shutdown",
358        )
359        .unwrap();
360
361        // Let task do some work
362        tokio::time::sleep(Duration::from_millis(50)).await;
363
364        // Request graceful shutdown
365        handle.cancel();
366
367        // Task should complete gracefully
368        let result = handle.join().await;
369        assert!(result.is_ok());
370
371        // Task should have done some work but not all
372        let final_work = work_done.load(Ordering::SeqCst);
373        assert!(final_work > 0);
374        assert!(final_work < 99);
375
376        // Parent token should NOT be cancelled (graceful shutdown)
377        assert!(!parent_token.is_cancelled());
378    }
379
380    #[tokio::test]
381    async fn test_multiple_critical_tasks_one_failure() {
382        // Test: Multiple critical tasks sharing a parent token, one fails
383        // Verifies:
384        // - Multiple critical tasks can share the same parent cancellation token
385        // - When one critical task fails, all related tasks receive cancellation signal
386        // - Tasks can respond to cancellation and stop gracefully
387        // - System-wide shutdown behavior when critical components fail
388        // - Demonstrates coordinated shutdown of related services
389        let parent_token = CancellationToken::new();
390        let task1_completed = Arc::new(AtomicBool::new(false));
391        let task2_completed = Arc::new(AtomicBool::new(false));
392
393        let task1_completed_clone = task1_completed.clone();
394        let task2_completed_clone = task2_completed.clone();
395
396        // Start two critical tasks
397        let handle1 = CriticalTaskExecutionHandle::new(
398            |cancel_token| async move {
399                for _ in 0..50 {
400                    if cancel_token.is_cancelled() {
401                        return Ok(());
402                    }
403                    tokio::time::sleep(Duration::from_millis(10)).await;
404                }
405                task1_completed_clone.store(true, Ordering::SeqCst);
406                Ok(())
407            },
408            parent_token.clone(),
409            "long-running-task",
410        )
411        .unwrap();
412
413        let handle2 = CriticalTaskExecutionHandle::new(
414            |_cancel_token| async move {
415                tokio::time::sleep(Duration::from_millis(100)).await;
416                task2_completed_clone.store(true, Ordering::SeqCst);
417                anyhow::bail!("Task 2 failed!");
418            },
419            parent_token.clone(),
420            "failing-task",
421        )
422        .unwrap();
423
424        // Wait for task 2 to fail
425        let result2 = handle2.join().await;
426        assert!(result2.is_err());
427
428        // Parent token should be cancelled due to task 2 failure
429        assert!(parent_token.is_cancelled());
430
431        // Task 1 should complete early due to cancellation
432        let result1 = handle1.join().await;
433        assert!(result1.is_ok());
434        assert!(!task1_completed.load(Ordering::SeqCst)); // Should not have completed normally
435    }
436
437    #[tokio::test]
438    async fn test_status_checking_methods() {
439        // Test: Non-blocking status checking methods on the handle
440        // Verifies:
441        // - is_finished() accurately reports task completion status
442        // - is_cancelled() accurately reports cancellation status
443        // - Status methods work before and after cancellation
444        // - Methods are non-blocking and can be called multiple times
445        // - Demonstrates monitoring patterns for task supervision
446        let parent_token = CancellationToken::new();
447
448        let handle = CriticalTaskExecutionHandle::new(
449            |cancel_token| async move {
450                tokio::time::sleep(Duration::from_millis(100)).await;
451                if cancel_token.is_cancelled() {
452                    return Ok(());
453                }
454                tokio::time::sleep(Duration::from_millis(100)).await;
455                Ok(())
456            },
457            parent_token,
458            "status-test-task",
459        )
460        .unwrap();
461
462        // Initially task should be running
463        assert!(!handle.is_finished());
464        assert!(!handle.is_cancelled());
465
466        // Cancel the task
467        handle.cancel();
468
469        // Task should now be cancelled but may not be finished yet
470        assert!(handle.is_cancelled());
471
472        // Wait for completion
473        let result = handle.join().await;
474        assert!(result.is_ok());
475    }
476
477    #[tokio::test]
478    async fn test_task_with_select_pattern() {
479        // Test: Task using tokio::select! for cancellation-aware operations
480        // Verifies:
481        // - Tasks can use idiomatic tokio patterns with cancellation tokens
482        // - select! allows racing between work and cancellation
483        // - Cancellation interrupts work immediately, not just at check points
484        // - Demonstrates recommended pattern for responsive cancellation
485        // - Shows how to handle cancellation in the middle of async operations
486        let parent_token = CancellationToken::new();
487        let work_completed = Arc::new(AtomicBool::new(false));
488        let work_completed_clone = work_completed.clone();
489
490        let handle = CriticalTaskExecutionHandle::new(
491            |cancel_token| async move {
492                tokio::select! {
493                    _ = tokio::time::sleep(Duration::from_millis(200)) => {
494                        work_completed_clone.store(true, Ordering::SeqCst);
495                        Ok(())
496                    }
497                    _ = cancel_token.cancelled() => {
498                        // Graceful shutdown requested
499                        Ok(())
500                    }
501                }
502            },
503            parent_token,
504            "select-pattern-task",
505        )
506        .unwrap();
507
508        // Cancel after a short time
509        tokio::time::sleep(Duration::from_millis(50)).await;
510        handle.cancel();
511
512        let result = handle.join().await;
513        assert!(result.is_ok());
514        assert!(!work_completed.load(Ordering::SeqCst)); // Should not have completed the work
515    }
516
517    #[tokio::test]
518    async fn test_timeout_behavior() {
519        // Test: External timeout vs task failure distinction
520        // Verifies:
521        // - External timeouts don't trigger parent token cancellation
522        // - Tasks continue running in background even after timeout
523        // - Difference between "waiting timeout" and "task failure"
524        // - Client-side timeout vs server-side failure handling
525        // - Demonstrates that join() timeout != critical task failure
526        let parent_token = CancellationToken::new();
527
528        let handle = CriticalTaskExecutionHandle::new(
529            |_cancel_token| async move {
530                // A task that takes a long time
531                tokio::time::sleep(Duration::from_secs(10)).await;
532                Ok(())
533            },
534            parent_token,
535            "long-task",
536        )
537        .unwrap();
538
539        // Test with timeout
540        let result = timeout(Duration::from_millis(100), handle.join()).await;
541        assert!(result.is_err()); // Should timeout
542
543        // The parent token should NOT be cancelled since task didn't fail
544        // (it's still running in the background, but we timed out waiting for it)
545    }
546
547    #[tokio::test]
548    async fn test_panic_triggers_immediate_parent_cancellation() {
549        // Test: Verify that panics trigger parent cancellation immediately via monitor task
550        // Verifies:
551        // - Monitor task detects panics immediately when they occur
552        // - Parent token cancellation happens immediately, not on join()
553        // - System shutdown is triggered as soon as critical task panics
554        // - Demonstrates true "critical task" behavior with immediate failure propagation
555        let parent_token = CancellationToken::new();
556
557        let handle = CriticalTaskExecutionHandle::new(
558            |_cancel_token| async move {
559                tokio::time::sleep(Duration::from_millis(50)).await;
560                panic!("Critical failure!");
561            },
562            parent_token.clone(),
563            "immediate-panic-task",
564        )
565        .unwrap();
566
567        // Wait for the panic to be detected by monitor task
568        tokio::time::sleep(Duration::from_millis(100)).await;
569
570        // Parent token should be cancelled immediately via monitor task
571        assert!(
572            parent_token.is_cancelled(),
573            "Parent token should be cancelled immediately when critical task panics"
574        );
575        assert!(handle.join().await.is_err());
576    }
577
578    #[tokio::test]
579    async fn test_error_triggers_immediate_parent_cancellation() {
580        // Test: Verify that regular errors also trigger parent cancellation immediately
581        // Verifies:
582        // - Parent token cancellation happens immediately when task returns error
583        // - No need to call join() for critical failure detection
584        // - Both panics AND regular errors trigger immediate system shutdown
585        // - Demonstrates consistent critical failure behavior
586        let parent_token = CancellationToken::new();
587
588        let handle = CriticalTaskExecutionHandle::new(
589            |_cancel_token| async move {
590                tokio::time::sleep(Duration::from_millis(50)).await;
591                anyhow::bail!("Critical error!");
592            },
593            parent_token.clone(),
594            "immediate-error-task",
595        )
596        .unwrap();
597
598        // Don't call join() - just wait for the error to be detected
599        tokio::time::sleep(Duration::from_millis(100)).await;
600
601        // Parent token should be cancelled even though we didn't call join()
602        assert!(
603            parent_token.is_cancelled(),
604            "Parent token should be cancelled immediately when critical task errors"
605        );
606        assert!(handle.join().await.is_err());
607    }
608
609    #[tokio::test]
610    #[should_panic]
611    async fn test_task_detach() {
612        // Dropping without detaching should panic
613        let parent_token = CancellationToken::new();
614        let _handle = CriticalTaskExecutionHandle::new(
615            |_cancel_token| async move { Ok(()) },
616            parent_token,
617            "test-detach-task",
618        )
619        .unwrap();
620
621        // Dropping without detaching should panic
622    }
623}