viaspf/eval/
terms.rs

1// viaspf – implementation of the SPF specification
2// Copyright © 2020–2023 David Bürgin <dbuergin@gluet.ch>
3//
4// This program is free software: you can redistribute it and/or modify it under
5// the terms of the GNU General Public License as published by the Free Software
6// Foundation, either version 3 of the License, or (at your option) any later
7// version.
8//
9// This program is distributed in the hope that it will be useful, but WITHOUT
10// ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
11// FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
12// details.
13//
14// You should have received a copy of the GNU General Public License along with
15// this program. If not, see <https://www.gnu.org/licenses/>.
16
17use crate::{
18    eval::{
19        query::{Query, Resolver},
20        EvalError, EvalResult, Evaluate, EvaluateMatch, EvaluateToString, MatchResult,
21    },
22    lookup::{self, LookupError, LookupResult, Name},
23    record::{
24        DomainSpec, DualCidrLength, Exists, ExplainString, Explanation, Include, Ip4, Ip6,
25        Mechanism, Mx, Ptr, Redirect, A,
26    },
27    result::{ErrorCause, SpfResult},
28    trace::Tracepoint,
29};
30use async_trait::async_trait;
31use std::{
32    error::Error,
33    fmt::{self, Display, Formatter},
34    net::IpAddr,
35};
36
37#[async_trait]
38impl EvaluateMatch for Mechanism {
39    async fn evaluate_match(
40        &self,
41        query: &mut Query<'_>,
42        resolver: &Resolver<'_>,
43    ) -> EvalResult<MatchResult> {
44        trace!(query, Tracepoint::EvaluateMechanism(self.clone()));
45
46        let mechanism: &(dyn EvaluateMatch + Send + Sync) = match self {
47            Self::All => return Ok(MatchResult::Match),
48            Self::Include(include) => include,
49            Self::A(a) => a,
50            Self::Mx(mx) => mx,
51            Self::Ptr(ptr) => ptr,
52            Self::Ip4(ip4) => ip4,
53            Self::Ip6(ip6) => ip6,
54            Self::Exists(exists) => exists,
55        };
56
57        mechanism.evaluate_match(query, resolver).await
58    }
59}
60
61#[async_trait]
62impl EvaluateMatch for Include {
63    async fn evaluate_match(
64        &self,
65        query: &mut Query<'_>,
66        resolver: &Resolver<'_>,
67    ) -> EvalResult<MatchResult> {
68        increment_lookup_count(query)?;
69
70        let target_name = get_target_name(&self.domain_spec, query, resolver).await?;
71        trace!(query, Tracepoint::TargetName(target_name.clone()));
72
73        let result = execute_recursive_query(query, resolver, target_name, true).await;
74
75        // See the table in §5.2.
76        use SpfResult::*;
77        match result {
78            Pass => Ok(MatchResult::Match),
79            Fail(_) | Softfail | Neutral => Ok(MatchResult::NoMatch),
80            Temperror => Err(EvalError::RecursiveTemperror),
81            Permerror => Err(EvalError::RecursivePermerror),
82            None => Err(EvalError::IncludeNoSpfRecord),
83        }
84    }
85}
86
87#[async_trait]
88impl EvaluateMatch for A {
89    async fn evaluate_match(
90        &self,
91        query: &mut Query<'_>,
92        resolver: &Resolver<'_>,
93    ) -> EvalResult<MatchResult> {
94        increment_lookup_count(query)?;
95
96        let target_name =
97            get_target_name_or_domain(self.domain_spec.as_ref(), query, resolver).await?;
98        trace!(query, Tracepoint::TargetName(target_name.clone()));
99
100        let ip = query.params.ip();
101
102        // §5.3: ‘An address lookup is done on the <target-name> using the type
103        // of lookup (A or AAAA) appropriate for the connection type (IPv4 or
104        // IPv6).’
105        let addrs = to_eval_result(resolver.lookup_a_or_aaaa(query, &target_name, ip).await)?;
106        increment_void_lookup_count_if_void(query, addrs.len())?;
107
108        for addr in addrs {
109            trace!(query, Tracepoint::TryIpAddr(addr));
110
111            // ‘The <ip> is compared to the returned address(es). If any address
112            // matches, the mechanism matches.’
113            if is_in_network(addr, self.prefix_len, ip) {
114                return Ok(MatchResult::Match);
115            }
116        }
117
118        Ok(MatchResult::NoMatch)
119    }
120}
121
122#[async_trait]
123impl EvaluateMatch for Mx {
124    async fn evaluate_match(
125        &self,
126        query: &mut Query<'_>,
127        resolver: &Resolver<'_>,
128    ) -> EvalResult<MatchResult> {
129        increment_lookup_count(query)?;
130
131        let target_name =
132            get_target_name_or_domain(self.domain_spec.as_ref(), query, resolver).await?;
133        trace!(query, Tracepoint::TargetName(target_name.clone()));
134
135        let mxs = to_eval_result(resolver.lookup_mx(query, &target_name).await)?;
136        increment_void_lookup_count_if_void(query, mxs.len())?;
137
138        let ip = query.params.ip();
139
140        let mut i = 0;
141
142        for mx in mxs {
143            trace!(query, Tracepoint::TryMxName(mx.clone()));
144
145            // §4.6.4: ‘the evaluation of each "MX" record MUST NOT result in
146            // querying more than 10 address records -- either "A" or "AAAA"
147            // resource records. If this limit is exceeded, the "mx" mechanism
148            // MUST produce a "permerror" result.’
149            increment_per_mechanism_lookup_count(query, &mut i)?;
150
151            let addrs = to_eval_result(resolver.lookup_a_or_aaaa(query, &mx, ip).await)?;
152            increment_void_lookup_count_if_void(query, addrs.len())?;
153
154            for addr in addrs {
155                trace!(query, Tracepoint::TryIpAddr(addr));
156
157                if is_in_network(addr, self.prefix_len, ip) {
158                    return Ok(MatchResult::Match);
159                }
160            }
161        }
162
163        Ok(MatchResult::NoMatch)
164    }
165}
166
167#[async_trait]
168impl EvaluateMatch for Ptr {
169    async fn evaluate_match(
170        &self,
171        query: &mut Query<'_>,
172        resolver: &Resolver<'_>,
173    ) -> EvalResult<MatchResult> {
174        increment_lookup_count(query)?;
175
176        let target_name =
177            get_target_name_or_domain(self.domain_spec.as_ref(), query, resolver).await?;
178        trace!(query, Tracepoint::TargetName(target_name.clone()));
179
180        let ip = query.params.ip();
181
182        let ptrs = match to_eval_result(resolver.lookup_ptr(query, ip).await) {
183            Ok(ptrs) => ptrs,
184            // §5.5: ‘If a DNS error occurs while doing the PTR RR lookup, then
185            // this mechanism fails to match.’
186            Err(e) => {
187                trace!(query, Tracepoint::ReverseLookupError(e));
188                return Ok(MatchResult::NoMatch);
189            }
190        };
191        increment_void_lookup_count_if_void(query, ptrs.len())?;
192
193        let validated_names = get_validated_domain_names(query, resolver, ip, ptrs).await?;
194
195        // ‘Check all validated domain names to see if they either match the
196        // <target-name> domain or are a subdomain of the <target-name> domain.
197        // If any do, this mechanism matches.’
198        for name in &validated_names {
199            trace!(query, Tracepoint::TryValidatedName(name.clone()));
200            if name == &target_name || name.is_subdomain_of(&target_name) {
201                return Ok(MatchResult::Match);
202            }
203        }
204
205        // ‘If no validated domain name can be found, or if none of the
206        // validated domain names match or are a subdomain of the <target-name>,
207        // this mechanism fails to match.’
208        Ok(MatchResult::NoMatch)
209    }
210}
211
212pub async fn get_validated_domain_names(
213    query: &mut Query<'_>,
214    resolver: &Resolver<'_>,
215    ip: IpAddr,
216    names: Vec<Name>,
217) -> EvalResult<Vec<Name>> {
218    let mut validated_names = Vec::new();
219
220    let mut i = 0;
221
222    // §5.5: ‘For each record returned, validate the domain name by looking up
223    // its IP addresses.’
224    for name in names {
225        trace!(query, Tracepoint::ValidatePtrName(name.clone()));
226
227        // §4.6.4: ‘the evaluation of each "PTR" record MUST NOT result in
228        // querying more than 10 address records -- either "A" or "AAAA"
229        // resource records. If this limit is exceeded, all records other than
230        // the first 10 MUST be ignored.’
231        if increment_per_mechanism_lookup_count(query, &mut i).is_err() {
232            trace!(query, Tracepoint::PtrAddressLookupLimitExceeded);
233            break;
234        }
235
236        let addrs = match to_eval_result(resolver.lookup_a_or_aaaa(query, &name, ip).await) {
237            Ok(addrs) => addrs,
238            // §5.5: ‘If a DNS error occurs while doing an A RR lookup, then
239            // that domain name is skipped and the search continues.’
240            Err(e) => {
241                trace!(query, Tracepoint::PtrAddressLookupError(e));
242                continue;
243            }
244        };
245        increment_void_lookup_count_if_void(query, addrs.len())?;
246
247        for addr in addrs {
248            trace!(query, Tracepoint::TryIpAddr(addr));
249
250            // §5.5: ‘If <ip> is among the returned IP addresses, then that
251            // domain name is validated.’
252            if addr == ip {
253                trace!(query, Tracepoint::PtrNameValidated);
254                validated_names.push(name);
255                break;
256            }
257        }
258    }
259
260    Ok(validated_names)
261}
262
263#[async_trait]
264impl EvaluateMatch for Ip4 {
265    async fn evaluate_match(
266        &self,
267        query: &mut Query<'_>,
268        _: &Resolver<'_>,
269    ) -> EvalResult<MatchResult> {
270        Ok(if is_in_network(self.addr, self.prefix_len, query.params.ip()) {
271            MatchResult::Match
272        } else {
273            MatchResult::NoMatch
274        })
275    }
276}
277
278#[async_trait]
279impl EvaluateMatch for Ip6 {
280    async fn evaluate_match(
281        &self,
282        query: &mut Query<'_>,
283        _: &Resolver<'_>,
284    ) -> EvalResult<MatchResult> {
285        Ok(if is_in_network(self.addr, self.prefix_len, query.params.ip()) {
286            MatchResult::Match
287        } else {
288            MatchResult::NoMatch
289        })
290    }
291}
292
293fn is_in_network<A, L>(network_addr: A, prefix_len: Option<L>, ip: IpAddr) -> bool
294where
295    A: Into<IpAddr>,
296    L: Into<DualCidrLength>,
297{
298    match (network_addr.into(), ip) {
299        (IpAddr::V4(network_addr), IpAddr::V4(ip)) => {
300            match prefix_len.and_then(|l| l.into().ip4()) {
301                // §5: ‘If no CIDR prefix length is given in the directive, then
302                // <ip> and the IP address are compared for equality.’
303                None => network_addr == ip,
304                // ‘If a CIDR prefix length is specified, then only the
305                // specified number of high-order bits of <ip> and the IP
306                // address are compared for equality.’
307                Some(len) => {
308                    let mask = u32::MAX << (32 - len.get());
309                    (u32::from(network_addr) & mask) == (u32::from(ip) & mask)
310                }
311            }
312        }
313        (IpAddr::V6(network_addr), IpAddr::V6(ip)) => {
314            match prefix_len.and_then(|l| l.into().ip6()) {
315                None => network_addr == ip,
316                Some(len) => {
317                    let mask = u128::MAX << (128 - len.get());
318                    (u128::from(network_addr) & mask) == (u128::from(ip) & mask)
319                }
320            }
321        }
322        _ => false,
323    }
324}
325
326#[async_trait]
327impl EvaluateMatch for Exists {
328    async fn evaluate_match(
329        &self,
330        query: &mut Query<'_>,
331        resolver: &Resolver<'_>,
332    ) -> EvalResult<MatchResult> {
333        increment_lookup_count(query)?;
334
335        let target_name = get_target_name(&self.domain_spec, query, resolver).await?;
336        trace!(query, Tracepoint::TargetName(target_name.clone()));
337
338        // §5.7: ‘The resulting domain name is used for a DNS A RR lookup (even
339        // when the connection type is IPv6).’
340        let addrs = to_eval_result(resolver.lookup_a(query, &target_name).await)?;
341        increment_void_lookup_count_if_void(query, addrs.len())?;
342
343        // ‘If any A record is returned, this mechanism matches.’
344        Ok(if addrs.is_empty() {
345            MatchResult::NoMatch
346        } else {
347            MatchResult::Match
348        })
349    }
350}
351
352#[async_trait]
353impl Evaluate for Redirect {
354    async fn evaluate(&self, query: &mut Query<'_>, resolver: &Resolver<'_>) -> SpfResult {
355        trace!(query, Tracepoint::EvaluateRedirect(self.clone()));
356
357        if let Err(e) = increment_lookup_count(query) {
358            trace!(query, Tracepoint::RedirectLookupLimitExceeded);
359            query.result_cause = e.to_error_cause().map(From::from);
360            return e.to_spf_result();
361        }
362
363        // §6.1: ‘if the <target-name> is malformed, the result is a "permerror"
364        // rather than "none"’
365        let target_name = match get_target_name(&self.domain_spec, query, resolver).await {
366            Ok(n) => n,
367            Err(e) => {
368                trace!(query, Tracepoint::InvalidRedirectTargetName);
369                query.result_cause = e.to_error_cause().map(From::from);
370                return e.to_spf_result();
371            }
372        };
373        trace!(query, Tracepoint::TargetName(target_name.clone()));
374
375        let result = execute_recursive_query(query, resolver, target_name, false).await;
376
377        // ‘The result of this new evaluation of check_host() is then considered
378        // the result of the current evaluation with the exception that if no
379        // SPF record is found, […] the result is a "permerror" rather than
380        // "none".’
381        match result {
382            SpfResult::None => {
383                trace!(query, Tracepoint::RedirectNoSpfRecord);
384                query.result_cause = Some(ErrorCause::NoSpfRecord.into());
385                SpfResult::Permerror
386            }
387            result => result,
388        }
389    }
390}
391
392async fn execute_recursive_query(
393    query: &mut Query<'_>,
394    resolver: &Resolver<'_>,
395    target_name: Name,
396    included: bool,
397) -> SpfResult {
398    // For recursive queries, adjust the target domain and included query flag
399    // before execution, and restore them afterwards. Included redirections keep
400    // their included flag set.
401    let prev_name = query.params.replace_domain(target_name);
402    let prev_included = query.state.is_included_query();
403    query.state.set_included_query(prev_included || included);
404
405    let result = query.execute(resolver).await;
406
407    query.params.replace_domain(prev_name);
408    query.state.set_included_query(prev_included);
409
410    result
411}
412
413#[async_trait]
414impl EvaluateToString for Explanation {
415    async fn evaluate_to_string(
416        &self,
417        query: &mut Query<'_>,
418        resolver: &Resolver<'_>,
419    ) -> EvalResult<String> {
420        trace!(query, Tracepoint::EvaluateExplanation(self.clone()));
421
422        let target_name = get_target_name(&self.domain_spec, query, resolver).await?;
423        trace!(query, Tracepoint::TargetName(target_name.clone()));
424
425        // §6.2: ‘The fetched TXT record's strings are concatenated with no
426        // spaces, and then treated as an explain-string, which is
427        // macro-expanded.’
428        let mut explain_string = match lookup_explain_string(resolver, query, &target_name).await {
429            Ok(e) => e,
430            Err(e) => {
431                // ‘If there are any DNS processing errors (any RCODE other than
432                // 0), or if no records are returned, or if more than one record
433                // is returned, or if there are syntax errors in the explanation
434                // string, then proceed as if no "exp" modifier was given.’
435                use ExplainStringLookupError::*;
436                trace!(
437                    query,
438                    match e {
439                        DnsLookup(e) => Tracepoint::ExplainStringLookupError(e),
440                        NoExplainString => Tracepoint::NoExplainString,
441                        MultipleExplainStrings(s) => Tracepoint::MultipleExplainStrings(s),
442                        Syntax(s) => Tracepoint::InvalidExplainStringSyntax(s),
443                    }
444                );
445
446                // After the tracing above, may now conflate the error causes:
447                return Err(EvalError::Dns(None));
448            }
449        };
450
451        if let Some(f) = query.config.modify_exp_fn() {
452            trace!(query, Tracepoint::ModifyExplainString(explain_string.clone()));
453            f(&mut explain_string);
454        }
455
456        explain_string.evaluate_to_string(query, resolver).await
457    }
458}
459
460#[derive(Debug)]
461enum ExplainStringLookupError {
462    DnsLookup(LookupError),
463    NoExplainString,
464    MultipleExplainStrings(Vec<String>),
465    Syntax(String),
466}
467
468impl Error for ExplainStringLookupError {}
469
470impl Display for ExplainStringLookupError {
471    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
472        write!(f, "failed to obtain explain string")
473    }
474}
475
476impl From<LookupError> for ExplainStringLookupError {
477    fn from(error: LookupError) -> Self {
478        match error {
479            LookupError::NoRecords => Self::NoExplainString,
480            _ => Self::DnsLookup(error),
481        }
482    }
483}
484
485async fn lookup_explain_string(
486    resolver: &Resolver<'_>,
487    query: &mut Query<'_>,
488    name: &Name,
489) -> Result<ExplainString, ExplainStringLookupError> {
490    let mut exps = resolver.lookup_txt(query, name).await?.into_iter();
491
492    use ExplainStringLookupError::*;
493    match exps.next() {
494        None => Err(NoExplainString),
495        Some(exp) => {
496            let mut rest = exps.collect::<Vec<_>>();
497            match *rest {
498                [] => exp.parse().map_err(|_| Syntax(exp)),
499                [..] => {
500                    rest.insert(0, exp);
501                    Err(MultipleExplainStrings(rest))
502                }
503            }
504        }
505    }
506}
507
508async fn get_target_name_or_domain(
509    domain_spec: Option<&DomainSpec>,
510    query: &mut Query<'_>,
511    resolver: &Resolver<'_>,
512) -> EvalResult<Name> {
513    // §4.8: ‘For several mechanisms, the <domain-spec> is optional. If it is
514    // not provided, the <domain> from the check_host() arguments is used as the
515    // <target-name>.’
516    match domain_spec {
517        None => Ok(query.params.domain().clone()),
518        Some(domain_spec) => get_target_name(domain_spec, query, resolver).await,
519    }
520}
521
522async fn get_target_name(
523    domain_spec: &DomainSpec,
524    query: &mut Query<'_>,
525    resolver: &Resolver<'_>,
526) -> EvalResult<Name> {
527    // §4.8: ‘The <domain-spec> string is subject to macro expansion […]. The
528    // resulting string is the common presentation form of a fully qualified DNS
529    // name’
530    let mut name = domain_spec.evaluate_to_string(query, resolver).await?;
531    truncate_target_name_string(&mut name, lookup::MAX_DOMAIN_LENGTH);
532    Name::new(&name).map_err(|_| EvalError::InvalidName(name))
533}
534
535// §7.3: ‘When the result of macro expansion is used in a domain name query, if
536// the expanded domain name exceeds 253 characters (the maximum length of a
537// domain name in this format), the left side is truncated to fit, by removing
538// successive domain labels (and their following dots) until the total length
539// does not exceed 253 characters.’
540fn truncate_target_name_string(s: &mut String, max: usize) {
541    if s.ends_with('.') {
542        s.pop();
543    }
544    let len = s.len();
545    if len > max {
546        if let Some((i, _)) = s
547            .rmatch_indices('.')
548            .take_while(|(i, _)| len - i - 1 <= max)
549            .last()
550        {
551            s.drain(..=i);
552        }
553    }
554}
555
556// §4.6.4: ‘The following terms cause DNS queries: the "include", "a", "mx",
557// "ptr", and "exists" mechanisms, and the "redirect" modifier. SPF
558// implementations MUST limit the total number of those terms to 10 during SPF
559// evaluation’
560fn increment_lookup_count(query: &mut Query) -> EvalResult<()> {
561    trace!(query, Tracepoint::IncrementLookupCount);
562    query.state.increment_lookup_count(query.config.max_lookups())
563}
564
565// §4.6.4: ‘there may be cases where it is useful to limit the number of "terms"
566// for which DNS queries return either a positive answer (RCODE 0) with an
567// answer count of 0, or a "Name Error" (RCODE 3) answer. These are sometimes
568// collectively referred to as "void lookups".’
569pub fn increment_void_lookup_count_if_void(query: &mut Query, count: usize) -> EvalResult<()> {
570    if count == 0 {
571        trace!(query, Tracepoint::IncrementVoidLookupCount);
572        query.state.increment_void_lookup_count(query.config.max_void_lookups())
573    } else {
574        Ok(())
575    }
576}
577
578fn increment_per_mechanism_lookup_count(query: &mut Query, i: &mut usize) -> EvalResult<()> {
579    trace!(query, Tracepoint::IncrementPerMechanismLookupCount);
580    if *i < query.config.max_lookups() {
581        *i += 1;
582        Ok(())
583    } else {
584        Err(EvalError::PerMechanismLookupLimitExceeded)
585    }
586}
587
588pub fn to_eval_result<T>(result: LookupResult<Vec<T>>) -> EvalResult<Vec<T>> {
589    match result {
590        Ok(r) => Ok(r),
591        Err(e) => {
592            match e {
593                LookupError::Timeout => Err(EvalError::Timeout),
594                // §5: ‘If the server returns "Name Error" (RCODE 3), then
595                // evaluation of the mechanism continues as if the server
596                // returned no error (RCODE 0) and zero answer records.’
597                LookupError::NoRecords => Ok(Vec::new()),
598                LookupError::Dns(e) => Err(EvalError::Dns(e)),
599            }
600        }
601    }
602}
603
604#[cfg(test)]
605mod tests {
606    use super::*;
607    use crate::record::Ip4CidrLength;
608
609    #[test]
610    fn is_in_network_ok() {
611        assert!(is_in_network(
612            IpAddr::from([123, 12, 12, 12]),
613            Some(Ip4CidrLength::new(24).unwrap()),
614            IpAddr::from([123, 12, 12, 98]),
615        ));
616    }
617
618    #[test]
619    fn truncate_target_name_string_ok() {
620        fn truncate<S: Into<String>>(s: S, max: usize) -> String {
621            let mut s = s.into();
622            truncate_target_name_string(&mut s, max);
623            s
624        }
625
626        // Pathological case where final label longer than limit → no-op:
627        assert_eq!(truncate("ab.cd.ef", 1), "ab.cd.ef");
628        assert_eq!(truncate("ab.cd.ef.", 1), "ab.cd.ef");
629
630        // Truncating:
631        assert_eq!(truncate("ab.cd.ef", 2), "ef");
632        assert_eq!(truncate("ab.cd.ef.", 2), "ef");
633        assert_eq!(truncate("ab.cd.ef", 3), "ef");
634        assert_eq!(truncate("ab.cd.ef", 4), "ef");
635        assert_eq!(truncate("ab.cd.ef", 5), "cd.ef");
636        assert_eq!(truncate("ab.cd.ef", 6), "cd.ef");
637        assert_eq!(truncate("ab.cd.ef", 7), "cd.ef");
638        assert_eq!(truncate("ab.cd.ef.", 7), "cd.ef");
639
640        // Not longer than limit → no-op:
641        assert_eq!(truncate("ab.cd.ef", 8), "ab.cd.ef");
642        assert_eq!(truncate("ab.cd.ef.", 8), "ab.cd.ef");
643        assert_eq!(truncate("ab.cd.ef", 9), "ab.cd.ef");
644    }
645}