1use std::sync::Arc;
9
10use crate::{
11 handler::{HandlerResult, MessageHandler},
12 transformer::MessageTransformer,
13 EmailMessage,
14};
15
16#[derive(Debug, Clone, PartialEq, Eq)]
18pub enum MatchType {
19 ExactAddress,
21 Domain,
23 WildcardDomain,
25 CatchAll,
27}
28
29pub struct RoutingRule {
31 pub match_type: MatchType,
32 pub pattern: String,
33 pub handler: Arc<dyn MessageHandler>,
34 pub transformers: Vec<Box<dyn MessageTransformer>>,
35 pub auth_required: Option<bool>,
36}
37
38impl RoutingRule {
39 pub fn matches(&self, address: &str) -> bool {
41 match self.match_type {
42 MatchType::ExactAddress => address.eq_ignore_ascii_case(&self.pattern),
43 MatchType::Domain => {
44 if let Some(domain) = address.rsplit('@').next() {
45 domain.eq_ignore_ascii_case(&self.pattern)
46 } else {
47 false
48 }
49 }
50 MatchType::WildcardDomain => {
51 let wildcard = self.pattern.strip_prefix("*.").unwrap_or(&self.pattern);
52 if let Some(domain) = address.rsplit('@').next() {
53 domain.eq_ignore_ascii_case(wildcard)
55 || domain
56 .to_ascii_lowercase()
57 .ends_with(&format!(".{}", wildcard.to_ascii_lowercase()))
58 } else {
59 false
60 }
61 }
62 MatchType::CatchAll => true,
63 }
64 }
65}
66
67pub struct MessageRouter {
72 rules: Vec<RoutingRule>,
73 default_handler: Arc<dyn MessageHandler>,
74 default_transformers: Vec<Box<dyn MessageTransformer>>,
75}
76
77impl std::fmt::Debug for MessageRouter {
78 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79 f.debug_struct("MessageRouter")
80 .field("rules", &self.rules.len())
81 .finish()
82 }
83}
84
85impl MessageRouter {
86 pub fn new(
91 mut rules: Vec<RoutingRule>,
92 default_handler: Arc<dyn MessageHandler>,
93 default_transformers: Vec<Box<dyn MessageTransformer>>,
94 ) -> Self {
95 rules.sort_by_key(|r| match r.match_type {
97 MatchType::ExactAddress => 0,
98 MatchType::Domain => 1,
99 MatchType::WildcardDomain => 2,
100 MatchType::CatchAll => 3,
101 });
102 Self {
103 rules,
104 default_handler,
105 default_transformers,
106 }
107 }
108
109 pub fn resolve(&self, recipient: &str) -> &Arc<dyn MessageHandler> {
111 for rule in &self.rules {
112 if rule.matches(recipient) {
113 return &rule.handler;
114 }
115 }
116 &self.default_handler
117 }
118
119 fn resolve_transformers(&self, recipient: &str) -> &[Box<dyn MessageTransformer>] {
124 for rule in &self.rules {
125 if rule.matches(recipient) {
126 if !rule.transformers.is_empty() {
127 return &rule.transformers;
128 }
129 return &self.default_transformers;
130 }
131 }
132 &self.default_transformers
133 }
134
135 pub async fn route(&self, message: &mut EmailMessage) -> HandlerResult<()> {
139 let transformers = self.resolve_transformers(&message.to);
140 <crate::MessageIdTransformer as MessageTransformer>::apply(transformers, message).await;
141
142 let handler = self.resolve(&message.to);
143 handler.handle(message).await
144 }
145
146 pub fn resolve_auth_required(&self, recipient: &str, global_default: bool) -> bool {
151 for rule in &self.rules {
152 if rule.matches(recipient) {
153 if let Some(auth_req) = rule.auth_required {
154 return auth_req;
155 }
156 break;
157 }
158 }
159 global_default
160 }
161
162 pub fn rejection_for(&self, recipient: &str) -> Option<(u16, String)> {
169 self.resolve(recipient).reject_reply()
170 }
171
172 pub fn default_handler(&self) -> &Arc<dyn MessageHandler> {
174 &self.default_handler
175 }
176}
177
178pub fn determine_match_type(address: &Option<String>, domain: &Option<String>) -> MatchType {
185 if address.is_some() {
186 MatchType::ExactAddress
187 } else if let Some(d) = domain {
188 if d == "*" {
189 MatchType::CatchAll
190 } else if d.starts_with("*.") {
191 MatchType::WildcardDomain
192 } else {
193 MatchType::Domain
194 }
195 } else {
196 MatchType::Domain
197 }
198}
199
200pub fn extract_pattern(address: &Option<String>, domain: &Option<String>) -> String {
202 address
203 .as_ref()
204 .or(domain.as_ref())
205 .cloned()
206 .unwrap_or_default()
207}
208
209#[cfg(test)]
210mod tests {
211 use std::sync::atomic::{AtomicUsize, Ordering};
212
213 use super::*;
214 use crate::handler::HandlerFuture;
215
216 struct CountingHandler {
217 name: &'static str,
218 count: AtomicUsize,
219 }
220
221 impl CountingHandler {
222 fn new(name: &'static str) -> Self {
223 Self {
224 name,
225 count: AtomicUsize::new(0),
226 }
227 }
228
229 fn count(&self) -> usize {
230 self.count.load(Ordering::SeqCst)
231 }
232 }
233
234 impl MessageHandler for CountingHandler {
235 fn handle<'a>(&'a self, _message: &'a EmailMessage) -> HandlerFuture<'a> {
236 Box::pin(async move {
237 self.count.fetch_add(1, Ordering::SeqCst);
238 Ok(())
239 })
240 }
241
242 fn name(&self) -> &str {
243 self.name
244 }
245 }
246
247 #[test]
248 fn test_exact_address_match() {
249 let handler = Arc::new(CountingHandler::new("test"));
250 let rule = RoutingRule {
251 match_type: MatchType::ExactAddress,
252 pattern: "admin@example.com".to_string(),
253 handler,
254 transformers: vec![],
255 auth_required: None,
256 };
257
258 assert!(rule.matches("admin@example.com"));
259 assert!(rule.matches("ADMIN@EXAMPLE.COM"));
260 assert!(!rule.matches("user@example.com"));
261 assert!(!rule.matches("admin@other.com"));
262 }
263
264 #[test]
265 fn test_domain_match() {
266 let handler = Arc::new(CountingHandler::new("test"));
267 let rule = RoutingRule {
268 match_type: MatchType::Domain,
269 pattern: "example.com".to_string(),
270 handler,
271 transformers: vec![],
272 auth_required: None,
273 };
274
275 assert!(rule.matches("user@example.com"));
276 assert!(rule.matches("admin@example.com"));
277 assert!(rule.matches("user@EXAMPLE.COM"));
278 assert!(!rule.matches("user@other.com"));
279 assert!(!rule.matches("user@sub.example.com"));
280 }
281
282 #[test]
283 fn test_wildcard_domain_match() {
284 let handler = Arc::new(CountingHandler::new("test"));
285 let rule = RoutingRule {
286 match_type: MatchType::WildcardDomain,
287 pattern: "*.example.com".to_string(),
288 handler,
289 transformers: vec![],
290 auth_required: None,
291 };
292
293 assert!(rule.matches("user@sub.example.com"));
294 assert!(rule.matches("user@deep.sub.example.com"));
295 assert!(rule.matches("user@example.com"));
296 assert!(!rule.matches("user@other.com"));
297 }
298
299 #[test]
300 fn test_catch_all_match() {
301 let handler = Arc::new(CountingHandler::new("test"));
302 let rule = RoutingRule {
303 match_type: MatchType::CatchAll,
304 pattern: "*".to_string(),
305 handler,
306 transformers: vec![],
307 auth_required: None,
308 };
309
310 assert!(rule.matches("user@example.com"));
311 assert!(rule.matches("user@other.com"));
312 assert!(rule.matches("admin@deep.sub.example.com"));
313 assert!(rule.matches("malformed-no-at-sign"));
314 assert!(rule.matches(""));
315 }
316
317 #[tokio::test]
318 async fn test_router_specificity_order() {
319 let exact_handler = Arc::new(CountingHandler::new("exact"));
320 let domain_handler = Arc::new(CountingHandler::new("domain"));
321 let default_handler = Arc::new(CountingHandler::new("default"));
322
323 let rules = vec![
324 RoutingRule {
325 match_type: MatchType::Domain,
326 pattern: "example.com".to_string(),
327 handler: domain_handler.clone(),
328 transformers: vec![],
329 auth_required: None,
330 },
331 RoutingRule {
332 match_type: MatchType::ExactAddress,
333 pattern: "admin@example.com".to_string(),
334 handler: exact_handler.clone(),
335 transformers: vec![],
336 auth_required: None,
337 },
338 ];
339
340 let router = MessageRouter::new(rules, default_handler.clone(), vec![]);
341
342 let mut msg = EmailMessage::from_raw("sender@test.com", "admin@example.com", "test");
344 router.route(&mut msg).await.unwrap();
345 assert_eq!(exact_handler.count(), 1);
346 assert_eq!(domain_handler.count(), 0);
347
348 let mut msg = EmailMessage::from_raw("sender@test.com", "user@example.com", "test");
350 router.route(&mut msg).await.unwrap();
351 assert_eq!(domain_handler.count(), 1);
352
353 let mut msg = EmailMessage::from_raw("sender@test.com", "user@other.com", "test");
355 router.route(&mut msg).await.unwrap();
356 assert_eq!(default_handler.count(), 1);
357 }
358
359 #[tokio::test]
360 async fn test_router_catch_all_fallback() {
361 let domain_handler = Arc::new(CountingHandler::new("domain"));
362 let catch_all_handler = Arc::new(CountingHandler::new("catch_all"));
363 let default_handler = Arc::new(CountingHandler::new("default"));
364
365 let rules = vec![
366 RoutingRule {
367 match_type: MatchType::CatchAll,
368 pattern: "*".to_string(),
369 handler: catch_all_handler.clone(),
370 transformers: vec![],
371 auth_required: None,
372 },
373 RoutingRule {
374 match_type: MatchType::Domain,
375 pattern: "example.com".to_string(),
376 handler: domain_handler.clone(),
377 transformers: vec![],
378 auth_required: None,
379 },
380 ];
381
382 let router = MessageRouter::new(rules, default_handler.clone(), vec![]);
383
384 let mut msg = EmailMessage::from_raw("sender@test.com", "user@example.com", "test");
386 router.route(&mut msg).await.unwrap();
387 assert_eq!(domain_handler.count(), 1);
388 assert_eq!(catch_all_handler.count(), 0);
389
390 let mut msg = EmailMessage::from_raw("sender@test.com", "user@other.com", "test");
392 router.route(&mut msg).await.unwrap();
393 assert_eq!(catch_all_handler.count(), 1);
394 assert_eq!(default_handler.count(), 0);
395 }
396
397 #[test]
398 fn test_resolve_auth_required() {
399 let handler = Arc::new(CountingHandler::new("test"));
400 let rules = vec![
401 RoutingRule {
402 match_type: MatchType::ExactAddress,
403 pattern: "secure@example.com".to_string(),
404 handler: handler.clone(),
405 transformers: vec![],
406 auth_required: Some(true),
407 },
408 RoutingRule {
409 match_type: MatchType::Domain,
410 pattern: "open.com".to_string(),
411 handler: handler.clone(),
412 transformers: vec![],
413 auth_required: Some(false),
414 },
415 RoutingRule {
416 match_type: MatchType::Domain,
417 pattern: "default.com".to_string(),
418 handler: handler.clone(),
419 transformers: vec![],
420 auth_required: None,
421 },
422 ];
423
424 let router = MessageRouter::new(rules, handler, vec![]);
425
426 assert!(router.resolve_auth_required("secure@example.com", false));
428
429 assert!(!router.resolve_auth_required("user@open.com", true));
431
432 assert!(router.resolve_auth_required("user@default.com", true));
434 assert!(!router.resolve_auth_required("user@default.com", false));
435
436 assert!(router.resolve_auth_required("user@unknown.com", true));
438 assert!(!router.resolve_auth_required("user@unknown.com", false));
439 }
440
441 #[test]
442 fn test_catch_all_auth_required_override() {
443 let handler = Arc::new(CountingHandler::new("test"));
444 let rules = vec![RoutingRule {
445 match_type: MatchType::CatchAll,
446 pattern: "*".to_string(),
447 handler: handler.clone(),
448 transformers: vec![],
449 auth_required: Some(true),
450 }];
451
452 let router = MessageRouter::new(rules, handler, vec![]);
453
454 assert!(router.resolve_auth_required("user@whatever.com", false));
456 assert!(router.resolve_auth_required("someone@else.net", false));
457 }
458
459 #[test]
460 fn test_determine_match_type() {
461 assert_eq!(
462 determine_match_type(&Some("user@test.com".to_string()), &None),
463 MatchType::ExactAddress
464 );
465 assert_eq!(
466 determine_match_type(&None, &Some("example.com".to_string())),
467 MatchType::Domain
468 );
469 assert_eq!(
470 determine_match_type(&None, &Some("*.example.com".to_string())),
471 MatchType::WildcardDomain
472 );
473 assert_eq!(
474 determine_match_type(&None, &Some("*".to_string())),
475 MatchType::CatchAll
476 );
477 }
478
479 #[test]
480 fn test_extract_pattern() {
481 assert_eq!(
482 extract_pattern(&Some("user@test.com".to_string()), &None),
483 "user@test.com"
484 );
485 assert_eq!(
486 extract_pattern(&None, &Some("example.com".to_string())),
487 "example.com"
488 );
489 }
490}