ic_bn_lib/http/dns/
mod.rs

1use core::task;
2use std::{
3    collections::BTreeMap, fmt::Debug, net::IpAddr, pin::Pin, str::FromStr, sync::Arc, task::Poll,
4};
5
6use anyhow::Context;
7use arc_swap::ArcSwap;
8use async_trait::async_trait;
9use candid::Principal;
10use hickory_proto::rr::{Record, RecordType};
11use hickory_resolver::{
12    ResolveError, TokioResolver,
13    config::{NameServerConfigGroup, ResolveHosts, ResolverConfig, ResolverOpts},
14    name_server::TokioConnectionProvider,
15};
16use hyper_util::client::legacy::connect::dns::Name as HyperName;
17use ic_agent::Agent;
18use ic_bn_lib_common::{
19    principal,
20    traits::{
21        Run,
22        dns::{CloneableDnsResolver, CloneableHyperDnsResolver, HyperDnsResolver, Resolves},
23    },
24    types::{
25        dns::{Options, Protocol, SocketAddrs},
26        http::Error,
27    },
28};
29use reqwest::dns::{Addrs, Name, Resolve, Resolving};
30use tokio_util::sync::CancellationToken;
31use tower::Service;
32
33/// DNS-resolver based on Hickory
34#[derive(Debug, Clone)]
35pub struct Resolver(Arc<TokioResolver>);
36impl CloneableDnsResolver for Resolver {}
37impl HyperDnsResolver for Resolver {}
38impl CloneableHyperDnsResolver for Resolver {}
39
40impl Resolver {
41    /// Creates a new resolver with given options.
42    /// It must be called in Tokio context.
43    pub fn new(o: Options) -> Self {
44        let name_servers = match o.protocol {
45            Protocol::Clear(p) => NameServerConfigGroup::from_ips_clear(&o.servers, p, true),
46            Protocol::Tls(p) => {
47                NameServerConfigGroup::from_ips_tls(&o.servers, p, o.tls_name, true)
48            }
49            Protocol::Https(p) => {
50                NameServerConfigGroup::from_ips_https(&o.servers, p, o.tls_name, true)
51            }
52        };
53
54        let cfg = ResolverConfig::from_parts(None, vec![], name_servers);
55
56        let mut opts = ResolverOpts::default();
57        opts.cache_size = o.cache_size;
58        opts.timeout = o.timeout;
59        opts.ip_strategy = o.lookup_ip_strategy;
60        opts.use_hosts_file = ResolveHosts::Never;
61        opts.preserve_intermediates = false;
62        opts.try_tcp_on_error = true;
63
64        let builder = TokioResolver::builder_with_config(cfg, TokioConnectionProvider::default())
65            .with_options(opts);
66
67        Self(Arc::new(builder.build()))
68    }
69}
70
71impl Default for Resolver {
72    fn default() -> Self {
73        Self::new(Options::default())
74    }
75}
76
77// Implement resolving for Reqwest
78impl Resolve for Resolver {
79    fn resolve(&self, name: Name) -> Resolving {
80        let resolver = self.clone();
81
82        Box::pin(async move {
83            let lookup = resolver.0.lookup_ip(name.as_str()).await?;
84            let addrs: Addrs = Box::new(SocketAddrs {
85                iter: Box::new(lookup.into_iter()),
86            });
87
88            Ok(addrs)
89        })
90    }
91}
92
93#[async_trait]
94impl Resolves for Resolver {
95    async fn resolve(
96        &self,
97        record_type: RecordType,
98        name: &str,
99    ) -> Result<Vec<Record>, ResolveError> {
100        let lookup = self.0.lookup(name, record_type).await?;
101        Ok(lookup.records().to_vec())
102    }
103
104    fn flush_cache(&self) {
105        self.0.clear_cache();
106    }
107}
108
109/// Implement resolving for Hyper
110impl Service<HyperName> for Resolver {
111    type Response = SocketAddrs;
112    type Error = Error;
113    #[allow(clippy::type_complexity)]
114    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
115
116    fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
117        Poll::Ready(Ok(()))
118    }
119
120    fn call(&mut self, name: HyperName) -> Self::Future {
121        let resolver = self.0.clone();
122
123        Box::pin(async move {
124            let response = resolver
125                .lookup_ip(name.as_str())
126                .await
127                .map_err(|e| Error::DnsError(e.to_string()))?;
128            let addresses = response.into_iter();
129
130            Ok(SocketAddrs {
131                iter: Box::new(addresses),
132            })
133        })
134    }
135}
136
137/// Resolver that always resolves the predefined hostname instead of provided one.
138/// Wraps `Resolver`.
139#[derive(Debug, Clone)]
140pub struct FixedResolver(Resolver, String, HyperName);
141impl CloneableDnsResolver for FixedResolver {}
142impl HyperDnsResolver for FixedResolver {}
143impl CloneableHyperDnsResolver for FixedResolver {}
144
145impl FixedResolver {
146    pub fn new(o: Options, name: String) -> Result<Self, Error> {
147        let resolver = Resolver::new(o);
148        let hyper_name = HyperName::from_str(&name).context("unable to parse name")?;
149
150        Ok(Self(resolver, name, hyper_name))
151    }
152}
153
154/// Implement resolving for Reqwest
155impl Resolve for FixedResolver {
156    fn resolve(&self, _name: Name) -> Resolving {
157        // Name cannot be cloned so we have to parse it each time.
158        // If new() succeeded then this will always succeed too.
159        let name = Name::from_str(&self.1).unwrap();
160        reqwest::dns::Resolve::resolve(&self.0, name)
161    }
162}
163
164/// Implement resolving for Hyper
165impl Service<HyperName> for FixedResolver {
166    type Response = SocketAddrs;
167    type Error = Error;
168    #[allow(clippy::type_complexity)]
169    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
170
171    fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
172        Poll::Ready(Ok(()))
173    }
174
175    fn call(&mut self, _name: HyperName) -> Self::Future {
176        self.0.call(self.2.clone())
177    }
178}
179
180/// Resolver that resolves from the provided mappings
181#[derive(Debug, Clone)]
182pub struct StaticResolver(Arc<BTreeMap<String, Vec<IpAddr>>>);
183impl CloneableDnsResolver for StaticResolver {}
184impl HyperDnsResolver for StaticResolver {}
185impl CloneableHyperDnsResolver for StaticResolver {}
186
187impl StaticResolver {
188    pub fn new(items: impl IntoIterator<Item = (String, Vec<IpAddr>)>) -> Self {
189        Self(Arc::new(BTreeMap::from_iter(items)))
190    }
191
192    pub fn lookup(&self, name: &str) -> Option<Vec<IpAddr>> {
193        self.0.get(name).cloned()
194    }
195}
196
197/// Implement resolving for Reqwest
198impl Resolve for StaticResolver {
199    fn resolve(&self, name: Name) -> Resolving {
200        let addrs = self.lookup(name.as_str()).unwrap_or_default();
201
202        Box::pin(async move {
203            Ok(Box::new(SocketAddrs {
204                iter: Box::new(addrs.into_iter()),
205            }) as Addrs)
206        })
207    }
208}
209
210/// Implement resolving for Hyper
211impl Service<HyperName> for StaticResolver {
212    type Response = SocketAddrs;
213    type Error = Error;
214    #[allow(clippy::type_complexity)]
215    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
216
217    fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
218        Poll::Ready(Ok(()))
219    }
220
221    fn call(&mut self, name: HyperName) -> Self::Future {
222        let addrs = self.lookup(name.as_str()).unwrap_or_default();
223
224        Box::pin(async move {
225            Ok(SocketAddrs {
226                iter: Box::new(addrs.into_iter()),
227            })
228        })
229    }
230}
231
232/// Resolver that resolves the API BN IPs using the registry.
233/// If the registry doesn't contain the requested host - use the normal fallback
234/// DNS resolver to look it up.
235#[derive(Debug, Clone)]
236pub struct ApiBnResolver {
237    agent: Agent,
238    subnet: Principal,
239    resolver_static: Arc<ArcSwap<StaticResolver>>,
240    resolver_fallback: Resolver,
241}
242impl CloneableDnsResolver for ApiBnResolver {}
243impl HyperDnsResolver for ApiBnResolver {}
244impl CloneableHyperDnsResolver for ApiBnResolver {}
245
246impl ApiBnResolver {
247    pub fn new(resolver_fallback: Resolver, agent: Agent) -> Self {
248        let resolver_static = Arc::new(ArcSwap::new(Arc::new(StaticResolver::new(vec![]))));
249        let subnet = principal!("tdb26-jop6k-aogll-7ltgs-eruif-6kk7m-qpktf-gdiqx-mxtrf-vb5e6-eqe");
250
251        Self {
252            agent,
253            subnet,
254            resolver_static,
255            resolver_fallback,
256        }
257    }
258
259    /// Gets a list of API BN domains and their IP addresses from the registry
260    async fn get_api_bns(&self) -> Result<Vec<(String, Vec<IpAddr>)>, Error> {
261        let api_bns = self
262            .agent
263            .fetch_api_boundary_nodes_by_subnet_id(self.subnet)
264            .await
265            .context("unable to get API BNs from IC")?;
266
267        let mut r = Vec::with_capacity(api_bns.len());
268        for n in api_bns {
269            let ipv6 = IpAddr::from_str(&n.ipv6_address)
270                .context(format!("unable to parse IPv6 address for {}", n.domain))?;
271            let mut addrs = vec![ipv6];
272
273            // See if there's an IPv4 too
274            if let Some(v) = n.ipv4_address {
275                let ipv4 = IpAddr::from_str(&v)
276                    .context(format!("unable to parse IPv4 address for {}", n.domain))?;
277                addrs.push(ipv4);
278            }
279
280            r.push((n.domain, addrs));
281        }
282
283        Ok(r)
284    }
285}
286
287/// Implement resolving for Reqwest
288impl Resolve for ApiBnResolver {
289    fn resolve(&self, name: Name) -> Resolving {
290        let api_bns = self.resolver_static.load_full().lookup(name.as_str());
291        let resolver_fallback = self.resolver_fallback.clone();
292
293        Box::pin(async move {
294            let addrs = match api_bns {
295                Some(v) => v,
296                None => {
297                    // Look up using a fallback resolver if nothing was found in the static one
298                    resolver_fallback
299                        .0
300                        .lookup_ip(name.as_str())
301                        .await
302                        .map_err(|e| Error::DnsError(e.to_string()))?
303                        .into_iter()
304                        .collect()
305                }
306            };
307
308            Ok(Box::new(SocketAddrs {
309                iter: Box::new(addrs.into_iter()),
310            }) as Addrs)
311        })
312    }
313}
314
315/// Implement resolving for Hyper
316impl Service<HyperName> for ApiBnResolver {
317    type Response = SocketAddrs;
318    type Error = Error;
319    #[allow(clippy::type_complexity)]
320    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
321
322    fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
323        Poll::Ready(Ok(()))
324    }
325
326    fn call(&mut self, name: HyperName) -> Self::Future {
327        let api_bns = self.resolver_static.load_full().lookup(name.as_str());
328        let resolver_fallback = self.resolver_fallback.clone();
329
330        Box::pin(async move {
331            let addrs = match api_bns {
332                Some(v) => v,
333                None => {
334                    // Look up using a fallback resolver if nothing was found in the static one
335                    resolver_fallback
336                        .0
337                        .lookup_ip(name.as_str())
338                        .await
339                        .map_err(|e| Error::DnsError(e.to_string()))?
340                        .into_iter()
341                        .collect()
342                }
343            };
344
345            Ok(SocketAddrs {
346                iter: Box::new(addrs.into_iter()),
347            })
348        })
349    }
350}
351
352#[async_trait]
353impl Run for ApiBnResolver {
354    async fn run(&self, _token: CancellationToken) -> Result<(), anyhow::Error> {
355        let api_bns = self.get_api_bns().await?;
356        let resolver = StaticResolver::new(api_bns);
357        self.resolver_static.store(Arc::new(resolver));
358
359        Ok(())
360    }
361}
362
363/// Resolver that resolves all hostnames to the single IP address
364#[derive(Debug, Clone)]
365pub struct SingleResolver(IpAddr);
366impl CloneableDnsResolver for SingleResolver {}
367
368impl SingleResolver {
369    pub const fn new(addr: IpAddr) -> Self {
370        Self(addr)
371    }
372}
373
374/// Implement resolving for Reqwest
375impl Resolve for SingleResolver {
376    fn resolve(&self, _name: Name) -> Resolving {
377        let addr = self.0;
378
379        Box::pin(async move {
380            Ok(Box::new(SocketAddrs {
381                iter: Box::new(vec![addr].into_iter()),
382            }) as Addrs)
383        })
384    }
385}
386
387#[cfg(test)]
388mod test {
389    use std::net::{Ipv4Addr, SocketAddr};
390
391    use super::*;
392
393    #[test]
394    fn test_dns_protocol() {
395        assert_eq!(Protocol::from_str("clear").unwrap(), Protocol::Clear(53));
396        assert_eq!(Protocol::from_str("tls").unwrap(), Protocol::Tls(853));
397        assert_eq!(Protocol::from_str("https").unwrap(), Protocol::Https(443));
398
399        assert_eq!(
400            Protocol::from_str("clear:8053").unwrap(),
401            Protocol::Clear(8053)
402        );
403        assert_eq!(Protocol::from_str("tls:8853").unwrap(), Protocol::Tls(8853));
404        assert_eq!(
405            Protocol::from_str("https:8443").unwrap(),
406            Protocol::Https(8443)
407        );
408
409        assert!(Protocol::from_str("clear:").is_err(),);
410        assert!(Protocol::from_str("clear:x").is_err(),);
411        assert!(Protocol::from_str("clear:-1").is_err(),);
412        assert!(Protocol::from_str("clear:65537").is_err(),);
413    }
414
415    #[tokio::test]
416    async fn test_single_resolver() {
417        let addr = IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4));
418        let resolver = SingleResolver::new(addr);
419
420        let mut res = resolver
421            .resolve(Name::from_str("foo.bar").unwrap())
422            .await
423            .unwrap();
424        assert_eq!(res.next(), Some(SocketAddr::new(addr, 0)));
425    }
426}