dynamo_runtime/
health_check.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::component::{Client, Component, Endpoint, Instance};
5use crate::pipeline::PushRouter;
6use crate::pipeline::{AsyncEngine, Context, ManyOut, SingleIn};
7use crate::protocols::annotated::Annotated;
8use crate::protocols::maybe_error::MaybeError;
9use crate::{DistributedRuntime, HealthStatus, SystemHealth};
10use futures::StreamExt;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::{Arc, Mutex};
14use std::time::{Duration, Instant};
15use tokio::task::JoinHandle;
16use tokio::time::{MissedTickBehavior, interval};
17use tracing::{debug, error, info, warn};
18
19/// Configuration for health check behavior
20pub struct HealthCheckConfig {
21    /// Wait time before sending canary health checks (when no activity)
22    pub canary_wait_time: Duration,
23    /// Timeout for health check requests
24    pub request_timeout: Duration,
25}
26
27impl Default for HealthCheckConfig {
28    fn default() -> Self {
29        Self {
30            canary_wait_time: Duration::from_secs(crate::config::DEFAULT_CANARY_WAIT_TIME_SECS),
31            request_timeout: Duration::from_secs(
32                crate::config::DEFAULT_HEALTH_CHECK_REQUEST_TIMEOUT_SECS,
33            ),
34        }
35    }
36}
37
38// Type alias for the router cache to improve readability
39// Maps endpoint subject -> router and payload
40type RouterCache =
41    Arc<Mutex<HashMap<String, Arc<PushRouter<serde_json::Value, Annotated<serde_json::Value>>>>>>;
42
43/// Health check manager that monitors endpoint health
44pub struct HealthCheckManager {
45    drt: DistributedRuntime,
46    config: HealthCheckConfig,
47    /// Cache of PushRouters and payloads for each endpoint
48    router_cache: RouterCache,
49    /// Track per-endpoint health check tasks
50    /// Maps: endpoint_subject -> task_handle
51    endpoint_tasks: Arc<Mutex<HashMap<String, JoinHandle<()>>>>,
52}
53
54impl HealthCheckManager {
55    pub fn new(drt: DistributedRuntime, config: HealthCheckConfig) -> Self {
56        Self {
57            drt,
58            config,
59            router_cache: Arc::new(Mutex::new(HashMap::new())),
60            endpoint_tasks: Arc::new(Mutex::new(HashMap::new())),
61        }
62    }
63
64    /// Get or create a PushRouter for an endpoint
65    async fn get_or_create_router(
66        &self,
67        cache_key: &str,
68        endpoint: Endpoint,
69    ) -> anyhow::Result<Arc<PushRouter<serde_json::Value, Annotated<serde_json::Value>>>> {
70        let cache_key = cache_key.to_string();
71
72        // Check cache first
73        {
74            let cache = self.router_cache.lock().unwrap();
75            if let Some(router) = cache.get(&cache_key) {
76                return Ok(router.clone());
77            }
78        }
79
80        // Create a client that discovers instances dynamically for this endpoint
81        let client = Client::new_dynamic(endpoint).await?;
82
83        // Create PushRouter - it will use direct routing when we call direct()
84        let router: Arc<PushRouter<serde_json::Value, Annotated<serde_json::Value>>> = Arc::new(
85            PushRouter::from_client(
86                client,
87                crate::pipeline::RouterMode::RoundRobin, // Default mode, we'll use direct() explicitly
88            )
89            .await?,
90        );
91
92        // Cache it
93        self.router_cache
94            .lock()
95            .unwrap()
96            .insert(cache_key, router.clone());
97
98        Ok(router)
99    }
100
101    /// Start the health check manager by spawning per-endpoint monitoring tasks
102    pub async fn start(self: Arc<Self>) -> anyhow::Result<()> {
103        // Get all registered endpoints at startup
104        let targets = {
105            let system_health = self.drt.system_health.lock().unwrap();
106            system_health.get_health_check_targets()
107        };
108
109        info!(
110            "Starting health check tasks for {} endpoints with canary_wait_time: {:?}",
111            targets.len(),
112            self.config.canary_wait_time
113        );
114
115        // Spawn a health check task for each registered endpoint
116        for (endpoint_subject, _target) in targets {
117            self.spawn_endpoint_health_check_task(endpoint_subject);
118        }
119
120        // CRITICAL: Spawn a task to monitor for NEW endpoints registered after startup
121        // This uses a channel-based approach to guarantee no lost notifications
122        // Will return an error if the receiver has already been taken
123        self.spawn_new_endpoint_monitor().await?;
124
125        info!("HealthCheckManager started successfully with channel-based endpoint discovery");
126        Ok(())
127    }
128
129    /// Spawn a dedicated health check task for a specific endpoint
130    fn spawn_endpoint_health_check_task(self: &Arc<Self>, endpoint_subject: String) {
131        let manager = self.clone();
132        let canary_wait = self.config.canary_wait_time;
133        let endpoint_subject_clone = endpoint_subject.clone();
134
135        // Get the endpoint-specific notifier
136        let notifier = {
137            let system_health = self.drt.system_health.lock().unwrap();
138            system_health
139                .get_endpoint_health_check_notifier(&endpoint_subject)
140                .expect("Notifier should exist for registered endpoint")
141        };
142
143        let task = tokio::spawn(async move {
144            let endpoint_subject = endpoint_subject_clone;
145            info!("Health check task started for: {}", endpoint_subject);
146
147            loop {
148                // Wait for either timeout or activity notification
149                tokio::select! {
150                    _ = tokio::time::sleep(canary_wait) => {
151                        // Timeout - send health check for this specific endpoint
152                        info!("Canary timer expired for {}, sending health check", endpoint_subject);
153
154                        // Get the health check payload for this endpoint
155                        let target = {
156                            let system_health = manager.drt.system_health.lock().unwrap();
157                            system_health.get_health_check_target(&endpoint_subject)
158                        };
159
160                        if let Some(target) = target {
161                            if let Err(e) = manager.send_health_check_request(&endpoint_subject, &target.payload).await {
162                                error!("Failed to send health check for {}: {}", endpoint_subject, e);
163                            }
164                        } else {
165                            // This should never happen - targets are registered at startup and never removed
166                            error!(
167                                "CRITICAL: Health check target for {} disappeared unexpectedly! This indicates a bug. Stopping health check task.",
168                                endpoint_subject
169                            );
170                            break;
171                        }
172                    }
173
174                    _ = notifier.notified() => {
175                        // Activity detected - reset timer for this endpoint only
176                        debug!("Activity detected for {}, resetting health check timer", endpoint_subject);
177                        // Loop continues, timer resets
178                    }
179                }
180            }
181
182            info!("Health check task for {} exiting", endpoint_subject);
183        });
184
185        // Store the task handle
186        self.endpoint_tasks
187            .lock()
188            .unwrap()
189            .insert(endpoint_subject.clone(), task);
190
191        info!(
192            "Spawned health check task for endpoint: {}",
193            endpoint_subject
194        );
195    }
196
197    /// Spawn a task to monitor for newly registered endpoints
198    /// Returns an error if duplicate endpoints are detected, indicating a bug in the system
199    async fn spawn_new_endpoint_monitor(self: &Arc<Self>) -> anyhow::Result<()> {
200        let manager = self.clone();
201
202        // Get the receiver (can only be taken once)
203        let mut rx = {
204            let system_health = manager.drt.system_health.lock().unwrap();
205            system_health.take_new_endpoint_receiver().ok_or_else(|| {
206                anyhow::anyhow!("Endpoint receiver already taken - this should only be called once")
207            })?
208        };
209
210        tokio::spawn(async move {
211            info!("Starting dynamic endpoint discovery monitor with channel-based notifications");
212
213            while let Some(endpoint_subject) = rx.recv().await {
214                debug!(
215                    "Received endpoint registration via channel: {}",
216                    endpoint_subject
217                );
218
219                let already_exists = {
220                    let tasks = manager.endpoint_tasks.lock().unwrap();
221                    tasks.contains_key(&endpoint_subject)
222                };
223
224                if already_exists {
225                    error!(
226                        "CRITICAL: Received registration for endpoint '{}' that already has a health check task!",
227                        endpoint_subject
228                    );
229                    break;
230                }
231
232                info!(
233                    "Spawning health check task for new endpoint: {}",
234                    endpoint_subject
235                );
236                manager.spawn_endpoint_health_check_task(endpoint_subject);
237            }
238
239            info!("Endpoint discovery monitor exiting - no new endpoints will be monitored!");
240        });
241
242        info!("Dynamic endpoint discovery monitor started");
243        Ok(())
244    }
245
246    /// Send a health check request through AsyncEngine
247    async fn send_health_check_request(
248        &self,
249        endpoint_subject: &str,
250        payload: &serde_json::Value,
251    ) -> anyhow::Result<()> {
252        let target = {
253            let system_health = self.drt.system_health.lock().unwrap();
254            system_health
255                .get_health_check_target(endpoint_subject)
256                .ok_or_else(|| {
257                    anyhow::anyhow!("No health check target found for {}", endpoint_subject)
258                })?
259        };
260
261        debug!(
262            "Sending health check to {} (instance_id: {})",
263            endpoint_subject, target.instance.instance_id
264        );
265
266        // Create the Endpoint directly from the Instance info
267        let namespace = self.drt.namespace(&target.instance.namespace)?;
268        let component = namespace.component(&target.instance.component)?;
269        let endpoint = component.endpoint(&target.instance.endpoint);
270
271        // Get or create router for this endpoint
272        let router = self
273            .get_or_create_router(endpoint_subject, endpoint)
274            .await?;
275
276        // Create the request context
277        let request: SingleIn<serde_json::Value> = Context::new(payload.clone());
278
279        // Clone what we need for the spawned task
280        let system_health = self.drt.system_health.clone();
281        let endpoint_subject_owned = endpoint_subject.to_string();
282        let instance_id = target.instance.instance_id;
283        let timeout = self.config.request_timeout;
284
285        // Spawn task to send health check and wait for response
286        tokio::spawn(async move {
287            let result = tokio::time::timeout(timeout, async {
288                // Call direct() on the PushRouter to target specific instance
289                match router.direct(request, instance_id).await {
290                    Ok(mut response_stream) => {
291                        // Get the first response to verify endpoint is alive
292                        let is_healthy = if let Some(response) = response_stream.next().await {
293                            // Check if response indicates an error
294                            if let Some(error) = response.err() {
295                                warn!(
296                                    "Health check error response from {}: {:?}",
297                                    endpoint_subject_owned, error
298                                );
299                                false
300                            } else {
301                                info!("Health check successful for {}", endpoint_subject_owned);
302                                true
303                            }
304                        } else {
305                            warn!(
306                                "Health check got no response from {}",
307                                endpoint_subject_owned
308                            );
309                            false
310                        };
311
312                        // Update health status based on response
313                        system_health.lock().unwrap().set_endpoint_health_status(
314                            &endpoint_subject_owned,
315                            if is_healthy {
316                                HealthStatus::Ready
317                            } else {
318                                HealthStatus::NotReady
319                            },
320                        );
321                    }
322                    Err(e) => {
323                        error!(
324                            "Health check request failed for {}: {}",
325                            endpoint_subject_owned, e
326                        );
327                        system_health.lock().unwrap().set_endpoint_health_status(
328                            &endpoint_subject_owned,
329                            HealthStatus::NotReady,
330                        );
331                    }
332                }
333            })
334            .await;
335
336            // Handle timeout
337            if result.is_err() {
338                warn!("Health check timeout for {}", endpoint_subject_owned);
339                system_health
340                    .lock()
341                    .unwrap()
342                    .set_endpoint_health_status(&endpoint_subject_owned, HealthStatus::NotReady);
343            }
344
345            debug!("Health check completed for {}", endpoint_subject_owned);
346        });
347
348        Ok(())
349    }
350}
351
352/// Start health check manager for the distributed runtime
353pub async fn start_health_check_manager(
354    drt: DistributedRuntime,
355    config: Option<HealthCheckConfig>,
356) -> anyhow::Result<()> {
357    let config = config.unwrap_or_default();
358    let manager = Arc::new(HealthCheckManager::new(drt, config));
359
360    // Start the health check manager (this spawns per-endpoint tasks internally)
361    manager.start().await?;
362
363    Ok(())
364}
365
366/// Get health check status for all endpoints
367pub async fn get_health_check_status(
368    drt: &DistributedRuntime,
369) -> anyhow::Result<serde_json::Value> {
370    // Get endpoints list from SystemHealth
371    let endpoint_subjects: Vec<String> = {
372        let system_health = drt.system_health.lock().unwrap();
373        system_health.get_health_check_endpoints()
374    };
375
376    let mut endpoint_statuses = HashMap::new();
377
378    // Check each endpoint's health status
379    {
380        let system_health = drt.system_health.lock().unwrap();
381        for endpoint_subject in &endpoint_subjects {
382            let health_status = system_health
383                .get_endpoint_health_status(endpoint_subject)
384                .unwrap_or(HealthStatus::NotReady);
385
386            let is_healthy = matches!(health_status, HealthStatus::Ready);
387
388            endpoint_statuses.insert(
389                endpoint_subject.clone(),
390                serde_json::json!({
391                    "healthy": is_healthy,
392                    "status": format!("{:?}", health_status),
393                }),
394            );
395        }
396    }
397
398    let overall_healthy = endpoint_statuses
399        .values()
400        .all(|v| v["healthy"].as_bool().unwrap_or(false));
401
402    Ok(serde_json::json!({
403        "status": if overall_healthy { "ready" } else { "notready" },
404        "endpoints_checked": endpoint_subjects.len(),
405        "endpoint_statuses": endpoint_statuses,
406    }))
407}
408
409// ===============================
410// Integration Tests (require DRT)
411// ===============================
412#[cfg(all(test, feature = "integration"))]
413mod integration_tests {
414    use super::*;
415    use crate::HealthStatus;
416    use crate::distributed::distributed_test_utils::create_test_drt_async;
417    use std::sync::Arc;
418    use std::time::Duration;
419
420    #[tokio::test]
421    async fn test_initialization() {
422        let drt = create_test_drt_async().await;
423
424        let canary_wait_time = Duration::from_secs(5);
425        let request_timeout = Duration::from_secs(3);
426
427        let config = HealthCheckConfig {
428            canary_wait_time,
429            request_timeout,
430        };
431
432        let manager = HealthCheckManager::new(drt.clone(), config);
433
434        assert_eq!(manager.config.canary_wait_time, canary_wait_time);
435        assert_eq!(manager.config.request_timeout, request_timeout);
436
437        assert!(Arc::ptr_eq(&manager.drt.system_health, &drt.system_health));
438    }
439
440    #[tokio::test]
441    async fn test_payload_registration() {
442        let drt = create_test_drt_async().await;
443
444        let endpoint = "test.endpoint";
445        let payload = serde_json::json!({
446            "prompt": "test",
447            "_health_check": true
448        });
449
450        drt.system_health
451            .lock()
452            .unwrap()
453            .register_health_check_target(
454                endpoint,
455                crate::component::Instance {
456                    component: "test_component".to_string(),
457                    endpoint: "test_endpoint".to_string(),
458                    namespace: "test_namespace".to_string(),
459                    instance_id: 12345,
460                    transport: crate::component::TransportType::NatsTcp(endpoint.to_string()),
461                },
462                payload.clone(),
463            );
464
465        let retrieved = drt
466            .system_health
467            .lock()
468            .unwrap()
469            .get_health_check_target(endpoint)
470            .map(|t| t.payload);
471        assert!(retrieved.is_some());
472        assert_eq!(retrieved.unwrap(), payload);
473
474        // Verify endpoint appears in the list
475        let endpoints = drt
476            .system_health
477            .lock()
478            .unwrap()
479            .get_health_check_endpoints();
480        assert!(endpoints.contains(&endpoint.to_string()));
481    }
482
483    #[tokio::test]
484    async fn test_spawn_per_endpoint_tasks() {
485        let drt = create_test_drt_async().await;
486
487        for i in 0..3 {
488            let endpoint = format!("test.endpoint.{}", i);
489            let payload = serde_json::json!({
490                "prompt": format!("test{}", i),
491                "_health_check": true
492            });
493            drt.system_health
494                .lock()
495                .unwrap()
496                .register_health_check_target(
497                    &endpoint,
498                    crate::component::Instance {
499                        component: "test_component".to_string(),
500                        endpoint: format!("test_endpoint_{}", i),
501                        namespace: "test_namespace".to_string(),
502                        instance_id: i as i64,
503                        transport: crate::component::TransportType::NatsTcp(endpoint.clone()),
504                    },
505                    payload,
506                );
507        }
508
509        let config = HealthCheckConfig {
510            canary_wait_time: Duration::from_secs(5),
511            request_timeout: Duration::from_secs(1),
512        };
513
514        let manager = Arc::new(HealthCheckManager::new(drt.clone(), config));
515        manager.clone().start().await.unwrap();
516
517        // Verify all endpoints have their own health check tasks
518        let tasks = manager.endpoint_tasks.lock().unwrap();
519        // Should have 3 tasks (one for each endpoint)
520        assert_eq!(tasks.len(), 3);
521        // Check that all endpoints are represented in tasks
522        let endpoints: Vec<String> = tasks.keys().cloned().collect();
523        assert!(endpoints.contains(&"test.endpoint.0".to_string()));
524        assert!(endpoints.contains(&"test.endpoint.1".to_string()));
525        assert!(endpoints.contains(&"test.endpoint.2".to_string()));
526    }
527
528    #[tokio::test]
529    async fn test_endpoint_health_check_notifier_created() {
530        let drt = create_test_drt_async().await;
531
532        let endpoint = "test.endpoint.notifier";
533        let payload = serde_json::json!({
534            "prompt": "test",
535            "_health_check": true
536        });
537
538        // Register the endpoint
539        drt.system_health
540            .lock()
541            .unwrap()
542            .register_health_check_target(
543                endpoint,
544                crate::component::Instance {
545                    component: "test_component".to_string(),
546                    endpoint: "test_endpoint_notifier".to_string(),
547                    namespace: "test_namespace".to_string(),
548                    instance_id: 999,
549                    transport: crate::component::TransportType::NatsTcp(endpoint.to_string()),
550                },
551                payload.clone(),
552            );
553
554        // Verify that a notifier was created for this endpoint
555        let notifier = drt
556            .system_health
557            .lock()
558            .unwrap()
559            .get_endpoint_health_check_notifier(endpoint);
560
561        assert!(
562            notifier.is_some(),
563            "Endpoint should have a notifier created"
564        );
565
566        // Verify we can notify it without panicking
567        if let Some(notifier) = notifier {
568            notifier.notify_one();
569        }
570
571        // Initially, the endpoint should be Ready (default after registration)
572        let status = drt
573            .system_health
574            .lock()
575            .unwrap()
576            .get_endpoint_health_status(endpoint);
577        assert_eq!(status, Some(HealthStatus::Ready));
578    }
579}