mail_auth/common/
resolver.rs

1/*
2 * SPDX-FileCopyrightText: 2020 Stalwart Labs LLC <hello@stalw.art>
3 *
4 * SPDX-License-Identifier: Apache-2.0 OR MIT
5 */
6
7use std::{
8    borrow::Cow,
9    net::{IpAddr, Ipv4Addr, Ipv6Addr},
10    sync::Arc,
11    time::Instant,
12};
13
14use hickory_resolver::{
15    config::{ResolverConfig, ResolverOpts},
16    name_server::TokioConnectionProvider,
17    proto::{ProtoError, ProtoErrorKind},
18    system_conf::read_system_conf,
19    Name, TokioResolver,
20};
21
22use crate::{
23    dkim::{Atps, DomainKeyReport},
24    dmarc::Dmarc,
25    mta_sts::{MtaSts, TlsRpt},
26    spf::{Macro, Spf},
27    Error, IpLookupStrategy, MessageAuthenticator, ResolverCache, Txt, MX,
28};
29
30use super::{parse::TxtRecordParser, verify::DomainKey};
31
32pub struct DnsEntry<T> {
33    pub entry: T,
34    pub expires: Instant,
35}
36
37impl MessageAuthenticator {
38    pub fn new_cloudflare_tls() -> Result<Self, ProtoError> {
39        Self::new(ResolverConfig::cloudflare_tls(), ResolverOpts::default())
40    }
41
42    pub fn new_cloudflare() -> Result<Self, ProtoError> {
43        Self::new(ResolverConfig::cloudflare(), ResolverOpts::default())
44    }
45
46    pub fn new_google() -> Result<Self, ProtoError> {
47        Self::new(ResolverConfig::google(), ResolverOpts::default())
48    }
49
50    pub fn new_quad9() -> Result<Self, ProtoError> {
51        Self::new(ResolverConfig::quad9(), ResolverOpts::default())
52    }
53
54    pub fn new_quad9_tls() -> Result<Self, ProtoError> {
55        Self::new(ResolverConfig::quad9_tls(), ResolverOpts::default())
56    }
57
58    pub fn new_system_conf() -> Result<Self, ProtoError> {
59        let (config, options) = read_system_conf()?;
60        Self::new(config, options)
61    }
62
63    pub fn new(config: ResolverConfig, options: ResolverOpts) -> Result<Self, ProtoError> {
64        Ok(MessageAuthenticator(
65            TokioResolver::builder_with_config(config, TokioConnectionProvider::default())
66                .with_options(options)
67                .build(),
68        ))
69    }
70
71    pub fn resolver(&self) -> &TokioResolver {
72        &self.0
73    }
74
75    pub async fn txt_raw_lookup(&self, key: impl IntoFqdn<'_>) -> crate::Result<Vec<u8>> {
76        let mut result = vec![];
77        for record in self
78            .0
79            .txt_lookup(Name::from_str_relaxed(key.into_fqdn().as_ref())?)
80            .await?
81            .as_lookup()
82            .record_iter()
83        {
84            if let Some(txt_data) = record.data().as_txt() {
85                for item in txt_data.txt_data() {
86                    result.extend_from_slice(item);
87                }
88            }
89        }
90
91        Ok(result)
92    }
93
94    pub async fn txt_lookup<'x, T: TxtRecordParser + Into<Txt> + UnwrapTxtRecord>(
95        &self,
96        key: impl IntoFqdn<'x>,
97        cache: Option<&impl ResolverCache<String, Txt>>,
98    ) -> crate::Result<Arc<T>> {
99        let key = key.into_fqdn();
100        if let Some(value) = cache.as_ref().and_then(|c| c.get(key.as_ref())) {
101            return T::unwrap_txt(value);
102        }
103
104        #[cfg(any(test, feature = "test"))]
105        if true {
106            return mock_resolve(key.as_ref());
107        }
108
109        let txt_lookup = self
110            .0
111            .txt_lookup(Name::from_str_relaxed(key.as_ref())?)
112            .await?;
113        let mut result = Err(Error::InvalidRecordType);
114        let records = txt_lookup.as_lookup().record_iter().filter_map(|r| {
115            let txt_data = r.data().as_txt()?.txt_data();
116            match txt_data.len() {
117                1 => Some(Cow::from(txt_data[0].as_ref())),
118                0 => None,
119                _ => {
120                    let mut entry = Vec::with_capacity(255 * txt_data.len());
121                    for data in txt_data {
122                        entry.extend_from_slice(data);
123                    }
124                    Some(Cow::from(entry))
125                }
126            }
127        });
128
129        for record in records {
130            result = T::parse(record.as_ref());
131            if result.is_ok() {
132                break;
133            }
134        }
135
136        let result: Txt = result.into();
137
138        if let Some(cache) = cache {
139            cache.insert(key.into_owned(), result.clone(), txt_lookup.valid_until());
140        }
141
142        T::unwrap_txt(result)
143    }
144
145    pub async fn mx_lookup<'x>(
146        &self,
147        key: impl IntoFqdn<'x>,
148        cache: Option<&impl ResolverCache<String, Arc<Vec<MX>>>>,
149    ) -> crate::Result<Arc<Vec<MX>>> {
150        let key = key.into_fqdn();
151        if let Some(value) = cache.as_ref().and_then(|c| c.get(key.as_ref())) {
152            return Ok(value);
153        }
154
155        #[cfg(any(test, feature = "test"))]
156        if true {
157            return mock_resolve(key.as_ref());
158        }
159
160        let mx_lookup = self
161            .0
162            .mx_lookup(Name::from_str_relaxed(key.as_ref())?)
163            .await?;
164        let mx_records = mx_lookup.as_lookup().records();
165        let mut records: Vec<MX> = Vec::with_capacity(mx_records.len());
166        for mx_record in mx_records {
167            if let Some(mx) = mx_record.data().as_mx() {
168                let preference = mx.preference();
169                let exchange = mx.exchange().to_lowercase().to_string();
170
171                if let Some(record) = records.iter_mut().find(|r| r.preference == preference) {
172                    record.exchanges.push(exchange);
173                } else {
174                    records.push(MX {
175                        exchanges: vec![exchange],
176                        preference,
177                    });
178                }
179            }
180        }
181
182        records.sort_unstable_by(|a, b| a.preference.cmp(&b.preference));
183        let records = Arc::new(records);
184
185        if let Some(cache) = cache {
186            cache.insert(key.into_owned(), records.clone(), mx_lookup.valid_until());
187        }
188
189        Ok(records)
190    }
191
192    pub async fn ipv4_lookup<'x>(
193        &self,
194        key: impl IntoFqdn<'x>,
195        cache: Option<&impl ResolverCache<String, Arc<Vec<Ipv4Addr>>>>,
196    ) -> crate::Result<Arc<Vec<Ipv4Addr>>> {
197        let key = key.into_fqdn();
198        if let Some(value) = cache.as_ref().and_then(|c| c.get(key.as_ref())) {
199            return Ok(value);
200        }
201
202        let ipv4_lookup = self.ipv4_lookup_raw(key.as_ref()).await?;
203
204        if let Some(cache) = cache {
205            cache.insert(
206                key.into_owned(),
207                ipv4_lookup.entry.clone(),
208                ipv4_lookup.expires,
209            );
210        }
211
212        Ok(ipv4_lookup.entry)
213    }
214
215    pub async fn ipv4_lookup_raw(&self, key: &str) -> crate::Result<DnsEntry<Arc<Vec<Ipv4Addr>>>> {
216        #[cfg(any(test, feature = "test"))]
217        if true {
218            return mock_resolve(key);
219        }
220
221        let ipv4_lookup = self.0.ipv4_lookup(Name::from_str_relaxed(key)?).await?;
222        let ips: Arc<Vec<Ipv4Addr>> = ipv4_lookup
223            .as_lookup()
224            .record_iter()
225            .filter_map(|r| r.data().as_a()?.0.into())
226            .collect::<Vec<Ipv4Addr>>()
227            .into();
228
229        Ok(DnsEntry {
230            entry: ips,
231            expires: ipv4_lookup.valid_until(),
232        })
233    }
234
235    pub async fn ipv6_lookup<'x>(
236        &self,
237        key: impl IntoFqdn<'x>,
238        cache: Option<&impl ResolverCache<String, Arc<Vec<Ipv6Addr>>>>,
239    ) -> crate::Result<Arc<Vec<Ipv6Addr>>> {
240        let key = key.into_fqdn();
241        if let Some(value) = cache.as_ref().and_then(|c| c.get(key.as_ref())) {
242            return Ok(value);
243        }
244
245        let ipv6_lookup = self.ipv6_lookup_raw(key.as_ref()).await?;
246
247        if let Some(cache) = cache {
248            cache.insert(
249                key.into_owned(),
250                ipv6_lookup.entry.clone(),
251                ipv6_lookup.expires,
252            );
253        }
254
255        Ok(ipv6_lookup.entry)
256    }
257
258    pub async fn ipv6_lookup_raw(&self, key: &str) -> crate::Result<DnsEntry<Arc<Vec<Ipv6Addr>>>> {
259        #[cfg(any(test, feature = "test"))]
260        if true {
261            return mock_resolve(key);
262        }
263
264        let ipv6_lookup = self.0.ipv6_lookup(Name::from_str_relaxed(key)?).await?;
265        let ips: Arc<Vec<Ipv6Addr>> = ipv6_lookup
266            .as_lookup()
267            .record_iter()
268            .filter_map(|r| r.data().as_aaaa()?.0.into())
269            .collect::<Vec<Ipv6Addr>>()
270            .into();
271
272        Ok(DnsEntry {
273            entry: ips,
274            expires: ipv6_lookup.valid_until(),
275        })
276    }
277
278    pub async fn ip_lookup(
279        &self,
280        key: &str,
281        mut strategy: IpLookupStrategy,
282        max_results: usize,
283        cache_ipv4: Option<&impl ResolverCache<String, Arc<Vec<Ipv4Addr>>>>,
284        cache_ipv6: Option<&impl ResolverCache<String, Arc<Vec<Ipv6Addr>>>>,
285    ) -> crate::Result<Vec<IpAddr>> {
286        loop {
287            match strategy {
288                IpLookupStrategy::Ipv4Only | IpLookupStrategy::Ipv4thenIpv6 => {
289                    match (self.ipv4_lookup(key, cache_ipv4).await, strategy) {
290                        (Ok(result), _) => {
291                            return Ok(result
292                                .iter()
293                                .take(max_results)
294                                .copied()
295                                .map(IpAddr::from)
296                                .collect())
297                        }
298                        (Err(err), IpLookupStrategy::Ipv4Only) => return Err(err),
299                        _ => {
300                            strategy = IpLookupStrategy::Ipv6Only;
301                        }
302                    }
303                }
304                IpLookupStrategy::Ipv6Only | IpLookupStrategy::Ipv6thenIpv4 => {
305                    match (self.ipv6_lookup(key, cache_ipv6).await, strategy) {
306                        (Ok(result), _) => {
307                            return Ok(result
308                                .iter()
309                                .take(max_results)
310                                .copied()
311                                .map(IpAddr::from)
312                                .collect())
313                        }
314                        (Err(err), IpLookupStrategy::Ipv6Only) => return Err(err),
315                        _ => {
316                            strategy = IpLookupStrategy::Ipv4Only;
317                        }
318                    }
319                }
320            }
321        }
322    }
323
324    pub async fn ptr_lookup(
325        &self,
326        addr: IpAddr,
327        cache: Option<&impl ResolverCache<IpAddr, Arc<Vec<String>>>>,
328    ) -> crate::Result<Arc<Vec<String>>> {
329        if let Some(value) = cache.as_ref().and_then(|c| c.get(&addr)) {
330            return Ok(value);
331        }
332
333        #[cfg(any(test, feature = "test"))]
334        if true {
335            return mock_resolve(&addr.to_string());
336        }
337
338        let ptr_lookup = self.0.reverse_lookup(addr).await?;
339        let ptr: Arc<Vec<String>> = ptr_lookup
340            .as_lookup()
341            .record_iter()
342            .filter_map(|r| {
343                let r = r.data().as_ptr()?;
344                if !r.is_empty() {
345                    r.to_lowercase().to_string().into()
346                } else {
347                    None
348                }
349            })
350            .collect::<Vec<String>>()
351            .into();
352
353        if let Some(cache) = cache {
354            cache.insert(addr, ptr.clone(), ptr_lookup.valid_until());
355        }
356
357        Ok(ptr)
358    }
359
360    #[cfg(any(test, feature = "test"))]
361    pub async fn exists<'x>(
362        &self,
363        key: impl IntoFqdn<'x>,
364        cache_ipv4: Option<&impl ResolverCache<String, Arc<Vec<Ipv4Addr>>>>,
365        cache_ipv6: Option<&impl ResolverCache<String, Arc<Vec<Ipv6Addr>>>>,
366    ) -> crate::Result<bool> {
367        let key = key.into_fqdn().into_owned();
368        match self.ipv4_lookup(key.as_str(), cache_ipv4).await {
369            Ok(_) => Ok(true),
370            Err(Error::DnsRecordNotFound(_)) => {
371                match self.ipv6_lookup(key.as_str(), cache_ipv6).await {
372                    Ok(_) => Ok(true),
373                    Err(Error::DnsRecordNotFound(_)) => Ok(false),
374                    Err(err) => Err(err),
375                }
376            }
377            Err(err) => Err(err),
378        }
379    }
380
381    #[cfg(not(any(test, feature = "test")))]
382    pub async fn exists<'x>(
383        &self,
384        key: impl IntoFqdn<'x>,
385        cache_ipv4: Option<&impl ResolverCache<String, Arc<Vec<Ipv4Addr>>>>,
386        cache_ipv6: Option<&impl ResolverCache<String, Arc<Vec<Ipv6Addr>>>>,
387    ) -> crate::Result<bool> {
388        let key = key.into_fqdn();
389
390        if cache_ipv4.is_some_and(|c| c.get(key.as_ref()).is_some())
391            || cache_ipv6.is_some_and(|c| c.get(key.as_ref()).is_some())
392        {
393            return Ok(true);
394        }
395
396        match self
397            .0
398            .lookup_ip(Name::from_str_relaxed(key.as_ref())?)
399            .await
400        {
401            Ok(result) => Ok(result.as_lookup().record_iter().any(|r| {
402                matches!(
403                    r.data().record_type(),
404                    hickory_resolver::proto::rr::RecordType::A
405                        | hickory_resolver::proto::rr::RecordType::AAAA
406                )
407            })),
408            Err(err) => match err.kind() {
409                ProtoErrorKind::NoRecordsFound { .. } => Ok(false),
410                _ => Err(err.into()),
411            },
412        }
413    }
414}
415
416impl From<ProtoError> for Error {
417    fn from(err: ProtoError) -> Self {
418        match err.kind() {
419            ProtoErrorKind::NoRecordsFound(response_code) => {
420                Error::DnsRecordNotFound(response_code.response_code)
421            }
422            _ => Error::DnsError(err.to_string()),
423        }
424    }
425}
426
427impl From<DomainKey> for Txt {
428    fn from(v: DomainKey) -> Self {
429        Txt::DomainKey(v.into())
430    }
431}
432
433impl From<DomainKeyReport> for Txt {
434    fn from(v: DomainKeyReport) -> Self {
435        Txt::DomainKeyReport(v.into())
436    }
437}
438
439impl From<Atps> for Txt {
440    fn from(v: Atps) -> Self {
441        Txt::Atps(v.into())
442    }
443}
444
445impl From<Spf> for Txt {
446    fn from(v: Spf) -> Self {
447        Txt::Spf(v.into())
448    }
449}
450
451impl From<Macro> for Txt {
452    fn from(v: Macro) -> Self {
453        Txt::SpfMacro(v.into())
454    }
455}
456
457impl From<Dmarc> for Txt {
458    fn from(v: Dmarc) -> Self {
459        Txt::Dmarc(v.into())
460    }
461}
462
463impl From<MtaSts> for Txt {
464    fn from(v: MtaSts) -> Self {
465        Txt::MtaSts(v.into())
466    }
467}
468
469impl From<TlsRpt> for Txt {
470    fn from(v: TlsRpt) -> Self {
471        Txt::TlsRpt(v.into())
472    }
473}
474
475impl<T: Into<Txt>> From<crate::Result<T>> for Txt {
476    fn from(v: crate::Result<T>) -> Self {
477        match v {
478            Ok(v) => v.into(),
479            Err(err) => Txt::Error(err),
480        }
481    }
482}
483
484pub trait UnwrapTxtRecord: Sized {
485    fn unwrap_txt(txt: Txt) -> crate::Result<Arc<Self>>;
486}
487
488impl UnwrapTxtRecord for DomainKey {
489    fn unwrap_txt(txt: Txt) -> crate::Result<Arc<Self>> {
490        match txt {
491            Txt::DomainKey(a) => Ok(a),
492            Txt::Error(err) => Err(err),
493            _ => Err(Error::Io("Invalid record type".to_string())),
494        }
495    }
496}
497
498impl UnwrapTxtRecord for DomainKeyReport {
499    fn unwrap_txt(txt: Txt) -> crate::Result<Arc<Self>> {
500        match txt {
501            Txt::DomainKeyReport(a) => Ok(a),
502            Txt::Error(err) => Err(err),
503            _ => Err(Error::Io("Invalid record type".to_string())),
504        }
505    }
506}
507
508impl UnwrapTxtRecord for Atps {
509    fn unwrap_txt(txt: Txt) -> crate::Result<Arc<Self>> {
510        match txt {
511            Txt::Atps(a) => Ok(a),
512            Txt::Error(err) => Err(err),
513            _ => Err(Error::Io("Invalid record type".to_string())),
514        }
515    }
516}
517
518impl UnwrapTxtRecord for Spf {
519    fn unwrap_txt(txt: Txt) -> crate::Result<Arc<Self>> {
520        match txt {
521            Txt::Spf(a) => Ok(a),
522            Txt::Error(err) => Err(err),
523            _ => Err(Error::Io("Invalid record type".to_string())),
524        }
525    }
526}
527
528impl UnwrapTxtRecord for Macro {
529    fn unwrap_txt(txt: Txt) -> crate::Result<Arc<Self>> {
530        match txt {
531            Txt::SpfMacro(a) => Ok(a),
532            Txt::Error(err) => Err(err),
533            _ => Err(Error::Io("Invalid record type".to_string())),
534        }
535    }
536}
537
538impl UnwrapTxtRecord for Dmarc {
539    fn unwrap_txt(txt: Txt) -> crate::Result<Arc<Self>> {
540        match txt {
541            Txt::Dmarc(a) => Ok(a),
542            Txt::Error(err) => Err(err),
543            _ => Err(Error::Io("Invalid record type".to_string())),
544        }
545    }
546}
547
548impl UnwrapTxtRecord for MtaSts {
549    fn unwrap_txt(txt: Txt) -> crate::Result<Arc<Self>> {
550        match txt {
551            Txt::MtaSts(a) => Ok(a),
552            Txt::Error(err) => Err(err),
553            _ => Err(Error::Io("Invalid record type".to_string())),
554        }
555    }
556}
557
558impl UnwrapTxtRecord for TlsRpt {
559    fn unwrap_txt(txt: Txt) -> crate::Result<Arc<Self>> {
560        match txt {
561            Txt::TlsRpt(a) => Ok(a),
562            Txt::Error(err) => Err(err),
563            _ => Err(Error::Io("Invalid record type".to_string())),
564        }
565    }
566}
567
568pub trait IntoFqdn<'x> {
569    fn into_fqdn(self) -> Cow<'x, str>;
570}
571
572impl<'x> IntoFqdn<'x> for String {
573    fn into_fqdn(self) -> Cow<'x, str> {
574        if self.ends_with('.') {
575            self.to_lowercase().into()
576        } else {
577            format!("{}.", self.to_lowercase()).into()
578        }
579    }
580}
581
582impl<'x> IntoFqdn<'x> for &'x str {
583    fn into_fqdn(self) -> Cow<'x, str> {
584        if self.ends_with('.') {
585            self.to_lowercase().into()
586        } else {
587            format!("{}.", self.to_lowercase()).into()
588        }
589    }
590}
591
592impl<'x> IntoFqdn<'x> for &String {
593    fn into_fqdn(self) -> Cow<'x, str> {
594        if self.ends_with('.') {
595            self.to_lowercase().into()
596        } else {
597            format!("{}.", self.to_lowercase()).into()
598        }
599    }
600}
601
602pub trait ToReverseName {
603    fn to_reverse_name(&self) -> String;
604}
605
606impl ToReverseName for IpAddr {
607    fn to_reverse_name(&self) -> String {
608        use std::fmt::Write;
609
610        match self {
611            IpAddr::V4(ip) => {
612                let mut segments = String::with_capacity(15);
613                for octet in ip.octets().iter().rev() {
614                    if !segments.is_empty() {
615                        segments.push('.');
616                    }
617                    let _ = write!(&mut segments, "{}", octet);
618                }
619                segments
620            }
621            IpAddr::V6(ip) => {
622                let mut segments = String::with_capacity(63);
623                for segment in ip.segments().iter().rev() {
624                    for &p in format!("{segment:04x}").as_bytes().iter().rev() {
625                        if !segments.is_empty() {
626                            segments.push('.');
627                        }
628                        segments.push(char::from(p));
629                    }
630                }
631                segments
632            }
633        }
634    }
635}
636
637#[cfg(any(test, feature = "test"))]
638pub fn mock_resolve<T>(domain: &str) -> crate::Result<T> {
639    Err(if domain.contains("_parse_error.") {
640        Error::ParseError
641    } else if domain.contains("_invalid_record.") {
642        Error::InvalidRecordType
643    } else if domain.contains("_dns_error.") {
644        Error::DnsError("".to_string())
645    } else {
646        Error::DnsRecordNotFound(hickory_resolver::proto::op::ResponseCode::NXDomain)
647    })
648}
649
650#[cfg(test)]
651mod test {
652    use std::net::IpAddr;
653
654    use crate::common::resolver::ToReverseName;
655
656    #[test]
657    fn reverse_lookup_addr() {
658        for (addr, expected) in [
659            ("1.2.3.4", "4.3.2.1"),
660            (
661                "2001:db8::cb01",
662                "1.0.b.c.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2",
663            ),
664            (
665                "2a01:4f9:c011:b43c::1",
666                "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.c.3.4.b.1.1.0.c.9.f.4.0.1.0.a.2",
667            ),
668        ] {
669            assert_eq!(addr.parse::<IpAddr>().unwrap().to_reverse_name(), expected);
670        }
671    }
672}