codemod_core/transform/
applier.rs1use similar::TextDiff;
12
13use crate::error::CodemodError;
14use crate::pattern::matcher::Match;
15use crate::pattern::Pattern;
16
17pub struct TransformApplier;
19
20impl TransformApplier {
21 pub fn apply(source: &str, pattern: &Pattern, matches: &[Match]) -> crate::Result<String> {
30 if matches.is_empty() {
31 return Ok(source.to_string());
32 }
33
34 let mut sorted: Vec<&Match> = matches.iter().collect();
37 sorted.sort_by(|a, b| b.byte_range.start.cmp(&a.byte_range.start));
38
39 let mut result = source.to_string();
40
41 for m in &sorted {
42 let replacement = Self::render_replacement(pattern, m)?;
43 let indented = Self::preserve_indentation(source, m, &replacement);
45 result.replace_range(m.byte_range.clone(), &indented);
46 }
47
48 Ok(result)
49 }
50
51 pub fn generate_diff(file_path: &str, original: &str, transformed: &str) -> String {
56 let diff = TextDiff::from_lines(original, transformed);
57 let mut output = String::new();
58
59 output.push_str(&format!("--- a/{file_path}\n"));
60 output.push_str(&format!("+++ b/{file_path}\n"));
61
62 for hunk in diff.unified_diff().context_radius(3).iter_hunks() {
63 output.push_str(&format!("{hunk}"));
64 }
65
66 output
67 }
68
69 fn render_replacement(pattern: &Pattern, m: &Match) -> crate::Result<String> {
76 let mut result = pattern.after_template.clone();
77
78 for var in &pattern.variables {
79 if let Some(value) = m.bindings.get(&var.name) {
80 result = result.replace(&var.name, value);
81 } else {
82 if result.contains(&var.name) {
86 return Err(CodemodError::Transform(format!(
87 "Variable '{}' has no binding for match at byte offset {}",
88 var.name, m.byte_range.start
89 )));
90 }
91 }
92 }
93
94 Ok(result)
95 }
96
97 fn preserve_indentation(source: &str, m: &Match, replacement: &str) -> String {
100 let line_start = source[..m.byte_range.start]
102 .rfind('\n')
103 .map(|p| p + 1)
104 .unwrap_or(0);
105 let indent: String = source[line_start..m.byte_range.start]
106 .chars()
107 .take_while(|c| c.is_whitespace())
108 .collect();
109
110 if indent.is_empty() || !replacement.contains('\n') {
111 return replacement.to_string();
112 }
113
114 let mut lines = replacement.lines();
116 let mut result = String::new();
117 if let Some(first) = lines.next() {
118 result.push_str(first);
119 }
120 for line in lines {
121 result.push('\n');
122 if !line.is_empty() {
123 result.push_str(&indent);
124 }
125 result.push_str(line);
126 }
127 if replacement.ends_with('\n') {
129 result.push('\n');
130 }
131 result
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138 use crate::pattern::matcher::Position;
139 use crate::pattern::PatternVar;
140 use std::collections::HashMap;
141
142 fn make_match(
143 start: usize,
144 end: usize,
145 text: &str,
146 bindings: HashMap<String, String>,
147 ) -> Match {
148 Match {
149 byte_range: start..end,
150 start_position: Position {
151 line: 0,
152 column: start,
153 },
154 end_position: Position {
155 line: 0,
156 column: end,
157 },
158 matched_text: text.to_string(),
159 bindings,
160 }
161 }
162
163 #[test]
164 fn test_apply_single_replacement() {
165 let source = "println!(x);";
166 let pattern = Pattern::new(
167 "println!($var1)".into(),
168 "log::info!($var1)".into(),
169 vec![PatternVar {
170 name: "$var1".into(),
171 node_type: None,
172 }],
173 "rust".into(),
174 0.9,
175 );
176 let mut bindings = HashMap::new();
177 bindings.insert("$var1".into(), "x".into());
178 let m = make_match(0, 12, "println!(x);", bindings);
179
180 let result = TransformApplier::apply(source, &pattern, &[m]).unwrap();
181 assert_eq!(result, "log::info!(x)");
182 }
183
184 #[test]
185 fn test_generate_diff() {
186 let original = "line1\nline2\nline3\n";
187 let transformed = "line1\nchanged\nline3\n";
188 let diff = TransformApplier::generate_diff("test.rs", original, transformed);
189 assert!(diff.contains("--- a/test.rs"));
190 assert!(diff.contains("+++ b/test.rs"));
191 assert!(diff.contains("-line2"));
192 assert!(diff.contains("+changed"));
193 }
194
195 #[test]
196 fn test_preserve_indentation() {
197 let source = "fn main() {\n old_call();\n}";
198 let m = make_match(16, 27, "old_call()", HashMap::new());
199 let replacement = "new_call(\n arg\n)";
200 let result = TransformApplier::preserve_indentation(source, &m, replacement);
201 assert!(result.contains(" arg"));
202 }
203
204 #[test]
205 fn test_empty_matches() {
206 let source = "hello world";
207 let pattern = Pattern::new("a".into(), "b".into(), vec![], "rust".into(), 0.9);
208 let result = TransformApplier::apply(source, &pattern, &[]).unwrap();
209 assert_eq!(result, source);
210 }
211}