rama_http/matcher/
subdomain_trie.rs1use crate::Request;
2use radix_trie::Trie;
3use rama_core::{Context, context::Extensions, matcher::Matcher};
4use rama_net::address::Host;
5use rama_net::http::RequestContext;
6
7#[derive(Debug, Clone)]
8pub struct SubdomainTrieMatcher {
9 trie: Trie<String, ()>,
10}
11
12impl SubdomainTrieMatcher {
13 pub fn new<I, S>(domains: I) -> Self
14 where
15 I: IntoIterator<Item = S>,
16 S: AsRef<str>,
17 {
18 let mut trie = Trie::new();
19 for domain in domains {
20 let reversed = reverse_domain(domain.as_ref());
21 trie.insert(reversed, ());
22 }
23 Self { trie }
24 }
25
26 pub fn is_match(&self, domain: impl AsRef<str>) -> bool {
31 let reversed = reverse_domain(domain.as_ref());
32 self.trie.get_ancestor(&reversed).is_some()
33 }
34}
35
36fn reverse_domain(domain: &str) -> String {
37 let from = domain.strip_prefix('.').unwrap_or(domain);
38 let mut domain = from.split('.').rev().collect::<Vec<&str>>().join(".");
39 domain.push('.');
40 domain
41}
42
43impl<State, Body> Matcher<State, Request<Body>> for SubdomainTrieMatcher {
44 fn matches(
45 &self,
46 ext: Option<&mut Extensions>,
47 ctx: &Context<State>,
48 req: &Request<Body>,
49 ) -> bool {
50 let match_authority = |ctx: &RequestContext| match ctx.authority.host() {
51 Host::Name(domain) => {
52 let is_match = self.is_match(domain.as_str());
53 tracing::trace!(
54 "SubdomainTrieMatcher: matching domain = {}, matched = {}",
55 domain,
56 is_match
57 );
58 is_match
59 }
60 Host::Address(address) => {
61 tracing::trace!(
62 %address,
63 "SubdomainTrieMatcher: ignoring numeric address",
64 );
65 false
66 }
67 };
68
69 match ctx.get() {
70 Some(req_ctx) => match_authority(req_ctx),
71 None => {
72 let req_ctx: RequestContext = match (ctx, req).try_into() {
73 Ok(rc) => rc,
74 Err(err) => {
75 tracing::debug!(
76 error = %err,
77 "SubdomainTrieMatcher: failed to extract request context",
78 );
79 return false;
80 }
81 };
82 let is_match = match_authority(&req_ctx);
83 if let Some(ext) = ext {
84 ext.insert(req_ctx);
85 }
86 is_match
87 }
88 }
89 }
90}
91
92impl<S> FromIterator<S> for SubdomainTrieMatcher
93where
94 S: AsRef<str>,
95{
96 #[inline]
97 fn from_iter<I: IntoIterator<Item = S>>(iter: I) -> Self {
98 SubdomainTrieMatcher::new(iter)
99 }
100}
101
102#[cfg(test)]
103mod subdomain_trie_tests {
104 use super::*;
105
106 #[test]
107 fn test_reverse_domain() {
108 assert_eq!(reverse_domain("example.com"), "com.example.");
109 assert_eq!(reverse_domain(".example.com"), "com.example.");
110 assert_eq!(reverse_domain("sub.example.com"), "com.example.sub.");
111 assert_eq!(reverse_domain("localhost"), "localhost.");
112 assert_eq!(reverse_domain(""), ".");
113 }
114
115 #[test]
116 fn test_trie_matching() {
117 let matcher = SubdomainTrieMatcher::new(vec!["example.com", "sub.domain.org"]);
118 assert!(matcher.is_match("example.com"));
119 assert!(matcher.is_match(".example.com"));
120 assert!(matcher.is_match("sub.domain.org"));
121 assert!(matcher.is_match("sub.example.com"));
122 assert!(!matcher.is_match("domain.org"));
123 assert!(!matcher.is_match("other.com"));
124 assert!(!matcher.is_match(""));
125 assert!(!matcher.is_match("localhost"));
126 }
127
128 #[test]
129 fn test_path_matching_with_trie() {
130 let domains: Vec<String> = vec!["example.com".to_owned(), "sub.domain.org".to_owned()];
131 let matcher: SubdomainTrieMatcher = domains.into_iter().collect();
132
133 let path = "sub.example.com";
134
135 let request = Request::builder().uri(path).body(()).unwrap();
136 let ctx = Context::default();
137
138 assert!(matcher.matches(None, &ctx, &request));
139 }
140
141 #[test]
142 fn test_non_matching_path() {
143 let domains: Vec<String> = vec!["example.com".to_owned()];
144 let matcher: SubdomainTrieMatcher = domains.into_iter().collect();
145
146 let path = "nonmatching.com";
147
148 let request = Request::builder().uri(path).body(()).unwrap();
149 let ctx = Context::default();
150
151 assert!(!matcher.matches(None, &ctx, &request));
152 }
153}