mod bibtex;
mod forester;
pub mod latex;
mod org;
mod query;
mod rst;
mod shared;
mod sweave;
mod tinylang;
mod typst;
use anyhow::{Result, anyhow};
use std::ops::Range;
use std::path::Path;
use tree_sitter::{Language, Parser};
use crate::ignore_rules::{DirectiveRegion, IgnoreParser};
use crate::sls::SchemaRegistry;
pub struct ProseExtractor {
parser: Parser,
language: Language,
}
impl ProseExtractor {
pub fn new(language: Language) -> Result<Self> {
let mut parser = Parser::new();
parser.set_language(&language)?;
Ok(Self { parser, language })
}
pub fn extract(
&mut self,
text: &str,
lang_id: &str,
latex_extras: &latex::LatexExtras,
) -> Result<Vec<ProseRange>> {
let tree = self
.parser
.parse(text, None)
.ok_or_else(|| anyhow!("Failed to parse text"))?;
let root = tree.root_node();
let ranges = match lang_id {
"latex" => latex::extract(text, root, latex_extras),
"sweave" => sweave::extract(text, root, latex_extras),
"forester" => forester::extract(text, root),
"tinylang" => tinylang::extract(text, root),
"rst" => rst::extract(text, root),
"bibtex" => bibtex::extract(text, root),
"org" => org::extract(text, root),
"typst" => typst::extract(text, root),
lang => query::extract(text, root, &self.language, lang)?,
};
let force_regions = crate::ignore_rules::IgnoreParser::block_regions(text);
Ok(shared::merge_continuations(ranges, text, &force_regions))
}
}
pub fn extract_with_fallback(
text: &str,
lang_id: &str,
path: Option<&Path>,
schema_registry: Option<&SchemaRegistry>,
latex_extras: &latex::LatexExtras,
) -> Result<Vec<ProseRange>> {
if let Some(ext) = path
.and_then(|value| value.extension())
.and_then(|value| value.to_str())
&& crate::languages::builtin_language_for_extension(ext).is_none()
&& let Some(schema) = schema_registry.and_then(|registry| registry.find_by_extension(ext))
{
return Ok(schema.extract(text));
}
let canonical_lang = crate::languages::resolve_language_id(lang_id);
let language = crate::languages::resolve_ts_language(canonical_lang);
let mut extractor = ProseExtractor::new(language)?;
let mut ranges = extractor.extract(text, canonical_lang, latex_extras)?;
let directives = IgnoreParser::parse_directives(text);
let resolved = IgnoreParser::resolve_all(text, &directives);
let type_regions: Vec<_> = resolved
.regions
.iter()
.filter(|r| r.options.doc_type.is_some())
.collect();
if !type_regions.is_empty() {
ranges = apply_type_overrides(text, ranges, &type_regions, latex_extras)?;
}
Ok(ranges)
}
fn apply_type_overrides(
text: &str,
base_ranges: Vec<ProseRange>,
type_regions: &[&DirectiveRegion],
latex_extras: &latex::LatexExtras,
) -> Result<Vec<ProseRange>> {
let override_spans: Vec<&Range<usize>> = type_regions.iter().map(|r| &r.byte_range).collect();
let mut result: Vec<ProseRange> = base_ranges
.into_iter()
.filter(|r| {
!override_spans
.iter()
.any(|span| span.contains(&r.start_byte))
})
.collect();
for region in type_regions {
let doc_type = region.options.doc_type.as_deref().unwrap();
let canonical = crate::languages::resolve_language_id(doc_type);
if !crate::languages::SUPPORTED_LANGUAGE_IDS.contains(&canonical) {
eprintln!("lang-check: `type:{doc_type}` is not a supported language; skipping region");
continue;
}
let slice = &text[region.byte_range.clone()];
let ts_lang = crate::languages::resolve_ts_language(canonical);
let mut ext = ProseExtractor::new(ts_lang)?;
let sub_ranges = ext.extract(slice, canonical, latex_extras)?;
let offset = region.byte_range.start;
for mut r in sub_ranges {
r.start_byte += offset;
r.end_byte += offset;
r.exclusions = r
.exclusions
.into_iter()
.map(|(s, e)| (s + offset, e + offset))
.collect();
result.push(r);
}
}
result.sort_by_key(|r| r.start_byte);
Ok(result)
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ProseRange {
pub start_byte: usize,
pub end_byte: usize,
pub exclusions: Vec<(usize, usize)>,
}
impl ProseRange {
#[must_use]
pub fn extract_text<'a>(&self, text: &'a str) -> std::borrow::Cow<'a, str> {
let slice = &text[self.start_byte..self.end_byte];
if self.exclusions.is_empty() {
return std::borrow::Cow::Borrowed(slice);
}
#[cfg(debug_assertions)]
for &(exc_start, exc_end) in &self.exclusions {
let s = exc_start.saturating_sub(self.start_byte).min(slice.len());
let e = exc_end.saturating_sub(self.start_byte).min(slice.len());
debug_assert!(
slice.is_char_boundary(s) && slice.is_char_boundary(e),
"exclusion ({s}, {e}) is not on a char boundary in {slice:?}"
);
}
let mut buf = slice.to_string();
let bytes = unsafe { buf.as_bytes_mut() };
for &(exc_start, exc_end) in &self.exclusions {
let local_start = exc_start.saturating_sub(self.start_byte).min(bytes.len());
let local_end = exc_end.saturating_sub(self.start_byte).min(bytes.len());
if local_start < local_end {
bytes[local_start..local_end].fill(b' ');
}
}
strip_unmatched_brackets(bytes);
std::borrow::Cow::Owned(buf)
}
#[must_use]
#[allow(clippy::cast_possible_truncation)]
pub fn overlaps_exclusion(&self, local_start: u32, local_end: u32) -> bool {
let doc_start = self.start_byte as u32 + local_start;
let doc_end = self.start_byte as u32 + local_end;
self.exclusions.iter().any(|&(exc_start, exc_end)| {
let es = exc_start as u32;
let ee = exc_end as u32;
doc_start < ee && doc_end > es
})
}
#[must_use]
pub fn exclusion_adjacency(
&self,
text: &str,
local_start: u32,
local_end: u32,
) -> ExclusionAdjacency {
if self.overlaps_exclusion(local_start, local_end) {
return ExclusionAdjacency::Overlapping;
}
let doc_start = self.start_byte + local_start as usize;
let doc_end = self.start_byte + local_end as usize;
let mut best = ExclusionAdjacency::None;
for &(es, ee) in &self.exclusions {
let rel = if doc_start >= ee {
classify_gap(text, ee, doc_start, byte_before_is_whitespace(text, ee))
} else {
classify_gap(text, doc_end, es, byte_at_is_whitespace(text, es))
};
best = best.max_severity(rel);
if best == ExclusionAdjacency::Glued {
break; }
}
best
}
#[must_use]
pub fn suppresses_diagnostic(
&self,
text: &str,
local_start: u32,
local_end: u32,
unified_id: &str,
) -> bool {
match self.exclusion_adjacency(text, local_start, local_end) {
ExclusionAdjacency::Overlapping | ExclusionAdjacency::Glued => true,
ExclusionAdjacency::WhitespaceAdjacent => !is_spelling_category(unified_id),
ExclusionAdjacency::None => false,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ExclusionAdjacency {
Overlapping,
Glued,
WhitespaceAdjacent,
None,
}
impl ExclusionAdjacency {
const fn rank(self) -> u8 {
match self {
Self::None => 0,
Self::WhitespaceAdjacent => 1,
Self::Glued => 2,
Self::Overlapping => 3,
}
}
#[must_use]
const fn max_severity(self, other: Self) -> Self {
if other.rank() > self.rank() {
other
} else {
self
}
}
}
fn classify_gap(
text: &str,
lo: usize,
hi: usize,
skip_edge_is_whitespace: bool,
) -> ExclusionAdjacency {
if lo == hi {
return if skip_edge_is_whitespace {
ExclusionAdjacency::WhitespaceAdjacent
} else {
ExclusionAdjacency::Glued
};
}
match text.get(lo..hi) {
Some(gap) if gap.chars().all(char::is_whitespace) => ExclusionAdjacency::WhitespaceAdjacent,
_ => ExclusionAdjacency::None,
}
}
fn byte_before_is_whitespace(text: &str, pos: usize) -> bool {
text.get(..pos)
.and_then(|s| s.chars().next_back())
.is_some_and(char::is_whitespace)
}
fn byte_at_is_whitespace(text: &str, pos: usize) -> bool {
text.get(pos..)
.and_then(|s| s.chars().next())
.is_some_and(char::is_whitespace)
}
#[must_use]
pub fn is_spelling_category(unified_id: &str) -> bool {
unified_id.starts_with("spelling.")
}
fn strip_unmatched_brackets(bytes: &mut [u8]) {
let mut paren_stack: Vec<usize> = Vec::new();
let mut bracket_stack: Vec<usize> = Vec::new();
let mut brace_stack: Vec<usize> = Vec::new();
let mut unmatched: Vec<usize> = Vec::new();
for (i, &b) in bytes.iter().enumerate() {
match b {
b'(' => paren_stack.push(i),
b')' if paren_stack.pop().is_none() => {
unmatched.push(i);
}
b'[' => bracket_stack.push(i),
b']' if bracket_stack.pop().is_none() => {
unmatched.push(i);
}
b'{' => brace_stack.push(i),
b'}' if brace_stack.pop().is_none() => {
unmatched.push(i);
}
_ => {}
}
}
unmatched.extend(paren_stack);
unmatched.extend(bracket_stack);
unmatched.extend(brace_stack);
for idx in unmatched {
bytes[idx] = b' ';
}
}
#[cfg(test)]
mod tests {
use super::*;
use latex::LatexExtras;
#[test]
fn extract_text_no_exclusions_is_borrowed() {
let text = "café — touché";
let range = ProseRange {
start_byte: 0,
end_byte: text.len(),
exclusions: Vec::new(),
};
let out = range.extract_text(text);
assert!(matches!(out, std::borrow::Cow::Borrowed(_)));
assert_eq!(out, text);
}
#[test]
fn extract_text_blanks_excluded_ascii_keeping_multibyte() {
let text = "café X tea";
let x = text.find('X').unwrap();
let range = ProseRange {
start_byte: 0,
end_byte: text.len(),
exclusions: vec![(x, x + 1)],
};
let out = range.extract_text(text);
assert_eq!(out, "café tea");
assert!(std::str::from_utf8(out.as_bytes()).is_ok());
}
#[test]
fn extract_text_blanks_a_whole_multibyte_char() {
let text = "a—b";
let dash_start = text.find('—').unwrap();
let dash_end = dash_start + '—'.len_utf8();
let range = ProseRange {
start_byte: 0,
end_byte: text.len(),
exclusions: vec![(dash_start, dash_end)],
};
let out = range.extract_text(text);
assert_eq!(out, "a b");
}
#[test]
fn extract_text_handles_document_level_offsets() {
let text = "PREFIX café — done";
let start = text.find("café").unwrap();
let dash = text.find('—').unwrap();
let range = ProseRange {
start_byte: start,
end_byte: text.len(),
exclusions: vec![(dash, dash + '—'.len_utf8())],
};
let out = range.extract_text(text);
assert_eq!(out, "café done");
}
#[test]
fn test_markdown_extraction() -> Result<()> {
let language: tree_sitter::Language = tree_sitter_md::LANGUAGE.into();
let mut extractor = ProseExtractor::new(language)?;
let text =
"# Header\n\nThis is a paragraph.\n\n```rust\nfn main() {}\n```\n\nAnother paragraph.";
let ranges = extractor.extract(text, "markdown", &LatexExtras::default())?;
assert!(ranges.len() >= 3);
let extracted_texts: Vec<&str> = ranges
.iter()
.map(|r| &text[r.start_byte..r.end_byte])
.collect();
assert!(extracted_texts.iter().any(|t| t.contains("Header")));
assert!(
extracted_texts
.iter()
.any(|t| t.contains("This is a paragraph"))
);
assert!(
extracted_texts
.iter()
.any(|t| t.contains("Another paragraph"))
);
Ok(())
}
#[test]
fn test_overlaps_exclusion() {
let range = ProseRange {
start_byte: 100,
end_byte: 300,
exclusions: vec![(150, 200)],
};
assert!(range.overlaps_exclusion(50, 100)); assert!(range.overlaps_exclusion(40, 60)); assert!(range.overlaps_exclusion(90, 110)); assert!(!range.overlaps_exclusion(0, 40)); assert!(!range.overlaps_exclusion(110, 130)); }
#[test]
fn test_exclusion_adjacency_classifies_position() {
let text = "a #{i} is b";
let range = ProseRange {
start_byte: 0,
end_byte: text.len(),
exclusions: vec![(2, 6)],
};
assert_eq!(
range.exclusion_adjacency(text, 7, 9),
ExclusionAdjacency::WhitespaceAdjacent
);
assert_eq!(
range.exclusion_adjacency(text, 3, 5),
ExclusionAdjacency::Overlapping
);
assert_eq!(
range.exclusion_adjacency(text, 10, 11),
ExclusionAdjacency::None
);
}
#[test]
fn test_exclusion_adjacency_detects_glued_fragment() {
let text = "#{n}th word";
let range = ProseRange {
start_byte: 0,
end_byte: text.len(),
exclusions: vec![(0, 4)],
};
assert_eq!(
range.exclusion_adjacency(text, 4, 6),
ExclusionAdjacency::Glued
);
}
#[test]
fn test_exclusion_swallowing_flanking_space_is_not_glued() {
let text = "teh #{G} ok"; let range = ProseRange {
start_byte: 0,
end_byte: text.len(),
exclusions: vec![(3, 6)],
};
assert_eq!(
range.exclusion_adjacency(text, 0, 3),
ExclusionAdjacency::WhitespaceAdjacent
);
assert!(!range.suppresses_diagnostic(text, 0, 3, "spelling.typo"));
assert!(range.suppresses_diagnostic(text, 0, 3, "typography.capitalization"));
}
#[test]
fn test_suppresses_diagnostic_keeps_spelling_near_skip() {
let text = "a #{i} wrd b";
let range = ProseRange {
start_byte: 0,
end_byte: text.len(),
exclusions: vec![(2, 6)],
};
assert!(range.suppresses_diagnostic(text, 7, 10, "typography.capitalization"));
assert!(!range.suppresses_diagnostic(text, 7, 10, "spelling.typo"));
}
#[test]
fn test_suppresses_diagnostic_drops_glued_fragment_spelling() {
let text = "#{n}th word";
let range = ProseRange {
start_byte: 0,
end_byte: text.len(),
exclusions: vec![(0, 4)],
};
assert!(range.suppresses_diagnostic(text, 4, 6, "spelling.typo"));
assert!(!range.suppresses_diagnostic(text, 7, 11, "spelling.typo"));
}
#[test]
fn type_override_latex_in_markdown() -> Result<()> {
let text = "\
# Title
Some intro text.
<!-- lang-check-begin type:latex -->
\\emph{Hello} world and \\textbf{bold} text.
<!-- lang-check-end -->
Final paragraph.";
let ranges = extract_with_fallback(text, "markdown", None, None, &LatexExtras::default())?;
let texts: Vec<&str> = ranges
.iter()
.map(|r| &text[r.start_byte..r.end_byte])
.collect();
assert!(texts.iter().any(|t| t.contains("Title")));
assert!(texts.iter().any(|t| t.contains("intro text")));
assert!(texts.iter().any(|t| t.contains("Final paragraph")));
assert!(
texts.iter().any(|t| t.contains("Hello")),
"expected LaTeX extractor to produce range containing 'Hello', got: {texts:?}"
);
Ok(())
}
#[test]
fn type_override_unknown_skipped() -> Result<()> {
let text = "\
# Title
<!-- lang-check-begin type:foobar -->
Some content here.
<!-- lang-check-end -->
Trailing text.";
let ranges = extract_with_fallback(text, "markdown", None, None, &LatexExtras::default())?;
let texts: Vec<&str> = ranges
.iter()
.map(|r| &text[r.start_byte..r.end_byte])
.collect();
assert!(texts.iter().any(|t| t.contains("Title")));
assert!(texts.iter().any(|t| t.contains("Trailing text")));
assert!(
!texts.iter().any(|t| t.contains("Some content")),
"expected unknown type region to be skipped, got: {texts:?}"
);
Ok(())
}
#[test]
fn type_override_preserves_surrounding() -> Result<()> {
let text = "\
First paragraph before.
<!-- lang-check-begin type:latex -->
\\section{Test}
Some LaTeX prose.
<!-- lang-check-end -->
Last paragraph after.";
let ranges = extract_with_fallback(text, "markdown", None, None, &LatexExtras::default())?;
let texts: Vec<&str> = ranges
.iter()
.map(|r| &text[r.start_byte..r.end_byte])
.collect();
assert!(
texts.iter().any(|t| t.contains("First paragraph before")),
"pre-region range missing: {texts:?}"
);
assert!(
texts.iter().any(|t| t.contains("Last paragraph after")),
"post-region range missing: {texts:?}"
);
Ok(())
}
#[test]
fn strip_unmatched_orphan_close() {
let mut bytes = b"hello } world".to_vec();
strip_unmatched_brackets(&mut bytes);
assert_eq!(&bytes, b"hello world");
}
#[test]
fn strip_unmatched_orphan_open() {
let mut bytes = b"hello ( world".to_vec();
strip_unmatched_brackets(&mut bytes);
assert_eq!(&bytes, b"hello world");
}
#[test]
fn strip_unmatched_preserves_matched() {
let mut bytes = b"f(x) and [y]".to_vec();
strip_unmatched_brackets(&mut bytes);
assert_eq!(&bytes, b"f(x) and [y]");
}
#[test]
fn strip_unmatched_mixed() {
let mut bytes = b"value } is f(x)".to_vec();
strip_unmatched_brackets(&mut bytes);
assert_eq!(&bytes, b"value is f(x)");
}
#[test]
fn strip_unmatched_via_extract_text() {
let range = ProseRange {
start_byte: 0,
end_byte: 20,
exclusions: vec![(5, 10)],
};
let text = "text #{x+y} rest____";
let clean = range.extract_text(text);
assert!(!clean.contains('#'));
assert!(!clean.contains('{'));
assert!(!clean.contains('}'));
}
}