Skip to main content

dynamo_runtime/
system_health.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16//! System health monitoring and health check management
17
18use std::{
19    collections::HashMap,
20    sync::{Arc, OnceLock},
21    time::Instant,
22};
23use tokio::sync::mpsc;
24
25use crate::component;
26use crate::config::HealthStatus;
27use crate::metrics::{MetricsHierarchy, prometheus_names::distributed_runtime};
28
29/// Health check target containing instance info and payload
30#[derive(Clone, Debug)]
31pub struct HealthCheckTarget {
32    pub instance: component::Instance,
33    pub payload: serde_json::Value,
34}
35
36/// Current Health Status
37/// If use_endpoint_health_status is set then
38/// initialize the endpoint_health hashmap to the
39/// starting health status
40#[derive(Clone)]
41pub struct SystemHealth {
42    system_health: HealthStatus,
43    endpoint_health: Arc<std::sync::RwLock<HashMap<String, HealthStatus>>>,
44    /// Maps endpoint subject to health check target (instance + payload)
45    health_check_targets: Arc<std::sync::RwLock<HashMap<String, HealthCheckTarget>>>,
46    /// Maps endpoint subject to its specific health check notifier
47    health_check_notifiers: Arc<std::sync::RwLock<HashMap<String, Arc<tokio::sync::Notify>>>>,
48    /// Channel for new endpoint registrations
49    /// This solves the race condition where HealthCheckManager starts before endpoints are registered
50    /// Using a channel ensures no registrations are lost.
51    new_endpoint_tx: mpsc::UnboundedSender<String>,
52    new_endpoint_rx: Arc<parking_lot::Mutex<Option<mpsc::UnboundedReceiver<String>>>>,
53    use_endpoint_health_status: Vec<String>,
54    health_path: String,
55    live_path: String,
56    start_time: Instant,
57    uptime_gauge: OnceLock<prometheus::Gauge>,
58}
59
60impl SystemHealth {
61    pub fn new(
62        starting_health_status: HealthStatus,
63        use_endpoint_health_status: Vec<String>,
64        health_path: String,
65        live_path: String,
66    ) -> Self {
67        let mut endpoint_health = HashMap::new();
68        for endpoint in &use_endpoint_health_status {
69            endpoint_health.insert(endpoint.clone(), starting_health_status.clone());
70        }
71
72        // Create the channel for endpoint registration notifications
73        let (tx, rx) = mpsc::unbounded_channel();
74
75        SystemHealth {
76            system_health: starting_health_status,
77            endpoint_health: Arc::new(std::sync::RwLock::new(endpoint_health)),
78            health_check_targets: Arc::new(std::sync::RwLock::new(HashMap::new())),
79            health_check_notifiers: Arc::new(std::sync::RwLock::new(HashMap::new())),
80            new_endpoint_tx: tx,
81            new_endpoint_rx: Arc::new(parking_lot::Mutex::new(Some(rx))),
82            use_endpoint_health_status,
83            health_path,
84            live_path,
85            start_time: Instant::now(),
86            uptime_gauge: OnceLock::new(),
87        }
88    }
89    pub fn set_health_status(&mut self, status: HealthStatus) {
90        self.system_health = status;
91    }
92
93    pub fn set_endpoint_health_status(&self, endpoint: &str, status: HealthStatus) {
94        let mut endpoint_health = self.endpoint_health.write().unwrap();
95        endpoint_health.insert(endpoint.to_string(), status);
96    }
97
98    /// Returns the overall health status and endpoint health statuses
99    /// System health is determined by ALL endpoints that have registered health checks
100    pub fn get_health_status(&self) -> (bool, HashMap<String, String>) {
101        let health_check_targets = self.health_check_targets.read().unwrap();
102        let endpoint_health = self.endpoint_health.read().unwrap();
103        let mut endpoints: HashMap<String, String> = HashMap::new();
104
105        for (endpoint, status) in endpoint_health.iter() {
106            endpoints.insert(
107                endpoint.clone(),
108                if *status == HealthStatus::Ready {
109                    "ready".to_string()
110                } else {
111                    "notready".to_string()
112                },
113            );
114        }
115
116        let healthy = if !self.use_endpoint_health_status.is_empty() {
117            self.use_endpoint_health_status.iter().all(|endpoint| {
118                endpoint_health
119                    .get(endpoint)
120                    .is_some_and(|status| *status == HealthStatus::Ready)
121            })
122        } else {
123            // If we have registered health check targets, use them to determine health
124            if !health_check_targets.is_empty() {
125                health_check_targets
126                    .iter()
127                    .all(|(endpoint_subject, _target)| {
128                        endpoint_health
129                            .get(endpoint_subject)
130                            .is_some_and(|status| *status == HealthStatus::Ready)
131                    })
132            } else {
133                // No health check targets registered, use simple system health
134                self.system_health == HealthStatus::Ready
135            }
136        };
137
138        (healthy, endpoints)
139    }
140
141    /// Register a health check target for an endpoint
142    pub fn register_health_check_target(
143        &self,
144        endpoint_subject: &str,
145        instance: component::Instance,
146        payload: serde_json::Value,
147    ) {
148        let key = endpoint_subject.to_owned();
149
150        // Atomically check+insert under a single write lock to avoid races.
151        let inserted = {
152            let mut targets = self.health_check_targets.write().unwrap();
153            match targets.entry(key.clone()) {
154                std::collections::hash_map::Entry::Occupied(_) => false,
155                std::collections::hash_map::Entry::Vacant(v) => {
156                    v.insert(HealthCheckTarget { instance, payload });
157                    true
158                }
159            }
160        };
161
162        if !inserted {
163            tracing::warn!(
164                "Attempted to re-register health check for endpoint '{}'; ignoring.",
165                key
166            );
167            return;
168        }
169
170        // Create and store a unique notifier for this endpoint (idempotent).
171        {
172            let mut notifiers = self.health_check_notifiers.write().unwrap();
173            notifiers
174                .entry(key.clone())
175                .or_insert_with(|| Arc::new(tokio::sync::Notify::new()));
176        }
177
178        // Initialize endpoint health status conservatively to NotReady.
179        {
180            let mut endpoint_health = self.endpoint_health.write().unwrap();
181            endpoint_health
182                .entry(key.clone())
183                .or_insert(HealthStatus::NotReady);
184        }
185
186        if let Err(e) = self.new_endpoint_tx.send(key.clone()) {
187            tracing::error!(
188                "Failed to send endpoint '{}' registration to health check manager: {}. \
189                 Health checks will not be performed for this endpoint.",
190                key,
191                e
192            );
193        }
194    }
195
196    /// Get all health check targets
197    pub fn get_health_check_targets(&self) -> Vec<(String, HealthCheckTarget)> {
198        let targets = self.health_check_targets.read().unwrap();
199        targets
200            .iter()
201            .map(|(k, v)| (k.clone(), v.clone()))
202            .collect()
203    }
204
205    /// Check if any health check targets are registered
206    pub fn has_health_check_targets(&self) -> bool {
207        let targets = self.health_check_targets.read().unwrap();
208        !targets.is_empty()
209    }
210
211    /// Get list of endpoints with health check targets
212    pub fn get_health_check_endpoints(&self) -> Vec<String> {
213        let targets = self.health_check_targets.read().unwrap();
214        targets.keys().cloned().collect()
215    }
216
217    /// Get health check target for a specific endpoint
218    pub fn get_health_check_target(&self, endpoint: &str) -> Option<HealthCheckTarget> {
219        let targets = self.health_check_targets.read().unwrap();
220        targets.get(endpoint).cloned()
221    }
222
223    /// Get the endpoint health status (Ready/NotReady)
224    pub fn get_endpoint_health_status(&self, endpoint: &str) -> Option<HealthStatus> {
225        let endpoint_health = self.endpoint_health.read().unwrap();
226        endpoint_health.get(endpoint).cloned()
227    }
228
229    /// Get the endpoint-specific health check notifier
230    pub fn get_endpoint_health_check_notifier(
231        &self,
232        endpoint_subject: &str,
233    ) -> Option<Arc<tokio::sync::Notify>> {
234        let notifiers = self.health_check_notifiers.read().unwrap();
235        notifiers.get(endpoint_subject).cloned()
236    }
237
238    /// Take the receiver for new endpoint registrations (can only be called once)
239    /// This is used by HealthCheckManager to receive notifications of new endpoints
240    pub fn take_new_endpoint_receiver(&self) -> Option<mpsc::UnboundedReceiver<String>> {
241        self.new_endpoint_rx.lock().take()
242    }
243
244    /// Initialize the uptime gauge using the provided metrics registry
245    pub fn initialize_uptime_gauge<T: MetricsHierarchy>(&self, registry: &T) -> anyhow::Result<()> {
246        let gauge = registry.metrics().create_gauge(
247            distributed_runtime::UPTIME_SECONDS,
248            "Total uptime of the DistributedRuntime in seconds",
249            &[],
250        )?;
251        self.uptime_gauge
252            .set(gauge)
253            .map_err(|_| anyhow::anyhow!("uptime_gauge already initialized"))?;
254        Ok(())
255    }
256
257    /// Get the current uptime as a Duration
258    pub fn uptime(&self) -> std::time::Duration {
259        self.start_time.elapsed()
260    }
261
262    /// Update the uptime gauge with the current uptime value
263    pub fn update_uptime_gauge(&self) {
264        if let Some(gauge) = self.uptime_gauge.get() {
265            gauge.set(self.uptime().as_secs_f64());
266        }
267    }
268
269    /// Get the health check path
270    pub fn health_path(&self) -> &str {
271        &self.health_path
272    }
273
274    /// Get the liveness check path
275    pub fn live_path(&self) -> &str {
276        &self.live_path
277    }
278}