1use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6use tokio::sync::RwLock;
7use tracing::{debug, warn};
8
9use crate::error::ProviderError;
10use crate::provider::Provider;
11
12#[derive(Clone, Debug)]
14pub struct HealthCheckResult {
15 pub is_healthy: bool,
17 pub checked_at: Instant,
19 pub error: Option<String>,
21}
22
23impl HealthCheckResult {
24 pub fn is_valid(&self, ttl: Duration) -> bool {
26 self.checked_at.elapsed() < ttl
27 }
28}
29
30pub struct HealthCheckCache {
32 cache: Arc<RwLock<HashMap<String, HealthCheckResult>>>,
34 ttl: Duration,
36 timeout: Duration,
38}
39
40impl HealthCheckCache {
41 pub fn new(ttl: Duration, timeout: Duration) -> Self {
43 Self {
44 cache: Arc::new(RwLock::new(HashMap::new())),
45 ttl,
46 timeout,
47 }
48 }
49}
50
51impl Default for HealthCheckCache {
52 fn default() -> Self {
56 Self::new(Duration::from_secs(300), Duration::from_secs(10))
57 }
58}
59
60impl HealthCheckCache {
61 pub async fn check_health(&self, provider: &Arc<dyn Provider>) -> Result<bool, ProviderError> {
63 let provider_id = provider.id();
64
65 {
67 let cache = self.cache.read().await;
68 if let Some(result) = cache.get(provider_id) {
69 if result.is_valid(self.ttl) {
70 debug!(
71 "Using cached health check for provider: {} (healthy: {})",
72 provider_id, result.is_healthy
73 );
74 return if result.is_healthy {
75 Ok(true)
76 } else {
77 Err(ProviderError::ProviderError(
78 result
79 .error
80 .clone()
81 .unwrap_or_else(|| "Provider unhealthy".to_string()),
82 ))
83 };
84 }
85 }
86 }
87
88 debug!("Performing health check for provider: {}", provider_id);
90 let result = match tokio::time::timeout(self.timeout, provider.health_check()).await {
91 Ok(Ok(is_healthy)) => HealthCheckResult {
92 is_healthy,
93 checked_at: Instant::now(),
94 error: None,
95 },
96 Ok(Err(e)) => {
97 warn!("Health check failed for provider {}: {}", provider_id, e);
98 HealthCheckResult {
99 is_healthy: false,
100 checked_at: Instant::now(),
101 error: Some(e.to_string()),
102 }
103 }
104 Err(_) => {
105 warn!("Health check timeout for provider: {}", provider_id);
106 HealthCheckResult {
107 is_healthy: false,
108 checked_at: Instant::now(),
109 error: Some("Health check timeout".to_string()),
110 }
111 }
112 };
113
114 {
116 let mut cache = self.cache.write().await;
117 cache.insert(provider_id.to_string(), result.clone());
118 }
119
120 if result.is_healthy {
121 Ok(true)
122 } else {
123 Err(ProviderError::ProviderError(
124 result
125 .error
126 .unwrap_or_else(|| "Provider unhealthy".to_string()),
127 ))
128 }
129 }
130
131 pub async fn invalidate(&self, provider_id: &str) {
133 let mut cache = self.cache.write().await;
134 cache.remove(provider_id);
135 debug!(
136 "Invalidated health check cache for provider: {}",
137 provider_id
138 );
139 }
140
141 pub async fn invalidate_all(&self) {
143 let mut cache = self.cache.write().await;
144 cache.clear();
145 debug!("Invalidated all health check cache");
146 }
147
148 pub async fn get_cached(&self, provider_id: &str) -> Option<HealthCheckResult> {
150 let cache = self.cache.read().await;
151 cache.get(provider_id).cloned()
152 }
153
154 pub fn with_ttl(mut self, ttl: Duration) -> Self {
156 self.ttl = ttl;
157 self
158 }
159
160 pub fn with_timeout(mut self, timeout: Duration) -> Self {
162 self.timeout = timeout;
163 self
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170 use crate::models::ChatRequest;
171 use crate::models::{ChatResponse, FinishReason, TokenUsage};
172 use crate::provider::Provider;
173 use async_trait::async_trait;
174
175 struct MockHealthyProvider;
176
177 #[async_trait]
178 impl Provider for MockHealthyProvider {
179 fn id(&self) -> &str {
180 "mock-healthy"
181 }
182
183 fn name(&self) -> &str {
184 "Mock Healthy"
185 }
186
187 fn models(&self) -> Vec<crate::models::ModelInfo> {
188 vec![]
189 }
190
191 async fn chat(&self, _request: ChatRequest) -> Result<ChatResponse, ProviderError> {
192 Ok(ChatResponse {
193 content: "test".to_string(),
194 model: "test".to_string(),
195 usage: TokenUsage {
196 prompt_tokens: 0,
197 completion_tokens: 0,
198 total_tokens: 0,
199 },
200 finish_reason: FinishReason::Stop,
201 })
202 }
203
204 async fn chat_stream(
205 &self,
206 _request: ChatRequest,
207 ) -> Result<crate::provider::ChatStream, ProviderError> {
208 Err(ProviderError::NotFound("Not implemented".to_string()))
209 }
210
211 fn count_tokens(&self, _content: &str, _model: &str) -> Result<usize, ProviderError> {
212 Ok(0)
213 }
214
215 async fn health_check(&self) -> Result<bool, ProviderError> {
216 Ok(true)
217 }
218 }
219
220 struct MockUnhealthyProvider;
221
222 #[async_trait]
223 impl Provider for MockUnhealthyProvider {
224 fn id(&self) -> &str {
225 "mock-unhealthy"
226 }
227
228 fn name(&self) -> &str {
229 "Mock Unhealthy"
230 }
231
232 fn models(&self) -> Vec<crate::models::ModelInfo> {
233 vec![]
234 }
235
236 async fn chat(&self, _request: ChatRequest) -> Result<ChatResponse, ProviderError> {
237 Ok(ChatResponse {
238 content: "test".to_string(),
239 model: "test".to_string(),
240 usage: TokenUsage {
241 prompt_tokens: 0,
242 completion_tokens: 0,
243 total_tokens: 0,
244 },
245 finish_reason: FinishReason::Stop,
246 })
247 }
248
249 async fn chat_stream(
250 &self,
251 _request: ChatRequest,
252 ) -> Result<crate::provider::ChatStream, ProviderError> {
253 Err(ProviderError::NotFound("Not implemented".to_string()))
254 }
255
256 fn count_tokens(&self, _content: &str, _model: &str) -> Result<usize, ProviderError> {
257 Ok(0)
258 }
259
260 async fn health_check(&self) -> Result<bool, ProviderError> {
261 Err(ProviderError::ProviderError("Provider is down".to_string()))
262 }
263 }
264
265 #[tokio::test]
266 async fn test_health_check_cache_healthy() {
267 let cache = HealthCheckCache::default();
268 let provider: Arc<dyn Provider> = Arc::new(MockHealthyProvider);
269
270 let result = cache.check_health(&provider).await;
271 assert!(result.is_ok());
272 assert!(result.unwrap());
273 }
274
275 #[tokio::test]
276 async fn test_health_check_cache_unhealthy() {
277 let cache = HealthCheckCache::default();
278 let provider: Arc<dyn Provider> = Arc::new(MockUnhealthyProvider);
279
280 let result = cache.check_health(&provider).await;
281 assert!(result.is_err());
282 }
283
284 #[tokio::test]
285 async fn test_health_check_caching() {
286 let cache = HealthCheckCache::default();
287 let provider: Arc<dyn Provider> = Arc::new(MockHealthyProvider);
288
289 let result1 = cache.check_health(&provider).await;
291 assert!(result1.is_ok());
292
293 let result2 = cache.check_health(&provider).await;
295 assert!(result2.is_ok());
296
297 let cached = cache.get_cached("mock-healthy").await;
299 assert!(cached.is_some());
300 }
301
302 #[tokio::test]
303 async fn test_health_check_invalidate() {
304 let cache = HealthCheckCache::default();
305 let provider: Arc<dyn Provider> = Arc::new(MockHealthyProvider);
306
307 cache.check_health(&provider).await.ok();
309
310 let cached = cache.get_cached("mock-healthy").await;
312 assert!(cached.is_some());
313
314 cache.invalidate("mock-healthy").await;
316
317 let cached = cache.get_cached("mock-healthy").await;
319 assert!(cached.is_none());
320 }
321
322 #[tokio::test]
323 async fn test_health_check_invalidate_all() {
324 let cache = HealthCheckCache::default();
325 let provider1: Arc<dyn Provider> = Arc::new(MockHealthyProvider);
326 let provider2: Arc<dyn Provider> = Arc::new(MockUnhealthyProvider);
327
328 cache.check_health(&provider1).await.ok();
330 cache.check_health(&provider2).await.ok();
331
332 assert!(cache.get_cached("mock-healthy").await.is_some());
334 assert!(cache.get_cached("mock-unhealthy").await.is_some());
335
336 cache.invalidate_all().await;
338
339 assert!(cache.get_cached("mock-healthy").await.is_none());
341 assert!(cache.get_cached("mock-unhealthy").await.is_none());
342 }
343
344 #[tokio::test]
345 async fn test_health_check_timeout() {
346 let cache = HealthCheckCache::new(Duration::from_secs(300), Duration::from_millis(1));
347
348 struct SlowProvider;
349
350 #[async_trait]
351 impl Provider for SlowProvider {
352 fn id(&self) -> &str {
353 "slow"
354 }
355
356 fn name(&self) -> &str {
357 "Slow"
358 }
359
360 fn models(&self) -> Vec<crate::models::ModelInfo> {
361 vec![]
362 }
363
364 async fn chat(&self, _request: ChatRequest) -> Result<ChatResponse, ProviderError> {
365 Ok(ChatResponse {
366 content: "test".to_string(),
367 model: "test".to_string(),
368 usage: TokenUsage {
369 prompt_tokens: 0,
370 completion_tokens: 0,
371 total_tokens: 0,
372 },
373 finish_reason: FinishReason::Stop,
374 })
375 }
376
377 async fn chat_stream(
378 &self,
379 _request: ChatRequest,
380 ) -> Result<crate::provider::ChatStream, ProviderError> {
381 Err(ProviderError::NotFound("Not implemented".to_string()))
382 }
383
384 fn count_tokens(&self, _content: &str, _model: &str) -> Result<usize, ProviderError> {
385 Ok(0)
386 }
387
388 async fn health_check(&self) -> Result<bool, ProviderError> {
389 tokio::time::sleep(Duration::from_secs(10)).await;
390 Ok(true)
391 }
392 }
393
394 let provider: Arc<dyn Provider> = Arc::new(SlowProvider);
395 let result = cache.check_health(&provider).await;
396 assert!(result.is_err());
397 }
398}