ant_quic/
shutdown.rs

1//! Coordinated shutdown for ant-quic endpoints
2//!
3//! Implements staged shutdown:
4//! 1. Stop accepting new work
5//! 2. Drain existing work with timeout
6//! 3. Cancel remaining tasks
7//! 4. Clean up resources
8
9use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
10use std::sync::{Arc, Mutex};
11use std::time::Duration;
12
13use tokio::sync::Notify;
14use tokio::task::JoinHandle;
15use tokio::time::timeout;
16use tokio_util::sync::CancellationToken;
17use tracing::{debug, info, warn};
18
19/// Default timeout for graceful shutdown
20pub const DEFAULT_SHUTDOWN_TIMEOUT: Duration = Duration::from_millis(500);
21
22/// Timeout for waiting on individual tasks
23pub const TASK_ABORT_TIMEOUT: Duration = Duration::from_millis(100);
24
25/// Coordinates shutdown across all endpoint components
26pub struct ShutdownCoordinator {
27    /// Token cancelled when shutdown starts (stop accepting new work)
28    close_start: CancellationToken,
29
30    /// Token cancelled after connections drained
31    close_complete: CancellationToken,
32
33    /// Whether shutdown has been initiated
34    shutdown_initiated: AtomicBool,
35
36    /// Count of active background tasks
37    active_tasks: Arc<AtomicUsize>,
38
39    /// Notified when all tasks complete
40    tasks_complete: Arc<Notify>,
41
42    /// Tracked task handles
43    task_handles: Mutex<Vec<JoinHandle<()>>>,
44}
45
46impl std::fmt::Debug for ShutdownCoordinator {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        f.debug_struct("ShutdownCoordinator")
49            .field("shutdown_initiated", &self.shutdown_initiated)
50            .field("active_tasks", &self.active_tasks)
51            .finish_non_exhaustive()
52    }
53}
54
55impl ShutdownCoordinator {
56    /// Create a new shutdown coordinator
57    pub fn new() -> Arc<Self> {
58        Arc::new(Self {
59            close_start: CancellationToken::new(),
60            close_complete: CancellationToken::new(),
61            shutdown_initiated: AtomicBool::new(false),
62            active_tasks: Arc::new(AtomicUsize::new(0)),
63            tasks_complete: Arc::new(Notify::new()),
64            task_handles: Mutex::new(Vec::new()),
65        })
66    }
67
68    /// Get a token that is cancelled when shutdown starts
69    pub fn close_start_token(&self) -> CancellationToken {
70        self.close_start.clone()
71    }
72
73    /// Get a token that is cancelled when shutdown completes
74    pub fn close_complete_token(&self) -> CancellationToken {
75        self.close_complete.clone()
76    }
77
78    /// Check if shutdown has been initiated
79    pub fn is_shutting_down(&self) -> bool {
80        self.shutdown_initiated.load(Ordering::SeqCst)
81    }
82
83    /// Register a background task for tracking
84    pub fn register_task(&self, handle: JoinHandle<()>) {
85        self.active_tasks.fetch_add(1, Ordering::SeqCst);
86        if let Ok(mut handles) = self.task_handles.lock() {
87            handles.push(handle);
88        }
89    }
90
91    /// Spawn a tracked task that respects the shutdown token
92    pub fn spawn_tracked<F>(self: &Arc<Self>, future: F) -> JoinHandle<()>
93    where
94        F: std::future::Future<Output = ()> + Send + 'static,
95    {
96        let tasks_complete = Arc::clone(&self.tasks_complete);
97        let task_counter = Arc::clone(&self.active_tasks);
98
99        // Increment task count before spawning
100        self.active_tasks.fetch_add(1, Ordering::SeqCst);
101
102        tokio::spawn(async move {
103            future.await;
104            // Decrement and notify if last task
105            if task_counter.fetch_sub(1, Ordering::SeqCst) == 1 {
106                tasks_complete.notify_waiters();
107            }
108        })
109    }
110
111    /// Get count of active tasks
112    pub fn active_task_count(&self) -> usize {
113        self.active_tasks.load(Ordering::SeqCst)
114    }
115
116    /// Execute coordinated shutdown
117    pub async fn shutdown(&self) {
118        // Prevent multiple shutdown attempts
119        if self.shutdown_initiated.swap(true, Ordering::SeqCst) {
120            debug!("Shutdown already in progress");
121            return;
122        }
123
124        info!("Starting coordinated shutdown");
125
126        // Stage 1: Signal close start (stop accepting new work)
127        debug!("Stage 1: Signaling close start");
128        self.close_start.cancel();
129
130        // Stage 2: Wait for tasks with timeout
131        debug!("Stage 2: Waiting for tasks to complete");
132        let wait_result = timeout(DEFAULT_SHUTDOWN_TIMEOUT, self.wait_for_tasks()).await;
133
134        if wait_result.is_err() {
135            warn!("Shutdown timeout - aborting remaining tasks");
136        }
137
138        // Stage 3: Abort any remaining tasks
139        debug!("Stage 3: Aborting remaining tasks");
140        self.abort_remaining_tasks().await;
141
142        // Stage 4: Signal close complete
143        debug!("Stage 4: Signaling close complete");
144        self.close_complete.cancel();
145
146        info!("Shutdown complete");
147    }
148
149    /// Wait for all tasks to complete
150    async fn wait_for_tasks(&self) {
151        while self.active_tasks.load(Ordering::SeqCst) > 0 {
152            self.tasks_complete.notified().await;
153        }
154    }
155
156    /// Abort any tasks that didn't complete gracefully
157    async fn abort_remaining_tasks(&self) {
158        let handles: Vec<_> = if let Ok(mut guard) = self.task_handles.lock() {
159            guard.drain(..).collect()
160        } else {
161            Vec::new()
162        };
163
164        for handle in handles {
165            if !handle.is_finished() {
166                handle.abort();
167                // Give a moment for abort to take effect
168                let _ = timeout(TASK_ABORT_TIMEOUT, async {
169                    // Wait for task to actually finish
170                    let _ = handle.await;
171                })
172                .await;
173            }
174        }
175
176        self.active_tasks.store(0, Ordering::SeqCst);
177    }
178}
179
180impl Default for ShutdownCoordinator {
181    fn default() -> Self {
182        Self {
183            close_start: CancellationToken::new(),
184            close_complete: CancellationToken::new(),
185            shutdown_initiated: AtomicBool::new(false),
186            active_tasks: Arc::new(AtomicUsize::new(0)),
187            tasks_complete: Arc::new(Notify::new()),
188            task_handles: Mutex::new(Vec::new()),
189        }
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196    use std::time::Instant;
197
198    #[tokio::test]
199    async fn test_shutdown_completes_within_timeout() {
200        let coordinator = ShutdownCoordinator::new();
201
202        let start = Instant::now();
203        coordinator.shutdown().await;
204
205        assert!(start.elapsed() < DEFAULT_SHUTDOWN_TIMEOUT + Duration::from_millis(100));
206    }
207
208    #[tokio::test]
209    async fn test_shutdown_is_idempotent() {
210        let coordinator = ShutdownCoordinator::new();
211
212        // Multiple shutdowns should not panic
213        coordinator.shutdown().await;
214        coordinator.shutdown().await;
215        coordinator.shutdown().await;
216    }
217
218    #[tokio::test]
219    async fn test_is_shutting_down_flag() {
220        let coordinator = ShutdownCoordinator::new();
221
222        assert!(!coordinator.is_shutting_down());
223        coordinator.shutdown().await;
224        assert!(coordinator.is_shutting_down());
225    }
226
227    #[tokio::test]
228    async fn test_close_start_token_cancelled() {
229        let coordinator = ShutdownCoordinator::new();
230        let token = coordinator.close_start_token();
231
232        assert!(!token.is_cancelled());
233        coordinator.shutdown().await;
234        assert!(token.is_cancelled());
235    }
236
237    #[tokio::test]
238    async fn test_close_complete_token_cancelled() {
239        let coordinator = ShutdownCoordinator::new();
240        let token = coordinator.close_complete_token();
241
242        assert!(!token.is_cancelled());
243        coordinator.shutdown().await;
244        assert!(token.is_cancelled());
245    }
246
247    #[tokio::test]
248    async fn test_spawn_tracked_increments_count() {
249        let coordinator = ShutdownCoordinator::new();
250
251        assert_eq!(coordinator.active_task_count(), 0);
252
253        let _handle = coordinator.spawn_tracked(async {
254            tokio::time::sleep(Duration::from_secs(10)).await;
255        });
256
257        // Task count should be incremented
258        assert!(coordinator.active_task_count() >= 1);
259
260        coordinator.shutdown().await;
261    }
262
263    #[tokio::test]
264    async fn test_shutdown_with_long_running_tasks() {
265        let coordinator = ShutdownCoordinator::new();
266
267        // Spawn a task that would run forever
268        let token = coordinator.close_start_token();
269        let _handle = coordinator.spawn_tracked(async move {
270            // Respect shutdown token
271            token.cancelled().await;
272        });
273
274        // Shutdown should complete despite long-running task
275        let start = Instant::now();
276        coordinator.shutdown().await;
277
278        // Should complete within timeout + buffer
279        assert!(start.elapsed() < DEFAULT_SHUTDOWN_TIMEOUT + Duration::from_millis(200));
280    }
281
282    #[tokio::test]
283    async fn test_task_completes_before_shutdown() {
284        let coordinator = ShutdownCoordinator::new();
285
286        // Spawn a short task
287        let handle = coordinator.spawn_tracked(async {
288            tokio::time::sleep(Duration::from_millis(10)).await;
289        });
290
291        // Wait for task to complete
292        let _ = handle.await;
293
294        // Shutdown should be quick
295        let start = Instant::now();
296        coordinator.shutdown().await;
297        assert!(start.elapsed() < Duration::from_millis(100));
298    }
299
300    #[tokio::test]
301    async fn test_multiple_tracked_tasks() {
302        let coordinator = ShutdownCoordinator::new();
303        let token = coordinator.close_start_token();
304
305        // Spawn multiple tasks that respect shutdown
306        for _ in 0..5 {
307            let t = token.clone();
308            coordinator.spawn_tracked(async move {
309                t.cancelled().await;
310            });
311        }
312
313        // All should be tracked
314        assert!(coordinator.active_task_count() >= 5);
315
316        // Shutdown should complete all
317        coordinator.shutdown().await;
318    }
319
320    #[tokio::test]
321    async fn test_task_decrements_on_completion() {
322        let coordinator = ShutdownCoordinator::new();
323
324        // Spawn a task that completes quickly
325        let handle = coordinator.spawn_tracked(async {
326            // Quick task
327        });
328
329        // Wait for task to complete
330        let _ = handle.await;
331
332        // Give a moment for counter to update
333        tokio::time::sleep(Duration::from_millis(10)).await;
334
335        // Count should have decremented
336        assert_eq!(coordinator.active_task_count(), 0);
337    }
338}