1use std::collections::HashMap;
28use std::future::Future;
29use std::pin::Pin;
30use std::sync::Arc;
31use std::time::{Duration, Instant};
32use tokio::sync::RwLock;
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum HealthStatus {
37 Healthy,
39 Degraded,
41 Unhealthy,
43}
44
45impl HealthStatus {
46 #[must_use]
48 pub const fn score(&self) -> u8 {
49 match self {
50 Self::Healthy => 2,
51 Self::Degraded => 1,
52 Self::Unhealthy => 0,
53 }
54 }
55}
56
57#[derive(Debug, Clone)]
59pub struct HealthCheckResult {
60 pub component: String,
62 pub status: HealthStatus,
64 pub message: Option<String>,
66 pub check_duration: Duration,
68 pub timestamp: Instant,
70}
71
72impl HealthCheckResult {
73 pub fn new(component: String, status: HealthStatus) -> Self {
75 Self {
76 component,
77 status,
78 message: None,
79 check_duration: Duration::ZERO,
80 timestamp: Instant::now(),
81 }
82 }
83
84 pub fn with_message(mut self, message: impl Into<String>) -> Self {
86 self.message = Some(message.into());
87 self
88 }
89
90 pub fn with_duration(mut self, duration: Duration) -> Self {
92 self.check_duration = duration;
93 self
94 }
95}
96
97type HealthCheckFn = Box<
99 dyn Fn() -> Pin<Box<dyn Future<Output = Result<HealthStatus, String>> + Send>> + Send + Sync,
100>;
101
102pub struct HealthChecker {
104 checks: Arc<RwLock<HashMap<String, HealthCheckFn>>>,
105}
106
107impl HealthChecker {
108 #[must_use]
110 pub fn new() -> Self {
111 Self {
112 checks: Arc::new(RwLock::new(HashMap::new())),
113 }
114 }
115
116 pub async fn register<F, Fut>(&mut self, component: impl Into<String>, check: F)
118 where
119 F: Fn() -> Fut + Send + Sync + 'static,
120 Fut: Future<Output = Result<HealthStatus, String>> + Send + 'static,
121 {
122 let component_name = component.into();
123 let check_fn: HealthCheckFn = Box::new(move || Box::pin(check()));
124 self.checks.write().await.insert(component_name, check_fn);
125 }
126
127 pub async fn unregister(&mut self, component: &str) -> bool {
129 self.checks.write().await.remove(component).is_some()
130 }
131
132 pub async fn check(&self, component: &str) -> Option<HealthCheckResult> {
134 let checks = self.checks.read().await;
135 let check_fn = checks.get(component)?;
136
137 let start = Instant::now();
138 let result = check_fn().await;
139 let duration = start.elapsed();
140
141 let (status, message) = match result {
142 Ok(status) => (status, None),
143 Err(msg) => (HealthStatus::Unhealthy, Some(msg)),
144 };
145
146 Some(
147 HealthCheckResult::new(component.to_string(), status)
148 .with_duration(duration)
149 .with_message(message.unwrap_or_default()),
150 )
151 }
152
153 pub async fn check_all(&self) -> HealthReport {
155 let checks = self.checks.read().await;
156 let mut results = Vec::new();
157
158 for component in checks.keys() {
159 if let Some(result) = self.check(component).await {
160 results.push(result);
161 }
162 }
163
164 HealthReport { results }
165 }
166
167 #[must_use]
169 pub async fn count(&self) -> usize {
170 self.checks.read().await.len()
171 }
172}
173
174impl Default for HealthChecker {
175 fn default() -> Self {
176 Self::new()
177 }
178}
179
180#[derive(Debug, Clone)]
182pub struct HealthReport {
183 results: Vec<HealthCheckResult>,
184}
185
186impl HealthReport {
187 #[must_use]
189 #[inline]
190 pub fn results(&self) -> &[HealthCheckResult] {
191 &self.results
192 }
193
194 #[must_use]
196 #[inline]
197 pub fn overall_status(&self) -> HealthStatus {
198 if self.results.is_empty() {
199 return HealthStatus::Healthy;
200 }
201
202 if self
204 .results
205 .iter()
206 .any(|r| r.status == HealthStatus::Unhealthy)
207 {
208 return HealthStatus::Unhealthy;
209 }
210
211 if self
213 .results
214 .iter()
215 .any(|r| r.status == HealthStatus::Degraded)
216 {
217 return HealthStatus::Degraded;
218 }
219
220 HealthStatus::Healthy
221 }
222
223 #[must_use]
225 #[inline]
226 pub fn healthy_count(&self) -> usize {
227 self.results
228 .iter()
229 .filter(|r| r.status == HealthStatus::Healthy)
230 .count()
231 }
232
233 #[must_use]
235 #[inline]
236 pub fn degraded_count(&self) -> usize {
237 self.results
238 .iter()
239 .filter(|r| r.status == HealthStatus::Degraded)
240 .count()
241 }
242
243 #[must_use]
245 #[inline]
246 pub fn unhealthy_count(&self) -> usize {
247 self.results
248 .iter()
249 .filter(|r| r.status == HealthStatus::Unhealthy)
250 .count()
251 }
252
253 #[must_use]
255 #[inline]
256 pub fn total_duration(&self) -> Duration {
257 self.results.iter().map(|r| r.check_duration).sum()
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264
265 #[test]
266 fn test_health_status_score() {
267 assert_eq!(HealthStatus::Healthy.score(), 2);
268 assert_eq!(HealthStatus::Degraded.score(), 1);
269 assert_eq!(HealthStatus::Unhealthy.score(), 0);
270 }
271
272 #[test]
273 fn test_health_check_result() {
274 let result = HealthCheckResult::new("storage".to_string(), HealthStatus::Healthy)
275 .with_message("All systems operational");
276
277 assert_eq!(result.component, "storage");
278 assert_eq!(result.status, HealthStatus::Healthy);
279 assert_eq!(result.message, Some("All systems operational".to_string()));
280 }
281
282 #[tokio::test]
283 async fn test_health_checker_register() {
284 let mut checker = HealthChecker::new();
285
286 checker
287 .register("test", || async { Ok(HealthStatus::Healthy) })
288 .await;
289
290 assert_eq!(checker.count().await, 1);
291 }
292
293 #[tokio::test]
294 async fn test_health_checker_unregister() {
295 let mut checker = HealthChecker::new();
296
297 checker
298 .register("test", || async { Ok(HealthStatus::Healthy) })
299 .await;
300
301 assert!(checker.unregister("test").await);
302 assert_eq!(checker.count().await, 0);
303 assert!(!checker.unregister("nonexistent").await);
304 }
305
306 #[tokio::test]
307 async fn test_health_checker_check() {
308 let mut checker = HealthChecker::new();
309
310 checker
311 .register("storage", || async { Ok(HealthStatus::Healthy) })
312 .await;
313
314 let result = checker.check("storage").await;
315 assert!(result.is_some());
316
317 let result = result.unwrap();
318 assert_eq!(result.component, "storage");
319 assert_eq!(result.status, HealthStatus::Healthy);
320 }
321
322 #[tokio::test]
323 async fn test_health_checker_check_all() {
324 let mut checker = HealthChecker::new();
325
326 checker
327 .register("storage", || async { Ok(HealthStatus::Healthy) })
328 .await;
329
330 checker
331 .register("network", || async { Ok(HealthStatus::Degraded) })
332 .await;
333
334 checker
335 .register("database", || async {
336 Err("Connection failed".to_string())
337 })
338 .await;
339
340 let report = checker.check_all().await;
341 assert_eq!(report.results().len(), 3);
342 assert_eq!(report.healthy_count(), 1);
343 assert_eq!(report.degraded_count(), 1);
344 assert_eq!(report.unhealthy_count(), 1);
345 assert_eq!(report.overall_status(), HealthStatus::Unhealthy);
346 }
347
348 #[tokio::test]
349 async fn test_health_report_overall_status() {
350 let mut checker = HealthChecker::new();
351
352 checker
354 .register("storage", || async { Ok(HealthStatus::Healthy) })
355 .await;
356 checker
357 .register("network", || async { Ok(HealthStatus::Healthy) })
358 .await;
359
360 let report = checker.check_all().await;
361 assert_eq!(report.overall_status(), HealthStatus::Healthy);
362
363 let mut checker = HealthChecker::new();
365 checker
366 .register("storage", || async { Ok(HealthStatus::Healthy) })
367 .await;
368 checker
369 .register("network", || async { Ok(HealthStatus::Degraded) })
370 .await;
371
372 let report = checker.check_all().await;
373 assert_eq!(report.overall_status(), HealthStatus::Degraded);
374
375 let mut checker = HealthChecker::new();
377 checker
378 .register("storage", || async { Ok(HealthStatus::Healthy) })
379 .await;
380 checker
381 .register("network", || async { Ok(HealthStatus::Unhealthy) })
382 .await;
383
384 let report = checker.check_all().await;
385 assert_eq!(report.overall_status(), HealthStatus::Unhealthy);
386 }
387
388 #[tokio::test]
389 async fn test_health_report_empty() {
390 let checker = HealthChecker::new();
391 let report = checker.check_all().await;
392
393 assert_eq!(report.results().len(), 0);
394 assert_eq!(report.overall_status(), HealthStatus::Healthy);
395 assert_eq!(report.total_duration(), Duration::ZERO);
396 }
397}