radixtarget_rust/
dns.rs

1// DNSRadixTree: A radix tree for efficient DNS hostname lookups.
2// Hostnames are stored in reverse order (TLD to subdomain) for hierarchical matching.
3// Inspired by the Python implementation in dns.py.
4use crate::node::{BaseNode, DnsNode, hash_u64};
5use std::collections::HashSet;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
8pub enum ScopeMode {
9    /// Normal mode - standard radix tree behavior
10    Normal,
11    /// Strict scope mode - more restrictive matching
12    Strict,
13    /// ACL mode - access control list behavior
14    Acl,
15}
16
17#[derive(Debug, Clone)]
18pub struct DnsRadixTree {
19    pub root: DnsNode,
20    pub scope_mode: ScopeMode,
21}
22
23impl DnsRadixTree {
24    pub fn new(scope_mode: ScopeMode) -> Self {
25        DnsRadixTree {
26            root: DnsNode::new(),
27            scope_mode,
28        }
29    }
30
31    /// Insert a hostname into the tree, storing parts in reverse order for hierarchy.
32    /// Returns the canonicalized hostname after insertion, or None if already exists in ACL mode.
33    pub fn insert(&mut self, hostname: &str) -> Option<String> {
34        // If ACL mode is enabled, check if the host is already covered by the tree
35        if self.scope_mode == ScopeMode::Acl && self.get(hostname).is_some() {
36            return None; // Skip insertion if already covered
37        }
38
39        let parts: Vec<&str> = hostname.split('.').collect();
40        let mut node = &mut self.root;
41        for part in parts.iter().rev() {
42            node = node
43                .children
44                .entry(hash_u64(part))
45                .or_insert_with(|| Box::new(DnsNode::new()));
46        }
47        node.host = Some(hostname.to_string());
48
49        // If ACL mode is enabled, clear children of the inserted node
50        if self.scope_mode == ScopeMode::Acl {
51            node.clear();
52        }
53
54        Some(hostname.to_string())
55    }
56
57    /// Find the most specific matching entry for a given hostname.
58    /// If strict_scope is true, only exact matches are allowed.
59    /// Returns the canonicalized hostname if found.
60    pub fn get(&self, hostname: &str) -> Option<String> {
61        let parts: Vec<&str> = hostname.split('.').collect();
62        let mut node = &self.root;
63        let mut matched: Option<&String> = None;
64        for (i, part) in parts.iter().rev().enumerate() {
65            if let Some(child) = node.children.get(&hash_u64(part)) {
66                node = child;
67                if self.scope_mode == ScopeMode::Strict && i + 1 < parts.len() {
68                    continue;
69                }
70                if let Some(host) = &node.host {
71                    matched = Some(host);
72                }
73            } else {
74                break;
75            }
76        }
77        matched.cloned()
78    }
79
80    /// Delete a hostname from the tree.
81    /// Returns true if the hostname was found and deleted.
82    pub fn delete(&mut self, hostname: &str) -> bool {
83        let parts: Vec<&str> = hostname.split('.').collect();
84        Self::delete_rec(&mut self.root, &parts, 0)
85    }
86
87    /// Recursive helper for deletion.
88    fn delete_rec(node: &mut DnsNode, parts: &[&str], depth: usize) -> bool {
89        if depth == parts.len() {
90            if node.host.is_some() {
91                node.host = None;
92                return true;
93            }
94            return false;
95        }
96        let part = parts[parts.len() - 1 - depth];
97        if let Some(child) = node.children.get_mut(&hash_u64(part)) {
98            let deleted = Self::delete_rec(child, parts, depth + 1);
99            if child.children.is_empty() && child.host.is_none() {
100                node.children.remove(&hash_u64(part));
101            }
102            return deleted;
103        }
104        false
105    }
106
107    pub fn prune(&mut self) -> usize {
108        self.root.prune()
109    }
110
111    /// Get all hostnames stored in the tree
112    pub fn hosts(&self) -> HashSet<String> {
113        self.root.all_hosts()
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120
121    fn expected_canonical(host: &str) -> String {
122        host.to_lowercase()
123    }
124
125    #[test]
126    fn test_insert_and_get_basic() {
127        let mut tree = DnsRadixTree::new(ScopeMode::Normal);
128        let canonical1 = tree.insert("example.com").unwrap();
129        assert_eq!(
130            canonical1,
131            expected_canonical("example.com"),
132            "insert(example.com) canonical"
133        );
134        let canonical2 = tree.insert("api.test.www.example.com").unwrap();
135        assert_eq!(
136            canonical2,
137            expected_canonical("api.test.www.example.com"),
138            "insert(api.test.www.example.com) canonical"
139        );
140        assert_eq!(
141            tree.get("example.com"),
142            Some(expected_canonical("example.com"))
143        );
144        assert_eq!(
145            tree.get("api.test.www.example.com"),
146            Some(expected_canonical("api.test.www.example.com"))
147        );
148        // Subdomain matching
149        assert_eq!(
150            tree.get("wat.hm.api.test.www.example.com"),
151            Some(expected_canonical("api.test.www.example.com"))
152        );
153        // No match
154        assert_eq!(tree.get("notfound.com"), None);
155    }
156
157    #[test]
158    fn test_strict_scope() {
159        let mut tree = DnsRadixTree::new(ScopeMode::Strict);
160        let canonical1 = tree.insert("example.com").unwrap();
161        assert_eq!(
162            canonical1,
163            expected_canonical("example.com"),
164            "insert(example.com) canonical"
165        );
166        let canonical2 = tree.insert("api.test.www.example.com").unwrap();
167        assert_eq!(
168            canonical2,
169            expected_canonical("api.test.www.example.com"),
170            "insert(api.test.www.example.com) canonical"
171        );
172        // Only exact matches
173        assert_eq!(
174            tree.get("example.com"),
175            Some(expected_canonical("example.com"))
176        );
177        assert_eq!(
178            tree.get("api.test.www.example.com"),
179            Some(expected_canonical("api.test.www.example.com"))
180        );
181        assert_eq!(tree.get("wat.hm.api.test.www.example.com"), None);
182        assert_eq!(tree.get("notfound.com"), None);
183    }
184
185    #[test]
186    fn test_delete() {
187        let mut tree = DnsRadixTree::new(ScopeMode::Normal);
188        let canonical1 = tree.insert("example.com").unwrap();
189        assert_eq!(
190            canonical1,
191            expected_canonical("example.com"),
192            "insert(example.com) canonical"
193        );
194        let canonical2 = tree.insert("api.test.www.example.com").unwrap();
195        assert_eq!(
196            canonical2,
197            expected_canonical("api.test.www.example.com"),
198            "insert(api.test.www.example.com) canonical"
199        );
200        assert_eq!(
201            tree.get("example.com"),
202            Some(expected_canonical("example.com"))
203        );
204        assert!(tree.delete("example.com"));
205        assert_eq!(tree.get("example.com"), None);
206        // Deleting again should fail
207        assert!(!tree.delete("example.com"));
208        // Subdomain should still match the more specific
209        assert_eq!(
210            tree.get("wat.hm.api.test.www.example.com"),
211            Some(expected_canonical("api.test.www.example.com"))
212        );
213        assert!(tree.delete("api.test.www.example.com"));
214        assert_eq!(tree.get("wat.hm.api.test.www.example.com"), None);
215    }
216
217    #[test]
218    fn test_subdomain_matching() {
219        let mut tree = DnsRadixTree::new(ScopeMode::Normal);
220        let canonical1 = tree.insert("evilcorp.com").unwrap();
221        assert_eq!(
222            canonical1,
223            expected_canonical("evilcorp.com"),
224            "insert(evilcorp.com) canonical"
225        );
226        let canonical2 = tree.insert("www.evilcorp.com").unwrap();
227        assert_eq!(
228            canonical2,
229            expected_canonical("www.evilcorp.com"),
230            "insert(www.evilcorp.com) canonical"
231        );
232        let canonical3 = tree.insert("test.www.evilcorp.com").unwrap();
233        assert_eq!(
234            canonical3,
235            expected_canonical("test.www.evilcorp.com"),
236            "insert(test.www.evilcorp.com) canonical"
237        );
238        let canonical4 = tree.insert("api.test.www.evilcorp.com").unwrap();
239        assert_eq!(
240            canonical4,
241            expected_canonical("api.test.www.evilcorp.com"),
242            "insert(api.test.www.evilcorp.com) canonical"
243        );
244        assert_eq!(
245            tree.get("api.test.www.evilcorp.com"),
246            Some(expected_canonical("api.test.www.evilcorp.com"))
247        );
248        assert_eq!(
249            tree.get("test.www.evilcorp.com"),
250            Some(expected_canonical("test.www.evilcorp.com"))
251        );
252        assert_eq!(
253            tree.get("www.evilcorp.com"),
254            Some(expected_canonical("www.evilcorp.com"))
255        );
256        assert_eq!(
257            tree.get("evilcorp.com"),
258            Some(expected_canonical("evilcorp.com"))
259        );
260        // Subdomain matching
261        assert_eq!(
262            tree.get("wat.hm.api.test.www.evilcorp.com"),
263            Some(expected_canonical("api.test.www.evilcorp.com"))
264        );
265        assert_eq!(
266            tree.get("asdf.test.www.evilcorp.com"),
267            Some(expected_canonical("test.www.evilcorp.com"))
268        );
269        assert_eq!(
270            tree.get("asdf.evilcorp.com"),
271            Some(expected_canonical("evilcorp.com"))
272        );
273    }
274
275    #[test]
276    fn test_no_match() {
277        let mut tree = DnsRadixTree::new(ScopeMode::Normal);
278        let canonical = tree.insert("example.com").unwrap();
279        assert_eq!(
280            canonical,
281            expected_canonical("example.com"),
282            "insert(example.com) canonical"
283        );
284        assert_eq!(tree.get("notfound.com"), None);
285        assert_eq!(tree.get("com"), None);
286    }
287
288    #[test]
289    fn test_top_level_domain() {
290        let mut tree = DnsRadixTree::new(ScopeMode::Normal);
291        // insert a top level domain
292        let canonical = tree.insert("com").unwrap();
293        assert_eq!(
294            canonical,
295            expected_canonical("com"),
296            "insert(com) canonical"
297        );
298        // get subdomains
299        assert_eq!(tree.get("www.example.com"), Some(expected_canonical("com")));
300        assert_eq!(tree.get("example.com"), Some(expected_canonical("com")));
301        // get the top level domain
302        assert_eq!(tree.get("com"), Some(expected_canonical("com")));
303        // empty string should not match
304        assert_eq!(tree.get(""), None);
305    }
306
307    #[test]
308    fn test_clear_method() {
309        use crate::node::BaseNode;
310
311        let mut tree = DnsRadixTree::new(ScopeMode::Normal);
312
313        // Insert hosts in random order
314        let mut hosts = vec![
315            "example.com",
316            "www.example.com",
317            "api.example.com",
318            "mail.example.com",
319            "secure.api.example.com",
320            "dev.api.example.com",
321            "test.dev.api.example.com",
322            "staging.dev.api.example.com",
323            "other.com",
324            "sub.other.com",
325        ];
326
327        // Shuffle randomly
328        use rand::seq::SliceRandom;
329        use rand::thread_rng;
330        hosts.shuffle(&mut thread_rng());
331
332        for host in &hosts {
333            tree.insert(host);
334        }
335
336        // Verify all hosts are present
337        for host in &hosts {
338            assert!(tree.get(host).is_some(), "Host {} should be present", host);
339        }
340
341        // Find the node for "api.example.com" and clear it
342        let parts: Vec<&str> = "api.example.com".split('.').collect();
343        let mut node = &mut tree.root;
344        for part in parts.iter().rev() {
345            node = node
346                .children
347                .get_mut(&hash_u64(part))
348                .expect("Node should exist");
349        }
350
351        // Clear the api.example.com node (should clear its children)
352        let cleared_hosts = node.clear();
353
354        // Should have cleared: secure.api.example.com, dev.api.example.com,
355        // test.dev.api.example.com, staging.dev.api.example.com
356        let expected_cleared = vec![
357            "secure.api.example.com",
358            "dev.api.example.com",
359            "test.dev.api.example.com",
360            "staging.dev.api.example.com",
361        ];
362
363        assert_eq!(
364            cleared_hosts.len(),
365            expected_cleared.len(),
366            "Should have cleared {} hosts, got {}: {:?}",
367            expected_cleared.len(),
368            cleared_hosts.len(),
369            cleared_hosts
370        );
371
372        // Check that all expected hosts were cleared
373        for expected in &expected_cleared {
374            assert!(
375                cleared_hosts.contains(&expected.to_string()),
376                "Should have cleared {}",
377                expected
378            );
379        }
380
381        // Verify the cleared hosts are no longer accessible
382        for cleared in &expected_cleared {
383            assert!(
384                tree.get(cleared).is_none()
385                    || tree.get(cleared) == Some("api.example.com".to_string()),
386                "Cleared host {} should not be accessible or should fall back to parent",
387                cleared
388            );
389        }
390
391        // Verify that api.example.com itself is still accessible
392        assert_eq!(
393            tree.get("api.example.com"),
394            Some("api.example.com".to_string())
395        );
396
397        // Verify that unrelated hosts are still accessible
398        assert_eq!(tree.get("example.com"), Some("example.com".to_string()));
399        assert_eq!(
400            tree.get("www.example.com"),
401            Some("www.example.com".to_string())
402        );
403        assert_eq!(
404            tree.get("mail.example.com"),
405            Some("mail.example.com".to_string())
406        );
407        assert_eq!(tree.get("other.com"), Some("other.com".to_string()));
408        assert_eq!(tree.get("sub.other.com"), Some("sub.other.com".to_string()));
409    }
410
411    #[test]
412    fn test_acl_mode_skip_existing() {
413        let mut tree = DnsRadixTree::new(ScopeMode::Acl);
414
415        // First insertion should succeed
416        let result1 = tree.insert("example.com");
417        assert_eq!(result1, Some("example.com".to_string()));
418
419        // Second insertion of same host should return None
420        let result2 = tree.insert("example.com");
421        assert_eq!(result2, None);
422
423        // Different host should still work
424        let result3 = tree.insert("other.com");
425        assert_eq!(result3, Some("other.com".to_string()));
426
427        // Verify both hosts are accessible
428        assert_eq!(tree.get("example.com"), Some("example.com".to_string()));
429        assert_eq!(tree.get("other.com"), Some("other.com".to_string()));
430    }
431
432    #[test]
433    fn test_acl_mode_skip_children() {
434        let mut tree = DnsRadixTree::new(ScopeMode::Acl);
435
436        // Insert parent domain first
437        assert_eq!(tree.insert("example.com"), Some("example.com".to_string()));
438        assert_eq!(tree.get("example.com"), Some("example.com".to_string()));
439
440        // Insert child domain should return None (already covered by parent)
441        assert_eq!(tree.insert("api.example.com"), None);
442
443        // Get child domain should return parent
444        assert_eq!(tree.get("api.example.com"), Some("example.com".to_string()));
445    }
446
447    #[test]
448    fn test_acl_mode_clear_children() {
449        let mut tree = DnsRadixTree::new(ScopeMode::Acl);
450
451        // Insert child domains first
452        tree.insert("api.example.com");
453        tree.insert("www.example.com");
454        tree.insert("mail.example.com");
455
456        // Verify children are accessible
457        assert_eq!(
458            tree.get("api.example.com"),
459            Some("api.example.com".to_string())
460        );
461        assert_eq!(
462            tree.get("www.example.com"),
463            Some("www.example.com".to_string())
464        );
465        assert_eq!(
466            tree.get("mail.example.com"),
467            Some("mail.example.com".to_string())
468        );
469
470        // Insert parent domain - should clear children
471        let result = tree.insert("example.com");
472        assert_eq!(result, Some("example.com".to_string()));
473
474        // Parent should be accessible
475        assert_eq!(tree.get("example.com"), Some("example.com".to_string()));
476
477        // Children should now fall back to parent
478        assert_eq!(tree.get("api.example.com"), Some("example.com".to_string()));
479        assert_eq!(tree.get("www.example.com"), Some("example.com".to_string()));
480        assert_eq!(
481            tree.get("mail.example.com"),
482            Some("example.com".to_string())
483        );
484    }
485}