Skip to main content

pingora_load_balancing/
lib.rs

1// Copyright 2026 Cloudflare, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! # Pingora Load Balancing utilities
16//! This crate provides common service discovery, health check and load balancing
17//! algorithms for proxies to use.
18
19// https://github.com/mcarton/rust-derivative/issues/112
20// False positive for macro generated code
21#![allow(clippy::non_canonical_partial_ord_impl)]
22
23use arc_swap::ArcSwap;
24use derivative::Derivative;
25use futures::FutureExt;
26pub use http::Extensions;
27use pingora_core::protocols::l4::socket::SocketAddr;
28use pingora_error::{ErrorType, OrErr, Result};
29use std::collections::hash_map::DefaultHasher;
30use std::collections::{BTreeSet, HashMap};
31use std::hash::{Hash, Hasher};
32use std::io::Result as IoResult;
33use std::net::ToSocketAddrs;
34use std::sync::Arc;
35use std::time::Duration;
36
37mod background;
38pub mod discovery;
39pub mod health_check;
40pub mod selection;
41
42use discovery::ServiceDiscovery;
43use health_check::Health;
44use selection::UniqueIterator;
45use selection::{BackendIter, BackendSelection};
46
47pub mod prelude {
48    pub use crate::health_check::TcpHealthCheck;
49    pub use crate::selection::RoundRobin;
50    pub use crate::LoadBalancer;
51}
52
53/// [Backend] represents a server to proxy or connect to.
54#[derive(Derivative)]
55#[derivative(Clone, Hash, PartialEq, PartialOrd, Eq, Ord, Debug)]
56pub struct Backend {
57    /// The address to the backend server.
58    pub addr: SocketAddr,
59    /// The relative weight of the server. Load balancing algorithms will
60    /// proportionally distributed traffic according to this value.
61    pub weight: usize,
62
63    /// The extension field to put arbitrary data to annotate the Backend.
64    /// The data added here is opaque to this crate hence the data is ignored by
65    /// functionalities of this crate. For example, two backends with the same
66    /// [SocketAddr] and the same weight but different `ext` data are considered
67    /// identical.
68    /// See [Extensions] for how to add and read the data.
69    #[derivative(PartialEq = "ignore")]
70    #[derivative(PartialOrd = "ignore")]
71    #[derivative(Hash = "ignore")]
72    #[derivative(Ord = "ignore")]
73    pub ext: Extensions,
74}
75
76impl Backend {
77    /// Create a new [Backend] with `weight` 1. The function will try to parse
78    ///  `addr` into a [std::net::SocketAddr].
79    pub fn new(addr: &str) -> Result<Self> {
80        Self::new_with_weight(addr, 1)
81    }
82
83    /// Creates a new [Backend] with the specified `weight`. The function will try to parse
84    /// `addr` into a [std::net::SocketAddr].
85    pub fn new_with_weight(addr: &str, weight: usize) -> Result<Self> {
86        let addr = addr
87            .parse()
88            .or_err(ErrorType::InternalError, "invalid socket addr")?;
89        Ok(Backend {
90            addr: SocketAddr::Inet(addr),
91            weight,
92            ext: Extensions::new(),
93        })
94        // TODO: UDS
95    }
96
97    pub(crate) fn hash_key(&self) -> u64 {
98        let mut hasher = DefaultHasher::new();
99        self.hash(&mut hasher);
100        hasher.finish()
101    }
102}
103
104impl std::ops::Deref for Backend {
105    type Target = SocketAddr;
106
107    fn deref(&self) -> &Self::Target {
108        &self.addr
109    }
110}
111
112impl std::ops::DerefMut for Backend {
113    fn deref_mut(&mut self) -> &mut Self::Target {
114        &mut self.addr
115    }
116}
117
118impl std::net::ToSocketAddrs for Backend {
119    type Iter = std::iter::Once<std::net::SocketAddr>;
120
121    fn to_socket_addrs(&self) -> std::io::Result<Self::Iter> {
122        self.addr.to_socket_addrs()
123    }
124}
125
126/// [Backends] is a collection of [Backend]s.
127///
128/// It includes a service discovery method (static or dynamic) to discover all
129/// the available backends as well as an optional health check method to probe the liveness
130/// of each backend.
131pub struct Backends {
132    discovery: Box<dyn ServiceDiscovery + Send + Sync + 'static>,
133    health_check: Option<Arc<dyn health_check::HealthCheck + Send + Sync + 'static>>,
134    backends: ArcSwap<BTreeSet<Backend>>,
135    health: ArcSwap<HashMap<u64, Health>>,
136}
137
138impl Backends {
139    /// Create a new [Backends] with the given [ServiceDiscovery] implementation.
140    ///
141    /// The health check method is by default empty.
142    pub fn new(discovery: Box<dyn ServiceDiscovery + Send + Sync + 'static>) -> Self {
143        Self {
144            discovery,
145            health_check: None,
146            backends: Default::default(),
147            health: Default::default(),
148        }
149    }
150
151    /// Set the health check method. See [health_check] for the methods provided.
152    pub fn set_health_check(
153        &mut self,
154        hc: Box<dyn health_check::HealthCheck + Send + Sync + 'static>,
155    ) {
156        self.health_check = Some(hc.into())
157    }
158
159    /// Updates backends when the new is different from the current set,
160    /// the callback will be invoked when the new set of backend is different
161    /// from the current one so that the caller can update the selector accordingly.
162    fn do_update<F>(
163        &self,
164        new_backends: BTreeSet<Backend>,
165        enablement: HashMap<u64, bool>,
166        callback: F,
167    ) where
168        F: Fn(Arc<BTreeSet<Backend>>),
169    {
170        if (**self.backends.load()) != new_backends {
171            let old_health = self.health.load();
172            let mut health = HashMap::with_capacity(new_backends.len());
173            for backend in new_backends.iter() {
174                let hash_key = backend.hash_key();
175                // use the default health if the backend is new
176                let backend_health = old_health.get(&hash_key).cloned().unwrap_or_default();
177
178                // override enablement
179                if let Some(backend_enabled) = enablement.get(&hash_key) {
180                    backend_health.enable(*backend_enabled);
181                }
182                health.insert(hash_key, backend_health);
183            }
184
185            // TODO: put this all under 1 ArcSwap so the update is atomic
186            // It's important the `callback()` executes first since computing selector backends might
187            // be expensive. For example, if a caller checks `backends` to see if any are available
188            // they may encounter false positives if the selector isn't ready yet.
189            let new_backends = Arc::new(new_backends);
190            callback(new_backends.clone());
191            self.backends.store(new_backends);
192            self.health.store(Arc::new(health));
193        } else {
194            // no backend change, just check enablement
195            for (hash_key, backend_enabled) in enablement.iter() {
196                // override enablement if set
197                // this get should always be Some(_) because we already populate `health`` for all known backends
198                if let Some(backend_health) = self.health.load().get(hash_key) {
199                    backend_health.enable(*backend_enabled);
200                }
201            }
202        }
203    }
204
205    /// Whether a certain [Backend] is ready to serve traffic.
206    ///
207    /// This function returns true when the backend is both healthy and enabled.
208    /// This function returns true when the health check is unset but the backend is enabled.
209    /// When the health check is set, this function will return false for the `backend` it
210    /// doesn't know.
211    pub fn ready(&self, backend: &Backend) -> bool {
212        self.health
213            .load()
214            .get(&backend.hash_key())
215            // Racing: return `None` when this function is called between the
216            // backend store and the health store
217            .map_or(self.health_check.is_none(), |h| h.ready())
218    }
219
220    /// Manually set if a [Backend] is ready to serve traffic.
221    ///
222    /// This method does not override the health of the backend. It is meant to be used
223    /// to stop a backend from accepting traffic when it is still healthy.
224    ///
225    /// This method is noop when the given backend doesn't exist in the service discovery.
226    pub fn set_enable(&self, backend: &Backend, enabled: bool) {
227        // this should always be Some(_) because health is always populated during update
228        if let Some(h) = self.health.load().get(&backend.hash_key()) {
229            h.enable(enabled)
230        };
231    }
232
233    /// Return the collection of the backends.
234    pub fn get_backend(&self) -> Arc<BTreeSet<Backend>> {
235        self.backends.load_full()
236    }
237
238    /// Call the service discovery method to update the collection of backends.
239    ///
240    /// The callback will be invoked when the new set of backend is different
241    /// from the current one so that the caller can update the selector accordingly.
242    pub async fn update<F>(&self, callback: F) -> Result<()>
243    where
244        F: Fn(Arc<BTreeSet<Backend>>),
245    {
246        let (new_backends, enablement) = self.discovery.discover().await?;
247        self.do_update(new_backends, enablement, callback);
248        Ok(())
249    }
250
251    /// Run health check on all backends if it is set.
252    ///
253    /// When `parallel: true`, all backends are checked in parallel instead of sequentially
254    pub async fn run_health_check(&self, parallel: bool) {
255        use crate::health_check::HealthCheck;
256        use log::{info, warn};
257        use pingora_runtime::current_handle;
258
259        async fn check_and_report(
260            backend: &Backend,
261            check: &Arc<dyn HealthCheck + Send + Sync>,
262            health_table: &HashMap<u64, Health>,
263        ) {
264            let errored = check.check(backend).await.err();
265            if let Some(h) = health_table.get(&backend.hash_key()) {
266                let flipped =
267                    h.observe_health(errored.is_none(), check.health_threshold(errored.is_none()));
268                if flipped {
269                    check.health_status_change(backend, errored.is_none()).await;
270                    let summary = check.backend_summary(backend);
271                    if let Some(e) = errored {
272                        warn!("{summary} becomes unhealthy, {e}");
273                    } else {
274                        info!("{summary} becomes healthy");
275                    }
276                }
277            }
278        }
279
280        let Some(health_check) = self.health_check.as_ref() else {
281            return;
282        };
283
284        let backends = self.backends.load();
285        if parallel {
286            let health_table = self.health.load_full();
287            let runtime = current_handle();
288            let jobs = backends.iter().map(|backend| {
289                let backend = backend.clone();
290                let check = health_check.clone();
291                let ht = health_table.clone();
292                runtime.spawn(async move {
293                    check_and_report(&backend, &check, &ht).await;
294                })
295            });
296
297            futures::future::join_all(jobs).await;
298        } else {
299            for backend in backends.iter() {
300                check_and_report(backend, health_check, &self.health.load()).await;
301            }
302        }
303    }
304}
305
306/// A [LoadBalancer] instance contains the service discovery, health check and backend selection
307/// all together.
308///
309/// In order to run service discovery and health check at the designated frequencies, the [LoadBalancer]
310/// needs to be run as a [pingora_core::services::background::BackgroundService].
311pub struct LoadBalancer<S>
312where
313    S: BackendSelection,
314{
315    backends: Backends,
316    selector: ArcSwap<S>,
317
318    config: Option<S::Config>,
319
320    /// How frequent the health check logic (if set) should run.
321    ///
322    /// If `None`, the health check logic will only run once at the beginning.
323    pub health_check_frequency: Option<Duration>,
324    /// How frequent the service discovery should run.
325    ///
326    /// If `None`, the service discovery will only run once at the beginning.
327    pub update_frequency: Option<Duration>,
328    /// Whether to run health check to all backends in parallel. Default is false.
329    pub parallel_health_check: bool,
330}
331
332impl<S> LoadBalancer<S>
333where
334    S: BackendSelection + 'static,
335    S::Iter: BackendIter,
336{
337    /// Build a [LoadBalancer] with static backends created from the iter.
338    ///
339    /// Note: [ToSocketAddrs] will invoke blocking network IO for DNS lookup if
340    /// the input cannot be directly parsed as [SocketAddr].
341    pub fn try_from_iter<A, T: IntoIterator<Item = A>>(iter: T) -> IoResult<Self>
342    where
343        A: ToSocketAddrs,
344    {
345        let discovery = discovery::Static::try_from_iter(iter)?;
346        let backends = Backends::new(discovery);
347        let lb = Self::from_backends(backends);
348        lb.update()
349            .now_or_never()
350            .expect("static should not block")
351            .expect("static should not error");
352        Ok(lb)
353    }
354
355    /// Build a [LoadBalancer] with the given [Backends] and the config.
356    pub fn from_backends_with_config(backends: Backends, config_opt: Option<S::Config>) -> Self {
357        let selector_raw = if let Some(config) = config_opt.as_ref() {
358            S::build_with_config(&backends.get_backend(), config)
359        } else {
360            S::build(&backends.get_backend())
361        };
362
363        let selector = ArcSwap::new(Arc::new(selector_raw));
364
365        LoadBalancer {
366            backends,
367            selector,
368            config: config_opt,
369            health_check_frequency: None,
370            update_frequency: None,
371            parallel_health_check: false,
372        }
373    }
374
375    /// Build a [LoadBalancer] with the given [Backends].
376    pub fn from_backends(backends: Backends) -> Self {
377        Self::from_backends_with_config(backends, None)
378    }
379
380    /// Run the service discovery and update the selection algorithm.
381    ///
382    /// This function will be called every `update_frequency` if this [LoadBalancer] instance
383    /// is running as a background service.
384    pub async fn update(&self) -> Result<()> {
385        self.backends
386            .update(|backends| {
387                let selector = if let Some(config) = &self.config {
388                    S::build_with_config(&backends, config)
389                } else {
390                    S::build(&backends)
391                };
392
393                self.selector.store(Arc::new(selector))
394            })
395            .await
396    }
397
398    /// Return the first healthy [Backend] according to the selection algorithm and the
399    /// health check results.
400    ///
401    /// The `key` is used for hash based selection and is ignored if the selection is random or
402    /// round robin.
403    ///
404    /// the `max_iterations` is there to bound the search time for the next Backend. In certain
405    /// algorithm like Ketama hashing, the search for the next backend is linear and could take
406    /// a lot steps.
407    // TODO: consider remove `max_iterations` as users have no idea how to set it.
408    pub fn select(&self, key: &[u8], max_iterations: usize) -> Option<Backend> {
409        self.select_with(key, max_iterations, |_, health| health)
410    }
411
412    /// Similar to [Self::select], return the first healthy [Backend] according to the selection algorithm
413    /// and the user defined `accept` function.
414    ///
415    /// The `accept` function takes two inputs, the backend being selected and the internal health of that
416    /// backend. The function can do things like ignoring the internal health checks or skipping this backend
417    /// because it failed before. The `accept` function is called multiple times iterating over backends
418    /// until it returns `true`.
419    pub fn select_with<F>(&self, key: &[u8], max_iterations: usize, accept: F) -> Option<Backend>
420    where
421        F: Fn(&Backend, bool) -> bool,
422    {
423        let selection = self.selector.load();
424        let mut iter = UniqueIterator::new(selection.iter(key), max_iterations);
425        while let Some(b) = iter.get_next() {
426            if accept(&b, self.backends.ready(&b)) {
427                return Some(b);
428            }
429        }
430        None
431    }
432
433    /// Set the health check method. See [health_check].
434    pub fn set_health_check(
435        &mut self,
436        hc: Box<dyn health_check::HealthCheck + Send + Sync + 'static>,
437    ) {
438        self.backends.set_health_check(hc);
439    }
440
441    /// Access the [Backends] of this [LoadBalancer]
442    pub fn backends(&self) -> &Backends {
443        &self.backends
444    }
445}
446
447#[cfg(test)]
448mod test {
449    use std::sync::atomic::{AtomicBool, Ordering::Relaxed};
450
451    use super::*;
452    use async_trait::async_trait;
453
454    #[tokio::test]
455    async fn test_static_backends() {
456        let backends: LoadBalancer<selection::RoundRobin> =
457            LoadBalancer::try_from_iter(["1.1.1.1:80", "1.0.0.1:80"]).unwrap();
458
459        let backend1 = Backend::new("1.1.1.1:80").unwrap();
460        let backend2 = Backend::new("1.0.0.1:80").unwrap();
461        let backend = backends.backends().get_backend();
462        assert!(backend.contains(&backend1));
463        assert!(backend.contains(&backend2));
464    }
465
466    #[tokio::test]
467    async fn test_backends() {
468        let discovery = discovery::Static::default();
469        let good1 = Backend::new("1.1.1.1:80").unwrap();
470        discovery.add(good1.clone());
471        let good2 = Backend::new("1.0.0.1:80").unwrap();
472        discovery.add(good2.clone());
473        let bad = Backend::new("127.0.0.1:79").unwrap();
474        discovery.add(bad.clone());
475
476        let mut backends = Backends::new(Box::new(discovery));
477        let check = health_check::TcpHealthCheck::new();
478        backends.set_health_check(check);
479
480        // true: new backend discovered
481        let updated = AtomicBool::new(false);
482        backends
483            .update(|_| updated.store(true, Relaxed))
484            .await
485            .unwrap();
486        assert!(updated.load(Relaxed));
487
488        // false: no new backend discovered
489        let updated = AtomicBool::new(false);
490        backends
491            .update(|_| updated.store(true, Relaxed))
492            .await
493            .unwrap();
494        assert!(!updated.load(Relaxed));
495
496        backends.run_health_check(false).await;
497
498        let backend = backends.get_backend();
499        assert!(backend.contains(&good1));
500        assert!(backend.contains(&good2));
501        assert!(backend.contains(&bad));
502
503        assert!(backends.ready(&good1));
504        assert!(backends.ready(&good2));
505        assert!(!backends.ready(&bad));
506    }
507    #[tokio::test]
508    async fn test_backends_with_ext() {
509        let discovery = discovery::Static::default();
510        let mut b1 = Backend::new("1.1.1.1:80").unwrap();
511        b1.ext.insert(true);
512        let mut b2 = Backend::new("1.0.0.1:80").unwrap();
513        b2.ext.insert(1u8);
514        discovery.add(b1.clone());
515        discovery.add(b2.clone());
516
517        let backends = Backends::new(Box::new(discovery));
518
519        // fill in the backends
520        backends.update(|_| {}).await.unwrap();
521
522        let backend = backends.get_backend();
523        assert!(backend.contains(&b1));
524        assert!(backend.contains(&b2));
525
526        let b2 = backend.first().unwrap();
527        assert_eq!(b2.ext.get::<u8>(), Some(&1));
528
529        let b1 = backend.last().unwrap();
530        assert_eq!(b1.ext.get::<bool>(), Some(&true));
531    }
532
533    #[tokio::test]
534    async fn test_discovery_readiness() {
535        use discovery::Static;
536
537        struct TestDiscovery(Static);
538        #[async_trait]
539        impl ServiceDiscovery for TestDiscovery {
540            async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
541                let bad = Backend::new("127.0.0.1:79").unwrap();
542                let (backends, mut readiness) = self.0.discover().await?;
543                readiness.insert(bad.hash_key(), false);
544                Ok((backends, readiness))
545            }
546        }
547        let discovery = Static::default();
548        let good1 = Backend::new("1.1.1.1:80").unwrap();
549        discovery.add(good1.clone());
550        let good2 = Backend::new("1.0.0.1:80").unwrap();
551        discovery.add(good2.clone());
552        let bad = Backend::new("127.0.0.1:79").unwrap();
553        discovery.add(bad.clone());
554        let discovery = TestDiscovery(discovery);
555
556        let backends = Backends::new(Box::new(discovery));
557
558        // true: new backend discovered
559        let updated = AtomicBool::new(false);
560        backends
561            .update(|_| updated.store(true, Relaxed))
562            .await
563            .unwrap();
564        assert!(updated.load(Relaxed));
565
566        let backend = backends.get_backend();
567        assert!(backend.contains(&good1));
568        assert!(backend.contains(&good2));
569        assert!(backend.contains(&bad));
570
571        assert!(backends.ready(&good1));
572        assert!(backends.ready(&good2));
573        assert!(!backends.ready(&bad));
574    }
575
576    #[tokio::test]
577    async fn test_parallel_health_check() {
578        let discovery = discovery::Static::default();
579        let good1 = Backend::new("1.1.1.1:80").unwrap();
580        discovery.add(good1.clone());
581        let good2 = Backend::new("1.0.0.1:80").unwrap();
582        discovery.add(good2.clone());
583        let bad = Backend::new("127.0.0.1:79").unwrap();
584        discovery.add(bad.clone());
585
586        let mut backends = Backends::new(Box::new(discovery));
587        let check = health_check::TcpHealthCheck::new();
588        backends.set_health_check(check);
589
590        // true: new backend discovered
591        let updated = AtomicBool::new(false);
592        backends
593            .update(|_| updated.store(true, Relaxed))
594            .await
595            .unwrap();
596        assert!(updated.load(Relaxed));
597
598        backends.run_health_check(true).await;
599
600        assert!(backends.ready(&good1));
601        assert!(backends.ready(&good2));
602        assert!(!backends.ready(&bad));
603    }
604
605    mod thread_safety {
606        use super::*;
607
608        struct MockDiscovery {
609            expected: usize,
610        }
611        #[async_trait]
612        impl ServiceDiscovery for MockDiscovery {
613            async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
614                let mut d = BTreeSet::new();
615                let mut m = HashMap::with_capacity(self.expected);
616                for i in 0..self.expected {
617                    let b = Backend::new(&format!("1.1.1.1:{i}")).unwrap();
618                    m.insert(i as u64, true);
619                    d.insert(b);
620                }
621                Ok((d, m))
622            }
623        }
624
625        #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
626        async fn test_consistency() {
627            let expected = 3000;
628            let discovery = MockDiscovery { expected };
629            let lb = Arc::new(LoadBalancer::<selection::Consistent>::from_backends(
630                Backends::new(Box::new(discovery)),
631            ));
632            let lb2 = lb.clone();
633
634            tokio::spawn(async move {
635                assert!(lb2.update().await.is_ok());
636            });
637            let mut backend_count = 0;
638            while backend_count == 0 {
639                let backends = lb.backends();
640                backend_count = backends.backends.load_full().len();
641            }
642            assert_eq!(backend_count, expected);
643            assert!(lb.select_with(b"test", 1, |_, _| true).is_some());
644        }
645    }
646}