1use compact_str::CompactString;
2use hashbrown::HashMap;
3use smallvec::SmallVec;
4use std::cmp::Reverse;
5use std::mem::size_of;
6
7const MAX_SUGGESTIONS_PER_NODE: usize = 10;
8const SMALL_CHILDREN_CAPACITY: usize = 4;
9const MAX_WORD_LENGTH: usize = 100;
10
11type NodeIndex = u32;
12type WordIndex = u32;
13type PackedSuggestion = u64;
14
15#[inline(always)]
16const fn pack_suggestion(weight: u32, word_idx: WordIndex) -> PackedSuggestion {
17 ((weight as u64) << 32) | (word_idx as u64)
18}
19
20#[inline(always)]
21const fn get_weight(packed: PackedSuggestion) -> u32 {
22 (packed >> 32) as u32
23}
24
25#[inline(always)]
26const fn get_word_idx(packed: PackedSuggestion) -> WordIndex {
27 packed as u32
28}
29
30#[derive(Clone)]
31enum Children {
32 Small(SmallVec<[(char, NodeIndex); SMALL_CHILDREN_CAPACITY]>),
33 Large(HashMap<char, NodeIndex>),
34}
35
36impl Children {
37 #[inline]
38 fn new() -> Self {
39 Self::Small(SmallVec::new())
40 }
41
42 #[inline]
43 fn get(&self, c: char) -> Option<NodeIndex> {
44 match self {
45 Self::Small(vec) => vec.iter().find_map(|&(ch, idx)| (ch == c).then_some(idx)),
46 Self::Large(map) => map.get(&c).copied(),
47 }
48 }
49
50 #[inline]
51 fn insert(&mut self, c: char, idx: NodeIndex) {
52 match self {
53 Self::Small(vec) if vec.len() < SMALL_CHILDREN_CAPACITY => {
54 if let Some(entry) = vec.iter_mut().find(|(ch, _)| *ch == c) {
55 entry.1 = idx;
56 } else {
57 vec.push((c, idx));
58 }
59 }
60 Self::Small(vec) => {
61 #[cold]
62 fn transition_to_large(
63 vec: &mut SmallVec<[(char, NodeIndex); SMALL_CHILDREN_CAPACITY]>,
64 c: char,
65 idx: NodeIndex,
66 ) -> Children {
67 let mut map: HashMap<_, _> = vec.drain(..).collect();
68 map.insert(c, idx);
69 Children::Large(map)
70 }
71 *self = transition_to_large(vec, c, idx);
72 }
73 Self::Large(map) => {
74 map.insert(c, idx);
75 }
76 }
77 }
78}
79
80#[derive(Default)]
81pub struct TrieNode {
82 children: Children,
83 suggestions: SmallVec<[PackedSuggestion; 2]>,
84}
85
86impl Default for Children {
87 fn default() -> Self {
88 Self::new()
89 }
90}
91
92pub struct WeightedTrie {
93 nodes: Vec<TrieNode>,
94 root: NodeIndex,
95 words: Vec<CompactString>,
96 word_map: HashMap<CompactString, WordIndex>,
97 max_word_length: usize,
98 max_suggestions: usize,
99}
100
101#[derive(Clone)]
102pub struct WeightedString {
103 pub word: String,
104 pub weight: u32,
105}
106
107impl WeightedString {
108 pub fn new(word: impl Into<String>, weight: u32) -> Self {
109 Self {
110 word: word.into(),
111 weight,
112 }
113 }
114}
115
116pub struct MemoryStats {
117 pub nodes_count: usize,
118 pub nodes_vec_capacity: usize,
119 pub nodes_struct_size: usize,
120 pub words_count: usize,
121 pub words_storage_bytes: usize,
122 pub words_capacity_bytes: usize,
123 pub word_map_capacity: usize,
124 pub suggestions_total: usize,
125 pub suggestions_heap_bytes: usize,
126 pub children_small_count: usize,
127 pub children_large_count: usize,
128 pub children_heap_bytes: usize,
129 pub total_bytes: usize,
130}
131
132impl WeightedTrie {
133 pub fn new() -> Self {
134 Self::with_config(MAX_WORD_LENGTH, MAX_SUGGESTIONS_PER_NODE)
135 }
136
137 pub fn with_max_word_length(max_word_length: usize) -> Self {
138 Self::with_config(max_word_length, MAX_SUGGESTIONS_PER_NODE)
139 }
140
141 pub fn with_max_suggestions(max_suggestions: usize) -> Self {
142 Self::with_config(MAX_WORD_LENGTH, max_suggestions)
143 }
144
145 pub fn with_config(max_word_length: usize, max_suggestions: usize) -> Self {
146 Self {
147 nodes: vec![TrieNode::default()],
148 root: 0,
149 words: Vec::new(),
150 word_map: HashMap::new(),
151 max_word_length,
152 max_suggestions,
153 }
154 }
155
156 pub fn memory_stats(&self) -> MemoryStats {
157 let nodes_count = self.nodes.len();
158 let nodes_vec_capacity = self.nodes.capacity();
159 let nodes_struct_size = nodes_count * size_of::<TrieNode>();
160
161 let words_count = self.words.len();
162 let words_storage_bytes: usize = self.words.iter().map(|s| s.len()).sum();
163 let words_capacity_bytes: usize = self.words.iter().map(|s| s.capacity()).sum();
164 let word_map_capacity = self.word_map.capacity();
165
166 let (
167 suggestions_total,
168 suggestions_heap_bytes,
169 children_small_count,
170 children_large_count,
171 children_heap_bytes,
172 ) = self.nodes.iter().fold(
173 (0, 0, 0, 0, 0),
174 |(sugg_total, sugg_heap, small, large, child_heap), node| {
175 let sugg_heap_add = if node.suggestions.spilled() {
176 node.suggestions.capacity() * size_of::<PackedSuggestion>()
177 } else {
178 0
179 };
180
181 let (small_add, large_add, child_heap_add) = match &node.children {
182 Children::Small(_) => (1, 0, 0),
183 Children::Large(map) => (
184 0,
185 1,
186 map.capacity() * (size_of::<char>() + size_of::<u32>() + 8),
187 ),
188 };
189
190 (
191 sugg_total + node.suggestions.len(),
192 sugg_heap + sugg_heap_add,
193 small + small_add,
194 large + large_add,
195 child_heap + child_heap_add,
196 )
197 },
198 );
199
200 let total_bytes = nodes_struct_size
201 + nodes_vec_capacity * size_of::<TrieNode>()
202 + words_capacity_bytes
203 + word_map_capacity * (size_of::<CompactString>() + size_of::<u32>() + 8)
204 + suggestions_heap_bytes
205 + children_heap_bytes;
206
207 MemoryStats {
208 nodes_count,
209 nodes_vec_capacity,
210 nodes_struct_size,
211 words_count,
212 words_storage_bytes,
213 words_capacity_bytes,
214 word_map_capacity,
215 suggestions_total,
216 suggestions_heap_bytes,
217 children_small_count,
218 children_large_count,
219 children_heap_bytes,
220 total_bytes,
221 }
222 }
223
224 pub fn build(weighted_strings: Vec<WeightedString>) -> Self {
225 Self::build_with_config(weighted_strings, MAX_WORD_LENGTH, MAX_SUGGESTIONS_PER_NODE)
226 }
227
228 pub fn build_with_max_word_length(
229 weighted_strings: Vec<WeightedString>,
230 max_word_length: usize,
231 ) -> Self {
232 Self::build_with_config(weighted_strings, max_word_length, MAX_SUGGESTIONS_PER_NODE)
233 }
234
235 pub fn build_with_max_suggestions(
236 weighted_strings: Vec<WeightedString>,
237 max_suggestions: usize,
238 ) -> Self {
239 Self::build_with_config(weighted_strings, MAX_WORD_LENGTH, max_suggestions)
240 }
241
242 pub fn build_with_config(
243 weighted_strings: Vec<WeightedString>,
244 max_word_length: usize,
245 max_suggestions: usize,
246 ) -> Self {
247 let count = weighted_strings.len();
248 let mut trie = Self {
249 nodes: Vec::with_capacity((count * 2).max(1000)),
250 root: 0,
251 words: Vec::with_capacity(count),
252 word_map: HashMap::with_capacity(count),
253 max_word_length,
254 max_suggestions,
255 };
256 trie.nodes.push(TrieNode::default());
257
258 for WeightedString { word, weight } in weighted_strings {
259 trie.insert(word, weight);
260 }
261
262 trie.words.shrink_to_fit();
263 trie.word_map.shrink_to_fit();
264 trie.nodes.shrink_to_fit();
265
266 trie
267 }
268
269 pub fn insert(&mut self, word: impl Into<String>, weight: u32) -> bool {
270 let word = word.into();
271
272 if word.len() > self.max_word_length {
273 return false;
274 }
275
276 let word_compact = CompactString::from(&word);
277 let word_idx = *self
278 .word_map
279 .entry(word_compact.clone())
280 .or_insert_with(|| {
281 self.words.push(word_compact);
282 (self.words.len() - 1) as WordIndex
283 });
284
285 let packed = pack_suggestion(weight, word_idx);
286 let mut node_idx = self.root;
287
288 for c in word.chars() {
289 node_idx = self.get_or_create_child(node_idx, c);
290 self.insert_suggestion(node_idx, word_idx, packed, weight);
291 }
292
293 true
294 }
295
296 #[inline]
297 fn get_or_create_child(&mut self, node_idx: NodeIndex, c: char) -> NodeIndex {
298 if let Some(idx) = self.nodes[node_idx as usize].children.get(c) {
299 return idx;
300 }
301
302 let new_idx = self.nodes.len() as NodeIndex;
303 self.nodes.push(TrieNode::default());
304 self.nodes[node_idx as usize].children.insert(c, new_idx);
305 new_idx
306 }
307
308 #[inline]
309 fn insert_suggestion(
310 &mut self,
311 node_idx: NodeIndex,
312 word_idx: WordIndex,
313 packed: PackedSuggestion,
314 weight: u32,
315 ) {
316 let node = &mut self.nodes[node_idx as usize];
317
318 if let Some(pos) = node
319 .suggestions
320 .iter()
321 .position(|&p| get_word_idx(p) == word_idx)
322 {
323 if weight > get_weight(node.suggestions[pos]) {
324 node.suggestions.remove(pos);
325 } else {
326 return;
327 }
328 }
329
330 let pos = node
331 .suggestions
332 .binary_search_by_key(&Reverse(weight), |&p| Reverse(get_weight(p)))
333 .unwrap_or_else(|x| x);
334
335 node.suggestions.insert(pos, packed);
336
337 if node.suggestions.len() > self.max_suggestions {
338 node.suggestions.truncate(self.max_suggestions);
339 }
340 }
341
342 pub fn search(&self, prefix: &str) -> Vec<String> {
343 let mut node_idx = self.root;
344
345 for c in prefix.chars() {
346 node_idx = match self.nodes[node_idx as usize].children.get(c) {
347 Some(idx) => idx,
348 None => return Vec::new(),
349 };
350 }
351
352 self.nodes[node_idx as usize]
353 .suggestions
354 .iter()
355 .map(|&packed| self.words[get_word_idx(packed) as usize].to_string())
356 .collect()
357 }
358
359 pub fn max_word_length(&self) -> usize {
360 self.max_word_length
361 }
362
363 pub fn max_suggestions(&self) -> usize {
364 self.max_suggestions
365 }
366}
367
368impl Default for WeightedTrie {
369 fn default() -> Self {
370 Self::new()
371 }
372}