use std::future::Future;
use std::marker::Unpin;
use std::net::{IpAddr, SocketAddr};
#[cfg(feature = "__quic")]
use std::net::{Ipv4Addr, Ipv6Addr};
use std::pin::Pin;
#[cfg(any(feature = "__tls", feature = "__https"))]
use std::sync::Arc;
#[cfg(feature = "__https")]
use hickory_net::h2::HttpsClientStream;
#[cfg(feature = "__tls")]
use rustls::DigitallySignedStruct;
#[cfg(feature = "__tls")]
use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
#[cfg(feature = "__tls")]
use rustls::crypto::{CryptoProvider, verify_tls12_signature, verify_tls13_signature};
#[cfg(feature = "__tls")]
use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
#[cfg(not(feature = "__tls"))]
use tracing::warn;
#[cfg(feature = "__h3")]
use crate::net::h3::H3ClientStream;
#[cfg(feature = "__quic")]
use crate::net::quic::QuicClientStream;
#[cfg(feature = "__tls")]
use crate::net::tls::{client_config, default_provider, tls_exchange};
use crate::{
config::{ConnectionConfig, ProtocolConfig},
name_server_pool::PoolContext,
net::{
NetError,
runtime::RuntimeProvider,
tcp::TcpClientStream,
udp::UdpClientStream,
xfer::{DnsExchange, DnsHandle},
},
};
pub trait ConnectionProvider: 'static + Clone + Send + Sync + Unpin {
type Conn: DnsHandle + Clone + Send + Sync + 'static;
type FutureConn: Future<Output = Result<Self::Conn, NetError>> + Send + 'static;
type RuntimeProvider: RuntimeProvider;
fn new_connection(
&self,
ip: IpAddr,
config: &ConnectionConfig,
cx: &PoolContext,
) -> Result<Self::FutureConn, NetError>;
fn runtime_provider(&self) -> &Self::RuntimeProvider;
}
impl<P: RuntimeProvider> ConnectionProvider for P {
type Conn = DnsExchange<P>;
type FutureConn = Pin<Box<dyn Future<Output = Result<Self::Conn, NetError>> + Send + 'static>>;
type RuntimeProvider = P;
fn new_connection(
&self,
ip: IpAddr,
config: &ConnectionConfig,
cx: &PoolContext,
) -> Result<Self::FutureConn, NetError> {
let remote_addr = SocketAddr::new(ip, config.port);
match (&config.protocol, self.quic_binder()) {
(ProtocolConfig::Udp, _) => {
let (timeout, os_port_selection, avoid_local_udp_ports, bind_addr, provider) = (
cx.options.timeout,
cx.options.os_port_selection,
cx.options.avoid_local_udp_ports.clone(),
config.bind_addr,
self.clone(),
);
Ok(Box::pin(async move {
Ok(UdpClientStream::builder(remote_addr, provider)
.with_timeout(Some(timeout))
.with_os_port_selection(os_port_selection)
.avoid_local_ports(avoid_local_udp_ports)
.with_bind_addr(bind_addr)
.exchange())
}))
}
(ProtocolConfig::Tcp, _) => Ok(Box::pin(TcpClientStream::exchange(
remote_addr,
config.bind_addr,
cx.options.timeout,
self.clone(),
))),
#[cfg(feature = "__tls")]
(ProtocolConfig::Tls { server_name }, _) => {
let Ok(server_name) = ServerName::try_from(&**server_name) else {
return Err(NetError::from(format!(
"invalid server name: {server_name}"
)));
};
let server_name = server_name.to_owned();
Ok(Box::pin(tls_exchange(
remote_addr,
server_name,
cx.tls.clone(),
cx.options.timeout,
self.clone(),
)))
}
#[cfg(feature = "__https")]
(ProtocolConfig::Https { server_name, path }, _) => Ok(Box::pin(
HttpsClientStream::builder(Arc::new(cx.tls.clone()), self.clone()).exchange(
remote_addr,
server_name.clone(),
path.clone(),
),
)),
#[cfg(feature = "__quic")]
(ProtocolConfig::Quic { server_name }, Some(binder)) => {
let bind_addr = config.bind_addr.unwrap_or(match remote_addr {
SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
SocketAddr::V6(_) => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0),
});
Ok(Box::pin(
QuicClientStream::builder()
.crypto_config(cx.tls.clone())
.exchange(
binder.bind_quic(bind_addr, remote_addr)?,
remote_addr,
server_name.clone(),
self.clone(),
),
))
}
#[cfg(feature = "__h3")]
(
ProtocolConfig::H3 {
server_name,
path,
disable_grease,
},
Some(binder),
) => {
let bind_addr = config.bind_addr.unwrap_or(match remote_addr {
SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
SocketAddr::V6(_) => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0),
});
Ok(Box::pin(
H3ClientStream::builder()
.crypto_config(cx.tls.clone())
.disable_grease(*disable_grease)
.exchange(
binder.bind_quic(bind_addr, remote_addr)?,
remote_addr,
server_name.clone(),
path.clone(),
self.clone(),
),
))
}
#[cfg(feature = "__quic")]
(ProtocolConfig::Quic { .. }, None) => {
Err(NetError::from("runtime provider does not support QUIC"))
}
#[cfg(feature = "__h3")]
(ProtocolConfig::H3 { .. }, None) => {
Err(NetError::from("runtime provider does not support QUIC"))
}
}
}
fn runtime_provider(&self) -> &Self::RuntimeProvider {
self
}
}
pub struct TlsConfig {
#[cfg(feature = "__tls")]
pub config: rustls::ClientConfig,
}
impl TlsConfig {
pub fn new() -> Result<Self, NetError> {
Ok(Self {
#[cfg(feature = "__tls")]
config: client_config()?,
})
}
#[cfg(feature = "__tls")]
pub fn insecure_skip_verify(&mut self) {
self.config
.dangerous()
.set_certificate_verifier(Arc::new(NoCertificateVerification::default()))
}
#[cfg(not(feature = "__tls"))]
pub fn insecure_skip_verify(&mut self) {
warn!("asked to skip TLS verification without TLS support")
}
}
#[cfg(feature = "__tls")]
#[derive(Debug)]
struct NoCertificateVerification(CryptoProvider);
#[cfg(feature = "__tls")]
impl Default for NoCertificateVerification {
fn default() -> Self {
Self(default_provider())
}
}
#[cfg(feature = "__tls")]
impl ServerCertVerifier for NoCertificateVerification {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName<'_>,
_ocsp: &[u8],
_now: UnixTime,
) -> Result<ServerCertVerified, rustls::Error> {
Ok(ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
verify_tls12_signature(
message,
cert,
dss,
&self.0.signature_verification_algorithms,
)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
verify_tls13_signature(
message,
cert,
dss,
&self.0.signature_verification_algorithms,
)
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
self.0.signature_verification_algorithms.supported_schemes()
}
}
#[cfg(all(
test,
feature = "tokio",
any(feature = "webpki-roots", feature = "rustls-platform-verifier"),
any(
feature = "__tls",
feature = "__https",
feature = "__quic",
feature = "__h3"
)
))]
mod tests {
#[cfg(feature = "__quic")]
use std::net::IpAddr;
use test_support::subscribe;
use crate::TokioResolver;
#[cfg(any(feature = "__tls", feature = "__https"))]
use crate::config::CLOUDFLARE;
#[cfg(any(
feature = "__tls",
feature = "__https",
feature = "__quic",
feature = "__h3"
))]
use crate::config::GOOGLE;
use crate::config::ResolverConfig;
#[cfg(feature = "__quic")]
use crate::config::ServerGroup;
#[cfg(feature = "__quic")]
use crate::config::ServerOrderingStrategy;
use crate::net::runtime::TokioRuntimeProvider;
#[cfg(feature = "__quic")]
use crate::net::tls::client_config;
#[cfg(feature = "__h3")]
#[tokio::test]
async fn test_google_h3() {
subscribe();
h3_test(ResolverConfig::h3(&GOOGLE)).await
}
#[cfg(feature = "__h3")]
async fn h3_test(config: ResolverConfig) {
let mut builder =
TokioResolver::builder_with_config(config, TokioRuntimeProvider::default());
builder.options_mut().server_ordering_strategy = ServerOrderingStrategy::UserProvidedOrder;
let resolver = builder.build().unwrap();
let response = resolver
.lookup_ip("www.example.com.")
.await
.expect("failed to run lookup");
assert_ne!(response.iter().count(), 0);
let response = resolver
.lookup_ip("www.example.com.")
.await
.expect("failed to run lookup");
assert_ne!(response.iter().count(), 0);
}
#[cfg(feature = "__quic")]
#[tokio::test]
async fn test_adguard_quic() {
subscribe();
let config = client_config().unwrap();
let group = ServerGroup {
ips: &[
IpAddr::from([94, 140, 14, 140]),
IpAddr::from([94, 140, 14, 141]),
IpAddr::from([0x2a10, 0x50c0, 0, 0, 0, 0, 0x1, 0xff]),
IpAddr::from([0x2a10, 0x50c0, 0, 0, 0, 0, 0x2, 0xff]),
],
server_name: "unfiltered.adguard-dns.com",
path: "/dns-query",
};
quic_test(ResolverConfig::quic(&group), config).await
}
#[cfg(feature = "__quic")]
async fn quic_test(config: ResolverConfig, tls_config: rustls::ClientConfig) {
let mut resolver_builder =
TokioResolver::builder_with_config(config, TokioRuntimeProvider::default());
resolver_builder.options_mut().try_tcp_on_error = true;
resolver_builder.options_mut().server_ordering_strategy =
ServerOrderingStrategy::UserProvidedOrder;
resolver_builder = resolver_builder.with_tls_config(tls_config);
let resolver = resolver_builder.build().unwrap();
let response = resolver
.lookup_ip("www.example.com.")
.await
.expect("failed to run lookup");
assert_ne!(response.iter().count(), 0);
let response = resolver
.lookup_ip("www.example.com.")
.await
.expect("failed to run lookup");
assert_ne!(response.iter().count(), 0);
}
#[cfg(feature = "__https")]
#[tokio::test]
async fn test_google_https() {
subscribe();
https_test(ResolverConfig::https(&GOOGLE)).await
}
#[cfg(feature = "__https")]
#[tokio::test]
async fn test_cloudflare_https() {
subscribe();
https_test(ResolverConfig::https(&CLOUDFLARE)).await
}
#[cfg(feature = "__https")]
async fn https_test(config: ResolverConfig) {
let mut resolver_builder =
TokioResolver::builder_with_config(config, TokioRuntimeProvider::default());
resolver_builder.options_mut().try_tcp_on_error = true;
let resolver = resolver_builder.build().unwrap();
let response = resolver
.lookup_ip("www.example.com.")
.await
.expect("failed to run lookup");
assert_ne!(response.iter().count(), 0);
let response = resolver
.lookup_ip("www.example.com.")
.await
.expect("failed to run lookup");
assert_ne!(response.iter().count(), 0);
}
#[cfg(feature = "__tls")]
#[tokio::test]
async fn test_google_tls() {
subscribe();
tls_test(ResolverConfig::tls(&GOOGLE)).await
}
#[cfg(feature = "__tls")]
#[tokio::test]
async fn test_cloudflare_tls() {
subscribe();
tls_test(ResolverConfig::tls(&CLOUDFLARE)).await
}
#[cfg(feature = "__tls")]
async fn tls_test(config: ResolverConfig) {
let mut resolver_builder =
TokioResolver::builder_with_config(config, TokioRuntimeProvider::default());
resolver_builder.options_mut().try_tcp_on_error = true;
let resolver = resolver_builder.build().unwrap();
let response = resolver
.lookup_ip("www.example.com.")
.await
.expect("failed to run lookup");
assert_ne!(response.iter().count(), 0);
}
}