1use std::future::Future;
9use std::marker::Unpin;
10use std::net::{IpAddr, SocketAddr};
11#[cfg(feature = "__quic")]
12use std::net::{Ipv4Addr, Ipv6Addr};
13use std::pin::Pin;
14#[cfg(any(feature = "__tls", feature = "__https"))]
15use std::sync::Arc;
16
17#[cfg(feature = "__https")]
18use hickory_net::h2::HttpsClientStream;
19#[cfg(feature = "__tls")]
20use rustls::DigitallySignedStruct;
21#[cfg(feature = "__tls")]
22use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
23#[cfg(feature = "__tls")]
24use rustls::crypto::{CryptoProvider, verify_tls12_signature, verify_tls13_signature};
25#[cfg(feature = "__tls")]
26use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
27#[cfg(not(feature = "__tls"))]
28use tracing::warn;
29
30#[cfg(feature = "__h3")]
31use crate::net::h3::H3ClientStream;
32#[cfg(feature = "__quic")]
33use crate::net::quic::QuicClientStream;
34#[cfg(feature = "__tls")]
35use crate::net::tls::{client_config, default_provider, tls_exchange};
36use crate::{
37 config::{ConnectionConfig, ProtocolConfig},
38 name_server_pool::PoolContext,
39 net::{
40 NetError,
41 runtime::RuntimeProvider,
42 tcp::TcpClientStream,
43 udp::UdpClientStream,
44 xfer::{DnsExchange, DnsHandle},
45 },
46};
47
48pub trait ConnectionProvider: 'static + Clone + Send + Sync + Unpin {
51 type Conn: DnsHandle + Clone + Send + Sync + 'static;
53 type FutureConn: Future<Output = Result<Self::Conn, NetError>> + Send + 'static;
55 type RuntimeProvider: RuntimeProvider;
57
58 fn new_connection(
60 &self,
61 ip: IpAddr,
62 config: &ConnectionConfig,
63 cx: &PoolContext,
64 ) -> Result<Self::FutureConn, NetError>;
65
66 fn runtime_provider(&self) -> &Self::RuntimeProvider;
68}
69
70impl<P: RuntimeProvider> ConnectionProvider for P {
71 type Conn = DnsExchange<P>;
72 type FutureConn = Pin<Box<dyn Future<Output = Result<Self::Conn, NetError>> + Send + 'static>>;
73 type RuntimeProvider = P;
74
75 fn new_connection(
76 &self,
77 ip: IpAddr,
78 config: &ConnectionConfig,
79 cx: &PoolContext,
80 ) -> Result<Self::FutureConn, NetError> {
81 let remote_addr = SocketAddr::new(ip, config.port);
82 match (&config.protocol, self.quic_binder()) {
83 (ProtocolConfig::Udp, _) => {
84 let (timeout, os_port_selection, avoid_local_udp_ports, bind_addr, provider) = (
85 cx.options.timeout,
86 cx.options.os_port_selection,
87 cx.options.avoid_local_udp_ports.clone(),
88 config.bind_addr,
89 self.clone(),
90 );
91
92 Ok(Box::pin(async move {
93 Ok(UdpClientStream::builder(remote_addr, provider)
94 .with_timeout(Some(timeout))
95 .with_os_port_selection(os_port_selection)
96 .avoid_local_ports(avoid_local_udp_ports)
97 .with_bind_addr(bind_addr)
98 .exchange())
99 }))
100 }
101 (ProtocolConfig::Tcp, _) => Ok(Box::pin(TcpClientStream::exchange(
102 remote_addr,
103 config.bind_addr,
104 cx.options.timeout,
105 Some(cx.options.max_active_requests),
106 self.clone(),
107 ))),
108 #[cfg(feature = "__tls")]
109 (ProtocolConfig::Tls { server_name }, _) => {
110 let Ok(server_name) = ServerName::try_from(&**server_name) else {
111 return Err(NetError::from(format!(
112 "invalid server name: {server_name}"
113 )));
114 };
115
116 let server_name = server_name.to_owned();
117 Ok(Box::pin(tls_exchange(
118 remote_addr,
119 server_name,
120 cx.tls.clone(),
121 cx.options.timeout,
122 Some(cx.options.max_active_requests),
123 self.clone(),
124 )))
125 }
126 #[cfg(feature = "__https")]
127 (ProtocolConfig::Https { server_name, path }, _) => Ok(Box::pin(
128 HttpsClientStream::builder(Arc::new(cx.tls.clone()), self.clone()).exchange(
129 remote_addr,
130 server_name.clone(),
131 path.clone(),
132 ),
133 )),
134
135 #[cfg(feature = "__quic")]
136 (ProtocolConfig::Quic { server_name }, Some(binder)) => {
137 let bind_addr = config.bind_addr.unwrap_or(match remote_addr {
138 SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
139 SocketAddr::V6(_) => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0),
140 });
141
142 Ok(Box::pin(
143 QuicClientStream::builder()
144 .crypto_config(cx.tls.clone())
145 .exchange(
146 binder.bind_quic(bind_addr, remote_addr)?,
147 remote_addr,
148 server_name.clone(),
149 self.clone(),
150 ),
151 ))
152 }
153 #[cfg(feature = "__h3")]
154 (
155 ProtocolConfig::H3 {
156 server_name,
157 path,
158 disable_grease,
159 },
160 Some(binder),
161 ) => {
162 let bind_addr = config.bind_addr.unwrap_or(match remote_addr {
163 SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
164 SocketAddr::V6(_) => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0),
165 });
166
167 Ok(Box::pin(
168 H3ClientStream::builder()
169 .crypto_config(cx.tls.clone())
170 .disable_grease(*disable_grease)
171 .exchange(
172 binder.bind_quic(bind_addr, remote_addr)?,
173 remote_addr,
174 server_name.clone(),
175 path.clone(),
176 self.clone(),
177 ),
178 ))
179 }
180 #[cfg(feature = "__quic")]
181 (ProtocolConfig::Quic { .. }, None) => {
182 Err(NetError::from("runtime provider does not support QUIC"))
183 }
184 #[cfg(feature = "__h3")]
185 (ProtocolConfig::H3 { .. }, None) => {
186 Err(NetError::from("runtime provider does not support QUIC"))
187 }
188 }
189 }
190
191 fn runtime_provider(&self) -> &Self::RuntimeProvider {
192 self
193 }
194}
195
196pub struct TlsConfig {
198 #[cfg(feature = "__tls")]
200 pub config: rustls::ClientConfig,
201}
202
203impl TlsConfig {
204 pub fn new() -> Result<Self, NetError> {
206 Ok(Self {
207 #[cfg(feature = "__tls")]
208 config: client_config()?,
209 })
210 }
211
212 #[cfg(feature = "__tls")]
217 pub fn insecure_skip_verify(&mut self) {
218 self.config
219 .dangerous()
220 .set_certificate_verifier(Arc::new(NoCertificateVerification::default()))
221 }
222
223 #[cfg(not(feature = "__tls"))]
228 pub fn insecure_skip_verify(&mut self) {
229 warn!("asked to skip TLS verification without TLS support")
230 }
231}
232
233#[cfg(feature = "__tls")]
238#[derive(Debug)]
239struct NoCertificateVerification(CryptoProvider);
240
241#[cfg(feature = "__tls")]
242impl Default for NoCertificateVerification {
243 fn default() -> Self {
244 Self(default_provider())
245 }
246}
247
248#[cfg(feature = "__tls")]
249impl ServerCertVerifier for NoCertificateVerification {
250 fn verify_server_cert(
251 &self,
252 _end_entity: &CertificateDer<'_>,
253 _intermediates: &[CertificateDer<'_>],
254 _server_name: &ServerName<'_>,
255 _ocsp: &[u8],
256 _now: UnixTime,
257 ) -> Result<ServerCertVerified, rustls::Error> {
258 Ok(ServerCertVerified::assertion())
259 }
260
261 fn verify_tls12_signature(
262 &self,
263 message: &[u8],
264 cert: &CertificateDer<'_>,
265 dss: &DigitallySignedStruct,
266 ) -> Result<HandshakeSignatureValid, rustls::Error> {
267 verify_tls12_signature(
268 message,
269 cert,
270 dss,
271 &self.0.signature_verification_algorithms,
272 )
273 }
274
275 fn verify_tls13_signature(
276 &self,
277 message: &[u8],
278 cert: &CertificateDer<'_>,
279 dss: &DigitallySignedStruct,
280 ) -> Result<HandshakeSignatureValid, rustls::Error> {
281 verify_tls13_signature(
282 message,
283 cert,
284 dss,
285 &self.0.signature_verification_algorithms,
286 )
287 }
288
289 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
290 self.0.signature_verification_algorithms.supported_schemes()
291 }
292}
293
294#[cfg(all(
295 test,
296 feature = "tokio",
297 any(feature = "webpki-roots", feature = "rustls-platform-verifier"),
298 any(
299 feature = "__tls",
300 feature = "__https",
301 feature = "__quic",
302 feature = "__h3"
303 )
304))]
305mod tests {
306 #[cfg(feature = "__quic")]
307 use std::net::IpAddr;
308
309 use test_support::subscribe;
310
311 use crate::TokioResolver;
312 #[cfg(any(feature = "__tls", feature = "__https"))]
313 use crate::config::CLOUDFLARE;
314 #[cfg(any(
315 feature = "__tls",
316 feature = "__https",
317 feature = "__quic",
318 feature = "__h3"
319 ))]
320 use crate::config::GOOGLE;
321 use crate::config::ResolverConfig;
322 #[cfg(feature = "__quic")]
323 use crate::config::ServerGroup;
324 #[cfg(feature = "__quic")]
325 use crate::config::ServerOrderingStrategy;
326 use crate::net::runtime::TokioRuntimeProvider;
327 #[cfg(feature = "__quic")]
328 use crate::net::tls::client_config;
329
330 #[cfg(feature = "__h3")]
331 #[tokio::test]
332 async fn test_google_h3() {
333 subscribe();
334 h3_test(ResolverConfig::h3(&GOOGLE)).await
335 }
336
337 #[cfg(feature = "__h3")]
338 async fn h3_test(config: ResolverConfig) {
339 let mut builder =
340 TokioResolver::builder_with_config(config, TokioRuntimeProvider::default());
341 builder.options_mut().server_ordering_strategy = ServerOrderingStrategy::UserProvidedOrder;
343 let resolver = builder.build().unwrap();
344
345 let response = resolver
346 .lookup_ip("www.example.com.")
347 .await
348 .expect("failed to run lookup");
349
350 assert_ne!(response.iter().count(), 0);
351
352 let response = resolver
354 .lookup_ip("www.example.com.")
355 .await
356 .expect("failed to run lookup");
357
358 assert_ne!(response.iter().count(), 0);
359 }
360
361 #[cfg(feature = "__quic")]
362 #[tokio::test]
363 async fn test_adguard_quic() {
364 subscribe();
365
366 let config = client_config().unwrap();
368
369 let group = ServerGroup {
370 ips: &[
371 IpAddr::from([94, 140, 14, 140]),
372 IpAddr::from([94, 140, 14, 141]),
373 IpAddr::from([0x2a10, 0x50c0, 0, 0, 0, 0, 0x1, 0xff]),
374 IpAddr::from([0x2a10, 0x50c0, 0, 0, 0, 0, 0x2, 0xff]),
375 ],
376 server_name: "unfiltered.adguard-dns.com",
377 path: "/dns-query",
378 };
379
380 quic_test(ResolverConfig::quic(&group), config).await
381 }
382
383 #[cfg(feature = "__quic")]
384 async fn quic_test(config: ResolverConfig, tls_config: rustls::ClientConfig) {
385 let mut resolver_builder =
386 TokioResolver::builder_with_config(config, TokioRuntimeProvider::default());
387 resolver_builder.options_mut().try_tcp_on_error = true;
388 resolver_builder.options_mut().server_ordering_strategy =
390 ServerOrderingStrategy::UserProvidedOrder;
391 resolver_builder = resolver_builder.with_tls_config(tls_config);
392 let resolver = resolver_builder.build().unwrap();
393
394 let response = resolver
395 .lookup_ip("www.example.com.")
396 .await
397 .expect("failed to run lookup");
398
399 assert_ne!(response.iter().count(), 0);
400
401 let response = resolver
403 .lookup_ip("www.example.com.")
404 .await
405 .expect("failed to run lookup");
406
407 assert_ne!(response.iter().count(), 0);
408 }
409
410 #[cfg(feature = "__https")]
411 #[tokio::test]
412 async fn test_google_https() {
413 subscribe();
414 https_test(ResolverConfig::https(&GOOGLE)).await
415 }
416
417 #[cfg(feature = "__https")]
418 #[tokio::test]
419 async fn test_cloudflare_https() {
420 subscribe();
421 https_test(ResolverConfig::https(&CLOUDFLARE)).await
422 }
423
424 #[cfg(feature = "__https")]
425 async fn https_test(config: ResolverConfig) {
426 let mut resolver_builder =
427 TokioResolver::builder_with_config(config, TokioRuntimeProvider::default());
428 resolver_builder.options_mut().try_tcp_on_error = true;
429 let resolver = resolver_builder.build().unwrap();
430
431 let response = resolver
432 .lookup_ip("www.example.com.")
433 .await
434 .expect("failed to run lookup");
435
436 assert_ne!(response.iter().count(), 0);
437
438 let response = resolver
440 .lookup_ip("www.example.com.")
441 .await
442 .expect("failed to run lookup");
443
444 assert_ne!(response.iter().count(), 0);
445 }
446
447 #[cfg(feature = "__tls")]
448 #[tokio::test]
449 async fn test_google_tls() {
450 subscribe();
451 tls_test(ResolverConfig::tls(&GOOGLE)).await
452 }
453
454 #[cfg(feature = "__tls")]
455 #[tokio::test]
456 async fn test_cloudflare_tls() {
457 subscribe();
458 tls_test(ResolverConfig::tls(&CLOUDFLARE)).await
459 }
460
461 #[cfg(feature = "__tls")]
462 async fn tls_test(config: ResolverConfig) {
463 let mut resolver_builder =
464 TokioResolver::builder_with_config(config, TokioRuntimeProvider::default());
465 resolver_builder.options_mut().try_tcp_on_error = true;
466 let resolver = resolver_builder.build().unwrap();
467
468 let response = resolver
469 .lookup_ip("www.example.com.")
470 .await
471 .expect("failed to run lookup");
472
473 assert_ne!(response.iter().count(), 0);
474 }
475}