Skip to main content

mailsis_utils/
router.rs

1//! Recipient-based message routing for the SMTP pipeline.
2//!
3//! When an email arrives, the router decides which delivery handler should
4//! process it by matching the recipient address against a prioritized list
5//! of rules (exact address, domain, wildcard domain, catch-all). Unmatched
6//! recipients fall through to a configurable default handler.
7
8use std::sync::Arc;
9
10use crate::{
11    handler::{HandlerResult, MessageHandler},
12    transformer::MessageTransformer,
13    EmailMessage,
14};
15
16/// The type of match for a routing rule.
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub enum MatchType {
19    /// Matches an exact email address (e.g. "admin@example.com").
20    ExactAddress,
21    /// Matches all users at a domain (e.g. "example.com").
22    Domain,
23    /// Matches all subdomains (e.g. "*.example.com").
24    WildcardDomain,
25    /// Matches any recipient (e.g. "*").
26    CatchAll,
27}
28
29/// A single routing rule that maps a pattern to a handler.
30pub struct RoutingRule {
31    pub match_type: MatchType,
32    pub pattern: String,
33    pub handler: Arc<dyn MessageHandler>,
34    pub transformers: Vec<Box<dyn MessageTransformer>>,
35    pub auth_required: Option<bool>,
36}
37
38impl RoutingRule {
39    /// Tests if this rule matches the given recipient address.
40    pub fn matches(&self, address: &str) -> bool {
41        match self.match_type {
42            MatchType::ExactAddress => address.eq_ignore_ascii_case(&self.pattern),
43            MatchType::Domain => {
44                if let Some(domain) = address.rsplit('@').next() {
45                    domain.eq_ignore_ascii_case(&self.pattern)
46                } else {
47                    false
48                }
49            }
50            MatchType::WildcardDomain => {
51                let wildcard = self.pattern.strip_prefix("*.").unwrap_or(&self.pattern);
52                if let Some(domain) = address.rsplit('@').next() {
53                    // Match the domain itself or any subdomain
54                    domain.eq_ignore_ascii_case(wildcard)
55                        || domain
56                            .to_ascii_lowercase()
57                            .ends_with(&format!(".{}", wildcard.to_ascii_lowercase()))
58                } else {
59                    false
60                }
61            }
62            MatchType::CatchAll => true,
63        }
64    }
65}
66
67/// Routes incoming email messages to handlers based on recipient address rules.
68///
69/// Rules are evaluated in specificity order: exact address > domain > wildcard domain > catch-all.
70/// If no rule matches, the default handler is used.
71pub struct MessageRouter {
72    rules: Vec<RoutingRule>,
73    default_handler: Arc<dyn MessageHandler>,
74    default_transformers: Vec<Box<dyn MessageTransformer>>,
75}
76
77impl std::fmt::Debug for MessageRouter {
78    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79        f.debug_struct("MessageRouter")
80            .field("rules", &self.rules.len())
81            .finish()
82    }
83}
84
85impl MessageRouter {
86    /// Creates a new [`MessageRouter`] with the given rules, default handler, and
87    /// default transformers.
88    ///
89    /// Rules are automatically sorted by specificity (exact > domain > wildcard > catch-all).
90    pub fn new(
91        mut rules: Vec<RoutingRule>,
92        default_handler: Arc<dyn MessageHandler>,
93        default_transformers: Vec<Box<dyn MessageTransformer>>,
94    ) -> Self {
95        // Sort by specificity: ExactAddress first, then Domain, then WildcardDomain, then CatchAll
96        rules.sort_by_key(|r| match r.match_type {
97            MatchType::ExactAddress => 0,
98            MatchType::Domain => 1,
99            MatchType::WildcardDomain => 2,
100            MatchType::CatchAll => 3,
101        });
102        Self {
103            rules,
104            default_handler,
105            default_transformers,
106        }
107    }
108
109    /// Resolves which handler should process a message for the given recipient.
110    pub fn resolve(&self, recipient: &str) -> &Arc<dyn MessageHandler> {
111        for rule in &self.rules {
112            if rule.matches(recipient) {
113                return &rule.handler;
114            }
115        }
116        &self.default_handler
117    }
118
119    /// Resolves the transformers for a given recipient.
120    ///
121    /// Returns rule-specific transformers if the matching rule defines them,
122    /// otherwise falls back to the default transformers.
123    fn resolve_transformers(&self, recipient: &str) -> &[Box<dyn MessageTransformer>] {
124        for rule in &self.rules {
125            if rule.matches(recipient) {
126                if !rule.transformers.is_empty() {
127                    return &rule.transformers;
128                }
129                return &self.default_transformers;
130            }
131        }
132        &self.default_transformers
133    }
134
135    /// Routes a message to the appropriate handler based on the recipient address.
136    ///
137    /// Applies transformers before dispatching to the handler.
138    pub async fn route(&self, message: &mut EmailMessage) -> HandlerResult<()> {
139        let transformers = self.resolve_transformers(&message.to);
140        <crate::MessageIdTransformer as MessageTransformer>::apply(transformers, message).await;
141
142        let handler = self.resolve(&message.to);
143        handler.handle(message).await
144    }
145
146    /// Resolves whether authentication is required for a recipient.
147    ///
148    /// Returns the rule's [`RoutingRule::auth_required`] if the matching rule defines it,
149    /// otherwise returns the provided global default.
150    pub fn resolve_auth_required(&self, recipient: &str, global_default: bool) -> bool {
151        for rule in &self.rules {
152            if rule.matches(recipient) {
153                if let Some(auth_req) = rule.auth_required {
154                    return auth_req;
155                }
156                break;
157            }
158        }
159        global_default
160    }
161
162    /// Probes the handler that would process `recipient` and, if that handler
163    /// refuses every delivery, returns the SMTP reply `(code, message)` it
164    /// would emit.
165    ///
166    /// Used by SMTP front-ends to short-circuit at `RCPT TO` time so the
167    /// client receives a proper rejection instead of a post-`DATA` bounce.
168    pub fn rejection_for(&self, recipient: &str) -> Option<(u16, String)> {
169        self.resolve(recipient).reject_reply()
170    }
171
172    /// Returns a reference to the default handler.
173    pub fn default_handler(&self) -> &Arc<dyn MessageHandler> {
174        &self.default_handler
175    }
176}
177
178/// Determines the match type from a routing rule configuration.
179///
180/// Returns [`ExactAddress`](MatchType::ExactAddress) if an address field is present,
181/// [`CatchAll`](MatchType::CatchAll) if the domain is exactly `"*"`,
182/// [`WildcardDomain`](MatchType::WildcardDomain) if the domain starts with `*.`,
183/// or [`Domain`](MatchType::Domain) otherwise.
184pub fn determine_match_type(address: &Option<String>, domain: &Option<String>) -> MatchType {
185    if address.is_some() {
186        MatchType::ExactAddress
187    } else if let Some(d) = domain {
188        if d == "*" {
189            MatchType::CatchAll
190        } else if d.starts_with("*.") {
191            MatchType::WildcardDomain
192        } else {
193            MatchType::Domain
194        }
195    } else {
196        MatchType::Domain
197    }
198}
199
200/// Extracts the pattern string from a routing rule configuration.
201pub fn extract_pattern(address: &Option<String>, domain: &Option<String>) -> String {
202    address
203        .as_ref()
204        .or(domain.as_ref())
205        .cloned()
206        .unwrap_or_default()
207}
208
209#[cfg(test)]
210mod tests {
211    use std::sync::atomic::{AtomicUsize, Ordering};
212
213    use super::*;
214    use crate::handler::HandlerFuture;
215
216    struct CountingHandler {
217        name: &'static str,
218        count: AtomicUsize,
219    }
220
221    impl CountingHandler {
222        fn new(name: &'static str) -> Self {
223            Self {
224                name,
225                count: AtomicUsize::new(0),
226            }
227        }
228
229        fn count(&self) -> usize {
230            self.count.load(Ordering::SeqCst)
231        }
232    }
233
234    impl MessageHandler for CountingHandler {
235        fn handle<'a>(&'a self, _message: &'a EmailMessage) -> HandlerFuture<'a> {
236            Box::pin(async move {
237                self.count.fetch_add(1, Ordering::SeqCst);
238                Ok(())
239            })
240        }
241
242        fn name(&self) -> &str {
243            self.name
244        }
245    }
246
247    #[test]
248    fn test_exact_address_match() {
249        let handler = Arc::new(CountingHandler::new("test"));
250        let rule = RoutingRule {
251            match_type: MatchType::ExactAddress,
252            pattern: "admin@example.com".to_string(),
253            handler,
254            transformers: vec![],
255            auth_required: None,
256        };
257
258        assert!(rule.matches("admin@example.com"));
259        assert!(rule.matches("ADMIN@EXAMPLE.COM"));
260        assert!(!rule.matches("user@example.com"));
261        assert!(!rule.matches("admin@other.com"));
262    }
263
264    #[test]
265    fn test_domain_match() {
266        let handler = Arc::new(CountingHandler::new("test"));
267        let rule = RoutingRule {
268            match_type: MatchType::Domain,
269            pattern: "example.com".to_string(),
270            handler,
271            transformers: vec![],
272            auth_required: None,
273        };
274
275        assert!(rule.matches("user@example.com"));
276        assert!(rule.matches("admin@example.com"));
277        assert!(rule.matches("user@EXAMPLE.COM"));
278        assert!(!rule.matches("user@other.com"));
279        assert!(!rule.matches("user@sub.example.com"));
280    }
281
282    #[test]
283    fn test_wildcard_domain_match() {
284        let handler = Arc::new(CountingHandler::new("test"));
285        let rule = RoutingRule {
286            match_type: MatchType::WildcardDomain,
287            pattern: "*.example.com".to_string(),
288            handler,
289            transformers: vec![],
290            auth_required: None,
291        };
292
293        assert!(rule.matches("user@sub.example.com"));
294        assert!(rule.matches("user@deep.sub.example.com"));
295        assert!(rule.matches("user@example.com"));
296        assert!(!rule.matches("user@other.com"));
297    }
298
299    #[test]
300    fn test_catch_all_match() {
301        let handler = Arc::new(CountingHandler::new("test"));
302        let rule = RoutingRule {
303            match_type: MatchType::CatchAll,
304            pattern: "*".to_string(),
305            handler,
306            transformers: vec![],
307            auth_required: None,
308        };
309
310        assert!(rule.matches("user@example.com"));
311        assert!(rule.matches("user@other.com"));
312        assert!(rule.matches("admin@deep.sub.example.com"));
313        assert!(rule.matches("malformed-no-at-sign"));
314        assert!(rule.matches(""));
315    }
316
317    #[tokio::test]
318    async fn test_router_specificity_order() {
319        let exact_handler = Arc::new(CountingHandler::new("exact"));
320        let domain_handler = Arc::new(CountingHandler::new("domain"));
321        let default_handler = Arc::new(CountingHandler::new("default"));
322
323        let rules = vec![
324            RoutingRule {
325                match_type: MatchType::Domain,
326                pattern: "example.com".to_string(),
327                handler: domain_handler.clone(),
328                transformers: vec![],
329                auth_required: None,
330            },
331            RoutingRule {
332                match_type: MatchType::ExactAddress,
333                pattern: "admin@example.com".to_string(),
334                handler: exact_handler.clone(),
335                transformers: vec![],
336                auth_required: None,
337            },
338        ];
339
340        let router = MessageRouter::new(rules, default_handler.clone(), vec![]);
341
342        // Verify exact match takes priority over domain match
343        let mut msg = EmailMessage::from_raw("sender@test.com", "admin@example.com", "test");
344        router.route(&mut msg).await.unwrap();
345        assert_eq!(exact_handler.count(), 1);
346        assert_eq!(domain_handler.count(), 0);
347
348        // Verify domain match for other users
349        let mut msg = EmailMessage::from_raw("sender@test.com", "user@example.com", "test");
350        router.route(&mut msg).await.unwrap();
351        assert_eq!(domain_handler.count(), 1);
352
353        // Verify default handler for unmatched domains
354        let mut msg = EmailMessage::from_raw("sender@test.com", "user@other.com", "test");
355        router.route(&mut msg).await.unwrap();
356        assert_eq!(default_handler.count(), 1);
357    }
358
359    #[tokio::test]
360    async fn test_router_catch_all_fallback() {
361        let domain_handler = Arc::new(CountingHandler::new("domain"));
362        let catch_all_handler = Arc::new(CountingHandler::new("catch_all"));
363        let default_handler = Arc::new(CountingHandler::new("default"));
364
365        let rules = vec![
366            RoutingRule {
367                match_type: MatchType::CatchAll,
368                pattern: "*".to_string(),
369                handler: catch_all_handler.clone(),
370                transformers: vec![],
371                auth_required: None,
372            },
373            RoutingRule {
374                match_type: MatchType::Domain,
375                pattern: "example.com".to_string(),
376                handler: domain_handler.clone(),
377                transformers: vec![],
378                auth_required: None,
379            },
380        ];
381
382        let router = MessageRouter::new(rules, default_handler.clone(), vec![]);
383
384        // Specific rule wins over catch-all
385        let mut msg = EmailMessage::from_raw("sender@test.com", "user@example.com", "test");
386        router.route(&mut msg).await.unwrap();
387        assert_eq!(domain_handler.count(), 1);
388        assert_eq!(catch_all_handler.count(), 0);
389
390        // Catch-all picks up everything else, default never fires
391        let mut msg = EmailMessage::from_raw("sender@test.com", "user@other.com", "test");
392        router.route(&mut msg).await.unwrap();
393        assert_eq!(catch_all_handler.count(), 1);
394        assert_eq!(default_handler.count(), 0);
395    }
396
397    #[test]
398    fn test_resolve_auth_required() {
399        let handler = Arc::new(CountingHandler::new("test"));
400        let rules = vec![
401            RoutingRule {
402                match_type: MatchType::ExactAddress,
403                pattern: "secure@example.com".to_string(),
404                handler: handler.clone(),
405                transformers: vec![],
406                auth_required: Some(true),
407            },
408            RoutingRule {
409                match_type: MatchType::Domain,
410                pattern: "open.com".to_string(),
411                handler: handler.clone(),
412                transformers: vec![],
413                auth_required: Some(false),
414            },
415            RoutingRule {
416                match_type: MatchType::Domain,
417                pattern: "default.com".to_string(),
418                handler: handler.clone(),
419                transformers: vec![],
420                auth_required: None,
421            },
422        ];
423
424        let router = MessageRouter::new(rules, handler, vec![]);
425
426        // Rule with auth_required=true overrides global
427        assert!(router.resolve_auth_required("secure@example.com", false));
428
429        // Rule with auth_required=false overrides global
430        assert!(!router.resolve_auth_required("user@open.com", true));
431
432        // Rule with auth_required=None falls back to global
433        assert!(router.resolve_auth_required("user@default.com", true));
434        assert!(!router.resolve_auth_required("user@default.com", false));
435
436        // No matching rule falls back to global
437        assert!(router.resolve_auth_required("user@unknown.com", true));
438        assert!(!router.resolve_auth_required("user@unknown.com", false));
439    }
440
441    #[test]
442    fn test_catch_all_auth_required_override() {
443        let handler = Arc::new(CountingHandler::new("test"));
444        let rules = vec![RoutingRule {
445            match_type: MatchType::CatchAll,
446            pattern: "*".to_string(),
447            handler: handler.clone(),
448            transformers: vec![],
449            auth_required: Some(true),
450        }];
451
452        let router = MessageRouter::new(rules, handler, vec![]);
453
454        // Catch-all rule forces auth even when global default is false
455        assert!(router.resolve_auth_required("user@whatever.com", false));
456        assert!(router.resolve_auth_required("someone@else.net", false));
457    }
458
459    #[test]
460    fn test_determine_match_type() {
461        assert_eq!(
462            determine_match_type(&Some("user@test.com".to_string()), &None),
463            MatchType::ExactAddress
464        );
465        assert_eq!(
466            determine_match_type(&None, &Some("example.com".to_string())),
467            MatchType::Domain
468        );
469        assert_eq!(
470            determine_match_type(&None, &Some("*.example.com".to_string())),
471            MatchType::WildcardDomain
472        );
473        assert_eq!(
474            determine_match_type(&None, &Some("*".to_string())),
475            MatchType::CatchAll
476        );
477    }
478
479    #[test]
480    fn test_extract_pattern() {
481        assert_eq!(
482            extract_pattern(&Some("user@test.com".to_string()), &None),
483            "user@test.com"
484        );
485        assert_eq!(
486            extract_pattern(&None, &Some("example.com".to_string())),
487            "example.com"
488        );
489    }
490}