faucet_server/client/load_balancing/
ip_hash.rs1use super::LoadBalancingStrategy;
2use super::WorkerConfig;
3use crate::leak;
4use crate::{client::Client, error::FaucetResult};
5use std::net::IpAddr;
6use std::time::Duration;
7
8struct Targets {
9 targets: &'static [Client],
10}
11
12impl Targets {
13 fn new(configs: &[WorkerConfig]) -> FaucetResult<Self> {
14 let mut targets = Vec::new();
15 for state in configs {
16 let client = Client::builder(*state).build()?;
17 targets.push(client);
18 }
19 let targets = leak!(targets);
20 Ok(Targets { targets })
21 }
22}
23
24pub struct IpHash {
25 targets: Targets,
26 targets_len: usize,
27}
28
29impl IpHash {
30 pub(crate) fn new(targets: &[WorkerConfig]) -> FaucetResult<Self> {
31 Ok(Self {
32 targets_len: targets.as_ref().len(),
33 targets: Targets::new(targets)?,
34 })
35 }
36}
37
38fn calculate_hash(ip: IpAddr) -> u64 {
39 let mut hash_value = match ip {
40 IpAddr::V4(ip) => ip.to_bits() as u64,
41 IpAddr::V6(ip) => ip.to_bits() as u64,
42 };
43 hash_value ^= hash_value >> 33;
44 hash_value = hash_value.wrapping_mul(0xff51afd7ed558ccd);
45 hash_value ^= hash_value >> 33;
46 hash_value = hash_value.wrapping_mul(0xc4ceb9fe1a85ec53);
47 hash_value ^= hash_value >> 33;
48
49 hash_value
50}
51
52fn hash_to_index(value: IpAddr, length: usize) -> usize {
53 let hash = calculate_hash(value);
54 (hash % length as u64) as usize
55}
56
57const BASE_BACKOFF: Duration = Duration::from_millis(50);
59
60fn calculate_exponential_backoff(retries: u32) -> Duration {
61 BASE_BACKOFF * 2u32.pow(retries)
62}
63
64impl LoadBalancingStrategy for IpHash {
65 type Input = IpAddr;
66 async fn entry(&self, ip: IpAddr) -> Client {
67 let mut retries = 0;
68 let index = hash_to_index(ip, self.targets_len);
69 let client = self.targets.targets[index].clone();
70 loop {
71 if client.is_online() {
72 break client;
73 }
74
75 let backoff = calculate_exponential_backoff(retries);
76
77 log::debug!(
78 target: "faucet",
79 "IP {} tried to connect to offline {}, retrying in {:?}",
80 ip,
81 client.config.target,
82 backoff
83 );
84
85 tokio::time::sleep(backoff).await;
86 retries += 1;
87 }
88 }
89}
90
91#[cfg(test)]
92mod tests {
93
94 use std::sync::{atomic::AtomicBool, Arc};
95
96 use super::*;
97
98 #[test]
99 fn ip_v4_test_distribution_of_hash_function_len_4() {
100 const N_IP: usize = 100_000;
101
102 let ips: Vec<IpAddr> = (0..N_IP)
105 .map(|_| IpAddr::V4(std::net::Ipv4Addr::from_bits(rand::random::<u32>())))
106 .collect();
107
108 let mut counts = [0; 4];
110
111 ips.iter().for_each(|ip| {
112 let index = hash_to_index(*ip, 4);
113 counts[index] += 1;
114 });
115
116 let percent_0 = counts[0] as f64 / N_IP as f64;
117 let percent_1 = counts[1] as f64 / N_IP as f64;
118 let percent_2 = counts[2] as f64 / N_IP as f64;
119 let percent_3 = counts[3] as f64 / N_IP as f64;
120 assert!((0.24..=0.26).contains(&percent_0));
121 assert!((0.24..=0.26).contains(&percent_1));
122 assert!((0.24..=0.26).contains(&percent_2));
123 assert!((0.24..=0.26).contains(&percent_3));
124 }
125
126 #[test]
127 fn ip_v4_test_distribution_of_hash_function_len_3() {
128 const N_IP: usize = 100_000;
129
130 let ips: Vec<IpAddr> = (0..N_IP)
133 .map(|_| IpAddr::V4(std::net::Ipv4Addr::from_bits(rand::random::<u32>())))
134 .collect();
135
136 let mut counts = [0; 3];
138
139 ips.iter().for_each(|ip| {
140 let index = hash_to_index(*ip, 3);
141 counts[index] += 1;
142 });
143
144 let percent_0 = counts[0] as f64 / N_IP as f64;
145 let percent_1 = counts[1] as f64 / N_IP as f64;
146 let percent_2 = counts[2] as f64 / N_IP as f64;
147 assert!((0.32..=0.34).contains(&percent_0));
148 assert!((0.32..=0.34).contains(&percent_1));
149 assert!((0.32..=0.34).contains(&percent_2));
150 }
151
152 #[test]
153 fn ip_v4_test_distribution_of_hash_function_len_2() {
154 const N_IP: usize = 100_000;
155
156 let ips: Vec<IpAddr> = (0..N_IP)
159 .map(|_| IpAddr::V4(std::net::Ipv4Addr::from_bits(rand::random::<u32>())))
160 .collect();
161
162 let mut counts = [0; 2];
164
165 ips.iter().for_each(|ip| {
166 let index = hash_to_index(*ip, 2);
167 counts[index] += 1;
168 });
169
170 let percent_0 = counts[0] as f64 / N_IP as f64;
171 let percent_1 = counts[1] as f64 / N_IP as f64;
172 assert!((0.49..=0.51).contains(&percent_0));
173 assert!((0.49..=0.51).contains(&percent_1));
174 }
175
176 #[test]
177 fn ip_v6_test_distribution_of_hash_function_len_4() {
178 const N_IP: usize = 100_000;
179
180 let ips: Vec<IpAddr> = (0..N_IP)
183 .map(|_| IpAddr::V6(std::net::Ipv6Addr::from_bits(rand::random::<u128>())))
184 .collect();
185
186 let mut counts = [0; 4];
188
189 ips.iter().for_each(|ip| {
190 let index = hash_to_index(*ip, 4);
191 counts[index] += 1;
192 });
193
194 let percent_0 = counts[0] as f64 / N_IP as f64;
195 let percent_1 = counts[1] as f64 / N_IP as f64;
196 let percent_2 = counts[2] as f64 / N_IP as f64;
197 let percent_3 = counts[3] as f64 / N_IP as f64;
198 assert!((0.24..=0.26).contains(&percent_0));
199 assert!((0.24..=0.26).contains(&percent_1));
200 assert!((0.24..=0.26).contains(&percent_2));
201 assert!((0.24..=0.26).contains(&percent_3));
202 }
203
204 #[test]
205 fn ip_v6_test_distribution_of_hash_function_len_3() {
206 const N_IP: usize = 100_000;
207
208 let ips: Vec<IpAddr> = (0..N_IP)
211 .map(|_| IpAddr::V6(std::net::Ipv6Addr::from_bits(rand::random::<u128>())))
212 .collect();
213
214 let mut counts = [0; 3];
216
217 ips.iter().for_each(|ip| {
218 let index = hash_to_index(*ip, 3);
219 counts[index] += 1;
220 });
221
222 let percent_0 = counts[0] as f64 / N_IP as f64;
223 let percent_1 = counts[1] as f64 / N_IP as f64;
224 let percent_2 = counts[2] as f64 / N_IP as f64;
225 assert!((0.32..=0.34).contains(&percent_0));
226 assert!((0.32..=0.34).contains(&percent_1));
227 assert!((0.32..=0.34).contains(&percent_2));
228 }
229
230 #[test]
231 fn ip_v6_test_distribution_of_hash_function_len_2() {
232 const N_IP: usize = 100_000;
233
234 let ips: Vec<IpAddr> = (0..N_IP)
237 .map(|_| IpAddr::V6(std::net::Ipv6Addr::from_bits(rand::random::<u128>())))
238 .collect();
239
240 let mut counts = [0; 2];
242
243 ips.iter().for_each(|ip| {
244 let index = hash_to_index(*ip, 2);
245 counts[index] += 1;
246 });
247
248 let percent_0 = counts[0] as f64 / N_IP as f64;
249 let percent_1 = counts[1] as f64 / N_IP as f64;
250 assert!((0.49..=0.51).contains(&percent_0));
251 assert!((0.49..=0.51).contains(&percent_1));
252 }
253
254 #[test]
255 fn test_new_targets() {
256 let worker_state = WorkerConfig::dummy("test", "127.0.0.1:9999", true);
257 let Targets { targets } = Targets::new(&[worker_state]).unwrap();
258
259 assert_eq!(targets.len(), 1);
260 }
261
262 #[test]
263 fn test_new_ip_hash() {
264 let worker_state = WorkerConfig::dummy("test", "127.0.0.1:9999", true);
265 let IpHash {
266 targets,
267 targets_len,
268 } = IpHash::new(&[worker_state]).unwrap();
269
270 assert_eq!(targets.targets.len(), 1);
271 assert_eq!(targets_len, 1);
272 }
273
274 #[test]
275 fn test_calculate_exponential_backoff() {
276 assert_eq!(calculate_exponential_backoff(0), BASE_BACKOFF);
277 assert_eq!(calculate_exponential_backoff(1), BASE_BACKOFF * 2);
278 assert_eq!(calculate_exponential_backoff(2), BASE_BACKOFF * 4);
279 assert_eq!(calculate_exponential_backoff(3), BASE_BACKOFF * 8);
280 }
281
282 #[tokio::test]
283 async fn test_load_balancing_strategy() {
284 use crate::client::ExtractSocketAddr;
285 let workers = [
286 WorkerConfig::dummy("test", "127.0.0.1:9999", true),
287 WorkerConfig::dummy("test", "127.0.0.1:8888", true),
288 ];
289 let ip_hash = IpHash::new(&workers).unwrap();
290 let client1 = ip_hash.entry("192.168.0.1".parse().unwrap()).await;
291 let client2 = ip_hash.entry("192.168.0.1".parse().unwrap()).await;
292 assert_eq!(client1.socket_addr(), client2.socket_addr());
293
294 let client3 = ip_hash.entry("192.168.0.10".parse().unwrap()).await;
296 let client4 = ip_hash.entry("192.168.0.10".parse().unwrap()).await;
297
298 assert_eq!(client3.socket_addr(), client4.socket_addr());
299 assert_eq!(client1.socket_addr(), client2.socket_addr());
300
301 assert_ne!(client1.socket_addr(), client3.socket_addr());
302 }
303
304 #[tokio::test]
305 async fn test_load_balancing_strategy_offline() {
306 use crate::client::ExtractSocketAddr;
307
308 let online = Arc::new(AtomicBool::new(false));
309 let worker = WorkerConfig::dummy("test", "127.0.0.1:9999", true);
310
311 let ip_hash = IpHash::new(&[worker]).unwrap();
312
313 tokio::spawn(async move {
314 tokio::time::sleep(Duration::from_millis(100)).await;
315 online.store(true, std::sync::atomic::Ordering::SeqCst);
316 });
317
318 let entry = ip_hash.entry("192.168.0.1".parse().unwrap()).await;
319
320 assert_eq!(entry.socket_addr(), "127.0.0.1:9999".parse().unwrap());
321 }
322}