Skip to main content

grapsus_proxy/
health.rs

1//! Health checking module for Grapsus proxy
2//!
3//! This module implements active and passive health checking for upstream servers,
4//! supporting HTTP, TCP, and gRPC health checks with configurable thresholds.
5
6use async_trait::async_trait;
7use std::collections::HashMap;
8use std::net::SocketAddr;
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11use tokio::io::{AsyncReadExt, AsyncWriteExt};
12use tokio::net::TcpStream;
13use tokio::sync::RwLock;
14use tokio::time;
15use tracing::{debug, info, trace, warn};
16
17use grapsus_common::{errors::GrapsusResult, types::HealthCheckType};
18use grapsus_config::{HealthCheck as HealthCheckConfig, UpstreamTarget};
19
20/// Active health checker for upstream targets
21///
22/// Performs periodic health checks on upstream targets using HTTP, TCP, or gRPC
23/// protocols to determine their availability for load balancing.
24pub struct ActiveHealthChecker {
25    /// Check configuration
26    config: HealthCheckConfig,
27    /// Health checker implementation
28    checker: Arc<dyn HealthCheckImpl>,
29    /// Health status per target
30    health_status: Arc<RwLock<HashMap<String, TargetHealthInfo>>>,
31    /// Check task handles
32    check_handles: Arc<RwLock<Vec<tokio::task::JoinHandle<()>>>>,
33    /// Shutdown signal
34    shutdown_tx: Arc<tokio::sync::broadcast::Sender<()>>,
35}
36
37/// Health status information for a target
38#[derive(Debug, Clone)]
39pub struct TargetHealthInfo {
40    /// Is target healthy
41    pub healthy: bool,
42    /// Consecutive successes
43    pub consecutive_successes: u32,
44    /// Consecutive failures
45    pub consecutive_failures: u32,
46    /// Last check time
47    pub last_check: Instant,
48    /// Last successful check
49    pub last_success: Option<Instant>,
50    /// Last error message
51    pub last_error: Option<String>,
52    /// Total checks performed
53    pub total_checks: u64,
54    /// Total successful checks
55    pub total_successes: u64,
56    /// Average response time (ms)
57    pub avg_response_time: f64,
58}
59
60/// Health check implementation trait
61#[async_trait]
62trait HealthCheckImpl: Send + Sync {
63    /// Perform health check on a target
64    async fn check(&self, target: &str) -> Result<Duration, String>;
65
66    /// Get check type name
67    fn check_type(&self) -> &str;
68}
69
70/// HTTP health check implementation
71struct HttpHealthCheck {
72    path: String,
73    expected_status: u16,
74    host: Option<String>,
75    timeout: Duration,
76}
77
78/// TCP health check implementation
79struct TcpHealthCheck {
80    timeout: Duration,
81}
82
83/// gRPC health check implementation.
84///
85/// Currently uses TCP connectivity check as a fallback since full gRPC
86/// health checking protocol (grpc.health.v1.Health) requires the `tonic`
87/// crate for HTTP/2 and Protocol Buffers support.
88///
89/// Full implementation would:
90/// 1. Establish HTTP/2 connection
91/// 2. Call `grpc.health.v1.Health/Check` with service name
92/// 3. Parse `HealthCheckResponse` for SERVING/NOT_SERVING status
93///
94/// See: https://github.com/grpc/grpc/blob/master/doc/health-checking.md
95struct GrpcHealthCheck {
96    service: String,
97    timeout: Duration,
98}
99
100/// Inference health check implementation for LLM/AI backends.
101///
102/// Probes the models endpoint to verify the inference server is running
103/// and expected models are available. Typically used with OpenAI-compatible
104/// APIs that expose a `/v1/models` endpoint.
105///
106/// The check:
107/// 1. Sends GET request to the configured endpoint (default: `/v1/models`)
108/// 2. Expects HTTP 200 response
109/// 3. Optionally parses response to verify expected models are available
110struct InferenceHealthCheck {
111    endpoint: String,
112    expected_models: Vec<String>,
113    timeout: Duration,
114}
115
116/// Inference probe health check - sends minimal completion request
117///
118/// Verifies model can actually process requests, not just that server is running.
119struct InferenceProbeCheck {
120    config: grapsus_common::InferenceProbeConfig,
121    timeout: Duration,
122}
123
124/// Model status endpoint health check
125///
126/// Queries provider-specific status endpoints to verify model readiness.
127struct ModelStatusCheck {
128    config: grapsus_common::ModelStatusConfig,
129    timeout: Duration,
130}
131
132/// Queue depth health check
133///
134/// Monitors queue depth from headers or response body to detect overload.
135struct QueueDepthCheck {
136    config: grapsus_common::QueueDepthConfig,
137    models_endpoint: String,
138    timeout: Duration,
139}
140
141/// Composite inference health check that runs multiple sub-checks
142///
143/// Runs base inference check plus any configured readiness checks.
144/// All enabled checks must pass for the target to be considered healthy.
145struct CompositeInferenceHealthCheck {
146    base_check: InferenceHealthCheck,
147    inference_probe: Option<InferenceProbeCheck>,
148    model_status: Option<ModelStatusCheck>,
149    queue_depth: Option<QueueDepthCheck>,
150}
151
152impl ActiveHealthChecker {
153    /// Create new active health checker
154    pub fn new(config: HealthCheckConfig) -> Self {
155        debug!(
156            check_type = ?config.check_type,
157            interval_secs = config.interval_secs,
158            timeout_secs = config.timeout_secs,
159            healthy_threshold = config.healthy_threshold,
160            unhealthy_threshold = config.unhealthy_threshold,
161            "Creating active health checker"
162        );
163
164        let checker: Arc<dyn HealthCheckImpl> = match &config.check_type {
165            HealthCheckType::Http {
166                path,
167                expected_status,
168                host,
169            } => {
170                trace!(
171                    path = %path,
172                    expected_status = expected_status,
173                    host = host.as_deref().unwrap_or("(default)"),
174                    "Configuring HTTP health check"
175                );
176                Arc::new(HttpHealthCheck {
177                    path: path.clone(),
178                    expected_status: *expected_status,
179                    host: host.clone(),
180                    timeout: Duration::from_secs(config.timeout_secs),
181                })
182            }
183            HealthCheckType::Tcp => {
184                trace!("Configuring TCP health check");
185                Arc::new(TcpHealthCheck {
186                    timeout: Duration::from_secs(config.timeout_secs),
187                })
188            }
189            HealthCheckType::Grpc { service } => {
190                trace!(
191                    service = %service,
192                    "Configuring gRPC health check"
193                );
194                Arc::new(GrpcHealthCheck {
195                    service: service.clone(),
196                    timeout: Duration::from_secs(config.timeout_secs),
197                })
198            }
199            HealthCheckType::Inference {
200                endpoint,
201                expected_models,
202                readiness,
203            } => {
204                trace!(
205                    endpoint = %endpoint,
206                    expected_models = ?expected_models,
207                    has_readiness = readiness.is_some(),
208                    "Configuring inference health check"
209                );
210
211                let base_timeout = Duration::from_secs(config.timeout_secs);
212                let base_check = InferenceHealthCheck {
213                    endpoint: endpoint.clone(),
214                    expected_models: expected_models.clone(),
215                    timeout: base_timeout,
216                };
217
218                if let Some(ref readiness_config) = readiness {
219                    // Create composite check with sub-checks
220                    let inference_probe =
221                        readiness_config
222                            .inference_probe
223                            .as_ref()
224                            .map(|cfg| InferenceProbeCheck {
225                                config: cfg.clone(),
226                                timeout: Duration::from_secs(cfg.timeout_secs),
227                            });
228
229                    let model_status =
230                        readiness_config
231                            .model_status
232                            .as_ref()
233                            .map(|cfg| ModelStatusCheck {
234                                config: cfg.clone(),
235                                timeout: Duration::from_secs(cfg.timeout_secs),
236                            });
237
238                    let queue_depth =
239                        readiness_config
240                            .queue_depth
241                            .as_ref()
242                            .map(|cfg| QueueDepthCheck {
243                                config: cfg.clone(),
244                                models_endpoint: endpoint.clone(),
245                                timeout: Duration::from_secs(cfg.timeout_secs),
246                            });
247
248                    Arc::new(CompositeInferenceHealthCheck {
249                        base_check,
250                        inference_probe,
251                        model_status,
252                        queue_depth,
253                    })
254                } else {
255                    // Simple inference check without readiness sub-checks
256                    Arc::new(base_check)
257                }
258            }
259        };
260
261        let (shutdown_tx, _) = tokio::sync::broadcast::channel(1);
262
263        Self {
264            config,
265            checker,
266            health_status: Arc::new(RwLock::new(HashMap::new())),
267            check_handles: Arc::new(RwLock::new(Vec::new())),
268            shutdown_tx: Arc::new(shutdown_tx),
269        }
270    }
271
272    /// Start health checking for targets
273    pub async fn start(&self, targets: &[UpstreamTarget]) -> GrapsusResult<()> {
274        info!(
275            target_count = targets.len(),
276            interval_secs = self.config.interval_secs,
277            check_type = self.checker.check_type(),
278            "Starting health checking"
279        );
280
281        let mut handles = self.check_handles.write().await;
282
283        for target in targets {
284            let address = target.address.clone();
285
286            trace!(
287                target = %address,
288                "Initializing health status for target"
289            );
290
291            // Initialize health status
292            self.health_status
293                .write()
294                .await
295                .insert(address.clone(), TargetHealthInfo::new());
296
297            // Spawn health check task
298            debug!(
299                target = %address,
300                "Spawning health check task"
301            );
302            let handle = self.spawn_check_task(address);
303            handles.push(handle);
304        }
305
306        info!(
307            target_count = targets.len(),
308            interval_secs = self.config.interval_secs,
309            healthy_threshold = self.config.healthy_threshold,
310            unhealthy_threshold = self.config.unhealthy_threshold,
311            "Health checking started successfully"
312        );
313
314        Ok(())
315    }
316
317    /// Spawn health check task for a target
318    fn spawn_check_task(&self, target: String) -> tokio::task::JoinHandle<()> {
319        let interval = Duration::from_secs(self.config.interval_secs);
320        let checker = Arc::clone(&self.checker);
321        let health_status = Arc::clone(&self.health_status);
322        let healthy_threshold = self.config.healthy_threshold;
323        let unhealthy_threshold = self.config.unhealthy_threshold;
324        let check_type = self.checker.check_type().to_string();
325        let mut shutdown_rx = self.shutdown_tx.subscribe();
326
327        tokio::spawn(async move {
328            let mut interval_timer = time::interval(interval);
329            interval_timer.set_missed_tick_behavior(time::MissedTickBehavior::Skip);
330
331            trace!(
332                target = %target,
333                check_type = %check_type,
334                interval_ms = interval.as_millis(),
335                "Health check task started"
336            );
337
338            loop {
339                tokio::select! {
340                    _ = interval_timer.tick() => {
341                        // Perform health check
342                        trace!(
343                            target = %target,
344                            check_type = %check_type,
345                            "Performing health check"
346                        );
347                        let start = Instant::now();
348                        let result = checker.check(&target).await;
349                        let check_duration = start.elapsed();
350
351                        // Update health status
352                        let mut status_map = health_status.write().await;
353                        if let Some(status) = status_map.get_mut(&target) {
354                            status.last_check = Instant::now();
355                            status.total_checks += 1;
356
357                            match result {
358                                Ok(response_time) => {
359                                    status.consecutive_successes += 1;
360                                    status.consecutive_failures = 0;
361                                    status.last_success = Some(Instant::now());
362                                    status.last_error = None;
363                                    status.total_successes += 1;
364
365                                    // Update average response time
366                                    let response_ms = response_time.as_millis() as f64;
367                                    status.avg_response_time =
368                                        (status.avg_response_time * (status.total_successes - 1) as f64
369                                        + response_ms) / status.total_successes as f64;
370
371                                    // Check if should mark as healthy
372                                    if !status.healthy && status.consecutive_successes >= healthy_threshold {
373                                        status.healthy = true;
374                                        info!(
375                                            target = %target,
376                                            consecutive_successes = status.consecutive_successes,
377                                            avg_response_ms = format!("{:.2}", status.avg_response_time),
378                                            total_checks = status.total_checks,
379                                            "Target marked as healthy"
380                                        );
381                                    }
382
383                                    trace!(
384                                        target = %target,
385                                        response_time_ms = response_ms,
386                                        check_duration_ms = check_duration.as_millis(),
387                                        consecutive_successes = status.consecutive_successes,
388                                        health_score = format!("{:.2}", status.health_score()),
389                                        "Health check succeeded"
390                                    );
391                                }
392                                Err(error) => {
393                                    status.consecutive_failures += 1;
394                                    status.consecutive_successes = 0;
395                                    status.last_error = Some(error.clone());
396
397                                    // Check if should mark as unhealthy
398                                    if status.healthy && status.consecutive_failures >= unhealthy_threshold {
399                                        status.healthy = false;
400                                        warn!(
401                                            target = %target,
402                                            consecutive_failures = status.consecutive_failures,
403                                            error = %error,
404                                            total_checks = status.total_checks,
405                                            health_score = format!("{:.2}", status.health_score()),
406                                            "Target marked as unhealthy"
407                                        );
408                                    } else {
409                                        debug!(
410                                            target = %target,
411                                            error = %error,
412                                            consecutive_failures = status.consecutive_failures,
413                                            unhealthy_threshold = unhealthy_threshold,
414                                            "Health check failed"
415                                        );
416                                    }
417                                }
418                            }
419                        }
420                    }
421                    _ = shutdown_rx.recv() => {
422                        info!(target = %target, "Stopping health check task");
423                        break;
424                    }
425                }
426            }
427
428            debug!(target = %target, "Health check task stopped");
429        })
430    }
431
432    /// Stop health checking
433    pub async fn stop(&self) {
434        let task_count = self.check_handles.read().await.len();
435        info!(task_count = task_count, "Stopping health checker");
436
437        // Send shutdown signal
438        let _ = self.shutdown_tx.send(());
439
440        // Wait for all tasks to complete
441        let mut handles = self.check_handles.write().await;
442        for handle in handles.drain(..) {
443            let _ = handle.await;
444        }
445
446        info!("Health checker stopped successfully");
447    }
448
449    /// Get health status for a target
450    pub async fn get_status(&self, target: &str) -> Option<TargetHealthInfo> {
451        self.health_status.read().await.get(target).cloned()
452    }
453
454    /// Get all health statuses
455    pub async fn get_all_statuses(&self) -> HashMap<String, TargetHealthInfo> {
456        self.health_status.read().await.clone()
457    }
458
459    /// Check if target is healthy
460    pub async fn is_healthy(&self, target: &str) -> bool {
461        self.health_status
462            .read()
463            .await
464            .get(target)
465            .map(|s| s.healthy)
466            .unwrap_or(false)
467    }
468
469    /// Get healthy targets
470    pub async fn get_healthy_targets(&self) -> Vec<String> {
471        self.health_status
472            .read()
473            .await
474            .iter()
475            .filter_map(|(target, status)| {
476                if status.healthy {
477                    Some(target.clone())
478                } else {
479                    None
480                }
481            })
482            .collect()
483    }
484
485    /// Mark target as unhealthy (for passive health checking)
486    pub async fn mark_unhealthy(&self, target: &str, reason: String) {
487        if let Some(status) = self.health_status.write().await.get_mut(target) {
488            if status.healthy {
489                status.healthy = false;
490                status.consecutive_failures = self.config.unhealthy_threshold;
491                status.consecutive_successes = 0;
492                status.last_error = Some(reason.clone());
493                warn!(
494                    target = %target,
495                    reason = %reason,
496                    "Target marked unhealthy by passive check"
497                );
498            }
499        }
500    }
501}
502
503impl Default for TargetHealthInfo {
504    fn default() -> Self {
505        Self::new()
506    }
507}
508
509impl TargetHealthInfo {
510    /// Create new health status (initially healthy)
511    pub fn new() -> Self {
512        Self {
513            healthy: true,
514            consecutive_successes: 0,
515            consecutive_failures: 0,
516            last_check: Instant::now(),
517            last_success: Some(Instant::now()),
518            last_error: None,
519            total_checks: 0,
520            total_successes: 0,
521            avg_response_time: 0.0,
522        }
523    }
524
525    /// Get health score (0.0 - 1.0)
526    pub fn health_score(&self) -> f64 {
527        if self.total_checks == 0 {
528            return 1.0;
529        }
530        self.total_successes as f64 / self.total_checks as f64
531    }
532
533    /// Check if status is degraded (healthy but with recent failures)
534    pub fn is_degraded(&self) -> bool {
535        self.healthy && self.consecutive_failures > 0
536    }
537}
538
539#[async_trait]
540impl HealthCheckImpl for HttpHealthCheck {
541    async fn check(&self, target: &str) -> Result<Duration, String> {
542        let start = Instant::now();
543
544        // Parse target address
545        let addr: SocketAddr = target
546            .parse()
547            .map_err(|e| format!("Invalid address: {}", e))?;
548
549        // Connect with timeout
550        let stream = time::timeout(self.timeout, TcpStream::connect(addr))
551            .await
552            .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
553            .map_err(|e| format!("Connection failed: {}", e))?;
554
555        // Build HTTP request
556        let host = self.host.as_deref().unwrap_or(target);
557        let request = format!(
558            "GET {} HTTP/1.1\r\nHost: {}\r\nUser-Agent: Grapsus-HealthCheck/1.0\r\nConnection: close\r\n\r\n",
559            self.path,
560            host
561        );
562
563        // Send request and read response
564        let mut stream = stream;
565        stream
566            .write_all(request.as_bytes())
567            .await
568            .map_err(|e| format!("Failed to send request: {}", e))?;
569
570        let mut response = vec![0u8; 1024];
571        let n = stream
572            .read(&mut response)
573            .await
574            .map_err(|e| format!("Failed to read response: {}", e))?;
575
576        if n == 0 {
577            return Err("Empty response".to_string());
578        }
579
580        // Parse status code
581        let response_str = String::from_utf8_lossy(&response[..n]);
582        let status_code = parse_http_status(&response_str)
583            .ok_or_else(|| "Failed to parse HTTP status".to_string())?;
584
585        if status_code == self.expected_status {
586            Ok(start.elapsed())
587        } else {
588            Err(format!(
589                "Unexpected status code: {} (expected {})",
590                status_code, self.expected_status
591            ))
592        }
593    }
594
595    fn check_type(&self) -> &str {
596        "HTTP"
597    }
598}
599
600#[async_trait]
601impl HealthCheckImpl for TcpHealthCheck {
602    async fn check(&self, target: &str) -> Result<Duration, String> {
603        let start = Instant::now();
604
605        // Parse target address
606        let addr: SocketAddr = target
607            .parse()
608            .map_err(|e| format!("Invalid address: {}", e))?;
609
610        // Connect with timeout
611        time::timeout(self.timeout, TcpStream::connect(addr))
612            .await
613            .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
614            .map_err(|e| format!("Connection failed: {}", e))?;
615
616        Ok(start.elapsed())
617    }
618
619    fn check_type(&self) -> &str {
620        "TCP"
621    }
622}
623
624#[async_trait]
625impl HealthCheckImpl for GrpcHealthCheck {
626    async fn check(&self, target: &str) -> Result<Duration, String> {
627        let start = Instant::now();
628
629        // NOTE: Full gRPC health check requires `tonic` crate for HTTP/2 support.
630        // This implementation uses TCP connectivity as a reasonable fallback.
631        // The gRPC health checking protocol (grpc.health.v1.Health/Check) would
632        // return SERVING, NOT_SERVING, or UNKNOWN for the specified service.
633
634        let addr: SocketAddr = target
635            .parse()
636            .map_err(|e| format!("Invalid address: {}", e))?;
637
638        // TCP connectivity check as fallback
639        let stream = time::timeout(self.timeout, TcpStream::connect(addr))
640            .await
641            .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
642            .map_err(|e| format!("Connection failed: {}", e))?;
643
644        // Verify connection is writable (basic health indicator)
645        stream
646            .writable()
647            .await
648            .map_err(|e| format!("Connection not writable: {}", e))?;
649
650        debug!(
651            target = %target,
652            service = %self.service,
653            "gRPC health check using TCP fallback (full gRPC protocol requires tonic)"
654        );
655
656        Ok(start.elapsed())
657    }
658
659    fn check_type(&self) -> &str {
660        "gRPC"
661    }
662}
663
664#[async_trait]
665impl HealthCheckImpl for InferenceHealthCheck {
666    async fn check(&self, target: &str) -> Result<Duration, String> {
667        let start = Instant::now();
668
669        // Parse target address
670        let addr: SocketAddr = target
671            .parse()
672            .map_err(|e| format!("Invalid address: {}", e))?;
673
674        // Connect with timeout
675        let stream = time::timeout(self.timeout, TcpStream::connect(addr))
676            .await
677            .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
678            .map_err(|e| format!("Connection failed: {}", e))?;
679
680        // Build HTTP request for the models endpoint
681        let request = format!(
682            "GET {} HTTP/1.1\r\nHost: {}\r\nUser-Agent: Grapsus-HealthCheck/1.0\r\nAccept: application/json\r\nConnection: close\r\n\r\n",
683            self.endpoint,
684            target
685        );
686
687        // Send request and read response
688        let mut stream = stream;
689        stream
690            .write_all(request.as_bytes())
691            .await
692            .map_err(|e| format!("Failed to send request: {}", e))?;
693
694        // Read response (larger buffer for JSON response)
695        let mut response = vec![0u8; 8192];
696        let n = stream
697            .read(&mut response)
698            .await
699            .map_err(|e| format!("Failed to read response: {}", e))?;
700
701        if n == 0 {
702            return Err("Empty response".to_string());
703        }
704
705        // Parse status code
706        let response_str = String::from_utf8_lossy(&response[..n]);
707        let status_code = parse_http_status(&response_str)
708            .ok_or_else(|| "Failed to parse HTTP status".to_string())?;
709
710        if status_code != 200 {
711            return Err(format!(
712                "Unexpected status code: {} (expected 200)",
713                status_code
714            ));
715        }
716
717        // If expected models are specified, verify they're in the response
718        if !self.expected_models.is_empty() {
719            // Find the JSON body (after headers)
720            if let Some(body_start) = response_str.find("\r\n\r\n") {
721                let body = &response_str[body_start + 4..];
722
723                // Check if each expected model is mentioned in the response
724                for model in &self.expected_models {
725                    if !body.contains(model) {
726                        return Err(format!("Expected model '{}' not found in response", model));
727                    }
728                }
729
730                debug!(
731                    target = %target,
732                    endpoint = %self.endpoint,
733                    expected_models = ?self.expected_models,
734                    "All expected models found in inference health check"
735                );
736            } else {
737                return Err("Could not find response body".to_string());
738            }
739        }
740
741        trace!(
742            target = %target,
743            endpoint = %self.endpoint,
744            response_time_ms = start.elapsed().as_millis(),
745            "Inference health check passed"
746        );
747
748        Ok(start.elapsed())
749    }
750
751    fn check_type(&self) -> &str {
752        "Inference"
753    }
754}
755
756#[async_trait]
757impl HealthCheckImpl for InferenceProbeCheck {
758    async fn check(&self, target: &str) -> Result<Duration, String> {
759        let start = Instant::now();
760
761        // Parse target address
762        let addr: SocketAddr = target
763            .parse()
764            .map_err(|e| format!("Invalid address: {}", e))?;
765
766        // Connect with timeout
767        let stream = time::timeout(self.timeout, TcpStream::connect(addr))
768            .await
769            .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
770            .map_err(|e| format!("Connection failed: {}", e))?;
771
772        // Build completion request body
773        let body = format!(
774            r#"{{"model":"{}","prompt":"{}","max_tokens":{}}}"#,
775            self.config.model, self.config.prompt, self.config.max_tokens
776        );
777
778        // Build HTTP request
779        let request = format!(
780            "POST {} HTTP/1.1\r\nHost: {}\r\nUser-Agent: Grapsus-HealthCheck/1.0\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
781            self.config.endpoint,
782            target,
783            body.len(),
784            body
785        );
786
787        // Send request
788        let mut stream = stream;
789        stream
790            .write_all(request.as_bytes())
791            .await
792            .map_err(|e| format!("Failed to send request: {}", e))?;
793
794        // Read response
795        let mut response = vec![0u8; 16384];
796        let n = stream
797            .read(&mut response)
798            .await
799            .map_err(|e| format!("Failed to read response: {}", e))?;
800
801        if n == 0 {
802            return Err("Empty response".to_string());
803        }
804
805        let latency = start.elapsed();
806
807        // Parse status code
808        let response_str = String::from_utf8_lossy(&response[..n]);
809        let status_code = parse_http_status(&response_str)
810            .ok_or_else(|| "Failed to parse HTTP status".to_string())?;
811
812        if status_code != 200 {
813            return Err(format!(
814                "Inference probe failed: status {} (expected 200)",
815                status_code
816            ));
817        }
818
819        // Verify response contains choices array
820        if let Some(body_start) = response_str.find("\r\n\r\n") {
821            let body = &response_str[body_start + 4..];
822            if !body.contains("\"choices\"") {
823                return Err("Inference probe response missing 'choices' field".to_string());
824            }
825        }
826
827        // Check latency threshold if configured
828        if let Some(max_ms) = self.config.max_latency_ms {
829            if latency.as_millis() as u64 > max_ms {
830                return Err(format!(
831                    "Inference probe latency {}ms exceeds threshold {}ms",
832                    latency.as_millis(),
833                    max_ms
834                ));
835            }
836        }
837
838        trace!(
839            target = %target,
840            model = %self.config.model,
841            latency_ms = latency.as_millis(),
842            "Inference probe health check passed"
843        );
844
845        Ok(latency)
846    }
847
848    fn check_type(&self) -> &str {
849        "InferenceProbe"
850    }
851}
852
853#[async_trait]
854impl HealthCheckImpl for ModelStatusCheck {
855    async fn check(&self, target: &str) -> Result<Duration, String> {
856        let start = Instant::now();
857
858        // Parse target address
859        let addr: SocketAddr = target
860            .parse()
861            .map_err(|e| format!("Invalid address: {}", e))?;
862
863        // Check each model's status
864        for model in &self.config.models {
865            let endpoint = self.config.endpoint_pattern.replace("{model}", model);
866
867            // Connect with timeout
868            let stream = time::timeout(self.timeout, TcpStream::connect(addr))
869                .await
870                .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
871                .map_err(|e| format!("Connection failed: {}", e))?;
872
873            // Build HTTP request
874            let request = format!(
875                "GET {} HTTP/1.1\r\nHost: {}\r\nUser-Agent: Grapsus-HealthCheck/1.0\r\nAccept: application/json\r\nConnection: close\r\n\r\n",
876                endpoint,
877                target
878            );
879
880            // Send request
881            let mut stream = stream;
882            stream
883                .write_all(request.as_bytes())
884                .await
885                .map_err(|e| format!("Failed to send request: {}", e))?;
886
887            // Read response
888            let mut response = vec![0u8; 8192];
889            let n = stream
890                .read(&mut response)
891                .await
892                .map_err(|e| format!("Failed to read response: {}", e))?;
893
894            if n == 0 {
895                return Err(format!("Empty response for model '{}'", model));
896            }
897
898            let response_str = String::from_utf8_lossy(&response[..n]);
899            let status_code = parse_http_status(&response_str)
900                .ok_or_else(|| "Failed to parse HTTP status".to_string())?;
901
902            if status_code != 200 {
903                return Err(format!(
904                    "Model '{}' status check failed: HTTP {}",
905                    model, status_code
906                ));
907            }
908
909            // Extract status field from JSON body
910            if let Some(body_start) = response_str.find("\r\n\r\n") {
911                let body = &response_str[body_start + 4..];
912                let status = extract_json_field(body, &self.config.status_field);
913
914                match status {
915                    Some(s) if s == self.config.expected_status => {
916                        trace!(
917                            target = %target,
918                            model = %model,
919                            status = %s,
920                            "Model status check passed"
921                        );
922                    }
923                    Some(s) => {
924                        return Err(format!(
925                            "Model '{}' status '{}' != expected '{}'",
926                            model, s, self.config.expected_status
927                        ));
928                    }
929                    None => {
930                        return Err(format!(
931                            "Model '{}' status field '{}' not found",
932                            model, self.config.status_field
933                        ));
934                    }
935                }
936            }
937        }
938
939        Ok(start.elapsed())
940    }
941
942    fn check_type(&self) -> &str {
943        "ModelStatus"
944    }
945}
946
947#[async_trait]
948impl HealthCheckImpl for QueueDepthCheck {
949    async fn check(&self, target: &str) -> Result<Duration, String> {
950        let start = Instant::now();
951
952        // Parse target address
953        let addr: SocketAddr = target
954            .parse()
955            .map_err(|e| format!("Invalid address: {}", e))?;
956
957        let endpoint = self
958            .config
959            .endpoint
960            .as_ref()
961            .unwrap_or(&self.models_endpoint);
962
963        // Connect with timeout
964        let stream = time::timeout(self.timeout, TcpStream::connect(addr))
965            .await
966            .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
967            .map_err(|e| format!("Connection failed: {}", e))?;
968
969        // Build HTTP request
970        let request = format!(
971            "GET {} HTTP/1.1\r\nHost: {}\r\nUser-Agent: Grapsus-HealthCheck/1.0\r\nAccept: application/json\r\nConnection: close\r\n\r\n",
972            endpoint,
973            target
974        );
975
976        // Send request
977        let mut stream = stream;
978        stream
979            .write_all(request.as_bytes())
980            .await
981            .map_err(|e| format!("Failed to send request: {}", e))?;
982
983        // Read response
984        let mut response = vec![0u8; 8192];
985        let n = stream
986            .read(&mut response)
987            .await
988            .map_err(|e| format!("Failed to read response: {}", e))?;
989
990        if n == 0 {
991            return Err("Empty response".to_string());
992        }
993
994        let response_str = String::from_utf8_lossy(&response[..n]);
995
996        // Extract queue depth from header or body
997        let queue_depth = if let Some(ref header_name) = self.config.header {
998            extract_header_value(&response_str, header_name).and_then(|v| v.parse::<u64>().ok())
999        } else if let Some(ref field) = self.config.body_field {
1000            if let Some(body_start) = response_str.find("\r\n\r\n") {
1001                let body = &response_str[body_start + 4..];
1002                extract_json_field(body, field).and_then(|v| v.parse::<u64>().ok())
1003            } else {
1004                None
1005            }
1006        } else {
1007            return Err("No queue depth source configured (header or body_field)".to_string());
1008        };
1009
1010        let depth = queue_depth.ok_or_else(|| "Could not extract queue depth".to_string())?;
1011
1012        // Check thresholds
1013        if depth >= self.config.unhealthy_threshold {
1014            return Err(format!(
1015                "Queue depth {} exceeds unhealthy threshold {}",
1016                depth, self.config.unhealthy_threshold
1017            ));
1018        }
1019
1020        if depth >= self.config.degraded_threshold {
1021            warn!(
1022                target = %target,
1023                queue_depth = depth,
1024                threshold = self.config.degraded_threshold,
1025                "Queue depth exceeds degraded threshold"
1026            );
1027        }
1028
1029        trace!(
1030            target = %target,
1031            queue_depth = depth,
1032            "Queue depth check passed"
1033        );
1034
1035        Ok(start.elapsed())
1036    }
1037
1038    fn check_type(&self) -> &str {
1039        "QueueDepth"
1040    }
1041}
1042
1043#[async_trait]
1044impl HealthCheckImpl for CompositeInferenceHealthCheck {
1045    async fn check(&self, target: &str) -> Result<Duration, String> {
1046        let start = Instant::now();
1047
1048        // Run base inference check first (always required)
1049        self.base_check.check(target).await?;
1050
1051        // Run optional sub-checks (all must pass)
1052        if let Some(ref probe) = self.inference_probe {
1053            probe.check(target).await?;
1054        }
1055
1056        if let Some(ref status) = self.model_status {
1057            status.check(target).await?;
1058        }
1059
1060        if let Some(ref queue) = self.queue_depth {
1061            queue.check(target).await?;
1062        }
1063
1064        trace!(
1065            target = %target,
1066            total_time_ms = start.elapsed().as_millis(),
1067            "Composite inference health check passed"
1068        );
1069
1070        Ok(start.elapsed())
1071    }
1072
1073    fn check_type(&self) -> &str {
1074        "CompositeInference"
1075    }
1076}
1077
1078/// Extract a header value from HTTP response
1079fn extract_header_value(response: &str, header_name: &str) -> Option<String> {
1080    let header_lower = header_name.to_lowercase();
1081    for line in response.lines() {
1082        if line.is_empty() || line == "\r" {
1083            break; // End of headers
1084        }
1085        if let Some((name, value)) = line.split_once(':') {
1086            if name.trim().to_lowercase() == header_lower {
1087                return Some(value.trim().to_string());
1088            }
1089        }
1090    }
1091    None
1092}
1093
1094/// Extract a field from JSON body using dot notation (e.g., "status" or "state.loaded")
1095fn extract_json_field(body: &str, field_path: &str) -> Option<String> {
1096    let json: serde_json::Value = serde_json::from_str(body).ok()?;
1097    let parts: Vec<&str> = field_path.split('.').collect();
1098    let mut current = &json;
1099
1100    for part in parts {
1101        current = current.get(part)?;
1102    }
1103
1104    match current {
1105        serde_json::Value::String(s) => Some(s.clone()),
1106        serde_json::Value::Number(n) => Some(n.to_string()),
1107        serde_json::Value::Bool(b) => Some(b.to_string()),
1108        _ => None,
1109    }
1110}
1111
1112/// Parse HTTP status code from response
1113fn parse_http_status(response: &str) -> Option<u16> {
1114    response
1115        .lines()
1116        .next()?
1117        .split_whitespace()
1118        .nth(1)?
1119        .parse()
1120        .ok()
1121}
1122
1123/// Passive health checker that monitors request outcomes
1124///
1125/// Observes request success/failure rates to detect unhealthy targets
1126/// without performing explicit health checks. Works in combination with
1127/// `ActiveHealthChecker` for comprehensive health monitoring.
1128pub struct PassiveHealthChecker {
1129    /// Failure rate threshold (0.0 - 1.0)
1130    failure_rate_threshold: f64,
1131    /// Window size for calculating failure rate
1132    window_size: usize,
1133    /// Request outcomes per target (ring buffer)
1134    outcomes: Arc<RwLock<HashMap<String, Vec<bool>>>>,
1135    /// Last error per target
1136    last_errors: Arc<RwLock<HashMap<String, String>>>,
1137    /// Active health checker reference
1138    active_checker: Option<Arc<ActiveHealthChecker>>,
1139}
1140
1141impl PassiveHealthChecker {
1142    /// Create new passive health checker
1143    pub fn new(
1144        failure_rate_threshold: f64,
1145        window_size: usize,
1146        active_checker: Option<Arc<ActiveHealthChecker>>,
1147    ) -> Self {
1148        debug!(
1149            failure_rate_threshold = format!("{:.2}", failure_rate_threshold),
1150            window_size = window_size,
1151            has_active_checker = active_checker.is_some(),
1152            "Creating passive health checker"
1153        );
1154        Self {
1155            failure_rate_threshold,
1156            window_size,
1157            outcomes: Arc::new(RwLock::new(HashMap::new())),
1158            last_errors: Arc::new(RwLock::new(HashMap::new())),
1159            active_checker,
1160        }
1161    }
1162
1163    /// Record request outcome with optional error message
1164    pub async fn record_outcome(&self, target: &str, success: bool, error: Option<&str>) {
1165        trace!(
1166            target = %target,
1167            success = success,
1168            error = ?error,
1169            "Recording request outcome"
1170        );
1171
1172        // Track last error
1173        if let Some(err_msg) = error {
1174            self.last_errors
1175                .write()
1176                .await
1177                .insert(target.to_string(), err_msg.to_string());
1178        } else if success {
1179            // Clear last error on success
1180            self.last_errors.write().await.remove(target);
1181        }
1182
1183        let mut outcomes = self.outcomes.write().await;
1184        let target_outcomes = outcomes
1185            .entry(target.to_string())
1186            .or_insert_with(|| Vec::with_capacity(self.window_size));
1187
1188        // Add outcome to ring buffer
1189        if target_outcomes.len() >= self.window_size {
1190            target_outcomes.remove(0);
1191        }
1192        target_outcomes.push(success);
1193
1194        // Calculate failure rate
1195        let failures = target_outcomes.iter().filter(|&&s| !s).count();
1196        let failure_rate = failures as f64 / target_outcomes.len() as f64;
1197
1198        trace!(
1199            target = %target,
1200            failure_rate = format!("{:.2}", failure_rate),
1201            window_samples = target_outcomes.len(),
1202            failures = failures,
1203            "Updated failure rate"
1204        );
1205
1206        // Mark unhealthy if failure rate exceeds threshold
1207        if failure_rate > self.failure_rate_threshold {
1208            warn!(
1209                target = %target,
1210                failure_rate = format!("{:.2}", failure_rate * 100.0),
1211                threshold = format!("{:.2}", self.failure_rate_threshold * 100.0),
1212                window_samples = target_outcomes.len(),
1213                "Failure rate exceeds threshold"
1214            );
1215            if let Some(ref checker) = self.active_checker {
1216                checker
1217                    .mark_unhealthy(
1218                        target,
1219                        format!(
1220                            "Failure rate {:.2}% exceeds threshold",
1221                            failure_rate * 100.0
1222                        ),
1223                    )
1224                    .await;
1225            }
1226        }
1227    }
1228
1229    /// Get failure rate for a target
1230    pub async fn get_failure_rate(&self, target: &str) -> Option<f64> {
1231        let outcomes = self.outcomes.read().await;
1232        outcomes.get(target).map(|target_outcomes| {
1233            let failures = target_outcomes.iter().filter(|&&s| !s).count();
1234            failures as f64 / target_outcomes.len() as f64
1235        })
1236    }
1237
1238    /// Get last error for a target
1239    pub async fn get_last_error(&self, target: &str) -> Option<String> {
1240        self.last_errors.read().await.get(target).cloned()
1241    }
1242}
1243
1244// ============================================================================
1245// Warmth Tracker (Passive Cold Model Detection)
1246// ============================================================================
1247
1248use dashmap::DashMap;
1249use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
1250use grapsus_common::{ColdModelAction, WarmthDetectionConfig};
1251
1252/// Warmth tracker for detecting cold models after idle periods
1253///
1254/// This is a passive tracker that observes actual request latency rather than
1255/// sending active probes. It tracks baseline latency per target and detects
1256/// when first-request latency after an idle period indicates a cold model.
1257pub struct WarmthTracker {
1258    /// Configuration for warmth detection
1259    config: WarmthDetectionConfig,
1260    /// Per-target warmth state
1261    targets: DashMap<String, TargetWarmthState>,
1262}
1263
1264/// Per-target warmth tracking state
1265struct TargetWarmthState {
1266    /// Baseline latency in milliseconds (EWMA)
1267    baseline_latency_ms: AtomicU64,
1268    /// Number of samples collected for baseline
1269    sample_count: AtomicU32,
1270    /// Last request timestamp (millis since epoch)
1271    last_request_ms: AtomicU64,
1272    /// Currently considered cold
1273    is_cold: AtomicBool,
1274    /// Total cold starts detected (for metrics)
1275    cold_start_count: AtomicU64,
1276}
1277
1278impl TargetWarmthState {
1279    fn new() -> Self {
1280        Self {
1281            baseline_latency_ms: AtomicU64::new(0),
1282            sample_count: AtomicU32::new(0),
1283            last_request_ms: AtomicU64::new(0),
1284            is_cold: AtomicBool::new(false),
1285            cold_start_count: AtomicU64::new(0),
1286        }
1287    }
1288
1289    fn update_baseline(&self, latency_ms: u64, sample_size: u32) {
1290        let count = self.sample_count.fetch_add(1, Ordering::Relaxed);
1291        let current = self.baseline_latency_ms.load(Ordering::Relaxed);
1292
1293        if count < sample_size {
1294            // Building initial baseline - simple average
1295            let new_baseline = if count == 0 {
1296                latency_ms
1297            } else {
1298                (current * count as u64 + latency_ms) / (count as u64 + 1)
1299            };
1300            self.baseline_latency_ms
1301                .store(new_baseline, Ordering::Relaxed);
1302        } else {
1303            // EWMA update: new = alpha * sample + (1 - alpha) * old
1304            // Using alpha = 0.1 for smooth updates
1305            let alpha = 0.1_f64;
1306            let new_baseline = (alpha * latency_ms as f64 + (1.0 - alpha) * current as f64) as u64;
1307            self.baseline_latency_ms
1308                .store(new_baseline, Ordering::Relaxed);
1309        }
1310    }
1311}
1312
1313impl WarmthTracker {
1314    /// Create a new warmth tracker with the given configuration
1315    pub fn new(config: WarmthDetectionConfig) -> Self {
1316        Self {
1317            config,
1318            targets: DashMap::new(),
1319        }
1320    }
1321
1322    /// Create a warmth tracker with default configuration
1323    pub fn with_defaults() -> Self {
1324        Self::new(WarmthDetectionConfig {
1325            sample_size: 10,
1326            cold_threshold_multiplier: 3.0,
1327            idle_cold_timeout_secs: 300,
1328            cold_action: ColdModelAction::LogOnly,
1329        })
1330    }
1331
1332    /// Record a completed request and detect cold starts
1333    ///
1334    /// Returns true if a cold start was detected
1335    pub fn record_request(&self, target: &str, latency: Duration) -> bool {
1336        let now_ms = std::time::SystemTime::now()
1337            .duration_since(std::time::UNIX_EPOCH)
1338            .map(|d| d.as_millis() as u64)
1339            .unwrap_or(0);
1340
1341        let latency_ms = latency.as_millis() as u64;
1342        let idle_threshold_ms = self.config.idle_cold_timeout_secs * 1000;
1343
1344        let state = self
1345            .targets
1346            .entry(target.to_string())
1347            .or_insert_with(TargetWarmthState::new);
1348
1349        let last_request = state.last_request_ms.load(Ordering::Relaxed);
1350        let idle_duration_ms = if last_request > 0 {
1351            now_ms.saturating_sub(last_request)
1352        } else {
1353            0
1354        };
1355
1356        // Update last request time
1357        state.last_request_ms.store(now_ms, Ordering::Relaxed);
1358
1359        // Check if this might be a cold start (first request after idle period)
1360        if idle_duration_ms >= idle_threshold_ms {
1361            let baseline = state.baseline_latency_ms.load(Ordering::Relaxed);
1362
1363            // Only check if we have a baseline
1364            if baseline > 0 {
1365                let threshold = (baseline as f64 * self.config.cold_threshold_multiplier) as u64;
1366
1367                if latency_ms > threshold {
1368                    // Cold start detected!
1369                    state.is_cold.store(true, Ordering::Release);
1370                    state.cold_start_count.fetch_add(1, Ordering::Relaxed);
1371
1372                    warn!(
1373                        target = %target,
1374                        latency_ms = latency_ms,
1375                        baseline_ms = baseline,
1376                        threshold_ms = threshold,
1377                        idle_duration_secs = idle_duration_ms / 1000,
1378                        cold_action = ?self.config.cold_action,
1379                        "Cold model detected - latency spike after idle period"
1380                    );
1381
1382                    return true;
1383                }
1384            }
1385        }
1386
1387        // Normal request - update baseline and clear cold flag
1388        state.is_cold.store(false, Ordering::Release);
1389        state.update_baseline(latency_ms, self.config.sample_size);
1390
1391        trace!(
1392            target = %target,
1393            latency_ms = latency_ms,
1394            baseline_ms = state.baseline_latency_ms.load(Ordering::Relaxed),
1395            sample_count = state.sample_count.load(Ordering::Relaxed),
1396            "Recorded request latency for warmth tracking"
1397        );
1398
1399        false
1400    }
1401
1402    /// Check if a target is currently considered cold
1403    pub fn is_cold(&self, target: &str) -> bool {
1404        self.targets
1405            .get(target)
1406            .map(|s| s.is_cold.load(Ordering::Acquire))
1407            .unwrap_or(false)
1408    }
1409
1410    /// Get the configured action for cold models
1411    pub fn cold_action(&self) -> ColdModelAction {
1412        self.config.cold_action
1413    }
1414
1415    /// Get baseline latency for a target (in ms)
1416    pub fn baseline_latency_ms(&self, target: &str) -> Option<u64> {
1417        self.targets
1418            .get(target)
1419            .map(|s| s.baseline_latency_ms.load(Ordering::Relaxed))
1420    }
1421
1422    /// Get cold start count for a target
1423    pub fn cold_start_count(&self, target: &str) -> u64 {
1424        self.targets
1425            .get(target)
1426            .map(|s| s.cold_start_count.load(Ordering::Relaxed))
1427            .unwrap_or(0)
1428    }
1429
1430    /// Check if warmth tracking should affect load balancing for this target
1431    pub fn should_deprioritize(&self, target: &str) -> bool {
1432        if !self.is_cold(target) {
1433            return false;
1434        }
1435
1436        match self.config.cold_action {
1437            ColdModelAction::LogOnly => false,
1438            ColdModelAction::MarkDegraded | ColdModelAction::MarkUnhealthy => true,
1439        }
1440    }
1441}
1442
1443#[cfg(test)]
1444mod tests {
1445    use super::*;
1446
1447    #[tokio::test]
1448    async fn test_health_status() {
1449        let status = TargetHealthInfo::new();
1450        assert!(status.healthy);
1451        assert_eq!(status.health_score(), 1.0);
1452        assert!(!status.is_degraded());
1453    }
1454
1455    #[tokio::test]
1456    async fn test_passive_health_checker() {
1457        let checker = PassiveHealthChecker::new(0.5, 10, None);
1458
1459        // Record some outcomes
1460        for _ in 0..5 {
1461            checker.record_outcome("target1", true, None).await;
1462        }
1463        for _ in 0..3 {
1464            checker
1465                .record_outcome("target1", false, Some("HTTP 503"))
1466                .await;
1467        }
1468
1469        let failure_rate = checker.get_failure_rate("target1").await.unwrap();
1470        assert!(failure_rate > 0.3 && failure_rate < 0.4);
1471    }
1472
1473    #[test]
1474    fn test_parse_http_status() {
1475        let response = "HTTP/1.1 200 OK\r\n";
1476        assert_eq!(parse_http_status(response), Some(200));
1477
1478        let response = "HTTP/1.1 404 Not Found\r\n";
1479        assert_eq!(parse_http_status(response), Some(404));
1480
1481        let response = "Invalid response";
1482        assert_eq!(parse_http_status(response), None);
1483    }
1484
1485    #[test]
1486    fn test_warmth_tracker_baseline() {
1487        let tracker = WarmthTracker::with_defaults();
1488
1489        // First few requests should build baseline
1490        for i in 0..10 {
1491            let cold = tracker.record_request("target1", Duration::from_millis(100));
1492            assert!(!cold, "Should not detect cold on request {}", i);
1493        }
1494
1495        // Check baseline was established
1496        let baseline = tracker.baseline_latency_ms("target1");
1497        assert!(baseline.is_some());
1498        assert!(baseline.unwrap() > 0 && baseline.unwrap() <= 100);
1499    }
1500
1501    #[test]
1502    fn test_warmth_tracker_cold_detection() {
1503        let config = WarmthDetectionConfig {
1504            sample_size: 5,
1505            cold_threshold_multiplier: 2.0,
1506            idle_cold_timeout_secs: 0, // Immediate idle for testing
1507            cold_action: ColdModelAction::MarkDegraded,
1508        };
1509        let tracker = WarmthTracker::new(config);
1510
1511        // Build baseline with 100ms latency
1512        for _ in 0..5 {
1513            tracker.record_request("target1", Duration::from_millis(100));
1514        }
1515
1516        // Wait a tiny bit to simulate idle
1517        std::thread::sleep(Duration::from_millis(10));
1518
1519        // Next request with 3x latency (> 2x threshold) should detect cold
1520        let cold = tracker.record_request("target1", Duration::from_millis(300));
1521        assert!(cold, "Should detect cold start");
1522        assert!(tracker.is_cold("target1"));
1523        assert_eq!(tracker.cold_start_count("target1"), 1);
1524    }
1525
1526    #[test]
1527    fn test_warmth_tracker_no_cold_on_normal_latency() {
1528        let config = WarmthDetectionConfig {
1529            sample_size: 5,
1530            cold_threshold_multiplier: 3.0,
1531            idle_cold_timeout_secs: 0,
1532            cold_action: ColdModelAction::LogOnly,
1533        };
1534        let tracker = WarmthTracker::new(config);
1535
1536        // Build baseline
1537        for _ in 0..5 {
1538            tracker.record_request("target1", Duration::from_millis(100));
1539        }
1540
1541        std::thread::sleep(Duration::from_millis(10));
1542
1543        // Request with only 1.5x latency (< 3x threshold) should not detect cold
1544        let cold = tracker.record_request("target1", Duration::from_millis(150));
1545        assert!(!cold, "Should not detect cold for normal variation");
1546        assert!(!tracker.is_cold("target1"));
1547    }
1548
1549    #[test]
1550    fn test_warmth_tracker_deprioritize() {
1551        let config = WarmthDetectionConfig {
1552            sample_size: 2,
1553            cold_threshold_multiplier: 2.0,
1554            idle_cold_timeout_secs: 0,
1555            cold_action: ColdModelAction::MarkDegraded,
1556        };
1557        let tracker = WarmthTracker::new(config);
1558
1559        // Build baseline and trigger cold
1560        tracker.record_request("target1", Duration::from_millis(100));
1561        tracker.record_request("target1", Duration::from_millis(100));
1562        std::thread::sleep(Duration::from_millis(10));
1563        tracker.record_request("target1", Duration::from_millis(300));
1564
1565        // Should deprioritize when cold and action is MarkDegraded
1566        assert!(tracker.should_deprioritize("target1"));
1567
1568        // New normal request clears cold flag
1569        tracker.record_request("target1", Duration::from_millis(100));
1570        assert!(!tracker.should_deprioritize("target1"));
1571    }
1572}