arbor_cli/common/
trie.rs

1use std::collections::HashMap;
2
3use crate::util::errors::TrieError;
4
5#[derive(Default, Debug)]
6pub struct TrieNode {
7    children: HashMap<char, Option<Box<TrieNode>>>,
8    word_ends: bool,
9}
10
11#[derive(Default)]
12pub struct Trie {
13    pub root: TrieNode,
14}
15
16impl Trie {
17    pub fn new() -> Self {
18        Self {
19            root: TrieNode::default(),
20        }
21    }
22
23    pub fn insert(word: String, node: &mut TrieNode, position: usize) -> Result<(), TrieError> {
24        if position == word.len() {
25            node.word_ends = true;
26            return Ok(());
27        }
28
29        let c = word.as_bytes()[position] as char;
30
31        if !c.is_ascii_alphabetic() {
32            return Err(TrieError::InvalidCharacter);
33        }
34
35        let child = node
36            .children
37            .entry(c)
38            .or_insert_with(|| Some(Box::new(TrieNode::default())));
39
40        if let Some(child) = child {
41            Self::insert(word, child.as_mut(), position + 1)?;
42        }
43
44        Ok(())
45    }
46
47    pub fn search(word: String, node: &TrieNode, position: usize) -> Result<bool, TrieError> {
48        if position == word.len() {
49            return Ok(node.word_ends);
50        }
51
52        let c = word.as_bytes()[position] as char;
53
54        if !c.is_ascii_alphabetic() {
55            return Err(TrieError::InvalidCharacter);
56        }
57
58        match node.children.get(&c) {
59            Some(child) => Self::search(word, child.as_deref().unwrap(), position + 1),
60            None => Ok(false),
61        }
62    }
63
64    pub fn suggest(&self, prefix: &str) -> Result<Vec<String>, TrieError> {
65        let mut node = &self.root;
66        let mut current_letters = String::new();
67
68        for i in 0..prefix.len() {
69            let letter = prefix.as_bytes()[i] as char;
70
71            if let Some(child) = node.children.get(&letter) {
72                node = child.as_deref().unwrap();
73                current_letters.push(letter);
74            } else {
75                return Ok(Vec::new());
76            }
77        }
78
79        let mut suggestion_list: Vec<String> = Vec::new();
80
81        Self::consume_words(node, &mut suggestion_list, &mut current_letters);
82
83        Ok(suggestion_list)
84    }
85
86    fn consume_words(node: &TrieNode, word_list: &mut Vec<String>, current_letters: &mut String) {
87        if node.word_ends {
88            word_list.push(current_letters.clone());
89        }
90
91        for (child_char, child_node) in &node.children {
92            current_letters.push(*child_char);
93            Self::consume_words(child_node.as_deref().unwrap(), word_list, current_letters);
94            current_letters.pop();
95        }
96    }
97}
98
99#[cfg(test)]
100mod tests {
101    use std::error::Error;
102
103    use super::*;
104
105    #[test]
106    fn it_initializes_trie_node() {
107        let trie_node = TrieNode::default();
108
109        assert_eq!(trie_node.children.len(), 0);
110        assert_eq!(trie_node.word_ends, false);
111    }
112
113    #[test]
114    fn it_initializes_trie() {
115        let trie = Trie::new();
116
117        assert_eq!(trie.root.children.len(), 0);
118        assert_eq!(trie.root.word_ends, false);
119    }
120
121    #[test]
122    fn it_inserts_word_to_trie() -> Result<(), Box<dyn Error>> {
123        let mut trie = Trie::new();
124        let word = "test".to_string();
125
126        Trie::insert(word.clone(), &mut trie.root, 0)?;
127
128        let has_word = Trie::search(word.clone(), &trie.root, 0)?;
129
130        assert!(has_word);
131
132        Ok(())
133    }
134
135    #[test]
136    fn it_suggests_words() -> Result<(), Box<dyn Error>> {
137        let mut trie = Trie::new();
138        let words = vec!["hello", "helicopter", "helium", "hall", "hundred"];
139        let expected = vec![
140            "hello".to_string(),
141            "helicopter".to_string(),
142            "helium".to_string(),
143        ];
144
145        for word in words {
146            Trie::insert(word.to_string(), &mut trie.root, 0)?;
147        }
148
149        let result = trie.suggest("hel")?;
150
151        for word in &result {
152            assert!(expected.contains(word));
153        }
154
155        assert_eq!(result.len(), 3);
156
157        Ok(())
158    }
159
160    #[test]
161    fn it_inserts_word_with_non_ascii_character() -> Result<(), Box<dyn Error>> {
162        let mut trie = Trie::new();
163        let word = "~~~".to_string();
164
165        assert_eq!(
166            Trie::insert(word.clone(), &mut trie.root, 0).unwrap_err(),
167            TrieError::InvalidCharacter
168        );
169
170        Ok(())
171    }
172}