#![allow(clippy::map_unwrap_or)]
#![warn(
clippy::if_not_else,
clippy::items_after_statements,
clippy::mut_mut,
clippy::non_ascii_literal,
clippy::similar_names,
clippy::unicode_not_nfc,
clippy::used_underscore_binding,
missing_copy_implementations
)]
#![cfg_attr(test, allow(clippy::unwrap_used))]
mod config;
mod database_error;
#[macro_use]
mod database;
mod cli;
mod infer_schema_internals;
mod print_schema;
#[cfg(any(feature = "postgres", feature = "mysql"))]
mod query_helper;
use chrono::*;
use clap::ArgMatches;
use clap_complete::{generate, Shell};
use diesel::backend::Backend;
use diesel::migration::MigrationSource;
use diesel::Connection;
use diesel_migrations::{FileBasedMigrations, HarnessWithOutput, MigrationError, MigrationHarness};
use regex::Regex;
use std::any::Any;
use std::collections::{HashMap, HashSet};
use std::error::Error;
use std::fmt::Display;
use std::io::stdout;
use std::path::{Path, PathBuf};
use std::{env, fs};
use self::config::Config;
use self::database_error::{DatabaseError, DatabaseResult};
pub static TIMESTAMP_FORMAT: &str = "%Y-%m-%d-%H%M%S";
fn main() {
use dotenvy::dotenv;
dotenv().ok();
let matches = cli::build_cli().get_matches();
match matches.subcommand().unwrap() {
("migration", matches) => run_migration_command(matches).unwrap_or_else(handle_error),
("setup", matches) => run_setup_command(matches),
("database", matches) => run_database_command(matches).unwrap_or_else(handle_error),
("completions", matches) => generate_completions_command(matches),
("print-schema", matches) => run_infer_schema(matches).unwrap_or_else(handle_error),
_ => unreachable!("The cli parser should prevent reaching here"),
}
}
fn run_migration_command(
matches: &ArgMatches,
) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
match matches.subcommand().unwrap() {
("run", _) => {
let database_url = database::database_url(matches);
let dir = migrations_dir(matches).unwrap_or_else(handle_error);
let dir = FileBasedMigrations::from_path(dir).unwrap_or_else(handle_error);
call_with_conn!(database_url, run_migrations_with_output(dir))?;
regenerate_schema_if_file_specified(matches)?;
}
("revert", args) => {
let database_url = database::database_url(matches);
let dir = migrations_dir(matches).unwrap_or_else(handle_error);
let dir = FileBasedMigrations::from_path(dir).unwrap_or_else(handle_error);
if args.contains_id("REVERT_ALL") {
call_with_conn!(database_url, revert_all_migrations_with_output(dir))?;
} else {
let number = args.get_one::<u64>("REVERT_NUMBER").unwrap();
for _ in 0..*number {
match call_with_conn!(database_url, revert_migration_with_output(dir.clone())) {
Ok(_) => {}
Err(e) if e.is::<MigrationError>() => {
match e.downcast_ref::<MigrationError>() {
Some(MigrationError::NoMigrationRun) => break,
_ => return Err(e),
}
}
Err(e) => return Err(e),
}
}
}
regenerate_schema_if_file_specified(matches)?;
}
("redo", args) => {
let database_url = database::database_url(matches);
let dir = migrations_dir(matches).unwrap_or_else(handle_error);
let dir = FileBasedMigrations::from_path(dir).unwrap_or_else(handle_error);
call_with_conn!(database_url, redo_migrations(dir, args));
regenerate_schema_if_file_specified(matches)?;
}
("list", _) => {
let database_url = database::database_url(matches);
let dir = migrations_dir(matches).unwrap_or_else(handle_error);
let dir = FileBasedMigrations::from_path(dir).unwrap_or_else(handle_error);
call_with_conn!(database_url, list_migrations(dir))?;
}
("pending", _) => {
let database_url = database::database_url(matches);
let dir = migrations_dir(matches).unwrap_or_else(handle_error);
let dir = FileBasedMigrations::from_path(dir).unwrap_or_else(handle_error);
let result =
call_with_conn!(database_url, MigrationHarness::has_pending_migration(dir))?;
println!("{:?}", result);
}
("generate", args) => {
let migration_name = args.get_one::<String>("MIGRATION_NAME").unwrap();
let version = migration_version(args);
let versioned_name = format!("{}_{}", version, migration_name);
let migration_dir = migrations_dir(matches)
.unwrap_or_else(handle_error)
.join(versioned_name);
fs::create_dir(&migration_dir).unwrap();
match args
.get_one::<String>("MIGRATION_FORMAT")
.map(|s| s as &str)
{
Some("sql") => generate_sql_migration(
&migration_dir,
!args.contains_id("MIGRATION_NO_DOWN_FILE"),
),
Some(x) => return Err(format!("Unrecognized migration format `{}`", x).into()),
None => unreachable!("MIGRATION_FORMAT has a default value"),
}
}
_ => unreachable!("The cli parser should prevent reaching here"),
};
Ok(())
}
fn generate_sql_migration(path: &Path, with_down: bool) {
use std::io::Write;
let migration_dir_relative =
convert_absolute_path_to_relative(path, &env::current_dir().unwrap());
let up_path = path.join("up.sql");
println!(
"Creating {}",
migration_dir_relative.join("up.sql").display()
);
let mut up = fs::File::create(up_path).unwrap();
up.write_all(b"-- Your SQL goes here").unwrap();
if with_down {
let down_path = path.join("down.sql");
println!(
"Creating {}",
migration_dir_relative.join("down.sql").display()
);
let mut down = fs::File::create(down_path).unwrap();
down.write_all(b"-- This file should undo anything in `up.sql`")
.unwrap();
}
}
fn migration_version<'a>(matches: &'a ArgMatches) -> Box<dyn Display + 'a> {
matches
.get_one::<String>("MIGRATION_VERSION")
.map(|s| Box::new(s) as Box<dyn Display>)
.unwrap_or_else(|| Box::new(Utc::now().format(TIMESTAMP_FORMAT)))
}
fn migrations_dir_from_cli(matches: &ArgMatches) -> Option<PathBuf> {
matches.get_one("MIGRATION_DIRECTORY").cloned().or_else(|| {
matches
.subcommand()
.and_then(|s| migrations_dir_from_cli(s.1))
})
}
fn run_migrations_with_output<Conn, DB>(
conn: &mut Conn,
migrations: FileBasedMigrations,
) -> Result<(), Box<dyn Error + Send + Sync + 'static>>
where
Conn: MigrationHarness<DB> + Connection<Backend = DB> + 'static,
DB: Backend,
{
HarnessWithOutput::write_to_stdout(conn)
.run_pending_migrations(migrations)
.map(|_| ())
}
fn revert_all_migrations_with_output<Conn, DB>(
conn: &mut Conn,
migrations: FileBasedMigrations,
) -> Result<(), Box<dyn Error + Send + Sync + 'static>>
where
Conn: MigrationHarness<DB> + Connection<Backend = DB> + 'static,
DB: Backend,
{
HarnessWithOutput::write_to_stdout(conn)
.revert_all_migrations(migrations)
.map(|_| ())
}
fn revert_migration_with_output<Conn, DB>(
conn: &mut Conn,
migrations: FileBasedMigrations,
) -> Result<(), Box<dyn Error + Send + Sync + 'static>>
where
Conn: MigrationHarness<DB> + Connection<Backend = DB> + 'static,
DB: Backend,
{
HarnessWithOutput::write_to_stdout(conn)
.revert_last_migration(migrations)
.map(|_| ())
}
fn list_migrations<Conn, DB>(
conn: &mut Conn,
migrations: FileBasedMigrations,
) -> Result<(), Box<dyn Error + Send + Sync + 'static>>
where
Conn: MigrationHarness<DB> + Connection<Backend = DB> + 'static,
DB: Backend,
{
let applied_migrations = conn
.applied_migrations()?
.into_iter()
.collect::<HashSet<_>>();
let mut migrations = MigrationSource::<DB>::migrations(&migrations)?;
migrations.sort_unstable_by(|a, b| a.name().version().cmp(&b.name().version()));
println!("Migrations:");
for migration in migrations {
let applied = applied_migrations.contains(&migration.name().version());
let name = migration.name();
let x = if applied { 'X' } else { ' ' };
println!(" [{}] {}", x, name);
}
Ok(())
}
fn migrations_dir(matches: &ArgMatches) -> Result<PathBuf, MigrationError> {
let migrations_dir = migrations_dir_from_cli(matches)
.or_else(|| env::var("MIGRATION_DIRECTORY").map(PathBuf::from).ok())
.or_else(|| {
Some(
Config::read(matches)
.unwrap_or_else(handle_error)
.migrations_directory?
.dir,
)
});
match migrations_dir {
Some(dir) => Ok(dir),
None => FileBasedMigrations::find_migrations_directory().map(|p| p.path().to_path_buf()),
}
}
fn run_setup_command(matches: &ArgMatches) {
create_config_file(matches).unwrap_or_else(handle_error);
let migrations_dir = create_migrations_dir(matches).unwrap_or_else(handle_error);
database::setup_database(matches, &migrations_dir).unwrap_or_else(handle_error);
}
fn create_migrations_dir(matches: &ArgMatches) -> DatabaseResult<PathBuf> {
let dir = match migrations_dir(matches) {
Ok(dir) => dir,
Err(_) => find_project_root()
.unwrap_or_else(handle_error)
.join("migrations"),
};
if dir.exists() {
if let Ok(read_dir) = fs::read_dir(&dir) {
if let Some(dir_entry) =
read_dir
.filter_map(|entry| entry.ok())
.find(|entry| match entry.file_type() {
Ok(file_type) => file_type.is_file() && entry.file_name() == ".gitkeep",
Err(_) => false,
})
{
fs::remove_file(dir_entry.path()).unwrap_or_else(|err| {
eprintln!(
"WARNING: Unable to delete existing `migrations/.gitkeep`:\n{}",
err
)
});
}
}
} else {
create_migrations_directory(&dir)?;
}
Ok(dir)
}
fn create_config_file(matches: &ArgMatches) -> DatabaseResult<()> {
use std::io::Write;
let path = Config::file_path(matches);
if !path.exists() {
let mut file = fs::File::create(path)?;
file.write_all(include_bytes!("default_files/diesel.toml"))?;
}
Ok(())
}
fn run_database_command(
matches: &ArgMatches,
) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
match matches.subcommand().unwrap() {
("setup", args) => {
let migrations_dir = migrations_dir(matches).unwrap_or_else(handle_error);
database::setup_database(args, &migrations_dir)?;
regenerate_schema_if_file_specified(matches)?;
}
("reset", args) => {
let migrations_dir = migrations_dir(matches).unwrap_or_else(handle_error);
database::reset_database(args, &migrations_dir)?;
regenerate_schema_if_file_specified(matches)?;
}
("drop", args) => database::drop_database_command(args)?,
_ => unreachable!("The cli parser should prevent reaching here"),
};
Ok(())
}
fn generate_completions_command(matches: &ArgMatches) {
let shell: &Shell = matches.get_one("SHELL").expect("Shell is set here?");
let mut app = cli::build_cli();
let name = app.get_name().to_string();
generate(*shell, &mut app, name, &mut stdout());
}
fn create_migrations_directory(path: &Path) -> DatabaseResult<PathBuf> {
println!("Creating migrations directory at: {}", path.display());
fs::create_dir(path)?;
fs::File::create(path.join(".keep"))?;
Ok(path.to_owned())
}
fn find_project_root() -> DatabaseResult<PathBuf> {
let current_dir = env::current_dir()?;
search_for_directory_containing_file(¤t_dir, "diesel.toml")
.or_else(|_| search_for_directory_containing_file(¤t_dir, "Cargo.toml"))
}
fn search_for_directory_containing_file(path: &Path, file: &str) -> DatabaseResult<PathBuf> {
let toml_path = path.join(file);
if toml_path.is_file() {
Ok(path.to_owned())
} else {
path.parent()
.map(|p| search_for_directory_containing_file(p, file))
.unwrap_or_else(|| Err(DatabaseError::ProjectRootNotFound(path.into())))
.map_err(|_| DatabaseError::ProjectRootNotFound(path.into()))
}
}
fn redo_migrations<Conn, DB>(
conn: &mut Conn,
migrations_dir: FileBasedMigrations,
args: &ArgMatches,
) where
DB: Backend,
Conn: MigrationHarness<DB> + Connection<Backend = DB> + 'static,
{
let migrations_inner = |harness: &mut HarnessWithOutput<Conn, _>| -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
let reverted_versions = if args.contains_id("REDO_ALL") {
harness.revert_all_migrations(migrations_dir.clone())?
} else {
let number = args.get_one::<u64>("REDO_NUMBER").unwrap();
(0..*number)
.filter_map(|_|{
match harness.revert_last_migration(migrations_dir.clone()) {
Ok(v) => {
Some(Ok(v))
}
Err(e) if e.is::<MigrationError>() => {
match e.downcast_ref::<MigrationError>() {
Some(MigrationError::NoMigrationRun) => None,
_ => Some(Err(e)),
}
}
Err(e) => {
Some(Err(e))
}
}
})
.collect::<Result<Vec<_>, _>>()?
};
let mut migrations = MigrationSource::<DB>::migrations(&migrations_dir)?
.into_iter()
.map(|m| (m.name().version().as_owned(), m))
.collect::<HashMap<_, _>>();
let mut migrations = reverted_versions
.into_iter()
.map(|v| {
migrations
.remove(&v)
.ok_or_else(|| MigrationError::UnknownMigrationVersion(v.as_owned()))
})
.collect::<Result<Vec<_>, _>>()?;
migrations.sort_by_key(|m| m.name().version().as_owned());
harness.run_migrations(&migrations)?;
Ok(())
};
if should_redo_migration_in_transaction(conn) {
conn.transaction(|conn| migrations_inner(&mut HarnessWithOutput::write_to_stdout(conn)))
.unwrap_or_else(handle_error);
} else {
migrations_inner(&mut HarnessWithOutput::write_to_stdout(conn))
.unwrap_or_else(handle_error);
}
}
#[cfg(feature = "mysql")]
fn should_redo_migration_in_transaction(t: &dyn Any) -> bool {
!t.is::<::diesel::mysql::MysqlConnection>()
}
#[cfg(not(feature = "mysql"))]
fn should_redo_migration_in_transaction(_t: &dyn Any) -> bool {
true
}
#[allow(clippy::needless_pass_by_value)]
fn handle_error<E: Display, T>(error: E) -> T {
eprintln!("{}", error);
::std::process::exit(1);
}
fn convert_absolute_path_to_relative(target_path: &Path, mut current_path: &Path) -> PathBuf {
let mut result = PathBuf::new();
while !target_path.starts_with(current_path) {
result.push("..");
match current_path.parent() {
Some(parent) => current_path = parent,
None => return target_path.into(),
}
}
result.join(target_path.strip_prefix(current_path).unwrap())
}
fn run_infer_schema(matches: &ArgMatches) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
use crate::print_schema::*;
let database_url = database::database_url(matches);
let mut config = Config::read(matches)?.print_schema;
if let Some(schema_name) = matches.get_one::<String>("schema") {
config.schema = Some(schema_name.clone())
}
let filter = matches
.get_many::<String>("table-name")
.unwrap_or_default()
.map(|table_name_regex| Regex::new(table_name_regex).map(Into::into))
.collect::<Result<_, _>>()
.map_err(|e| format!("invalid argument for table filtering regex: {}", e));
if matches.contains_id("only-tables") {
config.filter = Filtering::OnlyTables(filter?)
} else if matches.contains_id("except-tables") {
config.filter = Filtering::ExceptTables(filter?)
}
if matches.contains_id("with-docs") {
config.with_docs = true;
}
if let Some(sorting) = matches.get_one::<String>("column-sorting") {
match sorting as &str {
"ordinal_position" => config.column_sorting = ColumnSorting::OrdinalPosition,
"name" => config.column_sorting = ColumnSorting::Name,
_ => return Err(format!("Invalid column sorting mode: {}", sorting).into()),
}
}
if let Some(path) = matches.get_one::<PathBuf>("patch-file") {
config.patch_file = Some(path.clone());
}
if let Some(types) = matches.get_many("import-types") {
let types = types.cloned().collect();
config.import_types = Some(types);
}
if matches.contains_id("generate-custom-type-definitions") {
config.generate_missing_sql_type_definitions = Some(false);
}
run_print_schema(&database_url, &config, &mut stdout())?;
Ok(())
}
fn regenerate_schema_if_file_specified(
matches: &ArgMatches,
) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
use std::io::Read;
let config = Config::read(matches)?;
if let Some(ref path) = config.print_schema.file {
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)?;
}
let database_url = database::database_url(matches);
if matches.contains_id("LOCKED_SCHEMA") {
let mut buf = Vec::new();
print_schema::run_print_schema(&database_url, &config.print_schema, &mut buf)?;
let mut old_buf = Vec::new();
let mut file = fs::File::open(path)?;
file.read_to_end(&mut old_buf)?;
if buf != old_buf {
return Err(format!(
"Command would result in changes to {}. \
Rerun the command locally, and commit the changes.",
path.display()
)
.into());
}
} else {
use std::io::Write;
let mut file = fs::File::create(path)?;
let schema = print_schema::output_schema(&database_url, &config.print_schema)?;
file.write_all(schema.as_bytes())?;
}
}
Ok(())
}
#[cfg(test)]
mod tests {
extern crate tempfile;
use crate::database_error::DatabaseError;
use self::tempfile::Builder;
use std::fs;
use std::path::PathBuf;
use super::convert_absolute_path_to_relative;
use super::search_for_directory_containing_file;
#[test]
fn toml_directory_find_cargo_toml() {
let dir = Builder::new().prefix("diesel").tempdir().unwrap();
let temp_path = dir.path().canonicalize().unwrap();
let toml_path = temp_path.join("Cargo.toml");
fs::File::create(&toml_path).unwrap();
assert_eq!(
Ok(temp_path.clone()),
search_for_directory_containing_file(&temp_path, "Cargo.toml")
);
}
#[test]
fn cargo_toml_not_found_if_no_cargo_toml() {
let dir = Builder::new().prefix("diesel").tempdir().unwrap();
let temp_path = dir.path().canonicalize().unwrap();
assert_eq!(
Err(DatabaseError::ProjectRootNotFound(temp_path.clone())),
search_for_directory_containing_file(&temp_path, "Cargo.toml")
);
}
#[test]
fn convert_absolute_path_to_relative_works() {
assert_eq!(
PathBuf::from("migrations/12345_create_user"),
convert_absolute_path_to_relative(
&PathBuf::from("projects/foo/migrations/12345_create_user"),
&PathBuf::from("projects/foo")
)
);
assert_eq!(
PathBuf::from("../migrations/12345_create_user"),
convert_absolute_path_to_relative(
&PathBuf::from("projects/foo/migrations/12345_create_user"),
&PathBuf::from("projects/foo/src")
)
);
assert_eq!(
PathBuf::from("../../../migrations/12345_create_user"),
convert_absolute_path_to_relative(
&PathBuf::from("projects/foo/migrations/12345_create_user"),
&PathBuf::from("projects/foo/src/controllers/errors")
)
);
assert_eq!(
PathBuf::from("12345_create_user"),
convert_absolute_path_to_relative(
&PathBuf::from("projects/foo/migrations/12345_create_user"),
&PathBuf::from("projects/foo/migrations")
)
);
assert_eq!(
PathBuf::from("../12345_create_user"),
convert_absolute_path_to_relative(
&PathBuf::from("projects/foo/migrations/12345_create_user"),
&PathBuf::from("projects/foo/migrations/67890_create_post")
)
);
}
}