use crate::{ZoeyError, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{
atomic::{AtomicBool, AtomicU64, Ordering},
Arc, RwLock,
};
use std::time::{Duration, Instant};
use tokio::sync::{broadcast, Semaphore};
use tracing::{debug, error, info, warn};
pub type ShutdownReceiver = broadcast::Receiver<ShutdownSignal>;
pub type ShutdownSender = broadcast::Sender<ShutdownSignal>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ShutdownSignal {
Graceful,
Immediate,
Checkpoint,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ShutdownState {
pub in_flight_requests: u64,
pub pending_tasks: u64,
pub stopped_services: Vec<String>,
pub running_services: Vec<String>,
pub shutdown_started_at: Option<i64>,
pub is_shutting_down: bool,
}
pub struct ShutdownManager {
sender: ShutdownSender,
shutting_down: Arc<AtomicBool>,
in_flight: Arc<AtomicU64>,
drain_semaphore: Arc<Semaphore>,
hooks: Arc<RwLock<Vec<Arc<dyn ShutdownHook>>>>,
timeout: Duration,
state_persister: Arc<RwLock<Option<Arc<dyn StatePersister>>>>,
}
impl ShutdownManager {
pub fn new(timeout: Duration) -> Self {
let (sender, _) = broadcast::channel(16);
Self {
sender,
shutting_down: Arc::new(AtomicBool::new(false)),
in_flight: Arc::new(AtomicU64::new(0)),
drain_semaphore: Arc::new(Semaphore::new(1000)), hooks: Arc::new(RwLock::new(Vec::new())),
timeout,
state_persister: Arc::new(RwLock::new(None)),
}
}
pub fn subscribe(&self) -> ShutdownReceiver {
self.sender.subscribe()
}
pub fn is_shutting_down(&self) -> bool {
self.shutting_down.load(Ordering::SeqCst)
}
pub fn register_hook<H: ShutdownHook + 'static>(&self, hook: H) {
self.hooks.write().unwrap().push(Arc::new(hook));
}
pub fn set_state_persister<P: StatePersister + 'static>(&self, persister: P) {
*self.state_persister.write().unwrap() = Some(Arc::new(persister));
}
pub fn clear_state_persister(&self) {
*self.state_persister.write().unwrap() = None;
}
pub fn track_request(&self) -> Option<RequestGuard> {
if self.shutting_down.load(Ordering::SeqCst) {
return None;
}
self.in_flight.fetch_add(1, Ordering::SeqCst);
Some(RequestGuard {
counter: Arc::clone(&self.in_flight),
})
}
pub fn in_flight_count(&self) -> u64 {
self.in_flight.load(Ordering::SeqCst)
}
pub async fn shutdown(&self) -> Result<ShutdownState> {
info!("Initiating graceful shutdown...");
if self.shutting_down.swap(true, Ordering::SeqCst) {
warn!("Shutdown already in progress");
return Err(ZoeyError::other("Shutdown already in progress"));
}
let shutdown_start = Instant::now();
let shutdown_started_at = chrono::Utc::now().timestamp();
let _ = self.sender.send(ShutdownSignal::Graceful);
info!(
"Waiting for {} in-flight requests to complete...",
self.in_flight.load(Ordering::SeqCst)
);
let drain_result = self.drain_requests().await;
if let Err(e) = drain_result {
warn!("Failed to drain all requests: {}", e);
}
let persister_opt = self.state_persister.read().unwrap().clone();
if let Some(persister) = persister_opt {
info!("Persisting runtime state...");
if let Err(e) = persister.persist_state().await {
error!("Failed to persist state: {}", e);
}
}
let mut stopped_services = Vec::new();
let hooks: Vec<Arc<dyn ShutdownHook>> = self.hooks.read().unwrap().clone();
for hook in hooks {
let name = hook.name().to_string();
info!("Running shutdown hook: {}", name);
match tokio::time::timeout(Duration::from_secs(30), hook.on_shutdown()).await {
Ok(Ok(())) => {
stopped_services.push(name);
}
Ok(Err(e)) => {
error!("Shutdown hook '{}' failed: {}", name, e);
}
Err(_) => {
error!("Shutdown hook '{}' timed out", name);
}
}
}
let elapsed = shutdown_start.elapsed();
info!("Shutdown completed in {:?}", elapsed);
Ok(ShutdownState {
in_flight_requests: self.in_flight.load(Ordering::SeqCst),
pending_tasks: 0,
stopped_services,
running_services: Vec::new(),
shutdown_started_at: Some(shutdown_started_at),
is_shutting_down: false,
})
}
async fn drain_requests(&self) -> Result<()> {
let start = Instant::now();
loop {
let count = self.in_flight.load(Ordering::SeqCst);
if count == 0 {
debug!("All requests drained");
return Ok(());
}
if start.elapsed() > self.timeout {
warn!(
"Drain timeout reached with {} requests still in flight",
count
);
return Err(ZoeyError::other(format!(
"Timeout: {} requests still in flight",
count
)));
}
debug!("Waiting for {} in-flight requests...", count);
tokio::time::sleep(Duration::from_millis(100)).await;
}
}
pub async fn checkpoint(&self) -> Result<()> {
info!("Creating checkpoint...");
let _ = self.sender.send(ShutdownSignal::Checkpoint);
let persister_opt = self.state_persister.read().unwrap().clone();
if let Some(persister) = persister_opt {
persister.persist_state().await?;
}
info!("Checkpoint completed");
Ok(())
}
}
impl Default for ShutdownManager {
fn default() -> Self {
Self::new(Duration::from_secs(30))
}
}
pub struct RequestGuard {
counter: Arc<AtomicU64>,
}
impl Drop for RequestGuard {
fn drop(&mut self) {
self.counter.fetch_sub(1, Ordering::SeqCst);
}
}
#[async_trait::async_trait]
pub trait ShutdownHook: Send + Sync {
fn name(&self) -> &str;
async fn on_shutdown(&self) -> Result<()>;
fn priority(&self) -> i32 {
0
}
}
#[async_trait::async_trait]
pub trait StatePersister: Send + Sync {
async fn persist_state(&self) -> Result<()>;
async fn restore_state(&self) -> Result<()>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PersistableRuntimeState {
pub agent_id: String,
pub current_run_id: Option<String>,
pub settings: HashMap<String, serde_json::Value>,
pub state_cache: HashMap<String, serde_json::Value>,
pub pending_tasks: Vec<String>,
pub timestamp: i64,
}
pub struct FileStatePersister {
path: std::path::PathBuf,
state_provider: Arc<dyn Fn() -> PersistableRuntimeState + Send + Sync>,
}
impl FileStatePersister {
pub fn new<F>(path: std::path::PathBuf, state_provider: F) -> Self
where
F: Fn() -> PersistableRuntimeState + Send + Sync + 'static,
{
Self {
path,
state_provider: Arc::new(state_provider),
}
}
}
#[async_trait::async_trait]
impl StatePersister for FileStatePersister {
async fn persist_state(&self) -> Result<()> {
let state = (self.state_provider)();
let json = serde_json::to_string_pretty(&state)
.map_err(|e| ZoeyError::other(format!("Failed to serialize state: {}", e)))?;
tokio::fs::write(&self.path, json)
.await
.map_err(|e| ZoeyError::other(format!("Failed to write state file: {}", e)))?;
info!("State persisted to {:?}", self.path);
Ok(())
}
async fn restore_state(&self) -> Result<()> {
Ok(())
}
}
pub struct DatabaseStatePersister<A> {
adapter: Arc<A>,
agent_id: uuid::Uuid,
state_provider: Arc<dyn Fn() -> PersistableRuntimeState + Send + Sync>,
}
impl<A: Send + Sync + 'static> DatabaseStatePersister<A> {
pub fn new<F>(adapter: Arc<A>, agent_id: uuid::Uuid, state_provider: F) -> Self
where
F: Fn() -> PersistableRuntimeState + Send + Sync + 'static,
{
Self {
adapter,
agent_id,
state_provider: Arc::new(state_provider),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_shutdown_manager_creation() {
let manager = ShutdownManager::new(Duration::from_secs(5));
assert!(!manager.is_shutting_down());
assert_eq!(manager.in_flight_count(), 0);
}
#[tokio::test]
async fn test_request_tracking() {
let manager = ShutdownManager::new(Duration::from_secs(5));
{
let _guard = manager.track_request().unwrap();
assert_eq!(manager.in_flight_count(), 1);
let _guard2 = manager.track_request().unwrap();
assert_eq!(manager.in_flight_count(), 2);
}
assert_eq!(manager.in_flight_count(), 0);
}
#[tokio::test]
async fn test_shutdown_blocks_new_requests() {
let manager = ShutdownManager::new(Duration::from_millis(100));
let manager_clone = Arc::new(manager);
let m = Arc::clone(&manager_clone);
tokio::spawn(async move {
let _ = m.shutdown().await;
});
tokio::time::sleep(Duration::from_millis(10)).await;
assert!(manager_clone.is_shutting_down());
}
struct TestHook {
name: String,
executed: Arc<AtomicBool>,
}
#[async_trait::async_trait]
impl ShutdownHook for TestHook {
fn name(&self) -> &str {
&self.name
}
async fn on_shutdown(&self) -> Result<()> {
self.executed.store(true, Ordering::SeqCst);
Ok(())
}
}
#[tokio::test]
async fn test_shutdown_hooks() {
let manager = ShutdownManager::new(Duration::from_secs(1));
let executed = Arc::new(AtomicBool::new(false));
manager.register_hook(TestHook {
name: "test_hook".to_string(),
executed: Arc::clone(&executed),
});
let _ = manager.shutdown().await;
assert!(executed.load(Ordering::SeqCst));
}
}