1use crate::component::{Client, Component, Endpoint, Instance};
5use crate::pipeline::PushRouter;
6use crate::pipeline::{AsyncEngine, Context, ManyOut, SingleIn};
7use crate::protocols::annotated::Annotated;
8use crate::protocols::maybe_error::MaybeError;
9use crate::{DistributedRuntime, HealthStatus, SystemHealth};
10use futures::StreamExt;
11use parking_lot::Mutex;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16use tokio::task::JoinHandle;
17use tokio::time::{MissedTickBehavior, interval};
18use tracing::{debug, error, info, warn};
19
20pub struct HealthCheckConfig {
22 pub canary_wait_time: Duration,
24 pub request_timeout: Duration,
26}
27
28impl Default for HealthCheckConfig {
29 fn default() -> Self {
30 Self {
31 canary_wait_time: Duration::from_secs(crate::config::DEFAULT_CANARY_WAIT_TIME_SECS),
32 request_timeout: Duration::from_secs(
33 crate::config::DEFAULT_HEALTH_CHECK_REQUEST_TIMEOUT_SECS,
34 ),
35 }
36 }
37}
38
39type RouterCache =
42 Arc<Mutex<HashMap<String, Arc<PushRouter<serde_json::Value, Annotated<serde_json::Value>>>>>>;
43
44pub struct HealthCheckManager {
46 drt: DistributedRuntime,
47 config: HealthCheckConfig,
48 router_cache: RouterCache,
50 endpoint_tasks: Arc<Mutex<HashMap<String, JoinHandle<()>>>>,
53}
54
55impl HealthCheckManager {
56 pub fn new(drt: DistributedRuntime, config: HealthCheckConfig) -> Self {
57 Self {
58 drt,
59 config,
60 router_cache: Arc::new(Mutex::new(HashMap::new())),
61 endpoint_tasks: Arc::new(Mutex::new(HashMap::new())),
62 }
63 }
64
65 async fn get_or_create_router(
67 &self,
68 cache_key: &str,
69 endpoint: Endpoint,
70 ) -> anyhow::Result<Arc<PushRouter<serde_json::Value, Annotated<serde_json::Value>>>> {
71 let cache_key = cache_key.to_string();
72
73 {
75 let cache = self.router_cache.lock();
76 if let Some(router) = cache.get(&cache_key) {
77 return Ok(router.clone());
78 }
79 }
80
81 let client = Client::new_dynamic(endpoint).await?;
83
84 let router: Arc<PushRouter<serde_json::Value, Annotated<serde_json::Value>>> = Arc::new(
86 PushRouter::from_client(
87 client,
88 crate::pipeline::RouterMode::RoundRobin, )
90 .await?,
91 );
92
93 self.router_cache.lock().insert(cache_key, router.clone());
95
96 Ok(router)
97 }
98
99 pub async fn start(self: Arc<Self>) -> anyhow::Result<()> {
101 let targets = {
103 let system_health = self.drt.system_health.lock();
104 system_health.get_health_check_targets()
105 };
106
107 info!(
108 "Starting health check tasks for {} endpoints with canary_wait_time: {:?}",
109 targets.len(),
110 self.config.canary_wait_time
111 );
112
113 for (endpoint_subject, _target) in targets {
115 self.spawn_endpoint_health_check_task(endpoint_subject);
116 }
117
118 self.spawn_new_endpoint_monitor().await?;
122
123 info!("HealthCheckManager started successfully with channel-based endpoint discovery");
124 Ok(())
125 }
126
127 fn spawn_endpoint_health_check_task(self: &Arc<Self>, endpoint_subject: String) {
129 let manager = self.clone();
130 let canary_wait = self.config.canary_wait_time;
131 let endpoint_subject_clone = endpoint_subject.clone();
132
133 let notifier = {
135 let system_health = self.drt.system_health.lock();
136 system_health
137 .get_endpoint_health_check_notifier(&endpoint_subject)
138 .expect("Notifier should exist for registered endpoint")
139 };
140
141 let task = tokio::spawn(async move {
142 let endpoint_subject = endpoint_subject_clone;
143 info!("Health check task started for: {}", endpoint_subject);
144
145 loop {
146 tokio::select! {
148 _ = tokio::time::sleep(canary_wait) => {
149 info!("Canary timer expired for {}, sending health check", endpoint_subject);
151
152 let target = {
154 let system_health = manager.drt.system_health.lock();
155 system_health.get_health_check_target(&endpoint_subject)
156 };
157
158 if let Some(target) = target {
159 if let Err(e) = manager.send_health_check_request(&endpoint_subject, &target.payload).await {
160 error!("Failed to send health check for {}: {}", endpoint_subject, e);
161 }
162 } else {
163 error!(
165 "CRITICAL: Health check target for {} disappeared unexpectedly! This indicates a bug. Stopping health check task.",
166 endpoint_subject
167 );
168 break;
169 }
170 }
171
172 _ = notifier.notified() => {
173 debug!("Activity detected for {}, resetting health check timer", endpoint_subject);
175 }
177 }
178 }
179
180 info!("Health check task for {} exiting", endpoint_subject);
181 });
182
183 self.endpoint_tasks
185 .lock()
186 .insert(endpoint_subject.clone(), task);
187
188 info!(
189 "Spawned health check task for endpoint: {}",
190 endpoint_subject
191 );
192 }
193
194 async fn spawn_new_endpoint_monitor(self: &Arc<Self>) -> anyhow::Result<()> {
197 let manager = self.clone();
198
199 let mut rx = {
201 let system_health = manager.drt.system_health.lock();
202 system_health.take_new_endpoint_receiver().ok_or_else(|| {
203 anyhow::anyhow!("Endpoint receiver already taken - this should only be called once")
204 })?
205 };
206
207 tokio::spawn(async move {
208 info!("Starting dynamic endpoint discovery monitor with channel-based notifications");
209
210 while let Some(endpoint_subject) = rx.recv().await {
211 debug!(
212 "Received endpoint registration via channel: {}",
213 endpoint_subject
214 );
215
216 let already_exists = {
217 let tasks = manager.endpoint_tasks.lock();
218 tasks.contains_key(&endpoint_subject)
219 };
220
221 if already_exists {
222 error!(
223 "CRITICAL: Received registration for endpoint '{}' that already has a health check task!",
224 endpoint_subject
225 );
226 break;
227 }
228
229 info!(
230 "Spawning health check task for new endpoint: {}",
231 endpoint_subject
232 );
233 manager.spawn_endpoint_health_check_task(endpoint_subject);
234 }
235
236 info!("Endpoint discovery monitor exiting - no new endpoints will be monitored!");
237 });
238
239 info!("Dynamic endpoint discovery monitor started");
240 Ok(())
241 }
242
243 async fn send_health_check_request(
245 &self,
246 endpoint_subject: &str,
247 payload: &serde_json::Value,
248 ) -> anyhow::Result<()> {
249 let target = {
250 let system_health = self.drt.system_health.lock();
251 system_health
252 .get_health_check_target(endpoint_subject)
253 .ok_or_else(|| {
254 anyhow::anyhow!("No health check target found for {}", endpoint_subject)
255 })?
256 };
257
258 debug!(
259 "Sending health check to {} (instance_id: {})",
260 endpoint_subject, target.instance.instance_id
261 );
262
263 let namespace = self.drt.namespace(&target.instance.namespace)?;
265 let component = namespace.component(&target.instance.component)?;
266 let endpoint = component.endpoint(&target.instance.endpoint);
267
268 let router = self
270 .get_or_create_router(endpoint_subject, endpoint)
271 .await?;
272
273 let request: SingleIn<serde_json::Value> = Context::new(payload.clone());
275
276 let system_health = self.drt.system_health.clone();
278 let endpoint_subject_owned = endpoint_subject.to_string();
279 let instance_id = target.instance.instance_id;
280 let timeout = self.config.request_timeout;
281
282 tokio::spawn(async move {
284 let result = tokio::time::timeout(timeout, async {
285 match router.direct(request, instance_id).await {
287 Ok(mut response_stream) => {
288 let is_healthy = if let Some(response) = response_stream.next().await {
290 if let Some(error) = response.err() {
292 warn!(
293 "Health check error response from {}: {:?}",
294 endpoint_subject_owned, error
295 );
296 false
297 } else {
298 info!("Health check successful for {}", endpoint_subject_owned);
299 true
300 }
301 } else {
302 warn!(
303 "Health check got no response from {}",
304 endpoint_subject_owned
305 );
306 false
307 };
308
309 system_health.lock().set_endpoint_health_status(
311 &endpoint_subject_owned,
312 if is_healthy {
313 HealthStatus::Ready
314 } else {
315 HealthStatus::NotReady
316 },
317 );
318 }
319 Err(e) => {
320 error!(
321 "Health check request failed for {}: {}",
322 endpoint_subject_owned, e
323 );
324 system_health.lock().set_endpoint_health_status(
325 &endpoint_subject_owned,
326 HealthStatus::NotReady,
327 );
328 }
329 }
330 })
331 .await;
332
333 if result.is_err() {
335 warn!("Health check timeout for {}", endpoint_subject_owned);
336 system_health
337 .lock()
338 .set_endpoint_health_status(&endpoint_subject_owned, HealthStatus::NotReady);
339 }
340
341 debug!("Health check completed for {}", endpoint_subject_owned);
342 });
343
344 Ok(())
345 }
346}
347
348pub async fn start_health_check_manager(
350 drt: DistributedRuntime,
351 config: Option<HealthCheckConfig>,
352) -> anyhow::Result<()> {
353 let config = config.unwrap_or_default();
354 let manager = Arc::new(HealthCheckManager::new(drt, config));
355
356 manager.start().await?;
358
359 Ok(())
360}
361
362pub async fn get_health_check_status(
364 drt: &DistributedRuntime,
365) -> anyhow::Result<serde_json::Value> {
366 let endpoint_subjects: Vec<String> = {
368 let system_health = drt.system_health.lock();
369 system_health.get_health_check_endpoints()
370 };
371
372 let mut endpoint_statuses = HashMap::new();
373
374 {
376 let system_health = drt.system_health.lock();
377 for endpoint_subject in &endpoint_subjects {
378 let health_status = system_health
379 .get_endpoint_health_status(endpoint_subject)
380 .unwrap_or(HealthStatus::NotReady);
381
382 let is_healthy = matches!(health_status, HealthStatus::Ready);
383
384 endpoint_statuses.insert(
385 endpoint_subject.clone(),
386 serde_json::json!({
387 "healthy": is_healthy,
388 "status": format!("{:?}", health_status),
389 }),
390 );
391 }
392 }
393
394 let overall_healthy = endpoint_statuses
395 .values()
396 .all(|v| v["healthy"].as_bool().unwrap_or(false));
397
398 Ok(serde_json::json!({
399 "status": if overall_healthy { "ready" } else { "notready" },
400 "endpoints_checked": endpoint_subjects.len(),
401 "endpoint_statuses": endpoint_statuses,
402 }))
403}
404
405#[cfg(all(test, feature = "integration"))]
409mod integration_tests {
410 use super::*;
411 use crate::HealthStatus;
412 use crate::distributed::distributed_test_utils::create_test_drt_async;
413 use std::sync::Arc;
414 use std::time::Duration;
415
416 #[tokio::test]
417 async fn test_initialization() {
418 let drt = create_test_drt_async().await;
419
420 let canary_wait_time = Duration::from_secs(5);
421 let request_timeout = Duration::from_secs(3);
422
423 let config = HealthCheckConfig {
424 canary_wait_time,
425 request_timeout,
426 };
427
428 let manager = HealthCheckManager::new(drt.clone(), config);
429
430 assert_eq!(manager.config.canary_wait_time, canary_wait_time);
431 assert_eq!(manager.config.request_timeout, request_timeout);
432
433 assert!(Arc::ptr_eq(&manager.drt.system_health, &drt.system_health));
434 }
435
436 #[tokio::test]
437 async fn test_payload_registration() {
438 let drt = create_test_drt_async().await;
439
440 let endpoint = "test.endpoint";
441 let payload = serde_json::json!({
442 "prompt": "test",
443 "_health_check": true
444 });
445
446 drt.system_health.lock().register_health_check_target(
447 endpoint,
448 crate::component::Instance {
449 component: "test_component".to_string(),
450 endpoint: "test_endpoint".to_string(),
451 namespace: "test_namespace".to_string(),
452 instance_id: 12345,
453 transport: crate::component::TransportType::NatsTcp(endpoint.to_string()),
454 },
455 payload.clone(),
456 );
457
458 let retrieved = drt
459 .system_health
460 .lock()
461 .get_health_check_target(endpoint)
462 .map(|t| t.payload);
463 assert!(retrieved.is_some());
464 assert_eq!(retrieved.unwrap(), payload);
465
466 let endpoints = drt.system_health.lock().get_health_check_endpoints();
468 assert!(endpoints.contains(&endpoint.to_string()));
469 }
470
471 #[tokio::test]
472 async fn test_spawn_per_endpoint_tasks() {
473 let drt = create_test_drt_async().await;
474
475 for i in 0..3 {
476 let endpoint = format!("test.endpoint.{}", i);
477 let payload = serde_json::json!({
478 "prompt": format!("test{}", i),
479 "_health_check": true
480 });
481 drt.system_health.lock().register_health_check_target(
482 &endpoint,
483 crate::component::Instance {
484 component: "test_component".to_string(),
485 endpoint: format!("test_endpoint_{}", i),
486 namespace: "test_namespace".to_string(),
487 instance_id: i,
488 transport: crate::component::TransportType::NatsTcp(endpoint.clone()),
489 },
490 payload,
491 );
492 }
493
494 let config = HealthCheckConfig {
495 canary_wait_time: Duration::from_secs(5),
496 request_timeout: Duration::from_secs(1),
497 };
498
499 let manager = Arc::new(HealthCheckManager::new(drt.clone(), config));
500 manager.clone().start().await.unwrap();
501
502 let tasks = manager.endpoint_tasks.lock();
504 assert_eq!(tasks.len(), 3);
506 let endpoints: Vec<String> = tasks.keys().cloned().collect();
508 assert!(endpoints.contains(&"test.endpoint.0".to_string()));
509 assert!(endpoints.contains(&"test.endpoint.1".to_string()));
510 assert!(endpoints.contains(&"test.endpoint.2".to_string()));
511 }
512
513 #[tokio::test]
514 async fn test_endpoint_health_check_notifier_created() {
515 let drt = create_test_drt_async().await;
516
517 let endpoint = "test.endpoint.notifier";
518 let payload = serde_json::json!({
519 "prompt": "test",
520 "_health_check": true
521 });
522
523 drt.system_health.lock().register_health_check_target(
525 endpoint,
526 crate::component::Instance {
527 component: "test_component".to_string(),
528 endpoint: "test_endpoint_notifier".to_string(),
529 namespace: "test_namespace".to_string(),
530 instance_id: 999,
531 transport: crate::component::TransportType::NatsTcp(endpoint.to_string()),
532 },
533 payload.clone(),
534 );
535
536 let notifier = drt
538 .system_health
539 .lock()
540 .get_endpoint_health_check_notifier(endpoint);
541
542 assert!(
543 notifier.is_some(),
544 "Endpoint should have a notifier created"
545 );
546
547 if let Some(notifier) = notifier {
549 notifier.notify_one();
550 }
551
552 let status = drt
554 .system_health
555 .lock()
556 .get_endpoint_health_status(endpoint);
557 assert_eq!(status, Some(HealthStatus::NotReady));
558 }
559}