use std::sync::Mutex;
use std::time::{Duration, Instant};
const EWMA_ALPHA: f64 = 0.20;
const WEIGHT_FLOOR_FRACTION: f64 = 0.10;
pub const DEFAULT_TIMEOUT_THRESHOLD: u32 = 3;
pub const DEFAULT_PROBE_INTERVAL: Duration = Duration::from_secs(10);
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ClusterMembership {
pub primary: String,
pub replicas: Vec<String>,
}
impl ClusterMembership {
pub fn new(primary: String, replicas: Vec<String>) -> Self {
Self { primary, replicas }
}
pub fn len(&self) -> usize {
1 + self.replicas.len()
}
fn urls(&self) -> Vec<String> {
let mut out = Vec::with_capacity(self.len());
out.push(self.primary.clone());
out.extend(self.replicas.iter().cloned());
out
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Outcome {
Rtt(Duration),
Timeout,
}
pub trait Clock: Send + Sync + 'static {
fn now(&self) -> Instant;
}
#[derive(Debug, Default, Clone, Copy)]
pub struct SystemClock;
impl Clock for SystemClock {
fn now(&self) -> Instant {
Instant::now()
}
}
#[derive(Debug)]
pub struct FakeClock {
inner: Mutex<Instant>,
}
impl FakeClock {
pub fn new() -> Self {
Self {
inner: Mutex::new(Instant::now()),
}
}
pub fn advance(&self, d: Duration) {
let mut guard = self.inner.lock().unwrap();
*guard += d;
}
}
impl Default for FakeClock {
fn default() -> Self {
Self::new()
}
}
impl Clock for FakeClock {
fn now(&self) -> Instant {
*self.inner.lock().unwrap()
}
}
#[derive(Debug, Clone)]
struct EndpointHealth {
url: String,
ewma_rtt_secs: Option<f64>,
samples: u64,
consecutive_timeouts: u32,
healthy: bool,
last_probe: Option<Instant>,
}
impl EndpointHealth {
fn new(url: String) -> Self {
Self {
url,
ewma_rtt_secs: None,
samples: 0,
consecutive_timeouts: 0,
healthy: true,
last_probe: None,
}
}
fn record_rtt(&mut self, rtt: Duration) {
let secs = rtt.as_secs_f64().max(1e-6);
self.ewma_rtt_secs = Some(match self.ewma_rtt_secs {
None => secs,
Some(prev) => EWMA_ALPHA * secs + (1.0 - EWMA_ALPHA) * prev,
});
self.samples = self.samples.saturating_add(1);
self.consecutive_timeouts = 0;
}
fn record_timeout(&mut self, threshold: u32) {
self.consecutive_timeouts = self.consecutive_timeouts.saturating_add(1);
if self.consecutive_timeouts >= threshold {
self.healthy = false;
}
}
fn admit(&mut self) {
self.healthy = true;
self.consecutive_timeouts = 0;
}
}
#[derive(Debug, Clone)]
pub struct RouterConfig {
pub timeout_threshold: u32,
pub probe_interval: Duration,
}
impl Default for RouterConfig {
fn default() -> Self {
Self {
timeout_threshold: DEFAULT_TIMEOUT_THRESHOLD,
probe_interval: DEFAULT_PROBE_INTERVAL,
}
}
}
pub struct HealthAwareRouter {
endpoints: Mutex<Vec<EndpointHealth>>,
config: RouterConfig,
clock: Box<dyn Clock>,
force_primary: bool,
rr_counter: Mutex<u64>,
}
impl std::fmt::Debug for HealthAwareRouter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let endpoints = self.endpoints.lock().unwrap();
f.debug_struct("HealthAwareRouter")
.field("endpoints", &*endpoints)
.field("config", &self.config)
.field("force_primary", &self.force_primary)
.finish()
}
}
impl HealthAwareRouter {
pub fn new(membership: ClusterMembership) -> Self {
Self::with_config(membership, RouterConfig::default(), Box::new(SystemClock))
}
pub fn with_force_primary(membership: ClusterMembership, force_primary: bool) -> Self {
let mut r = Self::new(membership);
r.force_primary = force_primary;
r
}
pub fn with_config(
membership: ClusterMembership,
config: RouterConfig,
clock: Box<dyn Clock>,
) -> Self {
let endpoints: Vec<EndpointHealth> = membership
.urls()
.into_iter()
.map(EndpointHealth::new)
.collect();
Self {
endpoints: Mutex::new(endpoints),
config,
clock,
force_primary: false,
rr_counter: Mutex::new(0),
}
}
pub fn len(&self) -> usize {
self.endpoints.lock().unwrap().len()
}
pub fn force_primary(&self) -> bool {
self.force_primary
}
pub fn pick_read_index(&self) -> usize {
let endpoints = self.endpoints.lock().unwrap();
if self.force_primary || endpoints.len() == 1 {
return 0;
}
let healthy_replicas: Vec<usize> = (1..endpoints.len())
.filter(|&i| endpoints[i].healthy)
.collect();
if healthy_replicas.is_empty() {
return 0;
}
let weights: Vec<f64> = healthy_replicas
.iter()
.map(|&i| weight_for(&endpoints[i]))
.collect();
let weights = apply_floor(&weights);
let mut counter = self.rr_counter.lock().unwrap();
let idx_in_healthy = weighted_pick(&weights, *counter);
*counter = counter.wrapping_add(1);
healthy_replicas[idx_in_healthy]
}
pub fn observe_index(&self, index: usize, outcome: Outcome) {
let mut endpoints = self.endpoints.lock().unwrap();
if let Some(ep) = endpoints.get_mut(index) {
match outcome {
Outcome::Rtt(rtt) => ep.record_rtt(rtt),
Outcome::Timeout => ep.record_timeout(self.config.timeout_threshold),
}
}
}
pub fn observe_url(&self, url: &str, outcome: Outcome) {
let mut endpoints = self.endpoints.lock().unwrap();
if let Some(ep) = endpoints.iter_mut().find(|ep| ep.url == url) {
match outcome {
Outcome::Rtt(rtt) => ep.record_rtt(rtt),
Outcome::Timeout => ep.record_timeout(self.config.timeout_threshold),
}
}
}
pub fn endpoints_due_for_probe(&self) -> Vec<ProbeTarget> {
let endpoints = self.endpoints.lock().unwrap();
let now = self.clock.now();
endpoints
.iter()
.enumerate()
.filter(|(_, ep)| !ep.healthy)
.filter(|(_, ep)| match ep.last_probe {
None => true,
Some(t) => now.duration_since(t) >= self.config.probe_interval,
})
.map(|(i, ep)| ProbeTarget {
index: i,
url: ep.url.clone(),
})
.collect()
}
pub fn record_probe_result(&self, index: usize, success: bool) {
let mut endpoints = self.endpoints.lock().unwrap();
if let Some(ep) = endpoints.get_mut(index) {
ep.last_probe = Some(self.clock.now());
if success {
ep.admit();
}
}
}
pub fn update_membership(&mut self, new_membership: ClusterMembership) {
let mut endpoints = self.endpoints.lock().unwrap();
let new_urls = new_membership.urls();
let mut next: Vec<EndpointHealth> = Vec::with_capacity(new_urls.len());
for url in new_urls {
if let Some(existing) = endpoints.iter().find(|ep| ep.url == url) {
next.push(existing.clone());
} else {
next.push(EndpointHealth::new(url));
}
}
*endpoints = next;
}
pub fn endpoint_url(&self, index: usize) -> Option<String> {
self.endpoints
.lock()
.unwrap()
.get(index)
.map(|ep| ep.url.clone())
}
#[cfg(test)]
fn snapshot(&self, index: usize) -> Option<EndpointHealth> {
self.endpoints.lock().unwrap().get(index).cloned()
}
}
#[derive(Debug, Clone)]
pub struct ProbeTarget {
pub index: usize,
pub url: String,
}
fn weight_for(ep: &EndpointHealth) -> f64 {
match ep.ewma_rtt_secs {
None => 1.0,
Some(rtt) => {
let rtt = rtt.max(1e-6);
1.0 / rtt
}
}
}
fn apply_floor(weights: &[f64]) -> Vec<f64> {
if weights.is_empty() {
return Vec::new();
}
let mut sorted = weights.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let median = sorted[sorted.len() / 2];
let floor = WEIGHT_FLOOR_FRACTION * median;
weights.iter().map(|w| w.max(floor)).collect()
}
fn weighted_pick(weights: &[f64], counter: u64) -> usize {
let total: f64 = weights.iter().sum();
if total <= 0.0 || !total.is_finite() {
return (counter as usize) % weights.len();
}
let counter_mod = (counter as f64) % total;
let mut acc = 0.0;
for (i, &w) in weights.iter().enumerate() {
acc += w;
if counter_mod < acc {
return i;
}
}
weights.len() - 1
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
fn membership(primary: &str, replicas: &[&str]) -> ClusterMembership {
ClusterMembership::new(
primary.to_string(),
replicas.iter().map(|s| s.to_string()).collect(),
)
}
#[test]
fn single_endpoint_always_returns_primary() {
let router = HealthAwareRouter::new(membership("primary", &[]));
for _ in 0..50 {
assert_eq!(router.pick_read_index(), 0);
}
}
#[test]
fn force_primary_short_circuits() {
let router = HealthAwareRouter::with_force_primary(membership("p", &["r1", "r2"]), true);
for _ in 0..50 {
assert_eq!(router.pick_read_index(), 0);
}
}
#[test]
fn cold_start_distributes_across_replicas() {
let router = HealthAwareRouter::new(membership("p", &["r1", "r2", "r3"]));
let mut hits: HashMap<usize, u32> = HashMap::new();
for _ in 0..3000 {
*hits.entry(router.pick_read_index()).or_insert(0) += 1;
}
assert_eq!(hits.get(&0).copied().unwrap_or(0), 0);
for idx in 1..=3 {
let n = hits.get(&idx).copied().unwrap_or(0);
assert!(n > 800 && n < 1200, "replica {idx} got {n} hits");
}
}
#[test]
fn circuit_breaker_opens_after_k_consecutive_timeouts() {
let router = HealthAwareRouter::new(membership("p", &["r1", "r2"]));
for _ in 0..DEFAULT_TIMEOUT_THRESHOLD {
router.observe_index(1, Outcome::Timeout);
}
for _ in 0..200 {
assert_eq!(router.pick_read_index(), 2);
}
let snap = router.snapshot(1).unwrap();
assert!(!snap.healthy);
assert_eq!(snap.consecutive_timeouts, DEFAULT_TIMEOUT_THRESHOLD);
}
#[test]
fn rtt_observation_resets_consecutive_timeouts() {
let router = HealthAwareRouter::new(membership("p", &["r1"]));
router.observe_index(1, Outcome::Timeout);
router.observe_index(1, Outcome::Timeout);
router.observe_index(1, Outcome::Rtt(Duration::from_millis(5)));
let snap = router.snapshot(1).unwrap();
assert_eq!(snap.consecutive_timeouts, 0);
assert!(snap.healthy);
}
#[test]
fn all_unhealthy_replicas_fall_back_to_primary() {
let router = HealthAwareRouter::new(membership("p", &["r1", "r2"]));
for _ in 0..DEFAULT_TIMEOUT_THRESHOLD {
router.observe_index(1, Outcome::Timeout);
router.observe_index(2, Outcome::Timeout);
}
for _ in 0..50 {
assert_eq!(router.pick_read_index(), 0);
}
}
#[test]
fn probe_readmits_endpoint() {
let clock = std::sync::Arc::new(FakeClock::new());
let router = HealthAwareRouter::with_config(
membership("p", &["r1", "r2"]),
RouterConfig::default(),
Box::new(FakeClockHandle(clock.clone())),
);
for _ in 0..DEFAULT_TIMEOUT_THRESHOLD {
router.observe_index(1, Outcome::Timeout);
}
assert_eq!(router.pick_read_index(), 2);
let due = router.endpoints_due_for_probe();
assert_eq!(due.len(), 1);
assert_eq!(due[0].index, 1);
router.record_probe_result(1, true);
let snap = router.snapshot(1).unwrap();
assert!(snap.healthy);
}
#[test]
fn probe_cadence_respects_interval_under_fake_clock() {
let clock = std::sync::Arc::new(FakeClock::new());
let router = HealthAwareRouter::with_config(
membership("p", &["r1"]),
RouterConfig {
timeout_threshold: 1,
probe_interval: Duration::from_secs(10),
},
Box::new(FakeClockHandle(clock.clone())),
);
router.observe_index(1, Outcome::Timeout);
assert_eq!(router.endpoints_due_for_probe().len(), 1);
router.record_probe_result(1, false);
assert!(router.endpoints_due_for_probe().is_empty());
clock.advance(Duration::from_secs(5));
assert!(router.endpoints_due_for_probe().is_empty());
clock.advance(Duration::from_secs(6));
assert_eq!(router.endpoints_due_for_probe().len(), 1);
}
#[test]
fn membership_update_preserves_known_endpoints() {
let mut router = HealthAwareRouter::new(membership("p", &["r1", "r2"]));
router.observe_index(1, Outcome::Rtt(Duration::from_millis(10)));
let prev_samples = router.snapshot(1).unwrap().samples;
assert_eq!(prev_samples, 1);
router.update_membership(membership("p", &["r1", "r3"]));
assert_eq!(router.snapshot(1).unwrap().samples, 1);
assert_eq!(router.snapshot(2).unwrap().samples, 0);
assert_eq!(router.snapshot(2).unwrap().url, "r3");
}
#[test]
fn weighted_distribution_favours_faster_replicas() {
let router = HealthAwareRouter::new(membership("p", &["fast", "slow"]));
for _ in 0..200 {
router.observe_index(1, Outcome::Rtt(Duration::from_millis(1)));
router.observe_index(2, Outcome::Rtt(Duration::from_millis(10)));
}
let mut hits: HashMap<usize, u32> = HashMap::new();
for _ in 0..10_000 {
*hits.entry(router.pick_read_index()).or_insert(0) += 1;
}
let fast = hits.get(&1).copied().unwrap_or(0) as f64;
let slow = hits.get(&2).copied().unwrap_or(0) as f64;
let ratio = fast / slow;
assert!(
(9.0..=11.0).contains(&ratio),
"expected ~10:1 fast/slow ratio, got {ratio}"
);
}
struct FakeClockHandle(std::sync::Arc<FakeClock>);
impl Clock for FakeClockHandle {
fn now(&self) -> Instant {
self.0.now()
}
}
}
#[cfg(test)]
mod proptest_router {
use super::*;
use proptest::prelude::*;
use std::collections::HashMap;
proptest! {
#[test]
fn weighted_distribution_tracks_inverse_rtt(
rtts in proptest::collection::vec(1u64..50u64, 2..6usize),
) {
let names: Vec<String> = (0..rtts.len()).map(|i| format!("r{i}")).collect();
let replicas: Vec<&str> = names.iter().map(|s| s.as_str()).collect();
let router = HealthAwareRouter::new(
ClusterMembership::new("primary".into(), replicas.iter().map(|s| s.to_string()).collect())
);
for (i, &rtt_ms) in rtts.iter().enumerate() {
let idx = i + 1;
for _ in 0..200 {
router.observe_index(idx, Outcome::Rtt(Duration::from_millis(rtt_ms)));
}
}
let n_calls = 10_000usize;
let mut hits: HashMap<usize, u32> = HashMap::new();
for _ in 0..n_calls {
*hits.entry(router.pick_read_index()).or_insert(0) += 1;
}
let raw_weights: Vec<f64> = rtts.iter().map(|&r| 1.0 / (r as f64 / 1000.0)).collect();
let expected_weights = apply_floor(&raw_weights);
let total: f64 = expected_weights.iter().sum();
for (i, &w) in expected_weights.iter().enumerate() {
let idx = i + 1;
let expected = (w / total) * (n_calls as f64);
let actual = hits.get(&idx).copied().unwrap_or(0) as f64;
let slack = 0.15 * expected + 50.0;
prop_assert!(
(actual - expected).abs() <= slack,
"replica {idx}: expected ~{expected:.0}, got {actual} (slack {slack:.0}); rtts={rtts:?}"
);
}
}
#[test]
fn circuit_breaker_open_on_k_consecutive(
seq in proptest::collection::vec(any::<bool>(), 1..40usize),
) {
let router = HealthAwareRouter::with_config(
ClusterMembership::new("p".into(), vec!["r1".into()]),
RouterConfig { timeout_threshold: DEFAULT_TIMEOUT_THRESHOLD, probe_interval: DEFAULT_PROBE_INTERVAL },
Box::new(SystemClock),
);
let mut consecutive = 0u32;
let mut should_be_unhealthy = false;
for &is_timeout in &seq {
if is_timeout {
router.observe_index(1, Outcome::Timeout);
consecutive += 1;
if consecutive >= DEFAULT_TIMEOUT_THRESHOLD {
should_be_unhealthy = true;
}
} else {
router.observe_index(1, Outcome::Rtt(Duration::from_millis(2)));
consecutive = 0;
}
}
let snap = router.snapshot(1).unwrap();
if should_be_unhealthy {
prop_assert!(!snap.healthy);
}
let trailing = seq.iter().rev().take_while(|&&b| b).count() as u32;
prop_assert_eq!(snap.consecutive_timeouts, trailing);
}
}
}