Skip to main content

codelens_core/analyzer/
trie.rs

1//! Token trie for fast multi-pattern matching.
2//!
3//! A 256-way trie that unifies all token types (comments, strings, keywords)
4//! into a single lookup structure. Combined with a process mask (bloom filter
5//! on first bytes), this skips ~90% of bytes in the hot loop.
6
7/// Type of token matched by the trie.
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum TokenType {
10    LineComment,
11    BlockCommentStart,
12    StringDelimiter,
13    DocStringDelimiter,
14}
15
16/// Result of a successful trie match.
17///
18/// Note: `advance` is set automatically by `TokenTrie::insert()` based on
19/// the pattern length. Any value provided by the caller will be overwritten.
20#[derive(Debug, Clone, PartialEq, Eq)]
21pub struct TokenMatch {
22    pub token_type: TokenType,
23    /// Closing byte sequence (e.g., `*/` for block comments, `"` for strings).
24    /// None for line comments (they end at newline).
25    pub close: Option<Vec<u8>>,
26    /// How many bytes to advance past the matched token (set by insert).
27    pub advance: usize,
28}
29
30struct TrieNode {
31    children: Vec<Option<Box<TrieNode>>>,
32    token_match: Option<TokenMatch>,
33}
34
35impl TrieNode {
36    fn new() -> Self {
37        Self {
38            children: (0..256).map(|_| None).collect(),
39            token_match: None,
40        }
41    }
42}
43
44/// Fast token lookup trie. All token types for a language are inserted into one trie.
45///
46/// `Debug` prints only the process mask (the full trie is too large to dump).
47pub struct TokenTrie {
48    root: TrieNode,
49    mask: u8,
50}
51
52impl TokenTrie {
53    pub fn new() -> Self {
54        Self {
55            root: TrieNode::new(),
56            mask: 0,
57        }
58    }
59
60    pub fn insert(&mut self, pattern: &[u8], mut token_match: TokenMatch) {
61        if pattern.is_empty() {
62            return;
63        }
64        self.mask |= pattern[0];
65        token_match.advance = pattern.len();
66
67        let mut node = &mut self.root;
68        for &byte in pattern {
69            let idx = byte as usize;
70            if node.children[idx].is_none() {
71                node.children[idx] = Some(Box::new(TrieNode::new()));
72            }
73            node = node.children[idx].as_mut().unwrap();
74        }
75        node.token_match = Some(token_match);
76    }
77
78    /// Try to match a token at position `pos`. Returns longest match (greedy).
79    pub fn match_at(&self, content: &[u8], pos: usize) -> Option<TokenMatch> {
80        let mut node = &self.root;
81        let mut last_match: Option<&TokenMatch> = None;
82
83        for &byte in &content[pos..] {
84            match &node.children[byte as usize] {
85                Some(child) => {
86                    node = child;
87                    if node.token_match.is_some() {
88                        last_match = node.token_match.as_ref();
89                    }
90                }
91                None => break,
92            }
93        }
94
95        last_match.cloned()
96    }
97
98    pub fn process_mask(&self) -> u8 {
99        self.mask
100    }
101}
102
103impl Default for TokenTrie {
104    fn default() -> Self {
105        Self::new()
106    }
107}
108
109impl std::fmt::Debug for TokenTrie {
110    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
111        f.debug_struct("TokenTrie")
112            .field("mask", &self.mask)
113            .finish_non_exhaustive()
114    }
115}
116
117/// Build a `TokenTrie` from a `Language` definition.
118///
119/// Inserts line comments, block comment starts (with close bytes),
120/// string delimiters (`"` and `'`), and for Python/Ruby also
121/// triple-quote doc-string delimiters (`"""` and `'''`).
122pub fn build_from_language(lang: &crate::language::Language) -> (TokenTrie, u8) {
123    let mut trie = TokenTrie::new();
124
125    // Line comments
126    for lc in &lang.line_comments {
127        trie.insert(
128            lc.as_bytes(),
129            TokenMatch {
130                token_type: TokenType::LineComment,
131                close: None,
132                advance: 0,
133            },
134        );
135    }
136
137    // Block comments
138    for (open, close) in &lang.block_comments {
139        trie.insert(
140            open.as_bytes(),
141            TokenMatch {
142                token_type: TokenType::BlockCommentStart,
143                close: Some(close.as_bytes().to_vec()),
144                advance: 0,
145            },
146        );
147    }
148
149    // String delimiters: always " and '
150    trie.insert(
151        b"\"",
152        TokenMatch {
153            token_type: TokenType::StringDelimiter,
154            close: Some(b"\"".to_vec()),
155            advance: 0,
156        },
157    );
158    trie.insert(
159        b"'",
160        TokenMatch {
161            token_type: TokenType::StringDelimiter,
162            close: Some(b"'".to_vec()),
163            advance: 0,
164        },
165    );
166
167    // Doc-string delimiters for Python and Ruby (longest-match wins over single quotes)
168    let name = lang.name.as_str();
169    if name == "Python" || name == "Ruby" {
170        trie.insert(
171            b"\"\"\"",
172            TokenMatch {
173                token_type: TokenType::DocStringDelimiter,
174                close: Some(b"\"\"\"".to_vec()),
175                advance: 0,
176            },
177        );
178        trie.insert(
179            b"'''",
180            TokenMatch {
181                token_type: TokenType::DocStringDelimiter,
182                close: Some(b"'''".to_vec()),
183                advance: 0,
184            },
185        );
186    }
187
188    let mask = trie.process_mask();
189    (trie, mask)
190}
191
192/// Fast check: could this byte possibly start a token?
193///
194/// This is a bloom-filter-style check using a bitwise OR mask of all first
195/// bytes of inserted patterns. False positives are expected and acceptable
196/// (they just trigger a `match_at` call that returns `None`). False negatives
197/// are impossible by construction (`mask |= first_byte` for every pattern).
198#[inline(always)]
199pub fn should_process(byte: u8, mask: u8) -> bool {
200    byte & mask == byte
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206
207    #[test]
208    fn test_empty_trie_matches_nothing() {
209        let trie = TokenTrie::new();
210        assert_eq!(trie.match_at(b"hello", 0), None);
211    }
212
213    #[test]
214    fn test_single_line_comment() {
215        let mut trie = TokenTrie::new();
216        trie.insert(
217            b"//",
218            TokenMatch {
219                token_type: TokenType::LineComment,
220                close: None,
221                advance: 0,
222            },
223        );
224        let m = trie.match_at(b"// comment", 0).unwrap();
225        assert_eq!(m.token_type, TokenType::LineComment);
226        assert_eq!(m.advance, 2);
227        assert!(m.close.is_none());
228    }
229
230    #[test]
231    fn test_block_comment() {
232        let mut trie = TokenTrie::new();
233        trie.insert(
234            b"/*",
235            TokenMatch {
236                token_type: TokenType::BlockCommentStart,
237                close: Some(b"*/".to_vec()),
238                advance: 0,
239            },
240        );
241        let m = trie.match_at(b"/* block */", 0).unwrap();
242        assert_eq!(m.token_type, TokenType::BlockCommentStart);
243        assert_eq!(m.advance, 2);
244        assert_eq!(m.close.as_deref(), Some(b"*/".as_slice()));
245    }
246
247    #[test]
248    fn test_no_match_at_wrong_position() {
249        let mut trie = TokenTrie::new();
250        trie.insert(
251            b"//",
252            TokenMatch {
253                token_type: TokenType::LineComment,
254                close: None,
255                advance: 0,
256            },
257        );
258        assert_eq!(trie.match_at(b"x // y", 0), None);
259        let m = trie.match_at(b"x // y", 2).unwrap();
260        assert_eq!(m.token_type, TokenType::LineComment);
261    }
262
263    #[test]
264    fn test_string_delimiter() {
265        let mut trie = TokenTrie::new();
266        trie.insert(
267            b"\"",
268            TokenMatch {
269                token_type: TokenType::StringDelimiter,
270                close: Some(b"\"".to_vec()),
271                advance: 0,
272            },
273        );
274        let m = trie.match_at(b"\"hello\"", 0).unwrap();
275        assert_eq!(m.token_type, TokenType::StringDelimiter);
276        assert_eq!(m.close.as_deref(), Some(b"\"".as_slice()));
277    }
278
279    #[test]
280    fn test_process_mask_filters_correctly() {
281        let mut trie = TokenTrie::new();
282        trie.insert(
283            b"//",
284            TokenMatch {
285                token_type: TokenType::LineComment,
286                close: None,
287                advance: 0,
288            },
289        );
290        trie.insert(
291            b"\"",
292            TokenMatch {
293                token_type: TokenType::StringDelimiter,
294                close: Some(b"\"".to_vec()),
295                advance: 0,
296            },
297        );
298        let mask = trie.process_mask();
299        assert!(should_process(b'/', mask));
300        assert!(should_process(b'"', mask));
301        // Letters should generally not pass (bloom filter allows some false positives)
302        assert!(!should_process(b'a', mask));
303    }
304
305    #[test]
306    fn test_longer_match_wins() {
307        let mut trie = TokenTrie::new();
308        trie.insert(
309            b"\"",
310            TokenMatch {
311                token_type: TokenType::StringDelimiter,
312                close: Some(b"\"".to_vec()),
313                advance: 0,
314            },
315        );
316        trie.insert(
317            b"\"\"\"",
318            TokenMatch {
319                token_type: TokenType::DocStringDelimiter,
320                close: Some(b"\"\"\"".to_vec()),
321                advance: 0,
322            },
323        );
324        let m = trie.match_at(b"\"\"\"hello\"\"\"", 0).unwrap();
325        assert_eq!(m.token_type, TokenType::DocStringDelimiter);
326        assert_eq!(m.advance, 3);
327    }
328
329    #[test]
330    fn test_build_from_rust_language() {
331        use crate::language::Language;
332
333        let lang = Language {
334            name: "Rust".to_string(),
335            extensions: vec![".rs".to_string()],
336            line_comments: vec!["//".to_string()],
337            block_comments: vec![("/*".to_string(), "*/".to_string())],
338            nested_comments: true,
339            ..Default::default()
340        };
341        let (trie, mask) = build_from_language(&lang);
342
343        let m = trie.match_at(b"// comment", 0).unwrap();
344        assert_eq!(m.token_type, TokenType::LineComment);
345
346        let m = trie.match_at(b"/* block */", 0).unwrap();
347        assert_eq!(m.token_type, TokenType::BlockCommentStart);
348        assert_eq!(m.close.as_deref(), Some(b"*/".as_slice()));
349
350        let m = trie.match_at(b"\"hello\"", 0).unwrap();
351        assert_eq!(m.token_type, TokenType::StringDelimiter);
352
353        assert_ne!(mask, 0);
354    }
355
356    #[test]
357    fn test_build_from_python_language() {
358        use crate::language::Language;
359
360        let lang = Language {
361            name: "Python".to_string(),
362            extensions: vec![".py".to_string()],
363            line_comments: vec!["#".to_string()],
364            ..Default::default()
365        };
366        let (trie, _mask) = build_from_language(&lang);
367
368        let m = trie.match_at(b"# comment", 0).unwrap();
369        assert_eq!(m.token_type, TokenType::LineComment);
370
371        let m = trie.match_at(b"\"\"\"docstring\"\"\"", 0).unwrap();
372        assert_eq!(m.token_type, TokenType::DocStringDelimiter);
373        assert_eq!(m.close.as_deref(), Some(b"\"\"\"".as_slice()));
374
375        let m = trie.match_at(b"'''docstring'''", 0).unwrap();
376        assert_eq!(m.token_type, TokenType::DocStringDelimiter);
377        assert_eq!(m.close.as_deref(), Some(b"'''".as_slice()));
378    }
379}