Skip to main content

reqwest/dns/
dot.rs

1//! DNS-over-TLS (DoT) resolution via hickory-resolver
2
3use hickory_resolver::{
4    config::{LookupIpStrategy, NameServerConfig, ResolverConfig},
5    net::runtime::TokioRuntimeProvider,
6    TokioResolver,
7};
8
9use std::net::IpAddr;
10use std::str::FromStr;
11use std::sync::{Arc, Mutex};
12use std::time::Duration;
13
14use super::{Addrs, Name, Resolve, Resolving, SocketAddrs};
15use super::gai::GaiResolver;
16use crate::error::BoxError;
17
18/// A DNS-over-TLS resolver backed by hickory-resolver.
19pub struct DotResolver {
20    state: Arc<Mutex<Option<Arc<TokioResolver>>>>,
21    bootstrap: Arc<dyn Resolve>,
22    tls_host: String,
23    tls_port: u16,
24}
25
26impl std::fmt::Debug for DotResolver {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        f.debug_struct("DotResolver")
29            .field("tls_host", &self.tls_host)
30            .field("tls_port", &self.tls_port)
31            .finish()
32    }
33}
34
35impl Clone for DotResolver {
36    fn clone(&self) -> Self {
37        Self {
38            state: self.state.clone(),
39            bootstrap: self.bootstrap.clone(),
40            tls_host: self.tls_host.clone(),
41            tls_port: self.tls_port,
42        }
43    }
44}
45
46impl DotResolver {
47    /// Create a new DoT resolver from a hostname like `"1.1.1.1"` or `"cloudflare-dns.com"`.
48    ///
49    /// The host is resolved via the system resolver (GaiResolver) on first lookup.
50    /// The default port is 853.
51    pub fn new(host: &str) -> Self {
52        Self::new_with_port(host, 853)
53    }
54
55    /// Create a new DoT resolver with a custom port.
56    pub fn new_with_port(host: &str, port: u16) -> Self {
57        let bootstrap: Arc<dyn Resolve> = Arc::new(GaiResolver::new());
58        Self {
59            state: Arc::new(Mutex::new(None)),
60            bootstrap,
61            tls_host: host.to_string(),
62            tls_port: port,
63        }
64    }
65
66    async fn get_resolver(&self) -> Result<Arc<TokioResolver>, BoxError> {
67        if let Some(ref resolver) = *self.state.lock().unwrap() {
68            return Ok(resolver.clone());
69        }
70
71        let addrs = self
72            .bootstrap
73            .resolve(Name::from_str(&self.tls_host)?)
74            .await?;
75        let ips: Vec<IpAddr> = addrs.map(|a| a.ip()).collect();
76
77        let name_servers: Vec<NameServerConfig> = ips
78            .iter()
79            .map(|&ip| NameServerConfig::tls(ip, self.tls_host.clone().into()))
80            .collect();
81        let config = ResolverConfig::from_parts(None, vec![], name_servers);
82
83        let mut builder =
84            TokioResolver::builder_with_config(config, TokioRuntimeProvider::default());
85        let opts = builder.options_mut();
86        opts.timeout = Duration::from_secs(5);
87        opts.ip_strategy = LookupIpStrategy::Ipv4AndIpv6;
88        let resolver = Arc::new(builder.build().expect("failed to build DoT resolver"));
89
90        let mut guard = self.state.lock().unwrap();
91        if guard.is_none() {
92            *guard = Some(resolver.clone());
93        }
94        Ok(guard.as_ref().unwrap().clone())
95    }
96}
97
98impl Resolve for DotResolver {
99    fn resolve(&self, name: Name) -> Resolving {
100        let this = self.clone();
101        Box::pin(async move {
102            let resolver = this.get_resolver().await?;
103            let lookup = resolver.lookup_ip(name.as_str()).await?;
104            let ips: Vec<IpAddr> = lookup.iter().collect();
105            let addrs: Addrs = Box::new(SocketAddrs {
106                iter: ips.into_iter(),
107            });
108            Ok(addrs)
109        })
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116
117    #[test]
118    fn new_default_port() {
119        let resolver = DotResolver::new("1.1.1.1");
120        assert_eq!(resolver.tls_host, "1.1.1.1");
121        assert_eq!(resolver.tls_port, 853);
122    }
123
124    #[test]
125    fn new_custom_port() {
126        let resolver = DotResolver::new_with_port("dns.google", 5353);
127        assert_eq!(resolver.tls_host, "dns.google");
128        assert_eq!(resolver.tls_port, 5353);
129    }
130
131    #[test]
132    fn debug_output() {
133        let resolver = DotResolver::new_with_port("cloudflare-dns.com", 853);
134        let debug = format!("{:?}", resolver);
135        assert!(debug.contains("cloudflare-dns.com"), "{debug}");
136        assert!(debug.contains("853"), "{debug}");
137    }
138}