1use dashmap::DashMap;
20use std::net::SocketAddr;
21use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
22use std::sync::Arc;
23use std::time::Duration;
24use tracing::{debug, warn};
25
26const UNHEALTHY_THRESHOLD: u64 = 3;
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum LbStrategy {
36 RoundRobin,
38 LeastConnections,
40}
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum HealthStatus {
49 Healthy,
50 Unhealthy,
51}
52
53pub struct Backend {
59 pub addr: SocketAddr,
61 active_connections: AtomicU64,
63 health: std::sync::RwLock<HealthStatus>,
65 consecutive_failures: AtomicU64,
67}
68
69impl std::fmt::Debug for Backend {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 f.debug_struct("Backend")
72 .field("addr", &self.addr)
73 .field(
74 "active_connections",
75 &self.active_connections.load(Ordering::Relaxed),
76 )
77 .field("health", &*self.health.read().unwrap())
78 .field(
79 "consecutive_failures",
80 &self.consecutive_failures.load(Ordering::Relaxed),
81 )
82 .finish()
83 }
84}
85
86impl Backend {
87 #[must_use]
89 pub fn new(addr: SocketAddr) -> Self {
90 Self {
91 addr,
92 active_connections: AtomicU64::new(0),
93 health: std::sync::RwLock::new(HealthStatus::Healthy),
94 consecutive_failures: AtomicU64::new(0),
95 }
96 }
97
98 pub fn track_connection(self: &Arc<Self>) -> ConnectionGuard {
101 self.active_connections.fetch_add(1, Ordering::Relaxed);
102 ConnectionGuard {
103 backend: Arc::clone(self),
104 }
105 }
106
107 pub fn active_connections(&self) -> u64 {
109 self.active_connections.load(Ordering::Relaxed)
110 }
111
112 pub fn is_healthy(&self) -> bool {
118 *self.health.read().unwrap() == HealthStatus::Healthy
119 }
120
121 pub fn set_healthy(&self) {
127 *self.health.write().unwrap() = HealthStatus::Healthy;
128 }
129
130 pub fn set_unhealthy(&self) {
136 *self.health.write().unwrap() = HealthStatus::Unhealthy;
137 }
138
139 pub fn record_failure(&self) {
141 self.consecutive_failures.fetch_add(1, Ordering::Relaxed);
142 }
143
144 pub fn reset_failures(&self) {
146 self.consecutive_failures.store(0, Ordering::Relaxed);
147 }
148
149 pub fn consecutive_failures(&self) -> u64 {
151 self.consecutive_failures.load(Ordering::Relaxed)
152 }
153}
154
155pub struct ConnectionGuard {
161 backend: Arc<Backend>,
162}
163
164impl Drop for ConnectionGuard {
165 fn drop(&mut self) {
166 self.backend
167 .active_connections
168 .fetch_sub(1, Ordering::Relaxed);
169 }
170}
171
172pub struct BackendGroup {
178 pub backends: Vec<Arc<Backend>>,
180 pub strategy: LbStrategy,
182 rr_counter: AtomicUsize,
184}
185
186impl BackendGroup {
187 #[must_use]
189 pub fn new(strategy: LbStrategy) -> Self {
190 Self {
191 backends: Vec::new(),
192 strategy,
193 rr_counter: AtomicUsize::new(0),
194 }
195 }
196
197 pub fn select(&self) -> Option<Arc<Backend>> {
201 if self.backends.is_empty() {
202 return None;
203 }
204
205 match self.strategy {
206 LbStrategy::RoundRobin => self.select_round_robin(),
207 LbStrategy::LeastConnections => self.select_least_connections(),
208 }
209 }
210
211 fn select_round_robin(&self) -> Option<Arc<Backend>> {
214 let len = self.backends.len();
215 let start = self.rr_counter.fetch_add(1, Ordering::Relaxed) % len;
216
217 for i in 0..len {
218 let idx = (start + i) % len;
219 let backend = &self.backends[idx];
220 if backend.is_healthy() {
221 return Some(Arc::clone(backend));
222 }
223 }
224
225 None
226 }
227
228 fn select_least_connections(&self) -> Option<Arc<Backend>> {
231 self.backends
232 .iter()
233 .filter(|b| b.is_healthy())
234 .min_by_key(|b| b.active_connections())
235 .cloned()
236 }
237
238 pub fn update_backends(&mut self, addrs: Vec<SocketAddr>) {
245 let mut new_backends = Vec::with_capacity(addrs.len());
246
247 for addr in addrs {
248 if let Some(existing) = self.backends.iter().find(|b| b.addr == addr) {
249 new_backends.push(Arc::clone(existing));
250 } else {
251 new_backends.push(Arc::new(Backend::new(addr)));
252 }
253 }
254
255 self.backends = new_backends;
256 }
257
258 pub fn add_backend(&mut self, addr: SocketAddr) {
260 if !self.backends.iter().any(|b| b.addr == addr) {
261 self.backends.push(Arc::new(Backend::new(addr)));
262 }
263 }
264
265 pub fn remove_backend(&mut self, addr: &SocketAddr) {
267 self.backends.retain(|b| b.addr != *addr);
268 }
269}
270
271#[derive(Debug, Clone)]
277pub struct BackendSnapshot {
278 pub addr: SocketAddr,
280 pub healthy: bool,
282 pub active_connections: u64,
284 pub consecutive_failures: u64,
286}
287
288#[derive(Debug, Clone)]
290pub struct BackendGroupSnapshot {
291 pub strategy: LbStrategy,
293 pub backends: Vec<BackendSnapshot>,
295}
296
297pub struct LoadBalancer {
299 groups: DashMap<String, BackendGroup>,
300}
301
302impl Default for LoadBalancer {
303 fn default() -> Self {
304 Self::new()
305 }
306}
307
308impl LoadBalancer {
309 #[must_use]
311 pub fn new() -> Self {
312 Self {
313 groups: DashMap::new(),
314 }
315 }
316
317 pub fn register(&self, service: &str, addrs: Vec<SocketAddr>, strategy: LbStrategy) {
319 let mut group = BackendGroup::new(strategy);
320 group.backends = addrs
321 .into_iter()
322 .map(|a| Arc::new(Backend::new(a)))
323 .collect();
324 self.groups.insert(service.to_string(), group);
325 }
326
327 #[must_use]
330 pub fn select(&self, service: &str) -> Option<Arc<Backend>> {
331 self.groups.get(service).and_then(|g| g.select())
332 }
333
334 pub fn update_backends(&self, service: &str, addrs: Vec<SocketAddr>) {
337 if let Some(mut group) = self.groups.get_mut(service) {
338 group.update_backends(addrs);
339 }
340 }
341
342 pub fn unregister(&self, service: &str) {
344 self.groups.remove(service);
345 }
346
347 pub fn add_backend(&self, service: &str, addr: SocketAddr) {
349 if let Some(mut group) = self.groups.get_mut(service) {
350 group.add_backend(addr);
351 debug!(service = service, backend = %addr, total = group.backends.len(), "Added backend to LB group");
352 } else {
353 warn!(service = service, backend = %addr, "Cannot add backend: LB group not registered");
354 }
355 }
356
357 pub fn remove_backend(&self, service: &str, addr: &SocketAddr) {
359 if let Some(mut group) = self.groups.get_mut(service) {
360 group.remove_backend(addr);
361 }
362 }
363
364 #[must_use]
367 pub fn backend_count(&self, service: &str) -> usize {
368 self.groups.get(service).map_or(0, |g| g.backends.len())
369 }
370
371 #[must_use]
374 pub fn healthy_count(&self, service: &str) -> usize {
375 self.groups
376 .get(service)
377 .map_or(0, |g| g.backends.iter().filter(|b| b.is_healthy()).count())
378 }
379
380 pub fn mark_health(&self, service: &str, addr: &SocketAddr, healthy: bool) {
385 if let Some(group) = self.groups.get(service) {
386 if let Some(backend) = group.backends.iter().find(|b| b.addr == *addr) {
387 if healthy {
388 backend.set_healthy();
389 backend.reset_failures();
390 } else {
391 backend.set_unhealthy();
392 backend.record_failure();
393 }
394 }
395 }
396 }
397
398 #[must_use]
400 pub fn list_service_names(&self) -> Vec<String> {
401 self.groups.iter().map(|e| e.key().clone()).collect()
402 }
403
404 #[must_use]
409 pub fn group_snapshot(&self, service: &str) -> Option<BackendGroupSnapshot> {
410 self.groups.get(service).map(|g| BackendGroupSnapshot {
411 strategy: g.strategy,
412 backends: g
413 .backends
414 .iter()
415 .map(|b| BackendSnapshot {
416 addr: b.addr,
417 healthy: b.is_healthy(),
418 active_connections: b.active_connections(),
419 consecutive_failures: b.consecutive_failures(),
420 })
421 .collect(),
422 })
423 }
424
425 #[must_use]
440 pub fn spawn_health_checker(
441 self: &Arc<Self>,
442 interval: Duration,
443 timeout: Duration,
444 ) -> tokio::task::JoinHandle<()> {
445 let lb = Arc::clone(self);
446
447 tokio::spawn(async move {
448 let semaphore = Arc::new(tokio::sync::Semaphore::new(64));
449
450 loop {
451 let backends: Vec<Arc<Backend>> = lb
453 .groups
454 .iter()
455 .flat_map(|entry| entry.value().backends.clone())
456 .collect();
457
458 debug!(
459 backend_count = backends.len(),
460 "Starting health check sweep"
461 );
462
463 let mut handles = Vec::with_capacity(backends.len());
464
465 for backend in backends {
466 let sem = Arc::clone(&semaphore);
467 let probe_timeout = timeout;
468
469 handles.push(tokio::spawn(async move {
470 let _permit = sem.acquire().await.expect("semaphore closed");
471 let addr = backend.addr;
472
473 match tokio::time::timeout(
474 probe_timeout,
475 tokio::net::TcpStream::connect(addr),
476 )
477 .await
478 {
479 Ok(Ok(_stream)) => {
480 if !backend.is_healthy() {
481 debug!(%addr, "Backend recovered");
482 }
483 backend.set_healthy();
484 backend.reset_failures();
485 }
486 Ok(Err(e)) => {
487 backend.record_failure();
488 let failures = backend.consecutive_failures();
489 if failures >= UNHEALTHY_THRESHOLD {
490 if backend.is_healthy() {
491 warn!(
492 %addr,
493 error = %e,
494 failures,
495 "Backend marked unhealthy after consecutive failures"
496 );
497 }
498 backend.set_unhealthy();
499 } else {
500 debug!(
501 %addr,
502 error = %e,
503 failures,
504 "Health check failed ({failures}/{UNHEALTHY_THRESHOLD} before unhealthy)"
505 );
506 }
507 }
508 Err(_elapsed) => {
509 backend.record_failure();
510 let failures = backend.consecutive_failures();
511 if failures >= UNHEALTHY_THRESHOLD {
512 if backend.is_healthy() {
513 warn!(
514 %addr,
515 failures,
516 "Backend marked unhealthy after consecutive timeout failures"
517 );
518 }
519 backend.set_unhealthy();
520 } else {
521 debug!(
522 %addr,
523 failures,
524 "Health check timed out ({failures}/{UNHEALTHY_THRESHOLD} before unhealthy)"
525 );
526 }
527 }
528 }
529 }));
530 }
531
532 for handle in handles {
534 let _ = handle.await;
535 }
536
537 tokio::time::sleep(interval).await;
538 }
539 })
540 }
541}
542
543#[cfg(test)]
548mod tests {
549 use super::*;
550
551 fn addr(port: u16) -> SocketAddr {
552 format!("127.0.0.1:{port}").parse().unwrap()
553 }
554
555 #[test]
556 fn test_round_robin_selection() {
557 let mut group = BackendGroup::new(LbStrategy::RoundRobin);
558 group.backends = vec![
559 Arc::new(Backend::new(addr(8001))),
560 Arc::new(Backend::new(addr(8002))),
561 Arc::new(Backend::new(addr(8003))),
562 ];
563
564 let a = group.select().unwrap();
565 let b = group.select().unwrap();
566 let c = group.select().unwrap();
567 let d = group.select().unwrap();
568
569 assert_eq!(a.addr, addr(8001));
570 assert_eq!(b.addr, addr(8002));
571 assert_eq!(c.addr, addr(8003));
572 assert_eq!(d.addr, addr(8001)); }
574
575 #[test]
576 fn test_least_connections_selection() {
577 let mut group = BackendGroup::new(LbStrategy::LeastConnections);
578 let b1 = Arc::new(Backend::new(addr(8001)));
579 let b2 = Arc::new(Backend::new(addr(8002)));
580 let b3 = Arc::new(Backend::new(addr(8003)));
581
582 let _guard = b1.track_connection();
584
585 group.backends = vec![b1, Arc::clone(&b2), b3];
586
587 let selected = group.select().unwrap();
589 assert_ne!(selected.addr, addr(8001));
590 assert!(selected.addr == addr(8002) || selected.addr == addr(8003));
591
592 let _guard2 = b2.track_connection();
594 let selected = group.select().unwrap();
595 assert_eq!(selected.addr, addr(8003));
596 }
597
598 #[test]
599 fn test_unhealthy_backends_skipped() {
600 let mut group = BackendGroup::new(LbStrategy::RoundRobin);
601 let b1 = Arc::new(Backend::new(addr(8001)));
602 let b2 = Arc::new(Backend::new(addr(8002)));
603 let b3 = Arc::new(Backend::new(addr(8003)));
604
605 b2.set_unhealthy();
606
607 group.backends = vec![b1, b2, Arc::clone(&b3)];
608
609 for _ in 0..10 {
611 let selected = group.select().unwrap();
612 assert_ne!(selected.addr, addr(8002), "Unhealthy backend was selected");
613 }
614 }
615
616 #[test]
617 fn test_connection_guard_decrement() {
618 let backend = Arc::new(Backend::new(addr(9000)));
619 assert_eq!(backend.active_connections(), 0);
620
621 let guard1 = backend.track_connection();
622 assert_eq!(backend.active_connections(), 1);
623
624 let guard2 = backend.track_connection();
625 assert_eq!(backend.active_connections(), 2);
626
627 drop(guard1);
628 assert_eq!(backend.active_connections(), 1);
629
630 drop(guard2);
631 assert_eq!(backend.active_connections(), 0);
632 }
633
634 #[test]
635 fn test_update_backends_preserves_state() {
636 let mut group = BackendGroup::new(LbStrategy::RoundRobin);
637 let b1 = Arc::new(Backend::new(addr(8001)));
638 let b2 = Arc::new(Backend::new(addr(8002)));
639
640 let _guard = b1.track_connection();
642 b2.set_unhealthy();
643
644 group.backends = vec![Arc::clone(&b1), Arc::clone(&b2)];
645
646 group.update_backends(vec![addr(8001), addr(8003)]);
648
649 assert_eq!(group.backends.len(), 2);
650
651 let preserved = group
653 .backends
654 .iter()
655 .find(|b| b.addr == addr(8001))
656 .unwrap();
657 assert_eq!(preserved.active_connections(), 1);
658
659 let new_backend = group
661 .backends
662 .iter()
663 .find(|b| b.addr == addr(8003))
664 .unwrap();
665 assert_eq!(new_backend.active_connections(), 0);
666 assert!(new_backend.is_healthy());
667
668 assert!(group.backends.iter().all(|b| b.addr != addr(8002)));
670 }
671
672 #[test]
673 fn test_all_unhealthy_returns_none() {
674 let mut group = BackendGroup::new(LbStrategy::RoundRobin);
675 let b1 = Arc::new(Backend::new(addr(8001)));
676 let b2 = Arc::new(Backend::new(addr(8002)));
677
678 b1.set_unhealthy();
679 b2.set_unhealthy();
680
681 group.backends = vec![b1, b2];
682
683 assert!(group.select().is_none());
684
685 group.strategy = LbStrategy::LeastConnections;
687 assert!(group.select().is_none());
688 }
689
690 #[test]
691 fn test_register_and_select() {
692 let lb = LoadBalancer::new();
693 lb.register("web", vec![addr(8080), addr(8081)], LbStrategy::RoundRobin);
694
695 let backend = lb.select("web").unwrap();
696 assert!(backend.addr == addr(8080) || backend.addr == addr(8081));
697
698 assert!(lb.select("nonexistent").is_none());
700 }
701
702 #[test]
703 fn test_add_remove_backend() {
704 let lb = LoadBalancer::new();
705 lb.register("api", vec![addr(9001)], LbStrategy::RoundRobin);
706
707 lb.add_backend("api", addr(9002));
709
710 {
711 let group = lb.groups.get("api").unwrap();
712 assert_eq!(group.backends.len(), 2);
713 }
714
715 lb.add_backend("api", addr(9002));
717 {
718 let group = lb.groups.get("api").unwrap();
719 assert_eq!(group.backends.len(), 2);
720 }
721
722 lb.remove_backend("api", &addr(9001));
724 {
725 let group = lb.groups.get("api").unwrap();
726 assert_eq!(group.backends.len(), 1);
727 assert_eq!(group.backends[0].addr, addr(9002));
728 }
729 }
730
731 #[test]
732 fn test_unregister() {
733 let lb = LoadBalancer::new();
734 lb.register("svc", vec![addr(5000)], LbStrategy::RoundRobin);
735 assert!(lb.select("svc").is_some());
736
737 lb.unregister("svc");
738 assert!(lb.select("svc").is_none());
739 }
740
741 #[test]
742 fn test_update_backends_via_lb() {
743 let lb = LoadBalancer::new();
744 lb.register("svc", vec![addr(3000)], LbStrategy::RoundRobin);
745
746 lb.update_backends("svc", vec![addr(3001), addr(3002)]);
747
748 let group = lb.groups.get("svc").unwrap();
749 assert_eq!(group.backends.len(), 2);
750 assert!(group.backends.iter().any(|b| b.addr == addr(3001)));
751 assert!(group.backends.iter().any(|b| b.addr == addr(3002)));
752 }
753
754 #[test]
755 fn test_empty_group_returns_none() {
756 let group = BackendGroup::new(LbStrategy::RoundRobin);
757 assert!(group.select().is_none());
758
759 let group_lc = BackendGroup::new(LbStrategy::LeastConnections);
760 assert!(group_lc.select().is_none());
761 }
762
763 #[test]
764 fn test_failure_tracking() {
765 let backend = Backend::new(addr(7000));
766 assert_eq!(backend.consecutive_failures(), 0);
767
768 backend.record_failure();
769 backend.record_failure();
770 assert_eq!(backend.consecutive_failures(), 2);
771
772 backend.reset_failures();
773 assert_eq!(backend.consecutive_failures(), 0);
774 }
775
776 #[test]
777 fn test_health_transitions() {
778 let backend = Backend::new(addr(7001));
779 assert!(backend.is_healthy());
780
781 backend.set_unhealthy();
782 assert!(!backend.is_healthy());
783
784 backend.set_healthy();
785 assert!(backend.is_healthy());
786 }
787}