use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::UnixStream;
use tokio::sync::Mutex;
use tracing::{debug, warn};
use breaker_machines::CircuitBreaker;
use chrono_machines::{BackoffStrategy, ExponentialBackoff};
use dashmap::DashMap;
use rama::http::{client::EasyHttpWebClient, service::client::HttpClientExt};
use rand::{SeedableRng, rngs::SmallRng};
const HEALTH_CHECK_INITIAL_DELAY: Duration = Duration::from_secs(7);
const HEALTH_CHECK_INTERVAL: Duration = Duration::from_secs(5);
const HEALTH_CHECK_TIMEOUT: Duration = Duration::from_secs(2);
const RETRY_ATTEMPTS: u8 = 3;
const RETRY_BASE_DELAY_MS: u64 = 100;
const RETRY_MAX_DELAY_MS: u64 = 1000;
const CIRCUIT_FAILURE_THRESHOLD: usize = 5;
const CIRCUIT_FAILURE_WINDOW_SECS: f64 = 60.0;
const CIRCUIT_HALF_OPEN_TIMEOUT_SECS: f64 = 60.0;
const CIRCUIT_SUCCESS_THRESHOLD: usize = 1;
pub struct HealthChecker {
circuits: Arc<DashMap<String, Arc<Mutex<CircuitBreaker>>>>,
}
impl Default for HealthChecker {
fn default() -> Self {
Self::new()
}
}
impl HealthChecker {
pub fn new() -> Self {
Self {
circuits: Arc::new(DashMap::new()),
}
}
fn get_circuit(&self, ship_name: &str) -> Arc<Mutex<CircuitBreaker>> {
self.circuits
.entry(ship_name.to_string())
.or_insert_with(|| {
let breaker = CircuitBreaker::builder(format!("health:{}", ship_name))
.failure_threshold(CIRCUIT_FAILURE_THRESHOLD)
.failure_window_secs(CIRCUIT_FAILURE_WINDOW_SECS)
.half_open_timeout_secs(CIRCUIT_HALF_OPEN_TIMEOUT_SECS)
.success_threshold(CIRCUIT_SUCCESS_THRESHOLD)
.build();
Arc::new(Mutex::new(breaker))
})
.clone()
}
pub async fn check_http(&self, ship_name: &str, url: &str) -> bool {
let circuit = self.get_circuit(ship_name);
{
let breaker = circuit.lock().await;
if breaker.is_open() {
debug!(
ship = ship_name,
url = url,
"Health check skipped - circuit breaker open"
);
return false;
}
}
let result = if let Some(rest) = url.strip_prefix("unix://") {
self.check_unix_socket_with_retry(ship_name, rest).await
} else {
self.check_http_with_retry(ship_name, url).await
};
{
let mut breaker = circuit.lock().await;
if result {
breaker.record_success(0.0);
} else {
breaker.record_failure_and_maybe_trip(0.0);
}
}
result
}
async fn check_http_with_retry(&self, ship_name: &str, url: &str) -> bool {
let backoff = ExponentialBackoff::new()
.base_delay_ms(RETRY_BASE_DELAY_MS)
.max_delay_ms(RETRY_MAX_DELAY_MS)
.max_attempts(RETRY_ATTEMPTS);
let mut rng = SmallRng::from_os_rng();
let mut attempt = 0u8;
loop {
attempt += 1;
let client = EasyHttpWebClient::default();
match tokio::time::timeout(HEALTH_CHECK_TIMEOUT, client.get(url).send()).await {
Ok(Ok(resp)) => {
let healthy = resp.status().is_success();
debug!(
ship = ship_name,
url = url,
status = %resp.status(),
attempt = attempt,
healthy = healthy,
"Health check"
);
return healthy;
}
Ok(Err(e)) => {
match backoff.delay(attempt, &mut rng) {
Some(delay_ms) => {
debug!(
ship = ship_name,
url = url,
attempt = attempt,
delay_ms = delay_ms,
error = %e,
"Health check failed, retrying"
);
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
}
None => {
warn!(
ship = ship_name,
url = url,
attempts = attempt,
error = %e,
"Health check failed after all retries"
);
return false;
}
}
}
Err(_) => {
match backoff.delay(attempt, &mut rng) {
Some(delay_ms) => {
debug!(
ship = ship_name,
url = url,
attempt = attempt,
delay_ms = delay_ms,
"Health check timed out, retrying"
);
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
}
None => {
warn!(
ship = ship_name,
url = url,
attempts = attempt,
"Health check timed out after all retries"
);
return false;
}
}
}
}
}
}
async fn check_unix_socket_with_retry(&self, ship_name: &str, path_and_endpoint: &str) -> bool {
let (socket_path, endpoint) = match Self::split_socket_path(path_and_endpoint) {
Some(parts) => parts,
None => {
warn!(
ship = ship_name,
path = path_and_endpoint,
"Could not find Unix socket file"
);
return false;
}
};
let backoff = ExponentialBackoff::new()
.base_delay_ms(RETRY_BASE_DELAY_MS)
.max_delay_ms(RETRY_MAX_DELAY_MS)
.max_attempts(RETRY_ATTEMPTS);
let mut rng = SmallRng::from_os_rng();
let mut attempt = 0u8;
loop {
attempt += 1;
let result = tokio::time::timeout(HEALTH_CHECK_TIMEOUT, async {
let mut stream = UnixStream::connect(&socket_path).await?;
let request = format!(
"GET {} HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n",
endpoint
);
stream.write_all(request.as_bytes()).await?;
let mut response = vec![0u8; 1024];
let n = stream.read(&mut response).await?;
let response_str = String::from_utf8_lossy(&response[..n]);
Ok::<bool, std::io::Error>(
response_str.starts_with("HTTP/1.1 2")
|| response_str.starts_with("HTTP/1.0 2"),
)
})
.await;
match result {
Ok(Ok(healthy)) => {
debug!(
ship = ship_name,
socket = socket_path,
endpoint = endpoint,
attempt = attempt,
healthy = healthy,
"Unix socket health check"
);
return healthy;
}
Ok(Err(e)) => {
let error_msg = format!("Connection failed: {}", e);
match backoff.delay(attempt, &mut rng) {
Some(delay_ms) => {
debug!(
ship = ship_name,
socket = socket_path,
attempt = attempt,
delay_ms = delay_ms,
error = error_msg,
"Unix socket health check failed, retrying"
);
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
}
None => {
warn!(
ship = ship_name,
socket = socket_path,
attempts = attempt,
error = error_msg,
"Unix socket health check failed after all retries"
);
return false;
}
}
}
Err(_) => {
let error_msg = "Timed out".to_string();
match backoff.delay(attempt, &mut rng) {
Some(delay_ms) => {
debug!(
ship = ship_name,
socket = socket_path,
attempt = attempt,
delay_ms = delay_ms,
error = error_msg,
"Unix socket health check failed, retrying"
);
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
}
None => {
warn!(
ship = ship_name,
socket = socket_path,
attempts = attempt,
error = error_msg,
"Unix socket health check failed after all retries"
);
return false;
}
}
}
}
}
}
fn split_socket_path(path: &str) -> Option<(String, String)> {
for ext in [".sock", ".socket", ".s"] {
if let Some(idx) = path.find(ext) {
let socket_end = idx + ext.len();
let socket_path = &path[..socket_end];
let endpoint = if socket_end < path.len() {
&path[socket_end..]
} else {
"/"
};
return Some((socket_path.to_string(), endpoint.to_string()));
}
}
let parts: Vec<&str> = path.split('/').collect();
for i in (1..parts.len()).rev() {
let potential_socket = parts[..=i].join("/");
if std::path::Path::new(&potential_socket).exists() {
let endpoint = if i + 1 < parts.len() {
format!("/{}", parts[i + 1..].join("/"))
} else {
"/".to_string()
};
return Some((potential_socket, endpoint));
}
}
None
}
pub fn interval(&self) -> Duration {
HEALTH_CHECK_INTERVAL
}
pub fn initial_delay(&self) -> Duration {
HEALTH_CHECK_INITIAL_DELAY
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_health_checker() {
let checker = HealthChecker::new();
let healthy = checker
.check_http("test-ship", "http://127.0.0.1:19999/health")
.await;
assert!(!healthy);
}
}