Skip to main content

mail_auth/common/
resolver.rs

1/*
2 * SPDX-FileCopyrightText: 2020 Stalwart Labs LLC <hello@stalw.art>
3 *
4 * SPDX-License-Identifier: Apache-2.0 OR MIT
5 */
6
7use super::{parse::TxtRecordParser, verify::DomainKey};
8use crate::{
9    Error, IpLookupStrategy, MX, MessageAuthenticator, ResolverCache, Txt,
10    dkim::{Atps, DomainKeyReport},
11    dmarc::Dmarc,
12    mta_sts::{MtaSts, TlsRpt},
13    spf::{Macro, Spf},
14};
15use hickory_resolver::{
16    Name, TokioResolver,
17    config::{ResolverConfig, ResolverOpts},
18    name_server::TokioConnectionProvider,
19    proto::{ProtoError, ProtoErrorKind},
20    system_conf::read_system_conf,
21};
22use std::{
23    borrow::Cow,
24    net::{IpAddr, Ipv4Addr, Ipv6Addr},
25    sync::Arc,
26    time::Instant,
27};
28
29pub struct DnsEntry<T> {
30    pub entry: T,
31    pub expires: Instant,
32}
33
34impl MessageAuthenticator {
35    pub fn new_cloudflare_tls() -> Result<Self, ProtoError> {
36        Self::new(ResolverConfig::cloudflare_tls(), ResolverOpts::default())
37    }
38
39    pub fn new_cloudflare() -> Result<Self, ProtoError> {
40        Self::new(ResolverConfig::cloudflare(), ResolverOpts::default())
41    }
42
43    pub fn new_google() -> Result<Self, ProtoError> {
44        Self::new(ResolverConfig::google(), ResolverOpts::default())
45    }
46
47    pub fn new_quad9() -> Result<Self, ProtoError> {
48        Self::new(ResolverConfig::quad9(), ResolverOpts::default())
49    }
50
51    pub fn new_quad9_tls() -> Result<Self, ProtoError> {
52        Self::new(ResolverConfig::quad9_tls(), ResolverOpts::default())
53    }
54
55    pub fn new_system_conf() -> Result<Self, ProtoError> {
56        let (config, options) = read_system_conf()?;
57        Self::new(config, options)
58    }
59
60    pub fn new(config: ResolverConfig, options: ResolverOpts) -> Result<Self, ProtoError> {
61        Ok(MessageAuthenticator(
62            TokioResolver::builder_with_config(config, TokioConnectionProvider::default())
63                .with_options(options)
64                .build(),
65        ))
66    }
67
68    pub fn resolver(&self) -> &TokioResolver {
69        &self.0
70    }
71
72    pub async fn txt_raw_lookup(&self, key: impl ToFqdn) -> crate::Result<Vec<u8>> {
73        let mut result = vec![];
74        for record in self
75            .0
76            .txt_lookup(Name::from_str_relaxed::<&str>(key.to_fqdn().as_ref())?)
77            .await?
78            .as_lookup()
79            .record_iter()
80        {
81            if let Some(txt_data) = record.data().as_txt() {
82                for item in txt_data.txt_data() {
83                    result.extend_from_slice(item);
84                }
85            }
86        }
87
88        Ok(result)
89    }
90
91    pub async fn txt_lookup<T: TxtRecordParser + Into<Txt> + UnwrapTxtRecord>(
92        &self,
93        key: impl ToFqdn,
94        cache: Option<&impl ResolverCache<Box<str>, Txt>>,
95    ) -> crate::Result<Arc<T>> {
96        let key = key.to_fqdn();
97        if let Some(value) = cache.as_ref().and_then(|c| c.get::<str>(key.as_ref())) {
98            return T::unwrap_txt(value);
99        }
100
101        #[cfg(any(test, feature = "test"))]
102        if true {
103            return mock_resolve(key.as_ref());
104        }
105
106        let txt_lookup = self
107            .0
108            .txt_lookup(Name::from_str_relaxed::<&str>(key.as_ref())?)
109            .await?;
110        let mut result = Err(Error::InvalidRecordType);
111        let records = txt_lookup.as_lookup().record_iter().filter_map(|r| {
112            let txt_data = r.data().as_txt()?.txt_data();
113            match txt_data.len() {
114                1 => Some(Cow::from(txt_data[0].as_ref())),
115                0 => None,
116                _ => {
117                    let mut entry = Vec::with_capacity(255 * txt_data.len());
118                    for data in txt_data {
119                        entry.extend_from_slice(data);
120                    }
121                    Some(Cow::from(entry))
122                }
123            }
124        });
125
126        for record in records {
127            result = T::parse(record.as_ref());
128            if result.is_ok() {
129                break;
130            }
131        }
132
133        let result: Txt = result.into();
134
135        if let Some(cache) = cache {
136            cache.insert(key, result.clone(), txt_lookup.valid_until());
137        }
138
139        T::unwrap_txt(result)
140    }
141
142    pub async fn mx_lookup(
143        &self,
144        key: impl ToFqdn,
145        cache: Option<&impl ResolverCache<Box<str>, Arc<[MX]>>>,
146    ) -> crate::Result<Arc<[MX]>> {
147        let key = key.to_fqdn();
148        if let Some(value) = cache.as_ref().and_then(|c| c.get::<str>(key.as_ref())) {
149            return Ok(value);
150        }
151
152        #[cfg(any(test, feature = "test"))]
153        if true {
154            return mock_resolve(key.as_ref());
155        }
156
157        let mx_lookup = self
158            .0
159            .mx_lookup(Name::from_str_relaxed::<&str>(key.as_ref())?)
160            .await?;
161        let mx_records = mx_lookup.as_lookup().records();
162        let mut records: Vec<(u16, Vec<Box<str>>)> = Vec::with_capacity(mx_records.len());
163        for mx_record in mx_records {
164            if let Some(mx) = mx_record.data().as_mx() {
165                let preference = mx.preference();
166                let exchange = mx.exchange().to_lowercase().to_string().into_boxed_str();
167
168                if let Some(record) = records.iter_mut().find(|r| r.0 == preference) {
169                    record.1.push(exchange);
170                } else {
171                    records.push((preference, vec![exchange]));
172                }
173            }
174        }
175
176        records.sort_unstable_by(|a, b| a.0.cmp(&b.0));
177        let records: Arc<[MX]> = records
178            .into_iter()
179            .map(|(preference, exchanges)| MX {
180                preference,
181                exchanges: exchanges.into_boxed_slice(),
182            })
183            .collect::<Arc<[MX]>>();
184
185        if let Some(cache) = cache {
186            cache.insert(key, records.clone(), mx_lookup.valid_until());
187        }
188
189        Ok(records)
190    }
191
192    pub async fn ipv4_lookup(
193        &self,
194        key: impl ToFqdn,
195        cache: Option<&impl ResolverCache<Box<str>, Arc<[Ipv4Addr]>>>,
196    ) -> crate::Result<Arc<[Ipv4Addr]>> {
197        let key = key.to_fqdn();
198        if let Some(value) = cache.as_ref().and_then(|c| c.get::<str>(key.as_ref())) {
199            return Ok(value);
200        }
201
202        let ipv4_lookup = self.ipv4_lookup_raw(key.as_ref()).await?;
203
204        if let Some(cache) = cache {
205            cache.insert(key, ipv4_lookup.entry.clone(), ipv4_lookup.expires);
206        }
207
208        Ok(ipv4_lookup.entry)
209    }
210
211    pub async fn ipv4_lookup_raw(&self, key: &str) -> crate::Result<DnsEntry<Arc<[Ipv4Addr]>>> {
212        #[cfg(any(test, feature = "test"))]
213        if true {
214            return mock_resolve(key);
215        }
216
217        let ipv4_lookup = self
218            .0
219            .ipv4_lookup(Name::from_str_relaxed::<&str>(key)?)
220            .await?;
221        let ips: Arc<[Ipv4Addr]> = ipv4_lookup
222            .as_lookup()
223            .record_iter()
224            .filter_map(|r| r.data().as_a()?.0.into())
225            .collect::<Vec<Ipv4Addr>>()
226            .into();
227
228        Ok(DnsEntry {
229            entry: ips,
230            expires: ipv4_lookup.valid_until(),
231        })
232    }
233
234    pub async fn ipv6_lookup(
235        &self,
236        key: impl ToFqdn,
237        cache: Option<&impl ResolverCache<Box<str>, Arc<[Ipv6Addr]>>>,
238    ) -> crate::Result<Arc<[Ipv6Addr]>> {
239        let key = key.to_fqdn();
240        if let Some(value) = cache.as_ref().and_then(|c| c.get::<str>(key.as_ref())) {
241            return Ok(value);
242        }
243
244        let ipv6_lookup = self.ipv6_lookup_raw(key.as_ref()).await?;
245
246        if let Some(cache) = cache {
247            cache.insert(key, ipv6_lookup.entry.clone(), ipv6_lookup.expires);
248        }
249
250        Ok(ipv6_lookup.entry)
251    }
252
253    pub async fn ipv6_lookup_raw(&self, key: &str) -> crate::Result<DnsEntry<Arc<[Ipv6Addr]>>> {
254        #[cfg(any(test, feature = "test"))]
255        if true {
256            return mock_resolve(key);
257        }
258
259        let ipv6_lookup = self
260            .0
261            .ipv6_lookup(Name::from_str_relaxed::<&str>(key)?)
262            .await?;
263        let ips: Arc<[Ipv6Addr]> = ipv6_lookup
264            .as_lookup()
265            .record_iter()
266            .filter_map(|r| r.data().as_aaaa()?.0.into())
267            .collect::<Vec<Ipv6Addr>>()
268            .into();
269
270        Ok(DnsEntry {
271            entry: ips,
272            expires: ipv6_lookup.valid_until(),
273        })
274    }
275
276    pub async fn ip_lookup(
277        &self,
278        key: &str,
279        mut strategy: IpLookupStrategy,
280        max_results: usize,
281        cache_ipv4: Option<&impl ResolverCache<Box<str>, Arc<[Ipv4Addr]>>>,
282        cache_ipv6: Option<&impl ResolverCache<Box<str>, Arc<[Ipv6Addr]>>>,
283    ) -> crate::Result<Vec<IpAddr>> {
284        loop {
285            match strategy {
286                IpLookupStrategy::Ipv4Only | IpLookupStrategy::Ipv4thenIpv6 => {
287                    match (self.ipv4_lookup(key, cache_ipv4).await, strategy) {
288                        (Ok(result), _) => {
289                            return Ok(result
290                                .iter()
291                                .take(max_results)
292                                .copied()
293                                .map(IpAddr::from)
294                                .collect());
295                        }
296                        (Err(err), IpLookupStrategy::Ipv4Only) => return Err(err),
297                        _ => {
298                            strategy = IpLookupStrategy::Ipv6Only;
299                        }
300                    }
301                }
302                IpLookupStrategy::Ipv6Only | IpLookupStrategy::Ipv6thenIpv4 => {
303                    match (self.ipv6_lookup(key, cache_ipv6).await, strategy) {
304                        (Ok(result), _) => {
305                            return Ok(result
306                                .iter()
307                                .take(max_results)
308                                .copied()
309                                .map(IpAddr::from)
310                                .collect());
311                        }
312                        (Err(err), IpLookupStrategy::Ipv6Only) => return Err(err),
313                        _ => {
314                            strategy = IpLookupStrategy::Ipv4Only;
315                        }
316                    }
317                }
318            }
319        }
320    }
321
322    pub async fn ptr_lookup(
323        &self,
324        addr: IpAddr,
325        cache: Option<&impl ResolverCache<IpAddr, Arc<[Box<str>]>>>,
326    ) -> crate::Result<Arc<[Box<str>]>> {
327        if let Some(value) = cache.as_ref().and_then(|c| c.get(&addr)) {
328            return Ok(value);
329        }
330
331        #[cfg(any(test, feature = "test"))]
332        if true {
333            return mock_resolve(&addr.to_string());
334        }
335
336        let ptr_lookup = self.0.reverse_lookup(addr).await?;
337        let ptr: Arc<[Box<str>]> = ptr_lookup
338            .as_lookup()
339            .record_iter()
340            .filter_map(|r| {
341                let r = r.data().as_ptr()?;
342                if !r.is_empty() {
343                    r.to_lowercase().to_string().into_boxed_str().into()
344                } else {
345                    None
346                }
347            })
348            .collect::<Arc<[Box<str>]>>();
349
350        if let Some(cache) = cache {
351            cache.insert(addr, ptr.clone(), ptr_lookup.valid_until());
352        }
353
354        Ok(ptr)
355    }
356
357    #[cfg(any(test, feature = "test"))]
358    pub async fn exists(
359        &self,
360        key: impl ToFqdn,
361        cache_ipv4: Option<&impl ResolverCache<Box<str>, Arc<[Ipv4Addr]>>>,
362        cache_ipv6: Option<&impl ResolverCache<Box<str>, Arc<[Ipv6Addr]>>>,
363    ) -> crate::Result<bool> {
364        let key = key.to_fqdn();
365        match self.ipv4_lookup(key.as_ref(), cache_ipv4).await {
366            Ok(_) => Ok(true),
367            Err(Error::DnsRecordNotFound(_)) => {
368                match self.ipv6_lookup(key.as_ref(), cache_ipv6).await {
369                    Ok(_) => Ok(true),
370                    Err(Error::DnsRecordNotFound(_)) => Ok(false),
371                    Err(err) => Err(err),
372                }
373            }
374            Err(err) => Err(err),
375        }
376    }
377
378    #[cfg(not(any(test, feature = "test")))]
379    pub async fn exists(
380        &self,
381        key: impl ToFqdn,
382        cache_ipv4: Option<&impl ResolverCache<Box<str>, Arc<[Ipv4Addr]>>>,
383        cache_ipv6: Option<&impl ResolverCache<Box<str>, Arc<[Ipv6Addr]>>>,
384    ) -> crate::Result<bool> {
385        let key = key.to_fqdn();
386
387        if cache_ipv4.is_some_and(|c| c.get::<str>(key.as_ref()).is_some())
388            || cache_ipv6.is_some_and(|c| c.get::<str>(key.as_ref()).is_some())
389        {
390            return Ok(true);
391        }
392
393        match self
394            .0
395            .lookup_ip(Name::from_str_relaxed::<&str>(key.as_ref())?)
396            .await
397        {
398            Ok(result) => Ok(result.as_lookup().record_iter().any(|r| {
399                matches!(
400                    r.data().record_type(),
401                    hickory_resolver::proto::rr::RecordType::A
402                        | hickory_resolver::proto::rr::RecordType::AAAA
403                )
404            })),
405            Err(err) => match err.kind() {
406                ProtoErrorKind::NoRecordsFound { .. } => Ok(false),
407                _ => Err(err.into()),
408            },
409        }
410    }
411}
412
413impl From<ProtoError> for Error {
414    fn from(err: ProtoError) -> Self {
415        match err.kind() {
416            ProtoErrorKind::NoRecordsFound(response_code) => {
417                Error::DnsRecordNotFound(response_code.response_code)
418            }
419            _ => Error::DnsError(err.to_string()),
420        }
421    }
422}
423
424impl From<DomainKey> for Txt {
425    fn from(v: DomainKey) -> Self {
426        Txt::DomainKey(v.into())
427    }
428}
429
430impl From<DomainKeyReport> for Txt {
431    fn from(v: DomainKeyReport) -> Self {
432        Txt::DomainKeyReport(v.into())
433    }
434}
435
436impl From<Atps> for Txt {
437    fn from(v: Atps) -> Self {
438        Txt::Atps(v.into())
439    }
440}
441
442impl From<Spf> for Txt {
443    fn from(v: Spf) -> Self {
444        Txt::Spf(v.into())
445    }
446}
447
448impl From<Macro> for Txt {
449    fn from(v: Macro) -> Self {
450        Txt::SpfMacro(v.into())
451    }
452}
453
454impl From<Dmarc> for Txt {
455    fn from(v: Dmarc) -> Self {
456        Txt::Dmarc(v.into())
457    }
458}
459
460impl From<MtaSts> for Txt {
461    fn from(v: MtaSts) -> Self {
462        Txt::MtaSts(v.into())
463    }
464}
465
466impl From<TlsRpt> for Txt {
467    fn from(v: TlsRpt) -> Self {
468        Txt::TlsRpt(v.into())
469    }
470}
471
472impl<T: Into<Txt>> From<crate::Result<T>> for Txt {
473    fn from(v: crate::Result<T>) -> Self {
474        match v {
475            Ok(v) => v.into(),
476            Err(err) => Txt::Error(err),
477        }
478    }
479}
480
481pub trait UnwrapTxtRecord: Sized {
482    fn unwrap_txt(txt: Txt) -> crate::Result<Arc<Self>>;
483}
484
485impl UnwrapTxtRecord for DomainKey {
486    fn unwrap_txt(txt: Txt) -> crate::Result<Arc<Self>> {
487        match txt {
488            Txt::DomainKey(a) => Ok(a),
489            Txt::Error(err) => Err(err),
490            _ => Err(Error::Io("Invalid record type".to_string())),
491        }
492    }
493}
494
495impl UnwrapTxtRecord for DomainKeyReport {
496    fn unwrap_txt(txt: Txt) -> crate::Result<Arc<Self>> {
497        match txt {
498            Txt::DomainKeyReport(a) => Ok(a),
499            Txt::Error(err) => Err(err),
500            _ => Err(Error::Io("Invalid record type".to_string())),
501        }
502    }
503}
504
505impl UnwrapTxtRecord for Atps {
506    fn unwrap_txt(txt: Txt) -> crate::Result<Arc<Self>> {
507        match txt {
508            Txt::Atps(a) => Ok(a),
509            Txt::Error(err) => Err(err),
510            _ => Err(Error::Io("Invalid record type".to_string())),
511        }
512    }
513}
514
515impl UnwrapTxtRecord for Spf {
516    fn unwrap_txt(txt: Txt) -> crate::Result<Arc<Self>> {
517        match txt {
518            Txt::Spf(a) => Ok(a),
519            Txt::Error(err) => Err(err),
520            _ => Err(Error::Io("Invalid record type".to_string())),
521        }
522    }
523}
524
525impl UnwrapTxtRecord for Macro {
526    fn unwrap_txt(txt: Txt) -> crate::Result<Arc<Self>> {
527        match txt {
528            Txt::SpfMacro(a) => Ok(a),
529            Txt::Error(err) => Err(err),
530            _ => Err(Error::Io("Invalid record type".to_string())),
531        }
532    }
533}
534
535impl UnwrapTxtRecord for Dmarc {
536    fn unwrap_txt(txt: Txt) -> crate::Result<Arc<Self>> {
537        match txt {
538            Txt::Dmarc(a) => Ok(a),
539            Txt::Error(err) => Err(err),
540            _ => Err(Error::Io("Invalid record type".to_string())),
541        }
542    }
543}
544
545impl UnwrapTxtRecord for MtaSts {
546    fn unwrap_txt(txt: Txt) -> crate::Result<Arc<Self>> {
547        match txt {
548            Txt::MtaSts(a) => Ok(a),
549            Txt::Error(err) => Err(err),
550            _ => Err(Error::Io("Invalid record type".to_string())),
551        }
552    }
553}
554
555impl UnwrapTxtRecord for TlsRpt {
556    fn unwrap_txt(txt: Txt) -> crate::Result<Arc<Self>> {
557        match txt {
558            Txt::TlsRpt(a) => Ok(a),
559            Txt::Error(err) => Err(err),
560            _ => Err(Error::Io("Invalid record type".to_string())),
561        }
562    }
563}
564
565pub trait ToFqdn {
566    fn to_fqdn(&self) -> Box<str>;
567}
568
569impl<T: AsRef<str>> ToFqdn for T {
570    fn to_fqdn(&self) -> Box<str> {
571        let value = self.as_ref();
572        if value.ends_with('.') {
573            value.to_lowercase().into()
574        } else {
575            format!("{}.", value.to_lowercase()).into()
576        }
577    }
578}
579
580pub trait ToReverseName {
581    fn to_reverse_name(&self) -> String;
582}
583
584impl ToReverseName for IpAddr {
585    fn to_reverse_name(&self) -> String {
586        use std::fmt::Write;
587
588        match self {
589            IpAddr::V4(ip) => {
590                let mut segments = String::with_capacity(15);
591                for octet in ip.octets().iter().rev() {
592                    if !segments.is_empty() {
593                        segments.push('.');
594                    }
595                    let _ = write!(&mut segments, "{}", octet);
596                }
597                segments
598            }
599            IpAddr::V6(ip) => {
600                let mut segments = String::with_capacity(63);
601                for segment in ip.segments().iter().rev() {
602                    for &p in format!("{segment:04x}").as_bytes().iter().rev() {
603                        if !segments.is_empty() {
604                            segments.push('.');
605                        }
606                        segments.push(char::from(p));
607                    }
608                }
609                segments
610            }
611        }
612    }
613}
614
615#[cfg(any(test, feature = "test"))]
616pub fn mock_resolve<T>(domain: &str) -> crate::Result<T> {
617    Err(if domain.contains("_parse_error.") {
618        Error::ParseError
619    } else if domain.contains("_invalid_record.") {
620        Error::InvalidRecordType
621    } else if domain.contains("_dns_error.") {
622        Error::DnsError("".to_string())
623    } else {
624        Error::DnsRecordNotFound(hickory_resolver::proto::op::ResponseCode::NXDomain)
625    })
626}
627
628#[cfg(test)]
629mod test {
630    use std::net::IpAddr;
631
632    use crate::common::resolver::ToReverseName;
633
634    #[test]
635    fn reverse_lookup_addr() {
636        for (addr, expected) in [
637            ("1.2.3.4", "4.3.2.1"),
638            (
639                "2001:db8::cb01",
640                "1.0.b.c.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2",
641            ),
642            (
643                "2a01:4f9:c011:b43c::1",
644                "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.c.3.4.b.1.1.0.c.9.f.4.0.1.0.a.2",
645            ),
646        ] {
647            assert_eq!(addr.parse::<IpAddr>().unwrap().to_reverse_name(), expected);
648        }
649    }
650}