1use std::sync::Arc;
9
10use crate::{
11 handler::{HandlerResult, MessageHandler},
12 transformer::MessageTransformer,
13 EmailMessage,
14};
15
16#[derive(Debug, Clone, PartialEq, Eq)]
18pub enum MatchType {
19 ExactAddress,
21 Domain,
23 WildcardDomain,
25}
26
27pub struct RoutingRule {
29 pub match_type: MatchType,
30 pub pattern: String,
31 pub handler: Arc<dyn MessageHandler>,
32 pub transformers: Vec<Box<dyn MessageTransformer>>,
33 pub auth_required: Option<bool>,
34}
35
36impl RoutingRule {
37 pub fn matches(&self, address: &str) -> bool {
39 match self.match_type {
40 MatchType::ExactAddress => address.eq_ignore_ascii_case(&self.pattern),
41 MatchType::Domain => {
42 if let Some(domain) = address.rsplit('@').next() {
43 domain.eq_ignore_ascii_case(&self.pattern)
44 } else {
45 false
46 }
47 }
48 MatchType::WildcardDomain => {
49 let wildcard = self.pattern.strip_prefix("*.").unwrap_or(&self.pattern);
50 if let Some(domain) = address.rsplit('@').next() {
51 domain.eq_ignore_ascii_case(wildcard)
53 || domain
54 .to_ascii_lowercase()
55 .ends_with(&format!(".{}", wildcard.to_ascii_lowercase()))
56 } else {
57 false
58 }
59 }
60 }
61 }
62}
63
64pub struct MessageRouter {
69 rules: Vec<RoutingRule>,
70 default_handler: Arc<dyn MessageHandler>,
71 default_transformers: Vec<Box<dyn MessageTransformer>>,
72}
73
74impl std::fmt::Debug for MessageRouter {
75 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76 f.debug_struct("MessageRouter")
77 .field("rules", &self.rules.len())
78 .finish()
79 }
80}
81
82impl MessageRouter {
83 pub fn new(
88 mut rules: Vec<RoutingRule>,
89 default_handler: Arc<dyn MessageHandler>,
90 default_transformers: Vec<Box<dyn MessageTransformer>>,
91 ) -> Self {
92 rules.sort_by_key(|r| match r.match_type {
94 MatchType::ExactAddress => 0,
95 MatchType::Domain => 1,
96 MatchType::WildcardDomain => 2,
97 });
98 Self {
99 rules,
100 default_handler,
101 default_transformers,
102 }
103 }
104
105 pub fn resolve(&self, recipient: &str) -> &Arc<dyn MessageHandler> {
107 for rule in &self.rules {
108 if rule.matches(recipient) {
109 return &rule.handler;
110 }
111 }
112 &self.default_handler
113 }
114
115 fn resolve_transformers(&self, recipient: &str) -> &[Box<dyn MessageTransformer>] {
120 for rule in &self.rules {
121 if rule.matches(recipient) {
122 if !rule.transformers.is_empty() {
123 return &rule.transformers;
124 }
125 return &self.default_transformers;
126 }
127 }
128 &self.default_transformers
129 }
130
131 pub async fn route(&self, message: &mut EmailMessage) -> HandlerResult<()> {
135 let transformers = self.resolve_transformers(&message.to);
136 <crate::MessageIdTransformer as MessageTransformer>::apply(transformers, message).await;
137
138 let handler = self.resolve(&message.to);
139 handler.handle(message).await
140 }
141
142 pub fn resolve_auth_required(&self, recipient: &str, global_default: bool) -> bool {
147 for rule in &self.rules {
148 if rule.matches(recipient) {
149 if let Some(auth_req) = rule.auth_required {
150 return auth_req;
151 }
152 break;
153 }
154 }
155 global_default
156 }
157
158 pub fn default_handler(&self) -> &Arc<dyn MessageHandler> {
160 &self.default_handler
161 }
162
163 pub fn rejection_for(&self, recipient: &str) -> Option<(u16, String)> {
170 self.resolve(recipient).reject_reply()
171 }
172}
173
174pub fn determine_match_type(address: &Option<String>, domain: &Option<String>) -> MatchType {
180 if address.is_some() {
181 MatchType::ExactAddress
182 } else if let Some(d) = domain {
183 if d.starts_with("*.") {
184 MatchType::WildcardDomain
185 } else {
186 MatchType::Domain
187 }
188 } else {
189 MatchType::Domain
190 }
191}
192
193pub fn extract_pattern(address: &Option<String>, domain: &Option<String>) -> String {
195 address
196 .as_ref()
197 .or(domain.as_ref())
198 .cloned()
199 .unwrap_or_default()
200}
201
202#[cfg(test)]
203mod tests {
204 use std::sync::atomic::{AtomicUsize, Ordering};
205
206 use super::*;
207 use crate::handler::HandlerFuture;
208
209 struct CountingHandler {
210 name: &'static str,
211 count: AtomicUsize,
212 }
213
214 impl CountingHandler {
215 fn new(name: &'static str) -> Self {
216 Self {
217 name,
218 count: AtomicUsize::new(0),
219 }
220 }
221
222 fn count(&self) -> usize {
223 self.count.load(Ordering::SeqCst)
224 }
225 }
226
227 impl MessageHandler for CountingHandler {
228 fn handle<'a>(&'a self, _message: &'a EmailMessage) -> HandlerFuture<'a> {
229 Box::pin(async move {
230 self.count.fetch_add(1, Ordering::SeqCst);
231 Ok(())
232 })
233 }
234
235 fn name(&self) -> &str {
236 self.name
237 }
238 }
239
240 #[test]
241 fn test_exact_address_match() {
242 let handler = Arc::new(CountingHandler::new("test"));
243 let rule = RoutingRule {
244 match_type: MatchType::ExactAddress,
245 pattern: "admin@example.com".to_string(),
246 handler,
247 transformers: vec![],
248 auth_required: None,
249 };
250
251 assert!(rule.matches("admin@example.com"));
252 assert!(rule.matches("ADMIN@EXAMPLE.COM"));
253 assert!(!rule.matches("user@example.com"));
254 assert!(!rule.matches("admin@other.com"));
255 }
256
257 #[test]
258 fn test_domain_match() {
259 let handler = Arc::new(CountingHandler::new("test"));
260 let rule = RoutingRule {
261 match_type: MatchType::Domain,
262 pattern: "example.com".to_string(),
263 handler,
264 transformers: vec![],
265 auth_required: None,
266 };
267
268 assert!(rule.matches("user@example.com"));
269 assert!(rule.matches("admin@example.com"));
270 assert!(rule.matches("user@EXAMPLE.COM"));
271 assert!(!rule.matches("user@other.com"));
272 assert!(!rule.matches("user@sub.example.com"));
273 }
274
275 #[test]
276 fn test_wildcard_domain_match() {
277 let handler = Arc::new(CountingHandler::new("test"));
278 let rule = RoutingRule {
279 match_type: MatchType::WildcardDomain,
280 pattern: "*.example.com".to_string(),
281 handler,
282 transformers: vec![],
283 auth_required: None,
284 };
285
286 assert!(rule.matches("user@sub.example.com"));
287 assert!(rule.matches("user@deep.sub.example.com"));
288 assert!(rule.matches("user@example.com"));
289 assert!(!rule.matches("user@other.com"));
290 }
291
292 #[tokio::test]
293 async fn test_router_specificity_order() {
294 let exact_handler = Arc::new(CountingHandler::new("exact"));
295 let domain_handler = Arc::new(CountingHandler::new("domain"));
296 let default_handler = Arc::new(CountingHandler::new("default"));
297
298 let rules = vec![
299 RoutingRule {
300 match_type: MatchType::Domain,
301 pattern: "example.com".to_string(),
302 handler: domain_handler.clone(),
303 transformers: vec![],
304 auth_required: None,
305 },
306 RoutingRule {
307 match_type: MatchType::ExactAddress,
308 pattern: "admin@example.com".to_string(),
309 handler: exact_handler.clone(),
310 transformers: vec![],
311 auth_required: None,
312 },
313 ];
314
315 let router = MessageRouter::new(rules, default_handler.clone(), vec![]);
316
317 let mut msg = EmailMessage::from_raw("sender@test.com", "admin@example.com", "test");
319 router.route(&mut msg).await.unwrap();
320 assert_eq!(exact_handler.count(), 1);
321 assert_eq!(domain_handler.count(), 0);
322
323 let mut msg = EmailMessage::from_raw("sender@test.com", "user@example.com", "test");
325 router.route(&mut msg).await.unwrap();
326 assert_eq!(domain_handler.count(), 1);
327
328 let mut msg = EmailMessage::from_raw("sender@test.com", "user@other.com", "test");
330 router.route(&mut msg).await.unwrap();
331 assert_eq!(default_handler.count(), 1);
332 }
333
334 #[test]
335 fn test_resolve_auth_required() {
336 let handler = Arc::new(CountingHandler::new("test"));
337 let rules = vec![
338 RoutingRule {
339 match_type: MatchType::ExactAddress,
340 pattern: "secure@example.com".to_string(),
341 handler: handler.clone(),
342 transformers: vec![],
343 auth_required: Some(true),
344 },
345 RoutingRule {
346 match_type: MatchType::Domain,
347 pattern: "open.com".to_string(),
348 handler: handler.clone(),
349 transformers: vec![],
350 auth_required: Some(false),
351 },
352 RoutingRule {
353 match_type: MatchType::Domain,
354 pattern: "default.com".to_string(),
355 handler: handler.clone(),
356 transformers: vec![],
357 auth_required: None,
358 },
359 ];
360
361 let router = MessageRouter::new(rules, handler, vec![]);
362
363 assert!(router.resolve_auth_required("secure@example.com", false));
365
366 assert!(!router.resolve_auth_required("user@open.com", true));
368
369 assert!(router.resolve_auth_required("user@default.com", true));
371 assert!(!router.resolve_auth_required("user@default.com", false));
372
373 assert!(router.resolve_auth_required("user@unknown.com", true));
375 assert!(!router.resolve_auth_required("user@unknown.com", false));
376 }
377
378 #[test]
379 fn test_determine_match_type() {
380 assert_eq!(
381 determine_match_type(&Some("user@test.com".to_string()), &None),
382 MatchType::ExactAddress
383 );
384 assert_eq!(
385 determine_match_type(&None, &Some("example.com".to_string())),
386 MatchType::Domain
387 );
388 assert_eq!(
389 determine_match_type(&None, &Some("*.example.com".to_string())),
390 MatchType::WildcardDomain
391 );
392 }
393
394 #[test]
395 fn test_extract_pattern() {
396 assert_eq!(
397 extract_pattern(&Some("user@test.com".to_string()), &None),
398 "user@test.com"
399 );
400 assert_eq!(
401 extract_pattern(&None, &Some("example.com".to_string())),
402 "example.com"
403 );
404 }
405}