1use crate::ts::{ParsedSource, RustParser, TreeSitterError};
16use std::path::Path;
17use thiserror::Error;
18
19#[derive(Error, Debug)]
21pub enum ValidationError {
22 #[error("Parse error introduced: found {count} new ERROR nodes")]
23 ParseErrorIntroduced {
24 count: usize,
25 errors: Vec<ErrorLocation>,
26 },
27
28 #[error("Selector matched {count} locations, expected exactly 1")]
29 SelectorNotUnique { count: usize, pattern: String },
30
31 #[error("Selector matched 0 locations")]
32 NoMatch { pattern: String },
33
34 #[error("syn validation failed: {message}")]
35 SynValidationFailed { message: String, code: String },
36
37 #[error("Tree-sitter error: {0}")]
38 TreeSitter(#[from] TreeSitterError),
39
40 #[error("IO error: {0}")]
41 Io(#[from] std::io::Error),
42}
43
44#[derive(Debug, Clone)]
46pub struct ErrorLocation {
47 pub byte_start: usize,
48 pub byte_end: usize,
49 pub line: usize,
50 pub column: usize,
51 pub context: String,
52}
53
54pub struct ParseValidator {
56 parser: RustParser,
57}
58
59impl ParseValidator {
60 pub fn new() -> Result<Self, TreeSitterError> {
62 Ok(Self {
63 parser: RustParser::new()?,
64 })
65 }
66
67 pub fn validate(&mut self, source: &str) -> Result<(), ValidationError> {
69 let parsed = self.parser.parse_with_source(source)?;
70 let errors = collect_errors(&parsed, source);
71
72 if !errors.is_empty() {
73 return Err(ValidationError::ParseErrorIntroduced {
74 count: errors.len(),
75 errors,
76 });
77 }
78
79 Ok(())
80 }
81
82 pub fn validate_file(&mut self, path: impl AsRef<Path>) -> Result<(), ValidationError> {
84 let source = std::fs::read_to_string(path)?;
85 self.validate(&source)
86 }
87
88 pub fn validate_edit(&mut self, original: &str, edited: &str) -> Result<(), ValidationError> {
93 let original_parsed = self.parser.parse_with_source(original)?;
94 let edited_parsed = self.parser.parse_with_source(edited)?;
95
96 let original_errors = collect_error_positions(&original_parsed);
97 let edited_errors = collect_errors(&edited_parsed, edited);
98
99 let new_errors: Vec<_> = edited_errors
101 .into_iter()
102 .filter(|e| !original_errors.contains(&(e.byte_start, e.byte_end)))
103 .collect();
104
105 if !new_errors.is_empty() {
106 return Err(ValidationError::ParseErrorIntroduced {
107 count: new_errors.len(),
108 errors: new_errors,
109 });
110 }
111
112 Ok(())
113 }
114}
115
116impl Default for ParseValidator {
117 fn default() -> Self {
121 Self::new().expect("tree-sitter parser initialization failed")
122 }
123}
124
125pub mod pooled {
130 use super::*;
131 use crate::pool;
132
133 pub fn validate(source: &str) -> Result<(), ValidationError> {
135 pool::with_parser(|parser| {
136 let parsed = parser.parse_with_source(source)?;
137 let errors = collect_errors(&parsed, source);
138
139 if !errors.is_empty() {
140 return Err(ValidationError::ParseErrorIntroduced {
141 count: errors.len(),
142 errors,
143 });
144 }
145
146 Ok(())
147 })?
148 }
149
150 pub fn validate_edit(original: &str, edited: &str) -> Result<(), ValidationError> {
152 pool::with_parser(|parser| {
153 let original_parsed = parser.parse_with_source(original)?;
154 let original_errors = collect_error_positions(&original_parsed);
155
156 let edited_parsed = parser.parse_with_source(edited)?;
157 let edited_errors = collect_error_positions(&edited_parsed);
158
159 let new_errors: Vec<_> = edited_errors
161 .difference(&original_errors)
162 .copied()
163 .collect();
164
165 if !new_errors.is_empty() {
166 let error_details = collect_errors(&edited_parsed, edited);
167 return Err(ValidationError::ParseErrorIntroduced {
168 count: error_details.len(),
169 errors: error_details,
170 });
171 }
172
173 Ok(())
174 })?
175 }
176}
177
178fn collect_errors(parsed: &ParsedSource<'_>, source: &str) -> Vec<ErrorLocation> {
180 let mut errors = Vec::new();
181 collect_errors_recursive(parsed.root_node(), source, &mut errors);
182 errors
183}
184
185fn collect_errors_recursive(
186 node: tree_sitter::Node<'_>,
187 source: &str,
188 errors: &mut Vec<ErrorLocation>,
189) {
190 if node.is_error() || node.is_missing() {
191 let start = node.start_position();
192 let byte_start = node.start_byte();
193 let byte_end = node.end_byte();
194
195 let context_start = byte_start.saturating_sub(20);
197 let context_end = (byte_end + 20).min(source.len());
198 let context = source
199 .get(context_start..context_end)
200 .unwrap_or("")
201 .replace('\n', "\\n");
202
203 errors.push(ErrorLocation {
204 byte_start,
205 byte_end,
206 line: start.row + 1,
207 column: start.column + 1,
208 context,
209 });
210 }
211
212 let mut cursor = node.walk();
213 for child in node.children(&mut cursor) {
214 collect_errors_recursive(child, source, errors);
215 }
216}
217
218fn collect_error_positions(parsed: &ParsedSource<'_>) -> std::collections::HashSet<(usize, usize)> {
220 let mut positions = std::collections::HashSet::new();
221 collect_error_positions_recursive(parsed.root_node(), &mut positions);
222 positions
223}
224
225fn collect_error_positions_recursive(
226 node: tree_sitter::Node<'_>,
227 positions: &mut std::collections::HashSet<(usize, usize)>,
228) {
229 if node.is_error() || node.is_missing() {
230 positions.insert((node.start_byte(), node.end_byte()));
231 }
232
233 let mut cursor = node.walk();
234 for child in node.children(&mut cursor) {
235 collect_error_positions_recursive(child, positions);
236 }
237}
238
239pub mod syn_validate {
241 use super::ValidationError;
242
243 pub fn validate_item(code: &str) -> Result<(), ValidationError> {
245 syn::parse_str::<syn::Item>(code).map_err(|e| ValidationError::SynValidationFailed {
246 message: e.to_string(),
247 code: code.to_string(),
248 })?;
249 Ok(())
250 }
251
252 pub fn validate_expr(code: &str) -> Result<(), ValidationError> {
254 syn::parse_str::<syn::Expr>(code).map_err(|e| ValidationError::SynValidationFailed {
255 message: e.to_string(),
256 code: code.to_string(),
257 })?;
258 Ok(())
259 }
260
261 pub fn validate_stmt(code: &str) -> Result<(), ValidationError> {
263 syn::parse_str::<syn::Stmt>(code).map_err(|e| ValidationError::SynValidationFailed {
264 message: e.to_string(),
265 code: code.to_string(),
266 })?;
267 Ok(())
268 }
269
270 pub fn validate_type(code: &str) -> Result<(), ValidationError> {
272 syn::parse_str::<syn::Type>(code).map_err(|e| ValidationError::SynValidationFailed {
273 message: e.to_string(),
274 code: code.to_string(),
275 })?;
276 Ok(())
277 }
278
279 pub fn validate_file(code: &str) -> Result<(), ValidationError> {
281 syn::parse_file(code).map_err(|e| ValidationError::SynValidationFailed {
282 message: e.to_string(),
283 code: code.to_string(),
284 })?;
285 Ok(())
286 }
287
288 pub fn validate_match_arm_body(code: &str) -> Result<(), ValidationError> {
290 let trimmed = code.trim().trim_end_matches(',');
292 validate_expr(trimmed)
293 }
294
295 pub fn validate_block(code: &str) -> Result<(), ValidationError> {
297 let block_code = format!("{{ {} }}", code);
299 syn::parse_str::<syn::Block>(&block_code).map_err(|e| {
300 ValidationError::SynValidationFailed {
301 message: e.to_string(),
302 code: code.to_string(),
303 }
304 })?;
305 Ok(())
306 }
307}
308
309pub struct SelectorValidator;
311
312impl SelectorValidator {
313 pub fn check_unique(count: usize, pattern: &str) -> Result<(), ValidationError> {
315 match count {
316 0 => Err(ValidationError::NoMatch {
317 pattern: pattern.to_string(),
318 }),
319 1 => Ok(()),
320 n => Err(ValidationError::SelectorNotUnique {
321 count: n,
322 pattern: pattern.to_string(),
323 }),
324 }
325 }
326
327 pub fn check_found(count: usize, pattern: &str) -> Result<(), ValidationError> {
329 if count == 0 {
330 Err(ValidationError::NoMatch {
331 pattern: pattern.to_string(),
332 })
333 } else {
334 Ok(())
335 }
336 }
337}
338
339pub struct ValidatedEdit {
345 edit: crate::edit::Edit,
346 validate_parse: bool,
347}
348
349impl ValidatedEdit {
350 pub fn new(edit: crate::edit::Edit) -> Self {
352 Self {
353 edit,
354 validate_parse: true,
355 }
356 }
357
358 pub fn skip_parse_validation(mut self) -> Self {
360 self.validate_parse = false;
361 self
362 }
363
364 pub fn apply(self) -> Result<crate::edit::EditResult, ValidationError> {
368 use std::fs;
369
370 if !self.validate_parse {
371 return Ok(self.edit.apply()?);
372 }
373
374 let original = fs::read_to_string(&self.edit.file)?;
376
377 let edited = {
379 let mut content = original.clone();
380 let before = &content[self.edit.byte_start..self.edit.byte_end];
381
382 if !self.edit.expected_before.matches(before) {
384 return Err(ValidationError::from(
385 crate::edit::EditError::BeforeTextMismatch {
386 file: self.edit.file.clone(),
387 byte_start: self.edit.byte_start,
388 byte_end: self.edit.byte_end,
389 expected: format!("{:?}", self.edit.expected_before),
390 found: before.to_string(),
391 },
392 ));
393 }
394
395 content.replace_range(
397 self.edit.byte_start..self.edit.byte_end,
398 &self.edit.new_text,
399 );
400 content
401 };
402
403 let mut validator = ParseValidator::new()?;
405 validator.validate_edit(&original, &edited)?;
406
407 Ok(self.edit.apply()?)
409 }
410}
411
412impl From<crate::edit::EditError> for ValidationError {
413 fn from(e: crate::edit::EditError) -> Self {
414 match e {
415 crate::edit::EditError::Io(io) => ValidationError::Io(io),
416 other => ValidationError::SynValidationFailed {
417 message: other.to_string(),
418 code: String::new(),
419 },
420 }
421 }
422}
423
424#[cfg(test)]
425mod tests {
426 use super::*;
427
428 #[test]
429 fn test_parse_validator_valid() {
430 let mut validator = ParseValidator::new().unwrap();
431 let source = "fn main() { println!(\"hello\"); }";
432 assert!(validator.validate(source).is_ok());
433 }
434
435 #[test]
436 fn test_parse_validator_invalid() {
437 let mut validator = ParseValidator::new().unwrap();
438 let source = "fn main( { }"; let result = validator.validate(source);
440 assert!(matches!(
441 result,
442 Err(ValidationError::ParseErrorIntroduced { .. })
443 ));
444 }
445
446 #[test]
447 fn test_parse_validator_edit_introduces_error() {
448 let mut validator = ParseValidator::new().unwrap();
449 let original = "fn main() { let x = 1; }";
450 let edited = "fn main( { let x = 1; }"; let result = validator.validate_edit(original, edited);
453 assert!(matches!(
454 result,
455 Err(ValidationError::ParseErrorIntroduced { .. })
456 ));
457 }
458
459 #[test]
460 fn test_parse_validator_edit_preserves_existing_error() {
461 let mut validator = ParseValidator::new().unwrap();
462 let original = "fn main( { }";
464 let edited = "fn main( { let x = 1; }";
465
466 let result = validator.validate_edit(original, edited);
469 assert!(result.is_ok());
470 }
471
472 #[test]
473 fn test_syn_validate_item() {
474 assert!(syn_validate::validate_item("fn foo() {}").is_ok());
475 assert!(syn_validate::validate_item("struct Foo { x: i32 }").is_ok());
476 assert!(syn_validate::validate_item("not valid rust").is_err());
477 }
478
479 #[test]
480 fn test_syn_validate_expr() {
481 assert!(syn_validate::validate_expr("1 + 2").is_ok());
482 assert!(syn_validate::validate_expr("foo.bar()").is_ok());
483 assert!(syn_validate::validate_expr("if x { 1 } else { 2 }").is_ok());
484 assert!(syn_validate::validate_expr("fn foo() {}").is_err()); }
486
487 #[test]
488 fn test_syn_validate_match_arm_body() {
489 assert!(syn_validate::validate_match_arm_body("OtelExporter::None").is_ok());
490 assert!(syn_validate::validate_match_arm_body("OtelExporter::None,").is_ok());
491 assert!(syn_validate::validate_match_arm_body("{ do_something(); result }").is_ok());
492 }
493
494 #[test]
495 fn test_syn_validate_block() {
496 assert!(syn_validate::validate_block("let x = 1; x + 1").is_ok());
497 assert!(syn_validate::validate_block("println!(\"hello\");").is_ok());
498 }
499
500 #[test]
501 fn test_selector_validator_unique() {
502 assert!(SelectorValidator::check_unique(1, "test").is_ok());
503 assert!(matches!(
504 SelectorValidator::check_unique(0, "test"),
505 Err(ValidationError::NoMatch { .. })
506 ));
507 assert!(matches!(
508 SelectorValidator::check_unique(2, "test"),
509 Err(ValidationError::SelectorNotUnique { count: 2, .. })
510 ));
511 }
512
513 #[test]
514 fn test_selector_validator_found() {
515 assert!(SelectorValidator::check_found(1, "test").is_ok());
516 assert!(SelectorValidator::check_found(5, "test").is_ok());
517 assert!(matches!(
518 SelectorValidator::check_found(0, "test"),
519 Err(ValidationError::NoMatch { .. })
520 ));
521 }
522
523 #[test]
524 fn test_validated_edit_success() {
525 use crate::edit::Edit;
526 use std::fs;
527
528 let temp_dir = tempfile::tempdir().unwrap();
529 let file_path = temp_dir.path().join("test.rs");
530 fs::write(&file_path, "fn main() { let x = 1; }").unwrap();
531
532 let edit = Edit::new(&file_path, 12, 22, "let y = 2;", "let x = 1;");
533 let validated = ValidatedEdit::new(edit);
534 let result = validated.apply();
535
536 assert!(result.is_ok());
537 let content = fs::read_to_string(&file_path).unwrap();
538 assert_eq!(content, "fn main() { let y = 2; }");
539 }
540
541 #[test]
542 fn test_validated_edit_rejects_parse_error() {
543 use crate::edit::Edit;
544 use std::fs;
545
546 let temp_dir = tempfile::tempdir().unwrap();
547 let file_path = temp_dir.path().join("test.rs");
548 fs::write(&file_path, "fn main() { let x = 1; }").unwrap();
549
550 let edit = Edit::new(&file_path, 22, 24, "", " }");
552 let validated = ValidatedEdit::new(edit);
553 let result = validated.apply();
554
555 assert!(matches!(
556 result,
557 Err(ValidationError::ParseErrorIntroduced { .. })
558 ));
559
560 let content = fs::read_to_string(&file_path).unwrap();
562 assert_eq!(content, "fn main() { let x = 1; }");
563 }
564
565 #[test]
566 fn test_validated_edit_skip_validation() {
567 use crate::edit::Edit;
568 use std::fs;
569
570 let temp_dir = tempfile::tempdir().unwrap();
571 let file_path = temp_dir.path().join("test.rs");
572 fs::write(&file_path, "fn main() { let x = 1; }").unwrap();
573
574 let edit = Edit::new(&file_path, 22, 24, "", " }");
576 let validated = ValidatedEdit::new(edit).skip_parse_validation();
577 let result = validated.apply();
578
579 assert!(result.is_ok());
580 let content = fs::read_to_string(&file_path).unwrap();
582 assert_eq!(content, "fn main() { let x = 1;");
583 }
584}