Skip to main content

email_auth/
auth.rs

1use std::net::IpAddr;
2
3use crate::common::dns::DnsResolver;
4use crate::common::domain;
5use crate::dkim::{DkimResult, DkimVerifier};
6use crate::dmarc::{DmarcEvaluator, DmarcResult};
7use crate::spf::{self, SpfResult};
8
9/// Combined authentication result from SPF, DKIM, and DMARC.
10#[derive(Debug)]
11pub struct AuthenticationResult {
12    pub spf: SpfResult,
13    pub dkim: Vec<DkimResult>,
14    pub dmarc: DmarcResult,
15    pub from_domain: String,
16    pub spf_domain: String,
17}
18
19/// Combined email authenticator running SPF, DKIM, and DMARC in sequence.
20pub struct EmailAuthenticator<R: DnsResolver> {
21    resolver: R,
22    clock_skew: u64,
23    receiver: String,
24}
25
26impl<R: DnsResolver> EmailAuthenticator<R> {
27    pub fn new(resolver: R, receiver: impl Into<String>) -> Self {
28        Self {
29            resolver,
30            clock_skew: 300,
31            receiver: receiver.into(),
32        }
33    }
34
35    pub fn with_clock_skew(mut self, skew: u64) -> Self {
36        self.clock_skew = skew;
37        self
38    }
39
40    /// Authenticate a raw RFC 5322 message.
41    ///
42    /// - `message`: raw RFC 5322 bytes
43    /// - `client_ip`: connecting client IP
44    /// - `helo`: EHLO/HELO identity
45    /// - `mail_from`: MAIL FROM address (envelope sender)
46    pub async fn authenticate(
47        &self,
48        message: &[u8],
49        client_ip: IpAddr,
50        helo: &str,
51        mail_from: &str,
52    ) -> Result<AuthenticationResult, AuthError> {
53        // 1. Parse message into headers + body
54        let (headers, body) = split_message(message);
55        let parsed_headers = parse_headers(&headers);
56
57        // 2. Extract From domain
58        let from_domain = extract_from_domain(&parsed_headers)
59            .ok_or(AuthError::NoFromDomain)?;
60
61        // 3. Determine SPF domain (MAIL FROM domain, or HELO if empty)
62        let spf_domain = if mail_from.is_empty() || !mail_from.contains('@') {
63            helo.to_string()
64        } else {
65            domain::domain_from_email(mail_from)
66                .unwrap_or(helo)
67                .to_string()
68        };
69
70        // 4. Run SPF
71        let spf_result = spf::check_host(
72            &self.resolver,
73            client_ip,
74            helo,
75            mail_from,
76            &spf_domain,
77            &self.receiver,
78        )
79        .await;
80
81        // 5. Run DKIM
82        let header_pairs: Vec<(&str, &str)> = parsed_headers
83            .iter()
84            .map(|(n, v)| (n.as_str(), v.as_str()))
85            .collect();
86        let dkim_verifier = DkimVerifier::new(&self.resolver)
87            .clock_skew(self.clock_skew);
88        let dkim_results = dkim_verifier.verify_message(&header_pairs, body).await;
89
90        // 6. Run DMARC
91        let dmarc_evaluator = DmarcEvaluator::new(&self.resolver);
92        let dmarc_result = dmarc_evaluator
93            .evaluate(&from_domain, &spf_result, &spf_domain, &dkim_results)
94            .await;
95
96        Ok(AuthenticationResult {
97            spf: spf_result,
98            dkim: dkim_results,
99            dmarc: dmarc_result,
100            from_domain,
101            spf_domain,
102        })
103    }
104}
105
106/// Error from authentication pipeline.
107#[derive(Debug, Clone, PartialEq, Eq)]
108pub enum AuthError {
109    /// No RFC 5322 From header with a valid domain found.
110    NoFromDomain,
111}
112
113impl std::fmt::Display for AuthError {
114    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115        match self {
116            AuthError::NoFromDomain => write!(f, "no From domain found in message"),
117        }
118    }
119}
120
121impl std::error::Error for AuthError {}
122
123// --- Message parsing ---
124
125/// Split raw message bytes into (headers_bytes, body_bytes).
126/// Splits at `\r\n\r\n` (preferred) or `\n\n` (fallback).
127fn split_message(message: &[u8]) -> (&[u8], &[u8]) {
128    // Look for \r\n\r\n first
129    if let Some(pos) = find_bytes(message, b"\r\n\r\n") {
130        return (&message[..pos], &message[pos + 4..]);
131    }
132    // Fallback: \n\n
133    if let Some(pos) = find_bytes(message, b"\n\n") {
134        return (&message[..pos], &message[pos + 2..]);
135    }
136    // No body separator — entire message is headers
137    (message, b"")
138}
139
140/// Find a byte pattern in a byte slice.
141fn find_bytes(haystack: &[u8], needle: &[u8]) -> Option<usize> {
142    haystack
143        .windows(needle.len())
144        .position(|w| w == needle)
145}
146
147/// Parse raw header bytes into (name, value) pairs.
148/// Handles folded headers (lines starting with SP/HTAB are continuations).
149fn parse_headers(header_bytes: &[u8]) -> Vec<(String, String)> {
150    // Convert to string, handling non-UTF-8 gracefully at the byte level.
151    // Headers are ASCII per RFC 5322 — parse as bytes, find colons.
152    let mut headers: Vec<(String, String)> = Vec::new();
153
154    // Split into lines, preserving CRLF semantics
155    let text = String::from_utf8_lossy(header_bytes);
156    let mut lines: Vec<&str> = Vec::new();
157
158    // Split on \r\n first, then on bare \n for remaining
159    let mut remaining = text.as_ref();
160    while !remaining.is_empty() {
161        if let Some(pos) = remaining.find("\r\n") {
162            lines.push(&remaining[..pos]);
163            remaining = &remaining[pos + 2..];
164        } else if let Some(pos) = remaining.find('\n') {
165            lines.push(&remaining[..pos]);
166            remaining = &remaining[pos + 1..];
167        } else {
168            lines.push(remaining);
169            break;
170        }
171    }
172
173    // Group folded lines
174    let mut current_line = String::new();
175    for line in &lines {
176        if line.starts_with(' ') || line.starts_with('\t') {
177            // Continuation — append with original folding whitespace
178            if !current_line.is_empty() {
179                current_line.push_str("\r\n");
180                current_line.push_str(line);
181            }
182        } else {
183            // New header — flush previous
184            if !current_line.is_empty() {
185                if let Some((name, value)) = split_header(&current_line) {
186                    headers.push((name, value));
187                }
188            }
189            current_line = line.to_string();
190        }
191    }
192    // Flush last header
193    if !current_line.is_empty() {
194        if let Some((name, value)) = split_header(&current_line) {
195            headers.push((name, value));
196        }
197    }
198
199    headers
200}
201
202/// Split a header line into (name, value) at the first colon.
203fn split_header(line: &str) -> Option<(String, String)> {
204    let pos = line.find(':')?;
205    let name = line[..pos].to_string();
206    let value = line[pos + 1..].to_string();
207    Some((name, value))
208}
209
210/// Extract the From domain from parsed headers.
211/// Handles RFC 5322 comments, angle brackets, and display names.
212fn extract_from_domain(headers: &[(String, String)]) -> Option<String> {
213    // Find the From header
214    let from_value = headers
215        .iter()
216        .find(|(name, _)| name.eq_ignore_ascii_case("from"))?;
217
218    let value = &from_value.1;
219
220    // Unfold (replace \r\n + WSP with single space)
221    let unfolded = unfold(value);
222
223    // Strip RFC 5322 comments (parenthesized, possibly nested)
224    let stripped = strip_comments(&unfolded);
225
226    // Extract email address
227    let email = extract_email_address(&stripped)?;
228
229    // Get domain from email
230    let domain = domain::domain_from_email(&email)?;
231    Some(domain.to_lowercase())
232}
233
234/// Unfold header value: replace \r\n followed by SP/HTAB with single space.
235fn unfold(value: &str) -> String {
236    let mut result = String::with_capacity(value.len());
237    let mut chars = value.chars().peekable();
238    while let Some(c) = chars.next() {
239        if c == '\r' {
240            if chars.peek() == Some(&'\n') {
241                chars.next(); // consume \n
242                if matches!(chars.peek(), Some(' ' | '\t')) {
243                    result.push(' ');
244                    chars.next(); // consume the WSP
245                } else {
246                    result.push('\r');
247                    result.push('\n');
248                }
249            } else {
250                result.push(c);
251            }
252        } else {
253            result.push(c);
254        }
255    }
256    result
257}
258
259/// Strip RFC 5322 comments (parenthesized text with nesting support).
260fn strip_comments(value: &str) -> String {
261    let mut result = String::with_capacity(value.len());
262    let mut depth = 0u32;
263    let mut escaped = false;
264
265    for c in value.chars() {
266        if escaped {
267            if depth == 0 {
268                result.push(c);
269            }
270            escaped = false;
271            continue;
272        }
273        if c == '\\' {
274            escaped = true;
275            if depth == 0 {
276                result.push(c);
277            }
278            continue;
279        }
280        if c == '(' {
281            depth += 1;
282            continue;
283        }
284        if c == ')' && depth > 0 {
285            depth -= 1;
286            continue;
287        }
288        if depth == 0 {
289            result.push(c);
290        }
291    }
292    result
293}
294
295/// Extract an email address from a From header value.
296/// Checks for angle brackets first (handles display names, commas in quoted strings).
297/// Falls back to bare address.
298fn extract_email_address(value: &str) -> Option<String> {
299    let trimmed = value.trim();
300
301    // Check for angle-bracket form: ... <addr>
302    if let Some(start) = trimmed.rfind('<') {
303        if let Some(end) = trimmed[start..].find('>') {
304            let addr = trimmed[start + 1..start + end].trim();
305            if !addr.is_empty() && addr.contains('@') {
306                return Some(addr.to_string());
307            }
308        }
309    }
310
311    // Bare address form
312    let addr = trimmed.trim();
313    if addr.contains('@') {
314        return Some(addr.to_string());
315    }
316
317    None
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323    use crate::common::dns::{DnsError, MxRecord};
324    use std::collections::HashMap;
325    use std::net::{Ipv4Addr, Ipv6Addr};
326
327    // --- MockResolver for integration tests ---
328    #[derive(Clone)]
329    struct MockResolver {
330        txt: HashMap<String, Vec<String>>,
331        a: HashMap<String, Vec<Ipv4Addr>>,
332    }
333
334    impl MockResolver {
335        fn new() -> Self {
336            Self {
337                txt: HashMap::new(),
338                a: HashMap::new(),
339            }
340        }
341
342        fn add_txt(&mut self, name: &str, records: Vec<&str>) {
343            self.txt
344                .insert(name.to_string(), records.into_iter().map(String::from).collect());
345        }
346
347    }
348
349    impl DnsResolver for MockResolver {
350        async fn query_txt(&self, name: &str) -> Result<Vec<String>, DnsError> {
351            self.txt
352                .get(name)
353                .cloned()
354                .ok_or(DnsError::NxDomain)
355        }
356        async fn query_a(&self, name: &str) -> Result<Vec<Ipv4Addr>, DnsError> {
357            self.a
358                .get(name)
359                .cloned()
360                .ok_or(DnsError::NxDomain)
361        }
362        async fn query_aaaa(&self, _name: &str) -> Result<Vec<Ipv6Addr>, DnsError> {
363            Err(DnsError::NxDomain)
364        }
365        async fn query_mx(&self, _name: &str) -> Result<Vec<MxRecord>, DnsError> {
366            Err(DnsError::NxDomain)
367        }
368        async fn query_ptr(&self, _ip: &IpAddr) -> Result<Vec<String>, DnsError> {
369            Err(DnsError::NxDomain)
370        }
371        async fn query_exists(&self, name: &str) -> Result<bool, DnsError> {
372            Ok(self.a.contains_key(name))
373        }
374    }
375
376    // --- Message parsing unit tests ---
377
378    #[test]
379    fn split_message_crlf() {
380        let msg = b"From: test@example.com\r\nSubject: hi\r\n\r\nBody here";
381        let (headers, body) = split_message(msg);
382        assert_eq!(headers, b"From: test@example.com\r\nSubject: hi");
383        assert_eq!(body, b"Body here");
384    }
385
386    #[test]
387    fn split_message_lf_fallback() {
388        let msg = b"From: test@example.com\nSubject: hi\n\nBody here";
389        let (headers, body) = split_message(msg);
390        assert_eq!(headers, b"From: test@example.com\nSubject: hi");
391        assert_eq!(body, b"Body here");
392    }
393
394    #[test]
395    fn split_message_no_body() {
396        let msg = b"From: test@example.com\r\nSubject: hi";
397        let (headers, body) = split_message(msg);
398        assert_eq!(headers, b"From: test@example.com\r\nSubject: hi");
399        assert_eq!(body, b"");
400    }
401
402    #[test]
403    fn parse_headers_simple() {
404        let raw = b"From: alice@example.com\r\nSubject: Hello";
405        let headers = parse_headers(raw);
406        assert_eq!(headers.len(), 2);
407        assert_eq!(headers[0].0, "From");
408        assert_eq!(headers[0].1, " alice@example.com");
409        assert_eq!(headers[1].0, "Subject");
410        assert_eq!(headers[1].1, " Hello");
411    }
412
413    #[test]
414    fn parse_headers_folded() {
415        let raw = b"Subject: This is a long\r\n subject line\r\nFrom: test@example.com";
416        let headers = parse_headers(raw);
417        assert_eq!(headers.len(), 2);
418        assert_eq!(headers[0].0, "Subject");
419        assert!(headers[0].1.contains("long"));
420        assert!(headers[0].1.contains("subject line"));
421        assert_eq!(headers[1].0, "From");
422    }
423
424    #[test]
425    fn parse_headers_bare_lf() {
426        let raw = b"From: alice@example.com\nSubject: Hello";
427        let headers = parse_headers(raw);
428        assert_eq!(headers.len(), 2);
429    }
430
431    #[test]
432    fn extract_from_angle_brackets() {
433        let headers = vec![
434            ("From".to_string(), " \"John Smith\" <john@example.com>".to_string()),
435        ];
436        assert_eq!(
437            extract_from_domain(&headers),
438            Some("example.com".to_string())
439        );
440    }
441
442    #[test]
443    fn extract_from_bare_address() {
444        let headers = vec![
445            ("From".to_string(), " alice@example.org".to_string()),
446        ];
447        assert_eq!(
448            extract_from_domain(&headers),
449            Some("example.org".to_string())
450        );
451    }
452
453    #[test]
454    fn extract_from_with_comment() {
455        let headers = vec![
456            ("From".to_string(), " alice(comment)@example.com".to_string()),
457        ];
458        assert_eq!(
459            extract_from_domain(&headers),
460            Some("example.com".to_string())
461        );
462    }
463
464    #[test]
465    fn extract_from_nested_comments() {
466        let headers = vec![
467            ("From".to_string(), " alice(nested (comment))@example.com".to_string()),
468        ];
469        assert_eq!(
470            extract_from_domain(&headers),
471            Some("example.com".to_string())
472        );
473    }
474
475    #[test]
476    fn extract_from_display_name_with_comma() {
477        // "Smith, John" <j@example.com> — comma must not split
478        let headers = vec![
479            ("From".to_string(), " \"Smith, John\" <j@example.com>".to_string()),
480        ];
481        assert_eq!(
482            extract_from_domain(&headers),
483            Some("example.com".to_string())
484        );
485    }
486
487    #[test]
488    fn extract_from_case_insensitive() {
489        let headers = vec![
490            ("from".to_string(), " alice@EXAMPLE.COM".to_string()),
491        ];
492        assert_eq!(
493            extract_from_domain(&headers),
494            Some("example.com".to_string())
495        );
496    }
497
498    #[test]
499    fn extract_from_missing() {
500        let headers = vec![
501            ("Subject".to_string(), " Hello".to_string()),
502        ];
503        assert_eq!(extract_from_domain(&headers), None);
504    }
505
506    #[test]
507    fn extract_from_no_at() {
508        let headers = vec![
509            ("From".to_string(), " invalid-address".to_string()),
510        ];
511        assert_eq!(extract_from_domain(&headers), None);
512    }
513
514    #[test]
515    fn strip_comments_basic() {
516        assert_eq!(strip_comments("alice(test)@example.com"), "alice@example.com");
517    }
518
519    #[test]
520    fn strip_comments_nested() {
521        assert_eq!(
522            strip_comments("alice(a (b) c)@example.com"),
523            "alice@example.com"
524        );
525    }
526
527    #[test]
528    fn strip_comments_escaped_paren() {
529        assert_eq!(strip_comments("alice(\\))@example.com"), "alice@example.com");
530    }
531
532    #[test]
533    fn unfold_crlf_sp() {
534        assert_eq!(unfold("hello\r\n world"), "hello world");
535    }
536
537    #[test]
538    fn unfold_crlf_tab() {
539        assert_eq!(unfold("hello\r\n\tworld"), "hello world");
540    }
541
542    #[test]
543    fn unfold_no_fold() {
544        assert_eq!(unfold("hello world"), "hello world");
545    }
546
547    // --- Integration tests ---
548
549    #[tokio::test]
550    async fn authenticate_full_pipeline() {
551        let mut resolver = MockResolver::new();
552
553        // SPF record for example.com: allow 192.0.2.1
554        resolver.add_txt("example.com", vec!["v=spf1 ip4:192.0.2.1 -all"]);
555
556        // DMARC record
557        resolver.add_txt("_dmarc.example.com", vec!["v=DMARC1; p=reject; adkim=r; aspf=r"]);
558
559        let message = b"From: sender@example.com\r\nSubject: Test\r\n\r\nBody";
560
561        let auth = EmailAuthenticator::new(resolver, "mx.receiver.com");
562        let result = auth
563            .authenticate(
564                message,
565                "192.0.2.1".parse().unwrap(),
566                "mail.example.com",
567                "sender@example.com",
568            )
569            .await
570            .unwrap();
571
572        assert_eq!(result.from_domain, "example.com");
573        assert_eq!(result.spf_domain, "example.com");
574        assert!(matches!(result.spf, SpfResult::Pass));
575        // No DKIM signatures → DkimResult::None
576        assert_eq!(result.dkim.len(), 1);
577        assert!(matches!(result.dkim[0], DkimResult::None));
578    }
579
580    #[tokio::test]
581    async fn authenticate_no_from_header() {
582        let resolver = MockResolver::new();
583        let message = b"Subject: No from\r\n\r\nBody";
584        let auth = EmailAuthenticator::new(resolver, "mx.receiver.com");
585        let result = auth
586            .authenticate(
587                message,
588                "192.0.2.1".parse().unwrap(),
589                "mail.example.com",
590                "sender@example.com",
591            )
592            .await;
593        assert!(matches!(result, Err(AuthError::NoFromDomain)));
594    }
595
596    #[tokio::test]
597    async fn authenticate_empty_mail_from_uses_helo() {
598        let mut resolver = MockResolver::new();
599        resolver.add_txt("mail.example.com", vec!["v=spf1 ip4:192.0.2.1 -all"]);
600        resolver.add_txt("_dmarc.example.com", vec!["v=DMARC1; p=none"]);
601
602        let message = b"From: sender@example.com\r\n\r\nBody";
603
604        let auth = EmailAuthenticator::new(resolver, "mx.receiver.com");
605        let result = auth
606            .authenticate(
607                message,
608                "192.0.2.1".parse().unwrap(),
609                "mail.example.com",
610                "", // empty MAIL FROM
611            )
612            .await
613            .unwrap();
614
615        // SPF domain should be HELO domain
616        assert_eq!(result.spf_domain, "mail.example.com");
617    }
618
619    #[tokio::test]
620    async fn authenticate_spf_fail_dmarc_reject() {
621        let mut resolver = MockResolver::new();
622        // SPF: only allow 10.0.0.1
623        resolver.add_txt("example.com", vec!["v=spf1 ip4:10.0.0.1 -all"]);
624        // DMARC: reject policy
625        resolver.add_txt("_dmarc.example.com", vec!["v=DMARC1; p=reject"]);
626
627        let message = b"From: sender@example.com\r\n\r\nBody";
628
629        let auth = EmailAuthenticator::new(resolver, "mx.receiver.com");
630        let result = auth
631            .authenticate(
632                message,
633                "192.0.2.99".parse().unwrap(), // not authorized
634                "mail.example.com",
635                "sender@example.com",
636            )
637            .await
638            .unwrap();
639
640        assert!(matches!(result.spf, SpfResult::Fail { .. }));
641        // DMARC should apply reject since SPF failed and no DKIM
642        assert!(matches!(
643            result.dmarc.disposition,
644            crate::dmarc::Disposition::Reject
645        ));
646    }
647
648    #[tokio::test]
649    async fn authenticate_folded_from_header() {
650        let mut resolver = MockResolver::new();
651        resolver.add_txt("example.com", vec!["v=spf1 ip4:192.0.2.1 -all"]);
652        resolver.add_txt("_dmarc.example.com", vec!["v=DMARC1; p=none"]);
653
654        // From header is folded across lines
655        let message = b"From: \"Very Long\r\n Display Name\" <sender@example.com>\r\nSubject: Test\r\n\r\nBody";
656
657        let auth = EmailAuthenticator::new(resolver, "mx.receiver.com");
658        let result = auth
659            .authenticate(
660                message,
661                "192.0.2.1".parse().unwrap(),
662                "mail.example.com",
663                "sender@example.com",
664            )
665            .await
666            .unwrap();
667
668        assert_eq!(result.from_domain, "example.com");
669    }
670
671    #[tokio::test]
672    async fn authenticate_from_with_rfc5322_comment() {
673        let mut resolver = MockResolver::new();
674        resolver.add_txt("example.com", vec!["v=spf1 ip4:192.0.2.1 -all"]);
675        resolver.add_txt("_dmarc.example.com", vec!["v=DMARC1; p=none"]);
676
677        let message = b"From: sender(comment)@example.com\r\n\r\nBody";
678
679        let auth = EmailAuthenticator::new(resolver, "mx.receiver.com");
680        let result = auth
681            .authenticate(
682                message,
683                "192.0.2.1".parse().unwrap(),
684                "mail.example.com",
685                "sender@example.com",
686            )
687            .await
688            .unwrap();
689
690        assert_eq!(result.from_domain, "example.com");
691    }
692}