1mod types;
2
3pub use types::{BackendHealth, HealthMetrics, HealthStatus};
4
5use crate::protocol::{DATE, ResponseParser};
6use crate::types::BackendId;
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
11use tokio::sync::RwLock;
12use tokio::time;
13
14#[derive(Debug, Clone)]
16pub struct HealthCheckConfig {
17 pub check_interval: Duration,
19 pub check_timeout: Duration,
21 pub unhealthy_threshold: u32,
23}
24
25impl Default for HealthCheckConfig {
26 fn default() -> Self {
27 Self {
28 check_interval: Duration::from_secs(30),
29 check_timeout: Duration::from_secs(5),
30 unhealthy_threshold: 3,
31 }
32 }
33}
34
35pub struct HealthChecker {
37 backend_health: Arc<RwLock<HashMap<BackendId, BackendHealth>>>,
39 config: HealthCheckConfig,
41}
42
43impl HealthChecker {
44 pub fn new(config: HealthCheckConfig) -> Self {
46 Self {
47 backend_health: Arc::new(RwLock::new(HashMap::new())),
48 config,
49 }
50 }
51
52 pub async fn register_backend(&self, backend_id: BackendId) {
54 let mut health = self.backend_health.write().await;
55 health.entry(backend_id).or_insert_with(BackendHealth::new);
56 }
57
58 pub fn start_health_checks(
60 self: Arc<Self>,
61 providers: Vec<crate::pool::DeadpoolConnectionProvider>,
62 ) {
63 tokio::spawn(async move {
64 let mut interval = time::interval(self.config.check_interval);
65 loop {
66 interval.tick().await;
67
68 for (i, provider) in providers.iter().enumerate() {
70 let backend_id = BackendId::from_index(i);
71 self.clone()
72 .check_backend(provider.clone(), backend_id)
73 .await;
74 }
75 }
76 });
77 }
78
79 async fn check_backend(
81 &self,
82 provider: crate::pool::DeadpoolConnectionProvider,
83 backend_id: BackendId,
84 ) {
85 {
87 let health = self.backend_health.read().await;
88 if let Some(backend_health) = health.get(&backend_id)
89 && !backend_health.needs_check(self.config.check_interval)
90 {
91 return;
92 }
93 }
94
95 let check_result = time::timeout(
97 self.config.check_timeout,
98 self.perform_health_check(provider.clone(), backend_id),
99 )
100 .await;
101
102 let mut health = self.backend_health.write().await;
104 let backend_health = health.entry(backend_id).or_insert_with(BackendHealth::new);
105
106 match check_result {
107 Ok(Ok(())) => {
108 backend_health.record_success();
109 }
110 Ok(Err(_)) | Err(_) => {
111 backend_health.record_failure(self.config.unhealthy_threshold);
112 }
113 }
114 }
115
116 async fn perform_health_check(
118 &self,
119 provider: crate::pool::DeadpoolConnectionProvider,
120 _backend_id: BackendId,
121 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
122 let mut conn = provider.get_pooled_connection().await?;
124
125 conn.write_all(DATE).await?;
127
128 let mut reader = BufReader::new(&mut *conn);
130 let mut response = Vec::with_capacity(64);
132 reader.read_until(b'\n', &mut response).await?;
133
134 if ResponseParser::is_response_code(&response, 111) {
136 Ok(())
137 } else {
138 Err("Unexpected response from DATE command".into())
139 }
140 }
141
142 pub async fn is_healthy(&self, backend_id: BackendId) -> bool {
144 let health = self.backend_health.read().await;
145 health
146 .get(&backend_id)
147 .map(|h| h.status == HealthStatus::Healthy)
148 .unwrap_or(true) }
150
151 pub async fn get_backend_health(&self, backend_id: BackendId) -> Option<BackendHealth> {
153 let health = self.backend_health.read().await;
154 health.get(&backend_id).cloned()
155 }
156
157 pub async fn get_metrics(&self) -> HealthMetrics {
159 let health = self.backend_health.read().await;
160
161 let mut metrics = HealthMetrics {
162 total_checks: health
163 .values()
164 .map(|h| h.total_successes + h.total_failures)
165 .sum(),
166 ..Default::default()
167 };
168
169 for backend_health in health.values() {
170 match backend_health.status {
171 HealthStatus::Healthy => metrics.healthy_count += 1,
172 HealthStatus::Unhealthy => metrics.unhealthy_count += 1,
173 }
174 }
175
176 metrics
177 }
178
179 pub async fn get_healthy_backends(&self) -> Vec<BackendId> {
181 let health = self.backend_health.read().await;
182 health
183 .iter()
184 .filter(|(_, h)| h.status == HealthStatus::Healthy)
185 .map(|(id, _)| *id)
186 .collect()
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193 use std::time::Duration;
194
195 #[tokio::test]
196 async fn test_health_checker_creation() {
197 let config = HealthCheckConfig::default();
198 let checker = HealthChecker::new(config);
199
200 let metrics = checker.get_metrics().await;
201 assert_eq!(metrics.healthy_count, 0);
202 assert_eq!(metrics.unhealthy_count, 0);
203 }
204
205 #[tokio::test]
206 async fn test_register_backend() {
207 let config = HealthCheckConfig::default();
208 let checker = HealthChecker::new(config);
209
210 let backend_id = BackendId::from_index(0);
211 checker.register_backend(backend_id).await;
212
213 let metrics = checker.get_metrics().await;
214 assert_eq!(metrics.healthy_count, 1);
215 assert_eq!(metrics.unhealthy_count, 0);
216 }
217
218 #[tokio::test]
219 async fn test_multiple_backend_registration() {
220 let config = HealthCheckConfig::default();
221 let checker = HealthChecker::new(config);
222
223 for i in 0..3 {
224 checker.register_backend(BackendId::from_index(i)).await;
225 }
226
227 let metrics = checker.get_metrics().await;
228 assert_eq!(metrics.healthy_count, 3);
229 assert_eq!(metrics.unhealthy_count, 0);
230 }
231
232 #[tokio::test]
233 async fn test_get_healthy_backends() {
234 let config = HealthCheckConfig::default();
235 let checker = HealthChecker::new(config);
236
237 let backend_ids = vec![
238 BackendId::from_index(0),
239 BackendId::from_index(1),
240 BackendId::from_index(2),
241 ];
242
243 for id in &backend_ids {
244 checker.register_backend(*id).await;
245 }
246
247 let healthy = checker.get_healthy_backends().await;
248 assert_eq!(healthy.len(), 3);
249 }
250
251 #[tokio::test]
252 async fn test_health_check_config_default() {
253 let config = HealthCheckConfig::default();
254 assert_eq!(config.check_interval, Duration::from_secs(30));
255 assert_eq!(config.check_timeout, Duration::from_secs(5));
256 assert_eq!(config.unhealthy_threshold, 3);
257 }
258
259 #[tokio::test]
260 async fn test_health_check_config_custom() {
261 let config = HealthCheckConfig {
262 check_interval: Duration::from_secs(10),
263 check_timeout: Duration::from_secs(2),
264 unhealthy_threshold: 5,
265 };
266
267 let checker = HealthChecker::new(config.clone());
268 assert_eq!(checker.config.check_interval, Duration::from_secs(10));
269 assert_eq!(checker.config.check_timeout, Duration::from_secs(2));
270 assert_eq!(checker.config.unhealthy_threshold, 5);
271 }
272
273 #[tokio::test]
274 async fn test_simulated_connection_failure() {
275 let config = HealthCheckConfig {
276 check_interval: Duration::from_millis(100),
277 check_timeout: Duration::from_millis(50),
278 unhealthy_threshold: 2,
279 };
280 let checker = HealthChecker::new(config);
281 let backend_id = BackendId::from_index(0);
282
283 checker.register_backend(backend_id).await;
284
285 {
287 let mut health = checker.backend_health.write().await;
288 if let Some(backend_health) = health.get_mut(&backend_id) {
289 backend_health.record_failure(2);
291 backend_health.record_failure(2);
292 }
293 }
294
295 let metrics = checker.get_metrics().await;
296 assert_eq!(metrics.unhealthy_count, 1);
297 assert_eq!(metrics.healthy_count, 0);
298 }
299
300 #[tokio::test]
301 async fn test_recovery_after_failures() {
302 let config = HealthCheckConfig {
303 check_interval: Duration::from_millis(100),
304 check_timeout: Duration::from_millis(50),
305 unhealthy_threshold: 2,
306 };
307 let checker = HealthChecker::new(config);
308 let backend_id = BackendId::from_index(0);
309
310 checker.register_backend(backend_id).await;
311
312 {
314 let mut health = checker.backend_health.write().await;
315 if let Some(backend_health) = health.get_mut(&backend_id) {
316 backend_health.record_failure(2);
317 backend_health.record_failure(2);
318 backend_health.record_success();
320 }
321 }
322
323 let metrics = checker.get_metrics().await;
324 assert_eq!(metrics.healthy_count, 1);
325 assert_eq!(metrics.unhealthy_count, 0);
326 }
327
328 #[tokio::test]
329 async fn test_health_metrics_mixed_states() {
330 let config = HealthCheckConfig::default();
331 let checker = HealthChecker::new(config);
332
333 for i in 0..5 {
335 checker.register_backend(BackendId::from_index(i)).await;
336 }
337
338 {
340 let mut health = checker.backend_health.write().await;
341 if let Some(backend_health) = health.get_mut(&BackendId::from_index(1)) {
343 backend_health.record_failure(3);
344 backend_health.record_failure(3);
345 backend_health.record_failure(3);
346 }
347 if let Some(backend_health) = health.get_mut(&BackendId::from_index(3)) {
348 backend_health.record_failure(3);
349 backend_health.record_failure(3);
350 backend_health.record_failure(3);
351 }
352 }
353
354 let metrics = checker.get_metrics().await;
355 assert_eq!(metrics.healthy_count, 3);
356 assert_eq!(metrics.unhealthy_count, 2);
357
358 let healthy = checker.get_healthy_backends().await;
359 assert_eq!(healthy.len(), 3);
360 assert!(healthy.contains(&BackendId::from_index(0)));
361 assert!(healthy.contains(&BackendId::from_index(2)));
362 assert!(healthy.contains(&BackendId::from_index(4)));
363 }
364
365 #[tokio::test]
366 async fn test_backend_health_isolation() {
367 let config = HealthCheckConfig::default();
368 let checker = HealthChecker::new(config);
369
370 let backend1 = BackendId::from_index(0);
371 let backend2 = BackendId::from_index(1);
372
373 checker.register_backend(backend1).await;
374 checker.register_backend(backend2).await;
375
376 {
378 let mut health = checker.backend_health.write().await;
379 if let Some(backend_health) = health.get_mut(&backend1) {
380 backend_health.record_failure(3);
381 backend_health.record_failure(3);
382 backend_health.record_failure(3);
383 }
384 }
385
386 let healthy = checker.get_healthy_backends().await;
387 assert_eq!(healthy.len(), 1);
388 assert_eq!(healthy[0], backend2);
389 }
390}