dynamo_runtime/
health_check.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::component::{Client, Component, Endpoint, Instance};
5use crate::pipeline::PushRouter;
6use crate::pipeline::{AsyncEngine, Context, ManyOut, SingleIn};
7use crate::protocols::annotated::Annotated;
8use crate::protocols::maybe_error::MaybeError;
9use crate::{DistributedRuntime, HealthStatus, SystemHealth};
10use futures::StreamExt;
11use parking_lot::Mutex;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16use tokio::task::JoinHandle;
17use tokio::time::{MissedTickBehavior, interval};
18use tracing::{debug, error, info, warn};
19
20/// Configuration for health check behavior
21pub struct HealthCheckConfig {
22    /// Wait time before sending canary health checks (when no activity)
23    pub canary_wait_time: Duration,
24    /// Timeout for health check requests
25    pub request_timeout: Duration,
26}
27
28impl Default for HealthCheckConfig {
29    fn default() -> Self {
30        Self {
31            canary_wait_time: Duration::from_secs(crate::config::DEFAULT_CANARY_WAIT_TIME_SECS),
32            request_timeout: Duration::from_secs(
33                crate::config::DEFAULT_HEALTH_CHECK_REQUEST_TIMEOUT_SECS,
34            ),
35        }
36    }
37}
38
39// Type alias for the router cache to improve readability
40// Maps endpoint subject -> router and payload
41type RouterCache =
42    Arc<Mutex<HashMap<String, Arc<PushRouter<serde_json::Value, Annotated<serde_json::Value>>>>>>;
43
44/// Health check manager that monitors endpoint health
45pub struct HealthCheckManager {
46    drt: DistributedRuntime,
47    config: HealthCheckConfig,
48    /// Cache of PushRouters and payloads for each endpoint
49    router_cache: RouterCache,
50    /// Track per-endpoint health check tasks
51    /// Maps: endpoint_subject -> task_handle
52    endpoint_tasks: Arc<Mutex<HashMap<String, JoinHandle<()>>>>,
53}
54
55impl HealthCheckManager {
56    pub fn new(drt: DistributedRuntime, config: HealthCheckConfig) -> Self {
57        Self {
58            drt,
59            config,
60            router_cache: Arc::new(Mutex::new(HashMap::new())),
61            endpoint_tasks: Arc::new(Mutex::new(HashMap::new())),
62        }
63    }
64
65    /// Get or create a PushRouter for an endpoint
66    async fn get_or_create_router(
67        &self,
68        cache_key: &str,
69        endpoint: Endpoint,
70    ) -> anyhow::Result<Arc<PushRouter<serde_json::Value, Annotated<serde_json::Value>>>> {
71        let cache_key = cache_key.to_string();
72
73        // Check cache first
74        {
75            let cache = self.router_cache.lock();
76            if let Some(router) = cache.get(&cache_key) {
77                return Ok(router.clone());
78            }
79        }
80
81        // Create a client that discovers instances dynamically for this endpoint
82        let client = Client::new_dynamic(endpoint).await?;
83
84        // Create PushRouter - it will use direct routing when we call direct()
85        let router: Arc<PushRouter<serde_json::Value, Annotated<serde_json::Value>>> = Arc::new(
86            PushRouter::from_client(
87                client,
88                crate::pipeline::RouterMode::RoundRobin, // Default mode, we'll use direct() explicitly
89            )
90            .await?,
91        );
92
93        // Cache it
94        self.router_cache.lock().insert(cache_key, router.clone());
95
96        Ok(router)
97    }
98
99    /// Start the health check manager by spawning per-endpoint monitoring tasks
100    pub async fn start(self: Arc<Self>) -> anyhow::Result<()> {
101        // Get all registered endpoints at startup
102        let targets = {
103            let system_health = self.drt.system_health.lock();
104            system_health.get_health_check_targets()
105        };
106
107        info!(
108            "Starting health check tasks for {} endpoints with canary_wait_time: {:?}",
109            targets.len(),
110            self.config.canary_wait_time
111        );
112
113        // Spawn a health check task for each registered endpoint
114        for (endpoint_subject, _target) in targets {
115            self.spawn_endpoint_health_check_task(endpoint_subject);
116        }
117
118        // CRITICAL: Spawn a task to monitor for NEW endpoints registered after startup
119        // This uses a channel-based approach to guarantee no lost notifications
120        // Will return an error if the receiver has already been taken
121        self.spawn_new_endpoint_monitor().await?;
122
123        info!("HealthCheckManager started successfully with channel-based endpoint discovery");
124        Ok(())
125    }
126
127    /// Spawn a dedicated health check task for a specific endpoint
128    fn spawn_endpoint_health_check_task(self: &Arc<Self>, endpoint_subject: String) {
129        let manager = self.clone();
130        let canary_wait = self.config.canary_wait_time;
131        let endpoint_subject_clone = endpoint_subject.clone();
132
133        // Get the endpoint-specific notifier
134        let notifier = {
135            let system_health = self.drt.system_health.lock();
136            system_health
137                .get_endpoint_health_check_notifier(&endpoint_subject)
138                .expect("Notifier should exist for registered endpoint")
139        };
140
141        let task = tokio::spawn(async move {
142            let endpoint_subject = endpoint_subject_clone;
143            info!("Health check task started for: {}", endpoint_subject);
144
145            loop {
146                // Wait for either timeout or activity notification
147                tokio::select! {
148                    _ = tokio::time::sleep(canary_wait) => {
149                        // Timeout - send health check for this specific endpoint
150                        info!("Canary timer expired for {}, sending health check", endpoint_subject);
151
152                        // Get the health check payload for this endpoint
153                        let target = {
154                            let system_health = manager.drt.system_health.lock();
155                            system_health.get_health_check_target(&endpoint_subject)
156                        };
157
158                        if let Some(target) = target {
159                            if let Err(e) = manager.send_health_check_request(&endpoint_subject, &target.payload).await {
160                                error!("Failed to send health check for {}: {}", endpoint_subject, e);
161                            }
162                        } else {
163                            // This should never happen - targets are registered at startup and never removed
164                            error!(
165                                "CRITICAL: Health check target for {} disappeared unexpectedly! This indicates a bug. Stopping health check task.",
166                                endpoint_subject
167                            );
168                            break;
169                        }
170                    }
171
172                    _ = notifier.notified() => {
173                        // Activity detected - reset timer for this endpoint only
174                        debug!("Activity detected for {}, resetting health check timer", endpoint_subject);
175                        // Loop continues, timer resets
176                    }
177                }
178            }
179
180            info!("Health check task for {} exiting", endpoint_subject);
181        });
182
183        // Store the task handle
184        self.endpoint_tasks
185            .lock()
186            .insert(endpoint_subject.clone(), task);
187
188        info!(
189            "Spawned health check task for endpoint: {}",
190            endpoint_subject
191        );
192    }
193
194    /// Spawn a task to monitor for newly registered endpoints
195    /// Returns an error if duplicate endpoints are detected, indicating a bug in the system
196    async fn spawn_new_endpoint_monitor(self: &Arc<Self>) -> anyhow::Result<()> {
197        let manager = self.clone();
198
199        // Get the receiver (can only be taken once)
200        let mut rx = {
201            let system_health = manager.drt.system_health.lock();
202            system_health.take_new_endpoint_receiver().ok_or_else(|| {
203                anyhow::anyhow!("Endpoint receiver already taken - this should only be called once")
204            })?
205        };
206
207        tokio::spawn(async move {
208            info!("Starting dynamic endpoint discovery monitor with channel-based notifications");
209
210            while let Some(endpoint_subject) = rx.recv().await {
211                debug!(
212                    "Received endpoint registration via channel: {}",
213                    endpoint_subject
214                );
215
216                let already_exists = {
217                    let tasks = manager.endpoint_tasks.lock();
218                    tasks.contains_key(&endpoint_subject)
219                };
220
221                if already_exists {
222                    error!(
223                        "CRITICAL: Received registration for endpoint '{}' that already has a health check task!",
224                        endpoint_subject
225                    );
226                    break;
227                }
228
229                info!(
230                    "Spawning health check task for new endpoint: {}",
231                    endpoint_subject
232                );
233                manager.spawn_endpoint_health_check_task(endpoint_subject);
234            }
235
236            info!("Endpoint discovery monitor exiting - no new endpoints will be monitored!");
237        });
238
239        info!("Dynamic endpoint discovery monitor started");
240        Ok(())
241    }
242
243    /// Send a health check request through AsyncEngine
244    async fn send_health_check_request(
245        &self,
246        endpoint_subject: &str,
247        payload: &serde_json::Value,
248    ) -> anyhow::Result<()> {
249        let target = {
250            let system_health = self.drt.system_health.lock();
251            system_health
252                .get_health_check_target(endpoint_subject)
253                .ok_or_else(|| {
254                    anyhow::anyhow!("No health check target found for {}", endpoint_subject)
255                })?
256        };
257
258        debug!(
259            "Sending health check to {} (instance_id: {})",
260            endpoint_subject, target.instance.instance_id
261        );
262
263        // Create the Endpoint directly from the Instance info
264        let namespace = self.drt.namespace(&target.instance.namespace)?;
265        let component = namespace.component(&target.instance.component)?;
266        let endpoint = component.endpoint(&target.instance.endpoint);
267
268        // Get or create router for this endpoint
269        let router = self
270            .get_or_create_router(endpoint_subject, endpoint)
271            .await?;
272
273        // Create the request context
274        let request: SingleIn<serde_json::Value> = Context::new(payload.clone());
275
276        // Clone what we need for the spawned task
277        let system_health = self.drt.system_health.clone();
278        let endpoint_subject_owned = endpoint_subject.to_string();
279        let instance_id = target.instance.instance_id;
280        let timeout = self.config.request_timeout;
281
282        // Spawn task to send health check and wait for response
283        tokio::spawn(async move {
284            let result = tokio::time::timeout(timeout, async {
285                // Call direct() on the PushRouter to target specific instance
286                match router.direct(request, instance_id).await {
287                    Ok(mut response_stream) => {
288                        // Get the first response to verify endpoint is alive
289                        let is_healthy = if let Some(response) = response_stream.next().await {
290                            // Check if response indicates an error
291                            if let Some(error) = response.err() {
292                                warn!(
293                                    "Health check error response from {}: {:?}",
294                                    endpoint_subject_owned, error
295                                );
296                                false
297                            } else {
298                                info!("Health check successful for {}", endpoint_subject_owned);
299                                true
300                            }
301                        } else {
302                            warn!(
303                                "Health check got no response from {}",
304                                endpoint_subject_owned
305                            );
306                            false
307                        };
308
309                        // Update health status based on response
310                        system_health.lock().set_endpoint_health_status(
311                            &endpoint_subject_owned,
312                            if is_healthy {
313                                HealthStatus::Ready
314                            } else {
315                                HealthStatus::NotReady
316                            },
317                        );
318                    }
319                    Err(e) => {
320                        error!(
321                            "Health check request failed for {}: {}",
322                            endpoint_subject_owned, e
323                        );
324                        system_health.lock().set_endpoint_health_status(
325                            &endpoint_subject_owned,
326                            HealthStatus::NotReady,
327                        );
328                    }
329                }
330            })
331            .await;
332
333            // Handle timeout
334            if result.is_err() {
335                warn!("Health check timeout for {}", endpoint_subject_owned);
336                system_health
337                    .lock()
338                    .set_endpoint_health_status(&endpoint_subject_owned, HealthStatus::NotReady);
339            }
340
341            debug!("Health check completed for {}", endpoint_subject_owned);
342        });
343
344        Ok(())
345    }
346}
347
348/// Start health check manager for the distributed runtime
349pub async fn start_health_check_manager(
350    drt: DistributedRuntime,
351    config: Option<HealthCheckConfig>,
352) -> anyhow::Result<()> {
353    let config = config.unwrap_or_default();
354    let manager = Arc::new(HealthCheckManager::new(drt, config));
355
356    // Start the health check manager (this spawns per-endpoint tasks internally)
357    manager.start().await?;
358
359    Ok(())
360}
361
362/// Get health check status for all endpoints
363pub async fn get_health_check_status(
364    drt: &DistributedRuntime,
365) -> anyhow::Result<serde_json::Value> {
366    // Get endpoints list from SystemHealth
367    let endpoint_subjects: Vec<String> = {
368        let system_health = drt.system_health.lock();
369        system_health.get_health_check_endpoints()
370    };
371
372    let mut endpoint_statuses = HashMap::new();
373
374    // Check each endpoint's health status
375    {
376        let system_health = drt.system_health.lock();
377        for endpoint_subject in &endpoint_subjects {
378            let health_status = system_health
379                .get_endpoint_health_status(endpoint_subject)
380                .unwrap_or(HealthStatus::NotReady);
381
382            let is_healthy = matches!(health_status, HealthStatus::Ready);
383
384            endpoint_statuses.insert(
385                endpoint_subject.clone(),
386                serde_json::json!({
387                    "healthy": is_healthy,
388                    "status": format!("{:?}", health_status),
389                }),
390            );
391        }
392    }
393
394    let overall_healthy = endpoint_statuses
395        .values()
396        .all(|v| v["healthy"].as_bool().unwrap_or(false));
397
398    Ok(serde_json::json!({
399        "status": if overall_healthy { "ready" } else { "notready" },
400        "endpoints_checked": endpoint_subjects.len(),
401        "endpoint_statuses": endpoint_statuses,
402    }))
403}
404
405// ===============================
406// Integration Tests (require DRT)
407// ===============================
408#[cfg(all(test, feature = "integration"))]
409mod integration_tests {
410    use super::*;
411    use crate::HealthStatus;
412    use crate::distributed::distributed_test_utils::create_test_drt_async;
413    use std::sync::Arc;
414    use std::time::Duration;
415
416    #[tokio::test]
417    async fn test_initialization() {
418        let drt = create_test_drt_async().await;
419
420        let canary_wait_time = Duration::from_secs(5);
421        let request_timeout = Duration::from_secs(3);
422
423        let config = HealthCheckConfig {
424            canary_wait_time,
425            request_timeout,
426        };
427
428        let manager = HealthCheckManager::new(drt.clone(), config);
429
430        assert_eq!(manager.config.canary_wait_time, canary_wait_time);
431        assert_eq!(manager.config.request_timeout, request_timeout);
432
433        assert!(Arc::ptr_eq(&manager.drt.system_health, &drt.system_health));
434    }
435
436    #[tokio::test]
437    async fn test_payload_registration() {
438        let drt = create_test_drt_async().await;
439
440        let endpoint = "test.endpoint";
441        let payload = serde_json::json!({
442            "prompt": "test",
443            "_health_check": true
444        });
445
446        drt.system_health.lock().register_health_check_target(
447            endpoint,
448            crate::component::Instance {
449                component: "test_component".to_string(),
450                endpoint: "test_endpoint".to_string(),
451                namespace: "test_namespace".to_string(),
452                instance_id: 12345,
453                transport: crate::component::TransportType::NatsTcp(endpoint.to_string()),
454            },
455            payload.clone(),
456        );
457
458        let retrieved = drt
459            .system_health
460            .lock()
461            .get_health_check_target(endpoint)
462            .map(|t| t.payload);
463        assert!(retrieved.is_some());
464        assert_eq!(retrieved.unwrap(), payload);
465
466        // Verify endpoint appears in the list
467        let endpoints = drt.system_health.lock().get_health_check_endpoints();
468        assert!(endpoints.contains(&endpoint.to_string()));
469    }
470
471    #[tokio::test]
472    async fn test_spawn_per_endpoint_tasks() {
473        let drt = create_test_drt_async().await;
474
475        for i in 0..3 {
476            let endpoint = format!("test.endpoint.{}", i);
477            let payload = serde_json::json!({
478                "prompt": format!("test{}", i),
479                "_health_check": true
480            });
481            drt.system_health.lock().register_health_check_target(
482                &endpoint,
483                crate::component::Instance {
484                    component: "test_component".to_string(),
485                    endpoint: format!("test_endpoint_{}", i),
486                    namespace: "test_namespace".to_string(),
487                    instance_id: i,
488                    transport: crate::component::TransportType::NatsTcp(endpoint.clone()),
489                },
490                payload,
491            );
492        }
493
494        let config = HealthCheckConfig {
495            canary_wait_time: Duration::from_secs(5),
496            request_timeout: Duration::from_secs(1),
497        };
498
499        let manager = Arc::new(HealthCheckManager::new(drt.clone(), config));
500        manager.clone().start().await.unwrap();
501
502        // Verify all endpoints have their own health check tasks
503        let tasks = manager.endpoint_tasks.lock();
504        // Should have 3 tasks (one for each endpoint)
505        assert_eq!(tasks.len(), 3);
506        // Check that all endpoints are represented in tasks
507        let endpoints: Vec<String> = tasks.keys().cloned().collect();
508        assert!(endpoints.contains(&"test.endpoint.0".to_string()));
509        assert!(endpoints.contains(&"test.endpoint.1".to_string()));
510        assert!(endpoints.contains(&"test.endpoint.2".to_string()));
511    }
512
513    #[tokio::test]
514    async fn test_endpoint_health_check_notifier_created() {
515        let drt = create_test_drt_async().await;
516
517        let endpoint = "test.endpoint.notifier";
518        let payload = serde_json::json!({
519            "prompt": "test",
520            "_health_check": true
521        });
522
523        // Register the endpoint
524        drt.system_health.lock().register_health_check_target(
525            endpoint,
526            crate::component::Instance {
527                component: "test_component".to_string(),
528                endpoint: "test_endpoint_notifier".to_string(),
529                namespace: "test_namespace".to_string(),
530                instance_id: 999,
531                transport: crate::component::TransportType::NatsTcp(endpoint.to_string()),
532            },
533            payload.clone(),
534        );
535
536        // Verify that a notifier was created for this endpoint
537        let notifier = drt
538            .system_health
539            .lock()
540            .get_endpoint_health_check_notifier(endpoint);
541
542        assert!(
543            notifier.is_some(),
544            "Endpoint should have a notifier created"
545        );
546
547        // Verify we can notify it without panicking
548        if let Some(notifier) = notifier {
549            notifier.notify_one();
550        }
551
552        // Initially, the endpoint should be Ready (default after registration)
553        let status = drt
554            .system_health
555            .lock()
556            .get_endpoint_health_status(endpoint);
557        assert_eq!(status, Some(HealthStatus::NotReady));
558    }
559}