use async_trait::async_trait;
use axum::{http::StatusCode, response::IntoResponse, Json};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Status {
Healthy,
Degraded,
Unhealthy,
}
impl Status {
pub fn status_code(&self) -> StatusCode {
match self {
Status::Healthy => StatusCode::OK,
Status::Degraded => StatusCode::OK, Status::Unhealthy => StatusCode::SERVICE_UNAVAILABLE,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CheckResult {
pub status: Status,
#[serde(skip_serializing_if = "Option::is_none")]
pub message: Option<String>,
pub duration_ms: u64,
}
impl CheckResult {
pub fn healthy() -> Self {
Self {
status: Status::Healthy,
message: None,
duration_ms: 0,
}
}
pub fn healthy_with_message(message: impl Into<String>) -> Self {
Self {
status: Status::Healthy,
message: Some(message.into()),
duration_ms: 0,
}
}
pub fn unhealthy(message: impl Into<String>) -> Self {
Self {
status: Status::Unhealthy,
message: Some(message.into()),
duration_ms: 0,
}
}
pub fn degraded(message: impl Into<String>) -> Self {
Self {
status: Status::Degraded,
message: Some(message.into()),
duration_ms: 0,
}
}
pub fn with_duration(mut self, duration_ms: u64) -> Self {
self.duration_ms = duration_ms;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HealthStatus {
pub status: Status,
pub version: String,
pub checks: HashMap<String, CheckResult>,
}
impl HealthStatus {
pub fn new(version: impl Into<String>) -> Self {
Self {
status: Status::Healthy,
version: version.into(),
checks: HashMap::new(),
}
}
pub fn with_check(mut self, name: impl Into<String>, result: CheckResult) -> Self {
self.checks.insert(name.into(), result);
self.update_overall_status();
self
}
fn update_overall_status(&mut self) {
if self.checks.is_empty() {
self.status = Status::Healthy;
return;
}
let has_unhealthy = self.checks.values().any(|c| c.status == Status::Unhealthy);
let has_degraded = self.checks.values().any(|c| c.status == Status::Degraded);
self.status = if has_unhealthy {
Status::Unhealthy
} else if has_degraded {
Status::Degraded
} else {
Status::Healthy
};
}
pub fn status_code(&self) -> StatusCode {
self.status.status_code()
}
}
impl IntoResponse for HealthStatus {
fn into_response(self) -> axum::response::Response {
let status_code = self.status_code();
(status_code, Json(self)).into_response()
}
}
#[async_trait]
pub trait HealthCheck: Send + Sync {
async fn check(&self) -> CheckResult;
fn name(&self) -> &str;
}
pub struct HealthChecker {
checks: Vec<Arc<dyn HealthCheck>>,
version: String,
}
impl HealthChecker {
pub fn new(version: impl Into<String>) -> Self {
Self {
checks: Vec::new(),
version: version.into(),
}
}
pub fn add_check(mut self, check: Arc<dyn HealthCheck>) -> Self {
self.checks.push(check);
self
}
pub async fn check_health(&self) -> HealthStatus {
let mut status = HealthStatus::new(&self.version);
for check in &self.checks {
let start = Instant::now();
let mut result = check.check().await;
result.duration_ms = start.elapsed().as_millis() as u64;
status = status.with_check(check.name(), result);
}
status
}
}
pub struct AlwaysHealthyCheck {
name: String,
}
impl AlwaysHealthyCheck {
pub fn new(name: impl Into<String>) -> Self {
Self { name: name.into() }
}
}
#[async_trait]
impl HealthCheck for AlwaysHealthyCheck {
async fn check(&self) -> CheckResult {
CheckResult::healthy()
}
fn name(&self) -> &str {
&self.name
}
}
pub async fn health_handler(
axum::extract::State(checker): axum::extract::State<Arc<HealthChecker>>,
) -> impl IntoResponse {
checker.check_health().await
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_status_code() {
assert_eq!(Status::Healthy.status_code(), StatusCode::OK);
assert_eq!(Status::Degraded.status_code(), StatusCode::OK);
assert_eq!(
Status::Unhealthy.status_code(),
StatusCode::SERVICE_UNAVAILABLE
);
}
#[test]
fn test_check_result_builders() {
let healthy = CheckResult::healthy();
assert_eq!(healthy.status, Status::Healthy);
assert!(healthy.message.is_none());
let unhealthy = CheckResult::unhealthy("Database down");
assert_eq!(unhealthy.status, Status::Unhealthy);
assert_eq!(unhealthy.message, Some("Database down".to_string()));
let degraded = CheckResult::degraded("Cache unavailable");
assert_eq!(degraded.status, Status::Degraded);
}
#[test]
fn test_health_status_overall() {
let mut status = HealthStatus::new("1.0.0");
status = status
.with_check("db", CheckResult::healthy())
.with_check("cache", CheckResult::healthy());
assert_eq!(status.status, Status::Healthy);
let mut status = HealthStatus::new("1.0.0");
status = status
.with_check("db", CheckResult::healthy())
.with_check("cache", CheckResult::degraded("Slow response"));
assert_eq!(status.status, Status::Degraded);
let mut status = HealthStatus::new("1.0.0");
status = status
.with_check("db", CheckResult::unhealthy("Connection refused"))
.with_check("cache", CheckResult::healthy());
assert_eq!(status.status, Status::Unhealthy);
}
#[tokio::test]
async fn test_health_checker() {
let checker = HealthChecker::new("1.0.0")
.add_check(Arc::new(AlwaysHealthyCheck::new("test")));
let health = checker.check_health().await;
assert_eq!(health.status, Status::Healthy);
assert_eq!(health.version, "1.0.0");
assert!(health.checks.contains_key("test"));
}
#[tokio::test]
async fn test_always_healthy_check() {
let check = AlwaysHealthyCheck::new("test");
assert_eq!(check.name(), "test");
let result = check.check().await;
assert_eq!(result.status, Status::Healthy);
}
}