1use crate::node::{BaseNode, DnsNode, hash_u64};
5use std::collections::HashSet;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
8pub enum ScopeMode {
9 Normal,
11 Strict,
13 Acl,
15}
16
17#[derive(Debug, Clone)]
18pub struct DnsRadixTree {
19 pub root: DnsNode,
20 pub scope_mode: ScopeMode,
21}
22
23impl DnsRadixTree {
24 pub fn new(scope_mode: ScopeMode) -> Self {
25 DnsRadixTree {
26 root: DnsNode::new(),
27 scope_mode,
28 }
29 }
30
31 pub fn insert(&mut self, hostname: &str) -> Option<String> {
34 if self.scope_mode == ScopeMode::Acl && self.get(hostname).is_some() {
36 return None; }
38
39 let parts: Vec<&str> = hostname.split('.').collect();
40 let mut node = &mut self.root;
41 for part in parts.iter().rev() {
42 node = node
43 .children
44 .entry(hash_u64(part))
45 .or_insert_with(|| Box::new(DnsNode::new()));
46 }
47 node.host = Some(hostname.to_string());
48
49 if self.scope_mode == ScopeMode::Acl {
51 node.clear();
52 }
53
54 Some(hostname.to_string())
55 }
56
57 pub fn get(&self, hostname: &str) -> Option<String> {
61 let parts: Vec<&str> = hostname.split('.').collect();
62 let mut node = &self.root;
63 let mut matched: Option<&String> = None;
64 for (i, part) in parts.iter().rev().enumerate() {
65 if let Some(child) = node.children.get(&hash_u64(part)) {
66 node = child;
67 if self.scope_mode == ScopeMode::Strict && i + 1 < parts.len() {
68 continue;
69 }
70 if let Some(host) = &node.host {
71 matched = Some(host);
72 }
73 } else {
74 break;
75 }
76 }
77 matched.cloned()
78 }
79
80 pub fn delete(&mut self, hostname: &str) -> bool {
83 let parts: Vec<&str> = hostname.split('.').collect();
84 Self::delete_rec(&mut self.root, &parts, 0)
85 }
86
87 fn delete_rec(node: &mut DnsNode, parts: &[&str], depth: usize) -> bool {
89 if depth == parts.len() {
90 if node.host.is_some() {
91 node.host = None;
92 return true;
93 }
94 return false;
95 }
96 let part = parts[parts.len() - 1 - depth];
97 if let Some(child) = node.children.get_mut(&hash_u64(part)) {
98 let deleted = Self::delete_rec(child, parts, depth + 1);
99 if child.children.is_empty() && child.host.is_none() {
100 node.children.remove(&hash_u64(part));
101 }
102 return deleted;
103 }
104 false
105 }
106
107 pub fn prune(&mut self) -> usize {
108 self.root.prune()
109 }
110
111 pub fn hosts(&self) -> HashSet<String> {
113 self.root.all_hosts()
114 }
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120
121 fn expected_canonical(host: &str) -> String {
122 host.to_lowercase()
123 }
124
125 #[test]
126 fn test_insert_and_get_basic() {
127 let mut tree = DnsRadixTree::new(ScopeMode::Normal);
128 let canonical1 = tree.insert("example.com").unwrap();
129 assert_eq!(
130 canonical1,
131 expected_canonical("example.com"),
132 "insert(example.com) canonical"
133 );
134 let canonical2 = tree.insert("api.test.www.example.com").unwrap();
135 assert_eq!(
136 canonical2,
137 expected_canonical("api.test.www.example.com"),
138 "insert(api.test.www.example.com) canonical"
139 );
140 assert_eq!(
141 tree.get("example.com"),
142 Some(expected_canonical("example.com"))
143 );
144 assert_eq!(
145 tree.get("api.test.www.example.com"),
146 Some(expected_canonical("api.test.www.example.com"))
147 );
148 assert_eq!(
150 tree.get("wat.hm.api.test.www.example.com"),
151 Some(expected_canonical("api.test.www.example.com"))
152 );
153 assert_eq!(tree.get("notfound.com"), None);
155 }
156
157 #[test]
158 fn test_strict_scope() {
159 let mut tree = DnsRadixTree::new(ScopeMode::Strict);
160 let canonical1 = tree.insert("example.com").unwrap();
161 assert_eq!(
162 canonical1,
163 expected_canonical("example.com"),
164 "insert(example.com) canonical"
165 );
166 let canonical2 = tree.insert("api.test.www.example.com").unwrap();
167 assert_eq!(
168 canonical2,
169 expected_canonical("api.test.www.example.com"),
170 "insert(api.test.www.example.com) canonical"
171 );
172 assert_eq!(
174 tree.get("example.com"),
175 Some(expected_canonical("example.com"))
176 );
177 assert_eq!(
178 tree.get("api.test.www.example.com"),
179 Some(expected_canonical("api.test.www.example.com"))
180 );
181 assert_eq!(tree.get("wat.hm.api.test.www.example.com"), None);
182 assert_eq!(tree.get("notfound.com"), None);
183 }
184
185 #[test]
186 fn test_delete() {
187 let mut tree = DnsRadixTree::new(ScopeMode::Normal);
188 let canonical1 = tree.insert("example.com").unwrap();
189 assert_eq!(
190 canonical1,
191 expected_canonical("example.com"),
192 "insert(example.com) canonical"
193 );
194 let canonical2 = tree.insert("api.test.www.example.com").unwrap();
195 assert_eq!(
196 canonical2,
197 expected_canonical("api.test.www.example.com"),
198 "insert(api.test.www.example.com) canonical"
199 );
200 assert_eq!(
201 tree.get("example.com"),
202 Some(expected_canonical("example.com"))
203 );
204 assert!(tree.delete("example.com"));
205 assert_eq!(tree.get("example.com"), None);
206 assert!(!tree.delete("example.com"));
208 assert_eq!(
210 tree.get("wat.hm.api.test.www.example.com"),
211 Some(expected_canonical("api.test.www.example.com"))
212 );
213 assert!(tree.delete("api.test.www.example.com"));
214 assert_eq!(tree.get("wat.hm.api.test.www.example.com"), None);
215 }
216
217 #[test]
218 fn test_subdomain_matching() {
219 let mut tree = DnsRadixTree::new(ScopeMode::Normal);
220 let canonical1 = tree.insert("evilcorp.com").unwrap();
221 assert_eq!(
222 canonical1,
223 expected_canonical("evilcorp.com"),
224 "insert(evilcorp.com) canonical"
225 );
226 let canonical2 = tree.insert("www.evilcorp.com").unwrap();
227 assert_eq!(
228 canonical2,
229 expected_canonical("www.evilcorp.com"),
230 "insert(www.evilcorp.com) canonical"
231 );
232 let canonical3 = tree.insert("test.www.evilcorp.com").unwrap();
233 assert_eq!(
234 canonical3,
235 expected_canonical("test.www.evilcorp.com"),
236 "insert(test.www.evilcorp.com) canonical"
237 );
238 let canonical4 = tree.insert("api.test.www.evilcorp.com").unwrap();
239 assert_eq!(
240 canonical4,
241 expected_canonical("api.test.www.evilcorp.com"),
242 "insert(api.test.www.evilcorp.com) canonical"
243 );
244 assert_eq!(
245 tree.get("api.test.www.evilcorp.com"),
246 Some(expected_canonical("api.test.www.evilcorp.com"))
247 );
248 assert_eq!(
249 tree.get("test.www.evilcorp.com"),
250 Some(expected_canonical("test.www.evilcorp.com"))
251 );
252 assert_eq!(
253 tree.get("www.evilcorp.com"),
254 Some(expected_canonical("www.evilcorp.com"))
255 );
256 assert_eq!(
257 tree.get("evilcorp.com"),
258 Some(expected_canonical("evilcorp.com"))
259 );
260 assert_eq!(
262 tree.get("wat.hm.api.test.www.evilcorp.com"),
263 Some(expected_canonical("api.test.www.evilcorp.com"))
264 );
265 assert_eq!(
266 tree.get("asdf.test.www.evilcorp.com"),
267 Some(expected_canonical("test.www.evilcorp.com"))
268 );
269 assert_eq!(
270 tree.get("asdf.evilcorp.com"),
271 Some(expected_canonical("evilcorp.com"))
272 );
273 }
274
275 #[test]
276 fn test_no_match() {
277 let mut tree = DnsRadixTree::new(ScopeMode::Normal);
278 let canonical = tree.insert("example.com").unwrap();
279 assert_eq!(
280 canonical,
281 expected_canonical("example.com"),
282 "insert(example.com) canonical"
283 );
284 assert_eq!(tree.get("notfound.com"), None);
285 assert_eq!(tree.get("com"), None);
286 }
287
288 #[test]
289 fn test_top_level_domain() {
290 let mut tree = DnsRadixTree::new(ScopeMode::Normal);
291 let canonical = tree.insert("com").unwrap();
293 assert_eq!(
294 canonical,
295 expected_canonical("com"),
296 "insert(com) canonical"
297 );
298 assert_eq!(tree.get("www.example.com"), Some(expected_canonical("com")));
300 assert_eq!(tree.get("example.com"), Some(expected_canonical("com")));
301 assert_eq!(tree.get("com"), Some(expected_canonical("com")));
303 assert_eq!(tree.get(""), None);
305 }
306
307 #[test]
308 fn test_clear_method() {
309 use crate::node::BaseNode;
310
311 let mut tree = DnsRadixTree::new(ScopeMode::Normal);
312
313 let mut hosts = vec![
315 "example.com",
316 "www.example.com",
317 "api.example.com",
318 "mail.example.com",
319 "secure.api.example.com",
320 "dev.api.example.com",
321 "test.dev.api.example.com",
322 "staging.dev.api.example.com",
323 "other.com",
324 "sub.other.com",
325 ];
326
327 use rand::seq::SliceRandom;
329 use rand::thread_rng;
330 hosts.shuffle(&mut thread_rng());
331
332 for host in &hosts {
333 tree.insert(host);
334 }
335
336 for host in &hosts {
338 assert!(tree.get(host).is_some(), "Host {} should be present", host);
339 }
340
341 let parts: Vec<&str> = "api.example.com".split('.').collect();
343 let mut node = &mut tree.root;
344 for part in parts.iter().rev() {
345 node = node
346 .children
347 .get_mut(&hash_u64(part))
348 .expect("Node should exist");
349 }
350
351 let cleared_hosts = node.clear();
353
354 let expected_cleared = vec![
357 "secure.api.example.com",
358 "dev.api.example.com",
359 "test.dev.api.example.com",
360 "staging.dev.api.example.com",
361 ];
362
363 assert_eq!(
364 cleared_hosts.len(),
365 expected_cleared.len(),
366 "Should have cleared {} hosts, got {}: {:?}",
367 expected_cleared.len(),
368 cleared_hosts.len(),
369 cleared_hosts
370 );
371
372 for expected in &expected_cleared {
374 assert!(
375 cleared_hosts.contains(&expected.to_string()),
376 "Should have cleared {}",
377 expected
378 );
379 }
380
381 for cleared in &expected_cleared {
383 assert!(
384 tree.get(cleared).is_none()
385 || tree.get(cleared) == Some("api.example.com".to_string()),
386 "Cleared host {} should not be accessible or should fall back to parent",
387 cleared
388 );
389 }
390
391 assert_eq!(
393 tree.get("api.example.com"),
394 Some("api.example.com".to_string())
395 );
396
397 assert_eq!(tree.get("example.com"), Some("example.com".to_string()));
399 assert_eq!(
400 tree.get("www.example.com"),
401 Some("www.example.com".to_string())
402 );
403 assert_eq!(
404 tree.get("mail.example.com"),
405 Some("mail.example.com".to_string())
406 );
407 assert_eq!(tree.get("other.com"), Some("other.com".to_string()));
408 assert_eq!(tree.get("sub.other.com"), Some("sub.other.com".to_string()));
409 }
410
411 #[test]
412 fn test_acl_mode_skip_existing() {
413 let mut tree = DnsRadixTree::new(ScopeMode::Acl);
414
415 let result1 = tree.insert("example.com");
417 assert_eq!(result1, Some("example.com".to_string()));
418
419 let result2 = tree.insert("example.com");
421 assert_eq!(result2, None);
422
423 let result3 = tree.insert("other.com");
425 assert_eq!(result3, Some("other.com".to_string()));
426
427 assert_eq!(tree.get("example.com"), Some("example.com".to_string()));
429 assert_eq!(tree.get("other.com"), Some("other.com".to_string()));
430 }
431
432 #[test]
433 fn test_acl_mode_skip_children() {
434 let mut tree = DnsRadixTree::new(ScopeMode::Acl);
435
436 assert_eq!(tree.insert("example.com"), Some("example.com".to_string()));
438 assert_eq!(tree.get("example.com"), Some("example.com".to_string()));
439
440 assert_eq!(tree.insert("api.example.com"), None);
442
443 assert_eq!(tree.get("api.example.com"), Some("example.com".to_string()));
445 }
446
447 #[test]
448 fn test_acl_mode_clear_children() {
449 let mut tree = DnsRadixTree::new(ScopeMode::Acl);
450
451 tree.insert("api.example.com");
453 tree.insert("www.example.com");
454 tree.insert("mail.example.com");
455
456 assert_eq!(
458 tree.get("api.example.com"),
459 Some("api.example.com".to_string())
460 );
461 assert_eq!(
462 tree.get("www.example.com"),
463 Some("www.example.com".to_string())
464 );
465 assert_eq!(
466 tree.get("mail.example.com"),
467 Some("mail.example.com".to_string())
468 );
469
470 let result = tree.insert("example.com");
472 assert_eq!(result, Some("example.com".to_string()));
473
474 assert_eq!(tree.get("example.com"), Some("example.com".to_string()));
476
477 assert_eq!(tree.get("api.example.com"), Some("example.com".to_string()));
479 assert_eq!(tree.get("www.example.com"), Some("example.com".to_string()));
480 assert_eq!(
481 tree.get("mail.example.com"),
482 Some("example.com".to_string())
483 );
484 }
485}