1use arc_swap::ArcSwap;
8use parking_lot::Mutex;
9use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12use tracing::{debug, info, warn};
13
14use crate::error::{Error, Result};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum ShutdownReason {
19 Signal(i32),
21 Requested,
23 Error,
25 ResourceExhausted,
27 Forced,
29}
30
31impl std::fmt::Display for ShutdownReason {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 match self {
34 Self::Signal(sig) => write!(f, "Signal({sig})"),
35 Self::Requested => write!(f, "Requested"),
36 Self::Error => write!(f, "Error"),
37 Self::ResourceExhausted => write!(f, "ResourceExhausted"),
38 Self::Forced => write!(f, "Forced"),
39 }
40 }
41}
42
43#[derive(Debug, Clone)]
45pub struct ShutdownHandle {
46 inner: Arc<ShutdownInner>,
47 subsystem_id: u64,
48}
49
50impl ShutdownHandle {
51 const fn new(inner: Arc<ShutdownInner>, subsystem_id: u64) -> Self {
53 Self {
54 inner,
55 subsystem_id,
56 }
57 }
58
59 #[must_use]
61 pub fn is_shutdown(&self) -> bool {
62 self.inner.is_shutdown()
63 }
64
65 pub async fn cancelled(&mut self) {
68 if self.inner.shutdown_initiated.load(Ordering::Relaxed) {
70 return;
71 }
72
73 #[cfg(feature = "tokio")]
75 {
76 let mut rx = self.inner.shutdown_tx.subscribe();
77 if self.is_shutdown() {
78 return;
79 }
80 let _ = rx.recv().await;
81 }
82
83 #[cfg(all(feature = "async-std", not(feature = "tokio")))]
84 {
85 let shutdown_flag = &self.inner.shutdown_initiated;
87 loop {
88 if shutdown_flag.load(Ordering::Acquire) {
89 break;
90 }
91 async_std::task::sleep(Duration::from_millis(10)).await;
92 }
93 }
94 }
95
96 #[must_use]
98 pub fn shutdown_reason(&self) -> Option<ShutdownReason> {
99 if self.is_shutdown() {
100 Some(**self.inner.shutdown_reason.load())
101 } else {
102 None
103 }
104 }
105
106 #[must_use]
108 pub fn shutdown_time(&self) -> Option<Instant> {
109 *self.inner.shutdown_time.lock()
110 }
111
112 #[must_use]
114 pub fn is_forced(&self) -> bool {
115 matches!(self.shutdown_reason(), Some(ShutdownReason::Forced))
116 }
117
118 pub fn ready(&self) {
121 self.inner.mark_subsystem_ready(self.subsystem_id);
122 }
123
124 #[must_use]
126 pub fn time_remaining(&self) -> Option<Duration> {
127 self.shutdown_time().and_then(|shutdown_time| {
128 let elapsed = shutdown_time.elapsed();
129 let timeout =
130 Duration::from_millis(self.inner.graceful_timeout_ms.load(Ordering::Acquire));
131
132 if elapsed < timeout {
133 timeout.checked_sub(elapsed)
134 } else {
135 None
136 }
137 })
138 }
139}
140
141#[derive(Debug)]
143struct ShutdownInner {
144 shutdown_initiated: AtomicBool,
146 shutdown_reason: ArcSwap<ShutdownReason>,
148 shutdown_time: Mutex<Option<Instant>>,
150 graceful_timeout_ms: AtomicU64,
152 force_timeout_ms: AtomicU64,
154 kill_timeout_ms: AtomicU64,
156 subsystems: Mutex<Vec<SubsystemState>>,
158 #[cfg(feature = "tokio")]
160 shutdown_tx: tokio::sync::broadcast::Sender<ShutdownReason>,
161}
162
163#[derive(Debug)]
165struct SubsystemState {
166 id: u64,
167 name: String,
168 ready: AtomicBool,
169 #[allow(dead_code)]
170 registered_at: Instant,
171}
172
173impl ShutdownInner {
174 fn new(graceful_timeout_ms: u64, force_timeout_ms: u64, kill_timeout_ms: u64) -> Self {
175 #[cfg(feature = "tokio")]
176 let (shutdown_tx, _) = tokio::sync::broadcast::channel(16);
177
178 Self {
179 shutdown_initiated: AtomicBool::new(false),
180 shutdown_reason: ArcSwap::new(Arc::new(ShutdownReason::Requested)),
181 shutdown_time: Mutex::new(None),
182 graceful_timeout_ms: AtomicU64::new(graceful_timeout_ms),
183 force_timeout_ms: AtomicU64::new(force_timeout_ms),
184 kill_timeout_ms: AtomicU64::new(kill_timeout_ms),
185 subsystems: Mutex::new(Vec::new()),
186 #[cfg(feature = "tokio")]
187 shutdown_tx,
188 }
189 }
190
191 #[must_use]
193 pub fn is_shutdown(&self) -> bool {
194 self.shutdown_initiated.load(Ordering::Relaxed)
196 }
197
198 #[must_use]
201 pub fn initiate_shutdown(&self, reason: ShutdownReason) -> bool {
202 if self
204 .shutdown_initiated
205 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
206 .is_ok()
207 {
208 self.shutdown_reason.store(Arc::new(reason));
210 *self.shutdown_time.lock() = Some(Instant::now());
211
212 #[cfg(feature = "tokio")]
214 {
215 let _ = self.shutdown_tx.send(reason);
216 }
217
218 info!("Shutdown initiated: {}", reason);
219 true
220 } else {
221 debug!("Shutdown already initiated, ignoring additional request");
222 false
223 }
224 }
225
226 fn register_subsystem(&self, name: &str) -> u64 {
227 let id = fastrand::u64(..);
228 let state = SubsystemState {
229 id,
230 name: name.to_string(),
231 ready: AtomicBool::new(false),
232 registered_at: Instant::now(),
233 };
234
235 self.subsystems.lock().push(state);
236 debug!("Registered subsystem '{}' with ID {}", name, id);
237 id
238 }
239
240 fn mark_subsystem_ready(&self, subsystem_id: u64) {
241 let subsystems = self.subsystems.lock();
242 if let Some(subsystem) = subsystems.iter().find(|s| s.id == subsystem_id) {
244 subsystem.ready.store(true, Ordering::Relaxed);
245 debug!(
246 "Subsystem '{}' marked as ready for shutdown",
247 subsystem.name
248 );
249 }
250 }
251
252 fn are_all_subsystems_ready(&self) -> bool {
253 let subsystems = self.subsystems.lock();
254 subsystems.iter().all(|s| s.ready.load(Ordering::Relaxed))
256 }
257
258 fn get_subsystem_states(&self) -> Vec<(String, bool)> {
259 let subsystems = self.subsystems.lock();
260 subsystems
261 .iter()
262 .map(|s| (s.name.clone(), s.ready.load(Ordering::Relaxed)))
263 .collect()
264 }
265}
266
267#[derive(Debug)]
269pub struct ShutdownCoordinator {
270 inner: Arc<ShutdownInner>,
271}
272
273impl ShutdownCoordinator {
274 #[must_use]
276 pub fn new(graceful_timeout_ms: u64, force_timeout_ms: u64, kill_timeout_ms: u64) -> Self {
277 Self {
278 inner: Arc::new(ShutdownInner::new(
279 graceful_timeout_ms,
280 force_timeout_ms,
281 kill_timeout_ms,
282 )),
283 }
284 }
285
286 pub fn create_handle<S: Into<String>>(&self, subsystem_name: S) -> ShutdownHandle {
288 let name = subsystem_name.into();
289 let subsystem_id = self.inner.register_subsystem(&name);
290 ShutdownHandle::new(Arc::clone(&self.inner), subsystem_id)
291 }
292
293 #[must_use]
295 pub fn initiate_shutdown(&self, reason: ShutdownReason) -> bool {
296 self.inner.initiate_shutdown(reason)
297 }
298
299 #[must_use]
301 pub fn is_shutdown(&self) -> bool {
302 self.inner.is_shutdown()
303 }
304
305 #[must_use]
307 pub fn get_reason(&self) -> Option<ShutdownReason> {
308 if self.is_shutdown() {
309 Some(**self.inner.shutdown_reason.load())
310 } else {
311 None
312 }
313 }
314
315 pub async fn wait_for_shutdown(&self) -> Result<()> {
322 if !self.is_shutdown() {
323 return Err(Error::invalid_state("Shutdown not initiated"));
324 }
325
326 let shutdown_time = *self.inner.shutdown_time.lock();
327 if shutdown_time.is_none() {
328 return Err(Error::invalid_state("Shutdown time not set"));
329 }
330
331 let graceful_timeout =
332 Duration::from_millis(self.inner.graceful_timeout_ms.load(Ordering::Acquire));
333
334 info!(
335 "Waiting for subsystems to shutdown gracefully (timeout: {:?})",
336 graceful_timeout
337 );
338
339 let start = Instant::now();
341
342 if self.inner.are_all_subsystems_ready() {
344 info!("All subsystems already shut down gracefully");
345 return Ok(());
346 }
347
348 let mut poll_interval = Duration::from_millis(1);
350 let max_poll_interval = Duration::from_millis(50);
351
352 while start.elapsed() < graceful_timeout {
353 if self.inner.are_all_subsystems_ready() {
354 info!(
355 "All subsystems shut down gracefully in {:?}",
356 start.elapsed()
357 );
358 return Ok(());
359 }
360
361 #[cfg(feature = "tokio")]
363 tokio::time::sleep(poll_interval).await;
364
365 #[cfg(all(feature = "async-std", not(feature = "tokio")))]
366 async_std::task::sleep(poll_interval).await;
367
368 poll_interval = (poll_interval * 2).min(max_poll_interval);
370 }
371
372 let states = self.inner.get_subsystem_states();
374 let not_ready: Vec<String> = states
375 .into_iter()
376 .filter_map(|(name, ready)| if ready { None } else { Some(name) })
377 .collect();
378
379 warn!(
380 "Graceful shutdown timeout exceeded. Subsystems not ready: {:?}",
381 not_ready
382 );
383
384 let _ = self.inner.initiate_shutdown(ShutdownReason::Forced);
386
387 let timeout_ms = u64::try_from(graceful_timeout.as_millis()).unwrap_or(u64::MAX);
388 Err(Error::timeout("Graceful shutdown", timeout_ms))
389 }
390
391 pub async fn wait_for_force_shutdown(&self) -> Result<()> {
398 let force_timeout =
399 Duration::from_millis(self.inner.force_timeout_ms.load(Ordering::Acquire));
400
401 warn!("Waiting for forced shutdown timeout: {:?}", force_timeout);
402
403 let start = Instant::now();
404 while start.elapsed() < force_timeout {
405 if self.inner.are_all_subsystems_ready() {
406 info!("All subsystems shut down during force phase");
407 return Ok(());
408 }
409
410 #[cfg(feature = "tokio")]
411 tokio::time::sleep(Duration::from_millis(50)).await;
412
413 #[cfg(all(feature = "async-std", not(feature = "tokio")))]
414 async_std::task::sleep(Duration::from_millis(50)).await;
415 }
416
417 let timeout_ms = u64::try_from(force_timeout.as_millis()).unwrap_or(u64::MAX);
418 Err(Error::timeout("Force shutdown", timeout_ms))
419 }
420
421 pub async fn wait_for_kill_shutdown(&self) -> Result<()> {
427 let kill_timeout =
428 Duration::from_millis(self.inner.kill_timeout_ms.load(Ordering::Acquire));
429
430 warn!("Waiting for kill shutdown timeout: {:?}", kill_timeout);
431
432 #[cfg(feature = "tokio")]
433 tokio::time::sleep(kill_timeout).await;
434
435 #[cfg(all(feature = "async-std", not(feature = "tokio")))]
436 async_std::task::sleep(kill_timeout).await;
437
438 let timeout_ms = u64::try_from(kill_timeout.as_millis()).unwrap_or(u64::MAX);
439 Err(Error::timeout("Kill shutdown", timeout_ms))
440 }
441
442 #[must_use]
444 pub fn get_stats(&self) -> ShutdownStats {
445 let subsystems = self.inner.get_subsystem_states();
446 let total_subsystems = subsystems.len();
447 let ready_subsystems = subsystems.iter().filter(|(_, ready)| *ready).count();
448
449 ShutdownStats {
450 is_shutdown: self.is_shutdown(),
451 reason: if self.is_shutdown() {
452 Some(**self.inner.shutdown_reason.load())
453 } else {
454 None
455 },
456 shutdown_time: *self.inner.shutdown_time.lock(),
457 total_subsystems,
458 ready_subsystems,
459 subsystem_states: subsystems,
460 }
461 }
462
463 pub fn update_timeouts(
465 &self,
466 graceful_timeout_ms: u64,
467 force_timeout_ms: u64,
468 kill_timeout_ms: u64,
469 ) {
470 self.inner
471 .graceful_timeout_ms
472 .store(graceful_timeout_ms, Ordering::Release);
473 self.inner
474 .force_timeout_ms
475 .store(force_timeout_ms, Ordering::Release);
476 self.inner
477 .kill_timeout_ms
478 .store(kill_timeout_ms, Ordering::Release);
479 debug!(
480 "Updated shutdown timeouts: graceful={}ms, force={}ms, kill={}ms",
481 graceful_timeout_ms, force_timeout_ms, kill_timeout_ms
482 );
483 }
484}
485
486impl Clone for ShutdownCoordinator {
487 fn clone(&self) -> Self {
488 Self {
489 inner: Arc::clone(&self.inner),
490 }
491 }
492}
493
494#[derive(Debug, Clone)]
496pub struct ShutdownStats {
497 pub is_shutdown: bool,
499 pub reason: Option<ShutdownReason>,
501 pub shutdown_time: Option<Instant>,
503 pub total_subsystems: usize,
505 pub ready_subsystems: usize,
507 pub subsystem_states: Vec<(String, bool)>,
509}
510
511impl ShutdownStats {
512 #[must_use]
514 #[allow(clippy::cast_precision_loss)]
515 pub fn progress(&self) -> f64 {
516 if self.total_subsystems == 0 {
518 1.0
519 } else {
520 self.ready_subsystems as f64 / self.total_subsystems as f64
522 }
523 }
524
525 #[must_use]
527 pub const fn is_complete(&self) -> bool {
528 self.total_subsystems > 0 && self.ready_subsystems == self.total_subsystems
529 }
530
531 #[must_use]
533 pub fn elapsed(&self) -> Option<Duration> {
534 self.shutdown_time.map(|t| t.elapsed())
535 }
536}
537
538#[cfg(test)]
539mod tests {
540 use super::*;
541 use std::time::Duration;
542
543 #[cfg(feature = "tokio")]
544 #[cfg_attr(miri, ignore)]
545 #[tokio::test]
546 async fn test_shutdown_coordination() {
547 let test_result = tokio::time::timeout(Duration::from_secs(5), async {
549 let coordinator = ShutdownCoordinator::new(100, 200, 300);
551
552 let handle1 = coordinator.create_handle("subsystem1");
554 let handle2 = coordinator.create_handle("subsystem2");
555
556 assert!(!coordinator.is_shutdown());
558 assert!(!handle1.is_shutdown());
559
560 assert!(coordinator.initiate_shutdown(ShutdownReason::Requested));
562
563 assert!(coordinator.is_shutdown());
565 assert!(handle1.is_shutdown());
566
567 assert!(handle1.is_shutdown());
570 assert!(handle2.is_shutdown());
571
572 handle1.ready();
574 handle2.ready();
575
576 let stats = coordinator.get_stats();
578 assert!(stats.is_complete());
579 let epsilon: f64 = 1e-6;
581 assert!((stats.progress() - 1.0).abs() < epsilon);
582 })
583 .await;
584
585 assert!(test_result.is_ok(), "Test timed out after 5 seconds");
586 }
587
588 #[cfg(feature = "tokio")]
589 #[cfg_attr(miri, ignore)]
590 #[tokio::test]
591 async fn test_shutdown_timeout() {
592 let test_result = tokio::time::timeout(Duration::from_secs(5), async {
594 let coordinator = ShutdownCoordinator::new(100, 200, 300); let _handle1 = coordinator.create_handle("slow_subsystem");
597
598 let _ = coordinator.initiate_shutdown(ShutdownReason::Requested);
600
601 let result = coordinator.wait_for_shutdown().await;
603 assert!(result.is_err());
604 assert!(result.unwrap_err().is_timeout());
605 })
606 .await;
607
608 assert!(test_result.is_ok(), "Test timed out after 5 seconds");
609 }
610
611 #[cfg(all(feature = "async-std", not(feature = "tokio")))]
612 #[async_std::test]
613 async fn test_shutdown_timeout() {
614 let test_result = async_std::future::timeout(Duration::from_secs(5), async {
616 let coordinator = ShutdownCoordinator::new(100, 200, 300); let _handle1 = coordinator.create_handle("slow_subsystem");
619
620 let _ = coordinator.initiate_shutdown(ShutdownReason::Requested);
622
623 let result = coordinator.wait_for_shutdown().await;
625 assert!(result.is_err());
626 assert!(result.unwrap_err().is_timeout());
627 })
628 .await;
629
630 assert!(test_result.is_ok(), "Test timed out after 5 seconds");
631 }
632
633 #[test]
634 fn test_shutdown_reason_display() {
635 assert_eq!(format!("{}", ShutdownReason::Signal(15)), "Signal(15)");
636 assert_eq!(format!("{}", ShutdownReason::Requested), "Requested");
637 assert_eq!(format!("{}", ShutdownReason::Error), "Error");
638 }
639
640 #[cfg(feature = "tokio")]
641 #[cfg_attr(miri, ignore)]
642 #[tokio::test]
643 async fn test_multiple_shutdown_initiation() {
644 let test_result = tokio::time::timeout(Duration::from_secs(5), async {
646 let coordinator = ShutdownCoordinator::new(5000, 10000, 15000);
647
648 assert!(coordinator.initiate_shutdown(ShutdownReason::Requested));
650
651 assert!(!coordinator.initiate_shutdown(ShutdownReason::Signal(15)));
653 assert!(!coordinator.initiate_shutdown(ShutdownReason::Error));
654
655 assert_eq!(coordinator.get_reason(), Some(ShutdownReason::Requested));
657 })
658 .await;
659
660 assert!(test_result.is_ok(), "Test timed out after 5 seconds");
661 }
662
663 #[cfg(all(feature = "async-std", not(feature = "tokio")))]
664 #[async_std::test]
665 async fn test_multiple_shutdown_initiation() {
666 let test_result = async_std::future::timeout(Duration::from_secs(5), async {
668 let coordinator = ShutdownCoordinator::new(5000, 10000, 15000);
669
670 assert!(coordinator.initiate_shutdown(ShutdownReason::Requested));
672
673 assert!(!coordinator.initiate_shutdown(ShutdownReason::Signal(15)));
675 assert!(!coordinator.initiate_shutdown(ShutdownReason::Error));
676
677 let stats = coordinator.get_stats();
679 assert_eq!(stats.reason, Some(ShutdownReason::Requested));
680 })
681 .await;
682
683 assert!(test_result.is_ok(), "Test timed out after 5 seconds");
684 }
685
686 #[test]
687 fn test_shutdown_stats() {
688 let coordinator = ShutdownCoordinator::new(5000, 10000, 15000);
689 let handle1 = coordinator.create_handle("test1");
690 let handle2 = coordinator.create_handle("test2");
691
692 let stats = coordinator.get_stats();
693 assert_eq!(stats.total_subsystems, 2);
694 assert_eq!(stats.ready_subsystems, 0);
695 assert!(!stats.is_complete());
696
697 let epsilon: f64 = 1e-6;
699 assert!((stats.progress() - 0.0).abs() < epsilon);
700
701 handle1.ready();
702 let stats = coordinator.get_stats();
703 assert_eq!(stats.ready_subsystems, 1);
704
705 assert!((stats.progress() - 0.5).abs() < epsilon);
706
707 handle2.ready();
708 let stats = coordinator.get_stats();
709 assert!(stats.is_complete());
710
711 assert!((stats.progress() - 1.0).abs() < epsilon);
712 }
713}
714
715#[cfg(all(feature = "async-std", not(feature = "tokio")))]
716#[async_std::test]
717async fn test_shutdown_coordination() {
718 let test_result = async_std::future::timeout(Duration::from_secs(5), async {
720 let coordinator = ShutdownCoordinator::new(100, 200, 300);
722
723 let handle1 = coordinator.create_handle("subsystem1");
725 let handle2 = coordinator.create_handle("subsystem2");
726
727 assert!(!coordinator.is_shutdown());
729 assert!(!handle1.is_shutdown());
730
731 assert!(coordinator.initiate_shutdown(ShutdownReason::Requested));
733
734 assert!(coordinator.is_shutdown());
736 assert!(handle1.is_shutdown());
737
738 assert!(handle1.is_shutdown());
741 assert!(handle2.is_shutdown());
742
743 handle1.ready();
745 handle2.ready();
746
747 let stats = coordinator.get_stats();
749 assert!(stats.is_complete());
750 let epsilon: f64 = 1e-6;
752 assert!((stats.progress() - 1.0).abs() < epsilon);
753 })
754 .await;
755
756 assert!(test_result.is_ok(), "Test timed out after 5 seconds");
757}