1use pingora_load_balancing::{
8 discovery::Static,
9 health_check::{HealthCheck as PingoraHealthCheck, HttpHealthCheck, TcpHealthCheck},
10 Backend, Backends,
11};
12use std::collections::BTreeSet;
13use std::sync::Arc;
14use std::time::Duration;
15use tokio::sync::RwLock;
16use tracing::{debug, info, trace, warn};
17
18use crate::grpc_health::GrpcHealthCheck;
19use crate::upstream::inference_health::InferenceHealthCheck;
20
21use sentinel_common::types::HealthCheckType;
22use sentinel_config::{HealthCheck as HealthCheckConfig, UpstreamConfig};
23
24pub struct ActiveHealthChecker {
29 upstream_id: String,
31 backends: Arc<Backends>,
33 interval: Duration,
35 parallel: bool,
37 health_callback: Arc<RwLock<Option<HealthChangeCallback>>>,
39}
40
41pub type HealthChangeCallback = Box<dyn Fn(&str, bool) + Send + Sync>;
43
44impl ActiveHealthChecker {
45 pub fn new(config: &UpstreamConfig) -> Option<Self> {
47 let health_config = config.health_check.as_ref()?;
48
49 info!(
50 upstream_id = %config.id,
51 check_type = ?health_config.check_type,
52 interval_secs = health_config.interval_secs,
53 "Creating active health checker"
54 );
55
56 let mut backend_set = BTreeSet::new();
58 for target in &config.targets {
59 match Backend::new_with_weight(&target.address, target.weight as usize) {
60 Ok(backend) => {
61 debug!(
62 upstream_id = %config.id,
63 target = %target.address,
64 weight = target.weight,
65 "Added backend for health checking"
66 );
67 backend_set.insert(backend);
68 }
69 Err(e) => {
70 warn!(
71 upstream_id = %config.id,
72 target = %target.address,
73 error = %e,
74 "Failed to create backend for health checking"
75 );
76 }
77 }
78 }
79
80 if backend_set.is_empty() {
81 warn!(
82 upstream_id = %config.id,
83 "No backends created for health checking"
84 );
85 return None;
86 }
87
88 let discovery = Static::new(backend_set);
90 let mut backends = Backends::new(discovery);
91
92 let health_check: Box<dyn PingoraHealthCheck + Send + Sync> =
94 Self::create_health_check(health_config, &config.id);
95
96 backends.set_health_check(health_check);
97
98 Some(Self {
99 upstream_id: config.id.clone(),
100 backends: Arc::new(backends),
101 interval: Duration::from_secs(health_config.interval_secs),
102 parallel: true,
103 health_callback: Arc::new(RwLock::new(None)),
104 })
105 }
106
107 fn create_health_check(
109 config: &HealthCheckConfig,
110 upstream_id: &str,
111 ) -> Box<dyn PingoraHealthCheck + Send + Sync> {
112 match &config.check_type {
113 HealthCheckType::Http {
114 path,
115 expected_status,
116 host,
117 } => {
118 let hostname = host.as_deref().unwrap_or("localhost");
119 let mut hc = HttpHealthCheck::new(hostname, false);
120
121 hc.consecutive_success = config.healthy_threshold as usize;
123 hc.consecutive_failure = config.unhealthy_threshold as usize;
124
125 if path != "/" {
129 if let Ok(req) =
131 pingora_http::RequestHeader::build("GET", path.as_bytes(), None)
132 {
133 hc.req = req;
134 }
135 }
136
137 debug!(
141 upstream_id = %upstream_id,
142 path = %path,
143 expected_status = expected_status,
144 host = hostname,
145 consecutive_success = hc.consecutive_success,
146 consecutive_failure = hc.consecutive_failure,
147 "Created HTTP health check"
148 );
149
150 Box::new(hc)
151 }
152 HealthCheckType::Tcp => {
153 let mut hc = TcpHealthCheck::new();
155 hc.consecutive_success = config.healthy_threshold as usize;
156 hc.consecutive_failure = config.unhealthy_threshold as usize;
157
158 debug!(
159 upstream_id = %upstream_id,
160 consecutive_success = hc.consecutive_success,
161 consecutive_failure = hc.consecutive_failure,
162 "Created TCP health check"
163 );
164
165 hc
166 }
167 HealthCheckType::Grpc { service } => {
168 let timeout = Duration::from_secs(config.timeout_secs);
169 let mut hc = GrpcHealthCheck::new(service.clone(), timeout);
170 hc.consecutive_success = config.healthy_threshold as usize;
171 hc.consecutive_failure = config.unhealthy_threshold as usize;
172
173 info!(
174 upstream_id = %upstream_id,
175 service = %service,
176 timeout_secs = config.timeout_secs,
177 consecutive_success = hc.consecutive_success,
178 consecutive_failure = hc.consecutive_failure,
179 "Created gRPC health check"
180 );
181
182 Box::new(hc)
183 }
184 HealthCheckType::Inference {
185 endpoint,
186 expected_models,
187 readiness: _,
188 } => {
189 let timeout = Duration::from_secs(config.timeout_secs);
191 let mut hc = InferenceHealthCheck::new(
192 endpoint.clone(),
193 expected_models.clone(),
194 timeout,
195 );
196 hc.consecutive_success = config.healthy_threshold as usize;
197 hc.consecutive_failure = config.unhealthy_threshold as usize;
198
199 info!(
200 upstream_id = %upstream_id,
201 endpoint = %endpoint,
202 expected_models = ?expected_models,
203 timeout_secs = config.timeout_secs,
204 consecutive_success = hc.consecutive_success,
205 consecutive_failure = hc.consecutive_failure,
206 "Created inference health check with model verification"
207 );
208
209 Box::new(hc)
210 }
211 }
212 }
213
214 pub async fn set_health_callback(&self, callback: HealthChangeCallback) {
216 *self.health_callback.write().await = Some(callback);
217 }
218
219 pub async fn run_health_check(&self) {
221 trace!(
222 upstream_id = %self.upstream_id,
223 parallel = self.parallel,
224 "Running health check cycle"
225 );
226
227 self.backends.run_health_check(self.parallel).await;
228 }
229
230 pub fn is_backend_healthy(&self, address: &str) -> bool {
232 let backends = self.backends.get_backend();
233 for backend in backends.iter() {
234 if backend.addr.to_string() == address {
235 return self.backends.ready(backend);
236 }
237 }
238 true
240 }
241
242 pub fn get_health_statuses(&self) -> Vec<(String, bool)> {
244 let backends = self.backends.get_backend();
245 backends
246 .iter()
247 .map(|b| {
248 let addr = b.addr.to_string();
249 let healthy = self.backends.ready(b);
250 (addr, healthy)
251 })
252 .collect()
253 }
254
255 pub fn interval(&self) -> Duration {
257 self.interval
258 }
259
260 pub fn upstream_id(&self) -> &str {
262 &self.upstream_id
263 }
264}
265
266pub struct HealthCheckRunner {
268 checkers: Vec<ActiveHealthChecker>,
270 running: Arc<RwLock<bool>>,
272}
273
274impl HealthCheckRunner {
275 pub fn new() -> Self {
277 Self {
278 checkers: Vec::new(),
279 running: Arc::new(RwLock::new(false)),
280 }
281 }
282
283 pub fn add_checker(&mut self, checker: ActiveHealthChecker) {
285 info!(
286 upstream_id = %checker.upstream_id,
287 interval_secs = checker.interval.as_secs(),
288 "Added health checker to runner"
289 );
290 self.checkers.push(checker);
291 }
292
293 pub fn checker_count(&self) -> usize {
295 self.checkers.len()
296 }
297
298 pub async fn run(&self) {
300 if self.checkers.is_empty() {
301 info!("No health checkers configured, skipping health check loop");
302 return;
303 }
304
305 *self.running.write().await = true;
306
307 info!(
308 checker_count = self.checkers.len(),
309 "Starting health check runner"
310 );
311
312 let min_interval = self
314 .checkers
315 .iter()
316 .map(|c| c.interval)
317 .min()
318 .unwrap_or(Duration::from_secs(10));
319
320 let mut interval = tokio::time::interval(min_interval);
321 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
322
323 loop {
324 interval.tick().await;
325
326 if !*self.running.read().await {
327 info!("Health check runner stopped");
328 break;
329 }
330
331 for checker in &self.checkers {
333 checker.run_health_check().await;
334
335 let statuses = checker.get_health_statuses();
337 for (addr, healthy) in &statuses {
338 trace!(
339 upstream_id = %checker.upstream_id,
340 backend = %addr,
341 healthy = healthy,
342 "Backend health status"
343 );
344 }
345 }
346 }
347 }
348
349 pub async fn stop(&self) {
351 info!("Stopping health check runner");
352 *self.running.write().await = false;
353 }
354
355 pub fn get_health(&self, upstream_id: &str, address: &str) -> Option<bool> {
357 self.checkers
358 .iter()
359 .find(|c| c.upstream_id == upstream_id)
360 .map(|c| c.is_backend_healthy(address))
361 }
362
363 pub fn get_upstream_health(&self, upstream_id: &str) -> Option<Vec<(String, bool)>> {
365 self.checkers
366 .iter()
367 .find(|c| c.upstream_id == upstream_id)
368 .map(|c| c.get_health_statuses())
369 }
370}
371
372impl Default for HealthCheckRunner {
373 fn default() -> Self {
374 Self::new()
375 }
376}
377
378#[cfg(test)]
379mod tests {
380 use super::*;
381 use sentinel_common::types::LoadBalancingAlgorithm;
382 use sentinel_config::{
383 ConnectionPoolConfig, HttpVersionConfig, UpstreamTarget, UpstreamTimeouts,
384 };
385 use std::collections::HashMap;
386 use std::sync::Once;
387
388 static INIT: Once = Once::new();
389
390 fn init_crypto_provider() {
391 INIT.call_once(|| {
392 let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
393 });
394 }
395
396 fn create_test_config() -> UpstreamConfig {
397 UpstreamConfig {
398 id: "test-upstream".to_string(),
399 targets: vec![UpstreamTarget {
400 address: "127.0.0.1:8081".to_string(),
401 weight: 1,
402 max_requests: None,
403 metadata: HashMap::new(),
404 }],
405 load_balancing: LoadBalancingAlgorithm::RoundRobin,
406 health_check: Some(HealthCheckConfig {
407 check_type: HealthCheckType::Http {
408 path: "/health".to_string(),
409 expected_status: 200,
410 host: None,
411 },
412 interval_secs: 5,
413 timeout_secs: 2,
414 healthy_threshold: 2,
415 unhealthy_threshold: 3,
416 }),
417 connection_pool: ConnectionPoolConfig::default(),
418 timeouts: UpstreamTimeouts::default(),
419 tls: None,
420 http_version: HttpVersionConfig::default(),
421 }
422 }
423
424 #[test]
425 fn test_create_health_checker() {
426 init_crypto_provider();
427 let config = create_test_config();
428 let checker = ActiveHealthChecker::new(&config);
429 assert!(checker.is_some());
430
431 let checker = checker.unwrap();
432 assert_eq!(checker.upstream_id, "test-upstream");
433 assert_eq!(checker.interval, Duration::from_secs(5));
434 }
435
436 #[test]
437 fn test_no_health_check_config() {
438 let mut config = create_test_config();
439 config.health_check = None;
440
441 let checker = ActiveHealthChecker::new(&config);
442 assert!(checker.is_none());
443 }
444
445 #[test]
446 fn test_health_check_runner() {
447 init_crypto_provider();
448 let mut runner = HealthCheckRunner::new();
449 assert_eq!(runner.checker_count(), 0);
450
451 let config = create_test_config();
452 if let Some(checker) = ActiveHealthChecker::new(&config) {
453 runner.add_checker(checker);
454 assert_eq!(runner.checker_count(), 1);
455 }
456 }
457}