mail_auth/spf/
macros.rs

1/*
2 * SPDX-FileCopyrightText: 2020 Stalwart Labs LLC <hello@stalw.art>
3 *
4 * SPDX-License-Identifier: Apache-2.0 OR MIT
5 */
6
7use std::{borrow::Cow, net::IpAddr, time::SystemTime};
8
9use super::{Macro, Variable, Variables};
10
11impl Macro {
12    pub fn eval<'z, 'x: 'z>(
13        &'z self,
14        vars: &'x Variables<'x>,
15        default: &'x str,
16        fqdn: bool,
17    ) -> Cow<'z, str> {
18        match self {
19            Macro::Literal(literal) => std::str::from_utf8(literal).unwrap_or_default().into(),
20            Macro::Variable {
21                letter,
22                num_parts,
23                reverse,
24                escape,
25                delimiters,
26            } => match vars.get(*letter, *num_parts, *reverse, *escape, fqdn, *delimiters) {
27                Cow::Borrowed(bytes) => std::str::from_utf8(bytes).unwrap_or_default().into(),
28                Cow::Owned(bytes) => String::from_utf8(bytes).unwrap_or_default().into(),
29            },
30            Macro::List(list) => {
31                let mut result = Vec::with_capacity(32);
32                for item in list {
33                    match item {
34                        Macro::Literal(literal) => {
35                            result.extend_from_slice(literal);
36                        }
37                        Macro::Variable {
38                            letter,
39                            num_parts,
40                            reverse,
41                            escape,
42                            delimiters,
43                        } => {
44                            result.extend_from_slice(
45                                vars.get(
46                                    *letter,
47                                    *num_parts,
48                                    *reverse,
49                                    *escape,
50                                    false,
51                                    *delimiters,
52                                )
53                                .as_ref(),
54                            );
55                        }
56                        Macro::List(_) | Macro::None => unreachable!(),
57                    }
58                }
59                if fqdn && !result.is_empty() && result.last().unwrap() != &b'.' {
60                    result.push(b'.');
61                }
62                String::from_utf8(result).unwrap_or_default().into()
63            }
64            Macro::None => default.into(),
65        }
66    }
67
68    pub fn needs_ptr(&self) -> bool {
69        match self {
70            Macro::Variable { letter, .. } => *letter == Variable::ValidatedDomain,
71            Macro::List(list) => list.iter().any(|m| matches!(m, Macro::Variable { letter, .. } if *letter == Variable::ValidatedDomain)),
72            _ => false,
73        }
74    }
75}
76
77impl<'x> Variables<'x> {
78    pub fn new() -> Self {
79        let mut vars = Variables::default();
80        vars.vars[Variable::CurrentTime as usize] = SystemTime::now()
81            .duration_since(SystemTime::UNIX_EPOCH)
82            .map(|d| d.as_secs())
83            .unwrap_or(0)
84            .to_string()
85            .into_bytes()
86            .into();
87        vars
88    }
89
90    pub fn set_ip(&mut self, value: &IpAddr) {
91        let (v, i, c) = match value {
92            IpAddr::V4(ip) => (
93                "in-addr".as_bytes(),
94                ip.to_string().into_bytes(),
95                ip.to_string(),
96            ),
97            IpAddr::V6(ip) => {
98                let mut segments = Vec::with_capacity(63);
99                for segment in ip.segments() {
100                    for &p in format!("{segment:04x}").as_bytes() {
101                        if !segments.is_empty() {
102                            segments.push(b'.');
103                        }
104                        segments.push(p);
105                    }
106                }
107                ("ip6".as_bytes(), segments, ip.to_string())
108            }
109        };
110        self.vars[Variable::IpVersion as usize] = v.into();
111        self.vars[Variable::Ip as usize] = i.into();
112        self.vars[Variable::SmtpIp as usize] = c.into_bytes().into();
113    }
114
115    pub fn set_sender(&mut self, value: impl Into<Cow<'x, [u8]>>) {
116        let value = value.into();
117        for (pos, ch) in value.as_ref().iter().enumerate() {
118            if ch == &b'@' {
119                if pos > 0 {
120                    self.vars[Variable::SenderLocalPart as usize] = match &value {
121                        Cow::Borrowed(value) => (&value[..pos]).into(),
122                        Cow::Owned(value) => value[..pos].to_vec().into(),
123                    };
124                }
125                self.vars[Variable::SenderDomainPart as usize] = match &value {
126                    Cow::Borrowed(value) => (value.get(pos + 1..).unwrap_or_default()).into(),
127                    Cow::Owned(value) => (value.get(pos + 1..).unwrap_or_default()).to_vec().into(),
128                };
129                break;
130            }
131        }
132
133        self.vars[Variable::Sender as usize] = value;
134    }
135
136    pub fn set_helo_domain(&mut self, value: impl Into<Cow<'x, [u8]>>) {
137        self.vars[Variable::HeloDomain as usize] = value.into();
138    }
139
140    pub fn set_host_domain(&mut self, value: impl Into<Cow<'x, [u8]>>) {
141        self.vars[Variable::HostDomain as usize] = value.into();
142    }
143
144    pub fn set_validated_domain(&mut self, value: impl Into<Cow<'x, [u8]>>) {
145        self.vars[Variable::ValidatedDomain as usize] = value.into();
146    }
147
148    pub fn set_domain(&mut self, value: impl Into<Cow<'x, [u8]>>) {
149        self.vars[Variable::Domain as usize] = value.into();
150    }
151
152    pub fn get(
153        &self,
154        name: Variable,
155        num_parts: u32,
156        reverse: bool,
157        escape: bool,
158        fqdn: bool,
159        delimiters: u64,
160    ) -> Cow<'_, [u8]> {
161        let var = self.vars[name as usize].as_ref();
162        if var.is_empty()
163            || (num_parts == 0 && !reverse && !escape && delimiters == 1u64 << (b'.' - b'+'))
164        {
165            return var.into();
166        }
167        let mut parts = Vec::new();
168        let mut parts_len = 0;
169        let mut start_pos = 0;
170
171        for (pos, ch) in var.iter().enumerate() {
172            if (b'+'..=b'_').contains(ch) && (delimiters & (1u64 << (*ch - b'+'))) != 0 {
173                parts_len += pos - start_pos + 1;
174                parts.push(&var[start_pos..pos]);
175                start_pos = pos + 1;
176            }
177        }
178        parts.push(&var[start_pos..var.len()]);
179
180        let num_parts = if num_parts == 0 {
181            parts.len()
182        } else {
183            std::cmp::min(parts.len(), num_parts as usize)
184        };
185
186        let mut result = Vec::with_capacity(parts_len + var.len() - start_pos);
187        if !reverse {
188            for (pos, part) in parts.iter().skip(parts.len() - num_parts).enumerate() {
189                add_part(&mut result, part, pos, escape);
190            }
191        } else {
192            for (pos, part) in parts.iter().rev().skip(parts.len() - num_parts).enumerate() {
193                add_part(&mut result, part, pos, escape);
194            }
195        }
196        if fqdn && result.last().unwrap_or(&0) != &b'.' {
197            result.push(b'.');
198        }
199        result.into()
200    }
201}
202
203#[inline(always)]
204fn add_part(result: &mut Vec<u8>, part: &[u8], pos: usize, escape: bool) {
205    if pos > 0 {
206        result.push(b'.');
207    }
208    if !escape {
209        result.extend_from_slice(part);
210    } else {
211        for ch in part {
212            if ch.is_ascii_alphanumeric() || [b'-', b'.', b'_', b'~'].contains(ch) {
213                result.push(*ch);
214            } else {
215                result.extend_from_slice(format!("%{ch:02x}").as_bytes());
216            }
217        }
218    }
219}
220
221#[cfg(test)]
222mod test {
223    use std::net::IpAddr;
224
225    use crate::spf::{parse::SPFParser, Variables};
226
227    #[test]
228    fn expand_macro() {
229        let mut vars = Variables::new();
230        vars.set_sender("strong-bad@email.example.com".as_bytes());
231        vars.set_ip(&"192.0.2.3".parse::<IpAddr>().unwrap());
232        vars.set_validated_domain("mx.example.org".as_bytes());
233        vars.set_domain("email.example.com".as_bytes());
234        vars.set_helo_domain("....".as_bytes());
235
236        for (macro_string, expansion) in [
237            ("%{s}", "strong-bad@email.example.com"),
238            ("%{o}", "email.example.com"),
239            ("%{d}", "email.example.com"),
240            ("%{d4}", "email.example.com"),
241            ("%{d3}", "email.example.com"),
242            ("%{d2}", "example.com"),
243            ("%{d1}", "com"),
244            ("%{dr}", "com.example.email"),
245            ("%{d2r}", "example.email"),
246            ("%{l}", "strong-bad"),
247            ("%{l-}", "strong.bad"),
248            ("%{lr}", "strong-bad"),
249            ("%{lr-}", "bad.strong"),
250            ("%{l1r-}", "strong"),
251            ("%{p1r}", "mx"),
252            ("%{h3r}", ".."),
253            (
254                "%{ir}.%{v}._spf.%{d2}",
255                "3.2.0.192.in-addr._spf.example.com",
256            ),
257            ("%{lr-}.lp._spf.%{d2}", "bad.strong.lp._spf.example.com"),
258            (
259                "%{lr-}.lp.%{ir}.%{v}._spf.%{d2}",
260                "bad.strong.lp.3.2.0.192.in-addr._spf.example.com",
261            ),
262            (
263                "%{ir}.%{v}.%{l1r-}.lp._spf.%{d2}",
264                "3.2.0.192.in-addr.strong.lp._spf.example.com",
265            ),
266            (
267                "%{d2}.trusted-domains.example.net",
268                "example.com.trusted-domains.example.net",
269            ),
270        ] {
271            let (m, _) = macro_string.as_bytes().iter().macro_string(false).unwrap();
272            assert_eq!(m.eval(&vars, "", false), expansion, "{macro_string:?}");
273        }
274
275        let mut vars = Variables::new();
276        vars.set_sender("strong-bad@email.example.com".as_bytes());
277        vars.set_ip(&"2001:db8::cb01".parse::<IpAddr>().unwrap());
278        vars.set_validated_domain("mx.example.org".as_bytes());
279        vars.set_domain("email.example.com".as_bytes());
280
281        for (macro_string, expansion) in [
282            (
283                "%{ir}.%{v}._spf.%{d2}",
284                concat!(
285                    "1.0.b.c.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.",
286                    "0.0.0.0.0.8.b.d.0.1.0.0.2.ip6._spf.example.com"
287                ),
288            ),
289            ("%{c}", "2001:db8::cb01"),
290            (
291                "%{c} is not one of %{d}'s designated mail servers.",
292                "2001:db8::cb01 is not one of email.example.com's designated mail servers.",
293            ),
294            (
295                "See http://%{d}/why.html?s=%{S}&i=%{C}",
296                concat!(
297                    "See http://email.example.com/why.html?",
298                    "s=strong-bad%40email.example.com&i=2001%3adb8%3a%3acb01"
299                ),
300            ),
301        ] {
302            let (m, _) = macro_string.as_bytes().iter().macro_string(true).unwrap();
303            assert_eq!(m.eval(&vars, "", false), expansion, "{macro_string:?}");
304        }
305    }
306}