faucet_server/client/load_balancing/
ip_hash.rs

1use super::LoadBalancingStrategy;
2use super::WorkerConfig;
3use crate::client::Client;
4use crate::leak;
5use std::net::IpAddr;
6use std::time::Duration;
7
8struct Targets {
9    targets: &'static [Client],
10}
11
12impl Targets {
13    fn new(configs: &[&'static WorkerConfig]) -> Self {
14        let mut targets = Vec::new();
15        for state in configs {
16            let client = Client::new(state);
17            targets.push(client);
18        }
19        let targets = leak!(targets);
20        Targets { targets }
21    }
22}
23
24pub struct IpHash {
25    targets: Targets,
26    targets_len: usize,
27}
28
29impl IpHash {
30    pub(crate) async fn new(configs: &[&'static WorkerConfig]) -> Self {
31        // Start the process of each config
32        for config in configs {
33            config.spawn_worker_task().await;
34        }
35        Self {
36            targets_len: configs.as_ref().len(),
37            targets: Targets::new(configs),
38        }
39    }
40}
41
42fn calculate_hash(ip: IpAddr) -> u64 {
43    let mut hash_value = match ip {
44        IpAddr::V4(ip) => ip.to_bits() as u64,
45        IpAddr::V6(ip) => ip.to_bits() as u64,
46    };
47    hash_value ^= hash_value >> 33;
48    hash_value = hash_value.wrapping_mul(0xff51afd7ed558ccd);
49    hash_value ^= hash_value >> 33;
50    hash_value = hash_value.wrapping_mul(0xc4ceb9fe1a85ec53);
51    hash_value ^= hash_value >> 33;
52
53    hash_value
54}
55
56fn hash_to_index(value: IpAddr, length: usize) -> usize {
57    let hash = calculate_hash(value);
58    (hash % length as u64) as usize
59}
60
61// 50ms is the minimum backoff time for exponential backoff
62const BASE_BACKOFF: Duration = Duration::from_millis(50);
63
64fn calculate_exponential_backoff(retries: u32) -> Duration {
65    BASE_BACKOFF * 2u32.pow(retries)
66}
67
68impl LoadBalancingStrategy for IpHash {
69    type Input = IpAddr;
70    async fn entry(&self, ip: IpAddr) -> Client {
71        let mut retries = 0;
72        let index = hash_to_index(ip, self.targets_len);
73        let client = self.targets.targets[index].clone();
74        loop {
75            if client.is_online() {
76                break client;
77            }
78
79            let backoff = calculate_exponential_backoff(retries);
80
81            log::debug!(
82                target: "faucet",
83                "IP {} tried to connect to offline {}, retrying in {:?}",
84                ip,
85                client.config.target,
86                backoff
87            );
88
89            tokio::time::sleep(backoff).await;
90            retries += 1;
91        }
92    }
93}
94
95#[cfg(test)]
96mod tests {
97
98    use std::sync::{atomic::AtomicBool, Arc};
99
100    use super::*;
101
102    #[test]
103    fn ip_v4_test_distribution_of_hash_function_len_4() {
104        const N_IP: usize = 100_000;
105
106        // Generate 10_000 ip address and see the
107        // distribution over diferent lengths
108        let ips: Vec<IpAddr> = (0..N_IP)
109            .map(|_| IpAddr::V4(std::net::Ipv4Addr::from_bits(rand::random::<u32>())))
110            .collect();
111
112        // Counts when length == 4
113        let mut counts = [0; 4];
114
115        ips.iter().for_each(|ip| {
116            let index = hash_to_index(*ip, 4);
117            counts[index] += 1;
118        });
119
120        let percent_0 = counts[0] as f64 / N_IP as f64;
121        let percent_1 = counts[1] as f64 / N_IP as f64;
122        let percent_2 = counts[2] as f64 / N_IP as f64;
123        let percent_3 = counts[3] as f64 / N_IP as f64;
124        assert!((0.24..=0.26).contains(&percent_0));
125        assert!((0.24..=0.26).contains(&percent_1));
126        assert!((0.24..=0.26).contains(&percent_2));
127        assert!((0.24..=0.26).contains(&percent_3));
128    }
129
130    #[test]
131    fn ip_v4_test_distribution_of_hash_function_len_3() {
132        const N_IP: usize = 100_000;
133
134        // Generate 10_000 ip address and see the
135        // distribution over diferent lengths
136        let ips: Vec<IpAddr> = (0..N_IP)
137            .map(|_| IpAddr::V4(std::net::Ipv4Addr::from_bits(rand::random::<u32>())))
138            .collect();
139
140        // Counts when length == 4
141        let mut counts = [0; 3];
142
143        ips.iter().for_each(|ip| {
144            let index = hash_to_index(*ip, 3);
145            counts[index] += 1;
146        });
147
148        let percent_0 = counts[0] as f64 / N_IP as f64;
149        let percent_1 = counts[1] as f64 / N_IP as f64;
150        let percent_2 = counts[2] as f64 / N_IP as f64;
151        assert!((0.32..=0.34).contains(&percent_0));
152        assert!((0.32..=0.34).contains(&percent_1));
153        assert!((0.32..=0.34).contains(&percent_2));
154    }
155
156    #[test]
157    fn ip_v4_test_distribution_of_hash_function_len_2() {
158        const N_IP: usize = 100_000;
159
160        // Generate 10_000 ip address and see the
161        // distribution over diferent lengths
162        let ips: Vec<IpAddr> = (0..N_IP)
163            .map(|_| IpAddr::V4(std::net::Ipv4Addr::from_bits(rand::random::<u32>())))
164            .collect();
165
166        // Counts when length == 4
167        let mut counts = [0; 2];
168
169        ips.iter().for_each(|ip| {
170            let index = hash_to_index(*ip, 2);
171            counts[index] += 1;
172        });
173
174        let percent_0 = counts[0] as f64 / N_IP as f64;
175        let percent_1 = counts[1] as f64 / N_IP as f64;
176        assert!((0.49..=0.51).contains(&percent_0));
177        assert!((0.49..=0.51).contains(&percent_1));
178    }
179
180    #[test]
181    fn ip_v6_test_distribution_of_hash_function_len_4() {
182        const N_IP: usize = 100_000;
183
184        // Generate 10_000 ip address and see the
185        // distribution over diferent lengths
186        let ips: Vec<IpAddr> = (0..N_IP)
187            .map(|_| IpAddr::V6(std::net::Ipv6Addr::from_bits(rand::random::<u128>())))
188            .collect();
189
190        // Counts when length == 4
191        let mut counts = [0; 4];
192
193        ips.iter().for_each(|ip| {
194            let index = hash_to_index(*ip, 4);
195            counts[index] += 1;
196        });
197
198        let percent_0 = counts[0] as f64 / N_IP as f64;
199        let percent_1 = counts[1] as f64 / N_IP as f64;
200        let percent_2 = counts[2] as f64 / N_IP as f64;
201        let percent_3 = counts[3] as f64 / N_IP as f64;
202        assert!((0.24..=0.26).contains(&percent_0));
203        assert!((0.24..=0.26).contains(&percent_1));
204        assert!((0.24..=0.26).contains(&percent_2));
205        assert!((0.24..=0.26).contains(&percent_3));
206    }
207
208    #[test]
209    fn ip_v6_test_distribution_of_hash_function_len_3() {
210        const N_IP: usize = 100_000;
211
212        // Generate 10_000 ip address and see the
213        // distribution over diferent lengths
214        let ips: Vec<IpAddr> = (0..N_IP)
215            .map(|_| IpAddr::V6(std::net::Ipv6Addr::from_bits(rand::random::<u128>())))
216            .collect();
217
218        // Counts when length == 4
219        let mut counts = [0; 3];
220
221        ips.iter().for_each(|ip| {
222            let index = hash_to_index(*ip, 3);
223            counts[index] += 1;
224        });
225
226        let percent_0 = counts[0] as f64 / N_IP as f64;
227        let percent_1 = counts[1] as f64 / N_IP as f64;
228        let percent_2 = counts[2] as f64 / N_IP as f64;
229        assert!((0.32..=0.34).contains(&percent_0));
230        assert!((0.32..=0.34).contains(&percent_1));
231        assert!((0.32..=0.34).contains(&percent_2));
232    }
233
234    #[test]
235    fn ip_v6_test_distribution_of_hash_function_len_2() {
236        const N_IP: usize = 100_000;
237
238        // Generate 10_000 ip address and see the
239        // distribution over diferent lengths
240        let ips: Vec<IpAddr> = (0..N_IP)
241            .map(|_| IpAddr::V6(std::net::Ipv6Addr::from_bits(rand::random::<u128>())))
242            .collect();
243
244        // Counts when length == 4
245        let mut counts = [0; 2];
246
247        ips.iter().for_each(|ip| {
248            let index = hash_to_index(*ip, 2);
249            counts[index] += 1;
250        });
251
252        let percent_0 = counts[0] as f64 / N_IP as f64;
253        let percent_1 = counts[1] as f64 / N_IP as f64;
254        assert!((0.49..=0.51).contains(&percent_0));
255        assert!((0.49..=0.51).contains(&percent_1));
256    }
257
258    #[test]
259    fn test_new_targets() {
260        let worker_state: &'static WorkerConfig = Box::leak(Box::new(WorkerConfig::dummy(
261            "test",
262            "127.0.0.1:9999",
263            true,
264        )));
265        let Targets { targets } = Targets::new(&[worker_state]);
266
267        assert_eq!(targets.len(), 1);
268    }
269
270    #[tokio::test]
271    async fn test_new_ip_hash() {
272        let worker_state: &'static WorkerConfig = Box::leak(Box::new(WorkerConfig::dummy(
273            "test",
274            "127.0.0.1:9999",
275            true,
276        )));
277        let IpHash {
278            targets,
279            targets_len,
280        } = IpHash::new(&[worker_state]).await;
281
282        assert_eq!(targets.targets.len(), 1);
283        assert_eq!(targets_len, 1);
284
285        worker_state.wait_until_done().await;
286    }
287
288    #[test]
289    fn test_calculate_exponential_backoff() {
290        assert_eq!(calculate_exponential_backoff(0), BASE_BACKOFF);
291        assert_eq!(calculate_exponential_backoff(1), BASE_BACKOFF * 2);
292        assert_eq!(calculate_exponential_backoff(2), BASE_BACKOFF * 4);
293        assert_eq!(calculate_exponential_backoff(3), BASE_BACKOFF * 8);
294    }
295
296    #[tokio::test]
297    async fn test_load_balancing_strategy() {
298        use crate::client::ExtractSocketAddr;
299        let worker1: &'static WorkerConfig = Box::leak(Box::new(WorkerConfig::dummy(
300            "test",
301            "127.0.0.1:9999",
302            true,
303        )));
304        let worker2: &'static WorkerConfig = Box::leak(Box::new(WorkerConfig::dummy(
305            "test",
306            "127.0.0.1:8888",
307            true,
308        )));
309        let workers_static_refs = [worker1, worker2];
310        let ip_hash = IpHash::new(&workers_static_refs).await;
311        let client1 = ip_hash.entry("192.168.0.1".parse().unwrap()).await;
312        let client2 = ip_hash.entry("192.168.0.1".parse().unwrap()).await;
313        assert_eq!(client1.socket_addr(), client2.socket_addr());
314
315        // This IP address should hash to a different index
316        let client3 = ip_hash.entry("192.168.0.10".parse().unwrap()).await;
317        let client4 = ip_hash.entry("192.168.0.10".parse().unwrap()).await;
318
319        assert_eq!(client3.socket_addr(), client4.socket_addr());
320        assert_eq!(client1.socket_addr(), client2.socket_addr());
321
322        assert_ne!(client1.socket_addr(), client3.socket_addr());
323
324        for worker_config in workers_static_refs.iter() {
325            worker_config.wait_until_done().await;
326        }
327    }
328
329    #[tokio::test]
330    async fn test_load_balancing_strategy_offline() {
331        use crate::client::ExtractSocketAddr;
332
333        let online = Arc::new(AtomicBool::new(false));
334        let worker: &'static WorkerConfig = Box::leak(Box::new(WorkerConfig::dummy(
335            "test",
336            "127.0.0.1:9999",
337            true,
338        )));
339
340        let ip_hash = IpHash::new(&[worker]).await;
341
342        tokio::spawn(async move {
343            tokio::time::sleep(Duration::from_millis(100)).await;
344            online.store(true, std::sync::atomic::Ordering::SeqCst);
345        });
346
347        let entry = ip_hash.entry("192.168.0.1".parse().unwrap()).await;
348
349        assert_eq!(entry.socket_addr(), "127.0.0.1:9999".parse().unwrap());
350
351        worker.wait_until_done().await;
352    }
353}