1use anyhow::Result;
2use regex::Regex;
3use serde::Deserialize;
4use std::collections::BTreeSet;
5use std::path::Path;
6
7use crate::prose::ProseRange;
8
9pub const DEFAULT_SCHEMA_DIR: &str = ".langcheck/schemas";
10
11#[derive(Debug, Deserialize, Clone)]
16pub struct LanguageSchema {
17 pub name: String,
19 #[serde(default)]
21 pub extensions: Vec<String>,
22 #[serde(default)]
24 pub prose_patterns: Vec<PatternRule>,
25 #[serde(default)]
27 pub skip_patterns: Vec<PatternRule>,
28 #[serde(default)]
30 pub skip_blocks: Vec<BlockRule>,
31}
32
33#[derive(Debug, Deserialize, Clone)]
35pub struct PatternRule {
36 pub pattern: String,
38}
39
40#[derive(Debug, Deserialize, Clone)]
42pub struct BlockRule {
43 pub start: String,
45 pub end: String,
47}
48
49#[derive(Debug)]
51pub struct CompiledSchema {
52 pub name: String,
53 pub extensions: Vec<String>,
54 prose_patterns: Vec<Regex>,
55 skip_patterns: Vec<Regex>,
56 skip_blocks: Vec<(Regex, Regex)>,
57}
58
59impl CompiledSchema {
60 pub fn compile(schema: &LanguageSchema) -> Result<Self> {
62 let prose_patterns: Result<Vec<_>> = schema
63 .prose_patterns
64 .iter()
65 .map(|p| Regex::new(&p.pattern).map_err(Into::into))
66 .collect();
67
68 let skip_patterns: Result<Vec<_>> = schema
69 .skip_patterns
70 .iter()
71 .map(|p| Regex::new(&p.pattern).map_err(Into::into))
72 .collect();
73
74 let skip_blocks: Result<Vec<_>> = schema
75 .skip_blocks
76 .iter()
77 .map(|b| Ok((Regex::new(&b.start)?, Regex::new(&b.end)?)))
78 .collect();
79
80 Ok(Self {
81 name: schema.name.clone(),
82 extensions: schema.extensions.clone(),
83 prose_patterns: prose_patterns?,
84 skip_patterns: skip_patterns?,
85 skip_blocks: skip_blocks?,
86 })
87 }
88
89 #[must_use]
98 pub fn extract(&self, text: &str) -> Vec<ProseRange> {
99 let skip_regions = self.find_skip_blocks(text);
100 let mut prose_lines: Vec<(usize, usize)> = Vec::new();
101
102 let mut offset = 0;
103 for line in text.split('\n') {
104 let line_start = offset;
105 let line_end = offset + line.len();
106 offset = line_end + 1; if skip_regions
110 .iter()
111 .any(|(s, e)| line_start >= *s && line_start < *e)
112 {
113 continue;
114 }
115
116 if self.skip_patterns.iter().any(|re| re.is_match(line)) {
118 continue;
119 }
120
121 if line.trim().is_empty() {
123 continue;
124 }
125
126 if !self.prose_patterns.is_empty()
128 && !self.prose_patterns.iter().any(|re| re.is_match(line))
129 {
130 continue;
131 }
132
133 prose_lines.push((line_start, line_end));
134 }
135
136 merge_ranges(prose_lines)
138 }
139
140 fn find_skip_blocks(&self, text: &str) -> Vec<(usize, usize)> {
142 let mut regions = Vec::new();
143
144 for (start_re, end_re) in &self.skip_blocks {
145 let lines: Vec<(usize, &str)> = text
146 .split('\n')
147 .scan(0usize, |offset, line| {
148 let start = *offset;
149 *offset += line.len() + 1;
150 Some((start, line))
151 })
152 .collect();
153
154 let mut i = 0;
155 while i < lines.len() {
156 let (line_start, line) = lines[i];
157 if start_re.is_match(line) {
158 let mut block_end = text.len();
160 for &(_, inner_line) in &lines[i + 1..] {
161 if end_re.is_match(inner_line) {
162 let inner_end = inner_line.as_ptr() as usize - text.as_ptr() as usize
164 + inner_line.len();
165 block_end = inner_end;
166 i = lines
168 .iter()
169 .position(|&(s, _)| s >= block_end)
170 .unwrap_or(lines.len());
171 break;
172 }
173 }
174 regions.push((line_start, block_end));
175 continue;
176 }
177 i += 1;
178 }
179 }
180
181 regions
182 }
183}
184
185fn merge_ranges(mut ranges: Vec<(usize, usize)>) -> Vec<ProseRange> {
187 if ranges.is_empty() {
188 return Vec::new();
189 }
190
191 ranges.sort_by_key(|(s, _)| *s);
192 let mut merged = Vec::new();
193 let (mut cur_start, mut cur_end) = ranges[0];
194
195 for &(start, end) in &ranges[1..] {
196 if start <= cur_end + 2 {
198 cur_end = cur_end.max(end);
199 } else {
200 merged.push(ProseRange {
201 start_byte: cur_start,
202 end_byte: cur_end,
203 exclusions: vec![],
204 });
205 cur_start = start;
206 cur_end = end;
207 }
208 }
209 merged.push(ProseRange {
210 start_byte: cur_start,
211 end_byte: cur_end,
212 exclusions: vec![],
213 });
214
215 merged
216}
217
218#[derive(Debug, Default)]
220pub struct SchemaRegistry {
221 schemas: Vec<CompiledSchema>,
222}
223
224impl SchemaRegistry {
225 #[must_use]
226 pub fn new() -> Self {
227 Self::default()
228 }
229
230 pub fn load_yaml(&mut self, yaml: &str) -> Result<()> {
232 let schema: LanguageSchema = serde_yaml::from_str(yaml)?;
233 let compiled = CompiledSchema::compile(&schema)?;
234 self.schemas.push(compiled);
235 Ok(())
236 }
237
238 pub fn load_file(&mut self, path: &std::path::Path) -> Result<()> {
240 let content = std::fs::read_to_string(path)?;
241 self.load_yaml(&content)
242 }
243
244 pub fn load_dir(&mut self, dir: &std::path::Path) -> Result<usize> {
246 let mut count = 0;
247 if !dir.exists() {
248 return Ok(0);
249 }
250 for entry in std::fs::read_dir(dir)? {
251 let entry = entry?;
252 let path = entry.path();
253 if let Some(ext) = path.extension().and_then(|e| e.to_str())
254 && (ext == "yaml" || ext == "yml")
255 {
256 self.load_file(&path)?;
257 count += 1;
258 }
259 }
260 Ok(count)
261 }
262
263 pub fn from_workspace(workspace_root: &Path) -> Result<Self> {
265 let mut registry = Self::new();
266 registry.load_dir(&workspace_root.join(DEFAULT_SCHEMA_DIR))?;
267 Ok(registry)
268 }
269
270 #[must_use]
272 pub fn find_by_extension(&self, ext: &str) -> Option<&CompiledSchema> {
273 self.schemas
274 .iter()
275 .find(|s| s.extensions.iter().any(|e| e == ext))
276 }
277
278 #[must_use]
280 pub const fn len(&self) -> usize {
281 self.schemas.len()
282 }
283
284 #[must_use]
286 pub const fn is_empty(&self) -> bool {
287 self.schemas.is_empty()
288 }
289
290 #[must_use]
292 pub fn fallback_file_patterns(&self) -> Vec<(String, String)> {
293 let mut patterns = BTreeSet::new();
294
295 for schema in &self.schemas {
296 for ext in &schema.extensions {
297 if crate::languages::builtin_language_for_extension(ext).is_none() {
298 patterns.insert((format!("**/*.{ext}"), schema.name.clone()));
299 }
300 }
301 }
302
303 patterns.into_iter().collect()
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310
311 const RST_SCHEMA: &str = r#"
312name: restructuredtext
313extensions:
314 - rst
315 - rest
316prose_patterns:
317 - pattern: "^[^\\s\\.\\:].*\\S"
318skip_patterns:
319 - pattern: "^\\.\\."
320 - pattern: "^\\s*$"
321 - pattern: "^[=\\-~`:'\"^_*+#]{3,}$"
322skip_blocks:
323 - start: "^::\\s*$"
324 end: "^\\S"
325"#;
326
327 const TOML_SCHEMA: &str = r#"
328name: toml
329extensions:
330 - toml
331prose_patterns: []
332skip_patterns:
333 - pattern: "^\\s*#"
334 - pattern: "^\\s*\\["
335 - pattern: "^\\s*\\w+\\s*="
336skip_blocks: []
337"#;
338
339 #[test]
340 fn compile_rst_schema() {
341 let schema: LanguageSchema = serde_yaml::from_str(RST_SCHEMA).unwrap();
342 let compiled = CompiledSchema::compile(&schema).unwrap();
343 assert_eq!(compiled.name, "restructuredtext");
344 assert_eq!(compiled.extensions, vec!["rst", "rest"]);
345 }
346
347 #[test]
348 fn rst_extract_prose() {
349 let schema: LanguageSchema = serde_yaml::from_str(RST_SCHEMA).unwrap();
350 let compiled = CompiledSchema::compile(&schema).unwrap();
351
352 let text = "Title\n=====\n\nThis is a paragraph.\n\n.. note::\n\n This is a directive.\n\nAnother paragraph here.";
353 let ranges = compiled.extract(text);
354
355 let extracted: Vec<&str> = ranges
356 .iter()
357 .map(|r| &text[r.start_byte..r.end_byte])
358 .collect();
359 assert!(extracted.iter().any(|t| t.contains("This is a paragraph")));
360 assert!(extracted.iter().any(|t| t.contains("Another paragraph")));
361 assert!(!extracted.iter().any(|t| t.contains(".. note")));
363 }
364
365 #[test]
366 fn toml_no_prose_patterns_means_all_non_skipped() {
367 let schema: LanguageSchema = serde_yaml::from_str(TOML_SCHEMA).unwrap();
368 let compiled = CompiledSchema::compile(&schema).unwrap();
369
370 let text = "# Comment\n[section]\nkey = \"value\"";
372 let ranges = compiled.extract(text);
373 assert!(ranges.is_empty());
375 }
376
377 #[test]
378 fn skip_blocks() {
379 let yaml = r#"
380name: test
381extensions: [test]
382prose_patterns: []
383skip_patterns: []
384skip_blocks:
385 - start: "^```"
386 end: "^```"
387"#;
388 let schema: LanguageSchema = serde_yaml::from_str(yaml).unwrap();
389 let compiled = CompiledSchema::compile(&schema).unwrap();
390
391 let text = "Prose line one\n```\ncode here\nmore code\n```\nProse line two";
392 let ranges = compiled.extract(text);
393
394 let extracted: Vec<&str> = ranges
395 .iter()
396 .map(|r| &text[r.start_byte..r.end_byte])
397 .collect();
398 assert!(extracted.iter().any(|t| t.contains("Prose line one")));
399 assert!(extracted.iter().any(|t| t.contains("Prose line two")));
400 assert!(!extracted.iter().any(|t| t.contains("code here")));
401 }
402
403 #[test]
404 fn schema_registry_lookup() {
405 let mut registry = SchemaRegistry::new();
406 registry.load_yaml(RST_SCHEMA).unwrap();
407 registry.load_yaml(TOML_SCHEMA).unwrap();
408 assert_eq!(registry.len(), 2);
409
410 let rst = registry.find_by_extension("rst");
411 assert!(rst.is_some());
412 assert_eq!(rst.unwrap().name, "restructuredtext");
413
414 let toml = registry.find_by_extension("toml");
415 assert!(toml.is_some());
416 assert_eq!(toml.unwrap().name, "toml");
417
418 assert!(registry.find_by_extension("py").is_none());
419 }
420
421 #[test]
422 fn merge_adjacent_ranges() {
423 let ranges = vec![(0, 5), (6, 10), (11, 15)];
424 let merged = merge_ranges(ranges);
425 assert_eq!(merged.len(), 1);
427 assert_eq!(merged[0].start_byte, 0);
428 assert_eq!(merged[0].end_byte, 15);
429 }
430
431 #[test]
432 fn no_merge_for_distant_ranges() {
433 let ranges = vec![(0, 5), (20, 25)];
434 let merged = merge_ranges(ranges);
435 assert_eq!(merged.len(), 2);
436 }
437
438 #[test]
439 fn empty_text() {
440 let schema: LanguageSchema = serde_yaml::from_str(RST_SCHEMA).unwrap();
441 let compiled = CompiledSchema::compile(&schema).unwrap();
442 let ranges = compiled.extract("");
443 assert!(ranges.is_empty());
444 }
445
446 #[test]
447 fn invalid_regex_returns_error() {
448 let yaml = r#"
449name: bad
450extensions: [bad]
451prose_patterns:
452 - pattern: "[invalid"
453"#;
454 let schema: LanguageSchema = serde_yaml::from_str(yaml).unwrap();
455 assert!(CompiledSchema::compile(&schema).is_err());
456 }
457
458 #[test]
459 fn fallback_file_patterns_skip_builtins() {
460 let mut registry = SchemaRegistry::new();
461 registry.load_yaml(RST_SCHEMA).unwrap();
462 registry
463 .load_yaml(
464 r#"
465name: asciidoc
466extensions: [adoc, asciidoc]
467prose_patterns: []
468skip_patterns: []
469skip_blocks: []
470"#,
471 )
472 .unwrap();
473
474 let patterns = registry.fallback_file_patterns();
475
476 assert!(!patterns.iter().any(|(pattern, _)| pattern == "**/*.rst"));
477 assert!(
478 patterns
479 .iter()
480 .any(|(pattern, lang)| pattern == "**/*.adoc" && lang == "asciidoc")
481 );
482 assert!(
483 patterns
484 .iter()
485 .any(|(pattern, lang)| pattern == "**/*.asciidoc" && lang == "asciidoc")
486 );
487 }
488}