Skip to main content

mailsis_utils/
router.rs

1//! Recipient-based message routing for the SMTP pipeline.
2//!
3//! When an email arrives, the router decides which delivery handler should
4//! process it by matching the recipient address against a prioritized list
5//! of rules (exact address, domain, wildcard domain). Unmatched recipients
6//! fall through to a configurable default handler.
7
8use std::sync::Arc;
9
10use crate::{
11    handler::{HandlerResult, MessageHandler},
12    transformer::MessageTransformer,
13    EmailMessage,
14};
15
16/// The type of match for a routing rule.
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub enum MatchType {
19    /// Matches an exact email address (e.g. "admin@example.com").
20    ExactAddress,
21    /// Matches all users at a domain (e.g. "example.com").
22    Domain,
23    /// Matches all subdomains (e.g. "*.example.com").
24    WildcardDomain,
25}
26
27/// A single routing rule that maps a pattern to a handler.
28pub struct RoutingRule {
29    pub match_type: MatchType,
30    pub pattern: String,
31    pub handler: Arc<dyn MessageHandler>,
32    pub transformers: Vec<Box<dyn MessageTransformer>>,
33    pub auth_required: Option<bool>,
34}
35
36impl RoutingRule {
37    /// Tests if this rule matches the given recipient address.
38    pub fn matches(&self, address: &str) -> bool {
39        match self.match_type {
40            MatchType::ExactAddress => address.eq_ignore_ascii_case(&self.pattern),
41            MatchType::Domain => {
42                if let Some(domain) = address.rsplit('@').next() {
43                    domain.eq_ignore_ascii_case(&self.pattern)
44                } else {
45                    false
46                }
47            }
48            MatchType::WildcardDomain => {
49                let wildcard = self.pattern.strip_prefix("*.").unwrap_or(&self.pattern);
50                if let Some(domain) = address.rsplit('@').next() {
51                    // Match the domain itself or any subdomain
52                    domain.eq_ignore_ascii_case(wildcard)
53                        || domain
54                            .to_ascii_lowercase()
55                            .ends_with(&format!(".{}", wildcard.to_ascii_lowercase()))
56                } else {
57                    false
58                }
59            }
60        }
61    }
62}
63
64/// Routes incoming email messages to handlers based on recipient address rules.
65///
66/// Rules are evaluated in specificity order: exact address > domain > wildcard domain.
67/// If no rule matches, the default handler is used.
68pub struct MessageRouter {
69    rules: Vec<RoutingRule>,
70    default_handler: Arc<dyn MessageHandler>,
71    default_transformers: Vec<Box<dyn MessageTransformer>>,
72}
73
74impl std::fmt::Debug for MessageRouter {
75    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76        f.debug_struct("MessageRouter")
77            .field("rules", &self.rules.len())
78            .finish()
79    }
80}
81
82impl MessageRouter {
83    /// Creates a new [`MessageRouter`] with the given rules, default handler, and
84    /// default transformers.
85    ///
86    /// Rules are automatically sorted by specificity (exact > domain > wildcard).
87    pub fn new(
88        mut rules: Vec<RoutingRule>,
89        default_handler: Arc<dyn MessageHandler>,
90        default_transformers: Vec<Box<dyn MessageTransformer>>,
91    ) -> Self {
92        // Sort by specificity: ExactAddress first, then Domain, then WildcardDomain
93        rules.sort_by_key(|r| match r.match_type {
94            MatchType::ExactAddress => 0,
95            MatchType::Domain => 1,
96            MatchType::WildcardDomain => 2,
97        });
98        Self {
99            rules,
100            default_handler,
101            default_transformers,
102        }
103    }
104
105    /// Resolves which handler should process a message for the given recipient.
106    pub fn resolve(&self, recipient: &str) -> &Arc<dyn MessageHandler> {
107        for rule in &self.rules {
108            if rule.matches(recipient) {
109                return &rule.handler;
110            }
111        }
112        &self.default_handler
113    }
114
115    /// Resolves the transformers for a given recipient.
116    ///
117    /// Returns rule-specific transformers if the matching rule defines them,
118    /// otherwise falls back to the default transformers.
119    fn resolve_transformers(&self, recipient: &str) -> &[Box<dyn MessageTransformer>] {
120        for rule in &self.rules {
121            if rule.matches(recipient) {
122                if !rule.transformers.is_empty() {
123                    return &rule.transformers;
124                }
125                return &self.default_transformers;
126            }
127        }
128        &self.default_transformers
129    }
130
131    /// Routes a message to the appropriate handler based on the recipient address.
132    ///
133    /// Applies transformers before dispatching to the handler.
134    pub async fn route(&self, message: &mut EmailMessage) -> HandlerResult<()> {
135        let transformers = self.resolve_transformers(&message.to);
136        <crate::MessageIdTransformer as MessageTransformer>::apply(transformers, message).await;
137
138        let handler = self.resolve(&message.to);
139        handler.handle(message).await
140    }
141
142    /// Resolves whether authentication is required for a recipient.
143    ///
144    /// Returns the rule's [`RoutingRule::auth_required`] if the matching rule defines it,
145    /// otherwise returns the provided global default.
146    pub fn resolve_auth_required(&self, recipient: &str, global_default: bool) -> bool {
147        for rule in &self.rules {
148            if rule.matches(recipient) {
149                if let Some(auth_req) = rule.auth_required {
150                    return auth_req;
151                }
152                break;
153            }
154        }
155        global_default
156    }
157
158    /// Returns a reference to the default handler.
159    pub fn default_handler(&self) -> &Arc<dyn MessageHandler> {
160        &self.default_handler
161    }
162
163    /// Probes the handler that would process `recipient` and, if that handler
164    /// refuses every delivery, returns the SMTP reply `(code, message)` it
165    /// would emit.
166    ///
167    /// Used by SMTP front-ends to short-circuit at `RCPT TO` time so the
168    /// client receives a proper rejection instead of a post-`DATA` bounce.
169    pub fn rejection_for(&self, recipient: &str) -> Option<(u16, String)> {
170        self.resolve(recipient).reject_reply()
171    }
172}
173
174/// Determines the match type from a routing rule configuration.
175///
176/// Returns [`ExactAddress`](MatchType::ExactAddress) if an address field is present,
177/// [`WildcardDomain`](MatchType::WildcardDomain) if the domain starts with `*.`,
178/// or [`Domain`](MatchType::Domain) otherwise.
179pub fn determine_match_type(address: &Option<String>, domain: &Option<String>) -> MatchType {
180    if address.is_some() {
181        MatchType::ExactAddress
182    } else if let Some(d) = domain {
183        if d.starts_with("*.") {
184            MatchType::WildcardDomain
185        } else {
186            MatchType::Domain
187        }
188    } else {
189        MatchType::Domain
190    }
191}
192
193/// Extracts the pattern string from a routing rule configuration.
194pub fn extract_pattern(address: &Option<String>, domain: &Option<String>) -> String {
195    address
196        .as_ref()
197        .or(domain.as_ref())
198        .cloned()
199        .unwrap_or_default()
200}
201
202#[cfg(test)]
203mod tests {
204    use std::sync::atomic::{AtomicUsize, Ordering};
205
206    use super::*;
207    use crate::handler::HandlerFuture;
208
209    struct CountingHandler {
210        name: &'static str,
211        count: AtomicUsize,
212    }
213
214    impl CountingHandler {
215        fn new(name: &'static str) -> Self {
216            Self {
217                name,
218                count: AtomicUsize::new(0),
219            }
220        }
221
222        fn count(&self) -> usize {
223            self.count.load(Ordering::SeqCst)
224        }
225    }
226
227    impl MessageHandler for CountingHandler {
228        fn handle<'a>(&'a self, _message: &'a EmailMessage) -> HandlerFuture<'a> {
229            Box::pin(async move {
230                self.count.fetch_add(1, Ordering::SeqCst);
231                Ok(())
232            })
233        }
234
235        fn name(&self) -> &str {
236            self.name
237        }
238    }
239
240    #[test]
241    fn test_exact_address_match() {
242        let handler = Arc::new(CountingHandler::new("test"));
243        let rule = RoutingRule {
244            match_type: MatchType::ExactAddress,
245            pattern: "admin@example.com".to_string(),
246            handler,
247            transformers: vec![],
248            auth_required: None,
249        };
250
251        assert!(rule.matches("admin@example.com"));
252        assert!(rule.matches("ADMIN@EXAMPLE.COM"));
253        assert!(!rule.matches("user@example.com"));
254        assert!(!rule.matches("admin@other.com"));
255    }
256
257    #[test]
258    fn test_domain_match() {
259        let handler = Arc::new(CountingHandler::new("test"));
260        let rule = RoutingRule {
261            match_type: MatchType::Domain,
262            pattern: "example.com".to_string(),
263            handler,
264            transformers: vec![],
265            auth_required: None,
266        };
267
268        assert!(rule.matches("user@example.com"));
269        assert!(rule.matches("admin@example.com"));
270        assert!(rule.matches("user@EXAMPLE.COM"));
271        assert!(!rule.matches("user@other.com"));
272        assert!(!rule.matches("user@sub.example.com"));
273    }
274
275    #[test]
276    fn test_wildcard_domain_match() {
277        let handler = Arc::new(CountingHandler::new("test"));
278        let rule = RoutingRule {
279            match_type: MatchType::WildcardDomain,
280            pattern: "*.example.com".to_string(),
281            handler,
282            transformers: vec![],
283            auth_required: None,
284        };
285
286        assert!(rule.matches("user@sub.example.com"));
287        assert!(rule.matches("user@deep.sub.example.com"));
288        assert!(rule.matches("user@example.com"));
289        assert!(!rule.matches("user@other.com"));
290    }
291
292    #[tokio::test]
293    async fn test_router_specificity_order() {
294        let exact_handler = Arc::new(CountingHandler::new("exact"));
295        let domain_handler = Arc::new(CountingHandler::new("domain"));
296        let default_handler = Arc::new(CountingHandler::new("default"));
297
298        let rules = vec![
299            RoutingRule {
300                match_type: MatchType::Domain,
301                pattern: "example.com".to_string(),
302                handler: domain_handler.clone(),
303                transformers: vec![],
304                auth_required: None,
305            },
306            RoutingRule {
307                match_type: MatchType::ExactAddress,
308                pattern: "admin@example.com".to_string(),
309                handler: exact_handler.clone(),
310                transformers: vec![],
311                auth_required: None,
312            },
313        ];
314
315        let router = MessageRouter::new(rules, default_handler.clone(), vec![]);
316
317        // Verify exact match takes priority over domain match
318        let mut msg = EmailMessage::from_raw("sender@test.com", "admin@example.com", "test");
319        router.route(&mut msg).await.unwrap();
320        assert_eq!(exact_handler.count(), 1);
321        assert_eq!(domain_handler.count(), 0);
322
323        // Verify domain match for other users
324        let mut msg = EmailMessage::from_raw("sender@test.com", "user@example.com", "test");
325        router.route(&mut msg).await.unwrap();
326        assert_eq!(domain_handler.count(), 1);
327
328        // Verify default handler for unmatched domains
329        let mut msg = EmailMessage::from_raw("sender@test.com", "user@other.com", "test");
330        router.route(&mut msg).await.unwrap();
331        assert_eq!(default_handler.count(), 1);
332    }
333
334    #[test]
335    fn test_resolve_auth_required() {
336        let handler = Arc::new(CountingHandler::new("test"));
337        let rules = vec![
338            RoutingRule {
339                match_type: MatchType::ExactAddress,
340                pattern: "secure@example.com".to_string(),
341                handler: handler.clone(),
342                transformers: vec![],
343                auth_required: Some(true),
344            },
345            RoutingRule {
346                match_type: MatchType::Domain,
347                pattern: "open.com".to_string(),
348                handler: handler.clone(),
349                transformers: vec![],
350                auth_required: Some(false),
351            },
352            RoutingRule {
353                match_type: MatchType::Domain,
354                pattern: "default.com".to_string(),
355                handler: handler.clone(),
356                transformers: vec![],
357                auth_required: None,
358            },
359        ];
360
361        let router = MessageRouter::new(rules, handler, vec![]);
362
363        // Rule with auth_required=true overrides global
364        assert!(router.resolve_auth_required("secure@example.com", false));
365
366        // Rule with auth_required=false overrides global
367        assert!(!router.resolve_auth_required("user@open.com", true));
368
369        // Rule with auth_required=None falls back to global
370        assert!(router.resolve_auth_required("user@default.com", true));
371        assert!(!router.resolve_auth_required("user@default.com", false));
372
373        // No matching rule falls back to global
374        assert!(router.resolve_auth_required("user@unknown.com", true));
375        assert!(!router.resolve_auth_required("user@unknown.com", false));
376    }
377
378    #[test]
379    fn test_determine_match_type() {
380        assert_eq!(
381            determine_match_type(&Some("user@test.com".to_string()), &None),
382            MatchType::ExactAddress
383        );
384        assert_eq!(
385            determine_match_type(&None, &Some("example.com".to_string())),
386            MatchType::Domain
387        );
388        assert_eq!(
389            determine_match_type(&None, &Some("*.example.com".to_string())),
390            MatchType::WildcardDomain
391        );
392    }
393
394    #[test]
395    fn test_extract_pattern() {
396        assert_eq!(
397            extract_pattern(&Some("user@test.com".to_string()), &None),
398            "user@test.com"
399        );
400        assert_eq!(
401            extract_pattern(&None, &Some("example.com".to_string())),
402            "example.com"
403        );
404    }
405}