1use async_trait::async_trait;
8use rand::seq::IndexedRandom;
9use std::collections::HashMap;
10use std::sync::atomic::{AtomicUsize, Ordering};
11use std::sync::Arc;
12use tokio::sync::RwLock;
13use tracing::{debug, trace, warn};
14
15use grapsus_common::errors::{GrapsusError, GrapsusResult};
16
17use super::{LoadBalancer, RequestContext, TargetSelection, UpstreamTarget};
18
19#[derive(Debug, Clone)]
21pub struct LocalityAwareConfig {
22 pub local_zone: String,
24 pub fallback_strategy: LocalityFallback,
26 pub min_local_healthy: usize,
28 pub use_weights: bool,
30 pub zone_priority: Vec<String>,
33}
34
35impl Default for LocalityAwareConfig {
36 fn default() -> Self {
37 Self {
38 local_zone: std::env::var("GRAPSUS_ZONE")
39 .or_else(|_| std::env::var("ZONE"))
40 .or_else(|_| std::env::var("REGION"))
41 .unwrap_or_else(|_| "default".to_string()),
42 fallback_strategy: LocalityFallback::RoundRobin,
43 min_local_healthy: 1,
44 use_weights: true,
45 zone_priority: Vec::new(),
46 }
47 }
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum LocalityFallback {
53 RoundRobin,
55 Random,
57 FailLocal,
59}
60
61#[derive(Debug, Clone)]
63struct ZonedTarget {
64 target: UpstreamTarget,
65 zone: String,
66}
67
68pub struct LocalityAwareBalancer {
70 targets: Vec<ZonedTarget>,
72 health_status: Arc<RwLock<HashMap<String, bool>>>,
74 local_counter: AtomicUsize,
76 fallback_counter: AtomicUsize,
78 config: LocalityAwareConfig,
80}
81
82impl LocalityAwareBalancer {
83 pub fn new(targets: Vec<UpstreamTarget>, config: LocalityAwareConfig) -> Self {
90 let mut health_status = HashMap::new();
91 let mut zoned_targets = Vec::with_capacity(targets.len());
92
93 for target in targets {
94 health_status.insert(target.full_address(), true);
95
96 let (zone, actual_target) = Self::parse_zone_from_target(&target);
99
100 zoned_targets.push(ZonedTarget {
101 target: actual_target,
102 zone,
103 });
104 }
105
106 debug!(
107 local_zone = %config.local_zone,
108 total_targets = zoned_targets.len(),
109 local_targets = zoned_targets.iter().filter(|t| t.zone == config.local_zone).count(),
110 "Created locality-aware balancer"
111 );
112
113 Self {
114 targets: zoned_targets,
115 health_status: Arc::new(RwLock::new(health_status)),
116 local_counter: AtomicUsize::new(0),
117 fallback_counter: AtomicUsize::new(0),
118 config,
119 }
120 }
121
122 fn parse_zone_from_target(target: &UpstreamTarget) -> (String, UpstreamTarget) {
129 let addr = &target.address;
130
131 if let Some(rest) = addr.strip_prefix("zone=") {
133 if let Some((zone, host)) = rest.split_once(',') {
134 return (
135 zone.to_string(),
136 UpstreamTarget::new(host, target.port, target.weight),
137 );
138 }
139 }
140
141 if let Some((zone, host)) = addr.split_once('/') {
143 if !zone.contains(':') && !zone.contains('.') {
145 return (
146 zone.to_string(),
147 UpstreamTarget::new(host, target.port, target.weight),
148 );
149 }
150 }
151
152 ("unknown".to_string(), target.clone())
154 }
155
156 async fn healthy_in_zone(&self, zone: &str) -> Vec<&ZonedTarget> {
158 let health = self.health_status.read().await;
159 self.targets
160 .iter()
161 .filter(|t| t.zone == zone && *health.get(&t.target.full_address()).unwrap_or(&true))
162 .collect()
163 }
164
165 async fn healthy_fallback(&self) -> Vec<&ZonedTarget> {
167 let health = self.health_status.read().await;
168 let local_zone = &self.config.local_zone;
169
170 let mut fallback: Vec<_> = self
171 .targets
172 .iter()
173 .filter(|t| {
174 t.zone != *local_zone && *health.get(&t.target.full_address()).unwrap_or(&true)
175 })
176 .collect();
177
178 if !self.config.zone_priority.is_empty() {
180 fallback.sort_by(|a, b| {
181 let priority_a = self
182 .config
183 .zone_priority
184 .iter()
185 .position(|z| z == &a.zone)
186 .unwrap_or(usize::MAX);
187 let priority_b = self
188 .config
189 .zone_priority
190 .iter()
191 .position(|z| z == &b.zone)
192 .unwrap_or(usize::MAX);
193 priority_a.cmp(&priority_b)
194 });
195 }
196
197 fallback
198 }
199
200 fn select_round_robin<'a>(
202 &self,
203 targets: &[&'a ZonedTarget],
204 counter: &AtomicUsize,
205 ) -> Option<&'a ZonedTarget> {
206 if targets.is_empty() {
207 return None;
208 }
209
210 if self.config.use_weights {
211 let total_weight: u32 = targets.iter().map(|t| t.target.weight).sum();
213 if total_weight == 0 {
214 return targets.first().copied();
215 }
216
217 let idx = counter.fetch_add(1, Ordering::Relaxed);
218 let mut weight_idx = (idx as u32) % total_weight;
219
220 for target in targets {
221 if weight_idx < target.target.weight {
222 return Some(target);
223 }
224 weight_idx -= target.target.weight;
225 }
226
227 targets.first().copied()
228 } else {
229 let idx = counter.fetch_add(1, Ordering::Relaxed) % targets.len();
230 Some(targets[idx])
231 }
232 }
233
234 fn select_random<'a>(&self, targets: &[&'a ZonedTarget]) -> Option<&'a ZonedTarget> {
236 use rand::seq::SliceRandom;
237
238 if targets.is_empty() {
239 return None;
240 }
241
242 let mut rng = rand::rng();
243 targets.choose(&mut rng).copied()
244 }
245}
246
247#[async_trait]
248impl LoadBalancer for LocalityAwareBalancer {
249 async fn select(&self, _context: Option<&RequestContext>) -> GrapsusResult<TargetSelection> {
250 trace!(
251 total_targets = self.targets.len(),
252 local_zone = %self.config.local_zone,
253 algorithm = "locality_aware",
254 "Selecting upstream target"
255 );
256
257 let local_healthy = self.healthy_in_zone(&self.config.local_zone).await;
259
260 if local_healthy.len() >= self.config.min_local_healthy {
261 let selected = self
263 .select_round_robin(&local_healthy, &self.local_counter)
264 .ok_or(GrapsusError::NoHealthyUpstream)?;
265
266 trace!(
267 selected_target = %selected.target.full_address(),
268 zone = %selected.zone,
269 local_healthy = local_healthy.len(),
270 algorithm = "locality_aware",
271 "Selected local target"
272 );
273
274 return Ok(TargetSelection {
275 address: selected.target.full_address(),
276 weight: selected.target.weight,
277 metadata: {
278 let mut m = HashMap::new();
279 m.insert("zone".to_string(), selected.zone.clone());
280 m.insert("locality".to_string(), "local".to_string());
281 m
282 },
283 });
284 }
285
286 match self.config.fallback_strategy {
288 LocalityFallback::FailLocal => {
289 warn!(
290 local_zone = %self.config.local_zone,
291 local_healthy = local_healthy.len(),
292 min_required = self.config.min_local_healthy,
293 algorithm = "locality_aware",
294 "No healthy local targets and fallback disabled"
295 );
296 return Err(GrapsusError::NoHealthyUpstream);
297 }
298 LocalityFallback::RoundRobin | LocalityFallback::Random => {
299 }
301 }
302
303 let fallback_targets = self.healthy_fallback().await;
305
306 let all_targets: Vec<&ZonedTarget> = if !local_healthy.is_empty() {
308 local_healthy.into_iter().chain(fallback_targets).collect()
310 } else {
311 fallback_targets
312 };
313
314 if all_targets.is_empty() {
315 warn!(
316 total_targets = self.targets.len(),
317 algorithm = "locality_aware",
318 "No healthy upstream targets available"
319 );
320 return Err(GrapsusError::NoHealthyUpstream);
321 }
322
323 let selected = match self.config.fallback_strategy {
325 LocalityFallback::RoundRobin => {
326 self.select_round_robin(&all_targets, &self.fallback_counter)
327 }
328 LocalityFallback::Random => self.select_random(&all_targets),
329 LocalityFallback::FailLocal => unreachable!(),
330 }
331 .ok_or(GrapsusError::NoHealthyUpstream)?;
332
333 let is_local = selected.zone == self.config.local_zone;
334 debug!(
335 selected_target = %selected.target.full_address(),
336 zone = %selected.zone,
337 is_local = is_local,
338 fallback_used = !is_local,
339 algorithm = "locality_aware",
340 "Selected target (fallback path)"
341 );
342
343 Ok(TargetSelection {
344 address: selected.target.full_address(),
345 weight: selected.target.weight,
346 metadata: {
347 let mut m = HashMap::new();
348 m.insert("zone".to_string(), selected.zone.clone());
349 m.insert(
350 "locality".to_string(),
351 if is_local { "local" } else { "remote" }.to_string(),
352 );
353 m
354 },
355 })
356 }
357
358 async fn report_health(&self, address: &str, healthy: bool) {
359 trace!(
360 target = %address,
361 healthy = healthy,
362 algorithm = "locality_aware",
363 "Updating target health status"
364 );
365 self.health_status
366 .write()
367 .await
368 .insert(address.to_string(), healthy);
369 }
370
371 async fn healthy_targets(&self) -> Vec<String> {
372 self.health_status
373 .read()
374 .await
375 .iter()
376 .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
377 .collect()
378 }
379}
380
381#[cfg(test)]
382mod tests {
383 use super::*;
384
385 fn make_zoned_targets() -> Vec<UpstreamTarget> {
386 vec![
387 UpstreamTarget::new("zone=us-west-1,10.0.0.1", 8080, 100),
389 UpstreamTarget::new("zone=us-west-1,10.0.0.2", 8080, 100),
390 UpstreamTarget::new("zone=us-east-1,10.1.0.1", 8080, 100),
392 UpstreamTarget::new("zone=us-east-1,10.1.0.2", 8080, 100),
393 UpstreamTarget::new("zone=eu-west-1,10.2.0.1", 8080, 100),
395 ]
396 }
397
398 #[test]
399 fn test_zone_parsing() {
400 let target = UpstreamTarget::new("zone=us-west-1,10.0.0.1", 8080, 100);
402 let (zone, parsed) = LocalityAwareBalancer::parse_zone_from_target(&target);
403 assert_eq!(zone, "us-west-1");
404 assert_eq!(parsed.address, "10.0.0.1");
405
406 let target = UpstreamTarget::new("us-east-1/10.0.0.1", 8080, 100);
408 let (zone, parsed) = LocalityAwareBalancer::parse_zone_from_target(&target);
409 assert_eq!(zone, "us-east-1");
410 assert_eq!(parsed.address, "10.0.0.1");
411
412 let target = UpstreamTarget::new("10.0.0.1", 8080, 100);
414 let (zone, parsed) = LocalityAwareBalancer::parse_zone_from_target(&target);
415 assert_eq!(zone, "unknown");
416 assert_eq!(parsed.address, "10.0.0.1");
417 }
418
419 #[tokio::test]
420 async fn test_prefers_local_zone() {
421 let targets = make_zoned_targets();
422 let config = LocalityAwareConfig {
423 local_zone: "us-west-1".to_string(),
424 ..Default::default()
425 };
426 let balancer = LocalityAwareBalancer::new(targets, config);
427
428 for _ in 0..10 {
430 let selection = balancer.select(None).await.unwrap();
431 assert!(
432 selection.address.starts_with("10.0.0."),
433 "Expected local target, got {}",
434 selection.address
435 );
436 assert_eq!(selection.metadata.get("locality").unwrap(), "local");
437 }
438 }
439
440 #[tokio::test]
441 async fn test_fallback_when_local_unhealthy() {
442 let targets = make_zoned_targets();
443 let config = LocalityAwareConfig {
444 local_zone: "us-west-1".to_string(),
445 min_local_healthy: 1,
446 ..Default::default()
447 };
448 let balancer = LocalityAwareBalancer::new(targets, config);
449
450 balancer.report_health("10.0.0.1:8080", false).await;
452 balancer.report_health("10.0.0.2:8080", false).await;
453
454 let selection = balancer.select(None).await.unwrap();
456 assert!(
457 !selection.address.starts_with("10.0.0."),
458 "Expected fallback target, got {}",
459 selection.address
460 );
461 assert_eq!(selection.metadata.get("locality").unwrap(), "remote");
462 }
463
464 #[tokio::test]
465 async fn test_zone_priority() {
466 let targets = make_zoned_targets();
467 let config = LocalityAwareConfig {
468 local_zone: "us-west-1".to_string(),
469 min_local_healthy: 1,
470 zone_priority: vec!["us-east-1".to_string(), "eu-west-1".to_string()],
471 ..Default::default()
472 };
473 let balancer = LocalityAwareBalancer::new(targets, config);
474
475 balancer.report_health("10.0.0.1:8080", false).await;
477 balancer.report_health("10.0.0.2:8080", false).await;
478
479 let selection = balancer.select(None).await.unwrap();
481 assert!(
482 selection.address.starts_with("10.1.0."),
483 "Expected us-east-1 target, got {}",
484 selection.address
485 );
486 }
487
488 #[tokio::test]
489 async fn test_fail_local_strategy() {
490 let targets = make_zoned_targets();
491 let config = LocalityAwareConfig {
492 local_zone: "us-west-1".to_string(),
493 fallback_strategy: LocalityFallback::FailLocal,
494 ..Default::default()
495 };
496 let balancer = LocalityAwareBalancer::new(targets, config);
497
498 balancer.report_health("10.0.0.1:8080", false).await;
500 balancer.report_health("10.0.0.2:8080", false).await;
501
502 let result = balancer.select(None).await;
504 assert!(result.is_err());
505 }
506}