1use crate::DistributedRuntime;
5use crate::config::HealthStatus;
6use crate::engine::AsyncEngine;
7use crate::pipeline::SingleIn;
8use crate::protocols::maybe_error::MaybeError;
9use futures::StreamExt;
10use parking_lot::Mutex;
11use std::collections::HashMap;
12use std::sync::Arc;
13use std::time::Duration;
14use tokio::task::JoinHandle;
15use tracing::{debug, error, info, warn};
16
17pub struct HealthCheckConfig {
19 pub canary_wait_time: Duration,
21 pub request_timeout: Duration,
23}
24
25impl Default for HealthCheckConfig {
26 fn default() -> Self {
27 Self {
28 canary_wait_time: Duration::from_secs(crate::config::DEFAULT_CANARY_WAIT_TIME_SECS),
29 request_timeout: Duration::from_secs(
30 crate::config::DEFAULT_HEALTH_CHECK_REQUEST_TIMEOUT_SECS,
31 ),
32 }
33 }
34}
35
36pub struct HealthCheckManager {
38 drt: DistributedRuntime,
39 config: HealthCheckConfig,
40 endpoint_tasks: Arc<Mutex<HashMap<String, JoinHandle<()>>>>,
43}
44
45impl HealthCheckManager {
46 pub fn new(drt: DistributedRuntime, config: HealthCheckConfig) -> Self {
47 Self {
48 drt,
49 config,
50 endpoint_tasks: Arc::new(Mutex::new(HashMap::new())),
51 }
52 }
53
54 pub async fn start(self: Arc<Self>) -> anyhow::Result<()> {
56 let targets = self.drt.system_health().lock().get_health_check_targets();
58
59 info!(
60 "Starting health check tasks for {} endpoints with canary_wait_time: {:?}",
61 targets.len(),
62 self.config.canary_wait_time
63 );
64
65 for (endpoint_subject, _target) in targets {
67 self.spawn_endpoint_health_check_task(endpoint_subject);
68 }
69
70 self.spawn_new_endpoint_monitor().await?;
74
75 info!("HealthCheckManager started successfully with channel-based endpoint discovery");
76 Ok(())
77 }
78
79 fn spawn_endpoint_health_check_task(self: &Arc<Self>, endpoint_subject: String) {
81 let manager = self.clone();
82 let canary_wait = self.config.canary_wait_time;
83 let endpoint_subject_clone = endpoint_subject.clone();
84
85 let notifier = self
87 .drt
88 .system_health()
89 .lock()
90 .get_endpoint_health_check_notifier(&endpoint_subject)
91 .expect("Notifier should exist for registered endpoint");
92
93 let task = tokio::spawn(async move {
94 let endpoint_subject = endpoint_subject_clone;
95 info!("Health check task started for: {}", endpoint_subject);
96
97 loop {
98 tokio::select! {
100 _ = tokio::time::sleep(canary_wait) => {
101 debug!("Canary timer expired for {}, sending health check", endpoint_subject);
103
104 let target = manager.drt.system_health().lock().get_health_check_target(&endpoint_subject);
106
107 if let Some(target) = target {
108 if let Err(e) = manager.send_health_check_request(&endpoint_subject, &target.payload).await {
109 error!("Failed to send health check for {}: {}", endpoint_subject, e);
110 }
111 } else {
112 error!(
114 "CRITICAL: Health check target for {} disappeared unexpectedly! This indicates a bug. Stopping health check task.",
115 endpoint_subject
116 );
117 break;
118 }
119 }
120
121 _ = notifier.notified() => {
122 debug!("Activity detected for {}, resetting health check timer", endpoint_subject);
126 manager.drt.system_health().lock().set_endpoint_health_status(
127 &endpoint_subject,
128 crate::config::HealthStatus::Ready,
129 );
130 }
131 }
132 }
133
134 info!("Health check task for {} exiting", endpoint_subject);
135 });
136
137 self.endpoint_tasks
139 .lock()
140 .insert(endpoint_subject.clone(), task);
141
142 info!(
143 "Spawned health check task for endpoint: {}",
144 endpoint_subject
145 );
146 }
147
148 async fn spawn_new_endpoint_monitor(self: &Arc<Self>) -> anyhow::Result<()> {
151 let manager = self.clone();
152
153 let mut rx = manager
155 .drt
156 .system_health()
157 .lock()
158 .take_new_endpoint_receiver()
159 .ok_or_else(|| {
160 anyhow::anyhow!("Endpoint receiver already taken - this should only be called once")
161 })?;
162
163 tokio::spawn(async move {
164 info!("Starting dynamic endpoint discovery monitor with channel-based notifications");
165
166 while let Some(endpoint_subject) = rx.recv().await {
167 debug!(
168 "Received endpoint registration via channel: {}",
169 endpoint_subject
170 );
171
172 let already_exists = {
173 let tasks = manager.endpoint_tasks.lock();
174 tasks.contains_key(&endpoint_subject)
175 };
176
177 if already_exists {
178 error!(
179 "CRITICAL: Received registration for endpoint '{}' that already has a health check task!",
180 endpoint_subject
181 );
182 break;
183 }
184
185 info!(
186 "Spawning health check task for new endpoint: {}",
187 endpoint_subject
188 );
189 manager.spawn_endpoint_health_check_task(endpoint_subject);
190 }
191
192 info!("Endpoint discovery monitor exiting - no new endpoints will be monitored!");
193 });
194
195 info!("Dynamic endpoint discovery monitor started");
196 Ok(())
197 }
198
199 async fn send_health_check_request(
201 &self,
202 endpoint_subject: &str,
203 payload: &serde_json::Value,
204 ) -> anyhow::Result<()> {
205 debug!(
206 "Sending health check to {} via local registry",
207 endpoint_subject
208 );
209
210 let engine = self
211 .drt
212 .local_endpoint_registry()
213 .get(endpoint_subject)
214 .ok_or_else(|| {
215 anyhow::anyhow!(
216 "Endpoint '{}' not found in local registry, engine may still be initializing",
217 endpoint_subject
218 )
219 })?;
220
221 let system_health = self.drt.system_health().clone();
223 let endpoint_subject_owned = endpoint_subject.to_string();
224 let payload = payload.clone();
225 let timeout = self.config.request_timeout;
226
227 tokio::spawn(async move {
229 let result = tokio::time::timeout(timeout, async {
230 let request = SingleIn::new(payload);
231 match engine.generate(request).await {
232 Ok(mut response_stream) => {
233 let is_healthy = if let Some(response) = response_stream.next().await {
236 if let Some(error) = response.err() {
237 warn!(
238 "Health check error response from {}: {:?}",
239 endpoint_subject_owned, error
240 );
241 false
242 } else {
243 debug!("Health check successful for {}", endpoint_subject_owned);
244 true
245 }
246 } else {
247 warn!(
248 "Health check got no response from {}",
249 endpoint_subject_owned
250 );
251 false
252 };
253
254 tokio::spawn(async move {
255 response_stream.for_each(|_| async {}).await;
257 });
258
259 system_health.lock().set_endpoint_health_status(
261 &endpoint_subject_owned,
262 if is_healthy {
263 HealthStatus::Ready
264 } else {
265 HealthStatus::NotReady
266 },
267 );
268 }
269 Err(e) => {
270 error!(
271 "Health check request failed for {}: {}",
272 endpoint_subject_owned, e
273 );
274 system_health.lock().set_endpoint_health_status(
275 &endpoint_subject_owned,
276 HealthStatus::NotReady,
277 );
278 }
279 }
280 })
281 .await;
282
283 if result.is_err() {
285 warn!("Health check timeout for {}", endpoint_subject_owned);
286 system_health
287 .lock()
288 .set_endpoint_health_status(&endpoint_subject_owned, HealthStatus::NotReady);
289 }
290
291 debug!("Health check completed for {}", endpoint_subject_owned);
292 });
293
294 Ok(())
295 }
296}
297
298pub async fn start_health_check_manager(
300 drt: DistributedRuntime,
301 config: Option<HealthCheckConfig>,
302) -> anyhow::Result<()> {
303 let config = config.unwrap_or_default();
304 let manager = Arc::new(HealthCheckManager::new(drt, config));
305
306 manager.start().await?;
308
309 Ok(())
310}
311
312pub async fn get_health_check_status(
314 drt: &DistributedRuntime,
315) -> anyhow::Result<serde_json::Value> {
316 let endpoint_subjects: Vec<String> = drt.system_health().lock().get_health_check_endpoints();
318
319 let mut endpoint_statuses = HashMap::new();
320
321 {
323 let system_health = drt.system_health();
324 let system_health_lock = system_health.lock();
325 for endpoint_subject in &endpoint_subjects {
326 let health_status = system_health_lock
327 .get_endpoint_health_status(endpoint_subject)
328 .unwrap_or(HealthStatus::NotReady);
329
330 let is_healthy = matches!(health_status, HealthStatus::Ready);
331
332 endpoint_statuses.insert(
333 endpoint_subject.clone(),
334 serde_json::json!({
335 "healthy": is_healthy,
336 "status": format!("{:?}", health_status),
337 }),
338 );
339 }
340 }
341
342 let overall_healthy = endpoint_statuses
343 .values()
344 .all(|v| v["healthy"].as_bool().unwrap_or(false));
345
346 Ok(serde_json::json!({
347 "status": if overall_healthy { "ready" } else { "notready" },
348 "endpoints_checked": endpoint_subjects.len(),
349 "endpoint_statuses": endpoint_statuses,
350 }))
351}
352
353#[cfg(all(test, feature = "integration"))]
359mod push_handler_notify_tests {
360 use super::*;
361 use crate::component::{Instance, TransportType};
362 use crate::config::HealthStatus;
363 use crate::distributed::distributed_test_utils::create_test_drt_async;
364 use crate::engine::{AsyncEngine, AsyncEngineContextProvider};
365 use crate::local_endpoint_registry::LocalAsyncEngine;
366 use crate::pipeline::network::codec::{TwoPartCodec, TwoPartMessage};
367 use crate::pipeline::network::tcp::server::{ServerOptions, TcpStreamServer};
368 use crate::pipeline::network::{
369 ConnectionInfo, Ingress, PushWorkHandler, ResponseService, StreamOptions,
370 };
371 use crate::pipeline::{ManyOut, ResponseStream, SingleIn};
372 use crate::protocols::annotated::Annotated;
373 use async_trait::async_trait;
374 use bytes::Bytes;
375 use futures::stream;
376 use std::sync::Arc;
377 use std::time::Duration;
378
379 type TestRequest = serde_json::Value;
380 type TestResponse = Annotated<serde_json::Value>;
381
382 struct MockStreamingEngine {
386 num_chunks: usize,
387 error_indices: Vec<usize>,
389 }
390
391 impl MockStreamingEngine {
392 fn success(num_chunks: usize) -> Arc<Self> {
393 Arc::new(Self {
394 num_chunks,
395 error_indices: vec![],
396 })
397 }
398
399 fn all_errors(num_chunks: usize) -> Arc<Self> {
400 Arc::new(Self {
401 num_chunks,
402 error_indices: (0..num_chunks).collect(),
403 })
404 }
405
406 fn with_error_at(num_chunks: usize, error_indices: Vec<usize>) -> Arc<Self> {
407 Arc::new(Self {
408 num_chunks,
409 error_indices,
410 })
411 }
412 }
413
414 #[async_trait]
415 impl AsyncEngine<SingleIn<TestRequest>, ManyOut<TestResponse>, anyhow::Error>
416 for MockStreamingEngine
417 {
418 async fn generate(
419 &self,
420 input: SingleIn<TestRequest>,
421 ) -> anyhow::Result<ManyOut<TestResponse>> {
422 let (_data, ctx) = input.into_parts();
423 let chunks: Vec<TestResponse> = (0..self.num_chunks)
424 .map(|i| {
425 if self.error_indices.contains(&i) {
426 Annotated::from_error(format!("mock error at chunk {i}"))
427 } else {
428 Annotated::from_data(serde_json::json!({"token": i}))
429 }
430 })
431 .collect();
432 Ok(ResponseStream::new(
433 Box::pin(stream::iter(chunks)),
434 ctx.context(),
435 ))
436 }
437 }
438
439 fn encode_request(
441 request_id: &str,
442 connection_info: ConnectionInfo,
443 request_body: &serde_json::Value,
444 ) -> Bytes {
445 let control = serde_json::json!({
446 "id": request_id,
447 "request_type": "single_in",
448 "response_type": "many_out",
449 "connection_info": connection_info,
450 });
451 let header = serde_json::to_vec(&control).unwrap();
452 let data = serde_json::to_vec(request_body).unwrap();
453 let msg = TwoPartMessage::new(Bytes::from(header), Bytes::from(data));
454 TwoPartCodec::default().encode_message(msg).unwrap()
455 }
456
457 async fn setup_tcp_receiver(request_id: &str) -> (Arc<TcpStreamServer>, ConnectionInfo) {
460 let options = ServerOptions::builder().port(0).build().unwrap();
461 let server = TcpStreamServer::new(options).await.unwrap();
462
463 let context = crate::pipeline::Context::with_id((), request_id.to_string());
464 let stream_options = StreamOptions::builder()
465 .context(context.context())
466 .enable_request_stream(false)
467 .enable_response_stream(true)
468 .build()
469 .unwrap();
470
471 let pending = server.register(stream_options).await;
472 let connection_info = pending
473 .recv_stream
474 .as_ref()
475 .unwrap()
476 .connection_info
477 .clone();
478
479 (server, connection_info)
480 }
481
482 fn register_endpoint(
485 drt: &crate::DistributedRuntime,
486 endpoint_name: &str,
487 local_engine: LocalAsyncEngine,
488 ) -> Arc<tokio::sync::Notify> {
489 let payload = serde_json::json!({
490 "prompt": "health",
491 "_health_check": true
492 });
493 drt.system_health().lock().register_health_check_target(
494 endpoint_name,
495 Instance {
496 component: "test_component".to_string(),
497 endpoint: endpoint_name.to_string(),
498 namespace: "test_namespace".to_string(),
499 instance_id: 0,
500 transport: TransportType::Nats(endpoint_name.to_string()),
501 device_type: None,
502 },
503 payload,
504 );
505
506 drt.local_endpoint_registry()
507 .register(endpoint_name.to_string(), local_engine);
508
509 drt.system_health()
510 .lock()
511 .get_endpoint_health_check_notifier(endpoint_name)
512 .expect("Notifier should exist for registered endpoint")
513 }
514
515 async fn send_request(ingress: &Ingress<SingleIn<TestRequest>, ManyOut<TestResponse>>) {
517 let request_id = uuid::Uuid::new_v4().to_string();
518 let (_server, connection_info) = setup_tcp_receiver(&request_id).await;
519 let payload = encode_request(
520 &request_id,
521 connection_info,
522 &serde_json::json!({"prompt": "test"}),
523 );
524 let result = ingress.handle_payload(payload, Some(request_id)).await;
525 assert!(result.is_ok(), "handle_payload should succeed");
526 }
527
528 fn assert_status(
530 drt: &crate::DistributedRuntime,
531 endpoint_name: &str,
532 expected: HealthStatus,
533 msg: &str,
534 ) {
535 let status = drt
536 .system_health()
537 .lock()
538 .get_endpoint_health_status(endpoint_name);
539 assert_eq!(status, Some(expected), "{msg}");
540 }
541
542 fn create_ingress(
544 engine: Arc<MockStreamingEngine>,
545 notifier: Arc<tokio::sync::Notify>,
546 ) -> Arc<Ingress<SingleIn<TestRequest>, ManyOut<TestResponse>>> {
547 let ingress =
548 Ingress::<SingleIn<TestRequest>, ManyOut<TestResponse>>::for_engine(engine).unwrap();
549 ingress
550 .set_endpoint_health_check_notifier(notifier)
551 .unwrap();
552 ingress
553 }
554
555 async fn start_manager(drt: &crate::DistributedRuntime, canary_wait_ms: u64) {
557 let config = HealthCheckConfig {
558 canary_wait_time: Duration::from_millis(canary_wait_ms),
559 request_timeout: Duration::from_secs(1),
560 };
561 let manager = Arc::new(HealthCheckManager::new(drt.clone(), config));
562 manager.start().await.unwrap();
563 }
564
565 #[tokio::test]
570 async fn test_successful_streaming_sets_ready() {
571 let drt = create_test_drt_async().await;
572 let endpoint = "test.successful_streaming";
573
574 let notifier = register_endpoint(&drt, endpoint, MockStreamingEngine::all_errors(1));
575 assert_status(&drt, endpoint, HealthStatus::NotReady, "initial");
576
577 let ingress = create_ingress(MockStreamingEngine::success(5), notifier);
578 start_manager(&drt, 500).await;
579
580 send_request(&ingress).await;
581 tokio::time::sleep(Duration::from_millis(200)).await;
582
583 assert_status(
585 &drt,
586 endpoint,
587 HealthStatus::Ready,
588 "successful streaming should set Ready via notification path",
589 );
590 }
591
592 #[tokio::test]
596 async fn test_canary_fires_on_idle_engine() {
597 let drt = create_test_drt_async().await;
598 let endpoint = "test.canary_idle";
599
600 let _notifier = register_endpoint(&drt, endpoint, MockStreamingEngine::success(1));
601 assert_status(&drt, endpoint, HealthStatus::NotReady, "initial");
602
603 start_manager(&drt, 50).await;
604 tokio::time::sleep(Duration::from_millis(300)).await;
605
606 assert_status(
608 &drt,
609 endpoint,
610 HealthStatus::Ready,
611 "canary should fire and set Ready on idle engine",
612 );
613 }
614
615 #[tokio::test]
619 async fn test_error_streaming_stays_not_ready() {
620 let drt = create_test_drt_async().await;
621 let endpoint = "test.error_streaming";
622
623 let notifier = register_endpoint(&drt, endpoint, MockStreamingEngine::all_errors(1));
624 assert_status(&drt, endpoint, HealthStatus::NotReady, "initial");
625
626 let ingress = create_ingress(MockStreamingEngine::all_errors(3), notifier);
628 start_manager(&drt, 50).await;
629
630 send_request(&ingress).await;
631 tokio::time::sleep(Duration::from_millis(300)).await;
633
634 assert_status(
636 &drt,
637 endpoint,
638 HealthStatus::NotReady,
639 "error streaming should not notify, canary also errors — stays NotReady",
640 );
641 }
642
643 #[tokio::test]
647 async fn test_idle_engine_with_failing_canary() {
648 let drt = create_test_drt_async().await;
649 let endpoint = "test.canary_fails";
650
651 let _notifier = register_endpoint(&drt, endpoint, MockStreamingEngine::all_errors(1));
652 assert_status(&drt, endpoint, HealthStatus::NotReady, "initial");
653
654 start_manager(&drt, 50).await;
655 tokio::time::sleep(Duration::from_millis(300)).await;
656
657 assert_status(
659 &drt,
660 endpoint,
661 HealthStatus::NotReady,
662 "canary fired but engine errored, status stays NotReady",
663 );
664 }
665
666 #[tokio::test]
672 async fn test_mixed_streaming_sets_ready() {
673 let drt = create_test_drt_async().await;
674 let endpoint = "test.mixed_streaming";
675
676 let notifier = register_endpoint(&drt, endpoint, MockStreamingEngine::all_errors(1));
677 assert_status(&drt, endpoint, HealthStatus::NotReady, "initial");
678
679 let ingress = create_ingress(MockStreamingEngine::with_error_at(5, vec![4]), notifier);
681 start_manager(&drt, 500).await;
682
683 send_request(&ingress).await;
684 tokio::time::sleep(Duration::from_millis(200)).await;
685
686 assert_status(
688 &drt,
689 endpoint,
690 HealthStatus::Ready,
691 "successful chunks should set Ready despite trailing error",
692 );
693 }
694}
695
696#[cfg(all(test, feature = "integration"))]
700mod integration_tests {
701 use super::*;
702 use crate::distributed::distributed_test_utils::create_test_drt_async;
703 use std::sync::Arc;
704 use std::time::Duration;
705
706 #[tokio::test]
707 async fn test_initialization() {
708 let drt = create_test_drt_async().await;
709
710 let canary_wait_time = Duration::from_secs(5);
711 let request_timeout = Duration::from_secs(3);
712
713 let config = HealthCheckConfig {
714 canary_wait_time,
715 request_timeout,
716 };
717
718 let manager = HealthCheckManager::new(drt.clone(), config);
719
720 assert_eq!(manager.config.canary_wait_time, canary_wait_time);
721 assert_eq!(manager.config.request_timeout, request_timeout);
722 }
723
724 #[tokio::test]
725 async fn test_payload_registration() {
726 let drt = create_test_drt_async().await;
727
728 let endpoint = "test.endpoint";
729 let payload = serde_json::json!({
730 "prompt": "test",
731 "_health_check": true
732 });
733
734 drt.system_health().lock().register_health_check_target(
735 endpoint,
736 crate::component::Instance {
737 component: "test_component".to_string(),
738 endpoint: "test_endpoint".to_string(),
739 namespace: "test_namespace".to_string(),
740 instance_id: 12345,
741 transport: crate::component::TransportType::Nats(endpoint.to_string()),
742 device_type: None,
743 },
744 payload.clone(),
745 );
746
747 let retrieved = drt
748 .system_health()
749 .lock()
750 .get_health_check_target(endpoint)
751 .map(|t| t.payload);
752 assert!(retrieved.is_some());
753 assert_eq!(retrieved.unwrap(), payload);
754
755 let endpoints = drt.system_health().lock().get_health_check_endpoints();
757 assert!(endpoints.contains(&endpoint.to_string()));
758 }
759
760 #[tokio::test]
761 async fn test_spawn_per_endpoint_tasks() {
762 let drt = create_test_drt_async().await;
763
764 for i in 0..3 {
765 let endpoint = format!("test.endpoint.{}", i);
766 let payload = serde_json::json!({
767 "prompt": format!("test{}", i),
768 "_health_check": true
769 });
770 drt.system_health().lock().register_health_check_target(
771 &endpoint,
772 crate::component::Instance {
773 component: "test_component".to_string(),
774 endpoint: format!("test_endpoint_{}", i),
775 namespace: "test_namespace".to_string(),
776 instance_id: i,
777 transport: crate::component::TransportType::Nats(endpoint.clone()),
778 device_type: None,
779 },
780 payload,
781 );
782 }
783
784 let config = HealthCheckConfig {
785 canary_wait_time: Duration::from_secs(5),
786 request_timeout: Duration::from_secs(1),
787 };
788
789 let manager = Arc::new(HealthCheckManager::new(drt.clone(), config));
790 manager.clone().start().await.unwrap();
791
792 let tasks = manager.endpoint_tasks.lock();
794 assert_eq!(tasks.len(), 3);
796 let endpoints: Vec<String> = tasks.keys().cloned().collect();
798 assert!(endpoints.contains(&"test.endpoint.0".to_string()));
799 assert!(endpoints.contains(&"test.endpoint.1".to_string()));
800 assert!(endpoints.contains(&"test.endpoint.2".to_string()));
801 }
802
803 #[tokio::test]
804 async fn test_endpoint_health_check_notifier_created() {
805 let drt = create_test_drt_async().await;
806
807 let endpoint = "test.endpoint.notifier";
808 let payload = serde_json::json!({
809 "prompt": "test",
810 "_health_check": true
811 });
812
813 drt.system_health().lock().register_health_check_target(
815 endpoint,
816 crate::component::Instance {
817 component: "test_component".to_string(),
818 endpoint: "test_endpoint_notifier".to_string(),
819 namespace: "test_namespace".to_string(),
820 instance_id: 999,
821 transport: crate::component::TransportType::Nats(endpoint.to_string()),
822 device_type: None,
823 },
824 payload.clone(),
825 );
826
827 let notifier = drt
829 .system_health()
830 .lock()
831 .get_endpoint_health_check_notifier(endpoint);
832
833 assert!(
834 notifier.is_some(),
835 "Endpoint should have a notifier created"
836 );
837
838 if let Some(notifier) = notifier {
840 notifier.notify_one();
841 }
842
843 let status = drt
845 .system_health()
846 .lock()
847 .get_endpoint_health_status(endpoint);
848 assert_eq!(status, Some(HealthStatus::NotReady));
849 }
850}