1use crate::splitter::{self};
2
3use crate::queries::{LanguageType, get_language_setting};
4use regex::Regex;
5use std::collections::{HashMap, HashSet};
6use streaming_iterator::StreamingIterator;
7use tree_sitter::{Parser, Query, QueryCursor};
8use unicode_segmentation::UnicodeSegmentation;
9
10#[derive(Debug, Clone, Copy, PartialEq, Ord, Eq, PartialOrd, Hash)]
11pub struct TextRange {
12 pub start_byte: usize,
14 pub end_byte: usize,
16}
17
18#[derive(Debug, Clone, Copy, PartialEq)]
19struct SkipRange {
20 start_byte: usize,
22 end_byte: usize,
24}
25
26
27fn is_within_skip_range(start: usize, end: usize, skip_ranges: &[SkipRange]) -> bool {
29 skip_ranges
30 .iter()
31 .any(|r| start >= r.start_byte && end <= r.end_byte)
32}
33
34fn find_skip_ranges(text: &str, patterns: &[Regex]) -> Vec<SkipRange> {
36 if patterns.is_empty() {
37 return Vec::new();
38 }
39
40 let mut ranges = Vec::new();
41
42 for pattern in patterns {
43 for regex_match in pattern.find_iter(text) {
44 ranges.push(SkipRange {
45 start_byte: regex_match.start(),
46 end_byte: regex_match.end(),
47 });
48 }
49 }
50
51 ranges.sort_by_key(|r| r.start_byte);
52 merge_overlapping_ranges(ranges)
53}
54
55fn merge_overlapping_ranges(ranges: Vec<SkipRange>) -> Vec<SkipRange> {
57 if ranges.is_empty() {
58 return ranges;
59 }
60
61 let mut merged = Vec::new();
62 let mut current = ranges[0];
63
64 for range in ranges.into_iter().skip(1) {
65 if range.start_byte <= current.end_byte {
66 current.end_byte = current.end_byte.max(range.end_byte);
67 } else {
68 merged.push(current);
69 current = range;
70 }
71 }
72 merged.push(current);
73 merged
74}
75
76struct TextProcessor {
78 text: String,
79 skip_ranges: Vec<SkipRange>,
80}
81
82impl TextProcessor {
83 fn new(text: &str, skip_patterns: &[Regex]) -> Self {
84 let skip_ranges = find_skip_ranges(text, skip_patterns);
85 Self {
86 text: text.to_string(),
87 skip_ranges,
88 }
89 }
90
91 fn should_skip(&self, start_byte: usize, word_len: usize) -> bool {
92 is_within_skip_range(start_byte, start_byte + word_len, &self.skip_ranges)
93 }
94
95 fn process_words_with_check<F>(&self, mut check_function: F) -> Vec<WordLocation>
96 where
97 F: FnMut(&str) -> bool,
98 {
99 let estimated_words = (self.text.len() as f64 / 6.0).ceil() as usize;
101 let mut word_positions: HashMap<&str, Vec<TextRange>> =
102 HashMap::with_capacity(estimated_words);
103
104 for (offset, word) in self.text.split_word_bound_indices() {
105 if is_alphabetic(word) && !self.should_skip(offset, word.len()) {
106 self.collect_split_words(word, offset, &mut word_positions);
107 }
108 }
109
110 let mut result_locations: HashMap<String, Vec<TextRange>> = HashMap::new();
112 for (word_text, positions) in word_positions {
113 if !check_function(word_text) {
114 result_locations.insert(word_text.to_string(), positions);
115 }
116 }
117
118 result_locations
119 .into_iter()
120 .map(|(word, locations)| WordLocation::new(word, locations))
121 .collect()
122 }
123
124 fn extract_words(&self) -> Vec<WordLocation> {
125 self.process_words_with_check(|_| false)
127 }
128
129 fn collect_split_words<'a>(
130 &self,
131 word: &'a str,
132 offset: usize,
133 word_positions: &mut HashMap<&'a str, Vec<TextRange>>,
134 ) {
135 if !word.is_empty() {
136 let split = splitter::split(word);
137 for split_word in split {
138 if !is_numeric(split_word.word) {
139 let word_start_byte = offset + split_word.start_byte;
140 let location = TextRange {
141 start_byte: word_start_byte,
142 end_byte: word_start_byte + split_word.word.len(),
143 };
144 let word_text = split_word.word;
145 word_positions.entry(word_text).or_default().push(location);
146 }
147 }
148 }
149 }
150}
151
152#[derive(Debug, Clone, PartialEq)]
153pub struct WordRef<'a> {
154 pub word: &'a str,
155 pub position: (u32, u32), }
157
158#[derive(Debug, Clone, PartialEq)]
159pub struct WordLocation {
160 pub word: String,
161 pub locations: Vec<TextRange>,
162}
163
164impl WordLocation {
165 pub fn new(word: String, locations: Vec<TextRange>) -> Self {
166 Self { word, locations }
167 }
168}
169
170pub fn find_locations(
171 text: &str,
172 language: LanguageType,
173 check_function: impl Fn(&str) -> bool,
174 skip_patterns: &[Regex],
175) -> Vec<WordLocation> {
176 match language {
177 LanguageType::Text => {
178 let processor = TextProcessor::new(text, skip_patterns);
179 processor.process_words_with_check(|word| check_function(word))
180 }
181 _ => find_locations_code(text, language, |word| check_function(word), skip_patterns),
182 }
183}
184
185fn find_locations_code(
186 text: &str,
187 language: LanguageType,
188 check_function: impl Fn(&str) -> bool,
189 skip_patterns: &[Regex],
190) -> Vec<WordLocation> {
191 let language_setting =
192 get_language_setting(language).expect("This _should_ never happen. Famous last words.");
193 let mut parser = Parser::new();
194 let language = language_setting.language().unwrap();
195 parser.set_language(&language).unwrap();
196
197 let tree = parser.parse(text, None).unwrap();
198 let root_node = tree.root_node();
199
200 let query = Query::new(&language, language_setting.query).unwrap();
201 let mut cursor = QueryCursor::new();
202 let mut word_locations: HashMap<String, HashSet<TextRange>> = HashMap::new();
203 let provider = text.as_bytes();
204 let mut matches_query = cursor.matches(&query, root_node, provider);
205
206 let all_skip_ranges = find_skip_ranges(text, skip_patterns);
208
209 while let Some(match_) = matches_query.next() {
210 for capture in match_.captures {
211 let node = capture.node;
212 let node_start_byte = node.start_byte();
213
214 let node_text = node.utf8_text(provider).unwrap();
215 let processor = TextProcessor::new(node_text, &[]);
216 let words = processor.extract_words();
217
218 for word_pos in words {
220 if !check_function(&word_pos.word) {
221 for range in word_pos.locations {
222 let global_start = range.start_byte + node_start_byte;
223 let global_end = range.end_byte + node_start_byte;
224
225 if is_within_skip_range(global_start, global_end, &all_skip_ranges) {
227 continue;
228 }
229
230 let location = TextRange {
231 start_byte: global_start,
232 end_byte: global_end,
233 };
234 if let Some(existing_result) = word_locations.get_mut(&word_pos.word) {
235 #[cfg(debug_assertions)]
236 {
237 let added = existing_result.insert(location);
238 if !added {
239 let word = word_pos.word.clone();
240 panic!(
241 "Two of the same locations found. Make a better query. Word: {word}, Location: {location:?}"
242 )
243 }
244 }
245 } else {
246 let mut set = HashSet::new();
247 set.insert(location);
248 word_locations.insert(word_pos.word.clone(), set);
249 }
250 }
251 }
252 }
253 }
254 }
255
256 word_locations
257 .keys()
258 .map(|word| WordLocation {
259 word: word.clone(),
260 locations: word_locations
261 .get(word)
262 .cloned()
263 .unwrap_or_default()
264 .into_iter()
265 .collect(),
266 })
267 .collect()
268}
269
270fn is_numeric(s: &str) -> bool {
271 s.chars().any(|c| c.is_numeric())
272}
273
274fn is_alphabetic(c: &str) -> bool {
275 c.chars().any(|c| c.is_alphabetic())
276}
277
278pub fn get_word_from_string(start_utf16: usize, end_utf16: usize, text: &str) -> String {
280 let utf16_slice: Vec<u16> = text
281 .encode_utf16()
282 .skip(start_utf16)
283 .take(end_utf16 - start_utf16)
284 .collect();
285 String::from_utf16_lossy(&utf16_slice)
286}
287
288#[cfg(test)]
289mod parser_tests {
290 use super::*;
291
292 #[test]
293 fn test_spell_checking() {
294 let text = "HelloWorld calc_wrld";
295 let results = find_locations(text, LanguageType::Text, |_| false, &[]);
296 println!("{results:?}");
297 assert_eq!(results.len(), 4);
298 }
299
300 #[test]
301 fn test_get_words_from_text() {
302 let text = r#"
303 HelloWorld calc_wrld
304 I'm a contraction, don't ignore me
305 this is a 3rd line.
306 "#;
307 let expected = vec![
308 ("Hello", (13, 18)),
309 ("World", (18, 23)),
310 ("calc", (24, 28)),
311 ("wrld", (29, 33)),
312 ("I'm", (46, 49)),
313 ("a", (50, 51)),
314 ("contraction", (52, 63)),
315 ("don't", (65, 70)),
316 ("ignore", (71, 77)),
317 ("me", (78, 80)),
318 ("this", (93, 97)),
319 ("is", (98, 100)),
320 ("a", (101, 102)),
321 ("rd", (104, 106)),
322 ("line", (107, 111)),
323 ];
324 let processor = TextProcessor::new(text, &[]);
325 let words = processor.extract_words();
326 println!("{words:?}");
327 for word in words {
328 let loc = word.locations.first().unwrap();
329 let pos = (loc.start_byte, loc.end_byte);
330 assert!(
331 expected.contains(&(word.word.as_str(), pos)),
332 "Expected word '{}' to be at position {:?}",
333 word.word,
334 pos
335 );
336 }
337 }
338
339 #[test]
340 fn test_contraction() {
341 let text = "I'm a contraction, wouldn't you agree'?";
342 let processor = TextProcessor::new(text, &[]);
343 let words = processor.extract_words();
344 println!("{words:?}");
345 let expected = ["I'm", "a", "contraction", "wouldn't", "you", "agree"];
346 for word in words {
347 assert!(expected.contains(&word.word.as_str()));
348 }
349 }
350
351 #[test]
352 fn test_get_word_from_string() {
353 let text = "Hello World";
355 assert_eq!(get_word_from_string(0, 5, text), "Hello");
356 assert_eq!(get_word_from_string(6, 11, text), "World");
357
358 assert_eq!(get_word_from_string(2, 5, text), "llo");
360
361 let unicode_text = "こんにちは世界";
363 assert_eq!(get_word_from_string(0, 5, unicode_text), "こんにちは");
364 assert_eq!(get_word_from_string(5, 7, unicode_text), "世界");
365
366 let emoji_text = "Hello 👨👩👧👦 World";
368 assert_eq!(get_word_from_string(6, 17, emoji_text), "👨👩👧👦");
369 }
370 #[test]
371 fn test_unicode_character_handling() {
372 crate::logging::init_test_logging();
373 let text = "©<div>badword</div>";
374 let processor = TextProcessor::new(text, &[]);
375 let words = processor.extract_words();
376 println!("{words:?}");
377
378 assert!(words.iter().any(|word| word.word == "badword"));
380
381 if let Some(pos) = words.iter().find(|word| word.word == "badword") {
383 let start_byte = pos.locations.first().unwrap().start_byte;
385 let end_byte = pos.locations.first().unwrap().end_byte;
386 assert_eq!(
387 start_byte, 7,
388 "Expected 'badword' to start at character position 7"
389 );
390 assert_eq!(end_byte, 14, "Expected 'badword' to be on end_byte 14");
391 } else {
392 panic!("Word 'badword' not found in the text");
393 }
394 }
395
396 }