use std::collections::HashMap;
#[derive(Default)]
struct TrieNode {
children: HashMap<char, TrieNode>,
is_end: bool,
count: u32,
}
pub struct Trie {
root: TrieNode,
}
impl Trie {
pub fn new() -> Self {
Self {
root: TrieNode::default(),
}
}
pub fn insert(&mut self, word: &str) {
let mut node = &mut self.root;
for ch in word.chars() {
node = node.children.entry(ch).or_default();
}
node.is_end = true;
node.count += 1;
}
pub fn find_prefix(&self, prefix: &str, limit: usize) -> Vec<String> {
let mut node = &self.root;
for ch in prefix.chars() {
match node.children.get(&ch) {
Some(n) => node = n,
None => return Vec::new(),
}
}
let mut results = Vec::with_capacity(limit.min(100));
let mut buffer = String::with_capacity(prefix.len() + 64);
buffer.push_str(prefix);
self.collect_words_optimized(node, &mut buffer, &mut results, limit);
results.sort_unstable_by(|a, b| b.1.cmp(&a.1));
let take_count = limit.min(results.len());
let mut output = Vec::with_capacity(take_count);
for (s, _) in results.into_iter().take(take_count) {
output.push(s);
}
output
}
fn collect_words_optimized(
&self,
node: &TrieNode,
buffer: &mut String,
results: &mut Vec<(String, u32)>,
limit: usize,
) {
if node.is_end {
results.push((buffer.clone(), node.count));
}
if results.len() >= limit.min(100) {
return;
}
for (ch, child) in &node.children {
buffer.push(*ch);
self.collect_words_optimized(child, buffer, results, limit);
buffer.pop();
}
}
#[allow(dead_code)]
fn collect_words(&self, node: &TrieNode, current: String, results: &mut Vec<(String, u32)>) {
if node.is_end {
results.push((current.clone(), node.count));
}
if results.len() >= 100 {
return;
}
for (ch, child) in &node.children {
let mut next = current.clone();
next.push(*ch);
self.collect_words(child, next, results);
}
}
}
impl Default for Trie {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_insert_and_find() {
let mut trie = Trie::new();
trie.insert("git status");
trie.insert("git commit");
trie.insert("git push");
trie.insert("grep pattern");
let results = trie.find_prefix("git ", 10);
assert_eq!(results.len(), 3);
let results = trie.find_prefix("grep", 10);
assert_eq!(results.len(), 1);
}
#[test]
fn test_frequency_ordering() {
let mut trie = Trie::new();
trie.insert("git status");
trie.insert("git status");
trie.insert("git status");
trie.insert("git commit");
let results = trie.find_prefix("git ", 10);
assert_eq!(results[0], "git status"); }
#[test]
fn test_no_match() {
let mut trie = Trie::new();
trie.insert("git status");
let results = trie.find_prefix("docker ", 10);
assert!(results.is_empty());
}
}