Skip to main content

grapsus_proxy/upstream/
locality.rs

1//! Locality-aware load balancer
2//!
3//! Prefers targets in the same zone/region as the proxy, falling back to
4//! other zones when local targets are unhealthy or overloaded. Useful for
5//! multi-region deployments to minimize latency and cross-zone traffic costs.
6
7use async_trait::async_trait;
8use rand::seq::IndexedRandom;
9use std::collections::HashMap;
10use std::sync::atomic::{AtomicUsize, Ordering};
11use std::sync::Arc;
12use tokio::sync::RwLock;
13use tracing::{debug, trace, warn};
14
15use grapsus_common::errors::{GrapsusError, GrapsusResult};
16
17use super::{LoadBalancer, RequestContext, TargetSelection, UpstreamTarget};
18
19/// Configuration for locality-aware load balancing
20#[derive(Debug, Clone)]
21pub struct LocalityAwareConfig {
22    /// The local zone/region identifier for this proxy instance
23    pub local_zone: String,
24    /// Fallback strategy when no local targets are healthy
25    pub fallback_strategy: LocalityFallback,
26    /// Minimum healthy local targets before considering fallback
27    pub min_local_healthy: usize,
28    /// Whether to use weighted selection within a zone
29    pub use_weights: bool,
30    /// Zone priority order for fallback (closest first)
31    /// If empty, all non-local zones are treated equally
32    pub zone_priority: Vec<String>,
33}
34
35impl Default for LocalityAwareConfig {
36    fn default() -> Self {
37        Self {
38            local_zone: std::env::var("GRAPSUS_ZONE")
39                .or_else(|_| std::env::var("ZONE"))
40                .or_else(|_| std::env::var("REGION"))
41                .unwrap_or_else(|_| "default".to_string()),
42            fallback_strategy: LocalityFallback::RoundRobin,
43            min_local_healthy: 1,
44            use_weights: true,
45            zone_priority: Vec::new(),
46        }
47    }
48}
49
50/// Fallback strategy when local targets are unavailable
51#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum LocalityFallback {
53    /// Round-robin across fallback targets
54    RoundRobin,
55    /// Random selection from fallback targets
56    Random,
57    /// Fail immediately if no local targets
58    FailLocal,
59}
60
61/// Target with zone information
62#[derive(Debug, Clone)]
63struct ZonedTarget {
64    target: UpstreamTarget,
65    zone: String,
66}
67
68/// Locality-aware load balancer
69pub struct LocalityAwareBalancer {
70    /// All targets with zone information
71    targets: Vec<ZonedTarget>,
72    /// Health status per target address
73    health_status: Arc<RwLock<HashMap<String, bool>>>,
74    /// Round-robin counter for local zone
75    local_counter: AtomicUsize,
76    /// Round-robin counter for fallback
77    fallback_counter: AtomicUsize,
78    /// Configuration
79    config: LocalityAwareConfig,
80}
81
82impl LocalityAwareBalancer {
83    /// Create a new locality-aware balancer
84    ///
85    /// Zone information is extracted from target addresses using the format:
86    /// - `zone:host:port` - explicit zone prefix
87    /// - Or via target metadata (weight field encodes zone in high bits)
88    /// - Or defaults to "unknown" zone
89    pub fn new(targets: Vec<UpstreamTarget>, config: LocalityAwareConfig) -> Self {
90        let mut health_status = HashMap::new();
91        let mut zoned_targets = Vec::with_capacity(targets.len());
92
93        for target in targets {
94            health_status.insert(target.full_address(), true);
95
96            // Extract zone from address if it contains zone prefix
97            // Format: "zone:host:port" or just "host:port"
98            let (zone, actual_target) = Self::parse_zone_from_target(&target);
99
100            zoned_targets.push(ZonedTarget {
101                target: actual_target,
102                zone,
103            });
104        }
105
106        debug!(
107            local_zone = %config.local_zone,
108            total_targets = zoned_targets.len(),
109            local_targets = zoned_targets.iter().filter(|t| t.zone == config.local_zone).count(),
110            "Created locality-aware balancer"
111        );
112
113        Self {
114            targets: zoned_targets,
115            health_status: Arc::new(RwLock::new(health_status)),
116            local_counter: AtomicUsize::new(0),
117            fallback_counter: AtomicUsize::new(0),
118            config,
119        }
120    }
121
122    /// Parse zone from target address
123    ///
124    /// Supports formats:
125    /// - `zone=us-west-1,host:port` - zone in metadata prefix
126    /// - `us-west-1/host:port` - zone as path prefix
127    /// - `host:port` - no zone, defaults to "unknown"
128    fn parse_zone_from_target(target: &UpstreamTarget) -> (String, UpstreamTarget) {
129        let addr = &target.address;
130
131        // Check for zone= prefix (e.g., "zone=us-west-1,10.0.0.1")
132        if let Some(rest) = addr.strip_prefix("zone=") {
133            if let Some((zone, host)) = rest.split_once(',') {
134                return (
135                    zone.to_string(),
136                    UpstreamTarget::new(host, target.port, target.weight),
137                );
138            }
139        }
140
141        // Check for zone/ prefix (e.g., "us-west-1/10.0.0.1")
142        if let Some((zone, host)) = addr.split_once('/') {
143            // Ensure it's not an IP with port
144            if !zone.contains(':') && !zone.contains('.') {
145                return (
146                    zone.to_string(),
147                    UpstreamTarget::new(host, target.port, target.weight),
148                );
149            }
150        }
151
152        // No zone prefix, return as-is with unknown zone
153        ("unknown".to_string(), target.clone())
154    }
155
156    /// Get healthy targets in a specific zone
157    async fn healthy_in_zone(&self, zone: &str) -> Vec<&ZonedTarget> {
158        let health = self.health_status.read().await;
159        self.targets
160            .iter()
161            .filter(|t| t.zone == zone && *health.get(&t.target.full_address()).unwrap_or(&true))
162            .collect()
163    }
164
165    /// Get all healthy targets not in the local zone, sorted by priority
166    async fn healthy_fallback(&self) -> Vec<&ZonedTarget> {
167        let health = self.health_status.read().await;
168        let local_zone = &self.config.local_zone;
169
170        let mut fallback: Vec<_> = self
171            .targets
172            .iter()
173            .filter(|t| {
174                t.zone != *local_zone && *health.get(&t.target.full_address()).unwrap_or(&true)
175            })
176            .collect();
177
178        // Sort by zone priority if specified
179        if !self.config.zone_priority.is_empty() {
180            fallback.sort_by(|a, b| {
181                let priority_a = self
182                    .config
183                    .zone_priority
184                    .iter()
185                    .position(|z| z == &a.zone)
186                    .unwrap_or(usize::MAX);
187                let priority_b = self
188                    .config
189                    .zone_priority
190                    .iter()
191                    .position(|z| z == &b.zone)
192                    .unwrap_or(usize::MAX);
193                priority_a.cmp(&priority_b)
194            });
195        }
196
197        fallback
198    }
199
200    /// Select from targets using round-robin
201    fn select_round_robin<'a>(
202        &self,
203        targets: &[&'a ZonedTarget],
204        counter: &AtomicUsize,
205    ) -> Option<&'a ZonedTarget> {
206        if targets.is_empty() {
207            return None;
208        }
209
210        if self.config.use_weights {
211            // Weighted round-robin
212            let total_weight: u32 = targets.iter().map(|t| t.target.weight).sum();
213            if total_weight == 0 {
214                return targets.first().copied();
215            }
216
217            let idx = counter.fetch_add(1, Ordering::Relaxed);
218            let mut weight_idx = (idx as u32) % total_weight;
219
220            for target in targets {
221                if weight_idx < target.target.weight {
222                    return Some(target);
223                }
224                weight_idx -= target.target.weight;
225            }
226
227            targets.first().copied()
228        } else {
229            let idx = counter.fetch_add(1, Ordering::Relaxed) % targets.len();
230            Some(targets[idx])
231        }
232    }
233
234    /// Select from targets using random selection
235    fn select_random<'a>(&self, targets: &[&'a ZonedTarget]) -> Option<&'a ZonedTarget> {
236        use rand::seq::SliceRandom;
237
238        if targets.is_empty() {
239            return None;
240        }
241
242        let mut rng = rand::rng();
243        targets.choose(&mut rng).copied()
244    }
245}
246
247#[async_trait]
248impl LoadBalancer for LocalityAwareBalancer {
249    async fn select(&self, _context: Option<&RequestContext>) -> GrapsusResult<TargetSelection> {
250        trace!(
251            total_targets = self.targets.len(),
252            local_zone = %self.config.local_zone,
253            algorithm = "locality_aware",
254            "Selecting upstream target"
255        );
256
257        // First, try local zone
258        let local_healthy = self.healthy_in_zone(&self.config.local_zone).await;
259
260        if local_healthy.len() >= self.config.min_local_healthy {
261            // Use local targets
262            let selected = self
263                .select_round_robin(&local_healthy, &self.local_counter)
264                .ok_or(GrapsusError::NoHealthyUpstream)?;
265
266            trace!(
267                selected_target = %selected.target.full_address(),
268                zone = %selected.zone,
269                local_healthy = local_healthy.len(),
270                algorithm = "locality_aware",
271                "Selected local target"
272            );
273
274            return Ok(TargetSelection {
275                address: selected.target.full_address(),
276                weight: selected.target.weight,
277                metadata: {
278                    let mut m = HashMap::new();
279                    m.insert("zone".to_string(), selected.zone.clone());
280                    m.insert("locality".to_string(), "local".to_string());
281                    m
282                },
283            });
284        }
285
286        // Not enough local targets, check fallback strategy
287        match self.config.fallback_strategy {
288            LocalityFallback::FailLocal => {
289                warn!(
290                    local_zone = %self.config.local_zone,
291                    local_healthy = local_healthy.len(),
292                    min_required = self.config.min_local_healthy,
293                    algorithm = "locality_aware",
294                    "No healthy local targets and fallback disabled"
295                );
296                return Err(GrapsusError::NoHealthyUpstream);
297            }
298            LocalityFallback::RoundRobin | LocalityFallback::Random => {
299                // Fall back to remote zones
300            }
301        }
302
303        // Get fallback targets (sorted by zone priority)
304        let fallback_targets = self.healthy_fallback().await;
305
306        // If we have some local targets, combine them with fallback
307        let all_targets: Vec<&ZonedTarget> = if !local_healthy.is_empty() {
308            // Local first, then fallback
309            local_healthy.into_iter().chain(fallback_targets).collect()
310        } else {
311            fallback_targets
312        };
313
314        if all_targets.is_empty() {
315            warn!(
316                total_targets = self.targets.len(),
317                algorithm = "locality_aware",
318                "No healthy upstream targets available"
319            );
320            return Err(GrapsusError::NoHealthyUpstream);
321        }
322
323        // Select based on fallback strategy
324        let selected = match self.config.fallback_strategy {
325            LocalityFallback::RoundRobin => {
326                self.select_round_robin(&all_targets, &self.fallback_counter)
327            }
328            LocalityFallback::Random => self.select_random(&all_targets),
329            LocalityFallback::FailLocal => unreachable!(),
330        }
331        .ok_or(GrapsusError::NoHealthyUpstream)?;
332
333        let is_local = selected.zone == self.config.local_zone;
334        debug!(
335            selected_target = %selected.target.full_address(),
336            zone = %selected.zone,
337            is_local = is_local,
338            fallback_used = !is_local,
339            algorithm = "locality_aware",
340            "Selected target (fallback path)"
341        );
342
343        Ok(TargetSelection {
344            address: selected.target.full_address(),
345            weight: selected.target.weight,
346            metadata: {
347                let mut m = HashMap::new();
348                m.insert("zone".to_string(), selected.zone.clone());
349                m.insert(
350                    "locality".to_string(),
351                    if is_local { "local" } else { "remote" }.to_string(),
352                );
353                m
354            },
355        })
356    }
357
358    async fn report_health(&self, address: &str, healthy: bool) {
359        trace!(
360            target = %address,
361            healthy = healthy,
362            algorithm = "locality_aware",
363            "Updating target health status"
364        );
365        self.health_status
366            .write()
367            .await
368            .insert(address.to_string(), healthy);
369    }
370
371    async fn healthy_targets(&self) -> Vec<String> {
372        self.health_status
373            .read()
374            .await
375            .iter()
376            .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
377            .collect()
378    }
379}
380
381#[cfg(test)]
382mod tests {
383    use super::*;
384
385    fn make_zoned_targets() -> Vec<UpstreamTarget> {
386        vec![
387            // Local zone (us-west-1)
388            UpstreamTarget::new("zone=us-west-1,10.0.0.1", 8080, 100),
389            UpstreamTarget::new("zone=us-west-1,10.0.0.2", 8080, 100),
390            // Remote zone (us-east-1)
391            UpstreamTarget::new("zone=us-east-1,10.1.0.1", 8080, 100),
392            UpstreamTarget::new("zone=us-east-1,10.1.0.2", 8080, 100),
393            // Another remote zone (eu-west-1)
394            UpstreamTarget::new("zone=eu-west-1,10.2.0.1", 8080, 100),
395        ]
396    }
397
398    #[test]
399    fn test_zone_parsing() {
400        // Test zone= prefix
401        let target = UpstreamTarget::new("zone=us-west-1,10.0.0.1", 8080, 100);
402        let (zone, parsed) = LocalityAwareBalancer::parse_zone_from_target(&target);
403        assert_eq!(zone, "us-west-1");
404        assert_eq!(parsed.address, "10.0.0.1");
405
406        // Test zone/ prefix
407        let target = UpstreamTarget::new("us-east-1/10.0.0.1", 8080, 100);
408        let (zone, parsed) = LocalityAwareBalancer::parse_zone_from_target(&target);
409        assert_eq!(zone, "us-east-1");
410        assert_eq!(parsed.address, "10.0.0.1");
411
412        // Test no zone
413        let target = UpstreamTarget::new("10.0.0.1", 8080, 100);
414        let (zone, parsed) = LocalityAwareBalancer::parse_zone_from_target(&target);
415        assert_eq!(zone, "unknown");
416        assert_eq!(parsed.address, "10.0.0.1");
417    }
418
419    #[tokio::test]
420    async fn test_prefers_local_zone() {
421        let targets = make_zoned_targets();
422        let config = LocalityAwareConfig {
423            local_zone: "us-west-1".to_string(),
424            ..Default::default()
425        };
426        let balancer = LocalityAwareBalancer::new(targets, config);
427
428        // All selections should be from local zone
429        for _ in 0..10 {
430            let selection = balancer.select(None).await.unwrap();
431            assert!(
432                selection.address.starts_with("10.0.0."),
433                "Expected local target, got {}",
434                selection.address
435            );
436            assert_eq!(selection.metadata.get("locality").unwrap(), "local");
437        }
438    }
439
440    #[tokio::test]
441    async fn test_fallback_when_local_unhealthy() {
442        let targets = make_zoned_targets();
443        let config = LocalityAwareConfig {
444            local_zone: "us-west-1".to_string(),
445            min_local_healthy: 1,
446            ..Default::default()
447        };
448        let balancer = LocalityAwareBalancer::new(targets, config);
449
450        // Mark local targets as unhealthy
451        balancer.report_health("10.0.0.1:8080", false).await;
452        balancer.report_health("10.0.0.2:8080", false).await;
453
454        // Should now use fallback targets
455        let selection = balancer.select(None).await.unwrap();
456        assert!(
457            !selection.address.starts_with("10.0.0."),
458            "Expected fallback target, got {}",
459            selection.address
460        );
461        assert_eq!(selection.metadata.get("locality").unwrap(), "remote");
462    }
463
464    #[tokio::test]
465    async fn test_zone_priority() {
466        let targets = make_zoned_targets();
467        let config = LocalityAwareConfig {
468            local_zone: "us-west-1".to_string(),
469            min_local_healthy: 1,
470            zone_priority: vec!["us-east-1".to_string(), "eu-west-1".to_string()],
471            ..Default::default()
472        };
473        let balancer = LocalityAwareBalancer::new(targets, config);
474
475        // Mark local targets as unhealthy
476        balancer.report_health("10.0.0.1:8080", false).await;
477        balancer.report_health("10.0.0.2:8080", false).await;
478
479        // Should prefer us-east-1 over eu-west-1
480        let selection = balancer.select(None).await.unwrap();
481        assert!(
482            selection.address.starts_with("10.1.0."),
483            "Expected us-east-1 target, got {}",
484            selection.address
485        );
486    }
487
488    #[tokio::test]
489    async fn test_fail_local_strategy() {
490        let targets = make_zoned_targets();
491        let config = LocalityAwareConfig {
492            local_zone: "us-west-1".to_string(),
493            fallback_strategy: LocalityFallback::FailLocal,
494            ..Default::default()
495        };
496        let balancer = LocalityAwareBalancer::new(targets, config);
497
498        // Mark local targets as unhealthy
499        balancer.report_health("10.0.0.1:8080", false).await;
500        balancer.report_health("10.0.0.2:8080", false).await;
501
502        // Should fail
503        let result = balancer.select(None).await;
504        assert!(result.is_err());
505    }
506}