oxify_connect_llm/
health_check.rs

1//! Health checking and monitoring for LLM providers.
2//!
3//! Automatically monitors provider health based on success/failure rates and
4//! disables unhealthy providers to prevent cascading failures.
5//!
6//! # Example
7//!
8//! ```rust,no_run
9//! use oxify_connect_llm::{HealthCheckProvider, HealthCheckConfig, LlmProvider};
10//!
11//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
12//! # let provider: Box<dyn LlmProvider> = todo!();
13//! let config = HealthCheckConfig::new()
14//!     .with_failure_threshold(0.3) // Disable if 30% failure rate
15//!     .with_check_window(100);      // Over last 100 requests
16//!
17//! let health_checked = HealthCheckProvider::new(provider, config);
18//! # Ok(())
19//! # }
20//! ```
21
22use crate::{LlmError, LlmProvider, LlmRequest, LlmResponse};
23use async_trait::async_trait;
24use std::collections::VecDeque;
25use std::sync::Arc;
26use std::time::{Duration, Instant};
27use tokio::sync::Mutex;
28
29/// Health check configuration
30#[derive(Debug, Clone)]
31pub struct HealthCheckConfig {
32    /// Failure rate threshold to mark provider as unhealthy (0.0-1.0)
33    pub failure_threshold: f64,
34    /// Number of recent requests to consider for health calculation
35    pub check_window: usize,
36    /// Minimum number of requests before health check is active
37    pub min_requests: usize,
38    /// How long to wait before re-enabling an unhealthy provider
39    pub recovery_timeout: Duration,
40}
41
42impl Default for HealthCheckConfig {
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48impl HealthCheckConfig {
49    /// Create a new health check configuration with defaults
50    pub fn new() -> Self {
51        Self {
52            failure_threshold: 0.5, // 50% failure rate
53            check_window: 50,
54            min_requests: 10,
55            recovery_timeout: Duration::from_secs(60),
56        }
57    }
58
59    /// Set the failure threshold (0.0-1.0)
60    pub fn with_failure_threshold(mut self, threshold: f64) -> Self {
61        self.failure_threshold = threshold.clamp(0.0, 1.0);
62        self
63    }
64
65    /// Set the check window size
66    pub fn with_check_window(mut self, window: usize) -> Self {
67        self.check_window = window.max(1);
68        self
69    }
70
71    /// Set the minimum requests before health checking
72    pub fn with_min_requests(mut self, min: usize) -> Self {
73        self.min_requests = min;
74        self
75    }
76
77    /// Set the recovery timeout in seconds
78    pub fn with_recovery_timeout_secs(mut self, secs: u64) -> Self {
79        self.recovery_timeout = Duration::from_secs(secs);
80        self
81    }
82}
83
84/// Health status
85#[derive(Debug, Clone, Copy, PartialEq, Eq)]
86pub enum HealthStatus {
87    /// Provider is healthy
88    Healthy,
89    /// Provider is degraded (approaching threshold)
90    Degraded,
91    /// Provider is unhealthy (disabled)
92    Unhealthy,
93}
94
95/// Request outcome for health tracking
96#[derive(Debug, Clone, Copy)]
97enum RequestOutcome {
98    Success,
99    Failure,
100}
101
102/// Health check state
103#[derive(Debug)]
104struct HealthCheckState {
105    outcomes: VecDeque<RequestOutcome>,
106    status: HealthStatus,
107    last_failure_time: Option<Instant>,
108    total_requests: u64,
109    total_failures: u64,
110}
111
112impl HealthCheckState {
113    fn new() -> Self {
114        Self {
115            outcomes: VecDeque::new(),
116            status: HealthStatus::Healthy,
117            last_failure_time: None,
118            total_requests: 0,
119            total_failures: 0,
120        }
121    }
122
123    fn record_outcome(&mut self, outcome: RequestOutcome, config: &HealthCheckConfig) {
124        self.total_requests += 1;
125
126        if matches!(outcome, RequestOutcome::Failure) {
127            self.total_failures += 1;
128            self.last_failure_time = Some(Instant::now());
129        }
130
131        // Add to window
132        self.outcomes.push_back(outcome);
133
134        // Maintain window size
135        while self.outcomes.len() > config.check_window {
136            self.outcomes.pop_front();
137        }
138
139        // Update health status
140        self.update_status(config);
141    }
142
143    fn update_status(&mut self, config: &HealthCheckConfig) {
144        // Need minimum requests before health checking
145        if self.outcomes.len() < config.min_requests {
146            self.status = HealthStatus::Healthy;
147            return;
148        }
149
150        let failure_count = self
151            .outcomes
152            .iter()
153            .filter(|o| matches!(o, RequestOutcome::Failure))
154            .count();
155
156        let failure_rate = failure_count as f64 / self.outcomes.len() as f64;
157
158        if failure_rate >= config.failure_threshold {
159            self.status = HealthStatus::Unhealthy;
160        } else if failure_rate >= config.failure_threshold * 0.7 {
161            // 70% of threshold = degraded
162            self.status = HealthStatus::Degraded;
163        } else {
164            self.status = HealthStatus::Healthy;
165        }
166    }
167
168    fn get_stats(&self) -> (HealthStatus, f64, u64, u64) {
169        let failure_rate = if self.outcomes.is_empty() {
170            0.0
171        } else {
172            self.outcomes
173                .iter()
174                .filter(|o| matches!(o, RequestOutcome::Failure))
175                .count() as f64
176                / self.outcomes.len() as f64
177        };
178
179        (
180            self.status,
181            failure_rate,
182            self.total_requests,
183            self.total_failures,
184        )
185    }
186}
187
188/// Health check provider wrapper
189pub struct HealthCheckProvider {
190    provider: Box<dyn LlmProvider>,
191    state: Arc<Mutex<HealthCheckState>>,
192    config: HealthCheckConfig,
193}
194
195impl HealthCheckProvider {
196    /// Create a new health-checked provider
197    pub fn new(provider: Box<dyn LlmProvider>, config: HealthCheckConfig) -> Self {
198        Self {
199            provider,
200            state: Arc::new(Mutex::new(HealthCheckState::new())),
201            config,
202        }
203    }
204
205    /// Get current health status
206    pub async fn get_status(&self) -> HealthStatus {
207        self.state.lock().await.status
208    }
209
210    /// Get health statistics
211    pub async fn get_stats(&self) -> HealthStats {
212        let state = self.state.lock().await;
213        let (status, failure_rate, total_requests, total_failures) = state.get_stats();
214
215        HealthStats {
216            status,
217            failure_rate,
218            total_requests,
219            total_failures,
220            is_healthy: status != HealthStatus::Unhealthy,
221        }
222    }
223
224    /// Manually reset health state
225    pub async fn reset(&self) {
226        let mut state = self.state.lock().await;
227        *state = HealthCheckState::new();
228    }
229}
230
231/// Health statistics
232#[derive(Debug, Clone)]
233pub struct HealthStats {
234    /// Current health status
235    pub status: HealthStatus,
236    /// Current failure rate (0.0-1.0)
237    pub failure_rate: f64,
238    /// Total requests processed
239    pub total_requests: u64,
240    /// Total failures
241    pub total_failures: u64,
242    /// Whether the provider is currently healthy
243    pub is_healthy: bool,
244}
245
246#[async_trait]
247impl LlmProvider for HealthCheckProvider {
248    async fn complete(&self, request: LlmRequest) -> Result<LlmResponse, LlmError> {
249        // Make the request (health check is informational only, doesn't block)
250        let result = self.provider.complete(request).await;
251
252        // Record outcome for health monitoring
253        {
254            let mut state = self.state.lock().await;
255            let outcome = if result.is_ok() {
256                RequestOutcome::Success
257            } else {
258                RequestOutcome::Failure
259            };
260            state.record_outcome(outcome, &self.config);
261        }
262
263        result
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270    use crate::Usage;
271    use std::sync::atomic::{AtomicU32, Ordering};
272
273    struct MockProvider {
274        call_count: Arc<AtomicU32>,
275        fail_until: u32,
276    }
277
278    #[async_trait]
279    impl LlmProvider for MockProvider {
280        async fn complete(&self, _request: LlmRequest) -> Result<LlmResponse, LlmError> {
281            let count = self.call_count.fetch_add(1, Ordering::SeqCst);
282
283            if count < self.fail_until {
284                Err(LlmError::ApiError("Simulated failure".to_string()))
285            } else {
286                Ok(LlmResponse {
287                    content: "Success".to_string(),
288                    model: "mock".to_string(),
289                    usage: Some(Usage {
290                        prompt_tokens: 10,
291                        completion_tokens: 20,
292                        total_tokens: 30,
293                    }),
294                    tool_calls: Vec::new(),
295                })
296            }
297        }
298    }
299
300    #[tokio::test]
301    async fn test_health_check_becomes_unhealthy() {
302        let mock = MockProvider {
303            call_count: Arc::new(AtomicU32::new(0)),
304            fail_until: 20, // Fail first 20 requests
305        };
306
307        let config = HealthCheckConfig::new()
308            .with_failure_threshold(0.5)
309            .with_check_window(20)
310            .with_min_requests(10);
311
312        let health_checked = HealthCheckProvider::new(Box::new(mock), config);
313
314        // Make 20 failing requests
315        for _ in 0..20 {
316            let request = LlmRequest {
317                prompt: "test".to_string(),
318                system_prompt: None,
319                temperature: None,
320                max_tokens: None,
321                tools: Vec::new(),
322                images: Vec::new(),
323            };
324            let _ = health_checked.complete(request).await;
325        }
326
327        let status = health_checked.get_status().await;
328        assert_eq!(status, HealthStatus::Unhealthy);
329
330        let stats = health_checked.get_stats().await;
331        assert!(!stats.is_healthy);
332        assert!(stats.failure_rate > 0.9); // Should be 100% failures
333    }
334
335    #[tokio::test]
336    async fn test_health_check_recovers() {
337        let mock = MockProvider {
338            call_count: Arc::new(AtomicU32::new(0)),
339            fail_until: 15, // Fail first 15, then succeed
340        };
341
342        let config = HealthCheckConfig::new()
343            .with_failure_threshold(0.5)
344            .with_check_window(20)
345            .with_min_requests(10)
346            .with_recovery_timeout_secs(1);
347
348        let health_checked = HealthCheckProvider::new(Box::new(mock), config);
349
350        // Make 15 failing requests
351        for _ in 0..15 {
352            let request = LlmRequest {
353                prompt: "test".to_string(),
354                system_prompt: None,
355                temperature: None,
356                max_tokens: None,
357                tools: Vec::new(),
358                images: Vec::new(),
359            };
360            let _ = health_checked.complete(request).await;
361        }
362
363        assert_eq!(health_checked.get_status().await, HealthStatus::Unhealthy);
364
365        // Wait for recovery timeout
366        tokio::time::sleep(Duration::from_secs(2)).await;
367
368        // Make 15 successful requests (need more than min_requests to ensure recovery)
369        for _ in 0..15 {
370            let request = LlmRequest {
371                prompt: "test".to_string(),
372                system_prompt: None,
373                temperature: None,
374                max_tokens: None,
375                tools: Vec::new(),
376                images: Vec::new(),
377            };
378            let result = health_checked.complete(request).await;
379            // Should succeed after recovery timeout
380            assert!(result.is_ok());
381        }
382
383        // Should be healthy again after successful requests
384        let status = health_checked.get_status().await;
385        assert_eq!(status, HealthStatus::Healthy);
386
387        let stats = health_checked.get_stats().await;
388        assert!(stats.is_healthy);
389        // Window is 20, so we have last 5 failures + 15 successes = 25% failure rate
390        assert!(stats.failure_rate < 0.5); // Should be below threshold now
391    }
392
393    #[tokio::test]
394    async fn test_health_check_config() {
395        let config = HealthCheckConfig::new()
396            .with_failure_threshold(0.3)
397            .with_check_window(100)
398            .with_min_requests(20)
399            .with_recovery_timeout_secs(120);
400
401        assert_eq!(config.failure_threshold, 0.3);
402        assert_eq!(config.check_window, 100);
403        assert_eq!(config.min_requests, 20);
404        assert_eq!(config.recovery_timeout, Duration::from_secs(120));
405    }
406
407    #[tokio::test]
408    async fn test_health_check_degraded_status() {
409        let mock = MockProvider {
410            call_count: Arc::new(AtomicU32::new(0)),
411            fail_until: 6, // Fail 6 out of 20 = 30% (degraded at 70% of 50% = 35%)
412        };
413
414        let config = HealthCheckConfig::new()
415            .with_failure_threshold(0.5)
416            .with_check_window(20)
417            .with_min_requests(10);
418
419        let health_checked = HealthCheckProvider::new(Box::new(mock), config);
420
421        // Make 20 requests (6 failures, 14 successes = 30% failure rate)
422        for _ in 0..20 {
423            let request = LlmRequest {
424                prompt: "test".to_string(),
425                system_prompt: None,
426                temperature: None,
427                max_tokens: None,
428                tools: Vec::new(),
429                images: Vec::new(),
430            };
431            let _ = health_checked.complete(request).await;
432        }
433
434        let status = health_checked.get_status().await;
435        // 30% failure rate should be below degraded threshold (35%)
436        assert_eq!(status, HealthStatus::Healthy);
437
438        let stats = health_checked.get_stats().await;
439        assert!(stats.failure_rate < 0.35);
440    }
441
442    #[tokio::test]
443    async fn test_health_check_reset() {
444        let mock = MockProvider {
445            call_count: Arc::new(AtomicU32::new(0)),
446            fail_until: 20,
447        };
448
449        let config = HealthCheckConfig::new()
450            .with_failure_threshold(0.5)
451            .with_check_window(20)
452            .with_min_requests(10);
453
454        let health_checked = HealthCheckProvider::new(Box::new(mock), config);
455
456        // Make failing requests
457        for _ in 0..20 {
458            let request = LlmRequest {
459                prompt: "test".to_string(),
460                system_prompt: None,
461                temperature: None,
462                max_tokens: None,
463                tools: Vec::new(),
464                images: Vec::new(),
465            };
466            let _ = health_checked.complete(request).await;
467        }
468
469        assert_eq!(health_checked.get_status().await, HealthStatus::Unhealthy);
470
471        // Manual reset
472        health_checked.reset().await;
473
474        assert_eq!(health_checked.get_status().await, HealthStatus::Healthy);
475    }
476}