Skip to main content

pingora_load_balancing/
lib.rs

1// Copyright 2026 Cloudflare, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! # Pingora Load Balancing utilities
16//! This crate provides common service discovery, health check and load balancing
17//! algorithms for proxies to use.
18
19// https://github.com/mcarton/rust-derivative/issues/112
20// False positive for macro generated code
21#![allow(clippy::non_canonical_partial_ord_impl)]
22
23use arc_swap::ArcSwap;
24use derivative::Derivative;
25use futures::FutureExt;
26pub use http::Extensions;
27use pingora_core::protocols::l4::socket::SocketAddr;
28use pingora_error::{ErrorType, OrErr, Result};
29use std::collections::hash_map::DefaultHasher;
30use std::collections::{BTreeSet, HashMap};
31use std::hash::{Hash, Hasher};
32use std::io::Result as IoResult;
33use std::net::ToSocketAddrs;
34use std::sync::Arc;
35use std::time::Duration;
36
37mod background;
38pub mod discovery;
39pub mod health_check;
40pub mod selection;
41
42use discovery::ServiceDiscovery;
43use health_check::Health;
44use selection::UniqueIterator;
45use selection::{BackendIter, BackendSelection};
46
47pub mod prelude {
48    pub use crate::health_check::TcpHealthCheck;
49    pub use crate::selection::RoundRobin;
50    pub use crate::LoadBalancer;
51}
52
53/// [Backend] represents a server to proxy or connect to.
54#[derive(Derivative)]
55#[derivative(Clone, Hash, PartialEq, PartialOrd, Eq, Ord, Debug)]
56pub struct Backend {
57    /// The address to the backend server.
58    pub addr: SocketAddr,
59    /// The relative weight of the server. Load balancing algorithms will
60    /// proportionally distributed traffic according to this value.
61    pub weight: usize,
62
63    /// The extension field to put arbitrary data to annotate the Backend.
64    /// The data added here is opaque to this crate hence the data is ignored by
65    /// functionalities of this crate. For example, two backends with the same
66    /// [SocketAddr] and the same weight but different `ext` data are considered
67    /// identical.
68    /// See [Extensions] for how to add and read the data.
69    #[derivative(PartialEq = "ignore")]
70    #[derivative(PartialOrd = "ignore")]
71    #[derivative(Hash = "ignore")]
72    #[derivative(Ord = "ignore")]
73    pub ext: Extensions,
74}
75
76impl Backend {
77    /// Create a new [Backend] with `weight` 1. The function will try to parse
78    ///  `addr` into a [std::net::SocketAddr].
79    pub fn new(addr: &str) -> Result<Self> {
80        Self::new_with_weight(addr, 1)
81    }
82
83    /// Creates a new [Backend] with the specified `weight`. The function will try to parse
84    /// `addr` into a [std::net::SocketAddr].
85    pub fn new_with_weight(addr: &str, weight: usize) -> Result<Self> {
86        let addr = addr
87            .parse()
88            .or_err(ErrorType::InternalError, "invalid socket addr")?;
89        Ok(Backend {
90            addr: SocketAddr::Inet(addr),
91            weight,
92            ext: Extensions::new(),
93        })
94        // TODO: UDS
95    }
96
97    pub(crate) fn hash_key(&self) -> u64 {
98        let mut hasher = DefaultHasher::new();
99        self.hash(&mut hasher);
100        hasher.finish()
101    }
102}
103
104impl std::ops::Deref for Backend {
105    type Target = SocketAddr;
106
107    fn deref(&self) -> &Self::Target {
108        &self.addr
109    }
110}
111
112impl std::ops::DerefMut for Backend {
113    fn deref_mut(&mut self) -> &mut Self::Target {
114        &mut self.addr
115    }
116}
117
118impl std::net::ToSocketAddrs for Backend {
119    type Iter = std::iter::Once<std::net::SocketAddr>;
120
121    fn to_socket_addrs(&self) -> std::io::Result<Self::Iter> {
122        self.addr.to_socket_addrs()
123    }
124}
125
126/// [Backends] is a collection of [Backend]s.
127///
128/// It includes a service discovery method (static or dynamic) to discover all
129/// the available backends as well as an optional health check method to probe the liveness
130/// of each backend.
131pub struct Backends {
132    discovery: Box<dyn ServiceDiscovery + Send + Sync + 'static>,
133    health_check: Option<Arc<dyn health_check::HealthCheck + Send + Sync + 'static>>,
134    backends: ArcSwap<BTreeSet<Backend>>,
135    health: ArcSwap<HashMap<u64, Health>>,
136}
137
138impl Backends {
139    /// Create a new [Backends] with the given [ServiceDiscovery] implementation.
140    ///
141    /// The health check method is by default empty.
142    pub fn new(discovery: Box<dyn ServiceDiscovery + Send + Sync + 'static>) -> Self {
143        Self {
144            discovery,
145            health_check: None,
146            backends: Default::default(),
147            health: Default::default(),
148        }
149    }
150
151    /// Set the health check method. See [health_check] for the methods provided.
152    pub fn set_health_check(
153        &mut self,
154        hc: Box<dyn health_check::HealthCheck + Send + Sync + 'static>,
155    ) {
156        self.health_check = Some(hc.into())
157    }
158
159    /// Updates backends when the new is different from the current set,
160    /// the callback will be invoked when the new set of backend is different
161    /// from the current one so that the caller can update the selector accordingly.
162    fn do_update<F>(
163        &self,
164        new_backends: BTreeSet<Backend>,
165        enablement: HashMap<u64, bool>,
166        callback: F,
167    ) where
168        F: Fn(Arc<BTreeSet<Backend>>),
169    {
170        if (**self.backends.load()) != new_backends {
171            let old_health = self.health.load();
172            let mut health = HashMap::with_capacity(new_backends.len());
173            for backend in new_backends.iter() {
174                let hash_key = backend.hash_key();
175                // use the default health if the backend is new
176                let backend_health = old_health.get(&hash_key).cloned().unwrap_or_default();
177
178                // override enablement
179                if let Some(backend_enabled) = enablement.get(&hash_key) {
180                    backend_health.enable(*backend_enabled);
181                }
182                health.insert(hash_key, backend_health);
183            }
184
185            // TODO: put this all under 1 ArcSwap so the update is atomic
186            // It's important the `callback()` executes first since computing selector backends might
187            // be expensive. For example, if a caller checks `backends` to see if any are available
188            // they may encounter false positives if the selector isn't ready yet.
189            let new_backends = Arc::new(new_backends);
190            callback(new_backends.clone());
191            self.backends.store(new_backends);
192            self.health.store(Arc::new(health));
193        } else {
194            // no backend change, just check enablement
195            for (hash_key, backend_enabled) in enablement.iter() {
196                // override enablement if set
197                // this get should always be Some(_) because we already populate `health`` for all known backends
198                if let Some(backend_health) = self.health.load().get(hash_key) {
199                    backend_health.enable(*backend_enabled);
200                }
201            }
202        }
203    }
204
205    /// Whether a certain [Backend] is ready to serve traffic.
206    ///
207    /// This function returns true when the backend is both healthy and enabled.
208    /// This function returns true when the health check is unset but the backend is enabled.
209    /// When the health check is set, this function will return false for the `backend` it
210    /// doesn't know.
211    pub fn ready(&self, backend: &Backend) -> bool {
212        self.health
213            .load()
214            .get(&backend.hash_key())
215            // Racing: return `None` when this function is called between the
216            // backend store and the health store
217            .map_or(self.health_check.is_none(), |h| h.ready())
218    }
219
220    /// Manually set if a [Backend] is ready to serve traffic.
221    ///
222    /// This method does not override the health of the backend. It is meant to be used
223    /// to stop a backend from accepting traffic when it is still healthy.
224    ///
225    /// This method is noop when the given backend doesn't exist in the service discovery.
226    pub fn set_enable(&self, backend: &Backend, enabled: bool) {
227        // this should always be Some(_) because health is always populated during update
228        if let Some(h) = self.health.load().get(&backend.hash_key()) {
229            h.enable(enabled)
230        };
231    }
232
233    /// Return the collection of the backends.
234    pub fn get_backend(&self) -> Arc<BTreeSet<Backend>> {
235        self.backends.load_full()
236    }
237
238    /// Call the service discovery method to update the collection of backends.
239    ///
240    /// The callback will be invoked when the new set of backend is different
241    /// from the current one so that the caller can update the selector accordingly.
242    pub async fn update<F>(&self, callback: F) -> Result<()>
243    where
244        F: Fn(Arc<BTreeSet<Backend>>),
245    {
246        let (new_backends, enablement) = self.discovery.discover().await?;
247        self.do_update(new_backends, enablement, callback);
248        Ok(())
249    }
250
251    /// Run health check on all backends if it is set.
252    ///
253    /// When `parallel: true`, all backends are checked in parallel instead of sequentially
254    pub async fn run_health_check(&self, parallel: bool) {
255        use crate::health_check::HealthCheck;
256        use log::{info, warn};
257        use pingora_runtime::current_handle;
258
259        async fn check_and_report(
260            backend: &Backend,
261            check: &Arc<dyn HealthCheck + Send + Sync>,
262            health_table: &HashMap<u64, Health>,
263        ) {
264            let errored = check.check(backend).await.err();
265            if let Some(h) = health_table.get(&backend.hash_key()) {
266                let flipped =
267                    h.observe_health(errored.is_none(), check.health_threshold(errored.is_none()));
268                if flipped {
269                    check.health_status_change(backend, errored.is_none()).await;
270                    let summary = check.backend_summary(backend);
271                    if let Some(e) = errored {
272                        warn!("{summary} becomes unhealthy, {e}");
273                    } else {
274                        info!("{summary} becomes healthy");
275                    }
276                }
277            }
278        }
279
280        let Some(health_check) = self.health_check.as_ref() else {
281            return;
282        };
283
284        let backends = self.backends.load();
285        if parallel {
286            let health_table = self.health.load_full();
287            let runtime = current_handle();
288            let jobs = backends.iter().map(|backend| {
289                let backend = backend.clone();
290                let check = health_check.clone();
291                let ht = health_table.clone();
292                runtime.spawn(async move {
293                    check_and_report(&backend, &check, &ht).await;
294                })
295            });
296
297            futures::future::join_all(jobs).await;
298        } else {
299            for backend in backends.iter() {
300                check_and_report(backend, health_check, &self.health.load()).await;
301            }
302        }
303    }
304}
305
306/// A [LoadBalancer] instance contains the service discovery, health check and backend selection
307/// all together.
308///
309/// In order to run service discovery and health check at the designated frequencies, the [LoadBalancer]
310/// needs to be run as a [pingora_core::services::background::BackgroundService].
311pub struct LoadBalancer<S> {
312    backends: Backends,
313    selector: ArcSwap<S>,
314    /// How frequent the health check logic (if set) should run.
315    ///
316    /// If `None`, the health check logic will only run once at the beginning.
317    pub health_check_frequency: Option<Duration>,
318    /// How frequent the service discovery should run.
319    ///
320    /// If `None`, the service discovery will only run once at the beginning.
321    pub update_frequency: Option<Duration>,
322    /// Whether to run health check to all backends in parallel. Default is false.
323    pub parallel_health_check: bool,
324}
325
326impl<S> LoadBalancer<S>
327where
328    S: BackendSelection + 'static,
329    S::Iter: BackendIter,
330{
331    /// Build a [LoadBalancer] with static backends created from the iter.
332    ///
333    /// Note: [ToSocketAddrs] will invoke blocking network IO for DNS lookup if
334    /// the input cannot be directly parsed as [SocketAddr].
335    pub fn try_from_iter<A, T: IntoIterator<Item = A>>(iter: T) -> IoResult<Self>
336    where
337        A: ToSocketAddrs,
338    {
339        let discovery = discovery::Static::try_from_iter(iter)?;
340        let backends = Backends::new(discovery);
341        let lb = Self::from_backends(backends);
342        lb.update()
343            .now_or_never()
344            .expect("static should not block")
345            .expect("static should not error");
346        Ok(lb)
347    }
348
349    /// Build a [LoadBalancer] with the given [Backends] and the config.
350    pub fn from_backends_with_config(backends: Backends, config: &S::Config) -> Self {
351        let selector = ArcSwap::new(Arc::new(S::build_with_config(
352            &backends.get_backend(),
353            config,
354        )));
355
356        LoadBalancer {
357            backends,
358            selector,
359            health_check_frequency: None,
360            update_frequency: None,
361            parallel_health_check: false,
362        }
363    }
364
365    /// Build a [LoadBalancer] with the given [Backends].
366    pub fn from_backends(backends: Backends) -> Self {
367        let selector = ArcSwap::new(Arc::new(S::build(&backends.get_backend())));
368        LoadBalancer {
369            backends,
370            selector,
371            health_check_frequency: None,
372            update_frequency: None,
373            parallel_health_check: false,
374        }
375    }
376
377    /// Run the service discovery and update the selection algorithm.
378    ///
379    /// This function will be called every `update_frequency` if this [LoadBalancer] instance
380    /// is running as a background service.
381    pub async fn update(&self) -> Result<()> {
382        self.backends
383            .update(|backends| self.selector.store(Arc::new(S::build(&backends))))
384            .await
385    }
386
387    /// Return the first healthy [Backend] according to the selection algorithm and the
388    /// health check results.
389    ///
390    /// The `key` is used for hash based selection and is ignored if the selection is random or
391    /// round robin.
392    ///
393    /// the `max_iterations` is there to bound the search time for the next Backend. In certain
394    /// algorithm like Ketama hashing, the search for the next backend is linear and could take
395    /// a lot steps.
396    // TODO: consider remove `max_iterations` as users have no idea how to set it.
397    pub fn select(&self, key: &[u8], max_iterations: usize) -> Option<Backend> {
398        self.select_with(key, max_iterations, |_, health| health)
399    }
400
401    /// Similar to [Self::select], return the first healthy [Backend] according to the selection algorithm
402    /// and the user defined `accept` function.
403    ///
404    /// The `accept` function takes two inputs, the backend being selected and the internal health of that
405    /// backend. The function can do things like ignoring the internal health checks or skipping this backend
406    /// because it failed before. The `accept` function is called multiple times iterating over backends
407    /// until it returns `true`.
408    pub fn select_with<F>(&self, key: &[u8], max_iterations: usize, accept: F) -> Option<Backend>
409    where
410        F: Fn(&Backend, bool) -> bool,
411    {
412        let selection = self.selector.load();
413        let mut iter = UniqueIterator::new(selection.iter(key), max_iterations);
414        while let Some(b) = iter.get_next() {
415            if accept(&b, self.backends.ready(&b)) {
416                return Some(b);
417            }
418        }
419        None
420    }
421
422    /// Set the health check method. See [health_check].
423    pub fn set_health_check(
424        &mut self,
425        hc: Box<dyn health_check::HealthCheck + Send + Sync + 'static>,
426    ) {
427        self.backends.set_health_check(hc);
428    }
429
430    /// Access the [Backends] of this [LoadBalancer]
431    pub fn backends(&self) -> &Backends {
432        &self.backends
433    }
434}
435
436#[cfg(test)]
437mod test {
438    use std::sync::atomic::{AtomicBool, Ordering::Relaxed};
439
440    use super::*;
441    use async_trait::async_trait;
442
443    #[tokio::test]
444    async fn test_static_backends() {
445        let backends: LoadBalancer<selection::RoundRobin> =
446            LoadBalancer::try_from_iter(["1.1.1.1:80", "1.0.0.1:80"]).unwrap();
447
448        let backend1 = Backend::new("1.1.1.1:80").unwrap();
449        let backend2 = Backend::new("1.0.0.1:80").unwrap();
450        let backend = backends.backends().get_backend();
451        assert!(backend.contains(&backend1));
452        assert!(backend.contains(&backend2));
453    }
454
455    #[tokio::test]
456    async fn test_backends() {
457        let discovery = discovery::Static::default();
458        let good1 = Backend::new("1.1.1.1:80").unwrap();
459        discovery.add(good1.clone());
460        let good2 = Backend::new("1.0.0.1:80").unwrap();
461        discovery.add(good2.clone());
462        let bad = Backend::new("127.0.0.1:79").unwrap();
463        discovery.add(bad.clone());
464
465        let mut backends = Backends::new(Box::new(discovery));
466        let check = health_check::TcpHealthCheck::new();
467        backends.set_health_check(check);
468
469        // true: new backend discovered
470        let updated = AtomicBool::new(false);
471        backends
472            .update(|_| updated.store(true, Relaxed))
473            .await
474            .unwrap();
475        assert!(updated.load(Relaxed));
476
477        // false: no new backend discovered
478        let updated = AtomicBool::new(false);
479        backends
480            .update(|_| updated.store(true, Relaxed))
481            .await
482            .unwrap();
483        assert!(!updated.load(Relaxed));
484
485        backends.run_health_check(false).await;
486
487        let backend = backends.get_backend();
488        assert!(backend.contains(&good1));
489        assert!(backend.contains(&good2));
490        assert!(backend.contains(&bad));
491
492        assert!(backends.ready(&good1));
493        assert!(backends.ready(&good2));
494        assert!(!backends.ready(&bad));
495    }
496    #[tokio::test]
497    async fn test_backends_with_ext() {
498        let discovery = discovery::Static::default();
499        let mut b1 = Backend::new("1.1.1.1:80").unwrap();
500        b1.ext.insert(true);
501        let mut b2 = Backend::new("1.0.0.1:80").unwrap();
502        b2.ext.insert(1u8);
503        discovery.add(b1.clone());
504        discovery.add(b2.clone());
505
506        let backends = Backends::new(Box::new(discovery));
507
508        // fill in the backends
509        backends.update(|_| {}).await.unwrap();
510
511        let backend = backends.get_backend();
512        assert!(backend.contains(&b1));
513        assert!(backend.contains(&b2));
514
515        let b2 = backend.first().unwrap();
516        assert_eq!(b2.ext.get::<u8>(), Some(&1));
517
518        let b1 = backend.last().unwrap();
519        assert_eq!(b1.ext.get::<bool>(), Some(&true));
520    }
521
522    #[tokio::test]
523    async fn test_discovery_readiness() {
524        use discovery::Static;
525
526        struct TestDiscovery(Static);
527        #[async_trait]
528        impl ServiceDiscovery for TestDiscovery {
529            async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
530                let bad = Backend::new("127.0.0.1:79").unwrap();
531                let (backends, mut readiness) = self.0.discover().await?;
532                readiness.insert(bad.hash_key(), false);
533                Ok((backends, readiness))
534            }
535        }
536        let discovery = Static::default();
537        let good1 = Backend::new("1.1.1.1:80").unwrap();
538        discovery.add(good1.clone());
539        let good2 = Backend::new("1.0.0.1:80").unwrap();
540        discovery.add(good2.clone());
541        let bad = Backend::new("127.0.0.1:79").unwrap();
542        discovery.add(bad.clone());
543        let discovery = TestDiscovery(discovery);
544
545        let backends = Backends::new(Box::new(discovery));
546
547        // true: new backend discovered
548        let updated = AtomicBool::new(false);
549        backends
550            .update(|_| updated.store(true, Relaxed))
551            .await
552            .unwrap();
553        assert!(updated.load(Relaxed));
554
555        let backend = backends.get_backend();
556        assert!(backend.contains(&good1));
557        assert!(backend.contains(&good2));
558        assert!(backend.contains(&bad));
559
560        assert!(backends.ready(&good1));
561        assert!(backends.ready(&good2));
562        assert!(!backends.ready(&bad));
563    }
564
565    #[tokio::test]
566    async fn test_parallel_health_check() {
567        let discovery = discovery::Static::default();
568        let good1 = Backend::new("1.1.1.1:80").unwrap();
569        discovery.add(good1.clone());
570        let good2 = Backend::new("1.0.0.1:80").unwrap();
571        discovery.add(good2.clone());
572        let bad = Backend::new("127.0.0.1:79").unwrap();
573        discovery.add(bad.clone());
574
575        let mut backends = Backends::new(Box::new(discovery));
576        let check = health_check::TcpHealthCheck::new();
577        backends.set_health_check(check);
578
579        // true: new backend discovered
580        let updated = AtomicBool::new(false);
581        backends
582            .update(|_| updated.store(true, Relaxed))
583            .await
584            .unwrap();
585        assert!(updated.load(Relaxed));
586
587        backends.run_health_check(true).await;
588
589        assert!(backends.ready(&good1));
590        assert!(backends.ready(&good2));
591        assert!(!backends.ready(&bad));
592    }
593
594    mod thread_safety {
595        use super::*;
596
597        struct MockDiscovery {
598            expected: usize,
599        }
600        #[async_trait]
601        impl ServiceDiscovery for MockDiscovery {
602            async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
603                let mut d = BTreeSet::new();
604                let mut m = HashMap::with_capacity(self.expected);
605                for i in 0..self.expected {
606                    let b = Backend::new(&format!("1.1.1.1:{i}")).unwrap();
607                    m.insert(i as u64, true);
608                    d.insert(b);
609                }
610                Ok((d, m))
611            }
612        }
613
614        #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
615        async fn test_consistency() {
616            let expected = 3000;
617            let discovery = MockDiscovery { expected };
618            let lb = Arc::new(LoadBalancer::<selection::Consistent>::from_backends(
619                Backends::new(Box::new(discovery)),
620            ));
621            let lb2 = lb.clone();
622
623            tokio::spawn(async move {
624                assert!(lb2.update().await.is_ok());
625            });
626            let mut backend_count = 0;
627            while backend_count == 0 {
628                let backends = lb.backends();
629                backend_count = backends.backends.load_full().len();
630            }
631            assert_eq!(backend_count, expected);
632            assert!(lb.select_with(b"test", 1, |_, _| true).is_some());
633        }
634    }
635}