Skip to main content

lang_check/prose/
mod.rs

1mod bibtex;
2mod forester;
3pub mod latex;
4mod org;
5mod query;
6mod rst;
7mod shared;
8mod sweave;
9mod tinylang;
10mod typst;
11
12use anyhow::{Result, anyhow};
13use std::ops::Range;
14use std::path::Path;
15use tree_sitter::{Language, Parser};
16
17use crate::ignore_rules::{DirectiveRegion, IgnoreParser};
18
19use crate::sls::SchemaRegistry;
20
21pub struct ProseExtractor {
22    parser: Parser,
23    language: Language,
24}
25
26impl ProseExtractor {
27    pub fn new(language: Language) -> Result<Self> {
28        let mut parser = Parser::new();
29        parser.set_language(&language)?;
30        Ok(Self { parser, language })
31    }
32
33    pub fn extract(
34        &mut self,
35        text: &str,
36        lang_id: &str,
37        latex_extras: &latex::LatexExtras,
38    ) -> Result<Vec<ProseRange>> {
39        let tree = self
40            .parser
41            .parse(text, None)
42            .ok_or_else(|| anyhow!("Failed to parse text"))?;
43
44        let root = tree.root_node();
45
46        match lang_id {
47            "latex" => Ok(latex::extract(text, root, latex_extras)),
48            "sweave" => Ok(sweave::extract(text, root, latex_extras)),
49            "forester" => Ok(forester::extract(text, root)),
50            "tinylang" => Ok(tinylang::extract(text, root)),
51            "rst" => Ok(rst::extract(text, root)),
52            "bibtex" => Ok(bibtex::extract(text, root)),
53            "org" => Ok(org::extract(text, root)),
54            "typst" => Ok(typst::extract(text, root)),
55            lang => query::extract(text, root, &self.language, lang),
56        }
57    }
58}
59
60/// Extract prose using a built-in tree-sitter extractor or an SLS fallback.
61///
62/// When the file extension matches a loaded SLS schema and that extension has
63/// no built-in tree-sitter extractor, the schema takes over. Built-in
64/// extensions always keep precedence.
65pub fn extract_with_fallback(
66    text: &str,
67    lang_id: &str,
68    path: Option<&Path>,
69    schema_registry: Option<&SchemaRegistry>,
70    latex_extras: &latex::LatexExtras,
71) -> Result<Vec<ProseRange>> {
72    if let Some(ext) = path
73        .and_then(|value| value.extension())
74        .and_then(|value| value.to_str())
75        && crate::languages::builtin_language_for_extension(ext).is_none()
76        && let Some(schema) = schema_registry.and_then(|registry| registry.find_by_extension(ext))
77    {
78        return Ok(schema.extract(text));
79    }
80
81    let canonical_lang = crate::languages::resolve_language_id(lang_id);
82    let language = crate::languages::resolve_ts_language(canonical_lang);
83    let mut extractor = ProseExtractor::new(language)?;
84    let mut ranges = extractor.extract(text, canonical_lang, latex_extras)?;
85
86    let directives = IgnoreParser::parse_directives(text);
87    let resolved = IgnoreParser::resolve_all(text, &directives);
88    let type_regions: Vec<_> = resolved
89        .regions
90        .iter()
91        .filter(|r| r.options.doc_type.is_some())
92        .collect();
93    if !type_regions.is_empty() {
94        ranges = apply_type_overrides(text, ranges, &type_regions, latex_extras)?;
95    }
96
97    Ok(ranges)
98}
99
100/// Re-extract prose for regions tagged with `type:FORMAT`.
101///
102/// For each type-override region, slices the document text, runs the specified
103/// format's extractor, and rebases the resulting ranges to document-level
104/// offsets. Base ranges whose `start_byte` falls inside a type-override region
105/// are removed and replaced with the re-extracted ranges.
106fn apply_type_overrides(
107    text: &str,
108    base_ranges: Vec<ProseRange>,
109    type_regions: &[&DirectiveRegion],
110    latex_extras: &latex::LatexExtras,
111) -> Result<Vec<ProseRange>> {
112    let override_spans: Vec<&Range<usize>> = type_regions.iter().map(|r| &r.byte_range).collect();
113
114    // Keep base ranges that don't start inside any type-override region.
115    let mut result: Vec<ProseRange> = base_ranges
116        .into_iter()
117        .filter(|r| {
118            !override_spans
119                .iter()
120                .any(|span| span.contains(&r.start_byte))
121        })
122        .collect();
123
124    for region in type_regions {
125        let doc_type = region.options.doc_type.as_deref().unwrap();
126        let canonical = crate::languages::resolve_language_id(doc_type);
127
128        if !crate::languages::SUPPORTED_LANGUAGE_IDS.contains(&canonical) {
129            eprintln!("lang-check: `type:{doc_type}` is not a supported language; skipping region");
130            continue;
131        }
132
133        let slice = &text[region.byte_range.clone()];
134        let ts_lang = crate::languages::resolve_ts_language(canonical);
135        let mut ext = ProseExtractor::new(ts_lang)?;
136        let sub_ranges = ext.extract(slice, canonical, latex_extras)?;
137
138        let offset = region.byte_range.start;
139        for mut r in sub_ranges {
140            r.start_byte += offset;
141            r.end_byte += offset;
142            r.exclusions = r
143                .exclusions
144                .into_iter()
145                .map(|(s, e)| (s + offset, e + offset))
146                .collect();
147            result.push(r);
148        }
149    }
150
151    result.sort_by_key(|r| r.start_byte);
152    Ok(result)
153}
154
155#[derive(Debug, Clone, PartialEq, Eq)]
156pub struct ProseRange {
157    pub start_byte: usize,
158    pub end_byte: usize,
159    /// Byte ranges (document-level) within this prose range that should be
160    /// excluded from grammar checking (e.g. display math). These regions are
161    /// replaced with spaces when extracting text, preserving byte offsets.
162    pub exclusions: Vec<(usize, usize)>,
163}
164
165impl ProseRange {
166    /// Extract the prose text from the full document, replacing any excluded
167    /// regions with spaces so that byte offsets remain stable.
168    #[must_use]
169    pub fn extract_text<'a>(&self, text: &'a str) -> std::borrow::Cow<'a, str> {
170        let slice = &text[self.start_byte..self.end_byte];
171        if self.exclusions.is_empty() {
172            return std::borrow::Cow::Borrowed(slice);
173        }
174        let mut buf = slice.to_string();
175        // SAFETY: we only replace valid UTF-8 ranges with ASCII spaces
176        let bytes = unsafe { buf.as_bytes_mut() };
177        for &(exc_start, exc_end) in &self.exclusions {
178            // Convert document-level offsets to slice-local offsets
179            let local_start = exc_start.saturating_sub(self.start_byte);
180            let local_end = exc_end.saturating_sub(self.start_byte).min(bytes.len());
181            for b in &mut bytes[local_start..local_end] {
182                *b = b' ';
183            }
184        }
185        strip_unmatched_brackets(bytes);
186        std::borrow::Cow::Owned(buf)
187    }
188
189    /// Check whether a local byte range (relative to this prose range)
190    /// overlaps with any exclusion zone.
191    #[must_use]
192    #[allow(clippy::cast_possible_truncation)]
193    pub fn overlaps_exclusion(&self, local_start: u32, local_end: u32) -> bool {
194        let doc_start = self.start_byte as u32 + local_start;
195        let doc_end = self.start_byte as u32 + local_end;
196        self.exclusions.iter().any(|&(exc_start, exc_end)| {
197            let es = exc_start as u32;
198            let ee = exc_end as u32;
199            doc_start < ee && doc_end > es
200        })
201    }
202}
203
204/// Replace provably-unmatched brackets `()[]{}` with spaces.
205///
206/// Uses a single O(n) pass with per-type stacks. Only brackets that have no
207/// matching partner anywhere in the text are replaced — correctly paired
208/// brackets (even across exclusion gaps) are left untouched.
209fn strip_unmatched_brackets(bytes: &mut [u8]) {
210    let mut paren_stack: Vec<usize> = Vec::new();
211    let mut bracket_stack: Vec<usize> = Vec::new();
212    let mut brace_stack: Vec<usize> = Vec::new();
213    let mut unmatched: Vec<usize> = Vec::new();
214
215    for (i, &b) in bytes.iter().enumerate() {
216        match b {
217            b'(' => paren_stack.push(i),
218            b')' => {
219                if paren_stack.pop().is_none() {
220                    unmatched.push(i);
221                }
222            }
223            b'[' => bracket_stack.push(i),
224            b']' => {
225                if bracket_stack.pop().is_none() {
226                    unmatched.push(i);
227                }
228            }
229            b'{' => brace_stack.push(i),
230            b'}' => {
231                if brace_stack.pop().is_none() {
232                    unmatched.push(i);
233                }
234            }
235            _ => {}
236        }
237    }
238
239    unmatched.extend(paren_stack);
240    unmatched.extend(bracket_stack);
241    unmatched.extend(brace_stack);
242
243    for idx in unmatched {
244        bytes[idx] = b' ';
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251    use latex::LatexExtras;
252
253    #[test]
254    fn test_markdown_extraction() -> Result<()> {
255        let language: tree_sitter::Language = tree_sitter_md::LANGUAGE.into();
256        let mut extractor = ProseExtractor::new(language)?;
257
258        let text =
259            "# Header\n\nThis is a paragraph.\n\n```rust\nfn main() {}\n```\n\nAnother paragraph.";
260        let ranges = extractor.extract(text, "markdown", &LatexExtras::default())?;
261
262        assert!(ranges.len() >= 3);
263
264        let extracted_texts: Vec<&str> = ranges
265            .iter()
266            .map(|r| &text[r.start_byte..r.end_byte])
267            .collect();
268        assert!(extracted_texts.iter().any(|t| t.contains("Header")));
269        assert!(
270            extracted_texts
271                .iter()
272                .any(|t| t.contains("This is a paragraph"))
273        );
274        assert!(
275            extracted_texts
276                .iter()
277                .any(|t| t.contains("Another paragraph"))
278        );
279
280        Ok(())
281    }
282
283    #[test]
284    fn test_overlaps_exclusion() {
285        let range = ProseRange {
286            start_byte: 100,
287            end_byte: 300,
288            exclusions: vec![(150, 200)],
289        };
290
291        // Diagnostic entirely inside exclusion
292        assert!(range.overlaps_exclusion(50, 100)); // local 50..100 = doc 150..200
293        // Diagnostic partially overlapping exclusion
294        assert!(range.overlaps_exclusion(40, 60)); // doc 140..160 overlaps 150..200
295        assert!(range.overlaps_exclusion(90, 110)); // doc 190..210 overlaps 150..200
296        // Diagnostic entirely outside exclusion
297        assert!(!range.overlaps_exclusion(0, 40)); // doc 100..140, before exclusion
298        assert!(!range.overlaps_exclusion(110, 130)); // doc 210..230, after exclusion
299    }
300
301    #[test]
302    fn type_override_latex_in_markdown() -> Result<()> {
303        let text = "\
304# Title
305
306Some intro text.
307
308<!-- lang-check-begin type:latex -->
309\\emph{Hello} world and \\textbf{bold} text.
310<!-- lang-check-end -->
311
312Final paragraph.";
313
314        let ranges = extract_with_fallback(text, "markdown", None, None, &LatexExtras::default())?;
315
316        let texts: Vec<&str> = ranges
317            .iter()
318            .map(|r| &text[r.start_byte..r.end_byte])
319            .collect();
320
321        // Surrounding markdown prose is preserved.
322        assert!(texts.iter().any(|t| t.contains("Title")));
323        assert!(texts.iter().any(|t| t.contains("intro text")));
324        assert!(texts.iter().any(|t| t.contains("Final paragraph")));
325
326        // The LaTeX region was re-extracted: the prose content from
327        // \emph{Hello} and \textbf{bold} should appear in ranges.
328        assert!(
329            texts.iter().any(|t| t.contains("Hello")),
330            "expected LaTeX extractor to produce range containing 'Hello', got: {texts:?}"
331        );
332
333        Ok(())
334    }
335
336    #[test]
337    fn type_override_unknown_skipped() -> Result<()> {
338        let text = "\
339# Title
340
341<!-- lang-check-begin type:foobar -->
342Some content here.
343<!-- lang-check-end -->
344
345Trailing text.";
346
347        let ranges = extract_with_fallback(text, "markdown", None, None, &LatexExtras::default())?;
348
349        let texts: Vec<&str> = ranges
350            .iter()
351            .map(|r| &text[r.start_byte..r.end_byte])
352            .collect();
353
354        // Surrounding ranges preserved.
355        assert!(texts.iter().any(|t| t.contains("Title")));
356        assert!(texts.iter().any(|t| t.contains("Trailing text")));
357
358        // The unknown-type region's base ranges were filtered out, and no
359        // re-extraction happened, so "Some content" should be absent.
360        assert!(
361            !texts.iter().any(|t| t.contains("Some content")),
362            "expected unknown type region to be skipped, got: {texts:?}"
363        );
364
365        Ok(())
366    }
367
368    #[test]
369    fn type_override_preserves_surrounding() -> Result<()> {
370        let text = "\
371First paragraph before.
372
373<!-- lang-check-begin type:latex -->
374\\section{Test}
375Some LaTeX prose.
376<!-- lang-check-end -->
377
378Last paragraph after.";
379
380        let ranges = extract_with_fallback(text, "markdown", None, None, &LatexExtras::default())?;
381
382        let texts: Vec<&str> = ranges
383            .iter()
384            .map(|r| &text[r.start_byte..r.end_byte])
385            .collect();
386
387        // Both surrounding paragraphs must be present and unmodified.
388        assert!(
389            texts.iter().any(|t| t.contains("First paragraph before")),
390            "pre-region range missing: {texts:?}"
391        );
392        assert!(
393            texts.iter().any(|t| t.contains("Last paragraph after")),
394            "post-region range missing: {texts:?}"
395        );
396
397        Ok(())
398    }
399
400    #[test]
401    fn strip_unmatched_orphan_close() {
402        let mut bytes = b"hello } world".to_vec();
403        strip_unmatched_brackets(&mut bytes);
404        assert_eq!(&bytes, b"hello   world");
405    }
406
407    #[test]
408    fn strip_unmatched_orphan_open() {
409        let mut bytes = b"hello ( world".to_vec();
410        strip_unmatched_brackets(&mut bytes);
411        assert_eq!(&bytes, b"hello   world");
412    }
413
414    #[test]
415    fn strip_unmatched_preserves_matched() {
416        let mut bytes = b"f(x) and [y]".to_vec();
417        strip_unmatched_brackets(&mut bytes);
418        assert_eq!(&bytes, b"f(x) and [y]");
419    }
420
421    #[test]
422    fn strip_unmatched_mixed() {
423        // '}' is unmatched, '(x)' is matched
424        let mut bytes = b"value } is f(x)".to_vec();
425        strip_unmatched_brackets(&mut bytes);
426        assert_eq!(&bytes, b"value   is f(x)");
427    }
428
429    #[test]
430    fn strip_unmatched_via_extract_text() {
431        let range = ProseRange {
432            start_byte: 0,
433            end_byte: 20,
434            exclusions: vec![(5, 10)],
435        };
436        // "text } rest" after blanking exclusion [5,10) -> "text      rest"
437        // but if original is "text #{x+y} rest", after blanking the #{x+y}
438        // region we get "text        rest" with no unmatched brackets.
439        let text = "text #{x+y} rest____";
440        let clean = range.extract_text(text);
441        // The #{x+y} was blanked, no unmatched brackets remain
442        assert!(!clean.contains('#'));
443        assert!(!clean.contains('{'));
444        assert!(!clean.contains('}'));
445    }
446}