sentinel_proxy/upstream/
health.rs

1//! Active health checking using Pingora's HttpHealthCheck
2//!
3//! This module provides active health probing for upstream backends using
4//! Pingora's built-in health check infrastructure. It complements the passive
5//! health tracking in load balancers by periodically probing backends.
6
7use pingora_load_balancing::{
8    discovery::Static,
9    health_check::{HealthCheck as PingoraHealthCheck, HttpHealthCheck, TcpHealthCheck},
10    Backend, Backends,
11};
12use std::collections::BTreeSet;
13use std::sync::Arc;
14use std::time::Duration;
15use tokio::sync::RwLock;
16use tracing::{debug, info, trace, warn};
17
18use crate::grpc_health::GrpcHealthCheck;
19use crate::upstream::inference_health::InferenceHealthCheck;
20
21use sentinel_common::types::HealthCheckType;
22use sentinel_config::{HealthCheck as HealthCheckConfig, UpstreamConfig};
23
24/// Active health checker for an upstream pool
25///
26/// This wraps Pingora's `Backends` struct with health checking enabled.
27/// It runs periodic health probes and reports status back to the load balancer.
28pub struct ActiveHealthChecker {
29    /// Upstream ID
30    upstream_id: String,
31    /// Pingora backends with health checking
32    backends: Arc<Backends>,
33    /// Health check interval
34    interval: Duration,
35    /// Whether to run checks in parallel
36    parallel: bool,
37    /// Callback to notify load balancer of health changes
38    health_callback: Arc<RwLock<Option<HealthChangeCallback>>>,
39}
40
41/// Callback type for health status changes
42pub type HealthChangeCallback = Box<dyn Fn(&str, bool) + Send + Sync>;
43
44impl ActiveHealthChecker {
45    /// Create a new active health checker from upstream config
46    pub fn new(config: &UpstreamConfig) -> Option<Self> {
47        let health_config = config.health_check.as_ref()?;
48
49        info!(
50            upstream_id = %config.id,
51            check_type = ?health_config.check_type,
52            interval_secs = health_config.interval_secs,
53            "Creating active health checker"
54        );
55
56        // Create backends from targets
57        let mut backend_set = BTreeSet::new();
58        for target in &config.targets {
59            match Backend::new_with_weight(&target.address, target.weight as usize) {
60                Ok(backend) => {
61                    debug!(
62                        upstream_id = %config.id,
63                        target = %target.address,
64                        weight = target.weight,
65                        "Added backend for health checking"
66                    );
67                    backend_set.insert(backend);
68                }
69                Err(e) => {
70                    warn!(
71                        upstream_id = %config.id,
72                        target = %target.address,
73                        error = %e,
74                        "Failed to create backend for health checking"
75                    );
76                }
77            }
78        }
79
80        if backend_set.is_empty() {
81            warn!(
82                upstream_id = %config.id,
83                "No backends created for health checking"
84            );
85            return None;
86        }
87
88        // Create static discovery (Static::new returns Box<Self>)
89        let discovery = Static::new(backend_set);
90        let mut backends = Backends::new(discovery);
91
92        // Create and configure health check
93        let health_check: Box<dyn PingoraHealthCheck + Send + Sync> =
94            Self::create_health_check(health_config, &config.id);
95
96        backends.set_health_check(health_check);
97
98        Some(Self {
99            upstream_id: config.id.clone(),
100            backends: Arc::new(backends),
101            interval: Duration::from_secs(health_config.interval_secs),
102            parallel: true,
103            health_callback: Arc::new(RwLock::new(None)),
104        })
105    }
106
107    /// Create the appropriate health check based on config
108    fn create_health_check(
109        config: &HealthCheckConfig,
110        upstream_id: &str,
111    ) -> Box<dyn PingoraHealthCheck + Send + Sync> {
112        match &config.check_type {
113            HealthCheckType::Http {
114                path,
115                expected_status,
116                host,
117            } => {
118                let hostname = host.as_deref().unwrap_or("localhost");
119                let mut hc = HttpHealthCheck::new(hostname, false);
120
121                // Configure thresholds
122                hc.consecutive_success = config.healthy_threshold as usize;
123                hc.consecutive_failure = config.unhealthy_threshold as usize;
124
125                // Configure request path
126                // Note: HttpHealthCheck sends GET to / by default
127                // We customize by modifying hc.req for non-root paths
128                if path != "/" {
129                    // Create custom request header for the health check path
130                    if let Ok(req) =
131                        pingora_http::RequestHeader::build("GET", path.as_bytes(), None)
132                    {
133                        hc.req = req;
134                    }
135                }
136
137                // Note: health_changed_callback requires implementing HealthObserve trait
138                // We use polling via run_health_check() and get_health_statuses() instead
139
140                debug!(
141                    upstream_id = %upstream_id,
142                    path = %path,
143                    expected_status = expected_status,
144                    host = hostname,
145                    consecutive_success = hc.consecutive_success,
146                    consecutive_failure = hc.consecutive_failure,
147                    "Created HTTP health check"
148                );
149
150                Box::new(hc)
151            }
152            HealthCheckType::Tcp => {
153                // TcpHealthCheck::new() returns Box<Self>
154                let mut hc = TcpHealthCheck::new();
155                hc.consecutive_success = config.healthy_threshold as usize;
156                hc.consecutive_failure = config.unhealthy_threshold as usize;
157
158                debug!(
159                    upstream_id = %upstream_id,
160                    consecutive_success = hc.consecutive_success,
161                    consecutive_failure = hc.consecutive_failure,
162                    "Created TCP health check"
163                );
164
165                hc
166            }
167            HealthCheckType::Grpc { service } => {
168                let timeout = Duration::from_secs(config.timeout_secs);
169                let mut hc = GrpcHealthCheck::new(service.clone(), timeout);
170                hc.consecutive_success = config.healthy_threshold as usize;
171                hc.consecutive_failure = config.unhealthy_threshold as usize;
172
173                info!(
174                    upstream_id = %upstream_id,
175                    service = %service,
176                    timeout_secs = config.timeout_secs,
177                    consecutive_success = hc.consecutive_success,
178                    consecutive_failure = hc.consecutive_failure,
179                    "Created gRPC health check"
180                );
181
182                Box::new(hc)
183            }
184            HealthCheckType::Inference {
185                endpoint,
186                expected_models,
187                readiness: _,
188            } => {
189                // Inference health check that verifies expected models are available
190                let timeout = Duration::from_secs(config.timeout_secs);
191                let mut hc = InferenceHealthCheck::new(
192                    endpoint.clone(),
193                    expected_models.clone(),
194                    timeout,
195                );
196                hc.consecutive_success = config.healthy_threshold as usize;
197                hc.consecutive_failure = config.unhealthy_threshold as usize;
198
199                info!(
200                    upstream_id = %upstream_id,
201                    endpoint = %endpoint,
202                    expected_models = ?expected_models,
203                    timeout_secs = config.timeout_secs,
204                    consecutive_success = hc.consecutive_success,
205                    consecutive_failure = hc.consecutive_failure,
206                    "Created inference health check with model verification"
207                );
208
209                Box::new(hc)
210            }
211        }
212    }
213
214    /// Set callback for health status changes
215    pub async fn set_health_callback(&self, callback: HealthChangeCallback) {
216        *self.health_callback.write().await = Some(callback);
217    }
218
219    /// Run a single health check cycle
220    pub async fn run_health_check(&self) {
221        trace!(
222            upstream_id = %self.upstream_id,
223            parallel = self.parallel,
224            "Running health check cycle"
225        );
226
227        self.backends.run_health_check(self.parallel).await;
228    }
229
230    /// Check if a specific backend is healthy
231    pub fn is_backend_healthy(&self, address: &str) -> bool {
232        let backends = self.backends.get_backend();
233        for backend in backends.iter() {
234            if backend.addr.to_string() == address {
235                return self.backends.ready(backend);
236            }
237        }
238        // Unknown backend, assume healthy
239        true
240    }
241
242    /// Get all backend health statuses
243    pub fn get_health_statuses(&self) -> Vec<(String, bool)> {
244        let backends = self.backends.get_backend();
245        backends
246            .iter()
247            .map(|b| {
248                let addr = b.addr.to_string();
249                let healthy = self.backends.ready(b);
250                (addr, healthy)
251            })
252            .collect()
253    }
254
255    /// Get the health check interval
256    pub fn interval(&self) -> Duration {
257        self.interval
258    }
259
260    /// Get the upstream ID
261    pub fn upstream_id(&self) -> &str {
262        &self.upstream_id
263    }
264}
265
266/// Health check runner that manages multiple upstream health checkers
267pub struct HealthCheckRunner {
268    /// Health checkers per upstream
269    checkers: Vec<ActiveHealthChecker>,
270    /// Whether the runner is active
271    running: Arc<RwLock<bool>>,
272}
273
274impl HealthCheckRunner {
275    /// Create a new health check runner
276    pub fn new() -> Self {
277        Self {
278            checkers: Vec::new(),
279            running: Arc::new(RwLock::new(false)),
280        }
281    }
282
283    /// Add a health checker for an upstream
284    pub fn add_checker(&mut self, checker: ActiveHealthChecker) {
285        info!(
286            upstream_id = %checker.upstream_id,
287            interval_secs = checker.interval.as_secs(),
288            "Added health checker to runner"
289        );
290        self.checkers.push(checker);
291    }
292
293    /// Get the number of health checkers
294    pub fn checker_count(&self) -> usize {
295        self.checkers.len()
296    }
297
298    /// Start the health check loop (runs until stopped)
299    pub async fn run(&self) {
300        if self.checkers.is_empty() {
301            info!("No health checkers configured, skipping health check loop");
302            return;
303        }
304
305        *self.running.write().await = true;
306
307        info!(
308            checker_count = self.checkers.len(),
309            "Starting health check runner"
310        );
311
312        // Find minimum interval
313        let min_interval = self
314            .checkers
315            .iter()
316            .map(|c| c.interval)
317            .min()
318            .unwrap_or(Duration::from_secs(10));
319
320        let mut interval = tokio::time::interval(min_interval);
321        interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
322
323        loop {
324            interval.tick().await;
325
326            if !*self.running.read().await {
327                info!("Health check runner stopped");
328                break;
329            }
330
331            // Run health checks for all upstreams
332            for checker in &self.checkers {
333                checker.run_health_check().await;
334
335                // Log current health statuses
336                let statuses = checker.get_health_statuses();
337                for (addr, healthy) in &statuses {
338                    trace!(
339                        upstream_id = %checker.upstream_id,
340                        backend = %addr,
341                        healthy = healthy,
342                        "Backend health status"
343                    );
344                }
345            }
346        }
347    }
348
349    /// Stop the health check loop
350    pub async fn stop(&self) {
351        info!("Stopping health check runner");
352        *self.running.write().await = false;
353    }
354
355    /// Get health status for a specific upstream and backend
356    pub fn get_health(&self, upstream_id: &str, address: &str) -> Option<bool> {
357        self.checkers
358            .iter()
359            .find(|c| c.upstream_id == upstream_id)
360            .map(|c| c.is_backend_healthy(address))
361    }
362
363    /// Get all health statuses for an upstream
364    pub fn get_upstream_health(&self, upstream_id: &str) -> Option<Vec<(String, bool)>> {
365        self.checkers
366            .iter()
367            .find(|c| c.upstream_id == upstream_id)
368            .map(|c| c.get_health_statuses())
369    }
370}
371
372impl Default for HealthCheckRunner {
373    fn default() -> Self {
374        Self::new()
375    }
376}
377
378#[cfg(test)]
379mod tests {
380    use super::*;
381    use sentinel_common::types::LoadBalancingAlgorithm;
382    use sentinel_config::{
383        ConnectionPoolConfig, HttpVersionConfig, UpstreamTarget, UpstreamTimeouts,
384    };
385    use std::collections::HashMap;
386    use std::sync::Once;
387
388    static INIT: Once = Once::new();
389
390    fn init_crypto_provider() {
391        INIT.call_once(|| {
392            let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
393        });
394    }
395
396    fn create_test_config() -> UpstreamConfig {
397        UpstreamConfig {
398            id: "test-upstream".to_string(),
399            targets: vec![UpstreamTarget {
400                address: "127.0.0.1:8081".to_string(),
401                weight: 1,
402                max_requests: None,
403                metadata: HashMap::new(),
404            }],
405            load_balancing: LoadBalancingAlgorithm::RoundRobin,
406            health_check: Some(HealthCheckConfig {
407                check_type: HealthCheckType::Http {
408                    path: "/health".to_string(),
409                    expected_status: 200,
410                    host: None,
411                },
412                interval_secs: 5,
413                timeout_secs: 2,
414                healthy_threshold: 2,
415                unhealthy_threshold: 3,
416            }),
417            connection_pool: ConnectionPoolConfig::default(),
418            timeouts: UpstreamTimeouts::default(),
419            tls: None,
420            http_version: HttpVersionConfig::default(),
421        }
422    }
423
424    #[test]
425    fn test_create_health_checker() {
426        init_crypto_provider();
427        let config = create_test_config();
428        let checker = ActiveHealthChecker::new(&config);
429        assert!(checker.is_some());
430
431        let checker = checker.unwrap();
432        assert_eq!(checker.upstream_id, "test-upstream");
433        assert_eq!(checker.interval, Duration::from_secs(5));
434    }
435
436    #[test]
437    fn test_no_health_check_config() {
438        let mut config = create_test_config();
439        config.health_check = None;
440
441        let checker = ActiveHealthChecker::new(&config);
442        assert!(checker.is_none());
443    }
444
445    #[test]
446    fn test_health_check_runner() {
447        init_crypto_provider();
448        let mut runner = HealthCheckRunner::new();
449        assert_eq!(runner.checker_count(), 0);
450
451        let config = create_test_config();
452        if let Some(checker) = ActiveHealthChecker::new(&config) {
453            runner.add_checker(checker);
454            assert_eq!(runner.checker_count(), 1);
455        }
456    }
457}