1use crate::component::{Client, Component, Endpoint, Instance};
5use crate::config::HealthStatus;
6use crate::pipeline::PushRouter;
7use crate::pipeline::{AsyncEngine, Context, ManyOut, SingleIn};
8use crate::protocols::annotated::Annotated;
9use crate::protocols::maybe_error::MaybeError;
10use crate::{DistributedRuntime, SystemHealth};
11use futures::StreamExt;
12use parking_lot::Mutex;
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::sync::Arc;
16use std::time::{Duration, Instant};
17use tokio::task::JoinHandle;
18use tokio::time::{MissedTickBehavior, interval};
19use tracing::{debug, error, info, warn};
20
21pub struct HealthCheckConfig {
23 pub canary_wait_time: Duration,
25 pub request_timeout: Duration,
27}
28
29impl Default for HealthCheckConfig {
30 fn default() -> Self {
31 Self {
32 canary_wait_time: Duration::from_secs(crate::config::DEFAULT_CANARY_WAIT_TIME_SECS),
33 request_timeout: Duration::from_secs(
34 crate::config::DEFAULT_HEALTH_CHECK_REQUEST_TIMEOUT_SECS,
35 ),
36 }
37 }
38}
39
40type RouterCache =
43 Arc<Mutex<HashMap<String, Arc<PushRouter<serde_json::Value, Annotated<serde_json::Value>>>>>>;
44
45pub struct HealthCheckManager {
47 drt: DistributedRuntime,
48 config: HealthCheckConfig,
49 router_cache: RouterCache,
51 endpoint_tasks: Arc<Mutex<HashMap<String, JoinHandle<()>>>>,
54}
55
56impl HealthCheckManager {
57 pub fn new(drt: DistributedRuntime, config: HealthCheckConfig) -> Self {
58 Self {
59 drt,
60 config,
61 router_cache: Arc::new(Mutex::new(HashMap::new())),
62 endpoint_tasks: Arc::new(Mutex::new(HashMap::new())),
63 }
64 }
65
66 async fn get_or_create_router(
68 &self,
69 cache_key: &str,
70 endpoint: Endpoint,
71 ) -> anyhow::Result<Arc<PushRouter<serde_json::Value, Annotated<serde_json::Value>>>> {
72 let cache_key = cache_key.to_string();
73
74 {
76 let cache = self.router_cache.lock();
77 if let Some(router) = cache.get(&cache_key) {
78 return Ok(router.clone());
79 }
80 }
81
82 let client = Client::new(endpoint).await?;
84
85 let router: Arc<PushRouter<serde_json::Value, Annotated<serde_json::Value>>> = Arc::new(
87 PushRouter::from_client(
88 client,
89 crate::pipeline::RouterMode::RoundRobin, )
91 .await?,
92 );
93
94 self.router_cache.lock().insert(cache_key, router.clone());
96
97 Ok(router)
98 }
99
100 pub async fn start(self: Arc<Self>) -> anyhow::Result<()> {
102 let targets = self.drt.system_health().lock().get_health_check_targets();
104
105 info!(
106 "Starting health check tasks for {} endpoints with canary_wait_time: {:?}",
107 targets.len(),
108 self.config.canary_wait_time
109 );
110
111 for (endpoint_subject, _target) in targets {
113 self.spawn_endpoint_health_check_task(endpoint_subject);
114 }
115
116 self.spawn_new_endpoint_monitor().await?;
120
121 info!("HealthCheckManager started successfully with channel-based endpoint discovery");
122 Ok(())
123 }
124
125 fn spawn_endpoint_health_check_task(self: &Arc<Self>, endpoint_subject: String) {
127 let manager = self.clone();
128 let canary_wait = self.config.canary_wait_time;
129 let endpoint_subject_clone = endpoint_subject.clone();
130
131 let notifier = self
133 .drt
134 .system_health()
135 .lock()
136 .get_endpoint_health_check_notifier(&endpoint_subject)
137 .expect("Notifier should exist for registered endpoint");
138
139 let task = tokio::spawn(async move {
140 let endpoint_subject = endpoint_subject_clone;
141 info!("Health check task started for: {}", endpoint_subject);
142
143 loop {
144 tokio::select! {
146 _ = tokio::time::sleep(canary_wait) => {
147 debug!("Canary timer expired for {}, sending health check", endpoint_subject);
149
150 let target = manager.drt.system_health().lock().get_health_check_target(&endpoint_subject);
152
153 if let Some(target) = target {
154 if let Err(e) = manager.send_health_check_request(&endpoint_subject, &target.payload).await {
155 error!("Failed to send health check for {}: {}", endpoint_subject, e);
156 }
157 } else {
158 error!(
160 "CRITICAL: Health check target for {} disappeared unexpectedly! This indicates a bug. Stopping health check task.",
161 endpoint_subject
162 );
163 break;
164 }
165 }
166
167 _ = notifier.notified() => {
168 debug!("Activity detected for {}, resetting health check timer", endpoint_subject);
170 }
172 }
173 }
174
175 info!("Health check task for {} exiting", endpoint_subject);
176 });
177
178 self.endpoint_tasks
180 .lock()
181 .insert(endpoint_subject.clone(), task);
182
183 info!(
184 "Spawned health check task for endpoint: {}",
185 endpoint_subject
186 );
187 }
188
189 async fn spawn_new_endpoint_monitor(self: &Arc<Self>) -> anyhow::Result<()> {
192 let manager = self.clone();
193
194 let mut rx = manager
196 .drt
197 .system_health()
198 .lock()
199 .take_new_endpoint_receiver()
200 .ok_or_else(|| {
201 anyhow::anyhow!("Endpoint receiver already taken - this should only be called once")
202 })?;
203
204 tokio::spawn(async move {
205 info!("Starting dynamic endpoint discovery monitor with channel-based notifications");
206
207 while let Some(endpoint_subject) = rx.recv().await {
208 debug!(
209 "Received endpoint registration via channel: {}",
210 endpoint_subject
211 );
212
213 let already_exists = {
214 let tasks = manager.endpoint_tasks.lock();
215 tasks.contains_key(&endpoint_subject)
216 };
217
218 if already_exists {
219 error!(
220 "CRITICAL: Received registration for endpoint '{}' that already has a health check task!",
221 endpoint_subject
222 );
223 break;
224 }
225
226 info!(
227 "Spawning health check task for new endpoint: {}",
228 endpoint_subject
229 );
230 manager.spawn_endpoint_health_check_task(endpoint_subject);
231 }
232
233 info!("Endpoint discovery monitor exiting - no new endpoints will be monitored!");
234 });
235
236 info!("Dynamic endpoint discovery monitor started");
237 Ok(())
238 }
239
240 async fn send_health_check_request(
242 &self,
243 endpoint_subject: &str,
244 payload: &serde_json::Value,
245 ) -> anyhow::Result<()> {
246 let target = self
247 .drt
248 .system_health()
249 .lock()
250 .get_health_check_target(endpoint_subject)
251 .ok_or_else(|| {
252 anyhow::anyhow!("No health check target found for {}", endpoint_subject)
253 })?;
254
255 debug!(
256 "Sending health check to {} (instance_id: {})",
257 endpoint_subject, target.instance.instance_id
258 );
259
260 let namespace = self.drt.namespace(&target.instance.namespace)?;
262 let component = namespace.component(&target.instance.component)?;
263 let endpoint = component.endpoint(&target.instance.endpoint);
264
265 let router = self
267 .get_or_create_router(endpoint_subject, endpoint)
268 .await?;
269
270 match tokio::time::timeout(
276 Duration::from_secs(10), router.client.wait_for_instances(),
278 )
279 .await
280 {
281 Ok(Ok(instances)) => {
282 debug!(
283 "Health check for {}: watch stream ready, found {} instance(s)",
284 endpoint_subject,
285 instances.len()
286 );
287 }
288 Ok(Err(e)) => {
289 return Err(anyhow::anyhow!(
290 "Failed to discover instances for {} during health check: {}",
291 endpoint_subject,
292 e
293 ));
294 }
295 Err(_) => {
296 return Err(anyhow::anyhow!(
297 "Timeout waiting for instance discovery for {} during health check",
298 endpoint_subject
299 ));
300 }
301 }
302
303 let request: SingleIn<serde_json::Value> = Context::new(payload.clone());
305
306 let system_health = self.drt.system_health().clone();
308 let endpoint_subject_owned = endpoint_subject.to_string();
309 let instance_id = target.instance.instance_id;
310 let timeout = self.config.request_timeout;
311
312 tokio::spawn(async move {
314 let result = tokio::time::timeout(timeout, async {
315 match router.direct(request, instance_id).await {
317 Ok(mut response_stream) => {
318 let is_healthy = if let Some(response) = response_stream.next().await {
320 if let Some(error) = response.err() {
322 warn!(
323 "Health check error response from {}: {:?}",
324 endpoint_subject_owned, error
325 );
326 false
327 } else {
328 debug!("Health check successful for {}", endpoint_subject_owned);
329 true
330 }
331 } else {
332 warn!(
333 "Health check got no response from {}",
334 endpoint_subject_owned
335 );
336 false
337 };
338
339 tokio::spawn(async move {
340 response_stream.for_each(|_| async {}).await;
342 });
343
344 system_health.lock().set_endpoint_health_status(
346 &endpoint_subject_owned,
347 if is_healthy {
348 HealthStatus::Ready
349 } else {
350 HealthStatus::NotReady
351 },
352 );
353 }
354 Err(e) => {
355 error!(
356 "Health check request failed for {}: {}",
357 endpoint_subject_owned, e
358 );
359 system_health.lock().set_endpoint_health_status(
360 &endpoint_subject_owned,
361 HealthStatus::NotReady,
362 );
363 }
364 }
365 })
366 .await;
367
368 if result.is_err() {
370 warn!("Health check timeout for {}", endpoint_subject_owned);
371 system_health
372 .lock()
373 .set_endpoint_health_status(&endpoint_subject_owned, HealthStatus::NotReady);
374 }
375
376 debug!("Health check completed for {}", endpoint_subject_owned);
377 });
378
379 Ok(())
380 }
381}
382
383pub async fn start_health_check_manager(
385 drt: DistributedRuntime,
386 config: Option<HealthCheckConfig>,
387) -> anyhow::Result<()> {
388 let config = config.unwrap_or_default();
389 let manager = Arc::new(HealthCheckManager::new(drt, config));
390
391 manager.start().await?;
393
394 Ok(())
395}
396
397pub async fn get_health_check_status(
399 drt: &DistributedRuntime,
400) -> anyhow::Result<serde_json::Value> {
401 let endpoint_subjects: Vec<String> = drt.system_health().lock().get_health_check_endpoints();
403
404 let mut endpoint_statuses = HashMap::new();
405
406 {
408 let system_health = drt.system_health();
409 let system_health_lock = system_health.lock();
410 for endpoint_subject in &endpoint_subjects {
411 let health_status = system_health_lock
412 .get_endpoint_health_status(endpoint_subject)
413 .unwrap_or(HealthStatus::NotReady);
414
415 let is_healthy = matches!(health_status, HealthStatus::Ready);
416
417 endpoint_statuses.insert(
418 endpoint_subject.clone(),
419 serde_json::json!({
420 "healthy": is_healthy,
421 "status": format!("{:?}", health_status),
422 }),
423 );
424 }
425 }
426
427 let overall_healthy = endpoint_statuses
428 .values()
429 .all(|v| v["healthy"].as_bool().unwrap_or(false));
430
431 Ok(serde_json::json!({
432 "status": if overall_healthy { "ready" } else { "notready" },
433 "endpoints_checked": endpoint_subjects.len(),
434 "endpoint_statuses": endpoint_statuses,
435 }))
436}
437
438#[cfg(all(test, feature = "integration"))]
442mod integration_tests {
443 use super::*;
444 use crate::distributed::distributed_test_utils::create_test_drt_async;
445 use std::sync::Arc;
446 use std::time::Duration;
447
448 #[tokio::test]
449 async fn test_initialization() {
450 let drt = create_test_drt_async().await;
451
452 let canary_wait_time = Duration::from_secs(5);
453 let request_timeout = Duration::from_secs(3);
454
455 let config = HealthCheckConfig {
456 canary_wait_time,
457 request_timeout,
458 };
459
460 let manager = HealthCheckManager::new(drt.clone(), config);
461
462 assert_eq!(manager.config.canary_wait_time, canary_wait_time);
463 assert_eq!(manager.config.request_timeout, request_timeout);
464 }
465
466 #[tokio::test]
467 async fn test_payload_registration() {
468 let drt = create_test_drt_async().await;
469
470 let endpoint = "test.endpoint";
471 let payload = serde_json::json!({
472 "prompt": "test",
473 "_health_check": true
474 });
475
476 drt.system_health().lock().register_health_check_target(
477 endpoint,
478 crate::component::Instance {
479 component: "test_component".to_string(),
480 endpoint: "test_endpoint".to_string(),
481 namespace: "test_namespace".to_string(),
482 instance_id: 12345,
483 transport: crate::component::TransportType::Nats(endpoint.to_string()),
484 },
485 payload.clone(),
486 );
487
488 let retrieved = drt
489 .system_health()
490 .lock()
491 .get_health_check_target(endpoint)
492 .map(|t| t.payload);
493 assert!(retrieved.is_some());
494 assert_eq!(retrieved.unwrap(), payload);
495
496 let endpoints = drt.system_health().lock().get_health_check_endpoints();
498 assert!(endpoints.contains(&endpoint.to_string()));
499 }
500
501 #[tokio::test]
502 async fn test_spawn_per_endpoint_tasks() {
503 let drt = create_test_drt_async().await;
504
505 for i in 0..3 {
506 let endpoint = format!("test.endpoint.{}", i);
507 let payload = serde_json::json!({
508 "prompt": format!("test{}", i),
509 "_health_check": true
510 });
511 drt.system_health().lock().register_health_check_target(
512 &endpoint,
513 crate::component::Instance {
514 component: "test_component".to_string(),
515 endpoint: format!("test_endpoint_{}", i),
516 namespace: "test_namespace".to_string(),
517 instance_id: i,
518 transport: crate::component::TransportType::Nats(endpoint.clone()),
519 },
520 payload,
521 );
522 }
523
524 let config = HealthCheckConfig {
525 canary_wait_time: Duration::from_secs(5),
526 request_timeout: Duration::from_secs(1),
527 };
528
529 let manager = Arc::new(HealthCheckManager::new(drt.clone(), config));
530 manager.clone().start().await.unwrap();
531
532 let tasks = manager.endpoint_tasks.lock();
534 assert_eq!(tasks.len(), 3);
536 let endpoints: Vec<String> = tasks.keys().cloned().collect();
538 assert!(endpoints.contains(&"test.endpoint.0".to_string()));
539 assert!(endpoints.contains(&"test.endpoint.1".to_string()));
540 assert!(endpoints.contains(&"test.endpoint.2".to_string()));
541 }
542
543 #[tokio::test]
544 async fn test_endpoint_health_check_notifier_created() {
545 let drt = create_test_drt_async().await;
546
547 let endpoint = "test.endpoint.notifier";
548 let payload = serde_json::json!({
549 "prompt": "test",
550 "_health_check": true
551 });
552
553 drt.system_health().lock().register_health_check_target(
555 endpoint,
556 crate::component::Instance {
557 component: "test_component".to_string(),
558 endpoint: "test_endpoint_notifier".to_string(),
559 namespace: "test_namespace".to_string(),
560 instance_id: 999,
561 transport: crate::component::TransportType::Nats(endpoint.to_string()),
562 },
563 payload.clone(),
564 );
565
566 let notifier = drt
568 .system_health()
569 .lock()
570 .get_endpoint_health_check_notifier(endpoint);
571
572 assert!(
573 notifier.is_some(),
574 "Endpoint should have a notifier created"
575 );
576
577 if let Some(notifier) = notifier {
579 notifier.notify_one();
580 }
581
582 let status = drt
584 .system_health()
585 .lock()
586 .get_endpoint_health_status(endpoint);
587 assert_eq!(status, Some(HealthStatus::NotReady));
588 }
589}