1#![allow(clippy::non_canonical_partial_ord_impl)]
22
23use arc_swap::ArcSwap;
24use derivative::Derivative;
25use futures::FutureExt;
26pub use http::Extensions;
27use pingora_core::protocols::l4::socket::SocketAddr;
28use pingora_error::{ErrorType, OrErr, Result};
29use std::collections::hash_map::DefaultHasher;
30use std::collections::{BTreeSet, HashMap};
31use std::hash::{Hash, Hasher};
32use std::io::Result as IoResult;
33use std::net::ToSocketAddrs;
34use std::sync::Arc;
35use std::time::Duration;
36
37mod background;
38pub mod discovery;
39pub mod health_check;
40pub mod selection;
41
42use discovery::ServiceDiscovery;
43use health_check::Health;
44use selection::UniqueIterator;
45use selection::{BackendIter, BackendSelection};
46
47pub mod prelude {
48 pub use crate::health_check::TcpHealthCheck;
49 pub use crate::selection::RoundRobin;
50 pub use crate::LoadBalancer;
51}
52
53#[derive(Derivative)]
55#[derivative(Clone, Hash, PartialEq, PartialOrd, Eq, Ord, Debug)]
56pub struct Backend {
57 pub addr: SocketAddr,
59 pub weight: usize,
62
63 #[derivative(PartialEq = "ignore")]
70 #[derivative(PartialOrd = "ignore")]
71 #[derivative(Hash = "ignore")]
72 #[derivative(Ord = "ignore")]
73 pub ext: Extensions,
74}
75
76impl Backend {
77 pub fn new(addr: &str) -> Result<Self> {
80 Self::new_with_weight(addr, 1)
81 }
82
83 pub fn new_with_weight(addr: &str, weight: usize) -> Result<Self> {
86 let addr = addr
87 .parse()
88 .or_err(ErrorType::InternalError, "invalid socket addr")?;
89 Ok(Backend {
90 addr: SocketAddr::Inet(addr),
91 weight,
92 ext: Extensions::new(),
93 })
94 }
96
97 pub(crate) fn hash_key(&self) -> u64 {
98 let mut hasher = DefaultHasher::new();
99 self.hash(&mut hasher);
100 hasher.finish()
101 }
102}
103
104impl std::ops::Deref for Backend {
105 type Target = SocketAddr;
106
107 fn deref(&self) -> &Self::Target {
108 &self.addr
109 }
110}
111
112impl std::ops::DerefMut for Backend {
113 fn deref_mut(&mut self) -> &mut Self::Target {
114 &mut self.addr
115 }
116}
117
118impl std::net::ToSocketAddrs for Backend {
119 type Iter = std::iter::Once<std::net::SocketAddr>;
120
121 fn to_socket_addrs(&self) -> std::io::Result<Self::Iter> {
122 self.addr.to_socket_addrs()
123 }
124}
125
126pub struct Backends {
132 discovery: Box<dyn ServiceDiscovery + Send + Sync + 'static>,
133 health_check: Option<Arc<dyn health_check::HealthCheck + Send + Sync + 'static>>,
134 backends: ArcSwap<BTreeSet<Backend>>,
135 health: ArcSwap<HashMap<u64, Health>>,
136}
137
138impl Backends {
139 pub fn new(discovery: Box<dyn ServiceDiscovery + Send + Sync + 'static>) -> Self {
143 Self {
144 discovery,
145 health_check: None,
146 backends: Default::default(),
147 health: Default::default(),
148 }
149 }
150
151 pub fn set_health_check(
153 &mut self,
154 hc: Box<dyn health_check::HealthCheck + Send + Sync + 'static>,
155 ) {
156 self.health_check = Some(hc.into())
157 }
158
159 fn do_update<F>(
163 &self,
164 new_backends: BTreeSet<Backend>,
165 enablement: HashMap<u64, bool>,
166 callback: F,
167 ) where
168 F: Fn(Arc<BTreeSet<Backend>>),
169 {
170 if (**self.backends.load()) != new_backends {
171 let old_health = self.health.load();
172 let mut health = HashMap::with_capacity(new_backends.len());
173 for backend in new_backends.iter() {
174 let hash_key = backend.hash_key();
175 let backend_health = old_health.get(&hash_key).cloned().unwrap_or_default();
177
178 if let Some(backend_enabled) = enablement.get(&hash_key) {
180 backend_health.enable(*backend_enabled);
181 }
182 health.insert(hash_key, backend_health);
183 }
184
185 let new_backends = Arc::new(new_backends);
190 callback(new_backends.clone());
191 self.backends.store(new_backends);
192 self.health.store(Arc::new(health));
193 } else {
194 for (hash_key, backend_enabled) in enablement.iter() {
196 if let Some(backend_health) = self.health.load().get(hash_key) {
199 backend_health.enable(*backend_enabled);
200 }
201 }
202 }
203 }
204
205 pub fn ready(&self, backend: &Backend) -> bool {
212 self.health
213 .load()
214 .get(&backend.hash_key())
215 .map_or(self.health_check.is_none(), |h| h.ready())
218 }
219
220 pub fn set_enable(&self, backend: &Backend, enabled: bool) {
227 if let Some(h) = self.health.load().get(&backend.hash_key()) {
229 h.enable(enabled)
230 };
231 }
232
233 pub fn get_backend(&self) -> Arc<BTreeSet<Backend>> {
235 self.backends.load_full()
236 }
237
238 pub async fn update<F>(&self, callback: F) -> Result<()>
243 where
244 F: Fn(Arc<BTreeSet<Backend>>),
245 {
246 let (new_backends, enablement) = self.discovery.discover().await?;
247 self.do_update(new_backends, enablement, callback);
248 Ok(())
249 }
250
251 pub async fn run_health_check(&self, parallel: bool) {
255 use crate::health_check::HealthCheck;
256 use log::{info, warn};
257 use pingora_runtime::current_handle;
258
259 async fn check_and_report(
260 backend: &Backend,
261 check: &Arc<dyn HealthCheck + Send + Sync>,
262 health_table: &HashMap<u64, Health>,
263 ) {
264 let errored = check.check(backend).await.err();
265 if let Some(h) = health_table.get(&backend.hash_key()) {
266 let flipped =
267 h.observe_health(errored.is_none(), check.health_threshold(errored.is_none()));
268 if flipped {
269 check.health_status_change(backend, errored.is_none()).await;
270 let summary = check.backend_summary(backend);
271 if let Some(e) = errored {
272 warn!("{summary} becomes unhealthy, {e}");
273 } else {
274 info!("{summary} becomes healthy");
275 }
276 }
277 }
278 }
279
280 let Some(health_check) = self.health_check.as_ref() else {
281 return;
282 };
283
284 let backends = self.backends.load();
285 if parallel {
286 let health_table = self.health.load_full();
287 let runtime = current_handle();
288 let jobs = backends.iter().map(|backend| {
289 let backend = backend.clone();
290 let check = health_check.clone();
291 let ht = health_table.clone();
292 runtime.spawn(async move {
293 check_and_report(&backend, &check, &ht).await;
294 })
295 });
296
297 futures::future::join_all(jobs).await;
298 } else {
299 for backend in backends.iter() {
300 check_and_report(backend, health_check, &self.health.load()).await;
301 }
302 }
303 }
304}
305
306pub struct LoadBalancer<S> {
312 backends: Backends,
313 selector: ArcSwap<S>,
314 pub health_check_frequency: Option<Duration>,
318 pub update_frequency: Option<Duration>,
322 pub parallel_health_check: bool,
324}
325
326impl<S: BackendSelection> LoadBalancer<S>
327where
328 S: BackendSelection + 'static,
329 S::Iter: BackendIter,
330{
331 pub fn try_from_iter<A, T: IntoIterator<Item = A>>(iter: T) -> IoResult<Self>
336 where
337 A: ToSocketAddrs,
338 {
339 let discovery = discovery::Static::try_from_iter(iter)?;
340 let backends = Backends::new(discovery);
341 let lb = Self::from_backends(backends);
342 lb.update()
343 .now_or_never()
344 .expect("static should not block")
345 .expect("static should not error");
346 Ok(lb)
347 }
348
349 pub fn from_backends(backends: Backends) -> Self {
351 let selector = ArcSwap::new(Arc::new(S::build(&backends.get_backend())));
352 LoadBalancer {
353 backends,
354 selector,
355 health_check_frequency: None,
356 update_frequency: None,
357 parallel_health_check: false,
358 }
359 }
360
361 pub async fn update(&self) -> Result<()> {
366 self.backends
367 .update(|backends| self.selector.store(Arc::new(S::build(&backends))))
368 .await
369 }
370
371 pub fn select(&self, key: &[u8], max_iterations: usize) -> Option<Backend> {
382 self.select_with(key, max_iterations, |_, health| health)
383 }
384
385 pub fn select_with<F>(&self, key: &[u8], max_iterations: usize, accept: F) -> Option<Backend>
393 where
394 F: Fn(&Backend, bool) -> bool,
395 {
396 let selection = self.selector.load();
397 let mut iter = UniqueIterator::new(selection.iter(key), max_iterations);
398 while let Some(b) = iter.get_next() {
399 if accept(&b, self.backends.ready(&b)) {
400 return Some(b);
401 }
402 }
403 None
404 }
405
406 pub fn set_health_check(
408 &mut self,
409 hc: Box<dyn health_check::HealthCheck + Send + Sync + 'static>,
410 ) {
411 self.backends.set_health_check(hc);
412 }
413
414 pub fn backends(&self) -> &Backends {
416 &self.backends
417 }
418}
419
420#[cfg(test)]
421mod test {
422 use std::sync::atomic::{AtomicBool, Ordering::Relaxed};
423
424 use super::*;
425 use async_trait::async_trait;
426
427 #[tokio::test]
428 async fn test_static_backends() {
429 let backends: LoadBalancer<selection::RoundRobin> =
430 LoadBalancer::try_from_iter(["1.1.1.1:80", "1.0.0.1:80"]).unwrap();
431
432 let backend1 = Backend::new("1.1.1.1:80").unwrap();
433 let backend2 = Backend::new("1.0.0.1:80").unwrap();
434 let backend = backends.backends().get_backend();
435 assert!(backend.contains(&backend1));
436 assert!(backend.contains(&backend2));
437 }
438
439 #[tokio::test]
440 async fn test_backends() {
441 let discovery = discovery::Static::default();
442 let good1 = Backend::new("1.1.1.1:80").unwrap();
443 discovery.add(good1.clone());
444 let good2 = Backend::new("1.0.0.1:80").unwrap();
445 discovery.add(good2.clone());
446 let bad = Backend::new("127.0.0.1:79").unwrap();
447 discovery.add(bad.clone());
448
449 let mut backends = Backends::new(Box::new(discovery));
450 let check = health_check::TcpHealthCheck::new();
451 backends.set_health_check(check);
452
453 let updated = AtomicBool::new(false);
455 backends
456 .update(|_| updated.store(true, Relaxed))
457 .await
458 .unwrap();
459 assert!(updated.load(Relaxed));
460
461 let updated = AtomicBool::new(false);
463 backends
464 .update(|_| updated.store(true, Relaxed))
465 .await
466 .unwrap();
467 assert!(!updated.load(Relaxed));
468
469 backends.run_health_check(false).await;
470
471 let backend = backends.get_backend();
472 assert!(backend.contains(&good1));
473 assert!(backend.contains(&good2));
474 assert!(backend.contains(&bad));
475
476 assert!(backends.ready(&good1));
477 assert!(backends.ready(&good2));
478 assert!(!backends.ready(&bad));
479 }
480 #[tokio::test]
481 async fn test_backends_with_ext() {
482 let discovery = discovery::Static::default();
483 let mut b1 = Backend::new("1.1.1.1:80").unwrap();
484 b1.ext.insert(true);
485 let mut b2 = Backend::new("1.0.0.1:80").unwrap();
486 b2.ext.insert(1u8);
487 discovery.add(b1.clone());
488 discovery.add(b2.clone());
489
490 let backends = Backends::new(Box::new(discovery));
491
492 backends.update(|_| {}).await.unwrap();
494
495 let backend = backends.get_backend();
496 assert!(backend.contains(&b1));
497 assert!(backend.contains(&b2));
498
499 let b2 = backend.first().unwrap();
500 assert_eq!(b2.ext.get::<u8>(), Some(&1));
501
502 let b1 = backend.last().unwrap();
503 assert_eq!(b1.ext.get::<bool>(), Some(&true));
504 }
505
506 #[tokio::test]
507 async fn test_discovery_readiness() {
508 use discovery::Static;
509
510 struct TestDiscovery(Static);
511 #[async_trait]
512 impl ServiceDiscovery for TestDiscovery {
513 async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
514 let bad = Backend::new("127.0.0.1:79").unwrap();
515 let (backends, mut readiness) = self.0.discover().await?;
516 readiness.insert(bad.hash_key(), false);
517 Ok((backends, readiness))
518 }
519 }
520 let discovery = Static::default();
521 let good1 = Backend::new("1.1.1.1:80").unwrap();
522 discovery.add(good1.clone());
523 let good2 = Backend::new("1.0.0.1:80").unwrap();
524 discovery.add(good2.clone());
525 let bad = Backend::new("127.0.0.1:79").unwrap();
526 discovery.add(bad.clone());
527 let discovery = TestDiscovery(discovery);
528
529 let backends = Backends::new(Box::new(discovery));
530
531 let updated = AtomicBool::new(false);
533 backends
534 .update(|_| updated.store(true, Relaxed))
535 .await
536 .unwrap();
537 assert!(updated.load(Relaxed));
538
539 let backend = backends.get_backend();
540 assert!(backend.contains(&good1));
541 assert!(backend.contains(&good2));
542 assert!(backend.contains(&bad));
543
544 assert!(backends.ready(&good1));
545 assert!(backends.ready(&good2));
546 assert!(!backends.ready(&bad));
547 }
548
549 #[tokio::test]
550 async fn test_parallel_health_check() {
551 let discovery = discovery::Static::default();
552 let good1 = Backend::new("1.1.1.1:80").unwrap();
553 discovery.add(good1.clone());
554 let good2 = Backend::new("1.0.0.1:80").unwrap();
555 discovery.add(good2.clone());
556 let bad = Backend::new("127.0.0.1:79").unwrap();
557 discovery.add(bad.clone());
558
559 let mut backends = Backends::new(Box::new(discovery));
560 let check = health_check::TcpHealthCheck::new();
561 backends.set_health_check(check);
562
563 let updated = AtomicBool::new(false);
565 backends
566 .update(|_| updated.store(true, Relaxed))
567 .await
568 .unwrap();
569 assert!(updated.load(Relaxed));
570
571 backends.run_health_check(true).await;
572
573 assert!(backends.ready(&good1));
574 assert!(backends.ready(&good2));
575 assert!(!backends.ready(&bad));
576 }
577
578 mod thread_safety {
579 use super::*;
580
581 struct MockDiscovery {
582 expected: usize,
583 }
584 #[async_trait]
585 impl ServiceDiscovery for MockDiscovery {
586 async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
587 let mut d = BTreeSet::new();
588 let mut m = HashMap::with_capacity(self.expected);
589 for i in 0..self.expected {
590 let b = Backend::new(&format!("1.1.1.1:{i}")).unwrap();
591 m.insert(i as u64, true);
592 d.insert(b);
593 }
594 Ok((d, m))
595 }
596 }
597
598 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
599 async fn test_consistency() {
600 let expected = 3000;
601 let discovery = MockDiscovery { expected };
602 let lb = Arc::new(LoadBalancer::<selection::Consistent>::from_backends(
603 Backends::new(Box::new(discovery)),
604 ));
605 let lb2 = lb.clone();
606
607 tokio::spawn(async move {
608 assert!(lb2.update().await.is_ok());
609 });
610 let mut backend_count = 0;
611 while backend_count == 0 {
612 let backends = lb.backends();
613 backend_count = backends.backends.load_full().len();
614 }
615 assert_eq!(backend_count, expected);
616 assert!(lb.select_with(b"test", 1, |_, _| true).is_some());
617 }
618 }
619}