1use std::sync::Arc;
7use std::sync::atomic::{AtomicU8, AtomicU32, Ordering};
8use std::time::Duration;
9
10use thiserror::Error;
11use tokio::sync::broadcast;
12use tokio::time::timeout;
13use tracing::{debug, warn};
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum ShutdownState {
18 Running = 0,
20 Draining = 1,
22 Terminated = 2,
24}
25
26impl ShutdownState {
27 fn from_u8(value: u8) -> Self {
29 match value {
30 0 => ShutdownState::Running,
31 1 => ShutdownState::Draining,
32 2 => ShutdownState::Terminated,
33 _ => ShutdownState::Terminated,
34 }
35 }
36
37 fn as_u8(self) -> u8 {
39 self as u8
40 }
41}
42
43#[derive(Debug, Error)]
45pub enum ShutdownError {
46 #[error("shutdown already initiated")]
48 AlreadyShuttingDown,
49
50 #[error("drain timeout expired with {remaining} tasks still in flight")]
52 DrainTimeout {
53 remaining: u32,
55 },
56
57 #[error("broadcast error: {0}")]
59 BroadcastError(String),
60}
61
62pub type ShutdownResult<T> = Result<T, ShutdownError>;
64
65#[derive(Debug)]
84pub struct ShutdownCoordinator {
85 state: AtomicU8,
86 drain_timeout: Duration,
87 in_flight: AtomicU32,
88 shutdown_signal: broadcast::Sender<()>,
89}
90
91impl ShutdownCoordinator {
92 pub fn new(drain_timeout: Duration) -> Self {
98 let (tx, _) = broadcast::channel(1);
99 Self {
100 state: AtomicU8::new(ShutdownState::Running.as_u8()),
101 drain_timeout,
102 in_flight: AtomicU32::new(0),
103 shutdown_signal: tx,
104 }
105 }
106
107 pub fn register_work(self: &Arc<Self>) -> ShutdownResult<WorkGuard> {
116 let current_state = ShutdownState::from_u8(self.state.load(Ordering::SeqCst));
117
118 if current_state != ShutdownState::Running {
119 return Err(ShutdownError::AlreadyShuttingDown);
120 }
121
122 self.in_flight.fetch_add(1, Ordering::SeqCst);
123 Ok(WorkGuard {
124 coordinator: Arc::clone(self),
125 })
126 }
127
128 pub async fn shutdown(&self) -> ShutdownResult<()> {
139 let old_state = self.state.compare_exchange(
141 ShutdownState::Running.as_u8(),
142 ShutdownState::Draining.as_u8(),
143 Ordering::SeqCst,
144 Ordering::SeqCst,
145 );
146
147 if old_state.is_err() {
148 return Err(ShutdownError::AlreadyShuttingDown);
149 }
150
151 debug!("shutdown initiated, draining in-flight work");
152
153 let _ = self.shutdown_signal.send(());
155
156 let result = timeout(self.drain_timeout, self.wait_for_drain()).await;
158
159 self.state
161 .store(ShutdownState::Terminated.as_u8(), Ordering::SeqCst);
162
163 match result {
164 Ok(_) => {
165 debug!("shutdown complete, all work drained");
166 Ok(())
167 }
168 Err(_) => {
169 let remaining = self.in_flight.load(Ordering::SeqCst);
170 warn!(remaining, "drain timeout expired");
171 Err(ShutdownError::DrainTimeout { remaining })
172 }
173 }
174 }
175
176 #[must_use]
178 pub fn state(&self) -> ShutdownState {
179 ShutdownState::from_u8(self.state.load(Ordering::SeqCst))
180 }
181
182 #[must_use]
186 pub fn subscribe(&self) -> broadcast::Receiver<()> {
187 self.shutdown_signal.subscribe()
188 }
189
190 #[must_use]
192 pub fn in_flight_count(&self) -> u32 {
193 self.in_flight.load(Ordering::SeqCst)
194 }
195
196 async fn wait_for_drain(&self) {
198 loop {
199 if self.in_flight.load(Ordering::SeqCst) == 0 {
200 break;
201 }
202 tokio::time::sleep(Duration::from_millis(10)).await;
203 }
204 }
205}
206
207pub struct WorkGuard {
211 coordinator: Arc<ShutdownCoordinator>,
212}
213
214impl Drop for WorkGuard {
215 fn drop(&mut self) {
216 self.coordinator.in_flight.fetch_sub(1, Ordering::SeqCst);
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223
224 #[test]
225 fn initial_state_is_running() {
226 let coordinator = Arc::new(ShutdownCoordinator::new(Duration::from_secs(1)));
227 assert_eq!(coordinator.state(), ShutdownState::Running);
228 assert_eq!(coordinator.in_flight_count(), 0);
229 }
230
231 #[test]
232 fn register_work_increments_counter() {
233 let coordinator = Arc::new(ShutdownCoordinator::new(Duration::from_secs(1)));
234 let _guard = coordinator.register_work().unwrap();
235 assert_eq!(coordinator.in_flight_count(), 1);
236 }
237
238 #[test]
239 fn work_guard_drop_decrements_counter() {
240 let coordinator = Arc::new(ShutdownCoordinator::new(Duration::from_secs(1)));
241 {
242 let _guard = coordinator.register_work().unwrap();
243 assert_eq!(coordinator.in_flight_count(), 1);
244 }
245 assert_eq!(coordinator.in_flight_count(), 0);
246 }
247
248 #[test]
249 fn register_work_fails_during_shutdown() {
250 let rt = tokio::runtime::Runtime::new().unwrap();
251 rt.block_on(async {
252 let coordinator = Arc::new(ShutdownCoordinator::new(Duration::from_secs(1)));
253 let _guard = coordinator.register_work().unwrap();
254
255 let coordinator_clone = Arc::clone(&coordinator);
256 tokio::spawn(async move {
257 let _ = coordinator_clone.shutdown().await;
258 });
259
260 tokio::time::sleep(Duration::from_millis(50)).await;
261 let result = coordinator.register_work();
262 assert!(result.is_err());
263 });
264 }
265
266 #[tokio::test]
267 async fn shutdown_waits_for_work_to_complete() {
268 let coordinator = Arc::new(ShutdownCoordinator::new(Duration::from_secs(5)));
269 let guard = coordinator.register_work().unwrap();
270
271 let coordinator_clone = Arc::clone(&coordinator);
272 let shutdown_task = tokio::spawn(async move { coordinator_clone.shutdown().await });
273
274 tokio::time::sleep(Duration::from_millis(100)).await;
275 drop(guard);
276
277 let result = shutdown_task.await.unwrap();
278 assert!(result.is_ok());
279 assert_eq!(coordinator.state(), ShutdownState::Terminated);
280 }
281
282 #[tokio::test]
283 async fn shutdown_timeout_with_pending_work() {
284 let coordinator = Arc::new(ShutdownCoordinator::new(Duration::from_millis(100)));
285 let _guard = coordinator.register_work().unwrap();
286
287 let result = coordinator.shutdown().await;
288 assert!(result.is_err());
289 assert_eq!(coordinator.state(), ShutdownState::Terminated);
290 }
291
292 #[tokio::test]
293 async fn shutdown_signal_broadcast() {
294 let coordinator = Arc::new(ShutdownCoordinator::new(Duration::from_secs(1)));
295 let mut rx = coordinator.subscribe();
296
297 let coordinator_clone = Arc::clone(&coordinator);
298 tokio::spawn(async move {
299 tokio::time::sleep(Duration::from_millis(50)).await;
300 let _ = coordinator_clone.shutdown().await;
301 });
302
303 let result = tokio::time::timeout(Duration::from_secs(1), rx.recv()).await;
304 assert!(result.is_ok());
305 }
306}