1use crate::splitter::{self};
2
3use crate::queries::{LanguageType, get_language_setting};
4use regex::Regex;
5use std::collections::{HashMap, HashSet};
6use std::sync::{LazyLock, Mutex};
7use streaming_iterator::StreamingIterator;
8use tree_sitter::{Parser, Query, QueryCursor};
9use unicode_segmentation::UnicodeSegmentation;
10
11static PARSER_CACHE: LazyLock<Mutex<HashMap<LanguageType, Parser>>> =
15 LazyLock::new(|| Mutex::new(HashMap::new()));
16
17#[derive(Debug, Clone, Copy, PartialEq, Ord, Eq, PartialOrd, Hash)]
18pub struct TextRange {
19 pub start_byte: usize,
21 pub end_byte: usize,
23}
24
25#[derive(Debug, Clone, Copy, PartialEq)]
26struct SkipRange {
27 start_byte: usize,
29 end_byte: usize,
31}
32
33fn is_within_skip_range(start: usize, end: usize, skip_ranges: &[SkipRange]) -> bool {
35 skip_ranges
36 .iter()
37 .any(|r| start >= r.start_byte && end <= r.end_byte)
38}
39
40fn find_skip_ranges(text: &str, patterns: &[Regex]) -> Vec<SkipRange> {
42 if patterns.is_empty() {
43 return Vec::new();
44 }
45
46 let mut ranges = Vec::new();
47
48 for pattern in patterns {
49 for regex_match in pattern.find_iter(text) {
50 ranges.push(SkipRange {
51 start_byte: regex_match.start(),
52 end_byte: regex_match.end(),
53 });
54 }
55 }
56
57 ranges.sort_by_key(|r| r.start_byte);
58 merge_overlapping_ranges(ranges)
59}
60
61fn merge_overlapping_ranges(ranges: Vec<SkipRange>) -> Vec<SkipRange> {
63 if ranges.is_empty() {
64 return ranges;
65 }
66
67 let mut merged = Vec::new();
68 let mut current = ranges[0];
69
70 for range in ranges.into_iter().skip(1) {
71 if range.start_byte <= current.end_byte {
72 current.end_byte = current.end_byte.max(range.end_byte);
73 } else {
74 merged.push(current);
75 current = range;
76 }
77 }
78 merged.push(current);
79 merged
80}
81
82struct TextProcessor {
84 text: String,
85 skip_ranges: Vec<SkipRange>,
86}
87
88impl TextProcessor {
89 fn new(text: &str, skip_patterns: &[Regex]) -> Self {
90 let skip_ranges = find_skip_ranges(text, skip_patterns);
91 Self {
92 text: text.to_string(),
93 skip_ranges,
94 }
95 }
96
97 fn should_skip(&self, start_byte: usize, word_len: usize) -> bool {
98 is_within_skip_range(start_byte, start_byte + word_len, &self.skip_ranges)
99 }
100
101 fn process_words_with_check<F>(&self, mut check_function: F) -> Vec<WordLocation>
102 where
103 F: FnMut(&str) -> bool,
104 {
105 let estimated_words = (self.text.len() as f64 / 6.0).ceil() as usize;
107 let mut word_positions: HashMap<&str, Vec<TextRange>> =
108 HashMap::with_capacity(estimated_words);
109
110 for (offset, word) in self.text.split_word_bound_indices() {
111 if is_alphabetic(word) && !self.should_skip(offset, word.len()) {
112 self.collect_split_words(word, offset, &mut word_positions);
113 }
114 }
115
116 let mut result_locations: HashMap<String, Vec<TextRange>> = HashMap::new();
118 for (word_text, positions) in word_positions {
119 if !check_function(word_text) {
120 result_locations.insert(word_text.to_string(), positions);
121 }
122 }
123
124 result_locations
125 .into_iter()
126 .map(|(word, locations)| WordLocation::new(word, locations))
127 .collect()
128 }
129
130 fn extract_words(&self) -> Vec<WordLocation> {
131 self.process_words_with_check(|_| false)
133 }
134
135 fn collect_split_words<'a>(
136 &self,
137 word: &'a str,
138 offset: usize,
139 word_positions: &mut HashMap<&'a str, Vec<TextRange>>,
140 ) {
141 if !word.is_empty() {
142 let split = splitter::split(word);
143 for split_word in split {
144 if !is_numeric(split_word.word) {
145 let word_start_byte = offset + split_word.start_byte;
146 let location = TextRange {
147 start_byte: word_start_byte,
148 end_byte: word_start_byte + split_word.word.len(),
149 };
150 let word_text = split_word.word;
151 word_positions.entry(word_text).or_default().push(location);
152 }
153 }
154 }
155 }
156}
157
158#[derive(Debug, Clone, PartialEq)]
159pub struct WordRef<'a> {
160 pub word: &'a str,
161 pub position: (u32, u32), }
163
164#[derive(Debug, Clone, PartialEq)]
165pub struct WordLocation {
166 pub word: String,
167 pub locations: Vec<TextRange>,
168}
169
170impl WordLocation {
171 pub fn new(word: String, locations: Vec<TextRange>) -> Self {
172 Self { word, locations }
173 }
174}
175
176pub fn find_locations(
177 text: &str,
178 language: LanguageType,
179 check_function: impl Fn(&str) -> bool,
180 skip_patterns: &[Regex],
181) -> Vec<WordLocation> {
182 match language {
183 LanguageType::Text => {
184 let processor = TextProcessor::new(text, skip_patterns);
185 processor.process_words_with_check(|word| check_function(word))
186 }
187 _ => find_locations_code(text, language, |word| check_function(word), skip_patterns),
188 }
189}
190
191fn find_locations_code(
192 text: &str,
193 language: LanguageType,
194 check_function: impl Fn(&str) -> bool,
195 skip_patterns: &[Regex],
196) -> Vec<WordLocation> {
197 let language_setting =
198 get_language_setting(language).expect("This _should_ never happen. Famous last words.");
199
200 let tree = {
203 let mut cache = PARSER_CACHE.lock().unwrap();
204 let parser = cache.entry(language).or_insert_with(|| {
205 let mut parser = Parser::new();
206 let lang = language_setting.language().unwrap();
207 parser.set_language(&lang).unwrap();
208 parser
209 });
210 parser.parse(text, None).unwrap()
211 };
212
213 let root_node = tree.root_node();
214 let lang = language_setting.language().unwrap();
215 let query = Query::new(&lang, language_setting.query).unwrap();
216 let mut cursor = QueryCursor::new();
217 let mut word_locations: HashMap<String, HashSet<TextRange>> = HashMap::new();
218 let provider = text.as_bytes();
219 let mut matches_query = cursor.matches(&query, root_node, provider);
220
221 let all_skip_ranges = find_skip_ranges(text, skip_patterns);
223
224 while let Some(match_) = matches_query.next() {
225 for capture in match_.captures {
226 let node = capture.node;
227 let node_start_byte = node.start_byte();
228
229 let node_text = node.utf8_text(provider).unwrap();
230 let processor = TextProcessor::new(node_text, &[]);
231 let words = processor.extract_words();
232
233 for word_pos in words {
235 if !check_function(&word_pos.word) {
236 for range in word_pos.locations {
237 let global_start = range.start_byte + node_start_byte;
238 let global_end = range.end_byte + node_start_byte;
239
240 if is_within_skip_range(global_start, global_end, &all_skip_ranges) {
242 continue;
243 }
244
245 let location = TextRange {
246 start_byte: global_start,
247 end_byte: global_end,
248 };
249 if let Some(existing_result) = word_locations.get_mut(&word_pos.word) {
250 let added = existing_result.insert(location);
251 debug_assert!(
252 added,
253 "Two of the same locations found. Make a better query. Word: {}, Location: {:?}",
254 word_pos.word, location
255 );
256 } else {
257 let mut set = HashSet::new();
258 set.insert(location);
259 word_locations.insert(word_pos.word.clone(), set);
260 }
261 }
262 }
263 }
264 }
265 }
266
267 word_locations
268 .keys()
269 .map(|word| WordLocation {
270 word: word.clone(),
271 locations: word_locations
272 .get(word)
273 .cloned()
274 .unwrap_or_default()
275 .into_iter()
276 .collect(),
277 })
278 .collect()
279}
280
281fn is_numeric(s: &str) -> bool {
282 s.chars().any(|c| c.is_numeric())
283}
284
285fn is_alphabetic(c: &str) -> bool {
286 c.chars().any(|c| c.is_alphabetic())
287}
288
289pub fn get_word_from_string(start_utf16: usize, end_utf16: usize, text: &str) -> String {
291 let utf16_slice: Vec<u16> = text
292 .encode_utf16()
293 .skip(start_utf16)
294 .take(end_utf16 - start_utf16)
295 .collect();
296 String::from_utf16_lossy(&utf16_slice)
297}
298
299#[cfg(test)]
300mod parser_tests {
301 use super::*;
302
303 #[test]
304 fn test_spell_checking() {
305 let text = "HelloWorld calc_wrld";
306 let results = find_locations(text, LanguageType::Text, |_| false, &[]);
307 println!("{results:?}");
308 assert_eq!(results.len(), 4);
309 }
310
311 #[test]
312 fn test_get_words_from_text() {
313 let text = r#"
314 HelloWorld calc_wrld
315 I'm a contraction, don't ignore me
316 this is a 3rd line.
317 "#;
318 let expected = vec![
319 ("Hello", (13, 18)),
320 ("World", (18, 23)),
321 ("calc", (24, 28)),
322 ("wrld", (29, 33)),
323 ("I'm", (46, 49)),
324 ("a", (50, 51)),
325 ("contraction", (52, 63)),
326 ("don't", (65, 70)),
327 ("ignore", (71, 77)),
328 ("me", (78, 80)),
329 ("this", (93, 97)),
330 ("is", (98, 100)),
331 ("a", (101, 102)),
332 ("rd", (104, 106)),
333 ("line", (107, 111)),
334 ];
335 let processor = TextProcessor::new(text, &[]);
336 let words = processor.extract_words();
337 println!("{words:?}");
338 for word in words {
339 let loc = word.locations.first().unwrap();
340 let pos = (loc.start_byte, loc.end_byte);
341 assert!(
342 expected.contains(&(word.word.as_str(), pos)),
343 "Expected word '{}' to be at position {:?}",
344 word.word,
345 pos
346 );
347 }
348 }
349
350 #[test]
351 fn test_contraction() {
352 let text = "I'm a contraction, wouldn't you agree'?";
353 let processor = TextProcessor::new(text, &[]);
354 let words = processor.extract_words();
355 println!("{words:?}");
356 let expected = ["I'm", "a", "contraction", "wouldn't", "you", "agree"];
357 for word in words {
358 assert!(expected.contains(&word.word.as_str()));
359 }
360 }
361
362 #[test]
363 fn test_get_word_from_string() {
364 let text = "Hello World";
366 assert_eq!(get_word_from_string(0, 5, text), "Hello");
367 assert_eq!(get_word_from_string(6, 11, text), "World");
368
369 assert_eq!(get_word_from_string(2, 5, text), "llo");
371
372 let unicode_text = "こんにちは世界";
374 assert_eq!(get_word_from_string(0, 5, unicode_text), "こんにちは");
375 assert_eq!(get_word_from_string(5, 7, unicode_text), "世界");
376
377 let emoji_text = "Hello 👨👩👧👦 World";
379 assert_eq!(get_word_from_string(6, 17, emoji_text), "👨👩👧👦");
380 }
381 #[test]
382 fn test_unicode_character_handling() {
383 crate::logging::init_test_logging();
384 let text = "©<div>badword</div>";
385 let processor = TextProcessor::new(text, &[]);
386 let words = processor.extract_words();
387 println!("{words:?}");
388
389 assert!(words.iter().any(|word| word.word == "badword"));
391
392 if let Some(pos) = words.iter().find(|word| word.word == "badword") {
394 let start_byte = pos.locations.first().unwrap().start_byte;
396 let end_byte = pos.locations.first().unwrap().end_byte;
397 assert_eq!(
398 start_byte, 7,
399 "Expected 'badword' to start at character position 7"
400 );
401 assert_eq!(end_byte, 14, "Expected 'badword' to be on end_byte 14");
402 } else {
403 panic!("Word 'badword' not found in the text");
404 }
405 }
406
407 #[test]
408 fn test_duplicate_word_locations() {
409 let text = "// wrld foo wrld";
411 let results = find_locations(text, LanguageType::Rust, |_| false, &[]);
412 let wrld = results.iter().find(|loc| loc.word == "wrld").unwrap();
413 assert_eq!(
414 wrld.locations.len(),
415 2,
416 "Expected two locations for repeated word 'wrld'"
417 );
418 }
419
420 }