Skip to main content

lang_check/prose/
mod.rs

1mod bibtex;
2mod forester;
3pub mod latex;
4mod org;
5mod query;
6mod rst;
7mod shared;
8mod sweave;
9mod tinylang;
10
11use anyhow::{Result, anyhow};
12use std::ops::Range;
13use std::path::Path;
14use tree_sitter::{Language, Parser};
15
16use crate::ignore_rules::{DirectiveRegion, IgnoreParser};
17
18use crate::sls::SchemaRegistry;
19
20pub struct ProseExtractor {
21    parser: Parser,
22    language: Language,
23}
24
25impl ProseExtractor {
26    pub fn new(language: Language) -> Result<Self> {
27        let mut parser = Parser::new();
28        parser.set_language(&language)?;
29        Ok(Self { parser, language })
30    }
31
32    pub fn extract(
33        &mut self,
34        text: &str,
35        lang_id: &str,
36        latex_extras: &latex::LatexExtras,
37    ) -> Result<Vec<ProseRange>> {
38        let tree = self
39            .parser
40            .parse(text, None)
41            .ok_or_else(|| anyhow!("Failed to parse text"))?;
42
43        let root = tree.root_node();
44
45        match lang_id {
46            "latex" => Ok(latex::extract(text, root, latex_extras)),
47            "sweave" => Ok(sweave::extract(text, root, latex_extras)),
48            "forester" => Ok(forester::extract(text, root)),
49            "tinylang" => Ok(tinylang::extract(text, root)),
50            "rst" => Ok(rst::extract(text, root)),
51            "bibtex" => Ok(bibtex::extract(text, root)),
52            "org" => Ok(org::extract(text, root)),
53            lang => query::extract(text, root, &self.language, lang),
54        }
55    }
56}
57
58/// Extract prose using a built-in tree-sitter extractor or an SLS fallback.
59///
60/// When the file extension matches a loaded SLS schema and that extension has
61/// no built-in tree-sitter extractor, the schema takes over. Built-in
62/// extensions always keep precedence.
63pub fn extract_with_fallback(
64    text: &str,
65    lang_id: &str,
66    path: Option<&Path>,
67    schema_registry: Option<&SchemaRegistry>,
68    latex_extras: &latex::LatexExtras,
69) -> Result<Vec<ProseRange>> {
70    if let Some(ext) = path
71        .and_then(|value| value.extension())
72        .and_then(|value| value.to_str())
73        && crate::languages::builtin_language_for_extension(ext).is_none()
74        && let Some(schema) = schema_registry.and_then(|registry| registry.find_by_extension(ext))
75    {
76        return Ok(schema.extract(text));
77    }
78
79    let canonical_lang = crate::languages::resolve_language_id(lang_id);
80    let language = crate::languages::resolve_ts_language(canonical_lang);
81    let mut extractor = ProseExtractor::new(language)?;
82    let mut ranges = extractor.extract(text, canonical_lang, latex_extras)?;
83
84    let directives = IgnoreParser::parse_directives(text);
85    let resolved = IgnoreParser::resolve_all(text, &directives);
86    let type_regions: Vec<_> = resolved
87        .regions
88        .iter()
89        .filter(|r| r.options.doc_type.is_some())
90        .collect();
91    if !type_regions.is_empty() {
92        ranges = apply_type_overrides(text, ranges, &type_regions, latex_extras)?;
93    }
94
95    Ok(ranges)
96}
97
98/// Re-extract prose for regions tagged with `type:FORMAT`.
99///
100/// For each type-override region, slices the document text, runs the specified
101/// format's extractor, and rebases the resulting ranges to document-level
102/// offsets. Base ranges whose `start_byte` falls inside a type-override region
103/// are removed and replaced with the re-extracted ranges.
104fn apply_type_overrides(
105    text: &str,
106    base_ranges: Vec<ProseRange>,
107    type_regions: &[&DirectiveRegion],
108    latex_extras: &latex::LatexExtras,
109) -> Result<Vec<ProseRange>> {
110    let override_spans: Vec<&Range<usize>> = type_regions.iter().map(|r| &r.byte_range).collect();
111
112    // Keep base ranges that don't start inside any type-override region.
113    let mut result: Vec<ProseRange> = base_ranges
114        .into_iter()
115        .filter(|r| {
116            !override_spans
117                .iter()
118                .any(|span| span.contains(&r.start_byte))
119        })
120        .collect();
121
122    for region in type_regions {
123        let doc_type = region.options.doc_type.as_deref().unwrap();
124        let canonical = crate::languages::resolve_language_id(doc_type);
125
126        if !crate::languages::SUPPORTED_LANGUAGE_IDS.contains(&canonical) {
127            eprintln!("lang-check: `type:{doc_type}` is not a supported language; skipping region");
128            continue;
129        }
130
131        let slice = &text[region.byte_range.clone()];
132        let ts_lang = crate::languages::resolve_ts_language(canonical);
133        let mut ext = ProseExtractor::new(ts_lang)?;
134        let sub_ranges = ext.extract(slice, canonical, latex_extras)?;
135
136        let offset = region.byte_range.start;
137        for mut r in sub_ranges {
138            r.start_byte += offset;
139            r.end_byte += offset;
140            r.exclusions = r
141                .exclusions
142                .into_iter()
143                .map(|(s, e)| (s + offset, e + offset))
144                .collect();
145            result.push(r);
146        }
147    }
148
149    result.sort_by_key(|r| r.start_byte);
150    Ok(result)
151}
152
153#[derive(Debug, Clone, PartialEq, Eq)]
154pub struct ProseRange {
155    pub start_byte: usize,
156    pub end_byte: usize,
157    /// Byte ranges (document-level) within this prose range that should be
158    /// excluded from grammar checking (e.g. display math). These regions are
159    /// replaced with spaces when extracting text, preserving byte offsets.
160    pub exclusions: Vec<(usize, usize)>,
161}
162
163impl ProseRange {
164    /// Extract the prose text from the full document, replacing any excluded
165    /// regions with spaces so that byte offsets remain stable.
166    #[must_use]
167    pub fn extract_text<'a>(&self, text: &'a str) -> std::borrow::Cow<'a, str> {
168        let slice = &text[self.start_byte..self.end_byte];
169        if self.exclusions.is_empty() {
170            return std::borrow::Cow::Borrowed(slice);
171        }
172        let mut buf = slice.to_string();
173        // SAFETY: we only replace valid UTF-8 ranges with ASCII spaces
174        let bytes = unsafe { buf.as_bytes_mut() };
175        for &(exc_start, exc_end) in &self.exclusions {
176            // Convert document-level offsets to slice-local offsets
177            let local_start = exc_start.saturating_sub(self.start_byte);
178            let local_end = exc_end.saturating_sub(self.start_byte).min(bytes.len());
179            for b in &mut bytes[local_start..local_end] {
180                *b = b' ';
181            }
182        }
183        strip_unmatched_brackets(bytes);
184        std::borrow::Cow::Owned(buf)
185    }
186
187    /// Check whether a local byte range (relative to this prose range)
188    /// overlaps with any exclusion zone.
189    #[must_use]
190    #[allow(clippy::cast_possible_truncation)]
191    pub fn overlaps_exclusion(&self, local_start: u32, local_end: u32) -> bool {
192        let doc_start = self.start_byte as u32 + local_start;
193        let doc_end = self.start_byte as u32 + local_end;
194        self.exclusions.iter().any(|&(exc_start, exc_end)| {
195            let es = exc_start as u32;
196            let ee = exc_end as u32;
197            doc_start < ee && doc_end > es
198        })
199    }
200}
201
202/// Replace provably-unmatched brackets `()[]{}` with spaces.
203///
204/// Uses a single O(n) pass with per-type stacks. Only brackets that have no
205/// matching partner anywhere in the text are replaced — correctly paired
206/// brackets (even across exclusion gaps) are left untouched.
207fn strip_unmatched_brackets(bytes: &mut [u8]) {
208    let mut paren_stack: Vec<usize> = Vec::new();
209    let mut bracket_stack: Vec<usize> = Vec::new();
210    let mut brace_stack: Vec<usize> = Vec::new();
211    let mut unmatched: Vec<usize> = Vec::new();
212
213    for (i, &b) in bytes.iter().enumerate() {
214        match b {
215            b'(' => paren_stack.push(i),
216            b')' => {
217                if paren_stack.pop().is_none() {
218                    unmatched.push(i);
219                }
220            }
221            b'[' => bracket_stack.push(i),
222            b']' => {
223                if bracket_stack.pop().is_none() {
224                    unmatched.push(i);
225                }
226            }
227            b'{' => brace_stack.push(i),
228            b'}' => {
229                if brace_stack.pop().is_none() {
230                    unmatched.push(i);
231                }
232            }
233            _ => {}
234        }
235    }
236
237    unmatched.extend(paren_stack);
238    unmatched.extend(bracket_stack);
239    unmatched.extend(brace_stack);
240
241    for idx in unmatched {
242        bytes[idx] = b' ';
243    }
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249    use latex::LatexExtras;
250
251    #[test]
252    fn test_markdown_extraction() -> Result<()> {
253        let language: tree_sitter::Language = tree_sitter_md::LANGUAGE.into();
254        let mut extractor = ProseExtractor::new(language)?;
255
256        let text =
257            "# Header\n\nThis is a paragraph.\n\n```rust\nfn main() {}\n```\n\nAnother paragraph.";
258        let ranges = extractor.extract(text, "markdown", &LatexExtras::default())?;
259
260        assert!(ranges.len() >= 3);
261
262        let extracted_texts: Vec<&str> = ranges
263            .iter()
264            .map(|r| &text[r.start_byte..r.end_byte])
265            .collect();
266        assert!(extracted_texts.iter().any(|t| t.contains("Header")));
267        assert!(
268            extracted_texts
269                .iter()
270                .any(|t| t.contains("This is a paragraph"))
271        );
272        assert!(
273            extracted_texts
274                .iter()
275                .any(|t| t.contains("Another paragraph"))
276        );
277
278        Ok(())
279    }
280
281    #[test]
282    fn test_overlaps_exclusion() {
283        let range = ProseRange {
284            start_byte: 100,
285            end_byte: 300,
286            exclusions: vec![(150, 200)],
287        };
288
289        // Diagnostic entirely inside exclusion
290        assert!(range.overlaps_exclusion(50, 100)); // local 50..100 = doc 150..200
291        // Diagnostic partially overlapping exclusion
292        assert!(range.overlaps_exclusion(40, 60)); // doc 140..160 overlaps 150..200
293        assert!(range.overlaps_exclusion(90, 110)); // doc 190..210 overlaps 150..200
294        // Diagnostic entirely outside exclusion
295        assert!(!range.overlaps_exclusion(0, 40)); // doc 100..140, before exclusion
296        assert!(!range.overlaps_exclusion(110, 130)); // doc 210..230, after exclusion
297    }
298
299    #[test]
300    fn type_override_latex_in_markdown() -> Result<()> {
301        let text = "\
302# Title
303
304Some intro text.
305
306<!-- lang-check-begin type:latex -->
307\\emph{Hello} world and \\textbf{bold} text.
308<!-- lang-check-end -->
309
310Final paragraph.";
311
312        let ranges = extract_with_fallback(text, "markdown", None, None, &LatexExtras::default())?;
313
314        let texts: Vec<&str> = ranges
315            .iter()
316            .map(|r| &text[r.start_byte..r.end_byte])
317            .collect();
318
319        // Surrounding markdown prose is preserved.
320        assert!(texts.iter().any(|t| t.contains("Title")));
321        assert!(texts.iter().any(|t| t.contains("intro text")));
322        assert!(texts.iter().any(|t| t.contains("Final paragraph")));
323
324        // The LaTeX region was re-extracted: the prose content from
325        // \emph{Hello} and \textbf{bold} should appear in ranges.
326        assert!(
327            texts.iter().any(|t| t.contains("Hello")),
328            "expected LaTeX extractor to produce range containing 'Hello', got: {texts:?}"
329        );
330
331        Ok(())
332    }
333
334    #[test]
335    fn type_override_unknown_skipped() -> Result<()> {
336        let text = "\
337# Title
338
339<!-- lang-check-begin type:foobar -->
340Some content here.
341<!-- lang-check-end -->
342
343Trailing text.";
344
345        let ranges = extract_with_fallback(text, "markdown", None, None, &LatexExtras::default())?;
346
347        let texts: Vec<&str> = ranges
348            .iter()
349            .map(|r| &text[r.start_byte..r.end_byte])
350            .collect();
351
352        // Surrounding ranges preserved.
353        assert!(texts.iter().any(|t| t.contains("Title")));
354        assert!(texts.iter().any(|t| t.contains("Trailing text")));
355
356        // The unknown-type region's base ranges were filtered out, and no
357        // re-extraction happened, so "Some content" should be absent.
358        assert!(
359            !texts.iter().any(|t| t.contains("Some content")),
360            "expected unknown type region to be skipped, got: {texts:?}"
361        );
362
363        Ok(())
364    }
365
366    #[test]
367    fn type_override_preserves_surrounding() -> Result<()> {
368        let text = "\
369First paragraph before.
370
371<!-- lang-check-begin type:latex -->
372\\section{Test}
373Some LaTeX prose.
374<!-- lang-check-end -->
375
376Last paragraph after.";
377
378        let ranges = extract_with_fallback(text, "markdown", None, None, &LatexExtras::default())?;
379
380        let texts: Vec<&str> = ranges
381            .iter()
382            .map(|r| &text[r.start_byte..r.end_byte])
383            .collect();
384
385        // Both surrounding paragraphs must be present and unmodified.
386        assert!(
387            texts.iter().any(|t| t.contains("First paragraph before")),
388            "pre-region range missing: {texts:?}"
389        );
390        assert!(
391            texts.iter().any(|t| t.contains("Last paragraph after")),
392            "post-region range missing: {texts:?}"
393        );
394
395        Ok(())
396    }
397
398    #[test]
399    fn strip_unmatched_orphan_close() {
400        let mut bytes = b"hello } world".to_vec();
401        strip_unmatched_brackets(&mut bytes);
402        assert_eq!(&bytes, b"hello   world");
403    }
404
405    #[test]
406    fn strip_unmatched_orphan_open() {
407        let mut bytes = b"hello ( world".to_vec();
408        strip_unmatched_brackets(&mut bytes);
409        assert_eq!(&bytes, b"hello   world");
410    }
411
412    #[test]
413    fn strip_unmatched_preserves_matched() {
414        let mut bytes = b"f(x) and [y]".to_vec();
415        strip_unmatched_brackets(&mut bytes);
416        assert_eq!(&bytes, b"f(x) and [y]");
417    }
418
419    #[test]
420    fn strip_unmatched_mixed() {
421        // '}' is unmatched, '(x)' is matched
422        let mut bytes = b"value } is f(x)".to_vec();
423        strip_unmatched_brackets(&mut bytes);
424        assert_eq!(&bytes, b"value   is f(x)");
425    }
426
427    #[test]
428    fn strip_unmatched_via_extract_text() {
429        let range = ProseRange {
430            start_byte: 0,
431            end_byte: 20,
432            exclusions: vec![(5, 10)],
433        };
434        // "text } rest" after blanking exclusion [5,10) -> "text      rest"
435        // but if original is "text #{x+y} rest", after blanking the #{x+y}
436        // region we get "text        rest" with no unmatched brackets.
437        let text = "text #{x+y} rest____";
438        let clean = range.extract_text(text);
439        // The #{x+y} was blanked, no unmatched brackets remain
440        assert!(!clean.contains('#'));
441        assert!(!clean.contains('{'));
442        assert!(!clean.contains('}'));
443    }
444}