rama_http/matcher/
subdomain_trie.rs

1use crate::Request;
2use radix_trie::Trie;
3use rama_core::{Context, context::Extensions, matcher::Matcher};
4use rama_net::address::Host;
5use rama_net::http::RequestContext;
6
7#[derive(Debug, Clone)]
8pub struct SubdomainTrieMatcher {
9    trie: Trie<String, ()>,
10}
11
12impl SubdomainTrieMatcher {
13    pub fn new<I, S>(domains: I) -> Self
14    where
15        I: IntoIterator<Item = S>,
16        S: AsRef<str>,
17    {
18        let mut trie = Trie::new();
19        for domain in domains {
20            let reversed = reverse_domain(domain.as_ref());
21            trie.insert(reversed, ());
22        }
23        Self { trie }
24    }
25
26    // Checks if the reversed domain has an ancestor in the trie.
27    //
28    // The domain is reversed to match the way Radix Tries store domains. `get_ancestor` is used
29    // to check if any prefix of the reversed domain exists in the trie, indicating a match.
30    pub fn is_match(&self, domain: impl AsRef<str>) -> bool {
31        let reversed = reverse_domain(domain.as_ref());
32        self.trie.get_ancestor(&reversed).is_some()
33    }
34}
35
36fn reverse_domain(domain: &str) -> String {
37    let from = domain.strip_prefix('.').unwrap_or(domain);
38    let mut domain = from.split('.').rev().collect::<Vec<&str>>().join(".");
39    domain.push('.');
40    domain
41}
42
43impl<State, Body> Matcher<State, Request<Body>> for SubdomainTrieMatcher {
44    fn matches(
45        &self,
46        ext: Option<&mut Extensions>,
47        ctx: &Context<State>,
48        req: &Request<Body>,
49    ) -> bool {
50        let match_authority = |ctx: &RequestContext| match ctx.authority.host() {
51            Host::Name(domain) => {
52                let is_match = self.is_match(domain.as_str());
53                tracing::trace!(
54                    "SubdomainTrieMatcher: matching domain = {}, matched = {}",
55                    domain,
56                    is_match
57                );
58                is_match
59            }
60            Host::Address(address) => {
61                tracing::trace!(
62                    %address,
63                    "SubdomainTrieMatcher: ignoring numeric address",
64                );
65                false
66            }
67        };
68
69        match ctx.get() {
70            Some(req_ctx) => match_authority(req_ctx),
71            None => {
72                let req_ctx: RequestContext = match (ctx, req).try_into() {
73                    Ok(rc) => rc,
74                    Err(err) => {
75                        tracing::debug!(
76                            error = %err,
77                            "SubdomainTrieMatcher: failed to extract request context",
78                        );
79                        return false;
80                    }
81                };
82                let is_match = match_authority(&req_ctx);
83                if let Some(ext) = ext {
84                    ext.insert(req_ctx);
85                }
86                is_match
87            }
88        }
89    }
90}
91
92impl<S> FromIterator<S> for SubdomainTrieMatcher
93where
94    S: AsRef<str>,
95{
96    #[inline]
97    fn from_iter<I: IntoIterator<Item = S>>(iter: I) -> Self {
98        SubdomainTrieMatcher::new(iter)
99    }
100}
101
102#[cfg(test)]
103mod subdomain_trie_tests {
104    use super::*;
105
106    #[test]
107    fn test_reverse_domain() {
108        assert_eq!(reverse_domain("example.com"), "com.example.");
109        assert_eq!(reverse_domain(".example.com"), "com.example.");
110        assert_eq!(reverse_domain("sub.example.com"), "com.example.sub.");
111        assert_eq!(reverse_domain("localhost"), "localhost.");
112        assert_eq!(reverse_domain(""), ".");
113    }
114
115    #[test]
116    fn test_trie_matching() {
117        let matcher = SubdomainTrieMatcher::new(vec!["example.com", "sub.domain.org"]);
118        assert!(matcher.is_match("example.com"));
119        assert!(matcher.is_match(".example.com"));
120        assert!(matcher.is_match("sub.domain.org"));
121        assert!(matcher.is_match("sub.example.com"));
122        assert!(!matcher.is_match("domain.org"));
123        assert!(!matcher.is_match("other.com"));
124        assert!(!matcher.is_match(""));
125        assert!(!matcher.is_match("localhost"));
126    }
127
128    #[test]
129    fn test_path_matching_with_trie() {
130        let domains: Vec<String> = vec!["example.com".to_owned(), "sub.domain.org".to_owned()];
131        let matcher: SubdomainTrieMatcher = domains.into_iter().collect();
132
133        let path = "sub.example.com";
134
135        let request = Request::builder().uri(path).body(()).unwrap();
136        let ctx = Context::default();
137
138        assert!(matcher.matches(None, &ctx, &request));
139    }
140
141    #[test]
142    fn test_non_matching_path() {
143        let domains: Vec<String> = vec!["example.com".to_owned()];
144        let matcher: SubdomainTrieMatcher = domains.into_iter().collect();
145
146        let path = "nonmatching.com";
147
148        let request = Request::builder().uri(path).body(()).unwrap();
149        let ctx = Context::default();
150
151        assert!(!matcher.matches(None, &ctx, &request));
152    }
153}