dynamo_runtime/
health_check.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::component::{Client, Component, Endpoint, Instance};
5use crate::config::HealthStatus;
6use crate::pipeline::PushRouter;
7use crate::pipeline::{AsyncEngine, Context, ManyOut, SingleIn};
8use crate::protocols::annotated::Annotated;
9use crate::protocols::maybe_error::MaybeError;
10use crate::{DistributedRuntime, SystemHealth};
11use futures::StreamExt;
12use parking_lot::Mutex;
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::sync::Arc;
16use std::time::{Duration, Instant};
17use tokio::task::JoinHandle;
18use tokio::time::{MissedTickBehavior, interval};
19use tracing::{debug, error, info, warn};
20
21/// Configuration for health check behavior
22pub struct HealthCheckConfig {
23    /// Wait time before sending canary health checks (when no activity)
24    pub canary_wait_time: Duration,
25    /// Timeout for health check requests
26    pub request_timeout: Duration,
27}
28
29impl Default for HealthCheckConfig {
30    fn default() -> Self {
31        Self {
32            canary_wait_time: Duration::from_secs(crate::config::DEFAULT_CANARY_WAIT_TIME_SECS),
33            request_timeout: Duration::from_secs(
34                crate::config::DEFAULT_HEALTH_CHECK_REQUEST_TIMEOUT_SECS,
35            ),
36        }
37    }
38}
39
40// Type alias for the router cache to improve readability
41// Maps endpoint subject -> router and payload
42type RouterCache =
43    Arc<Mutex<HashMap<String, Arc<PushRouter<serde_json::Value, Annotated<serde_json::Value>>>>>>;
44
45/// Health check manager that monitors endpoint health
46pub struct HealthCheckManager {
47    drt: DistributedRuntime,
48    config: HealthCheckConfig,
49    /// Cache of PushRouters and payloads for each endpoint
50    router_cache: RouterCache,
51    /// Track per-endpoint health check tasks
52    /// Maps: endpoint_subject -> task_handle
53    endpoint_tasks: Arc<Mutex<HashMap<String, JoinHandle<()>>>>,
54}
55
56impl HealthCheckManager {
57    pub fn new(drt: DistributedRuntime, config: HealthCheckConfig) -> Self {
58        Self {
59            drt,
60            config,
61            router_cache: Arc::new(Mutex::new(HashMap::new())),
62            endpoint_tasks: Arc::new(Mutex::new(HashMap::new())),
63        }
64    }
65
66    /// Get or create a PushRouter for an endpoint
67    async fn get_or_create_router(
68        &self,
69        cache_key: &str,
70        endpoint: Endpoint,
71    ) -> anyhow::Result<Arc<PushRouter<serde_json::Value, Annotated<serde_json::Value>>>> {
72        let cache_key = cache_key.to_string();
73
74        // Check cache first
75        {
76            let cache = self.router_cache.lock();
77            if let Some(router) = cache.get(&cache_key) {
78                return Ok(router.clone());
79            }
80        }
81
82        // Create a client that discovers instances dynamically for this endpoint
83        let client = Client::new_dynamic(endpoint).await?;
84
85        // Create PushRouter - it will use direct routing when we call direct()
86        let router: Arc<PushRouter<serde_json::Value, Annotated<serde_json::Value>>> = Arc::new(
87            PushRouter::from_client(
88                client,
89                crate::pipeline::RouterMode::RoundRobin, // Default mode, we'll use direct() explicitly
90            )
91            .await?,
92        );
93
94        // Cache it
95        self.router_cache.lock().insert(cache_key, router.clone());
96
97        Ok(router)
98    }
99
100    /// Start the health check manager by spawning per-endpoint monitoring tasks
101    pub async fn start(self: Arc<Self>) -> anyhow::Result<()> {
102        // Get all registered endpoints at startup
103        let targets = self.drt.system_health().lock().get_health_check_targets();
104
105        info!(
106            "Starting health check tasks for {} endpoints with canary_wait_time: {:?}",
107            targets.len(),
108            self.config.canary_wait_time
109        );
110
111        // Spawn a health check task for each registered endpoint
112        for (endpoint_subject, _target) in targets {
113            self.spawn_endpoint_health_check_task(endpoint_subject);
114        }
115
116        // CRITICAL: Spawn a task to monitor for NEW endpoints registered after startup
117        // This uses a channel-based approach to guarantee no lost notifications
118        // Will return an error if the receiver has already been taken
119        self.spawn_new_endpoint_monitor().await?;
120
121        info!("HealthCheckManager started successfully with channel-based endpoint discovery");
122        Ok(())
123    }
124
125    /// Spawn a dedicated health check task for a specific endpoint
126    fn spawn_endpoint_health_check_task(self: &Arc<Self>, endpoint_subject: String) {
127        let manager = self.clone();
128        let canary_wait = self.config.canary_wait_time;
129        let endpoint_subject_clone = endpoint_subject.clone();
130
131        // Get the endpoint-specific notifier
132        let notifier = self
133            .drt
134            .system_health()
135            .lock()
136            .get_endpoint_health_check_notifier(&endpoint_subject)
137            .expect("Notifier should exist for registered endpoint");
138
139        let task = tokio::spawn(async move {
140            let endpoint_subject = endpoint_subject_clone;
141            info!("Health check task started for: {}", endpoint_subject);
142
143            loop {
144                // Wait for either timeout or activity notification
145                tokio::select! {
146                    _ = tokio::time::sleep(canary_wait) => {
147                        // Timeout - send health check for this specific endpoint
148                        info!("Canary timer expired for {}, sending health check", endpoint_subject);
149
150                        // Get the health check payload for this endpoint
151                        let target = manager.drt.system_health().lock().get_health_check_target(&endpoint_subject);
152
153                        if let Some(target) = target {
154                            if let Err(e) = manager.send_health_check_request(&endpoint_subject, &target.payload).await {
155                                error!("Failed to send health check for {}: {}", endpoint_subject, e);
156                            }
157                        } else {
158                            // This should never happen - targets are registered at startup and never removed
159                            error!(
160                                "CRITICAL: Health check target for {} disappeared unexpectedly! This indicates a bug. Stopping health check task.",
161                                endpoint_subject
162                            );
163                            break;
164                        }
165                    }
166
167                    _ = notifier.notified() => {
168                        // Activity detected - reset timer for this endpoint only
169                        debug!("Activity detected for {}, resetting health check timer", endpoint_subject);
170                        // Loop continues, timer resets
171                    }
172                }
173            }
174
175            info!("Health check task for {} exiting", endpoint_subject);
176        });
177
178        // Store the task handle
179        self.endpoint_tasks
180            .lock()
181            .insert(endpoint_subject.clone(), task);
182
183        info!(
184            "Spawned health check task for endpoint: {}",
185            endpoint_subject
186        );
187    }
188
189    /// Spawn a task to monitor for newly registered endpoints
190    /// Returns an error if duplicate endpoints are detected, indicating a bug in the system
191    async fn spawn_new_endpoint_monitor(self: &Arc<Self>) -> anyhow::Result<()> {
192        let manager = self.clone();
193
194        // Get the receiver (can only be taken once)
195        let mut rx = manager
196            .drt
197            .system_health()
198            .lock()
199            .take_new_endpoint_receiver()
200            .ok_or_else(|| {
201                anyhow::anyhow!("Endpoint receiver already taken - this should only be called once")
202            })?;
203
204        tokio::spawn(async move {
205            info!("Starting dynamic endpoint discovery monitor with channel-based notifications");
206
207            while let Some(endpoint_subject) = rx.recv().await {
208                debug!(
209                    "Received endpoint registration via channel: {}",
210                    endpoint_subject
211                );
212
213                let already_exists = {
214                    let tasks = manager.endpoint_tasks.lock();
215                    tasks.contains_key(&endpoint_subject)
216                };
217
218                if already_exists {
219                    error!(
220                        "CRITICAL: Received registration for endpoint '{}' that already has a health check task!",
221                        endpoint_subject
222                    );
223                    break;
224                }
225
226                info!(
227                    "Spawning health check task for new endpoint: {}",
228                    endpoint_subject
229                );
230                manager.spawn_endpoint_health_check_task(endpoint_subject);
231            }
232
233            info!("Endpoint discovery monitor exiting - no new endpoints will be monitored!");
234        });
235
236        info!("Dynamic endpoint discovery monitor started");
237        Ok(())
238    }
239
240    /// Send a health check request through AsyncEngine
241    async fn send_health_check_request(
242        &self,
243        endpoint_subject: &str,
244        payload: &serde_json::Value,
245    ) -> anyhow::Result<()> {
246        let target = self
247            .drt
248            .system_health()
249            .lock()
250            .get_health_check_target(endpoint_subject)
251            .ok_or_else(|| {
252                anyhow::anyhow!("No health check target found for {}", endpoint_subject)
253            })?;
254
255        debug!(
256            "Sending health check to {} (instance_id: {})",
257            endpoint_subject, target.instance.instance_id
258        );
259
260        // Create the Endpoint directly from the Instance info
261        let namespace = self.drt.namespace(&target.instance.namespace)?;
262        let component = namespace.component(&target.instance.component)?;
263        let endpoint = component.endpoint(&target.instance.endpoint);
264
265        // Get or create router for this endpoint
266        let router = self
267            .get_or_create_router(endpoint_subject, endpoint)
268            .await?;
269
270        // Create the request context
271        let request: SingleIn<serde_json::Value> = Context::new(payload.clone());
272
273        // Clone what we need for the spawned task
274        let system_health = self.drt.system_health().clone();
275        let endpoint_subject_owned = endpoint_subject.to_string();
276        let instance_id = target.instance.instance_id;
277        let timeout = self.config.request_timeout;
278
279        // Spawn task to send health check and wait for response
280        tokio::spawn(async move {
281            let result = tokio::time::timeout(timeout, async {
282                // Call direct() on the PushRouter to target specific instance
283                match router.direct(request, instance_id).await {
284                    Ok(mut response_stream) => {
285                        // Get the first response to verify endpoint is alive
286                        let is_healthy = if let Some(response) = response_stream.next().await {
287                            // Check if response indicates an error
288                            if let Some(error) = response.err() {
289                                warn!(
290                                    "Health check error response from {}: {:?}",
291                                    endpoint_subject_owned, error
292                                );
293                                false
294                            } else {
295                                info!("Health check successful for {}", endpoint_subject_owned);
296                                true
297                            }
298                        } else {
299                            warn!(
300                                "Health check got no response from {}",
301                                endpoint_subject_owned
302                            );
303                            false
304                        };
305
306                        // Update health status based on response
307                        system_health.lock().set_endpoint_health_status(
308                            &endpoint_subject_owned,
309                            if is_healthy {
310                                HealthStatus::Ready
311                            } else {
312                                HealthStatus::NotReady
313                            },
314                        );
315                    }
316                    Err(e) => {
317                        error!(
318                            "Health check request failed for {}: {}",
319                            endpoint_subject_owned, e
320                        );
321                        system_health.lock().set_endpoint_health_status(
322                            &endpoint_subject_owned,
323                            HealthStatus::NotReady,
324                        );
325                    }
326                }
327            })
328            .await;
329
330            // Handle timeout
331            if result.is_err() {
332                warn!("Health check timeout for {}", endpoint_subject_owned);
333                system_health
334                    .lock()
335                    .set_endpoint_health_status(&endpoint_subject_owned, HealthStatus::NotReady);
336            }
337
338            debug!("Health check completed for {}", endpoint_subject_owned);
339        });
340
341        Ok(())
342    }
343}
344
345/// Start health check manager for the distributed runtime
346pub async fn start_health_check_manager(
347    drt: DistributedRuntime,
348    config: Option<HealthCheckConfig>,
349) -> anyhow::Result<()> {
350    let config = config.unwrap_or_default();
351    let manager = Arc::new(HealthCheckManager::new(drt, config));
352
353    // Start the health check manager (this spawns per-endpoint tasks internally)
354    manager.start().await?;
355
356    Ok(())
357}
358
359/// Get health check status for all endpoints
360pub async fn get_health_check_status(
361    drt: &DistributedRuntime,
362) -> anyhow::Result<serde_json::Value> {
363    // Get endpoints list from SystemHealth
364    let endpoint_subjects: Vec<String> = drt.system_health().lock().get_health_check_endpoints();
365
366    let mut endpoint_statuses = HashMap::new();
367
368    // Check each endpoint's health status
369    {
370        let system_health = drt.system_health();
371        let system_health_lock = system_health.lock();
372        for endpoint_subject in &endpoint_subjects {
373            let health_status = system_health_lock
374                .get_endpoint_health_status(endpoint_subject)
375                .unwrap_or(HealthStatus::NotReady);
376
377            let is_healthy = matches!(health_status, HealthStatus::Ready);
378
379            endpoint_statuses.insert(
380                endpoint_subject.clone(),
381                serde_json::json!({
382                    "healthy": is_healthy,
383                    "status": format!("{:?}", health_status),
384                }),
385            );
386        }
387    }
388
389    let overall_healthy = endpoint_statuses
390        .values()
391        .all(|v| v["healthy"].as_bool().unwrap_or(false));
392
393    Ok(serde_json::json!({
394        "status": if overall_healthy { "ready" } else { "notready" },
395        "endpoints_checked": endpoint_subjects.len(),
396        "endpoint_statuses": endpoint_statuses,
397    }))
398}
399
400// ===============================
401// Integration Tests (require DRT)
402// ===============================
403#[cfg(all(test, feature = "integration"))]
404mod integration_tests {
405    use super::*;
406    use crate::distributed::distributed_test_utils::create_test_drt_async;
407    use std::sync::Arc;
408    use std::time::Duration;
409
410    #[tokio::test]
411    async fn test_initialization() {
412        let drt = create_test_drt_async().await;
413
414        let canary_wait_time = Duration::from_secs(5);
415        let request_timeout = Duration::from_secs(3);
416
417        let config = HealthCheckConfig {
418            canary_wait_time,
419            request_timeout,
420        };
421
422        let manager = HealthCheckManager::new(drt.clone(), config);
423
424        assert_eq!(manager.config.canary_wait_time, canary_wait_time);
425        assert_eq!(manager.config.request_timeout, request_timeout);
426    }
427
428    #[tokio::test]
429    async fn test_payload_registration() {
430        let drt = create_test_drt_async().await;
431
432        let endpoint = "test.endpoint";
433        let payload = serde_json::json!({
434            "prompt": "test",
435            "_health_check": true
436        });
437
438        drt.system_health().lock().register_health_check_target(
439            endpoint,
440            crate::component::Instance {
441                component: "test_component".to_string(),
442                endpoint: "test_endpoint".to_string(),
443                namespace: "test_namespace".to_string(),
444                instance_id: 12345,
445                transport: crate::component::TransportType::Nats(endpoint.to_string()),
446            },
447            payload.clone(),
448        );
449
450        let retrieved = drt
451            .system_health()
452            .lock()
453            .get_health_check_target(endpoint)
454            .map(|t| t.payload);
455        assert!(retrieved.is_some());
456        assert_eq!(retrieved.unwrap(), payload);
457
458        // Verify endpoint appears in the list
459        let endpoints = drt.system_health().lock().get_health_check_endpoints();
460        assert!(endpoints.contains(&endpoint.to_string()));
461    }
462
463    #[tokio::test]
464    async fn test_spawn_per_endpoint_tasks() {
465        let drt = create_test_drt_async().await;
466
467        for i in 0..3 {
468            let endpoint = format!("test.endpoint.{}", i);
469            let payload = serde_json::json!({
470                "prompt": format!("test{}", i),
471                "_health_check": true
472            });
473            drt.system_health().lock().register_health_check_target(
474                &endpoint,
475                crate::component::Instance {
476                    component: "test_component".to_string(),
477                    endpoint: format!("test_endpoint_{}", i),
478                    namespace: "test_namespace".to_string(),
479                    instance_id: i,
480                    transport: crate::component::TransportType::Nats(endpoint.clone()),
481                },
482                payload,
483            );
484        }
485
486        let config = HealthCheckConfig {
487            canary_wait_time: Duration::from_secs(5),
488            request_timeout: Duration::from_secs(1),
489        };
490
491        let manager = Arc::new(HealthCheckManager::new(drt.clone(), config));
492        manager.clone().start().await.unwrap();
493
494        // Verify all endpoints have their own health check tasks
495        let tasks = manager.endpoint_tasks.lock();
496        // Should have 3 tasks (one for each endpoint)
497        assert_eq!(tasks.len(), 3);
498        // Check that all endpoints are represented in tasks
499        let endpoints: Vec<String> = tasks.keys().cloned().collect();
500        assert!(endpoints.contains(&"test.endpoint.0".to_string()));
501        assert!(endpoints.contains(&"test.endpoint.1".to_string()));
502        assert!(endpoints.contains(&"test.endpoint.2".to_string()));
503    }
504
505    #[tokio::test]
506    async fn test_endpoint_health_check_notifier_created() {
507        let drt = create_test_drt_async().await;
508
509        let endpoint = "test.endpoint.notifier";
510        let payload = serde_json::json!({
511            "prompt": "test",
512            "_health_check": true
513        });
514
515        // Register the endpoint
516        drt.system_health().lock().register_health_check_target(
517            endpoint,
518            crate::component::Instance {
519                component: "test_component".to_string(),
520                endpoint: "test_endpoint_notifier".to_string(),
521                namespace: "test_namespace".to_string(),
522                instance_id: 999,
523                transport: crate::component::TransportType::Nats(endpoint.to_string()),
524            },
525            payload.clone(),
526        );
527
528        // Verify that a notifier was created for this endpoint
529        let notifier = drt
530            .system_health()
531            .lock()
532            .get_endpoint_health_check_notifier(endpoint);
533
534        assert!(
535            notifier.is_some(),
536            "Endpoint should have a notifier created"
537        );
538
539        // Verify we can notify it without panicking
540        if let Some(notifier) = notifier {
541            notifier.notify_one();
542        }
543
544        // Initially, the endpoint should be Ready (default after registration)
545        let status = drt
546            .system_health()
547            .lock()
548            .get_endpoint_health_status(endpoint);
549        assert_eq!(status, Some(HealthStatus::NotReady));
550    }
551}