1use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
10use std::sync::{Arc, Mutex};
11use std::time::Duration;
12
13use tokio::sync::Notify;
14use tokio::task::JoinHandle;
15use tokio::time::timeout;
16use tokio_util::sync::CancellationToken;
17use tracing::{debug, info, warn};
18
19pub const DEFAULT_SHUTDOWN_TIMEOUT: Duration = Duration::from_millis(500);
21
22pub const TASK_ABORT_TIMEOUT: Duration = Duration::from_millis(100);
24
25pub struct ShutdownCoordinator {
27 close_start: CancellationToken,
29
30 close_complete: CancellationToken,
32
33 shutdown_initiated: AtomicBool,
35
36 active_tasks: Arc<AtomicUsize>,
38
39 tasks_complete: Arc<Notify>,
41
42 task_handles: Mutex<Vec<JoinHandle<()>>>,
44}
45
46impl std::fmt::Debug for ShutdownCoordinator {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 f.debug_struct("ShutdownCoordinator")
49 .field("shutdown_initiated", &self.shutdown_initiated)
50 .field("active_tasks", &self.active_tasks)
51 .finish_non_exhaustive()
52 }
53}
54
55impl ShutdownCoordinator {
56 pub fn new() -> Arc<Self> {
58 Arc::new(Self {
59 close_start: CancellationToken::new(),
60 close_complete: CancellationToken::new(),
61 shutdown_initiated: AtomicBool::new(false),
62 active_tasks: Arc::new(AtomicUsize::new(0)),
63 tasks_complete: Arc::new(Notify::new()),
64 task_handles: Mutex::new(Vec::new()),
65 })
66 }
67
68 pub fn close_start_token(&self) -> CancellationToken {
70 self.close_start.clone()
71 }
72
73 pub fn close_complete_token(&self) -> CancellationToken {
75 self.close_complete.clone()
76 }
77
78 pub fn is_shutting_down(&self) -> bool {
80 self.shutdown_initiated.load(Ordering::SeqCst)
81 }
82
83 pub fn register_task(&self, handle: JoinHandle<()>) {
85 self.active_tasks.fetch_add(1, Ordering::SeqCst);
86 if let Ok(mut handles) = self.task_handles.lock() {
87 handles.push(handle);
88 }
89 }
90
91 pub fn spawn_tracked<F>(self: &Arc<Self>, future: F) -> JoinHandle<()>
93 where
94 F: std::future::Future<Output = ()> + Send + 'static,
95 {
96 let tasks_complete = Arc::clone(&self.tasks_complete);
97 let task_counter = Arc::clone(&self.active_tasks);
98
99 self.active_tasks.fetch_add(1, Ordering::SeqCst);
101
102 tokio::spawn(async move {
103 future.await;
104 if task_counter.fetch_sub(1, Ordering::SeqCst) == 1 {
106 tasks_complete.notify_waiters();
107 }
108 })
109 }
110
111 pub fn active_task_count(&self) -> usize {
113 self.active_tasks.load(Ordering::SeqCst)
114 }
115
116 pub async fn shutdown(&self) {
118 if self.shutdown_initiated.swap(true, Ordering::SeqCst) {
120 debug!("Shutdown already in progress");
121 return;
122 }
123
124 info!("Starting coordinated shutdown");
125
126 debug!("Stage 1: Signaling close start");
128 self.close_start.cancel();
129
130 debug!("Stage 2: Waiting for tasks to complete");
132 let wait_result = timeout(DEFAULT_SHUTDOWN_TIMEOUT, self.wait_for_tasks()).await;
133
134 if wait_result.is_err() {
135 warn!("Shutdown timeout - aborting remaining tasks");
136 }
137
138 debug!("Stage 3: Aborting remaining tasks");
140 self.abort_remaining_tasks().await;
141
142 debug!("Stage 4: Signaling close complete");
144 self.close_complete.cancel();
145
146 info!("Shutdown complete");
147 }
148
149 async fn wait_for_tasks(&self) {
151 while self.active_tasks.load(Ordering::SeqCst) > 0 {
152 self.tasks_complete.notified().await;
153 }
154 }
155
156 async fn abort_remaining_tasks(&self) {
158 let handles: Vec<_> = if let Ok(mut guard) = self.task_handles.lock() {
159 guard.drain(..).collect()
160 } else {
161 Vec::new()
162 };
163
164 for handle in handles {
165 if !handle.is_finished() {
166 handle.abort();
167 let _ = timeout(TASK_ABORT_TIMEOUT, async {
169 let _ = handle.await;
171 })
172 .await;
173 }
174 }
175
176 self.active_tasks.store(0, Ordering::SeqCst);
177 }
178}
179
180impl Default for ShutdownCoordinator {
181 fn default() -> Self {
182 Self {
183 close_start: CancellationToken::new(),
184 close_complete: CancellationToken::new(),
185 shutdown_initiated: AtomicBool::new(false),
186 active_tasks: Arc::new(AtomicUsize::new(0)),
187 tasks_complete: Arc::new(Notify::new()),
188 task_handles: Mutex::new(Vec::new()),
189 }
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196 use std::time::Instant;
197
198 #[tokio::test]
199 async fn test_shutdown_completes_within_timeout() {
200 let coordinator = ShutdownCoordinator::new();
201
202 let start = Instant::now();
203 coordinator.shutdown().await;
204
205 assert!(start.elapsed() < DEFAULT_SHUTDOWN_TIMEOUT + Duration::from_millis(100));
206 }
207
208 #[tokio::test]
209 async fn test_shutdown_is_idempotent() {
210 let coordinator = ShutdownCoordinator::new();
211
212 coordinator.shutdown().await;
214 coordinator.shutdown().await;
215 coordinator.shutdown().await;
216 }
217
218 #[tokio::test]
219 async fn test_is_shutting_down_flag() {
220 let coordinator = ShutdownCoordinator::new();
221
222 assert!(!coordinator.is_shutting_down());
223 coordinator.shutdown().await;
224 assert!(coordinator.is_shutting_down());
225 }
226
227 #[tokio::test]
228 async fn test_close_start_token_cancelled() {
229 let coordinator = ShutdownCoordinator::new();
230 let token = coordinator.close_start_token();
231
232 assert!(!token.is_cancelled());
233 coordinator.shutdown().await;
234 assert!(token.is_cancelled());
235 }
236
237 #[tokio::test]
238 async fn test_close_complete_token_cancelled() {
239 let coordinator = ShutdownCoordinator::new();
240 let token = coordinator.close_complete_token();
241
242 assert!(!token.is_cancelled());
243 coordinator.shutdown().await;
244 assert!(token.is_cancelled());
245 }
246
247 #[tokio::test]
248 async fn test_spawn_tracked_increments_count() {
249 let coordinator = ShutdownCoordinator::new();
250
251 assert_eq!(coordinator.active_task_count(), 0);
252
253 let _handle = coordinator.spawn_tracked(async {
254 tokio::time::sleep(Duration::from_secs(10)).await;
255 });
256
257 assert!(coordinator.active_task_count() >= 1);
259
260 coordinator.shutdown().await;
261 }
262
263 #[tokio::test]
264 async fn test_shutdown_with_long_running_tasks() {
265 let coordinator = ShutdownCoordinator::new();
266
267 let token = coordinator.close_start_token();
269 let _handle = coordinator.spawn_tracked(async move {
270 token.cancelled().await;
272 });
273
274 let start = Instant::now();
276 coordinator.shutdown().await;
277
278 assert!(start.elapsed() < DEFAULT_SHUTDOWN_TIMEOUT + Duration::from_millis(200));
280 }
281
282 #[tokio::test]
283 async fn test_task_completes_before_shutdown() {
284 let coordinator = ShutdownCoordinator::new();
285
286 let handle = coordinator.spawn_tracked(async {
288 tokio::time::sleep(Duration::from_millis(10)).await;
289 });
290
291 let _ = handle.await;
293
294 let start = Instant::now();
296 coordinator.shutdown().await;
297 assert!(start.elapsed() < Duration::from_millis(100));
298 }
299
300 #[tokio::test]
301 async fn test_multiple_tracked_tasks() {
302 let coordinator = ShutdownCoordinator::new();
303 let token = coordinator.close_start_token();
304
305 for _ in 0..5 {
307 let t = token.clone();
308 coordinator.spawn_tracked(async move {
309 t.cancelled().await;
310 });
311 }
312
313 assert!(coordinator.active_task_count() >= 5);
315
316 coordinator.shutdown().await;
318 }
319
320 #[tokio::test]
321 async fn test_task_decrements_on_completion() {
322 let coordinator = ShutdownCoordinator::new();
323
324 let handle = coordinator.spawn_tracked(async {
326 });
328
329 let _ = handle.await;
331
332 tokio::time::sleep(Duration::from_millis(10)).await;
334
335 assert_eq!(coordinator.active_task_count(), 0);
337 }
338}