1use core::task;
2use std::{
3 collections::BTreeMap, fmt::Debug, net::IpAddr, pin::Pin, str::FromStr, sync::Arc, task::Poll,
4};
5
6use anyhow::Context;
7use arc_swap::ArcSwap;
8use async_trait::async_trait;
9use candid::Principal;
10use hickory_proto::rr::{Record, RecordType};
11use hickory_resolver::{
12 TokioResolver,
13 net::{DnsError as HickoryDnsError, NetError, runtime::TokioRuntimeProvider},
14};
15use hyper_util::client::legacy::connect::dns::Name as HyperName;
16use ic_agent::Agent;
17use ic_bn_lib_common::{
18 principal,
19 traits::{
20 Run,
21 dns::{CloneableDnsResolver, CloneableHyperDnsResolver, HyperDnsResolver, Resolves},
22 },
23 types::{
24 dns::{Options, SocketAddrs},
25 http::Error,
26 },
27};
28use reqwest::dns::{Addrs, Name, Resolve, Resolving};
29use tokio_util::sync::CancellationToken;
30use tower::Service;
31
32#[derive(thiserror::Error, Debug)]
33pub enum DnsError {
34 #[error("Resolver error: {0}")]
35 Resolver(#[from] NetError),
36 #[error("{0}")]
37 Other(#[from] anyhow::Error),
38}
39
40pub fn is_error_negative_lookup(e: &NetError) -> bool {
42 e.is_no_records_found()
43 || e.is_nx_domain()
44 || matches!(e, NetError::Dns(HickoryDnsError::Nsec { .. }))
45}
46
47#[derive(Debug, Clone)]
49pub struct Resolver(Arc<TokioResolver>);
50impl CloneableDnsResolver for Resolver {}
51impl HyperDnsResolver for Resolver {}
52impl CloneableHyperDnsResolver for Resolver {}
53
54impl Resolver {
55 pub fn new(o: Options) -> Result<Self, NetError> {
58 let resolver = TokioResolver::builder_with_config(o.cfg, TokioRuntimeProvider::default())
59 .with_options(o.opts)
60 .build()?;
61
62 Ok(Self(Arc::new(resolver)))
63 }
64}
65
66impl Default for Resolver {
67 fn default() -> Self {
68 Self::new(Options::default()).unwrap()
70 }
71}
72
73#[allow(clippy::needless_collect)]
76impl Resolve for Resolver {
77 fn resolve(&self, name: Name) -> Resolving {
78 let resolver = self.clone();
79
80 Box::pin(async move {
81 let lookup = resolver.0.lookup_ip(name.as_str()).await?;
82 let addresses: Vec<IpAddr> = lookup.iter().collect();
83
84 Ok(Box::new(SocketAddrs {
85 iter: Box::new(addresses.into_iter()),
86 }) as Addrs)
87 })
88 }
89}
90
91#[async_trait]
92impl Resolves for Resolver {
93 async fn resolve(&self, record_type: RecordType, name: &str) -> Result<Vec<Record>, NetError> {
94 let lookup = self.0.lookup(name, record_type).await?;
95 Ok(lookup.answers().to_vec())
96 }
97
98 fn flush_cache(&self) {
99 self.0.clear_cache();
100 }
101}
102
103#[allow(clippy::needless_collect)]
106impl Service<HyperName> for Resolver {
107 type Response = SocketAddrs;
108 type Error = Error;
109 #[allow(clippy::type_complexity)]
110 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
111
112 fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
113 Poll::Ready(Ok(()))
114 }
115
116 fn call(&mut self, name: HyperName) -> Self::Future {
117 let resolver = self.0.clone();
118
119 Box::pin(async move {
120 let lookup = resolver
121 .lookup_ip(name.as_str())
122 .await
123 .map_err(|e: NetError| Error::DnsError(e.to_string()))?;
124 let addresses: Vec<IpAddr> = lookup.iter().collect();
125
126 Ok(SocketAddrs {
127 iter: Box::new(addresses.into_iter()),
128 })
129 })
130 }
131}
132
133#[derive(Debug, Clone)]
136pub struct FixedResolver(Resolver, String, HyperName);
137impl CloneableDnsResolver for FixedResolver {}
138impl HyperDnsResolver for FixedResolver {}
139impl CloneableHyperDnsResolver for FixedResolver {}
140
141impl FixedResolver {
142 pub fn new(o: Options, name: String) -> Result<Self, DnsError> {
143 let resolver = Resolver::new(o)?;
144 let hyper_name = HyperName::from_str(&name).context("unable to parse name")?;
145
146 Ok(Self(resolver, name, hyper_name))
147 }
148}
149
150impl Resolve for FixedResolver {
152 fn resolve(&self, _name: Name) -> Resolving {
153 let name = Name::from_str(&self.1).unwrap();
156 reqwest::dns::Resolve::resolve(&self.0, name)
157 }
158}
159
160impl Service<HyperName> for FixedResolver {
162 type Response = SocketAddrs;
163 type Error = Error;
164 #[allow(clippy::type_complexity)]
165 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
166
167 fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
168 Poll::Ready(Ok(()))
169 }
170
171 fn call(&mut self, _name: HyperName) -> Self::Future {
172 self.0.call(self.2.clone())
173 }
174}
175
176#[derive(Debug, Clone)]
178pub struct StaticResolver(Arc<BTreeMap<String, Vec<IpAddr>>>);
179impl CloneableDnsResolver for StaticResolver {}
180impl HyperDnsResolver for StaticResolver {}
181impl CloneableHyperDnsResolver for StaticResolver {}
182
183impl StaticResolver {
184 pub fn new(items: impl IntoIterator<Item = (String, Vec<IpAddr>)>) -> Self {
185 Self(Arc::new(BTreeMap::from_iter(items)))
186 }
187
188 pub fn lookup(&self, name: &str) -> Option<Vec<IpAddr>> {
189 self.0.get(name).cloned()
190 }
191}
192
193impl Resolve for StaticResolver {
195 fn resolve(&self, name: Name) -> Resolving {
196 let addrs = self.lookup(name.as_str()).unwrap_or_default();
197
198 Box::pin(async move {
199 Ok(Box::new(SocketAddrs {
200 iter: Box::new(addrs.into_iter()),
201 }) as Addrs)
202 })
203 }
204}
205
206impl Service<HyperName> for StaticResolver {
208 type Response = SocketAddrs;
209 type Error = Error;
210 #[allow(clippy::type_complexity)]
211 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
212
213 fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
214 Poll::Ready(Ok(()))
215 }
216
217 fn call(&mut self, name: HyperName) -> Self::Future {
218 let addrs = self.lookup(name.as_str()).unwrap_or_default();
219
220 Box::pin(async move {
221 Ok(SocketAddrs {
222 iter: Box::new(addrs.into_iter()),
223 })
224 })
225 }
226}
227
228#[derive(Debug, Clone)]
232pub struct ApiBnResolver {
233 agent: Agent,
234 subnet: Principal,
235 resolver_static: Arc<ArcSwap<StaticResolver>>,
236 resolver_fallback: Resolver,
237}
238impl CloneableDnsResolver for ApiBnResolver {}
239impl HyperDnsResolver for ApiBnResolver {}
240impl CloneableHyperDnsResolver for ApiBnResolver {}
241
242impl ApiBnResolver {
243 pub fn new(resolver_fallback: Resolver, agent: Agent) -> Self {
244 let resolver_static = Arc::new(ArcSwap::new(Arc::new(StaticResolver::new(vec![]))));
245 let subnet = principal!("tdb26-jop6k-aogll-7ltgs-eruif-6kk7m-qpktf-gdiqx-mxtrf-vb5e6-eqe");
246
247 Self {
248 agent,
249 subnet,
250 resolver_static,
251 resolver_fallback,
252 }
253 }
254
255 async fn get_api_bns(&self) -> Result<Vec<(String, Vec<IpAddr>)>, Error> {
257 let api_bns = self
258 .agent
259 .fetch_api_boundary_nodes_by_subnet_id(self.subnet)
260 .await
261 .context("unable to get API BNs from IC")?;
262
263 let mut r = Vec::with_capacity(api_bns.len());
264 for n in api_bns {
265 let ipv6 = IpAddr::from_str(&n.ipv6_address)
266 .context(format!("unable to parse IPv6 address for {}", n.domain))?;
267 let mut addrs = vec![ipv6];
268
269 if let Some(v) = n.ipv4_address {
271 let ipv4 = IpAddr::from_str(&v)
272 .context(format!("unable to parse IPv4 address for {}", n.domain))?;
273 addrs.push(ipv4);
274 }
275
276 r.push((n.domain, addrs));
277 }
278
279 Ok(r)
280 }
281}
282
283impl Resolve for ApiBnResolver {
285 fn resolve(&self, name: Name) -> Resolving {
286 let api_bns = self.resolver_static.load_full().lookup(name.as_str());
287 let resolver_fallback = self.resolver_fallback.clone();
288
289 Box::pin(async move {
290 let addrs = match api_bns {
291 Some(v) => v,
292 None => {
293 resolver_fallback
295 .0
296 .lookup_ip(name.as_str())
297 .await
298 .map_err(|e| Error::DnsError(e.to_string()))?
299 .iter()
300 .collect()
301 }
302 };
303
304 Ok(Box::new(SocketAddrs {
305 iter: Box::new(addrs.into_iter()),
306 }) as Addrs)
307 })
308 }
309}
310
311impl Service<HyperName> for ApiBnResolver {
313 type Response = SocketAddrs;
314 type Error = Error;
315 #[allow(clippy::type_complexity)]
316 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
317
318 fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
319 Poll::Ready(Ok(()))
320 }
321
322 fn call(&mut self, name: HyperName) -> Self::Future {
323 let api_bns = self.resolver_static.load_full().lookup(name.as_str());
324 let resolver_fallback = self.resolver_fallback.clone();
325
326 Box::pin(async move {
327 let addrs = match api_bns {
328 Some(v) => v,
329 None => {
330 resolver_fallback
332 .0
333 .lookup_ip(name.as_str())
334 .await
335 .map_err(|e| Error::DnsError(e.to_string()))?
336 .iter()
337 .collect()
338 }
339 };
340
341 Ok(SocketAddrs {
342 iter: Box::new(addrs.into_iter()),
343 })
344 })
345 }
346}
347
348#[async_trait]
349impl Run for ApiBnResolver {
350 async fn run(&self, _token: CancellationToken) -> Result<(), anyhow::Error> {
351 let api_bns = self.get_api_bns().await?;
352 let resolver = StaticResolver::new(api_bns);
353 self.resolver_static.store(Arc::new(resolver));
354
355 Ok(())
356 }
357}
358
359#[derive(Debug, Clone)]
361pub struct SingleResolver(IpAddr);
362impl CloneableDnsResolver for SingleResolver {}
363
364impl SingleResolver {
365 pub const fn new(addr: IpAddr) -> Self {
366 Self(addr)
367 }
368}
369
370impl Resolve for SingleResolver {
372 fn resolve(&self, _name: Name) -> Resolving {
373 let addr = self.0;
374
375 Box::pin(async move {
376 Ok(Box::new(SocketAddrs {
377 iter: Box::new(vec![addr].into_iter()),
378 }) as Addrs)
379 })
380 }
381}
382
383#[cfg(test)]
384mod test {
385 use std::net::{Ipv4Addr, SocketAddr};
386
387 use ic_bn_lib_common::types::dns::Protocol;
388
389 use super::*;
390
391 #[test]
392 fn test_dns_protocol() {
393 assert_eq!(Protocol::from_str("clear").unwrap(), Protocol::Clear(53));
394 assert_eq!(Protocol::from_str("tls").unwrap(), Protocol::Tls(853));
395 assert_eq!(Protocol::from_str("https").unwrap(), Protocol::Https(443));
396
397 assert_eq!(
398 Protocol::from_str("clear:8053").unwrap(),
399 Protocol::Clear(8053)
400 );
401 assert_eq!(Protocol::from_str("tls:8853").unwrap(), Protocol::Tls(8853));
402 assert_eq!(
403 Protocol::from_str("https:8443").unwrap(),
404 Protocol::Https(8443)
405 );
406
407 assert!(Protocol::from_str("clear:").is_err(),);
408 assert!(Protocol::from_str("clear:x").is_err(),);
409 assert!(Protocol::from_str("clear:-1").is_err(),);
410 assert!(Protocol::from_str("clear:65537").is_err(),);
411 }
412
413 #[tokio::test]
414 async fn test_single_resolver() {
415 let addr = IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4));
416 let resolver = SingleResolver::new(addr);
417
418 let mut res = resolver
419 .resolve(Name::from_str("foo.bar").unwrap())
420 .await
421 .unwrap();
422 assert_eq!(res.next(), Some(SocketAddr::new(addr, 0)));
423 }
424}