1#[cfg(feature = "debug")]
25mod prettyprint;
26
27use std::collections::HashMap;
28
29const WILDCARD_SUFFIX: &str = "/*";
31
32#[derive(Debug, Clone)]
34struct RadixNode<T> {
35 prefix: String,
37 children: HashMap<char, RadixNode<T>>,
39 exact_value: Option<T>,
41 wildcard_value: Option<T>,
43}
44
45impl<T> RadixNode<T> {
46 fn new(prefix: String) -> Self {
48 Self {
49 prefix,
50 children: HashMap::new(),
51 exact_value: None,
52 wildcard_value: None,
53 }
54 }
55
56 fn insert(&mut self, path: &str, value: T, is_wildcard: bool) {
58 if path.is_empty() {
59 self.store_value(value, is_wildcard);
60 return;
61 }
62
63 let common_length = self.count_common_prefix_chars(path);
64
65 if common_length < self.prefix.len() {
67 self.split_at(common_length);
68 }
69
70 if common_length < path.len() {
72 self.insert_in_child(&path[common_length..], value, is_wildcard);
73 } else {
74 self.store_value(value, is_wildcard);
75 }
76 }
77
78 fn get(&self, path: &str) -> Option<&T> {
80 self.get_with_fallback(path, None)
81 }
82
83 fn remove(&mut self, path: &str, is_wildcard: bool) -> Option<T> {
85 if path.is_empty() {
86 return self.take_value(is_wildcard);
87 }
88
89 let common_length = self.count_common_prefix_chars(path);
90 if common_length != self.prefix.len() {
91 return None; }
93
94 let remaining_path = &path[common_length..];
95 if remaining_path.is_empty() {
96 self.take_value(is_wildcard)
97 } else {
98 self.remove_from_child(remaining_path, is_wildcard)
99 }
100 }
101
102 fn store_value(&mut self, value: T, is_wildcard: bool) {
104 if is_wildcard {
105 self.wildcard_value = Some(value);
106 } else {
107 self.exact_value = Some(value);
108 }
109 }
110
111 fn take_value(&mut self, is_wildcard: bool) -> Option<T> {
113 if is_wildcard {
114 self.wildcard_value.take()
115 } else {
116 self.exact_value.take()
117 }
118 }
119
120 fn count_common_prefix_chars(&self, path: &str) -> usize {
122 self.prefix
123 .chars()
124 .zip(path.chars())
125 .take_while(|(a, b)| a == b)
126 .count()
127 }
128
129 fn get_with_fallback<'a>(&'a self, path: &str, fallback: Option<&'a T>) -> Option<&'a T> {
131 let current_fallback = self.wildcard_value.as_ref().or(fallback);
133
134 if path.is_empty() {
135 return self
136 .exact_value
137 .as_ref()
138 .or(self.wildcard_value.as_ref())
139 .or(fallback);
140 }
141
142 let common_length = self.count_common_prefix_chars(path);
143
144 if common_length == self.prefix.len() {
145 let remaining_path = &path[common_length..];
146
147 if remaining_path.is_empty() {
148 self.exact_value
150 .as_ref()
151 .or(self.wildcard_value.as_ref())
152 .or(current_fallback)
153 } else {
154 self.search_in_child(remaining_path, current_fallback)
156 }
157 } else {
158 fallback
160 }
161 }
162
163 fn insert_in_child(&mut self, remaining_path: &str, value: T, is_wildcard: bool) {
165 let first_char = remaining_path.chars().next().unwrap();
166 self.children
167 .entry(first_char)
168 .or_insert_with(|| RadixNode::new(remaining_path.to_string()))
169 .insert(remaining_path, value, is_wildcard);
170 }
171
172 fn search_in_child<'a>(
174 &'a self,
175 remaining_path: &str,
176 fallback: Option<&'a T>,
177 ) -> Option<&'a T> {
178 let first_char = remaining_path.chars().next().unwrap();
179 if let Some(child) = self.children.get(&first_char) {
180 child.get_with_fallback(remaining_path, fallback)
181 } else {
182 fallback
183 }
184 }
185
186 fn remove_from_child(&mut self, remaining_path: &str, is_wildcard: bool) -> Option<T> {
188 let first_char = remaining_path.chars().next().unwrap();
189 if let Some(child) = self.children.get_mut(&first_char) {
190 child.remove(remaining_path, is_wildcard)
191 } else {
192 None
193 }
194 }
195
196 fn split_at(&mut self, split_position: usize) {
198 if split_position >= self.prefix.len() {
199 return;
200 }
201
202 let suffix = self.prefix.split_off(split_position);
204 let mut new_child = RadixNode::new(suffix.clone());
205
206 new_child.children = std::mem::take(&mut self.children);
208 new_child.exact_value = self.exact_value.take();
209 new_child.wildcard_value = self.wildcard_value.take();
210
211 let first_char = suffix.chars().next().unwrap();
213 self.children.insert(first_char, new_child);
214 }
215}
216
217#[derive(Debug)]
219pub struct Trie<T>(RadixNode<T>);
220
221impl<T> Default for Trie<T> {
222 fn default() -> Self {
223 Self(RadixNode::new(String::new()))
224 }
225}
226
227impl<T> Trie<T> {
228 pub fn new() -> Self {
230 Self::default()
231 }
232
233 pub fn insert(&mut self, path: &str, value: T) {
245 let (clean_path, is_wildcard) = Self::parse_path(path);
246 self.0.insert(clean_path, value, is_wildcard);
247 }
248
249 pub fn get<'a>(&'a self, path: &str) -> Option<&'a T> {
261 self.0.get(path)
262 }
263
264 pub fn remove(&mut self, path: &str) -> Option<T> {
266 let (clean_path, is_wildcard) = Self::parse_path(path);
267 self.0.remove(clean_path, is_wildcard)
268 }
269
270 fn parse_path(path: &str) -> (&str, bool) {
272 if let Some(prefix) = path.strip_suffix(WILDCARD_SUFFIX) {
273 (prefix, true)
274 } else {
275 (path, false)
276 }
277 }
278
279 fn is_empty(&self) -> bool {
281 self.0.children.is_empty()
282 && self.0.exact_value.is_none()
283 && self.0.wildcard_value.is_none()
284 }
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290
291 #[test]
292 fn test_exact_path_matching() {
293 let mut trie = Trie::new();
294 trie.insert("/api/users", "users_handler");
295 trie.insert("/api/posts", "posts_handler");
296
297 assert_eq!(trie.get("/api/users"), Some(&"users_handler"));
298 assert_eq!(trie.get("/api/posts"), Some(&"posts_handler"));
299 assert_eq!(trie.get("/api/other"), None);
300 }
301
302 #[test]
303 fn test_wildcard_matching() {
304 let mut trie = Trie::new();
305 trie.insert("/api/*", "api_handler");
306
307 assert_eq!(trie.get("/api/users"), Some(&"api_handler"));
308 assert_eq!(trie.get("/api/posts/123"), Some(&"api_handler"));
309 assert_eq!(trie.get("/auth/login"), None);
310 }
311
312 #[test]
313 fn test_exact_takes_precedence_over_wildcard() {
314 let mut trie = Trie::new();
315 trie.insert("/api/*", "wildcard_handler");
316 trie.insert("/api/users", "exact_handler");
317
318 assert_eq!(trie.get("/api/users"), Some(&"exact_handler"));
319 assert_eq!(trie.get("/api/posts"), Some(&"wildcard_handler"));
320 }
321
322 #[test]
323 fn test_path_compression() {
324 let mut trie = Trie::new();
325 trie.insert("/api/v1/users", "v1_users");
326 trie.insert("/api/v1/posts", "v1_posts");
327 trie.insert("/api/v2/users", "v2_users");
328
329 assert_eq!(trie.get("/api/v1/users"), Some(&"v1_users"));
330 assert_eq!(trie.get("/api/v1/posts"), Some(&"v1_posts"));
331 assert_eq!(trie.get("/api/v2/users"), Some(&"v2_users"));
332 }
333
334 #[test]
335 fn test_removal() {
336 let mut trie = Trie::new();
337 trie.insert("/api/users", "handler");
338
339 assert_eq!(trie.get("/api/users"), Some(&"handler"));
340 assert_eq!(trie.remove("/api/users"), Some("handler"));
341 assert_eq!(trie.get("/api/users"), None);
342 }
343
344 #[test]
345 fn test_wildcard_removal() {
346 let mut trie = Trie::new();
347 trie.insert("/api/*", "handler");
348
349 assert_eq!(trie.get("/api/users"), Some(&"handler"));
350 assert_eq!(trie.remove("/api/*"), Some("handler"));
351 assert_eq!(trie.get("/api/users"), None);
352 }
353
354 #[test]
355 fn test_root_path() {
356 let mut trie = Trie::new();
357 trie.insert("/", "root_handler");
358 assert_eq!(trie.get("/"), Some(&"root_handler"));
359 }
360
361 #[test]
362 fn test_root_wildcard() {
363 let mut trie = Trie::new();
364 trie.insert("/*", "root_handler");
365 assert_eq!(trie.get("/"), Some(&"root_handler"));
366 }
367
368 #[test]
369 fn test_empty_path() {
370 let mut trie = Trie::new();
371 trie.insert("", "empty_handler");
372 assert_eq!(trie.get(""), Some(&"empty_handler"));
373 }
374
375 #[test]
376 fn test_common_prefix() {
377 let mut trie = Trie::new();
378 trie.insert("long_prefix_one", "one");
379 trie.insert("long_prefix_two", "two");
380 trie.insert("long_prefix_three", "three");
381
382 assert_eq!(trie.get("long_prefix_one"), Some(&"one"));
383 assert_eq!(trie.get("long_prefix_two"), Some(&"two"));
384 assert_eq!(trie.get("long_prefix_three"), Some(&"three"));
385 }
386}