1use crate::component::{Client, Component, Endpoint, Instance};
5use crate::config::HealthStatus;
6use crate::pipeline::PushRouter;
7use crate::pipeline::{AsyncEngine, Context, ManyOut, SingleIn};
8use crate::protocols::annotated::Annotated;
9use crate::protocols::maybe_error::MaybeError;
10use crate::{DistributedRuntime, SystemHealth};
11use futures::StreamExt;
12use parking_lot::Mutex;
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::sync::Arc;
16use std::time::{Duration, Instant};
17use tokio::task::JoinHandle;
18use tokio::time::{MissedTickBehavior, interval};
19use tracing::{debug, error, info, warn};
20
21pub struct HealthCheckConfig {
23 pub canary_wait_time: Duration,
25 pub request_timeout: Duration,
27}
28
29impl Default for HealthCheckConfig {
30 fn default() -> Self {
31 Self {
32 canary_wait_time: Duration::from_secs(crate::config::DEFAULT_CANARY_WAIT_TIME_SECS),
33 request_timeout: Duration::from_secs(
34 crate::config::DEFAULT_HEALTH_CHECK_REQUEST_TIMEOUT_SECS,
35 ),
36 }
37 }
38}
39
40type RouterCache =
43 Arc<Mutex<HashMap<String, Arc<PushRouter<serde_json::Value, Annotated<serde_json::Value>>>>>>;
44
45pub struct HealthCheckManager {
47 drt: DistributedRuntime,
48 config: HealthCheckConfig,
49 router_cache: RouterCache,
51 endpoint_tasks: Arc<Mutex<HashMap<String, JoinHandle<()>>>>,
54}
55
56impl HealthCheckManager {
57 pub fn new(drt: DistributedRuntime, config: HealthCheckConfig) -> Self {
58 Self {
59 drt,
60 config,
61 router_cache: Arc::new(Mutex::new(HashMap::new())),
62 endpoint_tasks: Arc::new(Mutex::new(HashMap::new())),
63 }
64 }
65
66 async fn get_or_create_router(
68 &self,
69 cache_key: &str,
70 endpoint: Endpoint,
71 ) -> anyhow::Result<Arc<PushRouter<serde_json::Value, Annotated<serde_json::Value>>>> {
72 let cache_key = cache_key.to_string();
73
74 {
76 let cache = self.router_cache.lock();
77 if let Some(router) = cache.get(&cache_key) {
78 return Ok(router.clone());
79 }
80 }
81
82 let client = Client::new_dynamic(endpoint).await?;
84
85 let router: Arc<PushRouter<serde_json::Value, Annotated<serde_json::Value>>> = Arc::new(
87 PushRouter::from_client(
88 client,
89 crate::pipeline::RouterMode::RoundRobin, )
91 .await?,
92 );
93
94 self.router_cache.lock().insert(cache_key, router.clone());
96
97 Ok(router)
98 }
99
100 pub async fn start(self: Arc<Self>) -> anyhow::Result<()> {
102 let targets = self.drt.system_health().lock().get_health_check_targets();
104
105 info!(
106 "Starting health check tasks for {} endpoints with canary_wait_time: {:?}",
107 targets.len(),
108 self.config.canary_wait_time
109 );
110
111 for (endpoint_subject, _target) in targets {
113 self.spawn_endpoint_health_check_task(endpoint_subject);
114 }
115
116 self.spawn_new_endpoint_monitor().await?;
120
121 info!("HealthCheckManager started successfully with channel-based endpoint discovery");
122 Ok(())
123 }
124
125 fn spawn_endpoint_health_check_task(self: &Arc<Self>, endpoint_subject: String) {
127 let manager = self.clone();
128 let canary_wait = self.config.canary_wait_time;
129 let endpoint_subject_clone = endpoint_subject.clone();
130
131 let notifier = self
133 .drt
134 .system_health()
135 .lock()
136 .get_endpoint_health_check_notifier(&endpoint_subject)
137 .expect("Notifier should exist for registered endpoint");
138
139 let task = tokio::spawn(async move {
140 let endpoint_subject = endpoint_subject_clone;
141 info!("Health check task started for: {}", endpoint_subject);
142
143 loop {
144 tokio::select! {
146 _ = tokio::time::sleep(canary_wait) => {
147 info!("Canary timer expired for {}, sending health check", endpoint_subject);
149
150 let target = manager.drt.system_health().lock().get_health_check_target(&endpoint_subject);
152
153 if let Some(target) = target {
154 if let Err(e) = manager.send_health_check_request(&endpoint_subject, &target.payload).await {
155 error!("Failed to send health check for {}: {}", endpoint_subject, e);
156 }
157 } else {
158 error!(
160 "CRITICAL: Health check target for {} disappeared unexpectedly! This indicates a bug. Stopping health check task.",
161 endpoint_subject
162 );
163 break;
164 }
165 }
166
167 _ = notifier.notified() => {
168 debug!("Activity detected for {}, resetting health check timer", endpoint_subject);
170 }
172 }
173 }
174
175 info!("Health check task for {} exiting", endpoint_subject);
176 });
177
178 self.endpoint_tasks
180 .lock()
181 .insert(endpoint_subject.clone(), task);
182
183 info!(
184 "Spawned health check task for endpoint: {}",
185 endpoint_subject
186 );
187 }
188
189 async fn spawn_new_endpoint_monitor(self: &Arc<Self>) -> anyhow::Result<()> {
192 let manager = self.clone();
193
194 let mut rx = manager
196 .drt
197 .system_health()
198 .lock()
199 .take_new_endpoint_receiver()
200 .ok_or_else(|| {
201 anyhow::anyhow!("Endpoint receiver already taken - this should only be called once")
202 })?;
203
204 tokio::spawn(async move {
205 info!("Starting dynamic endpoint discovery monitor with channel-based notifications");
206
207 while let Some(endpoint_subject) = rx.recv().await {
208 debug!(
209 "Received endpoint registration via channel: {}",
210 endpoint_subject
211 );
212
213 let already_exists = {
214 let tasks = manager.endpoint_tasks.lock();
215 tasks.contains_key(&endpoint_subject)
216 };
217
218 if already_exists {
219 error!(
220 "CRITICAL: Received registration for endpoint '{}' that already has a health check task!",
221 endpoint_subject
222 );
223 break;
224 }
225
226 info!(
227 "Spawning health check task for new endpoint: {}",
228 endpoint_subject
229 );
230 manager.spawn_endpoint_health_check_task(endpoint_subject);
231 }
232
233 info!("Endpoint discovery monitor exiting - no new endpoints will be monitored!");
234 });
235
236 info!("Dynamic endpoint discovery monitor started");
237 Ok(())
238 }
239
240 async fn send_health_check_request(
242 &self,
243 endpoint_subject: &str,
244 payload: &serde_json::Value,
245 ) -> anyhow::Result<()> {
246 let target = self
247 .drt
248 .system_health()
249 .lock()
250 .get_health_check_target(endpoint_subject)
251 .ok_or_else(|| {
252 anyhow::anyhow!("No health check target found for {}", endpoint_subject)
253 })?;
254
255 debug!(
256 "Sending health check to {} (instance_id: {})",
257 endpoint_subject, target.instance.instance_id
258 );
259
260 let namespace = self.drt.namespace(&target.instance.namespace)?;
262 let component = namespace.component(&target.instance.component)?;
263 let endpoint = component.endpoint(&target.instance.endpoint);
264
265 let router = self
267 .get_or_create_router(endpoint_subject, endpoint)
268 .await?;
269
270 let request: SingleIn<serde_json::Value> = Context::new(payload.clone());
272
273 let system_health = self.drt.system_health().clone();
275 let endpoint_subject_owned = endpoint_subject.to_string();
276 let instance_id = target.instance.instance_id;
277 let timeout = self.config.request_timeout;
278
279 tokio::spawn(async move {
281 let result = tokio::time::timeout(timeout, async {
282 match router.direct(request, instance_id).await {
284 Ok(mut response_stream) => {
285 let is_healthy = if let Some(response) = response_stream.next().await {
287 if let Some(error) = response.err() {
289 warn!(
290 "Health check error response from {}: {:?}",
291 endpoint_subject_owned, error
292 );
293 false
294 } else {
295 info!("Health check successful for {}", endpoint_subject_owned);
296 true
297 }
298 } else {
299 warn!(
300 "Health check got no response from {}",
301 endpoint_subject_owned
302 );
303 false
304 };
305
306 system_health.lock().set_endpoint_health_status(
308 &endpoint_subject_owned,
309 if is_healthy {
310 HealthStatus::Ready
311 } else {
312 HealthStatus::NotReady
313 },
314 );
315 }
316 Err(e) => {
317 error!(
318 "Health check request failed for {}: {}",
319 endpoint_subject_owned, e
320 );
321 system_health.lock().set_endpoint_health_status(
322 &endpoint_subject_owned,
323 HealthStatus::NotReady,
324 );
325 }
326 }
327 })
328 .await;
329
330 if result.is_err() {
332 warn!("Health check timeout for {}", endpoint_subject_owned);
333 system_health
334 .lock()
335 .set_endpoint_health_status(&endpoint_subject_owned, HealthStatus::NotReady);
336 }
337
338 debug!("Health check completed for {}", endpoint_subject_owned);
339 });
340
341 Ok(())
342 }
343}
344
345pub async fn start_health_check_manager(
347 drt: DistributedRuntime,
348 config: Option<HealthCheckConfig>,
349) -> anyhow::Result<()> {
350 let config = config.unwrap_or_default();
351 let manager = Arc::new(HealthCheckManager::new(drt, config));
352
353 manager.start().await?;
355
356 Ok(())
357}
358
359pub async fn get_health_check_status(
361 drt: &DistributedRuntime,
362) -> anyhow::Result<serde_json::Value> {
363 let endpoint_subjects: Vec<String> = drt.system_health().lock().get_health_check_endpoints();
365
366 let mut endpoint_statuses = HashMap::new();
367
368 {
370 let system_health = drt.system_health();
371 let system_health_lock = system_health.lock();
372 for endpoint_subject in &endpoint_subjects {
373 let health_status = system_health_lock
374 .get_endpoint_health_status(endpoint_subject)
375 .unwrap_or(HealthStatus::NotReady);
376
377 let is_healthy = matches!(health_status, HealthStatus::Ready);
378
379 endpoint_statuses.insert(
380 endpoint_subject.clone(),
381 serde_json::json!({
382 "healthy": is_healthy,
383 "status": format!("{:?}", health_status),
384 }),
385 );
386 }
387 }
388
389 let overall_healthy = endpoint_statuses
390 .values()
391 .all(|v| v["healthy"].as_bool().unwrap_or(false));
392
393 Ok(serde_json::json!({
394 "status": if overall_healthy { "ready" } else { "notready" },
395 "endpoints_checked": endpoint_subjects.len(),
396 "endpoint_statuses": endpoint_statuses,
397 }))
398}
399
400#[cfg(all(test, feature = "integration"))]
404mod integration_tests {
405 use super::*;
406 use crate::distributed::distributed_test_utils::create_test_drt_async;
407 use std::sync::Arc;
408 use std::time::Duration;
409
410 #[tokio::test]
411 async fn test_initialization() {
412 let drt = create_test_drt_async().await;
413
414 let canary_wait_time = Duration::from_secs(5);
415 let request_timeout = Duration::from_secs(3);
416
417 let config = HealthCheckConfig {
418 canary_wait_time,
419 request_timeout,
420 };
421
422 let manager = HealthCheckManager::new(drt.clone(), config);
423
424 assert_eq!(manager.config.canary_wait_time, canary_wait_time);
425 assert_eq!(manager.config.request_timeout, request_timeout);
426 }
427
428 #[tokio::test]
429 async fn test_payload_registration() {
430 let drt = create_test_drt_async().await;
431
432 let endpoint = "test.endpoint";
433 let payload = serde_json::json!({
434 "prompt": "test",
435 "_health_check": true
436 });
437
438 drt.system_health().lock().register_health_check_target(
439 endpoint,
440 crate::component::Instance {
441 component: "test_component".to_string(),
442 endpoint: "test_endpoint".to_string(),
443 namespace: "test_namespace".to_string(),
444 instance_id: 12345,
445 transport: crate::component::TransportType::Nats(endpoint.to_string()),
446 },
447 payload.clone(),
448 );
449
450 let retrieved = drt
451 .system_health()
452 .lock()
453 .get_health_check_target(endpoint)
454 .map(|t| t.payload);
455 assert!(retrieved.is_some());
456 assert_eq!(retrieved.unwrap(), payload);
457
458 let endpoints = drt.system_health().lock().get_health_check_endpoints();
460 assert!(endpoints.contains(&endpoint.to_string()));
461 }
462
463 #[tokio::test]
464 async fn test_spawn_per_endpoint_tasks() {
465 let drt = create_test_drt_async().await;
466
467 for i in 0..3 {
468 let endpoint = format!("test.endpoint.{}", i);
469 let payload = serde_json::json!({
470 "prompt": format!("test{}", i),
471 "_health_check": true
472 });
473 drt.system_health().lock().register_health_check_target(
474 &endpoint,
475 crate::component::Instance {
476 component: "test_component".to_string(),
477 endpoint: format!("test_endpoint_{}", i),
478 namespace: "test_namespace".to_string(),
479 instance_id: i,
480 transport: crate::component::TransportType::Nats(endpoint.clone()),
481 },
482 payload,
483 );
484 }
485
486 let config = HealthCheckConfig {
487 canary_wait_time: Duration::from_secs(5),
488 request_timeout: Duration::from_secs(1),
489 };
490
491 let manager = Arc::new(HealthCheckManager::new(drt.clone(), config));
492 manager.clone().start().await.unwrap();
493
494 let tasks = manager.endpoint_tasks.lock();
496 assert_eq!(tasks.len(), 3);
498 let endpoints: Vec<String> = tasks.keys().cloned().collect();
500 assert!(endpoints.contains(&"test.endpoint.0".to_string()));
501 assert!(endpoints.contains(&"test.endpoint.1".to_string()));
502 assert!(endpoints.contains(&"test.endpoint.2".to_string()));
503 }
504
505 #[tokio::test]
506 async fn test_endpoint_health_check_notifier_created() {
507 let drt = create_test_drt_async().await;
508
509 let endpoint = "test.endpoint.notifier";
510 let payload = serde_json::json!({
511 "prompt": "test",
512 "_health_check": true
513 });
514
515 drt.system_health().lock().register_health_check_target(
517 endpoint,
518 crate::component::Instance {
519 component: "test_component".to_string(),
520 endpoint: "test_endpoint_notifier".to_string(),
521 namespace: "test_namespace".to_string(),
522 instance_id: 999,
523 transport: crate::component::TransportType::Nats(endpoint.to_string()),
524 },
525 payload.clone(),
526 );
527
528 let notifier = drt
530 .system_health()
531 .lock()
532 .get_endpoint_health_check_notifier(endpoint);
533
534 assert!(
535 notifier.is_some(),
536 "Endpoint should have a notifier created"
537 );
538
539 if let Some(notifier) = notifier {
541 notifier.notify_one();
542 }
543
544 let status = drt
546 .system_health()
547 .lock()
548 .get_endpoint_health_status(endpoint);
549 assert_eq!(status, Some(HealthStatus::NotReady));
550 }
551}