faucet_server/client/load_balancing/
cookie_hash.rs

1use uuid::Uuid;
2
3use super::LoadBalancingStrategy;
4use super::WorkerConfig;
5use crate::leak;
6use crate::{client::Client, error::FaucetResult};
7use std::time::Duration;
8
9struct Targets {
10    targets: &'static [Client],
11}
12
13impl Targets {
14    fn new(configs: &[WorkerConfig]) -> FaucetResult<Self> {
15        let mut targets = Vec::new();
16        for state in configs {
17            let client = Client::builder(*state).build()?;
18            targets.push(client);
19        }
20        let targets = leak!(targets);
21        Ok(Targets { targets })
22    }
23}
24
25pub struct CookieHash {
26    targets: Targets,
27    targets_len: usize,
28}
29
30impl CookieHash {
31    pub(crate) fn new(targets: &[WorkerConfig]) -> FaucetResult<Self> {
32        Ok(Self {
33            targets_len: targets.as_ref().len(),
34            targets: Targets::new(targets)?,
35        })
36    }
37}
38
39fn calculate_hash(cookie_uuid: Uuid) -> u64 {
40    let mut hash_value = cookie_uuid.as_u128() as u64;
41    hash_value ^= hash_value >> 33;
42    hash_value = hash_value.wrapping_mul(0xff51afd7ed558ccd);
43    hash_value ^= hash_value >> 33;
44    hash_value = hash_value.wrapping_mul(0xc4ceb9fe1a85ec53);
45    hash_value ^= hash_value >> 33;
46
47    hash_value
48}
49
50fn hash_to_index(value: Uuid, length: usize) -> usize {
51    let hash = calculate_hash(value);
52    (hash % length as u64) as usize
53}
54
55// 50ms is the minimum backoff time for exponential backoff
56const BASE_BACKOFF: Duration = Duration::from_millis(1);
57
58fn calculate_exponential_backoff(retries: u32) -> Duration {
59    BASE_BACKOFF * 2u32.pow(retries)
60}
61
62impl LoadBalancingStrategy for CookieHash {
63    type Input = Uuid;
64    async fn entry(&self, id: Uuid) -> Client {
65        let mut retries = 0;
66        let index = hash_to_index(id, self.targets_len);
67        let client = self.targets.targets[index].clone();
68        loop {
69            if client.is_online() {
70                break client;
71            }
72
73            let backoff = calculate_exponential_backoff(retries);
74
75            log::debug!(
76                target: "faucet",
77                "LB Session {} tried to connect to offline {}, retrying in {:?}",
78                id,
79                client.config.target,
80                backoff
81            );
82
83            tokio::time::sleep(backoff).await;
84            retries += 1;
85        }
86    }
87}
88
89#[cfg(test)]
90mod tests {
91    use super::*;
92    use crate::client::ExtractSocketAddr;
93
94    use uuid::Uuid;
95
96    #[test]
97    fn uuid_test_distribution_of_hash_function_len_4() {
98        const N_UUIDS: usize = 100_000;
99
100        let uuids: Vec<Uuid> = (0..N_UUIDS).map(|_| Uuid::now_v7()).collect();
101
102        let mut counts = [0; 4];
103
104        uuids.iter().for_each(|uuid| {
105            let index = hash_to_index(*uuid, 4);
106            counts[index] += 1;
107        });
108
109        let percent_0 = counts[0] as f64 / N_UUIDS as f64;
110        let percent_1 = counts[1] as f64 / N_UUIDS as f64;
111        let percent_2 = counts[2] as f64 / N_UUIDS as f64;
112        let percent_3 = counts[3] as f64 / N_UUIDS as f64;
113        assert!((0.24..=0.26).contains(&percent_0));
114        assert!((0.24..=0.26).contains(&percent_1));
115        assert!((0.24..=0.26).contains(&percent_2));
116        assert!((0.24..=0.26).contains(&percent_3));
117    }
118
119    #[test]
120    fn uuid_test_distribution_of_hash_function_len_3() {
121        const N_UUIDS: usize = 100_000;
122
123        let uuids: Vec<Uuid> = (0..N_UUIDS).map(|_| Uuid::now_v7()).collect();
124
125        let mut counts = [0; 3];
126
127        uuids.iter().for_each(|uuid| {
128            let index = hash_to_index(*uuid, 3);
129            counts[index] += 1;
130        });
131
132        let percent_0 = counts[0] as f64 / N_UUIDS as f64;
133        let percent_1 = counts[1] as f64 / N_UUIDS as f64;
134        let percent_2 = counts[2] as f64 / N_UUIDS as f64;
135        assert!((0.32..=0.34).contains(&percent_0));
136        assert!((0.32..=0.34).contains(&percent_1));
137        assert!((0.32..=0.34).contains(&percent_2));
138    }
139
140    #[test]
141    fn uuid_test_distribution_of_hash_function_len_2() {
142        const N_UUIDS: usize = 100_000;
143
144        let uuids: Vec<Uuid> = (0..N_UUIDS).map(|_| Uuid::now_v7()).collect();
145
146        let mut counts = [0; 2];
147
148        uuids.iter().for_each(|uuid| {
149            let index = hash_to_index(*uuid, 2);
150            counts[index] += 1;
151        });
152
153        let percent_0 = counts[0] as f64 / N_UUIDS as f64;
154        let percent_1 = counts[1] as f64 / N_UUIDS as f64;
155        assert!((0.49..=0.51).contains(&percent_0));
156        assert!((0.49..=0.51).contains(&percent_1));
157    }
158
159    #[test]
160    fn test_new_targets() {
161        let worker_state = WorkerConfig::dummy("test", "127.0.0.1:9999", true);
162        let Targets { targets } = Targets::new(&[worker_state]).unwrap();
163
164        assert_eq!(targets.len(), 1);
165    }
166
167    #[test]
168    fn test_new_cookie_hash() {
169        let worker_state = WorkerConfig::dummy("test", "127.0.0.1:9999", true);
170        let CookieHash {
171            targets,
172            targets_len,
173        } = CookieHash::new(&[worker_state]).unwrap();
174
175        assert_eq!(targets.targets.len(), 1);
176        assert_eq!(targets_len, 1);
177    }
178
179    #[test]
180    fn test_calculate_exponential_backoff() {
181        assert_eq!(calculate_exponential_backoff(0), BASE_BACKOFF);
182        assert_eq!(calculate_exponential_backoff(1), BASE_BACKOFF * 2);
183        assert_eq!(calculate_exponential_backoff(2), BASE_BACKOFF * 4);
184        assert_eq!(calculate_exponential_backoff(3), BASE_BACKOFF * 8);
185    }
186
187    #[tokio::test]
188    async fn test_load_balancing_strategy() {
189        let workers = [
190            WorkerConfig::dummy("test1", "127.0.0.1:9999", true),
191            WorkerConfig::dummy("test2", "127.0.0.1:8888", true),
192        ];
193        let cookie_hash = CookieHash::new(&workers).unwrap();
194
195        let uuid1 = Uuid::now_v7();
196        let client1_a = cookie_hash.entry(uuid1).await;
197        let client1_b = cookie_hash.entry(uuid1).await;
198        assert_eq!(client1_a.socket_addr(), client1_b.socket_addr());
199
200        // Generate many UUIDs to increase chance of hitting the other target
201        // This doesn't guarantee hitting the other target if hash distribution is not perfect
202        // or if N_TARGETS is small, but it's a practical test.
203        let mut client2_addr = client1_a.socket_addr();
204        let mut uuid2 = Uuid::now_v7();
205
206        for _ in 0..100 {
207            // Try a few times to get a different client
208            uuid2 = Uuid::now_v7();
209            let client_temp = cookie_hash.entry(uuid2).await;
210            if client_temp.socket_addr() != client1_a.socket_addr() {
211                client2_addr = client_temp.socket_addr();
212                break;
213            }
214        }
215
216        // It's possible (though unlikely for 2 targets and good hash) that we always hit the same target.
217        // A more robust test would mock specific hash results or use more targets.
218        // For now, we assert that two different UUIDs *can* map to different clients.
219        // And the same UUID (uuid2) consistently maps.
220        let client2_a = cookie_hash.entry(uuid2).await;
221        let client2_b = cookie_hash.entry(uuid2).await;
222        assert_eq!(client2_a.socket_addr(), client2_b.socket_addr());
223        assert_eq!(client2_a.socket_addr(), client2_addr);
224
225        if workers.len() > 1 {
226            // Only assert inequality if we expect different clients to be possible and were found
227            if client1_a.socket_addr() != client2_a.socket_addr() {
228                assert_ne!(client1_a.socket_addr(), client2_a.socket_addr());
229            } else {
230                // This might happen if all UUIDs hashed to the same target, or only 1 worker.
231                // Consider logging a warning if this happens frequently with >1 workers.
232                println!("Warning: test_load_balancing_strategy did not find two different UUIDs mapping to different targets.");
233            }
234        } else {
235            assert_eq!(client1_a.socket_addr(), client2_a.socket_addr());
236        }
237    }
238}