1use async_trait::async_trait;
15use std::collections::HashMap;
16use std::sync::atomic::{AtomicUsize, Ordering};
17use std::sync::Arc;
18use tokio::sync::RwLock;
19use tracing::{debug, trace, warn};
20
21use grapsus_common::errors::{GrapsusError, GrapsusResult};
22
23use super::{LoadBalancer, RequestContext, TargetSelection, UpstreamTarget};
24
25#[derive(Debug, Clone)]
27pub struct WeightedLeastConnConfig {
28 pub min_weight: u32,
30 pub tie_breaker: TieBreakerStrategy,
32}
33
34impl Default for WeightedLeastConnConfig {
35 fn default() -> Self {
36 Self {
37 min_weight: 1,
38 tie_breaker: TieBreakerStrategy::HigherWeight,
39 }
40 }
41}
42
43#[derive(Debug, Clone, Copy, Default)]
45pub enum TieBreakerStrategy {
46 #[default]
48 HigherWeight,
49 FewerConnections,
51 RoundRobin,
53}
54
55pub struct WeightedLeastConnBalancer {
57 targets: Vec<UpstreamTarget>,
59 connections: Arc<RwLock<HashMap<String, usize>>>,
61 health_status: Arc<RwLock<HashMap<String, bool>>>,
63 tie_breaker_counter: AtomicUsize,
65 config: WeightedLeastConnConfig,
67}
68
69impl WeightedLeastConnBalancer {
70 pub fn new(targets: Vec<UpstreamTarget>, config: WeightedLeastConnConfig) -> Self {
72 let mut health_status = HashMap::new();
73 let mut connections = HashMap::new();
74
75 for target in &targets {
76 let addr = target.full_address();
77 health_status.insert(addr.clone(), true);
78 connections.insert(addr, 0);
79 }
80
81 Self {
82 targets,
83 connections: Arc::new(RwLock::new(connections)),
84 health_status: Arc::new(RwLock::new(health_status)),
85 tie_breaker_counter: AtomicUsize::new(0),
86 config,
87 }
88 }
89
90 fn calculate_score(&self, connections: usize, weight: u32) -> f64 {
93 let effective_weight = weight.max(self.config.min_weight) as f64;
94 connections as f64 / effective_weight
95 }
96
97 fn break_tie<'a>(
99 &self,
100 candidates: &[(&'a UpstreamTarget, usize)],
101 ) -> Option<&'a UpstreamTarget> {
102 if candidates.is_empty() {
103 return None;
104 }
105 if candidates.len() == 1 {
106 return Some(candidates[0].0);
107 }
108
109 match self.config.tie_breaker {
110 TieBreakerStrategy::HigherWeight => candidates
111 .iter()
112 .max_by_key(|(t, _)| t.weight)
113 .map(|(t, _)| *t),
114 TieBreakerStrategy::FewerConnections => {
115 candidates.iter().min_by_key(|(_, c)| *c).map(|(t, _)| *t)
116 }
117 TieBreakerStrategy::RoundRobin => {
118 let idx =
119 self.tie_breaker_counter.fetch_add(1, Ordering::Relaxed) % candidates.len();
120 Some(candidates[idx].0)
121 }
122 }
123 }
124}
125
126#[async_trait]
127impl LoadBalancer for WeightedLeastConnBalancer {
128 async fn select(&self, _context: Option<&RequestContext>) -> GrapsusResult<TargetSelection> {
129 trace!(
130 total_targets = self.targets.len(),
131 algorithm = "weighted_least_conn",
132 "Selecting upstream target"
133 );
134
135 let health = self.health_status.read().await;
136 let conns = self.connections.read().await;
137
138 let scored_targets: Vec<_> = self
140 .targets
141 .iter()
142 .filter(|t| *health.get(&t.full_address()).unwrap_or(&true))
143 .map(|t| {
144 let addr = t.full_address();
145 let conn_count = *conns.get(&addr).unwrap_or(&0);
146 let score = self.calculate_score(conn_count, t.weight);
147 (t, conn_count, score)
148 })
149 .collect();
150
151 drop(health);
152
153 if scored_targets.is_empty() {
154 warn!(
155 total_targets = self.targets.len(),
156 algorithm = "weighted_least_conn",
157 "No healthy upstream targets available"
158 );
159 return Err(GrapsusError::NoHealthyUpstream);
160 }
161
162 let min_score = scored_targets
164 .iter()
165 .map(|(_, _, s)| *s)
166 .fold(f64::INFINITY, f64::min);
167
168 let candidates: Vec<_> = scored_targets
170 .iter()
171 .filter(|(_, _, s)| (*s - min_score).abs() < f64::EPSILON)
172 .map(|(t, c, _)| (*t, *c))
173 .collect();
174
175 let target = self
176 .break_tie(&candidates)
177 .ok_or(GrapsusError::NoHealthyUpstream)?;
178
179 drop(conns);
181 {
182 let mut conns = self.connections.write().await;
183 *conns.entry(target.full_address()).or_insert(0) += 1;
184 }
185
186 let conn_count = *self
187 .connections
188 .read()
189 .await
190 .get(&target.full_address())
191 .unwrap_or(&0);
192 let score = self.calculate_score(conn_count, target.weight);
193
194 trace!(
195 selected_target = %target.full_address(),
196 weight = target.weight,
197 connections = conn_count,
198 score = score,
199 healthy_count = scored_targets.len(),
200 algorithm = "weighted_least_conn",
201 "Selected target via weighted least connections"
202 );
203
204 Ok(TargetSelection {
205 address: target.full_address(),
206 weight: target.weight,
207 metadata: HashMap::new(),
208 })
209 }
210
211 async fn release(&self, selection: &TargetSelection) {
212 let mut conns = self.connections.write().await;
213 if let Some(count) = conns.get_mut(&selection.address) {
214 *count = count.saturating_sub(1);
215 trace!(
216 target = %selection.address,
217 connections = *count,
218 algorithm = "weighted_least_conn",
219 "Released connection"
220 );
221 }
222 }
223
224 async fn report_health(&self, address: &str, healthy: bool) {
225 trace!(
226 target = %address,
227 healthy = healthy,
228 algorithm = "weighted_least_conn",
229 "Updating target health status"
230 );
231 self.health_status
232 .write()
233 .await
234 .insert(address.to_string(), healthy);
235 }
236
237 async fn healthy_targets(&self) -> Vec<String> {
238 self.health_status
239 .read()
240 .await
241 .iter()
242 .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
243 .collect()
244 }
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250
251 fn make_weighted_targets() -> Vec<UpstreamTarget> {
252 vec![
253 UpstreamTarget::new("backend-small", 8080, 50), UpstreamTarget::new("backend-medium", 8080, 100), UpstreamTarget::new("backend-large", 8080, 200), ]
257 }
258
259 #[tokio::test]
260 async fn test_prefers_higher_weight_when_empty() {
261 let targets = make_weighted_targets();
262 let balancer = WeightedLeastConnBalancer::new(targets, WeightedLeastConnConfig::default());
263
264 let selection = balancer.select(None).await.unwrap();
266 assert_eq!(selection.address, "backend-large:8080");
267 }
268
269 #[tokio::test]
270 async fn test_weighted_connection_ratio() {
271 let targets = make_weighted_targets();
272 let balancer = WeightedLeastConnBalancer::new(targets, WeightedLeastConnConfig::default());
273
274 {
276 let mut conns = balancer.connections.write().await;
277 conns.insert("backend-small:8080".to_string(), 5); conns.insert("backend-medium:8080".to_string(), 10); conns.insert("backend-large:8080".to_string(), 20); }
281
282 let selection = balancer.select(None).await.unwrap();
284 assert_eq!(selection.address, "backend-large:8080");
285 }
286
287 #[tokio::test]
288 async fn test_selects_lower_ratio() {
289 let targets = make_weighted_targets();
290 let balancer = WeightedLeastConnBalancer::new(targets, WeightedLeastConnConfig::default());
291
292 {
294 let mut conns = balancer.connections.write().await;
295 conns.insert("backend-small:8080".to_string(), 10); conns.insert("backend-medium:8080".to_string(), 15); conns.insert("backend-large:8080".to_string(), 20); }
299
300 let selection = balancer.select(None).await.unwrap();
301 assert_eq!(selection.address, "backend-large:8080");
302 }
303
304 #[tokio::test]
305 async fn test_selects_small_when_others_overloaded() {
306 let targets = make_weighted_targets();
307 let balancer = WeightedLeastConnBalancer::new(targets, WeightedLeastConnConfig::default());
308
309 {
311 let mut conns = balancer.connections.write().await;
312 conns.insert("backend-small:8080".to_string(), 2); conns.insert("backend-medium:8080".to_string(), 20); conns.insert("backend-large:8080".to_string(), 50); }
316
317 let selection = balancer.select(None).await.unwrap();
318 assert_eq!(selection.address, "backend-small:8080");
319 }
320
321 #[tokio::test]
322 async fn test_connection_tracking() {
323 let targets = vec![UpstreamTarget::new("backend", 8080, 100)];
324 let balancer = WeightedLeastConnBalancer::new(targets, WeightedLeastConnConfig::default());
325
326 let selection1 = balancer.select(None).await.unwrap();
328 let selection2 = balancer.select(None).await.unwrap();
329
330 {
331 let conns = balancer.connections.read().await;
332 assert_eq!(*conns.get("backend:8080").unwrap(), 2);
333 }
334
335 balancer.release(&selection1).await;
337
338 {
339 let conns = balancer.connections.read().await;
340 assert_eq!(*conns.get("backend:8080").unwrap(), 1);
341 }
342
343 balancer.release(&selection2).await;
344
345 {
346 let conns = balancer.connections.read().await;
347 assert_eq!(*conns.get("backend:8080").unwrap(), 0);
348 }
349 }
350
351 #[tokio::test]
352 async fn test_fewer_connections_tie_breaker() {
353 let targets = vec![
354 UpstreamTarget::new("backend-a", 8080, 100),
355 UpstreamTarget::new("backend-b", 8080, 100),
356 ];
357 let config = WeightedLeastConnConfig {
358 min_weight: 1,
359 tie_breaker: TieBreakerStrategy::FewerConnections,
360 };
361 let balancer = WeightedLeastConnBalancer::new(targets, config);
362
363 {
365 let mut conns = balancer.connections.write().await;
366 conns.insert("backend-a:8080".to_string(), 5);
367 conns.insert("backend-b:8080".to_string(), 3); }
369
370 {
372 let mut conns = balancer.connections.write().await;
373 conns.insert("backend-a:8080".to_string(), 5);
374 conns.insert("backend-b:8080".to_string(), 5);
375 }
376
377 }
380
381 #[tokio::test]
382 async fn test_respects_health_status() {
383 let targets = make_weighted_targets();
384 let balancer = WeightedLeastConnBalancer::new(targets, WeightedLeastConnConfig::default());
385
386 balancer.report_health("backend-large:8080", false).await;
388
389 for _ in 0..10 {
391 let selection = balancer.select(None).await.unwrap();
392 assert_ne!(selection.address, "backend-large:8080");
393 }
394 }
395}