1use crate::DistributedRuntime;
5use crate::config::HealthStatus;
6use crate::engine::AsyncEngine;
7use crate::pipeline::SingleIn;
8use crate::protocols::maybe_error::MaybeError;
9use futures::StreamExt;
10use parking_lot::Mutex;
11use std::collections::HashMap;
12use std::sync::Arc;
13use std::time::Duration;
14use tokio::task::JoinHandle;
15use tracing::{debug, error, info, warn};
16
17pub struct HealthCheckConfig {
19 pub canary_wait_time: Duration,
21 pub request_timeout: Duration,
23}
24
25impl Default for HealthCheckConfig {
26 fn default() -> Self {
27 Self {
28 canary_wait_time: Duration::from_secs(crate::config::DEFAULT_CANARY_WAIT_TIME_SECS),
29 request_timeout: Duration::from_secs(
30 crate::config::DEFAULT_HEALTH_CHECK_REQUEST_TIMEOUT_SECS,
31 ),
32 }
33 }
34}
35
36pub struct HealthCheckManager {
38 drt: DistributedRuntime,
39 config: HealthCheckConfig,
40 endpoint_tasks: Arc<Mutex<HashMap<String, JoinHandle<()>>>>,
43}
44
45impl HealthCheckManager {
46 pub fn new(drt: DistributedRuntime, config: HealthCheckConfig) -> Self {
47 Self {
48 drt,
49 config,
50 endpoint_tasks: Arc::new(Mutex::new(HashMap::new())),
51 }
52 }
53
54 pub async fn start(self: Arc<Self>) -> anyhow::Result<()> {
56 let targets = self.drt.system_health().lock().get_health_check_targets();
58
59 info!(
60 "Starting health check tasks for {} endpoints with canary_wait_time: {:?}",
61 targets.len(),
62 self.config.canary_wait_time
63 );
64
65 for (endpoint_subject, _target) in targets {
67 self.spawn_endpoint_health_check_task(endpoint_subject);
68 }
69
70 self.spawn_new_endpoint_monitor().await?;
74
75 info!("HealthCheckManager started successfully with channel-based endpoint discovery");
76 Ok(())
77 }
78
79 fn spawn_endpoint_health_check_task(self: &Arc<Self>, endpoint_subject: String) {
81 let manager = self.clone();
82 let canary_wait = self.config.canary_wait_time;
83 let endpoint_subject_clone = endpoint_subject.clone();
84
85 let notifier = self
87 .drt
88 .system_health()
89 .lock()
90 .get_endpoint_health_check_notifier(&endpoint_subject)
91 .expect("Notifier should exist for registered endpoint");
92
93 let task = tokio::spawn(async move {
94 let endpoint_subject = endpoint_subject_clone;
95 info!("Health check task started for: {}", endpoint_subject);
96
97 loop {
98 tokio::select! {
100 _ = tokio::time::sleep(canary_wait) => {
101 debug!("Canary timer expired for {}, sending health check", endpoint_subject);
103
104 let target = manager.drt.system_health().lock().get_health_check_target(&endpoint_subject);
106
107 if let Some(target) = target {
108 if let Err(e) = manager.send_health_check_request(&endpoint_subject, &target.payload).await {
109 error!("Failed to send health check for {}: {}", endpoint_subject, e);
110 }
111 } else {
112 error!(
114 "CRITICAL: Health check target for {} disappeared unexpectedly! This indicates a bug. Stopping health check task.",
115 endpoint_subject
116 );
117 break;
118 }
119 }
120
121 _ = notifier.notified() => {
122 debug!("Activity detected for {}, resetting health check timer", endpoint_subject);
126 manager.drt.system_health().lock().set_endpoint_health_status(
127 &endpoint_subject,
128 crate::config::HealthStatus::Ready,
129 );
130 }
131 }
132 }
133
134 info!("Health check task for {} exiting", endpoint_subject);
135 });
136
137 self.endpoint_tasks
139 .lock()
140 .insert(endpoint_subject.clone(), task);
141
142 info!(
143 "Spawned health check task for endpoint: {}",
144 endpoint_subject
145 );
146 }
147
148 async fn spawn_new_endpoint_monitor(self: &Arc<Self>) -> anyhow::Result<()> {
151 let manager = self.clone();
152
153 let mut rx = manager
155 .drt
156 .system_health()
157 .lock()
158 .take_new_endpoint_receiver()
159 .ok_or_else(|| {
160 anyhow::anyhow!("Endpoint receiver already taken - this should only be called once")
161 })?;
162
163 tokio::spawn(async move {
164 info!("Starting dynamic endpoint discovery monitor with channel-based notifications");
165
166 while let Some(endpoint_subject) = rx.recv().await {
167 debug!(
168 "Received endpoint registration via channel: {}",
169 endpoint_subject
170 );
171
172 let already_exists = {
173 let tasks = manager.endpoint_tasks.lock();
174 tasks.contains_key(&endpoint_subject)
175 };
176
177 if already_exists {
178 error!(
179 "CRITICAL: Received registration for endpoint '{}' that already has a health check task!",
180 endpoint_subject
181 );
182 break;
183 }
184
185 info!(
186 "Spawning health check task for new endpoint: {}",
187 endpoint_subject
188 );
189 manager.spawn_endpoint_health_check_task(endpoint_subject);
190 }
191
192 info!("Endpoint discovery monitor exiting - no new endpoints will be monitored!");
193 });
194
195 info!("Dynamic endpoint discovery monitor started");
196 Ok(())
197 }
198
199 async fn send_health_check_request(
201 &self,
202 endpoint_subject: &str,
203 payload: &serde_json::Value,
204 ) -> anyhow::Result<()> {
205 debug!(
206 "Sending health check to {} via local registry",
207 endpoint_subject
208 );
209
210 let engine = self
211 .drt
212 .local_endpoint_registry()
213 .get(endpoint_subject)
214 .ok_or_else(|| {
215 anyhow::anyhow!(
216 "Endpoint '{}' not found in local registry, engine may still be initializing",
217 endpoint_subject
218 )
219 })?;
220
221 let system_health = self.drt.system_health().clone();
223 let endpoint_subject_owned = endpoint_subject.to_string();
224 let payload = payload.clone();
225 let timeout = self.config.request_timeout;
226
227 tokio::spawn(async move {
229 let result = tokio::time::timeout(timeout, async {
230 let request = SingleIn::new(payload);
231 match engine.generate(request).await {
232 Ok(mut response_stream) => {
233 let is_healthy = if let Some(response) = response_stream.next().await {
236 if let Some(error) = response.err() {
237 warn!(
238 "Health check error response from {}: {:?}",
239 endpoint_subject_owned, error
240 );
241 false
242 } else {
243 debug!("Health check successful for {}", endpoint_subject_owned);
244 true
245 }
246 } else {
247 warn!(
248 "Health check got no response from {}",
249 endpoint_subject_owned
250 );
251 false
252 };
253
254 tokio::spawn(async move {
255 response_stream.for_each(|_| async {}).await;
257 });
258
259 system_health.lock().set_endpoint_health_status(
261 &endpoint_subject_owned,
262 if is_healthy {
263 HealthStatus::Ready
264 } else {
265 HealthStatus::NotReady
266 },
267 );
268 }
269 Err(e) => {
270 error!(
271 "Health check request failed for {}: {}",
272 endpoint_subject_owned, e
273 );
274 system_health.lock().set_endpoint_health_status(
275 &endpoint_subject_owned,
276 HealthStatus::NotReady,
277 );
278 }
279 }
280 })
281 .await;
282
283 if result.is_err() {
285 warn!("Health check timeout for {}", endpoint_subject_owned);
286 system_health
287 .lock()
288 .set_endpoint_health_status(&endpoint_subject_owned, HealthStatus::NotReady);
289 }
290
291 debug!("Health check completed for {}", endpoint_subject_owned);
292 });
293
294 Ok(())
295 }
296}
297
298pub async fn start_health_check_manager(
300 drt: DistributedRuntime,
301 config: Option<HealthCheckConfig>,
302) -> anyhow::Result<()> {
303 let config = config.unwrap_or_default();
304 let manager = Arc::new(HealthCheckManager::new(drt, config));
305
306 manager.start().await?;
308
309 Ok(())
310}
311
312pub async fn get_health_check_status(
314 drt: &DistributedRuntime,
315) -> anyhow::Result<serde_json::Value> {
316 let endpoint_subjects: Vec<String> = drt.system_health().lock().get_health_check_endpoints();
318
319 let mut endpoint_statuses = HashMap::new();
320
321 {
323 let system_health = drt.system_health();
324 let system_health_lock = system_health.lock();
325 for endpoint_subject in &endpoint_subjects {
326 let health_status = system_health_lock
327 .get_endpoint_health_status(endpoint_subject)
328 .unwrap_or(HealthStatus::NotReady);
329
330 let is_healthy = matches!(health_status, HealthStatus::Ready);
331
332 endpoint_statuses.insert(
333 endpoint_subject.clone(),
334 serde_json::json!({
335 "healthy": is_healthy,
336 "status": format!("{:?}", health_status),
337 }),
338 );
339 }
340 }
341
342 let overall_healthy = endpoint_statuses
343 .values()
344 .all(|v| v["healthy"].as_bool().unwrap_or(false));
345
346 Ok(serde_json::json!({
347 "status": if overall_healthy { "ready" } else { "notready" },
348 "endpoints_checked": endpoint_subjects.len(),
349 "endpoint_statuses": endpoint_statuses,
350 }))
351}
352
353#[cfg(all(test, feature = "integration"))]
359mod push_handler_notify_tests {
360 use super::*;
361 use crate::component::{Instance, TransportType};
362 use crate::config::HealthStatus;
363 use crate::distributed::distributed_test_utils::create_test_drt_async;
364 use crate::engine::{AsyncEngine, AsyncEngineContextProvider};
365 use crate::local_endpoint_registry::LocalAsyncEngine;
366 use crate::pipeline::network::codec::{TwoPartCodec, TwoPartMessage};
367 use crate::pipeline::network::tcp::server::{ServerOptions, TcpStreamServer};
368 use crate::pipeline::network::{
369 ConnectionInfo, Ingress, PushWorkHandler, ResponseService, StreamOptions,
370 };
371 use crate::pipeline::{ManyOut, ResponseStream, SingleIn};
372 use crate::protocols::annotated::Annotated;
373 use async_trait::async_trait;
374 use bytes::Bytes;
375 use futures::stream;
376 use std::sync::Arc;
377 use std::time::Duration;
378
379 type TestRequest = serde_json::Value;
380 type TestResponse = Annotated<serde_json::Value>;
381
382 struct MockStreamingEngine {
386 num_chunks: usize,
387 error_indices: Vec<usize>,
389 }
390
391 impl MockStreamingEngine {
392 fn success(num_chunks: usize) -> Arc<Self> {
393 Arc::new(Self {
394 num_chunks,
395 error_indices: vec![],
396 })
397 }
398
399 fn all_errors(num_chunks: usize) -> Arc<Self> {
400 Arc::new(Self {
401 num_chunks,
402 error_indices: (0..num_chunks).collect(),
403 })
404 }
405
406 fn with_error_at(num_chunks: usize, error_indices: Vec<usize>) -> Arc<Self> {
407 Arc::new(Self {
408 num_chunks,
409 error_indices,
410 })
411 }
412 }
413
414 #[async_trait]
415 impl AsyncEngine<SingleIn<TestRequest>, ManyOut<TestResponse>, anyhow::Error>
416 for MockStreamingEngine
417 {
418 async fn generate(
419 &self,
420 input: SingleIn<TestRequest>,
421 ) -> anyhow::Result<ManyOut<TestResponse>> {
422 let (_data, ctx) = input.into_parts();
423 let chunks: Vec<TestResponse> = (0..self.num_chunks)
424 .map(|i| {
425 if self.error_indices.contains(&i) {
426 Annotated::from_error(format!("mock error at chunk {i}"))
427 } else {
428 Annotated::from_data(serde_json::json!({"token": i}))
429 }
430 })
431 .collect();
432 Ok(ResponseStream::new(
433 Box::pin(stream::iter(chunks)),
434 ctx.context(),
435 ))
436 }
437 }
438
439 fn encode_request(
441 request_id: &str,
442 connection_info: ConnectionInfo,
443 request_body: &serde_json::Value,
444 ) -> Bytes {
445 let control = serde_json::json!({
446 "id": request_id,
447 "request_type": "single_in",
448 "response_type": "many_out",
449 "connection_info": connection_info,
450 });
451 let header = serde_json::to_vec(&control).unwrap();
452 let data = serde_json::to_vec(request_body).unwrap();
453 let msg = TwoPartMessage::new(Bytes::from(header), Bytes::from(data));
454 TwoPartCodec::default().encode_message(msg).unwrap()
455 }
456
457 async fn setup_tcp_receiver(request_id: &str) -> (Arc<TcpStreamServer>, ConnectionInfo) {
460 let options = ServerOptions::builder().port(0).build().unwrap();
461 let server = TcpStreamServer::new(options).await.unwrap();
462
463 let context = crate::pipeline::Context::with_id_and_metadata(
464 (),
465 request_id.to_string(),
466 Default::default(),
467 );
468 let stream_options = StreamOptions::builder()
469 .context(context.context())
470 .enable_request_stream(false)
471 .enable_response_stream(true)
472 .build()
473 .unwrap();
474
475 let pending = server.register(stream_options).await;
476 let connection_info = pending
477 .recv_stream
478 .as_ref()
479 .unwrap()
480 .connection_info
481 .clone();
482
483 (server, connection_info)
484 }
485
486 fn register_endpoint(
489 drt: &crate::DistributedRuntime,
490 endpoint_name: &str,
491 local_engine: LocalAsyncEngine,
492 ) -> Arc<tokio::sync::Notify> {
493 let payload = serde_json::json!({
494 "prompt": "health",
495 "_health_check": true
496 });
497 drt.system_health().lock().register_health_check_target(
498 endpoint_name,
499 Instance {
500 component: "test_component".to_string(),
501 endpoint: endpoint_name.to_string(),
502 namespace: "test_namespace".to_string(),
503 instance_id: 0,
504 transport: TransportType::Nats(endpoint_name.to_string()),
505 device_type: None,
506 },
507 payload,
508 );
509
510 drt.local_endpoint_registry()
511 .register(endpoint_name.to_string(), local_engine);
512
513 drt.system_health()
514 .lock()
515 .get_endpoint_health_check_notifier(endpoint_name)
516 .expect("Notifier should exist for registered endpoint")
517 }
518
519 async fn send_request(ingress: &Ingress<SingleIn<TestRequest>, ManyOut<TestResponse>>) {
521 let request_id = uuid::Uuid::new_v4().to_string();
522 let (_server, connection_info) = setup_tcp_receiver(&request_id).await;
523 let payload = encode_request(
524 &request_id,
525 connection_info,
526 &serde_json::json!({"prompt": "test"}),
527 );
528 let result = ingress.handle_payload(payload, Some(request_id)).await;
529 assert!(result.is_ok(), "handle_payload should succeed");
530 }
531
532 fn assert_status(
534 drt: &crate::DistributedRuntime,
535 endpoint_name: &str,
536 expected: HealthStatus,
537 msg: &str,
538 ) {
539 let status = drt
540 .system_health()
541 .lock()
542 .get_endpoint_health_status(endpoint_name);
543 assert_eq!(status, Some(expected), "{msg}");
544 }
545
546 fn create_ingress(
548 engine: Arc<MockStreamingEngine>,
549 notifier: Arc<tokio::sync::Notify>,
550 ) -> Arc<Ingress<SingleIn<TestRequest>, ManyOut<TestResponse>>> {
551 let ingress =
552 Ingress::<SingleIn<TestRequest>, ManyOut<TestResponse>>::for_engine(engine).unwrap();
553 ingress
554 .set_endpoint_health_check_notifier(notifier)
555 .unwrap();
556 ingress
557 }
558
559 async fn start_manager(drt: &crate::DistributedRuntime, canary_wait_ms: u64) {
561 let config = HealthCheckConfig {
562 canary_wait_time: Duration::from_millis(canary_wait_ms),
563 request_timeout: Duration::from_secs(1),
564 };
565 let manager = Arc::new(HealthCheckManager::new(drt.clone(), config));
566 manager.start().await.unwrap();
567 }
568
569 #[tokio::test]
574 async fn test_successful_streaming_sets_ready() {
575 let drt = create_test_drt_async().await;
576 let endpoint = "test.successful_streaming";
577
578 let notifier = register_endpoint(&drt, endpoint, MockStreamingEngine::all_errors(1));
579 assert_status(&drt, endpoint, HealthStatus::NotReady, "initial");
580
581 let ingress = create_ingress(MockStreamingEngine::success(5), notifier);
582 start_manager(&drt, 500).await;
583
584 send_request(&ingress).await;
585 tokio::time::sleep(Duration::from_millis(200)).await;
586
587 assert_status(
589 &drt,
590 endpoint,
591 HealthStatus::Ready,
592 "successful streaming should set Ready via notification path",
593 );
594 }
595
596 #[tokio::test]
600 async fn test_canary_fires_on_idle_engine() {
601 let drt = create_test_drt_async().await;
602 let endpoint = "test.canary_idle";
603
604 let _notifier = register_endpoint(&drt, endpoint, MockStreamingEngine::success(1));
605 assert_status(&drt, endpoint, HealthStatus::NotReady, "initial");
606
607 start_manager(&drt, 50).await;
608 tokio::time::sleep(Duration::from_millis(300)).await;
609
610 assert_status(
612 &drt,
613 endpoint,
614 HealthStatus::Ready,
615 "canary should fire and set Ready on idle engine",
616 );
617 }
618
619 #[tokio::test]
623 async fn test_error_streaming_stays_not_ready() {
624 let drt = create_test_drt_async().await;
625 let endpoint = "test.error_streaming";
626
627 let notifier = register_endpoint(&drt, endpoint, MockStreamingEngine::all_errors(1));
628 assert_status(&drt, endpoint, HealthStatus::NotReady, "initial");
629
630 let ingress = create_ingress(MockStreamingEngine::all_errors(3), notifier);
632 start_manager(&drt, 50).await;
633
634 send_request(&ingress).await;
635 tokio::time::sleep(Duration::from_millis(300)).await;
637
638 assert_status(
640 &drt,
641 endpoint,
642 HealthStatus::NotReady,
643 "error streaming should not notify, canary also errors — stays NotReady",
644 );
645 }
646
647 #[tokio::test]
651 async fn test_idle_engine_with_failing_canary() {
652 let drt = create_test_drt_async().await;
653 let endpoint = "test.canary_fails";
654
655 let _notifier = register_endpoint(&drt, endpoint, MockStreamingEngine::all_errors(1));
656 assert_status(&drt, endpoint, HealthStatus::NotReady, "initial");
657
658 start_manager(&drt, 50).await;
659 tokio::time::sleep(Duration::from_millis(300)).await;
660
661 assert_status(
663 &drt,
664 endpoint,
665 HealthStatus::NotReady,
666 "canary fired but engine errored, status stays NotReady",
667 );
668 }
669
670 #[tokio::test]
676 async fn test_mixed_streaming_sets_ready() {
677 let drt = create_test_drt_async().await;
678 let endpoint = "test.mixed_streaming";
679
680 let notifier = register_endpoint(&drt, endpoint, MockStreamingEngine::all_errors(1));
681 assert_status(&drt, endpoint, HealthStatus::NotReady, "initial");
682
683 let ingress = create_ingress(MockStreamingEngine::with_error_at(5, vec![4]), notifier);
685 start_manager(&drt, 500).await;
686
687 send_request(&ingress).await;
688 tokio::time::sleep(Duration::from_millis(200)).await;
689
690 assert_status(
692 &drt,
693 endpoint,
694 HealthStatus::Ready,
695 "successful chunks should set Ready despite trailing error",
696 );
697 }
698}
699
700#[cfg(all(test, feature = "integration"))]
704mod integration_tests {
705 use super::*;
706 use crate::distributed::distributed_test_utils::create_test_drt_async;
707 use std::sync::Arc;
708 use std::time::Duration;
709
710 #[tokio::test]
711 async fn test_initialization() {
712 let drt = create_test_drt_async().await;
713
714 let canary_wait_time = Duration::from_secs(5);
715 let request_timeout = Duration::from_secs(3);
716
717 let config = HealthCheckConfig {
718 canary_wait_time,
719 request_timeout,
720 };
721
722 let manager = HealthCheckManager::new(drt.clone(), config);
723
724 assert_eq!(manager.config.canary_wait_time, canary_wait_time);
725 assert_eq!(manager.config.request_timeout, request_timeout);
726 }
727
728 #[tokio::test]
729 async fn test_payload_registration() {
730 let drt = create_test_drt_async().await;
731
732 let endpoint = "test.endpoint";
733 let payload = serde_json::json!({
734 "prompt": "test",
735 "_health_check": true
736 });
737
738 drt.system_health().lock().register_health_check_target(
739 endpoint,
740 crate::component::Instance {
741 component: "test_component".to_string(),
742 endpoint: "test_endpoint".to_string(),
743 namespace: "test_namespace".to_string(),
744 instance_id: 12345,
745 transport: crate::component::TransportType::Nats(endpoint.to_string()),
746 device_type: None,
747 },
748 payload.clone(),
749 );
750
751 let retrieved = drt
752 .system_health()
753 .lock()
754 .get_health_check_target(endpoint)
755 .map(|t| t.payload);
756 assert!(retrieved.is_some());
757 assert_eq!(retrieved.unwrap(), payload);
758
759 let endpoints = drt.system_health().lock().get_health_check_endpoints();
761 assert!(endpoints.contains(&endpoint.to_string()));
762 }
763
764 #[tokio::test]
765 async fn test_spawn_per_endpoint_tasks() {
766 let drt = create_test_drt_async().await;
767
768 for i in 0..3 {
769 let endpoint = format!("test.endpoint.{}", i);
770 let payload = serde_json::json!({
771 "prompt": format!("test{}", i),
772 "_health_check": true
773 });
774 drt.system_health().lock().register_health_check_target(
775 &endpoint,
776 crate::component::Instance {
777 component: "test_component".to_string(),
778 endpoint: format!("test_endpoint_{}", i),
779 namespace: "test_namespace".to_string(),
780 instance_id: i,
781 transport: crate::component::TransportType::Nats(endpoint.clone()),
782 device_type: None,
783 },
784 payload,
785 );
786 }
787
788 let config = HealthCheckConfig {
789 canary_wait_time: Duration::from_secs(5),
790 request_timeout: Duration::from_secs(1),
791 };
792
793 let manager = Arc::new(HealthCheckManager::new(drt.clone(), config));
794 manager.clone().start().await.unwrap();
795
796 let tasks = manager.endpoint_tasks.lock();
798 assert_eq!(tasks.len(), 3);
800 let endpoints: Vec<String> = tasks.keys().cloned().collect();
802 assert!(endpoints.contains(&"test.endpoint.0".to_string()));
803 assert!(endpoints.contains(&"test.endpoint.1".to_string()));
804 assert!(endpoints.contains(&"test.endpoint.2".to_string()));
805 }
806
807 #[tokio::test]
808 async fn test_endpoint_health_check_notifier_created() {
809 let drt = create_test_drt_async().await;
810
811 let endpoint = "test.endpoint.notifier";
812 let payload = serde_json::json!({
813 "prompt": "test",
814 "_health_check": true
815 });
816
817 drt.system_health().lock().register_health_check_target(
819 endpoint,
820 crate::component::Instance {
821 component: "test_component".to_string(),
822 endpoint: "test_endpoint_notifier".to_string(),
823 namespace: "test_namespace".to_string(),
824 instance_id: 999,
825 transport: crate::component::TransportType::Nats(endpoint.to_string()),
826 device_type: None,
827 },
828 payload.clone(),
829 );
830
831 let notifier = drt
833 .system_health()
834 .lock()
835 .get_endpoint_health_check_notifier(endpoint);
836
837 assert!(
838 notifier.is_some(),
839 "Endpoint should have a notifier created"
840 );
841
842 if let Some(notifier) = notifier {
844 notifier.notify_one();
845 }
846
847 let status = drt
849 .system_health()
850 .lock()
851 .get_endpoint_health_status(endpoint);
852 assert_eq!(status, Some(HealthStatus::NotReady));
853 }
854}