faucet_server/client/load_balancing/
ip_hash.rs1use super::LoadBalancingStrategy;
2use super::WorkerConfig;
3use crate::client::Client;
4use crate::leak;
5use std::net::IpAddr;
6use std::time::Duration;
7
8struct Targets {
9 targets: &'static [Client],
10}
11
12impl Targets {
13 fn new(configs: &[&'static WorkerConfig]) -> Self {
14 let mut targets = Vec::new();
15 for state in configs {
16 let client = Client::new(state);
17 targets.push(client);
18 }
19 let targets = leak!(targets);
20 Targets { targets }
21 }
22}
23
24pub struct IpHash {
25 targets: Targets,
26 targets_len: usize,
27}
28
29impl IpHash {
30 pub(crate) async fn new(configs: &[&'static WorkerConfig]) -> Self {
31 for config in configs {
33 config.spawn_worker_task().await;
34 }
35 Self {
36 targets_len: configs.as_ref().len(),
37 targets: Targets::new(configs),
38 }
39 }
40}
41
42fn calculate_hash(ip: IpAddr) -> u64 {
43 let mut hash_value = match ip {
44 IpAddr::V4(ip) => ip.to_bits() as u64,
45 IpAddr::V6(ip) => ip.to_bits() as u64,
46 };
47 hash_value ^= hash_value >> 33;
48 hash_value = hash_value.wrapping_mul(0xff51afd7ed558ccd);
49 hash_value ^= hash_value >> 33;
50 hash_value = hash_value.wrapping_mul(0xc4ceb9fe1a85ec53);
51 hash_value ^= hash_value >> 33;
52
53 hash_value
54}
55
56fn hash_to_index(value: IpAddr, length: usize) -> usize {
57 let hash = calculate_hash(value);
58 (hash % length as u64) as usize
59}
60
61const BASE_BACKOFF: Duration = Duration::from_millis(50);
63
64fn calculate_exponential_backoff(retries: u32) -> Duration {
65 BASE_BACKOFF * 2u32.pow(retries)
66}
67
68impl LoadBalancingStrategy for IpHash {
69 type Input = IpAddr;
70 async fn entry(&self, ip: IpAddr) -> Client {
71 let mut retries = 0;
72 let index = hash_to_index(ip, self.targets_len);
73 let client = self.targets.targets[index].clone();
74 loop {
75 if client.is_online() {
76 break client;
77 }
78
79 let backoff = calculate_exponential_backoff(retries);
80
81 log::debug!(
82 target: "faucet",
83 "IP {} tried to connect to offline {}, retrying in {:?}",
84 ip,
85 client.config.target,
86 backoff
87 );
88
89 tokio::time::sleep(backoff).await;
90 retries += 1;
91 }
92 }
93}
94
95#[cfg(test)]
96mod tests {
97
98 use std::sync::{atomic::AtomicBool, Arc};
99
100 use super::*;
101
102 #[test]
103 fn ip_v4_test_distribution_of_hash_function_len_4() {
104 const N_IP: usize = 100_000;
105
106 let ips: Vec<IpAddr> = (0..N_IP)
109 .map(|_| IpAddr::V4(std::net::Ipv4Addr::from_bits(rand::random::<u32>())))
110 .collect();
111
112 let mut counts = [0; 4];
114
115 ips.iter().for_each(|ip| {
116 let index = hash_to_index(*ip, 4);
117 counts[index] += 1;
118 });
119
120 let percent_0 = counts[0] as f64 / N_IP as f64;
121 let percent_1 = counts[1] as f64 / N_IP as f64;
122 let percent_2 = counts[2] as f64 / N_IP as f64;
123 let percent_3 = counts[3] as f64 / N_IP as f64;
124 assert!((0.24..=0.26).contains(&percent_0));
125 assert!((0.24..=0.26).contains(&percent_1));
126 assert!((0.24..=0.26).contains(&percent_2));
127 assert!((0.24..=0.26).contains(&percent_3));
128 }
129
130 #[test]
131 fn ip_v4_test_distribution_of_hash_function_len_3() {
132 const N_IP: usize = 100_000;
133
134 let ips: Vec<IpAddr> = (0..N_IP)
137 .map(|_| IpAddr::V4(std::net::Ipv4Addr::from_bits(rand::random::<u32>())))
138 .collect();
139
140 let mut counts = [0; 3];
142
143 ips.iter().for_each(|ip| {
144 let index = hash_to_index(*ip, 3);
145 counts[index] += 1;
146 });
147
148 let percent_0 = counts[0] as f64 / N_IP as f64;
149 let percent_1 = counts[1] as f64 / N_IP as f64;
150 let percent_2 = counts[2] as f64 / N_IP as f64;
151 assert!((0.32..=0.34).contains(&percent_0));
152 assert!((0.32..=0.34).contains(&percent_1));
153 assert!((0.32..=0.34).contains(&percent_2));
154 }
155
156 #[test]
157 fn ip_v4_test_distribution_of_hash_function_len_2() {
158 const N_IP: usize = 100_000;
159
160 let ips: Vec<IpAddr> = (0..N_IP)
163 .map(|_| IpAddr::V4(std::net::Ipv4Addr::from_bits(rand::random::<u32>())))
164 .collect();
165
166 let mut counts = [0; 2];
168
169 ips.iter().for_each(|ip| {
170 let index = hash_to_index(*ip, 2);
171 counts[index] += 1;
172 });
173
174 let percent_0 = counts[0] as f64 / N_IP as f64;
175 let percent_1 = counts[1] as f64 / N_IP as f64;
176 assert!((0.49..=0.51).contains(&percent_0));
177 assert!((0.49..=0.51).contains(&percent_1));
178 }
179
180 #[test]
181 fn ip_v6_test_distribution_of_hash_function_len_4() {
182 const N_IP: usize = 100_000;
183
184 let ips: Vec<IpAddr> = (0..N_IP)
187 .map(|_| IpAddr::V6(std::net::Ipv6Addr::from_bits(rand::random::<u128>())))
188 .collect();
189
190 let mut counts = [0; 4];
192
193 ips.iter().for_each(|ip| {
194 let index = hash_to_index(*ip, 4);
195 counts[index] += 1;
196 });
197
198 let percent_0 = counts[0] as f64 / N_IP as f64;
199 let percent_1 = counts[1] as f64 / N_IP as f64;
200 let percent_2 = counts[2] as f64 / N_IP as f64;
201 let percent_3 = counts[3] as f64 / N_IP as f64;
202 assert!((0.24..=0.26).contains(&percent_0));
203 assert!((0.24..=0.26).contains(&percent_1));
204 assert!((0.24..=0.26).contains(&percent_2));
205 assert!((0.24..=0.26).contains(&percent_3));
206 }
207
208 #[test]
209 fn ip_v6_test_distribution_of_hash_function_len_3() {
210 const N_IP: usize = 100_000;
211
212 let ips: Vec<IpAddr> = (0..N_IP)
215 .map(|_| IpAddr::V6(std::net::Ipv6Addr::from_bits(rand::random::<u128>())))
216 .collect();
217
218 let mut counts = [0; 3];
220
221 ips.iter().for_each(|ip| {
222 let index = hash_to_index(*ip, 3);
223 counts[index] += 1;
224 });
225
226 let percent_0 = counts[0] as f64 / N_IP as f64;
227 let percent_1 = counts[1] as f64 / N_IP as f64;
228 let percent_2 = counts[2] as f64 / N_IP as f64;
229 assert!((0.32..=0.34).contains(&percent_0));
230 assert!((0.32..=0.34).contains(&percent_1));
231 assert!((0.32..=0.34).contains(&percent_2));
232 }
233
234 #[test]
235 fn ip_v6_test_distribution_of_hash_function_len_2() {
236 const N_IP: usize = 100_000;
237
238 let ips: Vec<IpAddr> = (0..N_IP)
241 .map(|_| IpAddr::V6(std::net::Ipv6Addr::from_bits(rand::random::<u128>())))
242 .collect();
243
244 let mut counts = [0; 2];
246
247 ips.iter().for_each(|ip| {
248 let index = hash_to_index(*ip, 2);
249 counts[index] += 1;
250 });
251
252 let percent_0 = counts[0] as f64 / N_IP as f64;
253 let percent_1 = counts[1] as f64 / N_IP as f64;
254 assert!((0.49..=0.51).contains(&percent_0));
255 assert!((0.49..=0.51).contains(&percent_1));
256 }
257
258 #[test]
259 fn test_new_targets() {
260 let worker_state: &'static WorkerConfig = Box::leak(Box::new(WorkerConfig::dummy(
261 "test",
262 "127.0.0.1:9999",
263 true,
264 )));
265 let Targets { targets } = Targets::new(&[worker_state]);
266
267 assert_eq!(targets.len(), 1);
268 }
269
270 #[tokio::test]
271 async fn test_new_ip_hash() {
272 let worker_state: &'static WorkerConfig = Box::leak(Box::new(WorkerConfig::dummy(
273 "test",
274 "127.0.0.1:9999",
275 true,
276 )));
277 let IpHash {
278 targets,
279 targets_len,
280 } = IpHash::new(&[worker_state]).await;
281
282 assert_eq!(targets.targets.len(), 1);
283 assert_eq!(targets_len, 1);
284
285 worker_state.wait_until_done().await;
286 }
287
288 #[test]
289 fn test_calculate_exponential_backoff() {
290 assert_eq!(calculate_exponential_backoff(0), BASE_BACKOFF);
291 assert_eq!(calculate_exponential_backoff(1), BASE_BACKOFF * 2);
292 assert_eq!(calculate_exponential_backoff(2), BASE_BACKOFF * 4);
293 assert_eq!(calculate_exponential_backoff(3), BASE_BACKOFF * 8);
294 }
295
296 #[tokio::test]
297 async fn test_load_balancing_strategy() {
298 use crate::client::ExtractSocketAddr;
299 let worker1: &'static WorkerConfig = Box::leak(Box::new(WorkerConfig::dummy(
300 "test",
301 "127.0.0.1:9999",
302 true,
303 )));
304 let worker2: &'static WorkerConfig = Box::leak(Box::new(WorkerConfig::dummy(
305 "test",
306 "127.0.0.1:8888",
307 true,
308 )));
309 let workers_static_refs = [worker1, worker2];
310 let ip_hash = IpHash::new(&workers_static_refs).await;
311 let client1 = ip_hash.entry("192.168.0.1".parse().unwrap()).await;
312 let client2 = ip_hash.entry("192.168.0.1".parse().unwrap()).await;
313 assert_eq!(client1.socket_addr(), client2.socket_addr());
314
315 let client3 = ip_hash.entry("192.168.0.10".parse().unwrap()).await;
317 let client4 = ip_hash.entry("192.168.0.10".parse().unwrap()).await;
318
319 assert_eq!(client3.socket_addr(), client4.socket_addr());
320 assert_eq!(client1.socket_addr(), client2.socket_addr());
321
322 assert_ne!(client1.socket_addr(), client3.socket_addr());
323
324 for worker_config in workers_static_refs.iter() {
325 worker_config.wait_until_done().await;
326 }
327 }
328
329 #[tokio::test]
330 async fn test_load_balancing_strategy_offline() {
331 use crate::client::ExtractSocketAddr;
332
333 let online = Arc::new(AtomicBool::new(false));
334 let worker: &'static WorkerConfig = Box::leak(Box::new(WorkerConfig::dummy(
335 "test",
336 "127.0.0.1:9999",
337 true,
338 )));
339
340 let ip_hash = IpHash::new(&[worker]).await;
341
342 tokio::spawn(async move {
343 tokio::time::sleep(Duration::from_millis(100)).await;
344 online.store(true, std::sync::atomic::Ordering::SeqCst);
345 });
346
347 let entry = ip_hash.entry("192.168.0.1".parse().unwrap()).await;
348
349 assert_eq!(entry.socket_addr(), "127.0.0.1:9999".parse().unwrap());
350
351 worker.wait_until_done().await;
352 }
353}