Skip to main content

mailsis_utils/
router.rs

1//! Recipient-based message routing for the SMTP pipeline.
2//!
3//! When an email arrives, the router decides which delivery handler should
4//! process it by matching the recipient address against a prioritized list
5//! of rules (exact address, domain, wildcard domain). Unmatched recipients
6//! fall through to a configurable default handler.
7
8use std::sync::Arc;
9
10use crate::{
11    handler::{HandlerResult, MessageHandler},
12    transformer::MessageTransformer,
13    EmailMessage,
14};
15
16/// The type of match for a routing rule.
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub enum MatchType {
19    /// Matches an exact email address (e.g. "admin@example.com").
20    ExactAddress,
21    /// Matches all users at a domain (e.g. "example.com").
22    Domain,
23    /// Matches all subdomains (e.g. "*.example.com").
24    WildcardDomain,
25}
26
27/// A single routing rule that maps a pattern to a handler.
28pub struct RoutingRule {
29    pub match_type: MatchType,
30    pub pattern: String,
31    pub handler: Arc<dyn MessageHandler>,
32    pub transformers: Vec<Box<dyn MessageTransformer>>,
33    pub auth_required: Option<bool>,
34}
35
36impl RoutingRule {
37    /// Tests if this rule matches the given recipient address.
38    pub fn matches(&self, address: &str) -> bool {
39        match self.match_type {
40            MatchType::ExactAddress => address.eq_ignore_ascii_case(&self.pattern),
41            MatchType::Domain => {
42                if let Some(domain) = address.rsplit('@').next() {
43                    domain.eq_ignore_ascii_case(&self.pattern)
44                } else {
45                    false
46                }
47            }
48            MatchType::WildcardDomain => {
49                let wildcard = self.pattern.strip_prefix("*.").unwrap_or(&self.pattern);
50                if let Some(domain) = address.rsplit('@').next() {
51                    // Match the domain itself or any subdomain
52                    domain.eq_ignore_ascii_case(wildcard)
53                        || domain
54                            .to_ascii_lowercase()
55                            .ends_with(&format!(".{}", wildcard.to_ascii_lowercase()))
56                } else {
57                    false
58                }
59            }
60        }
61    }
62}
63
64/// Routes incoming email messages to handlers based on recipient address rules.
65///
66/// Rules are evaluated in specificity order: exact address > domain > wildcard domain.
67/// If no rule matches, the default handler is used.
68pub struct MessageRouter {
69    rules: Vec<RoutingRule>,
70    default_handler: Arc<dyn MessageHandler>,
71    default_transformers: Vec<Box<dyn MessageTransformer>>,
72}
73
74impl std::fmt::Debug for MessageRouter {
75    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76        f.debug_struct("MessageRouter")
77            .field("rules", &self.rules.len())
78            .finish()
79    }
80}
81
82impl MessageRouter {
83    /// Creates a new [`MessageRouter`] with the given rules, default handler, and
84    /// default transformers.
85    ///
86    /// Rules are automatically sorted by specificity (exact > domain > wildcard).
87    pub fn new(
88        mut rules: Vec<RoutingRule>,
89        default_handler: Arc<dyn MessageHandler>,
90        default_transformers: Vec<Box<dyn MessageTransformer>>,
91    ) -> Self {
92        // Sort by specificity: ExactAddress first, then Domain, then WildcardDomain
93        rules.sort_by_key(|r| match r.match_type {
94            MatchType::ExactAddress => 0,
95            MatchType::Domain => 1,
96            MatchType::WildcardDomain => 2,
97        });
98        Self {
99            rules,
100            default_handler,
101            default_transformers,
102        }
103    }
104
105    /// Resolves which handler should process a message for the given recipient.
106    pub fn resolve(&self, recipient: &str) -> &Arc<dyn MessageHandler> {
107        for rule in &self.rules {
108            if rule.matches(recipient) {
109                return &rule.handler;
110            }
111        }
112        &self.default_handler
113    }
114
115    /// Resolves the transformers for a given recipient.
116    ///
117    /// Returns rule-specific transformers if the matching rule defines them,
118    /// otherwise falls back to the default transformers.
119    fn resolve_transformers(&self, recipient: &str) -> &[Box<dyn MessageTransformer>] {
120        for rule in &self.rules {
121            if rule.matches(recipient) {
122                if !rule.transformers.is_empty() {
123                    return &rule.transformers;
124                }
125                return &self.default_transformers;
126            }
127        }
128        &self.default_transformers
129    }
130
131    /// Routes a message to the appropriate handler based on the recipient address.
132    ///
133    /// Applies transformers before dispatching to the handler.
134    pub async fn route(&self, message: &mut EmailMessage) -> HandlerResult<()> {
135        let transformers = self.resolve_transformers(&message.to);
136        <crate::MessageIdTransformer as MessageTransformer>::apply(transformers, message).await;
137
138        let handler = self.resolve(&message.to);
139        handler.handle(message).await
140    }
141
142    /// Resolves whether authentication is required for a recipient.
143    ///
144    /// Returns the rule's [`RoutingRule::auth_required`] if the matching rule defines it,
145    /// otherwise returns the provided global default.
146    pub fn resolve_auth_required(&self, recipient: &str, global_default: bool) -> bool {
147        for rule in &self.rules {
148            if rule.matches(recipient) {
149                if let Some(auth_req) = rule.auth_required {
150                    return auth_req;
151                }
152                break;
153            }
154        }
155        global_default
156    }
157
158    /// Returns a reference to the default handler.
159    pub fn default_handler(&self) -> &Arc<dyn MessageHandler> {
160        &self.default_handler
161    }
162}
163
164/// Determines the match type from a routing rule configuration.
165///
166/// Returns [`ExactAddress`](MatchType::ExactAddress) if an address field is present,
167/// [`WildcardDomain`](MatchType::WildcardDomain) if the domain starts with `*.`,
168/// or [`Domain`](MatchType::Domain) otherwise.
169pub fn determine_match_type(address: &Option<String>, domain: &Option<String>) -> MatchType {
170    if address.is_some() {
171        MatchType::ExactAddress
172    } else if let Some(d) = domain {
173        if d.starts_with("*.") {
174            MatchType::WildcardDomain
175        } else {
176            MatchType::Domain
177        }
178    } else {
179        MatchType::Domain
180    }
181}
182
183/// Extracts the pattern string from a routing rule configuration.
184pub fn extract_pattern(address: &Option<String>, domain: &Option<String>) -> String {
185    address
186        .as_ref()
187        .or(domain.as_ref())
188        .cloned()
189        .unwrap_or_default()
190}
191
192#[cfg(test)]
193mod tests {
194    use std::sync::atomic::{AtomicUsize, Ordering};
195
196    use super::*;
197    use crate::handler::HandlerFuture;
198
199    struct CountingHandler {
200        name: &'static str,
201        count: AtomicUsize,
202    }
203
204    impl CountingHandler {
205        fn new(name: &'static str) -> Self {
206            Self {
207                name,
208                count: AtomicUsize::new(0),
209            }
210        }
211
212        fn count(&self) -> usize {
213            self.count.load(Ordering::SeqCst)
214        }
215    }
216
217    impl MessageHandler for CountingHandler {
218        fn handle<'a>(&'a self, _message: &'a EmailMessage) -> HandlerFuture<'a> {
219            Box::pin(async move {
220                self.count.fetch_add(1, Ordering::SeqCst);
221                Ok(())
222            })
223        }
224
225        fn name(&self) -> &str {
226            self.name
227        }
228    }
229
230    #[test]
231    fn test_exact_address_match() {
232        let handler = Arc::new(CountingHandler::new("test"));
233        let rule = RoutingRule {
234            match_type: MatchType::ExactAddress,
235            pattern: "admin@example.com".to_string(),
236            handler,
237            transformers: vec![],
238            auth_required: None,
239        };
240
241        assert!(rule.matches("admin@example.com"));
242        assert!(rule.matches("ADMIN@EXAMPLE.COM"));
243        assert!(!rule.matches("user@example.com"));
244        assert!(!rule.matches("admin@other.com"));
245    }
246
247    #[test]
248    fn test_domain_match() {
249        let handler = Arc::new(CountingHandler::new("test"));
250        let rule = RoutingRule {
251            match_type: MatchType::Domain,
252            pattern: "example.com".to_string(),
253            handler,
254            transformers: vec![],
255            auth_required: None,
256        };
257
258        assert!(rule.matches("user@example.com"));
259        assert!(rule.matches("admin@example.com"));
260        assert!(rule.matches("user@EXAMPLE.COM"));
261        assert!(!rule.matches("user@other.com"));
262        assert!(!rule.matches("user@sub.example.com"));
263    }
264
265    #[test]
266    fn test_wildcard_domain_match() {
267        let handler = Arc::new(CountingHandler::new("test"));
268        let rule = RoutingRule {
269            match_type: MatchType::WildcardDomain,
270            pattern: "*.example.com".to_string(),
271            handler,
272            transformers: vec![],
273            auth_required: None,
274        };
275
276        assert!(rule.matches("user@sub.example.com"));
277        assert!(rule.matches("user@deep.sub.example.com"));
278        assert!(rule.matches("user@example.com"));
279        assert!(!rule.matches("user@other.com"));
280    }
281
282    #[tokio::test]
283    async fn test_router_specificity_order() {
284        let exact_handler = Arc::new(CountingHandler::new("exact"));
285        let domain_handler = Arc::new(CountingHandler::new("domain"));
286        let default_handler = Arc::new(CountingHandler::new("default"));
287
288        let rules = vec![
289            RoutingRule {
290                match_type: MatchType::Domain,
291                pattern: "example.com".to_string(),
292                handler: domain_handler.clone(),
293                transformers: vec![],
294                auth_required: None,
295            },
296            RoutingRule {
297                match_type: MatchType::ExactAddress,
298                pattern: "admin@example.com".to_string(),
299                handler: exact_handler.clone(),
300                transformers: vec![],
301                auth_required: None,
302            },
303        ];
304
305        let router = MessageRouter::new(rules, default_handler.clone(), vec![]);
306
307        // Verify exact match takes priority over domain match
308        let mut msg = EmailMessage::from_raw("sender@test.com", "admin@example.com", "test");
309        router.route(&mut msg).await.unwrap();
310        assert_eq!(exact_handler.count(), 1);
311        assert_eq!(domain_handler.count(), 0);
312
313        // Verify domain match for other users
314        let mut msg = EmailMessage::from_raw("sender@test.com", "user@example.com", "test");
315        router.route(&mut msg).await.unwrap();
316        assert_eq!(domain_handler.count(), 1);
317
318        // Verify default handler for unmatched domains
319        let mut msg = EmailMessage::from_raw("sender@test.com", "user@other.com", "test");
320        router.route(&mut msg).await.unwrap();
321        assert_eq!(default_handler.count(), 1);
322    }
323
324    #[test]
325    fn test_resolve_auth_required() {
326        let handler = Arc::new(CountingHandler::new("test"));
327        let rules = vec![
328            RoutingRule {
329                match_type: MatchType::ExactAddress,
330                pattern: "secure@example.com".to_string(),
331                handler: handler.clone(),
332                transformers: vec![],
333                auth_required: Some(true),
334            },
335            RoutingRule {
336                match_type: MatchType::Domain,
337                pattern: "open.com".to_string(),
338                handler: handler.clone(),
339                transformers: vec![],
340                auth_required: Some(false),
341            },
342            RoutingRule {
343                match_type: MatchType::Domain,
344                pattern: "default.com".to_string(),
345                handler: handler.clone(),
346                transformers: vec![],
347                auth_required: None,
348            },
349        ];
350
351        let router = MessageRouter::new(rules, handler, vec![]);
352
353        // Rule with auth_required=true overrides global
354        assert!(router.resolve_auth_required("secure@example.com", false));
355
356        // Rule with auth_required=false overrides global
357        assert!(!router.resolve_auth_required("user@open.com", true));
358
359        // Rule with auth_required=None falls back to global
360        assert!(router.resolve_auth_required("user@default.com", true));
361        assert!(!router.resolve_auth_required("user@default.com", false));
362
363        // No matching rule falls back to global
364        assert!(router.resolve_auth_required("user@unknown.com", true));
365        assert!(!router.resolve_auth_required("user@unknown.com", false));
366    }
367
368    #[test]
369    fn test_determine_match_type() {
370        assert_eq!(
371            determine_match_type(&Some("user@test.com".to_string()), &None),
372            MatchType::ExactAddress
373        );
374        assert_eq!(
375            determine_match_type(&None, &Some("example.com".to_string())),
376            MatchType::Domain
377        );
378        assert_eq!(
379            determine_match_type(&None, &Some("*.example.com".to_string())),
380            MatchType::WildcardDomain
381        );
382    }
383
384    #[test]
385    fn test_extract_pattern() {
386        assert_eq!(
387            extract_pattern(&Some("user@test.com".to_string()), &None),
388            "user@test.com"
389        );
390        assert_eq!(
391            extract_pattern(&None, &Some("example.com".to_string())),
392            "example.com"
393        );
394    }
395}