Skip to main content

dynamo_runtime/utils/
graceful_shutdown.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::sync::atomic::{AtomicUsize, Ordering};
5use tokio::sync::Notify;
6
7/// Tracks graceful shutdown state for endpoints
8pub struct GracefulShutdownTracker {
9    active_endpoints: AtomicUsize,
10    shutdown_complete: Notify,
11}
12
13impl std::fmt::Debug for GracefulShutdownTracker {
14    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
15        f.debug_struct("GracefulShutdownTracker")
16            .field(
17                "active_endpoints",
18                &self.active_endpoints.load(Ordering::SeqCst),
19            )
20            .finish()
21    }
22}
23
24impl GracefulShutdownTracker {
25    pub(crate) fn new() -> Self {
26        Self {
27            active_endpoints: AtomicUsize::new(0),
28            shutdown_complete: Notify::new(),
29        }
30    }
31
32    pub(crate) fn register_endpoint(&self) {
33        let count = self.active_endpoints.fetch_add(1, Ordering::SeqCst);
34        tracing::debug!(
35            "Endpoint registered, total active: {} -> {}",
36            count,
37            count + 1
38        );
39    }
40
41    pub(crate) fn unregister_endpoint(&self) {
42        let prev = self.active_endpoints.fetch_sub(1, Ordering::SeqCst);
43        tracing::debug!(
44            "Endpoint unregistered, remaining active: {} -> {}",
45            prev,
46            prev - 1
47        );
48        if prev == 1 {
49            // Last endpoint completed
50            tracing::info!("Last endpoint completed, notifying all waiters");
51            self.shutdown_complete.notify_waiters();
52        }
53    }
54
55    /// Get the current count of active endpoints
56    pub(crate) fn get_count(&self) -> usize {
57        self.active_endpoints.load(Ordering::Acquire)
58    }
59
60    pub(crate) async fn wait_for_completion(&self) {
61        loop {
62            // Create the waiter BEFORE checking the condition
63            let notified = self.shutdown_complete.notified();
64
65            let count = self.active_endpoints.load(Ordering::SeqCst);
66            tracing::trace!("Checking completion status, active endpoints: {}", count);
67
68            if count == 0 {
69                tracing::debug!("All endpoints completed");
70                break;
71            }
72
73            // Only wait if there are still active endpoints
74            tracing::debug!("Waiting for {} endpoints to complete", count);
75            notified.await;
76            tracing::trace!("Received notification, rechecking...");
77        }
78    }
79
80    // This method is no longer needed since we can access the tracker directly
81}