use crate::utils::error::{GatewayError, Result};
use crate::utils::error_recovery::CircuitBreaker;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use tokio::time::interval;
use tracing::{debug, error, info, warn};
use uuid::Uuid;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum HealthStatus {
Healthy,
Degraded,
Unhealthy,
Down,
}
impl HealthStatus {
pub fn allows_requests(&self) -> bool {
matches!(self, HealthStatus::Healthy | HealthStatus::Degraded)
}
pub fn score(&self) -> u32 {
match self {
HealthStatus::Healthy => 100,
HealthStatus::Degraded => 70,
HealthStatus::Unhealthy => 30,
HealthStatus::Down => 0,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HealthCheckResult {
pub status: HealthStatus,
pub response_time_ms: u64,
pub timestamp: chrono::DateTime<chrono::Utc>,
pub details: Option<String>,
pub error: Option<String>,
pub metrics: HashMap<String, f64>,
}
impl HealthCheckResult {
pub fn healthy(response_time_ms: u64) -> Self {
Self {
status: HealthStatus::Healthy,
response_time_ms,
timestamp: chrono::Utc::now(),
details: None,
error: None,
metrics: HashMap::new(),
}
}
pub fn unhealthy(error: String, response_time_ms: u64) -> Self {
Self {
status: HealthStatus::Unhealthy,
response_time_ms,
timestamp: chrono::Utc::now(),
details: None,
error: Some(error),
metrics: HashMap::new(),
}
}
pub fn degraded(reason: String, response_time_ms: u64) -> Self {
Self {
status: HealthStatus::Degraded,
response_time_ms,
timestamp: chrono::Utc::now(),
details: Some(reason),
error: None,
metrics: HashMap::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderHealth {
pub provider_id: String,
pub status: HealthStatus,
pub last_check: Option<HealthCheckResult>,
pub history: Vec<HealthCheckResult>,
pub avg_response_time_ms: f64,
pub success_rate: f64,
pub consecutive_failures: u32,
pub last_healthy: Option<chrono::DateTime<chrono::Utc>>,
pub metrics: HashMap<String, f64>,
}
impl ProviderHealth {
pub fn new(provider_id: String) -> Self {
Self {
provider_id,
status: HealthStatus::Healthy,
last_check: None,
history: Vec::new(),
avg_response_time_ms: 0.0,
success_rate: 100.0,
consecutive_failures: 0,
last_healthy: Some(chrono::Utc::now()),
metrics: HashMap::new(),
}
}
pub fn update(&mut self, result: HealthCheckResult) {
self.status = result.status.clone();
if result.status == HealthStatus::Healthy {
self.consecutive_failures = 0;
self.last_healthy = Some(result.timestamp);
} else {
self.consecutive_failures += 1;
}
self.history.push(result.clone());
if self.history.len() > 50 {
self.history.remove(0);
}
self.last_check = Some(result);
self.recalculate_metrics();
}
fn recalculate_metrics(&mut self) {
if self.history.is_empty() {
return;
}
let total_time: u64 = self.history.iter().map(|h| h.response_time_ms).sum();
self.avg_response_time_ms = total_time as f64 / self.history.len() as f64;
let successful_checks = self.history.iter()
.filter(|h| h.status == HealthStatus::Healthy || h.status == HealthStatus::Degraded)
.count();
self.success_rate = (successful_checks as f64 / self.history.len() as f64) * 100.0;
}
pub fn is_available(&self) -> bool {
self.status.allows_requests() && self.consecutive_failures < 5
}
pub fn routing_weight(&self) -> f64 {
if !self.is_available() {
return 0.0;
}
let status_weight = self.status.score() as f64 / 100.0;
let success_weight = self.success_rate / 100.0;
let latency_weight = if self.avg_response_time_ms > 0.0 {
1.0 / (1.0 + self.avg_response_time_ms / 1000.0)
} else {
1.0
};
(status_weight + success_weight + latency_weight) / 3.0
}
}
#[derive(Debug, Clone)]
pub struct HealthMonitorConfig {
pub check_interval: Duration,
pub check_timeout: Duration,
pub failure_threshold: u32,
pub recovery_threshold: u32,
pub auto_check_enabled: bool,
pub degraded_threshold_ms: u64,
}
impl Default for HealthMonitorConfig {
fn default() -> Self {
Self {
check_interval: Duration::from_secs(30),
check_timeout: Duration::from_secs(10),
failure_threshold: 3,
recovery_threshold: 2,
auto_check_enabled: true,
degraded_threshold_ms: 2000,
}
}
}
pub struct HealthMonitor {
config: HealthMonitorConfig,
provider_health: Arc<RwLock<HashMap<String, ProviderHealth>>>,
circuit_breakers: Arc<RwLock<HashMap<String, CircuitBreaker>>>,
check_tasks: Arc<RwLock<HashMap<String, tokio::task::JoinHandle<()>>>>,
}
impl HealthMonitor {
pub fn new(config: HealthMonitorConfig) -> Self {
Self {
config,
provider_health: Arc::new(RwLock::new(HashMap::new())),
circuit_breakers: Arc::new(RwLock::new(HashMap::new())),
check_tasks: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn register_provider(&self, provider_id: String) {
info!("Registering provider for health monitoring: {}", provider_id);
{
let mut health = self.provider_health.write().unwrap();
health.insert(provider_id.clone(), ProviderHealth::new(provider_id.clone()));
}
{
let mut breakers = self.circuit_breakers.write().unwrap();
let breaker_config = crate::utils::error_recovery::CircuitBreakerConfig::default();
breakers.insert(provider_id.clone(), CircuitBreaker::new(breaker_config));
}
if self.config.auto_check_enabled {
self.start_health_check_task(provider_id).await;
}
}
async fn start_health_check_task(&self, provider_id: String) {
let provider_health = self.provider_health.clone();
let check_interval = self.config.check_interval;
let check_timeout = self.config.check_timeout;
let degraded_threshold = self.config.degraded_threshold_ms;
let task = tokio::spawn(async move {
let mut interval = interval(check_interval);
loop {
interval.tick().await;
debug!("Running health check for provider: {}", provider_id);
let start_time = Instant::now();
let result = match tokio::time::timeout(check_timeout, Self::perform_health_check(&provider_id)).await {
Ok(Ok(response_time)) => {
let response_time_ms = response_time.as_millis() as u64;
if response_time_ms > degraded_threshold {
HealthCheckResult::degraded(
format!("High latency: {}ms", response_time_ms),
response_time_ms
)
} else {
HealthCheckResult::healthy(response_time_ms)
}
}
Ok(Err(error)) => {
let elapsed = start_time.elapsed().as_millis() as u64;
HealthCheckResult::unhealthy(error.to_string(), elapsed)
}
Err(_) => {
let elapsed = start_time.elapsed().as_millis() as u64;
HealthCheckResult::unhealthy("Health check timeout".to_string(), elapsed)
}
};
if let Ok(mut health_map) = provider_health.write() {
if let Some(provider_health) = health_map.get_mut(&provider_id) {
provider_health.update(result);
debug!("Updated health for {}: {:?}", provider_id, provider_health.status);
}
}
}
});
{
let mut tasks = self.check_tasks.write().unwrap();
tasks.insert(provider_id, task);
}
}
async fn perform_health_check(provider_id: &str) -> Result<Duration> {
let start_time = Instant::now();
let delay = match provider_id {
id if id.contains("openai") => Duration::from_millis(100 + rand::random::<u64>() % 200),
id if id.contains("anthropic") => Duration::from_millis(150 + rand::random::<u64>() % 300),
_ => Duration::from_millis(50 + rand::random::<u64>() % 100),
};
tokio::time::sleep(delay).await;
if rand::random::<f64>() < 0.05 {
return Err(GatewayError::External("Simulated health check failure".to_string()));
}
Ok(start_time.elapsed())
}
pub fn get_provider_health(&self, provider_id: &str) -> Option<ProviderHealth> {
self.provider_health.read().unwrap().get(provider_id).cloned()
}
pub fn get_all_provider_health(&self) -> HashMap<String, ProviderHealth> {
self.provider_health.read().unwrap().clone()
}
pub fn get_healthy_providers(&self) -> Vec<(String, f64)> {
let health_map = self.provider_health.read().unwrap();
let mut providers: Vec<_> = health_map
.iter()
.filter(|(_, health)| health.is_available())
.map(|(id, health)| (id.clone(), health.routing_weight()))
.collect();
providers.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
providers
}
pub fn update_provider_health(&self, provider_id: &str, result: HealthCheckResult) {
if let Ok(mut health_map) = self.provider_health.write() {
if let Some(provider_health) = health_map.get_mut(provider_id) {
provider_health.update(result);
info!("Manually updated health for {}: {:?}", provider_id, provider_health.status);
}
}
}
pub fn get_circuit_breaker(&self, provider_id: &str) -> Option<CircuitBreaker> {
self.circuit_breakers.read().unwrap().get(provider_id).cloned()
}
pub async fn shutdown(&self) {
info!("Shutting down health monitoring");
let tasks = {
let mut task_map = self.check_tasks.write().unwrap();
let tasks: Vec<_> = task_map.drain().map(|(_, task)| task).collect();
tasks
};
for task in tasks {
task.abort();
}
info!("Health monitoring shutdown complete");
}
}
pub struct SystemHealth {
provider_health: HashMap<String, ProviderHealth>,
last_updated: chrono::DateTime<chrono::Utc>,
}
impl SystemHealth {
pub fn new(provider_health: HashMap<String, ProviderHealth>) -> Self {
Self {
provider_health,
last_updated: chrono::Utc::now(),
}
}
pub fn overall_status(&self) -> HealthStatus {
if self.provider_health.is_empty() {
return HealthStatus::Down;
}
let total_providers = self.provider_health.len();
let healthy_providers = self.provider_health.values()
.filter(|h| h.status == HealthStatus::Healthy)
.count();
let available_providers = self.provider_health.values()
.filter(|h| h.is_available())
.count();
if available_providers == 0 {
HealthStatus::Down
} else if healthy_providers == total_providers {
HealthStatus::Healthy
} else if available_providers >= total_providers / 2 {
HealthStatus::Degraded
} else {
HealthStatus::Unhealthy
}
}
pub fn metrics(&self) -> HashMap<String, f64> {
let mut metrics = HashMap::new();
if !self.provider_health.is_empty() {
let total = self.provider_health.len() as f64;
let healthy = self.provider_health.values()
.filter(|h| h.status == HealthStatus::Healthy)
.count() as f64;
let available = self.provider_health.values()
.filter(|h| h.is_available())
.count() as f64;
metrics.insert("total_providers".to_string(), total);
metrics.insert("healthy_providers".to_string(), healthy);
metrics.insert("available_providers".to_string(), available);
metrics.insert("health_percentage".to_string(), (healthy / total) * 100.0);
metrics.insert("availability_percentage".to_string(), (available / total) * 100.0);
let avg_response_time: f64 = self.provider_health.values()
.map(|h| h.avg_response_time_ms)
.sum::<f64>() / total;
metrics.insert("avg_response_time_ms".to_string(), avg_response_time);
}
metrics
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_health_status_scoring() {
assert_eq!(HealthStatus::Healthy.score(), 100);
assert_eq!(HealthStatus::Degraded.score(), 70);
assert_eq!(HealthStatus::Unhealthy.score(), 30);
assert_eq!(HealthStatus::Down.score(), 0);
assert!(HealthStatus::Healthy.allows_requests());
assert!(HealthStatus::Degraded.allows_requests());
assert!(!HealthStatus::Unhealthy.allows_requests());
assert!(!HealthStatus::Down.allows_requests());
}
#[test]
fn test_provider_health_update() {
let mut provider = ProviderHealth::new("test-provider".to_string());
assert_eq!(provider.status, HealthStatus::Healthy);
assert_eq!(provider.consecutive_failures, 0);
let unhealthy_result = HealthCheckResult::unhealthy("test error".to_string(), 1000);
provider.update(unhealthy_result);
assert_eq!(provider.status, HealthStatus::Unhealthy);
assert_eq!(provider.consecutive_failures, 1);
let healthy_result = HealthCheckResult::healthy(500);
provider.update(healthy_result);
assert_eq!(provider.status, HealthStatus::Healthy);
assert_eq!(provider.consecutive_failures, 0);
}
#[test]
fn test_provider_routing_weight() {
let mut provider = ProviderHealth::new("test-provider".to_string());
let healthy_result = HealthCheckResult::healthy(100);
provider.update(healthy_result);
let weight = provider.routing_weight();
assert!(weight > 0.8);
provider.status = HealthStatus::Down;
let weight = provider.routing_weight();
assert_eq!(weight, 0.0);
}
#[tokio::test]
async fn test_health_monitor_registration() {
let config = HealthMonitorConfig {
auto_check_enabled: false,
..Default::default()
};
let monitor = HealthMonitor::new(config);
monitor.register_provider("test-provider".to_string()).await;
let health = monitor.get_provider_health("test-provider");
assert!(health.is_some());
assert_eq!(health.unwrap().provider_id, "test-provider");
}
#[test]
fn test_system_health() {
let mut providers = HashMap::new();
providers.insert("provider1".to_string(), ProviderHealth::new("provider1".to_string()));
let mut provider2 = ProviderHealth::new("provider2".to_string());
provider2.status = HealthStatus::Unhealthy;
providers.insert("provider2".to_string(), provider2);
let system_health = SystemHealth::new(providers);
assert_eq!(system_health.overall_status(), HealthStatus::Degraded);
let metrics = system_health.metrics();
assert_eq!(metrics.get("total_providers"), Some(&2.0));
assert_eq!(metrics.get("healthy_providers"), Some(&1.0));
}
}