use async_trait::async_trait;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, trace, warn};
use xxhash_rust::xxh3::xxh3_64;
use zentinel_common::errors::{ZentinelError, ZentinelResult};
use super::{LoadBalancer, RequestContext, TargetSelection, UpstreamTarget};
#[derive(Debug, Clone)]
pub struct MaglevConfig {
pub table_size: usize,
pub key_source: MaglevKeySource,
}
impl Default for MaglevConfig {
fn default() -> Self {
Self {
table_size: 65537,
key_source: MaglevKeySource::ClientIp,
}
}
}
#[derive(Debug, Clone)]
pub enum MaglevKeySource {
ClientIp,
Header(String),
Cookie(String),
Path,
}
pub struct MaglevBalancer {
targets: Vec<UpstreamTarget>,
lookup_table: Arc<RwLock<Vec<Option<usize>>>>,
health_status: Arc<RwLock<HashMap<String, bool>>>,
config: MaglevConfig,
generation: Arc<RwLock<u64>>,
}
impl MaglevBalancer {
pub fn new(targets: Vec<UpstreamTarget>, config: MaglevConfig) -> Self {
let mut health_status = HashMap::new();
for target in &targets {
health_status.insert(target.full_address(), true);
}
let table_size = config.table_size;
let balancer = Self {
targets,
lookup_table: Arc::new(RwLock::new(vec![None; table_size])),
health_status: Arc::new(RwLock::new(health_status)),
config,
generation: Arc::new(RwLock::new(0)),
};
let targets_clone = balancer.targets.clone();
let table_size = balancer.config.table_size;
let table = Self::build_lookup_table(&targets_clone, table_size);
if let Ok(mut lookup) = balancer.lookup_table.try_write() {
*lookup = table;
}
balancer
}
fn build_lookup_table(targets: &[UpstreamTarget], table_size: usize) -> Vec<Option<usize>> {
if targets.is_empty() {
return vec![None; table_size];
}
let n = targets.len();
let m = table_size;
let permutations: Vec<Vec<usize>> = targets
.iter()
.map(|target| Self::generate_permutation(&target.full_address(), m))
.collect();
let mut table = vec![None; m];
let mut next = vec![0usize; n]; let mut filled = 0;
while filled < m {
for i in 0..n {
loop {
let c = permutations[i][next[i]];
next[i] += 1;
if table[c].is_none() {
table[c] = Some(i);
filled += 1;
break;
}
if next[i] >= m {
next[i] = 0;
break;
}
}
if filled >= m {
break;
}
}
}
table
}
fn generate_permutation(name: &str, table_size: usize) -> Vec<usize> {
let m = table_size;
let h1 = xxh3_64(name.as_bytes()) as usize;
let h2 = {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
name.hash(&mut hasher);
hasher.finish() as usize
};
let offset = h1 % m;
let skip = (h2 % (m - 1)) + 1;
(0..m).map(|i| (offset + i * skip) % m).collect()
}
async fn rebuild_table_for_healthy(&self) {
let health = self.health_status.read().await;
let healthy_targets: Vec<_> = self
.targets
.iter()
.filter(|t| *health.get(&t.full_address()).unwrap_or(&true))
.cloned()
.collect();
drop(health);
if healthy_targets.is_empty() {
return;
}
let table = Self::build_lookup_table(&healthy_targets, self.config.table_size);
let mut lookup = self.lookup_table.write().await;
*lookup = table;
let mut gen = self.generation.write().await;
*gen += 1;
debug!(
healthy_count = healthy_targets.len(),
total_count = self.targets.len(),
generation = *gen,
"Maglev lookup table rebuilt"
);
}
fn extract_key(&self, context: Option<&RequestContext>) -> String {
match &self.config.key_source {
MaglevKeySource::ClientIp => context
.and_then(|c| c.client_ip.map(|ip| ip.ip().to_string()))
.unwrap_or_else(|| "default".to_string()),
MaglevKeySource::Header(name) => context
.and_then(|c| c.headers.get(name).cloned())
.unwrap_or_else(|| "default".to_string()),
MaglevKeySource::Cookie(name) => context
.and_then(|c| {
c.headers.get("cookie").and_then(|cookies| {
cookies.split(';').find_map(|cookie| {
let (key, value) = cookie.trim().split_once('=')?;
if key == name {
Some(value.to_string())
} else {
None
}
})
})
})
.unwrap_or_else(|| "default".to_string()),
MaglevKeySource::Path => context
.map(|c| c.path.clone())
.unwrap_or_else(|| "/".to_string()),
}
}
async fn get_healthy_targets(&self) -> Vec<&UpstreamTarget> {
let health = self.health_status.read().await;
self.targets
.iter()
.filter(|t| *health.get(&t.full_address()).unwrap_or(&true))
.collect()
}
}
#[async_trait]
impl LoadBalancer for MaglevBalancer {
async fn select(&self, context: Option<&RequestContext>) -> ZentinelResult<TargetSelection> {
trace!(
total_targets = self.targets.len(),
algorithm = "maglev",
"Selecting upstream target"
);
let health = self.health_status.read().await;
let healthy_targets: Vec<_> = self
.targets
.iter()
.enumerate()
.filter(|(_, t)| *health.get(&t.full_address()).unwrap_or(&true))
.collect();
drop(health);
if healthy_targets.is_empty() {
warn!(
total_targets = self.targets.len(),
algorithm = "maglev",
"No healthy upstream targets available"
);
return Err(ZentinelError::NoHealthyUpstream);
}
let key = self.extract_key(context);
let hash = xxh3_64(key.as_bytes()) as usize;
let table_index = hash % self.config.table_size;
let lookup = self.lookup_table.read().await;
let target_index = lookup[table_index];
drop(lookup);
let target = if let Some(idx) = target_index {
if idx < self.targets.len() {
let t = &self.targets[idx];
let health = self.health_status.read().await;
if *health.get(&t.full_address()).unwrap_or(&true) {
t
} else {
healthy_targets
.first()
.map(|(_, t)| *t)
.ok_or(ZentinelError::NoHealthyUpstream)?
}
} else {
healthy_targets
.first()
.map(|(_, t)| *t)
.ok_or(ZentinelError::NoHealthyUpstream)?
}
} else {
healthy_targets
.first()
.map(|(_, t)| *t)
.ok_or(ZentinelError::NoHealthyUpstream)?
};
trace!(
selected_target = %target.full_address(),
hash_key = %key,
table_index = table_index,
healthy_count = healthy_targets.len(),
algorithm = "maglev",
"Selected target via Maglev consistent hashing"
);
Ok(TargetSelection {
address: target.full_address(),
weight: target.weight,
metadata: HashMap::new(),
})
}
async fn report_health(&self, address: &str, healthy: bool) {
let prev_health = {
let health = self.health_status.read().await;
*health.get(address).unwrap_or(&true)
};
if prev_health != healthy {
trace!(
target = %address,
healthy = healthy,
algorithm = "maglev",
"Target health changed, rebuilding lookup table"
);
self.health_status
.write()
.await
.insert(address.to_string(), healthy);
self.rebuild_table_for_healthy().await;
} else {
self.health_status
.write()
.await
.insert(address.to_string(), healthy);
}
}
async fn healthy_targets(&self) -> Vec<String> {
self.health_status
.read()
.await
.iter()
.filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_targets(count: usize) -> Vec<UpstreamTarget> {
(0..count)
.map(|i| UpstreamTarget::new(format!("backend-{}", i), 8080, 100))
.collect()
}
#[test]
fn test_build_lookup_table() {
let targets = make_targets(3);
let table = MaglevBalancer::build_lookup_table(&targets, 65537);
assert!(table.iter().all(|entry| entry.is_some()));
let mut counts = vec![0usize; 3];
for idx in table.iter().flatten() {
counts[*idx] += 1;
}
let expected = 65537 / 3;
for count in counts {
assert!(
(count as i64 - expected as i64).abs() < (expected as i64 / 10),
"Uneven distribution: {} vs expected ~{}",
count,
expected
);
}
}
#[test]
fn test_permutation_generation() {
let perm1 = MaglevBalancer::generate_permutation("backend-1", 65537);
let perm2 = MaglevBalancer::generate_permutation("backend-2", 65537);
assert_ne!(perm1[0..100], perm2[0..100]);
let mut seen = vec![false; 65537];
for &idx in &perm1 {
seen[idx] = true;
}
assert!(seen.iter().all(|&s| s));
}
#[tokio::test]
async fn test_consistent_selection() {
let targets = make_targets(5);
let balancer = MaglevBalancer::new(targets, MaglevConfig::default());
let context = RequestContext {
client_ip: Some("192.168.1.100:12345".parse().unwrap()),
headers: HashMap::new(),
path: "/api/test".to_string(),
method: "GET".to_string(),
};
let selection1 = balancer.select(Some(&context)).await.unwrap();
let selection2 = balancer.select(Some(&context)).await.unwrap();
let selection3 = balancer.select(Some(&context)).await.unwrap();
assert_eq!(selection1.address, selection2.address);
assert_eq!(selection2.address, selection3.address);
}
#[tokio::test]
async fn test_minimal_disruption() {
let targets = make_targets(5);
let balancer = MaglevBalancer::new(targets.clone(), MaglevConfig::default());
let mut original_selections = HashMap::new();
for i in 0..1000 {
let context = RequestContext {
client_ip: Some(format!("192.168.1.{}:12345", i % 256).parse().unwrap()),
headers: HashMap::new(),
path: format!("/api/test/{}", i),
method: "GET".to_string(),
};
let selection = balancer.select(Some(&context)).await.unwrap();
original_selections.insert(i, selection.address);
}
balancer.report_health("backend-2:8080", false).await;
let mut changed = 0;
for i in 0..1000 {
let context = RequestContext {
client_ip: Some(format!("192.168.1.{}:12345", i % 256).parse().unwrap()),
headers: HashMap::new(),
path: format!("/api/test/{}", i),
method: "GET".to_string(),
};
let selection = balancer.select(Some(&context)).await.unwrap();
if selection.address != original_selections[&i] {
changed += 1;
}
}
assert!(
changed < 800,
"Too many selections changed: {} (expected less than 800 for 1/5 backend removal)",
changed
);
assert!(
changed < 1000 - 100,
"Too few stable selections: only {} unchanged",
1000 - changed
);
}
}