1use crate::Backend;
18use arc_swap::ArcSwap;
19use async_trait::async_trait;
20use pingora_core::connectors::http::custom;
21use pingora_core::connectors::{http::Connector as HttpConnector, TransportConnector};
22use pingora_core::custom_session;
23use pingora_core::protocols::http::custom::client::Session;
24use pingora_core::upstreams::peer::{BasicPeer, HttpPeer, Peer};
25use pingora_error::{Error, ErrorType::CustomCode, Result};
26use pingora_http::{RequestHeader, ResponseHeader};
27use std::sync::Arc;
28use std::time::Duration;
29
30#[async_trait]
33pub trait HealthObserve {
34 async fn observe(&self, target: &Backend, healthy: bool);
36}
37pub type HealthObserveCallback = Box<dyn HealthObserve + Send + Sync>;
39
40pub type BackendSummary = Box<dyn Fn(&Backend) -> String + Send + Sync>;
42
43#[async_trait]
45pub trait HealthCheck {
46 async fn check(&self, target: &Backend) -> Result<()>;
50
51 async fn health_status_change(&self, _target: &Backend, _healthy: bool) {}
53
54 fn backend_summary(&self, target: &Backend) -> String {
56 format!("{target:?}")
57 }
58
59 fn health_threshold(&self, success: bool) -> usize;
64}
65
66pub struct TcpHealthCheck {
70 pub consecutive_success: usize,
72 pub consecutive_failure: usize,
74 pub peer_template: BasicPeer,
83 connector: TransportConnector,
84 pub health_changed_callback: Option<HealthObserveCallback>,
86}
87
88impl Default for TcpHealthCheck {
89 fn default() -> Self {
90 let mut peer_template = BasicPeer::new("0.0.0.0:1");
91 peer_template.options.connection_timeout = Some(Duration::from_secs(1));
92 TcpHealthCheck {
93 consecutive_success: 1,
94 consecutive_failure: 1,
95 peer_template,
96 connector: TransportConnector::new(None),
97 health_changed_callback: None,
98 }
99 }
100}
101
102impl TcpHealthCheck {
103 pub fn new() -> Box<Self> {
108 Box::<TcpHealthCheck>::default()
109 }
110
111 pub fn new_tls(sni: &str) -> Box<Self> {
115 let mut new = Self::default();
116 new.peer_template.sni = sni.into();
117 Box::new(new)
118 }
119
120 pub fn set_connector(&mut self, connector: TransportConnector) {
122 self.connector = connector;
123 }
124}
125
126#[async_trait]
127impl HealthCheck for TcpHealthCheck {
128 fn health_threshold(&self, success: bool) -> usize {
129 if success {
130 self.consecutive_success
131 } else {
132 self.consecutive_failure
133 }
134 }
135
136 async fn check(&self, target: &Backend) -> Result<()> {
137 let mut peer = self.peer_template.clone();
138 peer._address = target.addr.clone();
139 self.connector.get_stream(&peer).await.map(|_| {})
140 }
141
142 async fn health_status_change(&self, target: &Backend, healthy: bool) {
143 if let Some(callback) = &self.health_changed_callback {
144 callback.observe(target, healthy).await;
145 }
146 }
147}
148
149type Validator = Box<dyn Fn(&ResponseHeader) -> Result<()> + Send + Sync>;
150
151pub struct HttpHealthCheck<C = ()>
155where
156 C: custom::Connector,
157{
158 pub consecutive_success: usize,
160 pub consecutive_failure: usize,
162 pub peer_template: HttpPeer,
170 pub reuse_connection: bool,
177 pub req: RequestHeader,
179 connector: HttpConnector<C>,
180 pub validator: Option<Validator>,
184 pub port_override: Option<u16>,
187 pub health_changed_callback: Option<HealthObserveCallback>,
189 pub backend_summary_callback: Option<BackendSummary>,
191}
192
193impl HttpHealthCheck<()> {
194 pub fn new(host: &str, tls: bool) -> Self {
203 let mut req = RequestHeader::build("GET", b"/", None).unwrap();
204 req.append_header("Host", host).unwrap();
205 let sni = if tls { host.into() } else { String::new() };
206 let mut peer_template = HttpPeer::new("0.0.0.0:1", tls, sni);
207 peer_template.options.connection_timeout = Some(Duration::from_secs(1));
208 peer_template.options.read_timeout = Some(Duration::from_secs(1));
209 HttpHealthCheck {
210 consecutive_success: 1,
211 consecutive_failure: 1,
212 peer_template,
213 connector: HttpConnector::new(None),
214 reuse_connection: false,
215 req,
216 validator: None,
217 port_override: None,
218 health_changed_callback: None,
219 backend_summary_callback: None,
220 }
221 }
222}
223
224impl<C> HttpHealthCheck<C>
225where
226 C: custom::Connector,
227{
228 pub fn new_custom(host: &str, tls: bool, custom: HttpConnector<C>) -> Self {
237 let mut req = RequestHeader::build("GET", b"/", None).unwrap();
238 req.append_header("Host", host).unwrap();
239 let sni = if tls { host.into() } else { String::new() };
240 let mut peer_template = HttpPeer::new("0.0.0.0:1", tls, sni);
241 peer_template.options.connection_timeout = Some(Duration::from_secs(1));
242 peer_template.options.read_timeout = Some(Duration::from_secs(1));
243 HttpHealthCheck {
244 consecutive_success: 1,
245 consecutive_failure: 1,
246 peer_template,
247 connector: custom,
248 reuse_connection: false,
249 req,
250 validator: None,
251 port_override: None,
252 health_changed_callback: None,
253 backend_summary_callback: None,
254 }
255 }
256
257 pub fn set_connector(&mut self, connector: HttpConnector<C>) {
259 self.connector = connector;
260 }
261
262 pub fn set_backend_summary<F>(&mut self, callback: F)
263 where
264 F: Fn(&Backend) -> String + Send + Sync + 'static,
265 {
266 self.backend_summary_callback = Some(Box::new(callback));
267 }
268}
269
270#[async_trait]
271impl<C> HealthCheck for HttpHealthCheck<C>
272where
273 C: custom::Connector,
274{
275 fn health_threshold(&self, success: bool) -> usize {
276 if success {
277 self.consecutive_success
278 } else {
279 self.consecutive_failure
280 }
281 }
282
283 async fn check(&self, target: &Backend) -> Result<()> {
284 let mut peer = self.peer_template.clone();
285 peer._address = target.addr.clone();
286 if let Some(port) = self.port_override {
287 peer._address.set_port(port);
288 }
289 let session = self.connector.get_http_session(&peer).await?;
290
291 let mut session = session.0;
292 let req = Box::new(self.req.clone());
293 session.write_request_header(req).await?;
294 session.finish_request_body().await?;
295
296 custom_session!(session.finish_custom().await?);
297
298 if let Some(read_timeout) = peer.options.read_timeout {
299 session.set_read_timeout(Some(read_timeout));
300 }
301
302 session.read_response_header().await?;
303
304 let resp = session.response_header().expect("just read");
305
306 if let Some(validator) = self.validator.as_ref() {
307 validator(resp)?;
308 } else if resp.status != 200 {
309 return Error::e_explain(
310 CustomCode("non 200 code", resp.status.as_u16()),
311 "during http healthcheck",
312 );
313 };
314
315 while session.read_response_body().await?.is_some() {
316 }
318
319 custom_session!(session.drain_custom_messages().await?);
321
322 if self.reuse_connection {
323 let idle_timeout = peer.idle_timeout();
324 self.connector
325 .release_http_session(session, &peer, idle_timeout)
326 .await;
327 }
328
329 Ok(())
330 }
331 async fn health_status_change(&self, target: &Backend, healthy: bool) {
332 if let Some(callback) = &self.health_changed_callback {
333 callback.observe(target, healthy).await;
334 }
335 }
336 fn backend_summary(&self, target: &Backend) -> String {
337 if let Some(callback) = &self.backend_summary_callback {
338 callback(target)
339 } else {
340 format!("{target:?}")
341 }
342 }
343}
344
345#[derive(Clone)]
346struct HealthInner {
347 healthy: bool,
349 enabled: bool,
351 consecutive_counter: usize,
355}
356
357pub(crate) struct Health(ArcSwap<HealthInner>);
359
360impl Default for Health {
361 fn default() -> Self {
362 Health(ArcSwap::new(Arc::new(HealthInner {
363 healthy: true, enabled: true,
365 consecutive_counter: 0,
366 })))
367 }
368}
369
370impl Clone for Health {
371 fn clone(&self) -> Self {
372 let inner = self.0.load_full();
373 Health(ArcSwap::new(inner))
374 }
375}
376
377impl Health {
378 pub fn ready(&self) -> bool {
379 let h = self.0.load();
380 h.healthy && h.enabled
381 }
382
383 pub fn enable(&self, enabled: bool) {
384 let h = self.0.load();
385 if h.enabled != enabled {
386 let mut new_health = (**h).clone();
388 new_health.enabled = enabled;
389 self.0.store(Arc::new(new_health));
390 };
391 }
392
393 pub fn observe_health(&self, health: bool, flip_threshold: usize) -> bool {
395 let h = self.0.load();
396 let mut flipped = false;
397 if h.healthy != health {
398 let mut new_health = (**h).clone();
401 new_health.consecutive_counter += 1;
402 if new_health.consecutive_counter >= flip_threshold {
403 new_health.healthy = health;
404 new_health.consecutive_counter = 0;
405 flipped = true;
406 }
407 self.0.store(Arc::new(new_health));
408 } else if h.consecutive_counter > 0 {
409 let mut new_health = (**h).clone();
412 new_health.consecutive_counter = 0;
413 self.0.store(Arc::new(new_health));
414 }
415 flipped
416 }
417}
418
419#[cfg(test)]
420mod test {
421 use std::{
422 collections::{BTreeSet, HashMap},
423 sync::atomic::{AtomicU16, Ordering},
424 };
425
426 use super::*;
427 use crate::{discovery, Backends, SocketAddr};
428 use async_trait::async_trait;
429 use http::Extensions;
430
431 #[tokio::test]
432 async fn test_tcp_check() {
433 let tcp_check = TcpHealthCheck::default();
434
435 let backend = Backend {
436 addr: SocketAddr::Inet("1.1.1.1:80".parse().unwrap()),
437 weight: 1,
438 ext: Extensions::new(),
439 };
440
441 assert!(tcp_check.check(&backend).await.is_ok());
442
443 let backend = Backend {
444 addr: SocketAddr::Inet("1.1.1.1:79".parse().unwrap()),
445 weight: 1,
446 ext: Extensions::new(),
447 };
448
449 assert!(tcp_check.check(&backend).await.is_err());
450 }
451
452 #[cfg(feature = "any_tls")]
453 #[tokio::test]
454 async fn test_tls_check() {
455 let tls_check = TcpHealthCheck::new_tls("one.one.one.one");
456 let backend = Backend {
457 addr: SocketAddr::Inet("1.1.1.1:443".parse().unwrap()),
458 weight: 1,
459 ext: Extensions::new(),
460 };
461
462 assert!(tls_check.check(&backend).await.is_ok());
463 }
464
465 #[cfg(feature = "any_tls")]
466 #[tokio::test]
467 async fn test_https_check() {
468 let https_check = HttpHealthCheck::new("one.one.one.one", true);
469
470 let backend = Backend {
471 addr: SocketAddr::Inet("1.1.1.1:443".parse().unwrap()),
472 weight: 1,
473 ext: Extensions::new(),
474 };
475
476 assert!(https_check.check(&backend).await.is_ok());
477 }
478
479 #[tokio::test]
480 async fn test_http_custom_check() {
481 let mut http_check = HttpHealthCheck::new("one.one.one.one", false);
482 http_check.validator = Some(Box::new(|resp: &ResponseHeader| {
483 if resp.status == 301 {
484 Ok(())
485 } else {
486 Error::e_explain(
487 CustomCode("non 301 code", resp.status.as_u16()),
488 "during http healthcheck",
489 )
490 }
491 }));
492
493 let backend = Backend {
494 addr: SocketAddr::Inet("1.1.1.1:80".parse().unwrap()),
495 weight: 1,
496 ext: Extensions::new(),
497 };
498
499 http_check.check(&backend).await.unwrap();
500
501 assert!(http_check.check(&backend).await.is_ok());
502 }
503
504 #[tokio::test]
505 async fn test_health_observe() {
506 struct Observe {
507 unhealthy_count: Arc<AtomicU16>,
508 }
509 #[async_trait]
510 impl HealthObserve for Observe {
511 async fn observe(&self, _target: &Backend, healthy: bool) {
512 if !healthy {
513 self.unhealthy_count.fetch_add(1, Ordering::Relaxed);
514 }
515 }
516 }
517
518 let good_backend = Backend::new("127.0.0.1:79").unwrap();
519 let new_good_backends = || -> (BTreeSet<Backend>, HashMap<u64, bool>) {
520 let mut healthy = HashMap::new();
521 healthy.insert(good_backend.hash_key(), true);
522 let mut backends = BTreeSet::new();
523 backends.extend(vec![good_backend.clone()]);
524 (backends, healthy)
525 };
526 {
528 let unhealthy_count = Arc::new(AtomicU16::new(0));
529 let ob = Observe {
530 unhealthy_count: unhealthy_count.clone(),
531 };
532 let bob = Box::new(ob);
533 let tcp_check = TcpHealthCheck {
534 health_changed_callback: Some(bob),
535 ..Default::default()
536 };
537
538 let discovery = discovery::Static::default();
539 let mut backends = Backends::new(Box::new(discovery));
540 backends.set_health_check(Box::new(tcp_check));
541 let result = new_good_backends();
542 backends.do_update(result.0, result.1, |_backend: Arc<BTreeSet<Backend>>| {});
543 assert!(backends.ready(&good_backend));
545
546 backends.run_health_check(false).await;
548 assert!(1 == unhealthy_count.load(Ordering::Relaxed));
549 assert!(!backends.ready(&good_backend));
551 }
552
553 {
555 let unhealthy_count = Arc::new(AtomicU16::new(0));
556 let ob = Observe {
557 unhealthy_count: unhealthy_count.clone(),
558 };
559 let bob = Box::new(ob);
560
561 let mut https_check = HttpHealthCheck::new("one.one.one.one", true);
562 https_check.health_changed_callback = Some(bob);
563
564 let discovery = discovery::Static::default();
565 let mut backends = Backends::new(Box::new(discovery));
566 backends.set_health_check(Box::new(https_check));
567 let result = new_good_backends();
568 backends.do_update(result.0, result.1, |_backend: Arc<BTreeSet<Backend>>| {});
569 assert!(backends.ready(&good_backend));
571 backends.run_health_check(false).await;
573 assert!(1 == unhealthy_count.load(Ordering::Relaxed));
574 assert!(!backends.ready(&good_backend));
575 }
576 }
577}