use std::{net::SocketAddr, sync::Arc, time::Duration};
use hickory_net::client::Client;
use hickory_net::runtime::TokioRuntimeProvider;
use super::{Error, Result, UpstreamConfig, UpstreamTransport};
const CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
pub type UpstreamBackground =
std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send + 'static>>;
#[derive(Clone)]
pub struct UpstreamClient {
handle: Client<TokioRuntimeProvider>,
transport: UpstreamTransport,
addr: SocketAddr,
}
impl std::fmt::Debug for UpstreamClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("UpstreamClient")
.field("transport", &self.transport)
.field("addr", &self.addr)
.finish_non_exhaustive()
}
}
impl UpstreamClient {
pub async fn connect(config: &UpstreamConfig) -> Result<(Self, UpstreamBackground)> {
match config.transport {
UpstreamTransport::Udp => Self::connect_udp(config).await,
UpstreamTransport::Tcp => Self::connect_tcp(config).await,
UpstreamTransport::Dot => Self::connect_dot(config).await,
UpstreamTransport::Doh => Self::connect_doh(config).await,
}
}
pub fn handle(&self) -> &Client<TokioRuntimeProvider> {
&self.handle
}
pub fn transport(&self) -> UpstreamTransport {
self.transport
}
pub fn addr(&self) -> SocketAddr {
self.addr
}
async fn connect_udp(config: &UpstreamConfig) -> Result<(Self, UpstreamBackground)> {
use hickory_net::udp::UdpClientStream;
let provider = TokioRuntimeProvider::default();
let stream = UdpClientStream::builder(config.addr, provider)
.with_timeout(Some(CONNECT_TIMEOUT))
.build();
let (client, bg) = Client::<TokioRuntimeProvider>::from_sender(stream);
Ok((
Self {
handle: client,
transport: config.transport,
addr: config.addr,
},
Box::pin(bg),
))
}
async fn connect_tcp(config: &UpstreamConfig) -> Result<(Self, UpstreamBackground)> {
use hickory_net::tcp::TcpClientStream;
let provider = TokioRuntimeProvider::default();
let (connect, handle) =
TcpClientStream::new(config.addr, None, Some(CONNECT_TIMEOUT), provider);
let stream = connect.await.map_err(|e| Error::Connect {
transport: UpstreamTransport::Tcp,
source: e,
})?;
let (client, bg) = Client::<TokioRuntimeProvider>::new(stream, handle);
Ok((
Self {
handle: client,
transport: config.transport,
addr: config.addr,
},
Box::pin(bg),
))
}
async fn connect_dot(config: &UpstreamConfig) -> Result<(Self, UpstreamBackground)> {
use hickory_net::tls::tls_client_connect;
let name = config
.tls_server_name
.clone()
.ok_or_else(|| Error::InvalidServerName("DoT requires tls_server_name".into()))?;
let server_name = rustls_pki_types::ServerName::try_from(name.clone())
.map_err(|_| Error::InvalidServerName(name))?;
let tls_config = Arc::new(
hickory_net::tls::client_config().map_err(|e| Error::Transport(e.to_string()))?,
);
let provider = TokioRuntimeProvider::default();
let (connect, handle) = tls_client_connect(config.addr, server_name, tls_config, provider);
let stream = connect.await.map_err(|e| Error::Connect {
transport: UpstreamTransport::Dot,
source: e,
})?;
let (client, bg) = Client::<TokioRuntimeProvider>::new(stream, handle);
Ok((
Self {
handle: client,
transport: config.transport,
addr: config.addr,
},
Box::pin(bg),
))
}
async fn connect_doh(config: &UpstreamConfig) -> Result<(Self, UpstreamBackground)> {
use hickory_net::h2::HttpsClientStream;
let name = config
.tls_server_name
.clone()
.ok_or_else(|| Error::InvalidServerName("DoH requires tls_server_name".into()))?;
let path = config
.http_endpoint
.clone()
.unwrap_or_else(|| "/dns-query".to_owned());
let tls_config = Arc::new(
hickory_net::tls::client_config().map_err(|e| Error::Transport(e.to_string()))?,
);
let provider = TokioRuntimeProvider::default();
let stream = HttpsClientStream::builder(tls_config, provider)
.build(
config.addr,
Arc::from(name.as_str()),
Arc::from(path.as_str()),
)
.await
.map_err(|e| Error::Connect {
transport: UpstreamTransport::Doh,
source: e,
})?;
let (client, bg) = Client::<TokioRuntimeProvider>::from_sender(stream);
Ok((
Self {
handle: client,
transport: config.transport,
addr: config.addr,
},
Box::pin(bg),
))
}
}
#[cfg(test)]
mod tests {
use tokio::time::timeout;
use super::*;
use crate::resolver::upstream::UpstreamConfig;
#[tokio::test]
async fn udp_client_construction_and_roundtrip() {
use hickory_net::proto::op::{DnsRequest, DnsRequestOptions, Message, Query};
use hickory_net::proto::rr::{Name, RecordType};
use hickory_net::xfer::{DnsHandle, FirstAnswer as _};
use tokio::net::UdpSocket;
let server_sock = UdpSocket::bind("127.0.0.1:0").await.unwrap();
let server_addr = server_sock.local_addr().unwrap();
tokio::spawn(async move {
let mut buf = vec![0u8; 512];
let (len, peer) = server_sock.recv_from(&mut buf).await.unwrap();
if len >= 3 {
buf[2] |= 0x80; }
server_sock.send_to(&buf[..len], peer).await.unwrap();
});
let cfg = UpstreamConfig {
addr: server_addr,
transport: UpstreamTransport::Udp,
tls_server_name: None,
http_endpoint: None,
};
let (client, bg) = UpstreamClient::connect(&cfg).await.unwrap();
tokio::spawn(bg);
let name = Name::from_ascii("example.com.").unwrap();
let query = Query::query(name, RecordType::A);
let mut msg = Message::query();
msg.add_query(query);
let request = DnsRequest::new(msg, DnsRequestOptions::default());
let response = timeout(
Duration::from_secs(5),
client.handle().send(request).first_answer(),
)
.await
.expect("timed out waiting for a UDP response from the local mock")
.expect("expected a DNS response from the local mock upstream");
let buf = response.into_buffer();
assert!(buf.len() >= 3, "response datagram too short: {}", buf.len());
assert_eq!(buf[2] & 0x80, 0x80, "QR bit must be set on the response");
}
#[tokio::test]
async fn tcp_client_construction() {
use tokio::net::TcpListener;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
let _ = listener.accept().await;
});
let cfg = UpstreamConfig {
addr,
transport: UpstreamTransport::Tcp,
tls_server_name: None,
http_endpoint: None,
};
let result = timeout(Duration::from_secs(5), UpstreamClient::connect(&cfg)).await;
assert!(result.is_ok(), "timeout during TCP connect");
let (_client, bg) = result.expect("timeout").expect("TCP connect failed");
tokio::spawn(bg);
}
#[ignore = "requires network access to 1.1.1.1:853"]
#[tokio::test]
async fn dot_client_connects_to_cloudflare() {
let cfg = UpstreamConfig {
addr: "1.1.1.1:853".parse().unwrap(),
transport: UpstreamTransport::Dot,
tls_server_name: Some("cloudflare-dns.com".to_owned()),
http_endpoint: None,
};
let (client, bg) = UpstreamClient::connect(&cfg)
.await
.expect("DoT connect failed");
tokio::spawn(bg);
assert_eq!(client.transport(), UpstreamTransport::Dot);
}
#[ignore = "requires network access to 1.1.1.1:443"]
#[tokio::test]
async fn doh_client_connects_to_cloudflare() {
let cfg = UpstreamConfig {
addr: "1.1.1.1:443".parse().unwrap(),
transport: UpstreamTransport::Doh,
tls_server_name: Some("cloudflare-dns.com".to_owned()),
http_endpoint: Some("/dns-query".to_owned()),
};
let (client, bg) = UpstreamClient::connect(&cfg)
.await
.expect("DoH connect failed");
tokio::spawn(bg);
assert_eq!(client.transport(), UpstreamTransport::Doh);
}
}