1use super::{NodeEndpoint, NodeId, NodeRole, ProxyError, Result};
7use std::collections::HashMap;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::sync::Arc;
10use tokio::sync::RwLock;
11
12#[derive(Debug, Clone)]
14pub struct LoadBalancerConfig {
15 pub read_strategy: RoutingStrategy,
17 pub write_strategy: RoutingStrategy,
19 pub read_write_split: bool,
21 pub latency_threshold_ms: u64,
23 pub min_weight: u32,
25}
26
27impl Default for LoadBalancerConfig {
28 fn default() -> Self {
29 Self {
30 read_strategy: RoutingStrategy::RoundRobin,
31 write_strategy: RoutingStrategy::PrimaryOnly,
32 read_write_split: true,
33 latency_threshold_ms: 100,
34 min_weight: 1,
35 }
36 }
37}
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum RoutingStrategy {
42 PrimaryOnly,
44 RoundRobin,
46 WeightedRoundRobin,
48 LeastConnections,
50 LatencyBased,
52 Random,
54 PreferLocal,
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq)]
64pub enum NodeHealth {
65 Healthy,
67 Degraded,
69 Transitioning,
71 Unhealthy,
73}
74
75impl NodeHealth {
76 pub fn can_serve_reads(&self) -> bool {
78 matches!(self, NodeHealth::Healthy | NodeHealth::Degraded)
79 }
80
81 pub fn can_serve_writes(&self) -> bool {
83 matches!(self, NodeHealth::Healthy)
84 }
85
86 pub fn is_usable(&self) -> bool {
88 !matches!(self, NodeHealth::Unhealthy)
89 }
90}
91
92impl Default for NodeHealth {
93 fn default() -> Self {
94 NodeHealth::Healthy
95 }
96}
97
98#[derive(Debug, Clone)]
100struct NodeState {
101 endpoint: NodeEndpoint,
103 health: NodeHealth,
105 replication_lag_ms: u64,
107 connections: u64,
109 avg_latency_ms: f64,
111 requests: u64,
113 failures: u64,
115}
116
117pub struct LoadBalancer {
119 config: LoadBalancerConfig,
121 nodes: Arc<RwLock<HashMap<NodeId, NodeState>>>,
123 rr_counter: AtomicU64,
125 total_requests: AtomicU64,
127}
128
129impl LoadBalancer {
130 pub fn new(config: LoadBalancerConfig) -> Self {
132 Self {
133 config,
134 nodes: Arc::new(RwLock::new(HashMap::new())),
135 rr_counter: AtomicU64::new(0),
136 total_requests: AtomicU64::new(0),
137 }
138 }
139
140 pub fn add_node(&mut self, endpoint: NodeEndpoint) {
142 let node_id = endpoint.id;
143 let state = NodeState {
144 endpoint,
145 health: NodeHealth::Healthy,
146 replication_lag_ms: 0,
147 connections: 0,
148 avg_latency_ms: 0.0,
149 requests: 0,
150 failures: 0,
151 };
152
153 let nodes = self.nodes.clone();
156 tokio::spawn(async move {
157 nodes.write().await.insert(node_id, state);
158 });
159 }
160
161 pub fn remove_node(&mut self, node_id: &NodeId) {
163 let id = *node_id;
164 let nodes = self.nodes.clone();
165 tokio::spawn(async move {
166 nodes.write().await.remove(&id);
167 });
168 }
169
170 pub fn select_for_read(&self) -> Result<NodeEndpoint> {
172 self.total_requests.fetch_add(1, Ordering::SeqCst);
173
174 let rt = tokio::runtime::Handle::try_current();
176 let nodes_guard = match rt {
177 Ok(handle) => {
178 handle.block_on(async { self.nodes.read().await })
179 }
180 Err(_) => {
181 return Err(ProxyError::Routing("No async runtime available".to_string()));
183 }
184 };
185
186 let mut eligible: Vec<_> = nodes_guard
188 .values()
189 .filter(|n| n.health.can_serve_reads() && n.endpoint.enabled)
190 .filter(|n| {
191 self.config.read_write_split
192 || n.endpoint.role == NodeRole::Primary
193 || n.endpoint.role == NodeRole::Standby
194 || n.endpoint.role == NodeRole::ReadReplica
195 })
196 .collect();
197
198 if eligible.is_empty() {
200 eligible = nodes_guard
201 .values()
202 .filter(|n| n.health == NodeHealth::Transitioning && n.endpoint.enabled)
203 .collect();
204 }
205
206 if eligible.is_empty() {
207 return Err(ProxyError::NoHealthyNodes);
208 }
209
210 eligible.sort_by_key(|n| match n.health {
212 NodeHealth::Healthy => 0,
213 NodeHealth::Degraded => 1,
214 NodeHealth::Transitioning => 2,
215 NodeHealth::Unhealthy => 3,
216 });
217
218 let selected = self.select_by_strategy(&eligible, self.config.read_strategy)?;
219 Ok(selected.endpoint.clone())
220 }
221
222 pub fn select_for_write(&self) -> Result<NodeEndpoint> {
224 self.total_requests.fetch_add(1, Ordering::SeqCst);
225
226 let rt = tokio::runtime::Handle::try_current();
227 let nodes_guard = match rt {
228 Ok(handle) => {
229 handle.block_on(async { self.nodes.read().await })
230 }
231 Err(_) => {
232 return Err(ProxyError::Routing("No async runtime available".to_string()));
233 }
234 };
235
236 let primary = nodes_guard
238 .values()
239 .find(|n| n.endpoint.role == NodeRole::Primary && n.health.can_serve_writes() && n.endpoint.enabled);
240
241 match primary {
242 Some(node) => Ok(node.endpoint.clone()),
243 None => Err(ProxyError::NoHealthyNodes),
244 }
245 }
246
247 fn select_by_strategy<'a>(
249 &self,
250 nodes: &[&'a NodeState],
251 strategy: RoutingStrategy,
252 ) -> Result<&'a NodeState> {
253 match strategy {
254 RoutingStrategy::PrimaryOnly => {
255 nodes
256 .iter()
257 .find(|n| n.endpoint.role == NodeRole::Primary)
258 .copied()
259 .ok_or(ProxyError::NoHealthyNodes)
260 }
261 RoutingStrategy::RoundRobin => {
262 let idx = self.rr_counter.fetch_add(1, Ordering::SeqCst) as usize;
263 Ok(nodes[idx % nodes.len()])
264 }
265 RoutingStrategy::WeightedRoundRobin => {
266 let total_weight: u32 = nodes.iter().map(|n| n.endpoint.weight).sum();
268 if total_weight == 0 {
269 return Err(ProxyError::NoHealthyNodes);
270 }
271
272 let idx = self.rr_counter.fetch_add(1, Ordering::SeqCst);
273 let mut target = (idx % total_weight as u64) as u32;
274
275 for node in nodes {
276 if target < node.endpoint.weight {
277 return Ok(node);
278 }
279 target -= node.endpoint.weight;
280 }
281
282 Ok(nodes[0])
283 }
284 RoutingStrategy::LeastConnections => {
285 nodes
286 .iter()
287 .min_by_key(|n| n.connections)
288 .copied()
289 .ok_or(ProxyError::NoHealthyNodes)
290 }
291 RoutingStrategy::LatencyBased => {
292 nodes
293 .iter()
294 .min_by(|a, b| {
295 a.avg_latency_ms
296 .partial_cmp(&b.avg_latency_ms)
297 .unwrap_or(std::cmp::Ordering::Equal)
298 })
299 .copied()
300 .ok_or(ProxyError::NoHealthyNodes)
301 }
302 RoutingStrategy::Random => {
303 use std::time::{SystemTime, UNIX_EPOCH};
304 let seed = SystemTime::now()
305 .duration_since(UNIX_EPOCH)
306 .unwrap()
307 .as_nanos() as usize;
308 Ok(nodes[seed % nodes.len()])
309 }
310 RoutingStrategy::PreferLocal => {
311 nodes.first().copied().ok_or(ProxyError::NoHealthyNodes)
314 }
315 }
316 }
317
318 pub async fn set_node_health(&self, node_id: &NodeId, health: NodeHealth) {
326 if let Some(node) = self.nodes.write().await.get_mut(node_id) {
327 let old_health = node.health;
328 node.health = health;
329 tracing::debug!("Node {:?} health changed: {:?} -> {:?}", node_id, old_health, health);
330 }
331 }
332
333 pub async fn set_node_healthy(&self, node_id: &NodeId, healthy: bool) {
335 let health = if healthy { NodeHealth::Healthy } else { NodeHealth::Unhealthy };
336 self.set_node_health(node_id, health).await;
337 }
338
339 pub async fn set_node_transitioning(&self, node_id: &NodeId) {
341 self.set_node_health(node_id, NodeHealth::Transitioning).await;
342 }
343
344 pub async fn update_latency(&self, node_id: &NodeId, latency_ms: f64) {
346 if let Some(node) = self.nodes.write().await.get_mut(node_id) {
347 let alpha = 0.2;
349 node.avg_latency_ms = alpha * latency_ms + (1.0 - alpha) * node.avg_latency_ms;
350
351 let threshold = self.config.latency_threshold_ms as f64;
353 let degraded_threshold = threshold * 0.7; if node.health != NodeHealth::Transitioning {
357 if latency_ms > threshold {
358 node.health = NodeHealth::Unhealthy;
359 tracing::warn!(
360 "Node {:?} marked unhealthy due to high latency: {}ms",
361 node_id,
362 latency_ms
363 );
364 } else if latency_ms > degraded_threshold {
365 node.health = NodeHealth::Degraded;
366 tracing::debug!(
367 "Node {:?} marked degraded due to elevated latency: {}ms",
368 node_id,
369 latency_ms
370 );
371 } else if node.health == NodeHealth::Degraded || node.health == NodeHealth::Unhealthy {
372 node.health = NodeHealth::Healthy;
374 tracing::info!("Node {:?} recovered, marked healthy", node_id);
375 }
376 }
377 }
378 }
379
380 pub async fn update_replication_lag(&self, node_id: &NodeId, lag_ms: u64) {
382 const DEGRADED_LAG_MS: u64 = 5000; const UNHEALTHY_LAG_MS: u64 = 30000; if let Some(node) = self.nodes.write().await.get_mut(node_id) {
387 node.replication_lag_ms = lag_ms;
388
389 if node.health != NodeHealth::Transitioning {
391 if lag_ms > UNHEALTHY_LAG_MS {
392 node.health = NodeHealth::Unhealthy;
393 tracing::warn!(
394 "Node {:?} marked unhealthy due to high replication lag: {}ms",
395 node_id,
396 lag_ms
397 );
398 } else if lag_ms > DEGRADED_LAG_MS {
399 node.health = NodeHealth::Degraded;
400 tracing::debug!(
401 "Node {:?} marked degraded due to replication lag: {}ms",
402 node_id,
403 lag_ms
404 );
405 } else if node.health == NodeHealth::Degraded && node.avg_latency_ms < self.config.latency_threshold_ms as f64 * 0.7 {
406 node.health = NodeHealth::Healthy;
408 tracing::info!("Node {:?} recovered from lag, marked healthy", node_id);
409 }
410 }
411 }
412 }
413
414 pub async fn update_node_metrics(&self, node_id: &NodeId, latency_ms: f64, replication_lag_ms: u64, failure_rate: f64) {
416 if let Some(node) = self.nodes.write().await.get_mut(node_id) {
417 node.avg_latency_ms = 0.2 * latency_ms + 0.8 * node.avg_latency_ms;
419 node.replication_lag_ms = replication_lag_ms;
420
421 if node.health != NodeHealth::Transitioning {
423 let new_health = if !Self::is_responsive(latency_ms) {
425 NodeHealth::Unhealthy
426 } else if replication_lag_ms > 30000 {
427 NodeHealth::Unhealthy
428 } else if replication_lag_ms > 5000 || failure_rate > 0.5 || latency_ms > self.config.latency_threshold_ms as f64 {
429 NodeHealth::Degraded
430 } else {
431 NodeHealth::Healthy
432 };
433
434 if new_health != node.health {
435 tracing::debug!("Node {:?} health: {:?} -> {:?}", node_id, node.health, new_health);
436 node.health = new_health;
437 }
438 }
439 }
440 }
441
442 fn is_responsive(latency_ms: f64) -> bool {
444 latency_ms >= 0.0 && latency_ms < 5000.0
446 }
447
448 pub async fn increment_connections(&self, node_id: &NodeId) {
450 if let Some(node) = self.nodes.write().await.get_mut(node_id) {
451 node.connections += 1;
452 node.requests += 1;
453 }
454 }
455
456 pub async fn decrement_connections(&self, node_id: &NodeId) {
458 if let Some(node) = self.nodes.write().await.get_mut(node_id) {
459 node.connections = node.connections.saturating_sub(1);
460 }
461 }
462
463 pub async fn record_failure(&self, node_id: &NodeId) {
465 if let Some(node) = self.nodes.write().await.get_mut(node_id) {
466 node.failures += 1;
467 }
468 }
469
470 pub fn requests_routed(&self) -> u64 {
472 self.total_requests.load(Ordering::SeqCst)
473 }
474
475 pub async fn node_stats(&self, node_id: &NodeId) -> Option<NodeStats> {
477 self.nodes.read().await.get(node_id).map(|n| NodeStats {
478 health: n.health,
479 replication_lag_ms: n.replication_lag_ms,
480 connections: n.connections,
481 avg_latency_ms: n.avg_latency_ms,
482 requests: n.requests,
483 failures: n.failures,
484 })
485 }
486
487 pub async fn all_stats(&self) -> HashMap<NodeId, NodeStats> {
489 self.nodes
490 .read()
491 .await
492 .iter()
493 .map(|(id, n)| {
494 (
495 *id,
496 NodeStats {
497 health: n.health,
498 replication_lag_ms: n.replication_lag_ms,
499 connections: n.connections,
500 avg_latency_ms: n.avg_latency_ms,
501 requests: n.requests,
502 failures: n.failures,
503 },
504 )
505 })
506 .collect()
507 }
508}
509
510#[derive(Debug, Clone)]
512pub struct NodeStats {
513 pub health: NodeHealth,
515 pub replication_lag_ms: u64,
517 pub connections: u64,
519 pub avg_latency_ms: f64,
521 pub requests: u64,
523 pub failures: u64,
525}
526
527impl NodeStats {
528 pub fn is_healthy(&self) -> bool {
530 self.health == NodeHealth::Healthy
531 }
532
533 pub fn can_serve_reads(&self) -> bool {
535 self.health.can_serve_reads()
536 }
537}
538
539#[cfg(test)]
540mod tests {
541 use super::*;
542
543 #[test]
544 fn test_config_default() {
545 let config = LoadBalancerConfig::default();
546 assert_eq!(config.read_strategy, RoutingStrategy::RoundRobin);
547 assert_eq!(config.write_strategy, RoutingStrategy::PrimaryOnly);
548 assert!(config.read_write_split);
549 }
550
551 #[tokio::test]
552 async fn test_set_node_health() {
553 let lb = LoadBalancer::new(LoadBalancerConfig::default());
554 let node_id = NodeId::new();
555
556 {
558 let mut nodes = lb.nodes.write().await;
559 nodes.insert(
560 node_id,
561 NodeState {
562 endpoint: NodeEndpoint::new("localhost", 5432).with_role(NodeRole::Primary),
563 health: NodeHealth::Healthy,
564 replication_lag_ms: 0,
565 connections: 0,
566 avg_latency_ms: 0.0,
567 requests: 0,
568 failures: 0,
569 },
570 );
571 }
572
573 lb.set_node_health(&node_id, NodeHealth::Unhealthy).await;
574
575 let stats = lb.node_stats(&node_id).await.unwrap();
576 assert_eq!(stats.health, NodeHealth::Unhealthy);
577 assert!(!stats.is_healthy());
578 }
579
580 #[tokio::test]
581 async fn test_degraded_state() {
582 let lb = LoadBalancer::new(LoadBalancerConfig::default());
583 let node_id = NodeId::new();
584
585 {
586 let mut nodes = lb.nodes.write().await;
587 nodes.insert(
588 node_id,
589 NodeState {
590 endpoint: NodeEndpoint::new("localhost", 5432).with_role(NodeRole::Standby),
591 health: NodeHealth::Healthy,
592 replication_lag_ms: 0,
593 connections: 0,
594 avg_latency_ms: 0.0,
595 requests: 0,
596 failures: 0,
597 },
598 );
599 }
600
601 lb.set_node_health(&node_id, NodeHealth::Degraded).await;
603
604 let stats = lb.node_stats(&node_id).await.unwrap();
605 assert_eq!(stats.health, NodeHealth::Degraded);
606 assert!(stats.can_serve_reads()); assert!(!stats.is_healthy()); }
609
610 #[tokio::test]
611 async fn test_update_latency() {
612 let lb = LoadBalancer::new(LoadBalancerConfig::default());
613 let node_id = NodeId::new();
614
615 {
616 let mut nodes = lb.nodes.write().await;
617 nodes.insert(
618 node_id,
619 NodeState {
620 endpoint: NodeEndpoint::new("localhost", 5432),
621 health: NodeHealth::Healthy,
622 replication_lag_ms: 0,
623 connections: 0,
624 avg_latency_ms: 0.0,
625 requests: 0,
626 failures: 0,
627 },
628 );
629 }
630
631 lb.update_latency(&node_id, 50.0).await;
632
633 let stats = lb.node_stats(&node_id).await.unwrap();
634 assert!(stats.avg_latency_ms > 0.0);
635 }
636
637 #[tokio::test]
638 async fn test_replication_lag_degrades_health() {
639 let lb = LoadBalancer::new(LoadBalancerConfig::default());
640 let node_id = NodeId::new();
641
642 {
643 let mut nodes = lb.nodes.write().await;
644 nodes.insert(
645 node_id,
646 NodeState {
647 endpoint: NodeEndpoint::new("localhost", 5432).with_role(NodeRole::Standby),
648 health: NodeHealth::Healthy,
649 replication_lag_ms: 0,
650 connections: 0,
651 avg_latency_ms: 0.0,
652 requests: 0,
653 failures: 0,
654 },
655 );
656 }
657
658 lb.update_replication_lag(&node_id, 10000).await; let stats = lb.node_stats(&node_id).await.unwrap();
662 assert_eq!(stats.health, NodeHealth::Degraded);
663 assert_eq!(stats.replication_lag_ms, 10000);
664 }
665
666 #[tokio::test]
667 async fn test_connection_tracking() {
668 let lb = LoadBalancer::new(LoadBalancerConfig::default());
669 let node_id = NodeId::new();
670
671 {
672 let mut nodes = lb.nodes.write().await;
673 nodes.insert(
674 node_id,
675 NodeState {
676 endpoint: NodeEndpoint::new("localhost", 5432),
677 health: NodeHealth::Healthy,
678 replication_lag_ms: 0,
679 connections: 0,
680 avg_latency_ms: 0.0,
681 requests: 0,
682 failures: 0,
683 },
684 );
685 }
686
687 lb.increment_connections(&node_id).await;
688 lb.increment_connections(&node_id).await;
689
690 let stats = lb.node_stats(&node_id).await.unwrap();
691 assert_eq!(stats.connections, 2);
692
693 lb.decrement_connections(&node_id).await;
694 let stats = lb.node_stats(&node_id).await.unwrap();
695 assert_eq!(stats.connections, 1);
696 }
697}