1#![allow(clippy::non_canonical_partial_ord_impl)]
22
23use arc_swap::ArcSwap;
24use derivative::Derivative;
25use futures::FutureExt;
26pub use http::Extensions;
27use pingora_core::protocols::l4::socket::SocketAddr;
28use pingora_error::{ErrorType, OrErr, Result};
29use std::collections::hash_map::DefaultHasher;
30use std::collections::{BTreeSet, HashMap};
31use std::hash::{Hash, Hasher};
32use std::io::Result as IoResult;
33use std::net::ToSocketAddrs;
34use std::sync::Arc;
35use std::time::Duration;
36
37mod background;
38pub mod discovery;
39pub mod health_check;
40pub mod selection;
41
42use discovery::ServiceDiscovery;
43use health_check::Health;
44use selection::UniqueIterator;
45use selection::{BackendIter, BackendSelection};
46
47pub mod prelude {
48 pub use crate::health_check::TcpHealthCheck;
49 pub use crate::selection::RoundRobin;
50 pub use crate::LoadBalancer;
51}
52
53#[derive(Derivative)]
55#[derivative(Clone, Hash, PartialEq, PartialOrd, Eq, Ord, Debug)]
56pub struct Backend {
57 pub addr: SocketAddr,
59 pub weight: usize,
62
63 #[derivative(PartialEq = "ignore")]
70 #[derivative(PartialOrd = "ignore")]
71 #[derivative(Hash = "ignore")]
72 #[derivative(Ord = "ignore")]
73 pub ext: Extensions,
74}
75
76impl Backend {
77 pub fn new(addr: &str) -> Result<Self> {
80 Self::new_with_weight(addr, 1)
81 }
82
83 pub fn new_with_weight(addr: &str, weight: usize) -> Result<Self> {
86 let addr = addr
87 .parse()
88 .or_err(ErrorType::InternalError, "invalid socket addr")?;
89 Ok(Backend {
90 addr: SocketAddr::Inet(addr),
91 weight,
92 ext: Extensions::new(),
93 })
94 }
96
97 pub(crate) fn hash_key(&self) -> u64 {
98 let mut hasher = DefaultHasher::new();
99 self.hash(&mut hasher);
100 hasher.finish()
101 }
102}
103
104impl std::ops::Deref for Backend {
105 type Target = SocketAddr;
106
107 fn deref(&self) -> &Self::Target {
108 &self.addr
109 }
110}
111
112impl std::ops::DerefMut for Backend {
113 fn deref_mut(&mut self) -> &mut Self::Target {
114 &mut self.addr
115 }
116}
117
118impl std::net::ToSocketAddrs for Backend {
119 type Iter = std::iter::Once<std::net::SocketAddr>;
120
121 fn to_socket_addrs(&self) -> std::io::Result<Self::Iter> {
122 self.addr.to_socket_addrs()
123 }
124}
125
126pub struct Backends {
132 discovery: Box<dyn ServiceDiscovery + Send + Sync + 'static>,
133 health_check: Option<Arc<dyn health_check::HealthCheck + Send + Sync + 'static>>,
134 backends: ArcSwap<BTreeSet<Backend>>,
135 health: ArcSwap<HashMap<u64, Health>>,
136}
137
138impl Backends {
139 pub fn new(discovery: Box<dyn ServiceDiscovery + Send + Sync + 'static>) -> Self {
143 Self {
144 discovery,
145 health_check: None,
146 backends: Default::default(),
147 health: Default::default(),
148 }
149 }
150
151 pub fn set_health_check(
153 &mut self,
154 hc: Box<dyn health_check::HealthCheck + Send + Sync + 'static>,
155 ) {
156 self.health_check = Some(hc.into())
157 }
158
159 fn do_update<F>(
163 &self,
164 new_backends: BTreeSet<Backend>,
165 enablement: HashMap<u64, bool>,
166 callback: F,
167 ) where
168 F: Fn(Arc<BTreeSet<Backend>>),
169 {
170 if (**self.backends.load()) != new_backends {
171 let old_health = self.health.load();
172 let mut health = HashMap::with_capacity(new_backends.len());
173 for backend in new_backends.iter() {
174 let hash_key = backend.hash_key();
175 let backend_health = old_health.get(&hash_key).cloned().unwrap_or_default();
177
178 if let Some(backend_enabled) = enablement.get(&hash_key) {
180 backend_health.enable(*backend_enabled);
181 }
182 health.insert(hash_key, backend_health);
183 }
184
185 let new_backends = Arc::new(new_backends);
190 callback(new_backends.clone());
191 self.backends.store(new_backends);
192 self.health.store(Arc::new(health));
193 } else {
194 for (hash_key, backend_enabled) in enablement.iter() {
196 if let Some(backend_health) = self.health.load().get(hash_key) {
199 backend_health.enable(*backend_enabled);
200 }
201 }
202 }
203 }
204
205 pub fn ready(&self, backend: &Backend) -> bool {
212 self.health
213 .load()
214 .get(&backend.hash_key())
215 .map_or(self.health_check.is_none(), |h| h.ready())
218 }
219
220 pub fn set_enable(&self, backend: &Backend, enabled: bool) {
227 if let Some(h) = self.health.load().get(&backend.hash_key()) {
229 h.enable(enabled)
230 };
231 }
232
233 pub fn get_backend(&self) -> Arc<BTreeSet<Backend>> {
235 self.backends.load_full()
236 }
237
238 pub async fn update<F>(&self, callback: F) -> Result<()>
243 where
244 F: Fn(Arc<BTreeSet<Backend>>),
245 {
246 let (new_backends, enablement) = self.discovery.discover().await?;
247 self.do_update(new_backends, enablement, callback);
248 Ok(())
249 }
250
251 pub async fn run_health_check(&self, parallel: bool) {
255 use crate::health_check::HealthCheck;
256 use log::{info, warn};
257 use pingora_runtime::current_handle;
258
259 async fn check_and_report(
260 backend: &Backend,
261 check: &Arc<dyn HealthCheck + Send + Sync>,
262 health_table: &HashMap<u64, Health>,
263 ) {
264 let errored = check.check(backend).await.err();
265 if let Some(h) = health_table.get(&backend.hash_key()) {
266 let flipped =
267 h.observe_health(errored.is_none(), check.health_threshold(errored.is_none()));
268 if flipped {
269 check.health_status_change(backend, errored.is_none()).await;
270 let summary = check.backend_summary(backend);
271 if let Some(e) = errored {
272 warn!("{summary} becomes unhealthy, {e}");
273 } else {
274 info!("{summary} becomes healthy");
275 }
276 }
277 }
278 }
279
280 let Some(health_check) = self.health_check.as_ref() else {
281 return;
282 };
283
284 let backends = self.backends.load();
285 if parallel {
286 let health_table = self.health.load_full();
287 let runtime = current_handle();
288 let jobs = backends.iter().map(|backend| {
289 let backend = backend.clone();
290 let check = health_check.clone();
291 let ht = health_table.clone();
292 runtime.spawn(async move {
293 check_and_report(&backend, &check, &ht).await;
294 })
295 });
296
297 futures::future::join_all(jobs).await;
298 } else {
299 for backend in backends.iter() {
300 check_and_report(backend, health_check, &self.health.load()).await;
301 }
302 }
303 }
304}
305
306pub struct LoadBalancer<S>
312where
313 S: BackendSelection,
314{
315 backends: Backends,
316 selector: ArcSwap<S>,
317
318 config: Option<S::Config>,
319
320 pub health_check_frequency: Option<Duration>,
324 pub update_frequency: Option<Duration>,
328 pub parallel_health_check: bool,
330}
331
332impl<S> LoadBalancer<S>
333where
334 S: BackendSelection + 'static,
335 S::Iter: BackendIter,
336{
337 pub fn try_from_iter<A, T: IntoIterator<Item = A>>(iter: T) -> IoResult<Self>
342 where
343 A: ToSocketAddrs,
344 {
345 let discovery = discovery::Static::try_from_iter(iter)?;
346 let backends = Backends::new(discovery);
347 let lb = Self::from_backends(backends);
348 lb.update()
349 .now_or_never()
350 .expect("static should not block")
351 .expect("static should not error");
352 Ok(lb)
353 }
354
355 pub fn from_backends_with_config(backends: Backends, config_opt: Option<S::Config>) -> Self {
357 let selector_raw = if let Some(config) = config_opt.as_ref() {
358 S::build_with_config(&backends.get_backend(), config)
359 } else {
360 S::build(&backends.get_backend())
361 };
362
363 let selector = ArcSwap::new(Arc::new(selector_raw));
364
365 LoadBalancer {
366 backends,
367 selector,
368 config: config_opt,
369 health_check_frequency: None,
370 update_frequency: None,
371 parallel_health_check: false,
372 }
373 }
374
375 pub fn from_backends(backends: Backends) -> Self {
377 Self::from_backends_with_config(backends, None)
378 }
379
380 pub async fn update(&self) -> Result<()> {
385 self.backends
386 .update(|backends| {
387 let selector = if let Some(config) = &self.config {
388 S::build_with_config(&backends, config)
389 } else {
390 S::build(&backends)
391 };
392
393 self.selector.store(Arc::new(selector))
394 })
395 .await
396 }
397
398 pub fn select(&self, key: &[u8], max_iterations: usize) -> Option<Backend> {
409 self.select_with(key, max_iterations, |_, health| health)
410 }
411
412 pub fn select_with<F>(&self, key: &[u8], max_iterations: usize, accept: F) -> Option<Backend>
420 where
421 F: Fn(&Backend, bool) -> bool,
422 {
423 let selection = self.selector.load();
424 let mut iter = UniqueIterator::new(selection.iter(key), max_iterations);
425 while let Some(b) = iter.get_next() {
426 if accept(&b, self.backends.ready(&b)) {
427 return Some(b);
428 }
429 }
430 None
431 }
432
433 pub fn set_health_check(
435 &mut self,
436 hc: Box<dyn health_check::HealthCheck + Send + Sync + 'static>,
437 ) {
438 self.backends.set_health_check(hc);
439 }
440
441 pub fn backends(&self) -> &Backends {
443 &self.backends
444 }
445}
446
447#[cfg(test)]
448mod test {
449 use std::sync::atomic::{AtomicBool, Ordering::Relaxed};
450
451 use super::*;
452 use async_trait::async_trait;
453
454 #[tokio::test]
455 async fn test_static_backends() {
456 let backends: LoadBalancer<selection::RoundRobin> =
457 LoadBalancer::try_from_iter(["1.1.1.1:80", "1.0.0.1:80"]).unwrap();
458
459 let backend1 = Backend::new("1.1.1.1:80").unwrap();
460 let backend2 = Backend::new("1.0.0.1:80").unwrap();
461 let backend = backends.backends().get_backend();
462 assert!(backend.contains(&backend1));
463 assert!(backend.contains(&backend2));
464 }
465
466 #[tokio::test]
467 async fn test_backends() {
468 let discovery = discovery::Static::default();
469 let good1 = Backend::new("1.1.1.1:80").unwrap();
470 discovery.add(good1.clone());
471 let good2 = Backend::new("1.0.0.1:80").unwrap();
472 discovery.add(good2.clone());
473 let bad = Backend::new("127.0.0.1:79").unwrap();
474 discovery.add(bad.clone());
475
476 let mut backends = Backends::new(Box::new(discovery));
477 let check = health_check::TcpHealthCheck::new();
478 backends.set_health_check(check);
479
480 let updated = AtomicBool::new(false);
482 backends
483 .update(|_| updated.store(true, Relaxed))
484 .await
485 .unwrap();
486 assert!(updated.load(Relaxed));
487
488 let updated = AtomicBool::new(false);
490 backends
491 .update(|_| updated.store(true, Relaxed))
492 .await
493 .unwrap();
494 assert!(!updated.load(Relaxed));
495
496 backends.run_health_check(false).await;
497
498 let backend = backends.get_backend();
499 assert!(backend.contains(&good1));
500 assert!(backend.contains(&good2));
501 assert!(backend.contains(&bad));
502
503 assert!(backends.ready(&good1));
504 assert!(backends.ready(&good2));
505 assert!(!backends.ready(&bad));
506 }
507 #[tokio::test]
508 async fn test_backends_with_ext() {
509 let discovery = discovery::Static::default();
510 let mut b1 = Backend::new("1.1.1.1:80").unwrap();
511 b1.ext.insert(true);
512 let mut b2 = Backend::new("1.0.0.1:80").unwrap();
513 b2.ext.insert(1u8);
514 discovery.add(b1.clone());
515 discovery.add(b2.clone());
516
517 let backends = Backends::new(Box::new(discovery));
518
519 backends.update(|_| {}).await.unwrap();
521
522 let backend = backends.get_backend();
523 assert!(backend.contains(&b1));
524 assert!(backend.contains(&b2));
525
526 let b2 = backend.first().unwrap();
527 assert_eq!(b2.ext.get::<u8>(), Some(&1));
528
529 let b1 = backend.last().unwrap();
530 assert_eq!(b1.ext.get::<bool>(), Some(&true));
531 }
532
533 #[tokio::test]
534 async fn test_discovery_readiness() {
535 use discovery::Static;
536
537 struct TestDiscovery(Static);
538 #[async_trait]
539 impl ServiceDiscovery for TestDiscovery {
540 async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
541 let bad = Backend::new("127.0.0.1:79").unwrap();
542 let (backends, mut readiness) = self.0.discover().await?;
543 readiness.insert(bad.hash_key(), false);
544 Ok((backends, readiness))
545 }
546 }
547 let discovery = Static::default();
548 let good1 = Backend::new("1.1.1.1:80").unwrap();
549 discovery.add(good1.clone());
550 let good2 = Backend::new("1.0.0.1:80").unwrap();
551 discovery.add(good2.clone());
552 let bad = Backend::new("127.0.0.1:79").unwrap();
553 discovery.add(bad.clone());
554 let discovery = TestDiscovery(discovery);
555
556 let backends = Backends::new(Box::new(discovery));
557
558 let updated = AtomicBool::new(false);
560 backends
561 .update(|_| updated.store(true, Relaxed))
562 .await
563 .unwrap();
564 assert!(updated.load(Relaxed));
565
566 let backend = backends.get_backend();
567 assert!(backend.contains(&good1));
568 assert!(backend.contains(&good2));
569 assert!(backend.contains(&bad));
570
571 assert!(backends.ready(&good1));
572 assert!(backends.ready(&good2));
573 assert!(!backends.ready(&bad));
574 }
575
576 #[tokio::test]
577 async fn test_parallel_health_check() {
578 let discovery = discovery::Static::default();
579 let good1 = Backend::new("1.1.1.1:80").unwrap();
580 discovery.add(good1.clone());
581 let good2 = Backend::new("1.0.0.1:80").unwrap();
582 discovery.add(good2.clone());
583 let bad = Backend::new("127.0.0.1:79").unwrap();
584 discovery.add(bad.clone());
585
586 let mut backends = Backends::new(Box::new(discovery));
587 let check = health_check::TcpHealthCheck::new();
588 backends.set_health_check(check);
589
590 let updated = AtomicBool::new(false);
592 backends
593 .update(|_| updated.store(true, Relaxed))
594 .await
595 .unwrap();
596 assert!(updated.load(Relaxed));
597
598 backends.run_health_check(true).await;
599
600 assert!(backends.ready(&good1));
601 assert!(backends.ready(&good2));
602 assert!(!backends.ready(&bad));
603 }
604
605 mod thread_safety {
606 use super::*;
607
608 struct MockDiscovery {
609 expected: usize,
610 }
611 #[async_trait]
612 impl ServiceDiscovery for MockDiscovery {
613 async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
614 let mut d = BTreeSet::new();
615 let mut m = HashMap::with_capacity(self.expected);
616 for i in 0..self.expected {
617 let b = Backend::new(&format!("1.1.1.1:{i}")).unwrap();
618 m.insert(i as u64, true);
619 d.insert(b);
620 }
621 Ok((d, m))
622 }
623 }
624
625 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
626 async fn test_consistency() {
627 let expected = 3000;
628 let discovery = MockDiscovery { expected };
629 let lb = Arc::new(LoadBalancer::<selection::Consistent>::from_backends(
630 Backends::new(Box::new(discovery)),
631 ));
632 let lb2 = lb.clone();
633
634 tokio::spawn(async move {
635 assert!(lb2.update().await.is_ok());
636 });
637 let mut backend_count = 0;
638 while backend_count == 0 {
639 let backends = lb.backends();
640 backend_count = backends.backends.load_full().len();
641 }
642 assert_eq!(backend_count, expected);
643 assert!(lb.select_with(b"test", 1, |_, _| true).is_some());
644 }
645 }
646}