codebook/
parser.rs

1use crate::splitter::{self};
2
3use crate::queries::{LanguageType, get_language_setting};
4use regex::Regex;
5use std::collections::{HashMap, HashSet};
6use streaming_iterator::StreamingIterator;
7use tree_sitter::{Parser, Query, QueryCursor};
8use unicode_segmentation::UnicodeSegmentation;
9
10#[derive(Debug, Clone, Copy, PartialEq, Ord, Eq, PartialOrd, Hash)]
11pub struct TextRange {
12    /// Start position in utf-8 byte offset
13    pub start_byte: usize,
14    /// End position in utf-8 byte offset
15    pub end_byte: usize,
16}
17
18#[derive(Debug, Clone, Copy, PartialEq)]
19struct SkipRange {
20    /// Start position in utf-8 byte offset
21    start_byte: usize,
22    /// End position in utf-8 byte offset
23    end_byte: usize,
24}
25
26
27/// Check if a word at [start, end) is entirely within any skip range
28fn is_within_skip_range(start: usize, end: usize, skip_ranges: &[SkipRange]) -> bool {
29    skip_ranges
30        .iter()
31        .any(|r| start >= r.start_byte && end <= r.end_byte)
32}
33
34/// Find skip ranges from pattern matches in text.
35fn find_skip_ranges(text: &str, patterns: &[Regex]) -> Vec<SkipRange> {
36    if patterns.is_empty() {
37        return Vec::new();
38    }
39
40    let mut ranges = Vec::new();
41
42    for pattern in patterns {
43        for regex_match in pattern.find_iter(text) {
44            ranges.push(SkipRange {
45                start_byte: regex_match.start(),
46                end_byte: regex_match.end(),
47            });
48        }
49    }
50
51    ranges.sort_by_key(|r| r.start_byte);
52    merge_overlapping_ranges(ranges)
53}
54
55/// Merge overlapping or adjacent ranges
56fn merge_overlapping_ranges(ranges: Vec<SkipRange>) -> Vec<SkipRange> {
57    if ranges.is_empty() {
58        return ranges;
59    }
60
61    let mut merged = Vec::new();
62    let mut current = ranges[0];
63
64    for range in ranges.into_iter().skip(1) {
65        if range.start_byte <= current.end_byte {
66            current.end_byte = current.end_byte.max(range.end_byte);
67        } else {
68            merged.push(current);
69            current = range;
70        }
71    }
72    merged.push(current);
73    merged
74}
75
76/// Helper struct to handle text position tracking and word extraction
77struct TextProcessor {
78    text: String,
79    skip_ranges: Vec<SkipRange>,
80}
81
82impl TextProcessor {
83    fn new(text: &str, skip_patterns: &[Regex]) -> Self {
84        let skip_ranges = find_skip_ranges(text, skip_patterns);
85        Self {
86            text: text.to_string(),
87            skip_ranges,
88        }
89    }
90
91    fn should_skip(&self, start_byte: usize, word_len: usize) -> bool {
92        is_within_skip_range(start_byte, start_byte + word_len, &self.skip_ranges)
93    }
94
95    fn process_words_with_check<F>(&self, mut check_function: F) -> Vec<WordLocation>
96    where
97        F: FnMut(&str) -> bool,
98    {
99        // First pass: collect all unique words with their positions
100        let estimated_words = (self.text.len() as f64 / 6.0).ceil() as usize;
101        let mut word_positions: HashMap<&str, Vec<TextRange>> =
102            HashMap::with_capacity(estimated_words);
103
104        for (offset, word) in self.text.split_word_bound_indices() {
105            if is_alphabetic(word) && !self.should_skip(offset, word.len()) {
106                self.collect_split_words(word, offset, &mut word_positions);
107            }
108        }
109
110        // Second pass: batch check unique words and filter
111        let mut result_locations: HashMap<String, Vec<TextRange>> = HashMap::new();
112        for (word_text, positions) in word_positions {
113            if !check_function(word_text) {
114                result_locations.insert(word_text.to_string(), positions);
115            }
116        }
117
118        result_locations
119            .into_iter()
120            .map(|(word, locations)| WordLocation::new(word, locations))
121            .collect()
122    }
123
124    fn extract_words(&self) -> Vec<WordLocation> {
125        // Reuse the word collection logic by collecting all words (check always returns false)
126        self.process_words_with_check(|_| false)
127    }
128
129    fn collect_split_words<'a>(
130        &self,
131        word: &'a str,
132        offset: usize,
133        word_positions: &mut HashMap<&'a str, Vec<TextRange>>,
134    ) {
135        if !word.is_empty() {
136            let split = splitter::split(word);
137            for split_word in split {
138                if !is_numeric(split_word.word) {
139                    let word_start_byte = offset + split_word.start_byte;
140                    let location = TextRange {
141                        start_byte: word_start_byte,
142                        end_byte: word_start_byte + split_word.word.len(),
143                    };
144                    let word_text = split_word.word;
145                    word_positions.entry(word_text).or_default().push(location);
146                }
147            }
148        }
149    }
150}
151
152#[derive(Debug, Clone, PartialEq)]
153pub struct WordRef<'a> {
154    pub word: &'a str,
155    pub position: (u32, u32), // (start_char, line)
156}
157
158#[derive(Debug, Clone, PartialEq)]
159pub struct WordLocation {
160    pub word: String,
161    pub locations: Vec<TextRange>,
162}
163
164impl WordLocation {
165    pub fn new(word: String, locations: Vec<TextRange>) -> Self {
166        Self { word, locations }
167    }
168}
169
170pub fn find_locations(
171    text: &str,
172    language: LanguageType,
173    check_function: impl Fn(&str) -> bool,
174    skip_patterns: &[Regex],
175) -> Vec<WordLocation> {
176    match language {
177        LanguageType::Text => {
178            let processor = TextProcessor::new(text, skip_patterns);
179            processor.process_words_with_check(|word| check_function(word))
180        }
181        _ => find_locations_code(text, language, |word| check_function(word), skip_patterns),
182    }
183}
184
185fn find_locations_code(
186    text: &str,
187    language: LanguageType,
188    check_function: impl Fn(&str) -> bool,
189    skip_patterns: &[Regex],
190) -> Vec<WordLocation> {
191    let language_setting =
192        get_language_setting(language).expect("This _should_ never happen. Famous last words.");
193    let mut parser = Parser::new();
194    let language = language_setting.language().unwrap();
195    parser.set_language(&language).unwrap();
196
197    let tree = parser.parse(text, None).unwrap();
198    let root_node = tree.root_node();
199
200    let query = Query::new(&language, language_setting.query).unwrap();
201    let mut cursor = QueryCursor::new();
202    let mut word_locations: HashMap<String, HashSet<TextRange>> = HashMap::new();
203    let provider = text.as_bytes();
204    let mut matches_query = cursor.matches(&query, root_node, provider);
205
206    // Find all skip ranges from patterns matched against the full source text
207    let all_skip_ranges = find_skip_ranges(text, skip_patterns);
208
209    while let Some(match_) = matches_query.next() {
210        for capture in match_.captures {
211            let node = capture.node;
212            let node_start_byte = node.start_byte();
213
214            let node_text = node.utf8_text(provider).unwrap();
215            let processor = TextProcessor::new(node_text, &[]);
216            let words = processor.extract_words();
217
218            // Check words against global skip ranges and dictionary
219            for word_pos in words {
220                if !check_function(&word_pos.word) {
221                    for range in word_pos.locations {
222                        let global_start = range.start_byte + node_start_byte;
223                        let global_end = range.end_byte + node_start_byte;
224
225                        // Skip if word is entirely within a skip range
226                        if is_within_skip_range(global_start, global_end, &all_skip_ranges) {
227                            continue;
228                        }
229
230                        let location = TextRange {
231                            start_byte: global_start,
232                            end_byte: global_end,
233                        };
234                        if let Some(existing_result) = word_locations.get_mut(&word_pos.word) {
235                            #[cfg(debug_assertions)]
236                            {
237                                let added = existing_result.insert(location);
238                                if !added {
239                                    let word = word_pos.word.clone();
240                                    panic!(
241                                        "Two of the same locations found. Make a better query. Word: {word}, Location: {location:?}"
242                                    )
243                                }
244                            }
245                        } else {
246                            let mut set = HashSet::new();
247                            set.insert(location);
248                            word_locations.insert(word_pos.word.clone(), set);
249                        }
250                    }
251                }
252            }
253        }
254    }
255
256    word_locations
257        .keys()
258        .map(|word| WordLocation {
259            word: word.clone(),
260            locations: word_locations
261                .get(word)
262                .cloned()
263                .unwrap_or_default()
264                .into_iter()
265                .collect(),
266        })
267        .collect()
268}
269
270fn is_numeric(s: &str) -> bool {
271    s.chars().any(|c| c.is_numeric())
272}
273
274fn is_alphabetic(c: &str) -> bool {
275    c.chars().any(|c| c.is_alphabetic())
276}
277
278/// Get a UTF-8 word from a string given the start and end bytes in utf16.
279pub fn get_word_from_string(start_utf16: usize, end_utf16: usize, text: &str) -> String {
280    let utf16_slice: Vec<u16> = text
281        .encode_utf16()
282        .skip(start_utf16)
283        .take(end_utf16 - start_utf16)
284        .collect();
285    String::from_utf16_lossy(&utf16_slice)
286}
287
288#[cfg(test)]
289mod parser_tests {
290    use super::*;
291
292    #[test]
293    fn test_spell_checking() {
294        let text = "HelloWorld calc_wrld";
295        let results = find_locations(text, LanguageType::Text, |_| false, &[]);
296        println!("{results:?}");
297        assert_eq!(results.len(), 4);
298    }
299
300    #[test]
301    fn test_get_words_from_text() {
302        let text = r#"
303            HelloWorld calc_wrld
304            I'm a contraction, don't ignore me
305            this is a 3rd line.
306            "#;
307        let expected = vec![
308            ("Hello", (13, 18)),
309            ("World", (18, 23)),
310            ("calc", (24, 28)),
311            ("wrld", (29, 33)),
312            ("I'm", (46, 49)),
313            ("a", (50, 51)),
314            ("contraction", (52, 63)),
315            ("don't", (65, 70)),
316            ("ignore", (71, 77)),
317            ("me", (78, 80)),
318            ("this", (93, 97)),
319            ("is", (98, 100)),
320            ("a", (101, 102)),
321            ("rd", (104, 106)),
322            ("line", (107, 111)),
323        ];
324        let processor = TextProcessor::new(text, &[]);
325        let words = processor.extract_words();
326        println!("{words:?}");
327        for word in words {
328            let loc = word.locations.first().unwrap();
329            let pos = (loc.start_byte, loc.end_byte);
330            assert!(
331                expected.contains(&(word.word.as_str(), pos)),
332                "Expected word '{}' to be at position {:?}",
333                word.word,
334                pos
335            );
336        }
337    }
338
339    #[test]
340    fn test_contraction() {
341        let text = "I'm a contraction, wouldn't you agree'?";
342        let processor = TextProcessor::new(text, &[]);
343        let words = processor.extract_words();
344        println!("{words:?}");
345        let expected = ["I'm", "a", "contraction", "wouldn't", "you", "agree"];
346        for word in words {
347            assert!(expected.contains(&word.word.as_str()));
348        }
349    }
350
351    #[test]
352    fn test_get_word_from_string() {
353        // Test with ASCII characters
354        let text = "Hello World";
355        assert_eq!(get_word_from_string(0, 5, text), "Hello");
356        assert_eq!(get_word_from_string(6, 11, text), "World");
357
358        // Test with partial words
359        assert_eq!(get_word_from_string(2, 5, text), "llo");
360
361        // Test with Unicode characters
362        let unicode_text = "こんにちは世界";
363        assert_eq!(get_word_from_string(0, 5, unicode_text), "こんにちは");
364        assert_eq!(get_word_from_string(5, 7, unicode_text), "世界");
365
366        // Test with emoji (which can be multi-codepoint)
367        let emoji_text = "Hello 👨‍👩‍👧‍👦 World";
368        assert_eq!(get_word_from_string(6, 17, emoji_text), "👨‍👩‍👧‍👦");
369    }
370    #[test]
371    fn test_unicode_character_handling() {
372        crate::logging::init_test_logging();
373        let text = "©<div>badword</div>";
374        let processor = TextProcessor::new(text, &[]);
375        let words = processor.extract_words();
376        println!("{words:?}");
377
378        // Make sure "badword" is included and correctly positioned
379        assert!(words.iter().any(|word| word.word == "badword"));
380
381        // If "badword" is found, verify its position
382        if let Some(pos) = words.iter().find(|word| word.word == "badword") {
383            // The correct position should be 6 (after ©<div>)
384            let start_byte = pos.locations.first().unwrap().start_byte;
385            let end_byte = pos.locations.first().unwrap().end_byte;
386            assert_eq!(
387                start_byte, 7,
388                "Expected 'badword' to start at character position 7"
389            );
390            assert_eq!(end_byte, 14, "Expected 'badword' to be on end_byte 14");
391        } else {
392            panic!("Word 'badword' not found in the text");
393        }
394    }
395
396    // Something is up with the HTML tree-sitter package
397    // #[test]
398    // fn test_spell_checking_with_unicode() {
399    //     crate::log::init_test_logging();
400    //     let text = "©<div>badword</div>";
401
402    //     // Mock spell check function that flags "badword"
403    //     let results = find_locations(text, LanguageType::Html, |word| word != "badword");
404
405    //     println!("{:?}", results);
406
407    //     // Ensure "badword" is flagged
408    //     let badword_result = results.iter().find(|loc| loc.word == "badword");
409    //     assert!(badword_result.is_some(), "Expected 'badword' to be flagged");
410
411    //     // Check if the location is correct
412    //     if let Some(location) = badword_result {
413    //         assert_eq!(
414    //             location.locations.len(),
415    //             1,
416    //             "Expected exactly one location for 'badword'"
417    //         );
418    //         let range = &location.locations[0];
419
420    //         // The word should start after "©<div>" which is 6 characters
421    //         assert_eq!(range.start_char, 6, "Wrong start position for 'badword'");
422
423    //         // The word should end after "badword" which is 13 characters from the start
424    //         assert_eq!(range.end_char, 13, "Wrong end position for 'badword'");
425    //     }
426    // }
427}