Skip to main content

pingora_load_balancing/selection/
weighted.rs

1// Copyright 2026 Cloudflare, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Weighted Selection
16
17use super::{Backend, BackendIter, BackendSelection, SelectionAlgorithm};
18use fnv::FnvHasher;
19use std::collections::BTreeSet;
20use std::sync::Arc;
21
22/// Weighted selection with a given selection algorithm
23///
24/// The default algorithm is [FnvHasher]. See [super::algorithms] for more choices.
25pub struct Weighted<H = FnvHasher> {
26    backends: Box<[Backend]>,
27    // each item is an index to the `backends`, use u16 to save memory, support up to 2^16 backends
28    weighted: Box<[u16]>,
29    algorithm: H,
30}
31
32impl<H: SelectionAlgorithm> BackendSelection for Weighted<H> {
33    type Iter = WeightedIterator<H>;
34
35    type Config = ();
36
37    fn build(backends: &BTreeSet<Backend>) -> Self {
38        assert!(
39            backends.len() <= u16::MAX as usize,
40            "support up to 2^16 backends"
41        );
42        let backends = Vec::from_iter(backends.iter().cloned()).into_boxed_slice();
43        let mut weighted = Vec::with_capacity(backends.len());
44        for (index, b) in backends.iter().enumerate() {
45            for _ in 0..b.weight {
46                weighted.push(index as u16);
47            }
48        }
49        Weighted {
50            backends,
51            weighted: weighted.into_boxed_slice(),
52            algorithm: H::new(),
53        }
54    }
55
56    fn iter(self: &Arc<Self>, key: &[u8]) -> Self::Iter {
57        WeightedIterator::new(key, self.clone())
58    }
59}
60
61/// An iterator over the backends of a [Weighted] selection.
62///
63/// See [super::BackendSelection] for more information.
64pub struct WeightedIterator<H> {
65    // the unbounded index seed
66    index: u64,
67    backend: Arc<Weighted<H>>,
68    first: bool,
69}
70
71impl<H: SelectionAlgorithm> WeightedIterator<H> {
72    /// Constructs a new [WeightedIterator].
73    fn new(input: &[u8], backend: Arc<Weighted<H>>) -> Self {
74        Self {
75            index: backend.algorithm.next(input),
76            backend,
77            first: true,
78        }
79    }
80}
81
82impl<H: SelectionAlgorithm> BackendIter for WeightedIterator<H> {
83    fn next(&mut self) -> Option<&Backend> {
84        if self.backend.backends.is_empty() {
85            // short circuit if empty
86            return None;
87        }
88
89        if self.first {
90            // initial hash, select from the weighted list
91            self.first = false;
92            let len = self.backend.weighted.len();
93            let index = self.backend.weighted[self.index as usize % len];
94            Some(&self.backend.backends[index as usize])
95        } else {
96            // fallback, select from the unique list
97            // deterministically select the next item
98            self.index = self.backend.algorithm.next(&self.index.to_le_bytes());
99            let len = self.backend.backends.len();
100            Some(&self.backend.backends[self.index as usize % len])
101        }
102    }
103}
104
105#[cfg(test)]
106mod test {
107    use super::super::algorithms::*;
108    use super::*;
109    use std::collections::HashMap;
110
111    #[test]
112    fn test_fnv() {
113        let b1 = Backend::new("1.1.1.1:80").unwrap();
114        let mut b2 = Backend::new("1.0.0.1:80").unwrap();
115        b2.weight = 10; // 10x than the rest
116        let b3 = Backend::new("1.0.0.255:80").unwrap();
117        let backends = BTreeSet::from_iter([b1.clone(), b2.clone(), b3.clone()]);
118        let hash: Arc<Weighted> = Arc::new(Weighted::build(&backends));
119
120        // same hash iter over
121        let mut iter = hash.iter(b"test");
122        // first, should be weighted
123        assert_eq!(iter.next(), Some(&b2));
124        // fallbacks, should be uniform, not weighted
125        assert_eq!(iter.next(), Some(&b2));
126        assert_eq!(iter.next(), Some(&b2));
127        assert_eq!(iter.next(), Some(&b1));
128        assert_eq!(iter.next(), Some(&b3));
129        assert_eq!(iter.next(), Some(&b2));
130        assert_eq!(iter.next(), Some(&b2));
131        assert_eq!(iter.next(), Some(&b1));
132        assert_eq!(iter.next(), Some(&b2));
133        assert_eq!(iter.next(), Some(&b3));
134        assert_eq!(iter.next(), Some(&b1));
135
136        // different hashes, the first selection should be weighted
137        let mut iter = hash.iter(b"test1");
138        assert_eq!(iter.next(), Some(&b2));
139        let mut iter = hash.iter(b"test2");
140        assert_eq!(iter.next(), Some(&b2));
141        let mut iter = hash.iter(b"test3");
142        assert_eq!(iter.next(), Some(&b3));
143        let mut iter = hash.iter(b"test4");
144        assert_eq!(iter.next(), Some(&b1));
145        let mut iter = hash.iter(b"test5");
146        assert_eq!(iter.next(), Some(&b2));
147        let mut iter = hash.iter(b"test6");
148        assert_eq!(iter.next(), Some(&b2));
149        let mut iter = hash.iter(b"test7");
150        assert_eq!(iter.next(), Some(&b2));
151    }
152
153    #[test]
154    fn test_round_robin() {
155        let b1 = Backend::new("1.1.1.1:80").unwrap();
156        let mut b2 = Backend::new("1.0.0.1:80").unwrap();
157        b2.weight = 8; // 8x than the rest
158        let b3 = Backend::new("1.0.0.255:80").unwrap();
159        // sorted with: [b2, b3, b1]
160        // weighted: [0, 0, 0, 0, 0, 0, 0, 0, 1, 2]
161        let backends = BTreeSet::from_iter([b1.clone(), b2.clone(), b3.clone()]);
162        let hash: Arc<Weighted<RoundRobin>> = Arc::new(Weighted::build(&backends));
163
164        // same hash iter over
165        let mut iter = hash.iter(b"test");
166        // first, should be weighted
167        // weighted: [0, 0, 0, 0, 0, 0, 0, 0, 1, 2]
168        //            ^
169        assert_eq!(iter.next(), Some(&b2));
170        // fallbacks, should be round robin
171        assert_eq!(iter.next(), Some(&b3));
172        assert_eq!(iter.next(), Some(&b1));
173        assert_eq!(iter.next(), Some(&b2));
174        assert_eq!(iter.next(), Some(&b3));
175
176        // round robin, ignoring the hash key
177        // index advanced 5 steps
178        // weighted: [0, 0, 0, 0, 0, 0, 0, 0, 1, 2]
179        //                           ^
180        let mut iter = hash.iter(b"test1");
181        assert_eq!(iter.next(), Some(&b2));
182        let mut iter = hash.iter(b"test1");
183        assert_eq!(iter.next(), Some(&b2));
184        let mut iter = hash.iter(b"test1");
185        assert_eq!(iter.next(), Some(&b2));
186        let mut iter = hash.iter(b"test1");
187        assert_eq!(iter.next(), Some(&b3));
188        let mut iter = hash.iter(b"test1");
189        assert_eq!(iter.next(), Some(&b1));
190        let mut iter = hash.iter(b"test1");
191        // rounded
192        assert_eq!(iter.next(), Some(&b2));
193        let mut iter = hash.iter(b"test1");
194        assert_eq!(iter.next(), Some(&b2));
195    }
196
197    #[test]
198    fn test_random() {
199        let b1 = Backend::new("1.1.1.1:80").unwrap();
200        let mut b2 = Backend::new("1.0.0.1:80").unwrap();
201        b2.weight = 8; // 8x than the rest
202        let b3 = Backend::new("1.0.0.255:80").unwrap();
203        let backends = BTreeSet::from_iter([b1.clone(), b2.clone(), b3.clone()]);
204        let hash: Arc<Weighted<Random>> = Arc::new(Weighted::build(&backends));
205
206        let mut count = HashMap::new();
207        count.insert(b1.clone(), 0);
208        count.insert(b2.clone(), 0);
209        count.insert(b3.clone(), 0);
210
211        for _ in 0..10000 {
212            let mut iter = hash.iter(b"test");
213            *count.get_mut(iter.next().unwrap()).unwrap() += 1;
214        }
215        let b2_count = *count.get(&b2).unwrap();
216        assert!((7000..=9000).contains(&b2_count));
217    }
218}