diesel_cli 0.5.3

Provides the CLI for the Diesel crate
extern crate chrono;
extern crate clap;
extern crate diesel;
extern crate dotenv;

mod database_error;
#[macro_use]
mod database;

use chrono::*;
use clap::{App, AppSettings, Arg, ArgMatches, SubCommand};
#[cfg(feature = "postgres")]
use diesel::pg::PgConnection;
#[cfg(feature = "sqlite")]
use diesel::sqlite::SqliteConnection;
use diesel::migrations::schema::*;
use diesel::types::{FromSql, VarChar};
use diesel::{migrations, Connection, Insertable};
use std::error::Error;
use std::io::stdout;
use std::path::{PathBuf, Path};
use std::{env, fs};

use self::database_error::{DatabaseError, DatabaseResult};

fn main() {
    use self::dotenv::dotenv;
    dotenv().ok();

    let database_arg = Arg::with_name("DATABASE_URL")
        .long("database-url")
        .help("Specifies the database URL to connect to. Falls back to \
                   the DATABASE_URL environment variable if unspecified.")
        .global(true)
        .takes_value(true);

    let migration_subcommand = SubCommand::with_name("migration")
        .about("A group of commands for generating, running, and reverting \
                migrations.")
        .setting(AppSettings::VersionlessSubcommands)
        .arg(Arg::with_name("MIGRATION_DIRECTORY")
            .long("migration-dir")
            .help("The location of your migration directory. By default this \
                   will look for a directory called `migrations` in the \
                   current directory and its parents.")
            .takes_value(true)
            .global(true)
        ).subcommand(
            SubCommand::with_name("run")
                .about("Runs all pending migrations")
        ).subcommand(
            SubCommand::with_name("revert")
                .about("Reverts the latest run migration")
        ).subcommand(
            SubCommand::with_name("redo")
                .about("Reverts and re-runs the latest migration. Useful \
                      for testing that a migration can in fact be reverted.")
        ).subcommand(
            SubCommand::with_name("generate")
                .about("Generate a new migration with the given name, and \
                      the current timestamp as the version")
                .arg(Arg::with_name("MIGRATION_NAME")
                     .help("The name of the migration to create")
                     .required(true)
                 )
                .arg(Arg::with_name("MIGRATION_VERSION")
                     .long("version")
                     .help("The version number to use when generating the migration. \
                            Defaults to the current timestamp, which should suffice \
                            for most use cases.")
                     .takes_value(true)
                )
        ).setting(AppSettings::SubcommandRequiredElseHelp);

    let setup_subcommand = SubCommand::with_name("setup")
        .about("Creates the migrations directory, creates the database \
                specified in your DATABASE_URL, and runs existing migrations.");

    let database_subcommand = SubCommand::with_name("database")
        .about("A group of commands for setting up and resetting your database.")
        .setting(AppSettings::VersionlessSubcommands)
        .subcommand(
            SubCommand::with_name("setup")
                .about("Creates the database specified in your DATABASE_URL, \
                        and then runs any existing migrations.")
        ).subcommand(
            SubCommand::with_name("reset")
                .about("Resets your database by dropping the database specified \
                        in your DATABASE_URL and then running `diesel database setup`.")
        ).subcommand(
            SubCommand::with_name("drop")
                .about("Drops the database specified in your DATABASE_URL.")
                .setting(AppSettings::Hidden)
        ).setting(AppSettings::SubcommandRequiredElseHelp);

    let matches = App::new("diesel")
        .version(env!("CARGO_PKG_VERSION"))
        .setting(AppSettings::VersionlessSubcommands)
        .after_help("You can also run `diesel SUBCOMMAND -h` to get more information about that subcommand.")
        .arg(database_arg)
        .subcommand(migration_subcommand)
        .subcommand(setup_subcommand)
        .subcommand(database_subcommand)
        .setting(AppSettings::SubcommandRequiredElseHelp)
        .get_matches();

    match matches.subcommand() {
        ("migration", Some(matches)) => run_migration_command(matches),
        ("setup", Some(matches)) => run_setup_command(matches),
        ("database", Some(matches)) => run_database_command(matches),
        _ => unreachable!("The cli parser should prevent reaching here"),
    }
}

fn run_migration_command(matches: &ArgMatches) {
    match matches.subcommand() {
        ("run", Some(_)) => {
            let database_url = database::database_url(matches);
            call_with_conn!(database_url, migrations::run_pending_migrations)
                .unwrap_or_else(handle_error);
        }
        ("revert", Some(_)) => {
            let database_url = database::database_url(matches);
            call_with_conn!(database_url, migrations::revert_latest_migration)
                .unwrap_or_else(handle_error);
        }
        ("redo", Some(_)) => {
            let database_url = database::database_url(matches);
            call_with_conn!(database_url, redo_latest_migration);
        }
        ("generate", Some(args)) => {
            let migration_name = args.value_of("MIGRATION_NAME").unwrap();
            let version = migration_version(args);
            let versioned_name = format!("{}_{}", version, migration_name);
            let migration_dir = migrations_dir(args).join(versioned_name);
            fs::create_dir(&migration_dir).unwrap();

            let migration_dir_relative = convert_absolute_path_to_relative(
                &migration_dir,
                &env::current_dir().unwrap()
            );

            let up_path = migration_dir.join("up.sql");
            println!("Creating {}", migration_dir_relative.join("up.sql").display());
            fs::File::create(up_path).unwrap();
            let down_path = migration_dir.join("down.sql");
            println!("Creating {}", migration_dir_relative.join("down.sql").display());
            fs::File::create(down_path).unwrap();
        }
        _ => unreachable!("The cli parser should prevent reaching here"),
    }
}

use std::fmt::Display;
fn migration_version<'a>(matches: &'a ArgMatches) -> Box<Display + 'a> {
    matches.value_of("MIGRATION_VERSION").map(|s| Box::new(s) as Box<Display>)
        .unwrap_or_else(|| Box::new(Local::now().format("%Y%m%d%H%M%S")))
}

fn migrations_dir(matches: &ArgMatches) -> PathBuf {
    matches.value_of("MIGRATION_DIRECTORY")
        .map(PathBuf::from)
        .or_else(|| {
            env::var("MIGRATION_DIRECTORY").map(PathBuf::from).ok()
        }).unwrap_or_else(|| {
            migrations::find_migrations_directory()
                .unwrap_or_else(handle_error)
        })
}

fn run_setup_command(matches: &ArgMatches) {
    migrations::find_migrations_directory()
        .unwrap_or_else(|_| {
            create_migrations_directory()
                .unwrap_or_else(handle_error)
        });

    database::setup_database(matches).unwrap_or_else(handle_error);
}

fn run_database_command(matches: &ArgMatches) {
    match matches.subcommand() {
        ("setup", Some(args)) => database::setup_database(args).unwrap_or_else(handle_error),
        ("reset", Some(args)) => database::reset_database(args).unwrap_or_else(handle_error),
        ("drop", Some(args)) => database::drop_database_command(args).unwrap_or_else(handle_error),
        _ => unreachable!("The cli parser should prevent reaching here"),
    };
}

/// Looks for a migrations directory in the current path and all parent paths,
/// and creates one in the same directory as the Cargo.toml if it can't find
/// one. It also sticks a .gitkeep in the directory so git will pick it up.
/// Returns a `DatabaseError::CargoTomlNotFound` if no Cargo.toml is found.
fn create_migrations_directory() -> DatabaseResult<PathBuf> {
    let project_root = try!(find_project_root());
    println!("Creating migrations/ directory at: {}", project_root
                                                        .join("migrations")
                                                        .display());
    try!(fs::create_dir(project_root.join("migrations")));
    try!(fs::File::create(project_root.join("migrations/.gitkeep")));
    Ok(project_root)
}

fn find_project_root() -> DatabaseResult<PathBuf> {
    search_for_cargo_toml_directory(&try!(env::current_dir()))
}

/// Searches for the directory that holds the project's Cargo.toml, and returns
/// the path if it found it, or returns a `DatabaseError::CargoTomlNotFound`.
fn search_for_cargo_toml_directory(path: &Path) -> DatabaseResult<PathBuf> {
    let toml_path = path.join("Cargo.toml");
    if toml_path.is_file() {
        Ok(path.to_owned())
    } else {
        path.parent().map(search_for_cargo_toml_directory)
            .unwrap_or(Err(DatabaseError::CargoTomlNotFound))
    }
}

/// Reverts the most recent migration, and then runs it again, all in a
/// transaction. If either part fails, the transaction is not committed.
fn redo_latest_migration<Conn>(conn: &Conn) where
        Conn: Connection,
        String: FromSql<VarChar, Conn::Backend>,
        for<'a> &'a NewMigration<'a>:
            Insertable<__diesel_schema_migrations::table, Conn::Backend>,
{
    conn.transaction(|| {
        let reverted_version = try!(migrations::revert_latest_migration(conn));
        migrations::run_migration_with_version(conn, &reverted_version, &mut stdout())
    }).unwrap_or_else(handle_error);
}

fn handle_error<E: Error, T>(error: E) -> T {
    panic!("{}", error);
}

// Converts an absolute path to a relative path, with the restriction that the
// target path must be in the same directory or above the current path.
fn convert_absolute_path_to_relative(target_path: &Path, mut current_path: &Path)
    -> PathBuf
{
    let mut result = PathBuf::new();

    while !target_path.starts_with(current_path) {
        result.push("..");
        match current_path.parent() {
            Some(parent) => current_path = parent,
            None => return target_path.into(),
        }
    }

    result.join(target_path.strip_prefix(current_path).unwrap())
}

#[cfg(test)]
mod tests {
    extern crate tempdir;

    use database_error::DatabaseError;

    use self::tempdir::TempDir;

    use std::fs;
    use std::path::PathBuf;

    use super::convert_absolute_path_to_relative;
    use super::search_for_cargo_toml_directory;

    #[test]
    fn toml_directory_find_cargo_toml() {
        let dir = TempDir::new("diesel").unwrap();
        let temp_path = dir.path().canonicalize().unwrap();
        let toml_path = temp_path.join("Cargo.toml");

        fs::File::create(&toml_path).unwrap();

        assert_eq!(Ok(temp_path.clone()), search_for_cargo_toml_directory(&temp_path));
    }

    #[test]
    fn cargo_toml_not_found_if_no_cargo_toml() {
        let dir = TempDir::new("diesel").unwrap();
        let temp_path = dir.path().canonicalize().unwrap();

        assert_eq!(Err(DatabaseError::CargoTomlNotFound),
            search_for_cargo_toml_directory(&temp_path));
    }

    #[test]
    fn convert_absolute_path_to_relative_works() {
        assert_eq!(PathBuf::from("migrations/12345_create_user"),
                   convert_absolute_path_to_relative(&PathBuf::from("projects/foo/migrations/12345_create_user"),
                                                     &PathBuf::from("projects/foo")));
        assert_eq!(PathBuf::from("../migrations/12345_create_user"),
                   convert_absolute_path_to_relative(&PathBuf::from("projects/foo/migrations/12345_create_user"),
                                                     &PathBuf::from("projects/foo/src")));
        assert_eq!(PathBuf::from("../../../migrations/12345_create_user"),
                   convert_absolute_path_to_relative(&PathBuf::from("projects/foo/migrations/12345_create_user"),
                                                     &PathBuf::from("projects/foo/src/controllers/errors")));
        assert_eq!(PathBuf::from("12345_create_user"),
                   convert_absolute_path_to_relative(&PathBuf::from("projects/foo/migrations/12345_create_user"),
                                                     &PathBuf::from("projects/foo/migrations")));
        assert_eq!(PathBuf::from("../12345_create_user"),
                   convert_absolute_path_to_relative(&PathBuf::from("projects/foo/migrations/12345_create_user"),
                                                     &PathBuf::from("projects/foo/migrations/67890_create_post")));
    }
}