use super::provider::ProviderHealth;
use crate::utils::error::recovery::circuit_breaker::CircuitBreaker;
use crate::utils::error::recovery::types::CircuitBreakerConfig;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use tracing::info;
#[derive(Debug, Clone)]
pub struct HealthMonitorConfig {
pub check_interval: Duration,
pub check_timeout: Duration,
pub failure_threshold: u32,
pub recovery_threshold: u32,
pub auto_check_enabled: bool,
pub degraded_threshold_ms: u64,
pub min_requests: u32,
pub success_threshold: u32,
}
impl Default for HealthMonitorConfig {
fn default() -> Self {
Self {
check_interval: Duration::from_secs(30),
check_timeout: Duration::from_secs(10),
failure_threshold: 3,
recovery_threshold: 2,
auto_check_enabled: true,
degraded_threshold_ms: 2000,
min_requests: 10,
success_threshold: 3,
}
}
}
pub struct HealthMonitor {
pub(crate) config: HealthMonitorConfig,
pub(crate) provider_health: Arc<RwLock<HashMap<String, ProviderHealth>>>,
pub(crate) circuit_breakers: Arc<RwLock<HashMap<String, Arc<CircuitBreaker>>>>,
pub(crate) check_tasks: Arc<RwLock<HashMap<String, tokio::task::JoinHandle<()>>>>,
}
impl HealthMonitor {
pub fn new(config: HealthMonitorConfig) -> Self {
Self {
config,
provider_health: Arc::new(RwLock::new(HashMap::new())),
circuit_breakers: Arc::new(RwLock::new(HashMap::new())),
check_tasks: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn register_provider(&self, provider_id: String) {
info!(
"Registering provider for health monitoring: {}",
provider_id
);
{
let mut health = self.provider_health.write().await;
health.insert(
provider_id.clone(),
ProviderHealth::new(provider_id.clone()),
);
}
{
let mut breakers = self.circuit_breakers.write().await;
let breaker_config = CircuitBreakerConfig {
failure_threshold: self.config.failure_threshold,
success_threshold: self.config.success_threshold,
min_requests: self.config.min_requests,
..CircuitBreakerConfig::default()
};
breakers.insert(
provider_id.clone(),
Arc::new(CircuitBreaker::new(breaker_config)),
);
}
if self.config.auto_check_enabled {
self.start_health_check_task(provider_id).await;
}
}
pub async fn get_circuit_breaker(&self, provider_id: &str) -> Option<Arc<CircuitBreaker>> {
let breakers = self.circuit_breakers.read().await;
breakers.get(provider_id).cloned()
}
pub async fn shutdown(&self) {
info!("Shutting down health monitoring");
let tasks = {
let mut task_map = self.check_tasks.write().await;
task_map.drain().map(|(_, task)| task).collect::<Vec<_>>()
};
for task in tasks {
task.abort();
}
info!("Health monitoring shutdown complete");
}
pub(crate) async fn start_health_check_task(&self, provider_id: String) {
use super::checker::perform_health_check;
use super::types::HealthCheckResult;
use std::time::Instant;
use tokio::time::interval;
use tracing::debug;
let provider_health = self.provider_health.clone();
let check_interval = self.config.check_interval;
let check_timeout = self.config.check_timeout;
let degraded_threshold = self.config.degraded_threshold_ms;
let provider_id_clone = provider_id.clone();
let task = tokio::spawn(async move {
let provider_id = provider_id_clone;
let mut interval = interval(check_interval);
loop {
interval.tick().await;
debug!("Running health check for provider: {}", provider_id);
let start_time = Instant::now();
let result =
match tokio::time::timeout(check_timeout, perform_health_check(&provider_id))
.await
{
Ok(Ok(response_time)) => {
let response_time_ms = response_time.as_millis() as u64;
if response_time_ms > degraded_threshold {
HealthCheckResult::degraded(
format!("High latency: {}ms", response_time_ms),
response_time_ms,
)
} else {
HealthCheckResult::healthy(response_time_ms)
}
}
Ok(Err(error)) => {
let elapsed = start_time.elapsed().as_millis() as u64;
HealthCheckResult::unhealthy(error.to_string(), elapsed)
}
Err(_) => {
let elapsed = start_time.elapsed().as_millis() as u64;
HealthCheckResult::unhealthy(
"Health check timeout".to_string(),
elapsed,
)
}
};
let mut health_map = provider_health.write().await;
if let Some(provider_health) = health_map.get_mut(&provider_id) {
provider_health.update(result);
debug!(
"Updated health for {}: {:?}",
provider_id, provider_health.status
);
}
}
});
let mut tasks = self.check_tasks.write().await;
tasks.insert(provider_id, task);
}
}