Skip to main content

codebook/
parser.rs

1use crate::splitter::{self};
2
3use crate::queries::{LanguageType, get_language_setting};
4use regex::Regex;
5use std::collections::{HashMap, HashSet};
6use std::sync::{LazyLock, Mutex};
7use streaming_iterator::StreamingIterator;
8use tree_sitter::{Parser, Query, QueryCursor};
9use unicode_segmentation::UnicodeSegmentation;
10
11/// Global parser cache protected by a mutex. Serializes all tree-sitter
12/// operations (create, parse, destroy) to protect external scanners that
13/// use global mutable C state (e.g. tree-sitter-vhdl's static TokenTree).
14static PARSER_CACHE: LazyLock<Mutex<HashMap<LanguageType, Parser>>> =
15    LazyLock::new(|| Mutex::new(HashMap::new()));
16
17#[derive(Debug, Clone, Copy, PartialEq, Ord, Eq, PartialOrd, Hash)]
18pub struct TextRange {
19    /// Start position in utf-8 byte offset
20    pub start_byte: usize,
21    /// End position in utf-8 byte offset
22    pub end_byte: usize,
23}
24
25#[derive(Debug, Clone, Copy, PartialEq)]
26struct SkipRange {
27    /// Start position in utf-8 byte offset
28    start_byte: usize,
29    /// End position in utf-8 byte offset
30    end_byte: usize,
31}
32
33/// Check if a word at [start, end) is entirely within any skip range
34fn is_within_skip_range(start: usize, end: usize, skip_ranges: &[SkipRange]) -> bool {
35    skip_ranges
36        .iter()
37        .any(|r| start >= r.start_byte && end <= r.end_byte)
38}
39
40/// Find skip ranges from pattern matches in text.
41fn find_skip_ranges(text: &str, patterns: &[Regex]) -> Vec<SkipRange> {
42    if patterns.is_empty() {
43        return Vec::new();
44    }
45
46    let mut ranges = Vec::new();
47
48    for pattern in patterns {
49        for regex_match in pattern.find_iter(text) {
50            ranges.push(SkipRange {
51                start_byte: regex_match.start(),
52                end_byte: regex_match.end(),
53            });
54        }
55    }
56
57    ranges.sort_by_key(|r| r.start_byte);
58    merge_overlapping_ranges(ranges)
59}
60
61/// Merge overlapping or adjacent ranges
62fn merge_overlapping_ranges(ranges: Vec<SkipRange>) -> Vec<SkipRange> {
63    if ranges.is_empty() {
64        return ranges;
65    }
66
67    let mut merged = Vec::new();
68    let mut current = ranges[0];
69
70    for range in ranges.into_iter().skip(1) {
71        if range.start_byte <= current.end_byte {
72            current.end_byte = current.end_byte.max(range.end_byte);
73        } else {
74            merged.push(current);
75            current = range;
76        }
77    }
78    merged.push(current);
79    merged
80}
81
82/// Helper struct to handle text position tracking and word extraction
83struct TextProcessor {
84    text: String,
85    skip_ranges: Vec<SkipRange>,
86}
87
88impl TextProcessor {
89    fn new(text: &str, skip_patterns: &[Regex]) -> Self {
90        let skip_ranges = find_skip_ranges(text, skip_patterns);
91        Self {
92            text: text.to_string(),
93            skip_ranges,
94        }
95    }
96
97    fn should_skip(&self, start_byte: usize, word_len: usize) -> bool {
98        is_within_skip_range(start_byte, start_byte + word_len, &self.skip_ranges)
99    }
100
101    fn process_words_with_check<F>(&self, mut check_function: F) -> Vec<WordLocation>
102    where
103        F: FnMut(&str) -> bool,
104    {
105        // First pass: collect all unique words with their positions
106        let estimated_words = (self.text.len() as f64 / 6.0).ceil() as usize;
107        let mut word_positions: HashMap<&str, Vec<TextRange>> =
108            HashMap::with_capacity(estimated_words);
109
110        for (offset, word) in self.text.split_word_bound_indices() {
111            if is_alphabetic(word) && !self.should_skip(offset, word.len()) {
112                self.collect_split_words(word, offset, &mut word_positions);
113            }
114        }
115
116        // Second pass: batch check unique words and filter
117        let mut result_locations: HashMap<String, Vec<TextRange>> = HashMap::new();
118        for (word_text, positions) in word_positions {
119            if !check_function(word_text) {
120                result_locations.insert(word_text.to_string(), positions);
121            }
122        }
123
124        result_locations
125            .into_iter()
126            .map(|(word, locations)| WordLocation::new(word, locations))
127            .collect()
128    }
129
130    fn extract_words(&self) -> Vec<WordLocation> {
131        // Reuse the word collection logic by collecting all words (check always returns false)
132        self.process_words_with_check(|_| false)
133    }
134
135    fn collect_split_words<'a>(
136        &self,
137        word: &'a str,
138        offset: usize,
139        word_positions: &mut HashMap<&'a str, Vec<TextRange>>,
140    ) {
141        if !word.is_empty() {
142            let split = splitter::split(word);
143            for split_word in split {
144                if !is_numeric(split_word.word) {
145                    let word_start_byte = offset + split_word.start_byte;
146                    let location = TextRange {
147                        start_byte: word_start_byte,
148                        end_byte: word_start_byte + split_word.word.len(),
149                    };
150                    let word_text = split_word.word;
151                    word_positions.entry(word_text).or_default().push(location);
152                }
153            }
154        }
155    }
156}
157
158#[derive(Debug, Clone, PartialEq)]
159pub struct WordRef<'a> {
160    pub word: &'a str,
161    pub position: (u32, u32), // (start_char, line)
162}
163
164#[derive(Debug, Clone, PartialEq)]
165pub struct WordLocation {
166    pub word: String,
167    pub locations: Vec<TextRange>,
168}
169
170impl WordLocation {
171    pub fn new(word: String, locations: Vec<TextRange>) -> Self {
172        Self { word, locations }
173    }
174}
175
176pub fn find_locations(
177    text: &str,
178    language: LanguageType,
179    check_function: impl Fn(&str) -> bool,
180    skip_patterns: &[Regex],
181) -> Vec<WordLocation> {
182    match language {
183        LanguageType::Text => {
184            let processor = TextProcessor::new(text, skip_patterns);
185            processor.process_words_with_check(|word| check_function(word))
186        }
187        _ => find_locations_code(text, language, |word| check_function(word), skip_patterns),
188    }
189}
190
191fn find_locations_code(
192    text: &str,
193    language: LanguageType,
194    check_function: impl Fn(&str) -> bool,
195    skip_patterns: &[Regex],
196) -> Vec<WordLocation> {
197    let language_setting =
198        get_language_setting(language).expect("This _should_ never happen. Famous last words.");
199
200    // Parse under global lock to protect external scanners with global C state.
201    // The lock covers create + parse; Tree is fully owned after parse returns.
202    let tree = {
203        let mut cache = PARSER_CACHE.lock().unwrap();
204        let parser = cache.entry(language).or_insert_with(|| {
205            let mut parser = Parser::new();
206            let lang = language_setting.language().unwrap();
207            parser.set_language(&lang).unwrap();
208            parser
209        });
210        parser.parse(text, None).unwrap()
211    };
212
213    let root_node = tree.root_node();
214    let lang = language_setting.language().unwrap();
215    let query = Query::new(&lang, language_setting.query).unwrap();
216    let mut cursor = QueryCursor::new();
217    let mut word_locations: HashMap<String, HashSet<TextRange>> = HashMap::new();
218    let provider = text.as_bytes();
219    let mut matches_query = cursor.matches(&query, root_node, provider);
220
221    // Find all skip ranges from patterns matched against the full source text
222    let all_skip_ranges = find_skip_ranges(text, skip_patterns);
223
224    while let Some(match_) = matches_query.next() {
225        for capture in match_.captures {
226            let node = capture.node;
227            let node_start_byte = node.start_byte();
228
229            let node_text = node.utf8_text(provider).unwrap();
230            let processor = TextProcessor::new(node_text, &[]);
231            let words = processor.extract_words();
232
233            // Check words against global skip ranges and dictionary
234            for word_pos in words {
235                if !check_function(&word_pos.word) {
236                    for range in word_pos.locations {
237                        let global_start = range.start_byte + node_start_byte;
238                        let global_end = range.end_byte + node_start_byte;
239
240                        // Skip if word is entirely within a skip range
241                        if is_within_skip_range(global_start, global_end, &all_skip_ranges) {
242                            continue;
243                        }
244
245                        let location = TextRange {
246                            start_byte: global_start,
247                            end_byte: global_end,
248                        };
249                        if let Some(existing_result) = word_locations.get_mut(&word_pos.word) {
250                            let added = existing_result.insert(location);
251                            debug_assert!(
252                                added,
253                                "Two of the same locations found. Make a better query. Word: {}, Location: {:?}",
254                                word_pos.word, location
255                            );
256                        } else {
257                            let mut set = HashSet::new();
258                            set.insert(location);
259                            word_locations.insert(word_pos.word.clone(), set);
260                        }
261                    }
262                }
263            }
264        }
265    }
266
267    word_locations
268        .keys()
269        .map(|word| WordLocation {
270            word: word.clone(),
271            locations: word_locations
272                .get(word)
273                .cloned()
274                .unwrap_or_default()
275                .into_iter()
276                .collect(),
277        })
278        .collect()
279}
280
281fn is_numeric(s: &str) -> bool {
282    s.chars().any(|c| c.is_numeric())
283}
284
285fn is_alphabetic(c: &str) -> bool {
286    c.chars().any(|c| c.is_alphabetic())
287}
288
289/// Get a UTF-8 word from a string given the start and end bytes in utf16.
290pub fn get_word_from_string(start_utf16: usize, end_utf16: usize, text: &str) -> String {
291    let utf16_slice: Vec<u16> = text
292        .encode_utf16()
293        .skip(start_utf16)
294        .take(end_utf16 - start_utf16)
295        .collect();
296    String::from_utf16_lossy(&utf16_slice)
297}
298
299#[cfg(test)]
300mod parser_tests {
301    use super::*;
302
303    #[test]
304    fn test_spell_checking() {
305        let text = "HelloWorld calc_wrld";
306        let results = find_locations(text, LanguageType::Text, |_| false, &[]);
307        println!("{results:?}");
308        assert_eq!(results.len(), 4);
309    }
310
311    #[test]
312    fn test_get_words_from_text() {
313        let text = r#"
314            HelloWorld calc_wrld
315            I'm a contraction, don't ignore me
316            this is a 3rd line.
317            "#;
318        let expected = vec![
319            ("Hello", (13, 18)),
320            ("World", (18, 23)),
321            ("calc", (24, 28)),
322            ("wrld", (29, 33)),
323            ("I'm", (46, 49)),
324            ("a", (50, 51)),
325            ("contraction", (52, 63)),
326            ("don't", (65, 70)),
327            ("ignore", (71, 77)),
328            ("me", (78, 80)),
329            ("this", (93, 97)),
330            ("is", (98, 100)),
331            ("a", (101, 102)),
332            ("rd", (104, 106)),
333            ("line", (107, 111)),
334        ];
335        let processor = TextProcessor::new(text, &[]);
336        let words = processor.extract_words();
337        println!("{words:?}");
338        for word in words {
339            let loc = word.locations.first().unwrap();
340            let pos = (loc.start_byte, loc.end_byte);
341            assert!(
342                expected.contains(&(word.word.as_str(), pos)),
343                "Expected word '{}' to be at position {:?}",
344                word.word,
345                pos
346            );
347        }
348    }
349
350    #[test]
351    fn test_contraction() {
352        let text = "I'm a contraction, wouldn't you agree'?";
353        let processor = TextProcessor::new(text, &[]);
354        let words = processor.extract_words();
355        println!("{words:?}");
356        let expected = ["I'm", "a", "contraction", "wouldn't", "you", "agree"];
357        for word in words {
358            assert!(expected.contains(&word.word.as_str()));
359        }
360    }
361
362    #[test]
363    fn test_get_word_from_string() {
364        // Test with ASCII characters
365        let text = "Hello World";
366        assert_eq!(get_word_from_string(0, 5, text), "Hello");
367        assert_eq!(get_word_from_string(6, 11, text), "World");
368
369        // Test with partial words
370        assert_eq!(get_word_from_string(2, 5, text), "llo");
371
372        // Test with Unicode characters
373        let unicode_text = "こんにちは世界";
374        assert_eq!(get_word_from_string(0, 5, unicode_text), "こんにちは");
375        assert_eq!(get_word_from_string(5, 7, unicode_text), "世界");
376
377        // Test with emoji (which can be multi-codepoint)
378        let emoji_text = "Hello 👨‍👩‍👧‍👦 World";
379        assert_eq!(get_word_from_string(6, 17, emoji_text), "👨‍👩‍👧‍👦");
380    }
381    #[test]
382    fn test_unicode_character_handling() {
383        crate::logging::init_test_logging();
384        let text = "©<div>badword</div>";
385        let processor = TextProcessor::new(text, &[]);
386        let words = processor.extract_words();
387        println!("{words:?}");
388
389        // Make sure "badword" is included and correctly positioned
390        assert!(words.iter().any(|word| word.word == "badword"));
391
392        // If "badword" is found, verify its position
393        if let Some(pos) = words.iter().find(|word| word.word == "badword") {
394            // The correct position should be 6 (after ©<div>)
395            let start_byte = pos.locations.first().unwrap().start_byte;
396            let end_byte = pos.locations.first().unwrap().end_byte;
397            assert_eq!(
398                start_byte, 7,
399                "Expected 'badword' to start at character position 7"
400            );
401            assert_eq!(end_byte, 14, "Expected 'badword' to be on end_byte 14");
402        } else {
403            panic!("Word 'badword' not found in the text");
404        }
405    }
406
407    #[test]
408    fn test_duplicate_word_locations() {
409        // Use a code language to exercise find_locations_code path
410        let text = "// wrld foo wrld";
411        let results = find_locations(text, LanguageType::Rust, |_| false, &[]);
412        let wrld = results.iter().find(|loc| loc.word == "wrld").unwrap();
413        assert_eq!(
414            wrld.locations.len(),
415            2,
416            "Expected two locations for repeated word 'wrld'"
417        );
418    }
419
420    // Something is up with the HTML tree-sitter package
421    // #[test]
422    // fn test_spell_checking_with_unicode() {
423    //     crate::log::init_test_logging();
424    //     let text = "©<div>badword</div>";
425
426    //     // Mock spell check function that flags "badword"
427    //     let results = find_locations(text, LanguageType::Html, |word| word != "badword");
428
429    //     println!("{:?}", results);
430
431    //     // Ensure "badword" is flagged
432    //     let badword_result = results.iter().find(|loc| loc.word == "badword");
433    //     assert!(badword_result.is_some(), "Expected 'badword' to be flagged");
434
435    //     // Check if the location is correct
436    //     if let Some(location) = badword_result {
437    //         assert_eq!(
438    //             location.locations.len(),
439    //             1,
440    //             "Expected exactly one location for 'badword'"
441    //         );
442    //         let range = &location.locations[0];
443
444    //         // The word should start after "©<div>" which is 6 characters
445    //         assert_eq!(range.start_char, 6, "Wrong start position for 'badword'");
446
447    //         // The word should end after "badword" which is 13 characters from the start
448    //         assert_eq!(range.end_char, 13, "Wrong end position for 'badword'");
449    //     }
450    // }
451}