1use core::task;
2use std::{
3 collections::BTreeMap, fmt::Debug, net::IpAddr, pin::Pin, str::FromStr, sync::Arc, task::Poll,
4};
5
6use anyhow::Context;
7use arc_swap::ArcSwap;
8use async_trait::async_trait;
9use candid::Principal;
10use hickory_proto::rr::{Record, RecordType};
11use hickory_resolver::{
12 ResolveError, TokioResolver,
13 config::{NameServerConfigGroup, ResolveHosts, ResolverConfig, ResolverOpts},
14 name_server::TokioConnectionProvider,
15};
16use hyper_util::client::legacy::connect::dns::Name as HyperName;
17use ic_agent::Agent;
18use ic_bn_lib_common::{
19 principal,
20 traits::{
21 Run,
22 dns::{CloneableDnsResolver, CloneableHyperDnsResolver, HyperDnsResolver, Resolves},
23 },
24 types::{
25 dns::{Options, Protocol, SocketAddrs},
26 http::Error,
27 },
28};
29use reqwest::dns::{Addrs, Name, Resolve, Resolving};
30use tokio_util::sync::CancellationToken;
31use tower::Service;
32
33#[derive(Debug, Clone)]
35pub struct Resolver(Arc<TokioResolver>);
36impl CloneableDnsResolver for Resolver {}
37impl HyperDnsResolver for Resolver {}
38impl CloneableHyperDnsResolver for Resolver {}
39
40impl Resolver {
41 pub fn new(o: Options) -> Self {
44 let name_servers = match o.protocol {
45 Protocol::Clear(p) => NameServerConfigGroup::from_ips_clear(&o.servers, p, true),
46 Protocol::Tls(p) => {
47 NameServerConfigGroup::from_ips_tls(&o.servers, p, o.tls_name, true)
48 }
49 Protocol::Https(p) => {
50 NameServerConfigGroup::from_ips_https(&o.servers, p, o.tls_name, true)
51 }
52 };
53
54 let cfg = ResolverConfig::from_parts(None, vec![], name_servers);
55
56 let mut opts = ResolverOpts::default();
57 opts.cache_size = o.cache_size;
58 opts.timeout = o.timeout;
59 opts.ip_strategy = o.lookup_ip_strategy;
60 opts.use_hosts_file = ResolveHosts::Never;
61 opts.preserve_intermediates = false;
62 opts.validate = !o.dnssec_disabled;
63 opts.try_tcp_on_error = true;
64
65 let builder = TokioResolver::builder_with_config(cfg, TokioConnectionProvider::default())
66 .with_options(opts);
67
68 Self(Arc::new(builder.build()))
69 }
70}
71
72impl Default for Resolver {
73 fn default() -> Self {
74 Self::new(Options::default())
75 }
76}
77
78impl Resolve for Resolver {
80 fn resolve(&self, name: Name) -> Resolving {
81 let resolver = self.clone();
82
83 Box::pin(async move {
84 let lookup = resolver.0.lookup_ip(name.as_str()).await?;
85 let addrs: Addrs = Box::new(SocketAddrs {
86 iter: Box::new(lookup.into_iter()),
87 });
88
89 Ok(addrs)
90 })
91 }
92}
93
94#[async_trait]
95impl Resolves for Resolver {
96 async fn resolve(
97 &self,
98 record_type: RecordType,
99 name: &str,
100 ) -> Result<Vec<Record>, ResolveError> {
101 let lookup = self.0.lookup(name, record_type).await?;
102 Ok(lookup.records().to_vec())
103 }
104
105 fn flush_cache(&self) {
106 self.0.clear_cache();
107 }
108}
109
110impl Service<HyperName> for Resolver {
112 type Response = SocketAddrs;
113 type Error = Error;
114 #[allow(clippy::type_complexity)]
115 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
116
117 fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
118 Poll::Ready(Ok(()))
119 }
120
121 fn call(&mut self, name: HyperName) -> Self::Future {
122 let resolver = self.0.clone();
123
124 Box::pin(async move {
125 let response = resolver
126 .lookup_ip(name.as_str())
127 .await
128 .map_err(|e| Error::DnsError(e.to_string()))?;
129 let addresses = response.into_iter();
130
131 Ok(SocketAddrs {
132 iter: Box::new(addresses),
133 })
134 })
135 }
136}
137
138#[derive(Debug, Clone)]
141pub struct FixedResolver(Resolver, String, HyperName);
142impl CloneableDnsResolver for FixedResolver {}
143impl HyperDnsResolver for FixedResolver {}
144impl CloneableHyperDnsResolver for FixedResolver {}
145
146impl FixedResolver {
147 pub fn new(o: Options, name: String) -> Result<Self, Error> {
148 let resolver = Resolver::new(o);
149 let hyper_name = HyperName::from_str(&name).context("unable to parse name")?;
150
151 Ok(Self(resolver, name, hyper_name))
152 }
153}
154
155impl Resolve for FixedResolver {
157 fn resolve(&self, _name: Name) -> Resolving {
158 let name = Name::from_str(&self.1).unwrap();
161 reqwest::dns::Resolve::resolve(&self.0, name)
162 }
163}
164
165impl Service<HyperName> for FixedResolver {
167 type Response = SocketAddrs;
168 type Error = Error;
169 #[allow(clippy::type_complexity)]
170 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
171
172 fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
173 Poll::Ready(Ok(()))
174 }
175
176 fn call(&mut self, _name: HyperName) -> Self::Future {
177 self.0.call(self.2.clone())
178 }
179}
180
181#[derive(Debug, Clone)]
183pub struct StaticResolver(Arc<BTreeMap<String, Vec<IpAddr>>>);
184impl CloneableDnsResolver for StaticResolver {}
185impl HyperDnsResolver for StaticResolver {}
186impl CloneableHyperDnsResolver for StaticResolver {}
187
188impl StaticResolver {
189 pub fn new(items: impl IntoIterator<Item = (String, Vec<IpAddr>)>) -> Self {
190 Self(Arc::new(BTreeMap::from_iter(items)))
191 }
192
193 pub fn lookup(&self, name: &str) -> Option<Vec<IpAddr>> {
194 self.0.get(name).cloned()
195 }
196}
197
198impl Resolve for StaticResolver {
200 fn resolve(&self, name: Name) -> Resolving {
201 let addrs = self.lookup(name.as_str()).unwrap_or_default();
202
203 Box::pin(async move {
204 Ok(Box::new(SocketAddrs {
205 iter: Box::new(addrs.into_iter()),
206 }) as Addrs)
207 })
208 }
209}
210
211impl Service<HyperName> for StaticResolver {
213 type Response = SocketAddrs;
214 type Error = Error;
215 #[allow(clippy::type_complexity)]
216 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
217
218 fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
219 Poll::Ready(Ok(()))
220 }
221
222 fn call(&mut self, name: HyperName) -> Self::Future {
223 let addrs = self.lookup(name.as_str()).unwrap_or_default();
224
225 Box::pin(async move {
226 Ok(SocketAddrs {
227 iter: Box::new(addrs.into_iter()),
228 })
229 })
230 }
231}
232
233#[derive(Debug, Clone)]
237pub struct ApiBnResolver {
238 agent: Agent,
239 subnet: Principal,
240 resolver_static: Arc<ArcSwap<StaticResolver>>,
241 resolver_fallback: Resolver,
242}
243impl CloneableDnsResolver for ApiBnResolver {}
244impl HyperDnsResolver for ApiBnResolver {}
245impl CloneableHyperDnsResolver for ApiBnResolver {}
246
247impl ApiBnResolver {
248 pub fn new(resolver_fallback: Resolver, agent: Agent) -> Self {
249 let resolver_static = Arc::new(ArcSwap::new(Arc::new(StaticResolver::new(vec![]))));
250 let subnet = principal!("tdb26-jop6k-aogll-7ltgs-eruif-6kk7m-qpktf-gdiqx-mxtrf-vb5e6-eqe");
251
252 Self {
253 agent,
254 subnet,
255 resolver_static,
256 resolver_fallback,
257 }
258 }
259
260 async fn get_api_bns(&self) -> Result<Vec<(String, Vec<IpAddr>)>, Error> {
262 let api_bns = self
263 .agent
264 .fetch_api_boundary_nodes_by_subnet_id(self.subnet)
265 .await
266 .context("unable to get API BNs from IC")?;
267
268 let mut r = Vec::with_capacity(api_bns.len());
269 for n in api_bns {
270 let ipv6 = IpAddr::from_str(&n.ipv6_address)
271 .context(format!("unable to parse IPv6 address for {}", n.domain))?;
272 let mut addrs = vec![ipv6];
273
274 if let Some(v) = n.ipv4_address {
276 let ipv4 = IpAddr::from_str(&v)
277 .context(format!("unable to parse IPv4 address for {}", n.domain))?;
278 addrs.push(ipv4);
279 }
280
281 r.push((n.domain, addrs));
282 }
283
284 Ok(r)
285 }
286}
287
288impl Resolve for ApiBnResolver {
290 fn resolve(&self, name: Name) -> Resolving {
291 let api_bns = self.resolver_static.load_full().lookup(name.as_str());
292 let resolver_fallback = self.resolver_fallback.clone();
293
294 Box::pin(async move {
295 let addrs = match api_bns {
296 Some(v) => v,
297 None => {
298 resolver_fallback
300 .0
301 .lookup_ip(name.as_str())
302 .await
303 .map_err(|e| Error::DnsError(e.to_string()))?
304 .into_iter()
305 .collect()
306 }
307 };
308
309 Ok(Box::new(SocketAddrs {
310 iter: Box::new(addrs.into_iter()),
311 }) as Addrs)
312 })
313 }
314}
315
316impl Service<HyperName> for ApiBnResolver {
318 type Response = SocketAddrs;
319 type Error = Error;
320 #[allow(clippy::type_complexity)]
321 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
322
323 fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
324 Poll::Ready(Ok(()))
325 }
326
327 fn call(&mut self, name: HyperName) -> Self::Future {
328 let api_bns = self.resolver_static.load_full().lookup(name.as_str());
329 let resolver_fallback = self.resolver_fallback.clone();
330
331 Box::pin(async move {
332 let addrs = match api_bns {
333 Some(v) => v,
334 None => {
335 resolver_fallback
337 .0
338 .lookup_ip(name.as_str())
339 .await
340 .map_err(|e| Error::DnsError(e.to_string()))?
341 .into_iter()
342 .collect()
343 }
344 };
345
346 Ok(SocketAddrs {
347 iter: Box::new(addrs.into_iter()),
348 })
349 })
350 }
351}
352
353#[async_trait]
354impl Run for ApiBnResolver {
355 async fn run(&self, _token: CancellationToken) -> Result<(), anyhow::Error> {
356 let api_bns = self.get_api_bns().await?;
357 let resolver = StaticResolver::new(api_bns);
358 self.resolver_static.store(Arc::new(resolver));
359
360 Ok(())
361 }
362}
363
364#[derive(Debug, Clone)]
366pub struct SingleResolver(IpAddr);
367impl CloneableDnsResolver for SingleResolver {}
368
369impl SingleResolver {
370 pub const fn new(addr: IpAddr) -> Self {
371 Self(addr)
372 }
373}
374
375impl Resolve for SingleResolver {
377 fn resolve(&self, _name: Name) -> Resolving {
378 let addr = self.0;
379
380 Box::pin(async move {
381 Ok(Box::new(SocketAddrs {
382 iter: Box::new(vec![addr].into_iter()),
383 }) as Addrs)
384 })
385 }
386}
387
388#[cfg(test)]
389mod test {
390 use std::net::{Ipv4Addr, SocketAddr};
391
392 use super::*;
393
394 #[test]
395 fn test_dns_protocol() {
396 assert_eq!(Protocol::from_str("clear").unwrap(), Protocol::Clear(53));
397 assert_eq!(Protocol::from_str("tls").unwrap(), Protocol::Tls(853));
398 assert_eq!(Protocol::from_str("https").unwrap(), Protocol::Https(443));
399
400 assert_eq!(
401 Protocol::from_str("clear:8053").unwrap(),
402 Protocol::Clear(8053)
403 );
404 assert_eq!(Protocol::from_str("tls:8853").unwrap(), Protocol::Tls(8853));
405 assert_eq!(
406 Protocol::from_str("https:8443").unwrap(),
407 Protocol::Https(8443)
408 );
409
410 assert!(Protocol::from_str("clear:").is_err(),);
411 assert!(Protocol::from_str("clear:x").is_err(),);
412 assert!(Protocol::from_str("clear:-1").is_err(),);
413 assert!(Protocol::from_str("clear:65537").is_err(),);
414 }
415
416 #[tokio::test]
417 async fn test_single_resolver() {
418 let addr = IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4));
419 let resolver = SingleResolver::new(addr);
420
421 let mut res = resolver
422 .resolve(Name::from_str("foo.bar").unwrap())
423 .await
424 .unwrap();
425 assert_eq!(res.next(), Some(SocketAddr::new(addr, 0)));
426 }
427}