1use arc_swap::ArcSwap;
8use parking_lot::Mutex;
9use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12use tracing::{debug, error, info, warn};
13
14use crate::error::{Error, Result};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum ShutdownReason {
19 Signal(i32),
21 Requested,
23 Error,
25 ResourceExhausted,
27 Forced,
29}
30
31impl std::fmt::Display for ShutdownReason {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 match self {
34 Self::Signal(sig) => write!(f, "Signal({sig})"),
35 Self::Requested => write!(f, "Requested"),
36 Self::Error => write!(f, "Error"),
37 Self::ResourceExhausted => write!(f, "ResourceExhausted"),
38 Self::Forced => write!(f, "Forced"),
39 }
40 }
41}
42
43#[derive(Debug, Clone)]
45pub struct ShutdownHandle {
46 inner: Arc<ShutdownInner>,
47 subsystem_id: u64,
48}
49
50impl ShutdownHandle {
51 const fn new(inner: Arc<ShutdownInner>, subsystem_id: u64) -> Self {
53 Self {
54 inner,
55 subsystem_id,
56 }
57 }
58
59 #[must_use]
61 pub fn is_shutdown(&self) -> bool {
62 self.inner.is_shutdown()
63 }
64
65 pub async fn cancelled(&mut self) {
68 #[cfg(feature = "tokio")]
70 {
71 let mut rx = self.inner.shutdown_tx.subscribe();
72 let _ = rx.recv().await;
73 }
74
75 #[cfg(all(feature = "async-std", not(feature = "tokio")))]
76 {
77 let shutdown_flag = &self.inner.shutdown_initiated;
79 loop {
80 if shutdown_flag.load(Ordering::Acquire) {
81 break;
82 }
83 async_std::task::sleep(Duration::from_millis(10)).await;
84 }
85 }
86 }
87
88 #[must_use]
90 pub fn shutdown_reason(&self) -> Option<ShutdownReason> {
91 if self.is_shutdown() {
92 Some(**self.inner.shutdown_reason.load())
93 } else {
94 None
95 }
96 }
97
98 #[must_use]
100 pub fn shutdown_time(&self) -> Option<Instant> {
101 *self.inner.shutdown_time.lock()
102 }
103
104 #[must_use]
106 pub fn is_forced(&self) -> bool {
107 matches!(self.shutdown_reason(), Some(ShutdownReason::Forced))
108 }
109
110 pub fn ready(&self) {
113 self.inner.mark_subsystem_ready(self.subsystem_id);
114 }
115
116 #[must_use]
118 pub fn time_remaining(&self) -> Option<Duration> {
119 self.shutdown_time().and_then(|shutdown_time| {
120 let elapsed = shutdown_time.elapsed();
121 let timeout =
122 Duration::from_millis(self.inner.force_timeout_ms.load(Ordering::Acquire));
123
124 if elapsed < timeout {
125 Some(timeout - elapsed)
126 } else {
127 None
128 }
129 })
130 }
131}
132
133#[derive(Debug)]
135struct ShutdownInner {
136 shutdown_initiated: AtomicBool,
138 shutdown_reason: ArcSwap<ShutdownReason>,
140 shutdown_time: Mutex<Option<Instant>>,
142 force_timeout_ms: AtomicU64,
144 kill_timeout_ms: AtomicU64,
146 subsystems: Mutex<Vec<SubsystemState>>,
148 #[cfg(feature = "tokio")]
150 shutdown_tx: tokio::sync::broadcast::Sender<ShutdownReason>,
151}
152
153#[derive(Debug)]
155struct SubsystemState {
156 id: u64,
157 name: String,
158 ready: AtomicBool,
159 #[allow(dead_code)]
160 registered_at: Instant,
161}
162
163impl ShutdownInner {
164 fn new(force_timeout_ms: u64, kill_timeout_ms: u64) -> Self {
165 #[cfg(feature = "tokio")]
166 let (shutdown_tx, _) = tokio::sync::broadcast::channel(16);
167
168 Self {
169 shutdown_initiated: AtomicBool::new(false),
170 shutdown_reason: ArcSwap::new(Arc::new(ShutdownReason::Requested)),
171 shutdown_time: Mutex::new(None),
172 force_timeout_ms: AtomicU64::new(force_timeout_ms),
173 kill_timeout_ms: AtomicU64::new(kill_timeout_ms),
174 subsystems: Mutex::new(Vec::new()),
175 #[cfg(feature = "tokio")]
176 shutdown_tx,
177 }
178 }
179
180 #[must_use]
182 pub fn is_shutdown(&self) -> bool {
183 self.shutdown_initiated.load(Ordering::Acquire)
184 }
185
186 #[must_use]
189 pub fn initiate_shutdown(&self, reason: ShutdownReason) -> bool {
190 if self
192 .shutdown_initiated
193 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
194 .is_ok()
195 {
196 self.shutdown_reason.store(Arc::new(reason));
198 *self.shutdown_time.lock() = Some(Instant::now());
199
200 #[cfg(feature = "tokio")]
202 {
203 let _ = self.shutdown_tx.send(reason);
204 }
205
206 info!("Shutdown initiated: {}", reason);
207 true
208 } else {
209 debug!("Shutdown already initiated, ignoring additional request");
210 false
211 }
212 }
213
214 fn register_subsystem(&self, name: &str) -> u64 {
215 let id = fastrand::u64(..);
216 let state = SubsystemState {
217 id,
218 name: name.to_string(),
219 ready: AtomicBool::new(false),
220 registered_at: Instant::now(),
221 };
222
223 self.subsystems.lock().push(state);
224 debug!("Registered subsystem '{}' with ID {}", name, id);
225 id
226 }
227
228 fn mark_subsystem_ready(&self, subsystem_id: u64) {
229 let subsystems = self.subsystems.lock();
230 if let Some(subsystem) = subsystems.iter().find(|s| s.id == subsystem_id) {
231 subsystem.ready.store(true, Ordering::Release);
232 debug!(
233 "Subsystem '{}' marked as ready for shutdown",
234 subsystem.name
235 );
236 }
237 }
238
239 fn are_all_subsystems_ready(&self) -> bool {
240 let subsystems = self.subsystems.lock();
241 subsystems.iter().all(|s| s.ready.load(Ordering::Acquire))
242 }
243
244 fn get_subsystem_states(&self) -> Vec<(String, bool)> {
245 let subsystems = self.subsystems.lock();
246 subsystems
247 .iter()
248 .map(|s| (s.name.clone(), s.ready.load(Ordering::Acquire)))
249 .collect()
250 }
251}
252
253#[derive(Debug)]
255pub struct ShutdownCoordinator {
256 inner: Arc<ShutdownInner>,
257}
258
259impl ShutdownCoordinator {
260 #[must_use]
262 pub fn new(force_timeout_ms: u64, kill_timeout_ms: u64) -> Self {
263 Self {
264 inner: Arc::new(ShutdownInner::new(force_timeout_ms, kill_timeout_ms)),
265 }
266 }
267
268 pub fn create_handle<S: Into<String>>(&self, subsystem_name: S) -> ShutdownHandle {
270 let name = subsystem_name.into();
271 let subsystem_id = self.inner.register_subsystem(&name);
272 ShutdownHandle::new(Arc::clone(&self.inner), subsystem_id)
273 }
274
275 #[must_use]
277 pub fn initiate_shutdown(&self, reason: ShutdownReason) -> bool {
278 self.inner.initiate_shutdown(reason)
279 }
280
281 #[must_use]
283 pub fn is_shutdown(&self) -> bool {
284 self.inner.is_shutdown()
285 }
286
287 #[must_use]
289 pub fn get_reason(&self) -> Option<ShutdownReason> {
290 if self.is_shutdown() {
291 Some(**self.inner.shutdown_reason.load())
292 } else {
293 None
294 }
295 }
296
297 pub async fn wait_for_shutdown(&self) -> Result<()> {
304 if !self.is_shutdown() {
305 return Err(Error::invalid_state("Shutdown not initiated"));
306 }
307
308 let _shutdown_time = self
309 .inner
310 .shutdown_time
311 .lock()
312 .ok_or_else(|| Error::invalid_state("Shutdown time not set"))?;
313
314 let graceful_timeout =
315 Duration::from_millis(self.inner.force_timeout_ms.load(Ordering::Acquire));
316
317 info!(
318 "Waiting for subsystems to shutdown gracefully (timeout: {:?})",
319 graceful_timeout
320 );
321
322 let start = Instant::now();
324 while start.elapsed() < graceful_timeout {
325 if self.inner.are_all_subsystems_ready() {
326 info!(
327 "All subsystems shut down gracefully in {:?}",
328 start.elapsed()
329 );
330 return Ok(());
331 }
332
333 #[cfg(feature = "tokio")]
335 tokio::time::sleep(Duration::from_millis(50)).await;
336
337 #[cfg(all(feature = "async-std", not(feature = "tokio")))]
338 async_std::task::sleep(Duration::from_millis(50)).await;
339 }
340
341 let states = self.inner.get_subsystem_states();
343 let not_ready: Vec<String> = states
344 .into_iter()
345 .filter_map(|(name, ready)| if ready { None } else { Some(name) })
346 .collect();
347
348 warn!(
349 "Graceful shutdown timeout exceeded. Subsystems not ready: {:?}",
350 not_ready
351 );
352
353 let _ = self.inner.initiate_shutdown(ShutdownReason::Forced);
355
356 let timeout_ms = u64::try_from(graceful_timeout.as_millis()).unwrap_or(u64::MAX);
357 Err(Error::timeout("Graceful shutdown", timeout_ms))
358 }
359
360 pub async fn wait_for_force_shutdown(&self) -> Result<()> {
367 let force_timeout =
368 Duration::from_millis(self.inner.force_timeout_ms.load(Ordering::Acquire));
369
370 warn!("Waiting for forced shutdown timeout: {:?}", force_timeout);
371
372 #[cfg(feature = "tokio")]
373 tokio::time::sleep(force_timeout).await;
374
375 #[cfg(all(feature = "async-std", not(feature = "tokio")))]
376 async_std::task::sleep(force_timeout).await;
377
378 error!("Force shutdown timeout exceeded, exiting immediately");
379 Ok(())
380 }
381
382 #[must_use]
384 pub fn get_stats(&self) -> ShutdownStats {
385 let subsystems = self.inner.get_subsystem_states();
386 let total_subsystems = subsystems.len();
387 let ready_subsystems = subsystems.iter().filter(|(_, ready)| *ready).count();
388
389 ShutdownStats {
390 is_shutdown: self.is_shutdown(),
391 reason: if self.is_shutdown() {
392 Some(**self.inner.shutdown_reason.load())
393 } else {
394 None
395 },
396 shutdown_time: *self.inner.shutdown_time.lock(),
397 total_subsystems,
398 ready_subsystems,
399 subsystem_states: subsystems,
400 }
401 }
402
403 pub fn update_timeouts(&self, force_timeout_ms: u64, kill_timeout_ms: u64) {
405 self.inner
406 .force_timeout_ms
407 .store(force_timeout_ms, Ordering::Release);
408 self.inner
409 .kill_timeout_ms
410 .store(kill_timeout_ms, Ordering::Release);
411 debug!(
412 "Updated shutdown timeouts: force={}ms, kill={}ms",
413 force_timeout_ms, kill_timeout_ms
414 );
415 }
416}
417
418impl Clone for ShutdownCoordinator {
419 fn clone(&self) -> Self {
420 Self {
421 inner: Arc::clone(&self.inner),
422 }
423 }
424}
425
426#[derive(Debug, Clone)]
428pub struct ShutdownStats {
429 pub is_shutdown: bool,
431 pub reason: Option<ShutdownReason>,
433 pub shutdown_time: Option<Instant>,
435 pub total_subsystems: usize,
437 pub ready_subsystems: usize,
439 pub subsystem_states: Vec<(String, bool)>,
441}
442
443impl ShutdownStats {
444 #[must_use]
446 #[allow(clippy::cast_precision_loss)]
447 pub fn progress(&self) -> f64 {
448 if self.total_subsystems == 0 {
450 1.0
451 } else {
452 self.ready_subsystems as f64 / self.total_subsystems as f64
454 }
455 }
456
457 #[must_use]
459 pub const fn is_complete(&self) -> bool {
460 self.total_subsystems > 0 && self.ready_subsystems == self.total_subsystems
461 }
462
463 #[must_use]
465 pub fn elapsed(&self) -> Option<Duration> {
466 self.shutdown_time.map(|t| t.elapsed())
467 }
468}
469
470#[cfg(test)]
471mod tests {
472 use super::*;
473 use std::time::Duration;
474
475 #[cfg(feature = "tokio")]
476 #[cfg_attr(miri, ignore)]
477 #[tokio::test]
478 async fn test_shutdown_coordination() {
479 let test_result = tokio::time::timeout(Duration::from_secs(5), async {
481 let coordinator = ShutdownCoordinator::new(100, 200);
483
484 let handle1 = coordinator.create_handle("subsystem1");
486 let handle2 = coordinator.create_handle("subsystem2");
487
488 assert!(!coordinator.is_shutdown());
490 assert!(!handle1.is_shutdown());
491
492 assert!(coordinator.initiate_shutdown(ShutdownReason::Requested));
494
495 assert!(coordinator.is_shutdown());
497 assert!(handle1.is_shutdown());
498
499 assert!(handle1.is_shutdown());
502 assert!(handle2.is_shutdown());
503
504 handle1.ready();
506 handle2.ready();
507
508 let stats = coordinator.get_stats();
510 assert!(stats.is_complete());
511 let epsilon: f64 = 1e-6;
513 assert!((stats.progress() - 1.0).abs() < epsilon);
514 })
515 .await;
516
517 assert!(test_result.is_ok(), "Test timed out after 5 seconds");
518 }
519
520 #[cfg(feature = "tokio")]
521 #[cfg_attr(miri, ignore)]
522 #[tokio::test]
523 async fn test_shutdown_timeout() {
524 let test_result = tokio::time::timeout(Duration::from_secs(5), async {
526 let coordinator = ShutdownCoordinator::new(100, 200); let _handle1 = coordinator.create_handle("slow_subsystem");
529
530 let _ = coordinator.initiate_shutdown(ShutdownReason::Requested);
532
533 let result = coordinator.wait_for_shutdown().await;
535 assert!(result.is_err());
536 assert!(result.unwrap_err().is_timeout());
537 })
538 .await;
539
540 assert!(test_result.is_ok(), "Test timed out after 5 seconds");
541 }
542
543 #[cfg(all(feature = "async-std", not(feature = "tokio")))]
544 #[async_std::test]
545 async fn test_shutdown_timeout() {
546 let test_result = async_std::future::timeout(Duration::from_secs(5), async {
548 let coordinator = ShutdownCoordinator::new(100, 200); let _handle1 = coordinator.create_handle("slow_subsystem");
551
552 let _ = coordinator.initiate_shutdown(ShutdownReason::Requested);
554
555 let result = coordinator.wait_for_shutdown().await;
557 assert!(result.is_err());
558 assert!(result.unwrap_err().is_timeout());
559 })
560 .await;
561
562 assert!(test_result.is_ok(), "Test timed out after 5 seconds");
563 }
564
565 #[test]
566 fn test_shutdown_reason_display() {
567 assert_eq!(format!("{}", ShutdownReason::Signal(15)), "Signal(15)");
568 assert_eq!(format!("{}", ShutdownReason::Requested), "Requested");
569 assert_eq!(format!("{}", ShutdownReason::Error), "Error");
570 }
571
572 #[cfg(feature = "tokio")]
573 #[cfg_attr(miri, ignore)]
574 #[tokio::test]
575 async fn test_multiple_shutdown_initiation() {
576 let test_result = tokio::time::timeout(Duration::from_secs(5), async {
578 let coordinator = ShutdownCoordinator::new(5000, 10000);
579
580 assert!(coordinator.initiate_shutdown(ShutdownReason::Requested));
582
583 assert!(!coordinator.initiate_shutdown(ShutdownReason::Signal(15)));
585 assert!(!coordinator.initiate_shutdown(ShutdownReason::Error));
586
587 assert_eq!(coordinator.get_reason(), Some(ShutdownReason::Requested));
589 })
590 .await;
591
592 assert!(test_result.is_ok(), "Test timed out after 5 seconds");
593 }
594
595 #[cfg(all(feature = "async-std", not(feature = "tokio")))]
596 #[async_std::test]
597 async fn test_multiple_shutdown_initiation() {
598 let test_result = async_std::future::timeout(Duration::from_secs(5), async {
600 let coordinator = ShutdownCoordinator::new(5000, 10000);
601
602 assert!(coordinator.initiate_shutdown(ShutdownReason::Requested));
604
605 assert!(!coordinator.initiate_shutdown(ShutdownReason::Signal(15)));
607 assert!(!coordinator.initiate_shutdown(ShutdownReason::Error));
608
609 let stats = coordinator.get_stats();
611 assert_eq!(stats.reason, Some(ShutdownReason::Requested));
612 })
613 .await;
614
615 assert!(test_result.is_ok(), "Test timed out after 5 seconds");
616 }
617
618 #[test]
619 fn test_shutdown_stats() {
620 let coordinator = ShutdownCoordinator::new(5000, 10000);
621 let handle1 = coordinator.create_handle("test1");
622 let handle2 = coordinator.create_handle("test2");
623
624 let stats = coordinator.get_stats();
625 assert_eq!(stats.total_subsystems, 2);
626 assert_eq!(stats.ready_subsystems, 0);
627 assert!(!stats.is_complete());
628
629 let epsilon: f64 = 1e-6;
631 assert!((stats.progress() - 0.0).abs() < epsilon);
632
633 handle1.ready();
634 let stats = coordinator.get_stats();
635 assert_eq!(stats.ready_subsystems, 1);
636
637 assert!((stats.progress() - 0.5).abs() < epsilon);
638
639 handle2.ready();
640 let stats = coordinator.get_stats();
641 assert!(stats.is_complete());
642
643 assert!((stats.progress() - 1.0).abs() < epsilon);
644 }
645}
646
647#[cfg(all(feature = "async-std", not(feature = "tokio")))]
648#[async_std::test]
649async fn test_shutdown_coordination() {
650 let test_result = async_std::future::timeout(Duration::from_secs(5), async {
652 let coordinator = ShutdownCoordinator::new(100, 200);
654
655 let handle1 = coordinator.create_handle("subsystem1");
657 let handle2 = coordinator.create_handle("subsystem2");
658
659 assert!(!coordinator.is_shutdown());
661 assert!(!handle1.is_shutdown());
662
663 assert!(coordinator.initiate_shutdown(ShutdownReason::Requested));
665
666 assert!(coordinator.is_shutdown());
668 assert!(handle1.is_shutdown());
669
670 assert!(handle1.is_shutdown());
673 assert!(handle2.is_shutdown());
674
675 handle1.ready();
677 handle2.ready();
678
679 let stats = coordinator.get_stats();
681 assert!(stats.is_complete());
682 let epsilon: f64 = 1e-6;
684 assert!((stats.progress() - 1.0).abs() < epsilon);
685 })
686 .await;
687
688 assert!(test_result.is_ok(), "Test timed out after 5 seconds");
689}