batata_client/naming/
balancer.rs1use std::sync::atomic::{AtomicI64, Ordering};
4
5use dashmap::DashMap;
6
7use crate::api::naming::Instance;
8
9pub trait LoadBalancer: Send + Sync {
11 fn select(&self, service_key: &str, instances: &[Instance]) -> Option<Instance>;
20
21 fn reset(&self, service_key: &str);
23
24 fn name(&self) -> &str;
26}
27
28pub struct WeightedRoundRobinBalancer {
37 states: DashMap<String, DashMap<String, AtomicI64>>,
39}
40
41impl WeightedRoundRobinBalancer {
42 pub fn new() -> Self {
44 Self {
45 states: DashMap::new(),
46 }
47 }
48}
49
50impl Default for WeightedRoundRobinBalancer {
51 fn default() -> Self {
52 Self::new()
53 }
54}
55
56impl LoadBalancer for WeightedRoundRobinBalancer {
57 fn select(&self, service_key: &str, instances: &[Instance]) -> Option<Instance> {
58 if instances.is_empty() {
59 return None;
60 }
61
62 let healthy_instances: Vec<&Instance> = instances
64 .iter()
65 .filter(|i| i.healthy && i.enabled && i.weight > 0.0)
66 .collect();
67
68 if healthy_instances.is_empty() {
69 return None;
70 }
71
72 if healthy_instances.len() == 1 {
74 return Some(healthy_instances[0].clone());
75 }
76
77 let service_state = self
79 .states
80 .entry(service_key.to_string())
81 .or_insert_with(DashMap::new);
82
83 let total_weight: i64 = healthy_instances
85 .iter()
86 .map(|i| (i.weight * 100.0) as i64)
87 .sum();
88
89 if total_weight == 0 {
90 return Some(healthy_instances[0].clone());
91 }
92
93 let mut max_current_weight: i64 = i64::MIN;
94 let mut selected: Option<&Instance> = None;
95 let mut selected_key: Option<String> = None;
96
97 for instance in &healthy_instances {
99 let key = instance.key();
100 let weight = (instance.weight * 100.0) as i64;
101
102 let current = service_state
104 .entry(key.clone())
105 .or_insert_with(|| AtomicI64::new(0));
106
107 let new_current = current.fetch_add(weight, Ordering::SeqCst) + weight;
109
110 if new_current > max_current_weight {
112 max_current_weight = new_current;
113 selected = Some(instance);
114 selected_key = Some(key);
115 }
116 }
117
118 if let Some(key) = selected_key {
120 if let Some(current) = service_state.get(&key) {
121 current.fetch_sub(total_weight, Ordering::SeqCst);
122 }
123 }
124
125 selected.cloned()
126 }
127
128 fn reset(&self, service_key: &str) {
129 self.states.remove(service_key);
130 }
131
132 fn name(&self) -> &str {
133 "WeightedRoundRobin"
134 }
135}
136
137pub struct RandomBalancer;
141
142impl RandomBalancer {
143 pub fn new() -> Self {
145 Self
146 }
147}
148
149impl Default for RandomBalancer {
150 fn default() -> Self {
151 Self::new()
152 }
153}
154
155impl LoadBalancer for RandomBalancer {
156 fn select(&self, _service_key: &str, instances: &[Instance]) -> Option<Instance> {
157 if instances.is_empty() {
158 return None;
159 }
160
161 let healthy_instances: Vec<&Instance> = instances
163 .iter()
164 .filter(|i| i.healthy && i.enabled)
165 .collect();
166
167 if healthy_instances.is_empty() {
168 return None;
169 }
170
171 let index = rand_index(healthy_instances.len());
172 Some(healthy_instances[index].clone())
173 }
174
175 fn reset(&self, _service_key: &str) {
176 }
178
179 fn name(&self) -> &str {
180 "Random"
181 }
182}
183
184fn rand_index(max: usize) -> usize {
186 use std::time::{SystemTime, UNIX_EPOCH};
187 let nanos = SystemTime::now()
188 .duration_since(UNIX_EPOCH)
189 .unwrap()
190 .subsec_nanos() as usize;
191 nanos % max
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197
198 fn create_test_instances() -> Vec<Instance> {
199 vec![
200 Instance::new("127.0.0.1", 8080).with_weight(1.0),
201 Instance::new("127.0.0.1", 8081).with_weight(2.0),
202 Instance::new("127.0.0.1", 8082).with_weight(3.0),
203 ]
204 }
205
206 #[test]
207 fn test_wrr_distribution() {
208 let balancer = WeightedRoundRobinBalancer::new();
209 let instances = create_test_instances();
210 let service_key = "test-service";
211
212 let mut counts = std::collections::HashMap::new();
213
214 for _ in 0..600 {
216 if let Some(instance) = balancer.select(service_key, &instances) {
217 *counts.entry(instance.port).or_insert(0) += 1;
218 }
219 }
220
221 let port_8080 = counts.get(&8080).unwrap_or(&0);
223 let port_8081 = counts.get(&8081).unwrap_or(&0);
224 let port_8082 = counts.get(&8082).unwrap_or(&0);
225
226 assert!(
228 *port_8080 >= 80 && *port_8080 <= 120,
229 "port 8080 count: {}",
230 port_8080
231 );
232 assert!(
233 *port_8081 >= 180 && *port_8081 <= 220,
234 "port 8081 count: {}",
235 port_8081
236 );
237 assert!(
238 *port_8082 >= 280 && *port_8082 <= 320,
239 "port 8082 count: {}",
240 port_8082
241 );
242 }
243
244 #[test]
245 fn test_wrr_empty_instances() {
246 let balancer = WeightedRoundRobinBalancer::new();
247 let instances: Vec<Instance> = vec![];
248
249 assert!(balancer.select("test-service", &instances).is_none());
250 }
251
252 #[test]
253 fn test_wrr_single_instance() {
254 let balancer = WeightedRoundRobinBalancer::new();
255 let instances = vec![Instance::new("127.0.0.1", 8080).with_weight(1.0)];
256
257 let selected = balancer.select("test-service", &instances);
258 assert!(selected.is_some());
259 assert_eq!(selected.unwrap().port, 8080);
260 }
261
262 #[test]
263 fn test_wrr_reset() {
264 let balancer = WeightedRoundRobinBalancer::new();
265 let instances = create_test_instances();
266 let service_key = "test-service";
267
268 for _ in 0..10 {
270 balancer.select(service_key, &instances);
271 }
272
273 balancer.reset(service_key);
275
276 assert!(!balancer.states.contains_key(service_key));
278 }
279
280 #[test]
281 fn test_random_balancer() {
282 let balancer = RandomBalancer::new();
283 let instances = create_test_instances();
284
285 let selected = balancer.select("test-service", &instances);
287 assert!(selected.is_some());
288 }
289
290 #[test]
291 fn test_random_empty_instances() {
292 let balancer = RandomBalancer::new();
293 let instances: Vec<Instance> = vec![];
294
295 assert!(balancer.select("test-service", &instances).is_none());
296 }
297}