1use std::{borrow::Cow, net::IpAddr, time::SystemTime};
8
9use super::{Macro, Variable, Variables};
10
11impl Macro {
12 pub fn eval<'z, 'x: 'z>(
13 &'z self,
14 vars: &'x Variables<'x>,
15 default: &'x str,
16 fqdn: bool,
17 ) -> Cow<'z, str> {
18 match self {
19 Macro::Literal(literal) => std::str::from_utf8(literal).unwrap_or_default().into(),
20 Macro::Variable {
21 letter,
22 num_parts,
23 reverse,
24 escape,
25 delimiters,
26 } => match vars.get(*letter, *num_parts, *reverse, *escape, fqdn, *delimiters) {
27 Cow::Borrowed(bytes) => std::str::from_utf8(bytes).unwrap_or_default().into(),
28 Cow::Owned(bytes) => String::from_utf8(bytes).unwrap_or_default().into(),
29 },
30 Macro::List(list) => {
31 let mut result = Vec::with_capacity(32);
32 for item in list {
33 match item {
34 Macro::Literal(literal) => {
35 result.extend_from_slice(literal);
36 }
37 Macro::Variable {
38 letter,
39 num_parts,
40 reverse,
41 escape,
42 delimiters,
43 } => {
44 result.extend_from_slice(
45 vars.get(
46 *letter,
47 *num_parts,
48 *reverse,
49 *escape,
50 false,
51 *delimiters,
52 )
53 .as_ref(),
54 );
55 }
56 Macro::List(_) | Macro::None => unreachable!(),
57 }
58 }
59 if fqdn && !result.is_empty() && result.last().unwrap() != &b'.' {
60 result.push(b'.');
61 }
62 String::from_utf8(result).unwrap_or_default().into()
63 }
64 Macro::None => default.into(),
65 }
66 }
67
68 pub fn needs_ptr(&self) -> bool {
69 match self {
70 Macro::Variable { letter, .. } => *letter == Variable::ValidatedDomain,
71 Macro::List(list) => list.iter().any(|m| matches!(m, Macro::Variable { letter, .. } if *letter == Variable::ValidatedDomain)),
72 _ => false,
73 }
74 }
75}
76
77impl<'x> Variables<'x> {
78 pub fn new() -> Self {
79 let mut vars = Variables::default();
80 vars.vars[Variable::CurrentTime as usize] = SystemTime::now()
81 .duration_since(SystemTime::UNIX_EPOCH)
82 .map(|d| d.as_secs())
83 .unwrap_or(0)
84 .to_string()
85 .into_bytes()
86 .into();
87 vars
88 }
89
90 pub fn set_ip(&mut self, value: &IpAddr) {
91 let (v, i, c) = match value {
92 IpAddr::V4(ip) => (
93 "in-addr".as_bytes(),
94 ip.to_string().into_bytes(),
95 ip.to_string(),
96 ),
97 IpAddr::V6(ip) => {
98 let mut segments = Vec::with_capacity(63);
99 for segment in ip.segments() {
100 for &p in format!("{segment:04x}").as_bytes() {
101 if !segments.is_empty() {
102 segments.push(b'.');
103 }
104 segments.push(p);
105 }
106 }
107 ("ip6".as_bytes(), segments, ip.to_string())
108 }
109 };
110 self.vars[Variable::IpVersion as usize] = v.into();
111 self.vars[Variable::Ip as usize] = i.into();
112 self.vars[Variable::SmtpIp as usize] = c.into_bytes().into();
113 }
114
115 pub fn set_sender(&mut self, value: impl Into<Cow<'x, [u8]>>) {
116 let value = value.into();
117 for (pos, ch) in value.as_ref().iter().enumerate() {
118 if ch == &b'@' {
119 if pos > 0 {
120 self.vars[Variable::SenderLocalPart as usize] = match &value {
121 Cow::Borrowed(value) => (&value[..pos]).into(),
122 Cow::Owned(value) => value[..pos].to_vec().into(),
123 };
124 }
125 self.vars[Variable::SenderDomainPart as usize] = match &value {
126 Cow::Borrowed(value) => (value.get(pos + 1..).unwrap_or_default()).into(),
127 Cow::Owned(value) => (value.get(pos + 1..).unwrap_or_default()).to_vec().into(),
128 };
129 break;
130 }
131 }
132
133 self.vars[Variable::Sender as usize] = value;
134 }
135
136 pub fn set_helo_domain(&mut self, value: impl Into<Cow<'x, [u8]>>) {
137 self.vars[Variable::HeloDomain as usize] = value.into();
138 }
139
140 pub fn set_host_domain(&mut self, value: impl Into<Cow<'x, [u8]>>) {
141 self.vars[Variable::HostDomain as usize] = value.into();
142 }
143
144 pub fn set_validated_domain(&mut self, value: impl Into<Cow<'x, [u8]>>) {
145 self.vars[Variable::ValidatedDomain as usize] = value.into();
146 }
147
148 pub fn set_domain(&mut self, value: impl Into<Cow<'x, [u8]>>) {
149 self.vars[Variable::Domain as usize] = value.into();
150 }
151
152 pub fn get(
153 &self,
154 name: Variable,
155 num_parts: u32,
156 reverse: bool,
157 escape: bool,
158 fqdn: bool,
159 delimiters: u64,
160 ) -> Cow<'_, [u8]> {
161 let var = self.vars[name as usize].as_ref();
162 if var.is_empty()
163 || (num_parts == 0 && !reverse && !escape && delimiters == 1u64 << (b'.' - b'+'))
164 {
165 return var.into();
166 }
167 let mut parts = Vec::new();
168 let mut parts_len = 0;
169 let mut start_pos = 0;
170
171 for (pos, ch) in var.iter().enumerate() {
172 if (b'+'..=b'_').contains(ch) && (delimiters & (1u64 << (*ch - b'+'))) != 0 {
173 parts_len += pos - start_pos + 1;
174 parts.push(&var[start_pos..pos]);
175 start_pos = pos + 1;
176 }
177 }
178 parts.push(&var[start_pos..var.len()]);
179
180 let num_parts = if num_parts == 0 {
181 parts.len()
182 } else {
183 std::cmp::min(parts.len(), num_parts as usize)
184 };
185
186 let mut result = Vec::with_capacity(parts_len + var.len() - start_pos);
187 if !reverse {
188 for (pos, part) in parts.iter().skip(parts.len() - num_parts).enumerate() {
189 add_part(&mut result, part, pos, escape);
190 }
191 } else {
192 for (pos, part) in parts.iter().rev().skip(parts.len() - num_parts).enumerate() {
193 add_part(&mut result, part, pos, escape);
194 }
195 }
196 if fqdn && result.last().unwrap_or(&0) != &b'.' {
197 result.push(b'.');
198 }
199 result.into()
200 }
201}
202
203#[inline(always)]
204fn add_part(result: &mut Vec<u8>, part: &[u8], pos: usize, escape: bool) {
205 if pos > 0 {
206 result.push(b'.');
207 }
208 if !escape {
209 result.extend_from_slice(part);
210 } else {
211 for ch in part {
212 if ch.is_ascii_alphanumeric() || [b'-', b'.', b'_', b'~'].contains(ch) {
213 result.push(*ch);
214 } else {
215 result.extend_from_slice(format!("%{ch:02x}").as_bytes());
216 }
217 }
218 }
219}
220
221#[cfg(test)]
222mod test {
223 use std::net::IpAddr;
224
225 use crate::spf::{parse::SPFParser, Variables};
226
227 #[test]
228 fn expand_macro() {
229 let mut vars = Variables::new();
230 vars.set_sender("strong-bad@email.example.com".as_bytes());
231 vars.set_ip(&"192.0.2.3".parse::<IpAddr>().unwrap());
232 vars.set_validated_domain("mx.example.org".as_bytes());
233 vars.set_domain("email.example.com".as_bytes());
234 vars.set_helo_domain("....".as_bytes());
235
236 for (macro_string, expansion) in [
237 ("%{s}", "strong-bad@email.example.com"),
238 ("%{o}", "email.example.com"),
239 ("%{d}", "email.example.com"),
240 ("%{d4}", "email.example.com"),
241 ("%{d3}", "email.example.com"),
242 ("%{d2}", "example.com"),
243 ("%{d1}", "com"),
244 ("%{dr}", "com.example.email"),
245 ("%{d2r}", "example.email"),
246 ("%{l}", "strong-bad"),
247 ("%{l-}", "strong.bad"),
248 ("%{lr}", "strong-bad"),
249 ("%{lr-}", "bad.strong"),
250 ("%{l1r-}", "strong"),
251 ("%{p1r}", "mx"),
252 ("%{h3r}", ".."),
253 (
254 "%{ir}.%{v}._spf.%{d2}",
255 "3.2.0.192.in-addr._spf.example.com",
256 ),
257 ("%{lr-}.lp._spf.%{d2}", "bad.strong.lp._spf.example.com"),
258 (
259 "%{lr-}.lp.%{ir}.%{v}._spf.%{d2}",
260 "bad.strong.lp.3.2.0.192.in-addr._spf.example.com",
261 ),
262 (
263 "%{ir}.%{v}.%{l1r-}.lp._spf.%{d2}",
264 "3.2.0.192.in-addr.strong.lp._spf.example.com",
265 ),
266 (
267 "%{d2}.trusted-domains.example.net",
268 "example.com.trusted-domains.example.net",
269 ),
270 ] {
271 let (m, _) = macro_string.as_bytes().iter().macro_string(false).unwrap();
272 assert_eq!(m.eval(&vars, "", false), expansion, "{macro_string:?}");
273 }
274
275 let mut vars = Variables::new();
276 vars.set_sender("strong-bad@email.example.com".as_bytes());
277 vars.set_ip(&"2001:db8::cb01".parse::<IpAddr>().unwrap());
278 vars.set_validated_domain("mx.example.org".as_bytes());
279 vars.set_domain("email.example.com".as_bytes());
280
281 for (macro_string, expansion) in [
282 (
283 "%{ir}.%{v}._spf.%{d2}",
284 concat!(
285 "1.0.b.c.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.",
286 "0.0.0.0.0.8.b.d.0.1.0.0.2.ip6._spf.example.com"
287 ),
288 ),
289 ("%{c}", "2001:db8::cb01"),
290 (
291 "%{c} is not one of %{d}'s designated mail servers.",
292 "2001:db8::cb01 is not one of email.example.com's designated mail servers.",
293 ),
294 (
295 "See http://%{d}/why.html?s=%{S}&i=%{C}",
296 concat!(
297 "See http://email.example.com/why.html?",
298 "s=strong-bad%40email.example.com&i=2001%3adb8%3a%3acb01"
299 ),
300 ),
301 ] {
302 let (m, _) = macro_string.as_bytes().iter().macro_string(true).unwrap();
303 assert_eq!(m.eval(&vars, "", false), expansion, "{macro_string:?}");
304 }
305 }
306}