ricecoder_refactoring/adapters/
rust.rs

1//! Rust-specific refactoring provider
2
3use crate::error::Result;
4use crate::providers::{RefactoringAnalysis, RefactoringProvider};
5use crate::types::{Refactoring, RefactoringType, ValidationResult};
6use regex::Regex;
7
8/// Rust-specific refactoring provider
9pub struct RustRefactoringProvider;
10
11impl RustRefactoringProvider {
12    /// Create a new Rust provider
13    pub fn new() -> Self {
14        Self
15    }
16
17    /// Check if code is valid Rust
18    fn is_valid_rust(code: &str) -> bool {
19        // Basic checks for Rust syntax
20        let open_braces = code.matches('{').count();
21        let close_braces = code.matches('}').count();
22        let open_parens = code.matches('(').count();
23        let close_parens = code.matches(')').count();
24        let open_brackets = code.matches('[').count();
25        let close_brackets = code.matches(']').count();
26
27        open_braces == close_braces
28            && open_parens == close_parens
29            && open_brackets == close_brackets
30    }
31
32    /// Apply a Rust-specific rename with word boundaries
33    pub fn apply_rust_rename(code: &str, old_name: &str, new_name: &str) -> Result<String> {
34        let pattern = format!(r"\b{}\b", regex::escape(old_name));
35        match Regex::new(&pattern) {
36            Ok(re) => Ok(re.replace_all(code, new_name).to_string()),
37            Err(_) => Ok(code.replace(old_name, new_name)),
38        }
39    }
40
41    /// Check for unsafe code
42    pub fn has_unsafe_code(code: &str) -> bool {
43        code.contains("unsafe")
44    }
45
46    /// Check for common Rust patterns
47    pub fn check_rust_patterns(code: &str) -> Vec<String> {
48        let mut issues = vec![];
49
50        if code.contains("unwrap()") {
51            issues.push("Code uses unwrap() which can panic".to_string());
52        }
53
54        if code.contains("panic!") {
55            issues.push("Code uses panic! macro".to_string());
56        }
57
58        if code.contains("todo!") {
59            issues.push("Code contains unimplemented todo!".to_string());
60        }
61
62        issues
63    }
64}
65
66impl Default for RustRefactoringProvider {
67    fn default() -> Self {
68        Self::new()
69    }
70}
71
72impl RefactoringProvider for RustRefactoringProvider {
73    fn analyze_refactoring(
74        &self,
75        _code: &str,
76        _language: &str,
77        refactoring_type: RefactoringType,
78    ) -> Result<RefactoringAnalysis> {
79        let complexity = match refactoring_type {
80            RefactoringType::Rename => 4,
81            RefactoringType::Extract => 7,
82            RefactoringType::Inline => 6,
83            RefactoringType::Move => 8,
84            RefactoringType::ChangeSignature => 9,
85            RefactoringType::RemoveUnused => 5,
86            RefactoringType::Simplify => 6,
87        };
88
89        Ok(RefactoringAnalysis {
90            applicable: true,
91            reason: None,
92            complexity,
93        })
94    }
95
96    fn apply_refactoring(
97        &self,
98        code: &str,
99        _language: &str,
100        refactoring: &Refactoring,
101    ) -> Result<String> {
102        // Apply Rust-specific refactoring
103        match refactoring.refactoring_type {
104            RefactoringType::Rename => {
105                // Rust rename: use word boundaries to avoid partial matches
106                Self::apply_rust_rename(code, &refactoring.target.symbol, &refactoring.target.symbol)
107            }
108            _ => Ok(code.to_string()),
109        }
110    }
111
112    fn validate_refactoring(
113        &self,
114        original: &str,
115        refactored: &str,
116        _language: &str,
117    ) -> Result<ValidationResult> {
118        let mut errors = vec![];
119        let mut warnings = vec![];
120
121        // Check if refactored code is not empty
122        if refactored.is_empty() {
123            errors.push("Refactored code cannot be empty".to_string());
124        }
125
126        // Check if content changed
127        if original == refactored {
128            warnings.push("No changes were made".to_string());
129        }
130
131        // Check Rust syntax validity
132        if !Self::is_valid_rust(refactored) {
133            errors.push("Refactored code has syntax errors (brace/paren mismatch)".to_string());
134        }
135
136        // Check for common Rust issues
137        if refactored.contains("unsafe") && !original.contains("unsafe") {
138            warnings.push("Refactoring introduced unsafe code".to_string());
139        }
140
141        Ok(ValidationResult {
142            passed: errors.is_empty(),
143            errors,
144            warnings,
145        })
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152
153    #[test]
154    fn test_rust_provider_analyze() -> Result<()> {
155        let provider = RustRefactoringProvider::new();
156        let analysis = provider.analyze_refactoring("fn main() {}", "rust", RefactoringType::Rename)?;
157
158        assert!(analysis.applicable);
159        assert_eq!(analysis.complexity, 4);
160
161        Ok(())
162    }
163
164    #[test]
165    fn test_rust_provider_validate_valid() -> Result<()> {
166        let provider = RustRefactoringProvider::new();
167        let result = provider.validate_refactoring("fn main() {}", "fn main() { println!(); }", "rust")?;
168
169        assert!(result.passed);
170
171        Ok(())
172    }
173
174    #[test]
175    fn test_rust_provider_validate_invalid_braces() -> Result<()> {
176        let provider = RustRefactoringProvider::new();
177        let result = provider.validate_refactoring("fn main() {}", "fn main() { ", "rust")?;
178
179        assert!(!result.passed);
180
181        Ok(())
182    }
183
184    #[test]
185    fn test_is_valid_rust() {
186        assert!(RustRefactoringProvider::is_valid_rust("fn main() {}"));
187        assert!(RustRefactoringProvider::is_valid_rust("let x = [1, 2, 3];"));
188        assert!(!RustRefactoringProvider::is_valid_rust("fn main() {"));
189        assert!(!RustRefactoringProvider::is_valid_rust("let x = [1, 2, 3;"));
190    }
191}