azoth_balancer/
shutdown.rs1use std::time::Duration;
10use thiserror::Error;
11use tokio::sync::watch;
12use tokio::task::{JoinError, JoinSet};
13use tracing::{error, info};
14
15#[derive(Debug, Error)]
16pub enum ShutdownError {
17 #[error("A background task panicked during shutdown")]
18 Panic(#[from] JoinError),
19 #[error("Graceful shutdown timed out after {0:?}")]
20 Timeout(Duration),
21}
22
23pub struct ShutdownManager {
25 tasks: JoinSet<()>,
26 shutdown_tx: watch::Sender<()>,
27}
28
29impl ShutdownManager {
30 pub fn new() -> Self {
32 let (shutdown_tx, _) = watch::channel(());
33 Self { tasks: JoinSet::new(), shutdown_tx }
34 }
35
36 pub fn spawn_task<F>(&mut self, task: F)
44 where
45 F: std::future::Future<Output = ()> + Send + 'static,
46 {
47 self.tasks.spawn(task);
48 }
49
50 pub fn subscribe(&self) -> watch::Receiver<()> {
55 self.shutdown_tx.subscribe()
56 }
57
58 pub fn abort_all(&mut self) {
64 self.tasks.abort_all();
65 }
66
67 pub async fn graceful_shutdown(self, timeout: Duration) -> Result<(), ShutdownError> {
77 let ShutdownManager { mut tasks, shutdown_tx } = self;
80
81 info!("Broadcasting shutdown signal to all {} background tasks...", tasks.len());
82 drop(shutdown_tx);
84
85 info!("Waiting for tasks to complete...");
86
87 let join_all_logic = async {
88 while let Some(res) = tasks.join_next().await {
89 res?;
91 }
92 Ok(())
93 };
94
95 match tokio::time::timeout(timeout, join_all_logic).await {
96 Ok(Ok(_)) => {
97 info!("All background tasks completed gracefully.");
98 Ok(())
99 }
100 Ok(Err(e)) => {
101 error!(error = %e, "A background task panicked during shutdown.");
102 Err(ShutdownError::Panic(e))
103 }
104 Err(_) => {
105 error!("Shutdown timeout of {:?} exceeded. Aborting remaining tasks.", timeout);
106 tasks.abort_all();
107 Err(ShutdownError::Timeout(timeout))
108 }
109 }
110 }
111}
112
113impl Default for ShutdownManager {
114 fn default() -> Self {
115 Self::new()
116 }
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122 use tokio::time::{sleep, Duration};
123 use tracing::info;
124
125 #[tokio::test]
126 async fn test_basic_shutdown() {
127 let mut manager = ShutdownManager::new();
128 let mut rx = manager.subscribe();
129 manager.spawn_task(async move {
130 info!("Task started, waiting for shutdown...");
131 let _ = rx.changed().await;
132 info!("Task received shutdown signal.");
133 });
134 let res = manager.graceful_shutdown(Duration::from_secs(1)).await;
135 assert!(res.is_ok(), "Expected graceful shutdown to succeed");
136 }
137
138 #[tokio::test]
139 async fn test_timeout() {
140 let mut manager = ShutdownManager::new();
141 manager.spawn_task(async {
142 info!("Long task started...");
143 sleep(Duration::from_secs(10)).await;
144 });
145 let res = manager.graceful_shutdown(Duration::from_millis(100)).await;
146 assert!(res.is_err(), "Expected shutdown to return error due to timeout");
147 assert!(matches!(res, Err(ShutdownError::Timeout(_))), "Expected a timeout error");
148 }
149
150 #[tokio::test]
151 async fn test_panic_propagation() {
152 let mut manager = ShutdownManager::new();
153 manager.spawn_task(async {
154 info!("Task about to panic...");
155 panic!("Simulated panic");
156 });
157 let res = manager.graceful_shutdown(Duration::from_secs(1)).await;
158 assert!(res.is_err(), "Expected shutdown to return error due to task panic");
159 assert!(matches!(res, Err(ShutdownError::Panic(_))), "Expected a panic error");
160 }
161
162 #[tokio::test]
163 async fn test_multiple_tasks() {
164 let mut manager = ShutdownManager::new();
165 let mut rx1 = manager.subscribe();
166 let mut rx2 = manager.subscribe();
167 manager.spawn_task(async move {
168 info!("Task 1 waiting for shutdown...");
169 let _ = rx1.changed().await;
170 info!("Task 1 shutdown complete");
171 });
172 manager.spawn_task(async move {
173 info!("Task 2 waiting for shutdown...");
174 let _ = rx2.changed().await;
175 info!("Task 2 shutdown complete");
176 });
177 let res = manager.graceful_shutdown(Duration::from_secs(1)).await;
178 assert!(res.is_ok(), "Expected all tasks to shutdown gracefully");
179 }
180
181 #[tokio::test]
182 async fn test_shutdown_with_no_tasks() {
183 let manager = ShutdownManager::new();
184 let res = manager.graceful_shutdown(Duration::from_secs(1)).await;
185 assert!(res.is_ok(), "Shutdown should succeed immediately with no tasks");
186 }
187
188 #[tokio::test]
189 async fn test_drop_does_not_panic() {
190 let mut manager = ShutdownManager::new();
191 manager.spawn_task(async { sleep(Duration::from_secs(10)).await });
192 }
194
195 #[tokio::test]
196 async fn test_task_ignores_shutdown() {
197 let mut manager = ShutdownManager::new();
198 manager.spawn_task(async {
199 info!("Task ignoring shutdown...");
200 sleep(Duration::from_secs(10)).await;
201 });
202 let res = manager.graceful_shutdown(Duration::from_millis(100)).await;
203 assert!(res.is_err(), "Expected timeout error");
204 assert!(matches!(res, Err(ShutdownError::Timeout(_))), "Expected a timeout error");
205 }
206
207 #[tokio::test]
208 async fn test_abort_all() {
209 let mut manager = ShutdownManager::new();
210 manager.spawn_task(async {
211 sleep(Duration::from_secs(60)).await;
213 });
214
215 manager.abort_all();
216
217 let res = manager.tasks.join_next().await;
218 assert!(res.is_some(), "Expected a result from the aborted task");
219 let task_res = res.unwrap();
220 assert!(task_res.is_err(), "Expected the task result to be an error");
221 assert!(
222 task_res.unwrap_err().is_cancelled(),
223 "Expected the JoinError to be of type 'cancelled'"
224 );
225 }
226
227 #[tokio::test]
228 async fn test_task_finishes_before_shutdown() {
229 let mut manager = ShutdownManager::new();
230 manager.spawn_task(async {
231 info!("Task that finishes early has started and finished.");
232 });
233 sleep(Duration::from_millis(50)).await;
235 let res = manager.graceful_shutdown(Duration::from_secs(1)).await;
236 assert!(res.is_ok(), "Shutdown should succeed even if tasks are already complete");
237 }
238
239 #[tokio::test]
240 async fn test_partial_panic_scenario() {
241 let mut manager = ShutdownManager::new();
242 let mut rx = manager.subscribe();
243
244 manager.spawn_task(async move {
246 info!("Normal task waiting for shutdown...");
247 let _ = rx.changed().await;
248 info!("Normal task completed");
249 });
250
251 manager.spawn_task(async {
253 info!("Task about to panic...");
254 panic!("Simulated panic in mixed scenario");
255 });
256
257 let res = manager.graceful_shutdown(Duration::from_secs(1)).await;
258 assert!(res.is_err(), "Expected error due to panic");
259 assert!(
260 matches!(res, Err(ShutdownError::Panic(_))),
261 "Expected panic error even with other tasks completing normally"
262 );
263 }
264
265 #[tokio::test]
266 async fn test_task_spawns_subtask() {
267 let mut manager = ShutdownManager::new();
268 let mut rx = manager.subscribe();
269
270 manager.spawn_task(async move {
271 info!("Parent task starting...");
272
273 let subtask = tokio::spawn(async {
275 sleep(Duration::from_millis(100)).await;
276 info!("Subtask completed");
277 });
278
279 let _ = rx.changed().await;
281
282 let _ = subtask.await;
284 info!("Parent task completed after subtask");
285 });
286
287 let res = manager.graceful_shutdown(Duration::from_secs(1)).await;
288 assert!(res.is_ok(), "Should handle tasks that spawn their own subtasks");
289 }
290
291 #[tokio::test]
292 async fn test_multiple_receivers_same_task() {
293 let mut manager = ShutdownManager::new();
294 let mut rx1 = manager.subscribe();
295 let mut rx2 = manager.subscribe();
296
297 manager.spawn_task(async move {
298 info!("Task with multiple receivers starting...");
299
300 let _ = rx1.changed().await;
302 info!("First receiver got signal");
303
304 let _ = rx2.changed().await;
306 info!("Second receiver got signal");
307 });
308
309 let res = manager.graceful_shutdown(Duration::from_secs(1)).await;
310 assert!(res.is_ok(), "Task using multiple receivers should complete");
311 }
312}