faucet_server/client/load_balancing/
ip_hash.rs

1use super::LoadBalancingStrategy;
2use super::WorkerConfig;
3use crate::leak;
4use crate::{client::Client, error::FaucetResult};
5use std::net::IpAddr;
6use std::time::Duration;
7
8struct Targets {
9    targets: &'static [Client],
10}
11
12impl Targets {
13    fn new(configs: &[WorkerConfig]) -> FaucetResult<Self> {
14        let mut targets = Vec::new();
15        for state in configs {
16            let client = Client::builder(*state).build()?;
17            targets.push(client);
18        }
19        let targets = leak!(targets);
20        Ok(Targets { targets })
21    }
22}
23
24pub struct IpHash {
25    targets: Targets,
26    targets_len: usize,
27}
28
29impl IpHash {
30    pub(crate) fn new(targets: &[WorkerConfig]) -> FaucetResult<Self> {
31        Ok(Self {
32            targets_len: targets.as_ref().len(),
33            targets: Targets::new(targets)?,
34        })
35    }
36}
37
38fn calculate_hash(ip: IpAddr) -> u64 {
39    let mut hash_value = match ip {
40        IpAddr::V4(ip) => ip.to_bits() as u64,
41        IpAddr::V6(ip) => ip.to_bits() as u64,
42    };
43    hash_value ^= hash_value >> 33;
44    hash_value = hash_value.wrapping_mul(0xff51afd7ed558ccd);
45    hash_value ^= hash_value >> 33;
46    hash_value = hash_value.wrapping_mul(0xc4ceb9fe1a85ec53);
47    hash_value ^= hash_value >> 33;
48
49    hash_value
50}
51
52fn hash_to_index(value: IpAddr, length: usize) -> usize {
53    let hash = calculate_hash(value);
54    (hash % length as u64) as usize
55}
56
57// 50ms is the minimum backoff time for exponential backoff
58const BASE_BACKOFF: Duration = Duration::from_millis(50);
59
60fn calculate_exponential_backoff(retries: u32) -> Duration {
61    BASE_BACKOFF * 2u32.pow(retries)
62}
63
64impl LoadBalancingStrategy for IpHash {
65    type Input = IpAddr;
66    async fn entry(&self, ip: IpAddr) -> Client {
67        let mut retries = 0;
68        let index = hash_to_index(ip, self.targets_len);
69        let client = self.targets.targets[index].clone();
70        loop {
71            if client.is_online() {
72                break client;
73            }
74
75            let backoff = calculate_exponential_backoff(retries);
76
77            log::debug!(
78                target: "faucet",
79                "IP {} tried to connect to offline {}, retrying in {:?}",
80                ip,
81                client.config.target,
82                backoff
83            );
84
85            tokio::time::sleep(backoff).await;
86            retries += 1;
87        }
88    }
89}
90
91#[cfg(test)]
92mod tests {
93
94    use std::sync::{atomic::AtomicBool, Arc};
95
96    use super::*;
97
98    #[test]
99    fn ip_v4_test_distribution_of_hash_function_len_4() {
100        const N_IP: usize = 100_000;
101
102        // Generate 10_000 ip address and see the
103        // distribution over diferent lengths
104        let ips: Vec<IpAddr> = (0..N_IP)
105            .map(|_| IpAddr::V4(std::net::Ipv4Addr::from_bits(rand::random::<u32>())))
106            .collect();
107
108        // Counts when length == 4
109        let mut counts = [0; 4];
110
111        ips.iter().for_each(|ip| {
112            let index = hash_to_index(*ip, 4);
113            counts[index] += 1;
114        });
115
116        let percent_0 = counts[0] as f64 / N_IP as f64;
117        let percent_1 = counts[1] as f64 / N_IP as f64;
118        let percent_2 = counts[2] as f64 / N_IP as f64;
119        let percent_3 = counts[3] as f64 / N_IP as f64;
120        assert!((0.24..=0.26).contains(&percent_0));
121        assert!((0.24..=0.26).contains(&percent_1));
122        assert!((0.24..=0.26).contains(&percent_2));
123        assert!((0.24..=0.26).contains(&percent_3));
124    }
125
126    #[test]
127    fn ip_v4_test_distribution_of_hash_function_len_3() {
128        const N_IP: usize = 100_000;
129
130        // Generate 10_000 ip address and see the
131        // distribution over diferent lengths
132        let ips: Vec<IpAddr> = (0..N_IP)
133            .map(|_| IpAddr::V4(std::net::Ipv4Addr::from_bits(rand::random::<u32>())))
134            .collect();
135
136        // Counts when length == 4
137        let mut counts = [0; 3];
138
139        ips.iter().for_each(|ip| {
140            let index = hash_to_index(*ip, 3);
141            counts[index] += 1;
142        });
143
144        let percent_0 = counts[0] as f64 / N_IP as f64;
145        let percent_1 = counts[1] as f64 / N_IP as f64;
146        let percent_2 = counts[2] as f64 / N_IP as f64;
147        assert!((0.32..=0.34).contains(&percent_0));
148        assert!((0.32..=0.34).contains(&percent_1));
149        assert!((0.32..=0.34).contains(&percent_2));
150    }
151
152    #[test]
153    fn ip_v4_test_distribution_of_hash_function_len_2() {
154        const N_IP: usize = 100_000;
155
156        // Generate 10_000 ip address and see the
157        // distribution over diferent lengths
158        let ips: Vec<IpAddr> = (0..N_IP)
159            .map(|_| IpAddr::V4(std::net::Ipv4Addr::from_bits(rand::random::<u32>())))
160            .collect();
161
162        // Counts when length == 4
163        let mut counts = [0; 2];
164
165        ips.iter().for_each(|ip| {
166            let index = hash_to_index(*ip, 2);
167            counts[index] += 1;
168        });
169
170        let percent_0 = counts[0] as f64 / N_IP as f64;
171        let percent_1 = counts[1] as f64 / N_IP as f64;
172        assert!((0.49..=0.51).contains(&percent_0));
173        assert!((0.49..=0.51).contains(&percent_1));
174    }
175
176    #[test]
177    fn ip_v6_test_distribution_of_hash_function_len_4() {
178        const N_IP: usize = 100_000;
179
180        // Generate 10_000 ip address and see the
181        // distribution over diferent lengths
182        let ips: Vec<IpAddr> = (0..N_IP)
183            .map(|_| IpAddr::V6(std::net::Ipv6Addr::from_bits(rand::random::<u128>())))
184            .collect();
185
186        // Counts when length == 4
187        let mut counts = [0; 4];
188
189        ips.iter().for_each(|ip| {
190            let index = hash_to_index(*ip, 4);
191            counts[index] += 1;
192        });
193
194        let percent_0 = counts[0] as f64 / N_IP as f64;
195        let percent_1 = counts[1] as f64 / N_IP as f64;
196        let percent_2 = counts[2] as f64 / N_IP as f64;
197        let percent_3 = counts[3] as f64 / N_IP as f64;
198        assert!((0.24..=0.26).contains(&percent_0));
199        assert!((0.24..=0.26).contains(&percent_1));
200        assert!((0.24..=0.26).contains(&percent_2));
201        assert!((0.24..=0.26).contains(&percent_3));
202    }
203
204    #[test]
205    fn ip_v6_test_distribution_of_hash_function_len_3() {
206        const N_IP: usize = 100_000;
207
208        // Generate 10_000 ip address and see the
209        // distribution over diferent lengths
210        let ips: Vec<IpAddr> = (0..N_IP)
211            .map(|_| IpAddr::V6(std::net::Ipv6Addr::from_bits(rand::random::<u128>())))
212            .collect();
213
214        // Counts when length == 4
215        let mut counts = [0; 3];
216
217        ips.iter().for_each(|ip| {
218            let index = hash_to_index(*ip, 3);
219            counts[index] += 1;
220        });
221
222        let percent_0 = counts[0] as f64 / N_IP as f64;
223        let percent_1 = counts[1] as f64 / N_IP as f64;
224        let percent_2 = counts[2] as f64 / N_IP as f64;
225        assert!((0.32..=0.34).contains(&percent_0));
226        assert!((0.32..=0.34).contains(&percent_1));
227        assert!((0.32..=0.34).contains(&percent_2));
228    }
229
230    #[test]
231    fn ip_v6_test_distribution_of_hash_function_len_2() {
232        const N_IP: usize = 100_000;
233
234        // Generate 10_000 ip address and see the
235        // distribution over diferent lengths
236        let ips: Vec<IpAddr> = (0..N_IP)
237            .map(|_| IpAddr::V6(std::net::Ipv6Addr::from_bits(rand::random::<u128>())))
238            .collect();
239
240        // Counts when length == 4
241        let mut counts = [0; 2];
242
243        ips.iter().for_each(|ip| {
244            let index = hash_to_index(*ip, 2);
245            counts[index] += 1;
246        });
247
248        let percent_0 = counts[0] as f64 / N_IP as f64;
249        let percent_1 = counts[1] as f64 / N_IP as f64;
250        assert!((0.49..=0.51).contains(&percent_0));
251        assert!((0.49..=0.51).contains(&percent_1));
252    }
253
254    #[test]
255    fn test_new_targets() {
256        let worker_state = WorkerConfig::dummy("test", "127.0.0.1:9999", true);
257        let Targets { targets } = Targets::new(&[worker_state]).unwrap();
258
259        assert_eq!(targets.len(), 1);
260    }
261
262    #[test]
263    fn test_new_ip_hash() {
264        let worker_state = WorkerConfig::dummy("test", "127.0.0.1:9999", true);
265        let IpHash {
266            targets,
267            targets_len,
268        } = IpHash::new(&[worker_state]).unwrap();
269
270        assert_eq!(targets.targets.len(), 1);
271        assert_eq!(targets_len, 1);
272    }
273
274    #[test]
275    fn test_calculate_exponential_backoff() {
276        assert_eq!(calculate_exponential_backoff(0), BASE_BACKOFF);
277        assert_eq!(calculate_exponential_backoff(1), BASE_BACKOFF * 2);
278        assert_eq!(calculate_exponential_backoff(2), BASE_BACKOFF * 4);
279        assert_eq!(calculate_exponential_backoff(3), BASE_BACKOFF * 8);
280    }
281
282    #[tokio::test]
283    async fn test_load_balancing_strategy() {
284        use crate::client::ExtractSocketAddr;
285        let workers = [
286            WorkerConfig::dummy("test", "127.0.0.1:9999", true),
287            WorkerConfig::dummy("test", "127.0.0.1:8888", true),
288        ];
289        let ip_hash = IpHash::new(&workers).unwrap();
290        let client1 = ip_hash.entry("192.168.0.1".parse().unwrap()).await;
291        let client2 = ip_hash.entry("192.168.0.1".parse().unwrap()).await;
292        assert_eq!(client1.socket_addr(), client2.socket_addr());
293
294        // This IP address should hash to a different index
295        let client3 = ip_hash.entry("192.168.0.10".parse().unwrap()).await;
296        let client4 = ip_hash.entry("192.168.0.10".parse().unwrap()).await;
297
298        assert_eq!(client3.socket_addr(), client4.socket_addr());
299        assert_eq!(client1.socket_addr(), client2.socket_addr());
300
301        assert_ne!(client1.socket_addr(), client3.socket_addr());
302    }
303
304    #[tokio::test]
305    async fn test_load_balancing_strategy_offline() {
306        use crate::client::ExtractSocketAddr;
307
308        let online = Arc::new(AtomicBool::new(false));
309        let worker = WorkerConfig::dummy("test", "127.0.0.1:9999", true);
310
311        let ip_hash = IpHash::new(&[worker]).unwrap();
312
313        tokio::spawn(async move {
314            tokio::time::sleep(Duration::from_millis(100)).await;
315            online.store(true, std::sync::atomic::Ordering::SeqCst);
316        });
317
318        let entry = ip_hash.entry("192.168.0.1".parse().unwrap()).await;
319
320        assert_eq!(entry.socket_addr(), "127.0.0.1:9999".parse().unwrap());
321    }
322}