1use std::collections::HashMap;
2
3use crate::util::errors::TrieError;
4
5#[derive(Default, Debug)]
6pub struct TrieNode {
7 children: HashMap<char, Option<Box<TrieNode>>>,
8 word_ends: bool,
9}
10
11#[derive(Default)]
12pub struct Trie {
13 pub root: TrieNode,
14}
15
16impl Trie {
17 pub fn new() -> Self {
18 Self {
19 root: TrieNode::default(),
20 }
21 }
22
23 pub fn insert(word: String, node: &mut TrieNode, position: usize) -> Result<(), TrieError> {
24 if position == word.len() {
25 node.word_ends = true;
26 return Ok(());
27 }
28
29 let c = word.as_bytes()[position] as char;
30
31 if !c.is_ascii_alphabetic() {
32 return Err(TrieError::InvalidCharacter);
33 }
34
35 let child = node
36 .children
37 .entry(c)
38 .or_insert_with(|| Some(Box::new(TrieNode::default())));
39
40 if let Some(child) = child {
41 Self::insert(word, child.as_mut(), position + 1)?;
42 }
43
44 Ok(())
45 }
46
47 pub fn search(word: String, node: &TrieNode, position: usize) -> Result<bool, TrieError> {
48 if position == word.len() {
49 return Ok(node.word_ends);
50 }
51
52 let c = word.as_bytes()[position] as char;
53
54 if !c.is_ascii_alphabetic() {
55 return Err(TrieError::InvalidCharacter);
56 }
57
58 match node.children.get(&c) {
59 Some(child) => Self::search(word, child.as_deref().unwrap(), position + 1),
60 None => Ok(false),
61 }
62 }
63
64 pub fn suggest(&self, prefix: &str) -> Result<Vec<String>, TrieError> {
65 let mut node = &self.root;
66 let mut current_letters = String::new();
67
68 for i in 0..prefix.len() {
69 let letter = prefix.as_bytes()[i] as char;
70
71 if let Some(child) = node.children.get(&letter) {
72 node = child.as_deref().unwrap();
73 current_letters.push(letter);
74 } else {
75 return Ok(Vec::new());
76 }
77 }
78
79 let mut suggestion_list: Vec<String> = Vec::new();
80
81 Self::consume_words(node, &mut suggestion_list, &mut current_letters);
82
83 Ok(suggestion_list)
84 }
85
86 fn consume_words(node: &TrieNode, word_list: &mut Vec<String>, current_letters: &mut String) {
87 if node.word_ends {
88 word_list.push(current_letters.clone());
89 }
90
91 for (child_char, child_node) in &node.children {
92 current_letters.push(*child_char);
93 Self::consume_words(child_node.as_deref().unwrap(), word_list, current_letters);
94 current_letters.pop();
95 }
96 }
97}
98
99#[cfg(test)]
100mod tests {
101 use std::error::Error;
102
103 use super::*;
104
105 #[test]
106 fn it_initializes_trie_node() {
107 let trie_node = TrieNode::default();
108
109 assert_eq!(trie_node.children.len(), 0);
110 assert_eq!(trie_node.word_ends, false);
111 }
112
113 #[test]
114 fn it_initializes_trie() {
115 let trie = Trie::new();
116
117 assert_eq!(trie.root.children.len(), 0);
118 assert_eq!(trie.root.word_ends, false);
119 }
120
121 #[test]
122 fn it_inserts_word_to_trie() -> Result<(), Box<dyn Error>> {
123 let mut trie = Trie::new();
124 let word = "test".to_string();
125
126 Trie::insert(word.clone(), &mut trie.root, 0)?;
127
128 let has_word = Trie::search(word.clone(), &trie.root, 0)?;
129
130 assert!(has_word);
131
132 Ok(())
133 }
134
135 #[test]
136 fn it_suggests_words() -> Result<(), Box<dyn Error>> {
137 let mut trie = Trie::new();
138 let words = vec!["hello", "helicopter", "helium", "hall", "hundred"];
139 let expected = vec![
140 "hello".to_string(),
141 "helicopter".to_string(),
142 "helium".to_string(),
143 ];
144
145 for word in words {
146 Trie::insert(word.to_string(), &mut trie.root, 0)?;
147 }
148
149 let result = trie.suggest("hel")?;
150
151 for word in &result {
152 assert!(expected.contains(word));
153 }
154
155 assert_eq!(result.len(), 3);
156
157 Ok(())
158 }
159
160 #[test]
161 fn it_inserts_word_with_non_ascii_character() -> Result<(), Box<dyn Error>> {
162 let mut trie = Trie::new();
163 let word = "~~~".to_string();
164
165 assert_eq!(
166 Trie::insert(word.clone(), &mut trie.root, 0).unwrap_err(),
167 TrieError::InvalidCharacter
168 );
169
170 Ok(())
171 }
172}