pingora_load_balancing/selection/
weighted.rs

1// Copyright 2025 Cloudflare, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Weighted Selection
16
17use super::{Backend, BackendIter, BackendSelection, SelectionAlgorithm};
18use fnv::FnvHasher;
19use std::collections::BTreeSet;
20use std::sync::Arc;
21
22/// Weighted selection with a given selection algorithm
23///
24/// The default algorithm is [FnvHasher]. See [super::algorithms] for more choices.
25pub struct Weighted<H = FnvHasher> {
26    backends: Box<[Backend]>,
27    // each item is an index to the `backends`, use u16 to save memory, support up to 2^16 backends
28    weighted: Box<[u16]>,
29    algorithm: H,
30}
31
32impl<H: SelectionAlgorithm> BackendSelection for Weighted<H> {
33    type Iter = WeightedIterator<H>;
34
35    fn build(backends: &BTreeSet<Backend>) -> Self {
36        assert!(
37            backends.len() <= u16::MAX as usize,
38            "support up to 2^16 backends"
39        );
40        let backends = Vec::from_iter(backends.iter().cloned()).into_boxed_slice();
41        let mut weighted = Vec::with_capacity(backends.len());
42        for (index, b) in backends.iter().enumerate() {
43            for _ in 0..b.weight {
44                weighted.push(index as u16);
45            }
46        }
47        Weighted {
48            backends,
49            weighted: weighted.into_boxed_slice(),
50            algorithm: H::new(),
51        }
52    }
53
54    fn iter(self: &Arc<Self>, key: &[u8]) -> Self::Iter {
55        WeightedIterator::new(key, self.clone())
56    }
57}
58
59/// An iterator over the backends of a [Weighted] selection.
60///
61/// See [super::BackendSelection] for more information.
62pub struct WeightedIterator<H> {
63    // the unbounded index seed
64    index: u64,
65    backend: Arc<Weighted<H>>,
66    first: bool,
67}
68
69impl<H: SelectionAlgorithm> WeightedIterator<H> {
70    /// Constructs a new [WeightedIterator].
71    fn new(input: &[u8], backend: Arc<Weighted<H>>) -> Self {
72        Self {
73            index: backend.algorithm.next(input),
74            backend,
75            first: true,
76        }
77    }
78}
79
80impl<H: SelectionAlgorithm> BackendIter for WeightedIterator<H> {
81    fn next(&mut self) -> Option<&Backend> {
82        if self.backend.backends.is_empty() {
83            // short circuit if empty
84            return None;
85        }
86
87        if self.first {
88            // initial hash, select from the weighted list
89            self.first = false;
90            let len = self.backend.weighted.len();
91            let index = self.backend.weighted[self.index as usize % len];
92            Some(&self.backend.backends[index as usize])
93        } else {
94            // fallback, select from the unique list
95            // deterministically select the next item
96            self.index = self.backend.algorithm.next(&self.index.to_le_bytes());
97            let len = self.backend.backends.len();
98            Some(&self.backend.backends[self.index as usize % len])
99        }
100    }
101}
102
103#[cfg(test)]
104mod test {
105    use super::super::algorithms::*;
106    use super::*;
107    use std::collections::HashMap;
108
109    #[test]
110    fn test_fnv() {
111        let b1 = Backend::new("1.1.1.1:80").unwrap();
112        let mut b2 = Backend::new("1.0.0.1:80").unwrap();
113        b2.weight = 10; // 10x than the rest
114        let b3 = Backend::new("1.0.0.255:80").unwrap();
115        let backends = BTreeSet::from_iter([b1.clone(), b2.clone(), b3.clone()]);
116        let hash: Arc<Weighted> = Arc::new(Weighted::build(&backends));
117
118        // same hash iter over
119        let mut iter = hash.iter(b"test");
120        // first, should be weighted
121        assert_eq!(iter.next(), Some(&b2));
122        // fallbacks, should be uniform, not weighted
123        assert_eq!(iter.next(), Some(&b2));
124        assert_eq!(iter.next(), Some(&b2));
125        assert_eq!(iter.next(), Some(&b1));
126        assert_eq!(iter.next(), Some(&b3));
127        assert_eq!(iter.next(), Some(&b2));
128        assert_eq!(iter.next(), Some(&b2));
129        assert_eq!(iter.next(), Some(&b1));
130        assert_eq!(iter.next(), Some(&b2));
131        assert_eq!(iter.next(), Some(&b3));
132        assert_eq!(iter.next(), Some(&b1));
133
134        // different hashes, the first selection should be weighted
135        let mut iter = hash.iter(b"test1");
136        assert_eq!(iter.next(), Some(&b2));
137        let mut iter = hash.iter(b"test2");
138        assert_eq!(iter.next(), Some(&b2));
139        let mut iter = hash.iter(b"test3");
140        assert_eq!(iter.next(), Some(&b3));
141        let mut iter = hash.iter(b"test4");
142        assert_eq!(iter.next(), Some(&b1));
143        let mut iter = hash.iter(b"test5");
144        assert_eq!(iter.next(), Some(&b2));
145        let mut iter = hash.iter(b"test6");
146        assert_eq!(iter.next(), Some(&b2));
147        let mut iter = hash.iter(b"test7");
148        assert_eq!(iter.next(), Some(&b2));
149    }
150
151    #[test]
152    fn test_round_robin() {
153        let b1 = Backend::new("1.1.1.1:80").unwrap();
154        let mut b2 = Backend::new("1.0.0.1:80").unwrap();
155        b2.weight = 8; // 8x than the rest
156        let b3 = Backend::new("1.0.0.255:80").unwrap();
157        // sorted with: [b2, b3, b1]
158        // weighted: [0, 0, 0, 0, 0, 0, 0, 0, 1, 2]
159        let backends = BTreeSet::from_iter([b1.clone(), b2.clone(), b3.clone()]);
160        let hash: Arc<Weighted<RoundRobin>> = Arc::new(Weighted::build(&backends));
161
162        // same hash iter over
163        let mut iter = hash.iter(b"test");
164        // first, should be weighted
165        // weighted: [0, 0, 0, 0, 0, 0, 0, 0, 1, 2]
166        //            ^
167        assert_eq!(iter.next(), Some(&b2));
168        // fallbacks, should be round robin
169        assert_eq!(iter.next(), Some(&b3));
170        assert_eq!(iter.next(), Some(&b1));
171        assert_eq!(iter.next(), Some(&b2));
172        assert_eq!(iter.next(), Some(&b3));
173
174        // round robin, ignoring the hash key
175        // index advanced 5 steps
176        // weighted: [0, 0, 0, 0, 0, 0, 0, 0, 1, 2]
177        //                           ^
178        let mut iter = hash.iter(b"test1");
179        assert_eq!(iter.next(), Some(&b2));
180        let mut iter = hash.iter(b"test1");
181        assert_eq!(iter.next(), Some(&b2));
182        let mut iter = hash.iter(b"test1");
183        assert_eq!(iter.next(), Some(&b2));
184        let mut iter = hash.iter(b"test1");
185        assert_eq!(iter.next(), Some(&b3));
186        let mut iter = hash.iter(b"test1");
187        assert_eq!(iter.next(), Some(&b1));
188        let mut iter = hash.iter(b"test1");
189        // rounded
190        assert_eq!(iter.next(), Some(&b2));
191        let mut iter = hash.iter(b"test1");
192        assert_eq!(iter.next(), Some(&b2));
193    }
194
195    #[test]
196    fn test_random() {
197        let b1 = Backend::new("1.1.1.1:80").unwrap();
198        let mut b2 = Backend::new("1.0.0.1:80").unwrap();
199        b2.weight = 8; // 8x than the rest
200        let b3 = Backend::new("1.0.0.255:80").unwrap();
201        let backends = BTreeSet::from_iter([b1.clone(), b2.clone(), b3.clone()]);
202        let hash: Arc<Weighted<Random>> = Arc::new(Weighted::build(&backends));
203
204        let mut count = HashMap::new();
205        count.insert(b1.clone(), 0);
206        count.insert(b2.clone(), 0);
207        count.insert(b3.clone(), 0);
208
209        for _ in 0..10000 {
210            let mut iter = hash.iter(b"test");
211            *count.get_mut(iter.next().unwrap()).unwrap() += 1;
212        }
213        let b2_count = *count.get(&b2).unwrap();
214        assert!((7000..=9000).contains(&b2_count));
215    }
216}