1use std::collections::HashMap;
8
9#[derive(Default)]
11struct TrieNode {
12 children: HashMap<char, TrieNode>,
13 is_end: bool,
14 count: u32,
15}
16
17pub struct Trie {
19 root: TrieNode,
20}
21
22impl Trie {
23 pub fn new() -> Self {
24 Self {
25 root: TrieNode::default(),
26 }
27 }
28
29 pub fn insert(&mut self, word: &str) {
31 let mut node = &mut self.root;
32
33 for ch in word.chars() {
34 node = node.children.entry(ch).or_default();
35 }
36
37 node.is_end = true;
38 node.count += 1;
39 }
40
41 pub fn find_prefix(&self, prefix: &str, limit: usize) -> Vec<String> {
47 let mut node = &self.root;
49
50 for ch in prefix.chars() {
51 match node.children.get(&ch) {
52 Some(n) => node = n,
53 None => return Vec::new(),
54 }
55 }
56
57 let mut results = Vec::with_capacity(limit.min(100));
59
60 let mut buffer = String::with_capacity(prefix.len() + 64);
62 buffer.push_str(prefix);
63
64 self.collect_words_optimized(node, &mut buffer, &mut results, limit);
65
66 results.sort_unstable_by(|a, b| b.1.cmp(&a.1));
68
69 let take_count = limit.min(results.len());
71 let mut output = Vec::with_capacity(take_count);
72 for (s, _) in results.into_iter().take(take_count) {
73 output.push(s);
74 }
75 output
76 }
77
78 fn collect_words_optimized(
80 &self,
81 node: &TrieNode,
82 buffer: &mut String,
83 results: &mut Vec<(String, u32)>,
84 limit: usize,
85 ) {
86 if node.is_end {
87 results.push((buffer.clone(), node.count));
88 }
89
90 if results.len() >= limit.min(100) {
92 return;
93 }
94
95 for (ch, child) in &node.children {
96 buffer.push(*ch);
98 self.collect_words_optimized(child, buffer, results, limit);
99 buffer.pop();
100 }
101 }
102
103 #[allow(dead_code)]
105 fn collect_words(&self, node: &TrieNode, current: String, results: &mut Vec<(String, u32)>) {
106 if node.is_end {
107 results.push((current.clone(), node.count));
108 }
109
110 if results.len() >= 100 {
112 return;
113 }
114
115 for (ch, child) in &node.children {
116 let mut next = current.clone();
117 next.push(*ch);
118 self.collect_words(child, next, results);
119 }
120 }
121}
122
123impl Default for Trie {
124 fn default() -> Self {
125 Self::new()
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132
133 #[test]
134 fn test_insert_and_find() {
135 let mut trie = Trie::new();
136 trie.insert("git status");
137 trie.insert("git commit");
138 trie.insert("git push");
139 trie.insert("grep pattern");
140
141 let results = trie.find_prefix("git ", 10);
142 assert_eq!(results.len(), 3);
143
144 let results = trie.find_prefix("grep", 10);
145 assert_eq!(results.len(), 1);
146 }
147
148 #[test]
149 fn test_frequency_ordering() {
150 let mut trie = Trie::new();
151 trie.insert("git status");
152 trie.insert("git status");
153 trie.insert("git status");
154 trie.insert("git commit");
155
156 let results = trie.find_prefix("git ", 10);
157 assert_eq!(results[0], "git status"); }
159
160 #[test]
161 fn test_no_match() {
162 let mut trie = Trie::new();
163 trie.insert("git status");
164
165 let results = trie.find_prefix("docker ", 10);
166 assert!(results.is_empty());
167 }
168}