#![allow(clippy::non_canonical_partial_ord_impl)]
use arc_swap::ArcSwap;
use derivative::Derivative;
use futures::FutureExt;
pub use http::Extensions;
use pingora_core::protocols::l4::socket::SocketAddr;
use pingora_error::{ErrorType, OrErr, Result};
use std::collections::hash_map::DefaultHasher;
use std::collections::{BTreeSet, HashMap};
use std::hash::{Hash, Hasher};
use std::io::Result as IoResult;
use std::net::ToSocketAddrs;
use std::sync::Arc;
use std::time::Duration;
mod background;
pub mod discovery;
pub mod health_check;
pub mod selection;
use discovery::ServiceDiscovery;
use health_check::Health;
use selection::UniqueIterator;
use selection::{BackendIter, BackendSelection};
pub mod prelude {
pub use crate::health_check::TcpHealthCheck;
pub use crate::selection::RoundRobin;
pub use crate::LoadBalancer;
}
#[derive(Derivative)]
#[derivative(Clone, Hash, PartialEq, PartialOrd, Eq, Ord, Debug)]
pub struct Backend {
pub addr: SocketAddr,
pub weight: usize,
#[derivative(PartialEq = "ignore")]
#[derivative(PartialOrd = "ignore")]
#[derivative(Hash = "ignore")]
#[derivative(Ord = "ignore")]
pub ext: Extensions,
}
impl Backend {
pub fn new(addr: &str) -> Result<Self> {
Self::new_with_weight(addr, 1)
}
pub fn new_with_weight(addr: &str, weight: usize) -> Result<Self> {
let addr = addr
.parse()
.or_err(ErrorType::InternalError, "invalid socket addr")?;
Ok(Backend {
addr: SocketAddr::Inet(addr),
weight,
ext: Extensions::new(),
})
}
pub(crate) fn hash_key(&self) -> u64 {
let mut hasher = DefaultHasher::new();
self.hash(&mut hasher);
hasher.finish()
}
}
impl std::ops::Deref for Backend {
type Target = SocketAddr;
fn deref(&self) -> &Self::Target {
&self.addr
}
}
impl std::ops::DerefMut for Backend {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.addr
}
}
impl std::net::ToSocketAddrs for Backend {
type Iter = std::iter::Once<std::net::SocketAddr>;
fn to_socket_addrs(&self) -> std::io::Result<Self::Iter> {
self.addr.to_socket_addrs()
}
}
pub struct Backends {
discovery: Box<dyn ServiceDiscovery + Send + Sync + 'static>,
health_check: Option<Arc<dyn health_check::HealthCheck + Send + Sync + 'static>>,
backends: ArcSwap<BTreeSet<Backend>>,
health: ArcSwap<HashMap<u64, Health>>,
}
impl Backends {
pub fn new(discovery: Box<dyn ServiceDiscovery + Send + Sync + 'static>) -> Self {
Self {
discovery,
health_check: None,
backends: Default::default(),
health: Default::default(),
}
}
pub fn set_health_check(
&mut self,
hc: Box<dyn health_check::HealthCheck + Send + Sync + 'static>,
) {
self.health_check = Some(hc.into())
}
fn do_update<F>(
&self,
new_backends: BTreeSet<Backend>,
enablement: HashMap<u64, bool>,
callback: F,
) where
F: Fn(Arc<BTreeSet<Backend>>),
{
if (**self.backends.load()) != new_backends {
let old_health = self.health.load();
let mut health = HashMap::with_capacity(new_backends.len());
for backend in new_backends.iter() {
let hash_key = backend.hash_key();
let backend_health = old_health.get(&hash_key).cloned().unwrap_or_default();
if let Some(backend_enabled) = enablement.get(&hash_key) {
backend_health.enable(*backend_enabled);
}
health.insert(hash_key, backend_health);
}
let new_backends = Arc::new(new_backends);
callback(new_backends.clone());
self.backends.store(new_backends);
self.health.store(Arc::new(health));
} else {
for (hash_key, backend_enabled) in enablement.iter() {
if let Some(backend_health) = self.health.load().get(hash_key) {
backend_health.enable(*backend_enabled);
}
}
}
}
pub fn ready(&self, backend: &Backend) -> bool {
self.health
.load()
.get(&backend.hash_key())
.map_or(self.health_check.is_none(), |h| h.ready())
}
pub fn set_enable(&self, backend: &Backend, enabled: bool) {
if let Some(h) = self.health.load().get(&backend.hash_key()) {
h.enable(enabled)
};
}
pub fn get_backend(&self) -> Arc<BTreeSet<Backend>> {
self.backends.load_full()
}
pub async fn update<F>(&self, callback: F) -> Result<()>
where
F: Fn(Arc<BTreeSet<Backend>>),
{
let (new_backends, enablement) = self.discovery.discover().await?;
self.do_update(new_backends, enablement, callback);
Ok(())
}
pub async fn run_health_check(&self, parallel: bool) {
use crate::health_check::HealthCheck;
use log::{info, warn};
use pingora_runtime::current_handle;
async fn check_and_report(
backend: &Backend,
check: &Arc<dyn HealthCheck + Send + Sync>,
health_table: &HashMap<u64, Health>,
) {
let errored = check.check(backend).await.err();
if let Some(h) = health_table.get(&backend.hash_key()) {
let flipped =
h.observe_health(errored.is_none(), check.health_threshold(errored.is_none()));
if flipped {
check.health_status_change(backend, errored.is_none()).await;
let summary = check.backend_summary(backend);
if let Some(e) = errored {
warn!("{summary} becomes unhealthy, {e}");
} else {
info!("{summary} becomes healthy");
}
}
}
}
let Some(health_check) = self.health_check.as_ref() else {
return;
};
let backends = self.backends.load();
if parallel {
let health_table = self.health.load_full();
let runtime = current_handle();
let jobs = backends.iter().map(|backend| {
let backend = backend.clone();
let check = health_check.clone();
let ht = health_table.clone();
runtime.spawn(async move {
check_and_report(&backend, &check, &ht).await;
})
});
futures::future::join_all(jobs).await;
} else {
for backend in backends.iter() {
check_and_report(backend, health_check, &self.health.load()).await;
}
}
}
}
pub struct LoadBalancer<S>
where
S: BackendSelection,
{
backends: Backends,
selector: ArcSwap<S>,
config: Option<S::Config>,
pub health_check_frequency: Option<Duration>,
pub update_frequency: Option<Duration>,
pub parallel_health_check: bool,
}
impl<S> LoadBalancer<S>
where
S: BackendSelection + 'static,
S::Iter: BackendIter,
{
pub fn try_from_iter<A, T: IntoIterator<Item = A>>(iter: T) -> IoResult<Self>
where
A: ToSocketAddrs,
{
let discovery = discovery::Static::try_from_iter(iter)?;
let backends = Backends::new(discovery);
let lb = Self::from_backends(backends);
lb.update()
.now_or_never()
.expect("static should not block")
.expect("static should not error");
Ok(lb)
}
pub fn from_backends_with_config(backends: Backends, config_opt: Option<S::Config>) -> Self {
let selector_raw = if let Some(config) = config_opt.as_ref() {
S::build_with_config(&backends.get_backend(), config)
} else {
S::build(&backends.get_backend())
};
let selector = ArcSwap::new(Arc::new(selector_raw));
LoadBalancer {
backends,
selector,
config: config_opt,
health_check_frequency: None,
update_frequency: None,
parallel_health_check: false,
}
}
pub fn from_backends(backends: Backends) -> Self {
Self::from_backends_with_config(backends, None)
}
pub async fn update(&self) -> Result<()> {
self.backends
.update(|backends| {
let selector = if let Some(config) = &self.config {
S::build_with_config(&backends, config)
} else {
S::build(&backends)
};
self.selector.store(Arc::new(selector))
})
.await
}
pub fn select(&self, key: &[u8], max_iterations: usize) -> Option<Backend> {
self.select_with(key, max_iterations, |_, health| health)
}
pub fn select_with<F>(&self, key: &[u8], max_iterations: usize, accept: F) -> Option<Backend>
where
F: Fn(&Backend, bool) -> bool,
{
let selection = self.selector.load();
let mut iter = UniqueIterator::new(selection.iter(key), max_iterations);
while let Some(b) = iter.get_next() {
if accept(&b, self.backends.ready(&b)) {
return Some(b);
}
}
None
}
pub fn set_health_check(
&mut self,
hc: Box<dyn health_check::HealthCheck + Send + Sync + 'static>,
) {
self.backends.set_health_check(hc);
}
pub fn backends(&self) -> &Backends {
&self.backends
}
}
#[cfg(test)]
mod test {
use std::sync::atomic::{AtomicBool, Ordering::Relaxed};
use super::*;
use async_trait::async_trait;
#[tokio::test]
async fn test_static_backends() {
let backends: LoadBalancer<selection::RoundRobin> =
LoadBalancer::try_from_iter(["1.1.1.1:80", "1.0.0.1:80"]).unwrap();
let backend1 = Backend::new("1.1.1.1:80").unwrap();
let backend2 = Backend::new("1.0.0.1:80").unwrap();
let backend = backends.backends().get_backend();
assert!(backend.contains(&backend1));
assert!(backend.contains(&backend2));
}
#[tokio::test]
async fn test_backends() {
let discovery = discovery::Static::default();
let good1 = Backend::new("1.1.1.1:80").unwrap();
discovery.add(good1.clone());
let good2 = Backend::new("1.0.0.1:80").unwrap();
discovery.add(good2.clone());
let bad = Backend::new("127.0.0.1:79").unwrap();
discovery.add(bad.clone());
let mut backends = Backends::new(Box::new(discovery));
let check = health_check::TcpHealthCheck::new();
backends.set_health_check(check);
let updated = AtomicBool::new(false);
backends
.update(|_| updated.store(true, Relaxed))
.await
.unwrap();
assert!(updated.load(Relaxed));
let updated = AtomicBool::new(false);
backends
.update(|_| updated.store(true, Relaxed))
.await
.unwrap();
assert!(!updated.load(Relaxed));
backends.run_health_check(false).await;
let backend = backends.get_backend();
assert!(backend.contains(&good1));
assert!(backend.contains(&good2));
assert!(backend.contains(&bad));
assert!(backends.ready(&good1));
assert!(backends.ready(&good2));
assert!(!backends.ready(&bad));
}
#[tokio::test]
async fn test_backends_with_ext() {
let discovery = discovery::Static::default();
let mut b1 = Backend::new("1.1.1.1:80").unwrap();
b1.ext.insert(true);
let mut b2 = Backend::new("1.0.0.1:80").unwrap();
b2.ext.insert(1u8);
discovery.add(b1.clone());
discovery.add(b2.clone());
let backends = Backends::new(Box::new(discovery));
backends.update(|_| {}).await.unwrap();
let backend = backends.get_backend();
assert!(backend.contains(&b1));
assert!(backend.contains(&b2));
let b2 = backend.first().unwrap();
assert_eq!(b2.ext.get::<u8>(), Some(&1));
let b1 = backend.last().unwrap();
assert_eq!(b1.ext.get::<bool>(), Some(&true));
}
#[tokio::test]
async fn test_discovery_readiness() {
use discovery::Static;
struct TestDiscovery(Static);
#[async_trait]
impl ServiceDiscovery for TestDiscovery {
async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
let bad = Backend::new("127.0.0.1:79").unwrap();
let (backends, mut readiness) = self.0.discover().await?;
readiness.insert(bad.hash_key(), false);
Ok((backends, readiness))
}
}
let discovery = Static::default();
let good1 = Backend::new("1.1.1.1:80").unwrap();
discovery.add(good1.clone());
let good2 = Backend::new("1.0.0.1:80").unwrap();
discovery.add(good2.clone());
let bad = Backend::new("127.0.0.1:79").unwrap();
discovery.add(bad.clone());
let discovery = TestDiscovery(discovery);
let backends = Backends::new(Box::new(discovery));
let updated = AtomicBool::new(false);
backends
.update(|_| updated.store(true, Relaxed))
.await
.unwrap();
assert!(updated.load(Relaxed));
let backend = backends.get_backend();
assert!(backend.contains(&good1));
assert!(backend.contains(&good2));
assert!(backend.contains(&bad));
assert!(backends.ready(&good1));
assert!(backends.ready(&good2));
assert!(!backends.ready(&bad));
}
#[tokio::test]
async fn test_parallel_health_check() {
let discovery = discovery::Static::default();
let good1 = Backend::new("1.1.1.1:80").unwrap();
discovery.add(good1.clone());
let good2 = Backend::new("1.0.0.1:80").unwrap();
discovery.add(good2.clone());
let bad = Backend::new("127.0.0.1:79").unwrap();
discovery.add(bad.clone());
let mut backends = Backends::new(Box::new(discovery));
let check = health_check::TcpHealthCheck::new();
backends.set_health_check(check);
let updated = AtomicBool::new(false);
backends
.update(|_| updated.store(true, Relaxed))
.await
.unwrap();
assert!(updated.load(Relaxed));
backends.run_health_check(true).await;
assert!(backends.ready(&good1));
assert!(backends.ready(&good2));
assert!(!backends.ready(&bad));
}
mod thread_safety {
use super::*;
struct MockDiscovery {
expected: usize,
}
#[async_trait]
impl ServiceDiscovery for MockDiscovery {
async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
let mut d = BTreeSet::new();
let mut m = HashMap::with_capacity(self.expected);
for i in 0..self.expected {
let b = Backend::new(&format!("1.1.1.1:{i}")).unwrap();
m.insert(i as u64, true);
d.insert(b);
}
Ok((d, m))
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_consistency() {
let expected = 3000;
let discovery = MockDiscovery { expected };
let lb = Arc::new(LoadBalancer::<selection::Consistent>::from_backends(
Backends::new(Box::new(discovery)),
));
let lb2 = lb.clone();
tokio::spawn(async move {
assert!(lb2.update().await.is_ok());
});
let mut backend_count = 0;
while backend_count == 0 {
let backends = lb.backends();
backend_count = backends.backends.load_full().len();
}
assert_eq!(backend_count, expected);
assert!(lb.select_with(b"test", 1, |_, _| true).is_some());
}
}
}