1use async_trait::async_trait;
7use std::collections::HashMap;
8use std::net::SocketAddr;
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11use tokio::io::{AsyncReadExt, AsyncWriteExt};
12use tokio::net::TcpStream;
13use tokio::sync::RwLock;
14use tokio::time;
15use tracing::{debug, info, trace, warn};
16
17use grapsus_common::{errors::GrapsusResult, types::HealthCheckType};
18use grapsus_config::{HealthCheck as HealthCheckConfig, UpstreamTarget};
19
20pub struct ActiveHealthChecker {
25 config: HealthCheckConfig,
27 checker: Arc<dyn HealthCheckImpl>,
29 health_status: Arc<RwLock<HashMap<String, TargetHealthInfo>>>,
31 check_handles: Arc<RwLock<Vec<tokio::task::JoinHandle<()>>>>,
33 shutdown_tx: Arc<tokio::sync::broadcast::Sender<()>>,
35}
36
37#[derive(Debug, Clone)]
39pub struct TargetHealthInfo {
40 pub healthy: bool,
42 pub consecutive_successes: u32,
44 pub consecutive_failures: u32,
46 pub last_check: Instant,
48 pub last_success: Option<Instant>,
50 pub last_error: Option<String>,
52 pub total_checks: u64,
54 pub total_successes: u64,
56 pub avg_response_time: f64,
58}
59
60#[async_trait]
62trait HealthCheckImpl: Send + Sync {
63 async fn check(&self, target: &str) -> Result<Duration, String>;
65
66 fn check_type(&self) -> &str;
68}
69
70struct HttpHealthCheck {
72 path: String,
73 expected_status: u16,
74 host: Option<String>,
75 timeout: Duration,
76}
77
78struct TcpHealthCheck {
80 timeout: Duration,
81}
82
83struct GrpcHealthCheck {
96 service: String,
97 timeout: Duration,
98}
99
100struct InferenceHealthCheck {
111 endpoint: String,
112 expected_models: Vec<String>,
113 timeout: Duration,
114}
115
116struct InferenceProbeCheck {
120 config: grapsus_common::InferenceProbeConfig,
121 timeout: Duration,
122}
123
124struct ModelStatusCheck {
128 config: grapsus_common::ModelStatusConfig,
129 timeout: Duration,
130}
131
132struct QueueDepthCheck {
136 config: grapsus_common::QueueDepthConfig,
137 models_endpoint: String,
138 timeout: Duration,
139}
140
141struct CompositeInferenceHealthCheck {
146 base_check: InferenceHealthCheck,
147 inference_probe: Option<InferenceProbeCheck>,
148 model_status: Option<ModelStatusCheck>,
149 queue_depth: Option<QueueDepthCheck>,
150}
151
152impl ActiveHealthChecker {
153 pub fn new(config: HealthCheckConfig) -> Self {
155 debug!(
156 check_type = ?config.check_type,
157 interval_secs = config.interval_secs,
158 timeout_secs = config.timeout_secs,
159 healthy_threshold = config.healthy_threshold,
160 unhealthy_threshold = config.unhealthy_threshold,
161 "Creating active health checker"
162 );
163
164 let checker: Arc<dyn HealthCheckImpl> = match &config.check_type {
165 HealthCheckType::Http {
166 path,
167 expected_status,
168 host,
169 } => {
170 trace!(
171 path = %path,
172 expected_status = expected_status,
173 host = host.as_deref().unwrap_or("(default)"),
174 "Configuring HTTP health check"
175 );
176 Arc::new(HttpHealthCheck {
177 path: path.clone(),
178 expected_status: *expected_status,
179 host: host.clone(),
180 timeout: Duration::from_secs(config.timeout_secs),
181 })
182 }
183 HealthCheckType::Tcp => {
184 trace!("Configuring TCP health check");
185 Arc::new(TcpHealthCheck {
186 timeout: Duration::from_secs(config.timeout_secs),
187 })
188 }
189 HealthCheckType::Grpc { service } => {
190 trace!(
191 service = %service,
192 "Configuring gRPC health check"
193 );
194 Arc::new(GrpcHealthCheck {
195 service: service.clone(),
196 timeout: Duration::from_secs(config.timeout_secs),
197 })
198 }
199 HealthCheckType::Inference {
200 endpoint,
201 expected_models,
202 readiness,
203 } => {
204 trace!(
205 endpoint = %endpoint,
206 expected_models = ?expected_models,
207 has_readiness = readiness.is_some(),
208 "Configuring inference health check"
209 );
210
211 let base_timeout = Duration::from_secs(config.timeout_secs);
212 let base_check = InferenceHealthCheck {
213 endpoint: endpoint.clone(),
214 expected_models: expected_models.clone(),
215 timeout: base_timeout,
216 };
217
218 if let Some(ref readiness_config) = readiness {
219 let inference_probe =
221 readiness_config
222 .inference_probe
223 .as_ref()
224 .map(|cfg| InferenceProbeCheck {
225 config: cfg.clone(),
226 timeout: Duration::from_secs(cfg.timeout_secs),
227 });
228
229 let model_status =
230 readiness_config
231 .model_status
232 .as_ref()
233 .map(|cfg| ModelStatusCheck {
234 config: cfg.clone(),
235 timeout: Duration::from_secs(cfg.timeout_secs),
236 });
237
238 let queue_depth =
239 readiness_config
240 .queue_depth
241 .as_ref()
242 .map(|cfg| QueueDepthCheck {
243 config: cfg.clone(),
244 models_endpoint: endpoint.clone(),
245 timeout: Duration::from_secs(cfg.timeout_secs),
246 });
247
248 Arc::new(CompositeInferenceHealthCheck {
249 base_check,
250 inference_probe,
251 model_status,
252 queue_depth,
253 })
254 } else {
255 Arc::new(base_check)
257 }
258 }
259 };
260
261 let (shutdown_tx, _) = tokio::sync::broadcast::channel(1);
262
263 Self {
264 config,
265 checker,
266 health_status: Arc::new(RwLock::new(HashMap::new())),
267 check_handles: Arc::new(RwLock::new(Vec::new())),
268 shutdown_tx: Arc::new(shutdown_tx),
269 }
270 }
271
272 pub async fn start(&self, targets: &[UpstreamTarget]) -> GrapsusResult<()> {
274 info!(
275 target_count = targets.len(),
276 interval_secs = self.config.interval_secs,
277 check_type = self.checker.check_type(),
278 "Starting health checking"
279 );
280
281 let mut handles = self.check_handles.write().await;
282
283 for target in targets {
284 let address = target.address.clone();
285
286 trace!(
287 target = %address,
288 "Initializing health status for target"
289 );
290
291 self.health_status
293 .write()
294 .await
295 .insert(address.clone(), TargetHealthInfo::new());
296
297 debug!(
299 target = %address,
300 "Spawning health check task"
301 );
302 let handle = self.spawn_check_task(address);
303 handles.push(handle);
304 }
305
306 info!(
307 target_count = targets.len(),
308 interval_secs = self.config.interval_secs,
309 healthy_threshold = self.config.healthy_threshold,
310 unhealthy_threshold = self.config.unhealthy_threshold,
311 "Health checking started successfully"
312 );
313
314 Ok(())
315 }
316
317 fn spawn_check_task(&self, target: String) -> tokio::task::JoinHandle<()> {
319 let interval = Duration::from_secs(self.config.interval_secs);
320 let checker = Arc::clone(&self.checker);
321 let health_status = Arc::clone(&self.health_status);
322 let healthy_threshold = self.config.healthy_threshold;
323 let unhealthy_threshold = self.config.unhealthy_threshold;
324 let check_type = self.checker.check_type().to_string();
325 let mut shutdown_rx = self.shutdown_tx.subscribe();
326
327 tokio::spawn(async move {
328 let mut interval_timer = time::interval(interval);
329 interval_timer.set_missed_tick_behavior(time::MissedTickBehavior::Skip);
330
331 trace!(
332 target = %target,
333 check_type = %check_type,
334 interval_ms = interval.as_millis(),
335 "Health check task started"
336 );
337
338 loop {
339 tokio::select! {
340 _ = interval_timer.tick() => {
341 trace!(
343 target = %target,
344 check_type = %check_type,
345 "Performing health check"
346 );
347 let start = Instant::now();
348 let result = checker.check(&target).await;
349 let check_duration = start.elapsed();
350
351 let mut status_map = health_status.write().await;
353 if let Some(status) = status_map.get_mut(&target) {
354 status.last_check = Instant::now();
355 status.total_checks += 1;
356
357 match result {
358 Ok(response_time) => {
359 status.consecutive_successes += 1;
360 status.consecutive_failures = 0;
361 status.last_success = Some(Instant::now());
362 status.last_error = None;
363 status.total_successes += 1;
364
365 let response_ms = response_time.as_millis() as f64;
367 status.avg_response_time =
368 (status.avg_response_time * (status.total_successes - 1) as f64
369 + response_ms) / status.total_successes as f64;
370
371 if !status.healthy && status.consecutive_successes >= healthy_threshold {
373 status.healthy = true;
374 info!(
375 target = %target,
376 consecutive_successes = status.consecutive_successes,
377 avg_response_ms = format!("{:.2}", status.avg_response_time),
378 total_checks = status.total_checks,
379 "Target marked as healthy"
380 );
381 }
382
383 trace!(
384 target = %target,
385 response_time_ms = response_ms,
386 check_duration_ms = check_duration.as_millis(),
387 consecutive_successes = status.consecutive_successes,
388 health_score = format!("{:.2}", status.health_score()),
389 "Health check succeeded"
390 );
391 }
392 Err(error) => {
393 status.consecutive_failures += 1;
394 status.consecutive_successes = 0;
395 status.last_error = Some(error.clone());
396
397 if status.healthy && status.consecutive_failures >= unhealthy_threshold {
399 status.healthy = false;
400 warn!(
401 target = %target,
402 consecutive_failures = status.consecutive_failures,
403 error = %error,
404 total_checks = status.total_checks,
405 health_score = format!("{:.2}", status.health_score()),
406 "Target marked as unhealthy"
407 );
408 } else {
409 debug!(
410 target = %target,
411 error = %error,
412 consecutive_failures = status.consecutive_failures,
413 unhealthy_threshold = unhealthy_threshold,
414 "Health check failed"
415 );
416 }
417 }
418 }
419 }
420 }
421 _ = shutdown_rx.recv() => {
422 info!(target = %target, "Stopping health check task");
423 break;
424 }
425 }
426 }
427
428 debug!(target = %target, "Health check task stopped");
429 })
430 }
431
432 pub async fn stop(&self) {
434 let task_count = self.check_handles.read().await.len();
435 info!(task_count = task_count, "Stopping health checker");
436
437 let _ = self.shutdown_tx.send(());
439
440 let mut handles = self.check_handles.write().await;
442 for handle in handles.drain(..) {
443 let _ = handle.await;
444 }
445
446 info!("Health checker stopped successfully");
447 }
448
449 pub async fn get_status(&self, target: &str) -> Option<TargetHealthInfo> {
451 self.health_status.read().await.get(target).cloned()
452 }
453
454 pub async fn get_all_statuses(&self) -> HashMap<String, TargetHealthInfo> {
456 self.health_status.read().await.clone()
457 }
458
459 pub async fn is_healthy(&self, target: &str) -> bool {
461 self.health_status
462 .read()
463 .await
464 .get(target)
465 .map(|s| s.healthy)
466 .unwrap_or(false)
467 }
468
469 pub async fn get_healthy_targets(&self) -> Vec<String> {
471 self.health_status
472 .read()
473 .await
474 .iter()
475 .filter_map(|(target, status)| {
476 if status.healthy {
477 Some(target.clone())
478 } else {
479 None
480 }
481 })
482 .collect()
483 }
484
485 pub async fn mark_unhealthy(&self, target: &str, reason: String) {
487 if let Some(status) = self.health_status.write().await.get_mut(target) {
488 if status.healthy {
489 status.healthy = false;
490 status.consecutive_failures = self.config.unhealthy_threshold;
491 status.consecutive_successes = 0;
492 status.last_error = Some(reason.clone());
493 warn!(
494 target = %target,
495 reason = %reason,
496 "Target marked unhealthy by passive check"
497 );
498 }
499 }
500 }
501}
502
503impl Default for TargetHealthInfo {
504 fn default() -> Self {
505 Self::new()
506 }
507}
508
509impl TargetHealthInfo {
510 pub fn new() -> Self {
512 Self {
513 healthy: true,
514 consecutive_successes: 0,
515 consecutive_failures: 0,
516 last_check: Instant::now(),
517 last_success: Some(Instant::now()),
518 last_error: None,
519 total_checks: 0,
520 total_successes: 0,
521 avg_response_time: 0.0,
522 }
523 }
524
525 pub fn health_score(&self) -> f64 {
527 if self.total_checks == 0 {
528 return 1.0;
529 }
530 self.total_successes as f64 / self.total_checks as f64
531 }
532
533 pub fn is_degraded(&self) -> bool {
535 self.healthy && self.consecutive_failures > 0
536 }
537}
538
539#[async_trait]
540impl HealthCheckImpl for HttpHealthCheck {
541 async fn check(&self, target: &str) -> Result<Duration, String> {
542 let start = Instant::now();
543
544 let addr: SocketAddr = target
546 .parse()
547 .map_err(|e| format!("Invalid address: {}", e))?;
548
549 let stream = time::timeout(self.timeout, TcpStream::connect(addr))
551 .await
552 .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
553 .map_err(|e| format!("Connection failed: {}", e))?;
554
555 let host = self.host.as_deref().unwrap_or(target);
557 let request = format!(
558 "GET {} HTTP/1.1\r\nHost: {}\r\nUser-Agent: Grapsus-HealthCheck/1.0\r\nConnection: close\r\n\r\n",
559 self.path,
560 host
561 );
562
563 let mut stream = stream;
565 stream
566 .write_all(request.as_bytes())
567 .await
568 .map_err(|e| format!("Failed to send request: {}", e))?;
569
570 let mut response = vec![0u8; 1024];
571 let n = stream
572 .read(&mut response)
573 .await
574 .map_err(|e| format!("Failed to read response: {}", e))?;
575
576 if n == 0 {
577 return Err("Empty response".to_string());
578 }
579
580 let response_str = String::from_utf8_lossy(&response[..n]);
582 let status_code = parse_http_status(&response_str)
583 .ok_or_else(|| "Failed to parse HTTP status".to_string())?;
584
585 if status_code == self.expected_status {
586 Ok(start.elapsed())
587 } else {
588 Err(format!(
589 "Unexpected status code: {} (expected {})",
590 status_code, self.expected_status
591 ))
592 }
593 }
594
595 fn check_type(&self) -> &str {
596 "HTTP"
597 }
598}
599
600#[async_trait]
601impl HealthCheckImpl for TcpHealthCheck {
602 async fn check(&self, target: &str) -> Result<Duration, String> {
603 let start = Instant::now();
604
605 let addr: SocketAddr = target
607 .parse()
608 .map_err(|e| format!("Invalid address: {}", e))?;
609
610 time::timeout(self.timeout, TcpStream::connect(addr))
612 .await
613 .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
614 .map_err(|e| format!("Connection failed: {}", e))?;
615
616 Ok(start.elapsed())
617 }
618
619 fn check_type(&self) -> &str {
620 "TCP"
621 }
622}
623
624#[async_trait]
625impl HealthCheckImpl for GrpcHealthCheck {
626 async fn check(&self, target: &str) -> Result<Duration, String> {
627 let start = Instant::now();
628
629 let addr: SocketAddr = target
635 .parse()
636 .map_err(|e| format!("Invalid address: {}", e))?;
637
638 let stream = time::timeout(self.timeout, TcpStream::connect(addr))
640 .await
641 .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
642 .map_err(|e| format!("Connection failed: {}", e))?;
643
644 stream
646 .writable()
647 .await
648 .map_err(|e| format!("Connection not writable: {}", e))?;
649
650 debug!(
651 target = %target,
652 service = %self.service,
653 "gRPC health check using TCP fallback (full gRPC protocol requires tonic)"
654 );
655
656 Ok(start.elapsed())
657 }
658
659 fn check_type(&self) -> &str {
660 "gRPC"
661 }
662}
663
664#[async_trait]
665impl HealthCheckImpl for InferenceHealthCheck {
666 async fn check(&self, target: &str) -> Result<Duration, String> {
667 let start = Instant::now();
668
669 let addr: SocketAddr = target
671 .parse()
672 .map_err(|e| format!("Invalid address: {}", e))?;
673
674 let stream = time::timeout(self.timeout, TcpStream::connect(addr))
676 .await
677 .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
678 .map_err(|e| format!("Connection failed: {}", e))?;
679
680 let request = format!(
682 "GET {} HTTP/1.1\r\nHost: {}\r\nUser-Agent: Grapsus-HealthCheck/1.0\r\nAccept: application/json\r\nConnection: close\r\n\r\n",
683 self.endpoint,
684 target
685 );
686
687 let mut stream = stream;
689 stream
690 .write_all(request.as_bytes())
691 .await
692 .map_err(|e| format!("Failed to send request: {}", e))?;
693
694 let mut response = vec![0u8; 8192];
696 let n = stream
697 .read(&mut response)
698 .await
699 .map_err(|e| format!("Failed to read response: {}", e))?;
700
701 if n == 0 {
702 return Err("Empty response".to_string());
703 }
704
705 let response_str = String::from_utf8_lossy(&response[..n]);
707 let status_code = parse_http_status(&response_str)
708 .ok_or_else(|| "Failed to parse HTTP status".to_string())?;
709
710 if status_code != 200 {
711 return Err(format!(
712 "Unexpected status code: {} (expected 200)",
713 status_code
714 ));
715 }
716
717 if !self.expected_models.is_empty() {
719 if let Some(body_start) = response_str.find("\r\n\r\n") {
721 let body = &response_str[body_start + 4..];
722
723 for model in &self.expected_models {
725 if !body.contains(model) {
726 return Err(format!("Expected model '{}' not found in response", model));
727 }
728 }
729
730 debug!(
731 target = %target,
732 endpoint = %self.endpoint,
733 expected_models = ?self.expected_models,
734 "All expected models found in inference health check"
735 );
736 } else {
737 return Err("Could not find response body".to_string());
738 }
739 }
740
741 trace!(
742 target = %target,
743 endpoint = %self.endpoint,
744 response_time_ms = start.elapsed().as_millis(),
745 "Inference health check passed"
746 );
747
748 Ok(start.elapsed())
749 }
750
751 fn check_type(&self) -> &str {
752 "Inference"
753 }
754}
755
756#[async_trait]
757impl HealthCheckImpl for InferenceProbeCheck {
758 async fn check(&self, target: &str) -> Result<Duration, String> {
759 let start = Instant::now();
760
761 let addr: SocketAddr = target
763 .parse()
764 .map_err(|e| format!("Invalid address: {}", e))?;
765
766 let stream = time::timeout(self.timeout, TcpStream::connect(addr))
768 .await
769 .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
770 .map_err(|e| format!("Connection failed: {}", e))?;
771
772 let body = format!(
774 r#"{{"model":"{}","prompt":"{}","max_tokens":{}}}"#,
775 self.config.model, self.config.prompt, self.config.max_tokens
776 );
777
778 let request = format!(
780 "POST {} HTTP/1.1\r\nHost: {}\r\nUser-Agent: Grapsus-HealthCheck/1.0\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
781 self.config.endpoint,
782 target,
783 body.len(),
784 body
785 );
786
787 let mut stream = stream;
789 stream
790 .write_all(request.as_bytes())
791 .await
792 .map_err(|e| format!("Failed to send request: {}", e))?;
793
794 let mut response = vec![0u8; 16384];
796 let n = stream
797 .read(&mut response)
798 .await
799 .map_err(|e| format!("Failed to read response: {}", e))?;
800
801 if n == 0 {
802 return Err("Empty response".to_string());
803 }
804
805 let latency = start.elapsed();
806
807 let response_str = String::from_utf8_lossy(&response[..n]);
809 let status_code = parse_http_status(&response_str)
810 .ok_or_else(|| "Failed to parse HTTP status".to_string())?;
811
812 if status_code != 200 {
813 return Err(format!(
814 "Inference probe failed: status {} (expected 200)",
815 status_code
816 ));
817 }
818
819 if let Some(body_start) = response_str.find("\r\n\r\n") {
821 let body = &response_str[body_start + 4..];
822 if !body.contains("\"choices\"") {
823 return Err("Inference probe response missing 'choices' field".to_string());
824 }
825 }
826
827 if let Some(max_ms) = self.config.max_latency_ms {
829 if latency.as_millis() as u64 > max_ms {
830 return Err(format!(
831 "Inference probe latency {}ms exceeds threshold {}ms",
832 latency.as_millis(),
833 max_ms
834 ));
835 }
836 }
837
838 trace!(
839 target = %target,
840 model = %self.config.model,
841 latency_ms = latency.as_millis(),
842 "Inference probe health check passed"
843 );
844
845 Ok(latency)
846 }
847
848 fn check_type(&self) -> &str {
849 "InferenceProbe"
850 }
851}
852
853#[async_trait]
854impl HealthCheckImpl for ModelStatusCheck {
855 async fn check(&self, target: &str) -> Result<Duration, String> {
856 let start = Instant::now();
857
858 let addr: SocketAddr = target
860 .parse()
861 .map_err(|e| format!("Invalid address: {}", e))?;
862
863 for model in &self.config.models {
865 let endpoint = self.config.endpoint_pattern.replace("{model}", model);
866
867 let stream = time::timeout(self.timeout, TcpStream::connect(addr))
869 .await
870 .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
871 .map_err(|e| format!("Connection failed: {}", e))?;
872
873 let request = format!(
875 "GET {} HTTP/1.1\r\nHost: {}\r\nUser-Agent: Grapsus-HealthCheck/1.0\r\nAccept: application/json\r\nConnection: close\r\n\r\n",
876 endpoint,
877 target
878 );
879
880 let mut stream = stream;
882 stream
883 .write_all(request.as_bytes())
884 .await
885 .map_err(|e| format!("Failed to send request: {}", e))?;
886
887 let mut response = vec![0u8; 8192];
889 let n = stream
890 .read(&mut response)
891 .await
892 .map_err(|e| format!("Failed to read response: {}", e))?;
893
894 if n == 0 {
895 return Err(format!("Empty response for model '{}'", model));
896 }
897
898 let response_str = String::from_utf8_lossy(&response[..n]);
899 let status_code = parse_http_status(&response_str)
900 .ok_or_else(|| "Failed to parse HTTP status".to_string())?;
901
902 if status_code != 200 {
903 return Err(format!(
904 "Model '{}' status check failed: HTTP {}",
905 model, status_code
906 ));
907 }
908
909 if let Some(body_start) = response_str.find("\r\n\r\n") {
911 let body = &response_str[body_start + 4..];
912 let status = extract_json_field(body, &self.config.status_field);
913
914 match status {
915 Some(s) if s == self.config.expected_status => {
916 trace!(
917 target = %target,
918 model = %model,
919 status = %s,
920 "Model status check passed"
921 );
922 }
923 Some(s) => {
924 return Err(format!(
925 "Model '{}' status '{}' != expected '{}'",
926 model, s, self.config.expected_status
927 ));
928 }
929 None => {
930 return Err(format!(
931 "Model '{}' status field '{}' not found",
932 model, self.config.status_field
933 ));
934 }
935 }
936 }
937 }
938
939 Ok(start.elapsed())
940 }
941
942 fn check_type(&self) -> &str {
943 "ModelStatus"
944 }
945}
946
947#[async_trait]
948impl HealthCheckImpl for QueueDepthCheck {
949 async fn check(&self, target: &str) -> Result<Duration, String> {
950 let start = Instant::now();
951
952 let addr: SocketAddr = target
954 .parse()
955 .map_err(|e| format!("Invalid address: {}", e))?;
956
957 let endpoint = self
958 .config
959 .endpoint
960 .as_ref()
961 .unwrap_or(&self.models_endpoint);
962
963 let stream = time::timeout(self.timeout, TcpStream::connect(addr))
965 .await
966 .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
967 .map_err(|e| format!("Connection failed: {}", e))?;
968
969 let request = format!(
971 "GET {} HTTP/1.1\r\nHost: {}\r\nUser-Agent: Grapsus-HealthCheck/1.0\r\nAccept: application/json\r\nConnection: close\r\n\r\n",
972 endpoint,
973 target
974 );
975
976 let mut stream = stream;
978 stream
979 .write_all(request.as_bytes())
980 .await
981 .map_err(|e| format!("Failed to send request: {}", e))?;
982
983 let mut response = vec![0u8; 8192];
985 let n = stream
986 .read(&mut response)
987 .await
988 .map_err(|e| format!("Failed to read response: {}", e))?;
989
990 if n == 0 {
991 return Err("Empty response".to_string());
992 }
993
994 let response_str = String::from_utf8_lossy(&response[..n]);
995
996 let queue_depth = if let Some(ref header_name) = self.config.header {
998 extract_header_value(&response_str, header_name).and_then(|v| v.parse::<u64>().ok())
999 } else if let Some(ref field) = self.config.body_field {
1000 if let Some(body_start) = response_str.find("\r\n\r\n") {
1001 let body = &response_str[body_start + 4..];
1002 extract_json_field(body, field).and_then(|v| v.parse::<u64>().ok())
1003 } else {
1004 None
1005 }
1006 } else {
1007 return Err("No queue depth source configured (header or body_field)".to_string());
1008 };
1009
1010 let depth = queue_depth.ok_or_else(|| "Could not extract queue depth".to_string())?;
1011
1012 if depth >= self.config.unhealthy_threshold {
1014 return Err(format!(
1015 "Queue depth {} exceeds unhealthy threshold {}",
1016 depth, self.config.unhealthy_threshold
1017 ));
1018 }
1019
1020 if depth >= self.config.degraded_threshold {
1021 warn!(
1022 target = %target,
1023 queue_depth = depth,
1024 threshold = self.config.degraded_threshold,
1025 "Queue depth exceeds degraded threshold"
1026 );
1027 }
1028
1029 trace!(
1030 target = %target,
1031 queue_depth = depth,
1032 "Queue depth check passed"
1033 );
1034
1035 Ok(start.elapsed())
1036 }
1037
1038 fn check_type(&self) -> &str {
1039 "QueueDepth"
1040 }
1041}
1042
1043#[async_trait]
1044impl HealthCheckImpl for CompositeInferenceHealthCheck {
1045 async fn check(&self, target: &str) -> Result<Duration, String> {
1046 let start = Instant::now();
1047
1048 self.base_check.check(target).await?;
1050
1051 if let Some(ref probe) = self.inference_probe {
1053 probe.check(target).await?;
1054 }
1055
1056 if let Some(ref status) = self.model_status {
1057 status.check(target).await?;
1058 }
1059
1060 if let Some(ref queue) = self.queue_depth {
1061 queue.check(target).await?;
1062 }
1063
1064 trace!(
1065 target = %target,
1066 total_time_ms = start.elapsed().as_millis(),
1067 "Composite inference health check passed"
1068 );
1069
1070 Ok(start.elapsed())
1071 }
1072
1073 fn check_type(&self) -> &str {
1074 "CompositeInference"
1075 }
1076}
1077
1078fn extract_header_value(response: &str, header_name: &str) -> Option<String> {
1080 let header_lower = header_name.to_lowercase();
1081 for line in response.lines() {
1082 if line.is_empty() || line == "\r" {
1083 break; }
1085 if let Some((name, value)) = line.split_once(':') {
1086 if name.trim().to_lowercase() == header_lower {
1087 return Some(value.trim().to_string());
1088 }
1089 }
1090 }
1091 None
1092}
1093
1094fn extract_json_field(body: &str, field_path: &str) -> Option<String> {
1096 let json: serde_json::Value = serde_json::from_str(body).ok()?;
1097 let parts: Vec<&str> = field_path.split('.').collect();
1098 let mut current = &json;
1099
1100 for part in parts {
1101 current = current.get(part)?;
1102 }
1103
1104 match current {
1105 serde_json::Value::String(s) => Some(s.clone()),
1106 serde_json::Value::Number(n) => Some(n.to_string()),
1107 serde_json::Value::Bool(b) => Some(b.to_string()),
1108 _ => None,
1109 }
1110}
1111
1112fn parse_http_status(response: &str) -> Option<u16> {
1114 response
1115 .lines()
1116 .next()?
1117 .split_whitespace()
1118 .nth(1)?
1119 .parse()
1120 .ok()
1121}
1122
1123pub struct PassiveHealthChecker {
1129 failure_rate_threshold: f64,
1131 window_size: usize,
1133 outcomes: Arc<RwLock<HashMap<String, Vec<bool>>>>,
1135 last_errors: Arc<RwLock<HashMap<String, String>>>,
1137 active_checker: Option<Arc<ActiveHealthChecker>>,
1139}
1140
1141impl PassiveHealthChecker {
1142 pub fn new(
1144 failure_rate_threshold: f64,
1145 window_size: usize,
1146 active_checker: Option<Arc<ActiveHealthChecker>>,
1147 ) -> Self {
1148 debug!(
1149 failure_rate_threshold = format!("{:.2}", failure_rate_threshold),
1150 window_size = window_size,
1151 has_active_checker = active_checker.is_some(),
1152 "Creating passive health checker"
1153 );
1154 Self {
1155 failure_rate_threshold,
1156 window_size,
1157 outcomes: Arc::new(RwLock::new(HashMap::new())),
1158 last_errors: Arc::new(RwLock::new(HashMap::new())),
1159 active_checker,
1160 }
1161 }
1162
1163 pub async fn record_outcome(&self, target: &str, success: bool, error: Option<&str>) {
1165 trace!(
1166 target = %target,
1167 success = success,
1168 error = ?error,
1169 "Recording request outcome"
1170 );
1171
1172 if let Some(err_msg) = error {
1174 self.last_errors
1175 .write()
1176 .await
1177 .insert(target.to_string(), err_msg.to_string());
1178 } else if success {
1179 self.last_errors.write().await.remove(target);
1181 }
1182
1183 let mut outcomes = self.outcomes.write().await;
1184 let target_outcomes = outcomes
1185 .entry(target.to_string())
1186 .or_insert_with(|| Vec::with_capacity(self.window_size));
1187
1188 if target_outcomes.len() >= self.window_size {
1190 target_outcomes.remove(0);
1191 }
1192 target_outcomes.push(success);
1193
1194 let failures = target_outcomes.iter().filter(|&&s| !s).count();
1196 let failure_rate = failures as f64 / target_outcomes.len() as f64;
1197
1198 trace!(
1199 target = %target,
1200 failure_rate = format!("{:.2}", failure_rate),
1201 window_samples = target_outcomes.len(),
1202 failures = failures,
1203 "Updated failure rate"
1204 );
1205
1206 if failure_rate > self.failure_rate_threshold {
1208 warn!(
1209 target = %target,
1210 failure_rate = format!("{:.2}", failure_rate * 100.0),
1211 threshold = format!("{:.2}", self.failure_rate_threshold * 100.0),
1212 window_samples = target_outcomes.len(),
1213 "Failure rate exceeds threshold"
1214 );
1215 if let Some(ref checker) = self.active_checker {
1216 checker
1217 .mark_unhealthy(
1218 target,
1219 format!(
1220 "Failure rate {:.2}% exceeds threshold",
1221 failure_rate * 100.0
1222 ),
1223 )
1224 .await;
1225 }
1226 }
1227 }
1228
1229 pub async fn get_failure_rate(&self, target: &str) -> Option<f64> {
1231 let outcomes = self.outcomes.read().await;
1232 outcomes.get(target).map(|target_outcomes| {
1233 let failures = target_outcomes.iter().filter(|&&s| !s).count();
1234 failures as f64 / target_outcomes.len() as f64
1235 })
1236 }
1237
1238 pub async fn get_last_error(&self, target: &str) -> Option<String> {
1240 self.last_errors.read().await.get(target).cloned()
1241 }
1242}
1243
1244use dashmap::DashMap;
1249use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
1250use grapsus_common::{ColdModelAction, WarmthDetectionConfig};
1251
1252pub struct WarmthTracker {
1258 config: WarmthDetectionConfig,
1260 targets: DashMap<String, TargetWarmthState>,
1262}
1263
1264struct TargetWarmthState {
1266 baseline_latency_ms: AtomicU64,
1268 sample_count: AtomicU32,
1270 last_request_ms: AtomicU64,
1272 is_cold: AtomicBool,
1274 cold_start_count: AtomicU64,
1276}
1277
1278impl TargetWarmthState {
1279 fn new() -> Self {
1280 Self {
1281 baseline_latency_ms: AtomicU64::new(0),
1282 sample_count: AtomicU32::new(0),
1283 last_request_ms: AtomicU64::new(0),
1284 is_cold: AtomicBool::new(false),
1285 cold_start_count: AtomicU64::new(0),
1286 }
1287 }
1288
1289 fn update_baseline(&self, latency_ms: u64, sample_size: u32) {
1290 let count = self.sample_count.fetch_add(1, Ordering::Relaxed);
1291 let current = self.baseline_latency_ms.load(Ordering::Relaxed);
1292
1293 if count < sample_size {
1294 let new_baseline = if count == 0 {
1296 latency_ms
1297 } else {
1298 (current * count as u64 + latency_ms) / (count as u64 + 1)
1299 };
1300 self.baseline_latency_ms
1301 .store(new_baseline, Ordering::Relaxed);
1302 } else {
1303 let alpha = 0.1_f64;
1306 let new_baseline = (alpha * latency_ms as f64 + (1.0 - alpha) * current as f64) as u64;
1307 self.baseline_latency_ms
1308 .store(new_baseline, Ordering::Relaxed);
1309 }
1310 }
1311}
1312
1313impl WarmthTracker {
1314 pub fn new(config: WarmthDetectionConfig) -> Self {
1316 Self {
1317 config,
1318 targets: DashMap::new(),
1319 }
1320 }
1321
1322 pub fn with_defaults() -> Self {
1324 Self::new(WarmthDetectionConfig {
1325 sample_size: 10,
1326 cold_threshold_multiplier: 3.0,
1327 idle_cold_timeout_secs: 300,
1328 cold_action: ColdModelAction::LogOnly,
1329 })
1330 }
1331
1332 pub fn record_request(&self, target: &str, latency: Duration) -> bool {
1336 let now_ms = std::time::SystemTime::now()
1337 .duration_since(std::time::UNIX_EPOCH)
1338 .map(|d| d.as_millis() as u64)
1339 .unwrap_or(0);
1340
1341 let latency_ms = latency.as_millis() as u64;
1342 let idle_threshold_ms = self.config.idle_cold_timeout_secs * 1000;
1343
1344 let state = self
1345 .targets
1346 .entry(target.to_string())
1347 .or_insert_with(TargetWarmthState::new);
1348
1349 let last_request = state.last_request_ms.load(Ordering::Relaxed);
1350 let idle_duration_ms = if last_request > 0 {
1351 now_ms.saturating_sub(last_request)
1352 } else {
1353 0
1354 };
1355
1356 state.last_request_ms.store(now_ms, Ordering::Relaxed);
1358
1359 if idle_duration_ms >= idle_threshold_ms {
1361 let baseline = state.baseline_latency_ms.load(Ordering::Relaxed);
1362
1363 if baseline > 0 {
1365 let threshold = (baseline as f64 * self.config.cold_threshold_multiplier) as u64;
1366
1367 if latency_ms > threshold {
1368 state.is_cold.store(true, Ordering::Release);
1370 state.cold_start_count.fetch_add(1, Ordering::Relaxed);
1371
1372 warn!(
1373 target = %target,
1374 latency_ms = latency_ms,
1375 baseline_ms = baseline,
1376 threshold_ms = threshold,
1377 idle_duration_secs = idle_duration_ms / 1000,
1378 cold_action = ?self.config.cold_action,
1379 "Cold model detected - latency spike after idle period"
1380 );
1381
1382 return true;
1383 }
1384 }
1385 }
1386
1387 state.is_cold.store(false, Ordering::Release);
1389 state.update_baseline(latency_ms, self.config.sample_size);
1390
1391 trace!(
1392 target = %target,
1393 latency_ms = latency_ms,
1394 baseline_ms = state.baseline_latency_ms.load(Ordering::Relaxed),
1395 sample_count = state.sample_count.load(Ordering::Relaxed),
1396 "Recorded request latency for warmth tracking"
1397 );
1398
1399 false
1400 }
1401
1402 pub fn is_cold(&self, target: &str) -> bool {
1404 self.targets
1405 .get(target)
1406 .map(|s| s.is_cold.load(Ordering::Acquire))
1407 .unwrap_or(false)
1408 }
1409
1410 pub fn cold_action(&self) -> ColdModelAction {
1412 self.config.cold_action
1413 }
1414
1415 pub fn baseline_latency_ms(&self, target: &str) -> Option<u64> {
1417 self.targets
1418 .get(target)
1419 .map(|s| s.baseline_latency_ms.load(Ordering::Relaxed))
1420 }
1421
1422 pub fn cold_start_count(&self, target: &str) -> u64 {
1424 self.targets
1425 .get(target)
1426 .map(|s| s.cold_start_count.load(Ordering::Relaxed))
1427 .unwrap_or(0)
1428 }
1429
1430 pub fn should_deprioritize(&self, target: &str) -> bool {
1432 if !self.is_cold(target) {
1433 return false;
1434 }
1435
1436 match self.config.cold_action {
1437 ColdModelAction::LogOnly => false,
1438 ColdModelAction::MarkDegraded | ColdModelAction::MarkUnhealthy => true,
1439 }
1440 }
1441}
1442
1443#[cfg(test)]
1444mod tests {
1445 use super::*;
1446
1447 #[tokio::test]
1448 async fn test_health_status() {
1449 let status = TargetHealthInfo::new();
1450 assert!(status.healthy);
1451 assert_eq!(status.health_score(), 1.0);
1452 assert!(!status.is_degraded());
1453 }
1454
1455 #[tokio::test]
1456 async fn test_passive_health_checker() {
1457 let checker = PassiveHealthChecker::new(0.5, 10, None);
1458
1459 for _ in 0..5 {
1461 checker.record_outcome("target1", true, None).await;
1462 }
1463 for _ in 0..3 {
1464 checker
1465 .record_outcome("target1", false, Some("HTTP 503"))
1466 .await;
1467 }
1468
1469 let failure_rate = checker.get_failure_rate("target1").await.unwrap();
1470 assert!(failure_rate > 0.3 && failure_rate < 0.4);
1471 }
1472
1473 #[test]
1474 fn test_parse_http_status() {
1475 let response = "HTTP/1.1 200 OK\r\n";
1476 assert_eq!(parse_http_status(response), Some(200));
1477
1478 let response = "HTTP/1.1 404 Not Found\r\n";
1479 assert_eq!(parse_http_status(response), Some(404));
1480
1481 let response = "Invalid response";
1482 assert_eq!(parse_http_status(response), None);
1483 }
1484
1485 #[test]
1486 fn test_warmth_tracker_baseline() {
1487 let tracker = WarmthTracker::with_defaults();
1488
1489 for i in 0..10 {
1491 let cold = tracker.record_request("target1", Duration::from_millis(100));
1492 assert!(!cold, "Should not detect cold on request {}", i);
1493 }
1494
1495 let baseline = tracker.baseline_latency_ms("target1");
1497 assert!(baseline.is_some());
1498 assert!(baseline.unwrap() > 0 && baseline.unwrap() <= 100);
1499 }
1500
1501 #[test]
1502 fn test_warmth_tracker_cold_detection() {
1503 let config = WarmthDetectionConfig {
1504 sample_size: 5,
1505 cold_threshold_multiplier: 2.0,
1506 idle_cold_timeout_secs: 0, cold_action: ColdModelAction::MarkDegraded,
1508 };
1509 let tracker = WarmthTracker::new(config);
1510
1511 for _ in 0..5 {
1513 tracker.record_request("target1", Duration::from_millis(100));
1514 }
1515
1516 std::thread::sleep(Duration::from_millis(10));
1518
1519 let cold = tracker.record_request("target1", Duration::from_millis(300));
1521 assert!(cold, "Should detect cold start");
1522 assert!(tracker.is_cold("target1"));
1523 assert_eq!(tracker.cold_start_count("target1"), 1);
1524 }
1525
1526 #[test]
1527 fn test_warmth_tracker_no_cold_on_normal_latency() {
1528 let config = WarmthDetectionConfig {
1529 sample_size: 5,
1530 cold_threshold_multiplier: 3.0,
1531 idle_cold_timeout_secs: 0,
1532 cold_action: ColdModelAction::LogOnly,
1533 };
1534 let tracker = WarmthTracker::new(config);
1535
1536 for _ in 0..5 {
1538 tracker.record_request("target1", Duration::from_millis(100));
1539 }
1540
1541 std::thread::sleep(Duration::from_millis(10));
1542
1543 let cold = tracker.record_request("target1", Duration::from_millis(150));
1545 assert!(!cold, "Should not detect cold for normal variation");
1546 assert!(!tracker.is_cold("target1"));
1547 }
1548
1549 #[test]
1550 fn test_warmth_tracker_deprioritize() {
1551 let config = WarmthDetectionConfig {
1552 sample_size: 2,
1553 cold_threshold_multiplier: 2.0,
1554 idle_cold_timeout_secs: 0,
1555 cold_action: ColdModelAction::MarkDegraded,
1556 };
1557 let tracker = WarmthTracker::new(config);
1558
1559 tracker.record_request("target1", Duration::from_millis(100));
1561 tracker.record_request("target1", Duration::from_millis(100));
1562 std::thread::sleep(Duration::from_millis(10));
1563 tracker.record_request("target1", Duration::from_millis(300));
1564
1565 assert!(tracker.should_deprioritize("target1"));
1567
1568 tracker.record_request("target1", Duration::from_millis(100));
1570 assert!(!tracker.should_deprioritize("target1"));
1571 }
1572}