1use crate::connection_pool::PooledConnection;
7use anyhow::Result;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12use tokio::sync::{broadcast, RwLock};
13use tokio::time::interval;
14use tracing::{debug, error, info, warn};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct HealthCheckConfig {
19 pub check_interval: Duration,
21 pub check_timeout: Duration,
23 pub failure_threshold: u32,
25 pub recovery_threshold: u32,
27 pub enable_statistics: bool,
29 pub retry_attempts: u32,
31 pub retry_delay: Duration,
33}
34
35impl Default for HealthCheckConfig {
36 fn default() -> Self {
37 Self {
38 check_interval: Duration::from_secs(30),
39 check_timeout: Duration::from_secs(5),
40 failure_threshold: 3,
41 recovery_threshold: 2,
42 enable_statistics: true,
43 retry_attempts: 2,
44 retry_delay: Duration::from_millis(500),
45 }
46 }
47}
48
49#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
51pub enum HealthStatus {
52 Healthy,
54 Degraded,
56 Unhealthy,
58 Dead,
60 Unknown,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct HealthStatistics {
67 pub total_checks: u64,
69 pub successful_checks: u64,
71 pub failed_checks: u64,
73 pub avg_response_time_ms: f64,
75 pub min_response_time_ms: f64,
77 pub max_response_time_ms: f64,
79 pub consecutive_failures: u32,
81 pub consecutive_successes: u32,
83 #[serde(skip)]
85 pub last_check: Option<Instant>,
86 #[serde(skip)]
88 pub last_success: Option<Instant>,
89 #[serde(skip)]
91 pub last_failure: Option<Instant>,
92 pub error_counts: HashMap<String, u64>,
94}
95
96impl Default for HealthStatistics {
97 fn default() -> Self {
98 Self {
99 total_checks: 0,
100 successful_checks: 0,
101 failed_checks: 0,
102 avg_response_time_ms: 0.0,
103 min_response_time_ms: f64::MAX,
104 max_response_time_ms: 0.0,
105 consecutive_failures: 0,
106 consecutive_successes: 0,
107 last_check: None,
108 last_success: None,
109 last_failure: None,
110 error_counts: HashMap::new(),
111 }
112 }
113}
114
115#[derive(Debug, Clone)]
117pub struct ConnectionHealthRecord {
118 pub connection_id: String,
120 pub status: HealthStatus,
122 pub statistics: HealthStatistics,
124 pub metadata: HashMap<String, String>,
126 pub history: Vec<HealthCheckResult>,
128}
129
130#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct HealthCheckResult {
133 #[serde(skip, default = "Instant::now")]
134 pub timestamp: Instant,
135 pub success: bool,
136 pub response_time_ms: f64,
137 pub error: Option<String>,
138}
139
140pub struct HealthMonitor<T: PooledConnection> {
142 config: HealthCheckConfig,
143 health_records: Arc<RwLock<HashMap<String, ConnectionHealthRecord>>>,
145 event_sender: broadcast::Sender<HealthEvent>,
147 shutdown_signal: Arc<RwLock<bool>>,
149 _phantom: std::marker::PhantomData<T>,
151}
152
153#[derive(Debug, Clone)]
155pub enum HealthEvent {
156 StatusChanged {
158 connection_id: String,
159 old_status: HealthStatus,
160 new_status: HealthStatus,
161 },
162 ConnectionDead {
164 connection_id: String,
165 reason: String,
166 },
167 ConnectionRecovered { connection_id: String },
169 HealthCheckFailed {
171 connection_id: String,
172 error: String,
173 },
174}
175
176impl<T: PooledConnection> HealthMonitor<T> {
177 pub fn new(config: HealthCheckConfig) -> Self {
179 let (event_sender, _) = broadcast::channel(1000);
180
181 Self {
182 config,
183 health_records: Arc::new(RwLock::new(HashMap::new())),
184 event_sender,
185 shutdown_signal: Arc::new(RwLock::new(false)),
186 _phantom: std::marker::PhantomData,
187 }
188 }
189
190 pub async fn register_connection(
192 &self,
193 connection_id: String,
194 metadata: HashMap<String, String>,
195 ) {
196 let mut records = self.health_records.write().await;
197
198 let record = ConnectionHealthRecord {
199 connection_id: connection_id.clone(),
200 status: HealthStatus::Unknown,
201 statistics: HealthStatistics::default(),
202 metadata,
203 history: Vec::with_capacity(100),
204 };
205
206 records.insert(connection_id.clone(), record);
207 info!(
208 "Registered connection {} for health monitoring",
209 connection_id
210 );
211 }
212
213 pub async fn unregister_connection(&self, connection_id: &str) {
215 let mut records = self.health_records.write().await;
216 if records.remove(connection_id).is_some() {
217 info!(
218 "Unregistered connection {} from health monitoring",
219 connection_id
220 );
221 }
222 }
223
224 pub async fn check_connection_health(
226 &self,
227 connection_id: &str,
228 connection: &T,
229 ) -> Result<HealthStatus> {
230 let start_time = Instant::now();
231 let mut attempts = 0;
232 let mut last_error = None;
233
234 while attempts < self.config.retry_attempts {
236 attempts += 1;
237
238 match tokio::time::timeout(self.config.check_timeout, connection.is_healthy()).await {
239 Ok(true) => {
240 let response_time = start_time.elapsed();
241 self.record_health_check_result(connection_id, true, response_time, None)
242 .await?;
243
244 return Ok(self.determine_health_status(connection_id).await);
245 }
246 Ok(false) => {
247 last_error = Some("Health check returned false".to_string());
248 }
249 Err(_) => {
250 last_error = Some("Health check timed out".to_string());
251 }
252 }
253
254 if attempts < self.config.retry_attempts {
255 tokio::time::sleep(self.config.retry_delay).await;
256 }
257 }
258
259 let response_time = start_time.elapsed();
261 self.record_health_check_result(connection_id, false, response_time, last_error.clone())
262 .await?;
263
264 let status = self.determine_health_status(connection_id).await;
265
266 if let Some(error) = last_error {
267 let _ = self.event_sender.send(HealthEvent::HealthCheckFailed {
268 connection_id: connection_id.to_string(),
269 error,
270 });
271 }
272
273 Ok(status)
274 }
275
276 async fn record_health_check_result(
278 &self,
279 connection_id: &str,
280 success: bool,
281 response_time: Duration,
282 error: Option<String>,
283 ) -> Result<()> {
284 let mut records = self.health_records.write().await;
285
286 if let Some(record) = records.get_mut(connection_id) {
287 let response_time_ms = response_time.as_millis() as f64;
288 let stats = &mut record.statistics;
289
290 stats.total_checks += 1;
292 stats.last_check = Some(Instant::now());
293
294 if success {
295 stats.successful_checks += 1;
296 stats.consecutive_successes += 1;
297 stats.consecutive_failures = 0;
298 stats.last_success = Some(Instant::now());
299 } else {
300 stats.failed_checks += 1;
301 stats.consecutive_failures += 1;
302 stats.consecutive_successes = 0;
303 stats.last_failure = Some(Instant::now());
304
305 if let Some(ref err) = error {
306 *stats.error_counts.entry(err.clone()).or_insert(0) += 1;
307 }
308 }
309
310 stats.min_response_time_ms = stats.min_response_time_ms.min(response_time_ms);
312 stats.max_response_time_ms = stats.max_response_time_ms.max(response_time_ms);
313
314 let alpha = 0.1;
316 if stats.total_checks == 1 {
317 stats.avg_response_time_ms = response_time_ms;
318 } else {
319 stats.avg_response_time_ms =
320 alpha * response_time_ms + (1.0 - alpha) * stats.avg_response_time_ms;
321 }
322
323 let result = HealthCheckResult {
325 timestamp: Instant::now(),
326 success,
327 response_time_ms,
328 error,
329 };
330
331 record.history.push(result);
332 if record.history.len() > 100 {
333 record.history.remove(0);
334 }
335 }
336
337 Ok(())
338 }
339
340 async fn determine_health_status(&self, connection_id: &str) -> HealthStatus {
342 let records = self.health_records.read().await;
343
344 if let Some(record) = records.get(connection_id) {
345 let stats = &record.statistics;
346 let old_status = record.status.clone();
347 let consecutive_failures = stats.consecutive_failures; let new_status = if stats.consecutive_failures >= self.config.failure_threshold * 2 {
350 HealthStatus::Dead
351 } else if stats.consecutive_failures >= self.config.failure_threshold {
352 HealthStatus::Unhealthy
353 } else if stats.consecutive_successes >= self.config.recovery_threshold {
354 HealthStatus::Healthy
355 } else if stats.consecutive_failures > 0 {
356 HealthStatus::Degraded
357 } else {
358 HealthStatus::Unknown
359 };
360
361 if old_status != new_status {
363 drop(records); let _ = self.event_sender.send(HealthEvent::StatusChanged {
366 connection_id: connection_id.to_string(),
367 old_status: old_status.clone(), new_status: new_status.clone(),
369 });
370
371 match new_status {
372 HealthStatus::Dead => {
373 let _ = self.event_sender.send(HealthEvent::ConnectionDead {
374 connection_id: connection_id.to_string(),
375 reason: format!("{consecutive_failures} consecutive failures"), });
377 }
378 HealthStatus::Healthy if old_status == HealthStatus::Unhealthy => {
379 let _ = self.event_sender.send(HealthEvent::ConnectionRecovered {
380 connection_id: connection_id.to_string(),
381 });
382 }
383 _ => {}
384 }
385
386 let mut records = self.health_records.write().await;
388 if let Some(record) = records.get_mut(connection_id) {
389 record.status = new_status.clone();
390 }
391 }
392
393 new_status
394 } else {
395 HealthStatus::Unknown
396 }
397 }
398
399 pub async fn start_monitoring(&self, connections: Arc<RwLock<HashMap<String, T>>>) {
401 let health_records = self.health_records.clone();
402 let config = self.config.clone();
403 let shutdown_signal = self.shutdown_signal.clone();
404 let event_sender = self.event_sender.clone();
405
406 tokio::spawn(async move {
407 let mut check_interval = interval(config.check_interval);
408
409 loop {
410 check_interval.tick().await;
411
412 if *shutdown_signal.read().await {
414 info!("Health monitor shutting down");
415 break;
416 }
417
418 let connections_guard = connections.read().await;
420 let connection_ids: Vec<String> = connections_guard.keys().cloned().collect();
421 drop(connections_guard);
422
423 for conn_id in connection_ids {
424 let start_time = Instant::now();
425
426 let health_check_result = {
428 let connection_guard = connections.read().await;
429 let connection = match connection_guard.get(&conn_id) {
430 Some(conn) => conn,
431 None => continue, };
433
434 tokio::time::timeout(config.check_timeout, connection.is_healthy()).await
436 };
437
438 match health_check_result {
439 Ok(healthy) => {
440 let response_time = start_time.elapsed();
441 let response_time_ms = response_time.as_millis() as f64;
442
443 let mut records = health_records.write().await;
445 if let Some(record) = records.get_mut(&conn_id) {
446 let stats = &mut record.statistics;
447 stats.total_checks += 1;
448 stats.last_check = Some(Instant::now());
449
450 if healthy {
451 stats.successful_checks += 1;
452 stats.consecutive_successes += 1;
453 stats.consecutive_failures = 0;
454 stats.last_success = Some(Instant::now());
455
456 debug!(
457 "Connection {} health check passed in {:.2}ms",
458 conn_id, response_time_ms
459 );
460 } else {
461 stats.failed_checks += 1;
462 stats.consecutive_failures += 1;
463 stats.consecutive_successes = 0;
464 stats.last_failure = Some(Instant::now());
465
466 warn!("Connection {} health check failed", conn_id);
467 }
468
469 let old_status = record.status.clone();
471 let new_status = if stats.consecutive_failures
472 >= config.failure_threshold * 2
473 {
474 HealthStatus::Dead
475 } else if stats.consecutive_failures >= config.failure_threshold {
476 HealthStatus::Unhealthy
477 } else if stats.consecutive_successes >= config.recovery_threshold {
478 HealthStatus::Healthy
479 } else {
480 old_status.clone()
481 };
482
483 if old_status != new_status {
484 record.status = new_status.clone();
485 let _ = event_sender.send(HealthEvent::StatusChanged {
486 connection_id: conn_id.clone(),
487 old_status,
488 new_status,
489 });
490 }
491 }
492 }
493 Err(_) => {
494 error!("Health check timeout for connection {}", conn_id);
495
496 let mut records = health_records.write().await;
497 if let Some(record) = records.get_mut(&conn_id) {
498 record.statistics.failed_checks += 1;
499 record.statistics.consecutive_failures += 1;
500 record.statistics.consecutive_successes = 0;
501 *record
502 .statistics
503 .error_counts
504 .entry("timeout".to_string())
505 .or_insert(0) += 1;
506 }
507 }
508 }
509 }
510 }
511 });
512 }
513
514 pub async fn stop_monitoring(&self) {
516 *self.shutdown_signal.write().await = true;
517 }
518
519 pub async fn get_connection_health(
521 &self,
522 connection_id: &str,
523 ) -> Option<ConnectionHealthRecord> {
524 self.health_records.read().await.get(connection_id).cloned()
525 }
526
527 pub async fn get_unhealthy_connections(&self) -> Vec<String> {
529 self.health_records
530 .read()
531 .await
532 .iter()
533 .filter(|(_, record)| {
534 matches!(
535 record.status,
536 HealthStatus::Unhealthy | HealthStatus::Dead | HealthStatus::Degraded
537 )
538 })
539 .map(|(id, _)| id.clone())
540 .collect()
541 }
542
543 pub async fn get_dead_connections(&self) -> Vec<String> {
545 self.health_records
546 .read()
547 .await
548 .iter()
549 .filter(|(_, record)| record.status == HealthStatus::Dead)
550 .map(|(id, _)| id.clone())
551 .collect()
552 }
553
554 pub async fn get_overall_statistics(&self) -> OverallHealthStatistics {
556 let records = self.health_records.read().await;
557
558 let total_connections = records.len();
559 let healthy_connections = records
560 .values()
561 .filter(|r| r.status == HealthStatus::Healthy)
562 .count();
563 let degraded_connections = records
564 .values()
565 .filter(|r| r.status == HealthStatus::Degraded)
566 .count();
567 let unhealthy_connections = records
568 .values()
569 .filter(|r| r.status == HealthStatus::Unhealthy)
570 .count();
571 let dead_connections = records
572 .values()
573 .filter(|r| r.status == HealthStatus::Dead)
574 .count();
575
576 let total_checks: u64 = records.values().map(|r| r.statistics.total_checks).sum();
577 let successful_checks: u64 = records
578 .values()
579 .map(|r| r.statistics.successful_checks)
580 .sum();
581 let failed_checks: u64 = records.values().map(|r| r.statistics.failed_checks).sum();
582
583 let avg_response_time_ms = if total_connections > 0 {
584 records
585 .values()
586 .map(|r| r.statistics.avg_response_time_ms)
587 .sum::<f64>()
588 / total_connections as f64
589 } else {
590 0.0
591 };
592
593 OverallHealthStatistics {
594 total_connections,
595 healthy_connections,
596 degraded_connections,
597 unhealthy_connections,
598 dead_connections,
599 total_checks,
600 successful_checks,
601 failed_checks,
602 success_rate: if total_checks > 0 {
603 (successful_checks as f64 / total_checks as f64) * 100.0
604 } else {
605 0.0
606 },
607 avg_response_time_ms,
608 }
609 }
610
611 pub fn subscribe(&self) -> broadcast::Receiver<HealthEvent> {
613 self.event_sender.subscribe()
614 }
615}
616
617#[derive(Debug, Clone, Serialize, Deserialize)]
619pub struct OverallHealthStatistics {
620 pub total_connections: usize,
621 pub healthy_connections: usize,
622 pub degraded_connections: usize,
623 pub unhealthy_connections: usize,
624 pub dead_connections: usize,
625 pub total_checks: u64,
626 pub successful_checks: u64,
627 pub failed_checks: u64,
628 pub success_rate: f64,
629 pub avg_response_time_ms: f64,
630}
631
632#[cfg(test)]
633mod tests {
634 use super::*;
635 use std::sync::atomic::{AtomicBool, Ordering};
636
637 #[derive(Clone)]
638 struct TestConnection {
639 healthy: Arc<AtomicBool>,
640 }
641
642 #[async_trait::async_trait]
643 impl PooledConnection for TestConnection {
644 async fn is_healthy(&self) -> bool {
645 self.healthy.load(Ordering::Relaxed)
646 }
647
648 async fn close(&mut self) -> Result<()> {
649 Ok(())
650 }
651
652 fn clone_connection(&self) -> Box<dyn PooledConnection> {
653 Box::new(TestConnection {
654 healthy: Arc::new(AtomicBool::new(self.healthy.load(Ordering::Relaxed))),
655 })
656 }
657
658 fn created_at(&self) -> Instant {
659 Instant::now()
660 }
661
662 fn last_activity(&self) -> Instant {
663 Instant::now()
664 }
665
666 fn update_activity(&mut self) {}
667 }
668
669 #[tokio::test]
670 async fn test_health_monitoring() {
671 let config = HealthCheckConfig::default();
672 let monitor = HealthMonitor::<TestConnection>::new(config);
673
674 let metadata = HashMap::new();
676 monitor
677 .register_connection("test-conn-1".to_string(), metadata)
678 .await;
679
680 let healthy_flag = Arc::new(AtomicBool::new(true));
682 let connection = TestConnection {
683 healthy: healthy_flag.clone(),
684 };
685
686 let status = monitor
688 .check_connection_health("test-conn-1", &connection)
689 .await
690 .unwrap();
691 assert_eq!(status, HealthStatus::Unknown); for _ in 0..3 {
695 monitor
696 .check_connection_health("test-conn-1", &connection)
697 .await
698 .unwrap();
699 }
700
701 let health = monitor.get_connection_health("test-conn-1").await.unwrap();
702 assert_eq!(health.status, HealthStatus::Healthy);
703 assert_eq!(health.statistics.consecutive_successes, 4);
704
705 healthy_flag.store(false, Ordering::Relaxed);
707
708 for _ in 0..3 {
710 monitor
711 .check_connection_health("test-conn-1", &connection)
712 .await
713 .unwrap();
714 }
715
716 let health = monitor.get_connection_health("test-conn-1").await.unwrap();
717 assert_eq!(health.status, HealthStatus::Unhealthy);
718 assert_eq!(health.statistics.consecutive_failures, 3);
719
720 let unhealthy = monitor.get_unhealthy_connections().await;
722 assert!(unhealthy.contains(&"test-conn-1".to_string()));
723 }
724
725 #[tokio::test]
726 async fn test_dead_connection_detection() {
727 let config = HealthCheckConfig {
728 failure_threshold: 2,
729 ..Default::default()
730 };
731
732 let monitor = HealthMonitor::<TestConnection>::new(config);
733 monitor
734 .register_connection("test-conn-1".to_string(), HashMap::new())
735 .await;
736
737 let connection = TestConnection {
738 healthy: Arc::new(AtomicBool::new(false)),
739 };
740
741 for _ in 0..5 {
743 monitor
744 .check_connection_health("test-conn-1", &connection)
745 .await
746 .unwrap();
747 }
748
749 let health = monitor.get_connection_health("test-conn-1").await.unwrap();
750 assert_eq!(health.status, HealthStatus::Dead);
751
752 let dead = monitor.get_dead_connections().await;
753 assert!(dead.contains(&"test-conn-1".to_string()));
754 }
755
756 #[tokio::test]
757 async fn test_health_events() {
758 let config = HealthCheckConfig::default();
759 let monitor = HealthMonitor::<TestConnection>::new(config);
760
761 let mut event_receiver = monitor.subscribe();
762
763 monitor
764 .register_connection("test-conn-1".to_string(), HashMap::new())
765 .await;
766
767 let healthy_flag = Arc::new(AtomicBool::new(true));
768 let connection = TestConnection {
769 healthy: healthy_flag.clone(),
770 };
771
772 for _ in 0..3 {
774 monitor
775 .check_connection_health("test-conn-1", &connection)
776 .await
777 .unwrap();
778 }
779
780 healthy_flag.store(false, Ordering::Relaxed);
781
782 for _ in 0..3 {
783 monitor
784 .check_connection_health("test-conn-1", &connection)
785 .await
786 .unwrap();
787 }
788
789 tokio::time::timeout(Duration::from_secs(1), async {
791 while let Ok(event) = event_receiver.recv().await {
792 if matches!(event, HealthEvent::StatusChanged { .. }) {
793 return;
794 }
795 }
796 })
797 .await
798 .expect("Should receive status change event");
799 }
800}