use crate::core::providers::Provider;
use crate::utils::error::{GatewayError, Result};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use tokio::time::interval;
use tracing::{debug, error, info, warn};
pub struct HealthChecker {
providers: Arc<RwLock<HashMap<String, Arc<dyn Provider>>>>,
statuses: Arc<RwLock<HashMap<String, ProviderHealthStatus>>>,
check_interval: Duration,
timeout: Duration,
max_failures: u32,
}
#[derive(Debug, Clone)]
pub struct ProviderHealthStatus {
pub healthy: bool,
pub last_success: Option<Instant>,
pub last_error: Option<String>,
pub response_time: Option<Duration>,
pub consecutive_failures: u32,
pub last_check: Instant,
}
impl Default for ProviderHealthStatus {
fn default() -> Self {
Self {
healthy: true,
last_success: None,
last_error: None,
response_time: None,
consecutive_failures: 0,
last_check: Instant::now(),
}
}
}
impl HealthChecker {
pub async fn new(providers: Arc<RwLock<HashMap<String, Arc<dyn Provider>>>>) -> Result<Self> {
info!("Creating health checker");
let checker = Self {
providers,
statuses: Arc::new(RwLock::new(HashMap::new())),
check_interval: Duration::from_secs(30),
timeout: Duration::from_secs(10),
max_failures: 3,
};
checker.start_background_checks().await?;
Ok(checker)
}
async fn start_background_checks(&self) -> Result<()> {
let providers = self.providers.clone();
let statuses = self.statuses.clone();
let check_interval = self.check_interval;
let timeout = self.timeout;
let max_failures = self.max_failures;
tokio::spawn(async move {
let mut interval = interval(check_interval);
loop {
interval.tick().await;
let providers_guard = providers.read().await;
let provider_list: Vec<(String, Arc<dyn Provider>)> = providers_guard
.iter()
.map(|(name, provider)| (name.clone(), provider.clone()))
.collect();
drop(providers_guard);
for (name, provider) in provider_list {
let start_time = Instant::now();
match tokio::time::timeout(timeout, provider.health_check()).await {
Ok(Ok(())) => {
let response_time = start_time.elapsed();
let mut statuses_guard = statuses.write().await;
let status = statuses_guard.entry(name.clone()).or_default();
status.healthy = true;
status.last_success = Some(Instant::now());
status.response_time = Some(response_time);
status.consecutive_failures = 0;
status.last_check = Instant::now();
status.last_error = None;
debug!(
"Health check passed for provider {}: {}ms",
name,
response_time.as_millis()
);
}
Ok(Err(e)) => {
let mut statuses_guard = statuses.write().await;
let status = statuses_guard.entry(name.clone()).or_default();
status.consecutive_failures += 1;
status.last_error = Some(e.to_string());
status.last_check = Instant::now();
if status.consecutive_failures >= max_failures {
status.healthy = false;
warn!(
"Provider {} marked unhealthy after {} consecutive failures",
name, status.consecutive_failures
);
}
error!("Health check failed for provider {}: {}", name, e);
}
Err(_) => {
let mut statuses_guard = statuses.write().await;
let status = statuses_guard.entry(name.clone()).or_default();
status.consecutive_failures += 1;
status.last_error = Some("Health check timeout".to_string());
status.last_check = Instant::now();
if status.consecutive_failures >= max_failures {
status.healthy = false;
warn!("Provider {} marked unhealthy due to timeout", name);
}
error!("Health check timeout for provider {}", name);
}
}
}
}
});
Ok(())
}
pub async fn get_status(&self) -> Result<RouterHealthStatus> {
let statuses = self.statuses.read().await;
let provider_statuses = statuses.clone();
let overall_healthy = provider_statuses.values().any(|status| status.healthy);
Ok(RouterHealthStatus {
healthy: overall_healthy,
providers: provider_statuses,
last_check: Instant::now(),
})
}
pub async fn get_provider_status(&self, name: &str) -> Result<Option<ProviderHealthStatus>> {
let statuses = self.statuses.read().await;
Ok(statuses.get(name).cloned())
}
pub async fn get_healthy_providers(&self) -> Result<Vec<String>> {
let statuses = self.statuses.read().await;
let healthy_providers = statuses
.iter()
.filter(|(_, status)| status.healthy)
.map(|(name, _)| name.clone())
.collect();
Ok(healthy_providers)
}
pub async fn add_provider(&self, name: &str) -> Result<()> {
let mut statuses = self.statuses.write().await;
statuses.insert(name.to_string(), ProviderHealthStatus::default());
info!("Added provider {} to health checking", name);
Ok(())
}
pub async fn remove_provider(&self, name: &str) -> Result<()> {
let mut statuses = self.statuses.write().await;
statuses.remove(name);
info!("Removed provider {} from health checking", name);
Ok(())
}
pub async fn check_provider(&self, name: &str) -> Result<ProviderHealthStatus> {
let providers = self.providers.read().await;
let provider = providers
.get(name)
.ok_or_else(|| GatewayError::ProviderNotFound(name.to_string()))?;
let start_time = Instant::now();
match tokio::time::timeout(self.timeout, provider.health_check()).await {
Ok(Ok(())) => {
let response_time = start_time.elapsed();
let mut statuses = self.statuses.write().await;
let status = statuses.entry(name.to_string()).or_default();
status.healthy = true;
status.last_success = Some(Instant::now());
status.response_time = Some(response_time);
status.consecutive_failures = 0;
status.last_check = Instant::now();
status.last_error = None;
debug!(
"Manual health check passed for provider {}: {}ms",
name,
response_time.as_millis()
);
Ok(status.clone())
}
Ok(Err(e)) => {
let mut statuses = self.statuses.write().await;
let status = statuses.entry(name.to_string()).or_default();
status.consecutive_failures += 1;
status.last_error = Some(e.to_string());
status.last_check = Instant::now();
if status.consecutive_failures >= self.max_failures {
status.healthy = false;
}
error!("Manual health check failed for provider {}: {}", name, e);
Ok(status.clone())
}
Err(_) => {
let mut statuses = self.statuses.write().await;
let status = statuses.entry(name.to_string()).or_default();
status.consecutive_failures += 1;
status.last_error = Some("Health check timeout".to_string());
status.last_check = Instant::now();
if status.consecutive_failures >= self.max_failures {
status.healthy = false;
}
error!("Manual health check timeout for provider {}", name);
Ok(status.clone())
}
}
}
}
#[derive(Debug, Clone)]
pub struct RouterHealthStatus {
pub healthy: bool,
pub providers: HashMap<String, ProviderHealthStatus>,
pub last_check: Instant,
}