1use async_trait::async_trait;
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::sync::Arc;
11use tokio::sync::RwLock;
12use tracing::{debug, trace, warn};
13
14use sentinel_common::errors::{SentinelError, SentinelResult};
15
16use super::{LoadBalancer, RequestContext, TargetSelection, UpstreamTarget};
17
18#[derive(Debug, Clone)]
20pub struct LocalityAwareConfig {
21 pub local_zone: String,
23 pub fallback_strategy: LocalityFallback,
25 pub min_local_healthy: usize,
27 pub use_weights: bool,
29 pub zone_priority: Vec<String>,
32}
33
34impl Default for LocalityAwareConfig {
35 fn default() -> Self {
36 Self {
37 local_zone: std::env::var("SENTINEL_ZONE")
38 .or_else(|_| std::env::var("ZONE"))
39 .or_else(|_| std::env::var("REGION"))
40 .unwrap_or_else(|_| "default".to_string()),
41 fallback_strategy: LocalityFallback::RoundRobin,
42 min_local_healthy: 1,
43 use_weights: true,
44 zone_priority: Vec::new(),
45 }
46 }
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
51pub enum LocalityFallback {
52 RoundRobin,
54 Random,
56 FailLocal,
58}
59
60#[derive(Debug, Clone)]
62struct ZonedTarget {
63 target: UpstreamTarget,
64 zone: String,
65}
66
67pub struct LocalityAwareBalancer {
69 targets: Vec<ZonedTarget>,
71 health_status: Arc<RwLock<HashMap<String, bool>>>,
73 local_counter: AtomicUsize,
75 fallback_counter: AtomicUsize,
77 config: LocalityAwareConfig,
79}
80
81impl LocalityAwareBalancer {
82 pub fn new(targets: Vec<UpstreamTarget>, config: LocalityAwareConfig) -> Self {
89 let mut health_status = HashMap::new();
90 let mut zoned_targets = Vec::with_capacity(targets.len());
91
92 for target in targets {
93 health_status.insert(target.full_address(), true);
94
95 let (zone, actual_target) = Self::parse_zone_from_target(&target);
98
99 zoned_targets.push(ZonedTarget {
100 target: actual_target,
101 zone,
102 });
103 }
104
105 debug!(
106 local_zone = %config.local_zone,
107 total_targets = zoned_targets.len(),
108 local_targets = zoned_targets.iter().filter(|t| t.zone == config.local_zone).count(),
109 "Created locality-aware balancer"
110 );
111
112 Self {
113 targets: zoned_targets,
114 health_status: Arc::new(RwLock::new(health_status)),
115 local_counter: AtomicUsize::new(0),
116 fallback_counter: AtomicUsize::new(0),
117 config,
118 }
119 }
120
121 fn parse_zone_from_target(target: &UpstreamTarget) -> (String, UpstreamTarget) {
128 let addr = &target.address;
129
130 if let Some(rest) = addr.strip_prefix("zone=") {
132 if let Some((zone, host)) = rest.split_once(',') {
133 return (
134 zone.to_string(),
135 UpstreamTarget::new(host, target.port, target.weight),
136 );
137 }
138 }
139
140 if let Some((zone, host)) = addr.split_once('/') {
142 if !zone.contains(':') && !zone.contains('.') {
144 return (
145 zone.to_string(),
146 UpstreamTarget::new(host, target.port, target.weight),
147 );
148 }
149 }
150
151 ("unknown".to_string(), target.clone())
153 }
154
155 async fn healthy_in_zone(&self, zone: &str) -> Vec<&ZonedTarget> {
157 let health = self.health_status.read().await;
158 self.targets
159 .iter()
160 .filter(|t| {
161 t.zone == zone && *health.get(&t.target.full_address()).unwrap_or(&true)
162 })
163 .collect()
164 }
165
166 async fn healthy_fallback(&self) -> Vec<&ZonedTarget> {
168 let health = self.health_status.read().await;
169 let local_zone = &self.config.local_zone;
170
171 let mut fallback: Vec<_> = self
172 .targets
173 .iter()
174 .filter(|t| {
175 t.zone != *local_zone && *health.get(&t.target.full_address()).unwrap_or(&true)
176 })
177 .collect();
178
179 if !self.config.zone_priority.is_empty() {
181 fallback.sort_by(|a, b| {
182 let priority_a = self
183 .config
184 .zone_priority
185 .iter()
186 .position(|z| z == &a.zone)
187 .unwrap_or(usize::MAX);
188 let priority_b = self
189 .config
190 .zone_priority
191 .iter()
192 .position(|z| z == &b.zone)
193 .unwrap_or(usize::MAX);
194 priority_a.cmp(&priority_b)
195 });
196 }
197
198 fallback
199 }
200
201 fn select_round_robin<'a>(
203 &self,
204 targets: &[&'a ZonedTarget],
205 counter: &AtomicUsize,
206 ) -> Option<&'a ZonedTarget> {
207 if targets.is_empty() {
208 return None;
209 }
210
211 if self.config.use_weights {
212 let total_weight: u32 = targets.iter().map(|t| t.target.weight).sum();
214 if total_weight == 0 {
215 return targets.first().copied();
216 }
217
218 let idx = counter.fetch_add(1, Ordering::Relaxed);
219 let mut weight_idx = (idx as u32) % total_weight;
220
221 for target in targets {
222 if weight_idx < target.target.weight {
223 return Some(target);
224 }
225 weight_idx -= target.target.weight;
226 }
227
228 targets.first().copied()
229 } else {
230 let idx = counter.fetch_add(1, Ordering::Relaxed) % targets.len();
231 Some(targets[idx])
232 }
233 }
234
235 fn select_random<'a>(&self, targets: &[&'a ZonedTarget]) -> Option<&'a ZonedTarget> {
237 use rand::seq::SliceRandom;
238
239 if targets.is_empty() {
240 return None;
241 }
242
243 let mut rng = rand::thread_rng();
244 targets.choose(&mut rng).copied()
245 }
246}
247
248#[async_trait]
249impl LoadBalancer for LocalityAwareBalancer {
250 async fn select(&self, _context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
251 trace!(
252 total_targets = self.targets.len(),
253 local_zone = %self.config.local_zone,
254 algorithm = "locality_aware",
255 "Selecting upstream target"
256 );
257
258 let local_healthy = self.healthy_in_zone(&self.config.local_zone).await;
260
261 if local_healthy.len() >= self.config.min_local_healthy {
262 let selected = self
264 .select_round_robin(&local_healthy, &self.local_counter)
265 .ok_or(SentinelError::NoHealthyUpstream)?;
266
267 trace!(
268 selected_target = %selected.target.full_address(),
269 zone = %selected.zone,
270 local_healthy = local_healthy.len(),
271 algorithm = "locality_aware",
272 "Selected local target"
273 );
274
275 return Ok(TargetSelection {
276 address: selected.target.full_address(),
277 weight: selected.target.weight,
278 metadata: {
279 let mut m = HashMap::new();
280 m.insert("zone".to_string(), selected.zone.clone());
281 m.insert("locality".to_string(), "local".to_string());
282 m
283 },
284 });
285 }
286
287 match self.config.fallback_strategy {
289 LocalityFallback::FailLocal => {
290 warn!(
291 local_zone = %self.config.local_zone,
292 local_healthy = local_healthy.len(),
293 min_required = self.config.min_local_healthy,
294 algorithm = "locality_aware",
295 "No healthy local targets and fallback disabled"
296 );
297 return Err(SentinelError::NoHealthyUpstream);
298 }
299 LocalityFallback::RoundRobin | LocalityFallback::Random => {
300 }
302 }
303
304 let fallback_targets = self.healthy_fallback().await;
306
307 let all_targets: Vec<&ZonedTarget> = if !local_healthy.is_empty() {
309 local_healthy
311 .into_iter()
312 .chain(fallback_targets.into_iter())
313 .collect()
314 } else {
315 fallback_targets
316 };
317
318 if all_targets.is_empty() {
319 warn!(
320 total_targets = self.targets.len(),
321 algorithm = "locality_aware",
322 "No healthy upstream targets available"
323 );
324 return Err(SentinelError::NoHealthyUpstream);
325 }
326
327 let selected = match self.config.fallback_strategy {
329 LocalityFallback::RoundRobin => {
330 self.select_round_robin(&all_targets, &self.fallback_counter)
331 }
332 LocalityFallback::Random => self.select_random(&all_targets),
333 LocalityFallback::FailLocal => unreachable!(),
334 }
335 .ok_or(SentinelError::NoHealthyUpstream)?;
336
337 let is_local = selected.zone == self.config.local_zone;
338 debug!(
339 selected_target = %selected.target.full_address(),
340 zone = %selected.zone,
341 is_local = is_local,
342 fallback_used = !is_local,
343 algorithm = "locality_aware",
344 "Selected target (fallback path)"
345 );
346
347 Ok(TargetSelection {
348 address: selected.target.full_address(),
349 weight: selected.target.weight,
350 metadata: {
351 let mut m = HashMap::new();
352 m.insert("zone".to_string(), selected.zone.clone());
353 m.insert(
354 "locality".to_string(),
355 if is_local { "local" } else { "remote" }.to_string(),
356 );
357 m
358 },
359 })
360 }
361
362 async fn report_health(&self, address: &str, healthy: bool) {
363 trace!(
364 target = %address,
365 healthy = healthy,
366 algorithm = "locality_aware",
367 "Updating target health status"
368 );
369 self.health_status
370 .write()
371 .await
372 .insert(address.to_string(), healthy);
373 }
374
375 async fn healthy_targets(&self) -> Vec<String> {
376 self.health_status
377 .read()
378 .await
379 .iter()
380 .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
381 .collect()
382 }
383}
384
385#[cfg(test)]
386mod tests {
387 use super::*;
388
389 fn make_zoned_targets() -> Vec<UpstreamTarget> {
390 vec![
391 UpstreamTarget::new("zone=us-west-1,10.0.0.1", 8080, 100),
393 UpstreamTarget::new("zone=us-west-1,10.0.0.2", 8080, 100),
394 UpstreamTarget::new("zone=us-east-1,10.1.0.1", 8080, 100),
396 UpstreamTarget::new("zone=us-east-1,10.1.0.2", 8080, 100),
397 UpstreamTarget::new("zone=eu-west-1,10.2.0.1", 8080, 100),
399 ]
400 }
401
402 #[test]
403 fn test_zone_parsing() {
404 let target = UpstreamTarget::new("zone=us-west-1,10.0.0.1", 8080, 100);
406 let (zone, parsed) = LocalityAwareBalancer::parse_zone_from_target(&target);
407 assert_eq!(zone, "us-west-1");
408 assert_eq!(parsed.address, "10.0.0.1");
409
410 let target = UpstreamTarget::new("us-east-1/10.0.0.1", 8080, 100);
412 let (zone, parsed) = LocalityAwareBalancer::parse_zone_from_target(&target);
413 assert_eq!(zone, "us-east-1");
414 assert_eq!(parsed.address, "10.0.0.1");
415
416 let target = UpstreamTarget::new("10.0.0.1", 8080, 100);
418 let (zone, parsed) = LocalityAwareBalancer::parse_zone_from_target(&target);
419 assert_eq!(zone, "unknown");
420 assert_eq!(parsed.address, "10.0.0.1");
421 }
422
423 #[tokio::test]
424 async fn test_prefers_local_zone() {
425 let targets = make_zoned_targets();
426 let config = LocalityAwareConfig {
427 local_zone: "us-west-1".to_string(),
428 ..Default::default()
429 };
430 let balancer = LocalityAwareBalancer::new(targets, config);
431
432 for _ in 0..10 {
434 let selection = balancer.select(None).await.unwrap();
435 assert!(
436 selection.address.starts_with("10.0.0."),
437 "Expected local target, got {}",
438 selection.address
439 );
440 assert_eq!(selection.metadata.get("locality").unwrap(), "local");
441 }
442 }
443
444 #[tokio::test]
445 async fn test_fallback_when_local_unhealthy() {
446 let targets = make_zoned_targets();
447 let config = LocalityAwareConfig {
448 local_zone: "us-west-1".to_string(),
449 min_local_healthy: 1,
450 ..Default::default()
451 };
452 let balancer = LocalityAwareBalancer::new(targets, config);
453
454 balancer.report_health("10.0.0.1:8080", false).await;
456 balancer.report_health("10.0.0.2:8080", false).await;
457
458 let selection = balancer.select(None).await.unwrap();
460 assert!(
461 !selection.address.starts_with("10.0.0."),
462 "Expected fallback target, got {}",
463 selection.address
464 );
465 assert_eq!(selection.metadata.get("locality").unwrap(), "remote");
466 }
467
468 #[tokio::test]
469 async fn test_zone_priority() {
470 let targets = make_zoned_targets();
471 let config = LocalityAwareConfig {
472 local_zone: "us-west-1".to_string(),
473 min_local_healthy: 1,
474 zone_priority: vec!["us-east-1".to_string(), "eu-west-1".to_string()],
475 ..Default::default()
476 };
477 let balancer = LocalityAwareBalancer::new(targets, config);
478
479 balancer.report_health("10.0.0.1:8080", false).await;
481 balancer.report_health("10.0.0.2:8080", false).await;
482
483 let selection = balancer.select(None).await.unwrap();
485 assert!(
486 selection.address.starts_with("10.1.0."),
487 "Expected us-east-1 target, got {}",
488 selection.address
489 );
490 }
491
492 #[tokio::test]
493 async fn test_fail_local_strategy() {
494 let targets = make_zoned_targets();
495 let config = LocalityAwareConfig {
496 local_zone: "us-west-1".to_string(),
497 fallback_strategy: LocalityFallback::FailLocal,
498 ..Default::default()
499 };
500 let balancer = LocalityAwareBalancer::new(targets, config);
501
502 balancer.report_health("10.0.0.1:8080", false).await;
504 balancer.report_health("10.0.0.2:8080", false).await;
505
506 let result = balancer.select(None).await;
508 assert!(result.is_err());
509 }
510}