1use async_trait::async_trait;
10use std::collections::HashMap;
11use std::hash::{Hash, Hasher};
12use std::sync::Arc;
13use tokio::sync::RwLock;
14use tracing::{debug, trace, warn};
15use xxhash_rust::xxh3::xxh3_64;
16
17use grapsus_common::errors::{GrapsusError, GrapsusResult};
18
19use super::{LoadBalancer, RequestContext, TargetSelection, UpstreamTarget};
20
21#[derive(Debug, Clone)]
23pub struct MaglevConfig {
24 pub table_size: usize,
26 pub key_source: MaglevKeySource,
28}
29
30impl Default for MaglevConfig {
31 fn default() -> Self {
32 Self {
33 table_size: 65537,
35 key_source: MaglevKeySource::ClientIp,
36 }
37 }
38}
39
40#[derive(Debug, Clone)]
42pub enum MaglevKeySource {
43 ClientIp,
45 Header(String),
47 Cookie(String),
49 Path,
51}
52
53pub struct MaglevBalancer {
55 targets: Vec<UpstreamTarget>,
57 lookup_table: Arc<RwLock<Vec<Option<usize>>>>,
59 health_status: Arc<RwLock<HashMap<String, bool>>>,
61 config: MaglevConfig,
63 generation: Arc<RwLock<u64>>,
65}
66
67impl MaglevBalancer {
68 pub fn new(targets: Vec<UpstreamTarget>, config: MaglevConfig) -> Self {
70 let mut health_status = HashMap::new();
71 for target in &targets {
72 health_status.insert(target.full_address(), true);
73 }
74
75 let table_size = config.table_size;
76 let balancer = Self {
77 targets,
78 lookup_table: Arc::new(RwLock::new(vec![None; table_size])),
79 health_status: Arc::new(RwLock::new(health_status)),
80 config,
81 generation: Arc::new(RwLock::new(0)),
82 };
83
84 let targets_clone = balancer.targets.clone();
87 let table_size = balancer.config.table_size;
88 let table = Self::build_lookup_table(&targets_clone, table_size);
89
90 if let Ok(mut lookup) = balancer.lookup_table.try_write() {
92 *lookup = table;
93 }
94
95 balancer
96 }
97
98 fn build_lookup_table(targets: &[UpstreamTarget], table_size: usize) -> Vec<Option<usize>> {
100 if targets.is_empty() {
101 return vec![None; table_size];
102 }
103
104 let n = targets.len();
105 let m = table_size;
106
107 let permutations: Vec<Vec<usize>> = targets
109 .iter()
110 .map(|target| Self::generate_permutation(&target.full_address(), m))
111 .collect();
112
113 let mut table = vec![None; m];
115 let mut next = vec![0usize; n]; let mut filled = 0;
117
118 while filled < m {
119 for i in 0..n {
120 loop {
122 let c = permutations[i][next[i]];
123 next[i] += 1;
124
125 if table[c].is_none() {
126 table[c] = Some(i);
127 filled += 1;
128 break;
129 }
130
131 if next[i] >= m {
133 next[i] = 0;
134 break;
135 }
136 }
137
138 if filled >= m {
139 break;
140 }
141 }
142 }
143
144 table
145 }
146
147 fn generate_permutation(name: &str, table_size: usize) -> Vec<usize> {
149 let m = table_size;
150
151 let h1 = xxh3_64(name.as_bytes()) as usize;
153 let h2 = {
154 let mut hasher = std::collections::hash_map::DefaultHasher::new();
155 name.hash(&mut hasher);
156 hasher.finish() as usize
157 };
158
159 let offset = h1 % m;
161 let skip = (h2 % (m - 1)) + 1; (0..m).map(|i| (offset + i * skip) % m).collect()
165 }
166
167 async fn rebuild_table_for_healthy(&self) {
169 let health = self.health_status.read().await;
170 let healthy_targets: Vec<_> = self
171 .targets
172 .iter()
173 .filter(|t| *health.get(&t.full_address()).unwrap_or(&true))
174 .cloned()
175 .collect();
176 drop(health);
177
178 if healthy_targets.is_empty() {
179 return;
181 }
182
183 let table = Self::build_lookup_table(&healthy_targets, self.config.table_size);
184
185 let mut lookup = self.lookup_table.write().await;
186 *lookup = table;
187
188 let mut gen = self.generation.write().await;
189 *gen += 1;
190
191 debug!(
192 healthy_count = healthy_targets.len(),
193 total_count = self.targets.len(),
194 generation = *gen,
195 "Maglev lookup table rebuilt"
196 );
197 }
198
199 fn extract_key(&self, context: Option<&RequestContext>) -> String {
201 match &self.config.key_source {
202 MaglevKeySource::ClientIp => context
203 .and_then(|c| c.client_ip.map(|ip| ip.ip().to_string()))
204 .unwrap_or_else(|| "default".to_string()),
205 MaglevKeySource::Header(name) => context
206 .and_then(|c| c.headers.get(name).cloned())
207 .unwrap_or_else(|| "default".to_string()),
208 MaglevKeySource::Cookie(name) => context
209 .and_then(|c| {
210 c.headers.get("cookie").and_then(|cookies| {
211 cookies.split(';').find_map(|cookie| {
212 let (key, value) = cookie.trim().split_once('=')?;
213 if key == name {
214 Some(value.to_string())
215 } else {
216 None
217 }
218 })
219 })
220 })
221 .unwrap_or_else(|| "default".to_string()),
222 MaglevKeySource::Path => context
223 .map(|c| c.path.clone())
224 .unwrap_or_else(|| "/".to_string()),
225 }
226 }
227
228 async fn get_healthy_targets(&self) -> Vec<&UpstreamTarget> {
230 let health = self.health_status.read().await;
231 self.targets
232 .iter()
233 .filter(|t| *health.get(&t.full_address()).unwrap_or(&true))
234 .collect()
235 }
236}
237
238#[async_trait]
239impl LoadBalancer for MaglevBalancer {
240 async fn select(&self, context: Option<&RequestContext>) -> GrapsusResult<TargetSelection> {
241 trace!(
242 total_targets = self.targets.len(),
243 algorithm = "maglev",
244 "Selecting upstream target"
245 );
246
247 let health = self.health_status.read().await;
249 let healthy_targets: Vec<_> = self
250 .targets
251 .iter()
252 .enumerate()
253 .filter(|(_, t)| *health.get(&t.full_address()).unwrap_or(&true))
254 .collect();
255 drop(health);
256
257 if healthy_targets.is_empty() {
258 warn!(
259 total_targets = self.targets.len(),
260 algorithm = "maglev",
261 "No healthy upstream targets available"
262 );
263 return Err(GrapsusError::NoHealthyUpstream);
264 }
265
266 let key = self.extract_key(context);
268 let hash = xxh3_64(key.as_bytes()) as usize;
269 let table_index = hash % self.config.table_size;
270
271 let lookup = self.lookup_table.read().await;
273 let target_index = lookup[table_index];
274 drop(lookup);
275
276 let target = if let Some(idx) = target_index {
278 if idx < self.targets.len() {
280 let t = &self.targets[idx];
281 let health = self.health_status.read().await;
282 if *health.get(&t.full_address()).unwrap_or(&true) {
283 t
284 } else {
285 healthy_targets
287 .first()
288 .map(|(_, t)| *t)
289 .ok_or(GrapsusError::NoHealthyUpstream)?
290 }
291 } else {
292 healthy_targets
294 .first()
295 .map(|(_, t)| *t)
296 .ok_or(GrapsusError::NoHealthyUpstream)?
297 }
298 } else {
299 healthy_targets
301 .first()
302 .map(|(_, t)| *t)
303 .ok_or(GrapsusError::NoHealthyUpstream)?
304 };
305
306 trace!(
307 selected_target = %target.full_address(),
308 hash_key = %key,
309 table_index = table_index,
310 healthy_count = healthy_targets.len(),
311 algorithm = "maglev",
312 "Selected target via Maglev consistent hashing"
313 );
314
315 Ok(TargetSelection {
316 address: target.full_address(),
317 weight: target.weight,
318 metadata: HashMap::new(),
319 })
320 }
321
322 async fn report_health(&self, address: &str, healthy: bool) {
323 let prev_health = {
324 let health = self.health_status.read().await;
325 *health.get(address).unwrap_or(&true)
326 };
327
328 if prev_health != healthy {
329 trace!(
330 target = %address,
331 healthy = healthy,
332 algorithm = "maglev",
333 "Target health changed, rebuilding lookup table"
334 );
335
336 self.health_status
337 .write()
338 .await
339 .insert(address.to_string(), healthy);
340
341 self.rebuild_table_for_healthy().await;
343 } else {
344 self.health_status
345 .write()
346 .await
347 .insert(address.to_string(), healthy);
348 }
349 }
350
351 async fn healthy_targets(&self) -> Vec<String> {
352 self.health_status
353 .read()
354 .await
355 .iter()
356 .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
357 .collect()
358 }
359}
360
361#[cfg(test)]
362mod tests {
363 use super::*;
364
365 fn make_targets(count: usize) -> Vec<UpstreamTarget> {
366 (0..count)
367 .map(|i| UpstreamTarget::new(format!("backend-{}", i), 8080, 100))
368 .collect()
369 }
370
371 #[test]
372 fn test_build_lookup_table() {
373 let targets = make_targets(3);
374 let table = MaglevBalancer::build_lookup_table(&targets, 65537);
375
376 assert!(table.iter().all(|entry| entry.is_some()));
378
379 let mut counts = vec![0usize; 3];
381 for idx in table.iter().flatten() {
382 counts[*idx] += 1;
383 }
384
385 let expected = 65537 / 3;
387 for count in counts {
388 assert!(
389 (count as i64 - expected as i64).abs() < (expected as i64 / 10),
390 "Uneven distribution: {} vs expected ~{}",
391 count,
392 expected
393 );
394 }
395 }
396
397 #[test]
398 fn test_permutation_generation() {
399 let perm1 = MaglevBalancer::generate_permutation("backend-1", 65537);
400 let perm2 = MaglevBalancer::generate_permutation("backend-2", 65537);
401
402 assert_ne!(perm1[0..100], perm2[0..100]);
404
405 let mut seen = vec![false; 65537];
407 for &idx in &perm1 {
408 seen[idx] = true;
409 }
410 assert!(seen.iter().all(|&s| s));
411 }
412
413 #[tokio::test]
414 async fn test_consistent_selection() {
415 let targets = make_targets(5);
416 let balancer = MaglevBalancer::new(targets, MaglevConfig::default());
417
418 let context = RequestContext {
419 client_ip: Some("192.168.1.100:12345".parse().unwrap()),
420 headers: HashMap::new(),
421 path: "/api/test".to_string(),
422 method: "GET".to_string(),
423 };
424
425 let selection1 = balancer.select(Some(&context)).await.unwrap();
427 let selection2 = balancer.select(Some(&context)).await.unwrap();
428 let selection3 = balancer.select(Some(&context)).await.unwrap();
429
430 assert_eq!(selection1.address, selection2.address);
431 assert_eq!(selection2.address, selection3.address);
432 }
433
434 #[tokio::test]
435 async fn test_minimal_disruption() {
436 let targets = make_targets(5);
438 let balancer = MaglevBalancer::new(targets.clone(), MaglevConfig::default());
439
440 let mut original_selections = HashMap::new();
442 for i in 0..1000 {
443 let context = RequestContext {
444 client_ip: Some(format!("192.168.1.{}:12345", i % 256).parse().unwrap()),
445 headers: HashMap::new(),
446 path: format!("/api/test/{}", i),
447 method: "GET".to_string(),
448 };
449 let selection = balancer.select(Some(&context)).await.unwrap();
450 original_selections.insert(i, selection.address);
451 }
452
453 balancer.report_health("backend-2:8080", false).await;
455
456 let mut changed = 0;
458 for i in 0..1000 {
459 let context = RequestContext {
460 client_ip: Some(format!("192.168.1.{}:12345", i % 256).parse().unwrap()),
461 headers: HashMap::new(),
462 path: format!("/api/test/{}", i),
463 method: "GET".to_string(),
464 };
465 let selection = balancer.select(Some(&context)).await.unwrap();
466 if selection.address != original_selections[&i] {
467 changed += 1;
468 }
469 }
470
471 assert!(
477 changed < 800,
478 "Too many selections changed: {} (expected less than 800 for 1/5 backend removal)",
479 changed
480 );
481
482 assert!(
484 changed < 1000 - 100,
485 "Too few stable selections: only {} unchanged",
486 1000 - changed
487 );
488 }
489}