1use crate::Backend;
18use arc_swap::ArcSwap;
19use async_trait::async_trait;
20use pingora_core::connectors::{http::Connector as HttpConnector, TransportConnector};
21use pingora_core::upstreams::peer::{BasicPeer, HttpPeer, Peer};
22use pingora_error::{Error, ErrorType::CustomCode, Result};
23use pingora_http::{RequestHeader, ResponseHeader};
24use std::sync::Arc;
25use std::time::Duration;
26
27#[async_trait]
30pub trait HealthObserve {
31 async fn observe(&self, target: &Backend, healthy: bool);
33}
34pub type HealthObserveCallback = Box<dyn HealthObserve + Send + Sync>;
36
37#[async_trait]
39pub trait HealthCheck {
40 async fn check(&self, target: &Backend) -> Result<()>;
44
45 async fn health_status_change(&self, _target: &Backend, _healthy: bool) {}
47
48 fn health_threshold(&self, success: bool) -> usize;
53}
54
55pub struct TcpHealthCheck {
59 pub consecutive_success: usize,
61 pub consecutive_failure: usize,
63 pub peer_template: BasicPeer,
72 connector: TransportConnector,
73 pub health_changed_callback: Option<HealthObserveCallback>,
75}
76
77impl Default for TcpHealthCheck {
78 fn default() -> Self {
79 let mut peer_template = BasicPeer::new("0.0.0.0:1");
80 peer_template.options.connection_timeout = Some(Duration::from_secs(1));
81 TcpHealthCheck {
82 consecutive_success: 1,
83 consecutive_failure: 1,
84 peer_template,
85 connector: TransportConnector::new(None),
86 health_changed_callback: None,
87 }
88 }
89}
90
91impl TcpHealthCheck {
92 pub fn new() -> Box<Self> {
97 Box::<TcpHealthCheck>::default()
98 }
99
100 pub fn new_tls(sni: &str) -> Box<Self> {
104 let mut new = Self::default();
105 new.peer_template.sni = sni.into();
106 Box::new(new)
107 }
108
109 pub fn set_connector(&mut self, connector: TransportConnector) {
111 self.connector = connector;
112 }
113}
114
115#[async_trait]
116impl HealthCheck for TcpHealthCheck {
117 fn health_threshold(&self, success: bool) -> usize {
118 if success {
119 self.consecutive_success
120 } else {
121 self.consecutive_failure
122 }
123 }
124
125 async fn check(&self, target: &Backend) -> Result<()> {
126 let mut peer = self.peer_template.clone();
127 peer._address = target.addr.clone();
128 self.connector.get_stream(&peer).await.map(|_| {})
129 }
130
131 async fn health_status_change(&self, target: &Backend, healthy: bool) {
132 if let Some(callback) = &self.health_changed_callback {
133 callback.observe(target, healthy).await;
134 }
135 }
136}
137
138type Validator = Box<dyn Fn(&ResponseHeader) -> Result<()> + Send + Sync>;
139
140pub struct HttpHealthCheck {
144 pub consecutive_success: usize,
146 pub consecutive_failure: usize,
148 pub peer_template: HttpPeer,
156 pub reuse_connection: bool,
163 pub req: RequestHeader,
165 connector: HttpConnector,
166 pub validator: Option<Validator>,
170 pub port_override: Option<u16>,
173 pub health_changed_callback: Option<HealthObserveCallback>,
175}
176
177impl HttpHealthCheck {
178 pub fn new(host: &str, tls: bool) -> Self {
187 let mut req = RequestHeader::build("GET", b"/", None).unwrap();
188 req.append_header("Host", host).unwrap();
189 let sni = if tls { host.into() } else { String::new() };
190 let mut peer_template = HttpPeer::new("0.0.0.0:1", tls, sni);
191 peer_template.options.connection_timeout = Some(Duration::from_secs(1));
192 peer_template.options.read_timeout = Some(Duration::from_secs(1));
193 HttpHealthCheck {
194 consecutive_success: 1,
195 consecutive_failure: 1,
196 peer_template,
197 connector: HttpConnector::new(None),
198 reuse_connection: false,
199 req,
200 validator: None,
201 port_override: None,
202 health_changed_callback: None,
203 }
204 }
205
206 pub fn set_connector(&mut self, connector: HttpConnector) {
208 self.connector = connector;
209 }
210}
211
212#[async_trait]
213impl HealthCheck for HttpHealthCheck {
214 fn health_threshold(&self, success: bool) -> usize {
215 if success {
216 self.consecutive_success
217 } else {
218 self.consecutive_failure
219 }
220 }
221
222 async fn check(&self, target: &Backend) -> Result<()> {
223 let mut peer = self.peer_template.clone();
224 peer._address = target.addr.clone();
225 if let Some(port) = self.port_override {
226 peer._address.set_port(port);
227 }
228 let session = self.connector.get_http_session(&peer).await?;
229
230 let mut session = session.0;
231 let req = Box::new(self.req.clone());
232 session.write_request_header(req).await?;
233 session.finish_request_body().await?;
234
235 if let Some(read_timeout) = peer.options.read_timeout {
236 session.set_read_timeout(read_timeout);
237 }
238
239 session.read_response_header().await?;
240
241 let resp = session.response_header().expect("just read");
242
243 if let Some(validator) = self.validator.as_ref() {
244 validator(resp)?;
245 } else if resp.status != 200 {
246 return Error::e_explain(
247 CustomCode("non 200 code", resp.status.as_u16()),
248 "during http healthcheck",
249 );
250 };
251
252 while session.read_response_body().await?.is_some() {
253 }
255
256 if self.reuse_connection {
257 let idle_timeout = peer.idle_timeout();
258 self.connector
259 .release_http_session(session, &peer, idle_timeout)
260 .await;
261 }
262
263 Ok(())
264 }
265 async fn health_status_change(&self, target: &Backend, healthy: bool) {
266 if let Some(callback) = &self.health_changed_callback {
267 callback.observe(target, healthy).await;
268 }
269 }
270}
271
272#[derive(Clone)]
273struct HealthInner {
274 healthy: bool,
276 enabled: bool,
278 consecutive_counter: usize,
282}
283
284pub(crate) struct Health(ArcSwap<HealthInner>);
286
287impl Default for Health {
288 fn default() -> Self {
289 Health(ArcSwap::new(Arc::new(HealthInner {
290 healthy: true, enabled: true,
292 consecutive_counter: 0,
293 })))
294 }
295}
296
297impl Clone for Health {
298 fn clone(&self) -> Self {
299 let inner = self.0.load_full();
300 Health(ArcSwap::new(inner))
301 }
302}
303
304impl Health {
305 pub fn ready(&self) -> bool {
306 let h = self.0.load();
307 h.healthy && h.enabled
308 }
309
310 pub fn enable(&self, enabled: bool) {
311 let h = self.0.load();
312 if h.enabled != enabled {
313 let mut new_health = (**h).clone();
315 new_health.enabled = enabled;
316 self.0.store(Arc::new(new_health));
317 };
318 }
319
320 pub fn observe_health(&self, health: bool, flip_threshold: usize) -> bool {
322 let h = self.0.load();
323 let mut flipped = false;
324 if h.healthy != health {
325 let mut new_health = (**h).clone();
328 new_health.consecutive_counter += 1;
329 if new_health.consecutive_counter >= flip_threshold {
330 new_health.healthy = health;
331 new_health.consecutive_counter = 0;
332 flipped = true;
333 }
334 self.0.store(Arc::new(new_health));
335 } else if h.consecutive_counter > 0 {
336 let mut new_health = (**h).clone();
339 new_health.consecutive_counter = 0;
340 self.0.store(Arc::new(new_health));
341 }
342 flipped
343 }
344}
345
346#[cfg(test)]
347mod test {
348 use std::{
349 collections::{BTreeSet, HashMap},
350 sync::atomic::{AtomicU16, Ordering},
351 };
352
353 use super::*;
354 use crate::{discovery, Backends, SocketAddr};
355 use async_trait::async_trait;
356 use http::Extensions;
357
358 #[tokio::test]
359 async fn test_tcp_check() {
360 let tcp_check = TcpHealthCheck::default();
361
362 let backend = Backend {
363 addr: SocketAddr::Inet("1.1.1.1:80".parse().unwrap()),
364 weight: 1,
365 ext: Extensions::new(),
366 };
367
368 assert!(tcp_check.check(&backend).await.is_ok());
369
370 let backend = Backend {
371 addr: SocketAddr::Inet("1.1.1.1:79".parse().unwrap()),
372 weight: 1,
373 ext: Extensions::new(),
374 };
375
376 assert!(tcp_check.check(&backend).await.is_err());
377 }
378
379 #[cfg(feature = "any_tls")]
380 #[tokio::test]
381 async fn test_tls_check() {
382 let tls_check = TcpHealthCheck::new_tls("one.one.one.one");
383 let backend = Backend {
384 addr: SocketAddr::Inet("1.1.1.1:443".parse().unwrap()),
385 weight: 1,
386 ext: Extensions::new(),
387 };
388
389 assert!(tls_check.check(&backend).await.is_ok());
390 }
391
392 #[cfg(feature = "any_tls")]
393 #[tokio::test]
394 async fn test_https_check() {
395 let https_check = HttpHealthCheck::new("one.one.one.one", true);
396
397 let backend = Backend {
398 addr: SocketAddr::Inet("1.1.1.1:443".parse().unwrap()),
399 weight: 1,
400 ext: Extensions::new(),
401 };
402
403 assert!(https_check.check(&backend).await.is_ok());
404 }
405
406 #[tokio::test]
407 async fn test_http_custom_check() {
408 let mut http_check = HttpHealthCheck::new("one.one.one.one", false);
409 http_check.validator = Some(Box::new(|resp: &ResponseHeader| {
410 if resp.status == 301 {
411 Ok(())
412 } else {
413 Error::e_explain(
414 CustomCode("non 301 code", resp.status.as_u16()),
415 "during http healthcheck",
416 )
417 }
418 }));
419
420 let backend = Backend {
421 addr: SocketAddr::Inet("1.1.1.1:80".parse().unwrap()),
422 weight: 1,
423 ext: Extensions::new(),
424 };
425
426 http_check.check(&backend).await.unwrap();
427
428 assert!(http_check.check(&backend).await.is_ok());
429 }
430
431 #[tokio::test]
432 async fn test_health_observe() {
433 struct Observe {
434 unhealthy_count: Arc<AtomicU16>,
435 }
436 #[async_trait]
437 impl HealthObserve for Observe {
438 async fn observe(&self, _target: &Backend, healthy: bool) {
439 if !healthy {
440 self.unhealthy_count.fetch_add(1, Ordering::Relaxed);
441 }
442 }
443 }
444
445 let good_backend = Backend::new("127.0.0.1:79").unwrap();
446 let new_good_backends = || -> (BTreeSet<Backend>, HashMap<u64, bool>) {
447 let mut healthy = HashMap::new();
448 healthy.insert(good_backend.hash_key(), true);
449 let mut backends = BTreeSet::new();
450 backends.extend(vec![good_backend.clone()]);
451 (backends, healthy)
452 };
453 {
455 let unhealthy_count = Arc::new(AtomicU16::new(0));
456 let ob = Observe {
457 unhealthy_count: unhealthy_count.clone(),
458 };
459 let bob = Box::new(ob);
460 let tcp_check = TcpHealthCheck {
461 health_changed_callback: Some(bob),
462 ..Default::default()
463 };
464
465 let discovery = discovery::Static::default();
466 let mut backends = Backends::new(Box::new(discovery));
467 backends.set_health_check(Box::new(tcp_check));
468 let result = new_good_backends();
469 backends.do_update(result.0, result.1, |_backend: Arc<BTreeSet<Backend>>| {});
470 assert!(backends.ready(&good_backend));
472
473 backends.run_health_check(false).await;
475 assert!(1 == unhealthy_count.load(Ordering::Relaxed));
476 assert!(!backends.ready(&good_backend));
478 }
479
480 {
482 let unhealthy_count = Arc::new(AtomicU16::new(0));
483 let ob = Observe {
484 unhealthy_count: unhealthy_count.clone(),
485 };
486 let bob = Box::new(ob);
487
488 let mut https_check = HttpHealthCheck::new("one.one.one.one", true);
489 https_check.health_changed_callback = Some(bob);
490
491 let discovery = discovery::Static::default();
492 let mut backends = Backends::new(Box::new(discovery));
493 backends.set_health_check(Box::new(https_check));
494 let result = new_good_backends();
495 backends.do_update(result.0, result.1, |_backend: Arc<BTreeSet<Backend>>| {});
496 assert!(backends.ready(&good_backend));
498 backends.run_health_check(false).await;
500 assert!(1 == unhealthy_count.load(Ordering::Relaxed));
501 assert!(!backends.ready(&good_backend));
502 }
503 }
504}