Skip to main content

xrpc/
loadbalancer.rs

1use parking_lot::RwLock;
2use rand::Rng;
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicUsize, Ordering};
6use std::time::{Duration, Instant};
7
8use crate::discovery::{DiscoveryEvent, Endpoint, ServiceDiscovery};
9use crate::error::Result;
10use crate::streaming::StreamId;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
13pub enum ServerHealth {
14    Healthy,
15    Degraded,
16    Unhealthy,
17    #[default]
18    Unknown,
19}
20
21#[derive(Debug, Clone)]
22pub struct ServerState<S = ()> {
23    pub endpoint: Endpoint,
24    pub health: ServerHealth,
25    pub status: Option<S>,
26    pub last_update: Instant,
27    pub active_requests: usize,
28    pub total_requests: u64,
29    pub total_errors: u64,
30    pub consecutive_failures: u32,
31}
32
33impl<S> ServerState<S> {
34    pub fn new(endpoint: Endpoint) -> Self {
35        Self {
36            endpoint,
37            health: ServerHealth::Unknown,
38            status: None,
39            last_update: Instant::now(),
40            active_requests: 0,
41            total_requests: 0,
42            total_errors: 0,
43            consecutive_failures: 0,
44        }
45    }
46
47    pub fn is_available(&self) -> bool {
48        matches!(
49            self.health,
50            ServerHealth::Healthy | ServerHealth::Degraded | ServerHealth::Unknown
51        )
52    }
53
54    pub fn record_success(&mut self) {
55        self.consecutive_failures = 0;
56        self.total_requests += 1;
57        self.last_update = Instant::now();
58    }
59
60    pub fn record_failure(&mut self) {
61        self.consecutive_failures += 1;
62        self.total_errors += 1;
63        self.total_requests += 1;
64        self.last_update = Instant::now();
65    }
66}
67
68pub trait LoadBalanceStrategy: Send + Sync {
69    type Status: Clone + Send + Sync + 'static;
70
71    fn select(&self, servers: &[ServerState<Self::Status>]) -> Option<usize>;
72
73    fn update_status(&self, _server_idx: usize, _status: Self::Status) {}
74
75    fn on_success(&self, _server_idx: usize) {}
76
77    fn on_failure(&self, _server_idx: usize) {}
78
79    fn name(&self) -> &'static str;
80}
81
82pub struct RoundRobin {
83    counter: AtomicUsize,
84}
85
86impl RoundRobin {
87    pub fn new() -> Self {
88        Self {
89            counter: AtomicUsize::new(0),
90        }
91    }
92}
93
94impl Default for RoundRobin {
95    fn default() -> Self {
96        Self::new()
97    }
98}
99
100impl LoadBalanceStrategy for RoundRobin {
101    type Status = ();
102
103    fn select(&self, servers: &[ServerState<()>]) -> Option<usize> {
104        let available: Vec<_> = servers
105            .iter()
106            .enumerate()
107            .filter(|(_, s)| s.is_available())
108            .collect();
109
110        if available.is_empty() {
111            return None;
112        }
113
114        let idx = self.counter.fetch_add(1, Ordering::Relaxed);
115        Some(available[idx % available.len()].0)
116    }
117
118    fn name(&self) -> &'static str {
119        "RoundRobin"
120    }
121}
122
123pub struct Random;
124
125impl Random {
126    pub fn new() -> Self {
127        Self
128    }
129}
130
131impl Default for Random {
132    fn default() -> Self {
133        Self::new()
134    }
135}
136
137impl LoadBalanceStrategy for Random {
138    type Status = ();
139
140    fn select(&self, servers: &[ServerState<()>]) -> Option<usize> {
141        let available: Vec<_> = servers
142            .iter()
143            .enumerate()
144            .filter(|(_, s)| s.is_available())
145            .map(|(i, _)| i)
146            .collect();
147
148        if available.is_empty() {
149            return None;
150        }
151
152        let idx = rand::thread_rng().gen_range(0..available.len());
153        Some(available[idx])
154    }
155
156    fn name(&self) -> &'static str {
157        "Random"
158    }
159}
160
161pub struct LeastConnections;
162
163impl LeastConnections {
164    pub fn new() -> Self {
165        Self
166    }
167}
168
169impl Default for LeastConnections {
170    fn default() -> Self {
171        Self::new()
172    }
173}
174
175impl LoadBalanceStrategy for LeastConnections {
176    type Status = ();
177
178    fn select(&self, servers: &[ServerState<()>]) -> Option<usize> {
179        servers
180            .iter()
181            .enumerate()
182            .filter(|(_, s)| s.is_available())
183            .min_by_key(|(_, s)| s.active_requests)
184            .map(|(i, _)| i)
185    }
186
187    fn name(&self) -> &'static str {
188        "LeastConnections"
189    }
190}
191
192#[derive(Debug, Clone)]
193pub struct ServerWeight {
194    pub weight: u32,
195    pub current_weight: i32,
196}
197
198impl Default for ServerWeight {
199    fn default() -> Self {
200        Self {
201            weight: 1,
202            current_weight: 0,
203        }
204    }
205}
206
207pub struct WeightedRoundRobin {
208    weights: parking_lot::Mutex<Vec<ServerWeight>>,
209}
210
211impl WeightedRoundRobin {
212    pub fn new() -> Self {
213        Self {
214            weights: parking_lot::Mutex::new(Vec::new()),
215        }
216    }
217
218    pub fn with_weights(weights: Vec<u32>) -> Self {
219        let sw: Vec<_> = weights
220            .into_iter()
221            .map(|w| ServerWeight {
222                weight: w,
223                current_weight: 0,
224            })
225            .collect();
226        Self {
227            weights: parking_lot::Mutex::new(sw),
228        }
229    }
230
231    pub fn set_weight(&self, server_idx: usize, weight: u32) {
232        let mut weights = self.weights.lock();
233        while weights.len() <= server_idx {
234            weights.push(ServerWeight::default());
235        }
236        weights[server_idx].weight = weight;
237    }
238}
239
240impl Default for WeightedRoundRobin {
241    fn default() -> Self {
242        Self::new()
243    }
244}
245
246impl LoadBalanceStrategy for WeightedRoundRobin {
247    type Status = ();
248
249    fn select(&self, servers: &[ServerState<()>]) -> Option<usize> {
250        let mut weights = self.weights.lock();
251
252        while weights.len() < servers.len() {
253            weights.push(ServerWeight::default());
254        }
255
256        let mut total_weight = 0i32;
257        let mut best_idx = None;
258        let mut best_weight = i32::MIN;
259
260        for (i, (server, sw)) in servers.iter().zip(weights.iter_mut()).enumerate() {
261            if !server.is_available() {
262                continue;
263            }
264
265            sw.current_weight += sw.weight as i32;
266            total_weight += sw.weight as i32;
267
268            if sw.current_weight > best_weight {
269                best_weight = sw.current_weight;
270                best_idx = Some(i);
271            }
272        }
273
274        if let Some(idx) = best_idx {
275            weights[idx].current_weight -= total_weight;
276        }
277
278        best_idx
279    }
280
281    fn name(&self) -> &'static str {
282        "WeightedRoundRobin"
283    }
284}
285
286pub struct ScoreBased {
287    pub threshold: f32,
288    pub stale_timeout: Duration,
289}
290
291impl ScoreBased {
292    pub fn new() -> Self {
293        Self {
294            threshold: 0.95,
295            stale_timeout: Duration::from_secs(30),
296        }
297    }
298
299    pub fn with_threshold(mut self, threshold: f32) -> Self {
300        self.threshold = threshold;
301        self
302    }
303
304    pub fn with_stale_timeout(mut self, timeout: Duration) -> Self {
305        self.stale_timeout = timeout;
306        self
307    }
308}
309
310impl Default for ScoreBased {
311    fn default() -> Self {
312        Self::new()
313    }
314}
315
316impl LoadBalanceStrategy for ScoreBased {
317    type Status = f32;
318
319    fn select(&self, servers: &[ServerState<f32>]) -> Option<usize> {
320        servers
321            .iter()
322            .enumerate()
323            .filter(|(_, s)| {
324                if !s.is_available() {
325                    return false;
326                }
327                if s.last_update.elapsed() > self.stale_timeout {
328                    return true;
329                }
330                s.status.map_or(true, |score| score < self.threshold)
331            })
332            .min_by(|(_, a), (_, b)| {
333                let a_stale = a.last_update.elapsed() > self.stale_timeout;
334                let b_stale = b.last_update.elapsed() > self.stale_timeout;
335
336                match (a_stale, b_stale) {
337                    (true, false) => std::cmp::Ordering::Greater,
338                    (false, true) => std::cmp::Ordering::Less,
339                    _ => {
340                        let a_score = a.status.unwrap_or(0.0);
341                        let b_score = b.status.unwrap_or(0.0);
342                        a_score
343                            .partial_cmp(&b_score)
344                            .unwrap_or(std::cmp::Ordering::Equal)
345                    }
346                }
347            })
348            .map(|(i, _)| i)
349    }
350
351    fn name(&self) -> &'static str {
352        "ScoreBased"
353    }
354}
355
356#[derive(Debug, Clone)]
357pub struct LoadBalancerConfig {
358    pub max_failures: u32,
359    pub health_check_interval: Duration,
360    pub auto_health_check: bool,
361    pub failover_enabled: bool,
362    pub max_failover_attempts: u32,
363}
364
365impl Default for LoadBalancerConfig {
366    fn default() -> Self {
367        Self {
368            max_failures: 3,
369            health_check_interval: Duration::from_secs(10),
370            auto_health_check: true,
371            failover_enabled: true,
372            max_failover_attempts: 2,
373        }
374    }
375}
376
377pub struct LoadBalancer<S: LoadBalanceStrategy> {
378    discovery: Arc<dyn ServiceDiscovery>,
379    strategy: S,
380    servers: Arc<RwLock<Vec<ServerState<S::Status>>>>,
381    stream_affinity: RwLock<HashMap<StreamId, usize>>,
382    config: LoadBalancerConfig,
383}
384
385impl<S: LoadBalanceStrategy + 'static> LoadBalancer<S> {
386    pub fn new(discovery: Arc<dyn ServiceDiscovery>, strategy: S) -> Self {
387        Self::with_config(discovery, strategy, LoadBalancerConfig::default())
388    }
389
390    pub fn with_config(
391        discovery: Arc<dyn ServiceDiscovery>,
392        strategy: S,
393        config: LoadBalancerConfig,
394    ) -> Self {
395        Self {
396            discovery,
397            strategy,
398            servers: Arc::new(RwLock::new(Vec::new())),
399            stream_affinity: RwLock::new(HashMap::new()),
400            config,
401        }
402    }
403
404    pub async fn init(&self) -> Result<()> {
405        let endpoints = self.discovery.discover().await?;
406        self.update_endpoints(endpoints);
407        Ok(())
408    }
409
410    pub fn start(&self) -> LoadBalancerHandle {
411        let mut handles = Vec::new();
412
413        if let Some(mut rx) = self.discovery.watch() {
414            let servers = self.servers.clone();
415            let h = tokio::spawn(async move {
416                while let Ok(event) = rx.recv().await {
417                    match event {
418                        DiscoveryEvent::Updated(endpoints) => {
419                            Self::update_endpoints_static(&servers, endpoints);
420                        }
421                        DiscoveryEvent::Added(endpoint) => {
422                            servers.write().push(ServerState::new(endpoint));
423                        }
424                        DiscoveryEvent::Removed(endpoint) => {
425                            servers.write().retain(|s| s.endpoint != endpoint);
426                        }
427                    }
428                }
429            });
430            handles.push(h);
431        }
432
433        LoadBalancerHandle { handles }
434    }
435
436    fn update_endpoints(&self, endpoints: Vec<Endpoint>) {
437        Self::update_endpoints_static(&self.servers, endpoints);
438    }
439
440    fn update_endpoints_static(
441        servers: &RwLock<Vec<ServerState<S::Status>>>,
442        endpoints: Vec<Endpoint>,
443    ) {
444        let mut servers = servers.write();
445        servers.retain(|s| endpoints.contains(&s.endpoint));
446        for ep in endpoints {
447            if !servers.iter().any(|s| s.endpoint == ep) {
448                servers.push(ServerState::new(ep));
449            }
450        }
451    }
452
453    pub fn select(&self) -> Option<usize> {
454        let servers = self.servers.read();
455        self.strategy.select(&servers)
456    }
457
458    pub fn select_for_stream(&self, stream_id: StreamId) -> Option<usize> {
459        {
460            let affinity = self.stream_affinity.read();
461            if let Some(&idx) = affinity.get(&stream_id) {
462                let servers = self.servers.read();
463                if servers.get(idx).map_or(false, |s| s.is_available()) {
464                    return Some(idx);
465                }
466            }
467        }
468
469        let idx = self.select()?;
470        self.stream_affinity.write().insert(stream_id, idx);
471        Some(idx)
472    }
473
474    pub fn release_stream(&self, stream_id: StreamId) {
475        self.stream_affinity.write().remove(&stream_id);
476    }
477
478    pub fn get_endpoint(&self, server_idx: usize) -> Option<Endpoint> {
479        self.servers
480            .read()
481            .get(server_idx)
482            .map(|s| s.endpoint.clone())
483    }
484
485    pub fn report_status(&self, server_idx: usize, status: S::Status) {
486        if let Some(server) = self.servers.write().get_mut(server_idx) {
487            server.status = Some(status.clone());
488            server.last_update = Instant::now();
489        }
490        self.strategy.update_status(server_idx, status);
491    }
492
493    pub fn record_success(&self, server_idx: usize) {
494        if let Some(server) = self.servers.write().get_mut(server_idx) {
495            server.record_success();
496            server.health = ServerHealth::Healthy;
497        }
498        self.strategy.on_success(server_idx);
499    }
500
501    pub fn record_failure(&self, server_idx: usize) {
502        let should_mark_unhealthy = {
503            let mut servers = self.servers.write();
504            if let Some(server) = servers.get_mut(server_idx) {
505                server.record_failure();
506                server.consecutive_failures >= self.config.max_failures
507            } else {
508                false
509            }
510        };
511
512        if should_mark_unhealthy {
513            self.mark_unhealthy(server_idx);
514        }
515
516        self.strategy.on_failure(server_idx);
517    }
518
519    pub fn mark_unhealthy(&self, server_idx: usize) {
520        if let Some(server) = self.servers.write().get_mut(server_idx) {
521            server.health = ServerHealth::Unhealthy;
522        }
523    }
524
525    pub fn mark_healthy(&self, server_idx: usize) {
526        if let Some(server) = self.servers.write().get_mut(server_idx) {
527            server.health = ServerHealth::Healthy;
528            server.consecutive_failures = 0;
529        }
530    }
531
532    pub fn available_count(&self) -> usize {
533        self.servers
534            .read()
535            .iter()
536            .filter(|s| s.is_available())
537            .count()
538    }
539
540    pub fn server_count(&self) -> usize {
541        self.servers.read().len()
542    }
543
544    pub fn strategy_name(&self) -> &'static str {
545        self.strategy.name()
546    }
547
548    pub fn acquire(&self, server_idx: usize) {
549        if let Some(server) = self.servers.write().get_mut(server_idx) {
550            server.active_requests += 1;
551        }
552    }
553
554    pub fn release(&self, server_idx: usize) {
555        if let Some(server) = self.servers.write().get_mut(server_idx) {
556            server.active_requests = server.active_requests.saturating_sub(1);
557        }
558    }
559
560    pub fn config(&self) -> &LoadBalancerConfig {
561        &self.config
562    }
563}
564
565pub struct LoadBalancerHandle {
566    handles: Vec<tokio::task::JoinHandle<()>>,
567}
568
569impl LoadBalancerHandle {
570    pub async fn shutdown(self) {
571        for h in self.handles {
572            h.abort();
573            let _ = h.await;
574        }
575    }
576}
577
578#[cfg(test)]
579mod tests {
580    use super::*;
581    use crate::discovery::StaticDiscovery;
582
583    fn create_test_servers(count: usize) -> Vec<ServerState<()>> {
584        (0..count)
585            .map(|i| {
586                ServerState::new(Endpoint::tcp_from_str(&format!("127.0.0.1:800{}", i)).unwrap())
587            })
588            .collect()
589    }
590
591    #[test]
592    fn test_round_robin() {
593        let strategy = RoundRobin::new();
594        let servers = create_test_servers(3);
595
596        let selections: Vec<_> = (0..6).filter_map(|_| strategy.select(&servers)).collect();
597        assert_eq!(selections, vec![0, 1, 2, 0, 1, 2]);
598    }
599
600    #[test]
601    fn test_round_robin_skip_unhealthy() {
602        let strategy = RoundRobin::new();
603        let mut servers = create_test_servers(3);
604        servers[1].health = ServerHealth::Unhealthy;
605
606        let selections: Vec<_> = (0..4).filter_map(|_| strategy.select(&servers)).collect();
607        assert_eq!(selections, vec![0, 2, 0, 2]);
608    }
609
610    #[test]
611    fn test_least_connections() {
612        let strategy = LeastConnections::new();
613        let mut servers = create_test_servers(3);
614        servers[0].active_requests = 5;
615        servers[1].active_requests = 2;
616        servers[2].active_requests = 3;
617
618        assert_eq!(strategy.select(&servers), Some(1));
619    }
620
621    #[test]
622    fn test_weighted_round_robin() {
623        let strategy = WeightedRoundRobin::with_weights(vec![2, 1, 1]);
624        let servers = create_test_servers(3);
625
626        let mut counts = [0usize; 3];
627        for _ in 0..8 {
628            if let Some(idx) = strategy.select(&servers) {
629                counts[idx] += 1;
630            }
631        }
632
633        assert!(counts[0] > counts[1]);
634        assert!(counts[0] > counts[2]);
635    }
636
637    #[tokio::test]
638    async fn test_load_balancer_init() {
639        let endpoints = vec![
640            Endpoint::tcp_from_str("127.0.0.1:8001").unwrap(),
641            Endpoint::tcp_from_str("127.0.0.1:8002").unwrap(),
642        ];
643
644        let discovery = Arc::new(StaticDiscovery::new(endpoints));
645        let lb = LoadBalancer::new(discovery, RoundRobin::new());
646        lb.init().await.unwrap();
647
648        assert_eq!(lb.server_count(), 2);
649        assert_eq!(lb.available_count(), 2);
650    }
651
652    #[tokio::test]
653    async fn test_stream_affinity() {
654        let endpoints = vec![
655            Endpoint::tcp_from_str("127.0.0.1:8001").unwrap(),
656            Endpoint::tcp_from_str("127.0.0.1:8002").unwrap(),
657        ];
658
659        let discovery = Arc::new(StaticDiscovery::new(endpoints));
660        let lb = LoadBalancer::new(discovery, RoundRobin::new());
661        lb.init().await.unwrap();
662
663        let stream_id = 42;
664        let first = lb.select_for_stream(stream_id);
665        let second = lb.select_for_stream(stream_id);
666        let third = lb.select_for_stream(stream_id);
667
668        assert_eq!(first, second);
669        assert_eq!(second, third);
670
671        lb.release_stream(stream_id);
672    }
673
674    #[tokio::test]
675    async fn test_failure_tracking() {
676        let endpoints = vec![Endpoint::tcp_from_str("127.0.0.1:8001").unwrap()];
677        let discovery = Arc::new(StaticDiscovery::new(endpoints));
678        let config = LoadBalancerConfig {
679            max_failures: 2,
680            ..Default::default()
681        };
682        let lb = LoadBalancer::with_config(discovery, RoundRobin::new(), config);
683        lb.init().await.unwrap();
684
685        lb.record_failure(0);
686        assert_eq!(lb.available_count(), 1);
687
688        lb.record_failure(0);
689        assert_eq!(lb.available_count(), 0);
690
691        lb.mark_healthy(0);
692        assert_eq!(lb.available_count(), 1);
693    }
694}