rk_utils/
trie.rs

1use std::collections::HashMap;
2
3/// TrieNode is a node in a Trie.
4pub struct TrieNode<'a, T> {
5    children: HashMap<&'a str, TrieNode<'a, T>>,
6    data: Option<T>,
7}
8
9impl<'a, T> Default for TrieNode<'a, T> {
10    fn default() -> Self {
11        Self {
12            children: HashMap::new(),
13            data: None,
14        }
15    }
16}
17
18/// Trie is a data structure that stores a set of strings.
19/// It is used to find the longest match of a string.
20///
21/// # Example
22///
23/// ```rust
24/// use rk_utils::Trie;
25///
26/// let mut trie = Trie::new();
27/// trie.insert(vec!["a", "b", "c"], 1);
28/// trie.insert(vec!["a", "b", "d"], 2);
29///
30/// assert_eq!(trie.find_longest_match(vec!["a", "b", "c", "d"]), Some(&1));
31/// assert_eq!(trie.find_longest_match(vec!["a", "b", "d", "e"]), Some(&2));
32/// ```
33pub struct Trie<'a, T> {
34    root: TrieNode<'a, T>,
35}
36
37impl<'a, T> Default for Trie<'a, T> {
38    fn default() -> Self {
39        Self::new()
40    }
41}
42
43impl<'a, T> Trie<'a, T> {
44    /// Create a new Trie.
45    pub fn new() -> Self {
46        Self {
47            root: TrieNode::default(),
48        }
49    }
50
51    /// Insert a path of nodes with data.
52    pub fn insert(&mut self, path: Vec<&'a str>, data: T) {
53        let mut node = &mut self.root;
54        for key in path.iter() {
55            node = node.children.entry(key).or_default();
56        }
57        node.data = Some(data);
58    }
59
60    /// Find the longest match of a path of nodes.
61    /// It returns the data associated with the longest matched path.
62    pub fn find_longest_match(&'a self, request_path: Vec<&'a str>) -> Option<&T> {
63        let mut node = &self.root;
64        let mut last_matched_data: Option<&T> = None;
65
66        for key in request_path.iter() {
67            if let Some(next_node) = node.children.get(key) {
68                node = next_node;
69                if node.data.is_some() {
70                    last_matched_data = node.data.as_ref();
71                }
72            } else {
73                break;
74            }
75        }
76
77        last_matched_data
78    }
79}