aprender_shell/
trie.rs

1//! Trie data structure for fast prefix matching
2//!
3//! Optimized for minimal allocations (Issue #93):
4//! - Pre-allocated result vectors
5//! - Single mutable string buffer for traversal
6
7use std::collections::HashMap;
8
9/// Trie node
10#[derive(Default)]
11struct TrieNode {
12    children: HashMap<char, TrieNode>,
13    is_end: bool,
14    count: u32,
15}
16
17/// Trie for fast prefix-based command lookup
18pub struct Trie {
19    root: TrieNode,
20}
21
22impl Trie {
23    pub fn new() -> Self {
24        Self {
25            root: TrieNode::default(),
26        }
27    }
28
29    /// Insert a command into the trie
30    pub fn insert(&mut self, word: &str) {
31        let mut node = &mut self.root;
32
33        for ch in word.chars() {
34            node = node.children.entry(ch).or_default();
35        }
36
37        node.is_end = true;
38        node.count += 1;
39    }
40
41    /// Find all commands matching a prefix, sorted by frequency
42    ///
43    /// Optimized to reduce allocations:
44    /// - Pre-allocates result vector
45    /// - Uses single mutable buffer for string building
46    pub fn find_prefix(&self, prefix: &str, limit: usize) -> Vec<String> {
47        // Navigate to prefix node
48        let mut node = &self.root;
49
50        for ch in prefix.chars() {
51            match node.children.get(&ch) {
52                Some(n) => node = n,
53                None => return Vec::new(),
54            }
55        }
56
57        // Pre-allocate results with expected capacity
58        let mut results = Vec::with_capacity(limit.min(100));
59
60        // Use a single mutable buffer for building strings
61        let mut buffer = String::with_capacity(prefix.len() + 64);
62        buffer.push_str(prefix);
63
64        self.collect_words_optimized(node, &mut buffer, &mut results, limit);
65
66        // Sort by count (descending)
67        results.sort_unstable_by(|a, b| b.1.cmp(&a.1));
68
69        // Return just the strings (pre-allocate output)
70        let take_count = limit.min(results.len());
71        let mut output = Vec::with_capacity(take_count);
72        for (s, _) in results.into_iter().take(take_count) {
73            output.push(s);
74        }
75        output
76    }
77
78    /// Optimized collection using a single mutable buffer
79    fn collect_words_optimized(
80        &self,
81        node: &TrieNode,
82        buffer: &mut String,
83        results: &mut Vec<(String, u32)>,
84        limit: usize,
85    ) {
86        if node.is_end {
87            results.push((buffer.clone(), node.count));
88        }
89
90        // Early exit when we have enough results
91        if results.len() >= limit.min(100) {
92            return;
93        }
94
95        for (ch, child) in &node.children {
96            // Push character, recurse, then pop (avoids clone)
97            buffer.push(*ch);
98            self.collect_words_optimized(child, buffer, results, limit);
99            buffer.pop();
100        }
101    }
102
103    /// Legacy method for compatibility (unused but kept for reference)
104    #[allow(dead_code)]
105    fn collect_words(&self, node: &TrieNode, current: String, results: &mut Vec<(String, u32)>) {
106        if node.is_end {
107            results.push((current.clone(), node.count));
108        }
109
110        // Limit search depth for performance
111        if results.len() >= 100 {
112            return;
113        }
114
115        for (ch, child) in &node.children {
116            let mut next = current.clone();
117            next.push(*ch);
118            self.collect_words(child, next, results);
119        }
120    }
121}
122
123impl Default for Trie {
124    fn default() -> Self {
125        Self::new()
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132
133    #[test]
134    fn test_insert_and_find() {
135        let mut trie = Trie::new();
136        trie.insert("git status");
137        trie.insert("git commit");
138        trie.insert("git push");
139        trie.insert("grep pattern");
140
141        let results = trie.find_prefix("git ", 10);
142        assert_eq!(results.len(), 3);
143
144        let results = trie.find_prefix("grep", 10);
145        assert_eq!(results.len(), 1);
146    }
147
148    #[test]
149    fn test_frequency_ordering() {
150        let mut trie = Trie::new();
151        trie.insert("git status");
152        trie.insert("git status");
153        trie.insert("git status");
154        trie.insert("git commit");
155
156        let results = trie.find_prefix("git ", 10);
157        assert_eq!(results[0], "git status"); // Most frequent first
158    }
159
160    #[test]
161    fn test_no_match() {
162        let mut trie = Trie::new();
163        trie.insert("git status");
164
165        let results = trie.find_prefix("docker ", 10);
166        assert!(results.is_empty());
167    }
168}