pingora_load_balancing/
lib.rs

1// Copyright 2025 Cloudflare, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! # Pingora Load Balancing utilities
16//! This crate provides common service discovery, health check and load balancing
17//! algorithms for proxies to use.
18
19// https://github.com/mcarton/rust-derivative/issues/112
20// False positive for macro generated code
21#![allow(clippy::non_canonical_partial_ord_impl)]
22
23use arc_swap::ArcSwap;
24use derivative::Derivative;
25use futures::FutureExt;
26pub use http::Extensions;
27use pingora_core::protocols::l4::socket::SocketAddr;
28use pingora_error::{ErrorType, OrErr, Result};
29use std::collections::hash_map::DefaultHasher;
30use std::collections::{BTreeSet, HashMap};
31use std::hash::{Hash, Hasher};
32use std::io::Result as IoResult;
33use std::net::ToSocketAddrs;
34use std::sync::Arc;
35use std::time::Duration;
36
37mod background;
38pub mod discovery;
39pub mod health_check;
40pub mod selection;
41
42use discovery::ServiceDiscovery;
43use health_check::Health;
44use selection::UniqueIterator;
45use selection::{BackendIter, BackendSelection};
46
47pub mod prelude {
48    pub use crate::health_check::TcpHealthCheck;
49    pub use crate::selection::RoundRobin;
50    pub use crate::LoadBalancer;
51}
52
53/// [Backend] represents a server to proxy or connect to.
54#[derive(Derivative)]
55#[derivative(Clone, Hash, PartialEq, PartialOrd, Eq, Ord, Debug)]
56pub struct Backend {
57    /// The address to the backend server.
58    pub addr: SocketAddr,
59    /// The relative weight of the server. Load balancing algorithms will
60    /// proportionally distributed traffic according to this value.
61    pub weight: usize,
62
63    /// The extension field to put arbitrary data to annotate the Backend.
64    /// The data added here is opaque to this crate hence the data is ignored by
65    /// functionalities of this crate. For example, two backends with the same
66    /// [SocketAddr] and the same weight but different `ext` data are considered
67    /// identical.
68    /// See [Extensions] for how to add and read the data.
69    #[derivative(PartialEq = "ignore")]
70    #[derivative(PartialOrd = "ignore")]
71    #[derivative(Hash = "ignore")]
72    #[derivative(Ord = "ignore")]
73    pub ext: Extensions,
74}
75
76impl Backend {
77    /// Create a new [Backend] with `weight` 1. The function will try to parse
78    ///  `addr` into a [std::net::SocketAddr].
79    pub fn new(addr: &str) -> Result<Self> {
80        Self::new_with_weight(addr, 1)
81    }
82
83    /// Creates a new [Backend] with the specified `weight`. The function will try to parse
84    /// `addr` into a [std::net::SocketAddr].
85    pub fn new_with_weight(addr: &str, weight: usize) -> Result<Self> {
86        let addr = addr
87            .parse()
88            .or_err(ErrorType::InternalError, "invalid socket addr")?;
89        Ok(Backend {
90            addr: SocketAddr::Inet(addr),
91            weight,
92            ext: Extensions::new(),
93        })
94        // TODO: UDS
95    }
96
97    pub(crate) fn hash_key(&self) -> u64 {
98        let mut hasher = DefaultHasher::new();
99        self.hash(&mut hasher);
100        hasher.finish()
101    }
102}
103
104impl std::ops::Deref for Backend {
105    type Target = SocketAddr;
106
107    fn deref(&self) -> &Self::Target {
108        &self.addr
109    }
110}
111
112impl std::ops::DerefMut for Backend {
113    fn deref_mut(&mut self) -> &mut Self::Target {
114        &mut self.addr
115    }
116}
117
118impl std::net::ToSocketAddrs for Backend {
119    type Iter = std::iter::Once<std::net::SocketAddr>;
120
121    fn to_socket_addrs(&self) -> std::io::Result<Self::Iter> {
122        self.addr.to_socket_addrs()
123    }
124}
125
126/// [Backends] is a collection of [Backend]s.
127///
128/// It includes a service discovery method (static or dynamic) to discover all
129/// the available backends as well as an optional health check method to probe the liveness
130/// of each backend.
131pub struct Backends {
132    discovery: Box<dyn ServiceDiscovery + Send + Sync + 'static>,
133    health_check: Option<Arc<dyn health_check::HealthCheck + Send + Sync + 'static>>,
134    backends: ArcSwap<BTreeSet<Backend>>,
135    health: ArcSwap<HashMap<u64, Health>>,
136}
137
138impl Backends {
139    /// Create a new [Backends] with the given [ServiceDiscovery] implementation.
140    ///
141    /// The health check method is by default empty.
142    pub fn new(discovery: Box<dyn ServiceDiscovery + Send + Sync + 'static>) -> Self {
143        Self {
144            discovery,
145            health_check: None,
146            backends: Default::default(),
147            health: Default::default(),
148        }
149    }
150
151    /// Set the health check method. See [health_check] for the methods provided.
152    pub fn set_health_check(
153        &mut self,
154        hc: Box<dyn health_check::HealthCheck + Send + Sync + 'static>,
155    ) {
156        self.health_check = Some(hc.into())
157    }
158
159    /// Updates backends when the new is different from the current set,
160    /// the callback will be invoked when the new set of backend is different
161    /// from the current one so that the caller can update the selector accordingly.
162    fn do_update<F>(
163        &self,
164        new_backends: BTreeSet<Backend>,
165        enablement: HashMap<u64, bool>,
166        callback: F,
167    ) where
168        F: Fn(Arc<BTreeSet<Backend>>),
169    {
170        if (**self.backends.load()) != new_backends {
171            let old_health = self.health.load();
172            let mut health = HashMap::with_capacity(new_backends.len());
173            for backend in new_backends.iter() {
174                let hash_key = backend.hash_key();
175                // use the default health if the backend is new
176                let backend_health = old_health.get(&hash_key).cloned().unwrap_or_default();
177
178                // override enablement
179                if let Some(backend_enabled) = enablement.get(&hash_key) {
180                    backend_health.enable(*backend_enabled);
181                }
182                health.insert(hash_key, backend_health);
183            }
184
185            // TODO: put this all under 1 ArcSwap so the update is atomic
186            // It's important the `callback()` executes first since computing selector backends might
187            // be expensive. For example, if a caller checks `backends` to see if any are available
188            // they may encounter false positives if the selector isn't ready yet.
189            let new_backends = Arc::new(new_backends);
190            callback(new_backends.clone());
191            self.backends.store(new_backends);
192            self.health.store(Arc::new(health));
193        } else {
194            // no backend change, just check enablement
195            for (hash_key, backend_enabled) in enablement.iter() {
196                // override enablement if set
197                // this get should always be Some(_) because we already populate `health`` for all known backends
198                if let Some(backend_health) = self.health.load().get(hash_key) {
199                    backend_health.enable(*backend_enabled);
200                }
201            }
202        }
203    }
204
205    /// Whether a certain [Backend] is ready to serve traffic.
206    ///
207    /// This function returns true when the backend is both healthy and enabled.
208    /// This function returns true when the health check is unset but the backend is enabled.
209    /// When the health check is set, this function will return false for the `backend` it
210    /// doesn't know.
211    pub fn ready(&self, backend: &Backend) -> bool {
212        self.health
213            .load()
214            .get(&backend.hash_key())
215            // Racing: return `None` when this function is called between the
216            // backend store and the health store
217            .map_or(self.health_check.is_none(), |h| h.ready())
218    }
219
220    /// Manually set if a [Backend] is ready to serve traffic.
221    ///
222    /// This method does not override the health of the backend. It is meant to be used
223    /// to stop a backend from accepting traffic when it is still healthy.
224    ///
225    /// This method is noop when the given backend doesn't exist in the service discovery.
226    pub fn set_enable(&self, backend: &Backend, enabled: bool) {
227        // this should always be Some(_) because health is always populated during update
228        if let Some(h) = self.health.load().get(&backend.hash_key()) {
229            h.enable(enabled)
230        };
231    }
232
233    /// Return the collection of the backends.
234    pub fn get_backend(&self) -> Arc<BTreeSet<Backend>> {
235        self.backends.load_full()
236    }
237
238    /// Call the service discovery method to update the collection of backends.
239    ///
240    /// The callback will be invoked when the new set of backend is different
241    /// from the current one so that the caller can update the selector accordingly.
242    pub async fn update<F>(&self, callback: F) -> Result<()>
243    where
244        F: Fn(Arc<BTreeSet<Backend>>),
245    {
246        let (new_backends, enablement) = self.discovery.discover().await?;
247        self.do_update(new_backends, enablement, callback);
248        Ok(())
249    }
250
251    /// Run health check on all backends if it is set.
252    ///
253    /// When `parallel: true`, all backends are checked in parallel instead of sequentially
254    pub async fn run_health_check(&self, parallel: bool) {
255        use crate::health_check::HealthCheck;
256        use log::{info, warn};
257        use pingora_runtime::current_handle;
258
259        async fn check_and_report(
260            backend: &Backend,
261            check: &Arc<dyn HealthCheck + Send + Sync>,
262            health_table: &HashMap<u64, Health>,
263        ) {
264            let errored = check.check(backend).await.err();
265            if let Some(h) = health_table.get(&backend.hash_key()) {
266                let flipped =
267                    h.observe_health(errored.is_none(), check.health_threshold(errored.is_none()));
268                if flipped {
269                    check.health_status_change(backend, errored.is_none()).await;
270                    if let Some(e) = errored {
271                        warn!("{backend:?} becomes unhealthy, {e}");
272                    } else {
273                        info!("{backend:?} becomes healthy");
274                    }
275                }
276            }
277        }
278
279        let Some(health_check) = self.health_check.as_ref() else {
280            return;
281        };
282
283        let backends = self.backends.load();
284        if parallel {
285            let health_table = self.health.load_full();
286            let runtime = current_handle();
287            let jobs = backends.iter().map(|backend| {
288                let backend = backend.clone();
289                let check = health_check.clone();
290                let ht = health_table.clone();
291                runtime.spawn(async move {
292                    check_and_report(&backend, &check, &ht).await;
293                })
294            });
295
296            futures::future::join_all(jobs).await;
297        } else {
298            for backend in backends.iter() {
299                check_and_report(backend, health_check, &self.health.load()).await;
300            }
301        }
302    }
303}
304
305/// A [LoadBalancer] instance contains the service discovery, health check and backend selection
306/// all together.
307///
308/// In order to run service discovery and health check at the designated frequencies, the [LoadBalancer]
309/// needs to be run as a [pingora_core::services::background::BackgroundService].
310pub struct LoadBalancer<S> {
311    backends: Backends,
312    selector: ArcSwap<S>,
313    /// How frequent the health check logic (if set) should run.
314    ///
315    /// If `None`, the health check logic will only run once at the beginning.
316    pub health_check_frequency: Option<Duration>,
317    /// How frequent the service discovery should run.
318    ///
319    /// If `None`, the service discovery will only run once at the beginning.
320    pub update_frequency: Option<Duration>,
321    /// Whether to run health check to all backends in parallel. Default is false.
322    pub parallel_health_check: bool,
323}
324
325impl<S: BackendSelection> LoadBalancer<S>
326where
327    S: BackendSelection + 'static,
328    S::Iter: BackendIter,
329{
330    /// Build a [LoadBalancer] with static backends created from the iter.
331    ///
332    /// Note: [ToSocketAddrs] will invoke blocking network IO for DNS lookup if
333    /// the input cannot be directly parsed as [SocketAddr].
334    pub fn try_from_iter<A, T: IntoIterator<Item = A>>(iter: T) -> IoResult<Self>
335    where
336        A: ToSocketAddrs,
337    {
338        let discovery = discovery::Static::try_from_iter(iter)?;
339        let backends = Backends::new(discovery);
340        let lb = Self::from_backends(backends);
341        lb.update()
342            .now_or_never()
343            .expect("static should not block")
344            .expect("static should not error");
345        Ok(lb)
346    }
347
348    /// Build a [LoadBalancer] with the given [Backends].
349    pub fn from_backends(backends: Backends) -> Self {
350        let selector = ArcSwap::new(Arc::new(S::build(&backends.get_backend())));
351        LoadBalancer {
352            backends,
353            selector,
354            health_check_frequency: None,
355            update_frequency: None,
356            parallel_health_check: false,
357        }
358    }
359
360    /// Run the service discovery and update the selection algorithm.
361    ///
362    /// This function will be called every `update_frequency` if this [LoadBalancer] instance
363    /// is running as a background service.
364    pub async fn update(&self) -> Result<()> {
365        self.backends
366            .update(|backends| self.selector.store(Arc::new(S::build(&backends))))
367            .await
368    }
369
370    /// Return the first healthy [Backend] according to the selection algorithm and the
371    /// health check results.
372    ///
373    /// The `key` is used for hash based selection and is ignored if the selection is random or
374    /// round robin.
375    ///
376    /// the `max_iterations` is there to bound the search time for the next Backend. In certain
377    /// algorithm like Ketama hashing, the search for the next backend is linear and could take
378    /// a lot steps.
379    // TODO: consider remove `max_iterations` as users have no idea how to set it.
380    pub fn select(&self, key: &[u8], max_iterations: usize) -> Option<Backend> {
381        self.select_with(key, max_iterations, |_, health| health)
382    }
383
384    /// Similar to [Self::select], return the first healthy [Backend] according to the selection algorithm
385    /// and the user defined `accept` function.
386    ///
387    /// The `accept` function takes two inputs, the backend being selected and the internal health of that
388    /// backend. The function can do things like ignoring the internal health checks or skipping this backend
389    /// because it failed before. The `accept` function is called multiple times iterating over backends
390    /// until it returns `true`.
391    pub fn select_with<F>(&self, key: &[u8], max_iterations: usize, accept: F) -> Option<Backend>
392    where
393        F: Fn(&Backend, bool) -> bool,
394    {
395        let selection = self.selector.load();
396        let mut iter = UniqueIterator::new(selection.iter(key), max_iterations);
397        while let Some(b) = iter.get_next() {
398            if accept(&b, self.backends.ready(&b)) {
399                return Some(b);
400            }
401        }
402        None
403    }
404
405    /// Set the health check method. See [health_check].
406    pub fn set_health_check(
407        &mut self,
408        hc: Box<dyn health_check::HealthCheck + Send + Sync + 'static>,
409    ) {
410        self.backends.set_health_check(hc);
411    }
412
413    /// Access the [Backends] of this [LoadBalancer]
414    pub fn backends(&self) -> &Backends {
415        &self.backends
416    }
417}
418
419#[cfg(test)]
420mod test {
421    use std::sync::atomic::{AtomicBool, Ordering::Relaxed};
422
423    use super::*;
424    use async_trait::async_trait;
425
426    #[tokio::test]
427    async fn test_static_backends() {
428        let backends: LoadBalancer<selection::RoundRobin> =
429            LoadBalancer::try_from_iter(["1.1.1.1:80", "1.0.0.1:80"]).unwrap();
430
431        let backend1 = Backend::new("1.1.1.1:80").unwrap();
432        let backend2 = Backend::new("1.0.0.1:80").unwrap();
433        let backend = backends.backends().get_backend();
434        assert!(backend.contains(&backend1));
435        assert!(backend.contains(&backend2));
436    }
437
438    #[tokio::test]
439    async fn test_backends() {
440        let discovery = discovery::Static::default();
441        let good1 = Backend::new("1.1.1.1:80").unwrap();
442        discovery.add(good1.clone());
443        let good2 = Backend::new("1.0.0.1:80").unwrap();
444        discovery.add(good2.clone());
445        let bad = Backend::new("127.0.0.1:79").unwrap();
446        discovery.add(bad.clone());
447
448        let mut backends = Backends::new(Box::new(discovery));
449        let check = health_check::TcpHealthCheck::new();
450        backends.set_health_check(check);
451
452        // true: new backend discovered
453        let updated = AtomicBool::new(false);
454        backends
455            .update(|_| updated.store(true, Relaxed))
456            .await
457            .unwrap();
458        assert!(updated.load(Relaxed));
459
460        // false: no new backend discovered
461        let updated = AtomicBool::new(false);
462        backends
463            .update(|_| updated.store(true, Relaxed))
464            .await
465            .unwrap();
466        assert!(!updated.load(Relaxed));
467
468        backends.run_health_check(false).await;
469
470        let backend = backends.get_backend();
471        assert!(backend.contains(&good1));
472        assert!(backend.contains(&good2));
473        assert!(backend.contains(&bad));
474
475        assert!(backends.ready(&good1));
476        assert!(backends.ready(&good2));
477        assert!(!backends.ready(&bad));
478    }
479    #[tokio::test]
480    async fn test_backends_with_ext() {
481        let discovery = discovery::Static::default();
482        let mut b1 = Backend::new("1.1.1.1:80").unwrap();
483        b1.ext.insert(true);
484        let mut b2 = Backend::new("1.0.0.1:80").unwrap();
485        b2.ext.insert(1u8);
486        discovery.add(b1.clone());
487        discovery.add(b2.clone());
488
489        let backends = Backends::new(Box::new(discovery));
490
491        // fill in the backends
492        backends.update(|_| {}).await.unwrap();
493
494        let backend = backends.get_backend();
495        assert!(backend.contains(&b1));
496        assert!(backend.contains(&b2));
497
498        let b2 = backend.first().unwrap();
499        assert_eq!(b2.ext.get::<u8>(), Some(&1));
500
501        let b1 = backend.last().unwrap();
502        assert_eq!(b1.ext.get::<bool>(), Some(&true));
503    }
504
505    #[tokio::test]
506    async fn test_discovery_readiness() {
507        use discovery::Static;
508
509        struct TestDiscovery(Static);
510        #[async_trait]
511        impl ServiceDiscovery for TestDiscovery {
512            async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
513                let bad = Backend::new("127.0.0.1:79").unwrap();
514                let (backends, mut readiness) = self.0.discover().await?;
515                readiness.insert(bad.hash_key(), false);
516                Ok((backends, readiness))
517            }
518        }
519        let discovery = Static::default();
520        let good1 = Backend::new("1.1.1.1:80").unwrap();
521        discovery.add(good1.clone());
522        let good2 = Backend::new("1.0.0.1:80").unwrap();
523        discovery.add(good2.clone());
524        let bad = Backend::new("127.0.0.1:79").unwrap();
525        discovery.add(bad.clone());
526        let discovery = TestDiscovery(discovery);
527
528        let backends = Backends::new(Box::new(discovery));
529
530        // true: new backend discovered
531        let updated = AtomicBool::new(false);
532        backends
533            .update(|_| updated.store(true, Relaxed))
534            .await
535            .unwrap();
536        assert!(updated.load(Relaxed));
537
538        let backend = backends.get_backend();
539        assert!(backend.contains(&good1));
540        assert!(backend.contains(&good2));
541        assert!(backend.contains(&bad));
542
543        assert!(backends.ready(&good1));
544        assert!(backends.ready(&good2));
545        assert!(!backends.ready(&bad));
546    }
547
548    #[tokio::test]
549    async fn test_parallel_health_check() {
550        let discovery = discovery::Static::default();
551        let good1 = Backend::new("1.1.1.1:80").unwrap();
552        discovery.add(good1.clone());
553        let good2 = Backend::new("1.0.0.1:80").unwrap();
554        discovery.add(good2.clone());
555        let bad = Backend::new("127.0.0.1:79").unwrap();
556        discovery.add(bad.clone());
557
558        let mut backends = Backends::new(Box::new(discovery));
559        let check = health_check::TcpHealthCheck::new();
560        backends.set_health_check(check);
561
562        // true: new backend discovered
563        let updated = AtomicBool::new(false);
564        backends
565            .update(|_| updated.store(true, Relaxed))
566            .await
567            .unwrap();
568        assert!(updated.load(Relaxed));
569
570        backends.run_health_check(true).await;
571
572        assert!(backends.ready(&good1));
573        assert!(backends.ready(&good2));
574        assert!(!backends.ready(&bad));
575    }
576
577    mod thread_safety {
578        use super::*;
579
580        struct MockDiscovery {
581            expected: usize,
582        }
583        #[async_trait]
584        impl ServiceDiscovery for MockDiscovery {
585            async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
586                let mut d = BTreeSet::new();
587                let mut m = HashMap::with_capacity(self.expected);
588                for i in 0..self.expected {
589                    let b = Backend::new(&format!("1.1.1.1:{i}")).unwrap();
590                    m.insert(i as u64, true);
591                    d.insert(b);
592                }
593                Ok((d, m))
594            }
595        }
596
597        #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
598        async fn test_consistency() {
599            let expected = 3000;
600            let discovery = MockDiscovery { expected };
601            let lb = Arc::new(LoadBalancer::<selection::Consistent>::from_backends(
602                Backends::new(Box::new(discovery)),
603            ));
604            let lb2 = lb.clone();
605
606            tokio::spawn(async move {
607                assert!(lb2.update().await.is_ok());
608            });
609            let mut backend_count = 0;
610            while backend_count == 0 {
611                let backends = lb.backends();
612                backend_count = backends.backends.load_full().len();
613            }
614            assert_eq!(backend_count, expected);
615            assert!(lb.select_with(b"test", 1, |_, _| true).is_some());
616        }
617    }
618}