pingora_load_balancing/
lib.rs

1// Copyright 2025 Cloudflare, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! # Pingora Load Balancing utilities
16//! This crate provides common service discovery, health check and load balancing
17//! algorithms for proxies to use.
18
19// https://github.com/mcarton/rust-derivative/issues/112
20// False positive for macro generated code
21#![allow(clippy::non_canonical_partial_ord_impl)]
22
23use arc_swap::ArcSwap;
24use derivative::Derivative;
25use futures::FutureExt;
26pub use http::Extensions;
27use pingora_core::protocols::l4::socket::SocketAddr;
28use pingora_error::{ErrorType, OrErr, Result};
29use std::collections::hash_map::DefaultHasher;
30use std::collections::{BTreeSet, HashMap};
31use std::hash::{Hash, Hasher};
32use std::io::Result as IoResult;
33use std::net::ToSocketAddrs;
34use std::sync::Arc;
35use std::time::Duration;
36
37mod background;
38pub mod discovery;
39pub mod health_check;
40pub mod selection;
41
42use discovery::ServiceDiscovery;
43use health_check::Health;
44use selection::UniqueIterator;
45use selection::{BackendIter, BackendSelection};
46
47pub mod prelude {
48    pub use crate::health_check::TcpHealthCheck;
49    pub use crate::selection::RoundRobin;
50    pub use crate::LoadBalancer;
51}
52
53/// [Backend] represents a server to proxy or connect to.
54#[derive(Derivative)]
55#[derivative(Clone, Hash, PartialEq, PartialOrd, Eq, Ord, Debug)]
56pub struct Backend {
57    /// The address to the backend server.
58    pub addr: SocketAddr,
59    /// The relative weight of the server. Load balancing algorithms will
60    /// proportionally distributed traffic according to this value.
61    pub weight: usize,
62
63    /// The extension field to put arbitrary data to annotate the Backend.
64    /// The data added here is opaque to this crate hence the data is ignored by
65    /// functionalities of this crate. For example, two backends with the same
66    /// [SocketAddr] and the same weight but different `ext` data are considered
67    /// identical.
68    /// See [Extensions] for how to add and read the data.
69    #[derivative(PartialEq = "ignore")]
70    #[derivative(PartialOrd = "ignore")]
71    #[derivative(Hash = "ignore")]
72    #[derivative(Ord = "ignore")]
73    pub ext: Extensions,
74}
75
76impl Backend {
77    /// Create a new [Backend] with `weight` 1. The function will try to parse
78    ///  `addr` into a [std::net::SocketAddr].
79    pub fn new(addr: &str) -> Result<Self> {
80        Self::new_with_weight(addr, 1)
81    }
82
83    /// Creates a new [Backend] with the specified `weight`. The function will try to parse
84    /// `addr` into a [std::net::SocketAddr].
85    pub fn new_with_weight(addr: &str, weight: usize) -> Result<Self> {
86        let addr = addr
87            .parse()
88            .or_err(ErrorType::InternalError, "invalid socket addr")?;
89        Ok(Backend {
90            addr: SocketAddr::Inet(addr),
91            weight,
92            ext: Extensions::new(),
93        })
94        // TODO: UDS
95    }
96
97    pub(crate) fn hash_key(&self) -> u64 {
98        let mut hasher = DefaultHasher::new();
99        self.hash(&mut hasher);
100        hasher.finish()
101    }
102}
103
104impl std::ops::Deref for Backend {
105    type Target = SocketAddr;
106
107    fn deref(&self) -> &Self::Target {
108        &self.addr
109    }
110}
111
112impl std::ops::DerefMut for Backend {
113    fn deref_mut(&mut self) -> &mut Self::Target {
114        &mut self.addr
115    }
116}
117
118impl std::net::ToSocketAddrs for Backend {
119    type Iter = std::iter::Once<std::net::SocketAddr>;
120
121    fn to_socket_addrs(&self) -> std::io::Result<Self::Iter> {
122        self.addr.to_socket_addrs()
123    }
124}
125
126/// [Backends] is a collection of [Backend]s.
127///
128/// It includes a service discovery method (static or dynamic) to discover all
129/// the available backends as well as an optional health check method to probe the liveness
130/// of each backend.
131pub struct Backends {
132    discovery: Box<dyn ServiceDiscovery + Send + Sync + 'static>,
133    health_check: Option<Arc<dyn health_check::HealthCheck + Send + Sync + 'static>>,
134    backends: ArcSwap<BTreeSet<Backend>>,
135    health: ArcSwap<HashMap<u64, Health>>,
136}
137
138impl Backends {
139    /// Create a new [Backends] with the given [ServiceDiscovery] implementation.
140    ///
141    /// The health check method is by default empty.
142    pub fn new(discovery: Box<dyn ServiceDiscovery + Send + Sync + 'static>) -> Self {
143        Self {
144            discovery,
145            health_check: None,
146            backends: Default::default(),
147            health: Default::default(),
148        }
149    }
150
151    /// Set the health check method. See [health_check] for the methods provided.
152    pub fn set_health_check(
153        &mut self,
154        hc: Box<dyn health_check::HealthCheck + Send + Sync + 'static>,
155    ) {
156        self.health_check = Some(hc.into())
157    }
158
159    /// Updates backends when the new is different from the current set,
160    /// the callback will be invoked when the new set of backend is different
161    /// from the current one so that the caller can update the selector accordingly.
162    fn do_update<F>(
163        &self,
164        new_backends: BTreeSet<Backend>,
165        enablement: HashMap<u64, bool>,
166        callback: F,
167    ) where
168        F: Fn(Arc<BTreeSet<Backend>>),
169    {
170        if (**self.backends.load()) != new_backends {
171            let old_health = self.health.load();
172            let mut health = HashMap::with_capacity(new_backends.len());
173            for backend in new_backends.iter() {
174                let hash_key = backend.hash_key();
175                // use the default health if the backend is new
176                let backend_health = old_health.get(&hash_key).cloned().unwrap_or_default();
177
178                // override enablement
179                if let Some(backend_enabled) = enablement.get(&hash_key) {
180                    backend_health.enable(*backend_enabled);
181                }
182                health.insert(hash_key, backend_health);
183            }
184
185            // TODO: put this all under 1 ArcSwap so the update is atomic
186            // It's important the `callback()` executes first since computing selector backends might
187            // be expensive. For example, if a caller checks `backends` to see if any are available
188            // they may encounter false positives if the selector isn't ready yet.
189            let new_backends = Arc::new(new_backends);
190            callback(new_backends.clone());
191            self.backends.store(new_backends);
192            self.health.store(Arc::new(health));
193        } else {
194            // no backend change, just check enablement
195            for (hash_key, backend_enabled) in enablement.iter() {
196                // override enablement if set
197                // this get should always be Some(_) because we already populate `health`` for all known backends
198                if let Some(backend_health) = self.health.load().get(hash_key) {
199                    backend_health.enable(*backend_enabled);
200                }
201            }
202        }
203    }
204
205    /// Whether a certain [Backend] is ready to serve traffic.
206    ///
207    /// This function returns true when the backend is both healthy and enabled.
208    /// This function returns true when the health check is unset but the backend is enabled.
209    /// When the health check is set, this function will return false for the `backend` it
210    /// doesn't know.
211    pub fn ready(&self, backend: &Backend) -> bool {
212        self.health
213            .load()
214            .get(&backend.hash_key())
215            // Racing: return `None` when this function is called between the
216            // backend store and the health store
217            .map_or(self.health_check.is_none(), |h| h.ready())
218    }
219
220    /// Manually set if a [Backend] is ready to serve traffic.
221    ///
222    /// This method does not override the health of the backend. It is meant to be used
223    /// to stop a backend from accepting traffic when it is still healthy.
224    ///
225    /// This method is noop when the given backend doesn't exist in the service discovery.
226    pub fn set_enable(&self, backend: &Backend, enabled: bool) {
227        // this should always be Some(_) because health is always populated during update
228        if let Some(h) = self.health.load().get(&backend.hash_key()) {
229            h.enable(enabled)
230        };
231    }
232
233    /// Return the collection of the backends.
234    pub fn get_backend(&self) -> Arc<BTreeSet<Backend>> {
235        self.backends.load_full()
236    }
237
238    /// Call the service discovery method to update the collection of backends.
239    ///
240    /// The callback will be invoked when the new set of backend is different
241    /// from the current one so that the caller can update the selector accordingly.
242    pub async fn update<F>(&self, callback: F) -> Result<()>
243    where
244        F: Fn(Arc<BTreeSet<Backend>>),
245    {
246        let (new_backends, enablement) = self.discovery.discover().await?;
247        self.do_update(new_backends, enablement, callback);
248        Ok(())
249    }
250
251    /// Run health check on all backends if it is set.
252    ///
253    /// When `parallel: true`, all backends are checked in parallel instead of sequentially
254    pub async fn run_health_check(&self, parallel: bool) {
255        use crate::health_check::HealthCheck;
256        use log::{info, warn};
257        use pingora_runtime::current_handle;
258
259        async fn check_and_report(
260            backend: &Backend,
261            check: &Arc<dyn HealthCheck + Send + Sync>,
262            health_table: &HashMap<u64, Health>,
263        ) {
264            let errored = check.check(backend).await.err();
265            if let Some(h) = health_table.get(&backend.hash_key()) {
266                let flipped =
267                    h.observe_health(errored.is_none(), check.health_threshold(errored.is_none()));
268                if flipped {
269                    check.health_status_change(backend, errored.is_none()).await;
270                    let summary = check.backend_summary(backend);
271                    if let Some(e) = errored {
272                        warn!("{summary} becomes unhealthy, {e}");
273                    } else {
274                        info!("{summary} becomes healthy");
275                    }
276                }
277            }
278        }
279
280        let Some(health_check) = self.health_check.as_ref() else {
281            return;
282        };
283
284        let backends = self.backends.load();
285        if parallel {
286            let health_table = self.health.load_full();
287            let runtime = current_handle();
288            let jobs = backends.iter().map(|backend| {
289                let backend = backend.clone();
290                let check = health_check.clone();
291                let ht = health_table.clone();
292                runtime.spawn(async move {
293                    check_and_report(&backend, &check, &ht).await;
294                })
295            });
296
297            futures::future::join_all(jobs).await;
298        } else {
299            for backend in backends.iter() {
300                check_and_report(backend, health_check, &self.health.load()).await;
301            }
302        }
303    }
304}
305
306/// A [LoadBalancer] instance contains the service discovery, health check and backend selection
307/// all together.
308///
309/// In order to run service discovery and health check at the designated frequencies, the [LoadBalancer]
310/// needs to be run as a [pingora_core::services::background::BackgroundService].
311pub struct LoadBalancer<S> {
312    backends: Backends,
313    selector: ArcSwap<S>,
314    /// How frequent the health check logic (if set) should run.
315    ///
316    /// If `None`, the health check logic will only run once at the beginning.
317    pub health_check_frequency: Option<Duration>,
318    /// How frequent the service discovery should run.
319    ///
320    /// If `None`, the service discovery will only run once at the beginning.
321    pub update_frequency: Option<Duration>,
322    /// Whether to run health check to all backends in parallel. Default is false.
323    pub parallel_health_check: bool,
324}
325
326impl<S: BackendSelection> LoadBalancer<S>
327where
328    S: BackendSelection + 'static,
329    S::Iter: BackendIter,
330{
331    /// Build a [LoadBalancer] with static backends created from the iter.
332    ///
333    /// Note: [ToSocketAddrs] will invoke blocking network IO for DNS lookup if
334    /// the input cannot be directly parsed as [SocketAddr].
335    pub fn try_from_iter<A, T: IntoIterator<Item = A>>(iter: T) -> IoResult<Self>
336    where
337        A: ToSocketAddrs,
338    {
339        let discovery = discovery::Static::try_from_iter(iter)?;
340        let backends = Backends::new(discovery);
341        let lb = Self::from_backends(backends);
342        lb.update()
343            .now_or_never()
344            .expect("static should not block")
345            .expect("static should not error");
346        Ok(lb)
347    }
348
349    /// Build a [LoadBalancer] with the given [Backends].
350    pub fn from_backends(backends: Backends) -> Self {
351        let selector = ArcSwap::new(Arc::new(S::build(&backends.get_backend())));
352        LoadBalancer {
353            backends,
354            selector,
355            health_check_frequency: None,
356            update_frequency: None,
357            parallel_health_check: false,
358        }
359    }
360
361    /// Run the service discovery and update the selection algorithm.
362    ///
363    /// This function will be called every `update_frequency` if this [LoadBalancer] instance
364    /// is running as a background service.
365    pub async fn update(&self) -> Result<()> {
366        self.backends
367            .update(|backends| self.selector.store(Arc::new(S::build(&backends))))
368            .await
369    }
370
371    /// Return the first healthy [Backend] according to the selection algorithm and the
372    /// health check results.
373    ///
374    /// The `key` is used for hash based selection and is ignored if the selection is random or
375    /// round robin.
376    ///
377    /// the `max_iterations` is there to bound the search time for the next Backend. In certain
378    /// algorithm like Ketama hashing, the search for the next backend is linear and could take
379    /// a lot steps.
380    // TODO: consider remove `max_iterations` as users have no idea how to set it.
381    pub fn select(&self, key: &[u8], max_iterations: usize) -> Option<Backend> {
382        self.select_with(key, max_iterations, |_, health| health)
383    }
384
385    /// Similar to [Self::select], return the first healthy [Backend] according to the selection algorithm
386    /// and the user defined `accept` function.
387    ///
388    /// The `accept` function takes two inputs, the backend being selected and the internal health of that
389    /// backend. The function can do things like ignoring the internal health checks or skipping this backend
390    /// because it failed before. The `accept` function is called multiple times iterating over backends
391    /// until it returns `true`.
392    pub fn select_with<F>(&self, key: &[u8], max_iterations: usize, accept: F) -> Option<Backend>
393    where
394        F: Fn(&Backend, bool) -> bool,
395    {
396        let selection = self.selector.load();
397        let mut iter = UniqueIterator::new(selection.iter(key), max_iterations);
398        while let Some(b) = iter.get_next() {
399            if accept(&b, self.backends.ready(&b)) {
400                return Some(b);
401            }
402        }
403        None
404    }
405
406    /// Set the health check method. See [health_check].
407    pub fn set_health_check(
408        &mut self,
409        hc: Box<dyn health_check::HealthCheck + Send + Sync + 'static>,
410    ) {
411        self.backends.set_health_check(hc);
412    }
413
414    /// Access the [Backends] of this [LoadBalancer]
415    pub fn backends(&self) -> &Backends {
416        &self.backends
417    }
418}
419
420#[cfg(test)]
421mod test {
422    use std::sync::atomic::{AtomicBool, Ordering::Relaxed};
423
424    use super::*;
425    use async_trait::async_trait;
426
427    #[tokio::test]
428    async fn test_static_backends() {
429        let backends: LoadBalancer<selection::RoundRobin> =
430            LoadBalancer::try_from_iter(["1.1.1.1:80", "1.0.0.1:80"]).unwrap();
431
432        let backend1 = Backend::new("1.1.1.1:80").unwrap();
433        let backend2 = Backend::new("1.0.0.1:80").unwrap();
434        let backend = backends.backends().get_backend();
435        assert!(backend.contains(&backend1));
436        assert!(backend.contains(&backend2));
437    }
438
439    #[tokio::test]
440    async fn test_backends() {
441        let discovery = discovery::Static::default();
442        let good1 = Backend::new("1.1.1.1:80").unwrap();
443        discovery.add(good1.clone());
444        let good2 = Backend::new("1.0.0.1:80").unwrap();
445        discovery.add(good2.clone());
446        let bad = Backend::new("127.0.0.1:79").unwrap();
447        discovery.add(bad.clone());
448
449        let mut backends = Backends::new(Box::new(discovery));
450        let check = health_check::TcpHealthCheck::new();
451        backends.set_health_check(check);
452
453        // true: new backend discovered
454        let updated = AtomicBool::new(false);
455        backends
456            .update(|_| updated.store(true, Relaxed))
457            .await
458            .unwrap();
459        assert!(updated.load(Relaxed));
460
461        // false: no new backend discovered
462        let updated = AtomicBool::new(false);
463        backends
464            .update(|_| updated.store(true, Relaxed))
465            .await
466            .unwrap();
467        assert!(!updated.load(Relaxed));
468
469        backends.run_health_check(false).await;
470
471        let backend = backends.get_backend();
472        assert!(backend.contains(&good1));
473        assert!(backend.contains(&good2));
474        assert!(backend.contains(&bad));
475
476        assert!(backends.ready(&good1));
477        assert!(backends.ready(&good2));
478        assert!(!backends.ready(&bad));
479    }
480    #[tokio::test]
481    async fn test_backends_with_ext() {
482        let discovery = discovery::Static::default();
483        let mut b1 = Backend::new("1.1.1.1:80").unwrap();
484        b1.ext.insert(true);
485        let mut b2 = Backend::new("1.0.0.1:80").unwrap();
486        b2.ext.insert(1u8);
487        discovery.add(b1.clone());
488        discovery.add(b2.clone());
489
490        let backends = Backends::new(Box::new(discovery));
491
492        // fill in the backends
493        backends.update(|_| {}).await.unwrap();
494
495        let backend = backends.get_backend();
496        assert!(backend.contains(&b1));
497        assert!(backend.contains(&b2));
498
499        let b2 = backend.first().unwrap();
500        assert_eq!(b2.ext.get::<u8>(), Some(&1));
501
502        let b1 = backend.last().unwrap();
503        assert_eq!(b1.ext.get::<bool>(), Some(&true));
504    }
505
506    #[tokio::test]
507    async fn test_discovery_readiness() {
508        use discovery::Static;
509
510        struct TestDiscovery(Static);
511        #[async_trait]
512        impl ServiceDiscovery for TestDiscovery {
513            async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
514                let bad = Backend::new("127.0.0.1:79").unwrap();
515                let (backends, mut readiness) = self.0.discover().await?;
516                readiness.insert(bad.hash_key(), false);
517                Ok((backends, readiness))
518            }
519        }
520        let discovery = Static::default();
521        let good1 = Backend::new("1.1.1.1:80").unwrap();
522        discovery.add(good1.clone());
523        let good2 = Backend::new("1.0.0.1:80").unwrap();
524        discovery.add(good2.clone());
525        let bad = Backend::new("127.0.0.1:79").unwrap();
526        discovery.add(bad.clone());
527        let discovery = TestDiscovery(discovery);
528
529        let backends = Backends::new(Box::new(discovery));
530
531        // true: new backend discovered
532        let updated = AtomicBool::new(false);
533        backends
534            .update(|_| updated.store(true, Relaxed))
535            .await
536            .unwrap();
537        assert!(updated.load(Relaxed));
538
539        let backend = backends.get_backend();
540        assert!(backend.contains(&good1));
541        assert!(backend.contains(&good2));
542        assert!(backend.contains(&bad));
543
544        assert!(backends.ready(&good1));
545        assert!(backends.ready(&good2));
546        assert!(!backends.ready(&bad));
547    }
548
549    #[tokio::test]
550    async fn test_parallel_health_check() {
551        let discovery = discovery::Static::default();
552        let good1 = Backend::new("1.1.1.1:80").unwrap();
553        discovery.add(good1.clone());
554        let good2 = Backend::new("1.0.0.1:80").unwrap();
555        discovery.add(good2.clone());
556        let bad = Backend::new("127.0.0.1:79").unwrap();
557        discovery.add(bad.clone());
558
559        let mut backends = Backends::new(Box::new(discovery));
560        let check = health_check::TcpHealthCheck::new();
561        backends.set_health_check(check);
562
563        // true: new backend discovered
564        let updated = AtomicBool::new(false);
565        backends
566            .update(|_| updated.store(true, Relaxed))
567            .await
568            .unwrap();
569        assert!(updated.load(Relaxed));
570
571        backends.run_health_check(true).await;
572
573        assert!(backends.ready(&good1));
574        assert!(backends.ready(&good2));
575        assert!(!backends.ready(&bad));
576    }
577
578    mod thread_safety {
579        use super::*;
580
581        struct MockDiscovery {
582            expected: usize,
583        }
584        #[async_trait]
585        impl ServiceDiscovery for MockDiscovery {
586            async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
587                let mut d = BTreeSet::new();
588                let mut m = HashMap::with_capacity(self.expected);
589                for i in 0..self.expected {
590                    let b = Backend::new(&format!("1.1.1.1:{i}")).unwrap();
591                    m.insert(i as u64, true);
592                    d.insert(b);
593                }
594                Ok((d, m))
595            }
596        }
597
598        #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
599        async fn test_consistency() {
600            let expected = 3000;
601            let discovery = MockDiscovery { expected };
602            let lb = Arc::new(LoadBalancer::<selection::Consistent>::from_backends(
603                Backends::new(Box::new(discovery)),
604            ));
605            let lb2 = lb.clone();
606
607            tokio::spawn(async move {
608                assert!(lb2.update().await.is_ok());
609            });
610            let mut backend_count = 0;
611            while backend_count == 0 {
612                let backends = lb.backends();
613                backend_count = backends.backends.load_full().len();
614            }
615            assert_eq!(backend_count, expected);
616            assert!(lb.select_with(b"test", 1, |_, _| true).is_some());
617        }
618    }
619}