1use crate::DistributedRuntime;
5use crate::config::HealthStatus;
6use crate::engine::AsyncEngine;
7use crate::pipeline::SingleIn;
8use crate::protocols::maybe_error::MaybeError;
9use futures::StreamExt;
10use parking_lot::Mutex;
11use std::collections::HashMap;
12use std::sync::Arc;
13use std::time::Duration;
14use tokio::task::JoinHandle;
15use tracing::{debug, error, info, warn};
16
17pub struct HealthCheckConfig {
19 pub canary_wait_time: Duration,
21 pub request_timeout: Duration,
23}
24
25impl Default for HealthCheckConfig {
26 fn default() -> Self {
27 Self {
28 canary_wait_time: Duration::from_secs(crate::config::DEFAULT_CANARY_WAIT_TIME_SECS),
29 request_timeout: Duration::from_secs(
30 crate::config::DEFAULT_HEALTH_CHECK_REQUEST_TIMEOUT_SECS,
31 ),
32 }
33 }
34}
35
36pub struct HealthCheckManager {
38 drt: DistributedRuntime,
39 config: HealthCheckConfig,
40 endpoint_tasks: Arc<Mutex<HashMap<String, JoinHandle<()>>>>,
43}
44
45impl HealthCheckManager {
46 pub fn new(drt: DistributedRuntime, config: HealthCheckConfig) -> Self {
47 Self {
48 drt,
49 config,
50 endpoint_tasks: Arc::new(Mutex::new(HashMap::new())),
51 }
52 }
53
54 pub async fn start(self: Arc<Self>) -> anyhow::Result<()> {
56 let targets = self.drt.system_health().lock().get_health_check_targets();
58
59 info!(
60 "Starting health check tasks for {} endpoints with canary_wait_time: {:?}",
61 targets.len(),
62 self.config.canary_wait_time
63 );
64
65 for (endpoint_subject, _target) in targets {
67 self.spawn_endpoint_health_check_task(endpoint_subject);
68 }
69
70 self.spawn_new_endpoint_monitor().await?;
74
75 info!("HealthCheckManager started successfully with channel-based endpoint discovery");
76 Ok(())
77 }
78
79 fn spawn_endpoint_health_check_task(self: &Arc<Self>, endpoint_subject: String) {
81 let manager = self.clone();
82 let canary_wait = self.config.canary_wait_time;
83 let endpoint_subject_clone = endpoint_subject.clone();
84
85 let notifier = self
87 .drt
88 .system_health()
89 .lock()
90 .get_endpoint_health_check_notifier(&endpoint_subject)
91 .expect("Notifier should exist for registered endpoint");
92
93 let task = tokio::spawn(async move {
94 let endpoint_subject = endpoint_subject_clone;
95 info!("Health check task started for: {}", endpoint_subject);
96
97 loop {
98 tokio::select! {
100 _ = tokio::time::sleep(canary_wait) => {
101 debug!("Canary timer expired for {}, sending health check", endpoint_subject);
103
104 let target = manager.drt.system_health().lock().get_health_check_target(&endpoint_subject);
106
107 if let Some(target) = target {
108 if let Err(e) = manager.send_health_check_request(&endpoint_subject, &target.payload).await {
109 error!("Failed to send health check for {}: {}", endpoint_subject, e);
110 }
111 } else {
112 error!(
114 "CRITICAL: Health check target for {} disappeared unexpectedly! This indicates a bug. Stopping health check task.",
115 endpoint_subject
116 );
117 break;
118 }
119 }
120
121 _ = notifier.notified() => {
122 debug!("Activity detected for {}, resetting health check timer", endpoint_subject);
124 }
126 }
127 }
128
129 info!("Health check task for {} exiting", endpoint_subject);
130 });
131
132 self.endpoint_tasks
134 .lock()
135 .insert(endpoint_subject.clone(), task);
136
137 info!(
138 "Spawned health check task for endpoint: {}",
139 endpoint_subject
140 );
141 }
142
143 async fn spawn_new_endpoint_monitor(self: &Arc<Self>) -> anyhow::Result<()> {
146 let manager = self.clone();
147
148 let mut rx = manager
150 .drt
151 .system_health()
152 .lock()
153 .take_new_endpoint_receiver()
154 .ok_or_else(|| {
155 anyhow::anyhow!("Endpoint receiver already taken - this should only be called once")
156 })?;
157
158 tokio::spawn(async move {
159 info!("Starting dynamic endpoint discovery monitor with channel-based notifications");
160
161 while let Some(endpoint_subject) = rx.recv().await {
162 debug!(
163 "Received endpoint registration via channel: {}",
164 endpoint_subject
165 );
166
167 let already_exists = {
168 let tasks = manager.endpoint_tasks.lock();
169 tasks.contains_key(&endpoint_subject)
170 };
171
172 if already_exists {
173 error!(
174 "CRITICAL: Received registration for endpoint '{}' that already has a health check task!",
175 endpoint_subject
176 );
177 break;
178 }
179
180 info!(
181 "Spawning health check task for new endpoint: {}",
182 endpoint_subject
183 );
184 manager.spawn_endpoint_health_check_task(endpoint_subject);
185 }
186
187 info!("Endpoint discovery monitor exiting - no new endpoints will be monitored!");
188 });
189
190 info!("Dynamic endpoint discovery monitor started");
191 Ok(())
192 }
193
194 async fn send_health_check_request(
196 &self,
197 endpoint_subject: &str,
198 payload: &serde_json::Value,
199 ) -> anyhow::Result<()> {
200 debug!(
201 "Sending health check to {} via local registry",
202 endpoint_subject
203 );
204
205 let engine = self
206 .drt
207 .local_endpoint_registry()
208 .get(endpoint_subject)
209 .ok_or_else(|| {
210 anyhow::anyhow!(
211 "Endpoint '{}' not found in local registry, engine may still be initializing",
212 endpoint_subject
213 )
214 })?;
215
216 let system_health = self.drt.system_health().clone();
218 let endpoint_subject_owned = endpoint_subject.to_string();
219 let payload = payload.clone();
220 let timeout = self.config.request_timeout;
221
222 tokio::spawn(async move {
224 let result = tokio::time::timeout(timeout, async {
225 let request = SingleIn::new(payload);
226 match engine.generate(request).await {
227 Ok(mut response_stream) => {
228 let is_healthy = if let Some(response) = response_stream.next().await {
231 if let Some(error) = response.err() {
232 warn!(
233 "Health check error response from {}: {:?}",
234 endpoint_subject_owned, error
235 );
236 false
237 } else {
238 debug!("Health check successful for {}", endpoint_subject_owned);
239 true
240 }
241 } else {
242 warn!(
243 "Health check got no response from {}",
244 endpoint_subject_owned
245 );
246 false
247 };
248
249 tokio::spawn(async move {
250 response_stream.for_each(|_| async {}).await;
252 });
253
254 system_health.lock().set_endpoint_health_status(
256 &endpoint_subject_owned,
257 if is_healthy {
258 HealthStatus::Ready
259 } else {
260 HealthStatus::NotReady
261 },
262 );
263 }
264 Err(e) => {
265 error!(
266 "Health check request failed for {}: {}",
267 endpoint_subject_owned, e
268 );
269 system_health.lock().set_endpoint_health_status(
270 &endpoint_subject_owned,
271 HealthStatus::NotReady,
272 );
273 }
274 }
275 })
276 .await;
277
278 if result.is_err() {
280 warn!("Health check timeout for {}", endpoint_subject_owned);
281 system_health
282 .lock()
283 .set_endpoint_health_status(&endpoint_subject_owned, HealthStatus::NotReady);
284 }
285
286 debug!("Health check completed for {}", endpoint_subject_owned);
287 });
288
289 Ok(())
290 }
291}
292
293pub async fn start_health_check_manager(
295 drt: DistributedRuntime,
296 config: Option<HealthCheckConfig>,
297) -> anyhow::Result<()> {
298 let config = config.unwrap_or_default();
299 let manager = Arc::new(HealthCheckManager::new(drt, config));
300
301 manager.start().await?;
303
304 Ok(())
305}
306
307pub async fn get_health_check_status(
309 drt: &DistributedRuntime,
310) -> anyhow::Result<serde_json::Value> {
311 let endpoint_subjects: Vec<String> = drt.system_health().lock().get_health_check_endpoints();
313
314 let mut endpoint_statuses = HashMap::new();
315
316 {
318 let system_health = drt.system_health();
319 let system_health_lock = system_health.lock();
320 for endpoint_subject in &endpoint_subjects {
321 let health_status = system_health_lock
322 .get_endpoint_health_status(endpoint_subject)
323 .unwrap_or(HealthStatus::NotReady);
324
325 let is_healthy = matches!(health_status, HealthStatus::Ready);
326
327 endpoint_statuses.insert(
328 endpoint_subject.clone(),
329 serde_json::json!({
330 "healthy": is_healthy,
331 "status": format!("{:?}", health_status),
332 }),
333 );
334 }
335 }
336
337 let overall_healthy = endpoint_statuses
338 .values()
339 .all(|v| v["healthy"].as_bool().unwrap_or(false));
340
341 Ok(serde_json::json!({
342 "status": if overall_healthy { "ready" } else { "notready" },
343 "endpoints_checked": endpoint_subjects.len(),
344 "endpoint_statuses": endpoint_statuses,
345 }))
346}
347
348#[cfg(all(test, feature = "integration"))]
352mod integration_tests {
353 use super::*;
354 use crate::distributed::distributed_test_utils::create_test_drt_async;
355 use std::sync::Arc;
356 use std::time::Duration;
357
358 #[tokio::test]
359 async fn test_initialization() {
360 let drt = create_test_drt_async().await;
361
362 let canary_wait_time = Duration::from_secs(5);
363 let request_timeout = Duration::from_secs(3);
364
365 let config = HealthCheckConfig {
366 canary_wait_time,
367 request_timeout,
368 };
369
370 let manager = HealthCheckManager::new(drt.clone(), config);
371
372 assert_eq!(manager.config.canary_wait_time, canary_wait_time);
373 assert_eq!(manager.config.request_timeout, request_timeout);
374 }
375
376 #[tokio::test]
377 async fn test_payload_registration() {
378 let drt = create_test_drt_async().await;
379
380 let endpoint = "test.endpoint";
381 let payload = serde_json::json!({
382 "prompt": "test",
383 "_health_check": true
384 });
385
386 drt.system_health().lock().register_health_check_target(
387 endpoint,
388 crate::component::Instance {
389 component: "test_component".to_string(),
390 endpoint: "test_endpoint".to_string(),
391 namespace: "test_namespace".to_string(),
392 instance_id: 12345,
393 transport: crate::component::TransportType::Nats(endpoint.to_string()),
394 },
395 payload.clone(),
396 );
397
398 let retrieved = drt
399 .system_health()
400 .lock()
401 .get_health_check_target(endpoint)
402 .map(|t| t.payload);
403 assert!(retrieved.is_some());
404 assert_eq!(retrieved.unwrap(), payload);
405
406 let endpoints = drt.system_health().lock().get_health_check_endpoints();
408 assert!(endpoints.contains(&endpoint.to_string()));
409 }
410
411 #[tokio::test]
412 async fn test_spawn_per_endpoint_tasks() {
413 let drt = create_test_drt_async().await;
414
415 for i in 0..3 {
416 let endpoint = format!("test.endpoint.{}", i);
417 let payload = serde_json::json!({
418 "prompt": format!("test{}", i),
419 "_health_check": true
420 });
421 drt.system_health().lock().register_health_check_target(
422 &endpoint,
423 crate::component::Instance {
424 component: "test_component".to_string(),
425 endpoint: format!("test_endpoint_{}", i),
426 namespace: "test_namespace".to_string(),
427 instance_id: i,
428 transport: crate::component::TransportType::Nats(endpoint.clone()),
429 },
430 payload,
431 );
432 }
433
434 let config = HealthCheckConfig {
435 canary_wait_time: Duration::from_secs(5),
436 request_timeout: Duration::from_secs(1),
437 };
438
439 let manager = Arc::new(HealthCheckManager::new(drt.clone(), config));
440 manager.clone().start().await.unwrap();
441
442 let tasks = manager.endpoint_tasks.lock();
444 assert_eq!(tasks.len(), 3);
446 let endpoints: Vec<String> = tasks.keys().cloned().collect();
448 assert!(endpoints.contains(&"test.endpoint.0".to_string()));
449 assert!(endpoints.contains(&"test.endpoint.1".to_string()));
450 assert!(endpoints.contains(&"test.endpoint.2".to_string()));
451 }
452
453 #[tokio::test]
454 async fn test_endpoint_health_check_notifier_created() {
455 let drt = create_test_drt_async().await;
456
457 let endpoint = "test.endpoint.notifier";
458 let payload = serde_json::json!({
459 "prompt": "test",
460 "_health_check": true
461 });
462
463 drt.system_health().lock().register_health_check_target(
465 endpoint,
466 crate::component::Instance {
467 component: "test_component".to_string(),
468 endpoint: "test_endpoint_notifier".to_string(),
469 namespace: "test_namespace".to_string(),
470 instance_id: 999,
471 transport: crate::component::TransportType::Nats(endpoint.to_string()),
472 },
473 payload.clone(),
474 );
475
476 let notifier = drt
478 .system_health()
479 .lock()
480 .get_endpoint_health_check_notifier(endpoint);
481
482 assert!(
483 notifier.is_some(),
484 "Endpoint should have a notifier created"
485 );
486
487 if let Some(notifier) = notifier {
489 notifier.notify_one();
490 }
491
492 let status = drt
494 .system_health()
495 .lock()
496 .get_endpoint_health_status(endpoint);
497 assert_eq!(status, Some(HealthStatus::NotReady));
498 }
499}