Skip to main content

secure_exec_kernel/
dns.rs

1use hickory_resolver::config::{NameServerConfig, ResolverConfig};
2use hickory_resolver::net::runtime::TokioRuntimeProvider;
3use hickory_resolver::proto::rr::domain::Name;
4use hickory_resolver::proto::rr::rdata::{A, AAAA};
5use hickory_resolver::proto::rr::{RData, Record, RecordType};
6use hickory_resolver::TokioResolver;
7use std::collections::{BTreeMap, BTreeSet};
8use std::error::Error;
9use std::fmt;
10use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
11use std::sync::Arc;
12
13#[derive(Debug, Clone, Default, PartialEq, Eq)]
14pub struct DnsConfig {
15    pub name_servers: Vec<SocketAddr>,
16    pub overrides: BTreeMap<String, Vec<IpAddr>>,
17}
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum DnsLookupPolicy {
21    CheckPermissions,
22    SkipPermissions,
23}
24
25#[derive(Debug, Clone, PartialEq, Eq)]
26pub struct DnsLookupRequest {
27    hostname: String,
28    name_servers: Vec<SocketAddr>,
29}
30
31impl DnsLookupRequest {
32    pub fn new(hostname: impl Into<String>, name_servers: Vec<SocketAddr>) -> Self {
33        Self {
34            hostname: hostname.into(),
35            name_servers,
36        }
37    }
38
39    pub fn hostname(&self) -> &str {
40        &self.hostname
41    }
42
43    pub fn name_servers(&self) -> &[SocketAddr] {
44        &self.name_servers
45    }
46}
47
48#[derive(Debug, Clone, PartialEq, Eq)]
49pub struct DnsRecordLookupRequest {
50    hostname: String,
51    name_servers: Vec<SocketAddr>,
52    record_type: RecordType,
53}
54
55impl DnsRecordLookupRequest {
56    pub fn new(
57        hostname: impl Into<String>,
58        name_servers: Vec<SocketAddr>,
59        record_type: RecordType,
60    ) -> Self {
61        Self {
62            hostname: hostname.into(),
63            name_servers,
64            record_type,
65        }
66    }
67
68    pub fn hostname(&self) -> &str {
69        &self.hostname
70    }
71
72    pub fn name_servers(&self) -> &[SocketAddr] {
73        &self.name_servers
74    }
75
76    pub const fn record_type(&self) -> RecordType {
77        self.record_type
78    }
79}
80
81#[derive(Debug, Clone, Copy, PartialEq, Eq)]
82pub enum DnsResolutionSource {
83    Literal,
84    Override,
85    Resolver,
86}
87
88impl DnsResolutionSource {
89    pub const fn as_str(self) -> &'static str {
90        match self {
91            Self::Literal => "literal",
92            Self::Override => "override",
93            Self::Resolver => "resolver",
94        }
95    }
96}
97
98#[derive(Debug, Clone, PartialEq, Eq)]
99pub struct DnsResolution {
100    hostname: String,
101    source: DnsResolutionSource,
102    addresses: Vec<IpAddr>,
103}
104
105impl DnsResolution {
106    pub fn new(
107        hostname: impl Into<String>,
108        source: DnsResolutionSource,
109        addresses: Vec<IpAddr>,
110    ) -> Self {
111        Self {
112            hostname: hostname.into(),
113            source,
114            addresses,
115        }
116    }
117
118    pub fn hostname(&self) -> &str {
119        &self.hostname
120    }
121
122    pub const fn source(&self) -> DnsResolutionSource {
123        self.source
124    }
125
126    pub fn addresses(&self) -> &[IpAddr] {
127        &self.addresses
128    }
129}
130
131#[derive(Debug, Clone, PartialEq, Eq)]
132pub struct DnsRecordResolution {
133    hostname: String,
134    source: DnsResolutionSource,
135    records: Vec<Record>,
136}
137
138impl DnsRecordResolution {
139    pub fn new(
140        hostname: impl Into<String>,
141        source: DnsResolutionSource,
142        records: Vec<Record>,
143    ) -> Self {
144        Self {
145            hostname: hostname.into(),
146            source,
147            records,
148        }
149    }
150
151    pub fn hostname(&self) -> &str {
152        &self.hostname
153    }
154
155    pub const fn source(&self) -> DnsResolutionSource {
156        self.source
157    }
158
159    pub fn records(&self) -> &[Record] {
160        &self.records
161    }
162}
163
164#[derive(Debug, Clone, Copy, PartialEq, Eq)]
165pub enum DnsResolverErrorKind {
166    InvalidInput,
167    LookupFailed,
168}
169
170#[derive(Debug, Clone, PartialEq, Eq)]
171pub struct DnsResolverError {
172    kind: DnsResolverErrorKind,
173    message: String,
174}
175
176impl DnsResolverError {
177    pub fn invalid_input(message: impl Into<String>) -> Self {
178        Self {
179            kind: DnsResolverErrorKind::InvalidInput,
180            message: message.into(),
181        }
182    }
183
184    pub fn lookup_failed(message: impl Into<String>) -> Self {
185        Self {
186            kind: DnsResolverErrorKind::LookupFailed,
187            message: message.into(),
188        }
189    }
190
191    pub const fn kind(&self) -> DnsResolverErrorKind {
192        self.kind
193    }
194}
195
196impl fmt::Display for DnsResolverError {
197    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
198        write!(f, "{}", self.message)
199    }
200}
201
202impl Error for DnsResolverError {}
203
204pub trait DnsResolver {
205    fn lookup_ip(&self, request: &DnsLookupRequest) -> Result<Vec<IpAddr>, DnsResolverError>;
206    fn lookup_records(
207        &self,
208        request: &DnsRecordLookupRequest,
209    ) -> Result<Vec<Record>, DnsResolverError>;
210}
211
212pub type SharedDnsResolver = Arc<dyn DnsResolver + Send + Sync>;
213
214#[derive(Debug, Default)]
215pub struct HickoryDnsResolver;
216
217impl DnsResolver for HickoryDnsResolver {
218    fn lookup_ip(&self, request: &DnsLookupRequest) -> Result<Vec<IpAddr>, DnsResolverError> {
219        let resolver_config = resolver_config_from_name_servers(request.name_servers());
220        let hostname = request.hostname().to_owned();
221        std::thread::spawn(move || -> Result<Vec<IpAddr>, DnsResolverError> {
222            let runtime = tokio::runtime::Runtime::new().map_err(|error| {
223                DnsResolverError::lookup_failed(format!("failed to create DNS runtime: {error}"))
224            })?;
225
226            runtime.block_on(async move {
227                let builder = if let Some(config) = resolver_config {
228                    TokioResolver::builder_with_config(config, TokioRuntimeProvider::default())
229                } else {
230                    TokioResolver::builder_tokio().map_err(|error| {
231                        DnsResolverError::lookup_failed(format!(
232                            "failed to initialize DNS resolver from system configuration: {error}"
233                        ))
234                    })?
235                };
236
237                let resolver = builder.build().map_err(|error| {
238                    DnsResolverError::lookup_failed(format!(
239                        "failed to build DNS resolver: {error}"
240                    ))
241                })?;
242                let lookup = resolver.lookup_ip(&hostname).await.map_err(|error| {
243                    DnsResolverError::lookup_failed(format!(
244                        "failed to resolve DNS address {hostname}: {error}"
245                    ))
246                })?;
247
248                let mut addresses = Vec::new();
249                let mut seen = BTreeSet::new();
250                for ip in lookup.iter() {
251                    if seen.insert(ip) {
252                        addresses.push(ip);
253                    }
254                }
255
256                if addresses.is_empty() {
257                    return Err(DnsResolverError::lookup_failed(format!(
258                        "failed to resolve DNS address {hostname}"
259                    )));
260                }
261
262                Ok(addresses)
263            })
264        })
265        .join()
266        .map_err(|_| DnsResolverError::lookup_failed("dns resolver thread panicked"))?
267    }
268
269    fn lookup_records(
270        &self,
271        request: &DnsRecordLookupRequest,
272    ) -> Result<Vec<Record>, DnsResolverError> {
273        let resolver_config = resolver_config_from_name_servers(request.name_servers());
274        let hostname = request.hostname().to_owned();
275        let record_type = request.record_type();
276        std::thread::spawn(move || -> Result<Vec<Record>, DnsResolverError> {
277            let runtime = tokio::runtime::Runtime::new().map_err(|error| {
278                DnsResolverError::lookup_failed(format!("failed to create DNS runtime: {error}"))
279            })?;
280
281            runtime.block_on(async move {
282                let builder = if let Some(config) = resolver_config {
283                    TokioResolver::builder_with_config(config, TokioRuntimeProvider::default())
284                } else {
285                    TokioResolver::builder_tokio().map_err(|error| {
286                        DnsResolverError::lookup_failed(format!(
287                            "failed to initialize DNS resolver from system configuration: {error}"
288                        ))
289                    })?
290                };
291
292                let resolver = builder.build().map_err(|error| {
293                    DnsResolverError::lookup_failed(format!(
294                        "failed to build DNS resolver: {error}"
295                    ))
296                })?;
297                let lookup = resolver
298                    .lookup(&hostname, record_type)
299                    .await
300                    .map_err(|error| {
301                        DnsResolverError::lookup_failed(format!(
302                            "failed to resolve DNS {record_type} record {hostname}: {error}"
303                        ))
304                    })?;
305                let records = lookup.answers().to_vec();
306                if records.is_empty() {
307                    return Err(DnsResolverError::lookup_failed(format!(
308                        "failed to resolve DNS {record_type} record {hostname}"
309                    )));
310                }
311                Ok(records)
312            })
313        })
314        .join()
315        .map_err(|_| DnsResolverError::lookup_failed("dns resolver thread panicked"))?
316    }
317}
318
319pub fn normalize_dns_hostname(hostname: &str) -> Result<String, DnsResolverError> {
320    let normalized = hostname.trim().trim_end_matches('.').to_ascii_lowercase();
321    if normalized.is_empty() {
322        return Err(DnsResolverError::invalid_input(
323            "DNS hostname must not be empty",
324        ));
325    }
326    Ok(normalized)
327}
328
329pub fn format_dns_resource(hostname: &str) -> Result<String, DnsResolverError> {
330    Ok(format!("dns://{}", canonical_dns_subject(hostname)?))
331}
332
333pub fn resolve_dns(
334    config: &DnsConfig,
335    resolver: &dyn DnsResolver,
336    hostname: &str,
337) -> Result<DnsResolution, DnsResolverError> {
338    let trimmed = hostname.trim();
339    if let Ok(ip_addr) = trimmed.parse::<IpAddr>() {
340        return Ok(DnsResolution::new(
341            ip_addr.to_string(),
342            DnsResolutionSource::Literal,
343            vec![ip_addr],
344        ));
345    }
346
347    let normalized_hostname = normalize_dns_hostname(trimmed)?;
348    if let Some(addresses) = config.overrides.get(&normalized_hostname) {
349        return Ok(DnsResolution::new(
350            normalized_hostname,
351            DnsResolutionSource::Override,
352            addresses.clone(),
353        ));
354    }
355
356    let request = DnsLookupRequest::new(normalized_hostname.clone(), config.name_servers.clone());
357    let addresses = resolver.lookup_ip(&request)?;
358    if addresses.is_empty() {
359        return Err(DnsResolverError::lookup_failed(format!(
360            "failed to resolve DNS address {normalized_hostname}"
361        )));
362    }
363
364    Ok(DnsResolution::new(
365        normalized_hostname,
366        DnsResolutionSource::Resolver,
367        dedupe_addresses(addresses),
368    ))
369}
370
371pub fn resolve_dns_records(
372    config: &DnsConfig,
373    resolver: &dyn DnsResolver,
374    hostname: &str,
375    record_type: RecordType,
376) -> Result<DnsRecordResolution, DnsResolverError> {
377    let trimmed = hostname.trim();
378    let normalized_hostname = normalize_dns_hostname(trimmed)?;
379    let owner_name = normalized_hostname.parse::<Name>().map_err(|error| {
380        DnsResolverError::invalid_input(format!("invalid DNS hostname: {error}"))
381    })?;
382
383    if let Some(records) = records_from_literal(trimmed, owner_name.clone(), record_type) {
384        return Ok(DnsRecordResolution::new(
385            normalized_hostname,
386            DnsResolutionSource::Literal,
387            records,
388        ));
389    }
390
391    if let Some(addresses) = config.overrides.get(&normalized_hostname) {
392        let records = records_from_addresses(owner_name.clone(), addresses, record_type);
393        if !records.is_empty() {
394            return Ok(DnsRecordResolution::new(
395                normalized_hostname,
396                DnsResolutionSource::Override,
397                records,
398            ));
399        }
400    }
401
402    let request = DnsRecordLookupRequest::new(
403        normalized_hostname.clone(),
404        config.name_servers.clone(),
405        record_type,
406    );
407    let records = resolver.lookup_records(&request)?;
408    if records.is_empty() {
409        return Err(DnsResolverError::lookup_failed(format!(
410            "failed to resolve DNS {record_type} record {normalized_hostname}"
411        )));
412    }
413
414    Ok(DnsRecordResolution::new(
415        normalized_hostname,
416        DnsResolutionSource::Resolver,
417        records,
418    ))
419}
420
421fn canonical_dns_subject(hostname: &str) -> Result<String, DnsResolverError> {
422    let trimmed = hostname.trim();
423    if let Ok(ip_addr) = trimmed.parse::<IpAddr>() {
424        return Ok(ip_addr.to_string());
425    }
426
427    normalize_dns_hostname(trimmed)
428}
429
430fn resolver_config_from_name_servers(name_servers: &[SocketAddr]) -> Option<ResolverConfig> {
431    if name_servers.is_empty() {
432        return None;
433    }
434
435    let name_servers = name_servers
436        .iter()
437        .map(|server| {
438            let mut config = NameServerConfig::udp_and_tcp(server.ip());
439            for connection in &mut config.connections {
440                connection.port = server.port();
441                connection.bind_addr = Some(SocketAddr::new(
442                    if server.is_ipv6() {
443                        IpAddr::V6(Ipv6Addr::UNSPECIFIED)
444                    } else {
445                        IpAddr::V4(Ipv4Addr::UNSPECIFIED)
446                    },
447                    0,
448                ));
449            }
450            config
451        })
452        .collect();
453
454    Some(ResolverConfig::from_parts(None, vec![], name_servers))
455}
456
457fn dedupe_addresses(addresses: Vec<IpAddr>) -> Vec<IpAddr> {
458    let mut deduped = Vec::with_capacity(addresses.len());
459    let mut seen = BTreeSet::new();
460    for address in addresses {
461        if seen.insert(address) {
462            deduped.push(address);
463        }
464    }
465    deduped
466}
467
468fn records_from_literal(
469    hostname: &str,
470    owner_name: Name,
471    record_type: RecordType,
472) -> Option<Vec<Record>> {
473    let ip_addr = hostname.parse::<IpAddr>().ok()?;
474    let records = records_from_addresses(owner_name, &[ip_addr], record_type);
475    if records.is_empty() {
476        return None;
477    }
478    Some(records)
479}
480
481fn records_from_addresses(
482    owner_name: Name,
483    addresses: &[IpAddr],
484    record_type: RecordType,
485) -> Vec<Record> {
486    addresses
487        .iter()
488        .filter_map(|ip| match (record_type, ip) {
489            (RecordType::A, IpAddr::V4(ipv4)) | (RecordType::ANY, IpAddr::V4(ipv4)) => Some(
490                Record::from_rdata(owner_name.clone(), 60, RData::A(A::from(*ipv4))),
491            ),
492            (RecordType::AAAA, IpAddr::V6(ipv6)) | (RecordType::ANY, IpAddr::V6(ipv6)) => Some(
493                Record::from_rdata(owner_name.clone(), 60, RData::AAAA(AAAA::from(*ipv6))),
494            ),
495            _ => None,
496        })
497        .collect()
498}