mail_auth/common/
headers.rs

1/*
2 * SPDX-FileCopyrightText: 2020 Stalwart Labs LLC <hello@stalw.art>
3 *
4 * SPDX-License-Identifier: Apache-2.0 OR MIT
5 */
6
7use std::{
8    iter::{Enumerate, Peekable},
9    slice::Iter,
10};
11
12impl<'x, T> Header<'x, T> {
13    pub fn new(name: &'x [u8], value: &'x [u8], header: T) -> Self {
14        Header {
15            name,
16            value,
17            header,
18        }
19    }
20}
21
22pub trait HeaderStream<'x> {
23    fn next_header(&mut self) -> Option<(&'x [u8], &'x [u8])>;
24    fn body(&mut self) -> &'x [u8];
25}
26
27pub(crate) struct ChainedHeaderIterator<'x, T: Iterator<Item = &'x [u8]>> {
28    parts: T,
29    iter: HeaderIterator<'x>,
30}
31
32pub(crate) struct HeaderIterator<'x> {
33    message: &'x [u8],
34    iter: Peekable<Enumerate<Iter<'x, u8>>>,
35    start_pos: usize,
36}
37
38pub(crate) struct HeaderParser<'x> {
39    message: &'x [u8],
40    iter: Peekable<Enumerate<Iter<'x, u8>>>,
41    start_pos: usize,
42    pub num_received: usize,
43    pub has_message_id: bool,
44    pub has_date: bool,
45}
46
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub(crate) enum AuthenticatedHeader<'x> {
49    Ds(&'x [u8]),
50    Aar(&'x [u8]),
51    Ams(&'x [u8]),
52    As(&'x [u8]),
53    From(&'x [u8]),
54    Other(&'x [u8]),
55}
56
57#[derive(Debug, Clone, PartialEq, Eq)]
58pub struct Header<'x, T> {
59    pub name: &'x [u8],
60    pub value: &'x [u8],
61    pub header: T,
62}
63
64impl<'x> HeaderParser<'x> {
65    pub fn new(message: &'x [u8]) -> Self {
66        HeaderParser {
67            message,
68            iter: message.iter().enumerate().peekable(),
69            start_pos: 0,
70            num_received: 0,
71            has_message_id: false,
72            has_date: false,
73        }
74    }
75
76    pub fn body_offset(&mut self) -> Option<usize> {
77        self.iter.peek().map(|(pos, _)| *pos)
78    }
79}
80
81impl<'x> HeaderIterator<'x> {
82    pub fn new(message: &'x [u8]) -> Self {
83        HeaderIterator {
84            message,
85            iter: message.iter().enumerate().peekable(),
86            start_pos: 0,
87        }
88    }
89
90    pub fn seek_start(&mut self) {
91        while let Some((_, ch)) = self.iter.peek() {
92            if !ch.is_ascii_whitespace() {
93                break;
94            } else {
95                self.iter.next();
96            }
97        }
98    }
99
100    pub fn body_offset(&mut self) -> Option<usize> {
101        self.iter.peek().map(|(pos, _)| *pos)
102    }
103}
104
105impl<'x> HeaderStream<'x> for HeaderIterator<'x> {
106    fn next_header(&mut self) -> Option<(&'x [u8], &'x [u8])> {
107        self.next()
108    }
109
110    fn body(&mut self) -> &'x [u8] {
111        self.body_offset()
112            .and_then(|offset| self.message.get(offset..))
113            .unwrap_or_default()
114    }
115}
116
117impl<'x> Iterator for HeaderIterator<'x> {
118    type Item = (&'x [u8], &'x [u8]);
119
120    fn next(&mut self) -> Option<Self::Item> {
121        let mut colon_pos = usize::MAX;
122        let mut last_ch = 0;
123
124        while let Some((pos, &ch)) = self.iter.next() {
125            if colon_pos == usize::MAX {
126                match ch {
127                    b':' => {
128                        colon_pos = pos;
129                    }
130                    b'\n' => {
131                        if last_ch == b'\r' || self.start_pos == pos {
132                            // End of headers
133                            return None;
134                        } else if self
135                            .iter
136                            .peek()
137                            .is_none_or(|(_, next_byte)| ![b' ', b'\t'].contains(next_byte))
138                        {
139                            // Invalid header, return anyway.
140                            let header_name = self
141                                .message
142                                .get(self.start_pos..pos + 1)
143                                .unwrap_or_default();
144                            self.start_pos = pos + 1;
145                            return Some((header_name, b""));
146                        }
147                    }
148                    _ => (),
149                }
150            } else if ch == b'\n'
151                && self
152                    .iter
153                    .peek()
154                    .is_none_or(|(_, next_byte)| ![b' ', b'\t'].contains(next_byte))
155            {
156                let header_name = self
157                    .message
158                    .get(self.start_pos..colon_pos)
159                    .unwrap_or_default();
160                let header_value = self.message.get(colon_pos + 1..pos + 1).unwrap_or_default();
161
162                self.start_pos = pos + 1;
163
164                return Some((header_name, header_value));
165            }
166
167            last_ch = ch;
168        }
169
170        None
171    }
172}
173
174impl<'x, T: Iterator<Item = &'x [u8]>> ChainedHeaderIterator<'x, T> {
175    pub fn new(mut parts: T) -> Self {
176        ChainedHeaderIterator {
177            iter: HeaderIterator::new(parts.next().unwrap()),
178            parts,
179        }
180    }
181}
182
183impl<'x, T: Iterator<Item = &'x [u8]>> HeaderStream<'x> for ChainedHeaderIterator<'x, T> {
184    fn next_header(&mut self) -> Option<(&'x [u8], &'x [u8])> {
185        if let Some(header) = self.iter.next_header() {
186            Some(header)
187        } else {
188            self.iter = HeaderIterator::new(self.parts.next()?);
189            self.iter.next_header()
190        }
191    }
192
193    fn body(&mut self) -> &'x [u8] {
194        self.iter.body()
195    }
196}
197
198impl<'x> Iterator for HeaderParser<'x> {
199    type Item = (AuthenticatedHeader<'x>, &'x [u8]);
200
201    fn next(&mut self) -> Option<Self::Item> {
202        let mut colon_pos = usize::MAX;
203        let mut last_ch = 0;
204
205        let mut token_start = usize::MAX;
206        let mut token_end = usize::MAX;
207
208        let mut hash: u64 = 0;
209        let mut hash_shift = 0;
210
211        while let Some((pos, &ch)) = self.iter.next() {
212            if colon_pos == usize::MAX {
213                match ch {
214                    b':' => {
215                        colon_pos = pos;
216                    }
217                    b'\n' => {
218                        if last_ch == b'\r' || self.start_pos == pos {
219                            // End of headers
220                            return None;
221                        } else if self
222                            .iter
223                            .peek()
224                            .is_none_or(|(_, next_byte)| ![b' ', b'\t'].contains(next_byte))
225                        {
226                            // Invalid header, return anyway.
227                            let header_name = self
228                                .message
229                                .get(self.start_pos..pos + 1)
230                                .unwrap_or_default();
231                            self.start_pos = pos + 1;
232                            return Some((AuthenticatedHeader::Other(header_name), b""));
233                        }
234                    }
235                    b' ' | b'\t' | b'\r' => (),
236                    b'A'..=b'Z' => {
237                        if hash_shift < 64 {
238                            hash |= ((ch - b'A' + b'a') as u64) << hash_shift;
239                            hash_shift += 8;
240
241                            if token_start == usize::MAX {
242                                token_start = pos;
243                            }
244                        }
245                        token_end = pos;
246                    }
247                    b'a'..=b'z' | b'-' => {
248                        if hash_shift < 64 {
249                            hash |= (ch as u64) << hash_shift;
250                            hash_shift += 8;
251
252                            if token_start == usize::MAX {
253                                token_start = pos;
254                            }
255                        }
256                        token_end = pos;
257                    }
258                    _ => {
259                        hash = u64::MAX;
260                    }
261                }
262            } else if ch == b'\n'
263                && self
264                    .iter
265                    .peek()
266                    .is_none_or(|(_, next_byte)| ![b' ', b'\t'].contains(next_byte))
267            {
268                let header_name = self
269                    .message
270                    .get(self.start_pos..colon_pos)
271                    .unwrap_or_default();
272                let header_value = self.message.get(colon_pos + 1..pos + 1).unwrap_or_default();
273                let header_name = match hash {
274                    RECEIVED if token_start + 8 == token_end + 1 => {
275                        self.num_received += 1;
276                        AuthenticatedHeader::Other(header_name)
277                    }
278                    FROM => AuthenticatedHeader::From(header_name),
279                    AS => AuthenticatedHeader::As(header_name),
280                    AAR if self
281                        .message
282                        .get(token_start + 8..token_end + 1)
283                        .unwrap_or_default()
284                        .eq_ignore_ascii_case(b"entication-Results") =>
285                    {
286                        AuthenticatedHeader::Aar(header_name)
287                    }
288                    AMS if self
289                        .message
290                        .get(token_start + 8..token_end + 1)
291                        .unwrap_or_default()
292                        .eq_ignore_ascii_case(b"age-Signature") =>
293                    {
294                        AuthenticatedHeader::Ams(header_name)
295                    }
296                    DKIM if self
297                        .message
298                        .get(token_start + 8..token_end + 1)
299                        .unwrap_or_default()
300                        .eq_ignore_ascii_case(b"nature") =>
301                    {
302                        AuthenticatedHeader::Ds(header_name)
303                    }
304                    MSGID
305                        if self
306                            .message
307                            .get(token_start + 8..token_end + 1)
308                            .unwrap_or_default()
309                            .eq_ignore_ascii_case(b"id") =>
310                    {
311                        self.has_message_id = true;
312                        AuthenticatedHeader::Other(header_name)
313                    }
314                    DATE => {
315                        self.has_date = true;
316                        AuthenticatedHeader::Other(header_name)
317                    }
318                    _ => AuthenticatedHeader::Other(header_name),
319                };
320
321                self.start_pos = pos + 1;
322
323                return Some((header_name, header_value));
324            }
325
326            last_ch = ch;
327        }
328
329        None
330    }
331}
332
333pub trait HeaderWriter: Sized {
334    fn write_header(&self, writer: &mut impl Writer);
335    fn to_header(&self) -> String {
336        let mut buf = Vec::new();
337        self.write_header(&mut buf);
338        String::from_utf8(buf).unwrap()
339    }
340}
341
342pub trait Writable {
343    fn write(self, writer: &mut impl Writer);
344}
345
346impl Writable for &[u8] {
347    fn write(self, writer: &mut impl Writer) {
348        writer.write(self);
349    }
350}
351
352pub trait Writer {
353    fn write(&mut self, buf: &[u8]);
354
355    fn write_len(&mut self, buf: &[u8], len: &mut usize) {
356        self.write(buf);
357        *len += buf.len();
358    }
359}
360
361impl Writer for Vec<u8> {
362    fn write(&mut self, buf: &[u8]) {
363        self.extend(buf);
364    }
365}
366
367const FROM: u64 =
368    (b'f' as u64) | ((b'r' as u64) << 8) | ((b'o' as u64) << 16) | ((b'm' as u64) << 24);
369const DKIM: u64 = (b'd' as u64)
370    | ((b'k' as u64) << 8)
371    | ((b'i' as u64) << 16)
372    | ((b'm' as u64) << 24)
373    | ((b'-' as u64) << 32)
374    | ((b's' as u64) << 40)
375    | ((b'i' as u64) << 48)
376    | ((b'g' as u64) << 56);
377const AAR: u64 = (b'a' as u64)
378    | ((b'r' as u64) << 8)
379    | ((b'c' as u64) << 16)
380    | ((b'-' as u64) << 24)
381    | ((b'a' as u64) << 32)
382    | ((b'u' as u64) << 40)
383    | ((b't' as u64) << 48)
384    | ((b'h' as u64) << 56);
385const AMS: u64 = (b'a' as u64)
386    | ((b'r' as u64) << 8)
387    | ((b'c' as u64) << 16)
388    | ((b'-' as u64) << 24)
389    | ((b'm' as u64) << 32)
390    | ((b'e' as u64) << 40)
391    | ((b's' as u64) << 48)
392    | ((b's' as u64) << 56);
393const AS: u64 = (b'a' as u64)
394    | ((b'r' as u64) << 8)
395    | ((b'c' as u64) << 16)
396    | ((b'-' as u64) << 24)
397    | ((b's' as u64) << 32)
398    | ((b'e' as u64) << 40)
399    | ((b'a' as u64) << 48)
400    | ((b'l' as u64) << 56);
401const RECEIVED: u64 = (b'r' as u64)
402    | ((b'e' as u64) << 8)
403    | ((b'c' as u64) << 16)
404    | ((b'e' as u64) << 24)
405    | ((b'i' as u64) << 32)
406    | ((b'v' as u64) << 40)
407    | ((b'e' as u64) << 48)
408    | ((b'd' as u64) << 56);
409const DATE: u64 =
410    (b'd' as u64) | ((b'a' as u64) << 8) | ((b't' as u64) << 16) | ((b'e' as u64) << 24);
411const MSGID: u64 = (b'm' as u64)
412    | ((b'e' as u64) << 8)
413    | ((b's' as u64) << 16)
414    | ((b's' as u64) << 24)
415    | ((b'a' as u64) << 32)
416    | ((b'g' as u64) << 40)
417    | ((b'e' as u64) << 48)
418    | ((b'-' as u64) << 56);
419
420#[cfg(test)]
421mod test {
422    use crate::common::headers::{AuthenticatedHeader, HeaderParser};
423
424    use super::{ChainedHeaderIterator, HeaderIterator, HeaderStream};
425
426    #[test]
427    fn header_iterator() {
428        for (message, headers) in [
429            (
430                "From: a\nTo: b\nEmpty:\nMulti: 1\n 2\nSubject: c\n\nNot-header: ignore\n",
431                vec![
432                    ("From", " a\n"),
433                    ("To", " b\n"),
434                    ("Empty", "\n"),
435                    ("Multi", " 1\n 2\n"),
436                    ("Subject", " c\n"),
437                ],
438            ),
439            (
440                ": a\nTo: b\n \n \nc\n:\nFrom : d\nSubject: e\n\nNot-header: ignore\n",
441                vec![
442                    ("", " a\n"),
443                    ("To", " b\n \n \n"),
444                    ("c\n", ""),
445                    ("", "\n"),
446                    ("From ", " d\n"),
447                    ("Subject", " e\n"),
448                ],
449            ),
450            (
451                concat!(
452                    "A: X\r\n",
453                    "B : Y\t\r\n",
454                    "\tZ  \r\n",
455                    "\r\n",
456                    " C \r\n",
457                    "D \t E\r\n"
458                ),
459                vec![("A", " X\r\n"), ("B ", " Y\t\r\n\tZ  \r\n")],
460            ),
461        ] {
462            assert_eq!(
463                HeaderIterator::new(message.as_bytes())
464                    .map(|(h, v)| {
465                        (
466                            std::str::from_utf8(h).unwrap(),
467                            std::str::from_utf8(v).unwrap(),
468                        )
469                    })
470                    .collect::<Vec<_>>(),
471                headers
472            );
473
474            assert_eq!(
475                HeaderParser::new(message.as_bytes())
476                    .map(|(h, v)| {
477                        (
478                            std::str::from_utf8(match h {
479                                AuthenticatedHeader::Ds(v)
480                                | AuthenticatedHeader::Aar(v)
481                                | AuthenticatedHeader::Ams(v)
482                                | AuthenticatedHeader::As(v)
483                                | AuthenticatedHeader::From(v)
484                                | AuthenticatedHeader::Other(v) => v,
485                            })
486                            .unwrap(),
487                            std::str::from_utf8(v).unwrap(),
488                        )
489                    })
490                    .collect::<Vec<_>>(),
491                headers
492            );
493        }
494    }
495
496    #[test]
497    fn header_parser() {
498        let message = concat!(
499            "ARC-Message-Signature: i=1; a=rsa-sha256;\n",
500            "ARC-Authentication-Results: i=1;\n",
501            "ARC-Seal: i=1; a=rsa-sha256;\n",
502            "DKIM-Signature: v=1; a=rsa-sha256; c=relaxed/simple;\n",
503            "From: jdoe@domain\n",
504            "F r o m : jane@domain.com\n",
505            "ARC-Authentication: i=1;\n",
506            "Received: r1\n",
507            "Received: r2\n",
508            "Received: r3\n",
509            "Received-From: test\n",
510            "Date: date\n",
511            "Message-Id: myid\n",
512            "\nhey",
513        );
514        let mut parser = HeaderParser::new(message.as_bytes());
515        assert_eq!(
516            (&mut parser).map(|(h, _)| { h }).collect::<Vec<_>>(),
517            vec![
518                AuthenticatedHeader::Ams(b"ARC-Message-Signature"),
519                AuthenticatedHeader::Aar(b"ARC-Authentication-Results"),
520                AuthenticatedHeader::As(b"ARC-Seal"),
521                AuthenticatedHeader::Ds(b"DKIM-Signature"),
522                AuthenticatedHeader::From(b"From"),
523                AuthenticatedHeader::From(b"F r o m "),
524                AuthenticatedHeader::Other(b"ARC-Authentication"),
525                AuthenticatedHeader::Other(b"Received"),
526                AuthenticatedHeader::Other(b"Received"),
527                AuthenticatedHeader::Other(b"Received"),
528                AuthenticatedHeader::Other(b"Received-From"),
529                AuthenticatedHeader::Other(b"Date"),
530                AuthenticatedHeader::Other(b"Message-Id"),
531            ]
532        );
533        assert!(parser.has_date);
534        assert!(parser.has_message_id);
535        assert_eq!(parser.num_received, 3);
536    }
537
538    #[test]
539    fn chained_header_iterator() {
540        let parts = [
541            &b"From: a\nTo: b\nEmpty:\nMulti: 1\n 2\n"[..],
542            &b"Subject: c\nReceived: d\n\nhey"[..],
543        ];
544        let mut headers = vec![
545            ("From", " a\n"),
546            ("To", " b\n"),
547            ("Empty", "\n"),
548            ("Multi", " 1\n 2\n"),
549            ("Subject", " c\n"),
550            ("Received", " d\n"),
551        ]
552        .into_iter();
553        let mut it = ChainedHeaderIterator::new(parts.iter().copied());
554
555        while let Some((k, v)) = it.next_header() {
556            assert_eq!(
557                (
558                    std::str::from_utf8(k).unwrap(),
559                    std::str::from_utf8(v).unwrap()
560                ),
561                headers.next().unwrap()
562            );
563        }
564        assert_eq!(it.body(), b"hey");
565    }
566}