Skip to main content

ic_bn_lib/http/dns/
mod.rs

1use core::task;
2use std::{
3    collections::BTreeMap, fmt::Debug, net::IpAddr, pin::Pin, str::FromStr, sync::Arc, task::Poll,
4};
5
6use anyhow::Context;
7use arc_swap::ArcSwap;
8use async_trait::async_trait;
9use candid::Principal;
10use hickory_proto::rr::{Record, RecordType};
11use hickory_resolver::{
12    ResolveError, TokioResolver,
13    config::{NameServerConfigGroup, ResolveHosts, ResolverConfig, ResolverOpts},
14    name_server::TokioConnectionProvider,
15};
16use hyper_util::client::legacy::connect::dns::Name as HyperName;
17use ic_agent::Agent;
18use ic_bn_lib_common::{
19    principal,
20    traits::{
21        Run,
22        dns::{CloneableDnsResolver, CloneableHyperDnsResolver, HyperDnsResolver, Resolves},
23    },
24    types::{
25        dns::{Options, Protocol, SocketAddrs},
26        http::Error,
27    },
28};
29use reqwest::dns::{Addrs, Name, Resolve, Resolving};
30use tokio_util::sync::CancellationToken;
31use tower::Service;
32
33/// DNS-resolver based on Hickory
34#[derive(Debug, Clone)]
35pub struct Resolver(Arc<TokioResolver>);
36impl CloneableDnsResolver for Resolver {}
37impl HyperDnsResolver for Resolver {}
38impl CloneableHyperDnsResolver for Resolver {}
39
40impl Resolver {
41    /// Creates a new resolver with given options.
42    /// It must be called in Tokio context.
43    pub fn new(o: Options) -> Self {
44        let name_servers = match o.protocol {
45            Protocol::Clear(p) => NameServerConfigGroup::from_ips_clear(&o.servers, p, true),
46            Protocol::Tls(p) => {
47                NameServerConfigGroup::from_ips_tls(&o.servers, p, o.tls_name, true)
48            }
49            Protocol::Https(p) => {
50                NameServerConfigGroup::from_ips_https(&o.servers, p, o.tls_name, true)
51            }
52        };
53
54        let cfg = ResolverConfig::from_parts(None, vec![], name_servers);
55
56        let mut opts = ResolverOpts::default();
57        opts.cache_size = o.cache_size;
58        opts.timeout = o.timeout;
59        opts.ip_strategy = o.lookup_ip_strategy;
60        opts.use_hosts_file = ResolveHosts::Never;
61        opts.preserve_intermediates = false;
62        opts.validate = !o.dnssec_disabled;
63        opts.try_tcp_on_error = true;
64
65        let builder = TokioResolver::builder_with_config(cfg, TokioConnectionProvider::default())
66            .with_options(opts);
67
68        Self(Arc::new(builder.build()))
69    }
70}
71
72impl Default for Resolver {
73    fn default() -> Self {
74        Self::new(Options::default())
75    }
76}
77
78// Implement resolving for Reqwest
79impl Resolve for Resolver {
80    fn resolve(&self, name: Name) -> Resolving {
81        let resolver = self.clone();
82
83        Box::pin(async move {
84            let lookup = resolver.0.lookup_ip(name.as_str()).await?;
85            let addrs: Addrs = Box::new(SocketAddrs {
86                iter: Box::new(lookup.into_iter()),
87            });
88
89            Ok(addrs)
90        })
91    }
92}
93
94#[async_trait]
95impl Resolves for Resolver {
96    async fn resolve(
97        &self,
98        record_type: RecordType,
99        name: &str,
100    ) -> Result<Vec<Record>, ResolveError> {
101        let lookup = self.0.lookup(name, record_type).await?;
102        Ok(lookup.records().to_vec())
103    }
104
105    fn flush_cache(&self) {
106        self.0.clear_cache();
107    }
108}
109
110/// Implement resolving for Hyper
111impl Service<HyperName> for Resolver {
112    type Response = SocketAddrs;
113    type Error = Error;
114    #[allow(clippy::type_complexity)]
115    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
116
117    fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
118        Poll::Ready(Ok(()))
119    }
120
121    fn call(&mut self, name: HyperName) -> Self::Future {
122        let resolver = self.0.clone();
123
124        Box::pin(async move {
125            let response = resolver
126                .lookup_ip(name.as_str())
127                .await
128                .map_err(|e| Error::DnsError(e.to_string()))?;
129            let addresses = response.into_iter();
130
131            Ok(SocketAddrs {
132                iter: Box::new(addresses),
133            })
134        })
135    }
136}
137
138/// Resolver that always resolves the predefined hostname instead of provided one.
139/// Wraps `Resolver`.
140#[derive(Debug, Clone)]
141pub struct FixedResolver(Resolver, String, HyperName);
142impl CloneableDnsResolver for FixedResolver {}
143impl HyperDnsResolver for FixedResolver {}
144impl CloneableHyperDnsResolver for FixedResolver {}
145
146impl FixedResolver {
147    pub fn new(o: Options, name: String) -> Result<Self, Error> {
148        let resolver = Resolver::new(o);
149        let hyper_name = HyperName::from_str(&name).context("unable to parse name")?;
150
151        Ok(Self(resolver, name, hyper_name))
152    }
153}
154
155/// Implement resolving for Reqwest
156impl Resolve for FixedResolver {
157    fn resolve(&self, _name: Name) -> Resolving {
158        // Name cannot be cloned so we have to parse it each time.
159        // If new() succeeded then this will always succeed too.
160        let name = Name::from_str(&self.1).unwrap();
161        reqwest::dns::Resolve::resolve(&self.0, name)
162    }
163}
164
165/// Implement resolving for Hyper
166impl Service<HyperName> for FixedResolver {
167    type Response = SocketAddrs;
168    type Error = Error;
169    #[allow(clippy::type_complexity)]
170    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
171
172    fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
173        Poll::Ready(Ok(()))
174    }
175
176    fn call(&mut self, _name: HyperName) -> Self::Future {
177        self.0.call(self.2.clone())
178    }
179}
180
181/// Resolver that resolves from the provided mappings
182#[derive(Debug, Clone)]
183pub struct StaticResolver(Arc<BTreeMap<String, Vec<IpAddr>>>);
184impl CloneableDnsResolver for StaticResolver {}
185impl HyperDnsResolver for StaticResolver {}
186impl CloneableHyperDnsResolver for StaticResolver {}
187
188impl StaticResolver {
189    pub fn new(items: impl IntoIterator<Item = (String, Vec<IpAddr>)>) -> Self {
190        Self(Arc::new(BTreeMap::from_iter(items)))
191    }
192
193    pub fn lookup(&self, name: &str) -> Option<Vec<IpAddr>> {
194        self.0.get(name).cloned()
195    }
196}
197
198/// Implement resolving for Reqwest
199impl Resolve for StaticResolver {
200    fn resolve(&self, name: Name) -> Resolving {
201        let addrs = self.lookup(name.as_str()).unwrap_or_default();
202
203        Box::pin(async move {
204            Ok(Box::new(SocketAddrs {
205                iter: Box::new(addrs.into_iter()),
206            }) as Addrs)
207        })
208    }
209}
210
211/// Implement resolving for Hyper
212impl Service<HyperName> for StaticResolver {
213    type Response = SocketAddrs;
214    type Error = Error;
215    #[allow(clippy::type_complexity)]
216    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
217
218    fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
219        Poll::Ready(Ok(()))
220    }
221
222    fn call(&mut self, name: HyperName) -> Self::Future {
223        let addrs = self.lookup(name.as_str()).unwrap_or_default();
224
225        Box::pin(async move {
226            Ok(SocketAddrs {
227                iter: Box::new(addrs.into_iter()),
228            })
229        })
230    }
231}
232
233/// Resolver that resolves the API BN IPs using the registry.
234/// If the registry doesn't contain the requested host - use the normal fallback
235/// DNS resolver to look it up.
236#[derive(Debug, Clone)]
237pub struct ApiBnResolver {
238    agent: Agent,
239    subnet: Principal,
240    resolver_static: Arc<ArcSwap<StaticResolver>>,
241    resolver_fallback: Resolver,
242}
243impl CloneableDnsResolver for ApiBnResolver {}
244impl HyperDnsResolver for ApiBnResolver {}
245impl CloneableHyperDnsResolver for ApiBnResolver {}
246
247impl ApiBnResolver {
248    pub fn new(resolver_fallback: Resolver, agent: Agent) -> Self {
249        let resolver_static = Arc::new(ArcSwap::new(Arc::new(StaticResolver::new(vec![]))));
250        let subnet = principal!("tdb26-jop6k-aogll-7ltgs-eruif-6kk7m-qpktf-gdiqx-mxtrf-vb5e6-eqe");
251
252        Self {
253            agent,
254            subnet,
255            resolver_static,
256            resolver_fallback,
257        }
258    }
259
260    /// Gets a list of API BN domains and their IP addresses from the registry
261    async fn get_api_bns(&self) -> Result<Vec<(String, Vec<IpAddr>)>, Error> {
262        let api_bns = self
263            .agent
264            .fetch_api_boundary_nodes_by_subnet_id(self.subnet)
265            .await
266            .context("unable to get API BNs from IC")?;
267
268        let mut r = Vec::with_capacity(api_bns.len());
269        for n in api_bns {
270            let ipv6 = IpAddr::from_str(&n.ipv6_address)
271                .context(format!("unable to parse IPv6 address for {}", n.domain))?;
272            let mut addrs = vec![ipv6];
273
274            // See if there's an IPv4 too
275            if let Some(v) = n.ipv4_address {
276                let ipv4 = IpAddr::from_str(&v)
277                    .context(format!("unable to parse IPv4 address for {}", n.domain))?;
278                addrs.push(ipv4);
279            }
280
281            r.push((n.domain, addrs));
282        }
283
284        Ok(r)
285    }
286}
287
288/// Implement resolving for Reqwest
289impl Resolve for ApiBnResolver {
290    fn resolve(&self, name: Name) -> Resolving {
291        let api_bns = self.resolver_static.load_full().lookup(name.as_str());
292        let resolver_fallback = self.resolver_fallback.clone();
293
294        Box::pin(async move {
295            let addrs = match api_bns {
296                Some(v) => v,
297                None => {
298                    // Look up using a fallback resolver if nothing was found in the static one
299                    resolver_fallback
300                        .0
301                        .lookup_ip(name.as_str())
302                        .await
303                        .map_err(|e| Error::DnsError(e.to_string()))?
304                        .into_iter()
305                        .collect()
306                }
307            };
308
309            Ok(Box::new(SocketAddrs {
310                iter: Box::new(addrs.into_iter()),
311            }) as Addrs)
312        })
313    }
314}
315
316/// Implement resolving for Hyper
317impl Service<HyperName> for ApiBnResolver {
318    type Response = SocketAddrs;
319    type Error = Error;
320    #[allow(clippy::type_complexity)]
321    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
322
323    fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
324        Poll::Ready(Ok(()))
325    }
326
327    fn call(&mut self, name: HyperName) -> Self::Future {
328        let api_bns = self.resolver_static.load_full().lookup(name.as_str());
329        let resolver_fallback = self.resolver_fallback.clone();
330
331        Box::pin(async move {
332            let addrs = match api_bns {
333                Some(v) => v,
334                None => {
335                    // Look up using a fallback resolver if nothing was found in the static one
336                    resolver_fallback
337                        .0
338                        .lookup_ip(name.as_str())
339                        .await
340                        .map_err(|e| Error::DnsError(e.to_string()))?
341                        .into_iter()
342                        .collect()
343                }
344            };
345
346            Ok(SocketAddrs {
347                iter: Box::new(addrs.into_iter()),
348            })
349        })
350    }
351}
352
353#[async_trait]
354impl Run for ApiBnResolver {
355    async fn run(&self, _token: CancellationToken) -> Result<(), anyhow::Error> {
356        let api_bns = self.get_api_bns().await?;
357        let resolver = StaticResolver::new(api_bns);
358        self.resolver_static.store(Arc::new(resolver));
359
360        Ok(())
361    }
362}
363
364/// Resolver that resolves all hostnames to the single IP address
365#[derive(Debug, Clone)]
366pub struct SingleResolver(IpAddr);
367impl CloneableDnsResolver for SingleResolver {}
368
369impl SingleResolver {
370    pub const fn new(addr: IpAddr) -> Self {
371        Self(addr)
372    }
373}
374
375/// Implement resolving for Reqwest
376impl Resolve for SingleResolver {
377    fn resolve(&self, _name: Name) -> Resolving {
378        let addr = self.0;
379
380        Box::pin(async move {
381            Ok(Box::new(SocketAddrs {
382                iter: Box::new(vec![addr].into_iter()),
383            }) as Addrs)
384        })
385    }
386}
387
388#[cfg(test)]
389mod test {
390    use std::net::{Ipv4Addr, SocketAddr};
391
392    use super::*;
393
394    #[test]
395    fn test_dns_protocol() {
396        assert_eq!(Protocol::from_str("clear").unwrap(), Protocol::Clear(53));
397        assert_eq!(Protocol::from_str("tls").unwrap(), Protocol::Tls(853));
398        assert_eq!(Protocol::from_str("https").unwrap(), Protocol::Https(443));
399
400        assert_eq!(
401            Protocol::from_str("clear:8053").unwrap(),
402            Protocol::Clear(8053)
403        );
404        assert_eq!(Protocol::from_str("tls:8853").unwrap(), Protocol::Tls(8853));
405        assert_eq!(
406            Protocol::from_str("https:8443").unwrap(),
407            Protocol::Https(8443)
408        );
409
410        assert!(Protocol::from_str("clear:").is_err(),);
411        assert!(Protocol::from_str("clear:x").is_err(),);
412        assert!(Protocol::from_str("clear:-1").is_err(),);
413        assert!(Protocol::from_str("clear:65537").is_err(),);
414    }
415
416    #[tokio::test]
417    async fn test_single_resolver() {
418        let addr = IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4));
419        let resolver = SingleResolver::new(addr);
420
421        let mut res = resolver
422            .resolve(Name::from_str("foo.bar").unwrap())
423            .await
424            .unwrap();
425        assert_eq!(res.next(), Some(SocketAddr::new(addr, 0)));
426    }
427}