Skip to main content

dynamo_runtime/utils/
graceful_shutdown.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::sync::Arc;
5use std::sync::atomic::{AtomicUsize, Ordering};
6use tokio::sync::Notify;
7
8/// Tracks graceful shutdown state for endpoints
9pub struct GracefulShutdownTracker {
10    active_endpoints: AtomicUsize,
11    shutdown_complete: Notify,
12}
13
14/// RAII handle that holds a `GracefulShutdownTracker` registration. Drop
15/// it to release the registration. Used by long-running shutdown
16/// orchestrators (e.g. backend `Worker`) to keep `Runtime::shutdown`'s
17/// Phase 2 wait alive until they finish — otherwise Phase 3 cancels the
18/// main token and tears down NATS/etcd while the orchestrator is still
19/// running drain/cleanup.
20pub struct GracefulTaskGuard {
21    tracker: Arc<GracefulShutdownTracker>,
22}
23
24impl Drop for GracefulTaskGuard {
25    fn drop(&mut self) {
26        self.tracker.unregister_endpoint();
27    }
28}
29
30impl std::fmt::Debug for GracefulShutdownTracker {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        f.debug_struct("GracefulShutdownTracker")
33            .field(
34                "active_endpoints",
35                &self.active_endpoints.load(Ordering::SeqCst),
36            )
37            .finish()
38    }
39}
40
41impl GracefulShutdownTracker {
42    pub(crate) fn new() -> Self {
43        Self {
44            active_endpoints: AtomicUsize::new(0),
45            shutdown_complete: Notify::new(),
46        }
47    }
48
49    /// Acquire a guard that participates in the graceful-shutdown wait.
50    /// `Runtime::shutdown`'s Phase 2 will not advance to Phase 3 (NATS/etcd
51    /// teardown) until every outstanding guard is dropped.
52    pub fn register_task(self: &Arc<Self>) -> GracefulTaskGuard {
53        self.register_endpoint();
54        GracefulTaskGuard {
55            tracker: self.clone(),
56        }
57    }
58
59    pub(crate) fn register_endpoint(&self) {
60        let count = self.active_endpoints.fetch_add(1, Ordering::SeqCst);
61        tracing::debug!(
62            "Endpoint registered, total active: {} -> {}",
63            count,
64            count + 1
65        );
66    }
67
68    pub(crate) fn unregister_endpoint(&self) {
69        let prev = self.active_endpoints.fetch_sub(1, Ordering::SeqCst);
70        tracing::debug!(
71            "Endpoint unregistered, remaining active: {} -> {}",
72            prev,
73            prev - 1
74        );
75        if prev == 1 {
76            // Last endpoint completed
77            tracing::info!("Last endpoint completed, notifying all waiters");
78            self.shutdown_complete.notify_waiters();
79        }
80    }
81
82    /// Get the current count of active endpoints
83    pub(crate) fn get_count(&self) -> usize {
84        self.active_endpoints.load(Ordering::Acquire)
85    }
86
87    pub(crate) async fn wait_for_completion(&self) {
88        loop {
89            // Create the waiter BEFORE checking the condition
90            let notified = self.shutdown_complete.notified();
91
92            let count = self.active_endpoints.load(Ordering::SeqCst);
93            tracing::trace!("Checking completion status, active endpoints: {count}");
94
95            if count == 0 {
96                tracing::debug!("All endpoints completed");
97                break;
98            }
99
100            // Only wait if there are still active endpoints
101            tracing::debug!("Waiting for {count} endpoints to complete");
102            notified.await;
103            tracing::trace!("Received notification, rechecking...");
104        }
105    }
106
107    // This method is no longer needed since we can access the tracker directly
108}