Skip to main content

hickory_resolver/
connection_provider.rs

1// Copyright 2015-2019 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use std::future::Future;
9use std::marker::Unpin;
10use std::net::{IpAddr, SocketAddr};
11#[cfg(feature = "__quic")]
12use std::net::{Ipv4Addr, Ipv6Addr};
13use std::pin::Pin;
14#[cfg(any(feature = "__tls", feature = "__https"))]
15use std::sync::Arc;
16
17#[cfg(feature = "__https")]
18use hickory_net::h2::HttpsClientStream;
19#[cfg(feature = "__tls")]
20use rustls::DigitallySignedStruct;
21#[cfg(feature = "__tls")]
22use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
23#[cfg(feature = "__tls")]
24use rustls::crypto::{CryptoProvider, verify_tls12_signature, verify_tls13_signature};
25#[cfg(feature = "__tls")]
26use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
27#[cfg(not(feature = "__tls"))]
28use tracing::warn;
29
30#[cfg(feature = "__h3")]
31use crate::net::h3::H3ClientStream;
32#[cfg(feature = "__quic")]
33use crate::net::quic::QuicClientStream;
34#[cfg(feature = "__tls")]
35use crate::net::tls::{client_config, default_provider, tls_exchange};
36use crate::{
37    config::{ConnectionConfig, ProtocolConfig},
38    name_server_pool::PoolContext,
39    net::{
40        NetError,
41        runtime::RuntimeProvider,
42        tcp::TcpClientStream,
43        udp::UdpClientStream,
44        xfer::{DnsExchange, DnsHandle},
45    },
46};
47
48/// Create `DnsHandle` with the help of `RuntimeProvider`.
49/// This trait is designed for customization.
50pub trait ConnectionProvider: 'static + Clone + Send + Sync + Unpin {
51    /// The handle to the connection for sending DNS requests.
52    type Conn: DnsHandle + Clone + Send + Sync + 'static;
53    /// Ths future is responsible for spawning any background tasks as necessary.
54    type FutureConn: Future<Output = Result<Self::Conn, NetError>> + Send + 'static;
55    /// Provider that handles the underlying I/O and timing.
56    type RuntimeProvider: RuntimeProvider;
57
58    /// Create a new connection.
59    fn new_connection(
60        &self,
61        ip: IpAddr,
62        config: &ConnectionConfig,
63        cx: &PoolContext,
64    ) -> Result<Self::FutureConn, NetError>;
65
66    /// Get a reference to a [`RuntimeProvider`].
67    fn runtime_provider(&self) -> &Self::RuntimeProvider;
68}
69
70impl<P: RuntimeProvider> ConnectionProvider for P {
71    type Conn = DnsExchange<P>;
72    type FutureConn = Pin<Box<dyn Future<Output = Result<Self::Conn, NetError>> + Send + 'static>>;
73    type RuntimeProvider = P;
74
75    fn new_connection(
76        &self,
77        ip: IpAddr,
78        config: &ConnectionConfig,
79        cx: &PoolContext,
80    ) -> Result<Self::FutureConn, NetError> {
81        let remote_addr = SocketAddr::new(ip, config.port);
82        match (&config.protocol, self.quic_binder()) {
83            (ProtocolConfig::Udp, _) => {
84                let (timeout, os_port_selection, avoid_local_udp_ports, bind_addr, provider) = (
85                    cx.options.timeout,
86                    cx.options.os_port_selection,
87                    cx.options.avoid_local_udp_ports.clone(),
88                    config.bind_addr,
89                    self.clone(),
90                );
91
92                Ok(Box::pin(async move {
93                    Ok(UdpClientStream::builder(remote_addr, provider)
94                        .with_timeout(Some(timeout))
95                        .with_os_port_selection(os_port_selection)
96                        .avoid_local_ports(avoid_local_udp_ports)
97                        .with_bind_addr(bind_addr)
98                        .exchange())
99                }))
100            }
101            (ProtocolConfig::Tcp, _) => Ok(Box::pin(TcpClientStream::exchange(
102                remote_addr,
103                config.bind_addr,
104                cx.options.timeout,
105                Some(cx.options.max_active_requests),
106                self.clone(),
107            ))),
108            #[cfg(feature = "__tls")]
109            (ProtocolConfig::Tls { server_name }, _) => {
110                let Ok(server_name) = ServerName::try_from(&**server_name) else {
111                    return Err(NetError::from(format!(
112                        "invalid server name: {server_name}"
113                    )));
114                };
115
116                let server_name = server_name.to_owned();
117                Ok(Box::pin(tls_exchange(
118                    remote_addr,
119                    server_name,
120                    cx.tls.clone(),
121                    cx.options.timeout,
122                    Some(cx.options.max_active_requests),
123                    self.clone(),
124                )))
125            }
126            #[cfg(feature = "__https")]
127            (ProtocolConfig::Https { server_name, path }, _) => Ok(Box::pin(
128                HttpsClientStream::builder(Arc::new(cx.tls.clone()), self.clone()).exchange(
129                    remote_addr,
130                    server_name.clone(),
131                    path.clone(),
132                ),
133            )),
134
135            #[cfg(feature = "__quic")]
136            (ProtocolConfig::Quic { server_name }, Some(binder)) => {
137                let bind_addr = config.bind_addr.unwrap_or(match remote_addr {
138                    SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
139                    SocketAddr::V6(_) => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0),
140                });
141
142                Ok(Box::pin(
143                    QuicClientStream::builder()
144                        .crypto_config(cx.tls.clone())
145                        .exchange(
146                            binder.bind_quic(bind_addr, remote_addr)?,
147                            remote_addr,
148                            server_name.clone(),
149                            self.clone(),
150                        ),
151                ))
152            }
153            #[cfg(feature = "__h3")]
154            (
155                ProtocolConfig::H3 {
156                    server_name,
157                    path,
158                    disable_grease,
159                },
160                Some(binder),
161            ) => {
162                let bind_addr = config.bind_addr.unwrap_or(match remote_addr {
163                    SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
164                    SocketAddr::V6(_) => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0),
165                });
166
167                Ok(Box::pin(
168                    H3ClientStream::builder()
169                        .crypto_config(cx.tls.clone())
170                        .disable_grease(*disable_grease)
171                        .exchange(
172                            binder.bind_quic(bind_addr, remote_addr)?,
173                            remote_addr,
174                            server_name.clone(),
175                            path.clone(),
176                            self.clone(),
177                        ),
178                ))
179            }
180            #[cfg(feature = "__quic")]
181            (ProtocolConfig::Quic { .. }, None) => {
182                Err(NetError::from("runtime provider does not support QUIC"))
183            }
184            #[cfg(feature = "__h3")]
185            (ProtocolConfig::H3 { .. }, None) => {
186                Err(NetError::from("runtime provider does not support QUIC"))
187            }
188        }
189    }
190
191    fn runtime_provider(&self) -> &Self::RuntimeProvider {
192        self
193    }
194}
195
196/// TLS configuration for the connection provider.
197pub struct TlsConfig {
198    /// The TLS configuration to use for secure connections.
199    #[cfg(feature = "__tls")]
200    pub config: rustls::ClientConfig,
201}
202
203impl TlsConfig {
204    /// Create a new `TlsConfig` with default settings.
205    pub fn new() -> Result<Self, NetError> {
206        Ok(Self {
207            #[cfg(feature = "__tls")]
208            config: client_config()?,
209        })
210    }
211
212    /// Disable certificate verification.
213    ///
214    /// This is typically unsafe and insecure, except in the context of RFC 9539 opportunistic
215    /// encryption which requires the peer certificate not be verified.
216    #[cfg(feature = "__tls")]
217    pub fn insecure_skip_verify(&mut self) {
218        self.config
219            .dangerous()
220            .set_certificate_verifier(Arc::new(NoCertificateVerification::default()))
221    }
222
223    /// Disable certificate verification.
224    ///
225    /// This is typically unsafe and insecure, except in the context of RFC 9539 opportunistic
226    /// encryption which requires the peer certificate not be verified.
227    #[cfg(not(feature = "__tls"))]
228    pub fn insecure_skip_verify(&mut self) {
229        warn!("asked to skip TLS verification without TLS support")
230    }
231}
232
233/// A rustls ServerCertVerifier that performs **no** certificate verification.
234///
235/// This should only be used with great care, as skipping certificate verification is insecure
236/// and could allow person-in-the-middle attacks.
237#[cfg(feature = "__tls")]
238#[derive(Debug)]
239struct NoCertificateVerification(CryptoProvider);
240
241#[cfg(feature = "__tls")]
242impl Default for NoCertificateVerification {
243    fn default() -> Self {
244        Self(default_provider())
245    }
246}
247
248#[cfg(feature = "__tls")]
249impl ServerCertVerifier for NoCertificateVerification {
250    fn verify_server_cert(
251        &self,
252        _end_entity: &CertificateDer<'_>,
253        _intermediates: &[CertificateDer<'_>],
254        _server_name: &ServerName<'_>,
255        _ocsp: &[u8],
256        _now: UnixTime,
257    ) -> Result<ServerCertVerified, rustls::Error> {
258        Ok(ServerCertVerified::assertion())
259    }
260
261    fn verify_tls12_signature(
262        &self,
263        message: &[u8],
264        cert: &CertificateDer<'_>,
265        dss: &DigitallySignedStruct,
266    ) -> Result<HandshakeSignatureValid, rustls::Error> {
267        verify_tls12_signature(
268            message,
269            cert,
270            dss,
271            &self.0.signature_verification_algorithms,
272        )
273    }
274
275    fn verify_tls13_signature(
276        &self,
277        message: &[u8],
278        cert: &CertificateDer<'_>,
279        dss: &DigitallySignedStruct,
280    ) -> Result<HandshakeSignatureValid, rustls::Error> {
281        verify_tls13_signature(
282            message,
283            cert,
284            dss,
285            &self.0.signature_verification_algorithms,
286        )
287    }
288
289    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
290        self.0.signature_verification_algorithms.supported_schemes()
291    }
292}
293
294#[cfg(all(
295    test,
296    feature = "tokio",
297    any(feature = "webpki-roots", feature = "rustls-platform-verifier"),
298    any(
299        feature = "__tls",
300        feature = "__https",
301        feature = "__quic",
302        feature = "__h3"
303    )
304))]
305mod tests {
306    #[cfg(feature = "__quic")]
307    use std::net::IpAddr;
308
309    use test_support::subscribe;
310
311    use crate::TokioResolver;
312    #[cfg(any(feature = "__tls", feature = "__https"))]
313    use crate::config::CLOUDFLARE;
314    #[cfg(any(
315        feature = "__tls",
316        feature = "__https",
317        feature = "__quic",
318        feature = "__h3"
319    ))]
320    use crate::config::GOOGLE;
321    use crate::config::ResolverConfig;
322    #[cfg(feature = "__quic")]
323    use crate::config::ServerGroup;
324    #[cfg(feature = "__quic")]
325    use crate::config::ServerOrderingStrategy;
326    use crate::net::runtime::TokioRuntimeProvider;
327    #[cfg(feature = "__quic")]
328    use crate::net::tls::client_config;
329
330    #[cfg(feature = "__h3")]
331    #[tokio::test]
332    async fn test_google_h3() {
333        subscribe();
334        h3_test(ResolverConfig::h3(&GOOGLE)).await
335    }
336
337    #[cfg(feature = "__h3")]
338    async fn h3_test(config: ResolverConfig) {
339        let mut builder =
340            TokioResolver::builder_with_config(config, TokioRuntimeProvider::default());
341        // Prefer IPv4 addresses for this test.
342        builder.options_mut().server_ordering_strategy = ServerOrderingStrategy::UserProvidedOrder;
343        let resolver = builder.build().unwrap();
344
345        let response = resolver
346            .lookup_ip("www.example.com.")
347            .await
348            .expect("failed to run lookup");
349
350        assert_ne!(response.iter().count(), 0);
351
352        // check if there is another connection created
353        let response = resolver
354            .lookup_ip("www.example.com.")
355            .await
356            .expect("failed to run lookup");
357
358        assert_ne!(response.iter().count(), 0);
359    }
360
361    #[cfg(feature = "__quic")]
362    #[tokio::test]
363    async fn test_adguard_quic() {
364        subscribe();
365
366        // AdGuard requires SNI.
367        let config = client_config().unwrap();
368
369        let group = ServerGroup {
370            ips: &[
371                IpAddr::from([94, 140, 14, 140]),
372                IpAddr::from([94, 140, 14, 141]),
373                IpAddr::from([0x2a10, 0x50c0, 0, 0, 0, 0, 0x1, 0xff]),
374                IpAddr::from([0x2a10, 0x50c0, 0, 0, 0, 0, 0x2, 0xff]),
375            ],
376            server_name: "unfiltered.adguard-dns.com",
377            path: "/dns-query",
378        };
379
380        quic_test(ResolverConfig::quic(&group), config).await
381    }
382
383    #[cfg(feature = "__quic")]
384    async fn quic_test(config: ResolverConfig, tls_config: rustls::ClientConfig) {
385        let mut resolver_builder =
386            TokioResolver::builder_with_config(config, TokioRuntimeProvider::default());
387        resolver_builder.options_mut().try_tcp_on_error = true;
388        // Prefer IPv4 addresses for this test.
389        resolver_builder.options_mut().server_ordering_strategy =
390            ServerOrderingStrategy::UserProvidedOrder;
391        resolver_builder = resolver_builder.with_tls_config(tls_config);
392        let resolver = resolver_builder.build().unwrap();
393
394        let response = resolver
395            .lookup_ip("www.example.com.")
396            .await
397            .expect("failed to run lookup");
398
399        assert_ne!(response.iter().count(), 0);
400
401        // check if there is another connection created
402        let response = resolver
403            .lookup_ip("www.example.com.")
404            .await
405            .expect("failed to run lookup");
406
407        assert_ne!(response.iter().count(), 0);
408    }
409
410    #[cfg(feature = "__https")]
411    #[tokio::test]
412    async fn test_google_https() {
413        subscribe();
414        https_test(ResolverConfig::https(&GOOGLE)).await
415    }
416
417    #[cfg(feature = "__https")]
418    #[tokio::test]
419    async fn test_cloudflare_https() {
420        subscribe();
421        https_test(ResolverConfig::https(&CLOUDFLARE)).await
422    }
423
424    #[cfg(feature = "__https")]
425    async fn https_test(config: ResolverConfig) {
426        let mut resolver_builder =
427            TokioResolver::builder_with_config(config, TokioRuntimeProvider::default());
428        resolver_builder.options_mut().try_tcp_on_error = true;
429        let resolver = resolver_builder.build().unwrap();
430
431        let response = resolver
432            .lookup_ip("www.example.com.")
433            .await
434            .expect("failed to run lookup");
435
436        assert_ne!(response.iter().count(), 0);
437
438        // check if there is another connection created
439        let response = resolver
440            .lookup_ip("www.example.com.")
441            .await
442            .expect("failed to run lookup");
443
444        assert_ne!(response.iter().count(), 0);
445    }
446
447    #[cfg(feature = "__tls")]
448    #[tokio::test]
449    async fn test_google_tls() {
450        subscribe();
451        tls_test(ResolverConfig::tls(&GOOGLE)).await
452    }
453
454    #[cfg(feature = "__tls")]
455    #[tokio::test]
456    async fn test_cloudflare_tls() {
457        subscribe();
458        tls_test(ResolverConfig::tls(&CLOUDFLARE)).await
459    }
460
461    #[cfg(feature = "__tls")]
462    async fn tls_test(config: ResolverConfig) {
463        let mut resolver_builder =
464            TokioResolver::builder_with_config(config, TokioRuntimeProvider::default());
465        resolver_builder.options_mut().try_tcp_on_error = true;
466        let resolver = resolver_builder.build().unwrap();
467
468        let response = resolver
469            .lookup_ip("www.example.com.")
470            .await
471            .expect("failed to run lookup");
472
473        assert_ne!(response.iter().count(), 0);
474    }
475}