#[cfg(any(feature = "postgres", feature = "mysql"))]
use super::query_helper;
use clap::ArgMatches;
use diesel::connection::InstrumentationEvent;
use diesel::dsl::sql;
use diesel::sql_types::Bool;
use diesel::*;
use diesel_migrations::FileBasedMigrations;
use std::env;
#[cfg(feature = "postgres")]
use std::fs::{self};
use std::path::Path;
#[derive(Copy, Clone, Eq, PartialEq)]
pub enum Backend {
#[cfg(feature = "postgres")]
Pg,
#[cfg(feature = "sqlite")]
Sqlite,
#[cfg(feature = "mysql")]
Mysql,
}
impl Backend {
pub fn for_url(database_url: &str) -> Self {
match database_url {
_ if database_url.starts_with("postgres://")
|| database_url.starts_with("postgresql://") =>
{
#[cfg(feature = "postgres")]
{
Backend::Pg
}
#[cfg(not(feature = "postgres"))]
{
panic!(
"Database url `{database_url}` requires the `postgres` feature but it's not enabled."
);
}
}
_ if database_url.starts_with("mysql://") =>
{
#[cfg(feature = "mysql")]
{
Backend::Mysql
}
#[cfg(not(feature = "mysql"))]
{
panic!(
"Database url `{database_url}` requires the `mysql` feature but it's not enabled."
);
}
}
#[cfg(feature = "sqlite")]
_ => Backend::Sqlite,
#[cfg(not(feature = "sqlite"))]
_ => {
if database_url.starts_with("sqlite://") {
panic!(
"Database url `{database_url}` requires the `sqlite` feature but it's not enabled."
);
}
let mut available_schemes: Vec<&str> = Vec::new();
if cfg!(feature = "postgres") {
available_schemes.push("`postgres://`");
}
if cfg!(feature = "mysql") {
available_schemes.push("`mysql://`");
}
panic!(
"`{}` is not a valid database URL. It should start with {}, or maybe you meant to use the `sqlite` feature which is not enabled.",
database_url,
available_schemes.join(" or ")
);
}
#[cfg(not(any(feature = "mysql", feature = "sqlite", feature = "postgres")))]
_ => compile_error!(
"At least one backend must be specified for use with this crate. \
You may omit the unneeded dependencies in the following command. \n\n \
ex. `cargo install diesel_cli --no-default-features --features mysql postgres sqlite` \n"
),
}
}
pub(crate) fn for_connection(connection: &InferConnection) -> Backend {
match connection {
#[cfg(feature = "postgres")]
InferConnection::Pg(_) => Self::Pg,
#[cfg(feature = "sqlite")]
InferConnection::Sqlite(_) => Self::Sqlite,
#[cfg(feature = "mysql")]
InferConnection::Mysql(_) => Self::Mysql,
}
}
}
#[derive(diesel::MultiConnection)]
pub enum InferConnection {
#[cfg(feature = "postgres")]
Pg(PgConnection),
#[cfg(feature = "sqlite")]
Sqlite(SqliteConnection),
#[cfg(feature = "mysql")]
Mysql(MysqlConnection),
}
impl InferConnection {
pub fn from_matches(matches: &ArgMatches) -> Result<Self, crate::errors::Error> {
let database_url = database_url(matches)?;
Self::from_url(database_url)
}
fn from_url(database_url: String) -> Result<InferConnection, crate::errors::Error> {
let result = match Backend::for_url(&database_url) {
#[cfg(feature = "postgres")]
Backend::Pg => PgConnection::establish(&database_url).map(Self::Pg),
#[cfg(feature = "mysql")]
Backend::Mysql => MysqlConnection::establish(&database_url).map(Self::Mysql),
#[cfg(feature = "sqlite")]
Backend::Sqlite => SqliteConnection::establish(&database_url).map(Self::Sqlite),
};
let mut conn = result.map_err(|err| crate::errors::Error::ConnectionError {
error: err,
url: database_url,
})?;
conn.set_instrumentation(|event: InstrumentationEvent<'_>| {
if let InstrumentationEvent::FinishQuery { query, error, .. } = event {
if let Some(err) = error {
tracing::error!(?query, ?err, "Failed to execute query");
} else {
tracing::debug!(?query);
}
}
});
Ok(conn)
}
}
pub fn reset_database(
args: &ArgMatches,
migrations_dir: &Path,
) -> Result<(), crate::errors::Error> {
drop_database(&database_url(args)?)?;
setup_database(args, migrations_dir)
}
pub fn setup_database(
args: &ArgMatches,
migrations_dir: &Path,
) -> Result<(), crate::errors::Error> {
let database_url = database_url(args)?;
create_database_if_needed(&database_url)?;
let default_migrations = !args.get_flag("NO_DEFAULT_MIGRATION");
if default_migrations {
create_default_migration_if_needed(&database_url, migrations_dir)?;
}
create_schema_table_and_run_migrations_if_needed(&database_url, migrations_dir)?;
Ok(())
}
pub fn drop_database_command(args: &ArgMatches) -> Result<(), crate::errors::Error> {
drop_database(&database_url(args)?)
}
fn create_database_if_needed(database_url: &str) -> Result<(), crate::errors::Error> {
match Backend::for_url(database_url) {
#[cfg(feature = "postgres")]
Backend::Pg => {
if PgConnection::establish(database_url).is_err() {
let (database, postgres_url) = change_database_of_url(database_url, "postgres")?;
println!("Creating database: {database}");
let mut conn = PgConnection::establish(&postgres_url).map_err(|error| {
crate::errors::Error::ConnectionError {
error,
url: postgres_url,
}
})?;
query_helper::create_database(&database).execute(&mut conn)?;
}
}
#[cfg(feature = "sqlite")]
Backend::Sqlite => {
let path = path_from_sqlite_url(database_url)?;
if !path.exists() {
println!("Creating database: {database_url}");
SqliteConnection::establish(database_url).map_err(|error| {
crate::errors::Error::ConnectionError {
error,
url: database_url.to_owned(),
}
})?;
}
}
#[cfg(feature = "mysql")]
Backend::Mysql => {
if MysqlConnection::establish(database_url).is_err() {
let (database, mysql_url) =
change_database_of_url(database_url, "information_schema")?;
println!("Creating database: {database}");
let mut conn = MysqlConnection::establish(&mysql_url).map_err(|error| {
crate::errors::Error::ConnectionError {
error,
url: mysql_url,
}
})?;
query_helper::create_database(&database).execute(&mut conn)?;
}
}
}
Ok(())
}
fn create_default_migration_if_needed(
database_url: &str,
migrations_dir: &Path,
) -> Result<(), crate::errors::Error> {
let initial_migration_path = migrations_dir.join("00000000000000_diesel_initial_setup");
if initial_migration_path.exists() {
return Ok(());
}
#[allow(unreachable_patterns, clippy::single_match)]
match Backend::for_url(database_url) {
#[cfg(feature = "postgres")]
Backend::Pg => {
fs::create_dir_all(&initial_migration_path).map_err(|e| {
crate::errors::Error::IoError(e, Some(initial_migration_path.clone()))
})?;
let up_sql_file = initial_migration_path.join("up.sql");
std::fs::write(
&up_sql_file,
include_bytes!("setup_sql/postgres/initial_setup/up.sql"),
)
.map_err(|e| crate::errors::Error::IoError(e, Some(up_sql_file.clone())))?;
let down_sql_file = initial_migration_path.join("down.sql");
std::fs::write(
&down_sql_file,
include_bytes!("setup_sql/postgres/initial_setup/down.sql"),
)
.map_err(|e| crate::errors::Error::IoError(e, Some(down_sql_file.clone())))?;
}
_ => {} }
Ok(())
}
fn create_schema_table_and_run_migrations_if_needed(
database_url: &str,
migrations_dir: &Path,
) -> Result<(), crate::errors::Error> {
if !schema_table_exists(database_url)? {
let migrations = FileBasedMigrations::from_path(migrations_dir)
.map_err(|e| crate::errors::Error::from_migration_error(e, Some(migrations_dir)))?;
let mut conn = InferConnection::from_url(database_url.to_owned())?;
super::run_migrations_with_output(&mut conn, migrations)
.map_err(crate::errors::Error::MigrationError)?;
};
Ok(())
}
fn drop_database(database_url: &str) -> Result<(), crate::errors::Error> {
match Backend::for_url(database_url) {
#[cfg(feature = "postgres")]
Backend::Pg => {
let (current_database, _) = get_database_and_url(database_url)?;
let default_database = if current_database.eq("postgres") {
"template1"
} else {
"postgres"
};
let (database, postgres_url) = change_database_of_url(database_url, default_database)?;
let mut conn = PgConnection::establish(&postgres_url).map_err(|e| {
crate::errors::Error::ConnectionError {
error: e,
url: postgres_url,
}
})?;
if pg_database_exists(&mut conn, &database)? {
println!("Dropping database: {database}");
query_helper::drop_database(&database)
.if_exists()
.execute(&mut conn)?;
}
}
#[cfg(feature = "sqlite")]
Backend::Sqlite => {
if Path::new(database_url).exists() {
println!("Dropping database: {database_url}");
std::fs::remove_file(database_url).map_err(|e| {
crate::errors::Error::IoError(e, Some(std::path::PathBuf::from(database_url)))
})?;
}
}
#[cfg(feature = "mysql")]
Backend::Mysql => {
let (database, mysql_url) = change_database_of_url(database_url, "information_schema")?;
let mut conn = MysqlConnection::establish(&mysql_url).map_err(|e| {
crate::errors::Error::ConnectionError {
error: e,
url: mysql_url,
}
})?;
if mysql_database_exists(&mut conn, &database)? {
println!("Dropping database: {database}");
query_helper::drop_database(&database)
.if_exists()
.execute(&mut conn)?;
}
}
}
Ok(())
}
#[cfg(feature = "postgres")]
table! {
pg_database (datname) {
datname -> Text,
datistemplate -> Bool,
}
}
#[cfg(feature = "postgres")]
fn pg_database_exists(conn: &mut PgConnection, database_name: &str) -> QueryResult<bool> {
use self::pg_database::dsl::*;
pg_database
.select(datname)
.filter(datname.eq(database_name))
.filter(datistemplate.eq(false))
.get_result::<String>(conn)
.optional()
.map(|x| x.is_some())
}
#[cfg(feature = "mysql")]
table! {
information_schema.schemata (schema_name) {
schema_name -> Text,
}
}
#[cfg(feature = "mysql")]
fn mysql_database_exists(conn: &mut MysqlConnection, database_name: &str) -> QueryResult<bool> {
use self::schemata::dsl::*;
schemata
.select(schema_name)
.filter(schema_name.eq(database_name))
.get_result::<String>(conn)
.optional()
.map(|x| x.is_some())
}
pub fn schema_table_exists(database_url: &str) -> Result<bool, crate::errors::Error> {
match InferConnection::from_url(database_url.to_owned())? {
#[cfg(feature = "postgres")]
InferConnection::Pg(mut conn) => select(sql::<Bool>(
"EXISTS \
(SELECT 1 \
FROM information_schema.tables \
WHERE table_name = '__diesel_schema_migrations')",
))
.get_result(&mut conn),
#[cfg(feature = "sqlite")]
InferConnection::Sqlite(mut conn) => select(sql::<Bool>(
"EXISTS \
(SELECT 1 \
FROM sqlite_master \
WHERE type = 'table' \
AND name = '__diesel_schema_migrations')",
))
.get_result(&mut conn),
#[cfg(feature = "mysql")]
InferConnection::Mysql(mut conn) => select(sql::<Bool>(
"EXISTS \
(SELECT 1 \
FROM information_schema.tables \
WHERE table_name = '__diesel_schema_migrations'
AND table_schema = DATABASE())",
))
.get_result(&mut conn),
}
.map_err(Into::into)
}
pub fn database_url(matches: &ArgMatches) -> Result<String, crate::errors::Error> {
matches
.get_one::<String>("DATABASE_URL")
.cloned()
.or_else(|| env::var("DATABASE_URL").ok())
.ok_or(crate::errors::Error::DatabaseUrlMissing)
}
#[cfg(any(feature = "postgres", feature = "mysql"))]
fn change_database_of_url(
database_url: &str,
default_database: &str,
) -> Result<(String, String), crate::errors::Error> {
let (database, base) = get_database_and_url(database_url)?;
let mut new_url = base
.join(default_database)
.expect("The provided database is always valid");
new_url.set_query(base.query());
Ok((database, new_url.into()))
}
#[cfg(any(feature = "postgres", feature = "mysql"))]
fn get_database_and_url(database_url: &str) -> Result<(String, url::Url), crate::errors::Error> {
let base = url::Url::parse(database_url)?;
let database = base
.path_segments()
.expect("The database url has at least one path segment")
.next_back()
.expect("The database url has at least one path segment")
.to_owned();
Ok((database, base))
}
#[cfg(feature = "sqlite")]
fn path_from_sqlite_url(database_url: &str) -> Result<std::path::PathBuf, crate::errors::Error> {
if database_url.starts_with("file:/") {
match ::url::Url::parse(database_url) {
Ok(url) if url.scheme() == "file" => {
Ok(url
.to_file_path()
.map_err(|_err| crate::errors::Error::ConnectionError {
error: result::ConnectionError::InvalidConnectionUrl(String::from(
database_url,
)),
url: database_url.into(),
})?)
}
_ => {
Err(crate::errors::Error::ConnectionError {
error: result::ConnectionError::InvalidConnectionUrl(String::from(
database_url,
)),
url: database_url.into(),
})
}
}
} else {
Ok(::std::path::PathBuf::from(database_url))
}
}
#[cfg(all(test, any(feature = "postgres", feature = "mysql")))]
mod tests {
use super::change_database_of_url;
#[test]
fn split_pg_connection_string_returns_postgres_url_and_database() {
let database = "database".to_owned();
let base_url = "postgresql://localhost:5432".to_owned();
let database_url = format!("{base_url}/{database}");
let postgres_url = format!("{}/{}", base_url, "postgres");
assert_eq!(
(database, postgres_url),
change_database_of_url(&database_url, "postgres").unwrap()
);
}
#[test]
fn split_pg_connection_string_handles_user_and_password() {
let database = "database".to_owned();
let base_url = "postgresql://user:password@localhost:5432".to_owned();
let database_url = format!("{base_url}/{database}");
let postgres_url = format!("{}/{}", base_url, "postgres");
assert_eq!(
(database, postgres_url),
change_database_of_url(&database_url, "postgres").unwrap()
);
}
#[test]
fn split_pg_connection_string_handles_query_string() {
let database = "database".to_owned();
let query = "?sslmode=true".to_owned();
let base_url = "postgresql://user:password@localhost:5432".to_owned();
let database_url = format!("{base_url}/{database}{query}");
let postgres_url = format!("{}/{}{}", base_url, "postgres", query);
assert_eq!(
(database, postgres_url),
change_database_of_url(&database_url, "postgres").unwrap()
);
}
}