use std::str::FromStr;
use std::time::Duration;
use std::time::Instant;
use futures_core::future::BoxFuture;
pub(crate) use sqlx_core::migrate::*;
use sqlx_core::sql_str::AssertSqlSafe;
use crate::connection::{ConnectOptions, Connection};
use crate::error::Error;
use crate::executor::Executor;
use crate::query::query;
use crate::query_as::query_as;
use crate::query_scalar::query_scalar;
use crate::{MySql, MySqlConnectOptions, MySqlConnection};
fn parse_for_maintenance(url: &str) -> Result<(MySqlConnectOptions, String), Error> {
let mut options = MySqlConnectOptions::from_str(url)?;
let database = if let Some(database) = &options.database {
database.to_owned()
} else {
return Err(Error::Configuration(
"DATABASE_URL does not specify a database".into(),
));
};
options.database = None;
Ok((options, database))
}
impl MigrateDatabase for MySql {
async fn create_database(url: &str) -> Result<(), Error> {
let (options, database) = parse_for_maintenance(url)?;
let mut conn = options.connect().await?;
let _ = conn
.execute(AssertSqlSafe(format!("CREATE DATABASE `{database}`")))
.await?;
Ok(())
}
async fn database_exists(url: &str) -> Result<bool, Error> {
let (options, database) = parse_for_maintenance(url)?;
let mut conn = options.connect().await?;
let exists: bool = query_scalar(
"select exists(SELECT 1 from INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = ?)",
)
.bind(database)
.fetch_one(&mut conn)
.await?;
Ok(exists)
}
async fn drop_database(url: &str) -> Result<(), Error> {
let (options, database) = parse_for_maintenance(url)?;
let mut conn = options.connect().await?;
let _ = conn
.execute(AssertSqlSafe(format!(
"DROP DATABASE IF EXISTS `{database}`"
)))
.await?;
Ok(())
}
}
impl Migrate for MySqlConnection {
fn create_schema_if_not_exists<'e>(
&'e mut self,
schema_name: &'e str,
) -> BoxFuture<'e, Result<(), MigrateError>> {
Box::pin(async move {
self.execute(AssertSqlSafe(format!(
r#"CREATE SCHEMA IF NOT EXISTS {schema_name};"#
)))
.await?;
Ok(())
})
}
fn ensure_migrations_table<'e>(
&'e mut self,
table_name: &'e str,
) -> BoxFuture<'e, Result<(), MigrateError>> {
Box::pin(async move {
self.execute(AssertSqlSafe(format!(
r#"
CREATE TABLE IF NOT EXISTS {table_name} (
version BIGINT PRIMARY KEY,
description TEXT NOT NULL,
installed_on TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
success BOOLEAN NOT NULL,
checksum BLOB NOT NULL,
execution_time BIGINT NOT NULL
);
"#
)))
.await?;
Ok(())
})
}
fn dirty_version<'e>(
&'e mut self,
table_name: &'e str,
) -> BoxFuture<'e, Result<Option<i64>, MigrateError>> {
Box::pin(async move {
let row: Option<(i64,)> = query_as(AssertSqlSafe(format!(
"SELECT version FROM {table_name} WHERE success = false ORDER BY version LIMIT 1"
)))
.fetch_optional(self)
.await?;
Ok(row.map(|r| r.0))
})
}
fn list_applied_migrations<'e>(
&'e mut self,
table_name: &'e str,
) -> BoxFuture<'e, Result<Vec<AppliedMigration>, MigrateError>> {
Box::pin(async move {
let rows: Vec<(i64, Vec<u8>)> = query_as(AssertSqlSafe(format!(
"SELECT version, checksum FROM {table_name} ORDER BY version"
)))
.fetch_all(self)
.await?;
let migrations = rows
.into_iter()
.map(|(version, checksum)| AppliedMigration {
version,
checksum: checksum.into(),
})
.collect();
Ok(migrations)
})
}
fn lock(&mut self) -> BoxFuture<'_, Result<(), MigrateError>> {
Box::pin(async move {
let database_name = current_database(self).await?;
let lock_id = generate_lock_id(&database_name);
let _ = query("SELECT GET_LOCK(?, -1)")
.bind(lock_id)
.execute(self)
.await?;
Ok(())
})
}
fn unlock(&mut self) -> BoxFuture<'_, Result<(), MigrateError>> {
Box::pin(async move {
let database_name = current_database(self).await?;
let lock_id = generate_lock_id(&database_name);
let _ = query("SELECT RELEASE_LOCK(?)")
.bind(lock_id)
.execute(self)
.await?;
Ok(())
})
}
fn apply<'e>(
&'e mut self,
table_name: &'e str,
migration: &'e Migration,
) -> BoxFuture<'e, Result<Duration, MigrateError>> {
Box::pin(async move {
let mut tx = self.begin().await?;
let start = Instant::now();
let _ = query(AssertSqlSafe(format!(
r#"
INSERT INTO {table_name} ( version, description, success, checksum, execution_time )
VALUES ( ?, ?, FALSE, ?, -1 )
"#
)))
.bind(migration.version)
.bind(&*migration.description)
.bind(&*migration.checksum)
.execute(&mut *tx)
.await?;
let _ = tx
.execute(migration.sql.clone())
.await
.map_err(|e| MigrateError::ExecuteMigration(e, migration.version))?;
let _ = query(AssertSqlSafe(format!(
r#"
UPDATE {table_name}
SET success = TRUE
WHERE version = ?
"#
)))
.bind(migration.version)
.execute(&mut *tx)
.await?;
tx.commit().await?;
let elapsed = start.elapsed();
#[allow(clippy::cast_possible_truncation)]
let _ = query(AssertSqlSafe(format!(
r#"
UPDATE {table_name}
SET execution_time = ?
WHERE version = ?
"#
)))
.bind(elapsed.as_nanos() as i64)
.bind(migration.version)
.execute(self)
.await?;
Ok(elapsed)
})
}
fn revert<'e>(
&'e mut self,
table_name: &'e str,
migration: &'e Migration,
) -> BoxFuture<'e, Result<Duration, MigrateError>> {
Box::pin(async move {
let mut tx = self.begin().await?;
let start = Instant::now();
let _ = query(AssertSqlSafe(format!(
r#"
UPDATE {table_name}
SET success = FALSE
WHERE version = ?
"#
)))
.bind(migration.version)
.execute(&mut *tx)
.await?;
tx.execute(migration.sql.clone()).await?;
let _ = query(AssertSqlSafe(format!(
r#"DELETE FROM {table_name} WHERE version = ?"#
)))
.bind(migration.version)
.execute(&mut *tx)
.await?;
tx.commit().await?;
let elapsed = start.elapsed();
Ok(elapsed)
})
}
fn skip<'e>(
&'e mut self,
table_name: &'e str,
migration: &'e Migration,
) -> BoxFuture<'e, Result<(), MigrateError>> {
Box::pin(async move {
let _ = query(AssertSqlSafe(format!(
r#"
INSERT INTO {table_name} ( version, description, success, checksum, execution_time )
VALUES ( ?, ?, TRUE, ?, -1 )
"#
)))
.bind(migration.version)
.bind(&*migration.description)
.bind(&*migration.checksum)
.execute(self)
.await?;
Ok(())
})
}
}
async fn current_database(conn: &mut MySqlConnection) -> Result<String, MigrateError> {
Ok(query_scalar("SELECT DATABASE()").fetch_one(conn).await?)
}
fn generate_lock_id(database_name: &str) -> String {
const CRC_IEEE: crc::Crc<u32> = crc::Crc::<u32>::new(&crc::CRC_32_ISO_HDLC);
format!(
"{:x}",
0x3d32ad9e * (CRC_IEEE.checksum(database_name.as_bytes()) as i64)
)
}