1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
//! Implements [`LookupService`] for dns.

use crate::{LookupService, ServiceDefinition};
use anyhow::Context;
use std::collections::HashSet;
use std::net::SocketAddr;
use trust_dns_resolver::{system_conf, AsyncResolver, TokioAsyncResolver};

/// Implements [`LookupService`] by using DNS queries to lookup [`ServiceDefinition::hostname`].
pub struct DnsResolver {
    /// The trust-dns resolver which contacts the dns service directly such
    /// that we bypass os-specific dns caching.
    dns: TokioAsyncResolver,
}

impl DnsResolver {
    /// Construct a new [`DnsResolver`] from env and system configration, e.g `resolv.conf`.
    pub async fn from_system_config() -> Result<Self, anyhow::Error> {
        let (config, mut opts) = system_conf::read_system_conf()
            .context("failed to read dns services from system configuration")?;

        // We do not want any caching on out side.
        opts.cache_size = 0;

        let dns = AsyncResolver::tokio(config, opts).expect("resolver must be valid");

        Ok(Self { dns })
    }
}

#[async_trait::async_trait]
impl LookupService for DnsResolver {
    #[tracing::instrument(level = "debug", skip(self))]
    async fn resolve_service_endpoints(
        &self,
        definition: &ServiceDefinition,
    ) -> Result<HashSet<SocketAddr>, anyhow::Error> {
        match self.dns.lookup_ip(definition.hostname()).await {
            Ok(lookup) => {
                tracing::debug!("dns query expires in: {:?}", lookup.valid_until());
                Ok(lookup
                    .iter()
                    .map(|ip_addr| {
                        tracing::debug!("result: ip {}", ip_addr);
                        (ip_addr, definition.port()).into()
                    })
                    .collect())
            }
            Err(err) => Err(err.into()),
        }
    }
}