use std::net::SocketAddr;
use std::sync::{Arc, OnceLock};
use std::time::Duration;
use hickory_client::client::Client;
use hickory_client::proto::runtime::TokioRuntimeProvider;
use hickory_client::proto::tcp::TcpClientStream;
use hickory_client::proto::udp::UdpClientStream;
use rustls::ClientConfig;
use super::common::transport::Transport;
pub(super) async fn build_udp_client(addr: SocketAddr, timeout: Duration) -> Option<Client> {
let stream =
UdpClientStream::<TokioRuntimeProvider>::builder(addr, TokioRuntimeProvider::new())
.with_timeout(Some(timeout))
.build();
let (client, bg) = match Client::connect(stream).await {
Ok(pair) => pair,
Err(e) => {
tracing::warn!(upstream = %addr, error = %e, "failed to build UDP DNS client");
return None;
}
};
tokio::spawn(bg);
Some(client)
}
pub(super) async fn build_tcp_client(addr: SocketAddr, timeout: Duration) -> Option<Client> {
let (stream, sender) =
TcpClientStream::new(addr, None, Some(timeout), TokioRuntimeProvider::new());
let (client, bg) = match Client::new(stream, sender, None).await {
Ok(pair) => pair,
Err(e) => {
tracing::warn!(upstream = %addr, error = %e, "failed to build TCP DNS client");
return None;
}
};
tokio::spawn(bg);
Some(client)
}
pub(super) async fn build_dot_client(
addr: SocketAddr,
sni: String,
timeout: Duration,
) -> Option<Client> {
use hickory_proto::rustls::tls_client_connect;
use rustls::pki_types::ServerName;
let server_name = match ServerName::try_from(sni.clone()) {
Ok(name) => name,
Err(e) => {
tracing::warn!(upstream = %addr, sni = %sni, error = %e, "invalid SNI for DoT");
return None;
}
};
let client_config = dot_upstream_client_config();
let (stream_future, sender) = tls_client_connect(
addr,
server_name,
client_config,
TokioRuntimeProvider::new(),
);
let (client, bg) = match Client::with_timeout(stream_future, sender, timeout, None).await {
Ok(pair) => pair,
Err(e) => {
tracing::warn!(upstream = %addr, error = %e, "failed to build DoT client");
return None;
}
};
tokio::spawn(bg);
Some(client)
}
pub(super) async fn build_direct_client(
addr: SocketAddr,
transport: Transport,
sni: Option<&str>,
timeout: Duration,
) -> Option<Client> {
match transport {
Transport::Udp => build_udp_client(addr, timeout).await,
Transport::Tcp => build_tcp_client(addr, timeout).await,
Transport::Dot => {
let sni = sni
.map(|s| s.to_string())
.unwrap_or_else(|| addr.ip().to_string());
build_dot_client(addr, sni, timeout).await
}
}
}
fn dot_upstream_client_config() -> Arc<ClientConfig> {
static CONFIG: OnceLock<Arc<ClientConfig>> = OnceLock::new();
CONFIG
.get_or_init(|| {
let mut root_store = rustls::RootCertStore::empty();
let certs = rustls_native_certs::load_native_certs();
if !certs.errors.is_empty() {
tracing::warn!(
count = certs.errors.len(),
"errors loading native certificates for DoT upstream"
);
}
for cert in certs.certs {
let _ = root_store.add(cert);
}
if root_store.is_empty() {
tracing::error!(
"no native root certificates loaded — DoT upstream will fail to verify any resolver"
);
}
let client_config = ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
Arc::new(client_config)
})
.clone()
}