Skip to main content

codebook/
parser.rs

1use crate::checker::WordCandidate;
2use crate::queries::{LANGUAGE_SETTINGS, LanguageType, get_language_setting};
3use crate::splitter;
4use regex::Regex;
5use std::collections::{HashMap, HashSet};
6use std::str::FromStr;
7use std::sync::{LazyLock, Mutex};
8use streaming_iterator::StreamingIterator;
9use tree_sitter::{Parser, Query, QueryCursor};
10use unicode_segmentation::UnicodeSegmentation;
11
12/// Global parser cache protected by a mutex. Serializes all tree-sitter
13/// operations (create, parse, destroy) to protect external scanners that
14/// use global mutable C state (e.g. tree-sitter-vhdl's static TokenTree).
15static PARSER_CACHE: LazyLock<Mutex<HashMap<LanguageType, Parser>>> =
16    LazyLock::new(|| Mutex::new(HashMap::new()));
17
18/// Pre-compiled query for a language, with its capture names.
19struct CompiledQuery {
20    query: Query,
21    capture_names: Vec<String>,
22}
23
24/// All tree-sitter queries compiled eagerly at startup. Since queries come
25/// from static `include_str!` data, they never change at runtime. Compiling
26/// them once here means bad queries panic immediately rather than hiding
27/// until a user opens that file type.
28static COMPILED_QUERIES: LazyLock<HashMap<LanguageType, CompiledQuery>> = LazyLock::new(|| {
29    let mut map = HashMap::new();
30    for setting in LANGUAGE_SETTINGS {
31        let Some(lang) = setting.language() else {
32            continue;
33        };
34        if setting.query.is_empty() {
35            continue;
36        }
37        let query = Query::new(&lang, setting.query)
38            .unwrap_or_else(|e| panic!("Failed to compile query for {:?}: {e}", setting.type_));
39        let capture_names = query
40            .capture_names()
41            .iter()
42            .map(|s| s.to_string())
43            .collect();
44        map.insert(
45            setting.type_,
46            CompiledQuery {
47                query,
48                capture_names,
49            },
50        );
51    }
52    map
53});
54
55#[derive(Debug, Clone, Copy, PartialEq, Ord, Eq, PartialOrd, Hash)]
56pub struct TextRange {
57    /// Start position in utf-8 byte offset
58    pub start_byte: usize,
59    /// End position in utf-8 byte offset
60    pub end_byte: usize,
61}
62
63#[derive(Debug, Clone, Copy, PartialEq)]
64struct SkipRange {
65    start_byte: usize,
66    end_byte: usize,
67}
68
69fn is_within_skip_range(start: usize, end: usize, skip_ranges: &[SkipRange]) -> bool {
70    skip_ranges
71        .iter()
72        .any(|r| start >= r.start_byte && end <= r.end_byte)
73}
74
75fn find_skip_ranges(text: &str, patterns: &[Regex]) -> Vec<SkipRange> {
76    if patterns.is_empty() {
77        return Vec::new();
78    }
79    let mut ranges = Vec::new();
80    for pattern in patterns {
81        for regex_match in pattern.find_iter(text) {
82            ranges.push(SkipRange {
83                start_byte: regex_match.start(),
84                end_byte: regex_match.end(),
85            });
86        }
87    }
88    ranges.sort_by_key(|r| r.start_byte);
89    merge_overlapping_ranges(ranges)
90}
91
92fn merge_overlapping_ranges(ranges: Vec<SkipRange>) -> Vec<SkipRange> {
93    if ranges.is_empty() {
94        return ranges;
95    }
96    let mut merged = Vec::new();
97    let mut current = ranges[0];
98    for range in ranges.into_iter().skip(1) {
99        if range.start_byte <= current.end_byte {
100            current.end_byte = current.end_byte.max(range.end_byte);
101        } else {
102            merged.push(current);
103            current = range;
104        }
105    }
106    merged.push(current);
107    merged
108}
109
110#[derive(Debug, Clone, PartialEq)]
111pub struct WordLocation {
112    pub word: String,
113    pub locations: Vec<TextRange>,
114}
115
116impl WordLocation {
117    pub fn new(word: String, locations: Vec<TextRange>) -> Self {
118        Self { word, locations }
119    }
120}
121
122// =============================================================================
123// Main entry point: recursive word extraction with injection support
124// =============================================================================
125
126/// Extract all candidate words from a document, recursively following
127/// `@injection.*` captures in .scm query files to handle multi-language files.
128///
129/// Returns the candidates and the set of all languages encountered (for
130/// dictionary loading).
131pub fn extract_all_words<'a>(
132    document_text: &'a str,
133    language: LanguageType,
134    tag_filter: &dyn Fn(&str) -> bool,
135    skip_patterns: &[Regex],
136) -> (Vec<WordCandidate<'a>>, HashSet<LanguageType>) {
137    let skip_ranges = find_skip_ranges(document_text, skip_patterns);
138    let mut result = ExtractionResult {
139        candidates: Vec::new(),
140        languages: HashSet::from([language]),
141    };
142
143    extract_recursive(
144        document_text,
145        0,
146        document_text.len(),
147        language,
148        tag_filter,
149        &skip_ranges,
150        &mut result,
151    );
152
153    (result.candidates, result.languages)
154}
155
156/// Accumulated output from recursive word extraction.
157struct ExtractionResult<'a> {
158    candidates: Vec<WordCandidate<'a>>,
159    languages: HashSet<LanguageType>,
160}
161
162/// Recursively extract words from a byte range of the document.
163///
164/// For languages with a tree-sitter grammar and .scm query:
165///   - Text captures (`@string`, `@comment`, `@identifier.*`) → word-split
166///   - Static injections (`@injection.{lang}`) → recurse with that language
167///   - Dynamic injections (`@injection.content` + `@injection.language`) → read
168///     the language name from the sibling capture, then recurse
169///
170/// For LanguageType::Text (no grammar): word-split the entire range.
171fn extract_recursive<'a>(
172    document_text: &'a str,
173    start_byte: usize,
174    end_byte: usize,
175    language: LanguageType,
176    tag_filter: &dyn Fn(&str) -> bool,
177    skip_ranges: &[SkipRange],
178    result: &mut ExtractionResult<'a>,
179) {
180    let language_setting = match get_language_setting(language) {
181        Some(s) => s,
182        None => {
183            // No grammar (e.g. Text): word-split the whole range
184            let text = &document_text[start_byte..end_byte];
185            extract_words_from_text(text, start_byte, skip_ranges, &mut result.candidates);
186            return;
187        }
188    };
189
190    let region_text = &document_text[start_byte..end_byte];
191
192    // Parse under global lock
193    let tree = {
194        let mut cache = PARSER_CACHE.lock().unwrap();
195        let parser = cache.entry(language).or_insert_with(|| {
196            let mut parser = Parser::new();
197            let lang = language_setting.language().unwrap();
198            parser.set_language(&lang).unwrap();
199            parser
200        });
201        parser.parse(region_text, None).unwrap()
202    };
203
204    let root_node = tree.root_node();
205    let compiled = COMPILED_QUERIES
206        .get(&language)
207        .expect("Language has a LanguageSetting but no compiled query; this should not happen");
208    let mut cursor = QueryCursor::new();
209    let provider = region_text.as_bytes();
210    let mut matches_query = cursor.matches(&compiled.query, root_node, provider);
211
212    while let Some(match_) = matches_query.next() {
213        // First pass: look for dynamic injection pairs in this match
214        let mut injection_content: Option<tree_sitter::Node> = None;
215        let mut injection_language_text: Option<&str> = None;
216
217        for capture in match_.captures {
218            let tag = &compiled.capture_names[capture.index as usize];
219            if tag == "injection.content" {
220                injection_content = Some(capture.node);
221            } else if tag == "injection.language" {
222                injection_language_text = Some(capture.node.utf8_text(provider).unwrap_or(""));
223            }
224        }
225
226        // Handle dynamic injection pair
227        if let Some(content_node) = injection_content {
228            if let Some(lang_text) = injection_language_text {
229                let lowered = lang_text.trim().to_lowercase();
230                let child_lang = LanguageType::from_str(&lowered);
231                if let Ok(child_lang) = child_lang
232                    && child_lang != LanguageType::Text
233                {
234                    let child_start = content_node.start_byte() + start_byte;
235                    let child_end = content_node.end_byte() + start_byte;
236                    if child_start < child_end {
237                        result.languages.insert(child_lang);
238                        extract_recursive(
239                            document_text,
240                            child_start,
241                            child_end,
242                            child_lang,
243                            tag_filter,
244                            skip_ranges,
245                            result,
246                        );
247                    }
248                }
249            }
250            continue;
251        }
252
253        // Second pass: handle text captures and static injections
254        for capture in match_.captures {
255            let tag = &compiled.capture_names[capture.index as usize];
256            let node = capture.node;
257            let node_start = node.start_byte() + start_byte;
258            let node_end = node.end_byte() + start_byte;
259
260            if node_start >= node_end {
261                continue;
262            }
263
264            if tag == "language" || tag == "injection.language" {
265                continue;
266            }
267
268            if let Some(lang_name) = tag.strip_prefix("injection.") {
269                // Static injection: @injection.html, @injection.javascript, etc.
270                if let Ok(child_lang) = LanguageType::from_str(lang_name)
271                    && child_lang != LanguageType::Text
272                {
273                    result.languages.insert(child_lang);
274                    extract_recursive(
275                        document_text,
276                        node_start,
277                        node_end,
278                        child_lang,
279                        tag_filter,
280                        skip_ranges,
281                        result,
282                    );
283                }
284                continue;
285            }
286
287            // Normal text capture: extract words if tag passes filter
288            if !tag_filter(tag) {
289                continue;
290            }
291
292            let node_text = node.utf8_text(provider).unwrap();
293            extract_words_from_text(node_text, node_start, skip_ranges, &mut result.candidates);
294        }
295    }
296}
297
298// =============================================================================
299// Word extraction from plain text
300// =============================================================================
301
302fn extract_words_from_text<'a>(
303    text: &'a str,
304    base_offset: usize,
305    skip_ranges: &[SkipRange],
306    candidates: &mut Vec<WordCandidate<'a>>,
307) {
308    let mut split_buf = Vec::new();
309    for (offset, word) in text.split_word_bound_indices() {
310        if !is_alphabetic(word) {
311            continue;
312        }
313        let global_offset = base_offset + offset;
314        if is_within_skip_range(global_offset, global_offset + word.len(), skip_ranges) {
315            continue;
316        }
317        splitter::split_into(word, &mut split_buf);
318        for split_word in &split_buf {
319            if is_numeric(split_word.word) {
320                continue;
321            }
322            let word_start = global_offset + split_word.start_byte;
323            let word_end = word_start + split_word.word.len();
324            if is_within_skip_range(word_start, word_end, skip_ranges) {
325                continue;
326            }
327            candidates.push(WordCandidate {
328                word: split_word.word,
329                start_byte: word_start,
330                end_byte: word_end,
331            });
332        }
333    }
334}
335
336fn is_numeric(s: &str) -> bool {
337    s.chars().any(|c| c.is_numeric())
338}
339
340fn is_alphabetic(c: &str) -> bool {
341    c.chars().any(|c| c.is_alphabetic())
342}
343
344/// Get a UTF-8 word from a string given the start and end bytes in utf16.
345pub fn get_word_from_string(start_utf16: usize, end_utf16: usize, text: &str) -> String {
346    let utf16_slice: Vec<u16> = text
347        .encode_utf16()
348        .skip(start_utf16)
349        .take(end_utf16 - start_utf16)
350        .collect();
351    String::from_utf16_lossy(&utf16_slice)
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357
358    #[test]
359    fn test_extract_words_plain_text() {
360        let text = "HelloWorld calc_wrld";
361        let (words, langs) = extract_all_words(text, LanguageType::Text, &|_| true, &[]);
362        let word_strings: Vec<&str> = words.iter().map(|w| w.word).collect();
363        assert!(word_strings.contains(&"Hello"));
364        assert!(word_strings.contains(&"World"));
365        assert!(word_strings.contains(&"calc"));
366        assert!(word_strings.contains(&"wrld"));
367        assert_eq!(words.len(), 4);
368        assert!(langs.contains(&LanguageType::Text));
369    }
370
371    #[test]
372    fn test_extract_words_contraction() {
373        let text = "I'm a contraction, wouldn't you agree'?";
374        let (words, _) = extract_all_words(text, LanguageType::Text, &|_| true, &[]);
375        let word_strings: Vec<&str> = words.iter().map(|w| w.word).collect();
376        let expected = ["I'm", "a", "contraction", "wouldn't", "you", "agree"];
377        for e in &expected {
378            assert!(word_strings.contains(e), "Expected word '{e}' not found");
379        }
380    }
381
382    #[test]
383    fn test_extract_words_code() {
384        let text = "// a comment\nfn main() {}";
385        let (words, langs) = extract_all_words(text, LanguageType::Rust, &|_| true, &[]);
386        assert!(!words.is_empty());
387        let word_strings: Vec<&str> = words.iter().map(|w| w.word).collect();
388        assert!(
389            word_strings.contains(&"comment"),
390            "Should find 'comment' in Rust comment"
391        );
392        assert!(langs.contains(&LanguageType::Rust));
393    }
394
395    #[test]
396    fn test_extract_words_tag_filter() {
397        let text = "// comment\nlet x = \"string value\";";
398        let (words, _) = extract_all_words(
399            text,
400            LanguageType::Rust,
401            &|tag| tag.starts_with("comment"),
402            &[],
403        );
404        let word_strings: Vec<&str> = words.iter().map(|w| w.word).collect();
405        assert!(word_strings.contains(&"comment"));
406        assert!(!word_strings.contains(&"string"));
407        assert!(!word_strings.contains(&"value"));
408    }
409
410    #[test]
411    fn test_extract_words_with_skip_patterns() {
412        let text = "check https://example.com this";
413        let url_pattern = Regex::new(r"https?://[^\s]+").unwrap();
414        let (words, _) = extract_all_words(text, LanguageType::Text, &|_| true, &[url_pattern]);
415        let word_strings: Vec<&str> = words.iter().map(|w| w.word).collect();
416        assert!(word_strings.contains(&"check"));
417        assert!(word_strings.contains(&"this"));
418        assert!(!word_strings.contains(&"https"));
419        assert!(!word_strings.contains(&"example"));
420    }
421
422    #[test]
423    fn test_extract_words_code_duplicates() {
424        let text = "// wrld foo wrld";
425        let (words, _) = extract_all_words(text, LanguageType::Rust, &|_| true, &[]);
426        let wrld_words: Vec<_> = words.iter().filter(|w| w.word == "wrld").collect();
427        assert_eq!(wrld_words.len(), 2, "Expected two occurrences of 'wrld'");
428    }
429
430    #[test]
431    fn test_markdown_injection_discovers_languages() {
432        let text =
433            "# Hello\n\nSome text.\n\n```python\ndef foo(): pass\n```\n\n```bash\necho hi\n```\n";
434        let (_, langs) = extract_all_words(text, LanguageType::Markdown, &|_| true, &[]);
435        assert!(langs.contains(&LanguageType::Markdown));
436        assert!(langs.contains(&LanguageType::Python));
437        assert!(langs.contains(&LanguageType::Bash));
438    }
439
440    #[test]
441    fn test_markdown_injection_extracts_code_words() {
442        let text = "# Hello\n\n```python\ndef some_functin(): pass\n```\n";
443        let (words, _) = extract_all_words(text, LanguageType::Markdown, &|_| true, &[]);
444        let word_strings: Vec<&str> = words.iter().map(|w| w.word).collect();
445        assert!(word_strings.contains(&"functin"));
446        assert!(word_strings.contains(&"Hello"));
447    }
448
449    #[test]
450    fn test_markdown_unknown_language_skipped() {
451        let text = "# Hello\n\n```unknownlang\nbadwwword\n```\n";
452        let (words, _) = extract_all_words(text, LanguageType::Markdown, &|_| true, &[]);
453        let word_strings: Vec<&str> = words.iter().map(|w| w.word).collect();
454        assert!(!word_strings.contains(&"badwwword"));
455    }
456
457    #[test]
458    fn test_markdown_html_block_injection() {
459        let text = "# Hello\n\n<div>\n  <p>A misspeled word</p>\n</div>\n\nMore text.\n";
460        let (words, langs) = extract_all_words(text, LanguageType::Markdown, &|_| true, &[]);
461        let word_strings: Vec<&str> = words.iter().map(|w| w.word).collect();
462        assert!(langs.contains(&LanguageType::HTML));
463        assert!(word_strings.contains(&"misspeled"));
464        assert!(!word_strings.contains(&"div"));
465    }
466
467    #[test]
468    fn test_get_word_from_string() {
469        let text = "Hello World";
470        assert_eq!(get_word_from_string(0, 5, text), "Hello");
471        assert_eq!(get_word_from_string(6, 11, text), "World");
472
473        let unicode_text = "こんにちは世界";
474        assert_eq!(get_word_from_string(0, 5, unicode_text), "こんにちは");
475        assert_eq!(get_word_from_string(5, 7, unicode_text), "世界");
476
477        let emoji_text = "Hello 👨‍👩‍👧‍👦 World";
478        assert_eq!(get_word_from_string(6, 17, emoji_text), "👨‍👩‍👧‍👦");
479    }
480
481    #[test]
482    fn test_unicode_character_handling() {
483        crate::logging::init_test_logging();
484        let text = "©<div>badword</div>";
485        let (words, _) = extract_all_words(text, LanguageType::Text, &|_| true, &[]);
486        let bad_word = words.iter().find(|w| w.word == "badword");
487        assert!(bad_word.is_some(), "Expected 'badword' to be found");
488        let bw = bad_word.unwrap();
489        assert_eq!(bw.start_byte, 7);
490        assert_eq!(bw.end_byte, 14);
491    }
492}