Skip to main content

dynamo_runtime/
health_check.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::DistributedRuntime;
5use crate::config::HealthStatus;
6use crate::engine::AsyncEngine;
7use crate::pipeline::SingleIn;
8use crate::protocols::maybe_error::MaybeError;
9use futures::StreamExt;
10use parking_lot::Mutex;
11use std::collections::HashMap;
12use std::sync::Arc;
13use std::time::Duration;
14use tokio::task::JoinHandle;
15use tracing::{debug, error, info, warn};
16
17/// Configuration for health check behavior
18pub struct HealthCheckConfig {
19    /// Wait time before sending canary health checks (when no activity)
20    pub canary_wait_time: Duration,
21    /// Timeout for health check requests
22    pub request_timeout: Duration,
23}
24
25impl Default for HealthCheckConfig {
26    fn default() -> Self {
27        Self {
28            canary_wait_time: Duration::from_secs(crate::config::DEFAULT_CANARY_WAIT_TIME_SECS),
29            request_timeout: Duration::from_secs(
30                crate::config::DEFAULT_HEALTH_CHECK_REQUEST_TIMEOUT_SECS,
31            ),
32        }
33    }
34}
35
36/// Health check manager that monitors endpoint health
37pub struct HealthCheckManager {
38    drt: DistributedRuntime,
39    config: HealthCheckConfig,
40    /// Track per-endpoint health check tasks
41    /// Maps: endpoint_subject -> task_handle
42    endpoint_tasks: Arc<Mutex<HashMap<String, JoinHandle<()>>>>,
43}
44
45impl HealthCheckManager {
46    pub fn new(drt: DistributedRuntime, config: HealthCheckConfig) -> Self {
47        Self {
48            drt,
49            config,
50            endpoint_tasks: Arc::new(Mutex::new(HashMap::new())),
51        }
52    }
53
54    /// Start the health check manager by spawning per-endpoint monitoring tasks
55    pub async fn start(self: Arc<Self>) -> anyhow::Result<()> {
56        // Get all registered endpoints at startup
57        let targets = self.drt.system_health().lock().get_health_check_targets();
58
59        info!(
60            "Starting health check tasks for {} endpoints with canary_wait_time: {:?}",
61            targets.len(),
62            self.config.canary_wait_time
63        );
64
65        // Spawn a health check task for each registered endpoint
66        for (endpoint_subject, _target) in targets {
67            self.spawn_endpoint_health_check_task(endpoint_subject);
68        }
69
70        // CRITICAL: Spawn a task to monitor for NEW endpoints registered after startup
71        // This uses a channel-based approach to guarantee no lost notifications
72        // Will return an error if the receiver has already been taken
73        self.spawn_new_endpoint_monitor().await?;
74
75        info!("HealthCheckManager started successfully with channel-based endpoint discovery");
76        Ok(())
77    }
78
79    /// Spawn a dedicated health check task for a specific endpoint
80    fn spawn_endpoint_health_check_task(self: &Arc<Self>, endpoint_subject: String) {
81        let manager = self.clone();
82        let canary_wait = self.config.canary_wait_time;
83        let endpoint_subject_clone = endpoint_subject.clone();
84
85        // Get the endpoint-specific notifier
86        let notifier = self
87            .drt
88            .system_health()
89            .lock()
90            .get_endpoint_health_check_notifier(&endpoint_subject)
91            .expect("Notifier should exist for registered endpoint");
92
93        let task = tokio::spawn(async move {
94            let endpoint_subject = endpoint_subject_clone;
95            info!("Health check task started for: {}", endpoint_subject);
96
97            loop {
98                // Wait for either timeout or activity notification
99                tokio::select! {
100                    _ = tokio::time::sleep(canary_wait) => {
101                        // Timeout - send health check for this specific endpoint
102                        debug!("Canary timer expired for {}, sending health check", endpoint_subject);
103
104                        // Get the health check payload for this endpoint
105                        let target = manager.drt.system_health().lock().get_health_check_target(&endpoint_subject);
106
107                        if let Some(target) = target {
108                            if let Err(e) = manager.send_health_check_request(&endpoint_subject, &target.payload).await {
109                                error!("Failed to send health check for {}: {}", endpoint_subject, e);
110                            }
111                        } else {
112                            // This should never happen - targets are registered at startup and never removed
113                            error!(
114                                "CRITICAL: Health check target for {} disappeared unexpectedly! This indicates a bug. Stopping health check task.",
115                                endpoint_subject
116                            );
117                            break;
118                        }
119                    }
120
121                    _ = notifier.notified() => {
122                        // Activity detected - reset timer for this endpoint only
123                        debug!("Activity detected for {}, resetting health check timer", endpoint_subject);
124                        // Loop continues, timer resets
125                    }
126                }
127            }
128
129            info!("Health check task for {} exiting", endpoint_subject);
130        });
131
132        // Store the task handle
133        self.endpoint_tasks
134            .lock()
135            .insert(endpoint_subject.clone(), task);
136
137        info!(
138            "Spawned health check task for endpoint: {}",
139            endpoint_subject
140        );
141    }
142
143    /// Spawn a task to monitor for newly registered endpoints
144    /// Returns an error if duplicate endpoints are detected, indicating a bug in the system
145    async fn spawn_new_endpoint_monitor(self: &Arc<Self>) -> anyhow::Result<()> {
146        let manager = self.clone();
147
148        // Get the receiver (can only be taken once)
149        let mut rx = manager
150            .drt
151            .system_health()
152            .lock()
153            .take_new_endpoint_receiver()
154            .ok_or_else(|| {
155                anyhow::anyhow!("Endpoint receiver already taken - this should only be called once")
156            })?;
157
158        tokio::spawn(async move {
159            info!("Starting dynamic endpoint discovery monitor with channel-based notifications");
160
161            while let Some(endpoint_subject) = rx.recv().await {
162                debug!(
163                    "Received endpoint registration via channel: {}",
164                    endpoint_subject
165                );
166
167                let already_exists = {
168                    let tasks = manager.endpoint_tasks.lock();
169                    tasks.contains_key(&endpoint_subject)
170                };
171
172                if already_exists {
173                    error!(
174                        "CRITICAL: Received registration for endpoint '{}' that already has a health check task!",
175                        endpoint_subject
176                    );
177                    break;
178                }
179
180                info!(
181                    "Spawning health check task for new endpoint: {}",
182                    endpoint_subject
183                );
184                manager.spawn_endpoint_health_check_task(endpoint_subject);
185            }
186
187            info!("Endpoint discovery monitor exiting - no new endpoints will be monitored!");
188        });
189
190        info!("Dynamic endpoint discovery monitor started");
191        Ok(())
192    }
193
194    /// Send a health check request via the local endpoint registry (in-process).
195    async fn send_health_check_request(
196        &self,
197        endpoint_subject: &str,
198        payload: &serde_json::Value,
199    ) -> anyhow::Result<()> {
200        debug!(
201            "Sending health check to {} via local registry",
202            endpoint_subject
203        );
204
205        let engine = self
206            .drt
207            .local_endpoint_registry()
208            .get(endpoint_subject)
209            .ok_or_else(|| {
210                anyhow::anyhow!(
211                    "Endpoint '{}' not found in local registry, engine may still be initializing",
212                    endpoint_subject
213                )
214            })?;
215
216        // Clone what we need for the spawned task
217        let system_health = self.drt.system_health().clone();
218        let endpoint_subject_owned = endpoint_subject.to_string();
219        let payload = payload.clone();
220        let timeout = self.config.request_timeout;
221
222        // Spawn task to send health check and wait for response
223        tokio::spawn(async move {
224            let result = tokio::time::timeout(timeout, async {
225                let request = SingleIn::new(payload);
226                match engine.generate(request).await {
227                    Ok(mut response_stream) => {
228                        // Get the first response to verify endpoint is alive.
229                        // Check for errors
230                        let is_healthy = if let Some(response) = response_stream.next().await {
231                            if let Some(error) = response.err() {
232                                warn!(
233                                    "Health check error response from {}: {:?}",
234                                    endpoint_subject_owned, error
235                                );
236                                false
237                            } else {
238                                debug!("Health check successful for {}", endpoint_subject_owned);
239                                true
240                            }
241                        } else {
242                            warn!(
243                                "Health check got no response from {}",
244                                endpoint_subject_owned
245                            );
246                            false
247                        };
248
249                        tokio::spawn(async move {
250                            // We need to consume the rest of the stream to avoid warnings on the frontend.
251                            response_stream.for_each(|_| async {}).await;
252                        });
253
254                        // Update health status based on response
255                        system_health.lock().set_endpoint_health_status(
256                            &endpoint_subject_owned,
257                            if is_healthy {
258                                HealthStatus::Ready
259                            } else {
260                                HealthStatus::NotReady
261                            },
262                        );
263                    }
264                    Err(e) => {
265                        error!(
266                            "Health check request failed for {}: {}",
267                            endpoint_subject_owned, e
268                        );
269                        system_health.lock().set_endpoint_health_status(
270                            &endpoint_subject_owned,
271                            HealthStatus::NotReady,
272                        );
273                    }
274                }
275            })
276            .await;
277
278            // Handle timeout
279            if result.is_err() {
280                warn!("Health check timeout for {}", endpoint_subject_owned);
281                system_health
282                    .lock()
283                    .set_endpoint_health_status(&endpoint_subject_owned, HealthStatus::NotReady);
284            }
285
286            debug!("Health check completed for {}", endpoint_subject_owned);
287        });
288
289        Ok(())
290    }
291}
292
293/// Start health check manager for the distributed runtime
294pub async fn start_health_check_manager(
295    drt: DistributedRuntime,
296    config: Option<HealthCheckConfig>,
297) -> anyhow::Result<()> {
298    let config = config.unwrap_or_default();
299    let manager = Arc::new(HealthCheckManager::new(drt, config));
300
301    // Start the health check manager (this spawns per-endpoint tasks internally)
302    manager.start().await?;
303
304    Ok(())
305}
306
307/// Get health check status for all endpoints
308pub async fn get_health_check_status(
309    drt: &DistributedRuntime,
310) -> anyhow::Result<serde_json::Value> {
311    // Get endpoints list from SystemHealth
312    let endpoint_subjects: Vec<String> = drt.system_health().lock().get_health_check_endpoints();
313
314    let mut endpoint_statuses = HashMap::new();
315
316    // Check each endpoint's health status
317    {
318        let system_health = drt.system_health();
319        let system_health_lock = system_health.lock();
320        for endpoint_subject in &endpoint_subjects {
321            let health_status = system_health_lock
322                .get_endpoint_health_status(endpoint_subject)
323                .unwrap_or(HealthStatus::NotReady);
324
325            let is_healthy = matches!(health_status, HealthStatus::Ready);
326
327            endpoint_statuses.insert(
328                endpoint_subject.clone(),
329                serde_json::json!({
330                    "healthy": is_healthy,
331                    "status": format!("{:?}", health_status),
332                }),
333            );
334        }
335    }
336
337    let overall_healthy = endpoint_statuses
338        .values()
339        .all(|v| v["healthy"].as_bool().unwrap_or(false));
340
341    Ok(serde_json::json!({
342        "status": if overall_healthy { "ready" } else { "notready" },
343        "endpoints_checked": endpoint_subjects.len(),
344        "endpoint_statuses": endpoint_statuses,
345    }))
346}
347
348// ===============================
349// Integration Tests (require DRT)
350// ===============================
351#[cfg(all(test, feature = "integration"))]
352mod integration_tests {
353    use super::*;
354    use crate::distributed::distributed_test_utils::create_test_drt_async;
355    use std::sync::Arc;
356    use std::time::Duration;
357
358    #[tokio::test]
359    async fn test_initialization() {
360        let drt = create_test_drt_async().await;
361
362        let canary_wait_time = Duration::from_secs(5);
363        let request_timeout = Duration::from_secs(3);
364
365        let config = HealthCheckConfig {
366            canary_wait_time,
367            request_timeout,
368        };
369
370        let manager = HealthCheckManager::new(drt.clone(), config);
371
372        assert_eq!(manager.config.canary_wait_time, canary_wait_time);
373        assert_eq!(manager.config.request_timeout, request_timeout);
374    }
375
376    #[tokio::test]
377    async fn test_payload_registration() {
378        let drt = create_test_drt_async().await;
379
380        let endpoint = "test.endpoint";
381        let payload = serde_json::json!({
382            "prompt": "test",
383            "_health_check": true
384        });
385
386        drt.system_health().lock().register_health_check_target(
387            endpoint,
388            crate::component::Instance {
389                component: "test_component".to_string(),
390                endpoint: "test_endpoint".to_string(),
391                namespace: "test_namespace".to_string(),
392                instance_id: 12345,
393                transport: crate::component::TransportType::Nats(endpoint.to_string()),
394            },
395            payload.clone(),
396        );
397
398        let retrieved = drt
399            .system_health()
400            .lock()
401            .get_health_check_target(endpoint)
402            .map(|t| t.payload);
403        assert!(retrieved.is_some());
404        assert_eq!(retrieved.unwrap(), payload);
405
406        // Verify endpoint appears in the list
407        let endpoints = drt.system_health().lock().get_health_check_endpoints();
408        assert!(endpoints.contains(&endpoint.to_string()));
409    }
410
411    #[tokio::test]
412    async fn test_spawn_per_endpoint_tasks() {
413        let drt = create_test_drt_async().await;
414
415        for i in 0..3 {
416            let endpoint = format!("test.endpoint.{}", i);
417            let payload = serde_json::json!({
418                "prompt": format!("test{}", i),
419                "_health_check": true
420            });
421            drt.system_health().lock().register_health_check_target(
422                &endpoint,
423                crate::component::Instance {
424                    component: "test_component".to_string(),
425                    endpoint: format!("test_endpoint_{}", i),
426                    namespace: "test_namespace".to_string(),
427                    instance_id: i,
428                    transport: crate::component::TransportType::Nats(endpoint.clone()),
429                },
430                payload,
431            );
432        }
433
434        let config = HealthCheckConfig {
435            canary_wait_time: Duration::from_secs(5),
436            request_timeout: Duration::from_secs(1),
437        };
438
439        let manager = Arc::new(HealthCheckManager::new(drt.clone(), config));
440        manager.clone().start().await.unwrap();
441
442        // Verify all endpoints have their own health check tasks
443        let tasks = manager.endpoint_tasks.lock();
444        // Should have 3 tasks (one for each endpoint)
445        assert_eq!(tasks.len(), 3);
446        // Check that all endpoints are represented in tasks
447        let endpoints: Vec<String> = tasks.keys().cloned().collect();
448        assert!(endpoints.contains(&"test.endpoint.0".to_string()));
449        assert!(endpoints.contains(&"test.endpoint.1".to_string()));
450        assert!(endpoints.contains(&"test.endpoint.2".to_string()));
451    }
452
453    #[tokio::test]
454    async fn test_endpoint_health_check_notifier_created() {
455        let drt = create_test_drt_async().await;
456
457        let endpoint = "test.endpoint.notifier";
458        let payload = serde_json::json!({
459            "prompt": "test",
460            "_health_check": true
461        });
462
463        // Register the endpoint
464        drt.system_health().lock().register_health_check_target(
465            endpoint,
466            crate::component::Instance {
467                component: "test_component".to_string(),
468                endpoint: "test_endpoint_notifier".to_string(),
469                namespace: "test_namespace".to_string(),
470                instance_id: 999,
471                transport: crate::component::TransportType::Nats(endpoint.to_string()),
472            },
473            payload.clone(),
474        );
475
476        // Verify that a notifier was created for this endpoint
477        let notifier = drt
478            .system_health()
479            .lock()
480            .get_endpoint_health_check_notifier(endpoint);
481
482        assert!(
483            notifier.is_some(),
484            "Endpoint should have a notifier created"
485        );
486
487        // Verify we can notify it without panicking
488        if let Some(notifier) = notifier {
489            notifier.notify_one();
490        }
491
492        // Initially, the endpoint should be Ready (default after registration)
493        let status = drt
494            .system_health()
495            .lock()
496            .get_endpoint_health_status(endpoint);
497        assert_eq!(status, Some(HealthStatus::NotReady));
498    }
499}