batata_client/naming/
balancer.rs

1//! Load balancer implementations for service instance selection
2
3use std::sync::atomic::{AtomicI64, Ordering};
4
5use dashmap::DashMap;
6
7use crate::api::naming::Instance;
8
9/// Load balancer trait for selecting service instances
10pub trait LoadBalancer: Send + Sync {
11    /// Select one instance from the given list
12    ///
13    /// # Arguments
14    /// * `service_key` - Unique key for the service (used for stateful balancers)
15    /// * `instances` - List of available instances to select from
16    ///
17    /// # Returns
18    /// Selected instance, or None if no suitable instance found
19    fn select(&self, service_key: &str, instances: &[Instance]) -> Option<Instance>;
20
21    /// Reset balancer state for a service (called when instances change)
22    fn reset(&self, service_key: &str);
23
24    /// Get balancer name
25    fn name(&self) -> &str;
26}
27
28/// Smooth Weighted Round Robin (WRR) load balancer
29///
30/// This balancer distributes traffic proportionally to instance weights.
31/// For example, with weights [1, 2, 3], traffic is distributed as 1:2:3.
32///
33/// The algorithm maintains current weights for each instance and selects
34/// the instance with the highest current weight, then reduces its weight
35/// by the total weight of all instances.
36pub struct WeightedRoundRobinBalancer {
37    /// Current weight state per service: service_key -> (instance_key -> current_weight)
38    states: DashMap<String, DashMap<String, AtomicI64>>,
39}
40
41impl WeightedRoundRobinBalancer {
42    /// Create a new WRR balancer
43    pub fn new() -> Self {
44        Self {
45            states: DashMap::new(),
46        }
47    }
48}
49
50impl Default for WeightedRoundRobinBalancer {
51    fn default() -> Self {
52        Self::new()
53    }
54}
55
56impl LoadBalancer for WeightedRoundRobinBalancer {
57    fn select(&self, service_key: &str, instances: &[Instance]) -> Option<Instance> {
58        if instances.is_empty() {
59            return None;
60        }
61
62        // Filter healthy and enabled instances
63        let healthy_instances: Vec<&Instance> = instances
64            .iter()
65            .filter(|i| i.healthy && i.enabled && i.weight > 0.0)
66            .collect();
67
68        if healthy_instances.is_empty() {
69            return None;
70        }
71
72        // If only one instance, return it directly
73        if healthy_instances.len() == 1 {
74            return Some(healthy_instances[0].clone());
75        }
76
77        // Get or create state for this service
78        let service_state = self
79            .states
80            .entry(service_key.to_string())
81            .or_insert_with(DashMap::new);
82
83        // Calculate total weight (multiply by 100 to handle floating point precision)
84        let total_weight: i64 = healthy_instances
85            .iter()
86            .map(|i| (i.weight * 100.0) as i64)
87            .sum();
88
89        if total_weight == 0 {
90            return Some(healthy_instances[0].clone());
91        }
92
93        let mut max_current_weight: i64 = i64::MIN;
94        let mut selected: Option<&Instance> = None;
95        let mut selected_key: Option<String> = None;
96
97        // Smooth WRR algorithm
98        for instance in &healthy_instances {
99            let key = instance.key();
100            let weight = (instance.weight * 100.0) as i64;
101
102            // Get or create current weight for this instance
103            let current = service_state
104                .entry(key.clone())
105                .or_insert_with(|| AtomicI64::new(0));
106
107            // Add effective weight to current weight
108            let new_current = current.fetch_add(weight, Ordering::SeqCst) + weight;
109
110            // Select instance with highest current weight
111            if new_current > max_current_weight {
112                max_current_weight = new_current;
113                selected = Some(instance);
114                selected_key = Some(key);
115            }
116        }
117
118        // Subtract total weight from selected instance
119        if let Some(key) = selected_key {
120            if let Some(current) = service_state.get(&key) {
121                current.fetch_sub(total_weight, Ordering::SeqCst);
122            }
123        }
124
125        selected.cloned()
126    }
127
128    fn reset(&self, service_key: &str) {
129        self.states.remove(service_key);
130    }
131
132    fn name(&self) -> &str {
133        "WeightedRoundRobin"
134    }
135}
136
137/// Random load balancer (original behavior)
138///
139/// Selects instances randomly without considering weights.
140pub struct RandomBalancer;
141
142impl RandomBalancer {
143    /// Create a new random balancer
144    pub fn new() -> Self {
145        Self
146    }
147}
148
149impl Default for RandomBalancer {
150    fn default() -> Self {
151        Self::new()
152    }
153}
154
155impl LoadBalancer for RandomBalancer {
156    fn select(&self, _service_key: &str, instances: &[Instance]) -> Option<Instance> {
157        if instances.is_empty() {
158            return None;
159        }
160
161        // Filter healthy and enabled instances
162        let healthy_instances: Vec<&Instance> = instances
163            .iter()
164            .filter(|i| i.healthy && i.enabled)
165            .collect();
166
167        if healthy_instances.is_empty() {
168            return None;
169        }
170
171        let index = rand_index(healthy_instances.len());
172        Some(healthy_instances[index].clone())
173    }
174
175    fn reset(&self, _service_key: &str) {
176        // No state to reset
177    }
178
179    fn name(&self) -> &str {
180        "Random"
181    }
182}
183
184/// Simple random index generator
185fn rand_index(max: usize) -> usize {
186    use std::time::{SystemTime, UNIX_EPOCH};
187    let nanos = SystemTime::now()
188        .duration_since(UNIX_EPOCH)
189        .unwrap()
190        .subsec_nanos() as usize;
191    nanos % max
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197
198    fn create_test_instances() -> Vec<Instance> {
199        vec![
200            Instance::new("127.0.0.1", 8080).with_weight(1.0),
201            Instance::new("127.0.0.1", 8081).with_weight(2.0),
202            Instance::new("127.0.0.1", 8082).with_weight(3.0),
203        ]
204    }
205
206    #[test]
207    fn test_wrr_distribution() {
208        let balancer = WeightedRoundRobinBalancer::new();
209        let instances = create_test_instances();
210        let service_key = "test-service";
211
212        let mut counts = std::collections::HashMap::new();
213
214        // Run 600 selections (should be 100:200:300 distribution)
215        for _ in 0..600 {
216            if let Some(instance) = balancer.select(service_key, &instances) {
217                *counts.entry(instance.port).or_insert(0) += 1;
218            }
219        }
220
221        // Check distribution is roughly proportional to weights
222        let port_8080 = counts.get(&8080).unwrap_or(&0);
223        let port_8081 = counts.get(&8081).unwrap_or(&0);
224        let port_8082 = counts.get(&8082).unwrap_or(&0);
225
226        // Weight 1:2:3 should result in approximately 100:200:300 distribution
227        assert!(
228            *port_8080 >= 80 && *port_8080 <= 120,
229            "port 8080 count: {}",
230            port_8080
231        );
232        assert!(
233            *port_8081 >= 180 && *port_8081 <= 220,
234            "port 8081 count: {}",
235            port_8081
236        );
237        assert!(
238            *port_8082 >= 280 && *port_8082 <= 320,
239            "port 8082 count: {}",
240            port_8082
241        );
242    }
243
244    #[test]
245    fn test_wrr_empty_instances() {
246        let balancer = WeightedRoundRobinBalancer::new();
247        let instances: Vec<Instance> = vec![];
248
249        assert!(balancer.select("test-service", &instances).is_none());
250    }
251
252    #[test]
253    fn test_wrr_single_instance() {
254        let balancer = WeightedRoundRobinBalancer::new();
255        let instances = vec![Instance::new("127.0.0.1", 8080).with_weight(1.0)];
256
257        let selected = balancer.select("test-service", &instances);
258        assert!(selected.is_some());
259        assert_eq!(selected.unwrap().port, 8080);
260    }
261
262    #[test]
263    fn test_wrr_reset() {
264        let balancer = WeightedRoundRobinBalancer::new();
265        let instances = create_test_instances();
266        let service_key = "test-service";
267
268        // Make some selections
269        for _ in 0..10 {
270            balancer.select(service_key, &instances);
271        }
272
273        // Reset state
274        balancer.reset(service_key);
275
276        // State should be cleared
277        assert!(!balancer.states.contains_key(service_key));
278    }
279
280    #[test]
281    fn test_random_balancer() {
282        let balancer = RandomBalancer::new();
283        let instances = create_test_instances();
284
285        // Should return some instance
286        let selected = balancer.select("test-service", &instances);
287        assert!(selected.is_some());
288    }
289
290    #[test]
291    fn test_random_empty_instances() {
292        let balancer = RandomBalancer::new();
293        let instances: Vec<Instance> = vec![];
294
295        assert!(balancer.select("test-service", &instances).is_none());
296    }
297}