use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
use crate::config::ValidatedUpstream;
#[derive(Debug, Clone)]
pub struct UpstreamPool {
backends: Arc<Vec<UpstreamState>>,
}
#[derive(Debug, Clone)]
pub struct UpstreamState {
state: Arc<InnerState>,
}
#[derive(Debug)]
struct InnerState {
uri: hyper::Uri,
weight: u32,
consecutive_failures: AtomicU32,
consecutive_successes: AtomicU32,
healthy: AtomicBool,
}
impl UpstreamPool {
pub fn from_validated(upstreams: &[ValidatedUpstream]) -> Self {
let backends = upstreams.iter().map(UpstreamState::new).collect();
Self {
backends: Arc::new(backends),
}
}
pub fn all(&self) -> &[UpstreamState] {
&self.backends
}
pub fn healthy(&self) -> Vec<&UpstreamState> {
self.backends.iter().filter(|b| b.is_healthy()).collect()
}
pub fn len(&self) -> usize {
self.backends.len()
}
pub fn is_empty(&self) -> bool {
self.backends.is_empty()
}
}
impl UpstreamState {
pub fn new(backend: &ValidatedUpstream) -> Self {
Self {
state: Arc::new(InnerState {
uri: backend.uri.clone(),
weight: backend.weight,
consecutive_failures: AtomicU32::new(0),
consecutive_successes: AtomicU32::new(0),
healthy: AtomicBool::new(true),
}),
}
}
pub fn uri(&self) -> &hyper::Uri {
&self.state.uri
}
pub fn weight(&self) -> u32 {
self.state.weight
}
pub fn is_healthy(&self) -> bool {
self.state.healthy.load(Ordering::Acquire)
}
pub fn record_success(&self, healthy_threshold: u32) -> bool {
self.state.consecutive_failures.store(0, Ordering::Release);
if self.is_healthy() {
self.state.consecutive_successes.store(0, Ordering::Release);
return false;
}
let prev = self
.state
.consecutive_successes
.fetch_add(1, Ordering::AcqRel);
let new_count = prev.saturating_add(1);
if new_count >= healthy_threshold {
self.state.consecutive_successes.store(0, Ordering::Release);
self.state.healthy.store(true, Ordering::Release);
return true;
}
false
}
pub fn record_failure(&self, threshold: u32) -> bool {
self.state.consecutive_successes.store(0, Ordering::Release);
let prev = self
.state
.consecutive_failures
.fetch_add(1, Ordering::AcqRel);
let new_count = prev.saturating_add(1);
if new_count >= threshold && self.state.healthy.swap(false, Ordering::AcqRel) {
return true;
}
false
}
pub fn mark_healthy(&self) {
self.state.consecutive_failures.store(0, Ordering::Release);
self.state.consecutive_successes.store(0, Ordering::Release);
self.state.healthy.store(true, Ordering::Release);
}
pub fn mark_unhealthy(&self) {
self.state.consecutive_successes.store(0, Ordering::Release);
self.state.healthy.store(false, Ordering::Release);
}
pub fn failure_count(&self) -> u32 {
self.state.consecutive_failures.load(Ordering::Acquire)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_upstream(addr: &str, weight: u32) -> ValidatedUpstream {
ValidatedUpstream {
uri: addr.parse().unwrap(),
weight,
}
}
#[test]
fn new_upstream_starts_healthy() {
let state = UpstreamState::new(&test_upstream("http://localhost:3000", 1));
assert!(state.is_healthy());
assert_eq!(state.failure_count(), 0);
}
#[test]
fn record_success_resets_failures() {
let state = UpstreamState::new(&test_upstream("http://localhost:3000", 1));
state.record_failure(5);
state.record_failure(5);
assert_eq!(state.failure_count(), 2);
state.record_success(1);
assert_eq!(state.failure_count(), 0);
assert!(state.is_healthy());
}
#[test]
fn record_success_requires_threshold_for_recovery() {
let state = UpstreamState::new(&test_upstream("http://localhost:3000", 1));
state.mark_unhealthy();
assert!(!state.record_success(3));
assert!(!state.is_healthy());
assert!(!state.record_success(3));
assert!(!state.is_healthy());
assert!(state.record_success(3));
assert!(state.is_healthy());
}
#[test]
fn failure_resets_consecutive_successes() {
let state = UpstreamState::new(&test_upstream("http://localhost:3000", 1));
state.mark_unhealthy();
state.record_success(3);
state.record_success(3);
state.record_failure(10);
assert!(!state.is_healthy());
state.record_success(3);
assert!(!state.is_healthy());
}
#[test]
fn record_failure_marks_unhealthy_at_threshold() {
let state = UpstreamState::new(&test_upstream("http://localhost:3000", 1));
assert!(!state.record_failure(3));
assert!(!state.record_failure(3));
assert!(state.record_failure(3));
assert!(!state.is_healthy());
}
#[test]
fn record_failure_beyond_threshold_does_not_retrigger() {
let state = UpstreamState::new(&test_upstream("http://localhost:3000", 1));
state.record_failure(2);
assert!(state.record_failure(2));
assert!(!state.record_failure(2));
}
#[test]
fn pool_healthy_filters_unhealthy_backends() {
let backends = vec![
test_upstream("http://b1:3000", 1),
test_upstream("http://b2:3000", 1),
test_upstream("http://b3:3000", 1),
];
let pool = UpstreamPool::from_validated(&backends);
pool.all()[1].mark_unhealthy();
let healthy = pool.healthy();
assert_eq!(healthy.len(), 2);
assert_eq!(
healthy[0].uri(),
&"http://b1:3000".parse::<hyper::Uri>().unwrap()
);
assert_eq!(
healthy[1].uri(),
&"http://b3:3000".parse::<hyper::Uri>().unwrap()
);
}
}