weighted_trie/
trie.rs

1use compact_str::CompactString;
2use hashbrown::HashMap;
3use smallvec::SmallVec;
4use std::cmp::Reverse;
5use std::mem::size_of;
6
7const MAX_SUGGESTIONS_PER_NODE: usize = 10;
8const SMALL_CHILDREN_CAPACITY: usize = 4;
9const MAX_WORD_LENGTH: usize = 100;
10
11type NodeIndex = u32;
12type WordIndex = u32;
13type PackedSuggestion = u64;
14
15#[inline(always)]
16const fn pack_suggestion(weight: u32, word_idx: WordIndex) -> PackedSuggestion {
17    ((weight as u64) << 32) | (word_idx as u64)
18}
19
20#[inline(always)]
21const fn get_weight(packed: PackedSuggestion) -> u32 {
22    (packed >> 32) as u32
23}
24
25#[inline(always)]
26const fn get_word_idx(packed: PackedSuggestion) -> WordIndex {
27    packed as u32
28}
29
30#[derive(Clone)]
31enum Children {
32    Small(SmallVec<[(char, NodeIndex); SMALL_CHILDREN_CAPACITY]>),
33    Large(HashMap<char, NodeIndex>),
34}
35
36impl Children {
37    #[inline]
38    fn new() -> Self {
39        Self::Small(SmallVec::new())
40    }
41
42    #[inline]
43    fn get(&self, c: char) -> Option<NodeIndex> {
44        match self {
45            Self::Small(vec) => vec.iter().find_map(|&(ch, idx)| (ch == c).then_some(idx)),
46            Self::Large(map) => map.get(&c).copied(),
47        }
48    }
49
50    #[inline]
51    fn insert(&mut self, c: char, idx: NodeIndex) {
52        match self {
53            Self::Small(vec) if vec.len() < SMALL_CHILDREN_CAPACITY => {
54                if let Some(entry) = vec.iter_mut().find(|(ch, _)| *ch == c) {
55                    entry.1 = idx;
56                } else {
57                    vec.push((c, idx));
58                }
59            }
60            Self::Small(vec) => {
61                #[cold]
62                fn transition_to_large(
63                    vec: &mut SmallVec<[(char, NodeIndex); SMALL_CHILDREN_CAPACITY]>,
64                    c: char,
65                    idx: NodeIndex,
66                ) -> Children {
67                    let mut map: HashMap<_, _> = vec.drain(..).collect();
68                    map.insert(c, idx);
69                    Children::Large(map)
70                }
71                *self = transition_to_large(vec, c, idx);
72            }
73            Self::Large(map) => {
74                map.insert(c, idx);
75            }
76        }
77    }
78}
79
80#[derive(Default)]
81pub struct TrieNode {
82    children: Children,
83    suggestions: SmallVec<[PackedSuggestion; 2]>,
84}
85
86impl Default for Children {
87    fn default() -> Self {
88        Self::new()
89    }
90}
91
92pub struct WeightedTrie {
93    nodes: Vec<TrieNode>,
94    root: NodeIndex,
95    words: Vec<CompactString>,
96    word_map: HashMap<CompactString, WordIndex>,
97    max_word_length: usize,
98    max_suggestions: usize,
99}
100
101#[derive(Clone)]
102pub struct WeightedString {
103    pub word: String,
104    pub weight: u32,
105}
106
107impl WeightedString {
108    pub fn new(word: impl Into<String>, weight: u32) -> Self {
109        Self {
110            word: word.into(),
111            weight,
112        }
113    }
114}
115
116pub struct MemoryStats {
117    pub nodes_count: usize,
118    pub nodes_vec_capacity: usize,
119    pub nodes_struct_size: usize,
120    pub words_count: usize,
121    pub words_storage_bytes: usize,
122    pub words_capacity_bytes: usize,
123    pub word_map_capacity: usize,
124    pub suggestions_total: usize,
125    pub suggestions_heap_bytes: usize,
126    pub children_small_count: usize,
127    pub children_large_count: usize,
128    pub children_heap_bytes: usize,
129    pub total_bytes: usize,
130}
131
132impl WeightedTrie {
133    pub fn new() -> Self {
134        Self::with_config(MAX_WORD_LENGTH, MAX_SUGGESTIONS_PER_NODE)
135    }
136
137    pub fn with_max_word_length(max_word_length: usize) -> Self {
138        Self::with_config(max_word_length, MAX_SUGGESTIONS_PER_NODE)
139    }
140
141    pub fn with_max_suggestions(max_suggestions: usize) -> Self {
142        Self::with_config(MAX_WORD_LENGTH, max_suggestions)
143    }
144
145    pub fn with_config(max_word_length: usize, max_suggestions: usize) -> Self {
146        Self {
147            nodes: vec![TrieNode::default()],
148            root: 0,
149            words: Vec::new(),
150            word_map: HashMap::new(),
151            max_word_length,
152            max_suggestions,
153        }
154    }
155
156    pub fn memory_stats(&self) -> MemoryStats {
157        let nodes_count = self.nodes.len();
158        let nodes_vec_capacity = self.nodes.capacity();
159        let nodes_struct_size = nodes_count * size_of::<TrieNode>();
160
161        let words_count = self.words.len();
162        let words_storage_bytes: usize = self.words.iter().map(|s| s.len()).sum();
163        let words_capacity_bytes: usize = self.words.iter().map(|s| s.capacity()).sum();
164        let word_map_capacity = self.word_map.capacity();
165
166        let (
167            suggestions_total,
168            suggestions_heap_bytes,
169            children_small_count,
170            children_large_count,
171            children_heap_bytes,
172        ) = self.nodes.iter().fold(
173            (0, 0, 0, 0, 0),
174            |(sugg_total, sugg_heap, small, large, child_heap), node| {
175                let sugg_heap_add = if node.suggestions.spilled() {
176                    node.suggestions.capacity() * size_of::<PackedSuggestion>()
177                } else {
178                    0
179                };
180
181                let (small_add, large_add, child_heap_add) = match &node.children {
182                    Children::Small(_) => (1, 0, 0),
183                    Children::Large(map) => (
184                        0,
185                        1,
186                        map.capacity() * (size_of::<char>() + size_of::<u32>() + 8),
187                    ),
188                };
189
190                (
191                    sugg_total + node.suggestions.len(),
192                    sugg_heap + sugg_heap_add,
193                    small + small_add,
194                    large + large_add,
195                    child_heap + child_heap_add,
196                )
197            },
198        );
199
200        let total_bytes = nodes_struct_size
201            + nodes_vec_capacity * size_of::<TrieNode>()
202            + words_capacity_bytes
203            + word_map_capacity * (size_of::<CompactString>() + size_of::<u32>() + 8)
204            + suggestions_heap_bytes
205            + children_heap_bytes;
206
207        MemoryStats {
208            nodes_count,
209            nodes_vec_capacity,
210            nodes_struct_size,
211            words_count,
212            words_storage_bytes,
213            words_capacity_bytes,
214            word_map_capacity,
215            suggestions_total,
216            suggestions_heap_bytes,
217            children_small_count,
218            children_large_count,
219            children_heap_bytes,
220            total_bytes,
221        }
222    }
223
224    pub fn build(weighted_strings: Vec<WeightedString>) -> Self {
225        Self::build_with_config(weighted_strings, MAX_WORD_LENGTH, MAX_SUGGESTIONS_PER_NODE)
226    }
227
228    pub fn build_with_max_word_length(
229        weighted_strings: Vec<WeightedString>,
230        max_word_length: usize,
231    ) -> Self {
232        Self::build_with_config(weighted_strings, max_word_length, MAX_SUGGESTIONS_PER_NODE)
233    }
234
235    pub fn build_with_max_suggestions(
236        weighted_strings: Vec<WeightedString>,
237        max_suggestions: usize,
238    ) -> Self {
239        Self::build_with_config(weighted_strings, MAX_WORD_LENGTH, max_suggestions)
240    }
241
242    pub fn build_with_config(
243        weighted_strings: Vec<WeightedString>,
244        max_word_length: usize,
245        max_suggestions: usize,
246    ) -> Self {
247        let count = weighted_strings.len();
248        let mut trie = Self {
249            nodes: Vec::with_capacity((count * 2).max(1000)),
250            root: 0,
251            words: Vec::with_capacity(count),
252            word_map: HashMap::with_capacity(count),
253            max_word_length,
254            max_suggestions,
255        };
256        trie.nodes.push(TrieNode::default());
257
258        for WeightedString { word, weight } in weighted_strings {
259            trie.insert(word, weight);
260        }
261
262        trie.words.shrink_to_fit();
263        trie.word_map.shrink_to_fit();
264        trie.nodes.shrink_to_fit();
265
266        trie
267    }
268
269    pub fn insert(&mut self, word: impl Into<String>, weight: u32) -> bool {
270        let word = word.into();
271
272        if word.len() > self.max_word_length {
273            return false;
274        }
275
276        let word_compact = CompactString::from(&word);
277        let word_idx = *self
278            .word_map
279            .entry(word_compact.clone())
280            .or_insert_with(|| {
281                self.words.push(word_compact);
282                (self.words.len() - 1) as WordIndex
283            });
284
285        let packed = pack_suggestion(weight, word_idx);
286        let mut node_idx = self.root;
287
288        for c in word.chars() {
289            node_idx = self.get_or_create_child(node_idx, c);
290            self.insert_suggestion(node_idx, word_idx, packed, weight);
291        }
292
293        true
294    }
295
296    #[inline]
297    fn get_or_create_child(&mut self, node_idx: NodeIndex, c: char) -> NodeIndex {
298        if let Some(idx) = self.nodes[node_idx as usize].children.get(c) {
299            return idx;
300        }
301
302        let new_idx = self.nodes.len() as NodeIndex;
303        self.nodes.push(TrieNode::default());
304        self.nodes[node_idx as usize].children.insert(c, new_idx);
305        new_idx
306    }
307
308    #[inline]
309    fn insert_suggestion(
310        &mut self,
311        node_idx: NodeIndex,
312        word_idx: WordIndex,
313        packed: PackedSuggestion,
314        weight: u32,
315    ) {
316        let node = &mut self.nodes[node_idx as usize];
317
318        if let Some(pos) = node
319            .suggestions
320            .iter()
321            .position(|&p| get_word_idx(p) == word_idx)
322        {
323            if weight > get_weight(node.suggestions[pos]) {
324                node.suggestions.remove(pos);
325            } else {
326                return;
327            }
328        }
329
330        let pos = node
331            .suggestions
332            .binary_search_by_key(&Reverse(weight), |&p| Reverse(get_weight(p)))
333            .unwrap_or_else(|x| x);
334
335        node.suggestions.insert(pos, packed);
336
337        if node.suggestions.len() > self.max_suggestions {
338            node.suggestions.truncate(self.max_suggestions);
339        }
340    }
341
342    pub fn search(&self, prefix: &str) -> Vec<String> {
343        let mut node_idx = self.root;
344
345        for c in prefix.chars() {
346            node_idx = match self.nodes[node_idx as usize].children.get(c) {
347                Some(idx) => idx,
348                None => return Vec::new(),
349            };
350        }
351
352        self.nodes[node_idx as usize]
353            .suggestions
354            .iter()
355            .map(|&packed| self.words[get_word_idx(packed) as usize].to_string())
356            .collect()
357    }
358
359    pub fn max_word_length(&self) -> usize {
360        self.max_word_length
361    }
362
363    pub fn max_suggestions(&self) -> usize {
364        self.max_suggestions
365    }
366}
367
368impl Default for WeightedTrie {
369    fn default() -> Self {
370        Self::new()
371    }
372}