Skip to main content

mail_auth/common/
resolver.rs

1/*
2 * SPDX-FileCopyrightText: 2020 Stalwart Labs LLC <hello@stalw.art>
3 *
4 * SPDX-License-Identifier: Apache-2.0 OR MIT
5 */
6
7use super::{parse::TxtRecordParser, verify::DomainKey};
8use crate::{
9    Error, IpLookupStrategy, MX, MessageAuthenticator, ResolverCache, Txt,
10    dkim::{Atps, DomainKeyReport},
11    dmarc::Dmarc,
12    mta_sts::{MtaSts, TlsRpt},
13    spf::{Macro, Spf},
14};
15use hickory_resolver::{
16    TokioResolver,
17    config::{CLOUDFLARE, GOOGLE, QUAD9, ResolverConfig, ResolverOpts},
18    net::{
19        DnsError, NetError,
20        runtime::TokioRuntimeProvider,
21    },
22    proto::{
23        ProtoError,
24        rr::{Name, RData},
25    },
26    system_conf::read_system_conf,
27};
28use std::{
29    borrow::Cow,
30    net::{IpAddr, Ipv4Addr, Ipv6Addr},
31    sync::Arc,
32    time::Instant,
33};
34
35pub struct DnsEntry<T> {
36    pub entry: T,
37    pub expires: Instant,
38}
39
40impl MessageAuthenticator {
41    pub fn new_cloudflare_tls() -> Result<Self, NetError> {
42        Self::new(ResolverConfig::tls(&CLOUDFLARE), ResolverOpts::default())
43    }
44
45    pub fn new_cloudflare() -> Result<Self, NetError> {
46        Self::new(
47            ResolverConfig::udp_and_tcp(&CLOUDFLARE),
48            ResolverOpts::default(),
49        )
50    }
51
52    pub fn new_google() -> Result<Self, NetError> {
53        Self::new(
54            ResolverConfig::udp_and_tcp(&GOOGLE),
55            ResolverOpts::default(),
56        )
57    }
58
59    pub fn new_quad9() -> Result<Self, NetError> {
60        Self::new(ResolverConfig::udp_and_tcp(&QUAD9), ResolverOpts::default())
61    }
62
63    pub fn new_quad9_tls() -> Result<Self, NetError> {
64        Self::new(ResolverConfig::tls(&QUAD9), ResolverOpts::default())
65    }
66
67    pub fn new_system_conf() -> Result<Self, NetError> {
68        let (config, options) = read_system_conf()?;
69        Self::new(config, options)
70    }
71
72    pub fn new(config: ResolverConfig, options: ResolverOpts) -> Result<Self, NetError> {
73        Ok(MessageAuthenticator(
74            TokioResolver::builder_with_config(config, TokioRuntimeProvider::default())
75                .with_options(options)
76                .build()?,
77        ))
78    }
79
80    pub fn resolver(&self) -> &TokioResolver {
81        &self.0
82    }
83
84    pub async fn txt_raw_lookup(&self, key: impl ToFqdn) -> crate::Result<Vec<u8>> {
85        let mut result = vec![];
86        for record in self
87            .0
88            .txt_lookup(Name::from_str_relaxed::<&str>(key.to_fqdn().as_ref())?)
89            .await?
90            .answers()
91        {
92            if let RData::TXT(txt) = &record.data {
93                for item in &txt.txt_data {
94                    result.extend_from_slice(item);
95                }
96            }
97        }
98
99        Ok(result)
100    }
101
102    pub async fn txt_lookup<T: TxtRecordParser + Into<Txt> + UnwrapTxtRecord>(
103        &self,
104        key: impl ToFqdn,
105        cache: Option<&impl ResolverCache<Box<str>, Txt>>,
106    ) -> crate::Result<Arc<T>> {
107        let key = key.to_fqdn();
108        if let Some(value) = cache.as_ref().and_then(|c| c.get::<str>(key.as_ref())) {
109            return T::unwrap_txt(value);
110        }
111
112        #[cfg(any(test, feature = "test"))]
113        if true {
114            return mock_resolve(key.as_ref());
115        }
116
117        let txt_lookup = self
118            .0
119            .txt_lookup(Name::from_str_relaxed::<&str>(key.as_ref())?)
120            .await?;
121        let mut result = Err(Error::InvalidRecordType);
122        let records = txt_lookup.answers().iter().filter_map(|r| {
123            let RData::TXT(txt) = &r.data else {
124                return None;
125            };
126            let txt_data = &txt.txt_data;
127            match txt_data.len() {
128                1 => Some(Cow::from(txt_data[0].as_ref())),
129                0 => None,
130                _ => {
131                    let mut entry = Vec::with_capacity(255 * txt_data.len());
132                    for data in txt_data {
133                        entry.extend_from_slice(data);
134                    }
135                    Some(Cow::from(entry))
136                }
137            }
138        });
139
140        for record in records {
141            result = T::parse(record.as_ref());
142            if result.is_ok() {
143                break;
144            }
145        }
146
147        let result: Txt = result.into();
148
149        if let Some(cache) = cache {
150            cache.insert(key, result.clone(), txt_lookup.valid_until());
151        }
152
153        T::unwrap_txt(result)
154    }
155
156    pub async fn mx_lookup(
157        &self,
158        key: impl ToFqdn,
159        cache: Option<&impl ResolverCache<Box<str>, Arc<[MX]>>>,
160    ) -> crate::Result<Arc<[MX]>> {
161        let key = key.to_fqdn();
162        if let Some(value) = cache.as_ref().and_then(|c| c.get::<str>(key.as_ref())) {
163            return Ok(value);
164        }
165
166        #[cfg(any(test, feature = "test"))]
167        if true {
168            return mock_resolve(key.as_ref());
169        }
170
171        let mx_lookup = self
172            .0
173            .mx_lookup(Name::from_str_relaxed::<&str>(key.as_ref())?)
174            .await?;
175        let mx_records = mx_lookup.answers();
176        let mut records: Vec<(u16, Vec<Box<str>>)> = Vec::with_capacity(mx_records.len());
177        for mx_record in mx_records {
178            if let RData::MX(mx) = &mx_record.data {
179                let preference = mx.preference;
180                let exchange = mx.exchange.to_lowercase().to_string().into_boxed_str();
181
182                if let Some(record) = records.iter_mut().find(|r| r.0 == preference) {
183                    record.1.push(exchange);
184                } else {
185                    records.push((preference, vec![exchange]));
186                }
187            }
188        }
189
190        records.sort_unstable_by_key(|a| a.0);
191        let records: Arc<[MX]> = records
192            .into_iter()
193            .map(|(preference, exchanges)| MX {
194                preference,
195                exchanges: exchanges.into_boxed_slice(),
196            })
197            .collect::<Arc<[MX]>>();
198
199        if let Some(cache) = cache {
200            cache.insert(key, records.clone(), mx_lookup.valid_until());
201        }
202
203        Ok(records)
204    }
205
206    pub async fn ipv4_lookup(
207        &self,
208        key: impl ToFqdn,
209        cache: Option<&impl ResolverCache<Box<str>, Arc<[Ipv4Addr]>>>,
210    ) -> crate::Result<Arc<[Ipv4Addr]>> {
211        let key = key.to_fqdn();
212        if let Some(value) = cache.as_ref().and_then(|c| c.get::<str>(key.as_ref())) {
213            return Ok(value);
214        }
215
216        let ipv4_lookup = self.ipv4_lookup_raw(key.as_ref()).await?;
217
218        if let Some(cache) = cache {
219            cache.insert(key, ipv4_lookup.entry.clone(), ipv4_lookup.expires);
220        }
221
222        Ok(ipv4_lookup.entry)
223    }
224
225    pub async fn ipv4_lookup_raw(&self, key: &str) -> crate::Result<DnsEntry<Arc<[Ipv4Addr]>>> {
226        #[cfg(any(test, feature = "test"))]
227        if true {
228            return mock_resolve(key);
229        }
230
231        let ipv4_lookup = self
232            .0
233            .ipv4_lookup(Name::from_str_relaxed::<&str>(key)?)
234            .await?;
235        let ips: Arc<[Ipv4Addr]> = ipv4_lookup
236            .answers()
237            .iter()
238            .filter_map(|r| {
239                if let RData::A(a) = &r.data {
240                    Some(a.0)
241                } else {
242                    None
243                }
244            })
245            .collect::<Vec<Ipv4Addr>>()
246            .into();
247
248        Ok(DnsEntry {
249            entry: ips,
250            expires: ipv4_lookup.valid_until(),
251        })
252    }
253
254    pub async fn ipv6_lookup(
255        &self,
256        key: impl ToFqdn,
257        cache: Option<&impl ResolverCache<Box<str>, Arc<[Ipv6Addr]>>>,
258    ) -> crate::Result<Arc<[Ipv6Addr]>> {
259        let key = key.to_fqdn();
260        if let Some(value) = cache.as_ref().and_then(|c| c.get::<str>(key.as_ref())) {
261            return Ok(value);
262        }
263
264        let ipv6_lookup = self.ipv6_lookup_raw(key.as_ref()).await?;
265
266        if let Some(cache) = cache {
267            cache.insert(key, ipv6_lookup.entry.clone(), ipv6_lookup.expires);
268        }
269
270        Ok(ipv6_lookup.entry)
271    }
272
273    pub async fn ipv6_lookup_raw(&self, key: &str) -> crate::Result<DnsEntry<Arc<[Ipv6Addr]>>> {
274        #[cfg(any(test, feature = "test"))]
275        if true {
276            return mock_resolve(key);
277        }
278
279        let ipv6_lookup = self
280            .0
281            .ipv6_lookup(Name::from_str_relaxed::<&str>(key)?)
282            .await?;
283        let ips: Arc<[Ipv6Addr]> = ipv6_lookup
284            .answers()
285            .iter()
286            .filter_map(|r| {
287                if let RData::AAAA(aaaa) = &r.data {
288                    Some(aaaa.0)
289                } else {
290                    None
291                }
292            })
293            .collect::<Vec<Ipv6Addr>>()
294            .into();
295
296        Ok(DnsEntry {
297            entry: ips,
298            expires: ipv6_lookup.valid_until(),
299        })
300    }
301
302    pub async fn ip_lookup(
303        &self,
304        key: &str,
305        mut strategy: IpLookupStrategy,
306        max_results: usize,
307        cache_ipv4: Option<&impl ResolverCache<Box<str>, Arc<[Ipv4Addr]>>>,
308        cache_ipv6: Option<&impl ResolverCache<Box<str>, Arc<[Ipv6Addr]>>>,
309    ) -> crate::Result<Vec<IpAddr>> {
310        loop {
311            match strategy {
312                IpLookupStrategy::Ipv4Only | IpLookupStrategy::Ipv4thenIpv6 => {
313                    match (self.ipv4_lookup(key, cache_ipv4).await, strategy) {
314                        (Ok(result), _) => {
315                            return Ok(result
316                                .iter()
317                                .take(max_results)
318                                .copied()
319                                .map(IpAddr::from)
320                                .collect());
321                        }
322                        (Err(err), IpLookupStrategy::Ipv4Only) => return Err(err),
323                        _ => {
324                            strategy = IpLookupStrategy::Ipv6Only;
325                        }
326                    }
327                }
328                IpLookupStrategy::Ipv6Only | IpLookupStrategy::Ipv6thenIpv4 => {
329                    match (self.ipv6_lookup(key, cache_ipv6).await, strategy) {
330                        (Ok(result), _) => {
331                            return Ok(result
332                                .iter()
333                                .take(max_results)
334                                .copied()
335                                .map(IpAddr::from)
336                                .collect());
337                        }
338                        (Err(err), IpLookupStrategy::Ipv6Only) => return Err(err),
339                        _ => {
340                            strategy = IpLookupStrategy::Ipv4Only;
341                        }
342                    }
343                }
344            }
345        }
346    }
347
348    pub async fn ptr_lookup(
349        &self,
350        addr: IpAddr,
351        cache: Option<&impl ResolverCache<IpAddr, Arc<[Box<str>]>>>,
352    ) -> crate::Result<Arc<[Box<str>]>> {
353        if let Some(value) = cache.as_ref().and_then(|c| c.get(&addr)) {
354            return Ok(value);
355        }
356
357        #[cfg(any(test, feature = "test"))]
358        if true {
359            return mock_resolve(&addr.to_string());
360        }
361
362        let ptr_lookup = self.0.reverse_lookup(addr).await?;
363        let ptr: Arc<[Box<str>]> = ptr_lookup
364            .answers()
365            .iter()
366            .filter_map(|r| {
367                let RData::PTR(ptr) = &r.data else {
368                    return None;
369                };
370                if !ptr.is_empty() {
371                    Some(ptr.to_lowercase().to_string().into_boxed_str())
372                } else {
373                    None
374                }
375            })
376            .collect::<Arc<[Box<str>]>>();
377
378        if let Some(cache) = cache {
379            cache.insert(addr, ptr.clone(), ptr_lookup.valid_until());
380        }
381
382        Ok(ptr)
383    }
384
385    #[cfg(any(test, feature = "test"))]
386    pub async fn exists(
387        &self,
388        key: impl ToFqdn,
389        cache_ipv4: Option<&impl ResolverCache<Box<str>, Arc<[Ipv4Addr]>>>,
390        cache_ipv6: Option<&impl ResolverCache<Box<str>, Arc<[Ipv6Addr]>>>,
391    ) -> crate::Result<bool> {
392        let key = key.to_fqdn();
393        match self.ipv4_lookup(key.as_ref(), cache_ipv4).await {
394            Ok(_) => Ok(true),
395            Err(Error::DnsRecordNotFound(_)) => {
396                match self.ipv6_lookup(key.as_ref(), cache_ipv6).await {
397                    Ok(_) => Ok(true),
398                    Err(Error::DnsRecordNotFound(_)) => Ok(false),
399                    Err(err) => Err(err),
400                }
401            }
402            Err(err) => Err(err),
403        }
404    }
405
406    #[cfg(not(any(test, feature = "test")))]
407    pub async fn exists(
408        &self,
409        key: impl ToFqdn,
410        cache_ipv4: Option<&impl ResolverCache<Box<str>, Arc<[Ipv4Addr]>>>,
411        cache_ipv6: Option<&impl ResolverCache<Box<str>, Arc<[Ipv6Addr]>>>,
412    ) -> crate::Result<bool> {
413        let key = key.to_fqdn();
414
415        if cache_ipv4.is_some_and(|c| c.get::<str>(key.as_ref()).is_some())
416            || cache_ipv6.is_some_and(|c| c.get::<str>(key.as_ref()).is_some())
417        {
418            return Ok(true);
419        }
420
421        match self
422            .0
423            .lookup_ip(Name::from_str_relaxed::<&str>(key.as_ref())?)
424            .await
425        {
426            Ok(result) => Ok(result.as_lookup().answers().iter().any(|r| {
427                matches!(
428                    &r.data.record_type(),
429                    hickory_resolver::proto::rr::RecordType::A
430                        | hickory_resolver::proto::rr::RecordType::AAAA
431                )
432            })),
433            Err(err) if err.is_no_records_found() => Ok(false),
434            Err(err) => Err(err.into()),
435        }
436    }
437}
438
439impl From<ProtoError> for Error {
440    fn from(err: ProtoError) -> Self {
441        Error::DnsError(err.to_string())
442    }
443}
444
445impl From<NetError> for Error {
446    fn from(err: NetError) -> Self {
447        match &err {
448            NetError::Dns(DnsError::NoRecordsFound(no_records)) => {
449                Error::DnsRecordNotFound(no_records.response_code)
450            }
451            _ => Error::DnsError(err.to_string()),
452        }
453    }
454}
455
456impl From<DomainKey> for Txt {
457    fn from(v: DomainKey) -> Self {
458        Txt::DomainKey(v.into())
459    }
460}
461
462impl From<DomainKeyReport> for Txt {
463    fn from(v: DomainKeyReport) -> Self {
464        Txt::DomainKeyReport(v.into())
465    }
466}
467
468impl From<Atps> for Txt {
469    fn from(v: Atps) -> Self {
470        Txt::Atps(v.into())
471    }
472}
473
474impl From<Spf> for Txt {
475    fn from(v: Spf) -> Self {
476        Txt::Spf(v.into())
477    }
478}
479
480impl From<Macro> for Txt {
481    fn from(v: Macro) -> Self {
482        Txt::SpfMacro(v.into())
483    }
484}
485
486impl From<Dmarc> for Txt {
487    fn from(v: Dmarc) -> Self {
488        Txt::Dmarc(v.into())
489    }
490}
491
492impl From<MtaSts> for Txt {
493    fn from(v: MtaSts) -> Self {
494        Txt::MtaSts(v.into())
495    }
496}
497
498impl From<TlsRpt> for Txt {
499    fn from(v: TlsRpt) -> Self {
500        Txt::TlsRpt(v.into())
501    }
502}
503
504impl<T: Into<Txt>> From<crate::Result<T>> for Txt {
505    fn from(v: crate::Result<T>) -> Self {
506        match v {
507            Ok(v) => v.into(),
508            Err(err) => Txt::Error(err),
509        }
510    }
511}
512
513pub trait UnwrapTxtRecord: Sized {
514    fn unwrap_txt(txt: Txt) -> crate::Result<Arc<Self>>;
515}
516
517impl UnwrapTxtRecord for DomainKey {
518    fn unwrap_txt(txt: Txt) -> crate::Result<Arc<Self>> {
519        match txt {
520            Txt::DomainKey(a) => Ok(a),
521            Txt::Error(err) => Err(err),
522            _ => Err(Error::Io("Invalid record type".to_string())),
523        }
524    }
525}
526
527impl UnwrapTxtRecord for DomainKeyReport {
528    fn unwrap_txt(txt: Txt) -> crate::Result<Arc<Self>> {
529        match txt {
530            Txt::DomainKeyReport(a) => Ok(a),
531            Txt::Error(err) => Err(err),
532            _ => Err(Error::Io("Invalid record type".to_string())),
533        }
534    }
535}
536
537impl UnwrapTxtRecord for Atps {
538    fn unwrap_txt(txt: Txt) -> crate::Result<Arc<Self>> {
539        match txt {
540            Txt::Atps(a) => Ok(a),
541            Txt::Error(err) => Err(err),
542            _ => Err(Error::Io("Invalid record type".to_string())),
543        }
544    }
545}
546
547impl UnwrapTxtRecord for Spf {
548    fn unwrap_txt(txt: Txt) -> crate::Result<Arc<Self>> {
549        match txt {
550            Txt::Spf(a) => Ok(a),
551            Txt::Error(err) => Err(err),
552            _ => Err(Error::Io("Invalid record type".to_string())),
553        }
554    }
555}
556
557impl UnwrapTxtRecord for Macro {
558    fn unwrap_txt(txt: Txt) -> crate::Result<Arc<Self>> {
559        match txt {
560            Txt::SpfMacro(a) => Ok(a),
561            Txt::Error(err) => Err(err),
562            _ => Err(Error::Io("Invalid record type".to_string())),
563        }
564    }
565}
566
567impl UnwrapTxtRecord for Dmarc {
568    fn unwrap_txt(txt: Txt) -> crate::Result<Arc<Self>> {
569        match txt {
570            Txt::Dmarc(a) => Ok(a),
571            Txt::Error(err) => Err(err),
572            _ => Err(Error::Io("Invalid record type".to_string())),
573        }
574    }
575}
576
577impl UnwrapTxtRecord for MtaSts {
578    fn unwrap_txt(txt: Txt) -> crate::Result<Arc<Self>> {
579        match txt {
580            Txt::MtaSts(a) => Ok(a),
581            Txt::Error(err) => Err(err),
582            _ => Err(Error::Io("Invalid record type".to_string())),
583        }
584    }
585}
586
587impl UnwrapTxtRecord for TlsRpt {
588    fn unwrap_txt(txt: Txt) -> crate::Result<Arc<Self>> {
589        match txt {
590            Txt::TlsRpt(a) => Ok(a),
591            Txt::Error(err) => Err(err),
592            _ => Err(Error::Io("Invalid record type".to_string())),
593        }
594    }
595}
596
597pub trait ToFqdn {
598    fn to_fqdn(&self) -> Box<str>;
599}
600
601impl<T: AsRef<str>> ToFqdn for T {
602    fn to_fqdn(&self) -> Box<str> {
603        let value = self.as_ref();
604        if value.ends_with('.') {
605            value.to_lowercase().into()
606        } else {
607            format!("{}.", value.to_lowercase()).into()
608        }
609    }
610}
611
612pub trait ToReverseName {
613    fn to_reverse_name(&self) -> String;
614}
615
616impl ToReverseName for IpAddr {
617    fn to_reverse_name(&self) -> String {
618        use std::fmt::Write;
619
620        match self {
621            IpAddr::V4(ip) => {
622                let mut segments = String::with_capacity(15);
623                for octet in ip.octets().iter().rev() {
624                    if !segments.is_empty() {
625                        segments.push('.');
626                    }
627                    let _ = write!(&mut segments, "{}", octet);
628                }
629                segments
630            }
631            IpAddr::V6(ip) => {
632                let mut segments = String::with_capacity(63);
633                for segment in ip.segments().iter().rev() {
634                    for &p in format!("{segment:04x}").as_bytes().iter().rev() {
635                        if !segments.is_empty() {
636                            segments.push('.');
637                        }
638                        segments.push(char::from(p));
639                    }
640                }
641                segments
642            }
643        }
644    }
645}
646
647#[cfg(any(test, feature = "test"))]
648pub fn mock_resolve<T>(domain: &str) -> crate::Result<T> {
649    Err(if domain.contains("_parse_error.") {
650        Error::ParseError
651    } else if domain.contains("_invalid_record.") {
652        Error::InvalidRecordType
653    } else if domain.contains("_dns_error.") {
654        Error::DnsError("".to_string())
655    } else {
656        Error::DnsRecordNotFound(hickory_resolver::proto::op::ResponseCode::NXDomain)
657    })
658}
659
660#[cfg(test)]
661mod test {
662    use std::net::IpAddr;
663
664    use crate::common::resolver::ToReverseName;
665
666    #[test]
667    fn reverse_lookup_addr() {
668        for (addr, expected) in [
669            ("1.2.3.4", "4.3.2.1"),
670            (
671                "2001:db8::cb01",
672                "1.0.b.c.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2",
673            ),
674            (
675                "2a01:4f9:c011:b43c::1",
676                "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.c.3.4.b.1.1.0.c.9.f.4.0.1.0.a.2",
677            ),
678        ] {
679            assert_eq!(addr.parse::<IpAddr>().unwrap().to_reverse_name(), expected);
680        }
681    }
682}