dynamo_runtime/utils/
graceful_shutdown.rs1use std::sync::Arc;
5use std::sync::atomic::{AtomicUsize, Ordering};
6use tokio::sync::Notify;
7
8pub struct GracefulShutdownTracker {
10 active_endpoints: AtomicUsize,
11 shutdown_complete: Notify,
12}
13
14pub struct GracefulTaskGuard {
21 tracker: Arc<GracefulShutdownTracker>,
22}
23
24impl Drop for GracefulTaskGuard {
25 fn drop(&mut self) {
26 self.tracker.unregister_endpoint();
27 }
28}
29
30impl std::fmt::Debug for GracefulShutdownTracker {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 f.debug_struct("GracefulShutdownTracker")
33 .field(
34 "active_endpoints",
35 &self.active_endpoints.load(Ordering::SeqCst),
36 )
37 .finish()
38 }
39}
40
41impl GracefulShutdownTracker {
42 pub(crate) fn new() -> Self {
43 Self {
44 active_endpoints: AtomicUsize::new(0),
45 shutdown_complete: Notify::new(),
46 }
47 }
48
49 pub fn register_task(self: &Arc<Self>) -> GracefulTaskGuard {
53 self.register_endpoint();
54 GracefulTaskGuard {
55 tracker: self.clone(),
56 }
57 }
58
59 pub(crate) fn register_endpoint(&self) {
60 let count = self.active_endpoints.fetch_add(1, Ordering::SeqCst);
61 tracing::debug!(
62 "Endpoint registered, total active: {} -> {}",
63 count,
64 count + 1
65 );
66 }
67
68 pub(crate) fn unregister_endpoint(&self) {
69 let prev = self.active_endpoints.fetch_sub(1, Ordering::SeqCst);
70 tracing::debug!(
71 "Endpoint unregistered, remaining active: {} -> {}",
72 prev,
73 prev - 1
74 );
75 if prev == 1 {
76 tracing::info!("Last endpoint completed, notifying all waiters");
78 self.shutdown_complete.notify_waiters();
79 }
80 }
81
82 pub(crate) fn get_count(&self) -> usize {
84 self.active_endpoints.load(Ordering::Acquire)
85 }
86
87 pub(crate) async fn wait_for_completion(&self) {
88 loop {
89 let notified = self.shutdown_complete.notified();
91
92 let count = self.active_endpoints.load(Ordering::SeqCst);
93 tracing::trace!("Checking completion status, active endpoints: {count}");
94
95 if count == 0 {
96 tracing::debug!("All endpoints completed");
97 break;
98 }
99
100 tracing::debug!("Waiting for {count} endpoints to complete");
102 notified.await;
103 tracing::trace!("Received notification, rechecking...");
104 }
105 }
106
107 }