Skip to main content

dynamo_runtime/
health_check.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::component::{Client, Component, Endpoint, Instance};
5use crate::config::HealthStatus;
6use crate::pipeline::PushRouter;
7use crate::pipeline::{AsyncEngine, Context, ManyOut, SingleIn};
8use crate::protocols::annotated::Annotated;
9use crate::protocols::maybe_error::MaybeError;
10use crate::{DistributedRuntime, SystemHealth};
11use futures::StreamExt;
12use parking_lot::Mutex;
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::sync::Arc;
16use std::time::{Duration, Instant};
17use tokio::task::JoinHandle;
18use tokio::time::{MissedTickBehavior, interval};
19use tracing::{debug, error, info, warn};
20
21/// Configuration for health check behavior
22pub struct HealthCheckConfig {
23    /// Wait time before sending canary health checks (when no activity)
24    pub canary_wait_time: Duration,
25    /// Timeout for health check requests
26    pub request_timeout: Duration,
27}
28
29impl Default for HealthCheckConfig {
30    fn default() -> Self {
31        Self {
32            canary_wait_time: Duration::from_secs(crate::config::DEFAULT_CANARY_WAIT_TIME_SECS),
33            request_timeout: Duration::from_secs(
34                crate::config::DEFAULT_HEALTH_CHECK_REQUEST_TIMEOUT_SECS,
35            ),
36        }
37    }
38}
39
40// Type alias for the router cache to improve readability
41// Maps endpoint subject -> router and payload
42type RouterCache =
43    Arc<Mutex<HashMap<String, Arc<PushRouter<serde_json::Value, Annotated<serde_json::Value>>>>>>;
44
45/// Health check manager that monitors endpoint health
46pub struct HealthCheckManager {
47    drt: DistributedRuntime,
48    config: HealthCheckConfig,
49    /// Cache of PushRouters and payloads for each endpoint
50    router_cache: RouterCache,
51    /// Track per-endpoint health check tasks
52    /// Maps: endpoint_subject -> task_handle
53    endpoint_tasks: Arc<Mutex<HashMap<String, JoinHandle<()>>>>,
54}
55
56impl HealthCheckManager {
57    pub fn new(drt: DistributedRuntime, config: HealthCheckConfig) -> Self {
58        Self {
59            drt,
60            config,
61            router_cache: Arc::new(Mutex::new(HashMap::new())),
62            endpoint_tasks: Arc::new(Mutex::new(HashMap::new())),
63        }
64    }
65
66    /// Get or create a PushRouter for an endpoint
67    async fn get_or_create_router(
68        &self,
69        cache_key: &str,
70        endpoint: Endpoint,
71    ) -> anyhow::Result<Arc<PushRouter<serde_json::Value, Annotated<serde_json::Value>>>> {
72        let cache_key = cache_key.to_string();
73
74        // Check cache first
75        {
76            let cache = self.router_cache.lock();
77            if let Some(router) = cache.get(&cache_key) {
78                return Ok(router.clone());
79            }
80        }
81
82        // Create a client that discovers instances dynamically for this endpoint
83        let client = Client::new(endpoint).await?;
84
85        // Create PushRouter - it will use direct routing when we call direct()
86        let router: Arc<PushRouter<serde_json::Value, Annotated<serde_json::Value>>> = Arc::new(
87            PushRouter::from_client(
88                client,
89                crate::pipeline::RouterMode::RoundRobin, // Default mode, we'll use direct() explicitly
90            )
91            .await?,
92        );
93
94        // Cache it
95        self.router_cache.lock().insert(cache_key, router.clone());
96
97        Ok(router)
98    }
99
100    /// Start the health check manager by spawning per-endpoint monitoring tasks
101    pub async fn start(self: Arc<Self>) -> anyhow::Result<()> {
102        // Get all registered endpoints at startup
103        let targets = self.drt.system_health().lock().get_health_check_targets();
104
105        info!(
106            "Starting health check tasks for {} endpoints with canary_wait_time: {:?}",
107            targets.len(),
108            self.config.canary_wait_time
109        );
110
111        // Spawn a health check task for each registered endpoint
112        for (endpoint_subject, _target) in targets {
113            self.spawn_endpoint_health_check_task(endpoint_subject);
114        }
115
116        // CRITICAL: Spawn a task to monitor for NEW endpoints registered after startup
117        // This uses a channel-based approach to guarantee no lost notifications
118        // Will return an error if the receiver has already been taken
119        self.spawn_new_endpoint_monitor().await?;
120
121        info!("HealthCheckManager started successfully with channel-based endpoint discovery");
122        Ok(())
123    }
124
125    /// Spawn a dedicated health check task for a specific endpoint
126    fn spawn_endpoint_health_check_task(self: &Arc<Self>, endpoint_subject: String) {
127        let manager = self.clone();
128        let canary_wait = self.config.canary_wait_time;
129        let endpoint_subject_clone = endpoint_subject.clone();
130
131        // Get the endpoint-specific notifier
132        let notifier = self
133            .drt
134            .system_health()
135            .lock()
136            .get_endpoint_health_check_notifier(&endpoint_subject)
137            .expect("Notifier should exist for registered endpoint");
138
139        let task = tokio::spawn(async move {
140            let endpoint_subject = endpoint_subject_clone;
141            info!("Health check task started for: {}", endpoint_subject);
142
143            loop {
144                // Wait for either timeout or activity notification
145                tokio::select! {
146                    _ = tokio::time::sleep(canary_wait) => {
147                        // Timeout - send health check for this specific endpoint
148                        debug!("Canary timer expired for {}, sending health check", endpoint_subject);
149
150                        // Get the health check payload for this endpoint
151                        let target = manager.drt.system_health().lock().get_health_check_target(&endpoint_subject);
152
153                        if let Some(target) = target {
154                            if let Err(e) = manager.send_health_check_request(&endpoint_subject, &target.payload).await {
155                                error!("Failed to send health check for {}: {}", endpoint_subject, e);
156                            }
157                        } else {
158                            // This should never happen - targets are registered at startup and never removed
159                            error!(
160                                "CRITICAL: Health check target for {} disappeared unexpectedly! This indicates a bug. Stopping health check task.",
161                                endpoint_subject
162                            );
163                            break;
164                        }
165                    }
166
167                    _ = notifier.notified() => {
168                        // Activity detected - reset timer for this endpoint only
169                        debug!("Activity detected for {}, resetting health check timer", endpoint_subject);
170                        // Loop continues, timer resets
171                    }
172                }
173            }
174
175            info!("Health check task for {} exiting", endpoint_subject);
176        });
177
178        // Store the task handle
179        self.endpoint_tasks
180            .lock()
181            .insert(endpoint_subject.clone(), task);
182
183        info!(
184            "Spawned health check task for endpoint: {}",
185            endpoint_subject
186        );
187    }
188
189    /// Spawn a task to monitor for newly registered endpoints
190    /// Returns an error if duplicate endpoints are detected, indicating a bug in the system
191    async fn spawn_new_endpoint_monitor(self: &Arc<Self>) -> anyhow::Result<()> {
192        let manager = self.clone();
193
194        // Get the receiver (can only be taken once)
195        let mut rx = manager
196            .drt
197            .system_health()
198            .lock()
199            .take_new_endpoint_receiver()
200            .ok_or_else(|| {
201                anyhow::anyhow!("Endpoint receiver already taken - this should only be called once")
202            })?;
203
204        tokio::spawn(async move {
205            info!("Starting dynamic endpoint discovery monitor with channel-based notifications");
206
207            while let Some(endpoint_subject) = rx.recv().await {
208                debug!(
209                    "Received endpoint registration via channel: {}",
210                    endpoint_subject
211                );
212
213                let already_exists = {
214                    let tasks = manager.endpoint_tasks.lock();
215                    tasks.contains_key(&endpoint_subject)
216                };
217
218                if already_exists {
219                    error!(
220                        "CRITICAL: Received registration for endpoint '{}' that already has a health check task!",
221                        endpoint_subject
222                    );
223                    break;
224                }
225
226                info!(
227                    "Spawning health check task for new endpoint: {}",
228                    endpoint_subject
229                );
230                manager.spawn_endpoint_health_check_task(endpoint_subject);
231            }
232
233            info!("Endpoint discovery monitor exiting - no new endpoints will be monitored!");
234        });
235
236        info!("Dynamic endpoint discovery monitor started");
237        Ok(())
238    }
239
240    /// Send a health check request through AsyncEngine
241    async fn send_health_check_request(
242        &self,
243        endpoint_subject: &str,
244        payload: &serde_json::Value,
245    ) -> anyhow::Result<()> {
246        let target = self
247            .drt
248            .system_health()
249            .lock()
250            .get_health_check_target(endpoint_subject)
251            .ok_or_else(|| {
252                anyhow::anyhow!("No health check target found for {}", endpoint_subject)
253            })?;
254
255        debug!(
256            "Sending health check to {} (instance_id: {})",
257            endpoint_subject, target.instance.instance_id
258        );
259
260        // Create the Endpoint directly from the Instance info
261        let namespace = self.drt.namespace(&target.instance.namespace)?;
262        let component = namespace.component(&target.instance.component)?;
263        let endpoint = component.endpoint(&target.instance.endpoint);
264
265        // Get or create router for this endpoint
266        let router = self
267            .get_or_create_router(endpoint_subject, endpoint)
268            .await?;
269
270        // Wait for watch stream to discover instances before checking
271        // This ensures the router's client has populated its instance list
272        // from etcd before we attempt to send the health check request.
273        // Without this, the first health check can fail due to a race condition
274        // where the watch stream hasn't completed its initial discovery yet.
275        match tokio::time::timeout(
276            Duration::from_secs(10), // 10 second timeout for discovery
277            router.client.wait_for_instances(),
278        )
279        .await
280        {
281            Ok(Ok(instances)) => {
282                debug!(
283                    "Health check for {}: watch stream ready, found {} instance(s)",
284                    endpoint_subject,
285                    instances.len()
286                );
287            }
288            Ok(Err(e)) => {
289                return Err(anyhow::anyhow!(
290                    "Failed to discover instances for {} during health check: {}",
291                    endpoint_subject,
292                    e
293                ));
294            }
295            Err(_) => {
296                return Err(anyhow::anyhow!(
297                    "Timeout waiting for instance discovery for {} during health check",
298                    endpoint_subject
299                ));
300            }
301        }
302
303        // Create the request context
304        let request: SingleIn<serde_json::Value> = Context::new(payload.clone());
305
306        // Clone what we need for the spawned task
307        let system_health = self.drt.system_health().clone();
308        let endpoint_subject_owned = endpoint_subject.to_string();
309        let instance_id = target.instance.instance_id;
310        let timeout = self.config.request_timeout;
311
312        // Spawn task to send health check and wait for response
313        tokio::spawn(async move {
314            let result = tokio::time::timeout(timeout, async {
315                // Call direct() on the PushRouter to target specific instance
316                match router.direct(request, instance_id).await {
317                    Ok(mut response_stream) => {
318                        // Get the first response to verify endpoint is alive
319                        let is_healthy = if let Some(response) = response_stream.next().await {
320                            // Check if response indicates an error
321                            if let Some(error) = response.err() {
322                                warn!(
323                                    "Health check error response from {}: {:?}",
324                                    endpoint_subject_owned, error
325                                );
326                                false
327                            } else {
328                                debug!("Health check successful for {}", endpoint_subject_owned);
329                                true
330                            }
331                        } else {
332                            warn!(
333                                "Health check got no response from {}",
334                                endpoint_subject_owned
335                            );
336                            false
337                        };
338
339                        tokio::spawn(async move {
340                            // We need to consume the rest of the stream to avoid warnings on the frontend.
341                            response_stream.for_each(|_| async {}).await;
342                        });
343
344                        // Update health status based on response
345                        system_health.lock().set_endpoint_health_status(
346                            &endpoint_subject_owned,
347                            if is_healthy {
348                                HealthStatus::Ready
349                            } else {
350                                HealthStatus::NotReady
351                            },
352                        );
353                    }
354                    Err(e) => {
355                        error!(
356                            "Health check request failed for {}: {}",
357                            endpoint_subject_owned, e
358                        );
359                        system_health.lock().set_endpoint_health_status(
360                            &endpoint_subject_owned,
361                            HealthStatus::NotReady,
362                        );
363                    }
364                }
365            })
366            .await;
367
368            // Handle timeout
369            if result.is_err() {
370                warn!("Health check timeout for {}", endpoint_subject_owned);
371                system_health
372                    .lock()
373                    .set_endpoint_health_status(&endpoint_subject_owned, HealthStatus::NotReady);
374            }
375
376            debug!("Health check completed for {}", endpoint_subject_owned);
377        });
378
379        Ok(())
380    }
381}
382
383/// Start health check manager for the distributed runtime
384pub async fn start_health_check_manager(
385    drt: DistributedRuntime,
386    config: Option<HealthCheckConfig>,
387) -> anyhow::Result<()> {
388    let config = config.unwrap_or_default();
389    let manager = Arc::new(HealthCheckManager::new(drt, config));
390
391    // Start the health check manager (this spawns per-endpoint tasks internally)
392    manager.start().await?;
393
394    Ok(())
395}
396
397/// Get health check status for all endpoints
398pub async fn get_health_check_status(
399    drt: &DistributedRuntime,
400) -> anyhow::Result<serde_json::Value> {
401    // Get endpoints list from SystemHealth
402    let endpoint_subjects: Vec<String> = drt.system_health().lock().get_health_check_endpoints();
403
404    let mut endpoint_statuses = HashMap::new();
405
406    // Check each endpoint's health status
407    {
408        let system_health = drt.system_health();
409        let system_health_lock = system_health.lock();
410        for endpoint_subject in &endpoint_subjects {
411            let health_status = system_health_lock
412                .get_endpoint_health_status(endpoint_subject)
413                .unwrap_or(HealthStatus::NotReady);
414
415            let is_healthy = matches!(health_status, HealthStatus::Ready);
416
417            endpoint_statuses.insert(
418                endpoint_subject.clone(),
419                serde_json::json!({
420                    "healthy": is_healthy,
421                    "status": format!("{:?}", health_status),
422                }),
423            );
424        }
425    }
426
427    let overall_healthy = endpoint_statuses
428        .values()
429        .all(|v| v["healthy"].as_bool().unwrap_or(false));
430
431    Ok(serde_json::json!({
432        "status": if overall_healthy { "ready" } else { "notready" },
433        "endpoints_checked": endpoint_subjects.len(),
434        "endpoint_statuses": endpoint_statuses,
435    }))
436}
437
438// ===============================
439// Integration Tests (require DRT)
440// ===============================
441#[cfg(all(test, feature = "integration"))]
442mod integration_tests {
443    use super::*;
444    use crate::distributed::distributed_test_utils::create_test_drt_async;
445    use std::sync::Arc;
446    use std::time::Duration;
447
448    #[tokio::test]
449    async fn test_initialization() {
450        let drt = create_test_drt_async().await;
451
452        let canary_wait_time = Duration::from_secs(5);
453        let request_timeout = Duration::from_secs(3);
454
455        let config = HealthCheckConfig {
456            canary_wait_time,
457            request_timeout,
458        };
459
460        let manager = HealthCheckManager::new(drt.clone(), config);
461
462        assert_eq!(manager.config.canary_wait_time, canary_wait_time);
463        assert_eq!(manager.config.request_timeout, request_timeout);
464    }
465
466    #[tokio::test]
467    async fn test_payload_registration() {
468        let drt = create_test_drt_async().await;
469
470        let endpoint = "test.endpoint";
471        let payload = serde_json::json!({
472            "prompt": "test",
473            "_health_check": true
474        });
475
476        drt.system_health().lock().register_health_check_target(
477            endpoint,
478            crate::component::Instance {
479                component: "test_component".to_string(),
480                endpoint: "test_endpoint".to_string(),
481                namespace: "test_namespace".to_string(),
482                instance_id: 12345,
483                transport: crate::component::TransportType::Nats(endpoint.to_string()),
484            },
485            payload.clone(),
486        );
487
488        let retrieved = drt
489            .system_health()
490            .lock()
491            .get_health_check_target(endpoint)
492            .map(|t| t.payload);
493        assert!(retrieved.is_some());
494        assert_eq!(retrieved.unwrap(), payload);
495
496        // Verify endpoint appears in the list
497        let endpoints = drt.system_health().lock().get_health_check_endpoints();
498        assert!(endpoints.contains(&endpoint.to_string()));
499    }
500
501    #[tokio::test]
502    async fn test_spawn_per_endpoint_tasks() {
503        let drt = create_test_drt_async().await;
504
505        for i in 0..3 {
506            let endpoint = format!("test.endpoint.{}", i);
507            let payload = serde_json::json!({
508                "prompt": format!("test{}", i),
509                "_health_check": true
510            });
511            drt.system_health().lock().register_health_check_target(
512                &endpoint,
513                crate::component::Instance {
514                    component: "test_component".to_string(),
515                    endpoint: format!("test_endpoint_{}", i),
516                    namespace: "test_namespace".to_string(),
517                    instance_id: i,
518                    transport: crate::component::TransportType::Nats(endpoint.clone()),
519                },
520                payload,
521            );
522        }
523
524        let config = HealthCheckConfig {
525            canary_wait_time: Duration::from_secs(5),
526            request_timeout: Duration::from_secs(1),
527        };
528
529        let manager = Arc::new(HealthCheckManager::new(drt.clone(), config));
530        manager.clone().start().await.unwrap();
531
532        // Verify all endpoints have their own health check tasks
533        let tasks = manager.endpoint_tasks.lock();
534        // Should have 3 tasks (one for each endpoint)
535        assert_eq!(tasks.len(), 3);
536        // Check that all endpoints are represented in tasks
537        let endpoints: Vec<String> = tasks.keys().cloned().collect();
538        assert!(endpoints.contains(&"test.endpoint.0".to_string()));
539        assert!(endpoints.contains(&"test.endpoint.1".to_string()));
540        assert!(endpoints.contains(&"test.endpoint.2".to_string()));
541    }
542
543    #[tokio::test]
544    async fn test_endpoint_health_check_notifier_created() {
545        let drt = create_test_drt_async().await;
546
547        let endpoint = "test.endpoint.notifier";
548        let payload = serde_json::json!({
549            "prompt": "test",
550            "_health_check": true
551        });
552
553        // Register the endpoint
554        drt.system_health().lock().register_health_check_target(
555            endpoint,
556            crate::component::Instance {
557                component: "test_component".to_string(),
558                endpoint: "test_endpoint_notifier".to_string(),
559                namespace: "test_namespace".to_string(),
560                instance_id: 999,
561                transport: crate::component::TransportType::Nats(endpoint.to_string()),
562            },
563            payload.clone(),
564        );
565
566        // Verify that a notifier was created for this endpoint
567        let notifier = drt
568            .system_health()
569            .lock()
570            .get_endpoint_health_check_notifier(endpoint);
571
572        assert!(
573            notifier.is_some(),
574            "Endpoint should have a notifier created"
575        );
576
577        // Verify we can notify it without panicking
578        if let Some(notifier) = notifier {
579            notifier.notify_one();
580        }
581
582        // Initially, the endpoint should be Ready (default after registration)
583        let status = drt
584            .system_health()
585            .lock()
586            .get_endpoint_health_status(endpoint);
587        assert_eq!(status, Some(HealthStatus::NotReady));
588    }
589}