1use hickory_resolver::{
4 config::{LookupIpStrategy, NameServerConfig, ResolverConfig},
5 net::runtime::TokioRuntimeProvider,
6 TokioResolver,
7};
8
9use std::net::IpAddr;
10use std::str::FromStr;
11use std::sync::{Arc, Mutex};
12use std::time::Duration;
13
14use super::{Addrs, Name, Resolve, Resolving, SocketAddrs};
15use super::gai::GaiResolver;
16use crate::error::BoxError;
17
18pub struct DotResolver {
20 state: Arc<Mutex<Option<Arc<TokioResolver>>>>,
21 bootstrap: Arc<dyn Resolve>,
22 tls_host: String,
23 tls_port: u16,
24}
25
26impl std::fmt::Debug for DotResolver {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 f.debug_struct("DotResolver")
29 .field("tls_host", &self.tls_host)
30 .field("tls_port", &self.tls_port)
31 .finish()
32 }
33}
34
35impl Clone for DotResolver {
36 fn clone(&self) -> Self {
37 Self {
38 state: self.state.clone(),
39 bootstrap: self.bootstrap.clone(),
40 tls_host: self.tls_host.clone(),
41 tls_port: self.tls_port,
42 }
43 }
44}
45
46impl DotResolver {
47 pub fn new(host: &str) -> Self {
52 Self::new_with_port(host, 853)
53 }
54
55 pub fn new_with_port(host: &str, port: u16) -> Self {
57 let bootstrap: Arc<dyn Resolve> = Arc::new(GaiResolver::new());
58 Self {
59 state: Arc::new(Mutex::new(None)),
60 bootstrap,
61 tls_host: host.to_string(),
62 tls_port: port,
63 }
64 }
65
66 async fn get_resolver(&self) -> Result<Arc<TokioResolver>, BoxError> {
67 if let Some(ref resolver) = *self.state.lock().unwrap() {
68 return Ok(resolver.clone());
69 }
70
71 let addrs = self
72 .bootstrap
73 .resolve(Name::from_str(&self.tls_host)?)
74 .await?;
75 let ips: Vec<IpAddr> = addrs.map(|a| a.ip()).collect();
76
77 let name_servers: Vec<NameServerConfig> = ips
78 .iter()
79 .map(|&ip| NameServerConfig::tls(ip, self.tls_host.clone().into()))
80 .collect();
81 let config = ResolverConfig::from_parts(None, vec![], name_servers);
82
83 let mut builder =
84 TokioResolver::builder_with_config(config, TokioRuntimeProvider::default());
85 let opts = builder.options_mut();
86 opts.timeout = Duration::from_secs(5);
87 opts.ip_strategy = LookupIpStrategy::Ipv4AndIpv6;
88 let resolver = Arc::new(builder.build().expect("failed to build DoT resolver"));
89
90 let mut guard = self.state.lock().unwrap();
91 if guard.is_none() {
92 *guard = Some(resolver.clone());
93 }
94 Ok(guard.as_ref().unwrap().clone())
95 }
96}
97
98impl Resolve for DotResolver {
99 fn resolve(&self, name: Name) -> Resolving {
100 let this = self.clone();
101 Box::pin(async move {
102 let resolver = this.get_resolver().await?;
103 let lookup = resolver.lookup_ip(name.as_str()).await?;
104 let ips: Vec<IpAddr> = lookup.iter().collect();
105 let addrs: Addrs = Box::new(SocketAddrs {
106 iter: ips.into_iter(),
107 });
108 Ok(addrs)
109 })
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use super::*;
116
117 #[test]
118 fn new_default_port() {
119 let resolver = DotResolver::new("1.1.1.1");
120 assert_eq!(resolver.tls_host, "1.1.1.1");
121 assert_eq!(resolver.tls_port, 853);
122 }
123
124 #[test]
125 fn new_custom_port() {
126 let resolver = DotResolver::new_with_port("dns.google", 5353);
127 assert_eq!(resolver.tls_host, "dns.google");
128 assert_eq!(resolver.tls_port, 5353);
129 }
130
131 #[test]
132 fn debug_output() {
133 let resolver = DotResolver::new_with_port("cloudflare-dns.com", 853);
134 let debug = format!("{:?}", resolver);
135 assert!(debug.contains("cloudflare-dns.com"), "{debug}");
136 assert!(debug.contains("853"), "{debug}");
137 }
138}