Skip to main content

mail_auth/common/
resolver.rs

1/*
2 * SPDX-FileCopyrightText: 2020 Stalwart Labs LLC <hello@stalw.art>
3 *
4 * SPDX-License-Identifier: Apache-2.0 OR MIT
5 */
6
7use super::{parse::TxtRecordParser, verify::DomainKey};
8use crate::{
9    Error, IpLookupStrategy, MX, MessageAuthenticator, ResolverCache, Txt,
10    dkim::{Atps, DomainKeyReport},
11    dmarc::Dmarc,
12    mta_sts::{MtaSts, TlsRpt},
13    spf::{Macro, Spf},
14};
15use hickory_resolver::{
16    TokioResolver,
17    config::{CLOUDFLARE, GOOGLE, QUAD9, ResolverConfig, ResolverOpts},
18    net::{DnsError, NetError, runtime::TokioRuntimeProvider},
19    proto::{
20        ProtoError,
21        rr::{Name, RData},
22    },
23    system_conf::read_system_conf,
24};
25use std::{
26    borrow::Cow,
27    net::{IpAddr, Ipv4Addr, Ipv6Addr},
28    sync::Arc,
29    time::Instant,
30};
31
32pub struct DnsEntry<T> {
33    pub entry: T,
34    pub expires: Instant,
35}
36
37impl MessageAuthenticator {
38    pub fn new_cloudflare_tls() -> Result<Self, NetError> {
39        Self::new(ResolverConfig::tls(&CLOUDFLARE), ResolverOpts::default())
40    }
41
42    pub fn new_cloudflare() -> Result<Self, NetError> {
43        Self::new(
44            ResolverConfig::udp_and_tcp(&CLOUDFLARE),
45            ResolverOpts::default(),
46        )
47    }
48
49    pub fn new_google() -> Result<Self, NetError> {
50        Self::new(
51            ResolverConfig::udp_and_tcp(&GOOGLE),
52            ResolverOpts::default(),
53        )
54    }
55
56    pub fn new_quad9() -> Result<Self, NetError> {
57        Self::new(ResolverConfig::udp_and_tcp(&QUAD9), ResolverOpts::default())
58    }
59
60    pub fn new_quad9_tls() -> Result<Self, NetError> {
61        Self::new(ResolverConfig::tls(&QUAD9), ResolverOpts::default())
62    }
63
64    pub fn new_system_conf() -> Result<Self, NetError> {
65        let (config, options) = read_system_conf()?;
66        Self::new(config, options)
67    }
68
69    pub fn new(config: ResolverConfig, options: ResolverOpts) -> Result<Self, NetError> {
70        Ok(MessageAuthenticator(
71            TokioResolver::builder_with_config(config, TokioRuntimeProvider::default())
72                .with_options(options)
73                .build()?,
74        ))
75    }
76
77    pub fn resolver(&self) -> &TokioResolver {
78        &self.0
79    }
80
81    pub async fn txt_raw_lookup(&self, key: impl ToFqdn) -> crate::Result<Vec<u8>> {
82        let mut result = vec![];
83        for record in self
84            .0
85            .txt_lookup(Name::from_str_relaxed::<&str>(key.to_fqdn().as_ref())?)
86            .await?
87            .answers()
88        {
89            if let RData::TXT(txt) = &record.data {
90                for item in &txt.txt_data {
91                    result.extend_from_slice(item);
92                }
93            }
94        }
95
96        Ok(result)
97    }
98
99    pub async fn txt_lookup<T: TxtRecordParser + Into<Txt> + UnwrapTxtRecord>(
100        &self,
101        key: impl ToFqdn,
102        cache: Option<&impl ResolverCache<Box<str>, Txt>>,
103    ) -> crate::Result<Arc<T>> {
104        let key = key.to_fqdn();
105        if let Some(value) = cache.as_ref().and_then(|c| c.get::<str>(key.as_ref())) {
106            return T::unwrap_txt(value);
107        }
108
109        #[cfg(any(test, feature = "test"))]
110        if true {
111            return mock_resolve(key.as_ref());
112        }
113
114        let txt_lookup = self
115            .0
116            .txt_lookup(Name::from_str_relaxed::<&str>(key.as_ref())?)
117            .await?;
118        let mut result = Err(Error::InvalidRecordType);
119        let records = txt_lookup.answers().iter().filter_map(|r| {
120            let RData::TXT(txt) = &r.data else {
121                return None;
122            };
123            let txt_data = &txt.txt_data;
124            match txt_data.len() {
125                1 => Some(Cow::from(txt_data[0].as_ref())),
126                0 => None,
127                _ => {
128                    let mut entry = Vec::with_capacity(255 * txt_data.len());
129                    for data in txt_data {
130                        entry.extend_from_slice(data);
131                    }
132                    Some(Cow::from(entry))
133                }
134            }
135        });
136
137        for record in records {
138            result = T::parse(record.as_ref());
139            if result.is_ok() {
140                break;
141            }
142        }
143
144        let result: Txt = result.into();
145
146        if let Some(cache) = cache {
147            cache.insert(key, result.clone(), txt_lookup.valid_until());
148        }
149
150        T::unwrap_txt(result)
151    }
152
153    pub async fn mx_lookup(
154        &self,
155        key: impl ToFqdn,
156        cache: Option<&impl ResolverCache<Box<str>, Arc<[MX]>>>,
157    ) -> crate::Result<Arc<[MX]>> {
158        let key = key.to_fqdn();
159        if let Some(value) = cache.as_ref().and_then(|c| c.get::<str>(key.as_ref())) {
160            return Ok(value);
161        }
162
163        #[cfg(any(test, feature = "test"))]
164        if true {
165            return mock_resolve(key.as_ref());
166        }
167
168        let mx_lookup = self
169            .0
170            .mx_lookup(Name::from_str_relaxed::<&str>(key.as_ref())?)
171            .await?;
172        let mx_records = mx_lookup.answers();
173        let mut records: Vec<(u16, Vec<Box<str>>)> = Vec::with_capacity(mx_records.len());
174        for mx_record in mx_records {
175            if let RData::MX(mx) = &mx_record.data {
176                let preference = mx.preference;
177                let exchange = mx.exchange.to_lowercase().to_string().into_boxed_str();
178
179                if let Some(record) = records.iter_mut().find(|r| r.0 == preference) {
180                    record.1.push(exchange);
181                } else {
182                    records.push((preference, vec![exchange]));
183                }
184            }
185        }
186
187        records.sort_unstable_by_key(|a| a.0);
188        let records: Arc<[MX]> = records
189            .into_iter()
190            .map(|(preference, exchanges)| MX {
191                preference,
192                exchanges: exchanges.into_boxed_slice(),
193            })
194            .collect::<Arc<[MX]>>();
195
196        if let Some(cache) = cache {
197            cache.insert(key, records.clone(), mx_lookup.valid_until());
198        }
199
200        Ok(records)
201    }
202
203    pub async fn ipv4_lookup(
204        &self,
205        key: impl ToFqdn,
206        cache: Option<&impl ResolverCache<Box<str>, Arc<[Ipv4Addr]>>>,
207    ) -> crate::Result<Arc<[Ipv4Addr]>> {
208        let key = key.to_fqdn();
209        if let Some(value) = cache.as_ref().and_then(|c| c.get::<str>(key.as_ref())) {
210            return Ok(value);
211        }
212
213        let ipv4_lookup = self.ipv4_lookup_raw(key.as_ref()).await?;
214
215        if let Some(cache) = cache {
216            cache.insert(key, ipv4_lookup.entry.clone(), ipv4_lookup.expires);
217        }
218
219        Ok(ipv4_lookup.entry)
220    }
221
222    pub async fn ipv4_lookup_raw(&self, key: &str) -> crate::Result<DnsEntry<Arc<[Ipv4Addr]>>> {
223        #[cfg(any(test, feature = "test"))]
224        if true {
225            return mock_resolve(key);
226        }
227
228        let ipv4_lookup = self
229            .0
230            .ipv4_lookup(Name::from_str_relaxed::<&str>(key)?)
231            .await?;
232        let ips: Arc<[Ipv4Addr]> = ipv4_lookup
233            .answers()
234            .iter()
235            .filter_map(|r| {
236                if let RData::A(a) = &r.data {
237                    Some(a.0)
238                } else {
239                    None
240                }
241            })
242            .collect::<Vec<Ipv4Addr>>()
243            .into();
244
245        Ok(DnsEntry {
246            entry: ips,
247            expires: ipv4_lookup.valid_until(),
248        })
249    }
250
251    pub async fn ipv6_lookup(
252        &self,
253        key: impl ToFqdn,
254        cache: Option<&impl ResolverCache<Box<str>, Arc<[Ipv6Addr]>>>,
255    ) -> crate::Result<Arc<[Ipv6Addr]>> {
256        let key = key.to_fqdn();
257        if let Some(value) = cache.as_ref().and_then(|c| c.get::<str>(key.as_ref())) {
258            return Ok(value);
259        }
260
261        let ipv6_lookup = self.ipv6_lookup_raw(key.as_ref()).await?;
262
263        if let Some(cache) = cache {
264            cache.insert(key, ipv6_lookup.entry.clone(), ipv6_lookup.expires);
265        }
266
267        Ok(ipv6_lookup.entry)
268    }
269
270    pub async fn ipv6_lookup_raw(&self, key: &str) -> crate::Result<DnsEntry<Arc<[Ipv6Addr]>>> {
271        #[cfg(any(test, feature = "test"))]
272        if true {
273            return mock_resolve(key);
274        }
275
276        let ipv6_lookup = self
277            .0
278            .ipv6_lookup(Name::from_str_relaxed::<&str>(key)?)
279            .await?;
280        let ips: Arc<[Ipv6Addr]> = ipv6_lookup
281            .answers()
282            .iter()
283            .filter_map(|r| {
284                if let RData::AAAA(aaaa) = &r.data {
285                    Some(aaaa.0)
286                } else {
287                    None
288                }
289            })
290            .collect::<Vec<Ipv6Addr>>()
291            .into();
292
293        Ok(DnsEntry {
294            entry: ips,
295            expires: ipv6_lookup.valid_until(),
296        })
297    }
298
299    pub async fn ip_lookup(
300        &self,
301        key: &str,
302        mut strategy: IpLookupStrategy,
303        max_results: usize,
304        cache_ipv4: Option<&impl ResolverCache<Box<str>, Arc<[Ipv4Addr]>>>,
305        cache_ipv6: Option<&impl ResolverCache<Box<str>, Arc<[Ipv6Addr]>>>,
306    ) -> crate::Result<Vec<IpAddr>> {
307        loop {
308            match strategy {
309                IpLookupStrategy::Ipv4Only | IpLookupStrategy::Ipv4thenIpv6 => {
310                    match (self.ipv4_lookup(key, cache_ipv4).await, strategy) {
311                        (Ok(result), _) => {
312                            return Ok(result
313                                .iter()
314                                .take(max_results)
315                                .copied()
316                                .map(IpAddr::from)
317                                .collect());
318                        }
319                        (Err(err), IpLookupStrategy::Ipv4Only) => return Err(err),
320                        _ => {
321                            strategy = IpLookupStrategy::Ipv6Only;
322                        }
323                    }
324                }
325                IpLookupStrategy::Ipv6Only | IpLookupStrategy::Ipv6thenIpv4 => {
326                    match (self.ipv6_lookup(key, cache_ipv6).await, strategy) {
327                        (Ok(result), _) => {
328                            return Ok(result
329                                .iter()
330                                .take(max_results)
331                                .copied()
332                                .map(IpAddr::from)
333                                .collect());
334                        }
335                        (Err(err), IpLookupStrategy::Ipv6Only) => return Err(err),
336                        _ => {
337                            strategy = IpLookupStrategy::Ipv4Only;
338                        }
339                    }
340                }
341            }
342        }
343    }
344
345    pub async fn ptr_lookup(
346        &self,
347        addr: IpAddr,
348        cache: Option<&impl ResolverCache<IpAddr, Arc<[Box<str>]>>>,
349    ) -> crate::Result<Arc<[Box<str>]>> {
350        if let Some(value) = cache.as_ref().and_then(|c| c.get(&addr)) {
351            return Ok(value);
352        }
353
354        #[cfg(any(test, feature = "test"))]
355        if true {
356            return mock_resolve(&addr.to_string());
357        }
358
359        let ptr_lookup = self.0.reverse_lookup(addr).await?;
360        let ptr: Arc<[Box<str>]> = ptr_lookup
361            .answers()
362            .iter()
363            .filter_map(|r| {
364                let RData::PTR(ptr) = &r.data else {
365                    return None;
366                };
367                if !ptr.is_empty() {
368                    Some(ptr.to_lowercase().to_string().into_boxed_str())
369                } else {
370                    None
371                }
372            })
373            .collect::<Arc<[Box<str>]>>();
374
375        if let Some(cache) = cache {
376            cache.insert(addr, ptr.clone(), ptr_lookup.valid_until());
377        }
378
379        Ok(ptr)
380    }
381
382    #[cfg(any(test, feature = "test"))]
383    pub async fn exists(
384        &self,
385        key: impl ToFqdn,
386        cache_ipv4: Option<&impl ResolverCache<Box<str>, Arc<[Ipv4Addr]>>>,
387        cache_ipv6: Option<&impl ResolverCache<Box<str>, Arc<[Ipv6Addr]>>>,
388    ) -> crate::Result<bool> {
389        let key = key.to_fqdn();
390        match self.ipv4_lookup(key.as_ref(), cache_ipv4).await {
391            Ok(_) => Ok(true),
392            Err(Error::DnsRecordNotFound(_)) => {
393                match self.ipv6_lookup(key.as_ref(), cache_ipv6).await {
394                    Ok(_) => Ok(true),
395                    Err(Error::DnsRecordNotFound(_)) => Ok(false),
396                    Err(err) => Err(err),
397                }
398            }
399            Err(err) => Err(err),
400        }
401    }
402
403    #[cfg(not(any(test, feature = "test")))]
404    pub async fn exists(
405        &self,
406        key: impl ToFqdn,
407        cache_ipv4: Option<&impl ResolverCache<Box<str>, Arc<[Ipv4Addr]>>>,
408        cache_ipv6: Option<&impl ResolverCache<Box<str>, Arc<[Ipv6Addr]>>>,
409    ) -> crate::Result<bool> {
410        let key = key.to_fqdn();
411
412        if cache_ipv4.is_some_and(|c| c.get::<str>(key.as_ref()).is_some())
413            || cache_ipv6.is_some_and(|c| c.get::<str>(key.as_ref()).is_some())
414        {
415            return Ok(true);
416        }
417
418        match self
419            .0
420            .lookup_ip(Name::from_str_relaxed::<&str>(key.as_ref())?)
421            .await
422        {
423            Ok(result) => Ok(result.as_lookup().answers().iter().any(|r| {
424                matches!(
425                    &r.data.record_type(),
426                    hickory_resolver::proto::rr::RecordType::A
427                        | hickory_resolver::proto::rr::RecordType::AAAA
428                )
429            })),
430            Err(err) if err.is_no_records_found() => Ok(false),
431            Err(err) => Err(err.into()),
432        }
433    }
434}
435
436impl From<ProtoError> for Error {
437    fn from(err: ProtoError) -> Self {
438        Error::DnsError(err.to_string())
439    }
440}
441
442impl From<NetError> for Error {
443    fn from(err: NetError) -> Self {
444        match &err {
445            NetError::Dns(DnsError::NoRecordsFound(no_records)) => {
446                Error::DnsRecordNotFound(no_records.response_code)
447            }
448            _ => Error::DnsError(err.to_string()),
449        }
450    }
451}
452
453impl From<DomainKey> for Txt {
454    fn from(v: DomainKey) -> Self {
455        Txt::DomainKey(v.into())
456    }
457}
458
459impl From<DomainKeyReport> for Txt {
460    fn from(v: DomainKeyReport) -> Self {
461        Txt::DomainKeyReport(v.into())
462    }
463}
464
465impl From<Atps> for Txt {
466    fn from(v: Atps) -> Self {
467        Txt::Atps(v.into())
468    }
469}
470
471impl From<Spf> for Txt {
472    fn from(v: Spf) -> Self {
473        Txt::Spf(v.into())
474    }
475}
476
477impl From<Macro> for Txt {
478    fn from(v: Macro) -> Self {
479        Txt::SpfMacro(v.into())
480    }
481}
482
483impl From<Dmarc> for Txt {
484    fn from(v: Dmarc) -> Self {
485        Txt::Dmarc(v.into())
486    }
487}
488
489impl From<MtaSts> for Txt {
490    fn from(v: MtaSts) -> Self {
491        Txt::MtaSts(v.into())
492    }
493}
494
495impl From<TlsRpt> for Txt {
496    fn from(v: TlsRpt) -> Self {
497        Txt::TlsRpt(v.into())
498    }
499}
500
501impl<T: Into<Txt>> From<crate::Result<T>> for Txt {
502    fn from(v: crate::Result<T>) -> Self {
503        match v {
504            Ok(v) => v.into(),
505            Err(err) => Txt::Error(err),
506        }
507    }
508}
509
510pub trait UnwrapTxtRecord: Sized {
511    fn unwrap_txt(txt: Txt) -> crate::Result<Arc<Self>>;
512}
513
514impl UnwrapTxtRecord for DomainKey {
515    fn unwrap_txt(txt: Txt) -> crate::Result<Arc<Self>> {
516        match txt {
517            Txt::DomainKey(a) => Ok(a),
518            Txt::Error(err) => Err(err),
519            _ => Err(Error::Io("Invalid record type".to_string())),
520        }
521    }
522}
523
524impl UnwrapTxtRecord for DomainKeyReport {
525    fn unwrap_txt(txt: Txt) -> crate::Result<Arc<Self>> {
526        match txt {
527            Txt::DomainKeyReport(a) => Ok(a),
528            Txt::Error(err) => Err(err),
529            _ => Err(Error::Io("Invalid record type".to_string())),
530        }
531    }
532}
533
534impl UnwrapTxtRecord for Atps {
535    fn unwrap_txt(txt: Txt) -> crate::Result<Arc<Self>> {
536        match txt {
537            Txt::Atps(a) => Ok(a),
538            Txt::Error(err) => Err(err),
539            _ => Err(Error::Io("Invalid record type".to_string())),
540        }
541    }
542}
543
544impl UnwrapTxtRecord for Spf {
545    fn unwrap_txt(txt: Txt) -> crate::Result<Arc<Self>> {
546        match txt {
547            Txt::Spf(a) => Ok(a),
548            Txt::Error(err) => Err(err),
549            _ => Err(Error::Io("Invalid record type".to_string())),
550        }
551    }
552}
553
554impl UnwrapTxtRecord for Macro {
555    fn unwrap_txt(txt: Txt) -> crate::Result<Arc<Self>> {
556        match txt {
557            Txt::SpfMacro(a) => Ok(a),
558            Txt::Error(err) => Err(err),
559            _ => Err(Error::Io("Invalid record type".to_string())),
560        }
561    }
562}
563
564impl UnwrapTxtRecord for Dmarc {
565    fn unwrap_txt(txt: Txt) -> crate::Result<Arc<Self>> {
566        match txt {
567            Txt::Dmarc(a) => Ok(a),
568            Txt::Error(err) => Err(err),
569            _ => Err(Error::Io("Invalid record type".to_string())),
570        }
571    }
572}
573
574impl UnwrapTxtRecord for MtaSts {
575    fn unwrap_txt(txt: Txt) -> crate::Result<Arc<Self>> {
576        match txt {
577            Txt::MtaSts(a) => Ok(a),
578            Txt::Error(err) => Err(err),
579            _ => Err(Error::Io("Invalid record type".to_string())),
580        }
581    }
582}
583
584impl UnwrapTxtRecord for TlsRpt {
585    fn unwrap_txt(txt: Txt) -> crate::Result<Arc<Self>> {
586        match txt {
587            Txt::TlsRpt(a) => Ok(a),
588            Txt::Error(err) => Err(err),
589            _ => Err(Error::Io("Invalid record type".to_string())),
590        }
591    }
592}
593
594pub trait ToFqdn {
595    fn to_fqdn(&self) -> Box<str>;
596}
597
598impl<T: AsRef<str>> ToFqdn for T {
599    fn to_fqdn(&self) -> Box<str> {
600        let value = self.as_ref();
601        if value.ends_with('.') {
602            value.to_lowercase().into()
603        } else {
604            format!("{}.", value.to_lowercase()).into()
605        }
606    }
607}
608
609pub trait ToReverseName {
610    fn to_reverse_name(&self) -> String;
611}
612
613impl ToReverseName for IpAddr {
614    fn to_reverse_name(&self) -> String {
615        use std::fmt::Write;
616
617        match self {
618            IpAddr::V4(ip) => {
619                let mut segments = String::with_capacity(15);
620                for octet in ip.octets().iter().rev() {
621                    if !segments.is_empty() {
622                        segments.push('.');
623                    }
624                    let _ = write!(&mut segments, "{}", octet);
625                }
626                segments
627            }
628            IpAddr::V6(ip) => {
629                let mut segments = String::with_capacity(63);
630                for segment in ip.segments().iter().rev() {
631                    for &p in format!("{segment:04x}").as_bytes().iter().rev() {
632                        if !segments.is_empty() {
633                            segments.push('.');
634                        }
635                        segments.push(char::from(p));
636                    }
637                }
638                segments
639            }
640        }
641    }
642}
643
644#[cfg(any(test, feature = "test"))]
645pub fn mock_resolve<T>(domain: &str) -> crate::Result<T> {
646    Err(if domain.contains("_parse_error.") {
647        Error::ParseError
648    } else if domain.contains("_invalid_record.") {
649        Error::InvalidRecordType
650    } else if domain.contains("_dns_error.") {
651        Error::DnsError("".to_string())
652    } else {
653        Error::DnsRecordNotFound(hickory_resolver::proto::op::ResponseCode::NXDomain)
654    })
655}
656
657#[cfg(test)]
658mod test {
659    use std::net::IpAddr;
660
661    use crate::common::resolver::ToReverseName;
662
663    #[test]
664    fn reverse_lookup_addr() {
665        for (addr, expected) in [
666            ("1.2.3.4", "4.3.2.1"),
667            (
668                "2001:db8::cb01",
669                "1.0.b.c.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2",
670            ),
671            (
672                "2a01:4f9:c011:b43c::1",
673                "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.c.3.4.b.1.1.0.c.9.f.4.0.1.0.a.2",
674            ),
675        ] {
676            assert_eq!(addr.parse::<IpAddr>().unwrap().to_reverse_name(), expected);
677        }
678    }
679}