batata-client 0.0.2

Rust client for Batata/Nacos service discovery and configuration management
Documentation
//! Load balancer implementations for service instance selection

use std::sync::atomic::{AtomicI64, Ordering};

use dashmap::DashMap;

use crate::api::naming::Instance;

/// Load balancer trait for selecting service instances
pub trait LoadBalancer: Send + Sync {
    /// Select one instance from the given list
    ///
    /// # Arguments
    /// * `service_key` - Unique key for the service (used for stateful balancers)
    /// * `instances` - List of available instances to select from
    ///
    /// # Returns
    /// Selected instance, or None if no suitable instance found
    fn select(&self, service_key: &str, instances: &[Instance]) -> Option<Instance>;

    /// Reset balancer state for a service (called when instances change)
    fn reset(&self, service_key: &str);

    /// Get balancer name
    fn name(&self) -> &str;
}

/// Smooth Weighted Round Robin (WRR) load balancer
///
/// This balancer distributes traffic proportionally to instance weights.
/// For example, with weights [1, 2, 3], traffic is distributed as 1:2:3.
///
/// The algorithm maintains current weights for each instance and selects
/// the instance with the highest current weight, then reduces its weight
/// by the total weight of all instances.
pub struct WeightedRoundRobinBalancer {
    /// Current weight state per service: service_key -> (instance_key -> current_weight)
    states: DashMap<String, DashMap<String, AtomicI64>>,
}

impl WeightedRoundRobinBalancer {
    /// Create a new WRR balancer
    pub fn new() -> Self {
        Self {
            states: DashMap::new(),
        }
    }
}

impl Default for WeightedRoundRobinBalancer {
    fn default() -> Self {
        Self::new()
    }
}

impl LoadBalancer for WeightedRoundRobinBalancer {
    fn select(&self, service_key: &str, instances: &[Instance]) -> Option<Instance> {
        if instances.is_empty() {
            return None;
        }

        // Filter healthy and enabled instances
        let healthy_instances: Vec<&Instance> = instances
            .iter()
            .filter(|i| i.healthy && i.enabled && i.weight > 0.0)
            .collect();

        if healthy_instances.is_empty() {
            return None;
        }

        // If only one instance, return it directly
        if healthy_instances.len() == 1 {
            return Some(healthy_instances[0].clone());
        }

        // Get or create state for this service
        let service_state = self
            .states
            .entry(service_key.to_string())
            .or_insert_with(DashMap::new);

        // Calculate total weight (multiply by 100 to handle floating point precision)
        let total_weight: i64 = healthy_instances
            .iter()
            .map(|i| (i.weight * 100.0) as i64)
            .sum();

        if total_weight == 0 {
            return Some(healthy_instances[0].clone());
        }

        let mut max_current_weight: i64 = i64::MIN;
        let mut selected: Option<&Instance> = None;
        let mut selected_key: Option<String> = None;

        // Smooth WRR algorithm
        for instance in &healthy_instances {
            let key = instance.key();
            let weight = (instance.weight * 100.0) as i64;

            // Get or create current weight for this instance
            let current = service_state
                .entry(key.clone())
                .or_insert_with(|| AtomicI64::new(0));

            // Add effective weight to current weight
            let new_current = current.fetch_add(weight, Ordering::SeqCst) + weight;

            // Select instance with highest current weight
            if new_current > max_current_weight {
                max_current_weight = new_current;
                selected = Some(instance);
                selected_key = Some(key);
            }
        }

        // Subtract total weight from selected instance
        if let Some(key) = selected_key {
            if let Some(current) = service_state.get(&key) {
                current.fetch_sub(total_weight, Ordering::SeqCst);
            }
        }

        selected.cloned()
    }

    fn reset(&self, service_key: &str) {
        self.states.remove(service_key);
    }

    fn name(&self) -> &str {
        "WeightedRoundRobin"
    }
}

/// Random load balancer (original behavior)
///
/// Selects instances randomly without considering weights.
pub struct RandomBalancer;

impl RandomBalancer {
    /// Create a new random balancer
    pub fn new() -> Self {
        Self
    }
}

impl Default for RandomBalancer {
    fn default() -> Self {
        Self::new()
    }
}

impl LoadBalancer for RandomBalancer {
    fn select(&self, _service_key: &str, instances: &[Instance]) -> Option<Instance> {
        if instances.is_empty() {
            return None;
        }

        // Filter healthy and enabled instances
        let healthy_instances: Vec<&Instance> = instances
            .iter()
            .filter(|i| i.healthy && i.enabled)
            .collect();

        if healthy_instances.is_empty() {
            return None;
        }

        let index = rand_index(healthy_instances.len());
        Some(healthy_instances[index].clone())
    }

    fn reset(&self, _service_key: &str) {
        // No state to reset
    }

    fn name(&self) -> &str {
        "Random"
    }
}

/// Simple random index generator
fn rand_index(max: usize) -> usize {
    use std::time::{SystemTime, UNIX_EPOCH};
    let nanos = SystemTime::now()
        .duration_since(UNIX_EPOCH)
        .unwrap()
        .subsec_nanos() as usize;
    nanos % max
}

#[cfg(test)]
mod tests {
    use super::*;

    fn create_test_instances() -> Vec<Instance> {
        vec![
            Instance::new("127.0.0.1", 8080).with_weight(1.0),
            Instance::new("127.0.0.1", 8081).with_weight(2.0),
            Instance::new("127.0.0.1", 8082).with_weight(3.0),
        ]
    }

    #[test]
    fn test_wrr_distribution() {
        let balancer = WeightedRoundRobinBalancer::new();
        let instances = create_test_instances();
        let service_key = "test-service";

        let mut counts = std::collections::HashMap::new();

        // Run 600 selections (should be 100:200:300 distribution)
        for _ in 0..600 {
            if let Some(instance) = balancer.select(service_key, &instances) {
                *counts.entry(instance.port).or_insert(0) += 1;
            }
        }

        // Check distribution is roughly proportional to weights
        let port_8080 = counts.get(&8080).unwrap_or(&0);
        let port_8081 = counts.get(&8081).unwrap_or(&0);
        let port_8082 = counts.get(&8082).unwrap_or(&0);

        // Weight 1:2:3 should result in approximately 100:200:300 distribution
        assert!(
            *port_8080 >= 80 && *port_8080 <= 120,
            "port 8080 count: {}",
            port_8080
        );
        assert!(
            *port_8081 >= 180 && *port_8081 <= 220,
            "port 8081 count: {}",
            port_8081
        );
        assert!(
            *port_8082 >= 280 && *port_8082 <= 320,
            "port 8082 count: {}",
            port_8082
        );
    }

    #[test]
    fn test_wrr_empty_instances() {
        let balancer = WeightedRoundRobinBalancer::new();
        let instances: Vec<Instance> = vec![];

        assert!(balancer.select("test-service", &instances).is_none());
    }

    #[test]
    fn test_wrr_single_instance() {
        let balancer = WeightedRoundRobinBalancer::new();
        let instances = vec![Instance::new("127.0.0.1", 8080).with_weight(1.0)];

        let selected = balancer.select("test-service", &instances);
        assert!(selected.is_some());
        assert_eq!(selected.unwrap().port, 8080);
    }

    #[test]
    fn test_wrr_reset() {
        let balancer = WeightedRoundRobinBalancer::new();
        let instances = create_test_instances();
        let service_key = "test-service";

        // Make some selections
        for _ in 0..10 {
            balancer.select(service_key, &instances);
        }

        // Reset state
        balancer.reset(service_key);

        // State should be cleared
        assert!(!balancer.states.contains_key(service_key));
    }

    #[test]
    fn test_random_balancer() {
        let balancer = RandomBalancer::new();
        let instances = create_test_instances();

        // Should return some instance
        let selected = balancer.select("test-service", &instances);
        assert!(selected.is_some());
    }

    #[test]
    fn test_random_empty_instances() {
        let balancer = RandomBalancer::new();
        let instances: Vec<Instance> = vec![];

        assert!(balancer.select("test-service", &instances).is_none());
    }
}