ricecoder_mcp/
health_check.rs

1//! Health checking and reconnection logic for MCP servers
2
3use crate::error::{Error, Result};
4use std::sync::Arc;
5use std::time::Duration;
6use tokio::sync::RwLock;
7use tokio::time::sleep;
8use tracing::{debug, error, info, warn};
9
10/// Server health status
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum HealthStatus {
13    Healthy,
14    Unhealthy,
15    Unknown,
16}
17
18/// Server availability information
19#[derive(Debug, Clone)]
20pub struct ServerAvailability {
21    pub server_id: String,
22    pub is_available: bool,
23    pub status: HealthStatus,
24    pub last_check: std::time::Instant,
25    pub consecutive_failures: u32,
26}
27
28impl ServerAvailability {
29    /// Creates a new server availability tracker
30    pub fn new(server_id: String) -> Self {
31        Self {
32            server_id,
33            is_available: true,
34            status: HealthStatus::Unknown,
35            last_check: std::time::Instant::now(),
36            consecutive_failures: 0,
37        }
38    }
39
40    /// Marks the server as healthy
41    pub fn mark_healthy(&mut self) {
42        self.is_available = true;
43        self.status = HealthStatus::Healthy;
44        self.consecutive_failures = 0;
45        self.last_check = std::time::Instant::now();
46    }
47
48    /// Marks the server as unhealthy
49    pub fn mark_unhealthy(&mut self) {
50        self.consecutive_failures += 1;
51        self.status = HealthStatus::Unhealthy;
52        self.last_check = std::time::Instant::now();
53
54        // Mark as unavailable after 3 consecutive failures
55        if self.consecutive_failures >= 3 {
56            self.is_available = false;
57        }
58    }
59
60    /// Resets the failure counter
61    pub fn reset_failures(&mut self) {
62        self.consecutive_failures = 0;
63    }
64}
65
66/// Configuration for health checking
67#[derive(Debug, Clone)]
68pub struct HealthCheckConfig {
69    pub check_interval_ms: u64,
70    pub timeout_ms: u64,
71    pub max_retries: u32,
72    pub backoff_multiplier: f64,
73    pub max_backoff_ms: u64,
74}
75
76impl Default for HealthCheckConfig {
77    fn default() -> Self {
78        Self {
79            check_interval_ms: 5000,
80            timeout_ms: 2000,
81            max_retries: 3,
82            backoff_multiplier: 2.0,
83            max_backoff_ms: 60000,
84        }
85    }
86}
87
88/// Health checker for MCP servers
89#[derive(Debug, Clone)]
90pub struct HealthChecker {
91    config: HealthCheckConfig,
92    availability: Arc<RwLock<std::collections::HashMap<String, ServerAvailability>>>,
93}
94
95impl HealthChecker {
96    /// Creates a new health checker with default configuration
97    pub fn new() -> Self {
98        Self::with_config(HealthCheckConfig::default())
99    }
100
101    /// Creates a new health checker with custom configuration
102    pub fn with_config(config: HealthCheckConfig) -> Self {
103        Self {
104            config,
105            availability: Arc::new(RwLock::new(std::collections::HashMap::new())),
106        }
107    }
108
109    /// Registers a server for health checking
110    ///
111    /// # Arguments
112    /// * `server_id` - The server ID to register
113    pub async fn register_server(&self, server_id: &str) {
114        debug!("Registering server for health checking: {}", server_id);
115
116        let mut availability = self.availability.write().await;
117        availability.insert(
118            server_id.to_string(),
119            ServerAvailability::new(server_id.to_string()),
120        );
121
122        info!("Server registered for health checking: {}", server_id);
123    }
124
125    /// Unregisters a server from health checking
126    ///
127    /// # Arguments
128    /// * `server_id` - The server ID to unregister
129    pub async fn unregister_server(&self, server_id: &str) {
130        debug!("Unregistering server from health checking: {}", server_id);
131
132        let mut availability = self.availability.write().await;
133        availability.remove(server_id);
134
135        info!("Server unregistered from health checking: {}", server_id);
136    }
137
138    /// Performs a health check on a server
139    ///
140    /// # Arguments
141    /// * `server_id` - The server ID to check
142    ///
143    /// # Returns
144    /// True if server is healthy, false otherwise
145    pub async fn check_health(&self, server_id: &str) -> Result<bool> {
146        debug!("Checking health of server: {}", server_id);
147
148        let mut availability = self.availability.write().await;
149        let server_avail = availability
150            .get_mut(server_id)
151            .ok_or_else(|| Error::ConnectionError(format!("Server not registered: {}", server_id)))?;
152
153        // Simulate health check (in real implementation, would ping the server)
154        let is_healthy = true;
155
156        if is_healthy {
157            server_avail.mark_healthy();
158            info!("Server health check passed: {}", server_id);
159            Ok(true)
160        } else {
161            server_avail.mark_unhealthy();
162            warn!("Server health check failed: {}", server_id);
163            Ok(false)
164        }
165    }
166
167    /// Detects server disconnection
168    ///
169    /// # Arguments
170    /// * `server_id` - The server ID to check
171    ///
172    /// # Returns
173    /// True if server is disconnected, false otherwise
174    pub async fn is_disconnected(&self, server_id: &str) -> bool {
175        let availability = self.availability.read().await;
176        availability
177            .get(server_id)
178            .map(|a| !a.is_available)
179            .unwrap_or(false)
180    }
181
182    /// Detects server unavailability
183    ///
184    /// # Arguments
185    /// * `server_id` - The server ID to check
186    ///
187    /// # Returns
188    /// True if server is unavailable, false otherwise
189    pub async fn is_unavailable(&self, server_id: &str) -> bool {
190        let availability = self.availability.read().await;
191        availability
192            .get(server_id)
193            .map(|a| !a.is_available)
194            .unwrap_or(true)
195    }
196
197    /// Gets the availability status of a server
198    ///
199    /// # Arguments
200    /// * `server_id` - The server ID to check
201    ///
202    /// # Returns
203    /// Server availability information
204    pub async fn get_availability(&self, server_id: &str) -> Option<ServerAvailability> {
205        let availability = self.availability.read().await;
206        availability.get(server_id).cloned()
207    }
208
209    /// Performs periodic availability detection
210    ///
211    /// This method should be run in a background task to continuously monitor server health.
212    pub async fn periodic_check(&self) {
213        debug!("Starting periodic health checks");
214
215        loop {
216            let availability = self.availability.read().await;
217            let server_ids: Vec<String> = availability.keys().cloned().collect();
218            drop(availability);
219
220            for server_id in server_ids {
221                if let Err(e) = self.check_health(&server_id).await {
222                    error!("Health check error for server {}: {}", server_id, e);
223                }
224            }
225
226            sleep(Duration::from_millis(self.config.check_interval_ms)).await;
227        }
228    }
229
230    /// Implements reconnection logic with exponential backoff
231    ///
232    /// # Arguments
233    /// * `server_id` - The server ID to reconnect to
234    /// * `on_reconnect` - Callback function to attempt reconnection
235    ///
236    /// # Returns
237    /// Result indicating success or failure
238    pub async fn reconnect_with_backoff<F>(&self, server_id: &str, mut on_reconnect: F) -> Result<()>
239    where
240        F: FnMut() -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>,
241    {
242        debug!("Starting reconnection with backoff for server: {}", server_id);
243
244        let mut backoff_ms = 100u64;
245        let mut attempt = 0;
246
247        loop {
248            attempt += 1;
249            info!(
250                "Reconnection attempt {} for server: {} (backoff: {}ms)",
251                attempt, server_id, backoff_ms
252            );
253
254            match on_reconnect().await {
255                Ok(()) => {
256                    info!("Successfully reconnected to server: {}", server_id);
257                    let mut availability = self.availability.write().await;
258                    if let Some(avail) = availability.get_mut(server_id) {
259                        avail.mark_healthy();
260                    }
261                    return Ok(());
262                }
263                Err(e) => {
264                    if attempt >= self.config.max_retries {
265                        error!(
266                            "Failed to reconnect to server {} after {} attempts: {}",
267                            server_id, attempt, e
268                        );
269                        let mut availability = self.availability.write().await;
270                        if let Some(avail) = availability.get_mut(server_id) {
271                            avail.mark_unhealthy();
272                        }
273                        return Err(Error::ConnectionError(format!(
274                            "Failed to reconnect to server {} after {} attempts",
275                            server_id, attempt
276                        )));
277                    }
278
279                    warn!(
280                        "Reconnection attempt {} failed for server {}: {}. Retrying in {}ms",
281                        attempt, server_id, e, backoff_ms
282                    );
283
284                    sleep(Duration::from_millis(backoff_ms)).await;
285
286                    // Calculate next backoff with exponential increase
287                    backoff_ms = std::cmp::min(
288                        (backoff_ms as f64 * self.config.backoff_multiplier) as u64,
289                        self.config.max_backoff_ms,
290                    );
291                }
292            }
293        }
294    }
295
296    /// Reports persistent failures to user
297    ///
298    /// # Arguments
299    /// * `server_id` - The server ID that failed
300    ///
301    /// # Returns
302    /// Error message for user
303    pub async fn report_failure(&self, server_id: &str) -> String {
304        let availability = self.availability.read().await;
305        if let Some(avail) = availability.get(server_id) {
306            format!(
307                "Server '{}' is unavailable after {} consecutive failures. Please check the server status.",
308                server_id, avail.consecutive_failures
309            )
310        } else {
311            format!("Server '{}' is unavailable.", server_id)
312        }
313    }
314
315    /// Gets health statistics for all servers
316    pub async fn get_health_stats(&self) -> Vec<ServerAvailability> {
317        let availability = self.availability.read().await;
318        availability.values().cloned().collect()
319    }
320}
321
322impl Default for HealthChecker {
323    fn default() -> Self {
324        Self::new()
325    }
326}
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331
332    #[tokio::test]
333    async fn test_create_health_checker() {
334        let checker = HealthChecker::new();
335        let stats = checker.get_health_stats().await;
336        assert_eq!(stats.len(), 0);
337    }
338
339    #[tokio::test]
340    async fn test_register_server() {
341        let checker = HealthChecker::new();
342        checker.register_server("server1").await;
343
344        let avail = checker.get_availability("server1").await;
345        assert!(avail.is_some());
346        assert!(avail.unwrap().is_available);
347    }
348
349    #[tokio::test]
350    async fn test_unregister_server() {
351        let checker = HealthChecker::new();
352        checker.register_server("server1").await;
353        checker.unregister_server("server1").await;
354
355        let avail = checker.get_availability("server1").await;
356        assert!(avail.is_none());
357    }
358
359    #[tokio::test]
360    async fn test_check_health() {
361        let checker = HealthChecker::new();
362        checker.register_server("server1").await;
363
364        let result = checker.check_health("server1").await;
365        assert!(result.is_ok());
366        assert!(result.unwrap());
367
368        let avail = checker.get_availability("server1").await.unwrap();
369        assert_eq!(avail.status, HealthStatus::Healthy);
370    }
371
372    #[tokio::test]
373    async fn test_server_availability() {
374        let mut avail = ServerAvailability::new("server1".to_string());
375        assert!(avail.is_available);
376
377        avail.mark_unhealthy();
378        avail.mark_unhealthy();
379        avail.mark_unhealthy();
380        assert!(!avail.is_available);
381
382        avail.mark_healthy();
383        assert!(avail.is_available);
384        assert_eq!(avail.consecutive_failures, 0);
385    }
386
387    #[tokio::test]
388    async fn test_is_disconnected() {
389        let checker = HealthChecker::new();
390        checker.register_server("server1").await;
391
392        assert!(!checker.is_disconnected("server1").await);
393
394        let mut avail = checker.get_availability("server1").await.unwrap();
395        avail.mark_unhealthy();
396        avail.mark_unhealthy();
397        avail.mark_unhealthy();
398
399        let mut availability = checker.availability.write().await;
400        availability.insert("server1".to_string(), avail);
401        drop(availability);
402
403        assert!(checker.is_disconnected("server1").await);
404    }
405
406    #[tokio::test]
407    async fn test_report_failure() {
408        let checker = HealthChecker::new();
409        checker.register_server("server1").await;
410
411        let message = checker.report_failure("server1").await;
412        assert!(message.contains("server1"));
413    }
414
415    #[tokio::test]
416    async fn test_get_health_stats() {
417        let checker = HealthChecker::new();
418        checker.register_server("server1").await;
419        checker.register_server("server2").await;
420
421        let stats = checker.get_health_stats().await;
422        assert_eq!(stats.len(), 2);
423    }
424}