1use crate::Backend;
18use arc_swap::ArcSwap;
19use async_trait::async_trait;
20use pingora_core::connectors::{http::Connector as HttpConnector, TransportConnector};
21use pingora_core::upstreams::peer::{BasicPeer, HttpPeer, Peer};
22use pingora_error::{Error, ErrorType::CustomCode, Result};
23use pingora_http::{RequestHeader, ResponseHeader};
24use std::sync::Arc;
25use std::time::Duration;
26
27#[async_trait]
30pub trait HealthObserve {
31 async fn observe(&self, target: &Backend, healthy: bool);
33}
34pub type HealthObserveCallback = Box<dyn HealthObserve + Send + Sync>;
36
37pub type BackendSummary = Box<dyn Fn(&Backend) -> String + Send + Sync>;
39
40#[async_trait]
42pub trait HealthCheck {
43 async fn check(&self, target: &Backend) -> Result<()>;
47
48 async fn health_status_change(&self, _target: &Backend, _healthy: bool) {}
50
51 fn backend_summary(&self, target: &Backend) -> String {
53 format!("{target:?}")
54 }
55
56 fn health_threshold(&self, success: bool) -> usize;
61}
62
63pub struct TcpHealthCheck {
67 pub consecutive_success: usize,
69 pub consecutive_failure: usize,
71 pub peer_template: BasicPeer,
80 connector: TransportConnector,
81 pub health_changed_callback: Option<HealthObserveCallback>,
83}
84
85impl Default for TcpHealthCheck {
86 fn default() -> Self {
87 let mut peer_template = BasicPeer::new("0.0.0.0:1");
88 peer_template.options.connection_timeout = Some(Duration::from_secs(1));
89 TcpHealthCheck {
90 consecutive_success: 1,
91 consecutive_failure: 1,
92 peer_template,
93 connector: TransportConnector::new(None),
94 health_changed_callback: None,
95 }
96 }
97}
98
99impl TcpHealthCheck {
100 pub fn new() -> Box<Self> {
105 Box::<TcpHealthCheck>::default()
106 }
107
108 pub fn new_tls(sni: &str) -> Box<Self> {
112 let mut new = Self::default();
113 new.peer_template.sni = sni.into();
114 Box::new(new)
115 }
116
117 pub fn set_connector(&mut self, connector: TransportConnector) {
119 self.connector = connector;
120 }
121}
122
123#[async_trait]
124impl HealthCheck for TcpHealthCheck {
125 fn health_threshold(&self, success: bool) -> usize {
126 if success {
127 self.consecutive_success
128 } else {
129 self.consecutive_failure
130 }
131 }
132
133 async fn check(&self, target: &Backend) -> Result<()> {
134 let mut peer = self.peer_template.clone();
135 peer._address = target.addr.clone();
136 self.connector.get_stream(&peer).await.map(|_| {})
137 }
138
139 async fn health_status_change(&self, target: &Backend, healthy: bool) {
140 if let Some(callback) = &self.health_changed_callback {
141 callback.observe(target, healthy).await;
142 }
143 }
144}
145
146type Validator = Box<dyn Fn(&ResponseHeader) -> Result<()> + Send + Sync>;
147
148pub struct HttpHealthCheck {
152 pub consecutive_success: usize,
154 pub consecutive_failure: usize,
156 pub peer_template: HttpPeer,
164 pub reuse_connection: bool,
171 pub req: RequestHeader,
173 connector: HttpConnector,
174 pub validator: Option<Validator>,
178 pub port_override: Option<u16>,
181 pub health_changed_callback: Option<HealthObserveCallback>,
183 pub backend_summary_callback: Option<BackendSummary>,
185}
186
187impl HttpHealthCheck {
188 pub fn new(host: &str, tls: bool) -> Self {
197 let mut req = RequestHeader::build("GET", b"/", None).unwrap();
198 req.append_header("Host", host).unwrap();
199 let sni = if tls { host.into() } else { String::new() };
200 let mut peer_template = HttpPeer::new("0.0.0.0:1", tls, sni);
201 peer_template.options.connection_timeout = Some(Duration::from_secs(1));
202 peer_template.options.read_timeout = Some(Duration::from_secs(1));
203 HttpHealthCheck {
204 consecutive_success: 1,
205 consecutive_failure: 1,
206 peer_template,
207 connector: HttpConnector::new(None),
208 reuse_connection: false,
209 req,
210 validator: None,
211 port_override: None,
212 health_changed_callback: None,
213 backend_summary_callback: None,
214 }
215 }
216
217 pub fn set_connector(&mut self, connector: HttpConnector) {
219 self.connector = connector;
220 }
221
222 pub fn set_backend_summary<F>(&mut self, callback: F)
223 where
224 F: Fn(&Backend) -> String + Send + Sync + 'static,
225 {
226 self.backend_summary_callback = Some(Box::new(callback));
227 }
228}
229
230#[async_trait]
231impl HealthCheck for HttpHealthCheck {
232 fn health_threshold(&self, success: bool) -> usize {
233 if success {
234 self.consecutive_success
235 } else {
236 self.consecutive_failure
237 }
238 }
239
240 async fn check(&self, target: &Backend) -> Result<()> {
241 let mut peer = self.peer_template.clone();
242 peer._address = target.addr.clone();
243 if let Some(port) = self.port_override {
244 peer._address.set_port(port);
245 }
246 let session = self.connector.get_http_session(&peer).await?;
247
248 let mut session = session.0;
249 let req = Box::new(self.req.clone());
250 session.write_request_header(req).await?;
251 session.finish_request_body().await?;
252
253 if let Some(read_timeout) = peer.options.read_timeout {
254 session.set_read_timeout(Some(read_timeout));
255 }
256
257 session.read_response_header().await?;
258
259 let resp = session.response_header().expect("just read");
260
261 if let Some(validator) = self.validator.as_ref() {
262 validator(resp)?;
263 } else if resp.status != 200 {
264 return Error::e_explain(
265 CustomCode("non 200 code", resp.status.as_u16()),
266 "during http healthcheck",
267 );
268 };
269
270 while session.read_response_body().await?.is_some() {
271 }
273
274 if self.reuse_connection {
275 let idle_timeout = peer.idle_timeout();
276 self.connector
277 .release_http_session(session, &peer, idle_timeout)
278 .await;
279 }
280
281 Ok(())
282 }
283 async fn health_status_change(&self, target: &Backend, healthy: bool) {
284 if let Some(callback) = &self.health_changed_callback {
285 callback.observe(target, healthy).await;
286 }
287 }
288 fn backend_summary(&self, target: &Backend) -> String {
289 if let Some(callback) = &self.backend_summary_callback {
290 callback(target)
291 } else {
292 format!("{target:?}")
293 }
294 }
295}
296
297#[derive(Clone)]
298struct HealthInner {
299 healthy: bool,
301 enabled: bool,
303 consecutive_counter: usize,
307}
308
309pub(crate) struct Health(ArcSwap<HealthInner>);
311
312impl Default for Health {
313 fn default() -> Self {
314 Health(ArcSwap::new(Arc::new(HealthInner {
315 healthy: true, enabled: true,
317 consecutive_counter: 0,
318 })))
319 }
320}
321
322impl Clone for Health {
323 fn clone(&self) -> Self {
324 let inner = self.0.load_full();
325 Health(ArcSwap::new(inner))
326 }
327}
328
329impl Health {
330 pub fn ready(&self) -> bool {
331 let h = self.0.load();
332 h.healthy && h.enabled
333 }
334
335 pub fn enable(&self, enabled: bool) {
336 let h = self.0.load();
337 if h.enabled != enabled {
338 let mut new_health = (**h).clone();
340 new_health.enabled = enabled;
341 self.0.store(Arc::new(new_health));
342 };
343 }
344
345 pub fn observe_health(&self, health: bool, flip_threshold: usize) -> bool {
347 let h = self.0.load();
348 let mut flipped = false;
349 if h.healthy != health {
350 let mut new_health = (**h).clone();
353 new_health.consecutive_counter += 1;
354 if new_health.consecutive_counter >= flip_threshold {
355 new_health.healthy = health;
356 new_health.consecutive_counter = 0;
357 flipped = true;
358 }
359 self.0.store(Arc::new(new_health));
360 } else if h.consecutive_counter > 0 {
361 let mut new_health = (**h).clone();
364 new_health.consecutive_counter = 0;
365 self.0.store(Arc::new(new_health));
366 }
367 flipped
368 }
369}
370
371#[cfg(test)]
372mod test {
373 use std::{
374 collections::{BTreeSet, HashMap},
375 sync::atomic::{AtomicU16, Ordering},
376 };
377
378 use super::*;
379 use crate::{discovery, Backends, SocketAddr};
380 use async_trait::async_trait;
381 use http::Extensions;
382
383 #[tokio::test]
384 async fn test_tcp_check() {
385 let tcp_check = TcpHealthCheck::default();
386
387 let backend = Backend {
388 addr: SocketAddr::Inet("1.1.1.1:80".parse().unwrap()),
389 weight: 1,
390 ext: Extensions::new(),
391 };
392
393 assert!(tcp_check.check(&backend).await.is_ok());
394
395 let backend = Backend {
396 addr: SocketAddr::Inet("1.1.1.1:79".parse().unwrap()),
397 weight: 1,
398 ext: Extensions::new(),
399 };
400
401 assert!(tcp_check.check(&backend).await.is_err());
402 }
403
404 #[cfg(feature = "any_tls")]
405 #[tokio::test]
406 async fn test_tls_check() {
407 let tls_check = TcpHealthCheck::new_tls("one.one.one.one");
408 let backend = Backend {
409 addr: SocketAddr::Inet("1.1.1.1:443".parse().unwrap()),
410 weight: 1,
411 ext: Extensions::new(),
412 };
413
414 assert!(tls_check.check(&backend).await.is_ok());
415 }
416
417 #[cfg(feature = "any_tls")]
418 #[tokio::test]
419 async fn test_https_check() {
420 let https_check = HttpHealthCheck::new("one.one.one.one", true);
421
422 let backend = Backend {
423 addr: SocketAddr::Inet("1.1.1.1:443".parse().unwrap()),
424 weight: 1,
425 ext: Extensions::new(),
426 };
427
428 assert!(https_check.check(&backend).await.is_ok());
429 }
430
431 #[tokio::test]
432 async fn test_http_custom_check() {
433 let mut http_check = HttpHealthCheck::new("one.one.one.one", false);
434 http_check.validator = Some(Box::new(|resp: &ResponseHeader| {
435 if resp.status == 301 {
436 Ok(())
437 } else {
438 Error::e_explain(
439 CustomCode("non 301 code", resp.status.as_u16()),
440 "during http healthcheck",
441 )
442 }
443 }));
444
445 let backend = Backend {
446 addr: SocketAddr::Inet("1.1.1.1:80".parse().unwrap()),
447 weight: 1,
448 ext: Extensions::new(),
449 };
450
451 http_check.check(&backend).await.unwrap();
452
453 assert!(http_check.check(&backend).await.is_ok());
454 }
455
456 #[tokio::test]
457 async fn test_health_observe() {
458 struct Observe {
459 unhealthy_count: Arc<AtomicU16>,
460 }
461 #[async_trait]
462 impl HealthObserve for Observe {
463 async fn observe(&self, _target: &Backend, healthy: bool) {
464 if !healthy {
465 self.unhealthy_count.fetch_add(1, Ordering::Relaxed);
466 }
467 }
468 }
469
470 let good_backend = Backend::new("127.0.0.1:79").unwrap();
471 let new_good_backends = || -> (BTreeSet<Backend>, HashMap<u64, bool>) {
472 let mut healthy = HashMap::new();
473 healthy.insert(good_backend.hash_key(), true);
474 let mut backends = BTreeSet::new();
475 backends.extend(vec![good_backend.clone()]);
476 (backends, healthy)
477 };
478 {
480 let unhealthy_count = Arc::new(AtomicU16::new(0));
481 let ob = Observe {
482 unhealthy_count: unhealthy_count.clone(),
483 };
484 let bob = Box::new(ob);
485 let tcp_check = TcpHealthCheck {
486 health_changed_callback: Some(bob),
487 ..Default::default()
488 };
489
490 let discovery = discovery::Static::default();
491 let mut backends = Backends::new(Box::new(discovery));
492 backends.set_health_check(Box::new(tcp_check));
493 let result = new_good_backends();
494 backends.do_update(result.0, result.1, |_backend: Arc<BTreeSet<Backend>>| {});
495 assert!(backends.ready(&good_backend));
497
498 backends.run_health_check(false).await;
500 assert!(1 == unhealthy_count.load(Ordering::Relaxed));
501 assert!(!backends.ready(&good_backend));
503 }
504
505 {
507 let unhealthy_count = Arc::new(AtomicU16::new(0));
508 let ob = Observe {
509 unhealthy_count: unhealthy_count.clone(),
510 };
511 let bob = Box::new(ob);
512
513 let mut https_check = HttpHealthCheck::new("one.one.one.one", true);
514 https_check.health_changed_callback = Some(bob);
515
516 let discovery = discovery::Static::default();
517 let mut backends = Backends::new(Box::new(discovery));
518 backends.set_health_check(Box::new(https_check));
519 let result = new_good_backends();
520 backends.do_update(result.0, result.1, |_backend: Arc<BTreeSet<Backend>>| {});
521 assert!(backends.ready(&good_backend));
523 backends.run_health_check(false).await;
525 assert!(1 == unhealthy_count.load(Ordering::Relaxed));
526 assert!(!backends.ready(&good_backend));
527 }
528 }
529}