1use super::{Macro, Variable, Variables};
8use std::{borrow::Cow, net::IpAddr, time::SystemTime};
9
10impl Macro {
11 pub fn eval<'z, 'x: 'z>(
12 &'z self,
13 vars: &'x Variables<'x>,
14 default: &'x str,
15 fqdn: bool,
16 ) -> Cow<'z, str> {
17 match self {
18 Macro::Literal(literal) => std::str::from_utf8(literal).unwrap_or_default().into(),
19 Macro::Variable {
20 letter,
21 num_parts,
22 reverse,
23 escape,
24 delimiters,
25 } => match vars.get(*letter, *num_parts, *reverse, *escape, fqdn, *delimiters) {
26 Cow::Borrowed(bytes) => std::str::from_utf8(bytes).unwrap_or_default().into(),
27 Cow::Owned(bytes) => String::from_utf8(bytes).unwrap_or_default().into(),
28 },
29 Macro::List(list) => {
30 let mut result = Vec::with_capacity(32);
31 for item in list {
32 match item {
33 Macro::Literal(literal) => {
34 result.extend_from_slice(literal);
35 }
36 Macro::Variable {
37 letter,
38 num_parts,
39 reverse,
40 escape,
41 delimiters,
42 } => {
43 result.extend_from_slice(
44 vars.get(
45 *letter,
46 *num_parts,
47 *reverse,
48 *escape,
49 false,
50 *delimiters,
51 )
52 .as_ref(),
53 );
54 }
55 Macro::List(_) | Macro::None => unreachable!(),
56 }
57 }
58 if fqdn && !result.is_empty() && result.last().unwrap() != &b'.' {
59 result.push(b'.');
60 }
61 String::from_utf8(result).unwrap_or_default().into()
62 }
63 Macro::None => default.into(),
64 }
65 }
66
67 pub fn needs_ptr(&self) -> bool {
68 match self {
69 Macro::Variable { letter, .. } => *letter == Variable::ValidatedDomain,
70 Macro::List(list) => list.iter().any(|m| matches!(m, Macro::Variable { letter, .. } if *letter == Variable::ValidatedDomain)),
71 _ => false,
72 }
73 }
74}
75
76impl<'x> Variables<'x> {
77 pub fn new() -> Self {
78 let mut vars = Variables::default();
79 vars.vars[Variable::CurrentTime as usize] = SystemTime::now()
80 .duration_since(SystemTime::UNIX_EPOCH)
81 .map(|d| d.as_secs())
82 .unwrap_or(0)
83 .to_string()
84 .into_bytes()
85 .into();
86 vars
87 }
88
89 pub fn set_ip(&mut self, value: &IpAddr) {
90 let (v, i, c) = match value {
91 IpAddr::V4(ip) => (
92 "in-addr".as_bytes(),
93 ip.to_string().into_bytes(),
94 ip.to_string(),
95 ),
96 IpAddr::V6(ip) => {
97 let mut segments = Vec::with_capacity(63);
98 for segment in ip.segments() {
99 for &p in format!("{segment:04x}").as_bytes() {
100 if !segments.is_empty() {
101 segments.push(b'.');
102 }
103 segments.push(p);
104 }
105 }
106 ("ip6".as_bytes(), segments, ip.to_string())
107 }
108 };
109 self.vars[Variable::IpVersion as usize] = v.into();
110 self.vars[Variable::Ip as usize] = i.into();
111 self.vars[Variable::SmtpIp as usize] = c.into_bytes().into();
112 }
113
114 pub fn set_sender(&mut self, value: impl Into<Cow<'x, [u8]>>) {
115 let value = value.into();
116 for (pos, ch) in value.as_ref().iter().enumerate() {
117 if ch == &b'@' {
118 if pos > 0 {
119 self.vars[Variable::SenderLocalPart as usize] = match &value {
120 Cow::Borrowed(value) => (&value[..pos]).into(),
121 Cow::Owned(value) => value[..pos].to_vec().into(),
122 };
123 }
124 self.vars[Variable::SenderDomainPart as usize] = match &value {
125 Cow::Borrowed(value) => (value.get(pos + 1..).unwrap_or_default()).into(),
126 Cow::Owned(value) => (value.get(pos + 1..).unwrap_or_default()).to_vec().into(),
127 };
128 break;
129 }
130 }
131
132 self.vars[Variable::Sender as usize] = value;
133 }
134
135 pub fn set_helo_domain(&mut self, value: impl Into<Cow<'x, [u8]>>) {
136 self.vars[Variable::HeloDomain as usize] = value.into();
137 }
138
139 pub fn set_host_domain(&mut self, value: impl Into<Cow<'x, [u8]>>) {
140 self.vars[Variable::HostDomain as usize] = value.into();
141 }
142
143 pub fn set_validated_domain(&mut self, value: impl Into<Cow<'x, [u8]>>) {
144 self.vars[Variable::ValidatedDomain as usize] = value.into();
145 }
146
147 pub fn set_domain(&mut self, value: impl Into<Cow<'x, [u8]>>) {
148 self.vars[Variable::Domain as usize] = value.into();
149 }
150
151 pub fn get(
152 &self,
153 name: Variable,
154 num_parts: u32,
155 reverse: bool,
156 escape: bool,
157 fqdn: bool,
158 delimiters: u64,
159 ) -> Cow<'_, [u8]> {
160 let var = self.vars[name as usize].as_ref();
161 if var.is_empty()
162 || (num_parts == 0 && !reverse && !escape && delimiters == 1u64 << (b'.' - b'+'))
163 {
164 return var.into();
165 }
166 let mut parts = Vec::new();
167 let mut parts_len = 0;
168 let mut start_pos = 0;
169
170 for (pos, ch) in var.iter().enumerate() {
171 if (b'+'..=b'_').contains(ch) && (delimiters & (1u64 << (*ch - b'+'))) != 0 {
172 parts_len += pos - start_pos + 1;
173 parts.push(&var[start_pos..pos]);
174 start_pos = pos + 1;
175 }
176 }
177 parts.push(&var[start_pos..var.len()]);
178
179 let num_parts = if num_parts == 0 {
180 parts.len()
181 } else {
182 std::cmp::min(parts.len(), num_parts as usize)
183 };
184
185 let mut result = Vec::with_capacity(parts_len + var.len() - start_pos);
186 if !reverse {
187 for (pos, part) in parts.iter().skip(parts.len() - num_parts).enumerate() {
188 add_part(&mut result, part, pos, escape);
189 }
190 } else {
191 for (pos, part) in parts.iter().rev().skip(parts.len() - num_parts).enumerate() {
192 add_part(&mut result, part, pos, escape);
193 }
194 }
195 if fqdn && result.last().unwrap_or(&0) != &b'.' {
196 result.push(b'.');
197 }
198 result.into()
199 }
200}
201
202#[inline(always)]
203fn add_part(result: &mut Vec<u8>, part: &[u8], pos: usize, escape: bool) {
204 if pos > 0 {
205 result.push(b'.');
206 }
207 if !escape {
208 result.extend_from_slice(part);
209 } else {
210 for ch in part {
211 if ch.is_ascii_alphanumeric() || [b'-', b'.', b'_', b'~'].contains(ch) {
212 result.push(*ch);
213 } else {
214 result.extend_from_slice(format!("%{ch:02x}").as_bytes());
215 }
216 }
217 }
218}
219
220#[cfg(test)]
221mod test {
222 use std::net::IpAddr;
223
224 use crate::spf::{Variables, parse::SPFParser};
225
226 #[test]
227 fn expand_macro() {
228 let mut vars = Variables::new();
229 vars.set_sender("strong-bad@email.example.com".as_bytes());
230 vars.set_ip(&"192.0.2.3".parse::<IpAddr>().unwrap());
231 vars.set_validated_domain("mx.example.org".as_bytes());
232 vars.set_domain("email.example.com".as_bytes());
233 vars.set_helo_domain("....".as_bytes());
234
235 for (macro_string, expansion) in [
236 ("%{s}", "strong-bad@email.example.com"),
237 ("%{o}", "email.example.com"),
238 ("%{d}", "email.example.com"),
239 ("%{d4}", "email.example.com"),
240 ("%{d3}", "email.example.com"),
241 ("%{d2}", "example.com"),
242 ("%{d1}", "com"),
243 ("%{dr}", "com.example.email"),
244 ("%{d2r}", "example.email"),
245 ("%{l}", "strong-bad"),
246 ("%{l-}", "strong.bad"),
247 ("%{lr}", "strong-bad"),
248 ("%{lr-}", "bad.strong"),
249 ("%{l1r-}", "strong"),
250 ("%{p1r}", "mx"),
251 ("%{h3r}", ".."),
252 (
253 "%{ir}.%{v}._spf.%{d2}",
254 "3.2.0.192.in-addr._spf.example.com",
255 ),
256 ("%{lr-}.lp._spf.%{d2}", "bad.strong.lp._spf.example.com"),
257 (
258 "%{lr-}.lp.%{ir}.%{v}._spf.%{d2}",
259 "bad.strong.lp.3.2.0.192.in-addr._spf.example.com",
260 ),
261 (
262 "%{ir}.%{v}.%{l1r-}.lp._spf.%{d2}",
263 "3.2.0.192.in-addr.strong.lp._spf.example.com",
264 ),
265 (
266 "%{d2}.trusted-domains.example.net",
267 "example.com.trusted-domains.example.net",
268 ),
269 ] {
270 let (m, _) = macro_string.as_bytes().iter().macro_string(false).unwrap();
271 assert_eq!(m.eval(&vars, "", false), expansion, "{macro_string:?}");
272 }
273
274 let mut vars = Variables::new();
275 vars.set_sender("strong-bad@email.example.com".as_bytes());
276 vars.set_ip(&"2001:db8::cb01".parse::<IpAddr>().unwrap());
277 vars.set_validated_domain("mx.example.org".as_bytes());
278 vars.set_domain("email.example.com".as_bytes());
279
280 for (macro_string, expansion) in [
281 (
282 "%{ir}.%{v}._spf.%{d2}",
283 concat!(
284 "1.0.b.c.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.",
285 "0.0.0.0.0.8.b.d.0.1.0.0.2.ip6._spf.example.com"
286 ),
287 ),
288 ("%{c}", "2001:db8::cb01"),
289 (
290 "%{c} is not one of %{d}'s designated mail servers.",
291 "2001:db8::cb01 is not one of email.example.com's designated mail servers.",
292 ),
293 (
294 "See http://%{d}/why.html?s=%{S}&i=%{C}",
295 concat!(
296 "See http://email.example.com/why.html?",
297 "s=strong-bad%40email.example.com&i=2001%3adb8%3a%3acb01"
298 ),
299 ),
300 ] {
301 let (m, _) = macro_string.as_bytes().iter().macro_string(true).unwrap();
302 assert_eq!(m.eval(&vars, "", false), expansion, "{macro_string:?}");
303 }
304 }
305}