1#![allow(clippy::non_canonical_partial_ord_impl)]
22
23use arc_swap::ArcSwap;
24use derivative::Derivative;
25use futures::FutureExt;
26pub use http::Extensions;
27use pingora_core::protocols::l4::socket::SocketAddr;
28use pingora_error::{ErrorType, OrErr, Result};
29use std::collections::hash_map::DefaultHasher;
30use std::collections::{BTreeSet, HashMap};
31use std::hash::{Hash, Hasher};
32use std::io::Result as IoResult;
33use std::net::ToSocketAddrs;
34use std::sync::Arc;
35use std::time::Duration;
36
37mod background;
38pub mod discovery;
39pub mod health_check;
40pub mod selection;
41
42use discovery::ServiceDiscovery;
43use health_check::Health;
44use selection::UniqueIterator;
45use selection::{BackendIter, BackendSelection};
46
47pub mod prelude {
48 pub use crate::health_check::TcpHealthCheck;
49 pub use crate::selection::RoundRobin;
50 pub use crate::LoadBalancer;
51}
52
53#[derive(Derivative)]
55#[derivative(Clone, Hash, PartialEq, PartialOrd, Eq, Ord, Debug)]
56pub struct Backend {
57 pub addr: SocketAddr,
59 pub weight: usize,
62
63 #[derivative(PartialEq = "ignore")]
70 #[derivative(PartialOrd = "ignore")]
71 #[derivative(Hash = "ignore")]
72 #[derivative(Ord = "ignore")]
73 pub ext: Extensions,
74}
75
76impl Backend {
77 pub fn new(addr: &str) -> Result<Self> {
80 Self::new_with_weight(addr, 1)
81 }
82
83 pub fn new_with_weight(addr: &str, weight: usize) -> Result<Self> {
86 let addr = addr
87 .parse()
88 .or_err(ErrorType::InternalError, "invalid socket addr")?;
89 Ok(Backend {
90 addr: SocketAddr::Inet(addr),
91 weight,
92 ext: Extensions::new(),
93 })
94 }
96
97 pub(crate) fn hash_key(&self) -> u64 {
98 let mut hasher = DefaultHasher::new();
99 self.hash(&mut hasher);
100 hasher.finish()
101 }
102}
103
104impl std::ops::Deref for Backend {
105 type Target = SocketAddr;
106
107 fn deref(&self) -> &Self::Target {
108 &self.addr
109 }
110}
111
112impl std::ops::DerefMut for Backend {
113 fn deref_mut(&mut self) -> &mut Self::Target {
114 &mut self.addr
115 }
116}
117
118impl std::net::ToSocketAddrs for Backend {
119 type Iter = std::iter::Once<std::net::SocketAddr>;
120
121 fn to_socket_addrs(&self) -> std::io::Result<Self::Iter> {
122 self.addr.to_socket_addrs()
123 }
124}
125
126pub struct Backends {
132 discovery: Box<dyn ServiceDiscovery + Send + Sync + 'static>,
133 health_check: Option<Arc<dyn health_check::HealthCheck + Send + Sync + 'static>>,
134 backends: ArcSwap<BTreeSet<Backend>>,
135 health: ArcSwap<HashMap<u64, Health>>,
136}
137
138impl Backends {
139 pub fn new(discovery: Box<dyn ServiceDiscovery + Send + Sync + 'static>) -> Self {
143 Self {
144 discovery,
145 health_check: None,
146 backends: Default::default(),
147 health: Default::default(),
148 }
149 }
150
151 pub fn set_health_check(
153 &mut self,
154 hc: Box<dyn health_check::HealthCheck + Send + Sync + 'static>,
155 ) {
156 self.health_check = Some(hc.into())
157 }
158
159 fn do_update<F>(
163 &self,
164 new_backends: BTreeSet<Backend>,
165 enablement: HashMap<u64, bool>,
166 callback: F,
167 ) where
168 F: Fn(Arc<BTreeSet<Backend>>),
169 {
170 if (**self.backends.load()) != new_backends {
171 let old_health = self.health.load();
172 let mut health = HashMap::with_capacity(new_backends.len());
173 for backend in new_backends.iter() {
174 let hash_key = backend.hash_key();
175 let backend_health = old_health.get(&hash_key).cloned().unwrap_or_default();
177
178 if let Some(backend_enabled) = enablement.get(&hash_key) {
180 backend_health.enable(*backend_enabled);
181 }
182 health.insert(hash_key, backend_health);
183 }
184
185 let new_backends = Arc::new(new_backends);
190 callback(new_backends.clone());
191 self.backends.store(new_backends);
192 self.health.store(Arc::new(health));
193 } else {
194 for (hash_key, backend_enabled) in enablement.iter() {
196 if let Some(backend_health) = self.health.load().get(hash_key) {
199 backend_health.enable(*backend_enabled);
200 }
201 }
202 }
203 }
204
205 pub fn ready(&self, backend: &Backend) -> bool {
212 self.health
213 .load()
214 .get(&backend.hash_key())
215 .map_or(self.health_check.is_none(), |h| h.ready())
218 }
219
220 pub fn set_enable(&self, backend: &Backend, enabled: bool) {
227 if let Some(h) = self.health.load().get(&backend.hash_key()) {
229 h.enable(enabled)
230 };
231 }
232
233 pub fn get_backend(&self) -> Arc<BTreeSet<Backend>> {
235 self.backends.load_full()
236 }
237
238 pub async fn update<F>(&self, callback: F) -> Result<()>
243 where
244 F: Fn(Arc<BTreeSet<Backend>>),
245 {
246 let (new_backends, enablement) = self.discovery.discover().await?;
247 self.do_update(new_backends, enablement, callback);
248 Ok(())
249 }
250
251 pub async fn run_health_check(&self, parallel: bool) {
255 use crate::health_check::HealthCheck;
256 use log::{info, warn};
257 use pingora_runtime::current_handle;
258
259 async fn check_and_report(
260 backend: &Backend,
261 check: &Arc<dyn HealthCheck + Send + Sync>,
262 health_table: &HashMap<u64, Health>,
263 ) {
264 let errored = check.check(backend).await.err();
265 if let Some(h) = health_table.get(&backend.hash_key()) {
266 let flipped =
267 h.observe_health(errored.is_none(), check.health_threshold(errored.is_none()));
268 if flipped {
269 check.health_status_change(backend, errored.is_none()).await;
270 let summary = check.backend_summary(backend);
271 if let Some(e) = errored {
272 warn!("{summary} becomes unhealthy, {e}");
273 } else {
274 info!("{summary} becomes healthy");
275 }
276 }
277 }
278 }
279
280 let Some(health_check) = self.health_check.as_ref() else {
281 return;
282 };
283
284 let backends = self.backends.load();
285 if parallel {
286 let health_table = self.health.load_full();
287 let runtime = current_handle();
288 let jobs = backends.iter().map(|backend| {
289 let backend = backend.clone();
290 let check = health_check.clone();
291 let ht = health_table.clone();
292 runtime.spawn(async move {
293 check_and_report(&backend, &check, &ht).await;
294 })
295 });
296
297 futures::future::join_all(jobs).await;
298 } else {
299 for backend in backends.iter() {
300 check_and_report(backend, health_check, &self.health.load()).await;
301 }
302 }
303 }
304}
305
306pub struct LoadBalancer<S> {
312 backends: Backends,
313 selector: ArcSwap<S>,
314 pub health_check_frequency: Option<Duration>,
318 pub update_frequency: Option<Duration>,
322 pub parallel_health_check: bool,
324}
325
326impl<S> LoadBalancer<S>
327where
328 S: BackendSelection + 'static,
329 S::Iter: BackendIter,
330{
331 pub fn try_from_iter<A, T: IntoIterator<Item = A>>(iter: T) -> IoResult<Self>
336 where
337 A: ToSocketAddrs,
338 {
339 let discovery = discovery::Static::try_from_iter(iter)?;
340 let backends = Backends::new(discovery);
341 let lb = Self::from_backends(backends);
342 lb.update()
343 .now_or_never()
344 .expect("static should not block")
345 .expect("static should not error");
346 Ok(lb)
347 }
348
349 pub fn from_backends_with_config(backends: Backends, config: &S::Config) -> Self {
351 let selector = ArcSwap::new(Arc::new(S::build_with_config(
352 &backends.get_backend(),
353 config,
354 )));
355
356 LoadBalancer {
357 backends,
358 selector,
359 health_check_frequency: None,
360 update_frequency: None,
361 parallel_health_check: false,
362 }
363 }
364
365 pub fn from_backends(backends: Backends) -> Self {
367 let selector = ArcSwap::new(Arc::new(S::build(&backends.get_backend())));
368 LoadBalancer {
369 backends,
370 selector,
371 health_check_frequency: None,
372 update_frequency: None,
373 parallel_health_check: false,
374 }
375 }
376
377 pub async fn update(&self) -> Result<()> {
382 self.backends
383 .update(|backends| self.selector.store(Arc::new(S::build(&backends))))
384 .await
385 }
386
387 pub fn select(&self, key: &[u8], max_iterations: usize) -> Option<Backend> {
398 self.select_with(key, max_iterations, |_, health| health)
399 }
400
401 pub fn select_with<F>(&self, key: &[u8], max_iterations: usize, accept: F) -> Option<Backend>
409 where
410 F: Fn(&Backend, bool) -> bool,
411 {
412 let selection = self.selector.load();
413 let mut iter = UniqueIterator::new(selection.iter(key), max_iterations);
414 while let Some(b) = iter.get_next() {
415 if accept(&b, self.backends.ready(&b)) {
416 return Some(b);
417 }
418 }
419 None
420 }
421
422 pub fn set_health_check(
424 &mut self,
425 hc: Box<dyn health_check::HealthCheck + Send + Sync + 'static>,
426 ) {
427 self.backends.set_health_check(hc);
428 }
429
430 pub fn backends(&self) -> &Backends {
432 &self.backends
433 }
434}
435
436#[cfg(test)]
437mod test {
438 use std::sync::atomic::{AtomicBool, Ordering::Relaxed};
439
440 use super::*;
441 use async_trait::async_trait;
442
443 #[tokio::test]
444 async fn test_static_backends() {
445 let backends: LoadBalancer<selection::RoundRobin> =
446 LoadBalancer::try_from_iter(["1.1.1.1:80", "1.0.0.1:80"]).unwrap();
447
448 let backend1 = Backend::new("1.1.1.1:80").unwrap();
449 let backend2 = Backend::new("1.0.0.1:80").unwrap();
450 let backend = backends.backends().get_backend();
451 assert!(backend.contains(&backend1));
452 assert!(backend.contains(&backend2));
453 }
454
455 #[tokio::test]
456 async fn test_backends() {
457 let discovery = discovery::Static::default();
458 let good1 = Backend::new("1.1.1.1:80").unwrap();
459 discovery.add(good1.clone());
460 let good2 = Backend::new("1.0.0.1:80").unwrap();
461 discovery.add(good2.clone());
462 let bad = Backend::new("127.0.0.1:79").unwrap();
463 discovery.add(bad.clone());
464
465 let mut backends = Backends::new(Box::new(discovery));
466 let check = health_check::TcpHealthCheck::new();
467 backends.set_health_check(check);
468
469 let updated = AtomicBool::new(false);
471 backends
472 .update(|_| updated.store(true, Relaxed))
473 .await
474 .unwrap();
475 assert!(updated.load(Relaxed));
476
477 let updated = AtomicBool::new(false);
479 backends
480 .update(|_| updated.store(true, Relaxed))
481 .await
482 .unwrap();
483 assert!(!updated.load(Relaxed));
484
485 backends.run_health_check(false).await;
486
487 let backend = backends.get_backend();
488 assert!(backend.contains(&good1));
489 assert!(backend.contains(&good2));
490 assert!(backend.contains(&bad));
491
492 assert!(backends.ready(&good1));
493 assert!(backends.ready(&good2));
494 assert!(!backends.ready(&bad));
495 }
496 #[tokio::test]
497 async fn test_backends_with_ext() {
498 let discovery = discovery::Static::default();
499 let mut b1 = Backend::new("1.1.1.1:80").unwrap();
500 b1.ext.insert(true);
501 let mut b2 = Backend::new("1.0.0.1:80").unwrap();
502 b2.ext.insert(1u8);
503 discovery.add(b1.clone());
504 discovery.add(b2.clone());
505
506 let backends = Backends::new(Box::new(discovery));
507
508 backends.update(|_| {}).await.unwrap();
510
511 let backend = backends.get_backend();
512 assert!(backend.contains(&b1));
513 assert!(backend.contains(&b2));
514
515 let b2 = backend.first().unwrap();
516 assert_eq!(b2.ext.get::<u8>(), Some(&1));
517
518 let b1 = backend.last().unwrap();
519 assert_eq!(b1.ext.get::<bool>(), Some(&true));
520 }
521
522 #[tokio::test]
523 async fn test_discovery_readiness() {
524 use discovery::Static;
525
526 struct TestDiscovery(Static);
527 #[async_trait]
528 impl ServiceDiscovery for TestDiscovery {
529 async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
530 let bad = Backend::new("127.0.0.1:79").unwrap();
531 let (backends, mut readiness) = self.0.discover().await?;
532 readiness.insert(bad.hash_key(), false);
533 Ok((backends, readiness))
534 }
535 }
536 let discovery = Static::default();
537 let good1 = Backend::new("1.1.1.1:80").unwrap();
538 discovery.add(good1.clone());
539 let good2 = Backend::new("1.0.0.1:80").unwrap();
540 discovery.add(good2.clone());
541 let bad = Backend::new("127.0.0.1:79").unwrap();
542 discovery.add(bad.clone());
543 let discovery = TestDiscovery(discovery);
544
545 let backends = Backends::new(Box::new(discovery));
546
547 let updated = AtomicBool::new(false);
549 backends
550 .update(|_| updated.store(true, Relaxed))
551 .await
552 .unwrap();
553 assert!(updated.load(Relaxed));
554
555 let backend = backends.get_backend();
556 assert!(backend.contains(&good1));
557 assert!(backend.contains(&good2));
558 assert!(backend.contains(&bad));
559
560 assert!(backends.ready(&good1));
561 assert!(backends.ready(&good2));
562 assert!(!backends.ready(&bad));
563 }
564
565 #[tokio::test]
566 async fn test_parallel_health_check() {
567 let discovery = discovery::Static::default();
568 let good1 = Backend::new("1.1.1.1:80").unwrap();
569 discovery.add(good1.clone());
570 let good2 = Backend::new("1.0.0.1:80").unwrap();
571 discovery.add(good2.clone());
572 let bad = Backend::new("127.0.0.1:79").unwrap();
573 discovery.add(bad.clone());
574
575 let mut backends = Backends::new(Box::new(discovery));
576 let check = health_check::TcpHealthCheck::new();
577 backends.set_health_check(check);
578
579 let updated = AtomicBool::new(false);
581 backends
582 .update(|_| updated.store(true, Relaxed))
583 .await
584 .unwrap();
585 assert!(updated.load(Relaxed));
586
587 backends.run_health_check(true).await;
588
589 assert!(backends.ready(&good1));
590 assert!(backends.ready(&good2));
591 assert!(!backends.ready(&bad));
592 }
593
594 mod thread_safety {
595 use super::*;
596
597 struct MockDiscovery {
598 expected: usize,
599 }
600 #[async_trait]
601 impl ServiceDiscovery for MockDiscovery {
602 async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
603 let mut d = BTreeSet::new();
604 let mut m = HashMap::with_capacity(self.expected);
605 for i in 0..self.expected {
606 let b = Backend::new(&format!("1.1.1.1:{i}")).unwrap();
607 m.insert(i as u64, true);
608 d.insert(b);
609 }
610 Ok((d, m))
611 }
612 }
613
614 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
615 async fn test_consistency() {
616 let expected = 3000;
617 let discovery = MockDiscovery { expected };
618 let lb = Arc::new(LoadBalancer::<selection::Consistent>::from_backends(
619 Backends::new(Box::new(discovery)),
620 ));
621 let lb2 = lb.clone();
622
623 tokio::spawn(async move {
624 assert!(lb2.update().await.is_ok());
625 });
626 let mut backend_count = 0;
627 while backend_count == 0 {
628 let backends = lb.backends();
629 backend_count = backends.backends.load_full().len();
630 }
631 assert_eq!(backend_count, expected);
632 assert!(lb.select_with(b"test", 1, |_, _| true).is_some());
633 }
634 }
635}