use std::sync::atomic::{AtomicI64, Ordering};
use dashmap::DashMap;
use crate::api::naming::Instance;
pub trait LoadBalancer: Send + Sync {
fn select(&self, service_key: &str, instances: &[Instance]) -> Option<Instance>;
fn reset(&self, service_key: &str);
fn name(&self) -> &str;
}
pub struct WeightedRoundRobinBalancer {
states: DashMap<String, DashMap<String, AtomicI64>>,
}
impl WeightedRoundRobinBalancer {
pub fn new() -> Self {
Self {
states: DashMap::new(),
}
}
}
impl Default for WeightedRoundRobinBalancer {
fn default() -> Self {
Self::new()
}
}
impl LoadBalancer for WeightedRoundRobinBalancer {
fn select(&self, service_key: &str, instances: &[Instance]) -> Option<Instance> {
if instances.is_empty() {
return None;
}
let healthy_instances: Vec<&Instance> = instances
.iter()
.filter(|i| i.healthy && i.enabled && i.weight > 0.0)
.collect();
if healthy_instances.is_empty() {
return None;
}
if healthy_instances.len() == 1 {
return Some(healthy_instances[0].clone());
}
let service_state = self
.states
.entry(service_key.to_string())
.or_insert_with(DashMap::new);
let total_weight: i64 = healthy_instances
.iter()
.map(|i| (i.weight * 100.0) as i64)
.sum();
if total_weight == 0 {
return Some(healthy_instances[0].clone());
}
let mut max_current_weight: i64 = i64::MIN;
let mut selected: Option<&Instance> = None;
let mut selected_key: Option<String> = None;
for instance in &healthy_instances {
let key = instance.key();
let weight = (instance.weight * 100.0) as i64;
let current = service_state
.entry(key.clone())
.or_insert_with(|| AtomicI64::new(0));
let new_current = current.fetch_add(weight, Ordering::SeqCst) + weight;
if new_current > max_current_weight {
max_current_weight = new_current;
selected = Some(instance);
selected_key = Some(key);
}
}
if let Some(key) = selected_key {
if let Some(current) = service_state.get(&key) {
current.fetch_sub(total_weight, Ordering::SeqCst);
}
}
selected.cloned()
}
fn reset(&self, service_key: &str) {
self.states.remove(service_key);
}
fn name(&self) -> &str {
"WeightedRoundRobin"
}
}
pub struct RandomBalancer;
impl RandomBalancer {
pub fn new() -> Self {
Self
}
}
impl Default for RandomBalancer {
fn default() -> Self {
Self::new()
}
}
impl LoadBalancer for RandomBalancer {
fn select(&self, _service_key: &str, instances: &[Instance]) -> Option<Instance> {
if instances.is_empty() {
return None;
}
let healthy_instances: Vec<&Instance> = instances
.iter()
.filter(|i| i.healthy && i.enabled)
.collect();
if healthy_instances.is_empty() {
return None;
}
let index = rand_index(healthy_instances.len());
Some(healthy_instances[index].clone())
}
fn reset(&self, _service_key: &str) {
}
fn name(&self) -> &str {
"Random"
}
}
fn rand_index(max: usize) -> usize {
use std::time::{SystemTime, UNIX_EPOCH};
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.subsec_nanos() as usize;
nanos % max
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_instances() -> Vec<Instance> {
vec![
Instance::new("127.0.0.1", 8080).with_weight(1.0),
Instance::new("127.0.0.1", 8081).with_weight(2.0),
Instance::new("127.0.0.1", 8082).with_weight(3.0),
]
}
#[test]
fn test_wrr_distribution() {
let balancer = WeightedRoundRobinBalancer::new();
let instances = create_test_instances();
let service_key = "test-service";
let mut counts = std::collections::HashMap::new();
for _ in 0..600 {
if let Some(instance) = balancer.select(service_key, &instances) {
*counts.entry(instance.port).or_insert(0) += 1;
}
}
let port_8080 = counts.get(&8080).unwrap_or(&0);
let port_8081 = counts.get(&8081).unwrap_or(&0);
let port_8082 = counts.get(&8082).unwrap_or(&0);
assert!(
*port_8080 >= 80 && *port_8080 <= 120,
"port 8080 count: {}",
port_8080
);
assert!(
*port_8081 >= 180 && *port_8081 <= 220,
"port 8081 count: {}",
port_8081
);
assert!(
*port_8082 >= 280 && *port_8082 <= 320,
"port 8082 count: {}",
port_8082
);
}
#[test]
fn test_wrr_empty_instances() {
let balancer = WeightedRoundRobinBalancer::new();
let instances: Vec<Instance> = vec![];
assert!(balancer.select("test-service", &instances).is_none());
}
#[test]
fn test_wrr_single_instance() {
let balancer = WeightedRoundRobinBalancer::new();
let instances = vec![Instance::new("127.0.0.1", 8080).with_weight(1.0)];
let selected = balancer.select("test-service", &instances);
assert!(selected.is_some());
assert_eq!(selected.unwrap().port, 8080);
}
#[test]
fn test_wrr_reset() {
let balancer = WeightedRoundRobinBalancer::new();
let instances = create_test_instances();
let service_key = "test-service";
for _ in 0..10 {
balancer.select(service_key, &instances);
}
balancer.reset(service_key);
assert!(!balancer.states.contains_key(service_key));
}
#[test]
fn test_random_balancer() {
let balancer = RandomBalancer::new();
let instances = create_test_instances();
let selected = balancer.select("test-service", &instances);
assert!(selected.is_some());
}
#[test]
fn test_random_empty_instances() {
let balancer = RandomBalancer::new();
let instances: Vec<Instance> = vec![];
assert!(balancer.select("test-service", &instances).is_none());
}
}