use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{broadcast, RwLock};
use tokio::time::timeout;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ShutdownSignal {
Terminate,
Interrupt,
Manual,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ShutdownState {
Running,
Draining,
Stopped,
}
pub struct ShutdownCoordinator {
shutdown_tx: broadcast::Sender<ShutdownSignal>,
state: Arc<RwLock<ShutdownState>>,
pending_operations: Arc<RwLock<usize>>,
timeout: Duration,
}
impl ShutdownCoordinator {
pub fn new(timeout: Duration) -> Self {
let (shutdown_tx, _) = broadcast::channel(100);
Self {
shutdown_tx,
state: Arc::new(RwLock::new(ShutdownState::Running)),
pending_operations: Arc::new(RwLock::new(0)),
timeout,
}
}
pub fn subscribe(&self) -> broadcast::Receiver<ShutdownSignal> {
self.shutdown_tx.subscribe()
}
pub async fn shutdown(&self, signal: ShutdownSignal) -> Result<(), String> {
{
let mut state = self.state.write().await;
if *state != ShutdownState::Running {
return Err("Shutdown already in progress".to_string());
}
*state = ShutdownState::Draining;
}
let _ = self.shutdown_tx.send(signal);
let drain_result = timeout(self.timeout, self.drain_pending_operations()).await;
match drain_result {
Ok(_) => {
let mut state = self.state.write().await;
*state = ShutdownState::Stopped;
Ok(())
}
Err(_) => {
let mut state = self.state.write().await;
*state = ShutdownState::Stopped;
Err(format!(
"Shutdown timeout exceeded, {} operations incomplete",
*self.pending_operations.read().await
))
}
}
}
async fn drain_pending_operations(&self) {
loop {
let pending = *self.pending_operations.read().await;
if pending == 0 {
break;
}
tokio::time::sleep(Duration::from_millis(100)).await;
}
}
pub async fn register_operation(&self) -> bool {
let state = self.state.read().await;
if *state != ShutdownState::Running {
return false;
}
let mut pending = self.pending_operations.write().await;
*pending += 1;
true
}
pub async fn unregister_operation(&self) {
let mut pending = self.pending_operations.write().await;
if *pending > 0 {
*pending -= 1;
}
}
pub async fn state(&self) -> ShutdownState {
*self.state.read().await
}
pub async fn pending_operations(&self) -> usize {
*self.pending_operations.read().await
}
pub async fn is_shutting_down(&self) -> bool {
let state = self.state.read().await;
*state != ShutdownState::Running
}
}
pub struct OperationGuard {
coordinator: Arc<ShutdownCoordinator>,
}
impl OperationGuard {
pub async fn new(coordinator: Arc<ShutdownCoordinator>) -> Option<Self> {
if coordinator.register_operation().await {
Some(Self { coordinator })
} else {
None
}
}
}
impl Drop for OperationGuard {
fn drop(&mut self) {
let coordinator = self.coordinator.clone();
tokio::spawn(async move {
coordinator.unregister_operation().await;
});
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_shutdown_coordinator_creation() {
let coordinator = ShutdownCoordinator::new(Duration::from_secs(5));
assert_eq!(coordinator.state().await, ShutdownState::Running);
assert_eq!(coordinator.pending_operations().await, 0);
}
#[tokio::test]
async fn test_operation_registration() {
let coordinator = ShutdownCoordinator::new(Duration::from_secs(5));
assert!(coordinator.register_operation().await);
assert_eq!(coordinator.pending_operations().await, 1);
coordinator.unregister_operation().await;
assert_eq!(coordinator.pending_operations().await, 0);
}
#[tokio::test]
async fn test_shutdown_blocks_new_operations() {
let coordinator = Arc::new(ShutdownCoordinator::new(Duration::from_secs(5)));
let shutdown_handle = {
let coordinator = coordinator.clone();
tokio::spawn(async move {
coordinator.shutdown(ShutdownSignal::Manual).await
})
};
tokio::time::sleep(Duration::from_millis(100)).await;
assert!(!coordinator.register_operation().await);
shutdown_handle.await.unwrap().unwrap();
}
#[tokio::test]
async fn test_shutdown_waits_for_operations() {
let coordinator = Arc::new(ShutdownCoordinator::new(Duration::from_secs(5)));
assert!(coordinator.register_operation().await);
assert!(coordinator.register_operation().await);
assert_eq!(coordinator.pending_operations().await, 2);
let shutdown_coordinator = coordinator.clone();
let shutdown_handle = tokio::spawn(async move {
shutdown_coordinator.shutdown(ShutdownSignal::Manual).await
});
tokio::time::sleep(Duration::from_millis(100)).await;
assert_eq!(coordinator.state().await, ShutdownState::Draining);
coordinator.unregister_operation().await;
coordinator.unregister_operation().await;
let result = shutdown_handle.await.unwrap();
assert!(result.is_ok());
assert_eq!(coordinator.state().await, ShutdownState::Stopped);
}
#[tokio::test]
async fn test_shutdown_timeout() {
let coordinator = Arc::new(ShutdownCoordinator::new(Duration::from_millis(200)));
assert!(coordinator.register_operation().await);
let result = coordinator.shutdown(ShutdownSignal::Manual).await;
assert!(result.is_err());
assert_eq!(coordinator.state().await, ShutdownState::Stopped);
}
#[tokio::test]
async fn test_shutdown_signal_broadcast() {
let coordinator = ShutdownCoordinator::new(Duration::from_secs(5));
let mut receiver = coordinator.subscribe();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(100)).await;
let _ = coordinator.shutdown(ShutdownSignal::Terminate).await;
});
let signal = receiver.recv().await.unwrap();
assert_eq!(signal, ShutdownSignal::Terminate);
}
#[tokio::test]
async fn test_multiple_shutdown_attempts() {
let coordinator = Arc::new(ShutdownCoordinator::new(Duration::from_secs(5)));
let result = coordinator.shutdown(ShutdownSignal::Manual).await;
assert!(result.is_ok());
let coordinator2 = coordinator.clone();
let result = coordinator2.shutdown(ShutdownSignal::Manual).await;
assert!(result.is_err());
}
}