1use crate::checker::WordCandidate;
2use crate::queries::{LANGUAGE_SETTINGS, LanguageType, get_language_setting};
3use crate::splitter;
4use regex::Regex;
5use std::collections::{HashMap, HashSet};
6use std::str::FromStr;
7use std::sync::{LazyLock, Mutex};
8use streaming_iterator::StreamingIterator;
9use tree_sitter::{Parser, Query, QueryCursor};
10use unicode_segmentation::UnicodeSegmentation;
11
12static PARSER_CACHE: LazyLock<Mutex<HashMap<LanguageType, Parser>>> =
16 LazyLock::new(|| Mutex::new(HashMap::new()));
17
18struct CompiledQuery {
20 query: Query,
21 capture_names: Vec<String>,
22}
23
24static COMPILED_QUERIES: LazyLock<HashMap<LanguageType, CompiledQuery>> = LazyLock::new(|| {
29 let mut map = HashMap::new();
30 for setting in LANGUAGE_SETTINGS {
31 let Some(lang) = setting.language() else {
32 continue;
33 };
34 if setting.query.is_empty() {
35 continue;
36 }
37 let query = Query::new(&lang, setting.query)
38 .unwrap_or_else(|e| panic!("Failed to compile query for {:?}: {e}", setting.type_));
39 let capture_names = query
40 .capture_names()
41 .iter()
42 .map(|s| s.to_string())
43 .collect();
44 map.insert(
45 setting.type_,
46 CompiledQuery {
47 query,
48 capture_names,
49 },
50 );
51 }
52 map
53});
54
55#[derive(Debug, Clone, Copy, PartialEq, Ord, Eq, PartialOrd, Hash)]
56pub struct TextRange {
57 pub start_byte: usize,
59 pub end_byte: usize,
61}
62
63#[derive(Debug, Clone, Copy, PartialEq)]
64struct SkipRange {
65 start_byte: usize,
66 end_byte: usize,
67}
68
69fn is_within_skip_range(start: usize, end: usize, skip_ranges: &[SkipRange]) -> bool {
70 skip_ranges
71 .iter()
72 .any(|r| start >= r.start_byte && end <= r.end_byte)
73}
74
75fn find_skip_ranges(text: &str, patterns: &[Regex]) -> Vec<SkipRange> {
76 if patterns.is_empty() {
77 return Vec::new();
78 }
79 let mut ranges = Vec::new();
80 for pattern in patterns {
81 for regex_match in pattern.find_iter(text) {
82 ranges.push(SkipRange {
83 start_byte: regex_match.start(),
84 end_byte: regex_match.end(),
85 });
86 }
87 }
88 ranges.sort_by_key(|r| r.start_byte);
89 merge_overlapping_ranges(ranges)
90}
91
92fn merge_overlapping_ranges(ranges: Vec<SkipRange>) -> Vec<SkipRange> {
93 if ranges.is_empty() {
94 return ranges;
95 }
96 let mut merged = Vec::new();
97 let mut current = ranges[0];
98 for range in ranges.into_iter().skip(1) {
99 if range.start_byte <= current.end_byte {
100 current.end_byte = current.end_byte.max(range.end_byte);
101 } else {
102 merged.push(current);
103 current = range;
104 }
105 }
106 merged.push(current);
107 merged
108}
109
110#[derive(Debug, Clone, PartialEq)]
111pub struct WordLocation {
112 pub word: String,
113 pub locations: Vec<TextRange>,
114}
115
116impl WordLocation {
117 pub fn new(word: String, locations: Vec<TextRange>) -> Self {
118 Self { word, locations }
119 }
120}
121
122pub fn extract_all_words<'a>(
132 document_text: &'a str,
133 language: LanguageType,
134 tag_filter: &dyn Fn(&str) -> bool,
135 skip_patterns: &[Regex],
136) -> (Vec<WordCandidate<'a>>, HashSet<LanguageType>) {
137 let skip_ranges = find_skip_ranges(document_text, skip_patterns);
138 let mut result = ExtractionResult {
139 candidates: Vec::new(),
140 languages: HashSet::from([language]),
141 };
142
143 extract_recursive(
144 document_text,
145 0,
146 document_text.len(),
147 language,
148 tag_filter,
149 &skip_ranges,
150 &mut result,
151 );
152
153 (result.candidates, result.languages)
154}
155
156struct ExtractionResult<'a> {
158 candidates: Vec<WordCandidate<'a>>,
159 languages: HashSet<LanguageType>,
160}
161
162fn extract_recursive<'a>(
172 document_text: &'a str,
173 start_byte: usize,
174 end_byte: usize,
175 language: LanguageType,
176 tag_filter: &dyn Fn(&str) -> bool,
177 skip_ranges: &[SkipRange],
178 result: &mut ExtractionResult<'a>,
179) {
180 let language_setting = match get_language_setting(language) {
181 Some(s) => s,
182 None => {
183 let text = &document_text[start_byte..end_byte];
185 extract_words_from_text(text, start_byte, skip_ranges, &mut result.candidates);
186 return;
187 }
188 };
189
190 let region_text = &document_text[start_byte..end_byte];
191
192 let tree = {
194 let mut cache = PARSER_CACHE.lock().unwrap();
195 let parser = cache.entry(language).or_insert_with(|| {
196 let mut parser = Parser::new();
197 let lang = language_setting.language().unwrap();
198 parser.set_language(&lang).unwrap();
199 parser
200 });
201 parser.parse(region_text, None).unwrap()
202 };
203
204 let root_node = tree.root_node();
205 let compiled = COMPILED_QUERIES
206 .get(&language)
207 .expect("Language has a LanguageSetting but no compiled query; this should not happen");
208 let mut cursor = QueryCursor::new();
209 let provider = region_text.as_bytes();
210 let mut matches_query = cursor.matches(&compiled.query, root_node, provider);
211
212 while let Some(match_) = matches_query.next() {
213 let mut injection_content: Option<tree_sitter::Node> = None;
215 let mut injection_language_text: Option<&str> = None;
216
217 for capture in match_.captures {
218 let tag = &compiled.capture_names[capture.index as usize];
219 if tag == "injection.content" {
220 injection_content = Some(capture.node);
221 } else if tag == "injection.language" {
222 injection_language_text = Some(capture.node.utf8_text(provider).unwrap_or(""));
223 }
224 }
225
226 if let Some(content_node) = injection_content {
228 if let Some(lang_text) = injection_language_text {
229 let lowered = lang_text.trim().to_lowercase();
230 let child_lang = LanguageType::from_str(&lowered);
231 if let Ok(child_lang) = child_lang
232 && child_lang != LanguageType::Text
233 {
234 let child_start = content_node.start_byte() + start_byte;
235 let child_end = content_node.end_byte() + start_byte;
236 if child_start < child_end {
237 result.languages.insert(child_lang);
238 extract_recursive(
239 document_text,
240 child_start,
241 child_end,
242 child_lang,
243 tag_filter,
244 skip_ranges,
245 result,
246 );
247 }
248 }
249 }
250 continue;
251 }
252
253 for capture in match_.captures {
255 let tag = &compiled.capture_names[capture.index as usize];
256 let node = capture.node;
257 let node_start = node.start_byte() + start_byte;
258 let node_end = node.end_byte() + start_byte;
259
260 if node_start >= node_end {
261 continue;
262 }
263
264 if tag == "language" || tag == "injection.language" {
265 continue;
266 }
267
268 if let Some(lang_name) = tag.strip_prefix("injection.") {
269 if let Ok(child_lang) = LanguageType::from_str(lang_name)
271 && child_lang != LanguageType::Text
272 {
273 result.languages.insert(child_lang);
274 extract_recursive(
275 document_text,
276 node_start,
277 node_end,
278 child_lang,
279 tag_filter,
280 skip_ranges,
281 result,
282 );
283 }
284 continue;
285 }
286
287 if !tag_filter(tag) {
289 continue;
290 }
291
292 let node_text = node.utf8_text(provider).unwrap();
293 extract_words_from_text(node_text, node_start, skip_ranges, &mut result.candidates);
294 }
295 }
296}
297
298fn extract_words_from_text<'a>(
303 text: &'a str,
304 base_offset: usize,
305 skip_ranges: &[SkipRange],
306 candidates: &mut Vec<WordCandidate<'a>>,
307) {
308 let mut split_buf = Vec::new();
309 for (offset, word) in text.split_word_bound_indices() {
310 if !is_alphabetic(word) {
311 continue;
312 }
313 let global_offset = base_offset + offset;
314 if is_within_skip_range(global_offset, global_offset + word.len(), skip_ranges) {
315 continue;
316 }
317 splitter::split_into(word, &mut split_buf);
318 for split_word in &split_buf {
319 if is_numeric(split_word.word) {
320 continue;
321 }
322 let word_start = global_offset + split_word.start_byte;
323 let word_end = word_start + split_word.word.len();
324 if is_within_skip_range(word_start, word_end, skip_ranges) {
325 continue;
326 }
327 candidates.push(WordCandidate {
328 word: split_word.word,
329 start_byte: word_start,
330 end_byte: word_end,
331 });
332 }
333 }
334}
335
336fn is_numeric(s: &str) -> bool {
337 s.chars().any(|c| c.is_numeric())
338}
339
340fn is_alphabetic(c: &str) -> bool {
341 c.chars().any(|c| c.is_alphabetic())
342}
343
344pub fn get_word_from_string(start_utf16: usize, end_utf16: usize, text: &str) -> String {
346 let utf16_slice: Vec<u16> = text
347 .encode_utf16()
348 .skip(start_utf16)
349 .take(end_utf16 - start_utf16)
350 .collect();
351 String::from_utf16_lossy(&utf16_slice)
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357
358 #[test]
359 fn test_extract_words_plain_text() {
360 let text = "HelloWorld calc_wrld";
361 let (words, langs) = extract_all_words(text, LanguageType::Text, &|_| true, &[]);
362 let word_strings: Vec<&str> = words.iter().map(|w| w.word).collect();
363 assert!(word_strings.contains(&"Hello"));
364 assert!(word_strings.contains(&"World"));
365 assert!(word_strings.contains(&"calc"));
366 assert!(word_strings.contains(&"wrld"));
367 assert_eq!(words.len(), 4);
368 assert!(langs.contains(&LanguageType::Text));
369 }
370
371 #[test]
372 fn test_extract_words_contraction() {
373 let text = "I'm a contraction, wouldn't you agree'?";
374 let (words, _) = extract_all_words(text, LanguageType::Text, &|_| true, &[]);
375 let word_strings: Vec<&str> = words.iter().map(|w| w.word).collect();
376 let expected = ["I'm", "a", "contraction", "wouldn't", "you", "agree"];
377 for e in &expected {
378 assert!(word_strings.contains(e), "Expected word '{e}' not found");
379 }
380 }
381
382 #[test]
383 fn test_extract_words_code() {
384 let text = "// a comment\nfn main() {}";
385 let (words, langs) = extract_all_words(text, LanguageType::Rust, &|_| true, &[]);
386 assert!(!words.is_empty());
387 let word_strings: Vec<&str> = words.iter().map(|w| w.word).collect();
388 assert!(
389 word_strings.contains(&"comment"),
390 "Should find 'comment' in Rust comment"
391 );
392 assert!(langs.contains(&LanguageType::Rust));
393 }
394
395 #[test]
396 fn test_extract_words_tag_filter() {
397 let text = "// comment\nlet x = \"string value\";";
398 let (words, _) = extract_all_words(
399 text,
400 LanguageType::Rust,
401 &|tag| tag.starts_with("comment"),
402 &[],
403 );
404 let word_strings: Vec<&str> = words.iter().map(|w| w.word).collect();
405 assert!(word_strings.contains(&"comment"));
406 assert!(!word_strings.contains(&"string"));
407 assert!(!word_strings.contains(&"value"));
408 }
409
410 #[test]
411 fn test_extract_words_with_skip_patterns() {
412 let text = "check https://example.com this";
413 let url_pattern = Regex::new(r"https?://[^\s]+").unwrap();
414 let (words, _) = extract_all_words(text, LanguageType::Text, &|_| true, &[url_pattern]);
415 let word_strings: Vec<&str> = words.iter().map(|w| w.word).collect();
416 assert!(word_strings.contains(&"check"));
417 assert!(word_strings.contains(&"this"));
418 assert!(!word_strings.contains(&"https"));
419 assert!(!word_strings.contains(&"example"));
420 }
421
422 #[test]
423 fn test_extract_words_code_duplicates() {
424 let text = "// wrld foo wrld";
425 let (words, _) = extract_all_words(text, LanguageType::Rust, &|_| true, &[]);
426 let wrld_words: Vec<_> = words.iter().filter(|w| w.word == "wrld").collect();
427 assert_eq!(wrld_words.len(), 2, "Expected two occurrences of 'wrld'");
428 }
429
430 #[test]
431 fn test_markdown_injection_discovers_languages() {
432 let text =
433 "# Hello\n\nSome text.\n\n```python\ndef foo(): pass\n```\n\n```bash\necho hi\n```\n";
434 let (_, langs) = extract_all_words(text, LanguageType::Markdown, &|_| true, &[]);
435 assert!(langs.contains(&LanguageType::Markdown));
436 assert!(langs.contains(&LanguageType::Python));
437 assert!(langs.contains(&LanguageType::Bash));
438 }
439
440 #[test]
441 fn test_markdown_injection_extracts_code_words() {
442 let text = "# Hello\n\n```python\ndef some_functin(): pass\n```\n";
443 let (words, _) = extract_all_words(text, LanguageType::Markdown, &|_| true, &[]);
444 let word_strings: Vec<&str> = words.iter().map(|w| w.word).collect();
445 assert!(word_strings.contains(&"functin"));
446 assert!(word_strings.contains(&"Hello"));
447 }
448
449 #[test]
450 fn test_markdown_unknown_language_skipped() {
451 let text = "# Hello\n\n```unknownlang\nbadwwword\n```\n";
452 let (words, _) = extract_all_words(text, LanguageType::Markdown, &|_| true, &[]);
453 let word_strings: Vec<&str> = words.iter().map(|w| w.word).collect();
454 assert!(!word_strings.contains(&"badwwword"));
455 }
456
457 #[test]
458 fn test_markdown_html_block_injection() {
459 let text = "# Hello\n\n<div>\n <p>A misspeled word</p>\n</div>\n\nMore text.\n";
460 let (words, langs) = extract_all_words(text, LanguageType::Markdown, &|_| true, &[]);
461 let word_strings: Vec<&str> = words.iter().map(|w| w.word).collect();
462 assert!(langs.contains(&LanguageType::HTML));
463 assert!(word_strings.contains(&"misspeled"));
464 assert!(!word_strings.contains(&"div"));
465 }
466
467 #[test]
468 fn test_get_word_from_string() {
469 let text = "Hello World";
470 assert_eq!(get_word_from_string(0, 5, text), "Hello");
471 assert_eq!(get_word_from_string(6, 11, text), "World");
472
473 let unicode_text = "こんにちは世界";
474 assert_eq!(get_word_from_string(0, 5, unicode_text), "こんにちは");
475 assert_eq!(get_word_from_string(5, 7, unicode_text), "世界");
476
477 let emoji_text = "Hello 👨👩👧👦 World";
478 assert_eq!(get_word_from_string(6, 17, emoji_text), "👨👩👧👦");
479 }
480
481 #[test]
482 fn test_unicode_character_handling() {
483 crate::logging::init_test_logging();
484 let text = "©<div>badword</div>";
485 let (words, _) = extract_all_words(text, LanguageType::Text, &|_| true, &[]);
486 let bad_word = words.iter().find(|w| w.word == "badword");
487 assert!(bad_word.is_some(), "Expected 'badword' to be found");
488 let bw = bad_word.unwrap();
489 assert_eq!(bw.start_byte, 7);
490 assert_eq!(bw.end_byte, 14);
491 }
492}