use crate::error::{QuantumLogError, Result};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{broadcast, Notify, RwLock};
use tokio::time::timeout;
use tracing::{error, info, warn};
#[derive(Debug, Clone)]
pub struct ShutdownHandle {
shutdown_tx: broadcast::Sender<ShutdownSignal>,
state: Arc<RwLock<ShutdownState>>,
completion_notify: Arc<Notify>,
timeout_duration: Duration,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ShutdownSignal {
Graceful,
Force,
Immediate,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ShutdownState {
Running,
Shutting,
Shutdown,
Failed(String),
}
pub struct ShutdownListener {
shutdown_rx: broadcast::Receiver<ShutdownSignal>,
component_name: String,
}
#[derive(Debug, Clone, Default)]
pub struct ShutdownStats {
pub start_time: Option<std::time::Instant>,
pub end_time: Option<std::time::Instant>,
pub processed_logs: u64,
pub flushed_batches: u32,
pub shutdown_components: u32,
pub failed_components: u32,
}
impl ShutdownHandle {
pub fn new(timeout_duration: Duration) -> Self {
let (shutdown_tx, _) = broadcast::channel(16);
Self {
shutdown_tx,
state: Arc::new(RwLock::new(ShutdownState::Running)),
completion_notify: Arc::new(Notify::new()),
timeout_duration,
}
}
pub fn create_listener(&self, component_name: impl Into<String>) -> ShutdownListener {
ShutdownListener {
shutdown_rx: self.shutdown_tx.subscribe(),
component_name: component_name.into(),
}
}
pub async fn shutdown_graceful(&self) -> Result<ShutdownStats> {
self.shutdown_with_signal(ShutdownSignal::Graceful).await
}
pub async fn shutdown_force(&self) -> Result<ShutdownStats> {
self.shutdown_with_signal(ShutdownSignal::Force).await
}
pub async fn shutdown_immediate(&self) -> Result<ShutdownStats> {
self.shutdown_with_signal(ShutdownSignal::Immediate).await
}
async fn shutdown_with_signal(&self, signal: ShutdownSignal) -> Result<ShutdownStats> {
{
let mut state = self.state.write().await;
match *state {
ShutdownState::Running => {
*state = ShutdownState::Shutting;
info!("Starting {} shutdown", signal_name(&signal));
}
ShutdownState::Shutting => {
warn!("Shutdown already in progress");
return Err(QuantumLogError::ShutdownInProgress);
}
ShutdownState::Shutdown => {
warn!("Already shutdown");
return Err(QuantumLogError::AlreadyShutdown);
}
ShutdownState::Failed(ref reason) => {
warn!("Previous shutdown failed: {}", reason);
return Err(QuantumLogError::ShutdownFailed(reason.clone()));
}
}
}
let stats = Arc::new(RwLock::new(ShutdownStats {
start_time: Some(std::time::Instant::now()),
..Default::default()
}));
if let Err(e) = self.shutdown_tx.send(signal.clone()) {
error!("Failed to send shutdown signal: {}", e);
let mut state = self.state.write().await;
*state = ShutdownState::Failed(format!("Failed to send signal: {}", e));
return Err(QuantumLogError::ShutdownFailed(format!(
"Signal send failed: {}",
e
)));
}
let result = match signal {
ShutdownSignal::Immediate => {
Ok(())
}
_ => {
timeout(self.timeout_duration, self.completion_notify.notified())
.await
.map_err(|_| QuantumLogError::ShutdownTimeout)
}
};
let final_stats = {
let mut stats_guard = stats.write().await;
stats_guard.end_time = Some(std::time::Instant::now());
stats_guard.clone()
};
match result {
Ok(_) => {
let mut state = self.state.write().await;
*state = ShutdownState::Shutdown;
info!("Shutdown completed successfully");
Ok(final_stats)
}
Err(e) => {
let mut state = self.state.write().await;
*state = ShutdownState::Failed(e.to_string());
error!("Shutdown failed: {}", e);
Err(e)
}
}
}
pub fn notify_completion(&self) {
self.completion_notify.notify_waiters();
}
pub async fn get_state(&self) -> ShutdownState {
self.state.read().await.clone()
}
pub async fn is_shutting_down(&self) -> bool {
matches!(
*self.state.read().await,
ShutdownState::Shutting | ShutdownState::Shutdown
)
}
pub async fn is_shutdown(&self) -> bool {
matches!(*self.state.read().await, ShutdownState::Shutdown)
}
pub fn set_timeout(&mut self, timeout: Duration) {
self.timeout_duration = timeout;
}
}
impl ShutdownListener {
pub async fn wait_for_shutdown(&mut self) -> Result<ShutdownSignal> {
match self.shutdown_rx.recv().await {
Ok(signal) => {
info!(
"Component '{}' received shutdown signal: {:?}",
self.component_name, signal
);
Ok(signal)
}
Err(broadcast::error::RecvError::Closed) => {
warn!(
"Shutdown channel closed for component '{}'",
self.component_name
);
Err(QuantumLogError::ShutdownChannelClosed)
}
Err(broadcast::error::RecvError::Lagged(skipped)) => {
warn!(
"Component '{}' lagged behind, skipped {} signals",
self.component_name, skipped
);
Box::pin(self.wait_for_shutdown()).await
}
}
}
pub fn try_recv_shutdown(&mut self) -> Option<ShutdownSignal> {
match self.shutdown_rx.try_recv() {
Ok(signal) => {
info!(
"Component '{}' received shutdown signal: {:?}",
self.component_name, signal
);
Some(signal)
}
Err(broadcast::error::TryRecvError::Empty) => None,
Err(broadcast::error::TryRecvError::Closed) => {
warn!(
"Shutdown channel closed for component '{}'",
self.component_name
);
Some(ShutdownSignal::Immediate)
}
Err(broadcast::error::TryRecvError::Lagged(skipped)) => {
warn!(
"Component '{}' lagged behind, skipped {} signals",
self.component_name, skipped
);
Some(ShutdownSignal::Force)
}
}
}
pub fn component_name(&self) -> &str {
&self.component_name
}
}
fn signal_name(signal: &ShutdownSignal) -> &'static str {
match signal {
ShutdownSignal::Graceful => "graceful",
ShutdownSignal::Force => "force",
ShutdownSignal::Immediate => "immediate",
}
}
#[derive(Debug, Clone)]
pub struct ShutdownTimeouts {
pub graceful: Duration,
pub force: Duration,
pub component: Duration,
}
impl Default for ShutdownTimeouts {
fn default() -> Self {
Self {
graceful: Duration::from_secs(30),
force: Duration::from_secs(10),
component: Duration::from_secs(5),
}
}
}
pub struct ShutdownCoordinator {
handle: ShutdownHandle,
components: Arc<RwLock<Vec<String>>>,
}
impl ShutdownCoordinator {
pub fn new(timeouts: ShutdownTimeouts) -> Self {
Self {
handle: ShutdownHandle::new(timeouts.graceful),
components: Arc::new(RwLock::new(Vec::new())),
}
}
pub async fn register_component(&self, name: impl Into<String>) -> ShutdownListener {
let name = name.into();
self.components.write().await.push(name.clone());
self.handle.create_listener(name)
}
pub fn handle(&self) -> &ShutdownHandle {
&self.handle
}
pub async fn get_components(&self) -> Vec<String> {
self.components.read().await.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::time::sleep;
#[tokio::test]
async fn test_shutdown_handle_creation() {
let handle = ShutdownHandle::new(Duration::from_secs(5));
assert!(matches!(handle.get_state().await, ShutdownState::Running));
}
#[tokio::test]
async fn test_shutdown_listener() {
let handle = ShutdownHandle::new(Duration::from_secs(5));
let mut listener = handle.create_listener("test_component");
let handle_clone = handle.clone();
tokio::spawn(async move {
sleep(Duration::from_millis(100)).await;
let _ = handle_clone.shutdown_tx.send(ShutdownSignal::Graceful);
});
let signal = listener.wait_for_shutdown().await.unwrap();
assert_eq!(signal, ShutdownSignal::Graceful);
}
#[tokio::test]
async fn test_shutdown_coordinator() {
let coordinator = ShutdownCoordinator::new(ShutdownTimeouts::default());
let _listener1 = coordinator.register_component("component1").await;
let _listener2 = coordinator.register_component("component2").await;
let components = coordinator.get_components().await;
assert_eq!(components.len(), 2);
assert!(components.contains(&"component1".to_string()));
assert!(components.contains(&"component2".to_string()));
}
#[tokio::test]
async fn test_shutdown_states() {
let handle = ShutdownHandle::new(Duration::from_secs(1));
assert!(!handle.is_shutting_down().await);
assert!(!handle.is_shutdown().await);
handle.notify_completion();
}
}