1use super::{NodeEndpoint, NodeId, NodeRole, ProxyError, Result};
7use std::collections::HashMap;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::sync::Arc;
10use tokio::sync::RwLock;
11
12#[derive(Debug, Clone)]
14pub struct LoadBalancerConfig {
15 pub read_strategy: RoutingStrategy,
17 pub write_strategy: RoutingStrategy,
19 pub read_write_split: bool,
21 pub latency_threshold_ms: u64,
23 pub min_weight: u32,
25}
26
27impl Default for LoadBalancerConfig {
28 fn default() -> Self {
29 Self {
30 read_strategy: RoutingStrategy::RoundRobin,
31 write_strategy: RoutingStrategy::PrimaryOnly,
32 read_write_split: true,
33 latency_threshold_ms: 100,
34 min_weight: 1,
35 }
36 }
37}
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum RoutingStrategy {
42 PrimaryOnly,
44 RoundRobin,
46 WeightedRoundRobin,
48 LeastConnections,
50 LatencyBased,
52 Random,
54 PreferLocal,
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
64pub enum NodeHealth {
65 #[default]
67 Healthy,
68 Degraded,
70 Transitioning,
72 Unhealthy,
74}
75
76impl NodeHealth {
77 pub fn can_serve_reads(&self) -> bool {
79 matches!(self, NodeHealth::Healthy | NodeHealth::Degraded)
80 }
81
82 pub fn can_serve_writes(&self) -> bool {
84 matches!(self, NodeHealth::Healthy)
85 }
86
87 pub fn is_usable(&self) -> bool {
89 !matches!(self, NodeHealth::Unhealthy)
90 }
91}
92
93#[derive(Debug, Clone)]
95struct NodeState {
96 endpoint: NodeEndpoint,
98 health: NodeHealth,
100 replication_lag_ms: u64,
102 connections: u64,
104 avg_latency_ms: f64,
106 requests: u64,
108 failures: u64,
110}
111
112pub struct LoadBalancer {
114 config: LoadBalancerConfig,
116 nodes: Arc<RwLock<HashMap<NodeId, NodeState>>>,
118 rr_counter: AtomicU64,
120 total_requests: AtomicU64,
122}
123
124impl LoadBalancer {
125 pub fn new(config: LoadBalancerConfig) -> Self {
127 Self {
128 config,
129 nodes: Arc::new(RwLock::new(HashMap::new())),
130 rr_counter: AtomicU64::new(0),
131 total_requests: AtomicU64::new(0),
132 }
133 }
134
135 pub fn add_node(&mut self, endpoint: NodeEndpoint) {
137 let node_id = endpoint.id;
138 let state = NodeState {
139 endpoint,
140 health: NodeHealth::Healthy,
141 replication_lag_ms: 0,
142 connections: 0,
143 avg_latency_ms: 0.0,
144 requests: 0,
145 failures: 0,
146 };
147
148 let nodes = self.nodes.clone();
151 tokio::spawn(async move {
152 nodes.write().await.insert(node_id, state);
153 });
154 }
155
156 pub fn remove_node(&mut self, node_id: &NodeId) {
158 let id = *node_id;
159 let nodes = self.nodes.clone();
160 tokio::spawn(async move {
161 nodes.write().await.remove(&id);
162 });
163 }
164
165 pub fn select_for_read(&self) -> Result<NodeEndpoint> {
167 self.total_requests.fetch_add(1, Ordering::SeqCst);
168
169 let rt = tokio::runtime::Handle::try_current();
171 let nodes_guard = match rt {
172 Ok(handle) => handle.block_on(async { self.nodes.read().await }),
173 Err(_) => {
174 return Err(ProxyError::Routing(
176 "No async runtime available".to_string(),
177 ));
178 }
179 };
180
181 let mut eligible: Vec<_> = nodes_guard
183 .values()
184 .filter(|n| n.health.can_serve_reads() && n.endpoint.enabled)
185 .filter(|n| {
186 self.config.read_write_split
187 || n.endpoint.role == NodeRole::Primary
188 || n.endpoint.role == NodeRole::Standby
189 || n.endpoint.role == NodeRole::ReadReplica
190 })
191 .collect();
192
193 if eligible.is_empty() {
195 eligible = nodes_guard
196 .values()
197 .filter(|n| n.health == NodeHealth::Transitioning && n.endpoint.enabled)
198 .collect();
199 }
200
201 if eligible.is_empty() {
202 return Err(ProxyError::NoHealthyNodes);
203 }
204
205 eligible.sort_by_key(|n| match n.health {
207 NodeHealth::Healthy => 0,
208 NodeHealth::Degraded => 1,
209 NodeHealth::Transitioning => 2,
210 NodeHealth::Unhealthy => 3,
211 });
212
213 let selected = self.select_by_strategy(&eligible, self.config.read_strategy)?;
214 Ok(selected.endpoint.clone())
215 }
216
217 pub fn select_for_write(&self) -> Result<NodeEndpoint> {
219 self.total_requests.fetch_add(1, Ordering::SeqCst);
220
221 let rt = tokio::runtime::Handle::try_current();
222 let nodes_guard = match rt {
223 Ok(handle) => handle.block_on(async { self.nodes.read().await }),
224 Err(_) => {
225 return Err(ProxyError::Routing(
226 "No async runtime available".to_string(),
227 ));
228 }
229 };
230
231 let primary = nodes_guard.values().find(|n| {
233 n.endpoint.role == NodeRole::Primary
234 && n.health.can_serve_writes()
235 && n.endpoint.enabled
236 });
237
238 match primary {
239 Some(node) => Ok(node.endpoint.clone()),
240 None => Err(ProxyError::NoHealthyNodes),
241 }
242 }
243
244 fn select_by_strategy<'a>(
246 &self,
247 nodes: &[&'a NodeState],
248 strategy: RoutingStrategy,
249 ) -> Result<&'a NodeState> {
250 match strategy {
251 RoutingStrategy::PrimaryOnly => nodes
252 .iter()
253 .find(|n| n.endpoint.role == NodeRole::Primary)
254 .copied()
255 .ok_or(ProxyError::NoHealthyNodes),
256 RoutingStrategy::RoundRobin => {
257 let idx = self.rr_counter.fetch_add(1, Ordering::SeqCst) as usize;
258 Ok(nodes[idx % nodes.len()])
259 }
260 RoutingStrategy::WeightedRoundRobin => {
261 let total_weight: u32 = nodes.iter().map(|n| n.endpoint.weight).sum();
263 if total_weight == 0 {
264 return Err(ProxyError::NoHealthyNodes);
265 }
266
267 let idx = self.rr_counter.fetch_add(1, Ordering::SeqCst);
268 let mut target = (idx % total_weight as u64) as u32;
269
270 for node in nodes {
271 if target < node.endpoint.weight {
272 return Ok(node);
273 }
274 target -= node.endpoint.weight;
275 }
276
277 Ok(nodes[0])
278 }
279 RoutingStrategy::LeastConnections => nodes
280 .iter()
281 .min_by_key(|n| n.connections)
282 .copied()
283 .ok_or(ProxyError::NoHealthyNodes),
284 RoutingStrategy::LatencyBased => nodes
285 .iter()
286 .min_by(|a, b| {
287 a.avg_latency_ms
288 .partial_cmp(&b.avg_latency_ms)
289 .unwrap_or(std::cmp::Ordering::Equal)
290 })
291 .copied()
292 .ok_or(ProxyError::NoHealthyNodes),
293 RoutingStrategy::Random => {
294 use std::time::{SystemTime, UNIX_EPOCH};
295 let seed = SystemTime::now()
296 .duration_since(UNIX_EPOCH)
297 .unwrap()
298 .as_nanos() as usize;
299 Ok(nodes[seed % nodes.len()])
300 }
301 RoutingStrategy::PreferLocal => {
302 fn is_local(host: &str) -> bool {
307 matches!(host, "::1" | "[::1]" | "localhost") || host.starts_with("127.")
308 }
309 nodes
310 .iter()
311 .filter(|n| is_local(&n.endpoint.host))
312 .min_by_key(|n| n.connections)
313 .or_else(|| nodes.iter().min_by_key(|n| n.connections))
314 .copied()
315 .ok_or(ProxyError::NoHealthyNodes)
316 }
317 }
318 }
319
320 pub async fn set_node_health(&self, node_id: &NodeId, health: NodeHealth) {
328 if let Some(node) = self.nodes.write().await.get_mut(node_id) {
329 let old_health = node.health;
330 node.health = health;
331 tracing::debug!(
332 "Node {:?} health changed: {:?} -> {:?}",
333 node_id,
334 old_health,
335 health
336 );
337 }
338 }
339
340 pub async fn set_node_healthy(&self, node_id: &NodeId, healthy: bool) {
342 let health = if healthy {
343 NodeHealth::Healthy
344 } else {
345 NodeHealth::Unhealthy
346 };
347 self.set_node_health(node_id, health).await;
348 }
349
350 pub async fn set_node_transitioning(&self, node_id: &NodeId) {
352 self.set_node_health(node_id, NodeHealth::Transitioning)
353 .await;
354 }
355
356 pub async fn update_latency(&self, node_id: &NodeId, latency_ms: f64) {
358 if let Some(node) = self.nodes.write().await.get_mut(node_id) {
359 let alpha = 0.2;
361 node.avg_latency_ms = alpha * latency_ms + (1.0 - alpha) * node.avg_latency_ms;
362
363 let threshold = self.config.latency_threshold_ms as f64;
365 let degraded_threshold = threshold * 0.7; if node.health != NodeHealth::Transitioning {
369 if latency_ms > threshold {
370 node.health = NodeHealth::Unhealthy;
371 tracing::warn!(
372 "Node {:?} marked unhealthy due to high latency: {}ms",
373 node_id,
374 latency_ms
375 );
376 } else if latency_ms > degraded_threshold {
377 node.health = NodeHealth::Degraded;
378 tracing::debug!(
379 "Node {:?} marked degraded due to elevated latency: {}ms",
380 node_id,
381 latency_ms
382 );
383 } else if node.health == NodeHealth::Degraded
384 || node.health == NodeHealth::Unhealthy
385 {
386 node.health = NodeHealth::Healthy;
388 tracing::info!("Node {:?} recovered, marked healthy", node_id);
389 }
390 }
391 }
392 }
393
394 pub async fn update_replication_lag(&self, node_id: &NodeId, lag_ms: u64) {
396 const DEGRADED_LAG_MS: u64 = 5000; const UNHEALTHY_LAG_MS: u64 = 30000; if let Some(node) = self.nodes.write().await.get_mut(node_id) {
401 node.replication_lag_ms = lag_ms;
402
403 if node.health != NodeHealth::Transitioning {
405 if lag_ms > UNHEALTHY_LAG_MS {
406 node.health = NodeHealth::Unhealthy;
407 tracing::warn!(
408 "Node {:?} marked unhealthy due to high replication lag: {}ms",
409 node_id,
410 lag_ms
411 );
412 } else if lag_ms > DEGRADED_LAG_MS {
413 node.health = NodeHealth::Degraded;
414 tracing::debug!(
415 "Node {:?} marked degraded due to replication lag: {}ms",
416 node_id,
417 lag_ms
418 );
419 } else if node.health == NodeHealth::Degraded
420 && node.avg_latency_ms < self.config.latency_threshold_ms as f64 * 0.7
421 {
422 node.health = NodeHealth::Healthy;
424 tracing::info!("Node {:?} recovered from lag, marked healthy", node_id);
425 }
426 }
427 }
428 }
429
430 #[allow(clippy::if_same_then_else)]
432 pub async fn update_node_metrics(
433 &self,
434 node_id: &NodeId,
435 latency_ms: f64,
436 replication_lag_ms: u64,
437 failure_rate: f64,
438 ) {
439 if let Some(node) = self.nodes.write().await.get_mut(node_id) {
440 node.avg_latency_ms = 0.2 * latency_ms + 0.8 * node.avg_latency_ms;
442 node.replication_lag_ms = replication_lag_ms;
443
444 if node.health != NodeHealth::Transitioning {
446 let new_health = if !Self::is_responsive(latency_ms) {
448 NodeHealth::Unhealthy
449 } else if replication_lag_ms > 30000 {
450 NodeHealth::Unhealthy
451 } else if replication_lag_ms > 5000
452 || failure_rate > 0.5
453 || latency_ms > self.config.latency_threshold_ms as f64
454 {
455 NodeHealth::Degraded
456 } else {
457 NodeHealth::Healthy
458 };
459
460 if new_health != node.health {
461 tracing::debug!(
462 "Node {:?} health: {:?} -> {:?}",
463 node_id,
464 node.health,
465 new_health
466 );
467 node.health = new_health;
468 }
469 }
470 }
471 }
472
473 fn is_responsive(latency_ms: f64) -> bool {
475 (0.0..5000.0).contains(&latency_ms)
477 }
478
479 pub async fn increment_connections(&self, node_id: &NodeId) {
481 if let Some(node) = self.nodes.write().await.get_mut(node_id) {
482 node.connections += 1;
483 node.requests += 1;
484 }
485 }
486
487 pub async fn decrement_connections(&self, node_id: &NodeId) {
489 if let Some(node) = self.nodes.write().await.get_mut(node_id) {
490 node.connections = node.connections.saturating_sub(1);
491 }
492 }
493
494 pub async fn record_failure(&self, node_id: &NodeId) {
496 if let Some(node) = self.nodes.write().await.get_mut(node_id) {
497 node.failures += 1;
498 }
499 }
500
501 pub fn requests_routed(&self) -> u64 {
503 self.total_requests.load(Ordering::SeqCst)
504 }
505
506 pub async fn node_stats(&self, node_id: &NodeId) -> Option<NodeStats> {
508 self.nodes.read().await.get(node_id).map(|n| NodeStats {
509 health: n.health,
510 replication_lag_ms: n.replication_lag_ms,
511 connections: n.connections,
512 avg_latency_ms: n.avg_latency_ms,
513 requests: n.requests,
514 failures: n.failures,
515 })
516 }
517
518 pub async fn all_stats(&self) -> HashMap<NodeId, NodeStats> {
520 self.nodes
521 .read()
522 .await
523 .iter()
524 .map(|(id, n)| {
525 (
526 *id,
527 NodeStats {
528 health: n.health,
529 replication_lag_ms: n.replication_lag_ms,
530 connections: n.connections,
531 avg_latency_ms: n.avg_latency_ms,
532 requests: n.requests,
533 failures: n.failures,
534 },
535 )
536 })
537 .collect()
538 }
539}
540
541#[derive(Debug, Clone)]
543pub struct NodeStats {
544 pub health: NodeHealth,
546 pub replication_lag_ms: u64,
548 pub connections: u64,
550 pub avg_latency_ms: f64,
552 pub requests: u64,
554 pub failures: u64,
556}
557
558impl NodeStats {
559 pub fn is_healthy(&self) -> bool {
561 self.health == NodeHealth::Healthy
562 }
563
564 pub fn can_serve_reads(&self) -> bool {
566 self.health.can_serve_reads()
567 }
568}
569
570#[cfg(test)]
571mod tests {
572 use super::*;
573
574 #[test]
575 fn test_config_default() {
576 let config = LoadBalancerConfig::default();
577 assert_eq!(config.read_strategy, RoutingStrategy::RoundRobin);
578 assert_eq!(config.write_strategy, RoutingStrategy::PrimaryOnly);
579 assert!(config.read_write_split);
580 }
581
582 #[tokio::test]
583 async fn test_set_node_health() {
584 let lb = LoadBalancer::new(LoadBalancerConfig::default());
585 let node_id = NodeId::new();
586
587 {
589 let mut nodes = lb.nodes.write().await;
590 nodes.insert(
591 node_id,
592 NodeState {
593 endpoint: NodeEndpoint::new("localhost", 5432).with_role(NodeRole::Primary),
594 health: NodeHealth::Healthy,
595 replication_lag_ms: 0,
596 connections: 0,
597 avg_latency_ms: 0.0,
598 requests: 0,
599 failures: 0,
600 },
601 );
602 }
603
604 lb.set_node_health(&node_id, NodeHealth::Unhealthy).await;
605
606 let stats = lb.node_stats(&node_id).await.unwrap();
607 assert_eq!(stats.health, NodeHealth::Unhealthy);
608 assert!(!stats.is_healthy());
609 }
610
611 #[tokio::test]
612 async fn test_degraded_state() {
613 let lb = LoadBalancer::new(LoadBalancerConfig::default());
614 let node_id = NodeId::new();
615
616 {
617 let mut nodes = lb.nodes.write().await;
618 nodes.insert(
619 node_id,
620 NodeState {
621 endpoint: NodeEndpoint::new("localhost", 5432).with_role(NodeRole::Standby),
622 health: NodeHealth::Healthy,
623 replication_lag_ms: 0,
624 connections: 0,
625 avg_latency_ms: 0.0,
626 requests: 0,
627 failures: 0,
628 },
629 );
630 }
631
632 lb.set_node_health(&node_id, NodeHealth::Degraded).await;
634
635 let stats = lb.node_stats(&node_id).await.unwrap();
636 assert_eq!(stats.health, NodeHealth::Degraded);
637 assert!(stats.can_serve_reads()); assert!(!stats.is_healthy()); }
640
641 #[tokio::test]
642 async fn test_update_latency() {
643 let lb = LoadBalancer::new(LoadBalancerConfig::default());
644 let node_id = NodeId::new();
645
646 {
647 let mut nodes = lb.nodes.write().await;
648 nodes.insert(
649 node_id,
650 NodeState {
651 endpoint: NodeEndpoint::new("localhost", 5432),
652 health: NodeHealth::Healthy,
653 replication_lag_ms: 0,
654 connections: 0,
655 avg_latency_ms: 0.0,
656 requests: 0,
657 failures: 0,
658 },
659 );
660 }
661
662 lb.update_latency(&node_id, 50.0).await;
663
664 let stats = lb.node_stats(&node_id).await.unwrap();
665 assert!(stats.avg_latency_ms > 0.0);
666 }
667
668 #[tokio::test]
669 async fn test_replication_lag_degrades_health() {
670 let lb = LoadBalancer::new(LoadBalancerConfig::default());
671 let node_id = NodeId::new();
672
673 {
674 let mut nodes = lb.nodes.write().await;
675 nodes.insert(
676 node_id,
677 NodeState {
678 endpoint: NodeEndpoint::new("localhost", 5432).with_role(NodeRole::Standby),
679 health: NodeHealth::Healthy,
680 replication_lag_ms: 0,
681 connections: 0,
682 avg_latency_ms: 0.0,
683 requests: 0,
684 failures: 0,
685 },
686 );
687 }
688
689 lb.update_replication_lag(&node_id, 10000).await; let stats = lb.node_stats(&node_id).await.unwrap();
693 assert_eq!(stats.health, NodeHealth::Degraded);
694 assert_eq!(stats.replication_lag_ms, 10000);
695 }
696
697 #[tokio::test]
698 async fn test_connection_tracking() {
699 let lb = LoadBalancer::new(LoadBalancerConfig::default());
700 let node_id = NodeId::new();
701
702 {
703 let mut nodes = lb.nodes.write().await;
704 nodes.insert(
705 node_id,
706 NodeState {
707 endpoint: NodeEndpoint::new("localhost", 5432),
708 health: NodeHealth::Healthy,
709 replication_lag_ms: 0,
710 connections: 0,
711 avg_latency_ms: 0.0,
712 requests: 0,
713 failures: 0,
714 },
715 );
716 }
717
718 lb.increment_connections(&node_id).await;
719 lb.increment_connections(&node_id).await;
720
721 let stats = lb.node_stats(&node_id).await.unwrap();
722 assert_eq!(stats.connections, 2);
723
724 lb.decrement_connections(&node_id).await;
725 let stats = lb.node_stats(&node_id).await.unwrap();
726 assert_eq!(stats.connections, 1);
727 }
728
729 #[test]
730 fn prefer_local_routes_to_loopback_then_least_loaded() {
731 let lb = LoadBalancer::new(LoadBalancerConfig::default());
732 let mk = |host: &str, conns: u64| NodeState {
733 endpoint: NodeEndpoint::new(host, 5432),
734 health: NodeHealth::Healthy,
735 replication_lag_ms: 0,
736 connections: conns,
737 avg_latency_ms: 0.0,
738 requests: 0,
739 failures: 0,
740 };
741
742 let remote = mk("10.0.0.5", 1);
745 let local = mk("127.0.0.1", 9);
746 let refs = vec![&remote, &local];
747 let chosen = lb
748 .select_by_strategy(&refs, RoutingStrategy::PreferLocal)
749 .unwrap();
750 assert_eq!(chosen.endpoint.host, "127.0.0.1");
751
752 let local_busy = mk("127.0.0.1", 9);
754 let local_free = mk("localhost", 2);
755 let refs2 = vec![&local_busy, &local_free];
756 let chosen2 = lb
757 .select_by_strategy(&refs2, RoutingStrategy::PreferLocal)
758 .unwrap();
759 assert_eq!(chosen2.endpoint.host, "localhost");
760
761 let r1 = mk("10.0.0.1", 5);
764 let r2 = mk("10.0.0.2", 2);
765 let refs3 = vec![&r1, &r2];
766 let chosen3 = lb
767 .select_by_strategy(&refs3, RoutingStrategy::PreferLocal)
768 .unwrap();
769 assert_eq!(chosen3.endpoint.host, "10.0.0.2");
770 }
771}