Skip to main content

amaters_server/
shutdown.rs

1//! Graceful shutdown handling
2//!
3//! This module provides signal handling for graceful server shutdown,
4//! coordinating the shutdown of all components in the correct order.
5//! It supports connection draining, phased shutdown, state persistence
6//! via shutdown hooks, and detailed status reporting.
7//! It also handles SIGHUP for hot configuration reload on Unix platforms.
8
9use crate::config::ReloadableConfig;
10use std::fmt;
11use std::path::PathBuf;
12use std::sync::Arc;
13use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
14use std::time::{Duration, Instant};
15use tokio::sync::{Mutex, broadcast, watch};
16use tracing::{debug, error, info, warn};
17
18// ---------------------------------------------------------------------------
19// Shutdown phase
20// ---------------------------------------------------------------------------
21
22/// Phases of the shutdown lifecycle
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
24pub enum ShutdownPhase {
25    /// Server is running normally and accepting requests
26    Running,
27    /// Server stopped accepting new connections; waiting for in-flight requests
28    Draining,
29    /// Executing shutdown hooks (WAL flush, memtable flush, metrics snapshot, etc.)
30    FlushingState,
31    /// Shutdown complete
32    Terminated,
33}
34
35impl ShutdownPhase {
36    /// Numeric representation for atomic storage
37    fn as_u64(self) -> u64 {
38        match self {
39            Self::Running => 0,
40            Self::Draining => 1,
41            Self::FlushingState => 2,
42            Self::Terminated => 3,
43        }
44    }
45
46    fn from_u64(val: u64) -> Self {
47        match val {
48            0 => Self::Running,
49            1 => Self::Draining,
50            2 => Self::FlushingState,
51            3 => Self::Terminated,
52            _ => Self::Terminated,
53        }
54    }
55}
56
57impl fmt::Display for ShutdownPhase {
58    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59        match self {
60            Self::Running => write!(f, "Running"),
61            Self::Draining => write!(f, "Draining"),
62            Self::FlushingState => write!(f, "FlushingState"),
63            Self::Terminated => write!(f, "Terminated"),
64        }
65    }
66}
67
68// ---------------------------------------------------------------------------
69// Drain configuration
70// ---------------------------------------------------------------------------
71
72/// Configuration for connection draining behaviour
73#[derive(Debug, Clone)]
74pub struct DrainConfig {
75    /// Maximum time to wait for in-flight requests to complete
76    pub drain_timeout: Duration,
77    /// Interval at which we log draining progress
78    pub check_interval: Duration,
79    /// Timeout for the flushing-state phase (hook execution)
80    pub flush_timeout: Duration,
81}
82
83impl Default for DrainConfig {
84    fn default() -> Self {
85        Self {
86            drain_timeout: Duration::from_secs(30),
87            check_interval: Duration::from_secs(1),
88            flush_timeout: Duration::from_secs(30),
89        }
90    }
91}
92
93// ---------------------------------------------------------------------------
94// Shutdown hook trait
95// ---------------------------------------------------------------------------
96
97/// A hook that runs during the `FlushingState` phase of shutdown.
98///
99/// Implementors should perform any final persistence / cleanup work
100/// (e.g. flushing WAL, memtable, metrics) in [`on_shutdown`](ShutdownHook::on_shutdown).
101#[async_trait::async_trait]
102pub trait ShutdownHook: Send + Sync {
103    /// Human-readable name of this hook (used in logging)
104    fn name(&self) -> &str;
105
106    /// Execute the hook. Errors are logged but do **not** prevent other hooks
107    /// from running.
108    async fn on_shutdown(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
109}
110
111// ---------------------------------------------------------------------------
112// Storage integration traits
113// ---------------------------------------------------------------------------
114
115/// Trait for WAL sync operations during shutdown
116pub trait WalWriter: Send + Sync {
117    /// Sync all pending WAL data to disk
118    fn sync(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
119    /// Return the current size of the WAL in bytes
120    fn current_size(&self) -> u64;
121}
122
123/// Trait for memtable flush operations during shutdown
124pub trait MemtableFlusher: Send + Sync {
125    /// Flush the active memtable to an SSTable, returning the number of entries flushed
126    fn flush_to_sstable(&self) -> Result<usize, Box<dyn std::error::Error + Send + Sync>>;
127}
128
129// ---------------------------------------------------------------------------
130// Hook execution result
131// ---------------------------------------------------------------------------
132
133/// Result of a single shutdown hook execution
134#[derive(Debug, Clone)]
135pub struct HookExecutionResult {
136    /// Name of the hook that was executed
137    pub hook_name: String,
138    /// Whether the hook completed successfully
139    pub success: bool,
140    /// How long the hook took to execute
141    pub duration: Duration,
142    /// Error message if the hook failed
143    pub error: Option<String>,
144}
145
146// ---------------------------------------------------------------------------
147// Built-in hooks
148// ---------------------------------------------------------------------------
149
150/// Flush the Write-Ahead Log to disk
151pub struct WalFlushHook {
152    /// Timeout for this individual hook
153    pub timeout: Duration,
154    /// Optional WAL writer for real storage integration
155    writer: Option<Arc<dyn WalWriter>>,
156}
157
158impl WalFlushHook {
159    /// Create a WAL flush hook with a real writer
160    pub fn with_writer(writer: Arc<dyn WalWriter>, timeout: Duration) -> Self {
161        Self {
162            timeout,
163            writer: Some(writer),
164        }
165    }
166}
167
168impl Default for WalFlushHook {
169    fn default() -> Self {
170        Self {
171            timeout: Duration::from_secs(10),
172            writer: None,
173        }
174    }
175}
176
177#[async_trait::async_trait]
178impl ShutdownHook for WalFlushHook {
179    fn name(&self) -> &str {
180        "WalFlush"
181    }
182
183    async fn on_shutdown(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
184        match &self.writer {
185            Some(writer) => {
186                let size = writer.current_size();
187                info!("Flushing WAL to disk ({} bytes)", size);
188                writer.sync()?;
189                info!("WAL flush complete ({} bytes synced)", size);
190            }
191            None => {
192                info!("No WAL writer configured - skipping flush");
193            }
194        }
195        Ok(())
196    }
197}
198
199/// Flush the active memtable to an SSTable
200pub struct MemtableFlushHook {
201    /// Timeout for this individual hook
202    pub timeout: Duration,
203    /// Optional memtable flusher for real storage integration
204    flusher: Option<Arc<dyn MemtableFlusher>>,
205}
206
207impl MemtableFlushHook {
208    /// Create a memtable flush hook with a real flusher
209    pub fn with_flusher(flusher: Arc<dyn MemtableFlusher>, timeout: Duration) -> Self {
210        Self {
211            timeout,
212            flusher: Some(flusher),
213        }
214    }
215}
216
217impl Default for MemtableFlushHook {
218    fn default() -> Self {
219        Self {
220            timeout: Duration::from_secs(15),
221            flusher: None,
222        }
223    }
224}
225
226#[async_trait::async_trait]
227impl ShutdownHook for MemtableFlushHook {
228    fn name(&self) -> &str {
229        "MemtableFlush"
230    }
231
232    async fn on_shutdown(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
233        match &self.flusher {
234            Some(flusher) => {
235                info!("Flushing active memtable to SSTable");
236                let entries = flusher.flush_to_sstable()?;
237                info!("Memtable flush complete ({} entries flushed)", entries);
238            }
239            None => {
240                info!("No memtable flusher configured - skipping flush");
241            }
242        }
243        Ok(())
244    }
245}
246
247/// Drain active connections before shutdown
248pub struct ConnectionDrainHook {
249    /// Shared counter of active connections
250    active_connections: Arc<AtomicUsize>,
251    /// Maximum time to wait for connections to drain
252    drain_timeout: Duration,
253    /// Interval between polling the connection counter
254    poll_interval: Duration,
255}
256
257impl ConnectionDrainHook {
258    /// Create a new connection drain hook
259    ///
260    /// # Arguments
261    /// * `active_connections` - Shared atomic counter of active connections
262    /// * `drain_timeout` - Maximum time to wait for all connections to close
263    pub fn new(active_connections: Arc<AtomicUsize>, drain_timeout: Duration) -> Self {
264        Self {
265            active_connections,
266            drain_timeout,
267            poll_interval: Duration::from_millis(100),
268        }
269    }
270
271    /// Set a custom poll interval (default is 100ms)
272    pub fn with_poll_interval(mut self, interval: Duration) -> Self {
273        self.poll_interval = interval;
274        self
275    }
276}
277
278#[async_trait::async_trait]
279impl ShutdownHook for ConnectionDrainHook {
280    fn name(&self) -> &str {
281        "ConnectionDrain"
282    }
283
284    async fn on_shutdown(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
285        let deadline = Instant::now() + self.drain_timeout;
286
287        loop {
288            let remaining = self.active_connections.load(Ordering::SeqCst);
289            if remaining == 0 {
290                info!("All connections drained");
291                return Ok(());
292            }
293
294            if Instant::now() >= deadline {
295                warn!(
296                    "Connection drain timeout ({:?}) exceeded with {} connections remaining",
297                    self.drain_timeout, remaining
298                );
299                return Err(format!(
300                    "connection drain timed out with {} connections remaining",
301                    remaining
302                )
303                .into());
304            }
305
306            info!("Draining connections: {} remaining", remaining);
307            tokio::time::sleep(self.poll_interval).await;
308        }
309    }
310}
311
312/// Save a final metrics snapshot
313pub struct MetricsSnapshotHook {
314    /// Timeout for this individual hook
315    pub timeout: Duration,
316    /// Optional path to write metrics data to
317    metrics_path: Option<PathBuf>,
318    /// Optional provider that produces metrics data as bytes
319    metrics_provider: Option<Arc<dyn Fn() -> Vec<u8> + Send + Sync>>,
320}
321
322impl MetricsSnapshotHook {
323    /// Create a metrics snapshot hook with a provider and output path
324    pub fn with_provider(
325        provider: Arc<dyn Fn() -> Vec<u8> + Send + Sync>,
326        path: PathBuf,
327        timeout: Duration,
328    ) -> Self {
329        Self {
330            timeout,
331            metrics_path: Some(path),
332            metrics_provider: Some(provider),
333        }
334    }
335}
336
337impl Default for MetricsSnapshotHook {
338    fn default() -> Self {
339        Self {
340            timeout: Duration::from_secs(5),
341            metrics_path: None,
342            metrics_provider: None,
343        }
344    }
345}
346
347#[async_trait::async_trait]
348impl ShutdownHook for MetricsSnapshotHook {
349    fn name(&self) -> &str {
350        "MetricsSnapshot"
351    }
352
353    async fn on_shutdown(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
354        match (&self.metrics_provider, &self.metrics_path) {
355            (Some(provider), Some(path)) => {
356                let data = provider();
357                info!(
358                    "Writing {} bytes of metrics to {}",
359                    data.len(),
360                    path.display()
361                );
362                std::fs::write(path, &data)?;
363                info!("Metrics snapshot saved successfully");
364            }
365            _ => {
366                info!("No metrics provider/path configured - skipping snapshot");
367            }
368        }
369        Ok(())
370    }
371}
372
373// ---------------------------------------------------------------------------
374// Shutdown status
375// ---------------------------------------------------------------------------
376
377/// Snapshot of the current shutdown progress
378#[derive(Debug, Clone)]
379pub struct ShutdownStatus {
380    /// Current phase
381    pub phase: ShutdownPhase,
382    /// Number of in-flight requests still being processed
383    pub active_requests: usize,
384    /// Number of hooks that have completed (successfully or not)
385    pub hooks_completed: usize,
386    /// Total registered hooks
387    pub hooks_total: usize,
388    /// Milliseconds elapsed since shutdown was initiated (0 if not yet initiated)
389    pub elapsed_ms: u64,
390}
391
392// ---------------------------------------------------------------------------
393// Shutdown coordinator
394// ---------------------------------------------------------------------------
395
396/// Shutdown coordinator
397///
398/// Manages graceful shutdown across all server components, including connection
399/// draining, phased shutdown, and hook execution.
400#[derive(Clone)]
401pub struct ShutdownCoordinator {
402    inner: Arc<ShutdownInner>,
403}
404
405struct ShutdownInner {
406    /// Broadcast channel for the initial shutdown signal
407    sender: broadcast::Sender<()>,
408    /// Watch channel so late subscribers can observe phase changes
409    phase_tx: watch::Sender<ShutdownPhase>,
410    phase_rx: watch::Receiver<ShutdownPhase>,
411    /// Atomic flag indicating shutdown initiated (idempotent)
412    shutdown_initiated: AtomicBool,
413    /// Current phase stored atomically for lock-free reads
414    phase: AtomicU64,
415    /// Number of active (in-flight) requests
416    active_requests: AtomicUsize,
417    /// Registered shutdown hooks (protected by Mutex for append + iteration)
418    hooks: Mutex<Vec<Box<dyn ShutdownHook>>>,
419    /// Number of hooks completed so far
420    hooks_completed: AtomicUsize,
421    /// Results from hook execution
422    hook_results: Mutex<Vec<HookExecutionResult>>,
423    /// Drain configuration
424    drain_config: DrainConfig,
425    /// Instant when shutdown was initiated (set once, then read-only)
426    shutdown_start: Mutex<Option<Instant>>,
427}
428
429impl ShutdownCoordinator {
430    /// Create a new shutdown coordinator with default drain configuration
431    pub fn new() -> Self {
432        Self::with_config(DrainConfig::default())
433    }
434
435    /// Create a new shutdown coordinator with the given drain configuration
436    pub fn with_config(config: DrainConfig) -> Self {
437        let (sender, _) = broadcast::channel(16);
438        let (phase_tx, phase_rx) = watch::channel(ShutdownPhase::Running);
439
440        Self {
441            inner: Arc::new(ShutdownInner {
442                sender,
443                phase_tx,
444                phase_rx,
445                shutdown_initiated: AtomicBool::new(false),
446                phase: AtomicU64::new(ShutdownPhase::Running.as_u64()),
447                active_requests: AtomicUsize::new(0),
448                hooks: Mutex::new(Vec::new()),
449                hooks_completed: AtomicUsize::new(0),
450                hook_results: Mutex::new(Vec::new()),
451                drain_config: config,
452                shutdown_start: Mutex::new(None),
453            }),
454        }
455    }
456
457    // -- Subscription -------------------------------------------------------
458
459    /// Subscribe to the initial shutdown broadcast signal
460    pub fn subscribe(&self) -> broadcast::Receiver<()> {
461        self.inner.sender.subscribe()
462    }
463
464    /// Subscribe to phase changes via a `watch` channel
465    pub fn phase_watch(&self) -> watch::Receiver<ShutdownPhase> {
466        self.inner.phase_rx.clone()
467    }
468
469    // -- Active request tracking --------------------------------------------
470
471    /// Increment the active request counter (called when a new request arrives)
472    pub fn request_start(&self) {
473        self.inner.active_requests.fetch_add(1, Ordering::SeqCst);
474    }
475
476    /// Decrement the active request counter (called when a request completes)
477    pub fn request_end(&self) {
478        self.inner.active_requests.fetch_sub(1, Ordering::SeqCst);
479    }
480
481    /// Current number of active (in-flight) requests
482    pub fn active_request_count(&self) -> usize {
483        self.inner.active_requests.load(Ordering::SeqCst)
484    }
485
486    // -- Hook registration --------------------------------------------------
487
488    /// Register a shutdown hook that will run during the `FlushingState` phase
489    pub async fn register_shutdown_hook(&self, hook: Box<dyn ShutdownHook>) {
490        let mut hooks = self.inner.hooks.lock().await;
491        info!("Registered shutdown hook: {}", hook.name());
492        hooks.push(hook);
493    }
494
495    // -- Phase management ---------------------------------------------------
496
497    /// Get the current shutdown phase
498    pub fn current_phase(&self) -> ShutdownPhase {
499        ShutdownPhase::from_u64(self.inner.phase.load(Ordering::SeqCst))
500    }
501
502    /// Whether we are currently accepting new connections
503    pub fn is_accepting(&self) -> bool {
504        self.current_phase() == ShutdownPhase::Running
505    }
506
507    fn set_phase(&self, phase: ShutdownPhase) {
508        self.inner.phase.store(phase.as_u64(), Ordering::SeqCst);
509        // Ignore send error -- it only fails when there are no receivers
510        let _ = self.inner.phase_tx.send(phase);
511        info!("Shutdown phase: {}", phase);
512    }
513
514    // -- Query --------------------------------------------------------------
515
516    /// Check if shutdown has been initiated
517    pub fn is_shutting_down(&self) -> bool {
518        self.inner.shutdown_initiated.load(Ordering::SeqCst)
519    }
520
521    /// Returns `"shutting_down"` when draining/flushing/terminated, otherwise `"ok"`.
522    ///
523    /// Load balancers can poll this to detect that the server is shutting down
524    /// and stop routing new requests.
525    pub fn health_status_label(&self) -> &'static str {
526        match self.current_phase() {
527            ShutdownPhase::Running => "ok",
528            _ => "shutting_down",
529        }
530    }
531
532    /// Build a [`ShutdownStatus`] snapshot
533    pub fn status(&self) -> ShutdownStatus {
534        let elapsed_ms = {
535            // Try-lock: if the mutex is contended we just report 0
536            if let Ok(guard) = self.inner.shutdown_start.try_lock() {
537                guard.map(|s| s.elapsed().as_millis() as u64).unwrap_or(0)
538            } else {
539                0
540            }
541        };
542
543        let hooks_total = if let Ok(hooks) = self.inner.hooks.try_lock() {
544            hooks.len()
545        } else {
546            0
547        };
548
549        ShutdownStatus {
550            phase: self.current_phase(),
551            active_requests: self.active_request_count(),
552            hooks_completed: self.inner.hooks_completed.load(Ordering::SeqCst),
553            hooks_total,
554            elapsed_ms,
555        }
556    }
557
558    // -- Initiate shutdown --------------------------------------------------
559
560    /// Initiate a graceful shutdown.
561    ///
562    /// This is idempotent -- calling it more than once is a no-op.
563    /// The method broadcasts the shutdown signal, then (in a spawned task)
564    /// drives the phase machine: Draining -> FlushingState -> Terminated.
565    pub fn shutdown(&self) {
566        if self.inner.shutdown_initiated.swap(true, Ordering::SeqCst) {
567            // Already initiated
568            debug!("Shutdown already initiated - ignoring duplicate signal");
569            return;
570        }
571
572        info!("Initiating graceful shutdown");
573
574        // Record start time
575        if let Ok(mut guard) = self.inner.shutdown_start.try_lock() {
576            *guard = Some(Instant::now());
577        }
578
579        // Broadcast to legacy subscribers
580        if let Err(e) = self.inner.sender.send(()) {
581            warn!("Failed to broadcast shutdown signal: {}", e);
582        }
583
584        // Spawn the phase-driver task
585        let coord = self.clone();
586        tokio::spawn(async move {
587            coord.run_shutdown_sequence().await;
588        });
589    }
590
591    /// Execute the full shutdown sequence: Draining -> FlushingState -> Terminated
592    async fn run_shutdown_sequence(&self) {
593        // ---- Phase 1: Draining ----
594        self.set_phase(ShutdownPhase::Draining);
595        self.drain_connections().await;
596
597        // ---- Phase 2: Flushing state (run hooks) ----
598        self.set_phase(ShutdownPhase::FlushingState);
599        self.run_hooks().await;
600
601        // ---- Phase 3: Terminated ----
602        self.set_phase(ShutdownPhase::Terminated);
603        info!("Shutdown complete");
604    }
605
606    /// Wait for all active requests to drain, up to `drain_timeout`.
607    async fn drain_connections(&self) {
608        let cfg = &self.inner.drain_config;
609        let deadline = Instant::now() + cfg.drain_timeout;
610
611        loop {
612            let remaining = self.active_request_count();
613            if remaining == 0 {
614                info!("All in-flight requests drained");
615                return;
616            }
617
618            if Instant::now() >= deadline {
619                warn!(
620                    "Drain timeout ({:?}) exceeded with {} requests remaining - force-closing",
621                    cfg.drain_timeout, remaining
622                );
623                return;
624            }
625
626            info!("Draining: {} requests remaining", remaining);
627            tokio::time::sleep(cfg.check_interval).await;
628        }
629    }
630
631    /// Retrieve the results from all executed hooks.
632    ///
633    /// Returns an empty vector if shutdown has not been initiated or hooks
634    /// have not yet finished executing.
635    pub async fn hook_results(&self) -> Vec<HookExecutionResult> {
636        self.inner.hook_results.lock().await.clone()
637    }
638
639    /// Execute all registered shutdown hooks, each with the global flush timeout.
640    async fn run_hooks(&self) {
641        let hooks = {
642            let mut guard = self.inner.hooks.lock().await;
643            std::mem::take(&mut *guard)
644        };
645
646        if hooks.is_empty() {
647            info!("No shutdown hooks registered");
648            return;
649        }
650
651        let flush_timeout = self.inner.drain_config.flush_timeout;
652        info!("Executing {} shutdown hook(s)", hooks.len());
653
654        for hook in &hooks {
655            let name = hook.name().to_string();
656            info!("Running shutdown hook: {}", name);
657
658            let start = Instant::now();
659            let result = match tokio::time::timeout(flush_timeout, hook.on_shutdown()).await {
660                Ok(Ok(())) => {
661                    info!("Shutdown hook '{}' completed successfully", name);
662                    HookExecutionResult {
663                        hook_name: name,
664                        success: true,
665                        duration: start.elapsed(),
666                        error: None,
667                    }
668                }
669                Ok(Err(e)) => {
670                    let msg = e.to_string();
671                    error!("Shutdown hook '{}' failed: {}", name, msg);
672                    HookExecutionResult {
673                        hook_name: name,
674                        success: false,
675                        duration: start.elapsed(),
676                        error: Some(msg),
677                    }
678                }
679                Err(_) => {
680                    let msg = format!("timed out after {:?}", flush_timeout);
681                    error!("Shutdown hook '{}' {}", name, msg);
682                    HookExecutionResult {
683                        hook_name: name,
684                        success: false,
685                        duration: start.elapsed(),
686                        error: Some(msg),
687                    }
688                }
689            };
690
691            {
692                let mut results = self.inner.hook_results.lock().await;
693                results.push(result);
694            }
695            self.inner.hooks_completed.fetch_add(1, Ordering::SeqCst);
696        }
697
698        info!(
699            "All shutdown hooks processed ({} total)",
700            self.inner.hooks_completed.load(Ordering::SeqCst)
701        );
702    }
703}
704
705impl Default for ShutdownCoordinator {
706    fn default() -> Self {
707        Self::new()
708    }
709}
710
711// ---------------------------------------------------------------------------
712// Signal handler setup
713// ---------------------------------------------------------------------------
714
715/// Setup signal handlers for graceful shutdown
716///
717/// Listens for SIGTERM and SIGINT signals and triggers shutdown
718pub async fn setup_signal_handlers(coordinator: ShutdownCoordinator) {
719    tokio::spawn(async move {
720        if let Err(e) = wait_for_signal().await {
721            warn!("Error setting up signal handlers: {}", e);
722            return;
723        }
724
725        info!("Received shutdown signal");
726        coordinator.shutdown();
727    });
728}
729
730/// Setup SIGHUP handler for hot configuration reload (Unix only)
731///
732/// On SIGHUP, reloads configuration from the stored config file path.
733/// On non-Unix platforms, this is a no-op; use `ReloadableConfig::manual_reload()` instead.
734#[cfg(unix)]
735pub async fn setup_sighup_handler(config: ReloadableConfig) {
736    tokio::spawn(async move {
737        let mut sighup = match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::hangup())
738        {
739            Ok(s) => s,
740            Err(e) => {
741                warn!("Failed to setup SIGHUP handler: {}", e);
742                return;
743            }
744        };
745
746        loop {
747            sighup.recv().await;
748            info!("Received SIGHUP - reloading configuration");
749
750            match config.reload_from_stored_path() {
751                Ok(report) => {
752                    if report.success {
753                        info!("Configuration reload completed: {}", report);
754                    } else {
755                        error!("Configuration reload failed: {}", report);
756                    }
757                }
758                Err(e) => {
759                    error!("Configuration reload error: {}", e);
760                }
761            }
762        }
763    });
764}
765
766/// No-op SIGHUP handler for non-Unix platforms.
767///
768/// Use `ReloadableConfig::manual_reload()` as an alternative.
769#[cfg(not(unix))]
770pub async fn setup_sighup_handler(_config: ReloadableConfig) {
771    info!("SIGHUP handler not available on this platform; use manual_reload() instead");
772}
773
774/// Wait for shutdown signal (SIGTERM or SIGINT)
775async fn wait_for_signal() -> Result<(), std::io::Error> {
776    #[cfg(unix)]
777    {
778        use tokio::signal::unix::{SignalKind, signal};
779
780        let mut sigterm = signal(SignalKind::terminate())?;
781        let mut sigint = signal(SignalKind::interrupt())?;
782
783        tokio::select! {
784            _ = sigterm.recv() => {
785                info!("Received SIGTERM");
786            }
787            _ = sigint.recv() => {
788                info!("Received SIGINT");
789            }
790        }
791    }
792
793    #[cfg(not(unix))]
794    {
795        use tokio::signal;
796        signal::ctrl_c().await?;
797        info!("Received Ctrl+C");
798    }
799
800    Ok(())
801}
802
803// ---------------------------------------------------------------------------
804// Shutdown guard
805// ---------------------------------------------------------------------------
806
807/// Shutdown guard for automatic cleanup
808///
809/// Triggers shutdown when dropped (useful for panic recovery)
810pub struct ShutdownGuard {
811    coordinator: ShutdownCoordinator,
812    disarmed: Arc<AtomicBool>,
813}
814
815impl ShutdownGuard {
816    /// Create a new shutdown guard
817    pub fn new(coordinator: ShutdownCoordinator) -> Self {
818        Self {
819            coordinator,
820            disarmed: Arc::new(AtomicBool::new(false)),
821        }
822    }
823
824    /// Disarm the guard (won't trigger shutdown on drop)
825    pub fn disarm(&self) {
826        self.disarmed.store(true, Ordering::SeqCst);
827    }
828}
829
830impl Drop for ShutdownGuard {
831    fn drop(&mut self) {
832        if !self.disarmed.load(Ordering::SeqCst) {
833            warn!("ShutdownGuard dropped without disarming - triggering shutdown");
834            self.coordinator.shutdown();
835        }
836    }
837}
838
839// ---------------------------------------------------------------------------
840// Request guard (RAII active-request tracking)
841// ---------------------------------------------------------------------------
842
843/// RAII guard that tracks an active request.
844///
845/// Calls `request_start()` on creation and `request_end()` on drop,
846/// ensuring the active-request counter stays accurate even if the
847/// request handler panics.
848pub struct RequestGuard {
849    coordinator: ShutdownCoordinator,
850}
851
852impl RequestGuard {
853    /// Create a new request guard, incrementing the active request count
854    pub fn new(coordinator: ShutdownCoordinator) -> Self {
855        coordinator.request_start();
856        Self { coordinator }
857    }
858}
859
860impl Drop for RequestGuard {
861    fn drop(&mut self) {
862        self.coordinator.request_end();
863    }
864}
865
866// ---------------------------------------------------------------------------
867// Tests
868// ---------------------------------------------------------------------------
869
870#[cfg(test)]
871mod tests {
872    use super::*;
873    use std::sync::atomic::AtomicBool as StdAtomicBool;
874    use std::time::Duration;
875    use tokio::time::timeout;
876
877    /// Helper: wait for a coordinator to reach Terminated phase (with timeout)
878    async fn wait_terminated(coordinator: &ShutdownCoordinator, dur: Duration) {
879        let mut watcher = coordinator.phase_watch();
880        let _ = timeout(dur, async {
881            loop {
882                if *watcher.borrow() == ShutdownPhase::Terminated {
883                    return;
884                }
885                if watcher.changed().await.is_err() {
886                    return;
887                }
888            }
889        })
890        .await;
891    }
892
893    #[tokio::test]
894    async fn test_shutdown_coordinator() {
895        let coordinator = ShutdownCoordinator::new();
896        let mut receiver = coordinator.subscribe();
897
898        assert!(!coordinator.is_shutting_down());
899        assert_eq!(coordinator.current_phase(), ShutdownPhase::Running);
900
901        coordinator.shutdown();
902
903        assert!(coordinator.is_shutting_down());
904
905        // Should receive shutdown signal
906        let result = timeout(Duration::from_millis(100), receiver.recv()).await;
907        assert!(result.is_ok());
908    }
909
910    #[tokio::test]
911    async fn test_multiple_subscribers() {
912        let coordinator = ShutdownCoordinator::new();
913        let mut rx1 = coordinator.subscribe();
914        let mut rx2 = coordinator.subscribe();
915        let mut rx3 = coordinator.subscribe();
916
917        coordinator.shutdown();
918
919        assert!(
920            timeout(Duration::from_millis(100), rx1.recv())
921                .await
922                .is_ok()
923        );
924        assert!(
925            timeout(Duration::from_millis(100), rx2.recv())
926                .await
927                .is_ok()
928        );
929        assert!(
930            timeout(Duration::from_millis(100), rx3.recv())
931                .await
932                .is_ok()
933        );
934    }
935
936    #[tokio::test]
937    async fn test_shutdown_idempotent() {
938        let coordinator = ShutdownCoordinator::new();
939
940        coordinator.shutdown();
941        coordinator.shutdown(); // Second call should be a no-op
942
943        assert!(coordinator.is_shutting_down());
944
945        // Let the phase driver complete
946        wait_terminated(&coordinator, Duration::from_secs(2)).await;
947        assert_eq!(coordinator.current_phase(), ShutdownPhase::Terminated);
948    }
949
950    #[test]
951    fn test_shutdown_guard_disarm() {
952        let coordinator = ShutdownCoordinator::new();
953        let guard = ShutdownGuard::new(coordinator.clone());
954
955        guard.disarm();
956        drop(guard);
957
958        assert!(!coordinator.is_shutting_down());
959    }
960
961    #[tokio::test]
962    async fn test_shutdown_guard_trigger() {
963        let coordinator = ShutdownCoordinator::new();
964        let guard = ShutdownGuard::new(coordinator.clone());
965
966        drop(guard);
967
968        assert!(coordinator.is_shutting_down());
969
970        // Let the spawned phase-driver complete
971        wait_terminated(&coordinator, Duration::from_secs(2)).await;
972    }
973
974    // -- Phase transition tests ---------------------------------------------
975
976    #[tokio::test]
977    async fn test_phase_transitions() {
978        let config = DrainConfig {
979            drain_timeout: Duration::from_millis(200),
980            check_interval: Duration::from_millis(50),
981            flush_timeout: Duration::from_millis(200),
982        };
983        let coordinator = ShutdownCoordinator::with_config(config);
984
985        assert_eq!(coordinator.current_phase(), ShutdownPhase::Running);
986
987        coordinator.shutdown();
988
989        wait_terminated(&coordinator, Duration::from_secs(2)).await;
990        assert_eq!(coordinator.current_phase(), ShutdownPhase::Terminated);
991    }
992
993    #[tokio::test]
994    async fn test_drain_waits_for_in_flight_requests() {
995        let config = DrainConfig {
996            drain_timeout: Duration::from_secs(2),
997            check_interval: Duration::from_millis(50),
998            flush_timeout: Duration::from_millis(200),
999        };
1000        let coordinator = ShutdownCoordinator::with_config(config);
1001
1002        // Simulate 3 in-flight requests
1003        coordinator.request_start();
1004        coordinator.request_start();
1005        coordinator.request_start();
1006        assert_eq!(coordinator.active_request_count(), 3);
1007
1008        coordinator.shutdown();
1009
1010        // Give the drainer a moment to start
1011        tokio::time::sleep(Duration::from_millis(80)).await;
1012        assert_eq!(coordinator.current_phase(), ShutdownPhase::Draining);
1013
1014        // Complete requests one by one
1015        coordinator.request_end();
1016        tokio::time::sleep(Duration::from_millis(60)).await;
1017        coordinator.request_end();
1018        tokio::time::sleep(Duration::from_millis(60)).await;
1019        coordinator.request_end();
1020
1021        wait_terminated(&coordinator, Duration::from_secs(2)).await;
1022        assert_eq!(coordinator.current_phase(), ShutdownPhase::Terminated);
1023    }
1024
1025    #[tokio::test]
1026    async fn test_drain_timeout_forces_termination() {
1027        let config = DrainConfig {
1028            drain_timeout: Duration::from_millis(150),
1029            check_interval: Duration::from_millis(30),
1030            flush_timeout: Duration::from_millis(100),
1031        };
1032        let coordinator = ShutdownCoordinator::with_config(config);
1033
1034        // Simulate a request that never finishes
1035        coordinator.request_start();
1036
1037        coordinator.shutdown();
1038
1039        wait_terminated(&coordinator, Duration::from_secs(2)).await;
1040        assert_eq!(coordinator.current_phase(), ShutdownPhase::Terminated);
1041        // The stuck request is still counted
1042        assert_eq!(coordinator.active_request_count(), 1);
1043    }
1044
1045    #[tokio::test]
1046    async fn test_shutdown_hooks_execute_in_order() {
1047        let config = DrainConfig {
1048            drain_timeout: Duration::from_millis(100),
1049            check_interval: Duration::from_millis(20),
1050            flush_timeout: Duration::from_secs(1),
1051        };
1052        let coordinator = ShutdownCoordinator::with_config(config);
1053
1054        let order = Arc::new(Mutex::new(Vec::<String>::new()));
1055
1056        struct OrderHook {
1057            hook_name: String,
1058            order: Arc<Mutex<Vec<String>>>,
1059        }
1060
1061        #[async_trait::async_trait]
1062        impl ShutdownHook for OrderHook {
1063            fn name(&self) -> &str {
1064                &self.hook_name
1065            }
1066            async fn on_shutdown(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
1067                let mut guard = self.order.lock().await;
1068                guard.push(self.hook_name.clone());
1069                Ok(())
1070            }
1071        }
1072
1073        coordinator
1074            .register_shutdown_hook(Box::new(OrderHook {
1075                hook_name: "first".to_string(),
1076                order: order.clone(),
1077            }))
1078            .await;
1079        coordinator
1080            .register_shutdown_hook(Box::new(OrderHook {
1081                hook_name: "second".to_string(),
1082                order: order.clone(),
1083            }))
1084            .await;
1085        coordinator
1086            .register_shutdown_hook(Box::new(OrderHook {
1087                hook_name: "third".to_string(),
1088                order: order.clone(),
1089            }))
1090            .await;
1091
1092        coordinator.shutdown();
1093        wait_terminated(&coordinator, Duration::from_secs(2)).await;
1094
1095        let executed = order.lock().await;
1096        assert_eq!(*executed, vec!["first", "second", "third"]);
1097    }
1098
1099    #[tokio::test]
1100    async fn test_hook_failure_does_not_block_others() {
1101        let config = DrainConfig {
1102            drain_timeout: Duration::from_millis(50),
1103            check_interval: Duration::from_millis(10),
1104            flush_timeout: Duration::from_secs(1),
1105        };
1106        let coordinator = ShutdownCoordinator::with_config(config);
1107
1108        let completed = Arc::new(StdAtomicBool::new(false));
1109
1110        struct FailingHook;
1111
1112        #[async_trait::async_trait]
1113        impl ShutdownHook for FailingHook {
1114            fn name(&self) -> &str {
1115                "failing"
1116            }
1117            async fn on_shutdown(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
1118                Err("intentional failure".into())
1119            }
1120        }
1121
1122        struct SuccessHook {
1123            completed: Arc<StdAtomicBool>,
1124        }
1125
1126        #[async_trait::async_trait]
1127        impl ShutdownHook for SuccessHook {
1128            fn name(&self) -> &str {
1129                "success_after_failure"
1130            }
1131            async fn on_shutdown(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
1132                self.completed.store(true, Ordering::SeqCst);
1133                Ok(())
1134            }
1135        }
1136
1137        coordinator
1138            .register_shutdown_hook(Box::new(FailingHook))
1139            .await;
1140        coordinator
1141            .register_shutdown_hook(Box::new(SuccessHook {
1142                completed: completed.clone(),
1143            }))
1144            .await;
1145
1146        coordinator.shutdown();
1147        wait_terminated(&coordinator, Duration::from_secs(2)).await;
1148
1149        assert!(
1150            completed.load(Ordering::SeqCst),
1151            "Hook after failing hook should still run"
1152        );
1153        assert_eq!(coordinator.inner.hooks_completed.load(Ordering::SeqCst), 2);
1154    }
1155
1156    #[tokio::test]
1157    async fn test_status_reporting() {
1158        let config = DrainConfig {
1159            drain_timeout: Duration::from_secs(1),
1160            check_interval: Duration::from_millis(50),
1161            flush_timeout: Duration::from_millis(200),
1162        };
1163        let coordinator = ShutdownCoordinator::with_config(config);
1164
1165        // Before shutdown
1166        let st = coordinator.status();
1167        assert_eq!(st.phase, ShutdownPhase::Running);
1168        assert_eq!(st.active_requests, 0);
1169        assert_eq!(st.hooks_completed, 0);
1170        assert_eq!(st.elapsed_ms, 0);
1171
1172        coordinator.request_start();
1173        coordinator.request_start();
1174
1175        let st = coordinator.status();
1176        assert_eq!(st.active_requests, 2);
1177
1178        coordinator.request_end();
1179        coordinator.request_end();
1180
1181        coordinator.shutdown();
1182
1183        // Give it a moment so elapsed_ms is measurably > 0
1184        tokio::time::sleep(Duration::from_millis(20)).await;
1185
1186        // Check while still running or after completion
1187        let st = coordinator.status();
1188        assert!(st.elapsed_ms > 0, "elapsed_ms should be > 0 after shutdown");
1189
1190        // Wait for full completion
1191        wait_terminated(&coordinator, Duration::from_secs(2)).await;
1192
1193        let st = coordinator.status();
1194        assert_eq!(st.phase, ShutdownPhase::Terminated);
1195    }
1196
1197    #[tokio::test]
1198    async fn test_zero_active_requests_fast_shutdown() {
1199        let config = DrainConfig {
1200            drain_timeout: Duration::from_secs(30),
1201            check_interval: Duration::from_millis(50),
1202            flush_timeout: Duration::from_millis(100),
1203        };
1204        let coordinator = ShutdownCoordinator::with_config(config);
1205
1206        let start = Instant::now();
1207        coordinator.shutdown();
1208
1209        wait_terminated(&coordinator, Duration::from_secs(1)).await;
1210
1211        assert_eq!(coordinator.current_phase(), ShutdownPhase::Terminated);
1212        // With zero requests the drain phase should be nearly instant
1213        let elapsed = start.elapsed();
1214        assert!(
1215            elapsed < Duration::from_secs(1),
1216            "Fast shutdown should complete quickly, took {:?}",
1217            elapsed
1218        );
1219    }
1220
1221    #[tokio::test]
1222    async fn test_health_status_label() {
1223        let coordinator = ShutdownCoordinator::new();
1224        assert_eq!(coordinator.health_status_label(), "ok");
1225
1226        coordinator.shutdown();
1227        tokio::time::sleep(Duration::from_millis(50)).await;
1228
1229        // Once shutdown begins, label changes
1230        assert_eq!(coordinator.health_status_label(), "shutting_down");
1231    }
1232
1233    #[tokio::test]
1234    async fn test_request_guard_raii() {
1235        let coordinator = ShutdownCoordinator::new();
1236        assert_eq!(coordinator.active_request_count(), 0);
1237
1238        {
1239            let _g1 = RequestGuard::new(coordinator.clone());
1240            assert_eq!(coordinator.active_request_count(), 1);
1241            {
1242                let _g2 = RequestGuard::new(coordinator.clone());
1243                assert_eq!(coordinator.active_request_count(), 2);
1244            }
1245            // g2 dropped
1246            assert_eq!(coordinator.active_request_count(), 1);
1247        }
1248        // g1 dropped
1249        assert_eq!(coordinator.active_request_count(), 0);
1250    }
1251
1252    #[tokio::test]
1253    async fn test_is_accepting() {
1254        let coordinator = ShutdownCoordinator::new();
1255        assert!(coordinator.is_accepting());
1256
1257        coordinator.shutdown();
1258        tokio::time::sleep(Duration::from_millis(50)).await;
1259
1260        assert!(!coordinator.is_accepting());
1261    }
1262
1263    #[tokio::test]
1264    async fn test_built_in_hooks() {
1265        let config = DrainConfig {
1266            drain_timeout: Duration::from_millis(50),
1267            check_interval: Duration::from_millis(10),
1268            flush_timeout: Duration::from_secs(5),
1269        };
1270        let coordinator = ShutdownCoordinator::with_config(config);
1271
1272        coordinator
1273            .register_shutdown_hook(Box::new(WalFlushHook::default()))
1274            .await;
1275        coordinator
1276            .register_shutdown_hook(Box::new(MemtableFlushHook::default()))
1277            .await;
1278        coordinator
1279            .register_shutdown_hook(Box::new(MetricsSnapshotHook::default()))
1280            .await;
1281
1282        let st = coordinator.status();
1283        assert_eq!(st.hooks_total, 3);
1284
1285        coordinator.shutdown();
1286        wait_terminated(&coordinator, Duration::from_secs(2)).await;
1287
1288        assert_eq!(coordinator.inner.hooks_completed.load(Ordering::SeqCst), 3);
1289    }
1290
1291    #[tokio::test]
1292    async fn test_multiple_shutdown_signals_idempotent() {
1293        let coordinator = ShutdownCoordinator::new();
1294        let mut rx = coordinator.subscribe();
1295
1296        // First call should succeed
1297        coordinator.shutdown();
1298        let recv_result = timeout(Duration::from_millis(100), rx.recv()).await;
1299        assert!(recv_result.is_ok());
1300
1301        // Subsequent calls are no-ops (no additional broadcast)
1302        coordinator.shutdown();
1303        coordinator.shutdown();
1304        coordinator.shutdown();
1305
1306        assert!(coordinator.is_shutting_down());
1307
1308        wait_terminated(&coordinator, Duration::from_secs(2)).await;
1309        assert_eq!(coordinator.current_phase(), ShutdownPhase::Terminated);
1310    }
1311
1312    #[tokio::test]
1313    async fn test_drain_config_default() {
1314        let cfg = DrainConfig::default();
1315        assert_eq!(cfg.drain_timeout, Duration::from_secs(30));
1316        assert_eq!(cfg.check_interval, Duration::from_secs(1));
1317        assert_eq!(cfg.flush_timeout, Duration::from_secs(30));
1318    }
1319
1320    #[tokio::test]
1321    async fn test_phase_display() {
1322        assert_eq!(format!("{}", ShutdownPhase::Running), "Running");
1323        assert_eq!(format!("{}", ShutdownPhase::Draining), "Draining");
1324        assert_eq!(format!("{}", ShutdownPhase::FlushingState), "FlushingState");
1325        assert_eq!(format!("{}", ShutdownPhase::Terminated), "Terminated");
1326    }
1327
1328    // -- Storage integration hook tests -------------------------------------
1329
1330    /// Mock WalWriter for testing
1331    struct MockWalWriter {
1332        sync_called: Arc<StdAtomicBool>,
1333        size: u64,
1334        should_fail: bool,
1335    }
1336
1337    impl WalWriter for MockWalWriter {
1338        fn sync(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
1339            self.sync_called.store(true, Ordering::SeqCst);
1340            if self.should_fail {
1341                return Err("WAL sync failed".into());
1342            }
1343            Ok(())
1344        }
1345
1346        fn current_size(&self) -> u64 {
1347            self.size
1348        }
1349    }
1350
1351    /// Mock MemtableFlusher for testing
1352    struct MockMemtableFlusher {
1353        flush_called: Arc<StdAtomicBool>,
1354        entries: usize,
1355        should_fail: bool,
1356    }
1357
1358    impl MemtableFlusher for MockMemtableFlusher {
1359        fn flush_to_sstable(&self) -> Result<usize, Box<dyn std::error::Error + Send + Sync>> {
1360            self.flush_called.store(true, Ordering::SeqCst);
1361            if self.should_fail {
1362                return Err("memtable flush failed".into());
1363            }
1364            Ok(self.entries)
1365        }
1366    }
1367
1368    #[tokio::test]
1369    async fn test_wal_flush_hook_calls_sync() {
1370        let sync_called = Arc::new(StdAtomicBool::new(false));
1371        let writer = Arc::new(MockWalWriter {
1372            sync_called: sync_called.clone(),
1373            size: 4096,
1374            should_fail: false,
1375        });
1376
1377        let hook = WalFlushHook::with_writer(writer, Duration::from_secs(5));
1378        let result = hook.on_shutdown().await;
1379
1380        assert!(result.is_ok());
1381        assert!(
1382            sync_called.load(Ordering::SeqCst),
1383            "sync() should have been called"
1384        );
1385    }
1386
1387    #[tokio::test]
1388    async fn test_wal_flush_hook_no_writer() {
1389        let hook = WalFlushHook::default();
1390        let result = hook.on_shutdown().await;
1391        assert!(result.is_ok(), "no-writer hook should succeed");
1392    }
1393
1394    #[tokio::test]
1395    async fn test_wal_flush_hook_error() {
1396        let sync_called = Arc::new(StdAtomicBool::new(false));
1397        let writer = Arc::new(MockWalWriter {
1398            sync_called: sync_called.clone(),
1399            size: 1024,
1400            should_fail: true,
1401        });
1402
1403        let hook = WalFlushHook::with_writer(writer, Duration::from_secs(5));
1404        let result = hook.on_shutdown().await;
1405
1406        assert!(result.is_err());
1407        assert!(
1408            sync_called.load(Ordering::SeqCst),
1409            "sync() should have been called even on failure"
1410        );
1411        let err_msg = result.expect_err("should be error").to_string();
1412        assert!(
1413            err_msg.contains("WAL sync failed"),
1414            "error message should propagate"
1415        );
1416    }
1417
1418    #[tokio::test]
1419    async fn test_memtable_flush_hook_calls_flush() {
1420        let flush_called = Arc::new(StdAtomicBool::new(false));
1421        let flusher = Arc::new(MockMemtableFlusher {
1422            flush_called: flush_called.clone(),
1423            entries: 42,
1424            should_fail: false,
1425        });
1426
1427        let hook = MemtableFlushHook::with_flusher(flusher, Duration::from_secs(5));
1428        let result = hook.on_shutdown().await;
1429
1430        assert!(result.is_ok());
1431        assert!(
1432            flush_called.load(Ordering::SeqCst),
1433            "flush_to_sstable() should have been called"
1434        );
1435    }
1436
1437    #[tokio::test]
1438    async fn test_memtable_flush_hook_no_flusher() {
1439        let hook = MemtableFlushHook::default();
1440        let result = hook.on_shutdown().await;
1441        assert!(result.is_ok(), "no-flusher hook should succeed");
1442    }
1443
1444    #[tokio::test]
1445    async fn test_connection_drain_immediate() {
1446        let conns = Arc::new(AtomicUsize::new(0));
1447        let hook = ConnectionDrainHook::new(conns, Duration::from_secs(5));
1448
1449        let start = Instant::now();
1450        let result = hook.on_shutdown().await;
1451        let elapsed = start.elapsed();
1452
1453        assert!(result.is_ok());
1454        assert!(
1455            elapsed < Duration::from_millis(50),
1456            "should return immediately with 0 connections, took {:?}",
1457            elapsed
1458        );
1459    }
1460
1461    #[tokio::test]
1462    async fn test_connection_drain_waits_for_zero() {
1463        let conns = Arc::new(AtomicUsize::new(5));
1464        let hook = ConnectionDrainHook::new(conns.clone(), Duration::from_secs(5))
1465            .with_poll_interval(Duration::from_millis(50));
1466
1467        // Spawn a task that decrements connections over time
1468        let conns_clone = conns.clone();
1469        tokio::spawn(async move {
1470            for _ in 0..5 {
1471                tokio::time::sleep(Duration::from_millis(30)).await;
1472                conns_clone.fetch_sub(1, Ordering::SeqCst);
1473            }
1474        });
1475
1476        let result = hook.on_shutdown().await;
1477        assert!(result.is_ok());
1478        assert_eq!(conns.load(Ordering::SeqCst), 0);
1479    }
1480
1481    #[tokio::test]
1482    async fn test_connection_drain_timeout() {
1483        let conns = Arc::new(AtomicUsize::new(10));
1484        let hook = ConnectionDrainHook::new(conns.clone(), Duration::from_millis(200))
1485            .with_poll_interval(Duration::from_millis(50));
1486
1487        let start = Instant::now();
1488        let result = hook.on_shutdown().await;
1489        let elapsed = start.elapsed();
1490
1491        assert!(result.is_err(), "should error on timeout");
1492        let err_msg = result.expect_err("should be error").to_string();
1493        assert!(
1494            err_msg.contains("timed out"),
1495            "error should mention timeout"
1496        );
1497        assert!(
1498            elapsed >= Duration::from_millis(200),
1499            "should have waited at least the timeout duration, elapsed {:?}",
1500            elapsed
1501        );
1502    }
1503
1504    #[tokio::test]
1505    async fn test_hook_execution_result_captured() {
1506        let config = DrainConfig {
1507            drain_timeout: Duration::from_millis(50),
1508            check_interval: Duration::from_millis(10),
1509            flush_timeout: Duration::from_secs(1),
1510        };
1511        let coordinator = ShutdownCoordinator::with_config(config);
1512
1513        struct NamedHook {
1514            hook_name: String,
1515        }
1516
1517        #[async_trait::async_trait]
1518        impl ShutdownHook for NamedHook {
1519            fn name(&self) -> &str {
1520                &self.hook_name
1521            }
1522            async fn on_shutdown(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
1523                Ok(())
1524            }
1525        }
1526
1527        coordinator
1528            .register_shutdown_hook(Box::new(NamedHook {
1529                hook_name: "test_hook".to_string(),
1530            }))
1531            .await;
1532
1533        coordinator.shutdown();
1534        wait_terminated(&coordinator, Duration::from_secs(2)).await;
1535
1536        let results = coordinator.hook_results().await;
1537        assert_eq!(results.len(), 1);
1538        assert_eq!(results[0].hook_name, "test_hook");
1539        assert!(results[0].success);
1540        assert!(results[0].error.is_none());
1541        assert!(results[0].duration < Duration::from_secs(1));
1542    }
1543
1544    #[tokio::test]
1545    async fn test_hook_error_result() {
1546        let config = DrainConfig {
1547            drain_timeout: Duration::from_millis(50),
1548            check_interval: Duration::from_millis(10),
1549            flush_timeout: Duration::from_secs(1),
1550        };
1551        let coordinator = ShutdownCoordinator::with_config(config);
1552
1553        struct FailHook;
1554
1555        #[async_trait::async_trait]
1556        impl ShutdownHook for FailHook {
1557            fn name(&self) -> &str {
1558                "fail_hook"
1559            }
1560            async fn on_shutdown(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
1561                Err("catastrophic failure".into())
1562            }
1563        }
1564
1565        coordinator.register_shutdown_hook(Box::new(FailHook)).await;
1566
1567        coordinator.shutdown();
1568        wait_terminated(&coordinator, Duration::from_secs(2)).await;
1569
1570        let results = coordinator.hook_results().await;
1571        assert_eq!(results.len(), 1);
1572        assert_eq!(results[0].hook_name, "fail_hook");
1573        assert!(!results[0].success);
1574        assert!(results[0].error.is_some());
1575        let err = results[0].error.as_ref().expect("error should be present");
1576        assert!(
1577            err.contains("catastrophic failure"),
1578            "error should contain the failure message"
1579        );
1580    }
1581
1582    #[tokio::test]
1583    async fn test_metrics_snapshot_writes_file() {
1584        let dir = tempfile::tempdir().expect("failed to create temp dir");
1585        let path = dir.path().join("metrics.bin");
1586
1587        let expected_data = b"metric1=42\nmetric2=100\n".to_vec();
1588        let expected_clone = expected_data.clone();
1589        let provider: Arc<dyn Fn() -> Vec<u8> + Send + Sync> =
1590            Arc::new(move || expected_clone.clone());
1591
1592        let hook =
1593            MetricsSnapshotHook::with_provider(provider, path.clone(), Duration::from_secs(5));
1594        let result = hook.on_shutdown().await;
1595
1596        assert!(result.is_ok());
1597        let written = std::fs::read(&path).expect("should be able to read metrics file");
1598        assert_eq!(written, expected_data);
1599    }
1600
1601    #[tokio::test]
1602    async fn test_metrics_snapshot_no_provider() {
1603        let hook = MetricsSnapshotHook::default();
1604        let result = hook.on_shutdown().await;
1605        assert!(result.is_ok(), "no-provider hook should succeed");
1606    }
1607
1608    #[tokio::test]
1609    async fn test_connection_drain_poll_interval() {
1610        // Use a connection count that will require multiple polls
1611        let conns = Arc::new(AtomicUsize::new(1));
1612        let poll_interval = Duration::from_millis(80);
1613        let hook = ConnectionDrainHook::new(conns.clone(), Duration::from_secs(5))
1614            .with_poll_interval(poll_interval);
1615
1616        // Spawn task to zero out connections after ~150ms
1617        let conns_clone = conns.clone();
1618        tokio::spawn(async move {
1619            tokio::time::sleep(Duration::from_millis(150)).await;
1620            conns_clone.store(0, Ordering::SeqCst);
1621        });
1622
1623        let start = Instant::now();
1624        let result = hook.on_shutdown().await;
1625        let elapsed = start.elapsed();
1626
1627        assert!(result.is_ok());
1628        // Should have polled at least once (~80ms) before connections hit 0 at ~150ms
1629        assert!(
1630            elapsed >= Duration::from_millis(100),
1631            "should have polled at least once before completion, elapsed {:?}",
1632            elapsed
1633        );
1634    }
1635}