use std::{
collections::HashMap,
sync::{Arc, OnceLock},
time::Instant,
};
use tokio::sync::mpsc;
use crate::component;
use crate::config::HealthStatus;
use crate::metrics::{MetricsHierarchy, prometheus_names::distributed_runtime};
#[derive(Clone, Debug)]
pub struct HealthCheckTarget {
pub instance: component::Instance,
pub payload: serde_json::Value,
}
#[derive(Clone)]
pub struct SystemHealth {
system_health: HealthStatus,
endpoint_health: Arc<std::sync::RwLock<HashMap<String, HealthStatus>>>,
health_check_targets: Arc<std::sync::RwLock<HashMap<String, HealthCheckTarget>>>,
health_check_notifiers: Arc<std::sync::RwLock<HashMap<String, Arc<tokio::sync::Notify>>>>,
new_endpoint_tx: mpsc::UnboundedSender<String>,
new_endpoint_rx: Arc<parking_lot::Mutex<Option<mpsc::UnboundedReceiver<String>>>>,
use_endpoint_health_status: Vec<String>,
health_path: String,
live_path: String,
start_time: Instant,
uptime_gauge: OnceLock<prometheus::Gauge>,
}
impl SystemHealth {
pub fn new(
starting_health_status: HealthStatus,
use_endpoint_health_status: Vec<String>,
health_path: String,
live_path: String,
) -> Self {
let mut endpoint_health = HashMap::new();
for endpoint in &use_endpoint_health_status {
endpoint_health.insert(endpoint.clone(), starting_health_status.clone());
}
let (tx, rx) = mpsc::unbounded_channel();
SystemHealth {
system_health: starting_health_status,
endpoint_health: Arc::new(std::sync::RwLock::new(endpoint_health)),
health_check_targets: Arc::new(std::sync::RwLock::new(HashMap::new())),
health_check_notifiers: Arc::new(std::sync::RwLock::new(HashMap::new())),
new_endpoint_tx: tx,
new_endpoint_rx: Arc::new(parking_lot::Mutex::new(Some(rx))),
use_endpoint_health_status,
health_path,
live_path,
start_time: Instant::now(),
uptime_gauge: OnceLock::new(),
}
}
pub fn set_health_status(&mut self, status: HealthStatus) {
self.system_health = status;
}
pub fn set_endpoint_health_status(&self, endpoint: &str, status: HealthStatus) {
let mut endpoint_health = self.endpoint_health.write().unwrap();
endpoint_health.insert(endpoint.to_string(), status);
}
pub fn get_health_status(&self) -> (bool, HashMap<String, String>) {
let health_check_targets = self.health_check_targets.read().unwrap();
let endpoint_health = self.endpoint_health.read().unwrap();
let mut endpoints: HashMap<String, String> = HashMap::new();
for (endpoint, status) in endpoint_health.iter() {
endpoints.insert(
endpoint.clone(),
if *status == HealthStatus::Ready {
"ready".to_string()
} else {
"notready".to_string()
},
);
}
let healthy = if !self.use_endpoint_health_status.is_empty() {
self.use_endpoint_health_status.iter().all(|endpoint| {
endpoint_health
.get(endpoint)
.is_some_and(|status| *status == HealthStatus::Ready)
})
} else {
if !health_check_targets.is_empty() {
health_check_targets
.iter()
.all(|(endpoint_subject, _target)| {
endpoint_health
.get(endpoint_subject)
.is_some_and(|status| *status == HealthStatus::Ready)
})
} else {
self.system_health == HealthStatus::Ready
}
};
(healthy, endpoints)
}
pub fn register_health_check_target(
&self,
endpoint_subject: &str,
instance: component::Instance,
payload: serde_json::Value,
) {
let key = endpoint_subject.to_owned();
let inserted = {
let mut targets = self.health_check_targets.write().unwrap();
match targets.entry(key.clone()) {
std::collections::hash_map::Entry::Occupied(_) => false,
std::collections::hash_map::Entry::Vacant(v) => {
v.insert(HealthCheckTarget { instance, payload });
true
}
}
};
if !inserted {
tracing::warn!(
"Attempted to re-register health check for endpoint '{}'; ignoring.",
key
);
return;
}
{
let mut notifiers = self.health_check_notifiers.write().unwrap();
notifiers
.entry(key.clone())
.or_insert_with(|| Arc::new(tokio::sync::Notify::new()));
}
{
let mut endpoint_health = self.endpoint_health.write().unwrap();
endpoint_health
.entry(key.clone())
.or_insert(HealthStatus::NotReady);
}
if let Err(e) = self.new_endpoint_tx.send(key.clone()) {
tracing::error!(
"Failed to send endpoint '{}' registration to health check manager: {}. \
Health checks will not be performed for this endpoint.",
key,
e
);
}
}
pub fn get_health_check_targets(&self) -> Vec<(String, HealthCheckTarget)> {
let targets = self.health_check_targets.read().unwrap();
targets
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect()
}
pub fn has_health_check_targets(&self) -> bool {
let targets = self.health_check_targets.read().unwrap();
!targets.is_empty()
}
pub fn get_health_check_endpoints(&self) -> Vec<String> {
let targets = self.health_check_targets.read().unwrap();
targets.keys().cloned().collect()
}
pub fn get_health_check_target(&self, endpoint: &str) -> Option<HealthCheckTarget> {
let targets = self.health_check_targets.read().unwrap();
targets.get(endpoint).cloned()
}
pub fn get_endpoint_health_status(&self, endpoint: &str) -> Option<HealthStatus> {
let endpoint_health = self.endpoint_health.read().unwrap();
endpoint_health.get(endpoint).cloned()
}
pub fn get_endpoint_health_check_notifier(
&self,
endpoint_subject: &str,
) -> Option<Arc<tokio::sync::Notify>> {
let notifiers = self.health_check_notifiers.read().unwrap();
notifiers.get(endpoint_subject).cloned()
}
pub fn take_new_endpoint_receiver(&self) -> Option<mpsc::UnboundedReceiver<String>> {
self.new_endpoint_rx.lock().take()
}
pub fn initialize_uptime_gauge<T: MetricsHierarchy>(&self, registry: &T) -> anyhow::Result<()> {
let gauge = registry.metrics().create_gauge(
distributed_runtime::UPTIME_SECONDS,
"Total uptime of the DistributedRuntime in seconds",
&[],
)?;
self.uptime_gauge
.set(gauge)
.map_err(|_| anyhow::anyhow!("uptime_gauge already initialized"))?;
Ok(())
}
pub fn uptime(&self) -> std::time::Duration {
self.start_time.elapsed()
}
pub fn update_uptime_gauge(&self) {
if let Some(gauge) = self.uptime_gauge.get() {
gauge.set(self.uptime().as_secs_f64());
}
}
pub fn health_path(&self) -> &str {
&self.health_path
}
pub fn live_path(&self) -> &str {
&self.live_path
}
}