1use crate::component::{Client, Component, Endpoint, Instance};
5use crate::pipeline::PushRouter;
6use crate::pipeline::{AsyncEngine, Context, ManyOut, SingleIn};
7use crate::protocols::annotated::Annotated;
8use crate::protocols::maybe_error::MaybeError;
9use crate::{DistributedRuntime, HealthStatus, SystemHealth};
10use futures::StreamExt;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::{Arc, Mutex};
14use std::time::{Duration, Instant};
15use tokio::task::JoinHandle;
16use tokio::time::{MissedTickBehavior, interval};
17use tracing::{debug, error, info, warn};
18
19pub struct HealthCheckConfig {
21 pub canary_wait_time: Duration,
23 pub request_timeout: Duration,
25}
26
27impl Default for HealthCheckConfig {
28 fn default() -> Self {
29 Self {
30 canary_wait_time: Duration::from_secs(crate::config::DEFAULT_CANARY_WAIT_TIME_SECS),
31 request_timeout: Duration::from_secs(
32 crate::config::DEFAULT_HEALTH_CHECK_REQUEST_TIMEOUT_SECS,
33 ),
34 }
35 }
36}
37
38type RouterCache =
41 Arc<Mutex<HashMap<String, Arc<PushRouter<serde_json::Value, Annotated<serde_json::Value>>>>>>;
42
43pub struct HealthCheckManager {
45 drt: DistributedRuntime,
46 config: HealthCheckConfig,
47 router_cache: RouterCache,
49 endpoint_tasks: Arc<Mutex<HashMap<String, JoinHandle<()>>>>,
52}
53
54impl HealthCheckManager {
55 pub fn new(drt: DistributedRuntime, config: HealthCheckConfig) -> Self {
56 Self {
57 drt,
58 config,
59 router_cache: Arc::new(Mutex::new(HashMap::new())),
60 endpoint_tasks: Arc::new(Mutex::new(HashMap::new())),
61 }
62 }
63
64 async fn get_or_create_router(
66 &self,
67 cache_key: &str,
68 endpoint: Endpoint,
69 ) -> anyhow::Result<Arc<PushRouter<serde_json::Value, Annotated<serde_json::Value>>>> {
70 let cache_key = cache_key.to_string();
71
72 {
74 let cache = self.router_cache.lock().unwrap();
75 if let Some(router) = cache.get(&cache_key) {
76 return Ok(router.clone());
77 }
78 }
79
80 let client = Client::new_dynamic(endpoint).await?;
82
83 let router: Arc<PushRouter<serde_json::Value, Annotated<serde_json::Value>>> = Arc::new(
85 PushRouter::from_client(
86 client,
87 crate::pipeline::RouterMode::RoundRobin, )
89 .await?,
90 );
91
92 self.router_cache
94 .lock()
95 .unwrap()
96 .insert(cache_key, router.clone());
97
98 Ok(router)
99 }
100
101 pub async fn start(self: Arc<Self>) -> anyhow::Result<()> {
103 let targets = {
105 let system_health = self.drt.system_health.lock().unwrap();
106 system_health.get_health_check_targets()
107 };
108
109 info!(
110 "Starting health check tasks for {} endpoints with canary_wait_time: {:?}",
111 targets.len(),
112 self.config.canary_wait_time
113 );
114
115 for (endpoint_subject, _target) in targets {
117 self.spawn_endpoint_health_check_task(endpoint_subject);
118 }
119
120 self.spawn_new_endpoint_monitor().await?;
124
125 info!("HealthCheckManager started successfully with channel-based endpoint discovery");
126 Ok(())
127 }
128
129 fn spawn_endpoint_health_check_task(self: &Arc<Self>, endpoint_subject: String) {
131 let manager = self.clone();
132 let canary_wait = self.config.canary_wait_time;
133 let endpoint_subject_clone = endpoint_subject.clone();
134
135 let notifier = {
137 let system_health = self.drt.system_health.lock().unwrap();
138 system_health
139 .get_endpoint_health_check_notifier(&endpoint_subject)
140 .expect("Notifier should exist for registered endpoint")
141 };
142
143 let task = tokio::spawn(async move {
144 let endpoint_subject = endpoint_subject_clone;
145 info!("Health check task started for: {}", endpoint_subject);
146
147 loop {
148 tokio::select! {
150 _ = tokio::time::sleep(canary_wait) => {
151 info!("Canary timer expired for {}, sending health check", endpoint_subject);
153
154 let target = {
156 let system_health = manager.drt.system_health.lock().unwrap();
157 system_health.get_health_check_target(&endpoint_subject)
158 };
159
160 if let Some(target) = target {
161 if let Err(e) = manager.send_health_check_request(&endpoint_subject, &target.payload).await {
162 error!("Failed to send health check for {}: {}", endpoint_subject, e);
163 }
164 } else {
165 error!(
167 "CRITICAL: Health check target for {} disappeared unexpectedly! This indicates a bug. Stopping health check task.",
168 endpoint_subject
169 );
170 break;
171 }
172 }
173
174 _ = notifier.notified() => {
175 debug!("Activity detected for {}, resetting health check timer", endpoint_subject);
177 }
179 }
180 }
181
182 info!("Health check task for {} exiting", endpoint_subject);
183 });
184
185 self.endpoint_tasks
187 .lock()
188 .unwrap()
189 .insert(endpoint_subject.clone(), task);
190
191 info!(
192 "Spawned health check task for endpoint: {}",
193 endpoint_subject
194 );
195 }
196
197 async fn spawn_new_endpoint_monitor(self: &Arc<Self>) -> anyhow::Result<()> {
200 let manager = self.clone();
201
202 let mut rx = {
204 let system_health = manager.drt.system_health.lock().unwrap();
205 system_health.take_new_endpoint_receiver().ok_or_else(|| {
206 anyhow::anyhow!("Endpoint receiver already taken - this should only be called once")
207 })?
208 };
209
210 tokio::spawn(async move {
211 info!("Starting dynamic endpoint discovery monitor with channel-based notifications");
212
213 while let Some(endpoint_subject) = rx.recv().await {
214 debug!(
215 "Received endpoint registration via channel: {}",
216 endpoint_subject
217 );
218
219 let already_exists = {
220 let tasks = manager.endpoint_tasks.lock().unwrap();
221 tasks.contains_key(&endpoint_subject)
222 };
223
224 if already_exists {
225 error!(
226 "CRITICAL: Received registration for endpoint '{}' that already has a health check task!",
227 endpoint_subject
228 );
229 break;
230 }
231
232 info!(
233 "Spawning health check task for new endpoint: {}",
234 endpoint_subject
235 );
236 manager.spawn_endpoint_health_check_task(endpoint_subject);
237 }
238
239 info!("Endpoint discovery monitor exiting - no new endpoints will be monitored!");
240 });
241
242 info!("Dynamic endpoint discovery monitor started");
243 Ok(())
244 }
245
246 async fn send_health_check_request(
248 &self,
249 endpoint_subject: &str,
250 payload: &serde_json::Value,
251 ) -> anyhow::Result<()> {
252 let target = {
253 let system_health = self.drt.system_health.lock().unwrap();
254 system_health
255 .get_health_check_target(endpoint_subject)
256 .ok_or_else(|| {
257 anyhow::anyhow!("No health check target found for {}", endpoint_subject)
258 })?
259 };
260
261 debug!(
262 "Sending health check to {} (instance_id: {})",
263 endpoint_subject, target.instance.instance_id
264 );
265
266 let namespace = self.drt.namespace(&target.instance.namespace)?;
268 let component = namespace.component(&target.instance.component)?;
269 let endpoint = component.endpoint(&target.instance.endpoint);
270
271 let router = self
273 .get_or_create_router(endpoint_subject, endpoint)
274 .await?;
275
276 let request: SingleIn<serde_json::Value> = Context::new(payload.clone());
278
279 let system_health = self.drt.system_health.clone();
281 let endpoint_subject_owned = endpoint_subject.to_string();
282 let instance_id = target.instance.instance_id;
283 let timeout = self.config.request_timeout;
284
285 tokio::spawn(async move {
287 let result = tokio::time::timeout(timeout, async {
288 match router.direct(request, instance_id).await {
290 Ok(mut response_stream) => {
291 let is_healthy = if let Some(response) = response_stream.next().await {
293 if let Some(error) = response.err() {
295 warn!(
296 "Health check error response from {}: {:?}",
297 endpoint_subject_owned, error
298 );
299 false
300 } else {
301 info!("Health check successful for {}", endpoint_subject_owned);
302 true
303 }
304 } else {
305 warn!(
306 "Health check got no response from {}",
307 endpoint_subject_owned
308 );
309 false
310 };
311
312 system_health.lock().unwrap().set_endpoint_health_status(
314 &endpoint_subject_owned,
315 if is_healthy {
316 HealthStatus::Ready
317 } else {
318 HealthStatus::NotReady
319 },
320 );
321 }
322 Err(e) => {
323 error!(
324 "Health check request failed for {}: {}",
325 endpoint_subject_owned, e
326 );
327 system_health.lock().unwrap().set_endpoint_health_status(
328 &endpoint_subject_owned,
329 HealthStatus::NotReady,
330 );
331 }
332 }
333 })
334 .await;
335
336 if result.is_err() {
338 warn!("Health check timeout for {}", endpoint_subject_owned);
339 system_health
340 .lock()
341 .unwrap()
342 .set_endpoint_health_status(&endpoint_subject_owned, HealthStatus::NotReady);
343 }
344
345 debug!("Health check completed for {}", endpoint_subject_owned);
346 });
347
348 Ok(())
349 }
350}
351
352pub async fn start_health_check_manager(
354 drt: DistributedRuntime,
355 config: Option<HealthCheckConfig>,
356) -> anyhow::Result<()> {
357 let config = config.unwrap_or_default();
358 let manager = Arc::new(HealthCheckManager::new(drt, config));
359
360 manager.start().await?;
362
363 Ok(())
364}
365
366pub async fn get_health_check_status(
368 drt: &DistributedRuntime,
369) -> anyhow::Result<serde_json::Value> {
370 let endpoint_subjects: Vec<String> = {
372 let system_health = drt.system_health.lock().unwrap();
373 system_health.get_health_check_endpoints()
374 };
375
376 let mut endpoint_statuses = HashMap::new();
377
378 {
380 let system_health = drt.system_health.lock().unwrap();
381 for endpoint_subject in &endpoint_subjects {
382 let health_status = system_health
383 .get_endpoint_health_status(endpoint_subject)
384 .unwrap_or(HealthStatus::NotReady);
385
386 let is_healthy = matches!(health_status, HealthStatus::Ready);
387
388 endpoint_statuses.insert(
389 endpoint_subject.clone(),
390 serde_json::json!({
391 "healthy": is_healthy,
392 "status": format!("{:?}", health_status),
393 }),
394 );
395 }
396 }
397
398 let overall_healthy = endpoint_statuses
399 .values()
400 .all(|v| v["healthy"].as_bool().unwrap_or(false));
401
402 Ok(serde_json::json!({
403 "status": if overall_healthy { "ready" } else { "notready" },
404 "endpoints_checked": endpoint_subjects.len(),
405 "endpoint_statuses": endpoint_statuses,
406 }))
407}
408
409#[cfg(all(test, feature = "integration"))]
413mod integration_tests {
414 use super::*;
415 use crate::HealthStatus;
416 use crate::distributed::distributed_test_utils::create_test_drt_async;
417 use std::sync::Arc;
418 use std::time::Duration;
419
420 #[tokio::test]
421 async fn test_initialization() {
422 let drt = create_test_drt_async().await;
423
424 let canary_wait_time = Duration::from_secs(5);
425 let request_timeout = Duration::from_secs(3);
426
427 let config = HealthCheckConfig {
428 canary_wait_time,
429 request_timeout,
430 };
431
432 let manager = HealthCheckManager::new(drt.clone(), config);
433
434 assert_eq!(manager.config.canary_wait_time, canary_wait_time);
435 assert_eq!(manager.config.request_timeout, request_timeout);
436
437 assert!(Arc::ptr_eq(&manager.drt.system_health, &drt.system_health));
438 }
439
440 #[tokio::test]
441 async fn test_payload_registration() {
442 let drt = create_test_drt_async().await;
443
444 let endpoint = "test.endpoint";
445 let payload = serde_json::json!({
446 "prompt": "test",
447 "_health_check": true
448 });
449
450 drt.system_health
451 .lock()
452 .unwrap()
453 .register_health_check_target(
454 endpoint,
455 crate::component::Instance {
456 component: "test_component".to_string(),
457 endpoint: "test_endpoint".to_string(),
458 namespace: "test_namespace".to_string(),
459 instance_id: 12345,
460 transport: crate::component::TransportType::NatsTcp(endpoint.to_string()),
461 },
462 payload.clone(),
463 );
464
465 let retrieved = drt
466 .system_health
467 .lock()
468 .unwrap()
469 .get_health_check_target(endpoint)
470 .map(|t| t.payload);
471 assert!(retrieved.is_some());
472 assert_eq!(retrieved.unwrap(), payload);
473
474 let endpoints = drt
476 .system_health
477 .lock()
478 .unwrap()
479 .get_health_check_endpoints();
480 assert!(endpoints.contains(&endpoint.to_string()));
481 }
482
483 #[tokio::test]
484 async fn test_spawn_per_endpoint_tasks() {
485 let drt = create_test_drt_async().await;
486
487 for i in 0..3 {
488 let endpoint = format!("test.endpoint.{}", i);
489 let payload = serde_json::json!({
490 "prompt": format!("test{}", i),
491 "_health_check": true
492 });
493 drt.system_health
494 .lock()
495 .unwrap()
496 .register_health_check_target(
497 &endpoint,
498 crate::component::Instance {
499 component: "test_component".to_string(),
500 endpoint: format!("test_endpoint_{}", i),
501 namespace: "test_namespace".to_string(),
502 instance_id: i as i64,
503 transport: crate::component::TransportType::NatsTcp(endpoint.clone()),
504 },
505 payload,
506 );
507 }
508
509 let config = HealthCheckConfig {
510 canary_wait_time: Duration::from_secs(5),
511 request_timeout: Duration::from_secs(1),
512 };
513
514 let manager = Arc::new(HealthCheckManager::new(drt.clone(), config));
515 manager.clone().start().await.unwrap();
516
517 let tasks = manager.endpoint_tasks.lock().unwrap();
519 assert_eq!(tasks.len(), 3);
521 let endpoints: Vec<String> = tasks.keys().cloned().collect();
523 assert!(endpoints.contains(&"test.endpoint.0".to_string()));
524 assert!(endpoints.contains(&"test.endpoint.1".to_string()));
525 assert!(endpoints.contains(&"test.endpoint.2".to_string()));
526 }
527
528 #[tokio::test]
529 async fn test_endpoint_health_check_notifier_created() {
530 let drt = create_test_drt_async().await;
531
532 let endpoint = "test.endpoint.notifier";
533 let payload = serde_json::json!({
534 "prompt": "test",
535 "_health_check": true
536 });
537
538 drt.system_health
540 .lock()
541 .unwrap()
542 .register_health_check_target(
543 endpoint,
544 crate::component::Instance {
545 component: "test_component".to_string(),
546 endpoint: "test_endpoint_notifier".to_string(),
547 namespace: "test_namespace".to_string(),
548 instance_id: 999,
549 transport: crate::component::TransportType::NatsTcp(endpoint.to_string()),
550 },
551 payload.clone(),
552 );
553
554 let notifier = drt
556 .system_health
557 .lock()
558 .unwrap()
559 .get_endpoint_health_check_notifier(endpoint);
560
561 assert!(
562 notifier.is_some(),
563 "Endpoint should have a notifier created"
564 );
565
566 if let Some(notifier) = notifier {
568 notifier.notify_one();
569 }
570
571 let status = drt
573 .system_health
574 .lock()
575 .unwrap()
576 .get_endpoint_health_status(endpoint);
577 assert_eq!(status, Some(HealthStatus::Ready));
578 }
579}