Skip to main content

simple_agents_router/
health.rs

1//! Provider health tracking for routing decisions.
2//!
3//! Maintains per-provider metrics and health state.
4
5use simple_agent_type::prelude::{ProviderHealth, ProviderMetrics};
6use std::sync::Mutex;
7use std::time::Duration;
8
9/// Configuration for health tracking.
10#[derive(Debug, Clone, Copy)]
11pub struct HealthTrackerConfig {
12    /// Failure rate above which providers are degraded.
13    pub degrade_threshold: f32,
14    /// Failure rate above which providers are marked unavailable.
15    pub unavailable_threshold: f32,
16    /// Exponential moving average factor for latency.
17    pub latency_alpha: f64,
18}
19
20impl Default for HealthTrackerConfig {
21    fn default() -> Self {
22        Self {
23            degrade_threshold: 0.2,
24            unavailable_threshold: 0.5,
25            latency_alpha: 0.2,
26        }
27    }
28}
29
30/// Health tracker for providers.
31#[derive(Debug)]
32pub struct HealthTracker {
33    metrics: Mutex<Vec<ProviderMetrics>>,
34    config: HealthTrackerConfig,
35}
36
37impl HealthTracker {
38    /// Create a tracker for the given number of providers.
39    pub fn new(provider_count: usize, config: HealthTrackerConfig) -> Self {
40        let metrics = vec![ProviderMetrics::default(); provider_count];
41        Self {
42            metrics: Mutex::new(metrics),
43            config,
44        }
45    }
46
47    /// Record a successful request.
48    pub fn record_success(&self, provider_index: usize, latency: Duration) {
49        let mut metrics = self.metrics.lock().expect("health tracker lock poisoned");
50        if let Some(entry) = metrics.get_mut(provider_index) {
51            entry.total_requests = entry.total_requests.saturating_add(1);
52            entry.successful_requests = entry.successful_requests.saturating_add(1);
53            entry.avg_latency =
54                update_latency(entry.avg_latency, latency, self.config.latency_alpha);
55            entry.health = compute_health_with_config(entry, self.config);
56        }
57    }
58
59    /// Record a failed request.
60    pub fn record_failure(&self, provider_index: usize, latency: Option<Duration>) {
61        let mut metrics = self.metrics.lock().expect("health tracker lock poisoned");
62        if let Some(entry) = metrics.get_mut(provider_index) {
63            entry.total_requests = entry.total_requests.saturating_add(1);
64            entry.failed_requests = entry.failed_requests.saturating_add(1);
65            if let Some(value) = latency {
66                entry.avg_latency =
67                    update_latency(entry.avg_latency, value, self.config.latency_alpha);
68            }
69            entry.health = compute_health_with_config(entry, self.config);
70        }
71    }
72
73    /// Get metrics for a provider.
74    pub fn metrics(&self, provider_index: usize) -> Option<ProviderMetrics> {
75        let metrics = self.metrics.lock().expect("health tracker lock poisoned");
76        metrics.get(provider_index).copied()
77    }
78
79    /// Get health for a provider.
80    pub fn health(&self, provider_index: usize) -> Option<ProviderHealth> {
81        self.metrics(provider_index).map(|entry| entry.health)
82    }
83}
84
85fn update_latency(current: Duration, new_value: Duration, alpha: f64) -> Duration {
86    if current.as_millis() == 0 {
87        return new_value;
88    }
89    let current_ms = current.as_secs_f64() * 1000.0;
90    let new_ms = new_value.as_secs_f64() * 1000.0;
91    let ema = (alpha * new_ms) + ((1.0 - alpha) * current_ms);
92    Duration::from_millis(ema.max(0.0) as u64)
93}
94
95fn compute_health_with_config(
96    metrics: &ProviderMetrics,
97    config: HealthTrackerConfig,
98) -> ProviderHealth {
99    let failure_rate = metrics.failure_rate();
100    if failure_rate >= config.unavailable_threshold {
101        ProviderHealth::Unavailable
102    } else if failure_rate >= config.degrade_threshold {
103        ProviderHealth::Degraded
104    } else {
105        ProviderHealth::Healthy
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112
113    #[test]
114    fn success_updates_metrics() {
115        let tracker = HealthTracker::new(1, HealthTrackerConfig::default());
116        tracker.record_success(0, Duration::from_millis(100));
117        let metrics = tracker.metrics(0).unwrap();
118        assert_eq!(metrics.total_requests, 1);
119        assert_eq!(metrics.successful_requests, 1);
120        assert_eq!(metrics.failed_requests, 0);
121        assert_eq!(metrics.health, ProviderHealth::Healthy);
122    }
123
124    #[test]
125    fn failures_degrade_health() {
126        let config = HealthTrackerConfig {
127            degrade_threshold: 0.2,
128            unavailable_threshold: 0.5,
129            latency_alpha: 0.2,
130        };
131        let tracker = HealthTracker::new(1, config);
132
133        tracker.record_failure(0, Some(Duration::from_millis(50)));
134        tracker.record_failure(0, Some(Duration::from_millis(50)));
135        tracker.record_success(0, Duration::from_millis(50));
136        tracker.record_failure(0, Some(Duration::from_millis(50)));
137
138        let metrics = tracker.metrics(0).unwrap();
139        assert_eq!(metrics.total_requests, 4);
140        assert_eq!(metrics.failed_requests, 3);
141        assert_eq!(metrics.health, ProviderHealth::Unavailable);
142    }
143
144    #[test]
145    fn metrics_out_of_range_returns_none() {
146        let tracker = HealthTracker::new(1, HealthTrackerConfig::default());
147        assert!(tracker.metrics(5).is_none());
148        assert!(tracker.health(2).is_none());
149    }
150}