1#![allow(clippy::non_canonical_partial_ord_impl)]
22
23use arc_swap::ArcSwap;
24use derivative::Derivative;
25use futures::FutureExt;
26pub use http::Extensions;
27use pingora_core::protocols::l4::socket::SocketAddr;
28use pingora_error::{ErrorType, OrErr, Result};
29use std::collections::hash_map::DefaultHasher;
30use std::collections::{BTreeSet, HashMap};
31use std::hash::{Hash, Hasher};
32use std::io::Result as IoResult;
33use std::net::ToSocketAddrs;
34use std::sync::Arc;
35use std::time::Duration;
36
37mod background;
38pub mod discovery;
39pub mod health_check;
40pub mod selection;
41
42use discovery::ServiceDiscovery;
43use health_check::Health;
44use selection::UniqueIterator;
45use selection::{BackendIter, BackendSelection};
46
47pub mod prelude {
48 pub use crate::health_check::TcpHealthCheck;
49 pub use crate::selection::RoundRobin;
50 pub use crate::LoadBalancer;
51}
52
53#[derive(Derivative)]
55#[derivative(Clone, Hash, PartialEq, PartialOrd, Eq, Ord, Debug)]
56pub struct Backend {
57 pub addr: SocketAddr,
59 pub weight: usize,
62
63 #[derivative(PartialEq = "ignore")]
70 #[derivative(PartialOrd = "ignore")]
71 #[derivative(Hash = "ignore")]
72 #[derivative(Ord = "ignore")]
73 pub ext: Extensions,
74}
75
76impl Backend {
77 pub fn new(addr: &str) -> Result<Self> {
80 Self::new_with_weight(addr, 1)
81 }
82
83 pub fn new_with_weight(addr: &str, weight: usize) -> Result<Self> {
86 let addr = addr
87 .parse()
88 .or_err(ErrorType::InternalError, "invalid socket addr")?;
89 Ok(Backend {
90 addr: SocketAddr::Inet(addr),
91 weight,
92 ext: Extensions::new(),
93 })
94 }
96
97 pub(crate) fn hash_key(&self) -> u64 {
98 let mut hasher = DefaultHasher::new();
99 self.hash(&mut hasher);
100 hasher.finish()
101 }
102}
103
104impl std::ops::Deref for Backend {
105 type Target = SocketAddr;
106
107 fn deref(&self) -> &Self::Target {
108 &self.addr
109 }
110}
111
112impl std::ops::DerefMut for Backend {
113 fn deref_mut(&mut self) -> &mut Self::Target {
114 &mut self.addr
115 }
116}
117
118impl std::net::ToSocketAddrs for Backend {
119 type Iter = std::iter::Once<std::net::SocketAddr>;
120
121 fn to_socket_addrs(&self) -> std::io::Result<Self::Iter> {
122 self.addr.to_socket_addrs()
123 }
124}
125
126pub struct Backends {
132 discovery: Box<dyn ServiceDiscovery + Send + Sync + 'static>,
133 health_check: Option<Arc<dyn health_check::HealthCheck + Send + Sync + 'static>>,
134 backends: ArcSwap<BTreeSet<Backend>>,
135 health: ArcSwap<HashMap<u64, Health>>,
136}
137
138impl Backends {
139 pub fn new(discovery: Box<dyn ServiceDiscovery + Send + Sync + 'static>) -> Self {
143 Self {
144 discovery,
145 health_check: None,
146 backends: Default::default(),
147 health: Default::default(),
148 }
149 }
150
151 pub fn set_health_check(
153 &mut self,
154 hc: Box<dyn health_check::HealthCheck + Send + Sync + 'static>,
155 ) {
156 self.health_check = Some(hc.into())
157 }
158
159 fn do_update<F>(
163 &self,
164 new_backends: BTreeSet<Backend>,
165 enablement: HashMap<u64, bool>,
166 callback: F,
167 ) where
168 F: Fn(Arc<BTreeSet<Backend>>),
169 {
170 if (**self.backends.load()) != new_backends {
171 let old_health = self.health.load();
172 let mut health = HashMap::with_capacity(new_backends.len());
173 for backend in new_backends.iter() {
174 let hash_key = backend.hash_key();
175 let backend_health = old_health.get(&hash_key).cloned().unwrap_or_default();
177
178 if let Some(backend_enabled) = enablement.get(&hash_key) {
180 backend_health.enable(*backend_enabled);
181 }
182 health.insert(hash_key, backend_health);
183 }
184
185 let new_backends = Arc::new(new_backends);
190 callback(new_backends.clone());
191 self.backends.store(new_backends);
192 self.health.store(Arc::new(health));
193 } else {
194 for (hash_key, backend_enabled) in enablement.iter() {
196 if let Some(backend_health) = self.health.load().get(hash_key) {
199 backend_health.enable(*backend_enabled);
200 }
201 }
202 }
203 }
204
205 pub fn ready(&self, backend: &Backend) -> bool {
212 self.health
213 .load()
214 .get(&backend.hash_key())
215 .map_or(self.health_check.is_none(), |h| h.ready())
218 }
219
220 pub fn set_enable(&self, backend: &Backend, enabled: bool) {
227 if let Some(h) = self.health.load().get(&backend.hash_key()) {
229 h.enable(enabled)
230 };
231 }
232
233 pub fn get_backend(&self) -> Arc<BTreeSet<Backend>> {
235 self.backends.load_full()
236 }
237
238 pub async fn update<F>(&self, callback: F) -> Result<()>
243 where
244 F: Fn(Arc<BTreeSet<Backend>>),
245 {
246 let (new_backends, enablement) = self.discovery.discover().await?;
247 self.do_update(new_backends, enablement, callback);
248 Ok(())
249 }
250
251 pub async fn run_health_check(&self, parallel: bool) {
255 use crate::health_check::HealthCheck;
256 use log::{info, warn};
257 use pingora_runtime::current_handle;
258
259 async fn check_and_report(
260 backend: &Backend,
261 check: &Arc<dyn HealthCheck + Send + Sync>,
262 health_table: &HashMap<u64, Health>,
263 ) {
264 let errored = check.check(backend).await.err();
265 if let Some(h) = health_table.get(&backend.hash_key()) {
266 let flipped =
267 h.observe_health(errored.is_none(), check.health_threshold(errored.is_none()));
268 if flipped {
269 check.health_status_change(backend, errored.is_none()).await;
270 if let Some(e) = errored {
271 warn!("{backend:?} becomes unhealthy, {e}");
272 } else {
273 info!("{backend:?} becomes healthy");
274 }
275 }
276 }
277 }
278
279 let Some(health_check) = self.health_check.as_ref() else {
280 return;
281 };
282
283 let backends = self.backends.load();
284 if parallel {
285 let health_table = self.health.load_full();
286 let runtime = current_handle();
287 let jobs = backends.iter().map(|backend| {
288 let backend = backend.clone();
289 let check = health_check.clone();
290 let ht = health_table.clone();
291 runtime.spawn(async move {
292 check_and_report(&backend, &check, &ht).await;
293 })
294 });
295
296 futures::future::join_all(jobs).await;
297 } else {
298 for backend in backends.iter() {
299 check_and_report(backend, health_check, &self.health.load()).await;
300 }
301 }
302 }
303}
304
305pub struct LoadBalancer<S> {
311 backends: Backends,
312 selector: ArcSwap<S>,
313 pub health_check_frequency: Option<Duration>,
317 pub update_frequency: Option<Duration>,
321 pub parallel_health_check: bool,
323}
324
325impl<S: BackendSelection> LoadBalancer<S>
326where
327 S: BackendSelection + 'static,
328 S::Iter: BackendIter,
329{
330 pub fn try_from_iter<A, T: IntoIterator<Item = A>>(iter: T) -> IoResult<Self>
335 where
336 A: ToSocketAddrs,
337 {
338 let discovery = discovery::Static::try_from_iter(iter)?;
339 let backends = Backends::new(discovery);
340 let lb = Self::from_backends(backends);
341 lb.update()
342 .now_or_never()
343 .expect("static should not block")
344 .expect("static should not error");
345 Ok(lb)
346 }
347
348 pub fn from_backends(backends: Backends) -> Self {
350 let selector = ArcSwap::new(Arc::new(S::build(&backends.get_backend())));
351 LoadBalancer {
352 backends,
353 selector,
354 health_check_frequency: None,
355 update_frequency: None,
356 parallel_health_check: false,
357 }
358 }
359
360 pub async fn update(&self) -> Result<()> {
365 self.backends
366 .update(|backends| self.selector.store(Arc::new(S::build(&backends))))
367 .await
368 }
369
370 pub fn select(&self, key: &[u8], max_iterations: usize) -> Option<Backend> {
381 self.select_with(key, max_iterations, |_, health| health)
382 }
383
384 pub fn select_with<F>(&self, key: &[u8], max_iterations: usize, accept: F) -> Option<Backend>
392 where
393 F: Fn(&Backend, bool) -> bool,
394 {
395 let selection = self.selector.load();
396 let mut iter = UniqueIterator::new(selection.iter(key), max_iterations);
397 while let Some(b) = iter.get_next() {
398 if accept(&b, self.backends.ready(&b)) {
399 return Some(b);
400 }
401 }
402 None
403 }
404
405 pub fn set_health_check(
407 &mut self,
408 hc: Box<dyn health_check::HealthCheck + Send + Sync + 'static>,
409 ) {
410 self.backends.set_health_check(hc);
411 }
412
413 pub fn backends(&self) -> &Backends {
415 &self.backends
416 }
417}
418
419#[cfg(test)]
420mod test {
421 use std::sync::atomic::{AtomicBool, Ordering::Relaxed};
422
423 use super::*;
424 use async_trait::async_trait;
425
426 #[tokio::test]
427 async fn test_static_backends() {
428 let backends: LoadBalancer<selection::RoundRobin> =
429 LoadBalancer::try_from_iter(["1.1.1.1:80", "1.0.0.1:80"]).unwrap();
430
431 let backend1 = Backend::new("1.1.1.1:80").unwrap();
432 let backend2 = Backend::new("1.0.0.1:80").unwrap();
433 let backend = backends.backends().get_backend();
434 assert!(backend.contains(&backend1));
435 assert!(backend.contains(&backend2));
436 }
437
438 #[tokio::test]
439 async fn test_backends() {
440 let discovery = discovery::Static::default();
441 let good1 = Backend::new("1.1.1.1:80").unwrap();
442 discovery.add(good1.clone());
443 let good2 = Backend::new("1.0.0.1:80").unwrap();
444 discovery.add(good2.clone());
445 let bad = Backend::new("127.0.0.1:79").unwrap();
446 discovery.add(bad.clone());
447
448 let mut backends = Backends::new(Box::new(discovery));
449 let check = health_check::TcpHealthCheck::new();
450 backends.set_health_check(check);
451
452 let updated = AtomicBool::new(false);
454 backends
455 .update(|_| updated.store(true, Relaxed))
456 .await
457 .unwrap();
458 assert!(updated.load(Relaxed));
459
460 let updated = AtomicBool::new(false);
462 backends
463 .update(|_| updated.store(true, Relaxed))
464 .await
465 .unwrap();
466 assert!(!updated.load(Relaxed));
467
468 backends.run_health_check(false).await;
469
470 let backend = backends.get_backend();
471 assert!(backend.contains(&good1));
472 assert!(backend.contains(&good2));
473 assert!(backend.contains(&bad));
474
475 assert!(backends.ready(&good1));
476 assert!(backends.ready(&good2));
477 assert!(!backends.ready(&bad));
478 }
479 #[tokio::test]
480 async fn test_backends_with_ext() {
481 let discovery = discovery::Static::default();
482 let mut b1 = Backend::new("1.1.1.1:80").unwrap();
483 b1.ext.insert(true);
484 let mut b2 = Backend::new("1.0.0.1:80").unwrap();
485 b2.ext.insert(1u8);
486 discovery.add(b1.clone());
487 discovery.add(b2.clone());
488
489 let backends = Backends::new(Box::new(discovery));
490
491 backends.update(|_| {}).await.unwrap();
493
494 let backend = backends.get_backend();
495 assert!(backend.contains(&b1));
496 assert!(backend.contains(&b2));
497
498 let b2 = backend.first().unwrap();
499 assert_eq!(b2.ext.get::<u8>(), Some(&1));
500
501 let b1 = backend.last().unwrap();
502 assert_eq!(b1.ext.get::<bool>(), Some(&true));
503 }
504
505 #[tokio::test]
506 async fn test_discovery_readiness() {
507 use discovery::Static;
508
509 struct TestDiscovery(Static);
510 #[async_trait]
511 impl ServiceDiscovery for TestDiscovery {
512 async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
513 let bad = Backend::new("127.0.0.1:79").unwrap();
514 let (backends, mut readiness) = self.0.discover().await?;
515 readiness.insert(bad.hash_key(), false);
516 Ok((backends, readiness))
517 }
518 }
519 let discovery = Static::default();
520 let good1 = Backend::new("1.1.1.1:80").unwrap();
521 discovery.add(good1.clone());
522 let good2 = Backend::new("1.0.0.1:80").unwrap();
523 discovery.add(good2.clone());
524 let bad = Backend::new("127.0.0.1:79").unwrap();
525 discovery.add(bad.clone());
526 let discovery = TestDiscovery(discovery);
527
528 let backends = Backends::new(Box::new(discovery));
529
530 let updated = AtomicBool::new(false);
532 backends
533 .update(|_| updated.store(true, Relaxed))
534 .await
535 .unwrap();
536 assert!(updated.load(Relaxed));
537
538 let backend = backends.get_backend();
539 assert!(backend.contains(&good1));
540 assert!(backend.contains(&good2));
541 assert!(backend.contains(&bad));
542
543 assert!(backends.ready(&good1));
544 assert!(backends.ready(&good2));
545 assert!(!backends.ready(&bad));
546 }
547
548 #[tokio::test]
549 async fn test_parallel_health_check() {
550 let discovery = discovery::Static::default();
551 let good1 = Backend::new("1.1.1.1:80").unwrap();
552 discovery.add(good1.clone());
553 let good2 = Backend::new("1.0.0.1:80").unwrap();
554 discovery.add(good2.clone());
555 let bad = Backend::new("127.0.0.1:79").unwrap();
556 discovery.add(bad.clone());
557
558 let mut backends = Backends::new(Box::new(discovery));
559 let check = health_check::TcpHealthCheck::new();
560 backends.set_health_check(check);
561
562 let updated = AtomicBool::new(false);
564 backends
565 .update(|_| updated.store(true, Relaxed))
566 .await
567 .unwrap();
568 assert!(updated.load(Relaxed));
569
570 backends.run_health_check(true).await;
571
572 assert!(backends.ready(&good1));
573 assert!(backends.ready(&good2));
574 assert!(!backends.ready(&bad));
575 }
576
577 mod thread_safety {
578 use super::*;
579
580 struct MockDiscovery {
581 expected: usize,
582 }
583 #[async_trait]
584 impl ServiceDiscovery for MockDiscovery {
585 async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
586 let mut d = BTreeSet::new();
587 let mut m = HashMap::with_capacity(self.expected);
588 for i in 0..self.expected {
589 let b = Backend::new(&format!("1.1.1.1:{i}")).unwrap();
590 m.insert(i as u64, true);
591 d.insert(b);
592 }
593 Ok((d, m))
594 }
595 }
596
597 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
598 async fn test_consistency() {
599 let expected = 3000;
600 let discovery = MockDiscovery { expected };
601 let lb = Arc::new(LoadBalancer::<selection::Consistent>::from_backends(
602 Backends::new(Box::new(discovery)),
603 ));
604 let lb2 = lb.clone();
605
606 tokio::spawn(async move {
607 assert!(lb2.update().await.is_ok());
608 });
609 let mut backend_count = 0;
610 while backend_count == 0 {
611 let backends = lb.backends();
612 backend_count = backends.backends.load_full().len();
613 }
614 assert_eq!(backend_count, expected);
615 assert!(lb.select_with(b"test", 1, |_, _| true).is_some());
616 }
617 }
618}