reinhardt_admin_cli/
migrate_v2.rs1use std::path::PathBuf;
7
8use clap::Args;
9
10pub mod rewriter;
11pub mod rules;
12pub mod walker;
13
14#[derive(Args, Debug)]
16pub struct MigrateV2Args {
17 #[arg(default_value = ".")]
19 pub path: PathBuf,
20
21 #[arg(long)]
23 pub dry_run: bool,
24
25 #[arg(long, value_delimiter = ',')]
27 pub skip: Vec<String>,
28}
29
30pub fn run(args: MigrateV2Args) -> anyhow::Result<()> {
40 let all_rules = rules::all();
41 let known_rule_names: std::collections::BTreeSet<&'static str> =
42 all_rules.iter().map(|r| r.name()).collect();
43 let unknown: Vec<&str> = args
44 .skip
45 .iter()
46 .map(String::as_str)
47 .filter(|name| !known_rule_names.contains(name))
48 .collect();
49 if !unknown.is_empty() {
50 anyhow::bail!("unknown --skip rule(s): {}", unknown.join(", "));
51 }
52 let rules: Vec<_> = all_rules
53 .into_iter()
54 .filter(|r| !args.skip.iter().any(|s| s == r.name()))
55 .collect();
56
57 let files = walker::find_rs_files(&args.path)?;
58 let mut changed = 0_usize;
59
60 for path in files {
61 let src = read_developer_file(&path)?;
62 let parsed: syn::File = match syn::parse_file(&src) {
63 Ok(f) => f,
64 Err(_) => continue,
66 };
67 let mut out_ast = parsed.clone();
68 for r in &rules {
69 out_ast = r.rewrite(out_ast);
70 }
71
72 let out = apply_changes_preserving_formatting(&src, &parsed, &out_ast);
73 if out != src {
74 changed += 1;
75 if args.dry_run {
76 println!("would rewrite: {}", path.display());
77 } else {
78 write_developer_file(&path, &out)?;
79 println!("rewrote: {}", path.display());
80 }
81 }
82 }
83
84 println!(
85 "\nDone. {} file(s) {}.",
86 changed,
87 if args.dry_run {
88 "would change"
89 } else {
90 "changed"
91 }
92 );
93 Ok(())
94}
95
96fn apply_changes_preserving_formatting(
104 src: &str,
105 parsed: &syn::File,
106 out_ast: &syn::File,
107) -> String {
108 let mut result = String::with_capacity(src.len() + 1024);
109 let mut last_pos: usize = 0;
110
111 let item_count = std::cmp::min(parsed.items.len(), out_ast.items.len());
112 for i in 0..item_count {
113 let orig_item = &parsed.items[i];
114 let new_item = &out_ast.items[i];
115
116 let formatted_orig = format_single_item(orig_item);
117 let formatted_new = format_single_item(new_item);
118
119 let (start_byte, end_byte) = find_item_in_source(src, last_pos, &formatted_orig);
120
121 result.push_str(&src[last_pos..start_byte]);
123
124 if formatted_orig == formatted_new {
125 result.push_str(&src[start_byte..end_byte]);
127 } else {
128 result.push_str(&format_single_item(new_item));
130 }
131
132 last_pos = end_byte;
133 }
134
135 if last_pos < src.len() {
137 result.push_str(&src[last_pos..]);
138 }
139
140 result
141}
142
143fn format_single_item(item: &syn::Item) -> String {
148 let file = syn::File {
149 shebang: None,
150 attrs: vec![],
151 items: vec![item.clone()],
152 };
153 let formatted = prettyplease::unparse(&file);
154 formatted.trim_end().to_string()
155}
156
157fn find_item_in_source(src: &str, search_from: usize, item_tokens: &str) -> (usize, usize) {
162 let anchor = item_tokens
163 .lines()
164 .find(|l| {
165 let trimmed = l.trim();
166 !trimmed.is_empty()
167 && !trimmed.starts_with("//")
168 && !trimmed.starts_with("#[")
169 && !trimmed.starts_with("///")
170 })
171 .unwrap_or("");
172
173 if anchor.is_empty() {
174 return (search_from, search_from);
175 }
176
177 let rest = &src[search_from..];
178 let start = match rest.find(anchor) {
179 Some(pos) => search_from + pos,
180 None => return (search_from, search_from),
181 };
182
183 let after_start = &src[start..];
184 let end_offset = find_item_end_offset(after_start);
185 (start, start + end_offset)
186}
187
188fn find_item_end_offset(src: &str) -> usize {
194 let mut brace_depth: i32 = 0;
195 let mut has_block = false;
196 let bytes = src.as_bytes();
197 let len = bytes.len();
198 let mut i = 0;
199
200 while i < len {
201 let ch = bytes[i];
202
203 if ch == b'/' && i + 1 < len && bytes[i + 1] == b'/' {
205 while i < len && bytes[i] != b'\n' {
206 i += 1;
207 }
208 continue;
209 }
210
211 if ch == b'/' && i + 1 < len && bytes[i + 1] == b'*' {
213 i += 2;
214 while i + 1 < len && !(bytes[i] == b'*' && bytes[i + 1] == b'/') {
215 i += 1;
216 }
217 if i + 1 < len {
218 i += 2; }
220 continue;
221 }
222
223 if ch == b'r' && i + 1 < len {
225 let next = bytes[i + 1];
226 if next == b'"' || next == b'#' {
227 let hash_count = if next == b'"' {
228 0
229 } else {
230 let mut count = 0;
231 let mut j = i + 1;
232 while j < len && bytes[j] == b'#' {
233 count += 1;
234 j += 1;
235 }
236 if j < len && bytes[j] == b'"' {
237 i = j; count
239 } else {
240 i += 1;
241 continue;
242 }
243 };
244 i += 1; while i < len {
246 if bytes[i] == b'"' {
247 let mut h = 0;
249 let mut j = i + 1;
250 while j < len && bytes[j] == b'#' && h < hash_count {
251 h += 1;
252 j += 1;
253 }
254 if h == hash_count {
255 i = j;
256 break;
257 }
258 }
259 if bytes[i] == b'\\' && i + 1 < len {
260 i += 2; } else {
262 i += 1;
263 }
264 }
265 continue;
266 }
267 }
268
269 if ch == b'"' {
271 i += 1;
272 while i < len {
273 if bytes[i] == b'"' {
274 i += 1;
275 break;
276 }
277 if bytes[i] == b'\\' && i + 1 < len {
278 i += 2; } else {
280 i += 1;
281 }
282 }
283 continue;
284 }
285
286 if ch == b'b' && i + 1 < len && bytes[i + 1] == b'\'' {
288 i += 2; while i < len {
290 if bytes[i] == b'\'' {
291 i += 1;
292 break;
293 }
294 if bytes[i] == b'\\' && i + 1 < len {
295 i += 2;
296 } else {
297 i += 1;
298 }
299 }
300 continue;
301 }
302
303 if ch == b'\'' {
305 i += 1; if i < len {
307 if bytes[i].is_ascii_alphabetic() || bytes[i] == b'_' {
309 while i < len && (bytes[i].is_ascii_alphanumeric() || bytes[i] == b'_') {
310 i += 1;
311 }
312 } else {
313 if bytes[i] == b'\\' && i + 1 < len {
315 i += 2;
316 } else {
317 i += 1;
318 }
319 if i < len && bytes[i] == b'\'' {
321 i += 1;
322 }
323 }
324 }
325 continue;
326 }
327
328 match ch {
329 b'{' => {
330 brace_depth += 1;
331 has_block = true;
332 }
333 b'}' => {
334 brace_depth -= 1;
335 if brace_depth == 0 && has_block {
336 return i + 1;
337 }
338 }
339 b';' if brace_depth == 0 => {
340 return i + 1;
341 }
342 _ => {}
343 }
344 i += 1;
345 }
346
347 src.len()
348}
349
350fn read_developer_file(path: &std::path::Path) -> anyhow::Result<String> {
356 let canonical = path.canonicalize()?;
357 let mut file = std::fs::File::open(canonical)?; let mut buf = String::new();
359 std::io::Read::read_to_string(&mut file, &mut buf)?;
360 Ok(buf)
361}
362
363fn write_developer_file(path: &std::path::Path, content: &str) -> anyhow::Result<()> {
370 let canonical = path.canonicalize()?;
371 let parent = canonical
372 .parent()
373 .ok_or_else(|| anyhow::anyhow!("no parent directory for {}", canonical.display()))?;
374 let file_name = canonical
375 .file_name()
376 .and_then(|n| n.to_str())
377 .unwrap_or("rewrite");
378 let random_suffix: u32 = {
379 use std::time::{SystemTime, UNIX_EPOCH};
380 let nanos = SystemTime::now()
381 .duration_since(UNIX_EPOCH)
382 .unwrap_or_default()
383 .subsec_nanos();
384 nanos ^ std::process::id()
385 };
386 let tmp = parent.join(format!(".{file_name}.{random_suffix:x}.tmp")); if let Err(e) = std::fs::write(&tmp, content) {
388 let _ = std::fs::remove_file(&tmp);
389 return Err(e.into());
390 }
391 if let Err(e) = std::fs::rename(&tmp, canonical) {
392 let _ = std::fs::remove_file(&tmp);
393 return Err(e.into());
394 }
395 Ok(())
396}
397
398#[cfg(test)]
399mod tests {
400 use super::*;
401 use rstest::rstest;
402
403 #[rstest]
405 fn no_changes_output_identical() {
406 let src = "//! Module doc comment.\n\nuse std::collections::HashMap;\n\n/// A struct.\npub struct Foo {\n x: i32,\n}\n";
408 let parsed: syn::File = syn::parse_file(src).unwrap();
409 let out_ast = parsed.clone();
410
411 let result = apply_changes_preserving_formatting(src, &parsed, &out_ast);
413
414 assert_eq!(result, src);
416 }
417
418 #[rstest]
420 fn comments_between_items_preserved() {
421 let src = "//! Module doc.\n\n// Comment before struct\npub struct Foo {\n x: i32,\n}\n\n// Comment between items\npub struct Bar {\n y: String,\n}\n";
423 let parsed: syn::File = syn::parse_file(src).unwrap();
424 let mut out_ast = parsed.clone();
425
426 if let syn::Item::Struct(s) = &mut out_ast.items[0] {
428 s.ident = syn::Ident::new("Foo2", s.ident.span());
429 }
430
431 let result = apply_changes_preserving_formatting(src, &parsed, &out_ast);
433
434 assert!(
436 result.contains("pub struct Foo2"),
437 "changed item not updated"
438 );
439 assert!(result.contains("//! Module doc."), "module doc lost");
440 assert!(
441 result.contains("// Comment before struct"),
442 "comment before struct lost"
443 );
444 assert!(
445 result.contains("// Comment between items"),
446 "inter-item comment lost"
447 );
448 assert!(result.contains("pub struct Bar"), "unchanged item lost");
449 }
450
451 #[rstest]
453 fn blank_lines_between_items_preserved() {
454 let src = "use std::io;\n\n\nuse std::fs;\n\n\n\nuse std::path;\n";
456 let parsed: syn::File = syn::parse_file(src).unwrap();
457 let mut out_ast = parsed.clone();
458
459 if let syn::Item::Use(u) = &mut out_ast.items[1] {
461 *u = syn::parse_quote!(
463 use std::fs::File;
464 );
465 }
466
467 let result = apply_changes_preserving_formatting(src, &parsed, &out_ast);
469
470 assert!(result.contains("use std::io;"), "first use lost");
472 assert!(
473 result.contains("use std::fs::File;"),
474 "changed use not updated"
475 );
476 assert!(result.contains("use std::path;"), "third use lost");
477 assert!(
479 result.contains("use std::io;\n\n\n"),
480 "blank lines after first use altered"
481 );
482 assert!(
483 result.contains("\n\n\nuse std::path;"),
484 "blank lines before third use altered"
485 );
486 }
487
488 #[rstest]
490 fn module_doc_comment_preserved() {
491 let src = "//! Crate-level documentation.\n//! Second line.\n\npub fn foo() {}\n";
493 let parsed: syn::File = syn::parse_file(src).unwrap();
494 let mut out_ast = parsed.clone();
495
496 if let syn::Item::Fn(f) = &mut out_ast.items[0] {
498 f.sig.ident = syn::Ident::new("bar", f.sig.ident.span());
499 }
500
501 let result = apply_changes_preserving_formatting(src, &parsed, &out_ast);
503
504 assert!(
506 result.contains("//! Crate-level documentation."),
507 "module doc lost"
508 );
509 assert!(result.contains("//! Second line."), "second doc line lost");
510 assert!(result.contains("pub fn bar"), "renamed function missing");
511 }
512
513 #[rstest]
516 fn only_changed_item_replaced() {
517 let src = "pub const A: i32 = 1;\npub const B: i32 = 2;\npub const C: i32 = 3;\n";
519 let parsed: syn::File = syn::parse_file(src).unwrap();
520 let mut out_ast = parsed.clone();
521
522 if let syn::Item::Const(c) = &mut out_ast.items[1] {
524 c.ident = syn::Ident::new("B_CHANGED", c.ident.span());
525 }
526
527 let result = apply_changes_preserving_formatting(src, &parsed, &out_ast);
529
530 assert!(
532 result.contains("pub const A: i32 = 1;"),
533 "first item altered"
534 );
535 assert!(
536 result.contains("pub const B_CHANGED"),
537 "changed item not updated"
538 );
539 assert!(
540 result.contains("pub const C: i32 = 3;"),
541 "third item altered"
542 );
543 let a_idx = result.find("pub const A").unwrap();
545 let b_idx = result.find("pub const B_CHANGED").unwrap();
546 let c_idx = result.find("pub const C").unwrap();
547 assert!(a_idx < b_idx && b_idx < c_idx, "item order changed");
548 }
549}