Skip to main content

lang_check/
sls.rs

1use anyhow::Result;
2use regex::Regex;
3use serde::Deserialize;
4use std::collections::BTreeSet;
5use std::path::Path;
6
7use crate::prose::ProseRange;
8
9pub const DEFAULT_SCHEMA_DIR: &str = ".langcheck/schemas";
10
11/// A Simplified Language Schema definition, loaded from YAML.
12///
13/// Defines how to extract prose regions from a file format using regex patterns,
14/// for languages that don't have tree-sitter grammars (e.g. RST, `AsciiDoc`, TOML).
15#[derive(Debug, Deserialize, Clone)]
16pub struct LanguageSchema {
17    /// Schema name (e.g. "restructuredtext").
18    pub name: String,
19    /// File extensions this schema handles (e.g. [`rst`, `rest`]).
20    #[serde(default)]
21    pub extensions: Vec<String>,
22    /// Patterns that match lines containing prose text.
23    #[serde(default)]
24    pub prose_patterns: Vec<PatternRule>,
25    /// Patterns that match lines to skip (comments, directives, code, etc.).
26    #[serde(default)]
27    pub skip_patterns: Vec<PatternRule>,
28    /// Block delimiters for multi-line regions to skip entirely.
29    #[serde(default)]
30    pub skip_blocks: Vec<BlockRule>,
31}
32
33/// A single-line regex pattern rule.
34#[derive(Debug, Deserialize, Clone)]
35pub struct PatternRule {
36    /// The regex pattern to match against each line.
37    pub pattern: String,
38}
39
40/// A block delimiter pair for regions to skip.
41#[derive(Debug, Deserialize, Clone)]
42pub struct BlockRule {
43    /// Regex matching the start of the block.
44    pub start: String,
45    /// Regex matching the end of the block.
46    pub end: String,
47}
48
49/// Compiled version of a `LanguageSchema`, ready for fast matching.
50#[derive(Debug)]
51pub struct CompiledSchema {
52    pub name: String,
53    pub extensions: Vec<String>,
54    prose_patterns: Vec<Regex>,
55    skip_patterns: Vec<Regex>,
56    skip_blocks: Vec<(Regex, Regex)>,
57}
58
59impl CompiledSchema {
60    /// Compile a schema from its YAML definition.
61    pub fn compile(schema: &LanguageSchema) -> Result<Self> {
62        let prose_patterns: Result<Vec<_>> = schema
63            .prose_patterns
64            .iter()
65            .map(|p| Regex::new(&p.pattern).map_err(Into::into))
66            .collect();
67
68        let skip_patterns: Result<Vec<_>> = schema
69            .skip_patterns
70            .iter()
71            .map(|p| Regex::new(&p.pattern).map_err(Into::into))
72            .collect();
73
74        let skip_blocks: Result<Vec<_>> = schema
75            .skip_blocks
76            .iter()
77            .map(|b| Ok((Regex::new(&b.start)?, Regex::new(&b.end)?)))
78            .collect();
79
80        Ok(Self {
81            name: schema.name.clone(),
82            extensions: schema.extensions.clone(),
83            prose_patterns: prose_patterns?,
84            skip_patterns: skip_patterns?,
85            skip_blocks: skip_blocks?,
86        })
87    }
88
89    /// Extract prose ranges from the given text.
90    ///
91    /// Strategy:
92    /// 1. First, identify skip-block regions and mark them as excluded.
93    /// 2. For each line, check if it matches a skip pattern (excluded).
94    /// 3. For remaining lines, check if they match a prose pattern (included).
95    /// 4. If no prose patterns are defined, all non-skipped lines are prose.
96    /// 5. Merge adjacent prose ranges.
97    #[must_use]
98    pub fn extract(&self, text: &str) -> Vec<ProseRange> {
99        let skip_regions = self.find_skip_blocks(text);
100        let mut prose_lines: Vec<(usize, usize)> = Vec::new();
101
102        let mut offset = 0;
103        for line in text.split('\n') {
104            let line_start = offset;
105            let line_end = offset + line.len();
106            offset = line_end + 1; // +1 for newline
107
108            // Skip if inside a skip block
109            if skip_regions
110                .iter()
111                .any(|(s, e)| line_start >= *s && line_start < *e)
112            {
113                continue;
114            }
115
116            // Skip if matches a skip pattern
117            if self.skip_patterns.iter().any(|re| re.is_match(line)) {
118                continue;
119            }
120
121            // Skip empty lines
122            if line.trim().is_empty() {
123                continue;
124            }
125
126            // If prose patterns are defined, line must match at least one
127            if !self.prose_patterns.is_empty()
128                && !self.prose_patterns.iter().any(|re| re.is_match(line))
129            {
130                continue;
131            }
132
133            prose_lines.push((line_start, line_end));
134        }
135
136        // Merge adjacent/contiguous ranges
137        merge_ranges(prose_lines)
138    }
139
140    /// Find byte ranges of skip blocks in the text.
141    fn find_skip_blocks(&self, text: &str) -> Vec<(usize, usize)> {
142        let mut regions = Vec::new();
143
144        for (start_re, end_re) in &self.skip_blocks {
145            let lines: Vec<(usize, &str)> = text
146                .split('\n')
147                .scan(0usize, |offset, line| {
148                    let start = *offset;
149                    *offset += line.len() + 1;
150                    Some((start, line))
151                })
152                .collect();
153
154            let mut i = 0;
155            while i < lines.len() {
156                let (line_start, line) = lines[i];
157                if start_re.is_match(line) {
158                    // Find the matching end, starting from the NEXT line
159                    let mut block_end = text.len();
160                    for &(_, inner_line) in &lines[i + 1..] {
161                        if end_re.is_match(inner_line) {
162                            // End includes the closing delimiter line
163                            let inner_end = inner_line.as_ptr() as usize - text.as_ptr() as usize
164                                + inner_line.len();
165                            block_end = inner_end;
166                            // Skip past the end delimiter
167                            i = lines
168                                .iter()
169                                .position(|&(s, _)| s >= block_end)
170                                .unwrap_or(lines.len());
171                            break;
172                        }
173                    }
174                    regions.push((line_start, block_end));
175                    continue;
176                }
177                i += 1;
178            }
179        }
180
181        regions
182    }
183}
184
185/// Merge contiguous or overlapping byte ranges into larger ones.
186fn merge_ranges(mut ranges: Vec<(usize, usize)>) -> Vec<ProseRange> {
187    if ranges.is_empty() {
188        return Vec::new();
189    }
190
191    ranges.sort_by_key(|(s, _)| *s);
192    let mut merged = Vec::new();
193    let (mut cur_start, mut cur_end) = ranges[0];
194
195    for &(start, end) in &ranges[1..] {
196        // If this range is adjacent (within 1 byte for newline) or overlapping, extend
197        if start <= cur_end + 2 {
198            cur_end = cur_end.max(end);
199        } else {
200            merged.push(ProseRange {
201                start_byte: cur_start,
202                end_byte: cur_end,
203                exclusions: vec![],
204            });
205            cur_start = start;
206            cur_end = end;
207        }
208    }
209    merged.push(ProseRange {
210        start_byte: cur_start,
211        end_byte: cur_end,
212        exclusions: vec![],
213    });
214
215    merged
216}
217
218/// Registry of compiled schemas for looking up by file extension.
219#[derive(Debug, Default)]
220pub struct SchemaRegistry {
221    schemas: Vec<CompiledSchema>,
222}
223
224impl SchemaRegistry {
225    #[must_use]
226    pub fn new() -> Self {
227        Self::default()
228    }
229
230    /// Load and compile a schema from a YAML string.
231    pub fn load_yaml(&mut self, yaml: &str) -> Result<()> {
232        let schema: LanguageSchema = serde_yaml::from_str(yaml)?;
233        let compiled = CompiledSchema::compile(&schema)?;
234        self.schemas.push(compiled);
235        Ok(())
236    }
237
238    /// Load and compile a schema from a YAML file.
239    pub fn load_file(&mut self, path: &std::path::Path) -> Result<()> {
240        let content = std::fs::read_to_string(path)?;
241        self.load_yaml(&content)
242    }
243
244    /// Load all `.yaml`/`.yml` schemas from a directory.
245    pub fn load_dir(&mut self, dir: &std::path::Path) -> Result<usize> {
246        let mut count = 0;
247        if !dir.exists() {
248            return Ok(0);
249        }
250        for entry in std::fs::read_dir(dir)? {
251            let entry = entry?;
252            let path = entry.path();
253            if let Some(ext) = path.extension().and_then(|e| e.to_str())
254                && (ext == "yaml" || ext == "yml")
255            {
256                self.load_file(&path)?;
257                count += 1;
258            }
259        }
260        Ok(count)
261    }
262
263    /// Load all workspace schemas from the default config directory.
264    pub fn from_workspace(workspace_root: &Path) -> Result<Self> {
265        let mut registry = Self::new();
266        registry.load_dir(&workspace_root.join(DEFAULT_SCHEMA_DIR))?;
267        Ok(registry)
268    }
269
270    /// Find a compiled schema by file extension.
271    #[must_use]
272    pub fn find_by_extension(&self, ext: &str) -> Option<&CompiledSchema> {
273        self.schemas
274            .iter()
275            .find(|s| s.extensions.iter().any(|e| e == ext))
276    }
277
278    /// Number of loaded schemas.
279    #[must_use]
280    pub const fn len(&self) -> usize {
281        self.schemas.len()
282    }
283
284    /// Whether the registry is empty.
285    #[must_use]
286    pub const fn is_empty(&self) -> bool {
287        self.schemas.is_empty()
288    }
289
290    /// Glob patterns for extensions handled only by SLS, preserving built-in precedence.
291    #[must_use]
292    pub fn fallback_file_patterns(&self) -> Vec<(String, String)> {
293        let mut patterns = BTreeSet::new();
294
295        for schema in &self.schemas {
296            for ext in &schema.extensions {
297                if crate::languages::builtin_language_for_extension(ext).is_none() {
298                    patterns.insert((format!("**/*.{ext}"), schema.name.clone()));
299                }
300            }
301        }
302
303        patterns.into_iter().collect()
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310
311    const RST_SCHEMA: &str = r#"
312name: restructuredtext
313extensions:
314  - rst
315  - rest
316prose_patterns:
317  - pattern: "^[^\\s\\.\\:].*\\S"
318skip_patterns:
319  - pattern: "^\\.\\."
320  - pattern: "^\\s*$"
321  - pattern: "^[=\\-~`:'\"^_*+#]{3,}$"
322skip_blocks:
323  - start: "^::\\s*$"
324    end: "^\\S"
325"#;
326
327    const TOML_SCHEMA: &str = r#"
328name: toml
329extensions:
330  - toml
331prose_patterns: []
332skip_patterns:
333  - pattern: "^\\s*#"
334  - pattern: "^\\s*\\["
335  - pattern: "^\\s*\\w+\\s*="
336skip_blocks: []
337"#;
338
339    #[test]
340    fn compile_rst_schema() {
341        let schema: LanguageSchema = serde_yaml::from_str(RST_SCHEMA).unwrap();
342        let compiled = CompiledSchema::compile(&schema).unwrap();
343        assert_eq!(compiled.name, "restructuredtext");
344        assert_eq!(compiled.extensions, vec!["rst", "rest"]);
345    }
346
347    #[test]
348    fn rst_extract_prose() {
349        let schema: LanguageSchema = serde_yaml::from_str(RST_SCHEMA).unwrap();
350        let compiled = CompiledSchema::compile(&schema).unwrap();
351
352        let text = "Title\n=====\n\nThis is a paragraph.\n\n.. note::\n\n   This is a directive.\n\nAnother paragraph here.";
353        let ranges = compiled.extract(text);
354
355        let extracted: Vec<&str> = ranges
356            .iter()
357            .map(|r| &text[r.start_byte..r.end_byte])
358            .collect();
359        assert!(extracted.iter().any(|t| t.contains("This is a paragraph")));
360        assert!(extracted.iter().any(|t| t.contains("Another paragraph")));
361        // Directive content should be excluded via skip pattern
362        assert!(!extracted.iter().any(|t| t.contains(".. note")));
363    }
364
365    #[test]
366    fn toml_no_prose_patterns_means_all_non_skipped() {
367        let schema: LanguageSchema = serde_yaml::from_str(TOML_SCHEMA).unwrap();
368        let compiled = CompiledSchema::compile(&schema).unwrap();
369
370        // TOML with no prose_patterns and all lines matching skip patterns
371        let text = "# Comment\n[section]\nkey = \"value\"";
372        let ranges = compiled.extract(text);
373        // All lines match skip patterns, so no prose
374        assert!(ranges.is_empty());
375    }
376
377    #[test]
378    fn skip_blocks() {
379        let yaml = r#"
380name: test
381extensions: [test]
382prose_patterns: []
383skip_patterns: []
384skip_blocks:
385  - start: "^```"
386    end: "^```"
387"#;
388        let schema: LanguageSchema = serde_yaml::from_str(yaml).unwrap();
389        let compiled = CompiledSchema::compile(&schema).unwrap();
390
391        let text = "Prose line one\n```\ncode here\nmore code\n```\nProse line two";
392        let ranges = compiled.extract(text);
393
394        let extracted: Vec<&str> = ranges
395            .iter()
396            .map(|r| &text[r.start_byte..r.end_byte])
397            .collect();
398        assert!(extracted.iter().any(|t| t.contains("Prose line one")));
399        assert!(extracted.iter().any(|t| t.contains("Prose line two")));
400        assert!(!extracted.iter().any(|t| t.contains("code here")));
401    }
402
403    #[test]
404    fn schema_registry_lookup() {
405        let mut registry = SchemaRegistry::new();
406        registry.load_yaml(RST_SCHEMA).unwrap();
407        registry.load_yaml(TOML_SCHEMA).unwrap();
408        assert_eq!(registry.len(), 2);
409
410        let rst = registry.find_by_extension("rst");
411        assert!(rst.is_some());
412        assert_eq!(rst.unwrap().name, "restructuredtext");
413
414        let toml = registry.find_by_extension("toml");
415        assert!(toml.is_some());
416        assert_eq!(toml.unwrap().name, "toml");
417
418        assert!(registry.find_by_extension("py").is_none());
419    }
420
421    #[test]
422    fn merge_adjacent_ranges() {
423        let ranges = vec![(0, 5), (6, 10), (11, 15)];
424        let merged = merge_ranges(ranges);
425        // All within 2 bytes of each other, should merge to one
426        assert_eq!(merged.len(), 1);
427        assert_eq!(merged[0].start_byte, 0);
428        assert_eq!(merged[0].end_byte, 15);
429    }
430
431    #[test]
432    fn no_merge_for_distant_ranges() {
433        let ranges = vec![(0, 5), (20, 25)];
434        let merged = merge_ranges(ranges);
435        assert_eq!(merged.len(), 2);
436    }
437
438    #[test]
439    fn empty_text() {
440        let schema: LanguageSchema = serde_yaml::from_str(RST_SCHEMA).unwrap();
441        let compiled = CompiledSchema::compile(&schema).unwrap();
442        let ranges = compiled.extract("");
443        assert!(ranges.is_empty());
444    }
445
446    #[test]
447    fn invalid_regex_returns_error() {
448        let yaml = r#"
449name: bad
450extensions: [bad]
451prose_patterns:
452  - pattern: "[invalid"
453"#;
454        let schema: LanguageSchema = serde_yaml::from_str(yaml).unwrap();
455        assert!(CompiledSchema::compile(&schema).is_err());
456    }
457
458    #[test]
459    fn fallback_file_patterns_skip_builtins() {
460        let mut registry = SchemaRegistry::new();
461        registry.load_yaml(RST_SCHEMA).unwrap();
462        registry
463            .load_yaml(
464                r#"
465name: asciidoc
466extensions: [adoc, asciidoc]
467prose_patterns: []
468skip_patterns: []
469skip_blocks: []
470"#,
471            )
472            .unwrap();
473
474        let patterns = registry.fallback_file_patterns();
475
476        assert!(!patterns.iter().any(|(pattern, _)| pattern == "**/*.rst"));
477        assert!(
478            patterns
479                .iter()
480                .any(|(pattern, lang)| pattern == "**/*.adoc" && lang == "asciidoc")
481        );
482        assert!(
483            patterns
484                .iter()
485                .any(|(pattern, lang)| pattern == "**/*.asciidoc" && lang == "asciidoc")
486        );
487    }
488}