ricecoder_mcp/
health_check.rs1use crate::error::{Error, Result};
4use std::sync::Arc;
5use std::time::Duration;
6use tokio::sync::RwLock;
7use tokio::time::sleep;
8use tracing::{debug, error, info, warn};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum HealthStatus {
13 Healthy,
14 Unhealthy,
15 Unknown,
16}
17
18#[derive(Debug, Clone)]
20pub struct ServerAvailability {
21 pub server_id: String,
22 pub is_available: bool,
23 pub status: HealthStatus,
24 pub last_check: std::time::Instant,
25 pub consecutive_failures: u32,
26}
27
28impl ServerAvailability {
29 pub fn new(server_id: String) -> Self {
31 Self {
32 server_id,
33 is_available: true,
34 status: HealthStatus::Unknown,
35 last_check: std::time::Instant::now(),
36 consecutive_failures: 0,
37 }
38 }
39
40 pub fn mark_healthy(&mut self) {
42 self.is_available = true;
43 self.status = HealthStatus::Healthy;
44 self.consecutive_failures = 0;
45 self.last_check = std::time::Instant::now();
46 }
47
48 pub fn mark_unhealthy(&mut self) {
50 self.consecutive_failures += 1;
51 self.status = HealthStatus::Unhealthy;
52 self.last_check = std::time::Instant::now();
53
54 if self.consecutive_failures >= 3 {
56 self.is_available = false;
57 }
58 }
59
60 pub fn reset_failures(&mut self) {
62 self.consecutive_failures = 0;
63 }
64}
65
66#[derive(Debug, Clone)]
68pub struct HealthCheckConfig {
69 pub check_interval_ms: u64,
70 pub timeout_ms: u64,
71 pub max_retries: u32,
72 pub backoff_multiplier: f64,
73 pub max_backoff_ms: u64,
74}
75
76impl Default for HealthCheckConfig {
77 fn default() -> Self {
78 Self {
79 check_interval_ms: 5000,
80 timeout_ms: 2000,
81 max_retries: 3,
82 backoff_multiplier: 2.0,
83 max_backoff_ms: 60000,
84 }
85 }
86}
87
88#[derive(Debug, Clone)]
90pub struct HealthChecker {
91 config: HealthCheckConfig,
92 availability: Arc<RwLock<std::collections::HashMap<String, ServerAvailability>>>,
93}
94
95impl HealthChecker {
96 pub fn new() -> Self {
98 Self::with_config(HealthCheckConfig::default())
99 }
100
101 pub fn with_config(config: HealthCheckConfig) -> Self {
103 Self {
104 config,
105 availability: Arc::new(RwLock::new(std::collections::HashMap::new())),
106 }
107 }
108
109 pub async fn register_server(&self, server_id: &str) {
114 debug!("Registering server for health checking: {}", server_id);
115
116 let mut availability = self.availability.write().await;
117 availability.insert(
118 server_id.to_string(),
119 ServerAvailability::new(server_id.to_string()),
120 );
121
122 info!("Server registered for health checking: {}", server_id);
123 }
124
125 pub async fn unregister_server(&self, server_id: &str) {
130 debug!("Unregistering server from health checking: {}", server_id);
131
132 let mut availability = self.availability.write().await;
133 availability.remove(server_id);
134
135 info!("Server unregistered from health checking: {}", server_id);
136 }
137
138 pub async fn check_health(&self, server_id: &str) -> Result<bool> {
146 debug!("Checking health of server: {}", server_id);
147
148 let mut availability = self.availability.write().await;
149 let server_avail = availability
150 .get_mut(server_id)
151 .ok_or_else(|| Error::ConnectionError(format!("Server not registered: {}", server_id)))?;
152
153 let is_healthy = true;
155
156 if is_healthy {
157 server_avail.mark_healthy();
158 info!("Server health check passed: {}", server_id);
159 Ok(true)
160 } else {
161 server_avail.mark_unhealthy();
162 warn!("Server health check failed: {}", server_id);
163 Ok(false)
164 }
165 }
166
167 pub async fn is_disconnected(&self, server_id: &str) -> bool {
175 let availability = self.availability.read().await;
176 availability
177 .get(server_id)
178 .map(|a| !a.is_available)
179 .unwrap_or(false)
180 }
181
182 pub async fn is_unavailable(&self, server_id: &str) -> bool {
190 let availability = self.availability.read().await;
191 availability
192 .get(server_id)
193 .map(|a| !a.is_available)
194 .unwrap_or(true)
195 }
196
197 pub async fn get_availability(&self, server_id: &str) -> Option<ServerAvailability> {
205 let availability = self.availability.read().await;
206 availability.get(server_id).cloned()
207 }
208
209 pub async fn periodic_check(&self) {
213 debug!("Starting periodic health checks");
214
215 loop {
216 let availability = self.availability.read().await;
217 let server_ids: Vec<String> = availability.keys().cloned().collect();
218 drop(availability);
219
220 for server_id in server_ids {
221 if let Err(e) = self.check_health(&server_id).await {
222 error!("Health check error for server {}: {}", server_id, e);
223 }
224 }
225
226 sleep(Duration::from_millis(self.config.check_interval_ms)).await;
227 }
228 }
229
230 pub async fn reconnect_with_backoff<F>(&self, server_id: &str, mut on_reconnect: F) -> Result<()>
239 where
240 F: FnMut() -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>,
241 {
242 debug!("Starting reconnection with backoff for server: {}", server_id);
243
244 let mut backoff_ms = 100u64;
245 let mut attempt = 0;
246
247 loop {
248 attempt += 1;
249 info!(
250 "Reconnection attempt {} for server: {} (backoff: {}ms)",
251 attempt, server_id, backoff_ms
252 );
253
254 match on_reconnect().await {
255 Ok(()) => {
256 info!("Successfully reconnected to server: {}", server_id);
257 let mut availability = self.availability.write().await;
258 if let Some(avail) = availability.get_mut(server_id) {
259 avail.mark_healthy();
260 }
261 return Ok(());
262 }
263 Err(e) => {
264 if attempt >= self.config.max_retries {
265 error!(
266 "Failed to reconnect to server {} after {} attempts: {}",
267 server_id, attempt, e
268 );
269 let mut availability = self.availability.write().await;
270 if let Some(avail) = availability.get_mut(server_id) {
271 avail.mark_unhealthy();
272 }
273 return Err(Error::ConnectionError(format!(
274 "Failed to reconnect to server {} after {} attempts",
275 server_id, attempt
276 )));
277 }
278
279 warn!(
280 "Reconnection attempt {} failed for server {}: {}. Retrying in {}ms",
281 attempt, server_id, e, backoff_ms
282 );
283
284 sleep(Duration::from_millis(backoff_ms)).await;
285
286 backoff_ms = std::cmp::min(
288 (backoff_ms as f64 * self.config.backoff_multiplier) as u64,
289 self.config.max_backoff_ms,
290 );
291 }
292 }
293 }
294 }
295
296 pub async fn report_failure(&self, server_id: &str) -> String {
304 let availability = self.availability.read().await;
305 if let Some(avail) = availability.get(server_id) {
306 format!(
307 "Server '{}' is unavailable after {} consecutive failures. Please check the server status.",
308 server_id, avail.consecutive_failures
309 )
310 } else {
311 format!("Server '{}' is unavailable.", server_id)
312 }
313 }
314
315 pub async fn get_health_stats(&self) -> Vec<ServerAvailability> {
317 let availability = self.availability.read().await;
318 availability.values().cloned().collect()
319 }
320}
321
322impl Default for HealthChecker {
323 fn default() -> Self {
324 Self::new()
325 }
326}
327
328#[cfg(test)]
329mod tests {
330 use super::*;
331
332 #[tokio::test]
333 async fn test_create_health_checker() {
334 let checker = HealthChecker::new();
335 let stats = checker.get_health_stats().await;
336 assert_eq!(stats.len(), 0);
337 }
338
339 #[tokio::test]
340 async fn test_register_server() {
341 let checker = HealthChecker::new();
342 checker.register_server("server1").await;
343
344 let avail = checker.get_availability("server1").await;
345 assert!(avail.is_some());
346 assert!(avail.unwrap().is_available);
347 }
348
349 #[tokio::test]
350 async fn test_unregister_server() {
351 let checker = HealthChecker::new();
352 checker.register_server("server1").await;
353 checker.unregister_server("server1").await;
354
355 let avail = checker.get_availability("server1").await;
356 assert!(avail.is_none());
357 }
358
359 #[tokio::test]
360 async fn test_check_health() {
361 let checker = HealthChecker::new();
362 checker.register_server("server1").await;
363
364 let result = checker.check_health("server1").await;
365 assert!(result.is_ok());
366 assert!(result.unwrap());
367
368 let avail = checker.get_availability("server1").await.unwrap();
369 assert_eq!(avail.status, HealthStatus::Healthy);
370 }
371
372 #[tokio::test]
373 async fn test_server_availability() {
374 let mut avail = ServerAvailability::new("server1".to_string());
375 assert!(avail.is_available);
376
377 avail.mark_unhealthy();
378 avail.mark_unhealthy();
379 avail.mark_unhealthy();
380 assert!(!avail.is_available);
381
382 avail.mark_healthy();
383 assert!(avail.is_available);
384 assert_eq!(avail.consecutive_failures, 0);
385 }
386
387 #[tokio::test]
388 async fn test_is_disconnected() {
389 let checker = HealthChecker::new();
390 checker.register_server("server1").await;
391
392 assert!(!checker.is_disconnected("server1").await);
393
394 let mut avail = checker.get_availability("server1").await.unwrap();
395 avail.mark_unhealthy();
396 avail.mark_unhealthy();
397 avail.mark_unhealthy();
398
399 let mut availability = checker.availability.write().await;
400 availability.insert("server1".to_string(), avail);
401 drop(availability);
402
403 assert!(checker.is_disconnected("server1").await);
404 }
405
406 #[tokio::test]
407 async fn test_report_failure() {
408 let checker = HealthChecker::new();
409 checker.register_server("server1").await;
410
411 let message = checker.report_failure("server1").await;
412 assert!(message.contains("server1"));
413 }
414
415 #[tokio::test]
416 async fn test_get_health_stats() {
417 let checker = HealthChecker::new();
418 checker.register_server("server1").await;
419 checker.register_server("server2").await;
420
421 let stats = checker.get_health_stats().await;
422 assert_eq!(stats.len(), 2);
423 }
424}