use pingora_load_balancing::{
discovery::Static,
health_check::{HealthCheck as PingoraHealthCheck, HttpHealthCheck, TcpHealthCheck},
Backend, Backends,
};
use std::collections::BTreeSet;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use tracing::{debug, info, trace, warn};
use crate::grpc_health::GrpcHealthCheck;
use crate::upstream::inference_health::InferenceHealthCheck;
use zentinel_common::types::HealthCheckType;
use zentinel_config::{HealthCheck as HealthCheckConfig, UpstreamConfig};
pub struct ActiveHealthChecker {
upstream_id: String,
backends: Arc<Backends>,
interval: Duration,
parallel: bool,
health_callback: Arc<RwLock<Option<HealthChangeCallback>>>,
}
pub type HealthChangeCallback = Box<dyn Fn(&str, bool) + Send + Sync>;
impl ActiveHealthChecker {
pub fn new(config: &UpstreamConfig) -> Option<Self> {
let health_config = config.health_check.as_ref()?;
info!(
upstream_id = %config.id,
check_type = ?health_config.check_type,
interval_secs = health_config.interval_secs,
"Creating active health checker"
);
let mut backend_set = BTreeSet::new();
for target in &config.targets {
match Backend::new_with_weight(&target.address, target.weight as usize) {
Ok(backend) => {
debug!(
upstream_id = %config.id,
target = %target.address,
weight = target.weight,
"Added backend for health checking"
);
backend_set.insert(backend);
}
Err(e) => {
warn!(
upstream_id = %config.id,
target = %target.address,
error = %e,
"Failed to create backend for health checking"
);
}
}
}
if backend_set.is_empty() {
warn!(
upstream_id = %config.id,
"No backends created for health checking"
);
return None;
}
let discovery = Static::new(backend_set);
let mut backends = Backends::new(discovery);
let health_check: Box<dyn PingoraHealthCheck + Send + Sync> =
Self::create_health_check(health_config, &config.id);
backends.set_health_check(health_check);
Some(Self {
upstream_id: config.id.clone(),
backends: Arc::new(backends),
interval: Duration::from_secs(health_config.interval_secs),
parallel: true,
health_callback: Arc::new(RwLock::new(None)),
})
}
fn create_health_check(
config: &HealthCheckConfig,
upstream_id: &str,
) -> Box<dyn PingoraHealthCheck + Send + Sync> {
match &config.check_type {
HealthCheckType::Http {
path,
expected_status,
host,
} => {
let hostname = host.as_deref().unwrap_or("localhost");
let mut hc = HttpHealthCheck::new(hostname, false);
hc.consecutive_success = config.healthy_threshold as usize;
hc.consecutive_failure = config.unhealthy_threshold as usize;
if path != "/" {
if let Ok(req) =
pingora_http::RequestHeader::build("GET", path.as_bytes(), None)
{
hc.req = req;
}
}
debug!(
upstream_id = %upstream_id,
path = %path,
expected_status = expected_status,
host = hostname,
consecutive_success = hc.consecutive_success,
consecutive_failure = hc.consecutive_failure,
"Created HTTP health check"
);
Box::new(hc)
}
HealthCheckType::Tcp => {
let mut hc = TcpHealthCheck::new();
hc.consecutive_success = config.healthy_threshold as usize;
hc.consecutive_failure = config.unhealthy_threshold as usize;
debug!(
upstream_id = %upstream_id,
consecutive_success = hc.consecutive_success,
consecutive_failure = hc.consecutive_failure,
"Created TCP health check"
);
hc
}
HealthCheckType::Grpc { service } => {
let timeout = Duration::from_secs(config.timeout_secs);
let mut hc = GrpcHealthCheck::new(service.clone(), timeout);
hc.consecutive_success = config.healthy_threshold as usize;
hc.consecutive_failure = config.unhealthy_threshold as usize;
info!(
upstream_id = %upstream_id,
service = %service,
timeout_secs = config.timeout_secs,
consecutive_success = hc.consecutive_success,
consecutive_failure = hc.consecutive_failure,
"Created gRPC health check"
);
Box::new(hc)
}
HealthCheckType::Inference {
endpoint,
expected_models,
readiness: _,
} => {
let timeout = Duration::from_secs(config.timeout_secs);
let mut hc =
InferenceHealthCheck::new(endpoint.clone(), expected_models.clone(), timeout);
hc.consecutive_success = config.healthy_threshold as usize;
hc.consecutive_failure = config.unhealthy_threshold as usize;
info!(
upstream_id = %upstream_id,
endpoint = %endpoint,
expected_models = ?expected_models,
timeout_secs = config.timeout_secs,
consecutive_success = hc.consecutive_success,
consecutive_failure = hc.consecutive_failure,
"Created inference health check with model verification"
);
Box::new(hc)
}
}
}
pub async fn set_health_callback(&self, callback: HealthChangeCallback) {
*self.health_callback.write().await = Some(callback);
}
pub async fn run_health_check(&self) {
trace!(
upstream_id = %self.upstream_id,
parallel = self.parallel,
"Running health check cycle"
);
self.backends.run_health_check(self.parallel).await;
}
pub fn is_backend_healthy(&self, address: &str) -> bool {
let backends = self.backends.get_backend();
for backend in backends.iter() {
if backend.addr.to_string() == address {
return self.backends.ready(backend);
}
}
true
}
pub fn get_health_statuses(&self) -> Vec<(String, bool)> {
let backends = self.backends.get_backend();
backends
.iter()
.map(|b| {
let addr = b.addr.to_string();
let healthy = self.backends.ready(b);
(addr, healthy)
})
.collect()
}
pub fn interval(&self) -> Duration {
self.interval
}
pub fn upstream_id(&self) -> &str {
&self.upstream_id
}
}
pub struct HealthCheckRunner {
checkers: Vec<ActiveHealthChecker>,
running: Arc<RwLock<bool>>,
}
impl HealthCheckRunner {
pub fn new() -> Self {
Self {
checkers: Vec::new(),
running: Arc::new(RwLock::new(false)),
}
}
pub fn add_checker(&mut self, checker: ActiveHealthChecker) {
info!(
upstream_id = %checker.upstream_id,
interval_secs = checker.interval.as_secs(),
"Added health checker to runner"
);
self.checkers.push(checker);
}
pub fn checker_count(&self) -> usize {
self.checkers.len()
}
pub async fn run(&self) {
if self.checkers.is_empty() {
info!("No health checkers configured, skipping health check loop");
return;
}
*self.running.write().await = true;
info!(
checker_count = self.checkers.len(),
"Starting health check runner"
);
let min_interval = self
.checkers
.iter()
.map(|c| c.interval)
.min()
.unwrap_or(Duration::from_secs(10));
let mut interval = tokio::time::interval(min_interval);
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
interval.tick().await;
if !*self.running.read().await {
info!("Health check runner stopped");
break;
}
for checker in &self.checkers {
checker.run_health_check().await;
let statuses = checker.get_health_statuses();
for (addr, healthy) in &statuses {
trace!(
upstream_id = %checker.upstream_id,
backend = %addr,
healthy = healthy,
"Backend health status"
);
}
}
}
}
pub async fn stop(&self) {
info!("Stopping health check runner");
*self.running.write().await = false;
}
pub fn get_health(&self, upstream_id: &str, address: &str) -> Option<bool> {
self.checkers
.iter()
.find(|c| c.upstream_id == upstream_id)
.map(|c| c.is_backend_healthy(address))
}
pub fn get_upstream_health(&self, upstream_id: &str) -> Option<Vec<(String, bool)>> {
self.checkers
.iter()
.find(|c| c.upstream_id == upstream_id)
.map(|c| c.get_health_statuses())
}
}
impl Default for HealthCheckRunner {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
use std::sync::Once;
use zentinel_common::types::LoadBalancingAlgorithm;
use zentinel_config::{
ConnectionPoolConfig, HttpVersionConfig, UpstreamTarget, UpstreamTimeouts,
};
static INIT: Once = Once::new();
fn init_crypto_provider() {
INIT.call_once(|| {
let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
});
}
fn create_test_config() -> UpstreamConfig {
UpstreamConfig {
id: "test-upstream".to_string(),
targets: vec![UpstreamTarget {
address: "127.0.0.1:8081".to_string(),
weight: 1,
max_requests: None,
metadata: HashMap::new(),
}],
load_balancing: LoadBalancingAlgorithm::RoundRobin,
sticky_session: None,
health_check: Some(HealthCheckConfig {
check_type: HealthCheckType::Http {
path: "/health".to_string(),
expected_status: 200,
host: None,
},
interval_secs: 5,
timeout_secs: 2,
healthy_threshold: 2,
unhealthy_threshold: 3,
}),
connection_pool: ConnectionPoolConfig::default(),
timeouts: UpstreamTimeouts::default(),
tls: None,
http_version: HttpVersionConfig::default(),
}
}
#[test]
fn test_create_health_checker() {
init_crypto_provider();
let config = create_test_config();
let checker = ActiveHealthChecker::new(&config);
assert!(checker.is_some());
let checker = checker.unwrap();
assert_eq!(checker.upstream_id, "test-upstream");
assert_eq!(checker.interval, Duration::from_secs(5));
}
#[test]
fn test_no_health_check_config() {
let mut config = create_test_config();
config.health_check = None;
let checker = ActiveHealthChecker::new(&config);
assert!(checker.is_none());
}
#[test]
fn test_health_check_runner() {
init_crypto_provider();
let mut runner = HealthCheckRunner::new();
assert_eq!(runner.checker_count(), 0);
let config = create_test_config();
if let Some(checker) = ActiveHealthChecker::new(&config) {
runner.add_checker(checker);
assert_eq!(runner.checker_count(), 1);
}
}
}