awesome_trie/
lib.rs

1
2#[derive(Default, Debug)]
3struct AwesomeTrieNode {
4    children: [[Option<Box<AwesomeTrieNode>>; 16]; 16],
5    id: u16
6}
7
8
9impl AwesomeTrieNode {
10    fn new() -> Self {
11        let mut trienode = AwesomeTrieNode {
12            children: Default::default(),
13            id: 0
14        };
15        for index in 0..256 {
16            trienode.children[index >> 4][index & 15] = None;
17        }
18        trienode
19    }
20}
21
22#[derive(Debug)]
23pub struct AwesomeTrie {
24    root: AwesomeTrieNode,
25}
26
27
28impl AwesomeTrie {
29    pub fn new() -> Self {
30        AwesomeTrie {
31            root: AwesomeTrieNode::new(),
32        }
33    }
34
35    pub fn insert(&mut self, text: &str, id: u16) {
36        let text: &Vec<u8> = &text.to_string().into_bytes();
37        let mut node = &mut self.root;
38        for ch in text {
39            let ch = u8::from_be(*ch) as usize;
40            let index_a = ch >> 4;
41            let index_b = ch & 15;
42            if node.children[index_a][index_b].is_none() {
43                node.children[index_a][index_b] = Option::from(Box::new(AwesomeTrieNode::new()));
44            }
45            match &mut node.children[index_a][index_b] {
46                Some(next_node) => node = next_node,
47                None => unreachable!(),  // We've just checked that it's not None
48            }
49        }
50        node.id = id
51    }
52
53    fn find_longest(&self, text: &[u8]) -> (u16, u16) {
54        let mut node = &self.root;
55        let mut old_node: &AwesomeTrieNode = &self.root;
56        let mut index = 0;
57        let mut old_index = 0;
58        for ch in text {
59            let ch = u8::from_be(*ch) as usize;
60            let index_a = ch >> 4;
61            let index_b = ch & 15;
62            if let Some(next_node) = &node.children[index_a][index_b]{
63                if node.id != 0 {
64                    old_node = node;
65                    old_index = index;
66                }
67                node = &next_node;
68                index += 1;
69            } else {
70                return if node.id == 0 {
71                    (old_index, old_node.id)
72                } else {
73                    (index, node.id)
74                }
75            }
76        }
77        return if node.id == 0 {
78            (old_index, old_node.id)
79        } else {
80            (index, node.id)
81        }
82    }
83
84    pub fn tokenize(&self, text: &str) -> Vec<u16> {
85        let mut vec: Vec<u16> = Vec::new();
86        let text_length = text.len();
87        let mut index: usize = 0;
88        loop {
89            let result = self.find_longest(&text.as_bytes()[index..]);
90            if result.0 != 0 {
91                vec.push(result.1.into());
92                index += <u16 as Into<usize>>::into(result.0);
93            } else {
94                return vec;
95            }
96            if index >= text_length {
97                return vec;
98            }
99        }
100    }
101
102    pub fn contains(&self, text: &str) -> bool {
103        let mut node = &self.root;
104        let text = text.as_bytes();
105        for ch in text {
106            let ch = u8::from_be(*ch) as usize;
107            let index_a = ch >> 4;
108            let index_b = ch & 15;
109            if let Some(next_node) = &node.children[index_a][index_b]{
110                node = &next_node;
111            } else {
112                return false;
113            }
114        }
115        return if node.id == 0 {
116            false
117        } else {
118            true
119        }
120    }
121}
122
123#[cfg(test)]
124mod tests {
125
126    #[test]
127    fn it_contains() {
128        let mut trie = crate::AwesomeTrie::new();
129        trie.insert("hello", 1);
130        trie.insert("fish", 2);
131        trie.insert("hello guys", 3);
132        trie.insert("cat", 4);
133        trie.insert("catty", 5);
134        println!("{:?}", trie.tokenize("fish"));
135        println!("{:?}", trie.tokenize("fishcat"));
136        println!("{:?}", trie.tokenize("fishcathello"));
137        assert_eq!(trie.contains("catty"), true);
138        assert_eq!(trie.contains("fish"), true);
139        assert_eq!(trie.contains("hello"), true);
140        assert_eq!(trie.contains("hello guys"), true);
141        assert_eq!(trie.contains("hello guys"), true);
142        assert_eq!(trie.contains("fishcat"), false);
143        assert_eq!(trie.contains("fishcathello"), false);
144        assert_eq!(trie.contains("hallo"), false);
145    }
146}