pingora_load_balancing/selection/
weighted.rs1use super::{Backend, BackendIter, BackendSelection, SelectionAlgorithm};
18use fnv::FnvHasher;
19use std::collections::BTreeSet;
20use std::sync::Arc;
21
22pub struct Weighted<H = FnvHasher> {
26 backends: Box<[Backend]>,
27 weighted: Box<[u16]>,
29 algorithm: H,
30}
31
32impl<H: SelectionAlgorithm> BackendSelection for Weighted<H> {
33 type Iter = WeightedIterator<H>;
34
35 type Config = ();
36
37 fn build(backends: &BTreeSet<Backend>) -> Self {
38 assert!(
39 backends.len() <= u16::MAX as usize,
40 "support up to 2^16 backends"
41 );
42 let backends = Vec::from_iter(backends.iter().cloned()).into_boxed_slice();
43 let mut weighted = Vec::with_capacity(backends.len());
44 for (index, b) in backends.iter().enumerate() {
45 for _ in 0..b.weight {
46 weighted.push(index as u16);
47 }
48 }
49 Weighted {
50 backends,
51 weighted: weighted.into_boxed_slice(),
52 algorithm: H::new(),
53 }
54 }
55
56 fn iter(self: &Arc<Self>, key: &[u8]) -> Self::Iter {
57 WeightedIterator::new(key, self.clone())
58 }
59}
60
61pub struct WeightedIterator<H> {
65 index: u64,
67 backend: Arc<Weighted<H>>,
68 first: bool,
69}
70
71impl<H: SelectionAlgorithm> WeightedIterator<H> {
72 fn new(input: &[u8], backend: Arc<Weighted<H>>) -> Self {
74 Self {
75 index: backend.algorithm.next(input),
76 backend,
77 first: true,
78 }
79 }
80}
81
82impl<H: SelectionAlgorithm> BackendIter for WeightedIterator<H> {
83 fn next(&mut self) -> Option<&Backend> {
84 if self.backend.backends.is_empty() {
85 return None;
87 }
88
89 if self.first {
90 self.first = false;
92 let len = self.backend.weighted.len();
93 let index = self.backend.weighted[self.index as usize % len];
94 Some(&self.backend.backends[index as usize])
95 } else {
96 self.index = self.backend.algorithm.next(&self.index.to_le_bytes());
99 let len = self.backend.backends.len();
100 Some(&self.backend.backends[self.index as usize % len])
101 }
102 }
103}
104
105#[cfg(test)]
106mod test {
107 use super::super::algorithms::*;
108 use super::*;
109 use std::collections::HashMap;
110
111 #[test]
112 fn test_fnv() {
113 let b1 = Backend::new("1.1.1.1:80").unwrap();
114 let mut b2 = Backend::new("1.0.0.1:80").unwrap();
115 b2.weight = 10; let b3 = Backend::new("1.0.0.255:80").unwrap();
117 let backends = BTreeSet::from_iter([b1.clone(), b2.clone(), b3.clone()]);
118 let hash: Arc<Weighted> = Arc::new(Weighted::build(&backends));
119
120 let mut iter = hash.iter(b"test");
122 assert_eq!(iter.next(), Some(&b2));
124 assert_eq!(iter.next(), Some(&b2));
126 assert_eq!(iter.next(), Some(&b2));
127 assert_eq!(iter.next(), Some(&b1));
128 assert_eq!(iter.next(), Some(&b3));
129 assert_eq!(iter.next(), Some(&b2));
130 assert_eq!(iter.next(), Some(&b2));
131 assert_eq!(iter.next(), Some(&b1));
132 assert_eq!(iter.next(), Some(&b2));
133 assert_eq!(iter.next(), Some(&b3));
134 assert_eq!(iter.next(), Some(&b1));
135
136 let mut iter = hash.iter(b"test1");
138 assert_eq!(iter.next(), Some(&b2));
139 let mut iter = hash.iter(b"test2");
140 assert_eq!(iter.next(), Some(&b2));
141 let mut iter = hash.iter(b"test3");
142 assert_eq!(iter.next(), Some(&b3));
143 let mut iter = hash.iter(b"test4");
144 assert_eq!(iter.next(), Some(&b1));
145 let mut iter = hash.iter(b"test5");
146 assert_eq!(iter.next(), Some(&b2));
147 let mut iter = hash.iter(b"test6");
148 assert_eq!(iter.next(), Some(&b2));
149 let mut iter = hash.iter(b"test7");
150 assert_eq!(iter.next(), Some(&b2));
151 }
152
153 #[test]
154 fn test_round_robin() {
155 let b1 = Backend::new("1.1.1.1:80").unwrap();
156 let mut b2 = Backend::new("1.0.0.1:80").unwrap();
157 b2.weight = 8; let b3 = Backend::new("1.0.0.255:80").unwrap();
159 let backends = BTreeSet::from_iter([b1.clone(), b2.clone(), b3.clone()]);
162 let hash: Arc<Weighted<RoundRobin>> = Arc::new(Weighted::build(&backends));
163
164 let mut iter = hash.iter(b"test");
166 assert_eq!(iter.next(), Some(&b2));
170 assert_eq!(iter.next(), Some(&b3));
172 assert_eq!(iter.next(), Some(&b1));
173 assert_eq!(iter.next(), Some(&b2));
174 assert_eq!(iter.next(), Some(&b3));
175
176 let mut iter = hash.iter(b"test1");
181 assert_eq!(iter.next(), Some(&b2));
182 let mut iter = hash.iter(b"test1");
183 assert_eq!(iter.next(), Some(&b2));
184 let mut iter = hash.iter(b"test1");
185 assert_eq!(iter.next(), Some(&b2));
186 let mut iter = hash.iter(b"test1");
187 assert_eq!(iter.next(), Some(&b3));
188 let mut iter = hash.iter(b"test1");
189 assert_eq!(iter.next(), Some(&b1));
190 let mut iter = hash.iter(b"test1");
191 assert_eq!(iter.next(), Some(&b2));
193 let mut iter = hash.iter(b"test1");
194 assert_eq!(iter.next(), Some(&b2));
195 }
196
197 #[test]
198 fn test_random() {
199 let b1 = Backend::new("1.1.1.1:80").unwrap();
200 let mut b2 = Backend::new("1.0.0.1:80").unwrap();
201 b2.weight = 8; let b3 = Backend::new("1.0.0.255:80").unwrap();
203 let backends = BTreeSet::from_iter([b1.clone(), b2.clone(), b3.clone()]);
204 let hash: Arc<Weighted<Random>> = Arc::new(Weighted::build(&backends));
205
206 let mut count = HashMap::new();
207 count.insert(b1.clone(), 0);
208 count.insert(b2.clone(), 0);
209 count.insert(b3.clone(), 0);
210
211 for _ in 0..10000 {
212 let mut iter = hash.iter(b"test");
213 *count.get_mut(iter.next().unwrap()).unwrap() += 1;
214 }
215 let b2_count = *count.get(&b2).unwrap();
216 assert!((7000..=9000).contains(&b2_count));
217 }
218}