1mod bibtex;
2mod forester;
3pub mod latex;
4mod org;
5mod query;
6mod rst;
7mod shared;
8mod sweave;
9mod tinylang;
10mod typst;
11
12use anyhow::{Result, anyhow};
13use std::ops::Range;
14use std::path::Path;
15use tree_sitter::{Language, Parser};
16
17use crate::ignore_rules::{DirectiveRegion, IgnoreParser};
18
19use crate::sls::SchemaRegistry;
20
21pub struct ProseExtractor {
22 parser: Parser,
23 language: Language,
24}
25
26impl ProseExtractor {
27 pub fn new(language: Language) -> Result<Self> {
28 let mut parser = Parser::new();
29 parser.set_language(&language)?;
30 Ok(Self { parser, language })
31 }
32
33 pub fn extract(
34 &mut self,
35 text: &str,
36 lang_id: &str,
37 latex_extras: &latex::LatexExtras,
38 ) -> Result<Vec<ProseRange>> {
39 let tree = self
40 .parser
41 .parse(text, None)
42 .ok_or_else(|| anyhow!("Failed to parse text"))?;
43
44 let root = tree.root_node();
45
46 match lang_id {
47 "latex" => Ok(latex::extract(text, root, latex_extras)),
48 "sweave" => Ok(sweave::extract(text, root, latex_extras)),
49 "forester" => Ok(forester::extract(text, root)),
50 "tinylang" => Ok(tinylang::extract(text, root)),
51 "rst" => Ok(rst::extract(text, root)),
52 "bibtex" => Ok(bibtex::extract(text, root)),
53 "org" => Ok(org::extract(text, root)),
54 "typst" => Ok(typst::extract(text, root)),
55 lang => query::extract(text, root, &self.language, lang),
56 }
57 }
58}
59
60pub fn extract_with_fallback(
66 text: &str,
67 lang_id: &str,
68 path: Option<&Path>,
69 schema_registry: Option<&SchemaRegistry>,
70 latex_extras: &latex::LatexExtras,
71) -> Result<Vec<ProseRange>> {
72 if let Some(ext) = path
73 .and_then(|value| value.extension())
74 .and_then(|value| value.to_str())
75 && crate::languages::builtin_language_for_extension(ext).is_none()
76 && let Some(schema) = schema_registry.and_then(|registry| registry.find_by_extension(ext))
77 {
78 return Ok(schema.extract(text));
79 }
80
81 let canonical_lang = crate::languages::resolve_language_id(lang_id);
82 let language = crate::languages::resolve_ts_language(canonical_lang);
83 let mut extractor = ProseExtractor::new(language)?;
84 let mut ranges = extractor.extract(text, canonical_lang, latex_extras)?;
85
86 let directives = IgnoreParser::parse_directives(text);
87 let resolved = IgnoreParser::resolve_all(text, &directives);
88 let type_regions: Vec<_> = resolved
89 .regions
90 .iter()
91 .filter(|r| r.options.doc_type.is_some())
92 .collect();
93 if !type_regions.is_empty() {
94 ranges = apply_type_overrides(text, ranges, &type_regions, latex_extras)?;
95 }
96
97 Ok(ranges)
98}
99
100fn apply_type_overrides(
107 text: &str,
108 base_ranges: Vec<ProseRange>,
109 type_regions: &[&DirectiveRegion],
110 latex_extras: &latex::LatexExtras,
111) -> Result<Vec<ProseRange>> {
112 let override_spans: Vec<&Range<usize>> = type_regions.iter().map(|r| &r.byte_range).collect();
113
114 let mut result: Vec<ProseRange> = base_ranges
116 .into_iter()
117 .filter(|r| {
118 !override_spans
119 .iter()
120 .any(|span| span.contains(&r.start_byte))
121 })
122 .collect();
123
124 for region in type_regions {
125 let doc_type = region.options.doc_type.as_deref().unwrap();
126 let canonical = crate::languages::resolve_language_id(doc_type);
127
128 if !crate::languages::SUPPORTED_LANGUAGE_IDS.contains(&canonical) {
129 eprintln!("lang-check: `type:{doc_type}` is not a supported language; skipping region");
130 continue;
131 }
132
133 let slice = &text[region.byte_range.clone()];
134 let ts_lang = crate::languages::resolve_ts_language(canonical);
135 let mut ext = ProseExtractor::new(ts_lang)?;
136 let sub_ranges = ext.extract(slice, canonical, latex_extras)?;
137
138 let offset = region.byte_range.start;
139 for mut r in sub_ranges {
140 r.start_byte += offset;
141 r.end_byte += offset;
142 r.exclusions = r
143 .exclusions
144 .into_iter()
145 .map(|(s, e)| (s + offset, e + offset))
146 .collect();
147 result.push(r);
148 }
149 }
150
151 result.sort_by_key(|r| r.start_byte);
152 Ok(result)
153}
154
155#[derive(Debug, Clone, PartialEq, Eq)]
156pub struct ProseRange {
157 pub start_byte: usize,
158 pub end_byte: usize,
159 pub exclusions: Vec<(usize, usize)>,
163}
164
165impl ProseRange {
166 #[must_use]
169 pub fn extract_text<'a>(&self, text: &'a str) -> std::borrow::Cow<'a, str> {
170 let slice = &text[self.start_byte..self.end_byte];
171 if self.exclusions.is_empty() {
172 return std::borrow::Cow::Borrowed(slice);
173 }
174 let mut buf = slice.to_string();
175 let bytes = unsafe { buf.as_bytes_mut() };
177 for &(exc_start, exc_end) in &self.exclusions {
178 let local_start = exc_start.saturating_sub(self.start_byte);
180 let local_end = exc_end.saturating_sub(self.start_byte).min(bytes.len());
181 for b in &mut bytes[local_start..local_end] {
182 *b = b' ';
183 }
184 }
185 strip_unmatched_brackets(bytes);
186 std::borrow::Cow::Owned(buf)
187 }
188
189 #[must_use]
192 #[allow(clippy::cast_possible_truncation)]
193 pub fn overlaps_exclusion(&self, local_start: u32, local_end: u32) -> bool {
194 let doc_start = self.start_byte as u32 + local_start;
195 let doc_end = self.start_byte as u32 + local_end;
196 self.exclusions.iter().any(|&(exc_start, exc_end)| {
197 let es = exc_start as u32;
198 let ee = exc_end as u32;
199 doc_start < ee && doc_end > es
200 })
201 }
202}
203
204fn strip_unmatched_brackets(bytes: &mut [u8]) {
210 let mut paren_stack: Vec<usize> = Vec::new();
211 let mut bracket_stack: Vec<usize> = Vec::new();
212 let mut brace_stack: Vec<usize> = Vec::new();
213 let mut unmatched: Vec<usize> = Vec::new();
214
215 for (i, &b) in bytes.iter().enumerate() {
216 match b {
217 b'(' => paren_stack.push(i),
218 b')' => {
219 if paren_stack.pop().is_none() {
220 unmatched.push(i);
221 }
222 }
223 b'[' => bracket_stack.push(i),
224 b']' => {
225 if bracket_stack.pop().is_none() {
226 unmatched.push(i);
227 }
228 }
229 b'{' => brace_stack.push(i),
230 b'}' => {
231 if brace_stack.pop().is_none() {
232 unmatched.push(i);
233 }
234 }
235 _ => {}
236 }
237 }
238
239 unmatched.extend(paren_stack);
240 unmatched.extend(bracket_stack);
241 unmatched.extend(brace_stack);
242
243 for idx in unmatched {
244 bytes[idx] = b' ';
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251 use latex::LatexExtras;
252
253 #[test]
254 fn test_markdown_extraction() -> Result<()> {
255 let language: tree_sitter::Language = tree_sitter_md::LANGUAGE.into();
256 let mut extractor = ProseExtractor::new(language)?;
257
258 let text =
259 "# Header\n\nThis is a paragraph.\n\n```rust\nfn main() {}\n```\n\nAnother paragraph.";
260 let ranges = extractor.extract(text, "markdown", &LatexExtras::default())?;
261
262 assert!(ranges.len() >= 3);
263
264 let extracted_texts: Vec<&str> = ranges
265 .iter()
266 .map(|r| &text[r.start_byte..r.end_byte])
267 .collect();
268 assert!(extracted_texts.iter().any(|t| t.contains("Header")));
269 assert!(
270 extracted_texts
271 .iter()
272 .any(|t| t.contains("This is a paragraph"))
273 );
274 assert!(
275 extracted_texts
276 .iter()
277 .any(|t| t.contains("Another paragraph"))
278 );
279
280 Ok(())
281 }
282
283 #[test]
284 fn test_overlaps_exclusion() {
285 let range = ProseRange {
286 start_byte: 100,
287 end_byte: 300,
288 exclusions: vec![(150, 200)],
289 };
290
291 assert!(range.overlaps_exclusion(50, 100)); assert!(range.overlaps_exclusion(40, 60)); assert!(range.overlaps_exclusion(90, 110)); assert!(!range.overlaps_exclusion(0, 40)); assert!(!range.overlaps_exclusion(110, 130)); }
300
301 #[test]
302 fn type_override_latex_in_markdown() -> Result<()> {
303 let text = "\
304# Title
305
306Some intro text.
307
308<!-- lang-check-begin type:latex -->
309\\emph{Hello} world and \\textbf{bold} text.
310<!-- lang-check-end -->
311
312Final paragraph.";
313
314 let ranges = extract_with_fallback(text, "markdown", None, None, &LatexExtras::default())?;
315
316 let texts: Vec<&str> = ranges
317 .iter()
318 .map(|r| &text[r.start_byte..r.end_byte])
319 .collect();
320
321 assert!(texts.iter().any(|t| t.contains("Title")));
323 assert!(texts.iter().any(|t| t.contains("intro text")));
324 assert!(texts.iter().any(|t| t.contains("Final paragraph")));
325
326 assert!(
329 texts.iter().any(|t| t.contains("Hello")),
330 "expected LaTeX extractor to produce range containing 'Hello', got: {texts:?}"
331 );
332
333 Ok(())
334 }
335
336 #[test]
337 fn type_override_unknown_skipped() -> Result<()> {
338 let text = "\
339# Title
340
341<!-- lang-check-begin type:foobar -->
342Some content here.
343<!-- lang-check-end -->
344
345Trailing text.";
346
347 let ranges = extract_with_fallback(text, "markdown", None, None, &LatexExtras::default())?;
348
349 let texts: Vec<&str> = ranges
350 .iter()
351 .map(|r| &text[r.start_byte..r.end_byte])
352 .collect();
353
354 assert!(texts.iter().any(|t| t.contains("Title")));
356 assert!(texts.iter().any(|t| t.contains("Trailing text")));
357
358 assert!(
361 !texts.iter().any(|t| t.contains("Some content")),
362 "expected unknown type region to be skipped, got: {texts:?}"
363 );
364
365 Ok(())
366 }
367
368 #[test]
369 fn type_override_preserves_surrounding() -> Result<()> {
370 let text = "\
371First paragraph before.
372
373<!-- lang-check-begin type:latex -->
374\\section{Test}
375Some LaTeX prose.
376<!-- lang-check-end -->
377
378Last paragraph after.";
379
380 let ranges = extract_with_fallback(text, "markdown", None, None, &LatexExtras::default())?;
381
382 let texts: Vec<&str> = ranges
383 .iter()
384 .map(|r| &text[r.start_byte..r.end_byte])
385 .collect();
386
387 assert!(
389 texts.iter().any(|t| t.contains("First paragraph before")),
390 "pre-region range missing: {texts:?}"
391 );
392 assert!(
393 texts.iter().any(|t| t.contains("Last paragraph after")),
394 "post-region range missing: {texts:?}"
395 );
396
397 Ok(())
398 }
399
400 #[test]
401 fn strip_unmatched_orphan_close() {
402 let mut bytes = b"hello } world".to_vec();
403 strip_unmatched_brackets(&mut bytes);
404 assert_eq!(&bytes, b"hello world");
405 }
406
407 #[test]
408 fn strip_unmatched_orphan_open() {
409 let mut bytes = b"hello ( world".to_vec();
410 strip_unmatched_brackets(&mut bytes);
411 assert_eq!(&bytes, b"hello world");
412 }
413
414 #[test]
415 fn strip_unmatched_preserves_matched() {
416 let mut bytes = b"f(x) and [y]".to_vec();
417 strip_unmatched_brackets(&mut bytes);
418 assert_eq!(&bytes, b"f(x) and [y]");
419 }
420
421 #[test]
422 fn strip_unmatched_mixed() {
423 let mut bytes = b"value } is f(x)".to_vec();
425 strip_unmatched_brackets(&mut bytes);
426 assert_eq!(&bytes, b"value is f(x)");
427 }
428
429 #[test]
430 fn strip_unmatched_via_extract_text() {
431 let range = ProseRange {
432 start_byte: 0,
433 end_byte: 20,
434 exclusions: vec![(5, 10)],
435 };
436 let text = "text #{x+y} rest____";
440 let clean = range.extract_text(text);
441 assert!(!clean.contains('#'));
443 assert!(!clean.contains('{'));
444 assert!(!clean.contains('}'));
445 }
446}