use crate::Backend;
use arc_swap::ArcSwap;
use async_trait::async_trait;
use pingora_core::connectors::http::custom;
use pingora_core::connectors::{http::Connector as HttpConnector, TransportConnector};
use pingora_core::custom_session;
use pingora_core::protocols::http::custom::client::Session;
use pingora_core::upstreams::peer::{BasicPeer, HttpPeer, Peer};
use pingora_error::{Error, ErrorType::CustomCode, Result};
use pingora_http::{RequestHeader, ResponseHeader};
use std::sync::Arc;
use std::time::Duration;
#[async_trait]
pub trait HealthObserve {
async fn observe(&self, target: &Backend, healthy: bool);
}
pub type HealthObserveCallback = Box<dyn HealthObserve + Send + Sync>;
pub type BackendSummary = Box<dyn Fn(&Backend) -> String + Send + Sync>;
#[async_trait]
pub trait HealthCheck {
async fn check(&self, target: &Backend) -> Result<()>;
async fn health_status_change(&self, _target: &Backend, _healthy: bool) {}
fn backend_summary(&self, target: &Backend) -> String {
format!("{target:?}")
}
fn health_threshold(&self, success: bool) -> usize;
}
pub struct TcpHealthCheck {
pub consecutive_success: usize,
pub consecutive_failure: usize,
pub peer_template: BasicPeer,
connector: TransportConnector,
pub health_changed_callback: Option<HealthObserveCallback>,
}
impl Default for TcpHealthCheck {
fn default() -> Self {
let mut peer_template = BasicPeer::new("0.0.0.0:1");
peer_template.options.connection_timeout = Some(Duration::from_secs(1));
TcpHealthCheck {
consecutive_success: 1,
consecutive_failure: 1,
peer_template,
connector: TransportConnector::new(None),
health_changed_callback: None,
}
}
}
impl TcpHealthCheck {
pub fn new() -> Box<Self> {
Box::<TcpHealthCheck>::default()
}
pub fn new_tls(sni: &str) -> Box<Self> {
let mut new = Self::default();
new.peer_template.sni = sni.into();
Box::new(new)
}
pub fn set_connector(&mut self, connector: TransportConnector) {
self.connector = connector;
}
}
#[async_trait]
impl HealthCheck for TcpHealthCheck {
fn health_threshold(&self, success: bool) -> usize {
if success {
self.consecutive_success
} else {
self.consecutive_failure
}
}
async fn check(&self, target: &Backend) -> Result<()> {
let mut peer = self.peer_template.clone();
peer._address = target.addr.clone();
self.connector.get_stream(&peer).await.map(|_| {})
}
async fn health_status_change(&self, target: &Backend, healthy: bool) {
if let Some(callback) = &self.health_changed_callback {
callback.observe(target, healthy).await;
}
}
}
type Validator = Box<dyn Fn(&ResponseHeader) -> Result<()> + Send + Sync>;
pub struct HttpHealthCheck<C = ()>
where
C: custom::Connector,
{
pub consecutive_success: usize,
pub consecutive_failure: usize,
pub peer_template: HttpPeer,
pub reuse_connection: bool,
pub req: RequestHeader,
connector: HttpConnector<C>,
pub validator: Option<Validator>,
pub port_override: Option<u16>,
pub health_changed_callback: Option<HealthObserveCallback>,
pub backend_summary_callback: Option<BackendSummary>,
}
impl HttpHealthCheck<()> {
pub fn new(host: &str, tls: bool) -> Self {
let mut req = RequestHeader::build("GET", b"/", None).unwrap();
req.append_header("Host", host).unwrap();
let sni = if tls { host.into() } else { String::new() };
let mut peer_template = HttpPeer::new("0.0.0.0:1", tls, sni);
peer_template.options.connection_timeout = Some(Duration::from_secs(1));
peer_template.options.read_timeout = Some(Duration::from_secs(1));
HttpHealthCheck {
consecutive_success: 1,
consecutive_failure: 1,
peer_template,
connector: HttpConnector::new(None),
reuse_connection: false,
req,
validator: None,
port_override: None,
health_changed_callback: None,
backend_summary_callback: None,
}
}
}
impl<C> HttpHealthCheck<C>
where
C: custom::Connector,
{
pub fn new_custom(host: &str, tls: bool, custom: HttpConnector<C>) -> Self {
let mut req = RequestHeader::build("GET", b"/", None).unwrap();
req.append_header("Host", host).unwrap();
let sni = if tls { host.into() } else { String::new() };
let mut peer_template = HttpPeer::new("0.0.0.0:1", tls, sni);
peer_template.options.connection_timeout = Some(Duration::from_secs(1));
peer_template.options.read_timeout = Some(Duration::from_secs(1));
HttpHealthCheck {
consecutive_success: 1,
consecutive_failure: 1,
peer_template,
connector: custom,
reuse_connection: false,
req,
validator: None,
port_override: None,
health_changed_callback: None,
backend_summary_callback: None,
}
}
pub fn set_connector(&mut self, connector: HttpConnector<C>) {
self.connector = connector;
}
pub fn set_backend_summary<F>(&mut self, callback: F)
where
F: Fn(&Backend) -> String + Send + Sync + 'static,
{
self.backend_summary_callback = Some(Box::new(callback));
}
}
#[async_trait]
impl<C> HealthCheck for HttpHealthCheck<C>
where
C: custom::Connector,
{
fn health_threshold(&self, success: bool) -> usize {
if success {
self.consecutive_success
} else {
self.consecutive_failure
}
}
async fn check(&self, target: &Backend) -> Result<()> {
let mut peer = self.peer_template.clone();
peer._address = target.addr.clone();
if let Some(port) = self.port_override {
peer._address.set_port(port);
}
let session = self.connector.get_http_session(&peer).await?;
let mut session = session.0;
let req = Box::new(self.req.clone());
session.write_request_header(req).await?;
session.finish_request_body().await?;
custom_session!(session.finish_custom().await?);
if let Some(read_timeout) = peer.options.read_timeout {
session.set_read_timeout(Some(read_timeout));
}
session.read_response_header().await?;
let resp = session.response_header().expect("just read");
if let Some(validator) = self.validator.as_ref() {
validator(resp)?;
} else if resp.status != 200 {
return Error::e_explain(
CustomCode("non 200 code", resp.status.as_u16()),
"during http healthcheck",
);
};
while session.read_response_body().await?.is_some() {
}
custom_session!(session.drain_custom_messages().await?);
if self.reuse_connection {
let idle_timeout = peer.idle_timeout();
self.connector
.release_http_session(session, &peer, idle_timeout)
.await;
}
Ok(())
}
async fn health_status_change(&self, target: &Backend, healthy: bool) {
if let Some(callback) = &self.health_changed_callback {
callback.observe(target, healthy).await;
}
}
fn backend_summary(&self, target: &Backend) -> String {
if let Some(callback) = &self.backend_summary_callback {
callback(target)
} else {
format!("{target:?}")
}
}
}
#[derive(Clone)]
struct HealthInner {
healthy: bool,
enabled: bool,
consecutive_counter: usize,
}
pub(crate) struct Health(ArcSwap<HealthInner>);
impl Default for Health {
fn default() -> Self {
Health(ArcSwap::new(Arc::new(HealthInner {
healthy: true, enabled: true,
consecutive_counter: 0,
})))
}
}
impl Clone for Health {
fn clone(&self) -> Self {
let inner = self.0.load_full();
Health(ArcSwap::new(inner))
}
}
impl Health {
pub fn ready(&self) -> bool {
let h = self.0.load();
h.healthy && h.enabled
}
pub fn enable(&self, enabled: bool) {
let h = self.0.load();
if h.enabled != enabled {
let mut new_health = (**h).clone();
new_health.enabled = enabled;
self.0.store(Arc::new(new_health));
};
}
pub fn observe_health(&self, health: bool, flip_threshold: usize) -> bool {
let h = self.0.load();
let mut flipped = false;
if h.healthy != health {
let mut new_health = (**h).clone();
new_health.consecutive_counter += 1;
if new_health.consecutive_counter >= flip_threshold {
new_health.healthy = health;
new_health.consecutive_counter = 0;
flipped = true;
}
self.0.store(Arc::new(new_health));
} else if h.consecutive_counter > 0 {
let mut new_health = (**h).clone();
new_health.consecutive_counter = 0;
self.0.store(Arc::new(new_health));
}
flipped
}
}
#[cfg(test)]
mod test {
use std::{
collections::{BTreeSet, HashMap},
sync::atomic::{AtomicU16, Ordering},
};
use super::*;
use crate::{discovery, Backends, SocketAddr};
use async_trait::async_trait;
use http::Extensions;
#[tokio::test]
async fn test_tcp_check() {
let tcp_check = TcpHealthCheck::default();
let backend = Backend {
addr: SocketAddr::Inet("1.1.1.1:80".parse().unwrap()),
weight: 1,
ext: Extensions::new(),
};
assert!(tcp_check.check(&backend).await.is_ok());
let backend = Backend {
addr: SocketAddr::Inet("1.1.1.1:79".parse().unwrap()),
weight: 1,
ext: Extensions::new(),
};
assert!(tcp_check.check(&backend).await.is_err());
}
#[cfg(feature = "any_tls")]
#[tokio::test]
async fn test_tls_check() {
let tls_check = TcpHealthCheck::new_tls("one.one.one.one");
let backend = Backend {
addr: SocketAddr::Inet("1.1.1.1:443".parse().unwrap()),
weight: 1,
ext: Extensions::new(),
};
assert!(tls_check.check(&backend).await.is_ok());
}
#[cfg(feature = "any_tls")]
#[tokio::test]
async fn test_https_check() {
let https_check = HttpHealthCheck::new("one.one.one.one", true);
let backend = Backend {
addr: SocketAddr::Inet("1.1.1.1:443".parse().unwrap()),
weight: 1,
ext: Extensions::new(),
};
assert!(https_check.check(&backend).await.is_ok());
}
#[tokio::test]
async fn test_http_custom_check() {
let mut http_check = HttpHealthCheck::new("one.one.one.one", false);
http_check.validator = Some(Box::new(|resp: &ResponseHeader| {
if resp.status == 301 {
Ok(())
} else {
Error::e_explain(
CustomCode("non 301 code", resp.status.as_u16()),
"during http healthcheck",
)
}
}));
let backend = Backend {
addr: SocketAddr::Inet("1.1.1.1:80".parse().unwrap()),
weight: 1,
ext: Extensions::new(),
};
http_check.check(&backend).await.unwrap();
assert!(http_check.check(&backend).await.is_ok());
}
#[tokio::test]
async fn test_health_observe() {
struct Observe {
unhealthy_count: Arc<AtomicU16>,
}
#[async_trait]
impl HealthObserve for Observe {
async fn observe(&self, _target: &Backend, healthy: bool) {
if !healthy {
self.unhealthy_count.fetch_add(1, Ordering::Relaxed);
}
}
}
let good_backend = Backend::new("127.0.0.1:79").unwrap();
let new_good_backends = || -> (BTreeSet<Backend>, HashMap<u64, bool>) {
let mut healthy = HashMap::new();
healthy.insert(good_backend.hash_key(), true);
let mut backends = BTreeSet::new();
backends.extend(vec![good_backend.clone()]);
(backends, healthy)
};
{
let unhealthy_count = Arc::new(AtomicU16::new(0));
let ob = Observe {
unhealthy_count: unhealthy_count.clone(),
};
let bob = Box::new(ob);
let tcp_check = TcpHealthCheck {
health_changed_callback: Some(bob),
..Default::default()
};
let discovery = discovery::Static::default();
let mut backends = Backends::new(Box::new(discovery));
backends.set_health_check(Box::new(tcp_check));
let result = new_good_backends();
backends.do_update(result.0, result.1, |_backend: Arc<BTreeSet<Backend>>| {});
assert!(backends.ready(&good_backend));
backends.run_health_check(false).await;
assert!(1 == unhealthy_count.load(Ordering::Relaxed));
assert!(!backends.ready(&good_backend));
}
{
let unhealthy_count = Arc::new(AtomicU16::new(0));
let ob = Observe {
unhealthy_count: unhealthy_count.clone(),
};
let bob = Box::new(ob);
let mut https_check = HttpHealthCheck::new("one.one.one.one", true);
https_check.health_changed_callback = Some(bob);
let discovery = discovery::Static::default();
let mut backends = Backends::new(Box::new(discovery));
backends.set_health_check(Box::new(https_check));
let result = new_good_backends();
backends.do_update(result.0, result.1, |_backend: Arc<BTreeSet<Backend>>| {});
assert!(backends.ready(&good_backend));
backends.run_health_check(false).await;
assert!(1 == unhealthy_count.load(Ordering::Relaxed));
assert!(!backends.ready(&good_backend));
}
}
}