Skip to main content

dynamo_runtime/
health_check.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::DistributedRuntime;
5use crate::config::HealthStatus;
6use crate::engine::AsyncEngine;
7use crate::pipeline::SingleIn;
8use crate::protocols::maybe_error::MaybeError;
9use futures::StreamExt;
10use parking_lot::Mutex;
11use std::collections::HashMap;
12use std::sync::Arc;
13use std::time::Duration;
14use tokio::task::JoinHandle;
15use tracing::{debug, error, info, warn};
16
17/// Configuration for health check behavior
18pub struct HealthCheckConfig {
19    /// Wait time before sending canary health checks (when no activity)
20    pub canary_wait_time: Duration,
21    /// Timeout for health check requests
22    pub request_timeout: Duration,
23}
24
25impl Default for HealthCheckConfig {
26    fn default() -> Self {
27        Self {
28            canary_wait_time: Duration::from_secs(crate::config::DEFAULT_CANARY_WAIT_TIME_SECS),
29            request_timeout: Duration::from_secs(
30                crate::config::DEFAULT_HEALTH_CHECK_REQUEST_TIMEOUT_SECS,
31            ),
32        }
33    }
34}
35
36/// Health check manager that monitors endpoint health
37pub struct HealthCheckManager {
38    drt: DistributedRuntime,
39    config: HealthCheckConfig,
40    /// Track per-endpoint health check tasks
41    /// Maps: endpoint_subject -> task_handle
42    endpoint_tasks: Arc<Mutex<HashMap<String, JoinHandle<()>>>>,
43}
44
45impl HealthCheckManager {
46    pub fn new(drt: DistributedRuntime, config: HealthCheckConfig) -> Self {
47        Self {
48            drt,
49            config,
50            endpoint_tasks: Arc::new(Mutex::new(HashMap::new())),
51        }
52    }
53
54    /// Start the health check manager by spawning per-endpoint monitoring tasks
55    pub async fn start(self: Arc<Self>) -> anyhow::Result<()> {
56        // Get all registered endpoints at startup
57        let targets = self.drt.system_health().lock().get_health_check_targets();
58
59        info!(
60            "Starting health check tasks for {} endpoints with canary_wait_time: {:?}",
61            targets.len(),
62            self.config.canary_wait_time
63        );
64
65        // Spawn a health check task for each registered endpoint
66        for (endpoint_subject, _target) in targets {
67            self.spawn_endpoint_health_check_task(endpoint_subject);
68        }
69
70        // CRITICAL: Spawn a task to monitor for NEW endpoints registered after startup
71        // This uses a channel-based approach to guarantee no lost notifications
72        // Will return an error if the receiver has already been taken
73        self.spawn_new_endpoint_monitor().await?;
74
75        info!("HealthCheckManager started successfully with channel-based endpoint discovery");
76        Ok(())
77    }
78
79    /// Spawn a dedicated health check task for a specific endpoint
80    fn spawn_endpoint_health_check_task(self: &Arc<Self>, endpoint_subject: String) {
81        let manager = self.clone();
82        let canary_wait = self.config.canary_wait_time;
83        let endpoint_subject_clone = endpoint_subject.clone();
84
85        // Get the endpoint-specific notifier
86        let notifier = self
87            .drt
88            .system_health()
89            .lock()
90            .get_endpoint_health_check_notifier(&endpoint_subject)
91            .expect("Notifier should exist for registered endpoint");
92
93        let task = tokio::spawn(async move {
94            let endpoint_subject = endpoint_subject_clone;
95            info!("Health check task started for: {}", endpoint_subject);
96
97            loop {
98                // Wait for either timeout or activity notification
99                tokio::select! {
100                    _ = tokio::time::sleep(canary_wait) => {
101                        // Timeout - send health check for this specific endpoint
102                        debug!("Canary timer expired for {}, sending health check", endpoint_subject);
103
104                        // Get the health check payload for this endpoint
105                        let target = manager.drt.system_health().lock().get_health_check_target(&endpoint_subject);
106
107                        if let Some(target) = target {
108                            if let Err(e) = manager.send_health_check_request(&endpoint_subject, &target.payload).await {
109                                error!("Failed to send health check for {}: {}", endpoint_subject, e);
110                            }
111                        } else {
112                            // This should never happen - targets are registered at startup and never removed
113                            error!(
114                                "CRITICAL: Health check target for {} disappeared unexpectedly! This indicates a bug. Stopping health check task.",
115                                endpoint_subject
116                            );
117                            break;
118                        }
119                    }
120
121                    _ = notifier.notified() => {
122                        // Activity detected - reset timer for this endpoint only.
123                        // A notification means push_handler successfully streamed
124                        // a non-error response chunk, proving the engine is healthy.
125                        debug!("Activity detected for {}, resetting health check timer", endpoint_subject);
126                        manager.drt.system_health().lock().set_endpoint_health_status(
127                            &endpoint_subject,
128                            crate::config::HealthStatus::Ready,
129                        );
130                    }
131                }
132            }
133
134            info!("Health check task for {} exiting", endpoint_subject);
135        });
136
137        // Store the task handle
138        self.endpoint_tasks
139            .lock()
140            .insert(endpoint_subject.clone(), task);
141
142        info!(
143            "Spawned health check task for endpoint: {}",
144            endpoint_subject
145        );
146    }
147
148    /// Spawn a task to monitor for newly registered endpoints
149    /// Returns an error if duplicate endpoints are detected, indicating a bug in the system
150    async fn spawn_new_endpoint_monitor(self: &Arc<Self>) -> anyhow::Result<()> {
151        let manager = self.clone();
152
153        // Get the receiver (can only be taken once)
154        let mut rx = manager
155            .drt
156            .system_health()
157            .lock()
158            .take_new_endpoint_receiver()
159            .ok_or_else(|| {
160                anyhow::anyhow!("Endpoint receiver already taken - this should only be called once")
161            })?;
162
163        tokio::spawn(async move {
164            info!("Starting dynamic endpoint discovery monitor with channel-based notifications");
165
166            while let Some(endpoint_subject) = rx.recv().await {
167                debug!(
168                    "Received endpoint registration via channel: {}",
169                    endpoint_subject
170                );
171
172                let already_exists = {
173                    let tasks = manager.endpoint_tasks.lock();
174                    tasks.contains_key(&endpoint_subject)
175                };
176
177                if already_exists {
178                    error!(
179                        "CRITICAL: Received registration for endpoint '{}' that already has a health check task!",
180                        endpoint_subject
181                    );
182                    break;
183                }
184
185                info!(
186                    "Spawning health check task for new endpoint: {}",
187                    endpoint_subject
188                );
189                manager.spawn_endpoint_health_check_task(endpoint_subject);
190            }
191
192            info!("Endpoint discovery monitor exiting - no new endpoints will be monitored!");
193        });
194
195        info!("Dynamic endpoint discovery monitor started");
196        Ok(())
197    }
198
199    /// Send a health check request via the local endpoint registry (in-process).
200    async fn send_health_check_request(
201        &self,
202        endpoint_subject: &str,
203        payload: &serde_json::Value,
204    ) -> anyhow::Result<()> {
205        debug!(
206            "Sending health check to {} via local registry",
207            endpoint_subject
208        );
209
210        let engine = self
211            .drt
212            .local_endpoint_registry()
213            .get(endpoint_subject)
214            .ok_or_else(|| {
215                anyhow::anyhow!(
216                    "Endpoint '{}' not found in local registry, engine may still be initializing",
217                    endpoint_subject
218                )
219            })?;
220
221        // Clone what we need for the spawned task
222        let system_health = self.drt.system_health().clone();
223        let endpoint_subject_owned = endpoint_subject.to_string();
224        let payload = payload.clone();
225        let timeout = self.config.request_timeout;
226
227        // Spawn task to send health check and wait for response
228        tokio::spawn(async move {
229            let result = tokio::time::timeout(timeout, async {
230                let request = SingleIn::new(payload);
231                match engine.generate(request).await {
232                    Ok(mut response_stream) => {
233                        // Get the first response to verify endpoint is alive.
234                        // Check for errors
235                        let is_healthy = if let Some(response) = response_stream.next().await {
236                            if let Some(error) = response.err() {
237                                warn!(
238                                    "Health check error response from {}: {:?}",
239                                    endpoint_subject_owned, error
240                                );
241                                false
242                            } else {
243                                debug!("Health check successful for {}", endpoint_subject_owned);
244                                true
245                            }
246                        } else {
247                            warn!(
248                                "Health check got no response from {}",
249                                endpoint_subject_owned
250                            );
251                            false
252                        };
253
254                        tokio::spawn(async move {
255                            // We need to consume the rest of the stream to avoid warnings on the frontend.
256                            response_stream.for_each(|_| async {}).await;
257                        });
258
259                        // Update health status based on response
260                        system_health.lock().set_endpoint_health_status(
261                            &endpoint_subject_owned,
262                            if is_healthy {
263                                HealthStatus::Ready
264                            } else {
265                                HealthStatus::NotReady
266                            },
267                        );
268                    }
269                    Err(e) => {
270                        error!(
271                            "Health check request failed for {}: {}",
272                            endpoint_subject_owned, e
273                        );
274                        system_health.lock().set_endpoint_health_status(
275                            &endpoint_subject_owned,
276                            HealthStatus::NotReady,
277                        );
278                    }
279                }
280            })
281            .await;
282
283            // Handle timeout
284            if result.is_err() {
285                warn!("Health check timeout for {}", endpoint_subject_owned);
286                system_health
287                    .lock()
288                    .set_endpoint_health_status(&endpoint_subject_owned, HealthStatus::NotReady);
289            }
290
291            debug!("Health check completed for {}", endpoint_subject_owned);
292        });
293
294        Ok(())
295    }
296}
297
298/// Start health check manager for the distributed runtime
299pub async fn start_health_check_manager(
300    drt: DistributedRuntime,
301    config: Option<HealthCheckConfig>,
302) -> anyhow::Result<()> {
303    let config = config.unwrap_or_default();
304    let manager = Arc::new(HealthCheckManager::new(drt, config));
305
306    // Start the health check manager (this spawns per-endpoint tasks internally)
307    manager.start().await?;
308
309    Ok(())
310}
311
312/// Get health check status for all endpoints
313pub async fn get_health_check_status(
314    drt: &DistributedRuntime,
315) -> anyhow::Result<serde_json::Value> {
316    // Get endpoints list from SystemHealth
317    let endpoint_subjects: Vec<String> = drt.system_health().lock().get_health_check_endpoints();
318
319    let mut endpoint_statuses = HashMap::new();
320
321    // Check each endpoint's health status
322    {
323        let system_health = drt.system_health();
324        let system_health_lock = system_health.lock();
325        for endpoint_subject in &endpoint_subjects {
326            let health_status = system_health_lock
327                .get_endpoint_health_status(endpoint_subject)
328                .unwrap_or(HealthStatus::NotReady);
329
330            let is_healthy = matches!(health_status, HealthStatus::Ready);
331
332            endpoint_statuses.insert(
333                endpoint_subject.clone(),
334                serde_json::json!({
335                    "healthy": is_healthy,
336                    "status": format!("{:?}", health_status),
337                }),
338            );
339        }
340    }
341
342    let overall_healthy = endpoint_statuses
343        .values()
344        .all(|v| v["healthy"].as_bool().unwrap_or(false));
345
346    Ok(serde_json::json!({
347        "status": if overall_healthy { "ready" } else { "notready" },
348        "endpoints_checked": endpoint_subjects.len(),
349        "endpoint_statuses": endpoint_statuses,
350    }))
351}
352
353// ============================================================
354// Full pipeline tests: push_handler → notify → HealthCheckManager
355// These tests use the real HealthCheckManager (spawn_endpoint_health_check_task)
356// and the real push_handler pipeline (TwoPartCodec + TCP + engine.generate()).
357// ============================================================
358#[cfg(all(test, feature = "integration"))]
359mod push_handler_notify_tests {
360    use super::*;
361    use crate::component::{Instance, TransportType};
362    use crate::config::HealthStatus;
363    use crate::distributed::distributed_test_utils::create_test_drt_async;
364    use crate::engine::{AsyncEngine, AsyncEngineContextProvider};
365    use crate::local_endpoint_registry::LocalAsyncEngine;
366    use crate::pipeline::network::codec::{TwoPartCodec, TwoPartMessage};
367    use crate::pipeline::network::tcp::server::{ServerOptions, TcpStreamServer};
368    use crate::pipeline::network::{
369        ConnectionInfo, Ingress, PushWorkHandler, ResponseService, StreamOptions,
370    };
371    use crate::pipeline::{ManyOut, ResponseStream, SingleIn};
372    use crate::protocols::annotated::Annotated;
373    use async_trait::async_trait;
374    use bytes::Bytes;
375    use futures::stream;
376    use std::sync::Arc;
377    use std::time::Duration;
378
379    type TestRequest = serde_json::Value;
380    type TestResponse = Annotated<serde_json::Value>;
381
382    /// A mock engine that streams a configurable sequence of success/error chunks.
383    /// Used both as the push_handler pipeline engine and registered in
384    /// the local endpoint registry for health check requests.
385    struct MockStreamingEngine {
386        num_chunks: usize,
387        /// If set, chunks at these indices will be error responses.
388        error_indices: Vec<usize>,
389    }
390
391    impl MockStreamingEngine {
392        fn success(num_chunks: usize) -> Arc<Self> {
393            Arc::new(Self {
394                num_chunks,
395                error_indices: vec![],
396            })
397        }
398
399        fn all_errors(num_chunks: usize) -> Arc<Self> {
400            Arc::new(Self {
401                num_chunks,
402                error_indices: (0..num_chunks).collect(),
403            })
404        }
405
406        fn with_error_at(num_chunks: usize, error_indices: Vec<usize>) -> Arc<Self> {
407            Arc::new(Self {
408                num_chunks,
409                error_indices,
410            })
411        }
412    }
413
414    #[async_trait]
415    impl AsyncEngine<SingleIn<TestRequest>, ManyOut<TestResponse>, anyhow::Error>
416        for MockStreamingEngine
417    {
418        async fn generate(
419            &self,
420            input: SingleIn<TestRequest>,
421        ) -> anyhow::Result<ManyOut<TestResponse>> {
422            let (_data, ctx) = input.into_parts();
423            let chunks: Vec<TestResponse> = (0..self.num_chunks)
424                .map(|i| {
425                    if self.error_indices.contains(&i) {
426                        Annotated::from_error(format!("mock error at chunk {i}"))
427                    } else {
428                        Annotated::from_data(serde_json::json!({"token": i}))
429                    }
430                })
431                .collect();
432            Ok(ResponseStream::new(
433                Box::pin(stream::iter(chunks)),
434                ctx.context(),
435            ))
436        }
437    }
438
439    /// Encodes a request as a TwoPartCodec payload with the given connection info.
440    fn encode_request(
441        request_id: &str,
442        connection_info: ConnectionInfo,
443        request_body: &serde_json::Value,
444    ) -> Bytes {
445        let control = serde_json::json!({
446            "id": request_id,
447            "request_type": "single_in",
448            "response_type": "many_out",
449            "connection_info": connection_info,
450        });
451        let header = serde_json::to_vec(&control).unwrap();
452        let data = serde_json::to_vec(request_body).unwrap();
453        let msg = TwoPartMessage::new(Bytes::from(header), Bytes::from(data));
454        TwoPartCodec::default().encode_message(msg).unwrap()
455    }
456
457    /// Sets up a TCP server and registers a response stream for push_handler
458    /// to connect back to.
459    async fn setup_tcp_receiver(request_id: &str) -> (Arc<TcpStreamServer>, ConnectionInfo) {
460        let options = ServerOptions::builder().port(0).build().unwrap();
461        let server = TcpStreamServer::new(options).await.unwrap();
462
463        let context = crate::pipeline::Context::with_id_and_metadata(
464            (),
465            request_id.to_string(),
466            Default::default(),
467        );
468        let stream_options = StreamOptions::builder()
469            .context(context.context())
470            .enable_request_stream(false)
471            .enable_response_stream(true)
472            .build()
473            .unwrap();
474
475        let pending = server.register(stream_options).await;
476        let connection_info = pending
477            .recv_stream
478            .as_ref()
479            .unwrap()
480            .connection_info
481            .clone();
482
483        (server, connection_info)
484    }
485
486    /// Registers an endpoint in the DRT with the given engine in local registry.
487    /// Returns the notifier that the real HealthCheckManager will listen on.
488    fn register_endpoint(
489        drt: &crate::DistributedRuntime,
490        endpoint_name: &str,
491        local_engine: LocalAsyncEngine,
492    ) -> Arc<tokio::sync::Notify> {
493        let payload = serde_json::json!({
494            "prompt": "health",
495            "_health_check": true
496        });
497        drt.system_health().lock().register_health_check_target(
498            endpoint_name,
499            Instance {
500                component: "test_component".to_string(),
501                endpoint: endpoint_name.to_string(),
502                namespace: "test_namespace".to_string(),
503                instance_id: 0,
504                transport: TransportType::Nats(endpoint_name.to_string()),
505                device_type: None,
506            },
507            payload,
508        );
509
510        drt.local_endpoint_registry()
511            .register(endpoint_name.to_string(), local_engine);
512
513        drt.system_health()
514            .lock()
515            .get_endpoint_health_check_notifier(endpoint_name)
516            .expect("Notifier should exist for registered endpoint")
517    }
518
519    /// Helper: send a request through the ingress pipeline.
520    async fn send_request(ingress: &Ingress<SingleIn<TestRequest>, ManyOut<TestResponse>>) {
521        let request_id = uuid::Uuid::new_v4().to_string();
522        let (_server, connection_info) = setup_tcp_receiver(&request_id).await;
523        let payload = encode_request(
524            &request_id,
525            connection_info,
526            &serde_json::json!({"prompt": "test"}),
527        );
528        let result = ingress.handle_payload(payload, Some(request_id)).await;
529        assert!(result.is_ok(), "handle_payload should succeed");
530    }
531
532    /// Helper: assert endpoint health status.
533    fn assert_status(
534        drt: &crate::DistributedRuntime,
535        endpoint_name: &str,
536        expected: HealthStatus,
537        msg: &str,
538    ) {
539        let status = drt
540            .system_health()
541            .lock()
542            .get_endpoint_health_status(endpoint_name);
543        assert_eq!(status, Some(expected), "{msg}");
544    }
545
546    /// Helper: create ingress pipeline with given engine and notifier.
547    fn create_ingress(
548        engine: Arc<MockStreamingEngine>,
549        notifier: Arc<tokio::sync::Notify>,
550    ) -> Arc<Ingress<SingleIn<TestRequest>, ManyOut<TestResponse>>> {
551        let ingress =
552            Ingress::<SingleIn<TestRequest>, ManyOut<TestResponse>>::for_engine(engine).unwrap();
553        ingress
554            .set_endpoint_health_check_notifier(notifier)
555            .unwrap();
556        ingress
557    }
558
559    /// Helper: start HealthCheckManager with given canary wait.
560    async fn start_manager(drt: &crate::DistributedRuntime, canary_wait_ms: u64) {
561        let config = HealthCheckConfig {
562            canary_wait_time: Duration::from_millis(canary_wait_ms),
563            request_timeout: Duration::from_secs(1),
564        };
565        let manager = Arc::new(HealthCheckManager::new(drt.clone(), config));
566        manager.start().await.unwrap();
567    }
568
569    // =================================================================
570    // Test 1: Successful streaming → notification → Ready
571    // Canary engine returns errors, so Ready can only come from notify.
572    // =================================================================
573    #[tokio::test]
574    async fn test_successful_streaming_sets_ready() {
575        let drt = create_test_drt_async().await;
576        let endpoint = "test.successful_streaming";
577
578        let notifier = register_endpoint(&drt, endpoint, MockStreamingEngine::all_errors(1));
579        assert_status(&drt, endpoint, HealthStatus::NotReady, "initial");
580
581        let ingress = create_ingress(MockStreamingEngine::success(5), notifier);
582        start_manager(&drt, 500).await;
583
584        send_request(&ingress).await;
585        tokio::time::sleep(Duration::from_millis(200)).await;
586
587        // Ready can only come from notification (canary engine errors)
588        assert_status(
589            &drt,
590            endpoint,
591            HealthStatus::Ready,
592            "successful streaming should set Ready via notification path",
593        );
594    }
595
596    // =================================================================
597    // Test 2: Idle engine → canary fires → successful health check → Ready
598    // =================================================================
599    #[tokio::test]
600    async fn test_canary_fires_on_idle_engine() {
601        let drt = create_test_drt_async().await;
602        let endpoint = "test.canary_idle";
603
604        let _notifier = register_endpoint(&drt, endpoint, MockStreamingEngine::success(1));
605        assert_status(&drt, endpoint, HealthStatus::NotReady, "initial");
606
607        start_manager(&drt, 50).await;
608        tokio::time::sleep(Duration::from_millis(300)).await;
609
610        // No requests sent — canary fired and succeeded
611        assert_status(
612            &drt,
613            endpoint,
614            HealthStatus::Ready,
615            "canary should fire and set Ready on idle engine",
616        );
617    }
618
619    // =================================================================
620    // Test 3: Error streaming → no notification → canary errors → NotReady
621    // =================================================================
622    #[tokio::test]
623    async fn test_error_streaming_stays_not_ready() {
624        let drt = create_test_drt_async().await;
625        let endpoint = "test.error_streaming";
626
627        let notifier = register_endpoint(&drt, endpoint, MockStreamingEngine::all_errors(1));
628        assert_status(&drt, endpoint, HealthStatus::NotReady, "initial");
629
630        // Pipeline streams only errors — no notifications sent
631        let ingress = create_ingress(MockStreamingEngine::all_errors(3), notifier);
632        start_manager(&drt, 50).await;
633
634        send_request(&ingress).await;
635        // Wait for canary to fire (50ms wait + margin)
636        tokio::time::sleep(Duration::from_millis(300)).await;
637
638        // Error streaming didn't notify, canary fired but engine also errored
639        assert_status(
640            &drt,
641            endpoint,
642            HealthStatus::NotReady,
643            "error streaming should not notify, canary also errors — stays NotReady",
644        );
645    }
646
647    // =================================================================
648    // Test 4: Idle engine → canary fires → failing health check → NotReady
649    // =================================================================
650    #[tokio::test]
651    async fn test_idle_engine_with_failing_canary() {
652        let drt = create_test_drt_async().await;
653        let endpoint = "test.canary_fails";
654
655        let _notifier = register_endpoint(&drt, endpoint, MockStreamingEngine::all_errors(1));
656        assert_status(&drt, endpoint, HealthStatus::NotReady, "initial");
657
658        start_manager(&drt, 50).await;
659        tokio::time::sleep(Duration::from_millis(300)).await;
660
661        // No requests sent, canary fired but engine returned error
662        assert_status(
663            &drt,
664            endpoint,
665            HealthStatus::NotReady,
666            "canary fired but engine errored, status stays NotReady",
667        );
668    }
669
670    // =================================================================
671    // Test 5: Mixed streaming (success + trailing error) → Ready
672    // Successful chunks notify before the error, so status becomes Ready.
673    // Canary engine errors, proving Ready came from notification path.
674    // =================================================================
675    #[tokio::test]
676    async fn test_mixed_streaming_sets_ready() {
677        let drt = create_test_drt_async().await;
678        let endpoint = "test.mixed_streaming";
679
680        let notifier = register_endpoint(&drt, endpoint, MockStreamingEngine::all_errors(1));
681        assert_status(&drt, endpoint, HealthStatus::NotReady, "initial");
682
683        // 5 chunks: 4 success + error at index 4
684        let ingress = create_ingress(MockStreamingEngine::with_error_at(5, vec![4]), notifier);
685        start_manager(&drt, 500).await;
686
687        send_request(&ingress).await;
688        tokio::time::sleep(Duration::from_millis(200)).await;
689
690        // Successful chunks notified before the error chunk
691        assert_status(
692            &drt,
693            endpoint,
694            HealthStatus::Ready,
695            "successful chunks should set Ready despite trailing error",
696        );
697    }
698}
699
700// ===============================
701// Integration Tests (require DRT)
702// ===============================
703#[cfg(all(test, feature = "integration"))]
704mod integration_tests {
705    use super::*;
706    use crate::distributed::distributed_test_utils::create_test_drt_async;
707    use std::sync::Arc;
708    use std::time::Duration;
709
710    #[tokio::test]
711    async fn test_initialization() {
712        let drt = create_test_drt_async().await;
713
714        let canary_wait_time = Duration::from_secs(5);
715        let request_timeout = Duration::from_secs(3);
716
717        let config = HealthCheckConfig {
718            canary_wait_time,
719            request_timeout,
720        };
721
722        let manager = HealthCheckManager::new(drt.clone(), config);
723
724        assert_eq!(manager.config.canary_wait_time, canary_wait_time);
725        assert_eq!(manager.config.request_timeout, request_timeout);
726    }
727
728    #[tokio::test]
729    async fn test_payload_registration() {
730        let drt = create_test_drt_async().await;
731
732        let endpoint = "test.endpoint";
733        let payload = serde_json::json!({
734            "prompt": "test",
735            "_health_check": true
736        });
737
738        drt.system_health().lock().register_health_check_target(
739            endpoint,
740            crate::component::Instance {
741                component: "test_component".to_string(),
742                endpoint: "test_endpoint".to_string(),
743                namespace: "test_namespace".to_string(),
744                instance_id: 12345,
745                transport: crate::component::TransportType::Nats(endpoint.to_string()),
746                device_type: None,
747            },
748            payload.clone(),
749        );
750
751        let retrieved = drt
752            .system_health()
753            .lock()
754            .get_health_check_target(endpoint)
755            .map(|t| t.payload);
756        assert!(retrieved.is_some());
757        assert_eq!(retrieved.unwrap(), payload);
758
759        // Verify endpoint appears in the list
760        let endpoints = drt.system_health().lock().get_health_check_endpoints();
761        assert!(endpoints.contains(&endpoint.to_string()));
762    }
763
764    #[tokio::test]
765    async fn test_spawn_per_endpoint_tasks() {
766        let drt = create_test_drt_async().await;
767
768        for i in 0..3 {
769            let endpoint = format!("test.endpoint.{}", i);
770            let payload = serde_json::json!({
771                "prompt": format!("test{}", i),
772                "_health_check": true
773            });
774            drt.system_health().lock().register_health_check_target(
775                &endpoint,
776                crate::component::Instance {
777                    component: "test_component".to_string(),
778                    endpoint: format!("test_endpoint_{}", i),
779                    namespace: "test_namespace".to_string(),
780                    instance_id: i,
781                    transport: crate::component::TransportType::Nats(endpoint.clone()),
782                    device_type: None,
783                },
784                payload,
785            );
786        }
787
788        let config = HealthCheckConfig {
789            canary_wait_time: Duration::from_secs(5),
790            request_timeout: Duration::from_secs(1),
791        };
792
793        let manager = Arc::new(HealthCheckManager::new(drt.clone(), config));
794        manager.clone().start().await.unwrap();
795
796        // Verify all endpoints have their own health check tasks
797        let tasks = manager.endpoint_tasks.lock();
798        // Should have 3 tasks (one for each endpoint)
799        assert_eq!(tasks.len(), 3);
800        // Check that all endpoints are represented in tasks
801        let endpoints: Vec<String> = tasks.keys().cloned().collect();
802        assert!(endpoints.contains(&"test.endpoint.0".to_string()));
803        assert!(endpoints.contains(&"test.endpoint.1".to_string()));
804        assert!(endpoints.contains(&"test.endpoint.2".to_string()));
805    }
806
807    #[tokio::test]
808    async fn test_endpoint_health_check_notifier_created() {
809        let drt = create_test_drt_async().await;
810
811        let endpoint = "test.endpoint.notifier";
812        let payload = serde_json::json!({
813            "prompt": "test",
814            "_health_check": true
815        });
816
817        // Register the endpoint
818        drt.system_health().lock().register_health_check_target(
819            endpoint,
820            crate::component::Instance {
821                component: "test_component".to_string(),
822                endpoint: "test_endpoint_notifier".to_string(),
823                namespace: "test_namespace".to_string(),
824                instance_id: 999,
825                transport: crate::component::TransportType::Nats(endpoint.to_string()),
826                device_type: None,
827            },
828            payload.clone(),
829        );
830
831        // Verify that a notifier was created for this endpoint
832        let notifier = drt
833            .system_health()
834            .lock()
835            .get_endpoint_health_check_notifier(endpoint);
836
837        assert!(
838            notifier.is_some(),
839            "Endpoint should have a notifier created"
840        );
841
842        // Verify we can notify it without panicking
843        if let Some(notifier) = notifier {
844            notifier.notify_one();
845        }
846
847        // Initially, the endpoint should be Ready (default after registration)
848        let status = drt
849            .system_health()
850            .lock()
851            .get_endpoint_health_status(endpoint);
852        assert_eq!(status, Some(HealthStatus::NotReady));
853    }
854}