use std::path::PathBuf;
use clap::Args;
pub mod rewriter;
pub mod rules;
pub mod walker;
#[derive(Args, Debug)]
pub struct MigrateV2Args {
#[arg(default_value = ".")]
pub path: PathBuf,
#[arg(long)]
pub dry_run: bool,
#[arg(long, value_delimiter = ',')]
pub skip: Vec<String>,
}
pub fn run(args: MigrateV2Args) -> anyhow::Result<()> {
let all_rules = rules::all();
let known_rule_names: std::collections::BTreeSet<&'static str> =
all_rules.iter().map(|r| r.name()).collect();
let unknown: Vec<&str> = args
.skip
.iter()
.map(String::as_str)
.filter(|name| !known_rule_names.contains(name))
.collect();
if !unknown.is_empty() {
anyhow::bail!("unknown --skip rule(s): {}", unknown.join(", "));
}
let rules: Vec<_> = all_rules
.into_iter()
.filter(|r| !args.skip.iter().any(|s| s == r.name()))
.collect();
let files = walker::find_rs_files(&args.path)?;
let mut changed = 0_usize;
for path in files {
let src = read_developer_file(&path)?;
let parsed: syn::File = match syn::parse_file(&src) {
Ok(f) => f,
Err(_) => continue,
};
let mut out_ast = parsed.clone();
for r in &rules {
out_ast = r.rewrite(out_ast);
}
let out = apply_changes_preserving_formatting(&src, &parsed, &out_ast);
if out != src {
changed += 1;
if args.dry_run {
println!("would rewrite: {}", path.display());
} else {
write_developer_file(&path, &out)?;
println!("rewrote: {}", path.display());
}
}
}
println!(
"\nDone. {} file(s) {}.",
changed,
if args.dry_run {
"would change"
} else {
"changed"
}
);
Ok(())
}
fn apply_changes_preserving_formatting(
src: &str,
parsed: &syn::File,
out_ast: &syn::File,
) -> String {
let mut result = String::with_capacity(src.len() + 1024);
let mut last_pos: usize = 0;
let item_count = std::cmp::min(parsed.items.len(), out_ast.items.len());
for i in 0..item_count {
let orig_item = &parsed.items[i];
let new_item = &out_ast.items[i];
let formatted_orig = format_single_item(orig_item);
let formatted_new = format_single_item(new_item);
let (start_byte, end_byte) = find_item_in_source(src, last_pos, &formatted_orig);
result.push_str(&src[last_pos..start_byte]);
if formatted_orig == formatted_new {
result.push_str(&src[start_byte..end_byte]);
} else {
result.push_str(&format_single_item(new_item));
}
last_pos = end_byte;
}
if last_pos < src.len() {
result.push_str(&src[last_pos..]);
}
result
}
fn format_single_item(item: &syn::Item) -> String {
let file = syn::File {
shebang: None,
attrs: vec![],
items: vec![item.clone()],
};
let formatted = prettyplease::unparse(&file);
formatted.trim_end().to_string()
}
fn find_item_in_source(src: &str, search_from: usize, item_tokens: &str) -> (usize, usize) {
let anchor = item_tokens
.lines()
.find(|l| {
let trimmed = l.trim();
!trimmed.is_empty()
&& !trimmed.starts_with("//")
&& !trimmed.starts_with("#[")
&& !trimmed.starts_with("///")
})
.unwrap_or("");
if anchor.is_empty() {
return (search_from, search_from);
}
let rest = &src[search_from..];
let start = match rest.find(anchor) {
Some(pos) => search_from + pos,
None => return (search_from, search_from),
};
let after_start = &src[start..];
let end_offset = find_item_end_offset(after_start);
(start, start + end_offset)
}
fn find_item_end_offset(src: &str) -> usize {
let mut brace_depth: i32 = 0;
let mut has_block = false;
let bytes = src.as_bytes();
let len = bytes.len();
let mut i = 0;
while i < len {
let ch = bytes[i];
if ch == b'/' && i + 1 < len && bytes[i + 1] == b'/' {
while i < len && bytes[i] != b'\n' {
i += 1;
}
continue;
}
if ch == b'/' && i + 1 < len && bytes[i + 1] == b'*' {
i += 2;
while i + 1 < len && !(bytes[i] == b'*' && bytes[i + 1] == b'/') {
i += 1;
}
if i + 1 < len {
i += 2; }
continue;
}
if ch == b'r' && i + 1 < len {
let next = bytes[i + 1];
if next == b'"' || next == b'#' {
let hash_count = if next == b'"' {
0
} else {
let mut count = 0;
let mut j = i + 1;
while j < len && bytes[j] == b'#' {
count += 1;
j += 1;
}
if j < len && bytes[j] == b'"' {
i = j; count
} else {
i += 1;
continue;
}
};
i += 1; while i < len {
if bytes[i] == b'"' {
let mut h = 0;
let mut j = i + 1;
while j < len && bytes[j] == b'#' && h < hash_count {
h += 1;
j += 1;
}
if h == hash_count {
i = j;
break;
}
}
if bytes[i] == b'\\' && i + 1 < len {
i += 2; } else {
i += 1;
}
}
continue;
}
}
if ch == b'"' {
i += 1;
while i < len {
if bytes[i] == b'"' {
i += 1;
break;
}
if bytes[i] == b'\\' && i + 1 < len {
i += 2; } else {
i += 1;
}
}
continue;
}
if ch == b'b' && i + 1 < len && bytes[i + 1] == b'\'' {
i += 2; while i < len {
if bytes[i] == b'\'' {
i += 1;
break;
}
if bytes[i] == b'\\' && i + 1 < len {
i += 2;
} else {
i += 1;
}
}
continue;
}
if ch == b'\'' {
i += 1; if i < len {
if bytes[i].is_ascii_alphabetic() || bytes[i] == b'_' {
while i < len && (bytes[i].is_ascii_alphanumeric() || bytes[i] == b'_') {
i += 1;
}
} else {
if bytes[i] == b'\\' && i + 1 < len {
i += 2;
} else {
i += 1;
}
if i < len && bytes[i] == b'\'' {
i += 1;
}
}
}
continue;
}
match ch {
b'{' => {
brace_depth += 1;
has_block = true;
}
b'}' => {
brace_depth -= 1;
if brace_depth == 0 && has_block {
return i + 1;
}
}
b';' if brace_depth == 0 => {
return i + 1;
}
_ => {}
}
i += 1;
}
src.len()
}
fn read_developer_file(path: &std::path::Path) -> anyhow::Result<String> {
let canonical = path.canonicalize()?;
let mut file = std::fs::File::open(canonical)?; let mut buf = String::new();
std::io::Read::read_to_string(&mut file, &mut buf)?;
Ok(buf)
}
fn write_developer_file(path: &std::path::Path, content: &str) -> anyhow::Result<()> {
let canonical = path.canonicalize()?;
let parent = canonical
.parent()
.ok_or_else(|| anyhow::anyhow!("no parent directory for {}", canonical.display()))?;
let file_name = canonical
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("rewrite");
let random_suffix: u32 = {
use std::time::{SystemTime, UNIX_EPOCH};
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.subsec_nanos();
nanos ^ (std::process::id() as u32)
};
let tmp = parent.join(format!(".{file_name}.{random_suffix:x}.tmp")); if let Err(e) = std::fs::write(&tmp, content) {
let _ = std::fs::remove_file(&tmp);
return Err(e.into());
}
if let Err(e) = std::fs::rename(&tmp, canonical) {
let _ = std::fs::remove_file(&tmp);
return Err(e.into());
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
#[rstest]
fn no_changes_output_identical() {
let src = "//! Module doc comment.\n\nuse std::collections::HashMap;\n\n/// A struct.\npub struct Foo {\n x: i32,\n}\n";
let parsed: syn::File = syn::parse_file(src).unwrap();
let out_ast = parsed.clone();
let result = apply_changes_preserving_formatting(src, &parsed, &out_ast);
assert_eq!(result, src);
}
#[rstest]
fn comments_between_items_preserved() {
let src = "//! Module doc.\n\n// Comment before struct\npub struct Foo {\n x: i32,\n}\n\n// Comment between items\npub struct Bar {\n y: String,\n}\n";
let parsed: syn::File = syn::parse_file(src).unwrap();
let mut out_ast = parsed.clone();
if let syn::Item::Struct(s) = &mut out_ast.items[0] {
s.ident = syn::Ident::new("Foo2", s.ident.span());
}
let result = apply_changes_preserving_formatting(src, &parsed, &out_ast);
assert!(result.contains("pub struct Foo2"), "changed item not updated");
assert!(result.contains("//! Module doc."), "module doc lost");
assert!(result.contains("// Comment before struct"), "comment before struct lost");
assert!(result.contains("// Comment between items"), "inter-item comment lost");
assert!(result.contains("pub struct Bar"), "unchanged item lost");
}
#[rstest]
fn blank_lines_between_items_preserved() {
let src = "use std::io;\n\n\nuse std::fs;\n\n\n\nuse std::path;\n";
let parsed: syn::File = syn::parse_file(src).unwrap();
let mut out_ast = parsed.clone();
if let syn::Item::Use(u) = &mut out_ast.items[1] {
*u = syn::parse_quote!(use std::fs::File;);
}
let result = apply_changes_preserving_formatting(src, &parsed, &out_ast);
assert!(result.contains("use std::io;"), "first use lost");
assert!(result.contains("use std::fs::File;"), "changed use not updated");
assert!(result.contains("use std::path;"), "third use lost");
assert!(result.contains("use std::io;\n\n\n"), "blank lines after first use altered");
assert!(result.contains("\n\n\nuse std::path;"), "blank lines before third use altered");
}
#[rstest]
fn module_doc_comment_preserved() {
let src = "//! Crate-level documentation.\n//! Second line.\n\npub fn foo() {}\n";
let parsed: syn::File = syn::parse_file(src).unwrap();
let mut out_ast = parsed.clone();
if let syn::Item::Fn(f) = &mut out_ast.items[0] {
f.sig.ident = syn::Ident::new("bar", f.sig.ident.span());
}
let result = apply_changes_preserving_formatting(src, &parsed, &out_ast);
assert!(result.contains("//! Crate-level documentation."), "module doc lost");
assert!(result.contains("//! Second line."), "second doc line lost");
assert!(result.contains("pub fn bar"), "renamed function missing");
}
#[rstest]
fn only_changed_item_replaced() {
let src = "pub const A: i32 = 1;\npub const B: i32 = 2;\npub const C: i32 = 3;\n";
let parsed: syn::File = syn::parse_file(src).unwrap();
let mut out_ast = parsed.clone();
if let syn::Item::Const(c) = &mut out_ast.items[1] {
c.ident = syn::Ident::new("B_CHANGED", c.ident.span());
}
let result = apply_changes_preserving_formatting(src, &parsed, &out_ast);
assert!(result.contains("pub const A: i32 = 1;"), "first item altered");
assert!(result.contains("pub const B_CHANGED"), "changed item not updated");
assert!(result.contains("pub const C: i32 = 3;"), "third item altered");
let a_idx = result.find("pub const A").unwrap();
let b_idx = result.find("pub const B_CHANGED").unwrap();
let c_idx = result.find("pub const C").unwrap();
assert!(a_idx < b_idx && b_idx < c_idx, "item order changed");
}
}