1use core::task;
2use std::{
3 collections::BTreeMap, fmt::Debug, net::IpAddr, pin::Pin, str::FromStr, sync::Arc, task::Poll,
4};
5
6use anyhow::Context;
7use arc_swap::ArcSwap;
8use async_trait::async_trait;
9use candid::Principal;
10use hickory_proto::rr::{Record, RecordType};
11use hickory_resolver::{
12 ResolveError, TokioResolver,
13 config::{NameServerConfigGroup, ResolveHosts, ResolverConfig, ResolverOpts},
14 name_server::TokioConnectionProvider,
15};
16use hyper_util::client::legacy::connect::dns::Name as HyperName;
17use ic_agent::Agent;
18use ic_bn_lib_common::{
19 principal,
20 traits::{
21 Run,
22 dns::{CloneableDnsResolver, CloneableHyperDnsResolver, HyperDnsResolver, Resolves},
23 },
24 types::{
25 dns::{Options, Protocol, SocketAddrs},
26 http::Error,
27 },
28};
29use reqwest::dns::{Addrs, Name, Resolve, Resolving};
30use tokio_util::sync::CancellationToken;
31use tower::Service;
32
33#[derive(Debug, Clone)]
35pub struct Resolver(Arc<TokioResolver>);
36impl CloneableDnsResolver for Resolver {}
37impl HyperDnsResolver for Resolver {}
38impl CloneableHyperDnsResolver for Resolver {}
39
40impl Resolver {
41 pub fn new(o: Options) -> Self {
44 let name_servers = match o.protocol {
45 Protocol::Clear(p) => NameServerConfigGroup::from_ips_clear(&o.servers, p, true),
46 Protocol::Tls(p) => {
47 NameServerConfigGroup::from_ips_tls(&o.servers, p, o.tls_name, true)
48 }
49 Protocol::Https(p) => {
50 NameServerConfigGroup::from_ips_https(&o.servers, p, o.tls_name, true)
51 }
52 };
53
54 let cfg = ResolverConfig::from_parts(None, vec![], name_servers);
55
56 let mut opts = ResolverOpts::default();
57 opts.cache_size = o.cache_size;
58 opts.timeout = o.timeout;
59 opts.ip_strategy = o.lookup_ip_strategy;
60 opts.use_hosts_file = ResolveHosts::Never;
61 opts.preserve_intermediates = false;
62 opts.try_tcp_on_error = true;
63
64 let builder = TokioResolver::builder_with_config(cfg, TokioConnectionProvider::default())
65 .with_options(opts);
66
67 Self(Arc::new(builder.build()))
68 }
69}
70
71impl Default for Resolver {
72 fn default() -> Self {
73 Self::new(Options::default())
74 }
75}
76
77impl Resolve for Resolver {
79 fn resolve(&self, name: Name) -> Resolving {
80 let resolver = self.clone();
81
82 Box::pin(async move {
83 let lookup = resolver.0.lookup_ip(name.as_str()).await?;
84 let addrs: Addrs = Box::new(SocketAddrs {
85 iter: Box::new(lookup.into_iter()),
86 });
87
88 Ok(addrs)
89 })
90 }
91}
92
93#[async_trait]
94impl Resolves for Resolver {
95 async fn resolve(
96 &self,
97 record_type: RecordType,
98 name: &str,
99 ) -> Result<Vec<Record>, ResolveError> {
100 let lookup = self.0.lookup(name, record_type).await?;
101 Ok(lookup.records().to_vec())
102 }
103
104 fn flush_cache(&self) {
105 self.0.clear_cache();
106 }
107}
108
109impl Service<HyperName> for Resolver {
111 type Response = SocketAddrs;
112 type Error = Error;
113 #[allow(clippy::type_complexity)]
114 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
115
116 fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
117 Poll::Ready(Ok(()))
118 }
119
120 fn call(&mut self, name: HyperName) -> Self::Future {
121 let resolver = self.0.clone();
122
123 Box::pin(async move {
124 let response = resolver
125 .lookup_ip(name.as_str())
126 .await
127 .map_err(|e| Error::DnsError(e.to_string()))?;
128 let addresses = response.into_iter();
129
130 Ok(SocketAddrs {
131 iter: Box::new(addresses),
132 })
133 })
134 }
135}
136
137#[derive(Debug, Clone)]
140pub struct FixedResolver(Resolver, String, HyperName);
141impl CloneableDnsResolver for FixedResolver {}
142impl HyperDnsResolver for FixedResolver {}
143impl CloneableHyperDnsResolver for FixedResolver {}
144
145impl FixedResolver {
146 pub fn new(o: Options, name: String) -> Result<Self, Error> {
147 let resolver = Resolver::new(o);
148 let hyper_name = HyperName::from_str(&name).context("unable to parse name")?;
149
150 Ok(Self(resolver, name, hyper_name))
151 }
152}
153
154impl Resolve for FixedResolver {
156 fn resolve(&self, _name: Name) -> Resolving {
157 let name = Name::from_str(&self.1).unwrap();
160 reqwest::dns::Resolve::resolve(&self.0, name)
161 }
162}
163
164impl Service<HyperName> for FixedResolver {
166 type Response = SocketAddrs;
167 type Error = Error;
168 #[allow(clippy::type_complexity)]
169 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
170
171 fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
172 Poll::Ready(Ok(()))
173 }
174
175 fn call(&mut self, _name: HyperName) -> Self::Future {
176 self.0.call(self.2.clone())
177 }
178}
179
180#[derive(Debug, Clone)]
182pub struct StaticResolver(Arc<BTreeMap<String, Vec<IpAddr>>>);
183impl CloneableDnsResolver for StaticResolver {}
184impl HyperDnsResolver for StaticResolver {}
185impl CloneableHyperDnsResolver for StaticResolver {}
186
187impl StaticResolver {
188 pub fn new(items: impl IntoIterator<Item = (String, Vec<IpAddr>)>) -> Self {
189 Self(Arc::new(BTreeMap::from_iter(items)))
190 }
191
192 pub fn lookup(&self, name: &str) -> Option<Vec<IpAddr>> {
193 self.0.get(name).cloned()
194 }
195}
196
197impl Resolve for StaticResolver {
199 fn resolve(&self, name: Name) -> Resolving {
200 let addrs = self.lookup(name.as_str()).unwrap_or_default();
201
202 Box::pin(async move {
203 Ok(Box::new(SocketAddrs {
204 iter: Box::new(addrs.into_iter()),
205 }) as Addrs)
206 })
207 }
208}
209
210impl Service<HyperName> for StaticResolver {
212 type Response = SocketAddrs;
213 type Error = Error;
214 #[allow(clippy::type_complexity)]
215 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
216
217 fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
218 Poll::Ready(Ok(()))
219 }
220
221 fn call(&mut self, name: HyperName) -> Self::Future {
222 let addrs = self.lookup(name.as_str()).unwrap_or_default();
223
224 Box::pin(async move {
225 Ok(SocketAddrs {
226 iter: Box::new(addrs.into_iter()),
227 })
228 })
229 }
230}
231
232#[derive(Debug, Clone)]
236pub struct ApiBnResolver {
237 agent: Agent,
238 subnet: Principal,
239 resolver_static: Arc<ArcSwap<StaticResolver>>,
240 resolver_fallback: Resolver,
241}
242impl CloneableDnsResolver for ApiBnResolver {}
243impl HyperDnsResolver for ApiBnResolver {}
244impl CloneableHyperDnsResolver for ApiBnResolver {}
245
246impl ApiBnResolver {
247 pub fn new(resolver_fallback: Resolver, agent: Agent) -> Self {
248 let resolver_static = Arc::new(ArcSwap::new(Arc::new(StaticResolver::new(vec![]))));
249 let subnet = principal!("tdb26-jop6k-aogll-7ltgs-eruif-6kk7m-qpktf-gdiqx-mxtrf-vb5e6-eqe");
250
251 Self {
252 agent,
253 subnet,
254 resolver_static,
255 resolver_fallback,
256 }
257 }
258
259 async fn get_api_bns(&self) -> Result<Vec<(String, Vec<IpAddr>)>, Error> {
261 let api_bns = self
262 .agent
263 .fetch_api_boundary_nodes_by_subnet_id(self.subnet)
264 .await
265 .context("unable to get API BNs from IC")?;
266
267 let mut r = Vec::with_capacity(api_bns.len());
268 for n in api_bns {
269 let ipv6 = IpAddr::from_str(&n.ipv6_address)
270 .context(format!("unable to parse IPv6 address for {}", n.domain))?;
271 let mut addrs = vec![ipv6];
272
273 if let Some(v) = n.ipv4_address {
275 let ipv4 = IpAddr::from_str(&v)
276 .context(format!("unable to parse IPv4 address for {}", n.domain))?;
277 addrs.push(ipv4);
278 }
279
280 r.push((n.domain, addrs));
281 }
282
283 Ok(r)
284 }
285}
286
287impl Resolve for ApiBnResolver {
289 fn resolve(&self, name: Name) -> Resolving {
290 let api_bns = self.resolver_static.load_full().lookup(name.as_str());
291 let resolver_fallback = self.resolver_fallback.clone();
292
293 Box::pin(async move {
294 let addrs = match api_bns {
295 Some(v) => v,
296 None => {
297 resolver_fallback
299 .0
300 .lookup_ip(name.as_str())
301 .await
302 .map_err(|e| Error::DnsError(e.to_string()))?
303 .into_iter()
304 .collect()
305 }
306 };
307
308 Ok(Box::new(SocketAddrs {
309 iter: Box::new(addrs.into_iter()),
310 }) as Addrs)
311 })
312 }
313}
314
315impl Service<HyperName> for ApiBnResolver {
317 type Response = SocketAddrs;
318 type Error = Error;
319 #[allow(clippy::type_complexity)]
320 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
321
322 fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
323 Poll::Ready(Ok(()))
324 }
325
326 fn call(&mut self, name: HyperName) -> Self::Future {
327 let api_bns = self.resolver_static.load_full().lookup(name.as_str());
328 let resolver_fallback = self.resolver_fallback.clone();
329
330 Box::pin(async move {
331 let addrs = match api_bns {
332 Some(v) => v,
333 None => {
334 resolver_fallback
336 .0
337 .lookup_ip(name.as_str())
338 .await
339 .map_err(|e| Error::DnsError(e.to_string()))?
340 .into_iter()
341 .collect()
342 }
343 };
344
345 Ok(SocketAddrs {
346 iter: Box::new(addrs.into_iter()),
347 })
348 })
349 }
350}
351
352#[async_trait]
353impl Run for ApiBnResolver {
354 async fn run(&self, _token: CancellationToken) -> Result<(), anyhow::Error> {
355 let api_bns = self.get_api_bns().await?;
356 let resolver = StaticResolver::new(api_bns);
357 self.resolver_static.store(Arc::new(resolver));
358
359 Ok(())
360 }
361}
362
363#[derive(Debug, Clone)]
365pub struct SingleResolver(IpAddr);
366impl CloneableDnsResolver for SingleResolver {}
367
368impl SingleResolver {
369 pub const fn new(addr: IpAddr) -> Self {
370 Self(addr)
371 }
372}
373
374impl Resolve for SingleResolver {
376 fn resolve(&self, _name: Name) -> Resolving {
377 let addr = self.0;
378
379 Box::pin(async move {
380 Ok(Box::new(SocketAddrs {
381 iter: Box::new(vec![addr].into_iter()),
382 }) as Addrs)
383 })
384 }
385}
386
387#[cfg(test)]
388mod test {
389 use std::net::{Ipv4Addr, SocketAddr};
390
391 use super::*;
392
393 #[test]
394 fn test_dns_protocol() {
395 assert_eq!(Protocol::from_str("clear").unwrap(), Protocol::Clear(53));
396 assert_eq!(Protocol::from_str("tls").unwrap(), Protocol::Tls(853));
397 assert_eq!(Protocol::from_str("https").unwrap(), Protocol::Https(443));
398
399 assert_eq!(
400 Protocol::from_str("clear:8053").unwrap(),
401 Protocol::Clear(8053)
402 );
403 assert_eq!(Protocol::from_str("tls:8853").unwrap(), Protocol::Tls(8853));
404 assert_eq!(
405 Protocol::from_str("https:8443").unwrap(),
406 Protocol::Https(8443)
407 );
408
409 assert!(Protocol::from_str("clear:").is_err(),);
410 assert!(Protocol::from_str("clear:x").is_err(),);
411 assert!(Protocol::from_str("clear:-1").is_err(),);
412 assert!(Protocol::from_str("clear:65537").is_err(),);
413 }
414
415 #[tokio::test]
416 async fn test_single_resolver() {
417 let addr = IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4));
418 let resolver = SingleResolver::new(addr);
419
420 let mut res = resolver
421 .resolve(Name::from_str("foo.bar").unwrap())
422 .await
423 .unwrap();
424 assert_eq!(res.next(), Some(SocketAddr::new(addr, 0)));
425 }
426}