amaters_net/
balancer.rs

1//! Load balancing strategies for distributing requests across endpoints
2//!
3//! Provides multiple strategies for selecting endpoints based on different criteria.
4
5use crate::error::{NetError, NetResult};
6use parking_lot::RwLock;
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
10
11/// Endpoint identifier
12pub type EndpointId = String;
13
14/// Endpoint weight for weighted load balancing
15pub type Weight = u32;
16
17/// Endpoint information
18#[derive(Debug, Clone)]
19pub struct Endpoint {
20    /// Unique endpoint identifier
21    pub id: EndpointId,
22    /// Endpoint address (e.g., "localhost:50051")
23    pub address: String,
24    /// Endpoint weight (for weighted balancing)
25    pub weight: Weight,
26    /// Number of active connections
27    pub active_connections: Arc<AtomicUsize>,
28    /// Total requests handled
29    pub total_requests: Arc<AtomicU64>,
30    /// Whether endpoint is healthy
31    pub healthy: Arc<parking_lot::RwLock<bool>>,
32}
33
34impl Endpoint {
35    /// Create a new endpoint
36    pub fn new(id: EndpointId, address: String) -> Self {
37        Self::with_weight(id, address, 1)
38    }
39
40    /// Create a new endpoint with weight
41    pub fn with_weight(id: EndpointId, address: String, weight: Weight) -> Self {
42        Self {
43            id,
44            address,
45            weight,
46            active_connections: Arc::new(AtomicUsize::new(0)),
47            total_requests: Arc::new(AtomicU64::new(0)),
48            healthy: Arc::new(parking_lot::RwLock::new(true)),
49        }
50    }
51
52    /// Check if endpoint is healthy
53    pub fn is_healthy(&self) -> bool {
54        *self.healthy.read()
55    }
56
57    /// Mark endpoint as healthy
58    pub fn mark_healthy(&self) {
59        *self.healthy.write() = true;
60    }
61
62    /// Mark endpoint as unhealthy
63    pub fn mark_unhealthy(&self) {
64        *self.healthy.write() = false;
65    }
66
67    /// Get active connection count
68    pub fn active_connections(&self) -> usize {
69        self.active_connections.load(Ordering::Relaxed)
70    }
71
72    /// Increment active connections
73    pub fn increment_connections(&self) {
74        self.active_connections.fetch_add(1, Ordering::Relaxed);
75        self.total_requests.fetch_add(1, Ordering::Relaxed);
76    }
77
78    /// Decrement active connections
79    pub fn decrement_connections(&self) {
80        self.active_connections.fetch_sub(1, Ordering::Relaxed);
81    }
82
83    /// Get total requests handled
84    pub fn total_requests(&self) -> u64 {
85        self.total_requests.load(Ordering::Relaxed)
86    }
87}
88
89/// Load balancing strategy
90#[derive(Debug, Clone, Copy, PartialEq, Eq)]
91pub enum BalancingStrategy {
92    /// Round-robin: Rotate through endpoints in order
93    RoundRobin,
94    /// Least connections: Select endpoint with fewest active connections
95    LeastConnections,
96    /// Weighted: Select based on endpoint weights
97    Weighted,
98}
99
100/// Connection affinity (sticky sessions)
101#[derive(Debug, Clone)]
102pub struct Affinity {
103    /// Session ID to endpoint mapping
104    sessions: Arc<RwLock<HashMap<String, EndpointId>>>,
105}
106
107impl Affinity {
108    /// Create new affinity tracker
109    pub fn new() -> Self {
110        Self {
111            sessions: Arc::new(RwLock::new(HashMap::new())),
112        }
113    }
114
115    /// Get endpoint for session
116    pub fn get(&self, session_id: &str) -> Option<EndpointId> {
117        self.sessions.read().get(session_id).cloned()
118    }
119
120    /// Set endpoint for session
121    pub fn set(&self, session_id: String, endpoint_id: EndpointId) {
122        self.sessions.write().insert(session_id, endpoint_id);
123    }
124
125    /// Remove session
126    pub fn remove(&self, session_id: &str) {
127        self.sessions.write().remove(session_id);
128    }
129
130    /// Clear all sessions
131    pub fn clear(&self) {
132        self.sessions.write().clear();
133    }
134}
135
136impl Default for Affinity {
137    fn default() -> Self {
138        Self::new()
139    }
140}
141
142/// Load balancer for distributing requests across endpoints
143#[derive(Debug)]
144pub struct LoadBalancer {
145    /// Load balancing strategy
146    strategy: BalancingStrategy,
147    /// Available endpoints
148    endpoints: Arc<RwLock<Vec<Arc<Endpoint>>>>,
149    /// Current round-robin index
150    round_robin_index: AtomicUsize,
151    /// Connection affinity
152    affinity: Affinity,
153}
154
155impl LoadBalancer {
156    /// Create a new load balancer with the given strategy
157    pub fn new(strategy: BalancingStrategy) -> Self {
158        Self {
159            strategy,
160            endpoints: Arc::new(RwLock::new(Vec::new())),
161            round_robin_index: AtomicUsize::new(0),
162            affinity: Affinity::new(),
163        }
164    }
165
166    /// Add an endpoint to the load balancer
167    pub fn add_endpoint(&self, endpoint: Endpoint) {
168        self.endpoints.write().push(Arc::new(endpoint));
169    }
170
171    /// Remove an endpoint from the load balancer
172    pub fn remove_endpoint(&self, endpoint_id: &str) -> bool {
173        let mut endpoints = self.endpoints.write();
174        if let Some(pos) = endpoints.iter().position(|e| e.id == endpoint_id) {
175            endpoints.remove(pos);
176            true
177        } else {
178            false
179        }
180    }
181
182    /// Get all endpoints
183    pub fn endpoints(&self) -> Vec<Arc<Endpoint>> {
184        self.endpoints.read().clone()
185    }
186
187    /// Get healthy endpoints
188    pub fn healthy_endpoints(&self) -> Vec<Arc<Endpoint>> {
189        self.endpoints
190            .read()
191            .iter()
192            .filter(|e| e.is_healthy())
193            .cloned()
194            .collect()
195    }
196
197    /// Select an endpoint using the configured strategy
198    pub fn select_endpoint(&self) -> NetResult<Arc<Endpoint>> {
199        let healthy_endpoints = self.healthy_endpoints();
200
201        if healthy_endpoints.is_empty() {
202            return Err(NetError::ServerUnavailable(
203                "No healthy endpoints available".to_string(),
204            ));
205        }
206
207        match self.strategy {
208            BalancingStrategy::RoundRobin => self.select_round_robin(&healthy_endpoints),
209            BalancingStrategy::LeastConnections => {
210                self.select_least_connections(&healthy_endpoints)
211            }
212            BalancingStrategy::Weighted => self.select_weighted(&healthy_endpoints),
213        }
214    }
215
216    /// Select endpoint with affinity (sticky session)
217    pub fn select_with_affinity(&self, session_id: &str) -> NetResult<Arc<Endpoint>> {
218        // Check if session has an existing endpoint
219        if let Some(endpoint_id) = self.affinity.get(session_id) {
220            // Find the endpoint
221            if let Some(endpoint) = self
222                .healthy_endpoints()
223                .iter()
224                .find(|e| e.id == endpoint_id)
225            {
226                return Ok(Arc::clone(endpoint));
227            }
228        }
229
230        // No existing endpoint or unhealthy - select a new one
231        let endpoint = self.select_endpoint()?;
232        self.affinity
233            .set(session_id.to_string(), endpoint.id.clone());
234        Ok(endpoint)
235    }
236
237    /// Clear session affinity
238    pub fn clear_affinity(&self, session_id: &str) {
239        self.affinity.remove(session_id);
240    }
241
242    /// Get load balancing statistics
243    pub fn stats(&self) -> BalancerStats {
244        let endpoints = self.endpoints.read();
245        let total_endpoints = endpoints.len();
246        let healthy_endpoints = endpoints.iter().filter(|e| e.is_healthy()).count();
247        let total_connections: usize = endpoints.iter().map(|e| e.active_connections()).sum();
248        let total_requests: u64 = endpoints.iter().map(|e| e.total_requests()).sum();
249
250        BalancerStats {
251            total_endpoints,
252            healthy_endpoints,
253            total_connections,
254            total_requests,
255            strategy: self.strategy,
256        }
257    }
258
259    /// Round-robin selection
260    fn select_round_robin(&self, endpoints: &[Arc<Endpoint>]) -> NetResult<Arc<Endpoint>> {
261        if endpoints.is_empty() {
262            return Err(NetError::ServerUnavailable(
263                "No endpoints available".to_string(),
264            ));
265        }
266
267        let index = self.round_robin_index.fetch_add(1, Ordering::Relaxed);
268        let endpoint = &endpoints[index % endpoints.len()];
269        Ok(Arc::clone(endpoint))
270    }
271
272    /// Least connections selection
273    fn select_least_connections(&self, endpoints: &[Arc<Endpoint>]) -> NetResult<Arc<Endpoint>> {
274        endpoints
275            .iter()
276            .min_by_key(|e| e.active_connections())
277            .map(Arc::clone)
278            .ok_or_else(|| NetError::ServerUnavailable("No endpoints available".to_string()))
279    }
280
281    /// Weighted selection using weighted random
282    fn select_weighted(&self, endpoints: &[Arc<Endpoint>]) -> NetResult<Arc<Endpoint>> {
283        if endpoints.is_empty() {
284            return Err(NetError::ServerUnavailable(
285                "No endpoints available".to_string(),
286            ));
287        }
288
289        // Calculate total weight
290        let total_weight: u32 = endpoints.iter().map(|e| e.weight).sum();
291
292        if total_weight == 0 {
293            // If all weights are zero, fall back to round-robin
294            return self.select_round_robin(endpoints);
295        }
296
297        // Use round-robin counter as pseudo-random selector
298        let selector = self.round_robin_index.fetch_add(1, Ordering::Relaxed) as u32;
299        let target = selector % total_weight;
300
301        // Find endpoint based on weighted selection
302        let mut cumulative = 0u32;
303        for endpoint in endpoints {
304            cumulative += endpoint.weight;
305            if target < cumulative {
306                return Ok(Arc::clone(endpoint));
307            }
308        }
309
310        // Fallback to last endpoint (shouldn't happen)
311        Ok(Arc::clone(&endpoints[endpoints.len() - 1]))
312    }
313}
314
315/// Load balancer statistics
316#[derive(Debug, Clone)]
317pub struct BalancerStats {
318    /// Total number of endpoints
319    pub total_endpoints: usize,
320    /// Number of healthy endpoints
321    pub healthy_endpoints: usize,
322    /// Total active connections across all endpoints
323    pub total_connections: usize,
324    /// Total requests handled
325    pub total_requests: u64,
326    /// Current balancing strategy
327    pub strategy: BalancingStrategy,
328}
329
330/// Connection guard that automatically decrements connection count
331pub struct ConnectionGuard {
332    endpoint: Arc<Endpoint>,
333}
334
335impl ConnectionGuard {
336    /// Create a new connection guard
337    pub fn new(endpoint: Arc<Endpoint>) -> Self {
338        endpoint.increment_connections();
339        Self { endpoint }
340    }
341
342    /// Get the endpoint
343    pub fn endpoint(&self) -> &Arc<Endpoint> {
344        &self.endpoint
345    }
346}
347
348impl Drop for ConnectionGuard {
349    fn drop(&mut self) {
350        self.endpoint.decrement_connections();
351    }
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357
358    #[test]
359    fn test_endpoint_creation() {
360        let endpoint = Endpoint::new("ep1".to_string(), "localhost:50051".to_string());
361        assert_eq!(endpoint.id, "ep1");
362        assert_eq!(endpoint.address, "localhost:50051");
363        assert_eq!(endpoint.weight, 1);
364        assert!(endpoint.is_healthy());
365    }
366
367    #[test]
368    fn test_endpoint_health() {
369        let endpoint = Endpoint::new("ep1".to_string(), "localhost:50051".to_string());
370        assert!(endpoint.is_healthy());
371
372        endpoint.mark_unhealthy();
373        assert!(!endpoint.is_healthy());
374
375        endpoint.mark_healthy();
376        assert!(endpoint.is_healthy());
377    }
378
379    #[test]
380    fn test_endpoint_connections() {
381        let endpoint = Endpoint::new("ep1".to_string(), "localhost:50051".to_string());
382        assert_eq!(endpoint.active_connections(), 0);
383
384        endpoint.increment_connections();
385        assert_eq!(endpoint.active_connections(), 1);
386
387        endpoint.increment_connections();
388        assert_eq!(endpoint.active_connections(), 2);
389
390        endpoint.decrement_connections();
391        assert_eq!(endpoint.active_connections(), 1);
392    }
393
394    #[test]
395    fn test_load_balancer_round_robin() {
396        let lb = LoadBalancer::new(BalancingStrategy::RoundRobin);
397
398        lb.add_endpoint(Endpoint::new(
399            "ep1".to_string(),
400            "localhost:50051".to_string(),
401        ));
402        lb.add_endpoint(Endpoint::new(
403            "ep2".to_string(),
404            "localhost:50052".to_string(),
405        ));
406        lb.add_endpoint(Endpoint::new(
407            "ep3".to_string(),
408            "localhost:50053".to_string(),
409        ));
410
411        // Should rotate through endpoints
412        let ep1 = lb.select_endpoint().expect("should select endpoint");
413        let ep2 = lb.select_endpoint().expect("should select endpoint");
414        let ep3 = lb.select_endpoint().expect("should select endpoint");
415        let ep4 = lb.select_endpoint().expect("should select endpoint");
416
417        assert_eq!(ep1.id, "ep1");
418        assert_eq!(ep2.id, "ep2");
419        assert_eq!(ep3.id, "ep3");
420        assert_eq!(ep4.id, "ep1"); // Wraps around
421    }
422
423    #[test]
424    fn test_load_balancer_least_connections() {
425        let lb = LoadBalancer::new(BalancingStrategy::LeastConnections);
426
427        lb.add_endpoint(Endpoint::new(
428            "ep1".to_string(),
429            "localhost:50051".to_string(),
430        ));
431        lb.add_endpoint(Endpoint::new(
432            "ep2".to_string(),
433            "localhost:50052".to_string(),
434        ));
435
436        // First selection
437        let ep1 = lb.select_endpoint().expect("should select endpoint");
438        ep1.increment_connections();
439
440        // Should select ep2 (fewer connections)
441        let ep2 = lb.select_endpoint().expect("should select endpoint");
442        assert_eq!(ep2.id, "ep2");
443
444        ep2.increment_connections();
445        ep2.increment_connections(); // ep2 now has 2, ep1 has 1
446
447        // Should select ep1 (fewer connections)
448        let ep3 = lb.select_endpoint().expect("should select endpoint");
449        assert_eq!(ep3.id, "ep1");
450    }
451
452    #[test]
453    fn test_load_balancer_weighted() {
454        let lb = LoadBalancer::new(BalancingStrategy::Weighted);
455
456        lb.add_endpoint(Endpoint::with_weight(
457            "ep1".to_string(),
458            "localhost:50051".to_string(),
459            3,
460        ));
461        lb.add_endpoint(Endpoint::with_weight(
462            "ep2".to_string(),
463            "localhost:50052".to_string(),
464            1,
465        ));
466
467        // Collect selections
468        let mut counts = HashMap::new();
469        for _ in 0..40 {
470            let ep = lb.select_endpoint().expect("should select endpoint");
471            *counts.entry(ep.id.clone()).or_insert(0) += 1;
472        }
473
474        // ep1 should be selected ~3x more than ep2
475        let ep1_count = counts.get("ep1").copied().unwrap_or(0);
476        let ep2_count = counts.get("ep2").copied().unwrap_or(0);
477
478        // With 40 selections and 3:1 weight, expect ~30:10 distribution
479        assert!(ep1_count > ep2_count);
480        assert!(ep1_count >= 20); // At least 50% (should be ~75%)
481    }
482
483    #[test]
484    fn test_load_balancer_no_endpoints() {
485        let lb = LoadBalancer::new(BalancingStrategy::RoundRobin);
486        let result = lb.select_endpoint();
487        assert!(result.is_err());
488    }
489
490    #[test]
491    fn test_load_balancer_unhealthy_endpoints() {
492        let lb = LoadBalancer::new(BalancingStrategy::RoundRobin);
493
494        let ep1 = Endpoint::new("ep1".to_string(), "localhost:50051".to_string());
495        let ep2 = Endpoint::new("ep2".to_string(), "localhost:50052".to_string());
496
497        ep1.mark_unhealthy();
498
499        lb.add_endpoint(ep1);
500        lb.add_endpoint(ep2);
501
502        // Should only select ep2 (healthy)
503        for _ in 0..5 {
504            let ep = lb.select_endpoint().expect("should select endpoint");
505            assert_eq!(ep.id, "ep2");
506        }
507    }
508
509    #[test]
510    fn test_load_balancer_affinity() {
511        let lb = LoadBalancer::new(BalancingStrategy::RoundRobin);
512
513        lb.add_endpoint(Endpoint::new(
514            "ep1".to_string(),
515            "localhost:50051".to_string(),
516        ));
517        lb.add_endpoint(Endpoint::new(
518            "ep2".to_string(),
519            "localhost:50052".to_string(),
520        ));
521
522        let session_id = "session123";
523
524        // First selection should assign endpoint
525        let ep1 = lb
526            .select_with_affinity(session_id)
527            .expect("should select endpoint");
528
529        // Subsequent selections should return same endpoint
530        let ep2 = lb
531            .select_with_affinity(session_id)
532            .expect("should select endpoint");
533        let ep3 = lb
534            .select_with_affinity(session_id)
535            .expect("should select endpoint");
536
537        assert_eq!(ep1.id, ep2.id);
538        assert_eq!(ep2.id, ep3.id);
539
540        // Clear affinity
541        lb.clear_affinity(session_id);
542
543        // Next selection may be different
544        let _ep4 = lb
545            .select_with_affinity(session_id)
546            .expect("should select endpoint");
547    }
548
549    #[test]
550    fn test_load_balancer_remove_endpoint() {
551        let lb = LoadBalancer::new(BalancingStrategy::RoundRobin);
552
553        lb.add_endpoint(Endpoint::new(
554            "ep1".to_string(),
555            "localhost:50051".to_string(),
556        ));
557        lb.add_endpoint(Endpoint::new(
558            "ep2".to_string(),
559            "localhost:50052".to_string(),
560        ));
561
562        assert_eq!(lb.endpoints().len(), 2);
563
564        lb.remove_endpoint("ep1");
565        assert_eq!(lb.endpoints().len(), 1);
566
567        let ep = lb.select_endpoint().expect("should select endpoint");
568        assert_eq!(ep.id, "ep2");
569    }
570
571    #[test]
572    fn test_load_balancer_stats() {
573        let lb = LoadBalancer::new(BalancingStrategy::LeastConnections);
574
575        lb.add_endpoint(Endpoint::new(
576            "ep1".to_string(),
577            "localhost:50051".to_string(),
578        ));
579        lb.add_endpoint(Endpoint::new(
580            "ep2".to_string(),
581            "localhost:50052".to_string(),
582        ));
583
584        let stats = lb.stats();
585        assert_eq!(stats.total_endpoints, 2);
586        assert_eq!(stats.healthy_endpoints, 2);
587        assert_eq!(stats.total_connections, 0);
588        assert_eq!(stats.strategy, BalancingStrategy::LeastConnections);
589    }
590
591    #[test]
592    fn test_connection_guard() {
593        let endpoint = Arc::new(Endpoint::new(
594            "ep1".to_string(),
595            "localhost:50051".to_string(),
596        ));
597
598        assert_eq!(endpoint.active_connections(), 0);
599
600        {
601            let _guard = ConnectionGuard::new(Arc::clone(&endpoint));
602            assert_eq!(endpoint.active_connections(), 1);
603        }
604
605        // Guard dropped, connection should be decremented
606        assert_eq!(endpoint.active_connections(), 0);
607    }
608
609    #[test]
610    fn test_affinity() {
611        let affinity = Affinity::new();
612
613        affinity.set("session1".to_string(), "ep1".to_string());
614        affinity.set("session2".to_string(), "ep2".to_string());
615
616        assert_eq!(affinity.get("session1"), Some("ep1".to_string()));
617        assert_eq!(affinity.get("session2"), Some("ep2".to_string()));
618        assert_eq!(affinity.get("session3"), None);
619
620        affinity.remove("session1");
621        assert_eq!(affinity.get("session1"), None);
622
623        affinity.clear();
624        assert_eq!(affinity.get("session2"), None);
625    }
626}