use crate::connection_pool::PooledConnection;
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{broadcast, RwLock};
use tokio::time::interval;
use tracing::{debug, error, info, warn};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HealthCheckConfig {
pub check_interval: Duration,
pub check_timeout: Duration,
pub failure_threshold: u32,
pub recovery_threshold: u32,
pub enable_statistics: bool,
pub retry_attempts: u32,
pub retry_delay: Duration,
}
impl Default for HealthCheckConfig {
fn default() -> Self {
Self {
check_interval: Duration::from_secs(30),
check_timeout: Duration::from_secs(5),
failure_threshold: 3,
recovery_threshold: 2,
enable_statistics: true,
retry_attempts: 2,
retry_delay: Duration::from_millis(500),
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum HealthStatus {
Healthy,
Degraded,
Unhealthy,
Dead,
Unknown,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HealthStatistics {
pub total_checks: u64,
pub successful_checks: u64,
pub failed_checks: u64,
pub avg_response_time_ms: f64,
pub min_response_time_ms: f64,
pub max_response_time_ms: f64,
pub consecutive_failures: u32,
pub consecutive_successes: u32,
#[serde(skip)]
pub last_check: Option<Instant>,
#[serde(skip)]
pub last_success: Option<Instant>,
#[serde(skip)]
pub last_failure: Option<Instant>,
pub error_counts: HashMap<String, u64>,
}
impl Default for HealthStatistics {
fn default() -> Self {
Self {
total_checks: 0,
successful_checks: 0,
failed_checks: 0,
avg_response_time_ms: 0.0,
min_response_time_ms: f64::MAX,
max_response_time_ms: 0.0,
consecutive_failures: 0,
consecutive_successes: 0,
last_check: None,
last_success: None,
last_failure: None,
error_counts: HashMap::new(),
}
}
}
#[derive(Debug, Clone)]
pub struct ConnectionHealthRecord {
pub connection_id: String,
pub status: HealthStatus,
pub statistics: HealthStatistics,
pub metadata: HashMap<String, String>,
pub history: Vec<HealthCheckResult>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HealthCheckResult {
#[serde(skip, default = "Instant::now")]
pub timestamp: Instant,
pub success: bool,
pub response_time_ms: f64,
pub error: Option<String>,
}
pub struct HealthMonitor<T: PooledConnection> {
config: HealthCheckConfig,
health_records: Arc<RwLock<HashMap<String, ConnectionHealthRecord>>>,
event_sender: broadcast::Sender<HealthEvent>,
shutdown_signal: Arc<RwLock<bool>>,
_phantom: std::marker::PhantomData<T>,
}
#[derive(Debug, Clone)]
pub enum HealthEvent {
StatusChanged {
connection_id: String,
old_status: HealthStatus,
new_status: HealthStatus,
},
ConnectionDead {
connection_id: String,
reason: String,
},
ConnectionRecovered { connection_id: String },
HealthCheckFailed {
connection_id: String,
error: String,
},
}
impl<T: PooledConnection> HealthMonitor<T> {
pub fn new(config: HealthCheckConfig) -> Self {
let (event_sender, _) = broadcast::channel(1000);
Self {
config,
health_records: Arc::new(RwLock::new(HashMap::new())),
event_sender,
shutdown_signal: Arc::new(RwLock::new(false)),
_phantom: std::marker::PhantomData,
}
}
pub async fn register_connection(
&self,
connection_id: String,
metadata: HashMap<String, String>,
) {
let mut records = self.health_records.write().await;
let record = ConnectionHealthRecord {
connection_id: connection_id.clone(),
status: HealthStatus::Unknown,
statistics: HealthStatistics::default(),
metadata,
history: Vec::with_capacity(100),
};
records.insert(connection_id.clone(), record);
info!(
"Registered connection {} for health monitoring",
connection_id
);
}
pub async fn unregister_connection(&self, connection_id: &str) {
let mut records = self.health_records.write().await;
if records.remove(connection_id).is_some() {
info!(
"Unregistered connection {} from health monitoring",
connection_id
);
}
}
pub async fn check_connection_health(
&self,
connection_id: &str,
connection: &T,
) -> Result<HealthStatus> {
let start_time = Instant::now();
let mut attempts = 0;
let mut last_error = None;
while attempts < self.config.retry_attempts {
attempts += 1;
match tokio::time::timeout(self.config.check_timeout, connection.is_healthy()).await {
Ok(true) => {
let response_time = start_time.elapsed();
self.record_health_check_result(connection_id, true, response_time, None)
.await?;
return Ok(self.determine_health_status(connection_id).await);
}
Ok(false) => {
last_error = Some("Health check returned false".to_string());
}
Err(_) => {
last_error = Some("Health check timed out".to_string());
}
}
if attempts < self.config.retry_attempts {
tokio::time::sleep(self.config.retry_delay).await;
}
}
let response_time = start_time.elapsed();
self.record_health_check_result(connection_id, false, response_time, last_error.clone())
.await?;
let status = self.determine_health_status(connection_id).await;
if let Some(error) = last_error {
let _ = self.event_sender.send(HealthEvent::HealthCheckFailed {
connection_id: connection_id.to_string(),
error,
});
}
Ok(status)
}
async fn record_health_check_result(
&self,
connection_id: &str,
success: bool,
response_time: Duration,
error: Option<String>,
) -> Result<()> {
let mut records = self.health_records.write().await;
if let Some(record) = records.get_mut(connection_id) {
let response_time_ms = response_time.as_millis() as f64;
let stats = &mut record.statistics;
stats.total_checks += 1;
stats.last_check = Some(Instant::now());
if success {
stats.successful_checks += 1;
stats.consecutive_successes += 1;
stats.consecutive_failures = 0;
stats.last_success = Some(Instant::now());
} else {
stats.failed_checks += 1;
stats.consecutive_failures += 1;
stats.consecutive_successes = 0;
stats.last_failure = Some(Instant::now());
if let Some(ref err) = error {
*stats.error_counts.entry(err.clone()).or_insert(0) += 1;
}
}
stats.min_response_time_ms = stats.min_response_time_ms.min(response_time_ms);
stats.max_response_time_ms = stats.max_response_time_ms.max(response_time_ms);
let alpha = 0.1;
if stats.total_checks == 1 {
stats.avg_response_time_ms = response_time_ms;
} else {
stats.avg_response_time_ms =
alpha * response_time_ms + (1.0 - alpha) * stats.avg_response_time_ms;
}
let result = HealthCheckResult {
timestamp: Instant::now(),
success,
response_time_ms,
error,
};
record.history.push(result);
if record.history.len() > 100 {
record.history.remove(0);
}
}
Ok(())
}
async fn determine_health_status(&self, connection_id: &str) -> HealthStatus {
let records = self.health_records.read().await;
if let Some(record) = records.get(connection_id) {
let stats = &record.statistics;
let old_status = record.status.clone();
let consecutive_failures = stats.consecutive_failures;
let new_status = if stats.consecutive_failures >= self.config.failure_threshold * 2 {
HealthStatus::Dead
} else if stats.consecutive_failures >= self.config.failure_threshold {
HealthStatus::Unhealthy
} else if stats.consecutive_successes >= self.config.recovery_threshold {
HealthStatus::Healthy
} else if stats.consecutive_failures > 0 {
HealthStatus::Degraded
} else {
HealthStatus::Unknown
};
if old_status != new_status {
drop(records);
let _ = self.event_sender.send(HealthEvent::StatusChanged {
connection_id: connection_id.to_string(),
old_status: old_status.clone(), new_status: new_status.clone(),
});
match new_status {
HealthStatus::Dead => {
let _ = self.event_sender.send(HealthEvent::ConnectionDead {
connection_id: connection_id.to_string(),
reason: format!("{consecutive_failures} consecutive failures"), });
}
HealthStatus::Healthy if old_status == HealthStatus::Unhealthy => {
let _ = self.event_sender.send(HealthEvent::ConnectionRecovered {
connection_id: connection_id.to_string(),
});
}
_ => {}
}
let mut records = self.health_records.write().await;
if let Some(record) = records.get_mut(connection_id) {
record.status = new_status.clone();
}
}
new_status
} else {
HealthStatus::Unknown
}
}
pub async fn start_monitoring(&self, connections: Arc<RwLock<HashMap<String, T>>>) {
let health_records = self.health_records.clone();
let config = self.config.clone();
let shutdown_signal = self.shutdown_signal.clone();
let event_sender = self.event_sender.clone();
tokio::spawn(async move {
let mut check_interval = interval(config.check_interval);
loop {
check_interval.tick().await;
if *shutdown_signal.read().await {
info!("Health monitor shutting down");
break;
}
let connections_guard = connections.read().await;
let connection_ids: Vec<String> = connections_guard.keys().cloned().collect();
drop(connections_guard);
for conn_id in connection_ids {
let start_time = Instant::now();
let health_check_result = {
let connection_guard = connections.read().await;
let connection = match connection_guard.get(&conn_id) {
Some(conn) => conn,
None => continue, };
tokio::time::timeout(config.check_timeout, connection.is_healthy()).await
};
match health_check_result {
Ok(healthy) => {
let response_time = start_time.elapsed();
let response_time_ms = response_time.as_millis() as f64;
let mut records = health_records.write().await;
if let Some(record) = records.get_mut(&conn_id) {
let stats = &mut record.statistics;
stats.total_checks += 1;
stats.last_check = Some(Instant::now());
if healthy {
stats.successful_checks += 1;
stats.consecutive_successes += 1;
stats.consecutive_failures = 0;
stats.last_success = Some(Instant::now());
debug!(
"Connection {} health check passed in {:.2}ms",
conn_id, response_time_ms
);
} else {
stats.failed_checks += 1;
stats.consecutive_failures += 1;
stats.consecutive_successes = 0;
stats.last_failure = Some(Instant::now());
warn!("Connection {} health check failed", conn_id);
}
let old_status = record.status.clone();
let new_status = if stats.consecutive_failures
>= config.failure_threshold * 2
{
HealthStatus::Dead
} else if stats.consecutive_failures >= config.failure_threshold {
HealthStatus::Unhealthy
} else if stats.consecutive_successes >= config.recovery_threshold {
HealthStatus::Healthy
} else {
old_status.clone()
};
if old_status != new_status {
record.status = new_status.clone();
let _ = event_sender.send(HealthEvent::StatusChanged {
connection_id: conn_id.clone(),
old_status,
new_status,
});
}
}
}
Err(_) => {
error!("Health check timeout for connection {}", conn_id);
let mut records = health_records.write().await;
if let Some(record) = records.get_mut(&conn_id) {
record.statistics.failed_checks += 1;
record.statistics.consecutive_failures += 1;
record.statistics.consecutive_successes = 0;
*record
.statistics
.error_counts
.entry("timeout".to_string())
.or_insert(0) += 1;
}
}
}
}
}
});
}
pub async fn stop_monitoring(&self) {
*self.shutdown_signal.write().await = true;
}
pub async fn get_connection_health(
&self,
connection_id: &str,
) -> Option<ConnectionHealthRecord> {
self.health_records.read().await.get(connection_id).cloned()
}
pub async fn get_unhealthy_connections(&self) -> Vec<String> {
self.health_records
.read()
.await
.iter()
.filter(|(_, record)| {
matches!(
record.status,
HealthStatus::Unhealthy | HealthStatus::Dead | HealthStatus::Degraded
)
})
.map(|(id, _)| id.clone())
.collect()
}
pub async fn get_dead_connections(&self) -> Vec<String> {
self.health_records
.read()
.await
.iter()
.filter(|(_, record)| record.status == HealthStatus::Dead)
.map(|(id, _)| id.clone())
.collect()
}
pub async fn get_overall_statistics(&self) -> OverallHealthStatistics {
let records = self.health_records.read().await;
let total_connections = records.len();
let healthy_connections = records
.values()
.filter(|r| r.status == HealthStatus::Healthy)
.count();
let degraded_connections = records
.values()
.filter(|r| r.status == HealthStatus::Degraded)
.count();
let unhealthy_connections = records
.values()
.filter(|r| r.status == HealthStatus::Unhealthy)
.count();
let dead_connections = records
.values()
.filter(|r| r.status == HealthStatus::Dead)
.count();
let total_checks: u64 = records.values().map(|r| r.statistics.total_checks).sum();
let successful_checks: u64 = records
.values()
.map(|r| r.statistics.successful_checks)
.sum();
let failed_checks: u64 = records.values().map(|r| r.statistics.failed_checks).sum();
let avg_response_time_ms = if total_connections > 0 {
records
.values()
.map(|r| r.statistics.avg_response_time_ms)
.sum::<f64>()
/ total_connections as f64
} else {
0.0
};
OverallHealthStatistics {
total_connections,
healthy_connections,
degraded_connections,
unhealthy_connections,
dead_connections,
total_checks,
successful_checks,
failed_checks,
success_rate: if total_checks > 0 {
(successful_checks as f64 / total_checks as f64) * 100.0
} else {
0.0
},
avg_response_time_ms,
}
}
pub fn subscribe(&self) -> broadcast::Receiver<HealthEvent> {
self.event_sender.subscribe()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OverallHealthStatistics {
pub total_connections: usize,
pub healthy_connections: usize,
pub degraded_connections: usize,
pub unhealthy_connections: usize,
pub dead_connections: usize,
pub total_checks: u64,
pub successful_checks: u64,
pub failed_checks: u64,
pub success_rate: f64,
pub avg_response_time_ms: f64,
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicBool, Ordering};
#[derive(Clone)]
struct TestConnection {
healthy: Arc<AtomicBool>,
}
#[async_trait::async_trait]
impl PooledConnection for TestConnection {
async fn is_healthy(&self) -> bool {
self.healthy.load(Ordering::Relaxed)
}
async fn close(&mut self) -> Result<()> {
Ok(())
}
fn clone_connection(&self) -> Box<dyn PooledConnection> {
Box::new(TestConnection {
healthy: Arc::new(AtomicBool::new(self.healthy.load(Ordering::Relaxed))),
})
}
fn created_at(&self) -> Instant {
Instant::now()
}
fn last_activity(&self) -> Instant {
Instant::now()
}
fn update_activity(&mut self) {}
}
#[tokio::test]
async fn test_health_monitoring() {
let config = HealthCheckConfig::default();
let monitor = HealthMonitor::<TestConnection>::new(config);
let metadata = HashMap::new();
monitor
.register_connection("test-conn-1".to_string(), metadata)
.await;
let healthy_flag = Arc::new(AtomicBool::new(true));
let connection = TestConnection {
healthy: healthy_flag.clone(),
};
let status = monitor
.check_connection_health("test-conn-1", &connection)
.await
.unwrap();
assert_eq!(status, HealthStatus::Unknown);
for _ in 0..3 {
monitor
.check_connection_health("test-conn-1", &connection)
.await
.unwrap();
}
let health = monitor.get_connection_health("test-conn-1").await.unwrap();
assert_eq!(health.status, HealthStatus::Healthy);
assert_eq!(health.statistics.consecutive_successes, 4);
healthy_flag.store(false, Ordering::Relaxed);
for _ in 0..3 {
monitor
.check_connection_health("test-conn-1", &connection)
.await
.unwrap();
}
let health = monitor.get_connection_health("test-conn-1").await.unwrap();
assert_eq!(health.status, HealthStatus::Unhealthy);
assert_eq!(health.statistics.consecutive_failures, 3);
let unhealthy = monitor.get_unhealthy_connections().await;
assert!(unhealthy.contains(&"test-conn-1".to_string()));
}
#[tokio::test]
async fn test_dead_connection_detection() {
let config = HealthCheckConfig {
failure_threshold: 2,
..Default::default()
};
let monitor = HealthMonitor::<TestConnection>::new(config);
monitor
.register_connection("test-conn-1".to_string(), HashMap::new())
.await;
let connection = TestConnection {
healthy: Arc::new(AtomicBool::new(false)),
};
for _ in 0..5 {
monitor
.check_connection_health("test-conn-1", &connection)
.await
.unwrap();
}
let health = monitor.get_connection_health("test-conn-1").await.unwrap();
assert_eq!(health.status, HealthStatus::Dead);
let dead = monitor.get_dead_connections().await;
assert!(dead.contains(&"test-conn-1".to_string()));
}
#[tokio::test]
async fn test_health_events() {
let config = HealthCheckConfig::default();
let monitor = HealthMonitor::<TestConnection>::new(config);
let mut event_receiver = monitor.subscribe();
monitor
.register_connection("test-conn-1".to_string(), HashMap::new())
.await;
let healthy_flag = Arc::new(AtomicBool::new(true));
let connection = TestConnection {
healthy: healthy_flag.clone(),
};
for _ in 0..3 {
monitor
.check_connection_health("test-conn-1", &connection)
.await
.unwrap();
}
healthy_flag.store(false, Ordering::Relaxed);
for _ in 0..3 {
monitor
.check_connection_health("test-conn-1", &connection)
.await
.unwrap();
}
tokio::time::timeout(Duration::from_secs(1), async {
while let Ok(event) = event_receiver.recv().await {
if matches!(event, HealthEvent::StatusChanged { .. }) {
return;
}
}
})
.await
.expect("Should receive status change event");
}
}