faucet_server/client/load_balancing/
cookie_hash.rs1use uuid::Uuid;
2
3use super::LoadBalancingStrategy;
4use super::WorkerConfig;
5use crate::leak;
6use crate::{client::Client, error::FaucetResult};
7use std::time::Duration;
8
9struct Targets {
10 targets: &'static [Client],
11}
12
13impl Targets {
14 fn new(configs: &[WorkerConfig]) -> FaucetResult<Self> {
15 let mut targets = Vec::new();
16 for state in configs {
17 let client = Client::builder(*state).build()?;
18 targets.push(client);
19 }
20 let targets = leak!(targets);
21 Ok(Targets { targets })
22 }
23}
24
25pub struct CookieHash {
26 targets: Targets,
27 targets_len: usize,
28}
29
30impl CookieHash {
31 pub(crate) fn new(targets: &[WorkerConfig]) -> FaucetResult<Self> {
32 Ok(Self {
33 targets_len: targets.as_ref().len(),
34 targets: Targets::new(targets)?,
35 })
36 }
37}
38
39fn calculate_hash(cookie_uuid: Uuid) -> u64 {
40 let mut hash_value = cookie_uuid.as_u128() as u64;
41 hash_value ^= hash_value >> 33;
42 hash_value = hash_value.wrapping_mul(0xff51afd7ed558ccd);
43 hash_value ^= hash_value >> 33;
44 hash_value = hash_value.wrapping_mul(0xc4ceb9fe1a85ec53);
45 hash_value ^= hash_value >> 33;
46
47 hash_value
48}
49
50fn hash_to_index(value: Uuid, length: usize) -> usize {
51 let hash = calculate_hash(value);
52 (hash % length as u64) as usize
53}
54
55const BASE_BACKOFF: Duration = Duration::from_millis(1);
57
58fn calculate_exponential_backoff(retries: u32) -> Duration {
59 BASE_BACKOFF * 2u32.pow(retries)
60}
61
62impl LoadBalancingStrategy for CookieHash {
63 type Input = Uuid;
64 async fn entry(&self, id: Uuid) -> Client {
65 let mut retries = 0;
66 let index = hash_to_index(id, self.targets_len);
67 let client = self.targets.targets[index].clone();
68 loop {
69 if client.is_online() {
70 break client;
71 }
72
73 let backoff = calculate_exponential_backoff(retries);
74
75 log::debug!(
76 target: "faucet",
77 "LB Session {} tried to connect to offline {}, retrying in {:?}",
78 id,
79 client.config.target,
80 backoff
81 );
82
83 tokio::time::sleep(backoff).await;
84 retries += 1;
85 }
86 }
87}
88
89#[cfg(test)]
90mod tests {
91 use super::*;
92 use crate::client::ExtractSocketAddr;
93
94 use uuid::Uuid;
95
96 #[test]
97 fn uuid_test_distribution_of_hash_function_len_4() {
98 const N_UUIDS: usize = 100_000;
99
100 let uuids: Vec<Uuid> = (0..N_UUIDS).map(|_| Uuid::now_v7()).collect();
101
102 let mut counts = [0; 4];
103
104 uuids.iter().for_each(|uuid| {
105 let index = hash_to_index(*uuid, 4);
106 counts[index] += 1;
107 });
108
109 let percent_0 = counts[0] as f64 / N_UUIDS as f64;
110 let percent_1 = counts[1] as f64 / N_UUIDS as f64;
111 let percent_2 = counts[2] as f64 / N_UUIDS as f64;
112 let percent_3 = counts[3] as f64 / N_UUIDS as f64;
113 assert!((0.24..=0.26).contains(&percent_0));
114 assert!((0.24..=0.26).contains(&percent_1));
115 assert!((0.24..=0.26).contains(&percent_2));
116 assert!((0.24..=0.26).contains(&percent_3));
117 }
118
119 #[test]
120 fn uuid_test_distribution_of_hash_function_len_3() {
121 const N_UUIDS: usize = 100_000;
122
123 let uuids: Vec<Uuid> = (0..N_UUIDS).map(|_| Uuid::now_v7()).collect();
124
125 let mut counts = [0; 3];
126
127 uuids.iter().for_each(|uuid| {
128 let index = hash_to_index(*uuid, 3);
129 counts[index] += 1;
130 });
131
132 let percent_0 = counts[0] as f64 / N_UUIDS as f64;
133 let percent_1 = counts[1] as f64 / N_UUIDS as f64;
134 let percent_2 = counts[2] as f64 / N_UUIDS as f64;
135 assert!((0.32..=0.34).contains(&percent_0));
136 assert!((0.32..=0.34).contains(&percent_1));
137 assert!((0.32..=0.34).contains(&percent_2));
138 }
139
140 #[test]
141 fn uuid_test_distribution_of_hash_function_len_2() {
142 const N_UUIDS: usize = 100_000;
143
144 let uuids: Vec<Uuid> = (0..N_UUIDS).map(|_| Uuid::now_v7()).collect();
145
146 let mut counts = [0; 2];
147
148 uuids.iter().for_each(|uuid| {
149 let index = hash_to_index(*uuid, 2);
150 counts[index] += 1;
151 });
152
153 let percent_0 = counts[0] as f64 / N_UUIDS as f64;
154 let percent_1 = counts[1] as f64 / N_UUIDS as f64;
155 assert!((0.49..=0.51).contains(&percent_0));
156 assert!((0.49..=0.51).contains(&percent_1));
157 }
158
159 #[test]
160 fn test_new_targets() {
161 let worker_state = WorkerConfig::dummy("test", "127.0.0.1:9999", true);
162 let Targets { targets } = Targets::new(&[worker_state]).unwrap();
163
164 assert_eq!(targets.len(), 1);
165 }
166
167 #[test]
168 fn test_new_cookie_hash() {
169 let worker_state = WorkerConfig::dummy("test", "127.0.0.1:9999", true);
170 let CookieHash {
171 targets,
172 targets_len,
173 } = CookieHash::new(&[worker_state]).unwrap();
174
175 assert_eq!(targets.targets.len(), 1);
176 assert_eq!(targets_len, 1);
177 }
178
179 #[test]
180 fn test_calculate_exponential_backoff() {
181 assert_eq!(calculate_exponential_backoff(0), BASE_BACKOFF);
182 assert_eq!(calculate_exponential_backoff(1), BASE_BACKOFF * 2);
183 assert_eq!(calculate_exponential_backoff(2), BASE_BACKOFF * 4);
184 assert_eq!(calculate_exponential_backoff(3), BASE_BACKOFF * 8);
185 }
186
187 #[tokio::test]
188 async fn test_load_balancing_strategy() {
189 let workers = [
190 WorkerConfig::dummy("test1", "127.0.0.1:9999", true),
191 WorkerConfig::dummy("test2", "127.0.0.1:8888", true),
192 ];
193 let cookie_hash = CookieHash::new(&workers).unwrap();
194
195 let uuid1 = Uuid::now_v7();
196 let client1_a = cookie_hash.entry(uuid1).await;
197 let client1_b = cookie_hash.entry(uuid1).await;
198 assert_eq!(client1_a.socket_addr(), client1_b.socket_addr());
199
200 let mut client2_addr = client1_a.socket_addr();
204 let mut uuid2 = Uuid::now_v7();
205
206 for _ in 0..100 {
207 uuid2 = Uuid::now_v7();
209 let client_temp = cookie_hash.entry(uuid2).await;
210 if client_temp.socket_addr() != client1_a.socket_addr() {
211 client2_addr = client_temp.socket_addr();
212 break;
213 }
214 }
215
216 let client2_a = cookie_hash.entry(uuid2).await;
221 let client2_b = cookie_hash.entry(uuid2).await;
222 assert_eq!(client2_a.socket_addr(), client2_b.socket_addr());
223 assert_eq!(client2_a.socket_addr(), client2_addr);
224
225 if workers.len() > 1 {
226 if client1_a.socket_addr() != client2_a.socket_addr() {
228 assert_ne!(client1_a.socket_addr(), client2_a.socket_addr());
229 } else {
230 println!("Warning: test_load_balancing_strategy did not find two different UUIDs mapping to different targets.");
233 }
234 } else {
235 assert_eq!(client1_a.socket_addr(), client2_a.socket_addr());
236 }
237 }
238}