1use std::sync::Arc;
9
10use crate::{
11 handler::{HandlerResult, MessageHandler},
12 transformer::MessageTransformer,
13 EmailMessage,
14};
15
16#[derive(Debug, Clone, PartialEq, Eq)]
18pub enum MatchType {
19 ExactAddress,
21 Domain,
23 WildcardDomain,
25}
26
27pub struct RoutingRule {
29 pub match_type: MatchType,
30 pub pattern: String,
31 pub handler: Arc<dyn MessageHandler>,
32 pub transformers: Vec<Box<dyn MessageTransformer>>,
33 pub auth_required: Option<bool>,
34}
35
36impl RoutingRule {
37 pub fn matches(&self, address: &str) -> bool {
39 match self.match_type {
40 MatchType::ExactAddress => address.eq_ignore_ascii_case(&self.pattern),
41 MatchType::Domain => {
42 if let Some(domain) = address.rsplit('@').next() {
43 domain.eq_ignore_ascii_case(&self.pattern)
44 } else {
45 false
46 }
47 }
48 MatchType::WildcardDomain => {
49 let wildcard = self.pattern.strip_prefix("*.").unwrap_or(&self.pattern);
50 if let Some(domain) = address.rsplit('@').next() {
51 domain.eq_ignore_ascii_case(wildcard)
53 || domain
54 .to_ascii_lowercase()
55 .ends_with(&format!(".{}", wildcard.to_ascii_lowercase()))
56 } else {
57 false
58 }
59 }
60 }
61 }
62}
63
64pub struct MessageRouter {
69 rules: Vec<RoutingRule>,
70 default_handler: Arc<dyn MessageHandler>,
71 default_transformers: Vec<Box<dyn MessageTransformer>>,
72}
73
74impl std::fmt::Debug for MessageRouter {
75 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76 f.debug_struct("MessageRouter")
77 .field("rules", &self.rules.len())
78 .finish()
79 }
80}
81
82impl MessageRouter {
83 pub fn new(
88 mut rules: Vec<RoutingRule>,
89 default_handler: Arc<dyn MessageHandler>,
90 default_transformers: Vec<Box<dyn MessageTransformer>>,
91 ) -> Self {
92 rules.sort_by_key(|r| match r.match_type {
94 MatchType::ExactAddress => 0,
95 MatchType::Domain => 1,
96 MatchType::WildcardDomain => 2,
97 });
98 Self {
99 rules,
100 default_handler,
101 default_transformers,
102 }
103 }
104
105 pub fn resolve(&self, recipient: &str) -> &Arc<dyn MessageHandler> {
107 for rule in &self.rules {
108 if rule.matches(recipient) {
109 return &rule.handler;
110 }
111 }
112 &self.default_handler
113 }
114
115 fn resolve_transformers(&self, recipient: &str) -> &[Box<dyn MessageTransformer>] {
120 for rule in &self.rules {
121 if rule.matches(recipient) {
122 if !rule.transformers.is_empty() {
123 return &rule.transformers;
124 }
125 return &self.default_transformers;
126 }
127 }
128 &self.default_transformers
129 }
130
131 pub async fn route(&self, message: &mut EmailMessage) -> HandlerResult<()> {
135 let transformers = self.resolve_transformers(&message.to);
136 <crate::MessageIdTransformer as MessageTransformer>::apply(transformers, message).await;
137
138 let handler = self.resolve(&message.to);
139 handler.handle(message).await
140 }
141
142 pub fn resolve_auth_required(&self, recipient: &str, global_default: bool) -> bool {
147 for rule in &self.rules {
148 if rule.matches(recipient) {
149 if let Some(auth_req) = rule.auth_required {
150 return auth_req;
151 }
152 break;
153 }
154 }
155 global_default
156 }
157
158 pub fn default_handler(&self) -> &Arc<dyn MessageHandler> {
160 &self.default_handler
161 }
162}
163
164pub fn determine_match_type(address: &Option<String>, domain: &Option<String>) -> MatchType {
170 if address.is_some() {
171 MatchType::ExactAddress
172 } else if let Some(d) = domain {
173 if d.starts_with("*.") {
174 MatchType::WildcardDomain
175 } else {
176 MatchType::Domain
177 }
178 } else {
179 MatchType::Domain
180 }
181}
182
183pub fn extract_pattern(address: &Option<String>, domain: &Option<String>) -> String {
185 address
186 .as_ref()
187 .or(domain.as_ref())
188 .cloned()
189 .unwrap_or_default()
190}
191
192#[cfg(test)]
193mod tests {
194 use std::sync::atomic::{AtomicUsize, Ordering};
195
196 use super::*;
197 use crate::handler::HandlerFuture;
198
199 struct CountingHandler {
200 name: &'static str,
201 count: AtomicUsize,
202 }
203
204 impl CountingHandler {
205 fn new(name: &'static str) -> Self {
206 Self {
207 name,
208 count: AtomicUsize::new(0),
209 }
210 }
211
212 fn count(&self) -> usize {
213 self.count.load(Ordering::SeqCst)
214 }
215 }
216
217 impl MessageHandler for CountingHandler {
218 fn handle<'a>(&'a self, _message: &'a EmailMessage) -> HandlerFuture<'a> {
219 Box::pin(async move {
220 self.count.fetch_add(1, Ordering::SeqCst);
221 Ok(())
222 })
223 }
224
225 fn name(&self) -> &str {
226 self.name
227 }
228 }
229
230 #[test]
231 fn test_exact_address_match() {
232 let handler = Arc::new(CountingHandler::new("test"));
233 let rule = RoutingRule {
234 match_type: MatchType::ExactAddress,
235 pattern: "admin@example.com".to_string(),
236 handler,
237 transformers: vec![],
238 auth_required: None,
239 };
240
241 assert!(rule.matches("admin@example.com"));
242 assert!(rule.matches("ADMIN@EXAMPLE.COM"));
243 assert!(!rule.matches("user@example.com"));
244 assert!(!rule.matches("admin@other.com"));
245 }
246
247 #[test]
248 fn test_domain_match() {
249 let handler = Arc::new(CountingHandler::new("test"));
250 let rule = RoutingRule {
251 match_type: MatchType::Domain,
252 pattern: "example.com".to_string(),
253 handler,
254 transformers: vec![],
255 auth_required: None,
256 };
257
258 assert!(rule.matches("user@example.com"));
259 assert!(rule.matches("admin@example.com"));
260 assert!(rule.matches("user@EXAMPLE.COM"));
261 assert!(!rule.matches("user@other.com"));
262 assert!(!rule.matches("user@sub.example.com"));
263 }
264
265 #[test]
266 fn test_wildcard_domain_match() {
267 let handler = Arc::new(CountingHandler::new("test"));
268 let rule = RoutingRule {
269 match_type: MatchType::WildcardDomain,
270 pattern: "*.example.com".to_string(),
271 handler,
272 transformers: vec![],
273 auth_required: None,
274 };
275
276 assert!(rule.matches("user@sub.example.com"));
277 assert!(rule.matches("user@deep.sub.example.com"));
278 assert!(rule.matches("user@example.com"));
279 assert!(!rule.matches("user@other.com"));
280 }
281
282 #[tokio::test]
283 async fn test_router_specificity_order() {
284 let exact_handler = Arc::new(CountingHandler::new("exact"));
285 let domain_handler = Arc::new(CountingHandler::new("domain"));
286 let default_handler = Arc::new(CountingHandler::new("default"));
287
288 let rules = vec![
289 RoutingRule {
290 match_type: MatchType::Domain,
291 pattern: "example.com".to_string(),
292 handler: domain_handler.clone(),
293 transformers: vec![],
294 auth_required: None,
295 },
296 RoutingRule {
297 match_type: MatchType::ExactAddress,
298 pattern: "admin@example.com".to_string(),
299 handler: exact_handler.clone(),
300 transformers: vec![],
301 auth_required: None,
302 },
303 ];
304
305 let router = MessageRouter::new(rules, default_handler.clone(), vec![]);
306
307 let mut msg = EmailMessage::from_raw("sender@test.com", "admin@example.com", "test");
309 router.route(&mut msg).await.unwrap();
310 assert_eq!(exact_handler.count(), 1);
311 assert_eq!(domain_handler.count(), 0);
312
313 let mut msg = EmailMessage::from_raw("sender@test.com", "user@example.com", "test");
315 router.route(&mut msg).await.unwrap();
316 assert_eq!(domain_handler.count(), 1);
317
318 let mut msg = EmailMessage::from_raw("sender@test.com", "user@other.com", "test");
320 router.route(&mut msg).await.unwrap();
321 assert_eq!(default_handler.count(), 1);
322 }
323
324 #[test]
325 fn test_resolve_auth_required() {
326 let handler = Arc::new(CountingHandler::new("test"));
327 let rules = vec![
328 RoutingRule {
329 match_type: MatchType::ExactAddress,
330 pattern: "secure@example.com".to_string(),
331 handler: handler.clone(),
332 transformers: vec![],
333 auth_required: Some(true),
334 },
335 RoutingRule {
336 match_type: MatchType::Domain,
337 pattern: "open.com".to_string(),
338 handler: handler.clone(),
339 transformers: vec![],
340 auth_required: Some(false),
341 },
342 RoutingRule {
343 match_type: MatchType::Domain,
344 pattern: "default.com".to_string(),
345 handler: handler.clone(),
346 transformers: vec![],
347 auth_required: None,
348 },
349 ];
350
351 let router = MessageRouter::new(rules, handler, vec![]);
352
353 assert!(router.resolve_auth_required("secure@example.com", false));
355
356 assert!(!router.resolve_auth_required("user@open.com", true));
358
359 assert!(router.resolve_auth_required("user@default.com", true));
361 assert!(!router.resolve_auth_required("user@default.com", false));
362
363 assert!(router.resolve_auth_required("user@unknown.com", true));
365 assert!(!router.resolve_auth_required("user@unknown.com", false));
366 }
367
368 #[test]
369 fn test_determine_match_type() {
370 assert_eq!(
371 determine_match_type(&Some("user@test.com".to_string()), &None),
372 MatchType::ExactAddress
373 );
374 assert_eq!(
375 determine_match_type(&None, &Some("example.com".to_string())),
376 MatchType::Domain
377 );
378 assert_eq!(
379 determine_match_type(&None, &Some("*.example.com".to_string())),
380 MatchType::WildcardDomain
381 );
382 }
383
384 #[test]
385 fn test_extract_pattern() {
386 assert_eq!(
387 extract_pattern(&Some("user@test.com".to_string()), &None),
388 "user@test.com"
389 );
390 assert_eq!(
391 extract_pattern(&None, &Some("example.com".to_string())),
392 "example.com"
393 );
394 }
395}