use std::{fmt, sync::Arc, time::Duration};
use tokio_util::task::TaskTracker;
use tracing::warn;
use crate::codec::message::Question;
use super::{DEFAULT_QUERY_TIMEOUT, Error, ForwardResult, Result, UpstreamClient, UpstreamConfig};
pub const DEFAULT_FAILOVER_BUDGET: usize = 1;
pub trait UpstreamSelector: fmt::Debug + Send + Sync {
fn order(&self, count: usize) -> Vec<usize>;
}
#[derive(Debug, Default, Clone)]
pub struct RandomSelector;
impl UpstreamSelector for RandomSelector {
fn order(&self, count: usize) -> Vec<usize> {
use rand::seq::SliceRandom as _;
let mut indices: Vec<usize> = (0..count).collect();
indices.shuffle(&mut rand::rng());
indices
}
}
pub struct UpstreamPool {
clients: Vec<UpstreamClient>,
selector: Arc<dyn UpstreamSelector>,
max_attempts: usize,
per_attempt_timeout: Duration,
}
impl fmt::Debug for UpstreamPool {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("UpstreamPool")
.field("clients", &self.clients.len())
.field("max_attempts", &self.max_attempts)
.field("per_attempt_timeout", &self.per_attempt_timeout)
.finish_non_exhaustive()
}
}
impl UpstreamPool {
pub async fn connect(
configs: &[UpstreamConfig],
tracker: &TaskTracker,
selector: Arc<dyn UpstreamSelector>,
failover_budget: usize,
per_attempt_timeout: Duration,
) -> Self {
let mut clients = Vec::with_capacity(configs.len());
for cfg in configs {
match UpstreamClient::connect(cfg).await {
Ok((client, bg)) => {
tracker.spawn(bg);
clients.push(client);
}
Err(e) => {
warn!(
transport = %cfg.transport,
addr = %cfg.addr,
error = %e,
"upstream failed to connect, skipping"
);
}
}
}
Self {
clients,
selector,
max_attempts: failover_budget + 1,
per_attempt_timeout,
}
}
pub async fn connect_with_defaults(configs: &[UpstreamConfig], tracker: &TaskTracker) -> Self {
Self::connect(
configs,
tracker,
Arc::new(RandomSelector),
DEFAULT_FAILOVER_BUDGET,
DEFAULT_QUERY_TIMEOUT,
)
.await
}
pub async fn forward(&self, question: &Question) -> Result<ForwardResult> {
if self.clients.is_empty() {
return Err(Error::AllUpstreamsFailed { attempts: 0 });
}
let order = self.selector.order(self.clients.len());
let mut attempts: usize = 0;
for idx in order.iter().take(self.max_attempts) {
let Some(client) = self.clients.get(*idx) else {
continue;
};
attempts += 1;
match client.forward(question, self.per_attempt_timeout).await {
Ok(result) => return Ok(result),
Err(e) => {
warn!(
upstream_index = idx,
transport = %client.transport(),
error = %e,
"upstream failed, trying next"
);
}
}
}
Err(Error::AllUpstreamsFailed { attempts })
}
pub fn len(&self) -> usize {
self.clients.len()
}
pub fn is_empty(&self) -> bool {
self.clients.is_empty()
}
pub fn max_attempts(&self) -> usize {
self.max_attempts
}
}
pub struct SharedUpstreamPool {
inner: arc_swap::ArcSwap<UpstreamPool>,
}
impl fmt::Debug for SharedUpstreamPool {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SharedUpstreamPool")
.field("pool", &*self.inner.load())
.finish()
}
}
impl SharedUpstreamPool {
pub fn new(pool: UpstreamPool) -> Self {
Self {
inner: arc_swap::ArcSwap::from_pointee(pool),
}
}
pub fn load(&self) -> arc_swap::Guard<Arc<UpstreamPool>> {
self.inner.load()
}
pub fn store(&self, pool: UpstreamPool) {
self.inner.store(Arc::new(pool));
}
pub async fn forward(&self, question: &Question) -> Result<ForwardResult> {
let pool = self.inner.load_full();
pool.forward(question).await
}
}
#[cfg(test)]
mod tests {
use std::{net::SocketAddr, time::Duration};
use hickory_net::proto::op::{Message, MessageType, ResponseCode};
use hickory_net::proto::rr::{Name, RData, Record, rdata::A};
use tokio::net::UdpSocket;
use tokio::time::timeout;
use tokio_util::task::TaskTracker;
use super::*;
use crate::codec::message::{Qclass, Qtype, Question};
use crate::resolver::upstream::{UpstreamConfig, UpstreamTransport};
async fn spawn_mock_udp<F>(mut handler: F) -> SocketAddr
where
F: FnMut(Message) -> Option<Message> + Send + 'static,
{
let sock = UdpSocket::bind("127.0.0.1:0").await.unwrap();
let addr = sock.local_addr().unwrap();
tokio::spawn(async move {
let mut buf = vec![0u8; 512];
loop {
let Ok((len, peer)) = sock.recv_from(&mut buf).await else {
break;
};
let Ok(req) = Message::from_vec(&buf[..len]) else {
continue;
};
if let Some(resp) = handler(req)
&& let Ok(resp_bytes) = resp.to_vec()
{
let _ = sock.send_to(&resp_bytes, peer).await;
}
}
});
addr
}
fn positive_a_handler(req: Message) -> Option<Message> {
let mut resp = req.clone();
resp.metadata.message_type = MessageType::Response;
resp.metadata.response_code = ResponseCode::NoError;
let name = Name::from_ascii("example.com.").unwrap();
let rdata = RData::A(A::new(93, 184, 216, 34));
resp.add_answer(Record::from_rdata(name, 300, rdata));
Some(resp)
}
fn nxdomain_handler(req: Message) -> Option<Message> {
let mut resp = req.clone();
resp.metadata.message_type = MessageType::Response;
resp.metadata.response_code = ResponseCode::NXDomain;
Some(resp)
}
fn silent_handler(_req: Message) -> Option<Message> {
None
}
fn stock_question() -> Question {
Question {
name: "example.com".parse().unwrap(),
qtype: Qtype::A,
qclass: Qclass::In,
}
}
fn udp_config(addr: SocketAddr) -> UpstreamConfig {
UpstreamConfig {
addr,
transport: UpstreamTransport::Udp,
tls_server_name: None,
http_endpoint: None,
}
}
#[derive(Debug)]
struct SequentialSelector;
impl UpstreamSelector for SequentialSelector {
fn order(&self, count: usize) -> Vec<usize> {
(0..count).collect()
}
}
#[test]
fn random_selector_order_is_permutation() {
let sel = RandomSelector;
assert_eq!(sel.order(0), Vec::<usize>::new());
for _ in 0..100 {
let mut o = sel.order(3);
o.sort_unstable();
assert_eq!(o, vec![0, 1, 2]);
}
}
#[test]
fn random_selector_spread() {
let sel = RandomSelector;
let trials = 3000usize;
let count = 3usize;
let mut first_tally = vec![0usize; count];
for _ in 0..trials {
let o = sel.order(count);
first_tally[o[0]] += 1;
}
let lo = (trials as f64 * 0.20) as usize;
let hi = (trials as f64 * 0.47) as usize;
for (i, &tally) in first_tally.iter().enumerate() {
assert!(
tally >= lo && tally <= hi,
"index {i} appeared {tally} times as first in {trials} trials \
(expected {lo}–{hi})"
);
}
}
#[test]
fn sequential_selector_order() {
let sel = SequentialSelector;
assert_eq!(sel.order(0), Vec::<usize>::new());
assert_eq!(sel.order(1), vec![0]);
assert_eq!(sel.order(3), vec![0, 1, 2]);
}
#[tokio::test]
async fn empty_pool_returns_all_failed_attempts_zero() {
let tracker = TaskTracker::new();
let pool = UpstreamPool::connect(
&[],
&tracker,
Arc::new(SequentialSelector),
1,
Duration::from_millis(150),
)
.await;
assert!(pool.is_empty());
assert_eq!(pool.len(), 0);
let result = timeout(Duration::from_secs(5), pool.forward(&stock_question()))
.await
.expect("safety timeout");
assert!(
matches!(result, Err(Error::AllUpstreamsFailed { attempts: 0 })),
"expected AllUpstreamsFailed {{ attempts: 0 }}, got: {result:?}"
);
}
#[tokio::test]
async fn failover_to_second_upstream_on_timeout() {
let silent_addr = spawn_mock_udp(silent_handler).await;
let answer_addr = spawn_mock_udp(positive_a_handler).await;
let configs = vec![udp_config(silent_addr), udp_config(answer_addr)];
let tracker = TaskTracker::new();
let pool = UpstreamPool::connect(
&configs,
&tracker,
Arc::new(SequentialSelector),
1, Duration::from_millis(150),
)
.await;
assert_eq!(pool.max_attempts(), 2);
let result = timeout(Duration::from_secs(5), pool.forward(&stock_question()))
.await
.expect("safety timeout")
.expect("forward must succeed after failover");
assert!(
!result.is_negative,
"failover result must be a positive answer"
);
}
#[tokio::test]
async fn all_fail_returns_all_upstreams_failed() {
let s0 = spawn_mock_udp(silent_handler).await;
let s1 = spawn_mock_udp(silent_handler).await;
let configs = vec![udp_config(s0), udp_config(s1)];
let tracker = TaskTracker::new();
let pool = UpstreamPool::connect(
&configs,
&tracker,
Arc::new(SequentialSelector),
1, Duration::from_millis(150),
)
.await;
let result = timeout(Duration::from_secs(5), pool.forward(&stock_question()))
.await
.expect("safety timeout");
assert!(
matches!(result, Err(Error::AllUpstreamsFailed { attempts: 2 })),
"expected AllUpstreamsFailed {{ attempts: 2 }}, got: {result:?}"
);
}
#[tokio::test]
async fn budget_bounds_attempts() {
let s0 = spawn_mock_udp(silent_handler).await;
let s1 = spawn_mock_udp(silent_handler).await;
let s2 = spawn_mock_udp(silent_handler).await;
let configs = vec![udp_config(s0), udp_config(s1), udp_config(s2)];
let tracker = TaskTracker::new();
let pool = UpstreamPool::connect(
&configs,
&tracker,
Arc::new(SequentialSelector),
1, Duration::from_millis(150),
)
.await;
let result = timeout(Duration::from_secs(5), pool.forward(&stock_question()))
.await
.expect("safety timeout");
assert!(
matches!(result, Err(Error::AllUpstreamsFailed { attempts: 2 })),
"expected AllUpstreamsFailed {{ attempts: 2 }}, got: {result:?}"
);
}
#[tokio::test]
async fn shared_pool_swap_takes_effect() {
let positive_addr = spawn_mock_udp(positive_a_handler).await;
let nxdomain_addr = spawn_mock_udp(nxdomain_handler).await;
let tracker = TaskTracker::new();
let pool_a = UpstreamPool::connect(
&[udp_config(positive_addr)],
&tracker,
Arc::new(SequentialSelector),
0, Duration::from_millis(500),
)
.await;
let pool_b = UpstreamPool::connect(
&[udp_config(nxdomain_addr)],
&tracker,
Arc::new(SequentialSelector),
0,
Duration::from_millis(500),
)
.await;
let shared = SharedUpstreamPool::new(pool_a);
let q = stock_question();
let res_a = timeout(Duration::from_secs(5), shared.forward(&q))
.await
.expect("safety timeout")
.expect("pool_a forward must succeed");
assert!(!res_a.is_negative, "pool_a must return a positive answer");
shared.store(pool_b);
let res_b = timeout(Duration::from_secs(5), shared.forward(&q))
.await
.expect("safety timeout")
.expect("pool_b forward must succeed");
assert!(res_b.is_negative, "pool_b must return a negative answer");
}
}