Skip to main content

grapsus_proxy/upstream/
weighted_least_conn.rs

1//! Weighted Least Connections load balancer
2//!
3//! Combines weight-based selection with connection counting. The algorithm
4//! selects the backend with the lowest ratio of active connections to weight.
5//!
6//! Score = active_connections / weight
7//!
8//! A backend with weight 200 and 10 connections (score: 0.05) is preferred
9//! over a backend with weight 100 and 6 connections (score: 0.06).
10//!
11//! This is useful when backends have different capacities - higher weight
12//! backends can handle more concurrent connections proportionally.
13
14use async_trait::async_trait;
15use std::collections::HashMap;
16use std::sync::atomic::{AtomicUsize, Ordering};
17use std::sync::Arc;
18use tokio::sync::RwLock;
19use tracing::{debug, trace, warn};
20
21use grapsus_common::errors::{GrapsusError, GrapsusResult};
22
23use super::{LoadBalancer, RequestContext, TargetSelection, UpstreamTarget};
24
25/// Configuration for Weighted Least Connections
26#[derive(Debug, Clone)]
27pub struct WeightedLeastConnConfig {
28    /// Minimum weight to prevent division by zero (default: 1)
29    pub min_weight: u32,
30    /// Tie-breaker strategy when scores are equal
31    pub tie_breaker: TieBreakerStrategy,
32}
33
34impl Default for WeightedLeastConnConfig {
35    fn default() -> Self {
36        Self {
37            min_weight: 1,
38            tie_breaker: TieBreakerStrategy::HigherWeight,
39        }
40    }
41}
42
43/// Strategy for breaking ties when multiple backends have the same score
44#[derive(Debug, Clone, Copy, Default)]
45pub enum TieBreakerStrategy {
46    /// Prefer backend with higher weight (can handle more traffic)
47    #[default]
48    HigherWeight,
49    /// Prefer backend with fewer connections (more headroom)
50    FewerConnections,
51    /// Round-robin among tied backends
52    RoundRobin,
53}
54
55/// Weighted Least Connections load balancer
56pub struct WeightedLeastConnBalancer {
57    /// Target list
58    targets: Vec<UpstreamTarget>,
59    /// Active connections per target
60    connections: Arc<RwLock<HashMap<String, usize>>>,
61    /// Health status per target
62    health_status: Arc<RwLock<HashMap<String, bool>>>,
63    /// Round-robin counter for tie-breaking
64    tie_breaker_counter: AtomicUsize,
65    /// Configuration
66    config: WeightedLeastConnConfig,
67}
68
69impl WeightedLeastConnBalancer {
70    /// Create a new Weighted Least Connections balancer
71    pub fn new(targets: Vec<UpstreamTarget>, config: WeightedLeastConnConfig) -> Self {
72        let mut health_status = HashMap::new();
73        let mut connections = HashMap::new();
74
75        for target in &targets {
76            let addr = target.full_address();
77            health_status.insert(addr.clone(), true);
78            connections.insert(addr, 0);
79        }
80
81        Self {
82            targets,
83            connections: Arc::new(RwLock::new(connections)),
84            health_status: Arc::new(RwLock::new(health_status)),
85            tie_breaker_counter: AtomicUsize::new(0),
86            config,
87        }
88    }
89
90    /// Calculate the weighted connection score for a target
91    /// Lower score = better candidate
92    fn calculate_score(&self, connections: usize, weight: u32) -> f64 {
93        let effective_weight = weight.max(self.config.min_weight) as f64;
94        connections as f64 / effective_weight
95    }
96
97    /// Break ties between targets with the same score
98    fn break_tie<'a>(
99        &self,
100        candidates: &[(&'a UpstreamTarget, usize)],
101    ) -> Option<&'a UpstreamTarget> {
102        if candidates.is_empty() {
103            return None;
104        }
105        if candidates.len() == 1 {
106            return Some(candidates[0].0);
107        }
108
109        match self.config.tie_breaker {
110            TieBreakerStrategy::HigherWeight => candidates
111                .iter()
112                .max_by_key(|(t, _)| t.weight)
113                .map(|(t, _)| *t),
114            TieBreakerStrategy::FewerConnections => {
115                candidates.iter().min_by_key(|(_, c)| *c).map(|(t, _)| *t)
116            }
117            TieBreakerStrategy::RoundRobin => {
118                let idx =
119                    self.tie_breaker_counter.fetch_add(1, Ordering::Relaxed) % candidates.len();
120                Some(candidates[idx].0)
121            }
122        }
123    }
124}
125
126#[async_trait]
127impl LoadBalancer for WeightedLeastConnBalancer {
128    async fn select(&self, _context: Option<&RequestContext>) -> GrapsusResult<TargetSelection> {
129        trace!(
130            total_targets = self.targets.len(),
131            algorithm = "weighted_least_conn",
132            "Selecting upstream target"
133        );
134
135        let health = self.health_status.read().await;
136        let conns = self.connections.read().await;
137
138        // Calculate scores for healthy targets
139        let scored_targets: Vec<_> = self
140            .targets
141            .iter()
142            .filter(|t| *health.get(&t.full_address()).unwrap_or(&true))
143            .map(|t| {
144                let addr = t.full_address();
145                let conn_count = *conns.get(&addr).unwrap_or(&0);
146                let score = self.calculate_score(conn_count, t.weight);
147                (t, conn_count, score)
148            })
149            .collect();
150
151        drop(health);
152
153        if scored_targets.is_empty() {
154            warn!(
155                total_targets = self.targets.len(),
156                algorithm = "weighted_least_conn",
157                "No healthy upstream targets available"
158            );
159            return Err(GrapsusError::NoHealthyUpstream);
160        }
161
162        // Find minimum score
163        let min_score = scored_targets
164            .iter()
165            .map(|(_, _, s)| *s)
166            .fold(f64::INFINITY, f64::min);
167
168        // Get all targets with the minimum score (for tie-breaking)
169        let candidates: Vec<_> = scored_targets
170            .iter()
171            .filter(|(_, _, s)| (*s - min_score).abs() < f64::EPSILON)
172            .map(|(t, c, _)| (*t, *c))
173            .collect();
174
175        let target = self
176            .break_tie(&candidates)
177            .ok_or(GrapsusError::NoHealthyUpstream)?;
178
179        // Increment connection count
180        drop(conns);
181        {
182            let mut conns = self.connections.write().await;
183            *conns.entry(target.full_address()).or_insert(0) += 1;
184        }
185
186        let conn_count = *self
187            .connections
188            .read()
189            .await
190            .get(&target.full_address())
191            .unwrap_or(&0);
192        let score = self.calculate_score(conn_count, target.weight);
193
194        trace!(
195            selected_target = %target.full_address(),
196            weight = target.weight,
197            connections = conn_count,
198            score = score,
199            healthy_count = scored_targets.len(),
200            algorithm = "weighted_least_conn",
201            "Selected target via weighted least connections"
202        );
203
204        Ok(TargetSelection {
205            address: target.full_address(),
206            weight: target.weight,
207            metadata: HashMap::new(),
208        })
209    }
210
211    async fn release(&self, selection: &TargetSelection) {
212        let mut conns = self.connections.write().await;
213        if let Some(count) = conns.get_mut(&selection.address) {
214            *count = count.saturating_sub(1);
215            trace!(
216                target = %selection.address,
217                connections = *count,
218                algorithm = "weighted_least_conn",
219                "Released connection"
220            );
221        }
222    }
223
224    async fn report_health(&self, address: &str, healthy: bool) {
225        trace!(
226            target = %address,
227            healthy = healthy,
228            algorithm = "weighted_least_conn",
229            "Updating target health status"
230        );
231        self.health_status
232            .write()
233            .await
234            .insert(address.to_string(), healthy);
235    }
236
237    async fn healthy_targets(&self) -> Vec<String> {
238        self.health_status
239            .read()
240            .await
241            .iter()
242            .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
243            .collect()
244    }
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250
251    fn make_weighted_targets() -> Vec<UpstreamTarget> {
252        vec![
253            UpstreamTarget::new("backend-small", 8080, 50), // Low capacity
254            UpstreamTarget::new("backend-medium", 8080, 100), // Medium capacity
255            UpstreamTarget::new("backend-large", 8080, 200), // High capacity
256        ]
257    }
258
259    #[tokio::test]
260    async fn test_prefers_higher_weight_when_empty() {
261        let targets = make_weighted_targets();
262        let balancer = WeightedLeastConnBalancer::new(targets, WeightedLeastConnConfig::default());
263
264        // With no connections, all have score 0, tie-breaker prefers higher weight
265        let selection = balancer.select(None).await.unwrap();
266        assert_eq!(selection.address, "backend-large:8080");
267    }
268
269    #[tokio::test]
270    async fn test_weighted_connection_ratio() {
271        let targets = make_weighted_targets();
272        let balancer = WeightedLeastConnBalancer::new(targets, WeightedLeastConnConfig::default());
273
274        // Add connections proportional to weight
275        {
276            let mut conns = balancer.connections.write().await;
277            conns.insert("backend-small:8080".to_string(), 5); // 5/50 = 0.10
278            conns.insert("backend-medium:8080".to_string(), 10); // 10/100 = 0.10
279            conns.insert("backend-large:8080".to_string(), 20); // 20/200 = 0.10
280        }
281
282        // All have same ratio, tie-breaker picks highest weight
283        let selection = balancer.select(None).await.unwrap();
284        assert_eq!(selection.address, "backend-large:8080");
285    }
286
287    #[tokio::test]
288    async fn test_selects_lower_ratio() {
289        let targets = make_weighted_targets();
290        let balancer = WeightedLeastConnBalancer::new(targets, WeightedLeastConnConfig::default());
291
292        // backend-large has better ratio
293        {
294            let mut conns = balancer.connections.write().await;
295            conns.insert("backend-small:8080".to_string(), 10); // 10/50 = 0.20
296            conns.insert("backend-medium:8080".to_string(), 15); // 15/100 = 0.15
297            conns.insert("backend-large:8080".to_string(), 20); // 20/200 = 0.10 (best)
298        }
299
300        let selection = balancer.select(None).await.unwrap();
301        assert_eq!(selection.address, "backend-large:8080");
302    }
303
304    #[tokio::test]
305    async fn test_selects_small_when_others_overloaded() {
306        let targets = make_weighted_targets();
307        let balancer = WeightedLeastConnBalancer::new(targets, WeightedLeastConnConfig::default());
308
309        // backend-small has best ratio despite low weight
310        {
311            let mut conns = balancer.connections.write().await;
312            conns.insert("backend-small:8080".to_string(), 2); // 2/50 = 0.04 (best)
313            conns.insert("backend-medium:8080".to_string(), 20); // 20/100 = 0.20
314            conns.insert("backend-large:8080".to_string(), 50); // 50/200 = 0.25
315        }
316
317        let selection = balancer.select(None).await.unwrap();
318        assert_eq!(selection.address, "backend-small:8080");
319    }
320
321    #[tokio::test]
322    async fn test_connection_tracking() {
323        let targets = vec![UpstreamTarget::new("backend", 8080, 100)];
324        let balancer = WeightedLeastConnBalancer::new(targets, WeightedLeastConnConfig::default());
325
326        // Select increments connections
327        let selection1 = balancer.select(None).await.unwrap();
328        let selection2 = balancer.select(None).await.unwrap();
329
330        {
331            let conns = balancer.connections.read().await;
332            assert_eq!(*conns.get("backend:8080").unwrap(), 2);
333        }
334
335        // Release decrements connections
336        balancer.release(&selection1).await;
337
338        {
339            let conns = balancer.connections.read().await;
340            assert_eq!(*conns.get("backend:8080").unwrap(), 1);
341        }
342
343        balancer.release(&selection2).await;
344
345        {
346            let conns = balancer.connections.read().await;
347            assert_eq!(*conns.get("backend:8080").unwrap(), 0);
348        }
349    }
350
351    #[tokio::test]
352    async fn test_fewer_connections_tie_breaker() {
353        let targets = vec![
354            UpstreamTarget::new("backend-a", 8080, 100),
355            UpstreamTarget::new("backend-b", 8080, 100),
356        ];
357        let config = WeightedLeastConnConfig {
358            min_weight: 1,
359            tie_breaker: TieBreakerStrategy::FewerConnections,
360        };
361        let balancer = WeightedLeastConnBalancer::new(targets, config);
362
363        // Same weight, different connections
364        {
365            let mut conns = balancer.connections.write().await;
366            conns.insert("backend-a:8080".to_string(), 5);
367            conns.insert("backend-b:8080".to_string(), 3); // Fewer connections
368        }
369
370        // Both have score 0.05 and 0.03, but if we set them equal:
371        {
372            let mut conns = balancer.connections.write().await;
373            conns.insert("backend-a:8080".to_string(), 5);
374            conns.insert("backend-b:8080".to_string(), 5);
375        }
376
377        // With same score, fewer_connections tie-breaker should still work
378        // (but they're equal now so either is valid)
379    }
380
381    #[tokio::test]
382    async fn test_respects_health_status() {
383        let targets = make_weighted_targets();
384        let balancer = WeightedLeastConnBalancer::new(targets, WeightedLeastConnConfig::default());
385
386        // Mark large backend as unhealthy
387        balancer.report_health("backend-large:8080", false).await;
388
389        // Should not select the unhealthy backend
390        for _ in 0..10 {
391            let selection = balancer.select(None).await.unwrap();
392            assert_ne!(selection.address, "backend-large:8080");
393        }
394    }
395}