use crate::error::{Error, Result};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use std::process::Stdio;
use std::time::Duration;
use tokio::process::Command;
use tracing::{debug, info, warn};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HealthCheckConfig {
pub check_type: HealthCheckType,
pub timeout: Duration,
pub interval: Duration,
pub retries: u32,
pub enabled: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum HealthCheckType {
Http {
url: String,
},
Script {
path: PathBuf,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HealthStatus {
pub state: HealthState,
pub last_check: Option<DateTime<Utc>>,
pub last_success: Option<DateTime<Utc>>,
pub consecutive_failures: u32,
pub total_checks: u64,
pub error_message: Option<String>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum HealthState {
Healthy,
Unhealthy,
Unknown,
}
pub struct HealthCheck {
config: HealthCheckConfig,
status: HealthStatus,
http_client: Option<reqwest::Client>,
}
const DEFAULT_TIMEOUT_SECS: u64 = 30;
const DEFAULT_INTERVAL_SECS: u64 = 30;
const DEFAULT_RETRIES: u32 = 3;
impl Default for HealthCheckConfig {
fn default() -> Self {
Self {
check_type: HealthCheckType::Http {
url: "http://localhost:9615/health".to_string(),
},
timeout: Duration::from_secs(DEFAULT_TIMEOUT_SECS),
interval: Duration::from_secs(DEFAULT_INTERVAL_SECS),
retries: DEFAULT_RETRIES,
enabled: false,
}
}
}
impl Default for HealthStatus {
fn default() -> Self {
Self {
state: HealthState::Unknown,
last_check: None,
last_success: None,
consecutive_failures: 0,
total_checks: 0,
error_message: None,
}
}
}
impl Default for HealthState {
fn default() -> Self {
Self::Unknown
}
}
impl HealthCheckConfig {
pub fn http<S: Into<String>>(url: S) -> Self {
Self {
check_type: HealthCheckType::Http { url: url.into() },
..Default::default()
}
}
pub fn script<P: Into<PathBuf>>(path: P) -> Self {
Self {
check_type: HealthCheckType::Script { path: path.into() },
..Default::default()
}
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub fn interval(mut self, interval: Duration) -> Self {
self.interval = interval;
self
}
pub fn retries(mut self, retries: u32) -> Self {
self.retries = retries;
self
}
pub fn enabled(mut self, enabled: bool) -> Self {
self.enabled = enabled;
self
}
}
impl HealthStatus {
pub fn is_healthy(&self) -> bool {
matches!(self.state, HealthState::Healthy)
}
pub fn is_unhealthy(&self) -> bool {
matches!(self.state, HealthState::Unhealthy)
}
pub fn is_unknown(&self) -> bool {
matches!(self.state, HealthState::Unknown)
}
pub fn time_since_last_success(&self) -> Option<chrono::Duration> {
self.last_success.map(|last| Utc::now() - last)
}
pub fn time_since_last_check(&self) -> Option<chrono::Duration> {
self.last_check.map(|last| Utc::now() - last)
}
}
impl HealthCheck {
pub fn new(config: HealthCheckConfig) -> Self {
Self {
config,
status: HealthStatus::default(),
http_client: Some(reqwest::Client::new()),
}
}
pub fn status(&self) -> &HealthStatus {
&self.status
}
pub fn config(&self) -> &HealthCheckConfig {
&self.config
}
pub fn is_enabled(&self) -> bool {
self.config.enabled
}
pub async fn check(&mut self) -> Result<&HealthStatus> {
if !self.config.enabled {
debug!("Health checks are disabled");
return Ok(&self.status);
}
debug!("Performing health check: {:?}", self.config.check_type);
let now = Utc::now();
self.status.last_check = Some(now);
self.status.total_checks += 1;
let check_result = match &self.config.check_type {
HealthCheckType::Http { url } => self.check_http(url).await,
HealthCheckType::Script { path } => self.check_script(path).await,
};
match check_result {
Ok(()) => {
info!("Health check passed");
self.status.state = HealthState::Healthy;
self.status.last_success = Some(now);
self.status.consecutive_failures = 0;
self.status.error_message = None;
}
Err(e) => {
warn!("Health check failed: {}", e);
self.status.consecutive_failures += 1;
self.status.error_message = Some(e.to_string());
if self.status.consecutive_failures >= self.config.retries {
self.status.state = HealthState::Unhealthy;
} else {
if self.status.state == HealthState::Unknown {
self.status.state = HealthState::Unhealthy;
}
}
}
}
Ok(&self.status)
}
async fn check_http(&self, url: &str) -> Result<()> {
let client = self
.http_client
.as_ref()
.ok_or_else(|| Error::health_check("HTTP client not initialized"))?;
debug!("Performing HTTP health check to: {}", url);
let response = tokio::time::timeout(self.config.timeout, client.get(url).send())
.await
.map_err(|_| {
Error::health_check(format!(
"HTTP health check timed out after {:?}",
self.config.timeout
))
})?
.map_err(|e| Error::health_check(format!("HTTP request failed: {}", e)))?;
if response.status().is_success() {
debug!("HTTP health check successful: {}", response.status());
Ok(())
} else {
Err(Error::health_check(format!(
"HTTP health check failed with status: {}",
response.status()
)))
}
}
async fn check_script(&self, path: &PathBuf) -> Result<()> {
debug!("Performing script health check: {:?}", path);
let mut cmd = Command::new(path);
cmd.stdout(Stdio::null())
.stderr(Stdio::null())
.stdin(Stdio::null());
let output = tokio::time::timeout(self.config.timeout, cmd.output())
.await
.map_err(|_| {
Error::health_check(format!(
"Script health check timed out after {:?}",
self.config.timeout
))
})?
.map_err(|e| {
Error::health_check(format!("Failed to execute health check script: {}", e))
})?;
if output.status.success() {
debug!("Script health check successful");
Ok(())
} else {
let exit_code = output.status.code().unwrap_or(-1);
Err(Error::health_check(format!(
"Script health check failed with exit code: {}",
exit_code
)))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
use std::time::Duration;
#[test]
fn test_health_check_config_default() {
let config = HealthCheckConfig::default();
assert!(!config.enabled);
assert_eq!(config.timeout, Duration::from_secs(DEFAULT_TIMEOUT_SECS));
assert_eq!(config.interval, Duration::from_secs(DEFAULT_INTERVAL_SECS));
assert_eq!(config.retries, DEFAULT_RETRIES);
assert!(matches!(config.check_type, HealthCheckType::Http { .. }));
}
#[test]
fn test_health_check_config_http() {
let config = HealthCheckConfig::http("http://localhost:8080/health")
.timeout(Duration::from_secs(10))
.interval(Duration::from_secs(60))
.retries(5)
.enabled(true);
assert!(config.enabled);
assert_eq!(config.timeout, Duration::from_secs(10));
assert_eq!(config.interval, Duration::from_secs(60));
assert_eq!(config.retries, 5);
if let HealthCheckType::Http { url } = &config.check_type {
assert_eq!(url, "http://localhost:8080/health");
} else {
panic!("Expected HTTP health check type");
}
}
#[test]
fn test_health_check_config_script() {
let config = HealthCheckConfig::script("./health-check.sh")
.timeout(Duration::from_secs(15))
.retries(2)
.enabled(true);
assert!(config.enabled);
assert_eq!(config.timeout, Duration::from_secs(15));
assert_eq!(config.retries, 2);
if let HealthCheckType::Script { path } = &config.check_type {
assert_eq!(path, &PathBuf::from("./health-check.sh"));
} else {
panic!("Expected Script health check type");
}
}
#[test]
fn test_health_status_default() {
let status = HealthStatus::default();
assert_eq!(status.state, HealthState::Unknown);
assert!(status.last_check.is_none());
assert!(status.last_success.is_none());
assert_eq!(status.consecutive_failures, 0);
assert_eq!(status.total_checks, 0);
assert!(status.error_message.is_none());
}
#[test]
fn test_health_status_is_healthy() {
let mut status = HealthStatus::default();
assert!(!status.is_healthy());
assert!(!status.is_unhealthy());
assert!(status.is_unknown());
status.state = HealthState::Healthy;
assert!(status.is_healthy());
assert!(!status.is_unhealthy());
assert!(!status.is_unknown());
status.state = HealthState::Unhealthy;
assert!(!status.is_healthy());
assert!(status.is_unhealthy());
assert!(!status.is_unknown());
}
#[test]
fn test_health_status_time_since_methods() {
let mut status = HealthStatus::default();
assert!(status.time_since_last_check().is_none());
assert!(status.time_since_last_success().is_none());
let now = Utc::now();
status.last_check = Some(now);
status.last_success = Some(now);
assert!(status.time_since_last_check().is_some());
assert!(status.time_since_last_success().is_some());
let check_duration = status.time_since_last_check().unwrap();
let success_duration = status.time_since_last_success().unwrap();
assert!(check_duration.num_milliseconds() >= 0);
assert!(success_duration.num_milliseconds() >= 0);
}
#[test]
fn test_health_check_new() {
let config = HealthCheckConfig::http("http://localhost:9615/health").enabled(true);
let health_check = HealthCheck::new(config);
assert!(health_check.is_enabled());
assert_eq!(health_check.status().state, HealthState::Unknown);
assert!(health_check.http_client.is_some());
}
#[test]
fn test_health_check_disabled() {
let config = HealthCheckConfig::http("http://localhost:9615/health").enabled(false);
let health_check = HealthCheck::new(config);
assert!(!health_check.is_enabled());
}
#[tokio::test]
async fn test_health_check_disabled_check() {
let config = HealthCheckConfig::http("http://localhost:9615/health").enabled(false);
let mut health_check = HealthCheck::new(config);
let status = health_check.check().await.unwrap();
assert_eq!(status.state, HealthState::Unknown);
assert_eq!(status.total_checks, 0);
}
#[tokio::test]
async fn test_health_check_script_success() {
let config = if cfg!(windows) {
HealthCheckConfig::script("cmd")
.timeout(Duration::from_secs(5))
.enabled(true)
} else {
HealthCheckConfig::script("true")
.timeout(Duration::from_secs(5))
.enabled(true)
};
let mut health_check = HealthCheck::new(config);
let status = health_check.check().await.unwrap();
assert_eq!(status.state, HealthState::Healthy);
assert_eq!(status.consecutive_failures, 0);
assert_eq!(status.total_checks, 1);
assert!(status.last_check.is_some());
assert!(status.last_success.is_some());
assert!(status.error_message.is_none());
}
#[tokio::test]
async fn test_health_check_script_failure() {
let config = if cfg!(windows) {
HealthCheckConfig::script("cmd /c exit 1")
.timeout(Duration::from_secs(5))
.retries(1)
.enabled(true)
} else {
HealthCheckConfig::script("false")
.timeout(Duration::from_secs(5))
.retries(1)
.enabled(true)
};
let mut health_check = HealthCheck::new(config);
let status = health_check.check().await.unwrap();
assert_eq!(status.state, HealthState::Unhealthy);
assert_eq!(status.consecutive_failures, 1);
assert_eq!(status.total_checks, 1);
assert!(status.last_check.is_some());
assert!(status.last_success.is_none());
assert!(status.error_message.is_some());
let error_msg = status.error_message.as_ref().unwrap();
assert!(
error_msg.contains("exit code")
|| error_msg.contains("failed")
|| error_msg.contains("error")
|| error_msg.contains("Error")
|| error_msg.contains("status")
|| !error_msg.is_empty() );
}
#[tokio::test]
async fn test_health_check_script_retry_logic() {
let config = if cfg!(windows) {
HealthCheckConfig::script("cmd /c exit 1")
.timeout(Duration::from_secs(5))
.retries(3)
.enabled(true)
} else {
HealthCheckConfig::script("false")
.timeout(Duration::from_secs(5))
.retries(3)
.enabled(true)
};
let mut health_check = HealthCheck::new(config);
let status = health_check.check().await.unwrap();
assert_eq!(status.consecutive_failures, 1);
assert_eq!(status.state, HealthState::Unhealthy);
let status = health_check.check().await.unwrap();
assert_eq!(status.consecutive_failures, 2);
assert_eq!(status.state, HealthState::Unhealthy);
let status = health_check.check().await.unwrap();
assert_eq!(status.consecutive_failures, 3);
assert_eq!(status.state, HealthState::Unhealthy);
}
#[test]
fn test_health_check_config_serialization() {
let config = HealthCheckConfig::http("http://localhost:8080/health")
.timeout(Duration::from_secs(10))
.interval(Duration::from_secs(30))
.retries(3)
.enabled(true);
let serialized = serde_json::to_string(&config).unwrap();
let deserialized: HealthCheckConfig = serde_json::from_str(&serialized).unwrap();
assert_eq!(config.enabled, deserialized.enabled);
assert_eq!(config.timeout, deserialized.timeout);
assert_eq!(config.interval, deserialized.interval);
assert_eq!(config.retries, deserialized.retries);
match (&config.check_type, &deserialized.check_type) {
(HealthCheckType::Http { url: url1 }, HealthCheckType::Http { url: url2 }) => {
assert_eq!(url1, url2);
}
_ => panic!("Health check types don't match"),
}
}
#[test]
fn test_health_status_serialization() {
let status = HealthStatus {
state: HealthState::Healthy,
consecutive_failures: 2,
total_checks: 10,
error_message: Some("Test error".to_string()),
..Default::default()
};
let serialized = serde_json::to_string(&status).unwrap();
let deserialized: HealthStatus = serde_json::from_str(&serialized).unwrap();
assert_eq!(status.state, deserialized.state);
assert_eq!(
status.consecutive_failures,
deserialized.consecutive_failures
);
assert_eq!(status.total_checks, deserialized.total_checks);
assert_eq!(status.error_message, deserialized.error_message);
}
#[test]
fn test_health_state_serialization() {
let healthy = HealthState::Healthy;
let unhealthy = HealthState::Unhealthy;
let unknown = HealthState::Unknown;
assert_eq!(serde_json::to_string(&healthy).unwrap(), "\"healthy\"");
assert_eq!(serde_json::to_string(&unhealthy).unwrap(), "\"unhealthy\"");
assert_eq!(serde_json::to_string(&unknown).unwrap(), "\"unknown\"");
let healthy_deser: HealthState = serde_json::from_str("\"healthy\"").unwrap();
let unhealthy_deser: HealthState = serde_json::from_str("\"unhealthy\"").unwrap();
let unknown_deser: HealthState = serde_json::from_str("\"unknown\"").unwrap();
assert_eq!(healthy, healthy_deser);
assert_eq!(unhealthy, unhealthy_deser);
assert_eq!(unknown, unknown_deser);
}
}