1use crate::{
18 eval::{
19 query::{Query, Resolver},
20 EvalError, EvalResult, Evaluate, EvaluateMatch, EvaluateToString, MatchResult,
21 },
22 lookup::{self, LookupError, LookupResult, Name},
23 record::{
24 DomainSpec, DualCidrLength, Exists, ExplainString, Explanation, Include, Ip4, Ip6,
25 Mechanism, Mx, Ptr, Redirect, A,
26 },
27 result::{ErrorCause, SpfResult},
28 trace::Tracepoint,
29};
30use async_trait::async_trait;
31use std::{
32 error::Error,
33 fmt::{self, Display, Formatter},
34 net::IpAddr,
35};
36
37#[async_trait]
38impl EvaluateMatch for Mechanism {
39 async fn evaluate_match(
40 &self,
41 query: &mut Query<'_>,
42 resolver: &Resolver<'_>,
43 ) -> EvalResult<MatchResult> {
44 trace!(query, Tracepoint::EvaluateMechanism(self.clone()));
45
46 let mechanism: &(dyn EvaluateMatch + Send + Sync) = match self {
47 Self::All => return Ok(MatchResult::Match),
48 Self::Include(include) => include,
49 Self::A(a) => a,
50 Self::Mx(mx) => mx,
51 Self::Ptr(ptr) => ptr,
52 Self::Ip4(ip4) => ip4,
53 Self::Ip6(ip6) => ip6,
54 Self::Exists(exists) => exists,
55 };
56
57 mechanism.evaluate_match(query, resolver).await
58 }
59}
60
61#[async_trait]
62impl EvaluateMatch for Include {
63 async fn evaluate_match(
64 &self,
65 query: &mut Query<'_>,
66 resolver: &Resolver<'_>,
67 ) -> EvalResult<MatchResult> {
68 increment_lookup_count(query)?;
69
70 let target_name = get_target_name(&self.domain_spec, query, resolver).await?;
71 trace!(query, Tracepoint::TargetName(target_name.clone()));
72
73 let result = execute_recursive_query(query, resolver, target_name, true).await;
74
75 use SpfResult::*;
77 match result {
78 Pass => Ok(MatchResult::Match),
79 Fail(_) | Softfail | Neutral => Ok(MatchResult::NoMatch),
80 Temperror => Err(EvalError::RecursiveTemperror),
81 Permerror => Err(EvalError::RecursivePermerror),
82 None => Err(EvalError::IncludeNoSpfRecord),
83 }
84 }
85}
86
87#[async_trait]
88impl EvaluateMatch for A {
89 async fn evaluate_match(
90 &self,
91 query: &mut Query<'_>,
92 resolver: &Resolver<'_>,
93 ) -> EvalResult<MatchResult> {
94 increment_lookup_count(query)?;
95
96 let target_name =
97 get_target_name_or_domain(self.domain_spec.as_ref(), query, resolver).await?;
98 trace!(query, Tracepoint::TargetName(target_name.clone()));
99
100 let ip = query.params.ip();
101
102 let addrs = to_eval_result(resolver.lookup_a_or_aaaa(query, &target_name, ip).await)?;
106 increment_void_lookup_count_if_void(query, addrs.len())?;
107
108 for addr in addrs {
109 trace!(query, Tracepoint::TryIpAddr(addr));
110
111 if is_in_network(addr, self.prefix_len, ip) {
114 return Ok(MatchResult::Match);
115 }
116 }
117
118 Ok(MatchResult::NoMatch)
119 }
120}
121
122#[async_trait]
123impl EvaluateMatch for Mx {
124 async fn evaluate_match(
125 &self,
126 query: &mut Query<'_>,
127 resolver: &Resolver<'_>,
128 ) -> EvalResult<MatchResult> {
129 increment_lookup_count(query)?;
130
131 let target_name =
132 get_target_name_or_domain(self.domain_spec.as_ref(), query, resolver).await?;
133 trace!(query, Tracepoint::TargetName(target_name.clone()));
134
135 let mxs = to_eval_result(resolver.lookup_mx(query, &target_name).await)?;
136 increment_void_lookup_count_if_void(query, mxs.len())?;
137
138 let ip = query.params.ip();
139
140 let mut i = 0;
141
142 for mx in mxs {
143 trace!(query, Tracepoint::TryMxName(mx.clone()));
144
145 increment_per_mechanism_lookup_count(query, &mut i)?;
150
151 let addrs = to_eval_result(resolver.lookup_a_or_aaaa(query, &mx, ip).await)?;
152 increment_void_lookup_count_if_void(query, addrs.len())?;
153
154 for addr in addrs {
155 trace!(query, Tracepoint::TryIpAddr(addr));
156
157 if is_in_network(addr, self.prefix_len, ip) {
158 return Ok(MatchResult::Match);
159 }
160 }
161 }
162
163 Ok(MatchResult::NoMatch)
164 }
165}
166
167#[async_trait]
168impl EvaluateMatch for Ptr {
169 async fn evaluate_match(
170 &self,
171 query: &mut Query<'_>,
172 resolver: &Resolver<'_>,
173 ) -> EvalResult<MatchResult> {
174 increment_lookup_count(query)?;
175
176 let target_name =
177 get_target_name_or_domain(self.domain_spec.as_ref(), query, resolver).await?;
178 trace!(query, Tracepoint::TargetName(target_name.clone()));
179
180 let ip = query.params.ip();
181
182 let ptrs = match to_eval_result(resolver.lookup_ptr(query, ip).await) {
183 Ok(ptrs) => ptrs,
184 Err(e) => {
187 trace!(query, Tracepoint::ReverseLookupError(e));
188 return Ok(MatchResult::NoMatch);
189 }
190 };
191 increment_void_lookup_count_if_void(query, ptrs.len())?;
192
193 let validated_names = get_validated_domain_names(query, resolver, ip, ptrs).await?;
194
195 for name in &validated_names {
199 trace!(query, Tracepoint::TryValidatedName(name.clone()));
200 if name == &target_name || name.is_subdomain_of(&target_name) {
201 return Ok(MatchResult::Match);
202 }
203 }
204
205 Ok(MatchResult::NoMatch)
209 }
210}
211
212pub async fn get_validated_domain_names(
213 query: &mut Query<'_>,
214 resolver: &Resolver<'_>,
215 ip: IpAddr,
216 names: Vec<Name>,
217) -> EvalResult<Vec<Name>> {
218 let mut validated_names = Vec::new();
219
220 let mut i = 0;
221
222 for name in names {
225 trace!(query, Tracepoint::ValidatePtrName(name.clone()));
226
227 if increment_per_mechanism_lookup_count(query, &mut i).is_err() {
232 trace!(query, Tracepoint::PtrAddressLookupLimitExceeded);
233 break;
234 }
235
236 let addrs = match to_eval_result(resolver.lookup_a_or_aaaa(query, &name, ip).await) {
237 Ok(addrs) => addrs,
238 Err(e) => {
241 trace!(query, Tracepoint::PtrAddressLookupError(e));
242 continue;
243 }
244 };
245 increment_void_lookup_count_if_void(query, addrs.len())?;
246
247 for addr in addrs {
248 trace!(query, Tracepoint::TryIpAddr(addr));
249
250 if addr == ip {
253 trace!(query, Tracepoint::PtrNameValidated);
254 validated_names.push(name);
255 break;
256 }
257 }
258 }
259
260 Ok(validated_names)
261}
262
263#[async_trait]
264impl EvaluateMatch for Ip4 {
265 async fn evaluate_match(
266 &self,
267 query: &mut Query<'_>,
268 _: &Resolver<'_>,
269 ) -> EvalResult<MatchResult> {
270 Ok(if is_in_network(self.addr, self.prefix_len, query.params.ip()) {
271 MatchResult::Match
272 } else {
273 MatchResult::NoMatch
274 })
275 }
276}
277
278#[async_trait]
279impl EvaluateMatch for Ip6 {
280 async fn evaluate_match(
281 &self,
282 query: &mut Query<'_>,
283 _: &Resolver<'_>,
284 ) -> EvalResult<MatchResult> {
285 Ok(if is_in_network(self.addr, self.prefix_len, query.params.ip()) {
286 MatchResult::Match
287 } else {
288 MatchResult::NoMatch
289 })
290 }
291}
292
293fn is_in_network<A, L>(network_addr: A, prefix_len: Option<L>, ip: IpAddr) -> bool
294where
295 A: Into<IpAddr>,
296 L: Into<DualCidrLength>,
297{
298 match (network_addr.into(), ip) {
299 (IpAddr::V4(network_addr), IpAddr::V4(ip)) => {
300 match prefix_len.and_then(|l| l.into().ip4()) {
301 None => network_addr == ip,
304 Some(len) => {
308 let mask = u32::MAX << (32 - len.get());
309 (u32::from(network_addr) & mask) == (u32::from(ip) & mask)
310 }
311 }
312 }
313 (IpAddr::V6(network_addr), IpAddr::V6(ip)) => {
314 match prefix_len.and_then(|l| l.into().ip6()) {
315 None => network_addr == ip,
316 Some(len) => {
317 let mask = u128::MAX << (128 - len.get());
318 (u128::from(network_addr) & mask) == (u128::from(ip) & mask)
319 }
320 }
321 }
322 _ => false,
323 }
324}
325
326#[async_trait]
327impl EvaluateMatch for Exists {
328 async fn evaluate_match(
329 &self,
330 query: &mut Query<'_>,
331 resolver: &Resolver<'_>,
332 ) -> EvalResult<MatchResult> {
333 increment_lookup_count(query)?;
334
335 let target_name = get_target_name(&self.domain_spec, query, resolver).await?;
336 trace!(query, Tracepoint::TargetName(target_name.clone()));
337
338 let addrs = to_eval_result(resolver.lookup_a(query, &target_name).await)?;
341 increment_void_lookup_count_if_void(query, addrs.len())?;
342
343 Ok(if addrs.is_empty() {
345 MatchResult::NoMatch
346 } else {
347 MatchResult::Match
348 })
349 }
350}
351
352#[async_trait]
353impl Evaluate for Redirect {
354 async fn evaluate(&self, query: &mut Query<'_>, resolver: &Resolver<'_>) -> SpfResult {
355 trace!(query, Tracepoint::EvaluateRedirect(self.clone()));
356
357 if let Err(e) = increment_lookup_count(query) {
358 trace!(query, Tracepoint::RedirectLookupLimitExceeded);
359 query.result_cause = e.to_error_cause().map(From::from);
360 return e.to_spf_result();
361 }
362
363 let target_name = match get_target_name(&self.domain_spec, query, resolver).await {
366 Ok(n) => n,
367 Err(e) => {
368 trace!(query, Tracepoint::InvalidRedirectTargetName);
369 query.result_cause = e.to_error_cause().map(From::from);
370 return e.to_spf_result();
371 }
372 };
373 trace!(query, Tracepoint::TargetName(target_name.clone()));
374
375 let result = execute_recursive_query(query, resolver, target_name, false).await;
376
377 match result {
382 SpfResult::None => {
383 trace!(query, Tracepoint::RedirectNoSpfRecord);
384 query.result_cause = Some(ErrorCause::NoSpfRecord.into());
385 SpfResult::Permerror
386 }
387 result => result,
388 }
389 }
390}
391
392async fn execute_recursive_query(
393 query: &mut Query<'_>,
394 resolver: &Resolver<'_>,
395 target_name: Name,
396 included: bool,
397) -> SpfResult {
398 let prev_name = query.params.replace_domain(target_name);
402 let prev_included = query.state.is_included_query();
403 query.state.set_included_query(prev_included || included);
404
405 let result = query.execute(resolver).await;
406
407 query.params.replace_domain(prev_name);
408 query.state.set_included_query(prev_included);
409
410 result
411}
412
413#[async_trait]
414impl EvaluateToString for Explanation {
415 async fn evaluate_to_string(
416 &self,
417 query: &mut Query<'_>,
418 resolver: &Resolver<'_>,
419 ) -> EvalResult<String> {
420 trace!(query, Tracepoint::EvaluateExplanation(self.clone()));
421
422 let target_name = get_target_name(&self.domain_spec, query, resolver).await?;
423 trace!(query, Tracepoint::TargetName(target_name.clone()));
424
425 let mut explain_string = match lookup_explain_string(resolver, query, &target_name).await {
429 Ok(e) => e,
430 Err(e) => {
431 use ExplainStringLookupError::*;
436 trace!(
437 query,
438 match e {
439 DnsLookup(e) => Tracepoint::ExplainStringLookupError(e),
440 NoExplainString => Tracepoint::NoExplainString,
441 MultipleExplainStrings(s) => Tracepoint::MultipleExplainStrings(s),
442 Syntax(s) => Tracepoint::InvalidExplainStringSyntax(s),
443 }
444 );
445
446 return Err(EvalError::Dns(None));
448 }
449 };
450
451 if let Some(f) = query.config.modify_exp_fn() {
452 trace!(query, Tracepoint::ModifyExplainString(explain_string.clone()));
453 f(&mut explain_string);
454 }
455
456 explain_string.evaluate_to_string(query, resolver).await
457 }
458}
459
460#[derive(Debug)]
461enum ExplainStringLookupError {
462 DnsLookup(LookupError),
463 NoExplainString,
464 MultipleExplainStrings(Vec<String>),
465 Syntax(String),
466}
467
468impl Error for ExplainStringLookupError {}
469
470impl Display for ExplainStringLookupError {
471 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
472 write!(f, "failed to obtain explain string")
473 }
474}
475
476impl From<LookupError> for ExplainStringLookupError {
477 fn from(error: LookupError) -> Self {
478 match error {
479 LookupError::NoRecords => Self::NoExplainString,
480 _ => Self::DnsLookup(error),
481 }
482 }
483}
484
485async fn lookup_explain_string(
486 resolver: &Resolver<'_>,
487 query: &mut Query<'_>,
488 name: &Name,
489) -> Result<ExplainString, ExplainStringLookupError> {
490 let mut exps = resolver.lookup_txt(query, name).await?.into_iter();
491
492 use ExplainStringLookupError::*;
493 match exps.next() {
494 None => Err(NoExplainString),
495 Some(exp) => {
496 let mut rest = exps.collect::<Vec<_>>();
497 match *rest {
498 [] => exp.parse().map_err(|_| Syntax(exp)),
499 [..] => {
500 rest.insert(0, exp);
501 Err(MultipleExplainStrings(rest))
502 }
503 }
504 }
505 }
506}
507
508async fn get_target_name_or_domain(
509 domain_spec: Option<&DomainSpec>,
510 query: &mut Query<'_>,
511 resolver: &Resolver<'_>,
512) -> EvalResult<Name> {
513 match domain_spec {
517 None => Ok(query.params.domain().clone()),
518 Some(domain_spec) => get_target_name(domain_spec, query, resolver).await,
519 }
520}
521
522async fn get_target_name(
523 domain_spec: &DomainSpec,
524 query: &mut Query<'_>,
525 resolver: &Resolver<'_>,
526) -> EvalResult<Name> {
527 let mut name = domain_spec.evaluate_to_string(query, resolver).await?;
531 truncate_target_name_string(&mut name, lookup::MAX_DOMAIN_LENGTH);
532 Name::new(&name).map_err(|_| EvalError::InvalidName(name))
533}
534
535fn truncate_target_name_string(s: &mut String, max: usize) {
541 if s.ends_with('.') {
542 s.pop();
543 }
544 let len = s.len();
545 if len > max {
546 if let Some((i, _)) = s
547 .rmatch_indices('.')
548 .take_while(|(i, _)| len - i - 1 <= max)
549 .last()
550 {
551 s.drain(..=i);
552 }
553 }
554}
555
556fn increment_lookup_count(query: &mut Query) -> EvalResult<()> {
561 trace!(query, Tracepoint::IncrementLookupCount);
562 query.state.increment_lookup_count(query.config.max_lookups())
563}
564
565pub fn increment_void_lookup_count_if_void(query: &mut Query, count: usize) -> EvalResult<()> {
570 if count == 0 {
571 trace!(query, Tracepoint::IncrementVoidLookupCount);
572 query.state.increment_void_lookup_count(query.config.max_void_lookups())
573 } else {
574 Ok(())
575 }
576}
577
578fn increment_per_mechanism_lookup_count(query: &mut Query, i: &mut usize) -> EvalResult<()> {
579 trace!(query, Tracepoint::IncrementPerMechanismLookupCount);
580 if *i < query.config.max_lookups() {
581 *i += 1;
582 Ok(())
583 } else {
584 Err(EvalError::PerMechanismLookupLimitExceeded)
585 }
586}
587
588pub fn to_eval_result<T>(result: LookupResult<Vec<T>>) -> EvalResult<Vec<T>> {
589 match result {
590 Ok(r) => Ok(r),
591 Err(e) => {
592 match e {
593 LookupError::Timeout => Err(EvalError::Timeout),
594 LookupError::NoRecords => Ok(Vec::new()),
598 LookupError::Dns(e) => Err(EvalError::Dns(e)),
599 }
600 }
601 }
602}
603
604#[cfg(test)]
605mod tests {
606 use super::*;
607 use crate::record::Ip4CidrLength;
608
609 #[test]
610 fn is_in_network_ok() {
611 assert!(is_in_network(
612 IpAddr::from([123, 12, 12, 12]),
613 Some(Ip4CidrLength::new(24).unwrap()),
614 IpAddr::from([123, 12, 12, 98]),
615 ));
616 }
617
618 #[test]
619 fn truncate_target_name_string_ok() {
620 fn truncate<S: Into<String>>(s: S, max: usize) -> String {
621 let mut s = s.into();
622 truncate_target_name_string(&mut s, max);
623 s
624 }
625
626 assert_eq!(truncate("ab.cd.ef", 1), "ab.cd.ef");
628 assert_eq!(truncate("ab.cd.ef.", 1), "ab.cd.ef");
629
630 assert_eq!(truncate("ab.cd.ef", 2), "ef");
632 assert_eq!(truncate("ab.cd.ef.", 2), "ef");
633 assert_eq!(truncate("ab.cd.ef", 3), "ef");
634 assert_eq!(truncate("ab.cd.ef", 4), "ef");
635 assert_eq!(truncate("ab.cd.ef", 5), "cd.ef");
636 assert_eq!(truncate("ab.cd.ef", 6), "cd.ef");
637 assert_eq!(truncate("ab.cd.ef", 7), "cd.ef");
638 assert_eq!(truncate("ab.cd.ef.", 7), "cd.ef");
639
640 assert_eq!(truncate("ab.cd.ef", 8), "ab.cd.ef");
642 assert_eq!(truncate("ab.cd.ef.", 8), "ab.cd.ef");
643 assert_eq!(truncate("ab.cd.ef", 9), "ab.cd.ef");
644 }
645}