use std::{collections::HashMap, fmt, net::SocketAddr, sync::Arc, time::Duration};
use tokio::task::JoinSet;
use tokio_util::task::TaskTracker;
use tracing::warn;
use crate::codec::message::Question;
use super::{
DEFAULT_QUERY_TIMEOUT, Error, ForwardResult, Result, UpstreamClient, UpstreamConfig,
health::UpstreamHealth,
};
pub const DEFAULT_FAILOVER_BUDGET: usize = 1;
#[derive(Debug, Clone, Copy)]
pub struct UpstreamObservation {
pub ewma_latency_ms: Option<f64>,
pub success_rate: f64,
pub attempts: u64,
}
pub trait UpstreamSelector: fmt::Debug + Send + Sync {
fn order(&self, upstreams: &[UpstreamObservation]) -> Vec<usize>;
}
#[derive(Debug, Default, Clone)]
pub struct RandomSelector;
impl UpstreamSelector for RandomSelector {
fn order(&self, upstreams: &[UpstreamObservation]) -> Vec<usize> {
use rand::seq::SliceRandom as _;
let mut indices: Vec<usize> = (0..upstreams.len()).collect();
indices.shuffle(&mut rand::rng());
indices
}
}
#[derive(Debug, Default, Clone)]
pub struct LatencyWeightedSelector;
impl LatencyWeightedSelector {
const BASELINE_MS: f64 = 50.0;
const MIN_LATENCY_MS: f64 = 1.0;
const MIN_SUCCESS: f64 = 0.05;
fn weight(o: &UpstreamObservation) -> f64 {
let latency = o
.ewma_latency_ms
.unwrap_or(Self::BASELINE_MS)
.max(Self::MIN_LATENCY_MS);
let success = if o.attempts == 0 {
1.0
} else {
o.success_rate.max(Self::MIN_SUCCESS)
};
success / latency
}
}
impl UpstreamSelector for LatencyWeightedSelector {
fn order(&self, upstreams: &[UpstreamObservation]) -> Vec<usize> {
let weights: Vec<f64> = upstreams.iter().map(Self::weight).collect();
weighted_permutation(&weights)
}
}
fn weighted_permutation(weights: &[f64]) -> Vec<usize> {
use rand::RngExt as _;
let n = weights.len();
let mut remaining: Vec<usize> = (0..n).collect();
let mut order = Vec::with_capacity(n);
let mut rng = rand::rng();
while !remaining.is_empty() {
let total: f64 = remaining.iter().map(|&i| weights[i]).sum();
if total <= 0.0 {
order.append(&mut remaining);
break;
}
let mut pick = rng.random::<f64>() * total;
let mut chosen = remaining.len() - 1; for (pos, &i) in remaining.iter().enumerate() {
pick -= weights[i];
if pick <= 0.0 {
chosen = pos;
break;
}
}
order.push(remaining.remove(chosen));
}
order
}
pub struct UpstreamPool {
clients: Vec<UpstreamClient>,
selector: Arc<dyn UpstreamSelector>,
max_attempts: usize,
per_attempt_timeout: Duration,
parallel_fanout: Option<usize>,
}
impl fmt::Debug for UpstreamPool {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("UpstreamPool")
.field("clients", &self.clients.len())
.field("max_attempts", &self.max_attempts)
.field("per_attempt_timeout", &self.per_attempt_timeout)
.field("parallel_fanout", &self.parallel_fanout)
.finish_non_exhaustive()
}
}
impl UpstreamPool {
pub async fn connect(
configs: &[UpstreamConfig],
tracker: &TaskTracker,
selector: Arc<dyn UpstreamSelector>,
failover_budget: usize,
per_attempt_timeout: Duration,
) -> Self {
let mut clients = Vec::with_capacity(configs.len());
for cfg in configs {
match UpstreamClient::connect(cfg).await {
Ok((client, bg)) => {
tracker.spawn(bg);
clients.push(client);
}
Err(e) => {
warn!(
transport = %cfg.transport,
addr = %cfg.addr,
error = %e,
"upstream failed to connect, skipping"
);
}
}
}
Self {
clients,
selector,
max_attempts: failover_budget + 1,
per_attempt_timeout,
parallel_fanout: None,
}
}
pub async fn connect_with_defaults(configs: &[UpstreamConfig], tracker: &TaskTracker) -> Self {
Self::connect(
configs,
tracker,
Arc::new(RandomSelector),
DEFAULT_FAILOVER_BUDGET,
DEFAULT_QUERY_TIMEOUT,
)
.await
}
#[must_use]
pub fn with_parallel_fanout(mut self, fanout: usize) -> Self {
self.parallel_fanout = Some(fanout.max(1));
self
}
pub async fn forward(
&self,
question: &Question,
health: &UpstreamHealth,
) -> Result<ForwardResult> {
if self.clients.is_empty() {
return Err(Error::AllUpstreamsFailed { attempts: 0 });
}
let order = self.selector.order(&self.observations(health));
match self.parallel_fanout {
Some(fanout) => {
self.forward_parallel(question, health, &order, fanout)
.await
}
None => self.forward_sequential(question, health, &order).await,
}
}
fn observations(&self, health: &UpstreamHealth) -> Vec<UpstreamObservation> {
let by_addr: HashMap<SocketAddr, _> = health
.snapshot()
.into_iter()
.map(|row| (row.addr, row))
.collect();
self.clients
.iter()
.map(|client| match by_addr.get(&client.addr()) {
Some(row) => UpstreamObservation {
ewma_latency_ms: row.ewma_latency_ms,
success_rate: row.success_rate,
attempts: row.attempts(),
},
None => UpstreamObservation {
ewma_latency_ms: None,
success_rate: 0.0,
attempts: 0,
},
})
.collect()
}
async fn forward_sequential(
&self,
question: &Question,
health: &UpstreamHealth,
order: &[usize],
) -> Result<ForwardResult> {
let mut attempts: usize = 0;
for idx in order.iter().take(self.max_attempts) {
let Some(client) = self.clients.get(*idx) else {
continue;
};
attempts += 1;
match client.forward(question, self.per_attempt_timeout).await {
Ok(result) => {
health.record_success(client.addr(), result.latency);
return Ok(result);
}
Err(e) => {
health.record_failure(client.addr(), e.to_string());
warn!(
upstream_index = idx,
transport = %client.transport(),
error = %e,
"upstream failed, trying next"
);
}
}
}
Err(Error::AllUpstreamsFailed { attempts })
}
async fn forward_parallel(
&self,
question: &Question,
health: &UpstreamHealth,
order: &[usize],
fanout: usize,
) -> Result<ForwardResult> {
let mut set: JoinSet<(SocketAddr, Result<ForwardResult>)> = JoinSet::new();
for idx in order.iter().take(fanout) {
let Some(client) = self.clients.get(*idx) else {
continue;
};
let client = client.clone();
let question = question.clone();
let timeout = self.per_attempt_timeout;
set.spawn(async move {
let result = client.forward(&question, timeout).await;
(client.addr(), result)
});
}
let mut attempts: usize = 0;
while let Some(joined) = set.join_next().await {
let (addr, result) = joined.expect("upstream forward task panicked");
attempts += 1;
match result {
Ok(forward) => {
health.record_success(addr, forward.latency);
return Ok(forward);
}
Err(e) => {
health.record_failure(addr, e.to_string());
warn!(
upstream = %addr,
error = %e,
"parallel upstream attempt failed"
);
}
}
}
Err(Error::AllUpstreamsFailed { attempts })
}
pub fn len(&self) -> usize {
self.clients.len()
}
pub fn is_empty(&self) -> bool {
self.clients.is_empty()
}
pub fn max_attempts(&self) -> usize {
self.max_attempts
}
pub fn parallel_fanout(&self) -> Option<usize> {
self.parallel_fanout
}
}
pub struct SharedUpstreamPool {
inner: arc_swap::ArcSwap<UpstreamPool>,
health: Arc<UpstreamHealth>,
}
impl fmt::Debug for SharedUpstreamPool {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SharedUpstreamPool")
.field("pool", &*self.inner.load())
.finish()
}
}
impl SharedUpstreamPool {
pub fn new(pool: UpstreamPool) -> Self {
Self {
inner: arc_swap::ArcSwap::from_pointee(pool),
health: Arc::new(UpstreamHealth::new()),
}
}
pub fn health(&self) -> &Arc<UpstreamHealth> {
&self.health
}
pub fn load(&self) -> arc_swap::Guard<Arc<UpstreamPool>> {
self.inner.load()
}
pub fn store(&self, pool: UpstreamPool) {
self.inner.store(Arc::new(pool));
}
pub async fn forward(&self, question: &Question) -> Result<ForwardResult> {
let pool = self.inner.load_full();
pool.forward(question, &self.health).await
}
}
#[cfg(test)]
mod tests {
use std::{net::SocketAddr, time::Duration};
use tokio::time::timeout;
use tokio_util::task::TaskTracker;
use super::*;
use crate::resolver::upstream::{UpstreamConfig, UpstreamHealth, UpstreamTransport};
use crate::test_support::{
mock_udp_upstream, nxdomain_handler, positive_a_handler, silent_handler, stock_question,
};
fn udp_config(addr: SocketAddr) -> UpstreamConfig {
UpstreamConfig {
addr,
transport: UpstreamTransport::Udp,
tls_server_name: None,
http_endpoint: None,
}
}
#[derive(Debug)]
struct SequentialSelector;
impl UpstreamSelector for SequentialSelector {
fn order(&self, upstreams: &[UpstreamObservation]) -> Vec<usize> {
(0..upstreams.len()).collect()
}
}
fn obs(n: usize) -> Vec<UpstreamObservation> {
vec![
UpstreamObservation {
ewma_latency_ms: None,
success_rate: 0.0,
attempts: 0,
};
n
]
}
#[test]
fn random_selector_order_is_permutation() {
let sel = RandomSelector;
assert_eq!(sel.order(&obs(0)), Vec::<usize>::new());
for _ in 0..100 {
let mut o = sel.order(&obs(3));
o.sort_unstable();
assert_eq!(o, vec![0, 1, 2]);
}
}
#[test]
fn random_selector_spread() {
let sel = RandomSelector;
let trials = 3000usize;
let count = 3usize;
let mut first_tally = vec![0usize; count];
for _ in 0..trials {
let o = sel.order(&obs(count));
first_tally[o[0]] += 1;
}
let lo = (trials as f64 * 0.20) as usize;
let hi = (trials as f64 * 0.47) as usize;
for (i, &tally) in first_tally.iter().enumerate() {
assert!(
tally >= lo && tally <= hi,
"index {i} appeared {tally} times as first in {trials} trials \
(expected {lo}–{hi})"
);
}
}
#[test]
fn sequential_selector_order() {
let sel = SequentialSelector;
assert_eq!(sel.order(&obs(0)), Vec::<usize>::new());
assert_eq!(sel.order(&obs(1)), vec![0]);
assert_eq!(sel.order(&obs(3)), vec![0, 1, 2]);
}
#[test]
fn latency_weighted_favors_faster_upstream() {
let sel = LatencyWeightedSelector;
let upstreams = vec![
UpstreamObservation {
ewma_latency_ms: Some(5.0),
success_rate: 1.0,
attempts: 100,
},
UpstreamObservation {
ewma_latency_ms: Some(200.0),
success_rate: 1.0,
attempts: 100,
},
];
let trials = 2000usize;
let mut fast_first = 0usize;
for _ in 0..trials {
let order = sel.order(&upstreams);
assert_eq!(order.len(), 2);
assert!(order.contains(&0) && order.contains(&1));
if order[0] == 0 {
fast_first += 1;
}
}
let lo = (trials as f64 * 0.8) as usize;
assert!(
fast_first > lo,
"fast upstream led {fast_first}/{trials} times (expected > {lo})"
);
}
#[test]
fn latency_weighted_explores_unknown_upstream() {
let sel = LatencyWeightedSelector;
let upstreams = vec![
UpstreamObservation {
ewma_latency_ms: Some(5.0),
success_rate: 1.0,
attempts: 100,
},
UpstreamObservation {
ewma_latency_ms: None,
success_rate: 0.0,
attempts: 0,
},
];
let mut unknown_first = 0usize;
for _ in 0..2000 {
if sel.order(&upstreams)[0] == 1 {
unknown_first += 1;
}
}
assert!(
unknown_first > 50,
"unknown upstream was starved: led only {unknown_first}/2000"
);
}
#[tokio::test]
async fn parallel_returns_fastest_and_ignores_silent() {
let silent_addr = mock_udp_upstream(silent_handler).await;
let answer_addr = mock_udp_upstream(positive_a_handler).await;
let configs = vec![udp_config(silent_addr), udp_config(answer_addr)];
let tracker = TaskTracker::new();
let pool = UpstreamPool::connect(
&configs,
&tracker,
Arc::new(SequentialSelector),
1,
Duration::from_secs(2),
)
.await
.with_parallel_fanout(2);
let health = UpstreamHealth::new();
let result = timeout(
Duration::from_millis(500),
pool.forward(&stock_question(), &health),
)
.await
.expect("parallel must not block on the silent upstream's timeout")
.expect("forward must succeed via the fast upstream");
assert!(!result.is_negative);
assert_eq!(
result.upstream, answer_addr,
"the fast upstream must win the race"
);
let snap = health.snapshot();
let answer = snap
.iter()
.find(|r| r.addr == answer_addr)
.expect("winner tracked");
assert_eq!(answer.successes, 1);
}
#[tokio::test]
async fn empty_pool_returns_all_failed_attempts_zero() {
let tracker = TaskTracker::new();
let pool = UpstreamPool::connect(
&[],
&tracker,
Arc::new(SequentialSelector),
1,
Duration::from_millis(150),
)
.await;
assert!(pool.is_empty());
assert_eq!(pool.len(), 0);
let result = timeout(
Duration::from_secs(5),
pool.forward(&stock_question(), &UpstreamHealth::new()),
)
.await
.expect("safety timeout");
assert!(
matches!(result, Err(Error::AllUpstreamsFailed { attempts: 0 })),
"expected AllUpstreamsFailed {{ attempts: 0 }}, got: {result:?}"
);
}
#[tokio::test]
async fn failover_to_second_upstream_on_timeout() {
let silent_addr = mock_udp_upstream(silent_handler).await;
let answer_addr = mock_udp_upstream(positive_a_handler).await;
let configs = vec![udp_config(silent_addr), udp_config(answer_addr)];
let tracker = TaskTracker::new();
let pool = UpstreamPool::connect(
&configs,
&tracker,
Arc::new(SequentialSelector),
1, Duration::from_millis(150),
)
.await;
assert_eq!(pool.max_attempts(), 2);
let result = timeout(
Duration::from_secs(5),
pool.forward(&stock_question(), &UpstreamHealth::new()),
)
.await
.expect("safety timeout")
.expect("forward must succeed after failover");
assert!(
!result.is_negative,
"failover result must be a positive answer"
);
assert_eq!(
result.upstream, answer_addr,
"must record the upstream that answered after failover"
);
assert!(result.latency > Duration::ZERO, "latency must be measured");
}
#[tokio::test]
async fn forward_records_per_upstream_health() {
let silent_addr = mock_udp_upstream(silent_handler).await;
let answer_addr = mock_udp_upstream(positive_a_handler).await;
let configs = vec![udp_config(silent_addr), udp_config(answer_addr)];
let tracker = TaskTracker::new();
let pool = UpstreamPool::connect(
&configs,
&tracker,
Arc::new(SequentialSelector),
1,
Duration::from_millis(150),
)
.await;
let health = UpstreamHealth::new();
timeout(
Duration::from_secs(5),
pool.forward(&stock_question(), &health),
)
.await
.expect("safety timeout")
.expect("forward succeeds after failover");
let snap = health.snapshot();
let silent = snap
.iter()
.find(|r| r.addr == silent_addr)
.expect("silent upstream tracked");
assert_eq!(silent.failures, 1, "silent upstream recorded a failure");
assert_eq!(silent.successes, 0);
assert_eq!(silent.ewma_latency_ms, None, "a failure has no latency");
assert!(
silent.last_error.is_some(),
"failure retains an error string"
);
let answer = snap
.iter()
.find(|r| r.addr == answer_addr)
.expect("answering upstream tracked");
assert_eq!(answer.successes, 1, "answering upstream recorded a success");
assert_eq!(answer.failures, 0);
assert!(
answer.ewma_latency_ms.is_some(),
"success records a latency"
);
}
#[tokio::test]
async fn all_fail_returns_all_upstreams_failed() {
let s0 = mock_udp_upstream(silent_handler).await;
let s1 = mock_udp_upstream(silent_handler).await;
let configs = vec![udp_config(s0), udp_config(s1)];
let tracker = TaskTracker::new();
let pool = UpstreamPool::connect(
&configs,
&tracker,
Arc::new(SequentialSelector),
1, Duration::from_millis(150),
)
.await;
let result = timeout(
Duration::from_secs(5),
pool.forward(&stock_question(), &UpstreamHealth::new()),
)
.await
.expect("safety timeout");
assert!(
matches!(result, Err(Error::AllUpstreamsFailed { attempts: 2 })),
"expected AllUpstreamsFailed {{ attempts: 2 }}, got: {result:?}"
);
}
#[tokio::test]
async fn budget_bounds_attempts() {
let s0 = mock_udp_upstream(silent_handler).await;
let s1 = mock_udp_upstream(silent_handler).await;
let s2 = mock_udp_upstream(silent_handler).await;
let configs = vec![udp_config(s0), udp_config(s1), udp_config(s2)];
let tracker = TaskTracker::new();
let pool = UpstreamPool::connect(
&configs,
&tracker,
Arc::new(SequentialSelector),
1, Duration::from_millis(150),
)
.await;
let result = timeout(
Duration::from_secs(5),
pool.forward(&stock_question(), &UpstreamHealth::new()),
)
.await
.expect("safety timeout");
assert!(
matches!(result, Err(Error::AllUpstreamsFailed { attempts: 2 })),
"expected AllUpstreamsFailed {{ attempts: 2 }}, got: {result:?}"
);
}
#[tokio::test]
async fn shared_pool_swap_takes_effect() {
let positive_addr = mock_udp_upstream(positive_a_handler).await;
let nxdomain_addr = mock_udp_upstream(nxdomain_handler).await;
let tracker = TaskTracker::new();
let pool_a = UpstreamPool::connect(
&[udp_config(positive_addr)],
&tracker,
Arc::new(SequentialSelector),
0, Duration::from_millis(500),
)
.await;
let pool_b = UpstreamPool::connect(
&[udp_config(nxdomain_addr)],
&tracker,
Arc::new(SequentialSelector),
0,
Duration::from_millis(500),
)
.await;
let shared = SharedUpstreamPool::new(pool_a);
let q = stock_question();
let res_a = timeout(Duration::from_secs(5), shared.forward(&q))
.await
.expect("safety timeout")
.expect("pool_a forward must succeed");
assert!(!res_a.is_negative, "pool_a must return a positive answer");
shared.store(pool_b);
let res_b = timeout(Duration::from_secs(5), shared.forward(&q))
.await
.expect("safety timeout")
.expect("pool_b forward must succeed");
assert!(res_b.is_negative, "pool_b must return a negative answer");
}
}