1use std::sync::Arc;
7use std::sync::atomic::{AtomicU8, AtomicU32, Ordering};
8use std::time::Duration;
9
10use thiserror::Error;
11use tokio::sync::broadcast;
12use tokio::time::timeout;
13use tracing::{debug, warn};
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum ShutdownState {
18 Running = 0,
20 Draining = 1,
22 Terminated = 2,
24}
25
26impl ShutdownState {
27 fn from_u8(value: u8) -> Self {
29 match value {
30 0 => ShutdownState::Running,
31 1 => ShutdownState::Draining,
32 _ => ShutdownState::Terminated,
33 }
34 }
35
36 fn as_u8(self) -> u8 {
38 self as u8
39 }
40}
41
42#[derive(Debug, Error)]
44pub enum ShutdownError {
45 #[error("shutdown already initiated")]
47 AlreadyShuttingDown,
48
49 #[error("drain timeout expired with {remaining} tasks still in flight")]
51 DrainTimeout {
52 remaining: u32,
54 },
55
56 #[error("broadcast error: {0}")]
58 BroadcastError(String),
59}
60
61pub type ShutdownResult<T> = Result<T, ShutdownError>;
63
64#[derive(Debug)]
83pub struct ShutdownCoordinator {
84 state: AtomicU8,
85 drain_timeout: Duration,
86 in_flight: AtomicU32,
87 shutdown_signal: broadcast::Sender<()>,
88}
89
90impl ShutdownCoordinator {
91 #[must_use]
97 pub fn new(drain_timeout: Duration) -> Self {
98 let (tx, _) = broadcast::channel(1);
99 Self {
100 state: AtomicU8::new(ShutdownState::Running.as_u8()),
101 drain_timeout,
102 in_flight: AtomicU32::new(0),
103 shutdown_signal: tx,
104 }
105 }
106
107 pub fn register_work(self: &Arc<Self>) -> ShutdownResult<WorkGuard> {
116 let current_state = ShutdownState::from_u8(self.state.load(Ordering::SeqCst));
117
118 if current_state != ShutdownState::Running {
119 return Err(ShutdownError::AlreadyShuttingDown);
120 }
121
122 self.in_flight.fetch_add(1, Ordering::SeqCst);
123 Ok(WorkGuard {
124 coordinator: Arc::clone(self),
125 })
126 }
127
128 pub async fn shutdown(&self) -> ShutdownResult<()> {
139 let old_state = self.state.compare_exchange(
141 ShutdownState::Running.as_u8(),
142 ShutdownState::Draining.as_u8(),
143 Ordering::SeqCst,
144 Ordering::SeqCst,
145 );
146
147 if old_state.is_err() {
148 return Err(ShutdownError::AlreadyShuttingDown);
149 }
150
151 debug!("shutdown initiated, draining in-flight work");
152
153 let _ = self.shutdown_signal.send(());
155
156 let result = timeout(self.drain_timeout, self.wait_for_drain()).await;
158
159 self.state
161 .store(ShutdownState::Terminated.as_u8(), Ordering::SeqCst);
162
163 if result.is_ok() {
164 debug!("shutdown complete, all work drained");
165 Ok(())
166 } else {
167 let remaining = self.in_flight.load(Ordering::SeqCst);
168 warn!(remaining, "drain timeout expired");
169 Err(ShutdownError::DrainTimeout { remaining })
170 }
171 }
172
173 #[must_use]
175 pub fn state(&self) -> ShutdownState {
176 ShutdownState::from_u8(self.state.load(Ordering::SeqCst))
177 }
178
179 #[must_use]
183 pub fn subscribe(&self) -> broadcast::Receiver<()> {
184 self.shutdown_signal.subscribe()
185 }
186
187 #[must_use]
189 pub fn in_flight_count(&self) -> u32 {
190 self.in_flight.load(Ordering::SeqCst)
191 }
192
193 async fn wait_for_drain(&self) {
195 loop {
196 if self.in_flight.load(Ordering::SeqCst) == 0 {
197 break;
198 }
199 tokio::time::sleep(Duration::from_millis(10)).await;
200 }
201 }
202}
203
204pub struct WorkGuard {
208 coordinator: Arc<ShutdownCoordinator>,
209}
210
211impl Drop for WorkGuard {
212 fn drop(&mut self) {
213 self.coordinator.in_flight.fetch_sub(1, Ordering::SeqCst);
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220
221 #[test]
222 fn initial_state_is_running() {
223 let coordinator = Arc::new(ShutdownCoordinator::new(Duration::from_secs(1)));
224 assert_eq!(coordinator.state(), ShutdownState::Running);
225 assert_eq!(coordinator.in_flight_count(), 0);
226 }
227
228 #[test]
229 fn register_work_increments_counter() {
230 let coordinator = Arc::new(ShutdownCoordinator::new(Duration::from_secs(1)));
231 let _guard = coordinator.register_work().unwrap();
232 assert_eq!(coordinator.in_flight_count(), 1);
233 }
234
235 #[test]
236 fn work_guard_drop_decrements_counter() {
237 let coordinator = Arc::new(ShutdownCoordinator::new(Duration::from_secs(1)));
238 {
239 let _guard = coordinator.register_work().unwrap();
240 assert_eq!(coordinator.in_flight_count(), 1);
241 }
242 assert_eq!(coordinator.in_flight_count(), 0);
243 }
244
245 #[test]
246 fn register_work_fails_during_shutdown() {
247 let rt = tokio::runtime::Runtime::new().unwrap();
248 rt.block_on(async {
249 let coordinator = Arc::new(ShutdownCoordinator::new(Duration::from_secs(1)));
250 let _guard = coordinator.register_work().unwrap();
251
252 let coordinator_clone = Arc::clone(&coordinator);
253 tokio::spawn(async move {
254 let _ = coordinator_clone.shutdown().await;
255 });
256
257 tokio::time::sleep(Duration::from_millis(50)).await;
258 let result = coordinator.register_work();
259 assert!(result.is_err());
260 });
261 }
262
263 #[tokio::test]
264 async fn shutdown_waits_for_work_to_complete() {
265 let coordinator = Arc::new(ShutdownCoordinator::new(Duration::from_secs(5)));
266 let guard = coordinator.register_work().unwrap();
267
268 let coordinator_clone = Arc::clone(&coordinator);
269 let shutdown_task = tokio::spawn(async move { coordinator_clone.shutdown().await });
270
271 tokio::time::sleep(Duration::from_millis(100)).await;
272 drop(guard);
273
274 let result = shutdown_task.await.unwrap();
275 assert!(result.is_ok());
276 assert_eq!(coordinator.state(), ShutdownState::Terminated);
277 }
278
279 #[tokio::test]
280 async fn shutdown_timeout_with_pending_work() {
281 let coordinator = Arc::new(ShutdownCoordinator::new(Duration::from_millis(100)));
282 let _guard = coordinator.register_work().unwrap();
283
284 let result = coordinator.shutdown().await;
285 assert!(result.is_err());
286 assert_eq!(coordinator.state(), ShutdownState::Terminated);
287 }
288
289 #[tokio::test]
290 async fn shutdown_signal_broadcast() {
291 let coordinator = Arc::new(ShutdownCoordinator::new(Duration::from_secs(1)));
292 let mut rx = coordinator.subscribe();
293
294 let coordinator_clone = Arc::clone(&coordinator);
295 tokio::spawn(async move {
296 tokio::time::sleep(Duration::from_millis(50)).await;
297 let _ = coordinator_clone.shutdown().await;
298 });
299
300 let result = tokio::time::timeout(Duration::from_secs(1), rx.recv()).await;
301 assert!(result.is_ok());
302 }
303}