Skip to main content

codex_patcher/
validate.rs

1//! Validation module for ensuring edit safety.
2//!
3//! This module provides:
4//! - Parse validation (tree-sitter ERROR node detection)
5//! - syn validation for generated snippets
6//! - Selector uniqueness checks
7//!
8//! # Hard Rules (Never Violate)
9//!
10//! 1. **Parse validation**: After editing, re-parse with tree-sitter.
11//!    If the file has ERROR nodes that weren't there before, roll back.
12//! 2. **Selector uniqueness**: If a structural query matches 0 or >1
13//!    locations, refuse to edit. No guessing.
14
15use crate::ts::{ParsedSource, RustParser, TreeSitterError};
16use std::path::Path;
17use thiserror::Error;
18
19/// Validation errors.
20#[derive(Error, Debug)]
21pub enum ValidationError {
22    #[error("Parse error introduced: found {count} new ERROR nodes")]
23    ParseErrorIntroduced {
24        count: usize,
25        errors: Vec<ErrorLocation>,
26    },
27
28    #[error("Selector matched {count} locations, expected exactly 1")]
29    SelectorNotUnique { count: usize, pattern: String },
30
31    #[error("Selector matched 0 locations")]
32    NoMatch { pattern: String },
33
34    #[error("syn validation failed: {message}")]
35    SynValidationFailed { message: String, code: String },
36
37    #[error("Tree-sitter error: {0}")]
38    TreeSitter(#[from] TreeSitterError),
39
40    #[error("IO error: {0}")]
41    Io(#[from] std::io::Error),
42}
43
44/// Location of an error node in the source.
45#[derive(Debug, Clone)]
46pub struct ErrorLocation {
47    pub byte_start: usize,
48    pub byte_end: usize,
49    pub line: usize,
50    pub column: usize,
51    pub context: String,
52}
53
54/// Parse validator using tree-sitter.
55pub struct ParseValidator {
56    parser: RustParser,
57}
58
59impl ParseValidator {
60    /// Create a new parse validator.
61    pub fn new() -> Result<Self, TreeSitterError> {
62        Ok(Self {
63            parser: RustParser::new()?,
64        })
65    }
66
67    /// Validate that source has no parse errors.
68    pub fn validate(&mut self, source: &str) -> Result<(), ValidationError> {
69        let parsed = self.parser.parse_with_source(source)?;
70        let errors = collect_errors(&parsed, source);
71
72        if !errors.is_empty() {
73            return Err(ValidationError::ParseErrorIntroduced {
74                count: errors.len(),
75                errors,
76            });
77        }
78
79        Ok(())
80    }
81
82    /// Validate a file path.
83    pub fn validate_file(&mut self, path: impl AsRef<Path>) -> Result<(), ValidationError> {
84        let source = std::fs::read_to_string(path)?;
85        self.validate(&source)
86    }
87
88    /// Compare two sources and check if new errors were introduced.
89    ///
90    /// Returns Ok if the edited source doesn't introduce new parse errors
91    /// that weren't in the original.
92    pub fn validate_edit(&mut self, original: &str, edited: &str) -> Result<(), ValidationError> {
93        let original_parsed = self.parser.parse_with_source(original)?;
94        let edited_parsed = self.parser.parse_with_source(edited)?;
95
96        let original_errors = collect_error_positions(&original_parsed);
97        let edited_errors = collect_errors(&edited_parsed, edited);
98
99        // Filter to only new errors (not present in original)
100        let new_errors: Vec<_> = edited_errors
101            .into_iter()
102            .filter(|e| !original_errors.contains(&(e.byte_start, e.byte_end)))
103            .collect();
104
105        if !new_errors.is_empty() {
106            return Err(ValidationError::ParseErrorIntroduced {
107                count: new_errors.len(),
108                errors: new_errors,
109            });
110        }
111
112        Ok(())
113    }
114}
115
116impl Default for ParseValidator {
117    /// # Panics
118    ///
119    /// Panics if tree-sitter parser initialization fails (e.g., out of memory).
120    fn default() -> Self {
121        Self::new().expect("tree-sitter parser initialization failed")
122    }
123}
124
125/// Pooled validation functions that reuse parsers from thread-local pool.
126///
127/// These functions provide significant performance improvements for multi-patch
128/// workloads by avoiding redundant parser allocation and initialization.
129pub mod pooled {
130    use super::*;
131    use crate::pool;
132
133    /// Validate source code using pooled parser.
134    pub fn validate(source: &str) -> Result<(), ValidationError> {
135        pool::with_parser(|parser| {
136            let parsed = parser.parse_with_source(source)?;
137            let errors = collect_errors(&parsed, source);
138
139            if !errors.is_empty() {
140                return Err(ValidationError::ParseErrorIntroduced {
141                    count: errors.len(),
142                    errors,
143                });
144            }
145
146            Ok(())
147        })?
148    }
149
150    /// Compare two sources and check if new errors were introduced using pooled parser.
151    pub fn validate_edit(original: &str, edited: &str) -> Result<(), ValidationError> {
152        pool::with_parser(|parser| {
153            let original_parsed = parser.parse_with_source(original)?;
154            let original_errors = collect_error_positions(&original_parsed);
155
156            let edited_parsed = parser.parse_with_source(edited)?;
157            let edited_errors = collect_error_positions(&edited_parsed);
158
159            // Check if new errors were introduced
160            let new_errors: Vec<_> = edited_errors
161                .difference(&original_errors)
162                .copied()
163                .collect();
164
165            if !new_errors.is_empty() {
166                let error_details = collect_errors(&edited_parsed, edited);
167                return Err(ValidationError::ParseErrorIntroduced {
168                    count: error_details.len(),
169                    errors: error_details,
170                });
171            }
172
173            Ok(())
174        })?
175    }
176}
177
178/// Collect all error nodes from a parsed source.
179fn collect_errors(parsed: &ParsedSource<'_>, source: &str) -> Vec<ErrorLocation> {
180    let mut errors = Vec::new();
181    collect_errors_recursive(parsed.root_node(), source, &mut errors);
182    errors
183}
184
185fn collect_errors_recursive(
186    node: tree_sitter::Node<'_>,
187    source: &str,
188    errors: &mut Vec<ErrorLocation>,
189) {
190    if node.is_error() || node.is_missing() {
191        let start = node.start_position();
192        let byte_start = node.start_byte();
193        let byte_end = node.end_byte();
194
195        // Extract context (up to 50 chars around the error)
196        let context_start = byte_start.saturating_sub(20);
197        let context_end = (byte_end + 20).min(source.len());
198        let context = source
199            .get(context_start..context_end)
200            .unwrap_or("")
201            .replace('\n', "\\n");
202
203        errors.push(ErrorLocation {
204            byte_start,
205            byte_end,
206            line: start.row + 1,
207            column: start.column + 1,
208            context,
209        });
210    }
211
212    let mut cursor = node.walk();
213    for child in node.children(&mut cursor) {
214        collect_errors_recursive(child, source, errors);
215    }
216}
217
218/// Collect error positions (for comparison).
219fn collect_error_positions(parsed: &ParsedSource<'_>) -> std::collections::HashSet<(usize, usize)> {
220    let mut positions = std::collections::HashSet::new();
221    collect_error_positions_recursive(parsed.root_node(), &mut positions);
222    positions
223}
224
225fn collect_error_positions_recursive(
226    node: tree_sitter::Node<'_>,
227    positions: &mut std::collections::HashSet<(usize, usize)>,
228) {
229    if node.is_error() || node.is_missing() {
230        positions.insert((node.start_byte(), node.end_byte()));
231    }
232
233    let mut cursor = node.walk();
234    for child in node.children(&mut cursor) {
235        collect_error_positions_recursive(child, positions);
236    }
237}
238
239/// syn-based validation for generated Rust code snippets.
240pub mod syn_validate {
241    use super::ValidationError;
242
243    /// Validate that code parses as a valid Rust item (fn, struct, impl, etc.).
244    pub fn validate_item(code: &str) -> Result<(), ValidationError> {
245        syn::parse_str::<syn::Item>(code).map_err(|e| ValidationError::SynValidationFailed {
246            message: e.to_string(),
247            code: code.to_string(),
248        })?;
249        Ok(())
250    }
251
252    /// Validate that code parses as a valid Rust expression.
253    pub fn validate_expr(code: &str) -> Result<(), ValidationError> {
254        syn::parse_str::<syn::Expr>(code).map_err(|e| ValidationError::SynValidationFailed {
255            message: e.to_string(),
256            code: code.to_string(),
257        })?;
258        Ok(())
259    }
260
261    /// Validate that code parses as a valid Rust statement.
262    pub fn validate_stmt(code: &str) -> Result<(), ValidationError> {
263        syn::parse_str::<syn::Stmt>(code).map_err(|e| ValidationError::SynValidationFailed {
264            message: e.to_string(),
265            code: code.to_string(),
266        })?;
267        Ok(())
268    }
269
270    /// Validate that code parses as a valid Rust type.
271    pub fn validate_type(code: &str) -> Result<(), ValidationError> {
272        syn::parse_str::<syn::Type>(code).map_err(|e| ValidationError::SynValidationFailed {
273            message: e.to_string(),
274            code: code.to_string(),
275        })?;
276        Ok(())
277    }
278
279    /// Validate that code parses as a complete Rust file.
280    pub fn validate_file(code: &str) -> Result<(), ValidationError> {
281        syn::parse_file(code).map_err(|e| ValidationError::SynValidationFailed {
282            message: e.to_string(),
283            code: code.to_string(),
284        })?;
285        Ok(())
286    }
287
288    /// Validate match arm body (expression).
289    pub fn validate_match_arm_body(code: &str) -> Result<(), ValidationError> {
290        // Match arm bodies are expressions, possibly with a trailing comma
291        let trimmed = code.trim().trim_end_matches(',');
292        validate_expr(trimmed)
293    }
294
295    /// Validate function body (block contents).
296    pub fn validate_block(code: &str) -> Result<(), ValidationError> {
297        // Try parsing as a block
298        let block_code = format!("{{ {} }}", code);
299        syn::parse_str::<syn::Block>(&block_code).map_err(|e| {
300            ValidationError::SynValidationFailed {
301                message: e.to_string(),
302                code: code.to_string(),
303            }
304        })?;
305        Ok(())
306    }
307}
308
309/// Selector uniqueness checker.
310pub struct SelectorValidator;
311
312impl SelectorValidator {
313    /// Check that a pattern match count is exactly 1.
314    pub fn check_unique(count: usize, pattern: &str) -> Result<(), ValidationError> {
315        match count {
316            0 => Err(ValidationError::NoMatch {
317                pattern: pattern.to_string(),
318            }),
319            1 => Ok(()),
320            n => Err(ValidationError::SelectorNotUnique {
321                count: n,
322                pattern: pattern.to_string(),
323            }),
324        }
325    }
326
327    /// Check that a pattern matched at least once.
328    pub fn check_found(count: usize, pattern: &str) -> Result<(), ValidationError> {
329        if count == 0 {
330            Err(ValidationError::NoMatch {
331                pattern: pattern.to_string(),
332            })
333        } else {
334            Ok(())
335        }
336    }
337}
338
339/// Validated edit - wraps Edit with automatic parse validation.
340///
341/// Ensures that:
342/// 1. The edit doesn't introduce new parse errors
343/// 2. Generated code snippets are valid according to syn
344pub struct ValidatedEdit {
345    edit: crate::edit::Edit,
346    validate_parse: bool,
347}
348
349impl ValidatedEdit {
350    /// Create a validated edit from an existing edit.
351    pub fn new(edit: crate::edit::Edit) -> Self {
352        Self {
353            edit,
354            validate_parse: true,
355        }
356    }
357
358    /// Disable parse validation (useful when intentionally editing broken code).
359    pub fn skip_parse_validation(mut self) -> Self {
360        self.validate_parse = false;
361        self
362    }
363
364    /// Apply the edit with validation.
365    ///
366    /// Returns an error if the edit would introduce parse errors.
367    pub fn apply(self) -> Result<crate::edit::EditResult, ValidationError> {
368        use std::fs;
369
370        if !self.validate_parse {
371            return Ok(self.edit.apply()?);
372        }
373
374        // Read original content
375        let original = fs::read_to_string(&self.edit.file)?;
376
377        // Compute what the edited content would be
378        let edited = {
379            let mut content = original.clone();
380            let before = &content[self.edit.byte_start..self.edit.byte_end];
381
382            // Check verification
383            if !self.edit.expected_before.matches(before) {
384                return Err(ValidationError::from(
385                    crate::edit::EditError::BeforeTextMismatch {
386                        file: self.edit.file.clone(),
387                        byte_start: self.edit.byte_start,
388                        byte_end: self.edit.byte_end,
389                        expected: format!("{:?}", self.edit.expected_before),
390                        found: before.to_string(),
391                    },
392                ));
393            }
394
395            // Simulate edit
396            content.replace_range(
397                self.edit.byte_start..self.edit.byte_end,
398                &self.edit.new_text,
399            );
400            content
401        };
402
403        // Validate the edited content
404        let mut validator = ParseValidator::new()?;
405        validator.validate_edit(&original, &edited)?;
406
407        // Now apply for real
408        Ok(self.edit.apply()?)
409    }
410}
411
412impl From<crate::edit::EditError> for ValidationError {
413    fn from(e: crate::edit::EditError) -> Self {
414        match e {
415            crate::edit::EditError::Io(io) => ValidationError::Io(io),
416            other => ValidationError::SynValidationFailed {
417                message: other.to_string(),
418                code: String::new(),
419            },
420        }
421    }
422}
423
424#[cfg(test)]
425mod tests {
426    use super::*;
427
428    #[test]
429    fn test_parse_validator_valid() {
430        let mut validator = ParseValidator::new().unwrap();
431        let source = "fn main() { println!(\"hello\"); }";
432        assert!(validator.validate(source).is_ok());
433    }
434
435    #[test]
436    fn test_parse_validator_invalid() {
437        let mut validator = ParseValidator::new().unwrap();
438        let source = "fn main( { }"; // Missing closing paren
439        let result = validator.validate(source);
440        assert!(matches!(
441            result,
442            Err(ValidationError::ParseErrorIntroduced { .. })
443        ));
444    }
445
446    #[test]
447    fn test_parse_validator_edit_introduces_error() {
448        let mut validator = ParseValidator::new().unwrap();
449        let original = "fn main() { let x = 1; }";
450        let edited = "fn main( { let x = 1; }"; // Removed closing paren
451
452        let result = validator.validate_edit(original, edited);
453        assert!(matches!(
454            result,
455            Err(ValidationError::ParseErrorIntroduced { .. })
456        ));
457    }
458
459    #[test]
460    fn test_parse_validator_edit_preserves_existing_error() {
461        let mut validator = ParseValidator::new().unwrap();
462        // Both have the same error
463        let original = "fn main( { }";
464        let edited = "fn main( { let x = 1; }";
465
466        // This should pass because we're not introducing NEW errors
467        // (the error existed in the original)
468        let result = validator.validate_edit(original, edited);
469        assert!(result.is_ok());
470    }
471
472    #[test]
473    fn test_syn_validate_item() {
474        assert!(syn_validate::validate_item("fn foo() {}").is_ok());
475        assert!(syn_validate::validate_item("struct Foo { x: i32 }").is_ok());
476        assert!(syn_validate::validate_item("not valid rust").is_err());
477    }
478
479    #[test]
480    fn test_syn_validate_expr() {
481        assert!(syn_validate::validate_expr("1 + 2").is_ok());
482        assert!(syn_validate::validate_expr("foo.bar()").is_ok());
483        assert!(syn_validate::validate_expr("if x { 1 } else { 2 }").is_ok());
484        assert!(syn_validate::validate_expr("fn foo() {}").is_err()); // Not an expr
485    }
486
487    #[test]
488    fn test_syn_validate_match_arm_body() {
489        assert!(syn_validate::validate_match_arm_body("OtelExporter::None").is_ok());
490        assert!(syn_validate::validate_match_arm_body("OtelExporter::None,").is_ok());
491        assert!(syn_validate::validate_match_arm_body("{ do_something(); result }").is_ok());
492    }
493
494    #[test]
495    fn test_syn_validate_block() {
496        assert!(syn_validate::validate_block("let x = 1; x + 1").is_ok());
497        assert!(syn_validate::validate_block("println!(\"hello\");").is_ok());
498    }
499
500    #[test]
501    fn test_selector_validator_unique() {
502        assert!(SelectorValidator::check_unique(1, "test").is_ok());
503        assert!(matches!(
504            SelectorValidator::check_unique(0, "test"),
505            Err(ValidationError::NoMatch { .. })
506        ));
507        assert!(matches!(
508            SelectorValidator::check_unique(2, "test"),
509            Err(ValidationError::SelectorNotUnique { count: 2, .. })
510        ));
511    }
512
513    #[test]
514    fn test_selector_validator_found() {
515        assert!(SelectorValidator::check_found(1, "test").is_ok());
516        assert!(SelectorValidator::check_found(5, "test").is_ok());
517        assert!(matches!(
518            SelectorValidator::check_found(0, "test"),
519            Err(ValidationError::NoMatch { .. })
520        ));
521    }
522
523    #[test]
524    fn test_validated_edit_success() {
525        use crate::edit::Edit;
526        use std::fs;
527
528        let temp_dir = tempfile::tempdir().unwrap();
529        let file_path = temp_dir.path().join("test.rs");
530        fs::write(&file_path, "fn main() { let x = 1; }").unwrap();
531
532        let edit = Edit::new(&file_path, 12, 22, "let y = 2;", "let x = 1;");
533        let validated = ValidatedEdit::new(edit);
534        let result = validated.apply();
535
536        assert!(result.is_ok());
537        let content = fs::read_to_string(&file_path).unwrap();
538        assert_eq!(content, "fn main() { let y = 2; }");
539    }
540
541    #[test]
542    fn test_validated_edit_rejects_parse_error() {
543        use crate::edit::Edit;
544        use std::fs;
545
546        let temp_dir = tempfile::tempdir().unwrap();
547        let file_path = temp_dir.path().join("test.rs");
548        fs::write(&file_path, "fn main() { let x = 1; }").unwrap();
549
550        // This edit breaks the syntax (removes closing brace)
551        let edit = Edit::new(&file_path, 22, 24, "", " }");
552        let validated = ValidatedEdit::new(edit);
553        let result = validated.apply();
554
555        assert!(matches!(
556            result,
557            Err(ValidationError::ParseErrorIntroduced { .. })
558        ));
559
560        // File should be unchanged
561        let content = fs::read_to_string(&file_path).unwrap();
562        assert_eq!(content, "fn main() { let x = 1; }");
563    }
564
565    #[test]
566    fn test_validated_edit_skip_validation() {
567        use crate::edit::Edit;
568        use std::fs;
569
570        let temp_dir = tempfile::tempdir().unwrap();
571        let file_path = temp_dir.path().join("test.rs");
572        fs::write(&file_path, "fn main() { let x = 1; }").unwrap();
573
574        // This edit breaks syntax, but we skip validation
575        let edit = Edit::new(&file_path, 22, 24, "", " }");
576        let validated = ValidatedEdit::new(edit).skip_parse_validation();
577        let result = validated.apply();
578
579        assert!(result.is_ok());
580        // File should be changed (even though it's now invalid)
581        let content = fs::read_to_string(&file_path).unwrap();
582        assert_eq!(content, "fn main() { let x = 1;");
583    }
584}