use std::net::SocketAddr;
use std::sync::{Arc, OnceLock};
use std::time::Duration;
use hickory_client::client::Client;
use hickory_client::proto::runtime::TokioRuntimeProvider;
use hickory_client::proto::tcp::TcpClientStream;
use hickory_client::proto::udp::UdpClientStream;
use rustls::ClientConfig;
use super::common::transport::Transport;
pub(super) async fn build_udp_client(addr: SocketAddr, timeout: Duration) -> Option<Client> {
let stream =
UdpClientStream::<TokioRuntimeProvider>::builder(addr, TokioRuntimeProvider::new())
.with_timeout(Some(timeout))
.build();
let (client, bg) = match Client::connect(stream).await {
Ok(pair) => pair,
Err(e) => {
tracing::warn!(upstream = %addr, error = %e, "failed to build UDP DNS client");
return None;
}
};
tokio::spawn(bg);
Some(client)
}
pub(super) async fn build_tcp_client(addr: SocketAddr, timeout: Duration) -> Option<Client> {
let (stream, sender) =
TcpClientStream::new(addr, None, Some(timeout), TokioRuntimeProvider::new());
let (client, bg) = match Client::with_timeout(stream, sender, timeout, None).await {
Ok(pair) => pair,
Err(e) => {
tracing::warn!(upstream = %addr, error = %e, "failed to build TCP DNS client");
return None;
}
};
tokio::spawn(bg);
Some(client)
}
pub(super) async fn build_dot_client(
addr: SocketAddr,
sni: String,
timeout: Duration,
) -> Option<Client> {
use hickory_proto::rustls::tls_client_connect;
use rustls::pki_types::ServerName;
let server_name = match ServerName::try_from(sni.clone()) {
Ok(name) => name,
Err(e) => {
tracing::warn!(upstream = %addr, sni = %sni, error = %e, "invalid SNI for DoT");
return None;
}
};
let client_config = dot_upstream_client_config();
let (stream_future, sender) = tls_client_connect(
addr,
server_name,
client_config,
TokioRuntimeProvider::new(),
);
let (client, bg) = match Client::with_timeout(stream_future, sender, timeout, None).await {
Ok(pair) => pair,
Err(e) => {
tracing::warn!(upstream = %addr, error = %e, "failed to build DoT client");
return None;
}
};
tokio::spawn(bg);
Some(client)
}
pub(super) async fn build_direct_client(
addr: SocketAddr,
transport: Transport,
sni: Option<&str>,
timeout: Duration,
) -> Option<Client> {
match transport {
Transport::Udp => build_udp_client(addr, timeout).await,
Transport::Tcp => build_tcp_client(addr, timeout).await,
Transport::Dot => {
let sni = sni
.map(|s| s.to_string())
.unwrap_or_else(|| addr.ip().to_string());
build_dot_client(addr, sni, timeout).await
}
}
}
fn dot_upstream_client_config() -> Arc<ClientConfig> {
static CONFIG: OnceLock<Arc<ClientConfig>> = OnceLock::new();
CONFIG
.get_or_init(|| {
let mut root_store = rustls::RootCertStore::empty();
let certs = rustls_native_certs::load_native_certs();
if !certs.errors.is_empty() {
tracing::warn!(
count = certs.errors.len(),
"errors loading native certificates for DoT upstream"
);
}
for cert in certs.certs {
let _ = root_store.add(cert);
}
if root_store.is_empty() {
tracing::error!(
"no native root certificates loaded — DoT upstream will fail to verify any resolver"
);
}
let client_config = ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
Arc::new(client_config)
})
.clone()
}
#[cfg(test)]
mod tests {
use super::*;
use futures::StreamExt;
use hickory_client::proto::op::{Message, MessageType, OpCode, Query};
use hickory_client::proto::rr::{Name, RecordType};
use hickory_client::proto::xfer::{DnsHandle, DnsRequest};
use std::time::Instant;
use tokio::io::AsyncReadExt;
fn example_query() -> Message {
let mut msg = Message::new(0x4242, MessageType::Query, OpCode::Query);
msg.set_recursion_desired(true);
msg.add_query(Query::query(
Name::from_ascii("example.com.").unwrap(),
RecordType::A,
));
msg
}
async fn blackhole_tcp() -> SocketAddr {
let l = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = l.local_addr().unwrap();
tokio::spawn(async move {
while let Ok((mut s, _)) = l.accept().await {
tokio::spawn(async move {
let mut buf = [0u8; 4096];
while s.read(&mut buf).await.unwrap_or(0) > 0 {}
});
}
});
addr
}
async fn blackhole_udp() -> SocketAddr {
let s = tokio::net::UdpSocket::bind("127.0.0.1:0").await.unwrap();
let addr = s.local_addr().unwrap();
tokio::spawn(async move {
let mut b = [0u8; 4096];
loop {
let _ = s.recv_from(&mut b).await;
}
});
addr
}
async fn assert_upstream_honors_timeout(
label: &str,
addr: SocketAddr,
build: impl std::future::Future<Output = Option<Client>>,
) {
let client = build.await.unwrap_or_else(|| panic!("{label} client"));
let mut send = client.send(DnsRequest::from(example_query()));
let start = Instant::now();
let outcome = tokio::time::timeout(Duration::from_secs(20), send.next()).await;
let el = start.elapsed();
assert!(
outcome.is_ok(),
"{label} send.next() HUNG > 20s (no per-request timeout); elapsed {el:?}"
);
assert!(
el < Duration::from_secs(4),
"{label} did not honor the 2s query_timeout (elapsed {el:?}); \
likely fell back to hickory's 5s default"
);
let _ = addr;
}
#[tokio::test]
async fn tcp_upstream_honors_query_timeout() {
let addr = blackhole_tcp().await;
assert_upstream_honors_timeout("TCP", addr, build_tcp_client(addr, Duration::from_secs(2)))
.await;
}
#[tokio::test]
async fn udp_upstream_honors_query_timeout() {
let addr = blackhole_udp().await;
assert_upstream_honors_timeout("UDP", addr, build_udp_client(addr, Duration::from_secs(2)))
.await;
}
}