mail_auth/spf/
macros.rs

1/*
2 * SPDX-FileCopyrightText: 2020 Stalwart Labs LLC <hello@stalw.art>
3 *
4 * SPDX-License-Identifier: Apache-2.0 OR MIT
5 */
6
7use super::{Macro, Variable, Variables};
8use std::{borrow::Cow, net::IpAddr, time::SystemTime};
9
10impl Macro {
11    pub fn eval<'z, 'x: 'z>(
12        &'z self,
13        vars: &'x Variables<'x>,
14        default: &'x str,
15        fqdn: bool,
16    ) -> Cow<'z, str> {
17        match self {
18            Macro::Literal(literal) => std::str::from_utf8(literal).unwrap_or_default().into(),
19            Macro::Variable {
20                letter,
21                num_parts,
22                reverse,
23                escape,
24                delimiters,
25            } => match vars.get(*letter, *num_parts, *reverse, *escape, fqdn, *delimiters) {
26                Cow::Borrowed(bytes) => std::str::from_utf8(bytes).unwrap_or_default().into(),
27                Cow::Owned(bytes) => String::from_utf8(bytes).unwrap_or_default().into(),
28            },
29            Macro::List(list) => {
30                let mut result = Vec::with_capacity(32);
31                for item in list {
32                    match item {
33                        Macro::Literal(literal) => {
34                            result.extend_from_slice(literal);
35                        }
36                        Macro::Variable {
37                            letter,
38                            num_parts,
39                            reverse,
40                            escape,
41                            delimiters,
42                        } => {
43                            result.extend_from_slice(
44                                vars.get(
45                                    *letter,
46                                    *num_parts,
47                                    *reverse,
48                                    *escape,
49                                    false,
50                                    *delimiters,
51                                )
52                                .as_ref(),
53                            );
54                        }
55                        Macro::List(_) | Macro::None => unreachable!(),
56                    }
57                }
58                if fqdn && !result.is_empty() && result.last().unwrap() != &b'.' {
59                    result.push(b'.');
60                }
61                String::from_utf8(result).unwrap_or_default().into()
62            }
63            Macro::None => default.into(),
64        }
65    }
66
67    pub fn needs_ptr(&self) -> bool {
68        match self {
69            Macro::Variable { letter, .. } => *letter == Variable::ValidatedDomain,
70            Macro::List(list) => list.iter().any(|m| matches!(m, Macro::Variable { letter, .. } if *letter == Variable::ValidatedDomain)),
71            _ => false,
72        }
73    }
74}
75
76impl<'x> Variables<'x> {
77    pub fn new() -> Self {
78        let mut vars = Variables::default();
79        vars.vars[Variable::CurrentTime as usize] = SystemTime::now()
80            .duration_since(SystemTime::UNIX_EPOCH)
81            .map(|d| d.as_secs())
82            .unwrap_or(0)
83            .to_string()
84            .into_bytes()
85            .into();
86        vars
87    }
88
89    pub fn set_ip(&mut self, value: &IpAddr) {
90        let (v, i, c) = match value {
91            IpAddr::V4(ip) => (
92                "in-addr".as_bytes(),
93                ip.to_string().into_bytes(),
94                ip.to_string(),
95            ),
96            IpAddr::V6(ip) => {
97                let mut segments = Vec::with_capacity(63);
98                for segment in ip.segments() {
99                    for &p in format!("{segment:04x}").as_bytes() {
100                        if !segments.is_empty() {
101                            segments.push(b'.');
102                        }
103                        segments.push(p);
104                    }
105                }
106                ("ip6".as_bytes(), segments, ip.to_string())
107            }
108        };
109        self.vars[Variable::IpVersion as usize] = v.into();
110        self.vars[Variable::Ip as usize] = i.into();
111        self.vars[Variable::SmtpIp as usize] = c.into_bytes().into();
112    }
113
114    pub fn set_sender(&mut self, value: impl Into<Cow<'x, [u8]>>) {
115        let value = value.into();
116        for (pos, ch) in value.as_ref().iter().enumerate() {
117            if ch == &b'@' {
118                if pos > 0 {
119                    self.vars[Variable::SenderLocalPart as usize] = match &value {
120                        Cow::Borrowed(value) => (&value[..pos]).into(),
121                        Cow::Owned(value) => value[..pos].to_vec().into(),
122                    };
123                }
124                self.vars[Variable::SenderDomainPart as usize] = match &value {
125                    Cow::Borrowed(value) => (value.get(pos + 1..).unwrap_or_default()).into(),
126                    Cow::Owned(value) => (value.get(pos + 1..).unwrap_or_default()).to_vec().into(),
127                };
128                break;
129            }
130        }
131
132        self.vars[Variable::Sender as usize] = value;
133    }
134
135    pub fn set_helo_domain(&mut self, value: impl Into<Cow<'x, [u8]>>) {
136        self.vars[Variable::HeloDomain as usize] = value.into();
137    }
138
139    pub fn set_host_domain(&mut self, value: impl Into<Cow<'x, [u8]>>) {
140        self.vars[Variable::HostDomain as usize] = value.into();
141    }
142
143    pub fn set_validated_domain(&mut self, value: impl Into<Cow<'x, [u8]>>) {
144        self.vars[Variable::ValidatedDomain as usize] = value.into();
145    }
146
147    pub fn set_domain(&mut self, value: impl Into<Cow<'x, [u8]>>) {
148        self.vars[Variable::Domain as usize] = value.into();
149    }
150
151    pub fn get(
152        &self,
153        name: Variable,
154        num_parts: u32,
155        reverse: bool,
156        escape: bool,
157        fqdn: bool,
158        delimiters: u64,
159    ) -> Cow<'_, [u8]> {
160        let var = self.vars[name as usize].as_ref();
161        if var.is_empty()
162            || (num_parts == 0 && !reverse && !escape && delimiters == 1u64 << (b'.' - b'+'))
163        {
164            return var.into();
165        }
166        let mut parts = Vec::new();
167        let mut parts_len = 0;
168        let mut start_pos = 0;
169
170        for (pos, ch) in var.iter().enumerate() {
171            if (b'+'..=b'_').contains(ch) && (delimiters & (1u64 << (*ch - b'+'))) != 0 {
172                parts_len += pos - start_pos + 1;
173                parts.push(&var[start_pos..pos]);
174                start_pos = pos + 1;
175            }
176        }
177        parts.push(&var[start_pos..var.len()]);
178
179        let num_parts = if num_parts == 0 {
180            parts.len()
181        } else {
182            std::cmp::min(parts.len(), num_parts as usize)
183        };
184
185        let mut result = Vec::with_capacity(parts_len + var.len() - start_pos);
186        if !reverse {
187            for (pos, part) in parts.iter().skip(parts.len() - num_parts).enumerate() {
188                add_part(&mut result, part, pos, escape);
189            }
190        } else {
191            for (pos, part) in parts.iter().rev().skip(parts.len() - num_parts).enumerate() {
192                add_part(&mut result, part, pos, escape);
193            }
194        }
195        if fqdn && result.last().unwrap_or(&0) != &b'.' {
196            result.push(b'.');
197        }
198        result.into()
199    }
200}
201
202#[inline(always)]
203fn add_part(result: &mut Vec<u8>, part: &[u8], pos: usize, escape: bool) {
204    if pos > 0 {
205        result.push(b'.');
206    }
207    if !escape {
208        result.extend_from_slice(part);
209    } else {
210        for ch in part {
211            if ch.is_ascii_alphanumeric() || [b'-', b'.', b'_', b'~'].contains(ch) {
212                result.push(*ch);
213            } else {
214                result.extend_from_slice(format!("%{ch:02x}").as_bytes());
215            }
216        }
217    }
218}
219
220#[cfg(test)]
221mod test {
222    use std::net::IpAddr;
223
224    use crate::spf::{Variables, parse::SPFParser};
225
226    #[test]
227    fn expand_macro() {
228        let mut vars = Variables::new();
229        vars.set_sender("strong-bad@email.example.com".as_bytes());
230        vars.set_ip(&"192.0.2.3".parse::<IpAddr>().unwrap());
231        vars.set_validated_domain("mx.example.org".as_bytes());
232        vars.set_domain("email.example.com".as_bytes());
233        vars.set_helo_domain("....".as_bytes());
234
235        for (macro_string, expansion) in [
236            ("%{s}", "strong-bad@email.example.com"),
237            ("%{o}", "email.example.com"),
238            ("%{d}", "email.example.com"),
239            ("%{d4}", "email.example.com"),
240            ("%{d3}", "email.example.com"),
241            ("%{d2}", "example.com"),
242            ("%{d1}", "com"),
243            ("%{dr}", "com.example.email"),
244            ("%{d2r}", "example.email"),
245            ("%{l}", "strong-bad"),
246            ("%{l-}", "strong.bad"),
247            ("%{lr}", "strong-bad"),
248            ("%{lr-}", "bad.strong"),
249            ("%{l1r-}", "strong"),
250            ("%{p1r}", "mx"),
251            ("%{h3r}", ".."),
252            (
253                "%{ir}.%{v}._spf.%{d2}",
254                "3.2.0.192.in-addr._spf.example.com",
255            ),
256            ("%{lr-}.lp._spf.%{d2}", "bad.strong.lp._spf.example.com"),
257            (
258                "%{lr-}.lp.%{ir}.%{v}._spf.%{d2}",
259                "bad.strong.lp.3.2.0.192.in-addr._spf.example.com",
260            ),
261            (
262                "%{ir}.%{v}.%{l1r-}.lp._spf.%{d2}",
263                "3.2.0.192.in-addr.strong.lp._spf.example.com",
264            ),
265            (
266                "%{d2}.trusted-domains.example.net",
267                "example.com.trusted-domains.example.net",
268            ),
269        ] {
270            let (m, _) = macro_string.as_bytes().iter().macro_string(false).unwrap();
271            assert_eq!(m.eval(&vars, "", false), expansion, "{macro_string:?}");
272        }
273
274        let mut vars = Variables::new();
275        vars.set_sender("strong-bad@email.example.com".as_bytes());
276        vars.set_ip(&"2001:db8::cb01".parse::<IpAddr>().unwrap());
277        vars.set_validated_domain("mx.example.org".as_bytes());
278        vars.set_domain("email.example.com".as_bytes());
279
280        for (macro_string, expansion) in [
281            (
282                "%{ir}.%{v}._spf.%{d2}",
283                concat!(
284                    "1.0.b.c.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.",
285                    "0.0.0.0.0.8.b.d.0.1.0.0.2.ip6._spf.example.com"
286                ),
287            ),
288            ("%{c}", "2001:db8::cb01"),
289            (
290                "%{c} is not one of %{d}'s designated mail servers.",
291                "2001:db8::cb01 is not one of email.example.com's designated mail servers.",
292            ),
293            (
294                "See http://%{d}/why.html?s=%{S}&i=%{C}",
295                concat!(
296                    "See http://email.example.com/why.html?",
297                    "s=strong-bad%40email.example.com&i=2001%3adb8%3a%3acb01"
298                ),
299            ),
300        ] {
301            let (m, _) = macro_string.as_bytes().iter().macro_string(true).unwrap();
302            assert_eq!(m.eval(&vars, "", false), expansion, "{macro_string:?}");
303        }
304    }
305}