1use crate::config::ReloadableConfig;
10use std::fmt;
11use std::path::PathBuf;
12use std::sync::Arc;
13use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
14use std::time::{Duration, Instant};
15use tokio::sync::{Mutex, broadcast, watch};
16use tracing::{debug, error, info, warn};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
24pub enum ShutdownPhase {
25 Running,
27 Draining,
29 FlushingState,
31 Terminated,
33}
34
35impl ShutdownPhase {
36 fn as_u64(self) -> u64 {
38 match self {
39 Self::Running => 0,
40 Self::Draining => 1,
41 Self::FlushingState => 2,
42 Self::Terminated => 3,
43 }
44 }
45
46 fn from_u64(val: u64) -> Self {
47 match val {
48 0 => Self::Running,
49 1 => Self::Draining,
50 2 => Self::FlushingState,
51 3 => Self::Terminated,
52 _ => Self::Terminated,
53 }
54 }
55}
56
57impl fmt::Display for ShutdownPhase {
58 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59 match self {
60 Self::Running => write!(f, "Running"),
61 Self::Draining => write!(f, "Draining"),
62 Self::FlushingState => write!(f, "FlushingState"),
63 Self::Terminated => write!(f, "Terminated"),
64 }
65 }
66}
67
68#[derive(Debug, Clone)]
74pub struct DrainConfig {
75 pub drain_timeout: Duration,
77 pub check_interval: Duration,
79 pub flush_timeout: Duration,
81}
82
83impl Default for DrainConfig {
84 fn default() -> Self {
85 Self {
86 drain_timeout: Duration::from_secs(30),
87 check_interval: Duration::from_secs(1),
88 flush_timeout: Duration::from_secs(30),
89 }
90 }
91}
92
93#[async_trait::async_trait]
102pub trait ShutdownHook: Send + Sync {
103 fn name(&self) -> &str;
105
106 async fn on_shutdown(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
109}
110
111pub trait WalWriter: Send + Sync {
117 fn sync(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
119 fn current_size(&self) -> u64;
121}
122
123pub trait MemtableFlusher: Send + Sync {
125 fn flush_to_sstable(&self) -> Result<usize, Box<dyn std::error::Error + Send + Sync>>;
127}
128
129#[derive(Debug, Clone)]
135pub struct HookExecutionResult {
136 pub hook_name: String,
138 pub success: bool,
140 pub duration: Duration,
142 pub error: Option<String>,
144}
145
146pub struct WalFlushHook {
152 pub timeout: Duration,
154 writer: Option<Arc<dyn WalWriter>>,
156}
157
158impl WalFlushHook {
159 pub fn with_writer(writer: Arc<dyn WalWriter>, timeout: Duration) -> Self {
161 Self {
162 timeout,
163 writer: Some(writer),
164 }
165 }
166}
167
168impl Default for WalFlushHook {
169 fn default() -> Self {
170 Self {
171 timeout: Duration::from_secs(10),
172 writer: None,
173 }
174 }
175}
176
177#[async_trait::async_trait]
178impl ShutdownHook for WalFlushHook {
179 fn name(&self) -> &str {
180 "WalFlush"
181 }
182
183 async fn on_shutdown(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
184 match &self.writer {
185 Some(writer) => {
186 let size = writer.current_size();
187 info!("Flushing WAL to disk ({} bytes)", size);
188 writer.sync()?;
189 info!("WAL flush complete ({} bytes synced)", size);
190 }
191 None => {
192 info!("No WAL writer configured - skipping flush");
193 }
194 }
195 Ok(())
196 }
197}
198
199pub struct MemtableFlushHook {
201 pub timeout: Duration,
203 flusher: Option<Arc<dyn MemtableFlusher>>,
205}
206
207impl MemtableFlushHook {
208 pub fn with_flusher(flusher: Arc<dyn MemtableFlusher>, timeout: Duration) -> Self {
210 Self {
211 timeout,
212 flusher: Some(flusher),
213 }
214 }
215}
216
217impl Default for MemtableFlushHook {
218 fn default() -> Self {
219 Self {
220 timeout: Duration::from_secs(15),
221 flusher: None,
222 }
223 }
224}
225
226#[async_trait::async_trait]
227impl ShutdownHook for MemtableFlushHook {
228 fn name(&self) -> &str {
229 "MemtableFlush"
230 }
231
232 async fn on_shutdown(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
233 match &self.flusher {
234 Some(flusher) => {
235 info!("Flushing active memtable to SSTable");
236 let entries = flusher.flush_to_sstable()?;
237 info!("Memtable flush complete ({} entries flushed)", entries);
238 }
239 None => {
240 info!("No memtable flusher configured - skipping flush");
241 }
242 }
243 Ok(())
244 }
245}
246
247pub struct ConnectionDrainHook {
249 active_connections: Arc<AtomicUsize>,
251 drain_timeout: Duration,
253 poll_interval: Duration,
255}
256
257impl ConnectionDrainHook {
258 pub fn new(active_connections: Arc<AtomicUsize>, drain_timeout: Duration) -> Self {
264 Self {
265 active_connections,
266 drain_timeout,
267 poll_interval: Duration::from_millis(100),
268 }
269 }
270
271 pub fn with_poll_interval(mut self, interval: Duration) -> Self {
273 self.poll_interval = interval;
274 self
275 }
276}
277
278#[async_trait::async_trait]
279impl ShutdownHook for ConnectionDrainHook {
280 fn name(&self) -> &str {
281 "ConnectionDrain"
282 }
283
284 async fn on_shutdown(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
285 let deadline = Instant::now() + self.drain_timeout;
286
287 loop {
288 let remaining = self.active_connections.load(Ordering::SeqCst);
289 if remaining == 0 {
290 info!("All connections drained");
291 return Ok(());
292 }
293
294 if Instant::now() >= deadline {
295 warn!(
296 "Connection drain timeout ({:?}) exceeded with {} connections remaining",
297 self.drain_timeout, remaining
298 );
299 return Err(format!(
300 "connection drain timed out with {} connections remaining",
301 remaining
302 )
303 .into());
304 }
305
306 info!("Draining connections: {} remaining", remaining);
307 tokio::time::sleep(self.poll_interval).await;
308 }
309 }
310}
311
312pub struct MetricsSnapshotHook {
314 pub timeout: Duration,
316 metrics_path: Option<PathBuf>,
318 metrics_provider: Option<Arc<dyn Fn() -> Vec<u8> + Send + Sync>>,
320}
321
322impl MetricsSnapshotHook {
323 pub fn with_provider(
325 provider: Arc<dyn Fn() -> Vec<u8> + Send + Sync>,
326 path: PathBuf,
327 timeout: Duration,
328 ) -> Self {
329 Self {
330 timeout,
331 metrics_path: Some(path),
332 metrics_provider: Some(provider),
333 }
334 }
335}
336
337impl Default for MetricsSnapshotHook {
338 fn default() -> Self {
339 Self {
340 timeout: Duration::from_secs(5),
341 metrics_path: None,
342 metrics_provider: None,
343 }
344 }
345}
346
347#[async_trait::async_trait]
348impl ShutdownHook for MetricsSnapshotHook {
349 fn name(&self) -> &str {
350 "MetricsSnapshot"
351 }
352
353 async fn on_shutdown(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
354 match (&self.metrics_provider, &self.metrics_path) {
355 (Some(provider), Some(path)) => {
356 let data = provider();
357 info!(
358 "Writing {} bytes of metrics to {}",
359 data.len(),
360 path.display()
361 );
362 std::fs::write(path, &data)?;
363 info!("Metrics snapshot saved successfully");
364 }
365 _ => {
366 info!("No metrics provider/path configured - skipping snapshot");
367 }
368 }
369 Ok(())
370 }
371}
372
373#[derive(Debug, Clone)]
379pub struct ShutdownStatus {
380 pub phase: ShutdownPhase,
382 pub active_requests: usize,
384 pub hooks_completed: usize,
386 pub hooks_total: usize,
388 pub elapsed_ms: u64,
390}
391
392#[derive(Clone)]
401pub struct ShutdownCoordinator {
402 inner: Arc<ShutdownInner>,
403}
404
405struct ShutdownInner {
406 sender: broadcast::Sender<()>,
408 phase_tx: watch::Sender<ShutdownPhase>,
410 phase_rx: watch::Receiver<ShutdownPhase>,
411 shutdown_initiated: AtomicBool,
413 phase: AtomicU64,
415 active_requests: AtomicUsize,
417 hooks: Mutex<Vec<Box<dyn ShutdownHook>>>,
419 hooks_completed: AtomicUsize,
421 hook_results: Mutex<Vec<HookExecutionResult>>,
423 drain_config: DrainConfig,
425 shutdown_start: Mutex<Option<Instant>>,
427}
428
429impl ShutdownCoordinator {
430 pub fn new() -> Self {
432 Self::with_config(DrainConfig::default())
433 }
434
435 pub fn with_config(config: DrainConfig) -> Self {
437 let (sender, _) = broadcast::channel(16);
438 let (phase_tx, phase_rx) = watch::channel(ShutdownPhase::Running);
439
440 Self {
441 inner: Arc::new(ShutdownInner {
442 sender,
443 phase_tx,
444 phase_rx,
445 shutdown_initiated: AtomicBool::new(false),
446 phase: AtomicU64::new(ShutdownPhase::Running.as_u64()),
447 active_requests: AtomicUsize::new(0),
448 hooks: Mutex::new(Vec::new()),
449 hooks_completed: AtomicUsize::new(0),
450 hook_results: Mutex::new(Vec::new()),
451 drain_config: config,
452 shutdown_start: Mutex::new(None),
453 }),
454 }
455 }
456
457 pub fn subscribe(&self) -> broadcast::Receiver<()> {
461 self.inner.sender.subscribe()
462 }
463
464 pub fn phase_watch(&self) -> watch::Receiver<ShutdownPhase> {
466 self.inner.phase_rx.clone()
467 }
468
469 pub fn request_start(&self) {
473 self.inner.active_requests.fetch_add(1, Ordering::SeqCst);
474 }
475
476 pub fn request_end(&self) {
478 self.inner.active_requests.fetch_sub(1, Ordering::SeqCst);
479 }
480
481 pub fn active_request_count(&self) -> usize {
483 self.inner.active_requests.load(Ordering::SeqCst)
484 }
485
486 pub async fn register_shutdown_hook(&self, hook: Box<dyn ShutdownHook>) {
490 let mut hooks = self.inner.hooks.lock().await;
491 info!("Registered shutdown hook: {}", hook.name());
492 hooks.push(hook);
493 }
494
495 pub fn current_phase(&self) -> ShutdownPhase {
499 ShutdownPhase::from_u64(self.inner.phase.load(Ordering::SeqCst))
500 }
501
502 pub fn is_accepting(&self) -> bool {
504 self.current_phase() == ShutdownPhase::Running
505 }
506
507 fn set_phase(&self, phase: ShutdownPhase) {
508 self.inner.phase.store(phase.as_u64(), Ordering::SeqCst);
509 let _ = self.inner.phase_tx.send(phase);
511 info!("Shutdown phase: {}", phase);
512 }
513
514 pub fn is_shutting_down(&self) -> bool {
518 self.inner.shutdown_initiated.load(Ordering::SeqCst)
519 }
520
521 pub fn health_status_label(&self) -> &'static str {
526 match self.current_phase() {
527 ShutdownPhase::Running => "ok",
528 _ => "shutting_down",
529 }
530 }
531
532 pub fn status(&self) -> ShutdownStatus {
534 let elapsed_ms = {
535 if let Ok(guard) = self.inner.shutdown_start.try_lock() {
537 guard.map(|s| s.elapsed().as_millis() as u64).unwrap_or(0)
538 } else {
539 0
540 }
541 };
542
543 let hooks_total = if let Ok(hooks) = self.inner.hooks.try_lock() {
544 hooks.len()
545 } else {
546 0
547 };
548
549 ShutdownStatus {
550 phase: self.current_phase(),
551 active_requests: self.active_request_count(),
552 hooks_completed: self.inner.hooks_completed.load(Ordering::SeqCst),
553 hooks_total,
554 elapsed_ms,
555 }
556 }
557
558 pub fn shutdown(&self) {
566 if self.inner.shutdown_initiated.swap(true, Ordering::SeqCst) {
567 debug!("Shutdown already initiated - ignoring duplicate signal");
569 return;
570 }
571
572 info!("Initiating graceful shutdown");
573
574 if let Ok(mut guard) = self.inner.shutdown_start.try_lock() {
576 *guard = Some(Instant::now());
577 }
578
579 if let Err(e) = self.inner.sender.send(()) {
581 warn!("Failed to broadcast shutdown signal: {}", e);
582 }
583
584 let coord = self.clone();
586 tokio::spawn(async move {
587 coord.run_shutdown_sequence().await;
588 });
589 }
590
591 async fn run_shutdown_sequence(&self) {
593 self.set_phase(ShutdownPhase::Draining);
595 self.drain_connections().await;
596
597 self.set_phase(ShutdownPhase::FlushingState);
599 self.run_hooks().await;
600
601 self.set_phase(ShutdownPhase::Terminated);
603 info!("Shutdown complete");
604 }
605
606 async fn drain_connections(&self) {
608 let cfg = &self.inner.drain_config;
609 let deadline = Instant::now() + cfg.drain_timeout;
610
611 loop {
612 let remaining = self.active_request_count();
613 if remaining == 0 {
614 info!("All in-flight requests drained");
615 return;
616 }
617
618 if Instant::now() >= deadline {
619 warn!(
620 "Drain timeout ({:?}) exceeded with {} requests remaining - force-closing",
621 cfg.drain_timeout, remaining
622 );
623 return;
624 }
625
626 info!("Draining: {} requests remaining", remaining);
627 tokio::time::sleep(cfg.check_interval).await;
628 }
629 }
630
631 pub async fn hook_results(&self) -> Vec<HookExecutionResult> {
636 self.inner.hook_results.lock().await.clone()
637 }
638
639 async fn run_hooks(&self) {
641 let hooks = {
642 let mut guard = self.inner.hooks.lock().await;
643 std::mem::take(&mut *guard)
644 };
645
646 if hooks.is_empty() {
647 info!("No shutdown hooks registered");
648 return;
649 }
650
651 let flush_timeout = self.inner.drain_config.flush_timeout;
652 info!("Executing {} shutdown hook(s)", hooks.len());
653
654 for hook in &hooks {
655 let name = hook.name().to_string();
656 info!("Running shutdown hook: {}", name);
657
658 let start = Instant::now();
659 let result = match tokio::time::timeout(flush_timeout, hook.on_shutdown()).await {
660 Ok(Ok(())) => {
661 info!("Shutdown hook '{}' completed successfully", name);
662 HookExecutionResult {
663 hook_name: name,
664 success: true,
665 duration: start.elapsed(),
666 error: None,
667 }
668 }
669 Ok(Err(e)) => {
670 let msg = e.to_string();
671 error!("Shutdown hook '{}' failed: {}", name, msg);
672 HookExecutionResult {
673 hook_name: name,
674 success: false,
675 duration: start.elapsed(),
676 error: Some(msg),
677 }
678 }
679 Err(_) => {
680 let msg = format!("timed out after {:?}", flush_timeout);
681 error!("Shutdown hook '{}' {}", name, msg);
682 HookExecutionResult {
683 hook_name: name,
684 success: false,
685 duration: start.elapsed(),
686 error: Some(msg),
687 }
688 }
689 };
690
691 {
692 let mut results = self.inner.hook_results.lock().await;
693 results.push(result);
694 }
695 self.inner.hooks_completed.fetch_add(1, Ordering::SeqCst);
696 }
697
698 info!(
699 "All shutdown hooks processed ({} total)",
700 self.inner.hooks_completed.load(Ordering::SeqCst)
701 );
702 }
703}
704
705impl Default for ShutdownCoordinator {
706 fn default() -> Self {
707 Self::new()
708 }
709}
710
711pub async fn setup_signal_handlers(coordinator: ShutdownCoordinator) {
719 tokio::spawn(async move {
720 if let Err(e) = wait_for_signal().await {
721 warn!("Error setting up signal handlers: {}", e);
722 return;
723 }
724
725 info!("Received shutdown signal");
726 coordinator.shutdown();
727 });
728}
729
730#[cfg(unix)]
735pub async fn setup_sighup_handler(config: ReloadableConfig) {
736 tokio::spawn(async move {
737 let mut sighup = match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::hangup())
738 {
739 Ok(s) => s,
740 Err(e) => {
741 warn!("Failed to setup SIGHUP handler: {}", e);
742 return;
743 }
744 };
745
746 loop {
747 sighup.recv().await;
748 info!("Received SIGHUP - reloading configuration");
749
750 match config.reload_from_stored_path() {
751 Ok(report) => {
752 if report.success {
753 info!("Configuration reload completed: {}", report);
754 } else {
755 error!("Configuration reload failed: {}", report);
756 }
757 }
758 Err(e) => {
759 error!("Configuration reload error: {}", e);
760 }
761 }
762 }
763 });
764}
765
766#[cfg(not(unix))]
770pub async fn setup_sighup_handler(_config: ReloadableConfig) {
771 info!("SIGHUP handler not available on this platform; use manual_reload() instead");
772}
773
774async fn wait_for_signal() -> Result<(), std::io::Error> {
776 #[cfg(unix)]
777 {
778 use tokio::signal::unix::{SignalKind, signal};
779
780 let mut sigterm = signal(SignalKind::terminate())?;
781 let mut sigint = signal(SignalKind::interrupt())?;
782
783 tokio::select! {
784 _ = sigterm.recv() => {
785 info!("Received SIGTERM");
786 }
787 _ = sigint.recv() => {
788 info!("Received SIGINT");
789 }
790 }
791 }
792
793 #[cfg(not(unix))]
794 {
795 use tokio::signal;
796 signal::ctrl_c().await?;
797 info!("Received Ctrl+C");
798 }
799
800 Ok(())
801}
802
803pub struct ShutdownGuard {
811 coordinator: ShutdownCoordinator,
812 disarmed: Arc<AtomicBool>,
813}
814
815impl ShutdownGuard {
816 pub fn new(coordinator: ShutdownCoordinator) -> Self {
818 Self {
819 coordinator,
820 disarmed: Arc::new(AtomicBool::new(false)),
821 }
822 }
823
824 pub fn disarm(&self) {
826 self.disarmed.store(true, Ordering::SeqCst);
827 }
828}
829
830impl Drop for ShutdownGuard {
831 fn drop(&mut self) {
832 if !self.disarmed.load(Ordering::SeqCst) {
833 warn!("ShutdownGuard dropped without disarming - triggering shutdown");
834 self.coordinator.shutdown();
835 }
836 }
837}
838
839pub struct RequestGuard {
849 coordinator: ShutdownCoordinator,
850}
851
852impl RequestGuard {
853 pub fn new(coordinator: ShutdownCoordinator) -> Self {
855 coordinator.request_start();
856 Self { coordinator }
857 }
858}
859
860impl Drop for RequestGuard {
861 fn drop(&mut self) {
862 self.coordinator.request_end();
863 }
864}
865
866#[cfg(test)]
871mod tests {
872 use super::*;
873 use std::sync::atomic::AtomicBool as StdAtomicBool;
874 use std::time::Duration;
875 use tokio::time::timeout;
876
877 async fn wait_terminated(coordinator: &ShutdownCoordinator, dur: Duration) {
879 let mut watcher = coordinator.phase_watch();
880 let _ = timeout(dur, async {
881 loop {
882 if *watcher.borrow() == ShutdownPhase::Terminated {
883 return;
884 }
885 if watcher.changed().await.is_err() {
886 return;
887 }
888 }
889 })
890 .await;
891 }
892
893 #[tokio::test]
894 async fn test_shutdown_coordinator() {
895 let coordinator = ShutdownCoordinator::new();
896 let mut receiver = coordinator.subscribe();
897
898 assert!(!coordinator.is_shutting_down());
899 assert_eq!(coordinator.current_phase(), ShutdownPhase::Running);
900
901 coordinator.shutdown();
902
903 assert!(coordinator.is_shutting_down());
904
905 let result = timeout(Duration::from_millis(100), receiver.recv()).await;
907 assert!(result.is_ok());
908 }
909
910 #[tokio::test]
911 async fn test_multiple_subscribers() {
912 let coordinator = ShutdownCoordinator::new();
913 let mut rx1 = coordinator.subscribe();
914 let mut rx2 = coordinator.subscribe();
915 let mut rx3 = coordinator.subscribe();
916
917 coordinator.shutdown();
918
919 assert!(
920 timeout(Duration::from_millis(100), rx1.recv())
921 .await
922 .is_ok()
923 );
924 assert!(
925 timeout(Duration::from_millis(100), rx2.recv())
926 .await
927 .is_ok()
928 );
929 assert!(
930 timeout(Duration::from_millis(100), rx3.recv())
931 .await
932 .is_ok()
933 );
934 }
935
936 #[tokio::test]
937 async fn test_shutdown_idempotent() {
938 let coordinator = ShutdownCoordinator::new();
939
940 coordinator.shutdown();
941 coordinator.shutdown(); assert!(coordinator.is_shutting_down());
944
945 wait_terminated(&coordinator, Duration::from_secs(2)).await;
947 assert_eq!(coordinator.current_phase(), ShutdownPhase::Terminated);
948 }
949
950 #[test]
951 fn test_shutdown_guard_disarm() {
952 let coordinator = ShutdownCoordinator::new();
953 let guard = ShutdownGuard::new(coordinator.clone());
954
955 guard.disarm();
956 drop(guard);
957
958 assert!(!coordinator.is_shutting_down());
959 }
960
961 #[tokio::test]
962 async fn test_shutdown_guard_trigger() {
963 let coordinator = ShutdownCoordinator::new();
964 let guard = ShutdownGuard::new(coordinator.clone());
965
966 drop(guard);
967
968 assert!(coordinator.is_shutting_down());
969
970 wait_terminated(&coordinator, Duration::from_secs(2)).await;
972 }
973
974 #[tokio::test]
977 async fn test_phase_transitions() {
978 let config = DrainConfig {
979 drain_timeout: Duration::from_millis(200),
980 check_interval: Duration::from_millis(50),
981 flush_timeout: Duration::from_millis(200),
982 };
983 let coordinator = ShutdownCoordinator::with_config(config);
984
985 assert_eq!(coordinator.current_phase(), ShutdownPhase::Running);
986
987 coordinator.shutdown();
988
989 wait_terminated(&coordinator, Duration::from_secs(2)).await;
990 assert_eq!(coordinator.current_phase(), ShutdownPhase::Terminated);
991 }
992
993 #[tokio::test]
994 async fn test_drain_waits_for_in_flight_requests() {
995 let config = DrainConfig {
996 drain_timeout: Duration::from_secs(2),
997 check_interval: Duration::from_millis(50),
998 flush_timeout: Duration::from_millis(200),
999 };
1000 let coordinator = ShutdownCoordinator::with_config(config);
1001
1002 coordinator.request_start();
1004 coordinator.request_start();
1005 coordinator.request_start();
1006 assert_eq!(coordinator.active_request_count(), 3);
1007
1008 coordinator.shutdown();
1009
1010 tokio::time::sleep(Duration::from_millis(80)).await;
1012 assert_eq!(coordinator.current_phase(), ShutdownPhase::Draining);
1013
1014 coordinator.request_end();
1016 tokio::time::sleep(Duration::from_millis(60)).await;
1017 coordinator.request_end();
1018 tokio::time::sleep(Duration::from_millis(60)).await;
1019 coordinator.request_end();
1020
1021 wait_terminated(&coordinator, Duration::from_secs(2)).await;
1022 assert_eq!(coordinator.current_phase(), ShutdownPhase::Terminated);
1023 }
1024
1025 #[tokio::test]
1026 async fn test_drain_timeout_forces_termination() {
1027 let config = DrainConfig {
1028 drain_timeout: Duration::from_millis(150),
1029 check_interval: Duration::from_millis(30),
1030 flush_timeout: Duration::from_millis(100),
1031 };
1032 let coordinator = ShutdownCoordinator::with_config(config);
1033
1034 coordinator.request_start();
1036
1037 coordinator.shutdown();
1038
1039 wait_terminated(&coordinator, Duration::from_secs(2)).await;
1040 assert_eq!(coordinator.current_phase(), ShutdownPhase::Terminated);
1041 assert_eq!(coordinator.active_request_count(), 1);
1043 }
1044
1045 #[tokio::test]
1046 async fn test_shutdown_hooks_execute_in_order() {
1047 let config = DrainConfig {
1048 drain_timeout: Duration::from_millis(100),
1049 check_interval: Duration::from_millis(20),
1050 flush_timeout: Duration::from_secs(1),
1051 };
1052 let coordinator = ShutdownCoordinator::with_config(config);
1053
1054 let order = Arc::new(Mutex::new(Vec::<String>::new()));
1055
1056 struct OrderHook {
1057 hook_name: String,
1058 order: Arc<Mutex<Vec<String>>>,
1059 }
1060
1061 #[async_trait::async_trait]
1062 impl ShutdownHook for OrderHook {
1063 fn name(&self) -> &str {
1064 &self.hook_name
1065 }
1066 async fn on_shutdown(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
1067 let mut guard = self.order.lock().await;
1068 guard.push(self.hook_name.clone());
1069 Ok(())
1070 }
1071 }
1072
1073 coordinator
1074 .register_shutdown_hook(Box::new(OrderHook {
1075 hook_name: "first".to_string(),
1076 order: order.clone(),
1077 }))
1078 .await;
1079 coordinator
1080 .register_shutdown_hook(Box::new(OrderHook {
1081 hook_name: "second".to_string(),
1082 order: order.clone(),
1083 }))
1084 .await;
1085 coordinator
1086 .register_shutdown_hook(Box::new(OrderHook {
1087 hook_name: "third".to_string(),
1088 order: order.clone(),
1089 }))
1090 .await;
1091
1092 coordinator.shutdown();
1093 wait_terminated(&coordinator, Duration::from_secs(2)).await;
1094
1095 let executed = order.lock().await;
1096 assert_eq!(*executed, vec!["first", "second", "third"]);
1097 }
1098
1099 #[tokio::test]
1100 async fn test_hook_failure_does_not_block_others() {
1101 let config = DrainConfig {
1102 drain_timeout: Duration::from_millis(50),
1103 check_interval: Duration::from_millis(10),
1104 flush_timeout: Duration::from_secs(1),
1105 };
1106 let coordinator = ShutdownCoordinator::with_config(config);
1107
1108 let completed = Arc::new(StdAtomicBool::new(false));
1109
1110 struct FailingHook;
1111
1112 #[async_trait::async_trait]
1113 impl ShutdownHook for FailingHook {
1114 fn name(&self) -> &str {
1115 "failing"
1116 }
1117 async fn on_shutdown(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
1118 Err("intentional failure".into())
1119 }
1120 }
1121
1122 struct SuccessHook {
1123 completed: Arc<StdAtomicBool>,
1124 }
1125
1126 #[async_trait::async_trait]
1127 impl ShutdownHook for SuccessHook {
1128 fn name(&self) -> &str {
1129 "success_after_failure"
1130 }
1131 async fn on_shutdown(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
1132 self.completed.store(true, Ordering::SeqCst);
1133 Ok(())
1134 }
1135 }
1136
1137 coordinator
1138 .register_shutdown_hook(Box::new(FailingHook))
1139 .await;
1140 coordinator
1141 .register_shutdown_hook(Box::new(SuccessHook {
1142 completed: completed.clone(),
1143 }))
1144 .await;
1145
1146 coordinator.shutdown();
1147 wait_terminated(&coordinator, Duration::from_secs(2)).await;
1148
1149 assert!(
1150 completed.load(Ordering::SeqCst),
1151 "Hook after failing hook should still run"
1152 );
1153 assert_eq!(coordinator.inner.hooks_completed.load(Ordering::SeqCst), 2);
1154 }
1155
1156 #[tokio::test]
1157 async fn test_status_reporting() {
1158 let config = DrainConfig {
1159 drain_timeout: Duration::from_secs(1),
1160 check_interval: Duration::from_millis(50),
1161 flush_timeout: Duration::from_millis(200),
1162 };
1163 let coordinator = ShutdownCoordinator::with_config(config);
1164
1165 let st = coordinator.status();
1167 assert_eq!(st.phase, ShutdownPhase::Running);
1168 assert_eq!(st.active_requests, 0);
1169 assert_eq!(st.hooks_completed, 0);
1170 assert_eq!(st.elapsed_ms, 0);
1171
1172 coordinator.request_start();
1173 coordinator.request_start();
1174
1175 let st = coordinator.status();
1176 assert_eq!(st.active_requests, 2);
1177
1178 coordinator.request_end();
1179 coordinator.request_end();
1180
1181 coordinator.shutdown();
1182
1183 tokio::time::sleep(Duration::from_millis(20)).await;
1185
1186 let st = coordinator.status();
1188 assert!(st.elapsed_ms > 0, "elapsed_ms should be > 0 after shutdown");
1189
1190 wait_terminated(&coordinator, Duration::from_secs(2)).await;
1192
1193 let st = coordinator.status();
1194 assert_eq!(st.phase, ShutdownPhase::Terminated);
1195 }
1196
1197 #[tokio::test]
1198 async fn test_zero_active_requests_fast_shutdown() {
1199 let config = DrainConfig {
1200 drain_timeout: Duration::from_secs(30),
1201 check_interval: Duration::from_millis(50),
1202 flush_timeout: Duration::from_millis(100),
1203 };
1204 let coordinator = ShutdownCoordinator::with_config(config);
1205
1206 let start = Instant::now();
1207 coordinator.shutdown();
1208
1209 wait_terminated(&coordinator, Duration::from_secs(1)).await;
1210
1211 assert_eq!(coordinator.current_phase(), ShutdownPhase::Terminated);
1212 let elapsed = start.elapsed();
1214 assert!(
1215 elapsed < Duration::from_secs(1),
1216 "Fast shutdown should complete quickly, took {:?}",
1217 elapsed
1218 );
1219 }
1220
1221 #[tokio::test]
1222 async fn test_health_status_label() {
1223 let coordinator = ShutdownCoordinator::new();
1224 assert_eq!(coordinator.health_status_label(), "ok");
1225
1226 coordinator.shutdown();
1227 tokio::time::sleep(Duration::from_millis(50)).await;
1228
1229 assert_eq!(coordinator.health_status_label(), "shutting_down");
1231 }
1232
1233 #[tokio::test]
1234 async fn test_request_guard_raii() {
1235 let coordinator = ShutdownCoordinator::new();
1236 assert_eq!(coordinator.active_request_count(), 0);
1237
1238 {
1239 let _g1 = RequestGuard::new(coordinator.clone());
1240 assert_eq!(coordinator.active_request_count(), 1);
1241 {
1242 let _g2 = RequestGuard::new(coordinator.clone());
1243 assert_eq!(coordinator.active_request_count(), 2);
1244 }
1245 assert_eq!(coordinator.active_request_count(), 1);
1247 }
1248 assert_eq!(coordinator.active_request_count(), 0);
1250 }
1251
1252 #[tokio::test]
1253 async fn test_is_accepting() {
1254 let coordinator = ShutdownCoordinator::new();
1255 assert!(coordinator.is_accepting());
1256
1257 coordinator.shutdown();
1258 tokio::time::sleep(Duration::from_millis(50)).await;
1259
1260 assert!(!coordinator.is_accepting());
1261 }
1262
1263 #[tokio::test]
1264 async fn test_built_in_hooks() {
1265 let config = DrainConfig {
1266 drain_timeout: Duration::from_millis(50),
1267 check_interval: Duration::from_millis(10),
1268 flush_timeout: Duration::from_secs(5),
1269 };
1270 let coordinator = ShutdownCoordinator::with_config(config);
1271
1272 coordinator
1273 .register_shutdown_hook(Box::new(WalFlushHook::default()))
1274 .await;
1275 coordinator
1276 .register_shutdown_hook(Box::new(MemtableFlushHook::default()))
1277 .await;
1278 coordinator
1279 .register_shutdown_hook(Box::new(MetricsSnapshotHook::default()))
1280 .await;
1281
1282 let st = coordinator.status();
1283 assert_eq!(st.hooks_total, 3);
1284
1285 coordinator.shutdown();
1286 wait_terminated(&coordinator, Duration::from_secs(2)).await;
1287
1288 assert_eq!(coordinator.inner.hooks_completed.load(Ordering::SeqCst), 3);
1289 }
1290
1291 #[tokio::test]
1292 async fn test_multiple_shutdown_signals_idempotent() {
1293 let coordinator = ShutdownCoordinator::new();
1294 let mut rx = coordinator.subscribe();
1295
1296 coordinator.shutdown();
1298 let recv_result = timeout(Duration::from_millis(100), rx.recv()).await;
1299 assert!(recv_result.is_ok());
1300
1301 coordinator.shutdown();
1303 coordinator.shutdown();
1304 coordinator.shutdown();
1305
1306 assert!(coordinator.is_shutting_down());
1307
1308 wait_terminated(&coordinator, Duration::from_secs(2)).await;
1309 assert_eq!(coordinator.current_phase(), ShutdownPhase::Terminated);
1310 }
1311
1312 #[tokio::test]
1313 async fn test_drain_config_default() {
1314 let cfg = DrainConfig::default();
1315 assert_eq!(cfg.drain_timeout, Duration::from_secs(30));
1316 assert_eq!(cfg.check_interval, Duration::from_secs(1));
1317 assert_eq!(cfg.flush_timeout, Duration::from_secs(30));
1318 }
1319
1320 #[tokio::test]
1321 async fn test_phase_display() {
1322 assert_eq!(format!("{}", ShutdownPhase::Running), "Running");
1323 assert_eq!(format!("{}", ShutdownPhase::Draining), "Draining");
1324 assert_eq!(format!("{}", ShutdownPhase::FlushingState), "FlushingState");
1325 assert_eq!(format!("{}", ShutdownPhase::Terminated), "Terminated");
1326 }
1327
1328 struct MockWalWriter {
1332 sync_called: Arc<StdAtomicBool>,
1333 size: u64,
1334 should_fail: bool,
1335 }
1336
1337 impl WalWriter for MockWalWriter {
1338 fn sync(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
1339 self.sync_called.store(true, Ordering::SeqCst);
1340 if self.should_fail {
1341 return Err("WAL sync failed".into());
1342 }
1343 Ok(())
1344 }
1345
1346 fn current_size(&self) -> u64 {
1347 self.size
1348 }
1349 }
1350
1351 struct MockMemtableFlusher {
1353 flush_called: Arc<StdAtomicBool>,
1354 entries: usize,
1355 should_fail: bool,
1356 }
1357
1358 impl MemtableFlusher for MockMemtableFlusher {
1359 fn flush_to_sstable(&self) -> Result<usize, Box<dyn std::error::Error + Send + Sync>> {
1360 self.flush_called.store(true, Ordering::SeqCst);
1361 if self.should_fail {
1362 return Err("memtable flush failed".into());
1363 }
1364 Ok(self.entries)
1365 }
1366 }
1367
1368 #[tokio::test]
1369 async fn test_wal_flush_hook_calls_sync() {
1370 let sync_called = Arc::new(StdAtomicBool::new(false));
1371 let writer = Arc::new(MockWalWriter {
1372 sync_called: sync_called.clone(),
1373 size: 4096,
1374 should_fail: false,
1375 });
1376
1377 let hook = WalFlushHook::with_writer(writer, Duration::from_secs(5));
1378 let result = hook.on_shutdown().await;
1379
1380 assert!(result.is_ok());
1381 assert!(
1382 sync_called.load(Ordering::SeqCst),
1383 "sync() should have been called"
1384 );
1385 }
1386
1387 #[tokio::test]
1388 async fn test_wal_flush_hook_no_writer() {
1389 let hook = WalFlushHook::default();
1390 let result = hook.on_shutdown().await;
1391 assert!(result.is_ok(), "no-writer hook should succeed");
1392 }
1393
1394 #[tokio::test]
1395 async fn test_wal_flush_hook_error() {
1396 let sync_called = Arc::new(StdAtomicBool::new(false));
1397 let writer = Arc::new(MockWalWriter {
1398 sync_called: sync_called.clone(),
1399 size: 1024,
1400 should_fail: true,
1401 });
1402
1403 let hook = WalFlushHook::with_writer(writer, Duration::from_secs(5));
1404 let result = hook.on_shutdown().await;
1405
1406 assert!(result.is_err());
1407 assert!(
1408 sync_called.load(Ordering::SeqCst),
1409 "sync() should have been called even on failure"
1410 );
1411 let err_msg = result.expect_err("should be error").to_string();
1412 assert!(
1413 err_msg.contains("WAL sync failed"),
1414 "error message should propagate"
1415 );
1416 }
1417
1418 #[tokio::test]
1419 async fn test_memtable_flush_hook_calls_flush() {
1420 let flush_called = Arc::new(StdAtomicBool::new(false));
1421 let flusher = Arc::new(MockMemtableFlusher {
1422 flush_called: flush_called.clone(),
1423 entries: 42,
1424 should_fail: false,
1425 });
1426
1427 let hook = MemtableFlushHook::with_flusher(flusher, Duration::from_secs(5));
1428 let result = hook.on_shutdown().await;
1429
1430 assert!(result.is_ok());
1431 assert!(
1432 flush_called.load(Ordering::SeqCst),
1433 "flush_to_sstable() should have been called"
1434 );
1435 }
1436
1437 #[tokio::test]
1438 async fn test_memtable_flush_hook_no_flusher() {
1439 let hook = MemtableFlushHook::default();
1440 let result = hook.on_shutdown().await;
1441 assert!(result.is_ok(), "no-flusher hook should succeed");
1442 }
1443
1444 #[tokio::test]
1445 async fn test_connection_drain_immediate() {
1446 let conns = Arc::new(AtomicUsize::new(0));
1447 let hook = ConnectionDrainHook::new(conns, Duration::from_secs(5));
1448
1449 let start = Instant::now();
1450 let result = hook.on_shutdown().await;
1451 let elapsed = start.elapsed();
1452
1453 assert!(result.is_ok());
1454 assert!(
1455 elapsed < Duration::from_millis(50),
1456 "should return immediately with 0 connections, took {:?}",
1457 elapsed
1458 );
1459 }
1460
1461 #[tokio::test]
1462 async fn test_connection_drain_waits_for_zero() {
1463 let conns = Arc::new(AtomicUsize::new(5));
1464 let hook = ConnectionDrainHook::new(conns.clone(), Duration::from_secs(5))
1465 .with_poll_interval(Duration::from_millis(50));
1466
1467 let conns_clone = conns.clone();
1469 tokio::spawn(async move {
1470 for _ in 0..5 {
1471 tokio::time::sleep(Duration::from_millis(30)).await;
1472 conns_clone.fetch_sub(1, Ordering::SeqCst);
1473 }
1474 });
1475
1476 let result = hook.on_shutdown().await;
1477 assert!(result.is_ok());
1478 assert_eq!(conns.load(Ordering::SeqCst), 0);
1479 }
1480
1481 #[tokio::test]
1482 async fn test_connection_drain_timeout() {
1483 let conns = Arc::new(AtomicUsize::new(10));
1484 let hook = ConnectionDrainHook::new(conns.clone(), Duration::from_millis(200))
1485 .with_poll_interval(Duration::from_millis(50));
1486
1487 let start = Instant::now();
1488 let result = hook.on_shutdown().await;
1489 let elapsed = start.elapsed();
1490
1491 assert!(result.is_err(), "should error on timeout");
1492 let err_msg = result.expect_err("should be error").to_string();
1493 assert!(
1494 err_msg.contains("timed out"),
1495 "error should mention timeout"
1496 );
1497 assert!(
1498 elapsed >= Duration::from_millis(200),
1499 "should have waited at least the timeout duration, elapsed {:?}",
1500 elapsed
1501 );
1502 }
1503
1504 #[tokio::test]
1505 async fn test_hook_execution_result_captured() {
1506 let config = DrainConfig {
1507 drain_timeout: Duration::from_millis(50),
1508 check_interval: Duration::from_millis(10),
1509 flush_timeout: Duration::from_secs(1),
1510 };
1511 let coordinator = ShutdownCoordinator::with_config(config);
1512
1513 struct NamedHook {
1514 hook_name: String,
1515 }
1516
1517 #[async_trait::async_trait]
1518 impl ShutdownHook for NamedHook {
1519 fn name(&self) -> &str {
1520 &self.hook_name
1521 }
1522 async fn on_shutdown(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
1523 Ok(())
1524 }
1525 }
1526
1527 coordinator
1528 .register_shutdown_hook(Box::new(NamedHook {
1529 hook_name: "test_hook".to_string(),
1530 }))
1531 .await;
1532
1533 coordinator.shutdown();
1534 wait_terminated(&coordinator, Duration::from_secs(2)).await;
1535
1536 let results = coordinator.hook_results().await;
1537 assert_eq!(results.len(), 1);
1538 assert_eq!(results[0].hook_name, "test_hook");
1539 assert!(results[0].success);
1540 assert!(results[0].error.is_none());
1541 assert!(results[0].duration < Duration::from_secs(1));
1542 }
1543
1544 #[tokio::test]
1545 async fn test_hook_error_result() {
1546 let config = DrainConfig {
1547 drain_timeout: Duration::from_millis(50),
1548 check_interval: Duration::from_millis(10),
1549 flush_timeout: Duration::from_secs(1),
1550 };
1551 let coordinator = ShutdownCoordinator::with_config(config);
1552
1553 struct FailHook;
1554
1555 #[async_trait::async_trait]
1556 impl ShutdownHook for FailHook {
1557 fn name(&self) -> &str {
1558 "fail_hook"
1559 }
1560 async fn on_shutdown(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
1561 Err("catastrophic failure".into())
1562 }
1563 }
1564
1565 coordinator.register_shutdown_hook(Box::new(FailHook)).await;
1566
1567 coordinator.shutdown();
1568 wait_terminated(&coordinator, Duration::from_secs(2)).await;
1569
1570 let results = coordinator.hook_results().await;
1571 assert_eq!(results.len(), 1);
1572 assert_eq!(results[0].hook_name, "fail_hook");
1573 assert!(!results[0].success);
1574 assert!(results[0].error.is_some());
1575 let err = results[0].error.as_ref().expect("error should be present");
1576 assert!(
1577 err.contains("catastrophic failure"),
1578 "error should contain the failure message"
1579 );
1580 }
1581
1582 #[tokio::test]
1583 async fn test_metrics_snapshot_writes_file() {
1584 let dir = tempfile::tempdir().expect("failed to create temp dir");
1585 let path = dir.path().join("metrics.bin");
1586
1587 let expected_data = b"metric1=42\nmetric2=100\n".to_vec();
1588 let expected_clone = expected_data.clone();
1589 let provider: Arc<dyn Fn() -> Vec<u8> + Send + Sync> =
1590 Arc::new(move || expected_clone.clone());
1591
1592 let hook =
1593 MetricsSnapshotHook::with_provider(provider, path.clone(), Duration::from_secs(5));
1594 let result = hook.on_shutdown().await;
1595
1596 assert!(result.is_ok());
1597 let written = std::fs::read(&path).expect("should be able to read metrics file");
1598 assert_eq!(written, expected_data);
1599 }
1600
1601 #[tokio::test]
1602 async fn test_metrics_snapshot_no_provider() {
1603 let hook = MetricsSnapshotHook::default();
1604 let result = hook.on_shutdown().await;
1605 assert!(result.is_ok(), "no-provider hook should succeed");
1606 }
1607
1608 #[tokio::test]
1609 async fn test_connection_drain_poll_interval() {
1610 let conns = Arc::new(AtomicUsize::new(1));
1612 let poll_interval = Duration::from_millis(80);
1613 let hook = ConnectionDrainHook::new(conns.clone(), Duration::from_secs(5))
1614 .with_poll_interval(poll_interval);
1615
1616 let conns_clone = conns.clone();
1618 tokio::spawn(async move {
1619 tokio::time::sleep(Duration::from_millis(150)).await;
1620 conns_clone.store(0, Ordering::SeqCst);
1621 });
1622
1623 let start = Instant::now();
1624 let result = hook.on_shutdown().await;
1625 let elapsed = start.elapsed();
1626
1627 assert!(result.is_ok());
1628 assert!(
1630 elapsed >= Duration::from_millis(100),
1631 "should have polled at least once before completion, elapsed {:?}",
1632 elapsed
1633 );
1634 }
1635}