bzzz-core 0.1.0

Bzzz core library - Declarative orchestration engine for AI Agents
Documentation
//! Graceful Shutdown Module
//!
//! Handles graceful shutdown with SIGTERM/SIGINT support.

use std::sync::Arc;
use std::time::Duration;

use tokio::signal;
use tokio::sync::{broadcast, RwLock};
use tokio::time::timeout;

/// Shutdown signal
#[derive(Debug)]
pub struct ShutdownSignal {
    tx: broadcast::Sender<()>,
    timeout_secs: u64,
}

impl ShutdownSignal {
    /// Create a new shutdown signal
    pub fn new(timeout_secs: u64) -> Self {
        let (tx, _rx) = broadcast::channel(1);
        ShutdownSignal { tx, timeout_secs }
    }

    /// Get a sender to trigger shutdown
    pub fn sender(&self) -> broadcast::Sender<()> {
        self.tx.clone()
    }

    /// Get a receiver to wait for shutdown
    pub fn receiver(&self) -> broadcast::Receiver<()> {
        self.tx.subscribe()
    }

    /// Trigger shutdown
    pub fn trigger(&self) {
        let _ = self.tx.send(());
    }

    /// Wait for shutdown signal
    pub async fn wait(&self) {
        let mut rx = self.tx.subscribe();
        let _ = rx.recv().await;
    }
}

/// Graceful shutdown manager
pub struct GracefulShutdown {
    signal: Arc<ShutdownSignal>,
    running: Arc<RwLock<bool>>,
    active_tasks: Arc<RwLock<usize>>,
}

impl GracefulShutdown {
    /// Create a new graceful shutdown manager
    pub fn new(timeout_secs: u64) -> Self {
        GracefulShutdown {
            signal: Arc::new(ShutdownSignal::new(timeout_secs)),
            running: Arc::new(RwLock::new(true)),
            active_tasks: Arc::new(RwLock::new(0)),
        }
    }

    /// Register a task
    pub async fn register_task(&self) {
        let mut tasks = self.active_tasks.write().await;
        *tasks += 1;
    }

    /// Complete a task
    pub async fn complete_task(&self) {
        let mut tasks = self.active_tasks.write().await;
        if *tasks > 0 {
            *tasks -= 1;
        }
    }

    /// Get active task count
    pub async fn active_tasks(&self) -> usize {
        *self.active_tasks.read().await
    }

    /// Check if running
    pub async fn is_running(&self) -> bool {
        *self.running.read().await
    }

    /// Install signal handlers (SIGTERM, SIGINT)
    pub fn install_signal_handlers(&self) {
        let signal = self.signal.clone();
        let running = self.running.clone();

        tokio::spawn(async move {
            let ctrl_c = async {
                signal::ctrl_c()
                    .await
                    .expect("Failed to install Ctrl+C handler");
            };

            #[cfg(unix)]
            let terminate = async {
                signal::unix::signal(signal::unix::SignalKind::terminate())
                    .expect("Failed to install signal handler")
                    .recv()
                    .await;
            };

            #[cfg(not(unix))]
            let terminate = std::future::pending::<()>();

            tokio::select! {
                _ = ctrl_c => {},
                _ = terminate => {},
            }

            println!("Shutdown signal received, stopping gracefully...");

            // Mark as not running
            {
                let mut r = running.write().await;
                *r = false;
            }

            // Trigger shutdown
            signal.trigger();
        });
    }

    /// Wait for shutdown signal or timeout
    pub async fn wait_for_shutdown(&self) {
        let mut rx = self.signal.receiver();
        let timeout_duration = Duration::from_secs(self.signal.timeout_secs);

        match timeout(timeout_duration, rx.recv()).await {
            Ok(_) => {
                println!("Shutdown completed gracefully");
            }
            Err(_) => {
                println!("Shutdown timeout reached, forcing exit");
            }
        }
    }

    /// Wait for all tasks to complete or timeout
    pub async fn wait_for_tasks(&self) {
        let timeout_duration = Duration::from_secs(self.signal.timeout_secs);
        let start = std::time::Instant::now();

        loop {
            let active = self.active_tasks().await;
            if active == 0 {
                break;
            }

            if start.elapsed() > timeout_duration {
                println!(
                    "{} tasks still running after timeout, forcing shutdown",
                    active
                );
                break;
            }

            tokio::time::sleep(Duration::from_millis(100)).await;
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_shutdown_signal_creation() {
        let signal = ShutdownSignal::new(30);
        assert_eq!(signal.timeout_secs, 30);
    }

    #[tokio::test]
    async fn test_shutdown_signal_trigger() {
        let signal = ShutdownSignal::new(5);
        let mut rx = signal.receiver();

        signal.trigger();

        let result = rx.recv().await;
        assert!(result.is_ok());
    }

    #[tokio::test]
    async fn test_graceful_shutdown_creation() {
        let shutdown = GracefulShutdown::new(30);
        assert!(shutdown.is_running().await);
        assert_eq!(shutdown.active_tasks().await, 0);
    }

    #[tokio::test]
    async fn test_task_tracking() {
        let shutdown = GracefulShutdown::new(30);

        shutdown.register_task().await;
        assert_eq!(shutdown.active_tasks().await, 1);

        shutdown.register_task().await;
        assert_eq!(shutdown.active_tasks().await, 2);

        shutdown.complete_task().await;
        assert_eq!(shutdown.active_tasks().await, 1);
    }
}