rustfmt_if_chain 0.1.4

An if_chain-aware rustfmt (kind of)
use anyhow::{ensure, Result};
use if_chain::if_chain;
use proc_macro2::{Span, TokenStream, TokenTree};
use quote::{quote, ToTokens};
use sedregex::find_and_replace;
use std::{
    env,
    fs::{read_to_string, OpenOptions},
    io::Write,
    path::Path,
    process::{exit, Command},
};
use syn::{
    parse_file,
    spanned::Spanned,
    visit::{visit_item, Visit},
    Ident, Item, ItemMacro, Macro,
};

mod offset_based_rewriter;

mod offset_calculator;

mod backup;
use backup::Backup;

mod failed_to;
use failed_to::FailedTo;

mod rewriter;
use rewriter::Rewriter;

fn main() -> Result<()> {
    let (args, paths, preformat_failure_is_warning) = process_args();

    if paths.is_empty() {
        return rustfmt(&args, None);
    }

    for path in paths {
        let path = Path::new(&path);

        if let Err(error) = rustfmt(&args, Some(path)) {
            if preformat_failure_is_warning {
                eprintln!("Warning: {}", error);
                continue;
            }
            return Err(error);
        }

        let mut backup = Backup::new(path).failed_to(|| format!("backup {:?}", path))?;

        let marker = rewrite_if_chain(path)?;

        rustfmt(&args, Some(path))?;

        restore_if_chain(path, &marker)?;

        backup
            .disable()
            .failed_to(|| format!("disable {:?} backup", path))?;
    }

    Ok(())
}

#[allow(clippy::case_sensitive_file_extension_comparisons)]
fn process_args() -> (Vec<String>, Vec<String>, bool) {
    let mut args = Vec::new();
    let mut paths = Vec::new();
    let mut preformat_failure_is_warning = false;
    for arg in env::args().skip(1) {
        if arg == "--help" || arg == "-h" {
            usage();
        } else if arg == "--preformat-failure-is-warning" {
            preformat_failure_is_warning = true;
        } else if arg.to_lowercase().ends_with(".rs") {
            paths.push(arg);
        } else {
            args.push(arg);
        }
    }
    (args, paths, preformat_failure_is_warning)
}

const USAGE: &str = "\
Usage: rustfmt_if_chain [ARGS]

Arguments ending with `.rs` are considered source files and are
formatted. All other arguments are forwarded to `rustfmt`, with one
exception.

The one argument not forwarded to `rustfmt` is
`--preformat-failure-is-warning`. If this option is passed and `rustfmt`
fails on an unmodified source file, a warning results instead of an
error.\
";

fn usage() -> ! {
    println!("{}", USAGE);
    exit(0);
}

fn rewrite_if_chain(path: &Path) -> Result<Ident> {
    let contents = read_to_string(path).failed_to(|| format!("read from {:?}", path))?;

    let marker = unused_ident(&contents);

    let file = parse_file(&contents)?;

    let mut visitor = RewriteVisitor {
        rewriter: Rewriter::new(&contents),
        marker: &marker,
    };

    visitor.visit_file(&file);

    let mut file = OpenOptions::new()
        .truncate(true)
        .write(true)
        .open(path)
        .failed_to(|| format!("open {:?}", path))?;
    file.write_all(visitor.rewriter.contents().as_bytes())
        .failed_to(|| format!("write to {:?}", path))?;

    Ok(marker)
}

fn unused_ident(contents: &str) -> Ident {
    let mut i = 0;
    loop {
        let x = format!("x{}", i);
        if !contents.contains(&x) {
            return Ident::new(&x, Span::call_site());
        }
        i += 1;
    }
}

struct RewriteVisitor<'rewrite> {
    rewriter: Rewriter<'rewrite>,
    marker: &'rewrite Ident,
}

impl<'ast, 'rewrite> Visit<'ast> for RewriteVisitor<'rewrite> {
    fn visit_item(&mut self, item: &Item) {
        if let Some((span, tokens)) = match_if_chain(item) {
            let marker = self.marker;
            self.rewrite(span, &quote! { fn #marker() }.to_string());
            self.rewrite_tokens(tokens);
            return;
        }
        visit_item(self, item);
    }
}

impl<'rewrite> RewriteVisitor<'rewrite> {
    fn rewrite_tokens(&mut self, tokens: &TokenStream) {
        let mut iter = tokens.clone().into_iter().peekable();
        let mut curr_ends_let = if let Some(TokenTree::Ident(ident)) = iter.peek() {
            ident == "let"
        } else {
            false
        };
        while let Some(curr) = iter.next() {
            match (&curr, iter.peek()) {
                (TokenTree::Punct(punct), Some(TokenTree::Ident(next)))
                    if punct.as_char() == ';'
                        && ["if", "let", "then"].contains(&next.to_string().as_str()) =>
                {
                    let marker = self.marker;
                    if !curr_ends_let {
                        self.rewrite(
                            curr.span(),
                            &quote! { { #marker; } }.to_token_stream().to_string(),
                        );
                    }
                    if *next == "then" {
                        self.rewrite(
                            next.span(),
                            &quote! { if #marker }.to_token_stream().to_string(),
                        );
                        return;
                    }
                    curr_ends_let = *next == "let";
                }
                (_, _) => {}
            }
        }
        panic!("`if_chain!` without `then`");
    }

    fn rewrite(&mut self, span: Span, replacement: &str) {
        self.rewriter.rewrite(span, replacement);
    }
}

fn restore_if_chain(path: &Path, marker: &Ident) -> Result<()> {
    let contents = read_to_string(path).failed_to(|| format!("read from {:?}", path))?;

    let contents = find_and_replace(
        &contents,
        &[
            format!(r#"s/(?m)\bfn\s+{}\s*\(\)/if_chain!/g"#, marker),
            format!(r#"s/(?m)\s*\{{\s*{}\s*;\s*}}/;/g"#, marker),
            format!(r#"s/(?m)\bif\s+{}/then/g"#, marker),
        ],
    )?;

    let mut file = OpenOptions::new()
        .truncate(true)
        .write(true)
        .open(path)
        .failed_to(|| format!("open {:?}", path))?;
    file.write_all(contents.as_bytes())
        .failed_to(|| format!("write to {:?}", path))?;

    Ok(())
}

fn rustfmt(args: &[String], path: Option<&Path>) -> Result<()> {
    let mut command = Command::new("rustfmt");
    command.args(args);
    if let Some(path) = path {
        command.arg(path);
    }
    let status = command
        .status()
        .failed_to(|| format!("get status of {:?}", command))?;

    ensure!(status.success(), "failed to format {:?}", path);

    Ok(())
}

fn match_if_chain(item: &Item) -> Option<(Span, &TokenStream)> {
    if_chain! {
        if let Item::Macro(ItemMacro {
            mac:
                Macro {
                    path: path @ syn::Path { segments, .. },
                    bang_token,
                    tokens,
                    ..
                },
            ..
        }) = item;
        let segments = segments
            .iter()
            .map(|segment| segment.ident.to_string())
            .collect::<Vec<_>>();
        if let ["if_chain"] = segments
            .iter()
            .map(String::as_str)
            .collect::<Vec<_>>()
            .as_slice();
        then {
            Some((
                path.span()
                    .join(bang_token.span())
                    .expect("`path` and `bang_token` should be from the same file"),
                tokens,
            ))
        } else {
            None
        }
    }
}

#[test]
fn usage_wrapping() {
    let unwrapped =
        find_and_replace(USAGE, [r#"s/(?P<left>\S)\s(?P<right>\S)/$left $right/g"#]).unwrap();
    let mut prev = String::new();
    let mut rewrapped = unwrapped.to_string();
    while prev != rewrapped {
        prev = rewrapped;
        rewrapped = find_and_replace(
            &prev,
            [r#"s/(?m)^(?P<line>.{0,72})\s/$line
/g"#],
        )
        .unwrap()
        .to_string();
    }
    assert_eq!(USAGE, rewrapped);
}

#[test]
fn readme_contains_usage() {
    let readme = read_to_string(Path::new(env!("CARGO_MANIFEST_DIR")).join("README.md")).unwrap();
    assert!(readme.contains(USAGE));
}