use axum::{Json, extract::State, http::StatusCode};
use serde::Serialize;
use crate::handlers::AppState;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
pub enum HealthStatus {
#[serde(rename = "OK")]
Ok,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
#[serde(rename_all = "lowercase")]
pub enum HealthTrackingStatus {
Operational,
Degraded,
}
#[derive(Debug, Serialize)]
pub struct HealthResponse {
status: HealthStatus,
health_tracking_status: HealthTrackingStatus,
metrics_recording_status: HealthTrackingStatus,
background_task_status: HealthTrackingStatus,
background_task_failures: u64,
}
impl HealthResponse {
pub fn new(
health_tracking_failures: u64,
metrics_recording_failures: u64,
background_task_failures: u64,
) -> Self {
let health_tracking_status = if health_tracking_failures > 0 {
HealthTrackingStatus::Degraded
} else {
HealthTrackingStatus::Operational
};
let metrics_recording_status = if metrics_recording_failures > 0 {
HealthTrackingStatus::Degraded
} else {
HealthTrackingStatus::Operational
};
let background_task_status = if background_task_failures > 0 {
HealthTrackingStatus::Degraded
} else {
HealthTrackingStatus::Operational
};
Self {
status: HealthStatus::Ok,
health_tracking_status,
metrics_recording_status,
background_task_status,
background_task_failures,
}
}
}
pub async fn handler(State(state): State<AppState>) -> (StatusCode, Json<HealthResponse>) {
let health_tracking_failures = state.metrics().health_tracking_failures_count();
let metrics_recording_failures = state.metrics().metrics_recording_failures_count();
let background_task_failures = state.metrics().background_task_failures_count();
let response = HealthResponse::new(
health_tracking_failures,
metrics_recording_failures,
background_task_failures,
);
(StatusCode::OK, Json(response))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::Config;
use axum::extract::State;
use std::sync::Arc;
fn create_test_state() -> AppState {
let toml = r#"
[server]
host = "127.0.0.1"
port = 3000
request_timeout_seconds = 30
[[models.fast]]
name = "fast-1"
base_url = "http://localhost:1234/v1"
max_tokens = 2048
temperature = 0.7
weight = 1.0
priority = 1
[[models.balanced]]
name = "balanced-1"
base_url = "http://localhost:1235/v1"
max_tokens = 4096
temperature = 0.7
weight = 1.0
priority = 1
[[models.deep]]
name = "deep-1"
base_url = "http://localhost:1236/v1"
max_tokens = 8192
temperature = 0.7
weight = 1.0
priority = 1
[routing]
strategy = "rule"
default_importance = "normal"
router_tier = "balanced"
"#;
let config: Config = toml::from_str(toml).expect("should parse test config");
AppState::new(Arc::new(config)).expect("should create AppState")
}
#[tokio::test]
async fn test_health_handler_returns_ok() {
let state = create_test_state();
let (status, Json(response)) = handler(State(state)).await;
assert_eq!(status, StatusCode::OK);
let json = serde_json::to_value(&response).expect("Should serialize");
assert_eq!(json["status"], "OK");
assert_eq!(json["health_tracking_status"], "operational");
assert_eq!(json["metrics_recording_status"], "operational");
}
#[tokio::test]
async fn test_health_handler_shows_degraded_when_health_tracking_fails() {
let state = create_test_state();
state
.metrics()
.health_tracking_failure("test-endpoint", "unknown_endpoint");
let (status, Json(response)) = handler(State(state)).await;
assert_eq!(status, StatusCode::OK);
let json = serde_json::to_value(&response).expect("Should serialize");
assert_eq!(json["status"], "OK");
assert_eq!(json["health_tracking_status"], "degraded");
assert_eq!(json["metrics_recording_status"], "operational");
}
#[tokio::test]
async fn test_health_handler_shows_degraded_when_metrics_recording_fails() {
let state = create_test_state();
state.metrics().metrics_recording_failure("record_request");
let (status, Json(response)) = handler(State(state)).await;
assert_eq!(status, StatusCode::OK);
let json = serde_json::to_value(&response).expect("Should serialize");
assert_eq!(json["status"], "OK");
assert_eq!(json["health_tracking_status"], "operational");
assert_eq!(json["metrics_recording_status"], "degraded");
}
#[tokio::test]
async fn test_health_handler_shows_both_degraded_when_both_fail() {
let state = create_test_state();
state
.metrics()
.health_tracking_failure("test-endpoint", "unknown_endpoint");
state.metrics().metrics_recording_failure("record_request");
let (status, Json(response)) = handler(State(state)).await;
assert_eq!(status, StatusCode::OK);
let json = serde_json::to_value(&response).expect("Should serialize");
assert_eq!(json["status"], "OK");
assert_eq!(json["health_tracking_status"], "degraded");
assert_eq!(json["metrics_recording_status"], "degraded");
}
}