use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::watch;
use crate::service::health::HealthState;
#[derive(Debug, Clone)]
pub struct ShutdownConfig {
pub drain_timeout: Duration,
pub grace_period: Duration,
}
impl Default for ShutdownConfig {
fn default() -> Self {
Self {
drain_timeout: Duration::from_secs(30),
grace_period: Duration::from_secs(5),
}
}
}
impl ShutdownConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub const fn with_drain_timeout(mut self, timeout: Duration) -> Self {
self.drain_timeout = timeout;
self
}
#[must_use]
pub const fn with_grace_period(mut self, period: Duration) -> Self {
self.grace_period = period;
self
}
}
pub struct ShutdownManager {
config: ShutdownConfig,
health_state: Arc<HealthState>,
active_streams: AtomicU32,
shutdown_triggered: AtomicBool,
shutdown_tx: watch::Sender<bool>,
shutdown_rx: watch::Receiver<bool>,
}
impl ShutdownManager {
#[must_use]
pub fn new(config: ShutdownConfig, health_state: Arc<HealthState>) -> Self {
let (shutdown_tx, shutdown_rx) = watch::channel(false);
Self {
config,
health_state,
active_streams: AtomicU32::new(0),
shutdown_triggered: AtomicBool::new(false),
shutdown_tx,
shutdown_rx,
}
}
pub fn stream_started(&self) {
self.active_streams.fetch_add(1, Ordering::SeqCst);
}
pub fn stream_finished(&self) {
self.active_streams.fetch_sub(1, Ordering::SeqCst);
}
#[must_use]
pub fn active_count(&self) -> u32 {
self.active_streams.load(Ordering::SeqCst)
}
#[must_use]
pub fn is_shutting_down(&self) -> bool {
self.shutdown_triggered.load(Ordering::SeqCst)
}
pub fn trigger_shutdown(&self) {
self.shutdown_triggered.store(true, Ordering::SeqCst);
self.health_state.set_draining(true);
let _ = self.shutdown_tx.send(true);
tracing::info!("Shutdown triggered, starting drain");
}
#[must_use]
pub fn subscribe(&self) -> watch::Receiver<bool> {
self.shutdown_rx.clone()
}
#[must_use]
pub const fn config(&self) -> &ShutdownConfig {
&self.config
}
pub async fn wait_for_drain(&self) -> bool {
let timeout = self.config.drain_timeout;
let check_interval = Duration::from_millis(100);
let start = std::time::Instant::now();
tracing::info!(
active_streams = self.active_count(),
timeout_secs = timeout.as_secs(),
"Waiting for streams to drain"
);
loop {
if self.active_count() == 0 {
tracing::info!("All streams drained successfully");
return true;
}
if start.elapsed() >= timeout {
tracing::warn!(
remaining_streams = self.active_count(),
"Drain timeout reached"
);
return false;
}
tokio::time::sleep(check_interval).await;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_shutdown_config_default() {
let config = ShutdownConfig::default();
assert_eq!(config.drain_timeout, Duration::from_secs(30));
assert_eq!(config.grace_period, Duration::from_secs(5));
}
#[test]
fn test_shutdown_config_builder() {
let config = ShutdownConfig::new()
.with_drain_timeout(Duration::from_secs(60))
.with_grace_period(Duration::from_secs(10));
assert_eq!(config.drain_timeout, Duration::from_secs(60));
assert_eq!(config.grace_period, Duration::from_secs(10));
}
#[test]
fn test_shutdown_manager_new() {
let health_state = Arc::new(HealthState::new());
let config = ShutdownConfig::default();
let manager = ShutdownManager::new(config, health_state);
assert_eq!(manager.active_count(), 0);
assert!(!manager.is_shutting_down());
}
#[test]
fn test_stream_counting() {
let health_state = Arc::new(HealthState::new());
let config = ShutdownConfig::default();
let manager = ShutdownManager::new(config, health_state);
assert_eq!(manager.active_count(), 0);
manager.stream_started();
assert_eq!(manager.active_count(), 1);
manager.stream_started();
assert_eq!(manager.active_count(), 2);
manager.stream_finished();
assert_eq!(manager.active_count(), 1);
manager.stream_finished();
assert_eq!(manager.active_count(), 0);
}
#[test]
fn test_trigger_shutdown() {
let health_state = Arc::new(HealthState::new());
let config = ShutdownConfig::default();
let manager = ShutdownManager::new(config, Arc::clone(&health_state));
assert!(!manager.is_shutting_down());
assert!(!health_state.is_draining());
manager.trigger_shutdown();
assert!(manager.is_shutting_down());
assert!(health_state.is_draining());
}
#[test]
fn test_subscribe() {
let health_state = Arc::new(HealthState::new());
let config = ShutdownConfig::default();
let manager = ShutdownManager::new(config, health_state);
let rx = manager.subscribe();
assert!(!*rx.borrow());
manager.trigger_shutdown();
assert!(rx.has_changed().is_ok());
}
#[tokio::test]
async fn test_drain_completes_when_empty() {
let health_state = Arc::new(HealthState::new());
let config = ShutdownConfig::new().with_drain_timeout(Duration::from_secs(1));
let manager = ShutdownManager::new(config, health_state);
let result = manager.wait_for_drain().await;
assert!(result);
}
#[tokio::test]
async fn test_drain_waits_for_streams() {
let health_state = Arc::new(HealthState::new());
let config = ShutdownConfig::new().with_drain_timeout(Duration::from_secs(2));
let manager = Arc::new(ShutdownManager::new(config, health_state));
manager.stream_started();
let manager_clone = Arc::clone(&manager);
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(100)).await;
manager_clone.stream_finished();
});
let result = manager.wait_for_drain().await;
assert!(result);
assert_eq!(manager.active_count(), 0);
}
#[tokio::test]
async fn test_drain_timeout_enforced() {
let health_state = Arc::new(HealthState::new());
let config = ShutdownConfig::new().with_drain_timeout(Duration::from_millis(100));
let manager = ShutdownManager::new(config, health_state);
manager.stream_started();
let result = manager.wait_for_drain().await;
assert!(!result);
assert_eq!(manager.active_count(), 1);
}
#[test]
fn test_config_accessor() {
let health_state = Arc::new(HealthState::new());
let config = ShutdownConfig::new().with_drain_timeout(Duration::from_secs(45));
let manager = ShutdownManager::new(config, health_state);
assert_eq!(manager.config().drain_timeout, Duration::from_secs(45));
}
#[tokio::test]
async fn test_concurrent_stream_registrations() {
let health_state = Arc::new(HealthState::new());
let config = ShutdownConfig::default();
let manager = Arc::new(ShutdownManager::new(config, health_state));
let mut handles = vec![];
for _ in 0..100 {
let manager_clone = Arc::clone(&manager);
handles.push(tokio::spawn(async move {
manager_clone.stream_started();
tokio::time::sleep(Duration::from_millis(1)).await;
manager_clone.stream_finished();
}));
}
for handle in handles {
handle.await.expect("task should complete");
}
assert_eq!(manager.active_count(), 0);
}
#[tokio::test]
async fn test_concurrent_stream_registrations_during_drain() {
let health_state = Arc::new(HealthState::new());
let config = ShutdownConfig::new().with_drain_timeout(Duration::from_secs(1));
let manager = Arc::new(ShutdownManager::new(config, health_state));
for _ in 0..5 {
manager.stream_started();
}
let manager_clone = Arc::clone(&manager);
let drain_handle = tokio::spawn(async move { manager_clone.wait_for_drain().await });
tokio::time::sleep(Duration::from_millis(50)).await;
for _ in 0..5 {
manager.stream_finished();
tokio::time::sleep(Duration::from_millis(50)).await;
}
let result = drain_handle.await.expect("drain should complete");
assert!(result);
}
#[tokio::test]
async fn test_drain_with_partial_completion() {
let health_state = Arc::new(HealthState::new());
let config = ShutdownConfig::new().with_drain_timeout(Duration::from_millis(200));
let manager = Arc::new(ShutdownManager::new(config, health_state));
for _ in 0..10 {
manager.stream_started();
}
let manager_clone = Arc::clone(&manager);
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(50)).await;
for _ in 0..5 {
manager_clone.stream_finished();
}
});
let result = manager.wait_for_drain().await;
assert!(!result);
assert_eq!(manager.active_count(), 5);
}
#[tokio::test]
async fn test_drain_with_slow_completion() {
let health_state = Arc::new(HealthState::new());
let config = ShutdownConfig::new().with_drain_timeout(Duration::from_secs(1));
let manager = Arc::new(ShutdownManager::new(config, health_state));
for _ in 0..3 {
manager.stream_started();
}
let manager_clone = Arc::clone(&manager);
tokio::spawn(async move {
for _ in 0..3 {
tokio::time::sleep(Duration::from_millis(100)).await;
manager_clone.stream_finished();
}
});
let result = manager.wait_for_drain().await;
assert!(result);
assert_eq!(manager.active_count(), 0);
}
#[tokio::test]
async fn test_multiple_shutdown_subscribers() {
let health_state = Arc::new(HealthState::new());
let config = ShutdownConfig::default();
let manager = Arc::new(ShutdownManager::new(config, health_state));
let mut receivers = vec![];
for _ in 0..5 {
receivers.push(manager.subscribe());
}
for rx in &receivers {
assert!(!*rx.borrow());
}
manager.trigger_shutdown();
for rx in receivers {
assert!(rx.has_changed().is_ok());
assert!(*rx.borrow());
}
}
#[tokio::test]
async fn test_subscribe_after_shutdown_triggered() {
let health_state = Arc::new(HealthState::new());
let config = ShutdownConfig::default();
let manager = ShutdownManager::new(config, health_state);
manager.trigger_shutdown();
let rx = manager.subscribe();
assert!(*rx.borrow());
}
#[test]
fn test_grace_period_in_config() {
let config = ShutdownConfig::new()
.with_drain_timeout(Duration::from_secs(30))
.with_grace_period(Duration::from_secs(10));
assert_eq!(config.grace_period, Duration::from_secs(10));
}
#[test]
fn test_zero_grace_period() {
let config = ShutdownConfig::new().with_grace_period(Duration::from_secs(0));
assert_eq!(config.grace_period, Duration::from_secs(0));
}
#[test]
fn test_stream_underflow_protection() {
let health_state = Arc::new(HealthState::new());
let config = ShutdownConfig::default();
let manager = ShutdownManager::new(config, health_state);
manager.stream_finished();
assert_eq!(manager.active_count(), u32::MAX);
}
#[test]
fn test_stream_count_large_values() {
let health_state = Arc::new(HealthState::new());
let config = ShutdownConfig::default();
let manager = ShutdownManager::new(config, health_state);
for _ in 0..1000 {
manager.stream_started();
}
assert_eq!(manager.active_count(), 1000);
for _ in 0..1000 {
manager.stream_finished();
}
assert_eq!(manager.active_count(), 0);
}
#[test]
fn test_shutdown_state_idempotent() {
let health_state = Arc::new(HealthState::new());
let config = ShutdownConfig::default();
let manager = ShutdownManager::new(config, health_state);
manager.trigger_shutdown();
manager.trigger_shutdown();
manager.trigger_shutdown();
assert!(manager.is_shutting_down());
}
#[tokio::test]
async fn test_drain_immediate_return_when_no_streams() {
let health_state = Arc::new(HealthState::new());
let config = ShutdownConfig::new().with_drain_timeout(Duration::from_secs(10));
let manager = ShutdownManager::new(config, health_state);
let start = std::time::Instant::now();
let result = manager.wait_for_drain().await;
let elapsed = start.elapsed();
assert!(result);
assert!(elapsed < Duration::from_millis(500));
}
#[tokio::test]
async fn test_drain_with_zero_timeout() {
let health_state = Arc::new(HealthState::new());
let config = ShutdownConfig::new().with_drain_timeout(Duration::from_millis(0));
let manager = ShutdownManager::new(config, health_state);
manager.stream_started();
let result = manager.wait_for_drain().await;
assert!(!result);
assert_eq!(manager.active_count(), 1);
}
}