Skip to main content

ic_bn_lib/http/dns/
mod.rs

1use core::task;
2use std::{
3    collections::BTreeMap, fmt::Debug, net::IpAddr, pin::Pin, str::FromStr, sync::Arc, task::Poll,
4};
5
6use anyhow::Context;
7use arc_swap::ArcSwap;
8use async_trait::async_trait;
9use candid::Principal;
10use hickory_proto::rr::{Record, RecordType};
11use hickory_resolver::{
12    TokioResolver,
13    net::{DnsError as HickoryDnsError, NetError, runtime::TokioRuntimeProvider},
14};
15use hyper_util::client::legacy::connect::dns::Name as HyperName;
16use ic_agent::Agent;
17use ic_bn_lib_common::{
18    principal,
19    traits::{
20        Run,
21        dns::{CloneableDnsResolver, CloneableHyperDnsResolver, HyperDnsResolver, Resolves},
22    },
23    types::{
24        dns::{Options, SocketAddrs},
25        http::Error,
26    },
27};
28use reqwest::dns::{Addrs, Name, Resolve, Resolving};
29use tokio_util::sync::CancellationToken;
30use tower::Service;
31
32#[derive(thiserror::Error, Debug)]
33pub enum DnsError {
34    #[error("Resolver error: {0}")]
35    Resolver(#[from] NetError),
36    #[error("{0}")]
37    Other(#[from] anyhow::Error),
38}
39
40/// Checks if given Hickory error means there was a negative lookup
41pub fn is_error_negative_lookup(e: &NetError) -> bool {
42    e.is_no_records_found()
43        || e.is_nx_domain()
44        || matches!(e, NetError::Dns(HickoryDnsError::Nsec { .. }))
45}
46
47/// DNS-resolver based on Hickory
48#[derive(Debug, Clone)]
49pub struct Resolver(Arc<TokioResolver>);
50impl CloneableDnsResolver for Resolver {}
51impl HyperDnsResolver for Resolver {}
52impl CloneableHyperDnsResolver for Resolver {}
53
54impl Resolver {
55    /// Creates a new resolver with given options.
56    /// It must be called in Tokio context.
57    pub fn new(o: Options) -> Result<Self, NetError> {
58        let resolver = TokioResolver::builder_with_config(o.cfg, TokioRuntimeProvider::default())
59            .with_options(o.opts)
60            .build()?;
61
62        Ok(Self(Arc::new(resolver)))
63    }
64}
65
66impl Default for Resolver {
67    fn default() -> Self {
68        // SAFETY: the defaults shouldn't fail
69        Self::new(Options::default()).unwrap()
70    }
71}
72
73// Implement resolving for Reqwest
74// Clippy being stupid here (lifetimes issue).
75#[allow(clippy::needless_collect)]
76impl Resolve for Resolver {
77    fn resolve(&self, name: Name) -> Resolving {
78        let resolver = self.clone();
79
80        Box::pin(async move {
81            let lookup = resolver.0.lookup_ip(name.as_str()).await?;
82            let addresses: Vec<IpAddr> = lookup.iter().collect();
83
84            Ok(Box::new(SocketAddrs {
85                iter: Box::new(addresses.into_iter()),
86            }) as Addrs)
87        })
88    }
89}
90
91#[async_trait]
92impl Resolves for Resolver {
93    async fn resolve(&self, record_type: RecordType, name: &str) -> Result<Vec<Record>, NetError> {
94        let lookup = self.0.lookup(name, record_type).await?;
95        Ok(lookup.answers().to_vec())
96    }
97
98    fn flush_cache(&self) {
99        self.0.clear_cache();
100    }
101}
102
103// Implement resolving for Hyper
104// Clippy being stupid here (lifetimes issue).
105#[allow(clippy::needless_collect)]
106impl Service<HyperName> for Resolver {
107    type Response = SocketAddrs;
108    type Error = Error;
109    #[allow(clippy::type_complexity)]
110    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
111
112    fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
113        Poll::Ready(Ok(()))
114    }
115
116    fn call(&mut self, name: HyperName) -> Self::Future {
117        let resolver = self.0.clone();
118
119        Box::pin(async move {
120            let lookup = resolver
121                .lookup_ip(name.as_str())
122                .await
123                .map_err(|e: NetError| Error::DnsError(e.to_string()))?;
124            let addresses: Vec<IpAddr> = lookup.iter().collect();
125
126            Ok(SocketAddrs {
127                iter: Box::new(addresses.into_iter()),
128            })
129        })
130    }
131}
132
133/// Resolver that always resolves the predefined hostname instead of provided one.
134/// Wraps `Resolver`.
135#[derive(Debug, Clone)]
136pub struct FixedResolver(Resolver, String, HyperName);
137impl CloneableDnsResolver for FixedResolver {}
138impl HyperDnsResolver for FixedResolver {}
139impl CloneableHyperDnsResolver for FixedResolver {}
140
141impl FixedResolver {
142    pub fn new(o: Options, name: String) -> Result<Self, DnsError> {
143        let resolver = Resolver::new(o)?;
144        let hyper_name = HyperName::from_str(&name).context("unable to parse name")?;
145
146        Ok(Self(resolver, name, hyper_name))
147    }
148}
149
150/// Implement resolving for Reqwest
151impl Resolve for FixedResolver {
152    fn resolve(&self, _name: Name) -> Resolving {
153        // Name cannot be cloned so we have to parse it each time.
154        // If new() succeeded then this will always succeed too.
155        let name = Name::from_str(&self.1).unwrap();
156        reqwest::dns::Resolve::resolve(&self.0, name)
157    }
158}
159
160/// Implement resolving for Hyper
161impl Service<HyperName> for FixedResolver {
162    type Response = SocketAddrs;
163    type Error = Error;
164    #[allow(clippy::type_complexity)]
165    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
166
167    fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
168        Poll::Ready(Ok(()))
169    }
170
171    fn call(&mut self, _name: HyperName) -> Self::Future {
172        self.0.call(self.2.clone())
173    }
174}
175
176/// Resolver that resolves from the provided mappings
177#[derive(Debug, Clone)]
178pub struct StaticResolver(Arc<BTreeMap<String, Vec<IpAddr>>>);
179impl CloneableDnsResolver for StaticResolver {}
180impl HyperDnsResolver for StaticResolver {}
181impl CloneableHyperDnsResolver for StaticResolver {}
182
183impl StaticResolver {
184    pub fn new(items: impl IntoIterator<Item = (String, Vec<IpAddr>)>) -> Self {
185        Self(Arc::new(BTreeMap::from_iter(items)))
186    }
187
188    pub fn lookup(&self, name: &str) -> Option<Vec<IpAddr>> {
189        self.0.get(name).cloned()
190    }
191}
192
193/// Implement resolving for Reqwest
194impl Resolve for StaticResolver {
195    fn resolve(&self, name: Name) -> Resolving {
196        let addrs = self.lookup(name.as_str()).unwrap_or_default();
197
198        Box::pin(async move {
199            Ok(Box::new(SocketAddrs {
200                iter: Box::new(addrs.into_iter()),
201            }) as Addrs)
202        })
203    }
204}
205
206/// Implement resolving for Hyper
207impl Service<HyperName> for StaticResolver {
208    type Response = SocketAddrs;
209    type Error = Error;
210    #[allow(clippy::type_complexity)]
211    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
212
213    fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
214        Poll::Ready(Ok(()))
215    }
216
217    fn call(&mut self, name: HyperName) -> Self::Future {
218        let addrs = self.lookup(name.as_str()).unwrap_or_default();
219
220        Box::pin(async move {
221            Ok(SocketAddrs {
222                iter: Box::new(addrs.into_iter()),
223            })
224        })
225    }
226}
227
228/// Resolver that resolves the API BN IPs using the registry.
229/// If the registry doesn't contain the requested host - use the normal fallback
230/// DNS resolver to look it up.
231#[derive(Debug, Clone)]
232pub struct ApiBnResolver {
233    agent: Agent,
234    subnet: Principal,
235    resolver_static: Arc<ArcSwap<StaticResolver>>,
236    resolver_fallback: Resolver,
237}
238impl CloneableDnsResolver for ApiBnResolver {}
239impl HyperDnsResolver for ApiBnResolver {}
240impl CloneableHyperDnsResolver for ApiBnResolver {}
241
242impl ApiBnResolver {
243    pub fn new(resolver_fallback: Resolver, agent: Agent) -> Self {
244        let resolver_static = Arc::new(ArcSwap::new(Arc::new(StaticResolver::new(vec![]))));
245        let subnet = principal!("tdb26-jop6k-aogll-7ltgs-eruif-6kk7m-qpktf-gdiqx-mxtrf-vb5e6-eqe");
246
247        Self {
248            agent,
249            subnet,
250            resolver_static,
251            resolver_fallback,
252        }
253    }
254
255    /// Gets a list of API BN domains and their IP addresses from the registry
256    async fn get_api_bns(&self) -> Result<Vec<(String, Vec<IpAddr>)>, Error> {
257        let api_bns = self
258            .agent
259            .fetch_api_boundary_nodes_by_subnet_id(self.subnet)
260            .await
261            .context("unable to get API BNs from IC")?;
262
263        let mut r = Vec::with_capacity(api_bns.len());
264        for n in api_bns {
265            let ipv6 = IpAddr::from_str(&n.ipv6_address)
266                .context(format!("unable to parse IPv6 address for {}", n.domain))?;
267            let mut addrs = vec![ipv6];
268
269            // See if there's an IPv4 too
270            if let Some(v) = n.ipv4_address {
271                let ipv4 = IpAddr::from_str(&v)
272                    .context(format!("unable to parse IPv4 address for {}", n.domain))?;
273                addrs.push(ipv4);
274            }
275
276            r.push((n.domain, addrs));
277        }
278
279        Ok(r)
280    }
281}
282
283/// Implement resolving for Reqwest
284impl Resolve for ApiBnResolver {
285    fn resolve(&self, name: Name) -> Resolving {
286        let api_bns = self.resolver_static.load_full().lookup(name.as_str());
287        let resolver_fallback = self.resolver_fallback.clone();
288
289        Box::pin(async move {
290            let addrs = match api_bns {
291                Some(v) => v,
292                None => {
293                    // Look up using a fallback resolver if nothing was found in the static one
294                    resolver_fallback
295                        .0
296                        .lookup_ip(name.as_str())
297                        .await
298                        .map_err(|e| Error::DnsError(e.to_string()))?
299                        .iter()
300                        .collect()
301                }
302            };
303
304            Ok(Box::new(SocketAddrs {
305                iter: Box::new(addrs.into_iter()),
306            }) as Addrs)
307        })
308    }
309}
310
311/// Implement resolving for Hyper
312impl Service<HyperName> for ApiBnResolver {
313    type Response = SocketAddrs;
314    type Error = Error;
315    #[allow(clippy::type_complexity)]
316    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
317
318    fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
319        Poll::Ready(Ok(()))
320    }
321
322    fn call(&mut self, name: HyperName) -> Self::Future {
323        let api_bns = self.resolver_static.load_full().lookup(name.as_str());
324        let resolver_fallback = self.resolver_fallback.clone();
325
326        Box::pin(async move {
327            let addrs = match api_bns {
328                Some(v) => v,
329                None => {
330                    // Look up using a fallback resolver if nothing was found in the static one
331                    resolver_fallback
332                        .0
333                        .lookup_ip(name.as_str())
334                        .await
335                        .map_err(|e| Error::DnsError(e.to_string()))?
336                        .iter()
337                        .collect()
338                }
339            };
340
341            Ok(SocketAddrs {
342                iter: Box::new(addrs.into_iter()),
343            })
344        })
345    }
346}
347
348#[async_trait]
349impl Run for ApiBnResolver {
350    async fn run(&self, _token: CancellationToken) -> Result<(), anyhow::Error> {
351        let api_bns = self.get_api_bns().await?;
352        let resolver = StaticResolver::new(api_bns);
353        self.resolver_static.store(Arc::new(resolver));
354
355        Ok(())
356    }
357}
358
359/// Resolver that resolves all hostnames to the single IP address
360#[derive(Debug, Clone)]
361pub struct SingleResolver(IpAddr);
362impl CloneableDnsResolver for SingleResolver {}
363
364impl SingleResolver {
365    pub const fn new(addr: IpAddr) -> Self {
366        Self(addr)
367    }
368}
369
370/// Implement resolving for Reqwest
371impl Resolve for SingleResolver {
372    fn resolve(&self, _name: Name) -> Resolving {
373        let addr = self.0;
374
375        Box::pin(async move {
376            Ok(Box::new(SocketAddrs {
377                iter: Box::new(vec![addr].into_iter()),
378            }) as Addrs)
379        })
380    }
381}
382
383#[cfg(test)]
384mod test {
385    use std::net::{Ipv4Addr, SocketAddr};
386
387    use ic_bn_lib_common::types::dns::Protocol;
388
389    use super::*;
390
391    #[test]
392    fn test_dns_protocol() {
393        assert_eq!(Protocol::from_str("clear").unwrap(), Protocol::Clear(53));
394        assert_eq!(Protocol::from_str("tls").unwrap(), Protocol::Tls(853));
395        assert_eq!(Protocol::from_str("https").unwrap(), Protocol::Https(443));
396
397        assert_eq!(
398            Protocol::from_str("clear:8053").unwrap(),
399            Protocol::Clear(8053)
400        );
401        assert_eq!(Protocol::from_str("tls:8853").unwrap(), Protocol::Tls(8853));
402        assert_eq!(
403            Protocol::from_str("https:8443").unwrap(),
404            Protocol::Https(8443)
405        );
406
407        assert!(Protocol::from_str("clear:").is_err(),);
408        assert!(Protocol::from_str("clear:x").is_err(),);
409        assert!(Protocol::from_str("clear:-1").is_err(),);
410        assert!(Protocol::from_str("clear:65537").is_err(),);
411    }
412
413    #[tokio::test]
414    async fn test_single_resolver() {
415        let addr = IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4));
416        let resolver = SingleResolver::new(addr);
417
418        let mut res = resolver
419            .resolve(Name::from_str("foo.bar").unwrap())
420            .await
421            .unwrap();
422        assert_eq!(res.next(), Some(SocketAddr::new(addr, 0)));
423    }
424}