1use std::net::IpAddr;
2
3use crate::common::dns::DnsResolver;
4use crate::common::domain;
5use crate::dkim::{DkimResult, DkimVerifier};
6use crate::dmarc::{DmarcEvaluator, DmarcResult};
7use crate::spf::{self, SpfResult};
8
9#[derive(Debug)]
11pub struct AuthenticationResult {
12 pub spf: SpfResult,
13 pub dkim: Vec<DkimResult>,
14 pub dmarc: DmarcResult,
15 pub from_domain: String,
16 pub spf_domain: String,
17}
18
19pub struct EmailAuthenticator<R: DnsResolver> {
21 resolver: R,
22 clock_skew: u64,
23 receiver: String,
24}
25
26impl<R: DnsResolver> EmailAuthenticator<R> {
27 pub fn new(resolver: R, receiver: impl Into<String>) -> Self {
28 Self {
29 resolver,
30 clock_skew: 300,
31 receiver: receiver.into(),
32 }
33 }
34
35 pub fn with_clock_skew(mut self, skew: u64) -> Self {
36 self.clock_skew = skew;
37 self
38 }
39
40 pub async fn authenticate(
47 &self,
48 message: &[u8],
49 client_ip: IpAddr,
50 helo: &str,
51 mail_from: &str,
52 ) -> Result<AuthenticationResult, AuthError> {
53 let (headers, body) = split_message(message);
55 let parsed_headers = parse_headers(&headers);
56
57 let from_domain = extract_from_domain(&parsed_headers)
59 .ok_or(AuthError::NoFromDomain)?;
60
61 let spf_domain = if mail_from.is_empty() || !mail_from.contains('@') {
63 helo.to_string()
64 } else {
65 domain::domain_from_email(mail_from)
66 .unwrap_or(helo)
67 .to_string()
68 };
69
70 let spf_result = spf::check_host(
72 &self.resolver,
73 client_ip,
74 helo,
75 mail_from,
76 &spf_domain,
77 &self.receiver,
78 )
79 .await;
80
81 let header_pairs: Vec<(&str, &str)> = parsed_headers
83 .iter()
84 .map(|(n, v)| (n.as_str(), v.as_str()))
85 .collect();
86 let dkim_verifier = DkimVerifier::new(&self.resolver)
87 .clock_skew(self.clock_skew);
88 let dkim_results = dkim_verifier.verify_message(&header_pairs, body).await;
89
90 let dmarc_evaluator = DmarcEvaluator::new(&self.resolver);
92 let dmarc_result = dmarc_evaluator
93 .evaluate(&from_domain, &spf_result, &spf_domain, &dkim_results)
94 .await;
95
96 Ok(AuthenticationResult {
97 spf: spf_result,
98 dkim: dkim_results,
99 dmarc: dmarc_result,
100 from_domain,
101 spf_domain,
102 })
103 }
104}
105
106#[derive(Debug, Clone, PartialEq, Eq)]
108pub enum AuthError {
109 NoFromDomain,
111}
112
113impl std::fmt::Display for AuthError {
114 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115 match self {
116 AuthError::NoFromDomain => write!(f, "no From domain found in message"),
117 }
118 }
119}
120
121impl std::error::Error for AuthError {}
122
123fn split_message(message: &[u8]) -> (&[u8], &[u8]) {
128 if let Some(pos) = find_bytes(message, b"\r\n\r\n") {
130 return (&message[..pos], &message[pos + 4..]);
131 }
132 if let Some(pos) = find_bytes(message, b"\n\n") {
134 return (&message[..pos], &message[pos + 2..]);
135 }
136 (message, b"")
138}
139
140fn find_bytes(haystack: &[u8], needle: &[u8]) -> Option<usize> {
142 haystack
143 .windows(needle.len())
144 .position(|w| w == needle)
145}
146
147fn parse_headers(header_bytes: &[u8]) -> Vec<(String, String)> {
150 let mut headers: Vec<(String, String)> = Vec::new();
153
154 let text = String::from_utf8_lossy(header_bytes);
156 let mut lines: Vec<&str> = Vec::new();
157
158 let mut remaining = text.as_ref();
160 while !remaining.is_empty() {
161 if let Some(pos) = remaining.find("\r\n") {
162 lines.push(&remaining[..pos]);
163 remaining = &remaining[pos + 2..];
164 } else if let Some(pos) = remaining.find('\n') {
165 lines.push(&remaining[..pos]);
166 remaining = &remaining[pos + 1..];
167 } else {
168 lines.push(remaining);
169 break;
170 }
171 }
172
173 let mut current_line = String::new();
175 for line in &lines {
176 if line.starts_with(' ') || line.starts_with('\t') {
177 if !current_line.is_empty() {
179 current_line.push_str("\r\n");
180 current_line.push_str(line);
181 }
182 } else {
183 if !current_line.is_empty() {
185 if let Some((name, value)) = split_header(¤t_line) {
186 headers.push((name, value));
187 }
188 }
189 current_line = line.to_string();
190 }
191 }
192 if !current_line.is_empty() {
194 if let Some((name, value)) = split_header(¤t_line) {
195 headers.push((name, value));
196 }
197 }
198
199 headers
200}
201
202fn split_header(line: &str) -> Option<(String, String)> {
204 let pos = line.find(':')?;
205 let name = line[..pos].to_string();
206 let value = line[pos + 1..].to_string();
207 Some((name, value))
208}
209
210fn extract_from_domain(headers: &[(String, String)]) -> Option<String> {
213 let from_value = headers
215 .iter()
216 .find(|(name, _)| name.eq_ignore_ascii_case("from"))?;
217
218 let value = &from_value.1;
219
220 let unfolded = unfold(value);
222
223 let stripped = strip_comments(&unfolded);
225
226 let email = extract_email_address(&stripped)?;
228
229 let domain = domain::domain_from_email(&email)?;
231 Some(domain.to_lowercase())
232}
233
234fn unfold(value: &str) -> String {
236 let mut result = String::with_capacity(value.len());
237 let mut chars = value.chars().peekable();
238 while let Some(c) = chars.next() {
239 if c == '\r' {
240 if chars.peek() == Some(&'\n') {
241 chars.next(); if matches!(chars.peek(), Some(' ' | '\t')) {
243 result.push(' ');
244 chars.next(); } else {
246 result.push('\r');
247 result.push('\n');
248 }
249 } else {
250 result.push(c);
251 }
252 } else {
253 result.push(c);
254 }
255 }
256 result
257}
258
259fn strip_comments(value: &str) -> String {
261 let mut result = String::with_capacity(value.len());
262 let mut depth = 0u32;
263 let mut escaped = false;
264
265 for c in value.chars() {
266 if escaped {
267 if depth == 0 {
268 result.push(c);
269 }
270 escaped = false;
271 continue;
272 }
273 if c == '\\' {
274 escaped = true;
275 if depth == 0 {
276 result.push(c);
277 }
278 continue;
279 }
280 if c == '(' {
281 depth += 1;
282 continue;
283 }
284 if c == ')' && depth > 0 {
285 depth -= 1;
286 continue;
287 }
288 if depth == 0 {
289 result.push(c);
290 }
291 }
292 result
293}
294
295fn extract_email_address(value: &str) -> Option<String> {
299 let trimmed = value.trim();
300
301 if let Some(start) = trimmed.rfind('<') {
303 if let Some(end) = trimmed[start..].find('>') {
304 let addr = trimmed[start + 1..start + end].trim();
305 if !addr.is_empty() && addr.contains('@') {
306 return Some(addr.to_string());
307 }
308 }
309 }
310
311 let addr = trimmed.trim();
313 if addr.contains('@') {
314 return Some(addr.to_string());
315 }
316
317 None
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323 use crate::common::dns::{DnsError, MxRecord};
324 use std::collections::HashMap;
325 use std::net::{Ipv4Addr, Ipv6Addr};
326
327 #[derive(Clone)]
329 struct MockResolver {
330 txt: HashMap<String, Vec<String>>,
331 a: HashMap<String, Vec<Ipv4Addr>>,
332 }
333
334 impl MockResolver {
335 fn new() -> Self {
336 Self {
337 txt: HashMap::new(),
338 a: HashMap::new(),
339 }
340 }
341
342 fn add_txt(&mut self, name: &str, records: Vec<&str>) {
343 self.txt
344 .insert(name.to_string(), records.into_iter().map(String::from).collect());
345 }
346
347 }
348
349 impl DnsResolver for MockResolver {
350 async fn query_txt(&self, name: &str) -> Result<Vec<String>, DnsError> {
351 self.txt
352 .get(name)
353 .cloned()
354 .ok_or(DnsError::NxDomain)
355 }
356 async fn query_a(&self, name: &str) -> Result<Vec<Ipv4Addr>, DnsError> {
357 self.a
358 .get(name)
359 .cloned()
360 .ok_or(DnsError::NxDomain)
361 }
362 async fn query_aaaa(&self, _name: &str) -> Result<Vec<Ipv6Addr>, DnsError> {
363 Err(DnsError::NxDomain)
364 }
365 async fn query_mx(&self, _name: &str) -> Result<Vec<MxRecord>, DnsError> {
366 Err(DnsError::NxDomain)
367 }
368 async fn query_ptr(&self, _ip: &IpAddr) -> Result<Vec<String>, DnsError> {
369 Err(DnsError::NxDomain)
370 }
371 async fn query_exists(&self, name: &str) -> Result<bool, DnsError> {
372 Ok(self.a.contains_key(name))
373 }
374 }
375
376 #[test]
379 fn split_message_crlf() {
380 let msg = b"From: test@example.com\r\nSubject: hi\r\n\r\nBody here";
381 let (headers, body) = split_message(msg);
382 assert_eq!(headers, b"From: test@example.com\r\nSubject: hi");
383 assert_eq!(body, b"Body here");
384 }
385
386 #[test]
387 fn split_message_lf_fallback() {
388 let msg = b"From: test@example.com\nSubject: hi\n\nBody here";
389 let (headers, body) = split_message(msg);
390 assert_eq!(headers, b"From: test@example.com\nSubject: hi");
391 assert_eq!(body, b"Body here");
392 }
393
394 #[test]
395 fn split_message_no_body() {
396 let msg = b"From: test@example.com\r\nSubject: hi";
397 let (headers, body) = split_message(msg);
398 assert_eq!(headers, b"From: test@example.com\r\nSubject: hi");
399 assert_eq!(body, b"");
400 }
401
402 #[test]
403 fn parse_headers_simple() {
404 let raw = b"From: alice@example.com\r\nSubject: Hello";
405 let headers = parse_headers(raw);
406 assert_eq!(headers.len(), 2);
407 assert_eq!(headers[0].0, "From");
408 assert_eq!(headers[0].1, " alice@example.com");
409 assert_eq!(headers[1].0, "Subject");
410 assert_eq!(headers[1].1, " Hello");
411 }
412
413 #[test]
414 fn parse_headers_folded() {
415 let raw = b"Subject: This is a long\r\n subject line\r\nFrom: test@example.com";
416 let headers = parse_headers(raw);
417 assert_eq!(headers.len(), 2);
418 assert_eq!(headers[0].0, "Subject");
419 assert!(headers[0].1.contains("long"));
420 assert!(headers[0].1.contains("subject line"));
421 assert_eq!(headers[1].0, "From");
422 }
423
424 #[test]
425 fn parse_headers_bare_lf() {
426 let raw = b"From: alice@example.com\nSubject: Hello";
427 let headers = parse_headers(raw);
428 assert_eq!(headers.len(), 2);
429 }
430
431 #[test]
432 fn extract_from_angle_brackets() {
433 let headers = vec![
434 ("From".to_string(), " \"John Smith\" <john@example.com>".to_string()),
435 ];
436 assert_eq!(
437 extract_from_domain(&headers),
438 Some("example.com".to_string())
439 );
440 }
441
442 #[test]
443 fn extract_from_bare_address() {
444 let headers = vec![
445 ("From".to_string(), " alice@example.org".to_string()),
446 ];
447 assert_eq!(
448 extract_from_domain(&headers),
449 Some("example.org".to_string())
450 );
451 }
452
453 #[test]
454 fn extract_from_with_comment() {
455 let headers = vec![
456 ("From".to_string(), " alice(comment)@example.com".to_string()),
457 ];
458 assert_eq!(
459 extract_from_domain(&headers),
460 Some("example.com".to_string())
461 );
462 }
463
464 #[test]
465 fn extract_from_nested_comments() {
466 let headers = vec![
467 ("From".to_string(), " alice(nested (comment))@example.com".to_string()),
468 ];
469 assert_eq!(
470 extract_from_domain(&headers),
471 Some("example.com".to_string())
472 );
473 }
474
475 #[test]
476 fn extract_from_display_name_with_comma() {
477 let headers = vec![
479 ("From".to_string(), " \"Smith, John\" <j@example.com>".to_string()),
480 ];
481 assert_eq!(
482 extract_from_domain(&headers),
483 Some("example.com".to_string())
484 );
485 }
486
487 #[test]
488 fn extract_from_case_insensitive() {
489 let headers = vec![
490 ("from".to_string(), " alice@EXAMPLE.COM".to_string()),
491 ];
492 assert_eq!(
493 extract_from_domain(&headers),
494 Some("example.com".to_string())
495 );
496 }
497
498 #[test]
499 fn extract_from_missing() {
500 let headers = vec![
501 ("Subject".to_string(), " Hello".to_string()),
502 ];
503 assert_eq!(extract_from_domain(&headers), None);
504 }
505
506 #[test]
507 fn extract_from_no_at() {
508 let headers = vec![
509 ("From".to_string(), " invalid-address".to_string()),
510 ];
511 assert_eq!(extract_from_domain(&headers), None);
512 }
513
514 #[test]
515 fn strip_comments_basic() {
516 assert_eq!(strip_comments("alice(test)@example.com"), "alice@example.com");
517 }
518
519 #[test]
520 fn strip_comments_nested() {
521 assert_eq!(
522 strip_comments("alice(a (b) c)@example.com"),
523 "alice@example.com"
524 );
525 }
526
527 #[test]
528 fn strip_comments_escaped_paren() {
529 assert_eq!(strip_comments("alice(\\))@example.com"), "alice@example.com");
530 }
531
532 #[test]
533 fn unfold_crlf_sp() {
534 assert_eq!(unfold("hello\r\n world"), "hello world");
535 }
536
537 #[test]
538 fn unfold_crlf_tab() {
539 assert_eq!(unfold("hello\r\n\tworld"), "hello world");
540 }
541
542 #[test]
543 fn unfold_no_fold() {
544 assert_eq!(unfold("hello world"), "hello world");
545 }
546
547 #[tokio::test]
550 async fn authenticate_full_pipeline() {
551 let mut resolver = MockResolver::new();
552
553 resolver.add_txt("example.com", vec!["v=spf1 ip4:192.0.2.1 -all"]);
555
556 resolver.add_txt("_dmarc.example.com", vec!["v=DMARC1; p=reject; adkim=r; aspf=r"]);
558
559 let message = b"From: sender@example.com\r\nSubject: Test\r\n\r\nBody";
560
561 let auth = EmailAuthenticator::new(resolver, "mx.receiver.com");
562 let result = auth
563 .authenticate(
564 message,
565 "192.0.2.1".parse().unwrap(),
566 "mail.example.com",
567 "sender@example.com",
568 )
569 .await
570 .unwrap();
571
572 assert_eq!(result.from_domain, "example.com");
573 assert_eq!(result.spf_domain, "example.com");
574 assert!(matches!(result.spf, SpfResult::Pass));
575 assert_eq!(result.dkim.len(), 1);
577 assert!(matches!(result.dkim[0], DkimResult::None));
578 }
579
580 #[tokio::test]
581 async fn authenticate_no_from_header() {
582 let resolver = MockResolver::new();
583 let message = b"Subject: No from\r\n\r\nBody";
584 let auth = EmailAuthenticator::new(resolver, "mx.receiver.com");
585 let result = auth
586 .authenticate(
587 message,
588 "192.0.2.1".parse().unwrap(),
589 "mail.example.com",
590 "sender@example.com",
591 )
592 .await;
593 assert!(matches!(result, Err(AuthError::NoFromDomain)));
594 }
595
596 #[tokio::test]
597 async fn authenticate_empty_mail_from_uses_helo() {
598 let mut resolver = MockResolver::new();
599 resolver.add_txt("mail.example.com", vec!["v=spf1 ip4:192.0.2.1 -all"]);
600 resolver.add_txt("_dmarc.example.com", vec!["v=DMARC1; p=none"]);
601
602 let message = b"From: sender@example.com\r\n\r\nBody";
603
604 let auth = EmailAuthenticator::new(resolver, "mx.receiver.com");
605 let result = auth
606 .authenticate(
607 message,
608 "192.0.2.1".parse().unwrap(),
609 "mail.example.com",
610 "", )
612 .await
613 .unwrap();
614
615 assert_eq!(result.spf_domain, "mail.example.com");
617 }
618
619 #[tokio::test]
620 async fn authenticate_spf_fail_dmarc_reject() {
621 let mut resolver = MockResolver::new();
622 resolver.add_txt("example.com", vec!["v=spf1 ip4:10.0.0.1 -all"]);
624 resolver.add_txt("_dmarc.example.com", vec!["v=DMARC1; p=reject"]);
626
627 let message = b"From: sender@example.com\r\n\r\nBody";
628
629 let auth = EmailAuthenticator::new(resolver, "mx.receiver.com");
630 let result = auth
631 .authenticate(
632 message,
633 "192.0.2.99".parse().unwrap(), "mail.example.com",
635 "sender@example.com",
636 )
637 .await
638 .unwrap();
639
640 assert!(matches!(result.spf, SpfResult::Fail { .. }));
641 assert!(matches!(
643 result.dmarc.disposition,
644 crate::dmarc::Disposition::Reject
645 ));
646 }
647
648 #[tokio::test]
649 async fn authenticate_folded_from_header() {
650 let mut resolver = MockResolver::new();
651 resolver.add_txt("example.com", vec!["v=spf1 ip4:192.0.2.1 -all"]);
652 resolver.add_txt("_dmarc.example.com", vec!["v=DMARC1; p=none"]);
653
654 let message = b"From: \"Very Long\r\n Display Name\" <sender@example.com>\r\nSubject: Test\r\n\r\nBody";
656
657 let auth = EmailAuthenticator::new(resolver, "mx.receiver.com");
658 let result = auth
659 .authenticate(
660 message,
661 "192.0.2.1".parse().unwrap(),
662 "mail.example.com",
663 "sender@example.com",
664 )
665 .await
666 .unwrap();
667
668 assert_eq!(result.from_domain, "example.com");
669 }
670
671 #[tokio::test]
672 async fn authenticate_from_with_rfc5322_comment() {
673 let mut resolver = MockResolver::new();
674 resolver.add_txt("example.com", vec!["v=spf1 ip4:192.0.2.1 -all"]);
675 resolver.add_txt("_dmarc.example.com", vec!["v=DMARC1; p=none"]);
676
677 let message = b"From: sender(comment)@example.com\r\n\r\nBody";
678
679 let auth = EmailAuthenticator::new(resolver, "mx.receiver.com");
680 let result = auth
681 .authenticate(
682 message,
683 "192.0.2.1".parse().unwrap(),
684 "mail.example.com",
685 "sender@example.com",
686 )
687 .await
688 .unwrap();
689
690 assert_eq!(result.from_domain, "example.com");
691 }
692}