Skip to main content

heliosdb_proxy/
load_balancer.rs

1//! Load Balancer - HeliosProxy
2//!
3//! Intelligent request routing with read/write splitting,
4//! multiple routing strategies, and latency-aware selection.
5
6use super::{NodeEndpoint, NodeId, NodeRole, ProxyError, Result};
7use std::collections::HashMap;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::sync::Arc;
10use tokio::sync::RwLock;
11
12/// Load balancer configuration
13#[derive(Debug, Clone)]
14pub struct LoadBalancerConfig {
15    /// Routing strategy for read queries
16    pub read_strategy: RoutingStrategy,
17    /// Routing strategy for write queries (usually Primary only)
18    pub write_strategy: RoutingStrategy,
19    /// Enable read/write splitting
20    pub read_write_split: bool,
21    /// Latency threshold for unhealthy marking (ms)
22    pub latency_threshold_ms: u64,
23    /// Minimum weight for a node to receive traffic
24    pub min_weight: u32,
25}
26
27impl Default for LoadBalancerConfig {
28    fn default() -> Self {
29        Self {
30            read_strategy: RoutingStrategy::RoundRobin,
31            write_strategy: RoutingStrategy::PrimaryOnly,
32            read_write_split: true,
33            latency_threshold_ms: 100,
34            min_weight: 1,
35        }
36    }
37}
38
39/// Routing strategy
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum RoutingStrategy {
42    /// Only route to primary (for writes)
43    PrimaryOnly,
44    /// Round-robin across all eligible nodes
45    RoundRobin,
46    /// Weighted round-robin based on node weights
47    WeightedRoundRobin,
48    /// Route to least connections
49    LeastConnections,
50    /// Route to lowest latency node
51    LatencyBased,
52    /// Random selection
53    Random,
54    /// Prefer local node (same rack/zone)
55    PreferLocal,
56}
57
58/// Node health state for graceful degradation during failover
59///
60/// This enum enables the load balancer to handle intermediate states
61/// during failover, allowing for graceful degradation rather than
62/// binary healthy/unhealthy transitions.
63#[derive(Debug, Clone, Copy, PartialEq, Eq)]
64pub enum NodeHealth {
65    /// Node is operating normally - can serve all traffic
66    Healthy,
67    /// Node is degraded (high latency or replication lag) but still usable for reads
68    Degraded,
69    /// Node is transitioning (failover in progress) - hold new requests
70    Transitioning,
71    /// Node is down or unreachable - do not route traffic
72    Unhealthy,
73}
74
75impl NodeHealth {
76    /// Check if node can serve read requests
77    pub fn can_serve_reads(&self) -> bool {
78        matches!(self, NodeHealth::Healthy | NodeHealth::Degraded)
79    }
80
81    /// Check if node can serve write requests
82    pub fn can_serve_writes(&self) -> bool {
83        matches!(self, NodeHealth::Healthy)
84    }
85
86    /// Check if node is in a usable state
87    pub fn is_usable(&self) -> bool {
88        !matches!(self, NodeHealth::Unhealthy)
89    }
90}
91
92impl Default for NodeHealth {
93    fn default() -> Self {
94        NodeHealth::Healthy
95    }
96}
97
98/// Node state for load balancing
99#[derive(Debug, Clone)]
100struct NodeState {
101    /// Node endpoint
102    endpoint: NodeEndpoint,
103    /// Node health state (supports degraded/transitioning states)
104    health: NodeHealth,
105    /// Replication lag in milliseconds (for standby nodes)
106    replication_lag_ms: u64,
107    /// Current connection count
108    connections: u64,
109    /// Average latency (ms)
110    avg_latency_ms: f64,
111    /// Requests routed to this node
112    requests: u64,
113    /// Request failures
114    failures: u64,
115}
116
117/// Load Balancer
118pub struct LoadBalancer {
119    /// Configuration
120    config: LoadBalancerConfig,
121    /// Node states
122    nodes: Arc<RwLock<HashMap<NodeId, NodeState>>>,
123    /// Round-robin counter
124    rr_counter: AtomicU64,
125    /// Total requests routed
126    total_requests: AtomicU64,
127}
128
129impl LoadBalancer {
130    /// Create a new load balancer
131    pub fn new(config: LoadBalancerConfig) -> Self {
132        Self {
133            config,
134            nodes: Arc::new(RwLock::new(HashMap::new())),
135            rr_counter: AtomicU64::new(0),
136            total_requests: AtomicU64::new(0),
137        }
138    }
139
140    /// Add a node to the load balancer
141    pub fn add_node(&mut self, endpoint: NodeEndpoint) {
142        let node_id = endpoint.id;
143        let state = NodeState {
144            endpoint,
145            health: NodeHealth::Healthy,
146            replication_lag_ms: 0,
147            connections: 0,
148            avg_latency_ms: 0.0,
149            requests: 0,
150            failures: 0,
151        };
152
153        // Use blocking lock for simplicity in sync context
154        // In production, this should be async
155        let nodes = self.nodes.clone();
156        tokio::spawn(async move {
157            nodes.write().await.insert(node_id, state);
158        });
159    }
160
161    /// Remove a node from the load balancer
162    pub fn remove_node(&mut self, node_id: &NodeId) {
163        let id = *node_id;
164        let nodes = self.nodes.clone();
165        tokio::spawn(async move {
166            nodes.write().await.remove(&id);
167        });
168    }
169
170    /// Select a node for a read query
171    pub fn select_for_read(&self) -> Result<NodeEndpoint> {
172        self.total_requests.fetch_add(1, Ordering::SeqCst);
173
174        // Use blocking for sync compatibility
175        let rt = tokio::runtime::Handle::try_current();
176        let nodes_guard = match rt {
177            Ok(handle) => {
178                handle.block_on(async { self.nodes.read().await })
179            }
180            Err(_) => {
181                // Fallback: return error if no runtime
182                return Err(ProxyError::Routing("No async runtime available".to_string()));
183            }
184        };
185
186        // First, filter for healthy or degraded nodes (can serve reads)
187        let mut eligible: Vec<_> = nodes_guard
188            .values()
189            .filter(|n| n.health.can_serve_reads() && n.endpoint.enabled)
190            .filter(|n| {
191                self.config.read_write_split
192                    || n.endpoint.role == NodeRole::Primary
193                    || n.endpoint.role == NodeRole::Standby
194                    || n.endpoint.role == NodeRole::ReadReplica
195            })
196            .collect();
197
198        // If no healthy/degraded nodes, try transitioning nodes as last resort
199        if eligible.is_empty() {
200            eligible = nodes_guard
201                .values()
202                .filter(|n| n.health == NodeHealth::Transitioning && n.endpoint.enabled)
203                .collect();
204        }
205
206        if eligible.is_empty() {
207            return Err(ProxyError::NoHealthyNodes);
208        }
209
210        // Sort by health preference: Healthy first, then Degraded, then Transitioning
211        eligible.sort_by_key(|n| match n.health {
212            NodeHealth::Healthy => 0,
213            NodeHealth::Degraded => 1,
214            NodeHealth::Transitioning => 2,
215            NodeHealth::Unhealthy => 3,
216        });
217
218        let selected = self.select_by_strategy(&eligible, self.config.read_strategy)?;
219        Ok(selected.endpoint.clone())
220    }
221
222    /// Select a node for a write query
223    pub fn select_for_write(&self) -> Result<NodeEndpoint> {
224        self.total_requests.fetch_add(1, Ordering::SeqCst);
225
226        let rt = tokio::runtime::Handle::try_current();
227        let nodes_guard = match rt {
228            Ok(handle) => {
229                handle.block_on(async { self.nodes.read().await })
230            }
231            Err(_) => {
232                return Err(ProxyError::Routing("No async runtime available".to_string()));
233            }
234        };
235
236        // For writes, require fully healthy primary (not degraded)
237        let primary = nodes_guard
238            .values()
239            .find(|n| n.endpoint.role == NodeRole::Primary && n.health.can_serve_writes() && n.endpoint.enabled);
240
241        match primary {
242            Some(node) => Ok(node.endpoint.clone()),
243            None => Err(ProxyError::NoHealthyNodes),
244        }
245    }
246
247    /// Select by strategy
248    fn select_by_strategy<'a>(
249        &self,
250        nodes: &[&'a NodeState],
251        strategy: RoutingStrategy,
252    ) -> Result<&'a NodeState> {
253        match strategy {
254            RoutingStrategy::PrimaryOnly => {
255                nodes
256                    .iter()
257                    .find(|n| n.endpoint.role == NodeRole::Primary)
258                    .copied()
259                    .ok_or(ProxyError::NoHealthyNodes)
260            }
261            RoutingStrategy::RoundRobin => {
262                let idx = self.rr_counter.fetch_add(1, Ordering::SeqCst) as usize;
263                Ok(nodes[idx % nodes.len()])
264            }
265            RoutingStrategy::WeightedRoundRobin => {
266                // Simplified weighted selection
267                let total_weight: u32 = nodes.iter().map(|n| n.endpoint.weight).sum();
268                if total_weight == 0 {
269                    return Err(ProxyError::NoHealthyNodes);
270                }
271
272                let idx = self.rr_counter.fetch_add(1, Ordering::SeqCst);
273                let mut target = (idx % total_weight as u64) as u32;
274
275                for node in nodes {
276                    if target < node.endpoint.weight {
277                        return Ok(node);
278                    }
279                    target -= node.endpoint.weight;
280                }
281
282                Ok(nodes[0])
283            }
284            RoutingStrategy::LeastConnections => {
285                nodes
286                    .iter()
287                    .min_by_key(|n| n.connections)
288                    .copied()
289                    .ok_or(ProxyError::NoHealthyNodes)
290            }
291            RoutingStrategy::LatencyBased => {
292                nodes
293                    .iter()
294                    .min_by(|a, b| {
295                        a.avg_latency_ms
296                            .partial_cmp(&b.avg_latency_ms)
297                            .unwrap_or(std::cmp::Ordering::Equal)
298                    })
299                    .copied()
300                    .ok_or(ProxyError::NoHealthyNodes)
301            }
302            RoutingStrategy::Random => {
303                use std::time::{SystemTime, UNIX_EPOCH};
304                let seed = SystemTime::now()
305                    .duration_since(UNIX_EPOCH)
306                    .unwrap()
307                    .as_nanos() as usize;
308                Ok(nodes[seed % nodes.len()])
309            }
310            RoutingStrategy::PreferLocal => {
311                // For skeleton, just return first node
312                // In production, would check rack/zone affinity
313                nodes.first().copied().ok_or(ProxyError::NoHealthyNodes)
314            }
315        }
316    }
317
318    /// Set node health state
319    ///
320    /// Supports granular health states for graceful degradation:
321    /// - Healthy: Normal operation
322    /// - Degraded: High latency/lag but still usable for reads
323    /// - Transitioning: Failover in progress
324    /// - Unhealthy: Do not route traffic
325    pub async fn set_node_health(&self, node_id: &NodeId, health: NodeHealth) {
326        if let Some(node) = self.nodes.write().await.get_mut(node_id) {
327            let old_health = node.health;
328            node.health = health;
329            tracing::debug!("Node {:?} health changed: {:?} -> {:?}", node_id, old_health, health);
330        }
331    }
332
333    /// Legacy method for backward compatibility
334    pub async fn set_node_healthy(&self, node_id: &NodeId, healthy: bool) {
335        let health = if healthy { NodeHealth::Healthy } else { NodeHealth::Unhealthy };
336        self.set_node_health(node_id, health).await;
337    }
338
339    /// Mark node as transitioning (failover in progress)
340    pub async fn set_node_transitioning(&self, node_id: &NodeId) {
341        self.set_node_health(node_id, NodeHealth::Transitioning).await;
342    }
343
344    /// Update node latency and adjust health state accordingly
345    pub async fn update_latency(&self, node_id: &NodeId, latency_ms: f64) {
346        if let Some(node) = self.nodes.write().await.get_mut(node_id) {
347            // Exponential moving average
348            let alpha = 0.2;
349            node.avg_latency_ms = alpha * latency_ms + (1.0 - alpha) * node.avg_latency_ms;
350
351            // Adjust health based on latency thresholds
352            let threshold = self.config.latency_threshold_ms as f64;
353            let degraded_threshold = threshold * 0.7; // 70% of threshold = degraded
354
355            // Only adjust health if not transitioning (preserve failover state)
356            if node.health != NodeHealth::Transitioning {
357                if latency_ms > threshold {
358                    node.health = NodeHealth::Unhealthy;
359                    tracing::warn!(
360                        "Node {:?} marked unhealthy due to high latency: {}ms",
361                        node_id,
362                        latency_ms
363                    );
364                } else if latency_ms > degraded_threshold {
365                    node.health = NodeHealth::Degraded;
366                    tracing::debug!(
367                        "Node {:?} marked degraded due to elevated latency: {}ms",
368                        node_id,
369                        latency_ms
370                    );
371                } else if node.health == NodeHealth::Degraded || node.health == NodeHealth::Unhealthy {
372                    // Recovery: if latency is back to normal, restore to healthy
373                    node.health = NodeHealth::Healthy;
374                    tracing::info!("Node {:?} recovered, marked healthy", node_id);
375                }
376            }
377        }
378    }
379
380    /// Update node replication lag and adjust health state
381    pub async fn update_replication_lag(&self, node_id: &NodeId, lag_ms: u64) {
382        // Thresholds for replication lag (configurable in production)
383        const DEGRADED_LAG_MS: u64 = 5000;   // 5 seconds = degraded
384        const UNHEALTHY_LAG_MS: u64 = 30000; // 30 seconds = unhealthy
385
386        if let Some(node) = self.nodes.write().await.get_mut(node_id) {
387            node.replication_lag_ms = lag_ms;
388
389            // Only adjust health if not transitioning
390            if node.health != NodeHealth::Transitioning {
391                if lag_ms > UNHEALTHY_LAG_MS {
392                    node.health = NodeHealth::Unhealthy;
393                    tracing::warn!(
394                        "Node {:?} marked unhealthy due to high replication lag: {}ms",
395                        node_id,
396                        lag_ms
397                    );
398                } else if lag_ms > DEGRADED_LAG_MS {
399                    node.health = NodeHealth::Degraded;
400                    tracing::debug!(
401                        "Node {:?} marked degraded due to replication lag: {}ms",
402                        node_id,
403                        lag_ms
404                    );
405                } else if node.health == NodeHealth::Degraded && node.avg_latency_ms < self.config.latency_threshold_ms as f64 * 0.7 {
406                    // Recovery: lag is acceptable and latency is good
407                    node.health = NodeHealth::Healthy;
408                    tracing::info!("Node {:?} recovered from lag, marked healthy", node_id);
409                }
410            }
411        }
412    }
413
414    /// Update node health based on combined metrics
415    pub async fn update_node_metrics(&self, node_id: &NodeId, latency_ms: f64, replication_lag_ms: u64, failure_rate: f64) {
416        if let Some(node) = self.nodes.write().await.get_mut(node_id) {
417            // Update metrics
418            node.avg_latency_ms = 0.2 * latency_ms + 0.8 * node.avg_latency_ms;
419            node.replication_lag_ms = replication_lag_ms;
420
421            // Only adjust health if not transitioning
422            if node.health != NodeHealth::Transitioning {
423                // Determine health based on all factors
424                let new_health = if !Self::is_responsive(latency_ms) {
425                    NodeHealth::Unhealthy
426                } else if replication_lag_ms > 30000 {
427                    NodeHealth::Unhealthy
428                } else if replication_lag_ms > 5000 || failure_rate > 0.5 || latency_ms > self.config.latency_threshold_ms as f64 {
429                    NodeHealth::Degraded
430                } else {
431                    NodeHealth::Healthy
432                };
433
434                if new_health != node.health {
435                    tracing::debug!("Node {:?} health: {:?} -> {:?}", node_id, node.health, new_health);
436                    node.health = new_health;
437                }
438            }
439        }
440    }
441
442    /// Check if latency indicates node is responsive
443    fn is_responsive(latency_ms: f64) -> bool {
444        // Consider non-responsive if latency exceeds 5 seconds or is negative (timeout)
445        latency_ms >= 0.0 && latency_ms < 5000.0
446    }
447
448    /// Increment connection count for a node
449    pub async fn increment_connections(&self, node_id: &NodeId) {
450        if let Some(node) = self.nodes.write().await.get_mut(node_id) {
451            node.connections += 1;
452            node.requests += 1;
453        }
454    }
455
456    /// Decrement connection count for a node
457    pub async fn decrement_connections(&self, node_id: &NodeId) {
458        if let Some(node) = self.nodes.write().await.get_mut(node_id) {
459            node.connections = node.connections.saturating_sub(1);
460        }
461    }
462
463    /// Record a failure for a node
464    pub async fn record_failure(&self, node_id: &NodeId) {
465        if let Some(node) = self.nodes.write().await.get_mut(node_id) {
466            node.failures += 1;
467        }
468    }
469
470    /// Get total requests routed
471    pub fn requests_routed(&self) -> u64 {
472        self.total_requests.load(Ordering::SeqCst)
473    }
474
475    /// Get node statistics
476    pub async fn node_stats(&self, node_id: &NodeId) -> Option<NodeStats> {
477        self.nodes.read().await.get(node_id).map(|n| NodeStats {
478            health: n.health,
479            replication_lag_ms: n.replication_lag_ms,
480            connections: n.connections,
481            avg_latency_ms: n.avg_latency_ms,
482            requests: n.requests,
483            failures: n.failures,
484        })
485    }
486
487    /// Get all node statistics
488    pub async fn all_stats(&self) -> HashMap<NodeId, NodeStats> {
489        self.nodes
490            .read()
491            .await
492            .iter()
493            .map(|(id, n)| {
494                (
495                    *id,
496                    NodeStats {
497                        health: n.health,
498                        replication_lag_ms: n.replication_lag_ms,
499                        connections: n.connections,
500                        avg_latency_ms: n.avg_latency_ms,
501                        requests: n.requests,
502                        failures: n.failures,
503                    },
504                )
505            })
506            .collect()
507    }
508}
509
510/// Node statistics
511#[derive(Debug, Clone)]
512pub struct NodeStats {
513    /// Node health state
514    pub health: NodeHealth,
515    /// Replication lag (ms)
516    pub replication_lag_ms: u64,
517    /// Current connections
518    pub connections: u64,
519    /// Average latency (ms)
520    pub avg_latency_ms: f64,
521    /// Total requests
522    pub requests: u64,
523    /// Total failures
524    pub failures: u64,
525}
526
527impl NodeStats {
528    /// Check if node is healthy (backward compatibility)
529    pub fn is_healthy(&self) -> bool {
530        self.health == NodeHealth::Healthy
531    }
532
533    /// Check if node can serve reads
534    pub fn can_serve_reads(&self) -> bool {
535        self.health.can_serve_reads()
536    }
537}
538
539#[cfg(test)]
540mod tests {
541    use super::*;
542
543    #[test]
544    fn test_config_default() {
545        let config = LoadBalancerConfig::default();
546        assert_eq!(config.read_strategy, RoutingStrategy::RoundRobin);
547        assert_eq!(config.write_strategy, RoutingStrategy::PrimaryOnly);
548        assert!(config.read_write_split);
549    }
550
551    #[tokio::test]
552    async fn test_set_node_health() {
553        let lb = LoadBalancer::new(LoadBalancerConfig::default());
554        let node_id = NodeId::new();
555
556        // Add node
557        {
558            let mut nodes = lb.nodes.write().await;
559            nodes.insert(
560                node_id,
561                NodeState {
562                    endpoint: NodeEndpoint::new("localhost", 5432).with_role(NodeRole::Primary),
563                    health: NodeHealth::Healthy,
564                    replication_lag_ms: 0,
565                    connections: 0,
566                    avg_latency_ms: 0.0,
567                    requests: 0,
568                    failures: 0,
569                },
570            );
571        }
572
573        lb.set_node_health(&node_id, NodeHealth::Unhealthy).await;
574
575        let stats = lb.node_stats(&node_id).await.unwrap();
576        assert_eq!(stats.health, NodeHealth::Unhealthy);
577        assert!(!stats.is_healthy());
578    }
579
580    #[tokio::test]
581    async fn test_degraded_state() {
582        let lb = LoadBalancer::new(LoadBalancerConfig::default());
583        let node_id = NodeId::new();
584
585        {
586            let mut nodes = lb.nodes.write().await;
587            nodes.insert(
588                node_id,
589                NodeState {
590                    endpoint: NodeEndpoint::new("localhost", 5432).with_role(NodeRole::Standby),
591                    health: NodeHealth::Healthy,
592                    replication_lag_ms: 0,
593                    connections: 0,
594                    avg_latency_ms: 0.0,
595                    requests: 0,
596                    failures: 0,
597                },
598            );
599        }
600
601        // Set to degraded
602        lb.set_node_health(&node_id, NodeHealth::Degraded).await;
603
604        let stats = lb.node_stats(&node_id).await.unwrap();
605        assert_eq!(stats.health, NodeHealth::Degraded);
606        assert!(stats.can_serve_reads()); // Degraded can still serve reads
607        assert!(!stats.is_healthy()); // But not considered fully healthy
608    }
609
610    #[tokio::test]
611    async fn test_update_latency() {
612        let lb = LoadBalancer::new(LoadBalancerConfig::default());
613        let node_id = NodeId::new();
614
615        {
616            let mut nodes = lb.nodes.write().await;
617            nodes.insert(
618                node_id,
619                NodeState {
620                    endpoint: NodeEndpoint::new("localhost", 5432),
621                    health: NodeHealth::Healthy,
622                    replication_lag_ms: 0,
623                    connections: 0,
624                    avg_latency_ms: 0.0,
625                    requests: 0,
626                    failures: 0,
627                },
628            );
629        }
630
631        lb.update_latency(&node_id, 50.0).await;
632
633        let stats = lb.node_stats(&node_id).await.unwrap();
634        assert!(stats.avg_latency_ms > 0.0);
635    }
636
637    #[tokio::test]
638    async fn test_replication_lag_degrades_health() {
639        let lb = LoadBalancer::new(LoadBalancerConfig::default());
640        let node_id = NodeId::new();
641
642        {
643            let mut nodes = lb.nodes.write().await;
644            nodes.insert(
645                node_id,
646                NodeState {
647                    endpoint: NodeEndpoint::new("localhost", 5432).with_role(NodeRole::Standby),
648                    health: NodeHealth::Healthy,
649                    replication_lag_ms: 0,
650                    connections: 0,
651                    avg_latency_ms: 0.0,
652                    requests: 0,
653                    failures: 0,
654                },
655            );
656        }
657
658        // Update with high replication lag
659        lb.update_replication_lag(&node_id, 10000).await; // 10 seconds
660
661        let stats = lb.node_stats(&node_id).await.unwrap();
662        assert_eq!(stats.health, NodeHealth::Degraded);
663        assert_eq!(stats.replication_lag_ms, 10000);
664    }
665
666    #[tokio::test]
667    async fn test_connection_tracking() {
668        let lb = LoadBalancer::new(LoadBalancerConfig::default());
669        let node_id = NodeId::new();
670
671        {
672            let mut nodes = lb.nodes.write().await;
673            nodes.insert(
674                node_id,
675                NodeState {
676                    endpoint: NodeEndpoint::new("localhost", 5432),
677                    health: NodeHealth::Healthy,
678                    replication_lag_ms: 0,
679                    connections: 0,
680                    avg_latency_ms: 0.0,
681                    requests: 0,
682                    failures: 0,
683                },
684            );
685        }
686
687        lb.increment_connections(&node_id).await;
688        lb.increment_connections(&node_id).await;
689
690        let stats = lb.node_stats(&node_id).await.unwrap();
691        assert_eq!(stats.connections, 2);
692
693        lb.decrement_connections(&node_id).await;
694        let stats = lb.node_stats(&node_id).await.unwrap();
695        assert_eq!(stats.connections, 1);
696    }
697}