1use crate::{LlmError, LlmProvider, LlmRequest, LlmResponse};
23use async_trait::async_trait;
24use std::collections::VecDeque;
25use std::sync::Arc;
26use std::time::{Duration, Instant};
27use tokio::sync::Mutex;
28
29#[derive(Debug, Clone)]
31pub struct HealthCheckConfig {
32 pub failure_threshold: f64,
34 pub check_window: usize,
36 pub min_requests: usize,
38 pub recovery_timeout: Duration,
40}
41
42impl Default for HealthCheckConfig {
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48impl HealthCheckConfig {
49 pub fn new() -> Self {
51 Self {
52 failure_threshold: 0.5, check_window: 50,
54 min_requests: 10,
55 recovery_timeout: Duration::from_secs(60),
56 }
57 }
58
59 pub fn with_failure_threshold(mut self, threshold: f64) -> Self {
61 self.failure_threshold = threshold.clamp(0.0, 1.0);
62 self
63 }
64
65 pub fn with_check_window(mut self, window: usize) -> Self {
67 self.check_window = window.max(1);
68 self
69 }
70
71 pub fn with_min_requests(mut self, min: usize) -> Self {
73 self.min_requests = min;
74 self
75 }
76
77 pub fn with_recovery_timeout_secs(mut self, secs: u64) -> Self {
79 self.recovery_timeout = Duration::from_secs(secs);
80 self
81 }
82}
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq)]
86pub enum HealthStatus {
87 Healthy,
89 Degraded,
91 Unhealthy,
93}
94
95#[derive(Debug, Clone, Copy)]
97enum RequestOutcome {
98 Success,
99 Failure,
100}
101
102#[derive(Debug)]
104struct HealthCheckState {
105 outcomes: VecDeque<RequestOutcome>,
106 status: HealthStatus,
107 last_failure_time: Option<Instant>,
108 total_requests: u64,
109 total_failures: u64,
110}
111
112impl HealthCheckState {
113 fn new() -> Self {
114 Self {
115 outcomes: VecDeque::new(),
116 status: HealthStatus::Healthy,
117 last_failure_time: None,
118 total_requests: 0,
119 total_failures: 0,
120 }
121 }
122
123 fn record_outcome(&mut self, outcome: RequestOutcome, config: &HealthCheckConfig) {
124 self.total_requests += 1;
125
126 if matches!(outcome, RequestOutcome::Failure) {
127 self.total_failures += 1;
128 self.last_failure_time = Some(Instant::now());
129 }
130
131 self.outcomes.push_back(outcome);
133
134 while self.outcomes.len() > config.check_window {
136 self.outcomes.pop_front();
137 }
138
139 self.update_status(config);
141 }
142
143 fn update_status(&mut self, config: &HealthCheckConfig) {
144 if self.outcomes.len() < config.min_requests {
146 self.status = HealthStatus::Healthy;
147 return;
148 }
149
150 let failure_count = self
151 .outcomes
152 .iter()
153 .filter(|o| matches!(o, RequestOutcome::Failure))
154 .count();
155
156 let failure_rate = failure_count as f64 / self.outcomes.len() as f64;
157
158 if failure_rate >= config.failure_threshold {
159 self.status = HealthStatus::Unhealthy;
160 } else if failure_rate >= config.failure_threshold * 0.7 {
161 self.status = HealthStatus::Degraded;
163 } else {
164 self.status = HealthStatus::Healthy;
165 }
166 }
167
168 fn get_stats(&self) -> (HealthStatus, f64, u64, u64) {
169 let failure_rate = if self.outcomes.is_empty() {
170 0.0
171 } else {
172 self.outcomes
173 .iter()
174 .filter(|o| matches!(o, RequestOutcome::Failure))
175 .count() as f64
176 / self.outcomes.len() as f64
177 };
178
179 (
180 self.status,
181 failure_rate,
182 self.total_requests,
183 self.total_failures,
184 )
185 }
186}
187
188pub struct HealthCheckProvider {
190 provider: Box<dyn LlmProvider>,
191 state: Arc<Mutex<HealthCheckState>>,
192 config: HealthCheckConfig,
193}
194
195impl HealthCheckProvider {
196 pub fn new(provider: Box<dyn LlmProvider>, config: HealthCheckConfig) -> Self {
198 Self {
199 provider,
200 state: Arc::new(Mutex::new(HealthCheckState::new())),
201 config,
202 }
203 }
204
205 pub async fn get_status(&self) -> HealthStatus {
207 self.state.lock().await.status
208 }
209
210 pub async fn get_stats(&self) -> HealthStats {
212 let state = self.state.lock().await;
213 let (status, failure_rate, total_requests, total_failures) = state.get_stats();
214
215 HealthStats {
216 status,
217 failure_rate,
218 total_requests,
219 total_failures,
220 is_healthy: status != HealthStatus::Unhealthy,
221 }
222 }
223
224 pub async fn reset(&self) {
226 let mut state = self.state.lock().await;
227 *state = HealthCheckState::new();
228 }
229}
230
231#[derive(Debug, Clone)]
233pub struct HealthStats {
234 pub status: HealthStatus,
236 pub failure_rate: f64,
238 pub total_requests: u64,
240 pub total_failures: u64,
242 pub is_healthy: bool,
244}
245
246#[async_trait]
247impl LlmProvider for HealthCheckProvider {
248 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse, LlmError> {
249 let result = self.provider.complete(request).await;
251
252 {
254 let mut state = self.state.lock().await;
255 let outcome = if result.is_ok() {
256 RequestOutcome::Success
257 } else {
258 RequestOutcome::Failure
259 };
260 state.record_outcome(outcome, &self.config);
261 }
262
263 result
264 }
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270 use crate::Usage;
271 use std::sync::atomic::{AtomicU32, Ordering};
272
273 struct MockProvider {
274 call_count: Arc<AtomicU32>,
275 fail_until: u32,
276 }
277
278 #[async_trait]
279 impl LlmProvider for MockProvider {
280 async fn complete(&self, _request: LlmRequest) -> Result<LlmResponse, LlmError> {
281 let count = self.call_count.fetch_add(1, Ordering::SeqCst);
282
283 if count < self.fail_until {
284 Err(LlmError::ApiError("Simulated failure".to_string()))
285 } else {
286 Ok(LlmResponse {
287 content: "Success".to_string(),
288 model: "mock".to_string(),
289 usage: Some(Usage {
290 prompt_tokens: 10,
291 completion_tokens: 20,
292 total_tokens: 30,
293 }),
294 tool_calls: Vec::new(),
295 })
296 }
297 }
298 }
299
300 #[tokio::test]
301 async fn test_health_check_becomes_unhealthy() {
302 let mock = MockProvider {
303 call_count: Arc::new(AtomicU32::new(0)),
304 fail_until: 20, };
306
307 let config = HealthCheckConfig::new()
308 .with_failure_threshold(0.5)
309 .with_check_window(20)
310 .with_min_requests(10);
311
312 let health_checked = HealthCheckProvider::new(Box::new(mock), config);
313
314 for _ in 0..20 {
316 let request = LlmRequest {
317 prompt: "test".to_string(),
318 system_prompt: None,
319 temperature: None,
320 max_tokens: None,
321 tools: Vec::new(),
322 images: Vec::new(),
323 };
324 let _ = health_checked.complete(request).await;
325 }
326
327 let status = health_checked.get_status().await;
328 assert_eq!(status, HealthStatus::Unhealthy);
329
330 let stats = health_checked.get_stats().await;
331 assert!(!stats.is_healthy);
332 assert!(stats.failure_rate > 0.9); }
334
335 #[tokio::test]
336 async fn test_health_check_recovers() {
337 let mock = MockProvider {
338 call_count: Arc::new(AtomicU32::new(0)),
339 fail_until: 15, };
341
342 let config = HealthCheckConfig::new()
343 .with_failure_threshold(0.5)
344 .with_check_window(20)
345 .with_min_requests(10)
346 .with_recovery_timeout_secs(1);
347
348 let health_checked = HealthCheckProvider::new(Box::new(mock), config);
349
350 for _ in 0..15 {
352 let request = LlmRequest {
353 prompt: "test".to_string(),
354 system_prompt: None,
355 temperature: None,
356 max_tokens: None,
357 tools: Vec::new(),
358 images: Vec::new(),
359 };
360 let _ = health_checked.complete(request).await;
361 }
362
363 assert_eq!(health_checked.get_status().await, HealthStatus::Unhealthy);
364
365 tokio::time::sleep(Duration::from_secs(2)).await;
367
368 for _ in 0..15 {
370 let request = LlmRequest {
371 prompt: "test".to_string(),
372 system_prompt: None,
373 temperature: None,
374 max_tokens: None,
375 tools: Vec::new(),
376 images: Vec::new(),
377 };
378 let result = health_checked.complete(request).await;
379 assert!(result.is_ok());
381 }
382
383 let status = health_checked.get_status().await;
385 assert_eq!(status, HealthStatus::Healthy);
386
387 let stats = health_checked.get_stats().await;
388 assert!(stats.is_healthy);
389 assert!(stats.failure_rate < 0.5); }
392
393 #[tokio::test]
394 async fn test_health_check_config() {
395 let config = HealthCheckConfig::new()
396 .with_failure_threshold(0.3)
397 .with_check_window(100)
398 .with_min_requests(20)
399 .with_recovery_timeout_secs(120);
400
401 assert_eq!(config.failure_threshold, 0.3);
402 assert_eq!(config.check_window, 100);
403 assert_eq!(config.min_requests, 20);
404 assert_eq!(config.recovery_timeout, Duration::from_secs(120));
405 }
406
407 #[tokio::test]
408 async fn test_health_check_degraded_status() {
409 let mock = MockProvider {
410 call_count: Arc::new(AtomicU32::new(0)),
411 fail_until: 6, };
413
414 let config = HealthCheckConfig::new()
415 .with_failure_threshold(0.5)
416 .with_check_window(20)
417 .with_min_requests(10);
418
419 let health_checked = HealthCheckProvider::new(Box::new(mock), config);
420
421 for _ in 0..20 {
423 let request = LlmRequest {
424 prompt: "test".to_string(),
425 system_prompt: None,
426 temperature: None,
427 max_tokens: None,
428 tools: Vec::new(),
429 images: Vec::new(),
430 };
431 let _ = health_checked.complete(request).await;
432 }
433
434 let status = health_checked.get_status().await;
435 assert_eq!(status, HealthStatus::Healthy);
437
438 let stats = health_checked.get_stats().await;
439 assert!(stats.failure_rate < 0.35);
440 }
441
442 #[tokio::test]
443 async fn test_health_check_reset() {
444 let mock = MockProvider {
445 call_count: Arc::new(AtomicU32::new(0)),
446 fail_until: 20,
447 };
448
449 let config = HealthCheckConfig::new()
450 .with_failure_threshold(0.5)
451 .with_check_window(20)
452 .with_min_requests(10);
453
454 let health_checked = HealthCheckProvider::new(Box::new(mock), config);
455
456 for _ in 0..20 {
458 let request = LlmRequest {
459 prompt: "test".to_string(),
460 system_prompt: None,
461 temperature: None,
462 max_tokens: None,
463 tools: Vec::new(),
464 images: Vec::new(),
465 };
466 let _ = health_checked.complete(request).await;
467 }
468
469 assert_eq!(health_checked.get_status().await, HealthStatus::Unhealthy);
470
471 health_checked.reset().await;
473
474 assert_eq!(health_checked.get_status().await, HealthStatus::Healthy);
475 }
476}