ricecoder_providers/
health_check.rs

1//! Health check system with caching and timeout support
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6use tokio::sync::RwLock;
7use tracing::{debug, warn};
8
9use crate::error::ProviderError;
10use crate::provider::Provider;
11
12/// Health check result with timestamp
13#[derive(Clone, Debug)]
14pub struct HealthCheckResult {
15    /// Whether the provider is healthy
16    pub is_healthy: bool,
17    /// When the check was performed
18    pub checked_at: Instant,
19    /// Error if the check failed
20    pub error: Option<String>,
21}
22
23impl HealthCheckResult {
24    /// Check if the result is still valid (not expired)
25    pub fn is_valid(&self, ttl: Duration) -> bool {
26        self.checked_at.elapsed() < ttl
27    }
28}
29
30/// Health check cache for providers
31pub struct HealthCheckCache {
32    /// Cache of health check results by provider ID
33    cache: Arc<RwLock<HashMap<String, HealthCheckResult>>>,
34    /// Time-to-live for cached results
35    ttl: Duration,
36    /// Timeout for health check operations
37    timeout: Duration,
38}
39
40impl HealthCheckCache {
41    /// Create a new health check cache
42    pub fn new(ttl: Duration, timeout: Duration) -> Self {
43        Self {
44            cache: Arc::new(RwLock::new(HashMap::new())),
45            ttl,
46            timeout,
47        }
48    }
49}
50
51impl Default for HealthCheckCache {
52    /// Create a new health check cache with default settings
53    /// - TTL: 5 minutes
54    /// - Timeout: 10 seconds
55    fn default() -> Self {
56        Self::new(Duration::from_secs(300), Duration::from_secs(10))
57    }
58}
59
60impl HealthCheckCache {
61    /// Check provider health with caching
62    pub async fn check_health(&self, provider: &Arc<dyn Provider>) -> Result<bool, ProviderError> {
63        let provider_id = provider.id();
64
65        // Check cache first
66        {
67            let cache = self.cache.read().await;
68            if let Some(result) = cache.get(provider_id) {
69                if result.is_valid(self.ttl) {
70                    debug!(
71                        "Using cached health check for provider: {} (healthy: {})",
72                        provider_id, result.is_healthy
73                    );
74                    return if result.is_healthy {
75                        Ok(true)
76                    } else {
77                        Err(ProviderError::ProviderError(
78                            result
79                                .error
80                                .clone()
81                                .unwrap_or_else(|| "Provider unhealthy".to_string()),
82                        ))
83                    };
84                }
85            }
86        }
87
88        // Perform health check with timeout
89        debug!("Performing health check for provider: {}", provider_id);
90        let result = match tokio::time::timeout(self.timeout, provider.health_check()).await {
91            Ok(Ok(is_healthy)) => HealthCheckResult {
92                is_healthy,
93                checked_at: Instant::now(),
94                error: None,
95            },
96            Ok(Err(e)) => {
97                warn!("Health check failed for provider {}: {}", provider_id, e);
98                HealthCheckResult {
99                    is_healthy: false,
100                    checked_at: Instant::now(),
101                    error: Some(e.to_string()),
102                }
103            }
104            Err(_) => {
105                warn!("Health check timeout for provider: {}", provider_id);
106                HealthCheckResult {
107                    is_healthy: false,
108                    checked_at: Instant::now(),
109                    error: Some("Health check timeout".to_string()),
110                }
111            }
112        };
113
114        // Cache the result
115        {
116            let mut cache = self.cache.write().await;
117            cache.insert(provider_id.to_string(), result.clone());
118        }
119
120        if result.is_healthy {
121            Ok(true)
122        } else {
123            Err(ProviderError::ProviderError(
124                result
125                    .error
126                    .unwrap_or_else(|| "Provider unhealthy".to_string()),
127            ))
128        }
129    }
130
131    /// Invalidate cache for a specific provider
132    pub async fn invalidate(&self, provider_id: &str) {
133        let mut cache = self.cache.write().await;
134        cache.remove(provider_id);
135        debug!(
136            "Invalidated health check cache for provider: {}",
137            provider_id
138        );
139    }
140
141    /// Invalidate all cached results
142    pub async fn invalidate_all(&self) {
143        let mut cache = self.cache.write().await;
144        cache.clear();
145        debug!("Invalidated all health check cache");
146    }
147
148    /// Get cached result without performing a new check
149    pub async fn get_cached(&self, provider_id: &str) -> Option<HealthCheckResult> {
150        let cache = self.cache.read().await;
151        cache.get(provider_id).cloned()
152    }
153
154    /// Set TTL for cached results
155    pub fn with_ttl(mut self, ttl: Duration) -> Self {
156        self.ttl = ttl;
157        self
158    }
159
160    /// Set timeout for health check operations
161    pub fn with_timeout(mut self, timeout: Duration) -> Self {
162        self.timeout = timeout;
163        self
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170    use crate::models::ChatRequest;
171    use crate::models::{ChatResponse, FinishReason, TokenUsage};
172    use crate::provider::Provider;
173    use async_trait::async_trait;
174
175    struct MockHealthyProvider;
176
177    #[async_trait]
178    impl Provider for MockHealthyProvider {
179        fn id(&self) -> &str {
180            "mock-healthy"
181        }
182
183        fn name(&self) -> &str {
184            "Mock Healthy"
185        }
186
187        fn models(&self) -> Vec<crate::models::ModelInfo> {
188            vec![]
189        }
190
191        async fn chat(&self, _request: ChatRequest) -> Result<ChatResponse, ProviderError> {
192            Ok(ChatResponse {
193                content: "test".to_string(),
194                model: "test".to_string(),
195                usage: TokenUsage {
196                    prompt_tokens: 0,
197                    completion_tokens: 0,
198                    total_tokens: 0,
199                },
200                finish_reason: FinishReason::Stop,
201            })
202        }
203
204        async fn chat_stream(
205            &self,
206            _request: ChatRequest,
207        ) -> Result<crate::provider::ChatStream, ProviderError> {
208            Err(ProviderError::NotFound("Not implemented".to_string()))
209        }
210
211        fn count_tokens(&self, _content: &str, _model: &str) -> Result<usize, ProviderError> {
212            Ok(0)
213        }
214
215        async fn health_check(&self) -> Result<bool, ProviderError> {
216            Ok(true)
217        }
218    }
219
220    struct MockUnhealthyProvider;
221
222    #[async_trait]
223    impl Provider for MockUnhealthyProvider {
224        fn id(&self) -> &str {
225            "mock-unhealthy"
226        }
227
228        fn name(&self) -> &str {
229            "Mock Unhealthy"
230        }
231
232        fn models(&self) -> Vec<crate::models::ModelInfo> {
233            vec![]
234        }
235
236        async fn chat(&self, _request: ChatRequest) -> Result<ChatResponse, ProviderError> {
237            Ok(ChatResponse {
238                content: "test".to_string(),
239                model: "test".to_string(),
240                usage: TokenUsage {
241                    prompt_tokens: 0,
242                    completion_tokens: 0,
243                    total_tokens: 0,
244                },
245                finish_reason: FinishReason::Stop,
246            })
247        }
248
249        async fn chat_stream(
250            &self,
251            _request: ChatRequest,
252        ) -> Result<crate::provider::ChatStream, ProviderError> {
253            Err(ProviderError::NotFound("Not implemented".to_string()))
254        }
255
256        fn count_tokens(&self, _content: &str, _model: &str) -> Result<usize, ProviderError> {
257            Ok(0)
258        }
259
260        async fn health_check(&self) -> Result<bool, ProviderError> {
261            Err(ProviderError::ProviderError("Provider is down".to_string()))
262        }
263    }
264
265    #[tokio::test]
266    async fn test_health_check_cache_healthy() {
267        let cache = HealthCheckCache::default();
268        let provider: Arc<dyn Provider> = Arc::new(MockHealthyProvider);
269
270        let result = cache.check_health(&provider).await;
271        assert!(result.is_ok());
272        assert!(result.unwrap());
273    }
274
275    #[tokio::test]
276    async fn test_health_check_cache_unhealthy() {
277        let cache = HealthCheckCache::default();
278        let provider: Arc<dyn Provider> = Arc::new(MockUnhealthyProvider);
279
280        let result = cache.check_health(&provider).await;
281        assert!(result.is_err());
282    }
283
284    #[tokio::test]
285    async fn test_health_check_caching() {
286        let cache = HealthCheckCache::default();
287        let provider: Arc<dyn Provider> = Arc::new(MockHealthyProvider);
288
289        // First check
290        let result1 = cache.check_health(&provider).await;
291        assert!(result1.is_ok());
292
293        // Second check should use cache
294        let result2 = cache.check_health(&provider).await;
295        assert!(result2.is_ok());
296
297        // Verify cache entry exists
298        let cached = cache.get_cached("mock-healthy").await;
299        assert!(cached.is_some());
300    }
301
302    #[tokio::test]
303    async fn test_health_check_invalidate() {
304        let cache = HealthCheckCache::default();
305        let provider: Arc<dyn Provider> = Arc::new(MockHealthyProvider);
306
307        // First check
308        cache.check_health(&provider).await.ok();
309
310        // Verify cache entry exists
311        let cached = cache.get_cached("mock-healthy").await;
312        assert!(cached.is_some());
313
314        // Invalidate
315        cache.invalidate("mock-healthy").await;
316
317        // Verify cache entry is gone
318        let cached = cache.get_cached("mock-healthy").await;
319        assert!(cached.is_none());
320    }
321
322    #[tokio::test]
323    async fn test_health_check_invalidate_all() {
324        let cache = HealthCheckCache::default();
325        let provider1: Arc<dyn Provider> = Arc::new(MockHealthyProvider);
326        let provider2: Arc<dyn Provider> = Arc::new(MockUnhealthyProvider);
327
328        // Perform checks
329        cache.check_health(&provider1).await.ok();
330        cache.check_health(&provider2).await.ok();
331
332        // Verify cache entries exist
333        assert!(cache.get_cached("mock-healthy").await.is_some());
334        assert!(cache.get_cached("mock-unhealthy").await.is_some());
335
336        // Invalidate all
337        cache.invalidate_all().await;
338
339        // Verify all cache entries are gone
340        assert!(cache.get_cached("mock-healthy").await.is_none());
341        assert!(cache.get_cached("mock-unhealthy").await.is_none());
342    }
343
344    #[tokio::test]
345    async fn test_health_check_timeout() {
346        let cache = HealthCheckCache::new(Duration::from_secs(300), Duration::from_millis(1));
347
348        struct SlowProvider;
349
350        #[async_trait]
351        impl Provider for SlowProvider {
352            fn id(&self) -> &str {
353                "slow"
354            }
355
356            fn name(&self) -> &str {
357                "Slow"
358            }
359
360            fn models(&self) -> Vec<crate::models::ModelInfo> {
361                vec![]
362            }
363
364            async fn chat(&self, _request: ChatRequest) -> Result<ChatResponse, ProviderError> {
365                Ok(ChatResponse {
366                    content: "test".to_string(),
367                    model: "test".to_string(),
368                    usage: TokenUsage {
369                        prompt_tokens: 0,
370                        completion_tokens: 0,
371                        total_tokens: 0,
372                    },
373                    finish_reason: FinishReason::Stop,
374                })
375            }
376
377            async fn chat_stream(
378                &self,
379                _request: ChatRequest,
380            ) -> Result<crate::provider::ChatStream, ProviderError> {
381                Err(ProviderError::NotFound("Not implemented".to_string()))
382            }
383
384            fn count_tokens(&self, _content: &str, _model: &str) -> Result<usize, ProviderError> {
385                Ok(0)
386            }
387
388            async fn health_check(&self) -> Result<bool, ProviderError> {
389                tokio::time::sleep(Duration::from_secs(10)).await;
390                Ok(true)
391            }
392        }
393
394        let provider: Arc<dyn Provider> = Arc::new(SlowProvider);
395        let result = cache.check_health(&provider).await;
396        assert!(result.is_err());
397    }
398}