use async_trait::async_trait;
use dashmap::DashMap;
use std::sync::Arc;
use std::time::Instant;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum HealthStatus {
Healthy,
Degraded(String),
Unhealthy(String),
}
impl HealthStatus {
pub fn is_healthy(&self) -> bool {
matches!(self, HealthStatus::Healthy)
}
pub fn label(&self) -> &'static str {
match self {
HealthStatus::Healthy => "healthy",
HealthStatus::Degraded(_) => "degraded",
HealthStatus::Unhealthy(_) => "unhealthy",
}
}
}
#[async_trait]
pub trait HealthCheck: Send + Sync + 'static {
async fn check(&self) -> HealthStatus;
}
pub struct CapacityCheck {
current: Arc<dyn Fn() -> usize + Send + Sync>,
soft_limit: usize,
}
impl CapacityCheck {
pub fn new(current: impl Fn() -> usize + Send + Sync + 'static, soft_limit: usize) -> Self {
Self {
current: Arc::new(current),
soft_limit: soft_limit.max(1),
}
}
}
#[async_trait]
impl HealthCheck for CapacityCheck {
async fn check(&self) -> HealthStatus {
let n = (self.current)();
let pct = n * 100 / self.soft_limit;
if n >= self.soft_limit {
HealthStatus::Unhealthy(format!("at capacity: {n}/{}", self.soft_limit))
} else if pct >= 90 {
HealthStatus::Degraded(format!("near capacity: {n}/{}", self.soft_limit))
} else {
HealthStatus::Healthy
}
}
}
#[derive(Clone)]
pub struct HealthRegistry {
checks: Arc<DashMap<&'static str, Arc<dyn HealthCheck>>>,
started_at: Instant,
}
impl Default for HealthRegistry {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct HealthReport {
pub overall: HealthStatus,
pub checks: Vec<(&'static str, HealthStatus)>,
pub uptime_secs: u64,
}
impl HealthRegistry {
pub fn new() -> Self {
Self {
checks: Arc::new(DashMap::new()),
started_at: Instant::now(),
}
}
pub fn register(&self, name: &'static str, check: impl HealthCheck) {
self.checks.insert(name, Arc::new(check));
}
pub async fn run_all(&self) -> HealthReport {
let mut results = Vec::with_capacity(self.checks.len());
let mut overall = HealthStatus::Healthy;
let entries: Vec<(&'static str, Arc<dyn HealthCheck>)> = self
.checks
.iter()
.map(|r| (*r.key(), Arc::clone(r.value())))
.collect();
for (name, check) in entries {
let status = check.check().await;
overall = worst(overall, &status);
results.push((name, status));
}
results.sort_by_key(|(n, _)| *n);
HealthReport {
overall,
checks: results,
uptime_secs: self.started_at.elapsed().as_secs(),
}
}
}
fn worst(acc: HealthStatus, next: &HealthStatus) -> HealthStatus {
fn rank(s: &HealthStatus) -> u8 {
match s {
HealthStatus::Healthy => 0,
HealthStatus::Degraded(_) => 1,
HealthStatus::Unhealthy(_) => 2,
}
}
if rank(next) > rank(&acc) {
next.clone()
} else {
acc
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
struct Fixed(HealthStatus);
#[async_trait]
impl HealthCheck for Fixed {
async fn check(&self) -> HealthStatus {
self.0.clone()
}
}
#[tokio::test]
async fn aggregates_to_worst_status() {
let reg = HealthRegistry::new();
reg.register("a", Fixed(HealthStatus::Healthy));
reg.register("b", Fixed(HealthStatus::Degraded("warm".into())));
let report = reg.run_all().await;
assert_eq!(report.overall, HealthStatus::Degraded("warm".into()));
assert_eq!(report.checks.len(), 2);
reg.register("c", Fixed(HealthStatus::Unhealthy("down".into())));
assert_eq!(
reg.run_all().await.overall,
HealthStatus::Unhealthy("down".into())
);
}
#[tokio::test]
async fn capacity_check_degrades_then_fails() {
let n = Arc::new(AtomicUsize::new(0));
let n2 = Arc::clone(&n);
let check = CapacityCheck::new(move || n2.load(Ordering::Relaxed), 10);
assert!(check.check().await.is_healthy());
n.store(9, Ordering::Relaxed);
assert!(matches!(check.check().await, HealthStatus::Degraded(_)));
n.store(10, Ordering::Relaxed);
assert!(matches!(check.check().await, HealthStatus::Unhealthy(_)));
}
}