pingora_load_balancing/selection/
weighted.rs1use super::{Backend, BackendIter, BackendSelection, SelectionAlgorithm};
18use fnv::FnvHasher;
19use std::collections::BTreeSet;
20use std::sync::Arc;
21
22pub struct Weighted<H = FnvHasher> {
26 backends: Box<[Backend]>,
27 weighted: Box<[u16]>,
29 algorithm: H,
30}
31
32impl<H: SelectionAlgorithm> BackendSelection for Weighted<H> {
33 type Iter = WeightedIterator<H>;
34
35 fn build(backends: &BTreeSet<Backend>) -> Self {
36 assert!(
37 backends.len() <= u16::MAX as usize,
38 "support up to 2^16 backends"
39 );
40 let backends = Vec::from_iter(backends.iter().cloned()).into_boxed_slice();
41 let mut weighted = Vec::with_capacity(backends.len());
42 for (index, b) in backends.iter().enumerate() {
43 for _ in 0..b.weight {
44 weighted.push(index as u16);
45 }
46 }
47 Weighted {
48 backends,
49 weighted: weighted.into_boxed_slice(),
50 algorithm: H::new(),
51 }
52 }
53
54 fn iter(self: &Arc<Self>, key: &[u8]) -> Self::Iter {
55 WeightedIterator::new(key, self.clone())
56 }
57}
58
59pub struct WeightedIterator<H> {
63 index: u64,
65 backend: Arc<Weighted<H>>,
66 first: bool,
67}
68
69impl<H: SelectionAlgorithm> WeightedIterator<H> {
70 fn new(input: &[u8], backend: Arc<Weighted<H>>) -> Self {
72 Self {
73 index: backend.algorithm.next(input),
74 backend,
75 first: true,
76 }
77 }
78}
79
80impl<H: SelectionAlgorithm> BackendIter for WeightedIterator<H> {
81 fn next(&mut self) -> Option<&Backend> {
82 if self.backend.backends.is_empty() {
83 return None;
85 }
86
87 if self.first {
88 self.first = false;
90 let len = self.backend.weighted.len();
91 let index = self.backend.weighted[self.index as usize % len];
92 Some(&self.backend.backends[index as usize])
93 } else {
94 self.index = self.backend.algorithm.next(&self.index.to_le_bytes());
97 let len = self.backend.backends.len();
98 Some(&self.backend.backends[self.index as usize % len])
99 }
100 }
101}
102
103#[cfg(test)]
104mod test {
105 use super::super::algorithms::*;
106 use super::*;
107 use std::collections::HashMap;
108
109 #[test]
110 fn test_fnv() {
111 let b1 = Backend::new("1.1.1.1:80").unwrap();
112 let mut b2 = Backend::new("1.0.0.1:80").unwrap();
113 b2.weight = 10; let b3 = Backend::new("1.0.0.255:80").unwrap();
115 let backends = BTreeSet::from_iter([b1.clone(), b2.clone(), b3.clone()]);
116 let hash: Arc<Weighted> = Arc::new(Weighted::build(&backends));
117
118 let mut iter = hash.iter(b"test");
120 assert_eq!(iter.next(), Some(&b2));
122 assert_eq!(iter.next(), Some(&b2));
124 assert_eq!(iter.next(), Some(&b2));
125 assert_eq!(iter.next(), Some(&b1));
126 assert_eq!(iter.next(), Some(&b3));
127 assert_eq!(iter.next(), Some(&b2));
128 assert_eq!(iter.next(), Some(&b2));
129 assert_eq!(iter.next(), Some(&b1));
130 assert_eq!(iter.next(), Some(&b2));
131 assert_eq!(iter.next(), Some(&b3));
132 assert_eq!(iter.next(), Some(&b1));
133
134 let mut iter = hash.iter(b"test1");
136 assert_eq!(iter.next(), Some(&b2));
137 let mut iter = hash.iter(b"test2");
138 assert_eq!(iter.next(), Some(&b2));
139 let mut iter = hash.iter(b"test3");
140 assert_eq!(iter.next(), Some(&b3));
141 let mut iter = hash.iter(b"test4");
142 assert_eq!(iter.next(), Some(&b1));
143 let mut iter = hash.iter(b"test5");
144 assert_eq!(iter.next(), Some(&b2));
145 let mut iter = hash.iter(b"test6");
146 assert_eq!(iter.next(), Some(&b2));
147 let mut iter = hash.iter(b"test7");
148 assert_eq!(iter.next(), Some(&b2));
149 }
150
151 #[test]
152 fn test_round_robin() {
153 let b1 = Backend::new("1.1.1.1:80").unwrap();
154 let mut b2 = Backend::new("1.0.0.1:80").unwrap();
155 b2.weight = 8; let b3 = Backend::new("1.0.0.255:80").unwrap();
157 let backends = BTreeSet::from_iter([b1.clone(), b2.clone(), b3.clone()]);
160 let hash: Arc<Weighted<RoundRobin>> = Arc::new(Weighted::build(&backends));
161
162 let mut iter = hash.iter(b"test");
164 assert_eq!(iter.next(), Some(&b2));
168 assert_eq!(iter.next(), Some(&b3));
170 assert_eq!(iter.next(), Some(&b1));
171 assert_eq!(iter.next(), Some(&b2));
172 assert_eq!(iter.next(), Some(&b3));
173
174 let mut iter = hash.iter(b"test1");
179 assert_eq!(iter.next(), Some(&b2));
180 let mut iter = hash.iter(b"test1");
181 assert_eq!(iter.next(), Some(&b2));
182 let mut iter = hash.iter(b"test1");
183 assert_eq!(iter.next(), Some(&b2));
184 let mut iter = hash.iter(b"test1");
185 assert_eq!(iter.next(), Some(&b3));
186 let mut iter = hash.iter(b"test1");
187 assert_eq!(iter.next(), Some(&b1));
188 let mut iter = hash.iter(b"test1");
189 assert_eq!(iter.next(), Some(&b2));
191 let mut iter = hash.iter(b"test1");
192 assert_eq!(iter.next(), Some(&b2));
193 }
194
195 #[test]
196 fn test_random() {
197 let b1 = Backend::new("1.1.1.1:80").unwrap();
198 let mut b2 = Backend::new("1.0.0.1:80").unwrap();
199 b2.weight = 8; let b3 = Backend::new("1.0.0.255:80").unwrap();
201 let backends = BTreeSet::from_iter([b1.clone(), b2.clone(), b3.clone()]);
202 let hash: Arc<Weighted<Random>> = Arc::new(Weighted::build(&backends));
203
204 let mut count = HashMap::new();
205 count.insert(b1.clone(), 0);
206 count.insert(b2.clone(), 0);
207 count.insert(b3.clone(), 0);
208
209 for _ in 0..10000 {
210 let mut iter = hash.iter(b"test");
211 *count.get_mut(iter.next().unwrap()).unwrap() += 1;
212 }
213 let b2_count = *count.get(&b2).unwrap();
214 assert!((7000..=9000).contains(&b2_count));
215 }
216}