faucet_server/client/load_balancing/
cookie_hash.rs

1use uuid::Uuid;
2
3use super::LoadBalancingStrategy;
4use super::WorkerConfig;
5use crate::client::Client;
6use crate::leak;
7use std::time::Duration;
8
9struct Targets {
10    targets: &'static [Client],
11}
12
13impl Targets {
14    fn new(configs: &[&'static WorkerConfig]) -> Self {
15        let mut targets = Vec::new();
16        for state in configs {
17            let client = Client::new(state);
18            targets.push(client);
19        }
20        let targets = leak!(targets);
21        Targets { targets }
22    }
23}
24
25pub struct CookieHash {
26    targets: Targets,
27    targets_len: usize,
28}
29
30impl CookieHash {
31    pub(crate) async fn new(configs: &[&'static WorkerConfig]) -> Self {
32        // Start the process of each config
33        for config in configs {
34            config.spawn_worker_task().await;
35        }
36        Self {
37            targets_len: configs.as_ref().len(),
38            targets: Targets::new(configs),
39        }
40    }
41}
42
43fn calculate_hash(cookie_uuid: Uuid) -> u64 {
44    let mut hash_value = cookie_uuid.as_u128() as u64;
45    hash_value ^= hash_value >> 33;
46    hash_value = hash_value.wrapping_mul(0xff51afd7ed558ccd);
47    hash_value ^= hash_value >> 33;
48    hash_value = hash_value.wrapping_mul(0xc4ceb9fe1a85ec53);
49    hash_value ^= hash_value >> 33;
50
51    hash_value
52}
53
54fn hash_to_index(value: Uuid, length: usize) -> usize {
55    let hash = calculate_hash(value);
56    (hash % length as u64) as usize
57}
58
59// 50ms is the minimum backoff time for exponential backoff
60const BASE_BACKOFF: Duration = Duration::from_millis(1);
61
62fn calculate_exponential_backoff(retries: u32) -> Duration {
63    BASE_BACKOFF * 2u32.pow(retries)
64}
65
66impl LoadBalancingStrategy for CookieHash {
67    type Input = Uuid;
68    async fn entry(&self, id: Uuid) -> Client {
69        let mut retries = 0;
70        let index = hash_to_index(id, self.targets_len);
71        let client = self.targets.targets[index].clone();
72        loop {
73            if client.is_online() {
74                break client;
75            }
76
77            let backoff = calculate_exponential_backoff(retries);
78
79            log::debug!(
80                target: "faucet",
81                "LB Session {} tried to connect to offline {}, retrying in {:?}",
82                id,
83                client.config.target,
84                backoff
85            );
86
87            tokio::time::sleep(backoff).await;
88            retries += 1;
89        }
90    }
91}
92
93#[cfg(test)]
94mod tests {
95    use super::*;
96    use crate::client::ExtractSocketAddr;
97
98    use uuid::Uuid;
99
100    #[test]
101    fn uuid_test_distribution_of_hash_function_len_4() {
102        const N_UUIDS: usize = 100_000;
103
104        let uuids: Vec<Uuid> = (0..N_UUIDS).map(|_| Uuid::now_v7()).collect();
105
106        let mut counts = [0; 4];
107
108        uuids.iter().for_each(|uuid| {
109            let index = hash_to_index(*uuid, 4);
110            counts[index] += 1;
111        });
112
113        let percent_0 = counts[0] as f64 / N_UUIDS as f64;
114        let percent_1 = counts[1] as f64 / N_UUIDS as f64;
115        let percent_2 = counts[2] as f64 / N_UUIDS as f64;
116        let percent_3 = counts[3] as f64 / N_UUIDS as f64;
117        assert!((0.24..=0.26).contains(&percent_0));
118        assert!((0.24..=0.26).contains(&percent_1));
119        assert!((0.24..=0.26).contains(&percent_2));
120        assert!((0.24..=0.26).contains(&percent_3));
121    }
122
123    #[test]
124    fn uuid_test_distribution_of_hash_function_len_3() {
125        const N_UUIDS: usize = 100_000;
126
127        let uuids: Vec<Uuid> = (0..N_UUIDS).map(|_| Uuid::now_v7()).collect();
128
129        let mut counts = [0; 3];
130
131        uuids.iter().for_each(|uuid| {
132            let index = hash_to_index(*uuid, 3);
133            counts[index] += 1;
134        });
135
136        let percent_0 = counts[0] as f64 / N_UUIDS as f64;
137        let percent_1 = counts[1] as f64 / N_UUIDS as f64;
138        let percent_2 = counts[2] as f64 / N_UUIDS as f64;
139        assert!((0.32..=0.34).contains(&percent_0));
140        assert!((0.32..=0.34).contains(&percent_1));
141        assert!((0.32..=0.34).contains(&percent_2));
142    }
143
144    #[test]
145    fn uuid_test_distribution_of_hash_function_len_2() {
146        const N_UUIDS: usize = 100_000;
147
148        let uuids: Vec<Uuid> = (0..N_UUIDS).map(|_| Uuid::now_v7()).collect();
149
150        let mut counts = [0; 2];
151
152        uuids.iter().for_each(|uuid| {
153            let index = hash_to_index(*uuid, 2);
154            counts[index] += 1;
155        });
156
157        let percent_0 = counts[0] as f64 / N_UUIDS as f64;
158        let percent_1 = counts[1] as f64 / N_UUIDS as f64;
159        assert!((0.49..=0.51).contains(&percent_0));
160        assert!((0.49..=0.51).contains(&percent_1));
161    }
162
163    #[test]
164    fn test_new_targets() {
165        let worker_state: &'static WorkerConfig = Box::leak(Box::new(WorkerConfig::dummy(
166            "test",
167            "127.0.0.1:9999",
168            true,
169        )));
170        let Targets { targets } = Targets::new(&[worker_state]);
171
172        assert_eq!(targets.len(), 1);
173    }
174
175    #[tokio::test]
176    async fn test_new_cookie_hash() {
177        let worker_state: &'static WorkerConfig = Box::leak(Box::new(WorkerConfig::dummy(
178            "test",
179            "127.0.0.1:9999",
180            true,
181        )));
182        let CookieHash {
183            targets,
184            targets_len,
185        } = CookieHash::new(&[worker_state]).await;
186
187        assert_eq!(targets.targets.len(), 1);
188        assert_eq!(targets_len, 1);
189
190        worker_state.wait_until_done().await;
191    }
192
193    #[test]
194    fn test_calculate_exponential_backoff() {
195        assert_eq!(calculate_exponential_backoff(0), BASE_BACKOFF);
196        assert_eq!(calculate_exponential_backoff(1), BASE_BACKOFF * 2);
197        assert_eq!(calculate_exponential_backoff(2), BASE_BACKOFF * 4);
198        assert_eq!(calculate_exponential_backoff(3), BASE_BACKOFF * 8);
199    }
200
201    #[tokio::test]
202    async fn test_load_balancing_strategy() {
203        let worker1: &'static WorkerConfig = Box::leak(Box::new(WorkerConfig::dummy(
204            "test1",
205            "127.0.0.1:9999",
206            true,
207        )));
208        let worker2: &'static WorkerConfig = Box::leak(Box::new(WorkerConfig::dummy(
209            "test2",
210            "127.0.0.1:8888",
211            true,
212        )));
213        let workers_static_refs = [worker1, worker2];
214        let cookie_hash = CookieHash::new(&workers_static_refs).await;
215
216        let uuid1 = Uuid::now_v7();
217        let client1_a = cookie_hash.entry(uuid1).await;
218        let client1_b = cookie_hash.entry(uuid1).await;
219        assert_eq!(client1_a.socket_addr(), client1_b.socket_addr());
220
221        // Generate many UUIDs to increase chance of hitting the other target
222        // This doesn't guarantee hitting the other target if hash distribution is not perfect
223        // or if N_TARGETS is small, but it's a practical test.
224        let mut client2_addr = client1_a.socket_addr();
225        let mut uuid2 = Uuid::now_v7();
226
227        for _ in 0..100 {
228            // Try a few times to get a different client
229            uuid2 = Uuid::now_v7();
230            let client_temp = cookie_hash.entry(uuid2).await;
231            if client_temp.socket_addr() != client1_a.socket_addr() {
232                client2_addr = client_temp.socket_addr();
233                break;
234            }
235        }
236
237        // It's possible (though unlikely for 2 targets and good hash) that we always hit the same target.
238        // A more robust test would mock specific hash results or use more targets.
239        // For now, we assert that two different UUIDs *can* map to different clients.
240        // And the same UUID (uuid2) consistently maps.
241        let client2_a = cookie_hash.entry(uuid2).await;
242        let client2_b = cookie_hash.entry(uuid2).await;
243        assert_eq!(client2_a.socket_addr(), client2_b.socket_addr());
244        assert_eq!(client2_a.socket_addr(), client2_addr);
245
246        if workers_static_refs.len() > 1 {
247            // Only assert inequality if we expect different clients to be possible and were found
248            if client1_a.socket_addr() != client2_a.socket_addr() {
249                assert_ne!(client1_a.socket_addr(), client2_a.socket_addr());
250            } else {
251                // This might happen if all UUIDs hashed to the same target, or only 1 worker.
252                // Consider logging a warning if this happens frequently with >1 workers.
253                println!("Warning: test_load_balancing_strategy did not find two different UUIDs mapping to different targets.");
254            }
255        } else {
256            assert_eq!(client1_a.socket_addr(), client2_a.socket_addr());
257        }
258
259        for worker_config in workers_static_refs.iter() {
260            worker_config.wait_until_done().await;
261        }
262    }
263}