1mod bibtex;
2mod forester;
3pub mod latex;
4mod org;
5mod query;
6mod rst;
7mod shared;
8mod sweave;
9mod tinylang;
10
11use anyhow::{Result, anyhow};
12use std::ops::Range;
13use std::path::Path;
14use tree_sitter::{Language, Parser};
15
16use crate::ignore_rules::{DirectiveRegion, IgnoreParser};
17
18use crate::sls::SchemaRegistry;
19
20pub struct ProseExtractor {
21 parser: Parser,
22 language: Language,
23}
24
25impl ProseExtractor {
26 pub fn new(language: Language) -> Result<Self> {
27 let mut parser = Parser::new();
28 parser.set_language(&language)?;
29 Ok(Self { parser, language })
30 }
31
32 pub fn extract(
33 &mut self,
34 text: &str,
35 lang_id: &str,
36 latex_extras: &latex::LatexExtras,
37 ) -> Result<Vec<ProseRange>> {
38 let tree = self
39 .parser
40 .parse(text, None)
41 .ok_or_else(|| anyhow!("Failed to parse text"))?;
42
43 let root = tree.root_node();
44
45 match lang_id {
46 "latex" => Ok(latex::extract(text, root, latex_extras)),
47 "sweave" => Ok(sweave::extract(text, root, latex_extras)),
48 "forester" => Ok(forester::extract(text, root)),
49 "tinylang" => Ok(tinylang::extract(text, root)),
50 "rst" => Ok(rst::extract(text, root)),
51 "bibtex" => Ok(bibtex::extract(text, root)),
52 "org" => Ok(org::extract(text, root)),
53 lang => query::extract(text, root, &self.language, lang),
54 }
55 }
56}
57
58pub fn extract_with_fallback(
64 text: &str,
65 lang_id: &str,
66 path: Option<&Path>,
67 schema_registry: Option<&SchemaRegistry>,
68 latex_extras: &latex::LatexExtras,
69) -> Result<Vec<ProseRange>> {
70 if let Some(ext) = path
71 .and_then(|value| value.extension())
72 .and_then(|value| value.to_str())
73 && crate::languages::builtin_language_for_extension(ext).is_none()
74 && let Some(schema) = schema_registry.and_then(|registry| registry.find_by_extension(ext))
75 {
76 return Ok(schema.extract(text));
77 }
78
79 let canonical_lang = crate::languages::resolve_language_id(lang_id);
80 let language = crate::languages::resolve_ts_language(canonical_lang);
81 let mut extractor = ProseExtractor::new(language)?;
82 let mut ranges = extractor.extract(text, canonical_lang, latex_extras)?;
83
84 let directives = IgnoreParser::parse_directives(text);
85 let resolved = IgnoreParser::resolve_all(text, &directives);
86 let type_regions: Vec<_> = resolved
87 .regions
88 .iter()
89 .filter(|r| r.options.doc_type.is_some())
90 .collect();
91 if !type_regions.is_empty() {
92 ranges = apply_type_overrides(text, ranges, &type_regions, latex_extras)?;
93 }
94
95 Ok(ranges)
96}
97
98fn apply_type_overrides(
105 text: &str,
106 base_ranges: Vec<ProseRange>,
107 type_regions: &[&DirectiveRegion],
108 latex_extras: &latex::LatexExtras,
109) -> Result<Vec<ProseRange>> {
110 let override_spans: Vec<&Range<usize>> = type_regions.iter().map(|r| &r.byte_range).collect();
111
112 let mut result: Vec<ProseRange> = base_ranges
114 .into_iter()
115 .filter(|r| {
116 !override_spans
117 .iter()
118 .any(|span| span.contains(&r.start_byte))
119 })
120 .collect();
121
122 for region in type_regions {
123 let doc_type = region.options.doc_type.as_deref().unwrap();
124 let canonical = crate::languages::resolve_language_id(doc_type);
125
126 if !crate::languages::SUPPORTED_LANGUAGE_IDS.contains(&canonical) {
127 eprintln!("lang-check: `type:{doc_type}` is not a supported language; skipping region");
128 continue;
129 }
130
131 let slice = &text[region.byte_range.clone()];
132 let ts_lang = crate::languages::resolve_ts_language(canonical);
133 let mut ext = ProseExtractor::new(ts_lang)?;
134 let sub_ranges = ext.extract(slice, canonical, latex_extras)?;
135
136 let offset = region.byte_range.start;
137 for mut r in sub_ranges {
138 r.start_byte += offset;
139 r.end_byte += offset;
140 r.exclusions = r
141 .exclusions
142 .into_iter()
143 .map(|(s, e)| (s + offset, e + offset))
144 .collect();
145 result.push(r);
146 }
147 }
148
149 result.sort_by_key(|r| r.start_byte);
150 Ok(result)
151}
152
153#[derive(Debug, Clone, PartialEq, Eq)]
154pub struct ProseRange {
155 pub start_byte: usize,
156 pub end_byte: usize,
157 pub exclusions: Vec<(usize, usize)>,
161}
162
163impl ProseRange {
164 #[must_use]
167 pub fn extract_text<'a>(&self, text: &'a str) -> std::borrow::Cow<'a, str> {
168 let slice = &text[self.start_byte..self.end_byte];
169 if self.exclusions.is_empty() {
170 return std::borrow::Cow::Borrowed(slice);
171 }
172 let mut buf = slice.to_string();
173 let bytes = unsafe { buf.as_bytes_mut() };
175 for &(exc_start, exc_end) in &self.exclusions {
176 let local_start = exc_start.saturating_sub(self.start_byte);
178 let local_end = exc_end.saturating_sub(self.start_byte).min(bytes.len());
179 for b in &mut bytes[local_start..local_end] {
180 *b = b' ';
181 }
182 }
183 strip_unmatched_brackets(bytes);
184 std::borrow::Cow::Owned(buf)
185 }
186
187 #[must_use]
190 #[allow(clippy::cast_possible_truncation)]
191 pub fn overlaps_exclusion(&self, local_start: u32, local_end: u32) -> bool {
192 let doc_start = self.start_byte as u32 + local_start;
193 let doc_end = self.start_byte as u32 + local_end;
194 self.exclusions.iter().any(|&(exc_start, exc_end)| {
195 let es = exc_start as u32;
196 let ee = exc_end as u32;
197 doc_start < ee && doc_end > es
198 })
199 }
200}
201
202fn strip_unmatched_brackets(bytes: &mut [u8]) {
208 let mut paren_stack: Vec<usize> = Vec::new();
209 let mut bracket_stack: Vec<usize> = Vec::new();
210 let mut brace_stack: Vec<usize> = Vec::new();
211 let mut unmatched: Vec<usize> = Vec::new();
212
213 for (i, &b) in bytes.iter().enumerate() {
214 match b {
215 b'(' => paren_stack.push(i),
216 b')' => {
217 if paren_stack.pop().is_none() {
218 unmatched.push(i);
219 }
220 }
221 b'[' => bracket_stack.push(i),
222 b']' => {
223 if bracket_stack.pop().is_none() {
224 unmatched.push(i);
225 }
226 }
227 b'{' => brace_stack.push(i),
228 b'}' => {
229 if brace_stack.pop().is_none() {
230 unmatched.push(i);
231 }
232 }
233 _ => {}
234 }
235 }
236
237 unmatched.extend(paren_stack);
238 unmatched.extend(bracket_stack);
239 unmatched.extend(brace_stack);
240
241 for idx in unmatched {
242 bytes[idx] = b' ';
243 }
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249 use latex::LatexExtras;
250
251 #[test]
252 fn test_markdown_extraction() -> Result<()> {
253 let language: tree_sitter::Language = tree_sitter_md::LANGUAGE.into();
254 let mut extractor = ProseExtractor::new(language)?;
255
256 let text =
257 "# Header\n\nThis is a paragraph.\n\n```rust\nfn main() {}\n```\n\nAnother paragraph.";
258 let ranges = extractor.extract(text, "markdown", &LatexExtras::default())?;
259
260 assert!(ranges.len() >= 3);
261
262 let extracted_texts: Vec<&str> = ranges
263 .iter()
264 .map(|r| &text[r.start_byte..r.end_byte])
265 .collect();
266 assert!(extracted_texts.iter().any(|t| t.contains("Header")));
267 assert!(
268 extracted_texts
269 .iter()
270 .any(|t| t.contains("This is a paragraph"))
271 );
272 assert!(
273 extracted_texts
274 .iter()
275 .any(|t| t.contains("Another paragraph"))
276 );
277
278 Ok(())
279 }
280
281 #[test]
282 fn test_overlaps_exclusion() {
283 let range = ProseRange {
284 start_byte: 100,
285 end_byte: 300,
286 exclusions: vec![(150, 200)],
287 };
288
289 assert!(range.overlaps_exclusion(50, 100)); assert!(range.overlaps_exclusion(40, 60)); assert!(range.overlaps_exclusion(90, 110)); assert!(!range.overlaps_exclusion(0, 40)); assert!(!range.overlaps_exclusion(110, 130)); }
298
299 #[test]
300 fn type_override_latex_in_markdown() -> Result<()> {
301 let text = "\
302# Title
303
304Some intro text.
305
306<!-- lang-check-begin type:latex -->
307\\emph{Hello} world and \\textbf{bold} text.
308<!-- lang-check-end -->
309
310Final paragraph.";
311
312 let ranges = extract_with_fallback(text, "markdown", None, None, &LatexExtras::default())?;
313
314 let texts: Vec<&str> = ranges
315 .iter()
316 .map(|r| &text[r.start_byte..r.end_byte])
317 .collect();
318
319 assert!(texts.iter().any(|t| t.contains("Title")));
321 assert!(texts.iter().any(|t| t.contains("intro text")));
322 assert!(texts.iter().any(|t| t.contains("Final paragraph")));
323
324 assert!(
327 texts.iter().any(|t| t.contains("Hello")),
328 "expected LaTeX extractor to produce range containing 'Hello', got: {texts:?}"
329 );
330
331 Ok(())
332 }
333
334 #[test]
335 fn type_override_unknown_skipped() -> Result<()> {
336 let text = "\
337# Title
338
339<!-- lang-check-begin type:foobar -->
340Some content here.
341<!-- lang-check-end -->
342
343Trailing text.";
344
345 let ranges = extract_with_fallback(text, "markdown", None, None, &LatexExtras::default())?;
346
347 let texts: Vec<&str> = ranges
348 .iter()
349 .map(|r| &text[r.start_byte..r.end_byte])
350 .collect();
351
352 assert!(texts.iter().any(|t| t.contains("Title")));
354 assert!(texts.iter().any(|t| t.contains("Trailing text")));
355
356 assert!(
359 !texts.iter().any(|t| t.contains("Some content")),
360 "expected unknown type region to be skipped, got: {texts:?}"
361 );
362
363 Ok(())
364 }
365
366 #[test]
367 fn type_override_preserves_surrounding() -> Result<()> {
368 let text = "\
369First paragraph before.
370
371<!-- lang-check-begin type:latex -->
372\\section{Test}
373Some LaTeX prose.
374<!-- lang-check-end -->
375
376Last paragraph after.";
377
378 let ranges = extract_with_fallback(text, "markdown", None, None, &LatexExtras::default())?;
379
380 let texts: Vec<&str> = ranges
381 .iter()
382 .map(|r| &text[r.start_byte..r.end_byte])
383 .collect();
384
385 assert!(
387 texts.iter().any(|t| t.contains("First paragraph before")),
388 "pre-region range missing: {texts:?}"
389 );
390 assert!(
391 texts.iter().any(|t| t.contains("Last paragraph after")),
392 "post-region range missing: {texts:?}"
393 );
394
395 Ok(())
396 }
397
398 #[test]
399 fn strip_unmatched_orphan_close() {
400 let mut bytes = b"hello } world".to_vec();
401 strip_unmatched_brackets(&mut bytes);
402 assert_eq!(&bytes, b"hello world");
403 }
404
405 #[test]
406 fn strip_unmatched_orphan_open() {
407 let mut bytes = b"hello ( world".to_vec();
408 strip_unmatched_brackets(&mut bytes);
409 assert_eq!(&bytes, b"hello world");
410 }
411
412 #[test]
413 fn strip_unmatched_preserves_matched() {
414 let mut bytes = b"f(x) and [y]".to_vec();
415 strip_unmatched_brackets(&mut bytes);
416 assert_eq!(&bytes, b"f(x) and [y]");
417 }
418
419 #[test]
420 fn strip_unmatched_mixed() {
421 let mut bytes = b"value } is f(x)".to_vec();
423 strip_unmatched_brackets(&mut bytes);
424 assert_eq!(&bytes, b"value is f(x)");
425 }
426
427 #[test]
428 fn strip_unmatched_via_extract_text() {
429 let range = ProseRange {
430 start_byte: 0,
431 end_byte: 20,
432 exclusions: vec![(5, 10)],
433 };
434 let text = "text #{x+y} rest____";
438 let clean = range.extract_text(text);
439 assert!(!clean.contains('#'));
441 assert!(!clean.contains('{'));
442 assert!(!clean.contains('}'));
443 }
444}