use arc_swap::ArcSwap;
use parking_lot::Mutex;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tracing::{debug, info, warn};
use crate::error::{Error, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ShutdownReason {
Signal(i32),
Requested,
Error,
ResourceExhausted,
Forced,
}
impl std::fmt::Display for ShutdownReason {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Signal(sig) => write!(f, "Signal({sig})"),
Self::Requested => write!(f, "Requested"),
Self::Error => write!(f, "Error"),
Self::ResourceExhausted => write!(f, "ResourceExhausted"),
Self::Forced => write!(f, "Forced"),
}
}
}
#[derive(Debug, Clone)]
pub struct ShutdownHandle {
inner: Arc<ShutdownInner>,
subsystem_id: u64,
}
impl ShutdownHandle {
const fn new(inner: Arc<ShutdownInner>, subsystem_id: u64) -> Self {
Self {
inner,
subsystem_id,
}
}
#[must_use]
pub fn is_shutdown(&self) -> bool {
self.inner.is_shutdown()
}
pub async fn cancelled(&mut self) {
if self.inner.shutdown_initiated.load(Ordering::Relaxed) {
return;
}
#[cfg(feature = "tokio")]
{
let mut rx = self.inner.shutdown_tx.subscribe();
if self.is_shutdown() {
return;
}
let _ = rx.recv().await;
}
#[cfg(all(feature = "async-std", not(feature = "tokio")))]
{
let shutdown_flag = &self.inner.shutdown_initiated;
loop {
if shutdown_flag.load(Ordering::Acquire) {
break;
}
async_std::task::sleep(Duration::from_millis(10)).await;
}
}
}
#[must_use]
pub fn shutdown_reason(&self) -> Option<ShutdownReason> {
if self.is_shutdown() {
Some(**self.inner.shutdown_reason.load())
} else {
None
}
}
#[must_use]
pub fn shutdown_time(&self) -> Option<Instant> {
*self.inner.shutdown_time.lock()
}
#[must_use]
pub fn is_forced(&self) -> bool {
matches!(self.shutdown_reason(), Some(ShutdownReason::Forced))
}
pub fn ready(&self) {
self.inner.mark_subsystem_ready(self.subsystem_id);
}
#[must_use]
pub fn time_remaining(&self) -> Option<Duration> {
self.shutdown_time().and_then(|shutdown_time| {
let elapsed = shutdown_time.elapsed();
let timeout =
Duration::from_millis(self.inner.graceful_timeout_ms.load(Ordering::Acquire));
if elapsed < timeout {
timeout.checked_sub(elapsed)
} else {
None
}
})
}
}
#[derive(Debug)]
struct ShutdownInner {
shutdown_initiated: AtomicBool,
shutdown_reason: ArcSwap<ShutdownReason>,
shutdown_time: Mutex<Option<Instant>>,
graceful_timeout_ms: AtomicU64,
force_timeout_ms: AtomicU64,
kill_timeout_ms: AtomicU64,
subsystems: Mutex<Vec<SubsystemState>>,
#[cfg(feature = "tokio")]
shutdown_tx: tokio::sync::broadcast::Sender<ShutdownReason>,
}
#[derive(Debug)]
struct SubsystemState {
id: u64,
name: String,
ready: AtomicBool,
#[allow(dead_code)]
registered_at: Instant,
}
impl ShutdownInner {
fn new(graceful_timeout_ms: u64, force_timeout_ms: u64, kill_timeout_ms: u64) -> Self {
#[cfg(feature = "tokio")]
let (shutdown_tx, _) = tokio::sync::broadcast::channel(16);
Self {
shutdown_initiated: AtomicBool::new(false),
shutdown_reason: ArcSwap::new(Arc::new(ShutdownReason::Requested)),
shutdown_time: Mutex::new(None),
graceful_timeout_ms: AtomicU64::new(graceful_timeout_ms),
force_timeout_ms: AtomicU64::new(force_timeout_ms),
kill_timeout_ms: AtomicU64::new(kill_timeout_ms),
subsystems: Mutex::new(Vec::new()),
#[cfg(feature = "tokio")]
shutdown_tx,
}
}
#[must_use]
pub fn is_shutdown(&self) -> bool {
self.shutdown_initiated.load(Ordering::Relaxed)
}
#[must_use]
pub fn initiate_shutdown(&self, reason: ShutdownReason) -> bool {
if self
.shutdown_initiated
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
self.shutdown_reason.store(Arc::new(reason));
*self.shutdown_time.lock() = Some(Instant::now());
#[cfg(feature = "tokio")]
{
let _ = self.shutdown_tx.send(reason);
}
info!("Shutdown initiated: {}", reason);
true
} else {
debug!("Shutdown already initiated, ignoring additional request");
false
}
}
fn register_subsystem(&self, name: &str) -> u64 {
let id = fastrand::u64(..);
let state = SubsystemState {
id,
name: name.to_string(),
ready: AtomicBool::new(false),
registered_at: Instant::now(),
};
self.subsystems.lock().push(state);
debug!("Registered subsystem '{}' with ID {}", name, id);
id
}
fn mark_subsystem_ready(&self, subsystem_id: u64) {
let subsystems = self.subsystems.lock();
if let Some(subsystem) = subsystems.iter().find(|s| s.id == subsystem_id) {
subsystem.ready.store(true, Ordering::Relaxed);
debug!(
"Subsystem '{}' marked as ready for shutdown",
subsystem.name
);
}
}
fn are_all_subsystems_ready(&self) -> bool {
let subsystems = self.subsystems.lock();
subsystems.iter().all(|s| s.ready.load(Ordering::Relaxed))
}
fn get_subsystem_states(&self) -> Vec<(String, bool)> {
let subsystems = self.subsystems.lock();
subsystems
.iter()
.map(|s| (s.name.clone(), s.ready.load(Ordering::Relaxed)))
.collect()
}
}
#[derive(Debug)]
pub struct ShutdownCoordinator {
inner: Arc<ShutdownInner>,
}
impl ShutdownCoordinator {
#[must_use]
pub fn new(graceful_timeout_ms: u64, force_timeout_ms: u64, kill_timeout_ms: u64) -> Self {
Self {
inner: Arc::new(ShutdownInner::new(
graceful_timeout_ms,
force_timeout_ms,
kill_timeout_ms,
)),
}
}
pub fn create_handle<S: Into<String>>(&self, subsystem_name: S) -> ShutdownHandle {
let name = subsystem_name.into();
let subsystem_id = self.inner.register_subsystem(&name);
ShutdownHandle::new(Arc::clone(&self.inner), subsystem_id)
}
#[must_use]
pub fn initiate_shutdown(&self, reason: ShutdownReason) -> bool {
self.inner.initiate_shutdown(reason)
}
#[must_use]
pub fn is_shutdown(&self) -> bool {
self.inner.is_shutdown()
}
#[must_use]
pub fn get_reason(&self) -> Option<ShutdownReason> {
if self.is_shutdown() {
Some(**self.inner.shutdown_reason.load())
} else {
None
}
}
pub async fn wait_for_shutdown(&self) -> Result<()> {
if !self.is_shutdown() {
return Err(Error::invalid_state("Shutdown not initiated"));
}
let shutdown_time = *self.inner.shutdown_time.lock();
if shutdown_time.is_none() {
return Err(Error::invalid_state("Shutdown time not set"));
}
let graceful_timeout =
Duration::from_millis(self.inner.graceful_timeout_ms.load(Ordering::Acquire));
info!(
"Waiting for subsystems to shutdown gracefully (timeout: {:?})",
graceful_timeout
);
let start = Instant::now();
if self.inner.are_all_subsystems_ready() {
info!("All subsystems already shut down gracefully");
return Ok(());
}
let mut poll_interval = Duration::from_millis(1);
let max_poll_interval = Duration::from_millis(50);
while start.elapsed() < graceful_timeout {
if self.inner.are_all_subsystems_ready() {
info!(
"All subsystems shut down gracefully in {:?}",
start.elapsed()
);
return Ok(());
}
#[cfg(feature = "tokio")]
tokio::time::sleep(poll_interval).await;
#[cfg(all(feature = "async-std", not(feature = "tokio")))]
async_std::task::sleep(poll_interval).await;
poll_interval = (poll_interval * 2).min(max_poll_interval);
}
let states = self.inner.get_subsystem_states();
let not_ready: Vec<String> = states
.into_iter()
.filter_map(|(name, ready)| if ready { None } else { Some(name) })
.collect();
warn!(
"Graceful shutdown timeout exceeded. Subsystems not ready: {:?}",
not_ready
);
let _ = self.inner.initiate_shutdown(ShutdownReason::Forced);
let timeout_ms = u64::try_from(graceful_timeout.as_millis()).unwrap_or(u64::MAX);
Err(Error::timeout("Graceful shutdown", timeout_ms))
}
pub async fn wait_for_force_shutdown(&self) -> Result<()> {
let force_timeout =
Duration::from_millis(self.inner.force_timeout_ms.load(Ordering::Acquire));
warn!("Waiting for forced shutdown timeout: {:?}", force_timeout);
let start = Instant::now();
while start.elapsed() < force_timeout {
if self.inner.are_all_subsystems_ready() {
info!("All subsystems shut down during force phase");
return Ok(());
}
#[cfg(feature = "tokio")]
tokio::time::sleep(Duration::from_millis(50)).await;
#[cfg(all(feature = "async-std", not(feature = "tokio")))]
async_std::task::sleep(Duration::from_millis(50)).await;
}
let timeout_ms = u64::try_from(force_timeout.as_millis()).unwrap_or(u64::MAX);
Err(Error::timeout("Force shutdown", timeout_ms))
}
pub async fn wait_for_kill_shutdown(&self) -> Result<()> {
let kill_timeout =
Duration::from_millis(self.inner.kill_timeout_ms.load(Ordering::Acquire));
warn!("Waiting for kill shutdown timeout: {:?}", kill_timeout);
#[cfg(feature = "tokio")]
tokio::time::sleep(kill_timeout).await;
#[cfg(all(feature = "async-std", not(feature = "tokio")))]
async_std::task::sleep(kill_timeout).await;
let timeout_ms = u64::try_from(kill_timeout.as_millis()).unwrap_or(u64::MAX);
Err(Error::timeout("Kill shutdown", timeout_ms))
}
#[must_use]
pub fn get_stats(&self) -> ShutdownStats {
let subsystems = self.inner.get_subsystem_states();
let total_subsystems = subsystems.len();
let ready_subsystems = subsystems.iter().filter(|(_, ready)| *ready).count();
ShutdownStats {
is_shutdown: self.is_shutdown(),
reason: if self.is_shutdown() {
Some(**self.inner.shutdown_reason.load())
} else {
None
},
shutdown_time: *self.inner.shutdown_time.lock(),
total_subsystems,
ready_subsystems,
subsystem_states: subsystems,
}
}
pub fn update_timeouts(
&self,
graceful_timeout_ms: u64,
force_timeout_ms: u64,
kill_timeout_ms: u64,
) {
self.inner
.graceful_timeout_ms
.store(graceful_timeout_ms, Ordering::Release);
self.inner
.force_timeout_ms
.store(force_timeout_ms, Ordering::Release);
self.inner
.kill_timeout_ms
.store(kill_timeout_ms, Ordering::Release);
debug!(
"Updated shutdown timeouts: graceful={}ms, force={}ms, kill={}ms",
graceful_timeout_ms, force_timeout_ms, kill_timeout_ms
);
}
}
impl Clone for ShutdownCoordinator {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
#[derive(Debug, Clone)]
pub struct ShutdownStats {
pub is_shutdown: bool,
pub reason: Option<ShutdownReason>,
pub shutdown_time: Option<Instant>,
pub total_subsystems: usize,
pub ready_subsystems: usize,
pub subsystem_states: Vec<(String, bool)>,
}
impl ShutdownStats {
#[must_use]
#[allow(clippy::cast_precision_loss)]
pub fn progress(&self) -> f64 {
if self.total_subsystems == 0 {
1.0
} else {
self.ready_subsystems as f64 / self.total_subsystems as f64
}
}
#[must_use]
pub const fn is_complete(&self) -> bool {
self.total_subsystems > 0 && self.ready_subsystems == self.total_subsystems
}
#[must_use]
pub fn elapsed(&self) -> Option<Duration> {
self.shutdown_time.map(|t| t.elapsed())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[cfg(feature = "tokio")]
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn test_shutdown_coordination() {
let test_result = tokio::time::timeout(Duration::from_secs(5), async {
let coordinator = ShutdownCoordinator::new(100, 200, 300);
let handle1 = coordinator.create_handle("subsystem1");
let handle2 = coordinator.create_handle("subsystem2");
assert!(!coordinator.is_shutdown());
assert!(!handle1.is_shutdown());
assert!(coordinator.initiate_shutdown(ShutdownReason::Requested));
assert!(coordinator.is_shutdown());
assert!(handle1.is_shutdown());
assert!(handle1.is_shutdown());
assert!(handle2.is_shutdown());
handle1.ready();
handle2.ready();
let stats = coordinator.get_stats();
assert!(stats.is_complete());
let epsilon: f64 = 1e-6;
assert!((stats.progress() - 1.0).abs() < epsilon);
})
.await;
assert!(test_result.is_ok(), "Test timed out after 5 seconds");
}
#[cfg(feature = "tokio")]
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn test_shutdown_timeout() {
let test_result = tokio::time::timeout(Duration::from_secs(5), async {
let coordinator = ShutdownCoordinator::new(100, 200, 300);
let _handle1 = coordinator.create_handle("slow_subsystem");
let _ = coordinator.initiate_shutdown(ShutdownReason::Requested);
let result = coordinator.wait_for_shutdown().await;
assert!(result.is_err());
assert!(result.unwrap_err().is_timeout());
})
.await;
assert!(test_result.is_ok(), "Test timed out after 5 seconds");
}
#[cfg(all(feature = "async-std", not(feature = "tokio")))]
#[async_std::test]
async fn test_shutdown_timeout() {
let test_result = async_std::future::timeout(Duration::from_secs(5), async {
let coordinator = ShutdownCoordinator::new(100, 200, 300);
let _handle1 = coordinator.create_handle("slow_subsystem");
let _ = coordinator.initiate_shutdown(ShutdownReason::Requested);
let result = coordinator.wait_for_shutdown().await;
assert!(result.is_err());
assert!(result.unwrap_err().is_timeout());
})
.await;
assert!(test_result.is_ok(), "Test timed out after 5 seconds");
}
#[test]
fn test_shutdown_reason_display() {
assert_eq!(format!("{}", ShutdownReason::Signal(15)), "Signal(15)");
assert_eq!(format!("{}", ShutdownReason::Requested), "Requested");
assert_eq!(format!("{}", ShutdownReason::Error), "Error");
}
#[cfg(feature = "tokio")]
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn test_multiple_shutdown_initiation() {
let test_result = tokio::time::timeout(Duration::from_secs(5), async {
let coordinator = ShutdownCoordinator::new(5000, 10000, 15000);
assert!(coordinator.initiate_shutdown(ShutdownReason::Requested));
assert!(!coordinator.initiate_shutdown(ShutdownReason::Signal(15)));
assert!(!coordinator.initiate_shutdown(ShutdownReason::Error));
assert_eq!(coordinator.get_reason(), Some(ShutdownReason::Requested));
})
.await;
assert!(test_result.is_ok(), "Test timed out after 5 seconds");
}
#[cfg(all(feature = "async-std", not(feature = "tokio")))]
#[async_std::test]
async fn test_multiple_shutdown_initiation() {
let test_result = async_std::future::timeout(Duration::from_secs(5), async {
let coordinator = ShutdownCoordinator::new(5000, 10000, 15000);
assert!(coordinator.initiate_shutdown(ShutdownReason::Requested));
assert!(!coordinator.initiate_shutdown(ShutdownReason::Signal(15)));
assert!(!coordinator.initiate_shutdown(ShutdownReason::Error));
let stats = coordinator.get_stats();
assert_eq!(stats.reason, Some(ShutdownReason::Requested));
})
.await;
assert!(test_result.is_ok(), "Test timed out after 5 seconds");
}
#[test]
fn test_shutdown_stats() {
let coordinator = ShutdownCoordinator::new(5000, 10000, 15000);
let handle1 = coordinator.create_handle("test1");
let handle2 = coordinator.create_handle("test2");
let stats = coordinator.get_stats();
assert_eq!(stats.total_subsystems, 2);
assert_eq!(stats.ready_subsystems, 0);
assert!(!stats.is_complete());
let epsilon: f64 = 1e-6;
assert!((stats.progress() - 0.0).abs() < epsilon);
handle1.ready();
let stats = coordinator.get_stats();
assert_eq!(stats.ready_subsystems, 1);
assert!((stats.progress() - 0.5).abs() < epsilon);
handle2.ready();
let stats = coordinator.get_stats();
assert!(stats.is_complete());
assert!((stats.progress() - 1.0).abs() < epsilon);
}
}
#[cfg(all(feature = "async-std", not(feature = "tokio")))]
#[async_std::test]
async fn test_shutdown_coordination() {
let test_result = async_std::future::timeout(Duration::from_secs(5), async {
let coordinator = ShutdownCoordinator::new(100, 200, 300);
let handle1 = coordinator.create_handle("subsystem1");
let handle2 = coordinator.create_handle("subsystem2");
assert!(!coordinator.is_shutdown());
assert!(!handle1.is_shutdown());
assert!(coordinator.initiate_shutdown(ShutdownReason::Requested));
assert!(coordinator.is_shutdown());
assert!(handle1.is_shutdown());
assert!(handle1.is_shutdown());
assert!(handle2.is_shutdown());
handle1.ready();
handle2.ready();
let stats = coordinator.get_stats();
assert!(stats.is_complete());
let epsilon: f64 = 1e-6;
assert!((stats.progress() - 1.0).abs() < epsilon);
})
.await;
assert!(test_result.is_ok(), "Test timed out after 5 seconds");
}