nntp_proxy/health/
mod.rs

1mod types;
2
3pub use types::{BackendHealth, HealthMetrics, HealthStatus};
4
5use crate::protocol::{DATE, ResponseParser};
6use crate::types::BackendId;
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
11use tokio::sync::RwLock;
12use tokio::time;
13
14/// Configuration for health checking
15#[derive(Debug, Clone)]
16pub struct HealthCheckConfig {
17    /// Interval between health checks
18    pub check_interval: Duration,
19    /// Timeout for each health check
20    pub check_timeout: Duration,
21    /// Number of consecutive failures before marking unhealthy
22    pub unhealthy_threshold: u32,
23}
24
25impl Default for HealthCheckConfig {
26    fn default() -> Self {
27        Self {
28            check_interval: Duration::from_secs(30),
29            check_timeout: Duration::from_secs(5),
30            unhealthy_threshold: 3,
31        }
32    }
33}
34
35/// Health checker for backend connections
36pub struct HealthChecker {
37    /// Health status for each backend
38    backend_health: Arc<RwLock<HashMap<BackendId, BackendHealth>>>,
39    /// Configuration
40    config: HealthCheckConfig,
41}
42
43impl HealthChecker {
44    /// Create a new health checker
45    pub fn new(config: HealthCheckConfig) -> Self {
46        Self {
47            backend_health: Arc::new(RwLock::new(HashMap::new())),
48            config,
49        }
50    }
51
52    /// Initialize health tracking for a backend
53    pub async fn register_backend(&self, backend_id: BackendId) {
54        let mut health = self.backend_health.write().await;
55        health.entry(backend_id).or_insert_with(BackendHealth::new);
56    }
57
58    /// Start the background health checking task
59    pub fn start_health_checks(
60        self: Arc<Self>,
61        providers: Vec<crate::pool::DeadpoolConnectionProvider>,
62    ) {
63        tokio::spawn(async move {
64            let mut interval = time::interval(self.config.check_interval);
65            loop {
66                interval.tick().await;
67
68                // Check each backend
69                for (i, provider) in providers.iter().enumerate() {
70                    let backend_id = BackendId::from_index(i);
71                    self.clone()
72                        .check_backend(provider.clone(), backend_id)
73                        .await;
74                }
75            }
76        });
77    }
78
79    /// Perform a health check on a single backend
80    async fn check_backend(
81        &self,
82        provider: crate::pool::DeadpoolConnectionProvider,
83        backend_id: BackendId,
84    ) {
85        // Check if this backend needs a check
86        {
87            let health = self.backend_health.read().await;
88            if let Some(backend_health) = health.get(&backend_id)
89                && !backend_health.needs_check(self.config.check_interval)
90            {
91                return;
92            }
93        }
94
95        // Perform the health check with timeout
96        let check_result = time::timeout(
97            self.config.check_timeout,
98            self.perform_health_check(provider.clone(), backend_id),
99        )
100        .await;
101
102        // Update health status
103        let mut health = self.backend_health.write().await;
104        let backend_health = health.entry(backend_id).or_insert_with(BackendHealth::new);
105
106        match check_result {
107            Ok(Ok(())) => {
108                backend_health.record_success();
109            }
110            Ok(Err(_)) | Err(_) => {
111                backend_health.record_failure(self.config.unhealthy_threshold);
112            }
113        }
114    }
115
116    /// Perform the actual health check by sending DATE command
117    async fn perform_health_check(
118        &self,
119        provider: crate::pool::DeadpoolConnectionProvider,
120        _backend_id: BackendId,
121    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
122        // Get a pooled connection
123        let mut conn = provider.get_pooled_connection().await?;
124
125        // Send DATE command
126        conn.write_all(DATE).await?;
127
128        // Read response
129        let mut reader = BufReader::new(&mut *conn);
130        // Pre-allocate for typical NNTP DATE response ("111 YYYYMMDDhhmmss\r\n" ~20-30 bytes)
131        let mut response = Vec::with_capacity(64);
132        reader.read_until(b'\n', &mut response).await?;
133
134        // Check if response indicates success (111 response)
135        if ResponseParser::is_response_code(&response, 111) {
136            Ok(())
137        } else {
138            Err("Unexpected response from DATE command".into())
139        }
140    }
141
142    /// Check if a backend is healthy
143    pub async fn is_healthy(&self, backend_id: BackendId) -> bool {
144        let health = self.backend_health.read().await;
145        health
146            .get(&backend_id)
147            .map(|h| h.status == HealthStatus::Healthy)
148            .unwrap_or(true) // Assume healthy if not tracked yet
149    }
150
151    /// Get health status for a specific backend
152    pub async fn get_backend_health(&self, backend_id: BackendId) -> Option<BackendHealth> {
153        let health = self.backend_health.read().await;
154        health.get(&backend_id).cloned()
155    }
156
157    /// Get aggregated health metrics
158    pub async fn get_metrics(&self) -> HealthMetrics {
159        let health = self.backend_health.read().await;
160
161        let mut metrics = HealthMetrics {
162            total_checks: health
163                .values()
164                .map(|h| h.total_successes + h.total_failures)
165                .sum(),
166            ..Default::default()
167        };
168
169        for backend_health in health.values() {
170            match backend_health.status {
171                HealthStatus::Healthy => metrics.healthy_count += 1,
172                HealthStatus::Unhealthy => metrics.unhealthy_count += 1,
173            }
174        }
175
176        metrics
177    }
178
179    /// Get all healthy backend IDs
180    pub async fn get_healthy_backends(&self) -> Vec<BackendId> {
181        let health = self.backend_health.read().await;
182        health
183            .iter()
184            .filter(|(_, h)| h.status == HealthStatus::Healthy)
185            .map(|(id, _)| *id)
186            .collect()
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193    use std::time::Duration;
194
195    #[tokio::test]
196    async fn test_health_checker_creation() {
197        let config = HealthCheckConfig::default();
198        let checker = HealthChecker::new(config);
199
200        let metrics = checker.get_metrics().await;
201        assert_eq!(metrics.healthy_count, 0);
202        assert_eq!(metrics.unhealthy_count, 0);
203    }
204
205    #[tokio::test]
206    async fn test_register_backend() {
207        let config = HealthCheckConfig::default();
208        let checker = HealthChecker::new(config);
209
210        let backend_id = BackendId::from_index(0);
211        checker.register_backend(backend_id).await;
212
213        let metrics = checker.get_metrics().await;
214        assert_eq!(metrics.healthy_count, 1);
215        assert_eq!(metrics.unhealthy_count, 0);
216    }
217
218    #[tokio::test]
219    async fn test_multiple_backend_registration() {
220        let config = HealthCheckConfig::default();
221        let checker = HealthChecker::new(config);
222
223        for i in 0..3 {
224            checker.register_backend(BackendId::from_index(i)).await;
225        }
226
227        let metrics = checker.get_metrics().await;
228        assert_eq!(metrics.healthy_count, 3);
229        assert_eq!(metrics.unhealthy_count, 0);
230    }
231
232    #[tokio::test]
233    async fn test_get_healthy_backends() {
234        let config = HealthCheckConfig::default();
235        let checker = HealthChecker::new(config);
236
237        let backend_ids = vec![
238            BackendId::from_index(0),
239            BackendId::from_index(1),
240            BackendId::from_index(2),
241        ];
242
243        for id in &backend_ids {
244            checker.register_backend(*id).await;
245        }
246
247        let healthy = checker.get_healthy_backends().await;
248        assert_eq!(healthy.len(), 3);
249    }
250
251    #[tokio::test]
252    async fn test_health_check_config_default() {
253        let config = HealthCheckConfig::default();
254        assert_eq!(config.check_interval, Duration::from_secs(30));
255        assert_eq!(config.check_timeout, Duration::from_secs(5));
256        assert_eq!(config.unhealthy_threshold, 3);
257    }
258
259    #[tokio::test]
260    async fn test_health_check_config_custom() {
261        let config = HealthCheckConfig {
262            check_interval: Duration::from_secs(10),
263            check_timeout: Duration::from_secs(2),
264            unhealthy_threshold: 5,
265        };
266
267        let checker = HealthChecker::new(config.clone());
268        assert_eq!(checker.config.check_interval, Duration::from_secs(10));
269        assert_eq!(checker.config.check_timeout, Duration::from_secs(2));
270        assert_eq!(checker.config.unhealthy_threshold, 5);
271    }
272
273    #[tokio::test]
274    async fn test_simulated_connection_failure() {
275        let config = HealthCheckConfig {
276            check_interval: Duration::from_millis(100),
277            check_timeout: Duration::from_millis(50),
278            unhealthy_threshold: 2,
279        };
280        let checker = HealthChecker::new(config);
281        let backend_id = BackendId::from_index(0);
282
283        checker.register_backend(backend_id).await;
284
285        // Simulate failures by manually updating health
286        {
287            let mut health = checker.backend_health.write().await;
288            if let Some(backend_health) = health.get_mut(&backend_id) {
289                // Record failures to reach threshold
290                backend_health.record_failure(2);
291                backend_health.record_failure(2);
292            }
293        }
294
295        let metrics = checker.get_metrics().await;
296        assert_eq!(metrics.unhealthy_count, 1);
297        assert_eq!(metrics.healthy_count, 0);
298    }
299
300    #[tokio::test]
301    async fn test_recovery_after_failures() {
302        let config = HealthCheckConfig {
303            check_interval: Duration::from_millis(100),
304            check_timeout: Duration::from_millis(50),
305            unhealthy_threshold: 2,
306        };
307        let checker = HealthChecker::new(config);
308        let backend_id = BackendId::from_index(0);
309
310        checker.register_backend(backend_id).await;
311
312        // Simulate failures then success
313        {
314            let mut health = checker.backend_health.write().await;
315            if let Some(backend_health) = health.get_mut(&backend_id) {
316                backend_health.record_failure(2);
317                backend_health.record_failure(2);
318                // Now recover
319                backend_health.record_success();
320            }
321        }
322
323        let metrics = checker.get_metrics().await;
324        assert_eq!(metrics.healthy_count, 1);
325        assert_eq!(metrics.unhealthy_count, 0);
326    }
327
328    #[tokio::test]
329    async fn test_health_metrics_mixed_states() {
330        let config = HealthCheckConfig::default();
331        let checker = HealthChecker::new(config);
332
333        // Register multiple backends
334        for i in 0..5 {
335            checker.register_backend(BackendId::from_index(i)).await;
336        }
337
338        // Make some unhealthy
339        {
340            let mut health = checker.backend_health.write().await;
341            // Make backends 1 and 3 unhealthy
342            if let Some(backend_health) = health.get_mut(&BackendId::from_index(1)) {
343                backend_health.record_failure(3);
344                backend_health.record_failure(3);
345                backend_health.record_failure(3);
346            }
347            if let Some(backend_health) = health.get_mut(&BackendId::from_index(3)) {
348                backend_health.record_failure(3);
349                backend_health.record_failure(3);
350                backend_health.record_failure(3);
351            }
352        }
353
354        let metrics = checker.get_metrics().await;
355        assert_eq!(metrics.healthy_count, 3);
356        assert_eq!(metrics.unhealthy_count, 2);
357
358        let healthy = checker.get_healthy_backends().await;
359        assert_eq!(healthy.len(), 3);
360        assert!(healthy.contains(&BackendId::from_index(0)));
361        assert!(healthy.contains(&BackendId::from_index(2)));
362        assert!(healthy.contains(&BackendId::from_index(4)));
363    }
364
365    #[tokio::test]
366    async fn test_backend_health_isolation() {
367        let config = HealthCheckConfig::default();
368        let checker = HealthChecker::new(config);
369
370        let backend1 = BackendId::from_index(0);
371        let backend2 = BackendId::from_index(1);
372
373        checker.register_backend(backend1).await;
374        checker.register_backend(backend2).await;
375
376        // Fail only backend1
377        {
378            let mut health = checker.backend_health.write().await;
379            if let Some(backend_health) = health.get_mut(&backend1) {
380                backend_health.record_failure(3);
381                backend_health.record_failure(3);
382                backend_health.record_failure(3);
383            }
384        }
385
386        let healthy = checker.get_healthy_backends().await;
387        assert_eq!(healthy.len(), 1);
388        assert_eq!(healthy[0], backend2);
389    }
390}