use crate::{Mssql, MssqlConnection, MssqlConnectOptions};
use futures_core::future::BoxFuture;
use sqlx_core::error::Error;
use sqlx_core::migrate::{AppliedMigration, Migrate, MigrateDatabase, MigrateError, Migration};
use std::str::FromStr;
use std::time::Duration;
use url::Url;
fn extract_database_name(url: &str) -> std::result::Result<String, Error> {
let parsed = Url::parse(url).map_err(|e| {
Error::Protocol(format!("failed to parse migration URL: {e}"))
})?;
let database = parsed.path().trim_start_matches('/').to_owned();
if database.is_empty() {
return Err(Error::Configuration(
"migration URL does not contain a database name".into(),
));
}
Ok(database)
}
fn escape_sql_bracket(value: &str) -> String {
value.replace(']', "]]")
}
fn escape_sql_string(value: &str) -> String {
value.replace('\'', "''")
}
fn format_hex(bytes: &[u8]) -> String {
let mut hex = String::with_capacity(2 + bytes.len() * 2);
hex.push_str("0x");
for byte in bytes {
hex.push_str(&format!("{byte:02X}"));
}
hex
}
fn split_table_name(table_name: &str) -> (&str, &str) {
if let Some(dot) = table_name.find('.') {
let schema = &table_name[..dot];
let table = &table_name[dot + 1..];
(schema, table)
} else {
("", table_name)
}
}
fn quoted_table_name(table_name: &str) -> String {
let (schema, table) = split_table_name(table_name);
if schema.is_empty() {
format!("[{}]", escape_sql_bracket(table))
} else {
format!(
"[{}].[{}]",
escape_sql_bracket(schema),
escape_sql_bracket(table),
)
}
}
impl MigrateDatabase for Mssql {
fn create_database(url: &str) -> impl std::future::Future<Output = Result<(), Error>> + Send + '_ {
async move {
let options = MssqlConnectOptions::from_str(url)?;
let database = extract_database_name(url)?;
let master_options = options.with_database("master");
let conn = MssqlConnection::connect_blocking(&master_options)?;
conn.exec_sql_blocking(&format!(
"CREATE DATABASE [{}]",
escape_sql_bracket(&database),
))?;
drop(conn);
Ok(())
}
}
fn database_exists(url: &str) -> impl std::future::Future<Output = Result<bool, Error>> + Send + '_ {
async move {
let options = MssqlConnectOptions::from_str(url)?;
if MssqlConnection::connect_blocking(&options).is_ok() {
return Ok(true);
}
let database = extract_database_name(url)?;
let master_options = options.with_database("master");
let conn = match MssqlConnection::connect_blocking(&master_options) {
Ok(conn) => conn,
Err(_) => return Ok(false),
};
let sql = format!(
"SELECT COUNT(*) FROM sys.databases WHERE name = N'{}'",
escape_sql_string(&database),
);
let count = conn
.scalar_i64_blocking(&sql)?
.unwrap_or(0);
drop(conn);
Ok(count > 0)
}
}
fn drop_database(url: &str) -> impl std::future::Future<Output = Result<(), Error>> + Send + '_ {
async move {
let options = MssqlConnectOptions::from_str(url)?;
let database = extract_database_name(url)?;
let master_options = options.with_database("master");
let conn = MssqlConnection::connect_blocking(&master_options)?;
conn.exec_sql_blocking(&format!(
"DROP DATABASE IF EXISTS [{}]",
escape_sql_bracket(&database),
))?;
drop(conn);
Ok(())
}
}
}
impl Migrate for MssqlConnection {
fn create_schema_if_not_exists<'e>(
&'e mut self,
schema_name: &'e str,
) -> BoxFuture<'e, Result<(), MigrateError>> {
let sql = format!(
"IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = N'{}') \
EXEC('CREATE SCHEMA [{}]')",
escape_sql_string(schema_name),
escape_sql_bracket(schema_name),
);
Box::pin(async move {
self.exec_sql_blocking(&sql).map_err(MigrateError::Execute)?;
Ok(())
})
}
fn ensure_migrations_table<'e>(
&'e mut self,
table_name: &'e str,
) -> BoxFuture<'e, Result<(), MigrateError>> {
let quoted = quoted_table_name(table_name);
let (schema, table) = split_table_name(table_name);
let schema_condition = if schema.is_empty() {
"TABLE_SCHEMA = 'dbo'".to_owned()
} else {
format!("TABLE_SCHEMA = N'{}'", escape_sql_string(schema))
};
let create_sql = format!(
"IF NOT EXISTS ( \
SELECT * FROM INFORMATION_SCHEMA.TABLES \
WHERE TABLE_NAME = N'{table}' AND {schema_condition} \
) \
CREATE TABLE {quoted} ( \
version BIGINT NOT NULL PRIMARY KEY, \
description NVARCHAR(MAX) NOT NULL, \
migration_type NVARCHAR(20) NOT NULL, \
sql NVARCHAR(MAX) NOT NULL, \
checksum VARBINARY(8000) NOT NULL, \
executed_at DATETIME2 NOT NULL DEFAULT GETUTCDATE(), \
no_tx BIT NOT NULL DEFAULT 0 \
)",
table = escape_sql_string(table),
schema_condition = schema_condition,
quoted = quoted,
);
Box::pin(async move {
self.exec_sql_blocking(&create_sql).map_err(MigrateError::Execute)?;
Ok(())
})
}
fn dirty_version<'e>(
&'e mut self,
_table_name: &'e str,
) -> BoxFuture<'e, Result<Option<i64>, MigrateError>> {
Box::pin(async move { Ok(None) })
}
fn list_applied_migrations<'e>(
&'e mut self,
table_name: &'e str,
) -> BoxFuture<'e, Result<Vec<AppliedMigration>, MigrateError>> {
let quoted = quoted_table_name(table_name);
let sql = format!(
"SELECT version, checksum FROM {quoted} ORDER BY version",
);
Box::pin(async move {
let rows = self.list_migrations_blocking(&sql)?;
let migrations = rows
.into_iter()
.map(|(version, checksum)| AppliedMigration {
version,
checksum: checksum.into(),
})
.collect();
Ok(migrations)
})
}
fn lock(&mut self) -> BoxFuture<'_, Result<(), MigrateError>> {
Box::pin(async move {
self.exec_sql_blocking(
"EXEC sp_getapplock \
@Resource = N'sqlx_migration_lock', \
@LockMode = 'Exclusive', \
@LockOwner = 'Session'",
)
.map_err(MigrateError::Execute)?;
Ok(())
})
}
fn unlock(&mut self) -> BoxFuture<'_, Result<(), MigrateError>> {
Box::pin(async move {
self.exec_sql_blocking(
"EXEC sp_releaseapplock \
@Resource = N'sqlx_migration_lock', \
@LockOwner = 'Session'",
)
.map_err(MigrateError::Execute)?;
Ok(())
})
}
fn apply<'e>(
&'e mut self,
_table_name: &'e str,
migration: &'e Migration,
) -> BoxFuture<'e, Result<Duration, MigrateError>> {
let quoted = quoted_table_name(_table_name);
let sql = migration.sql.as_str().to_owned();
let version = migration.version;
let description = migration.description.to_string();
let migration_type = format!("{:?}", migration.migration_type);
let checksum = migration.checksum.to_vec();
let no_tx = migration.no_tx;
let insert_sql = format!(
"INSERT INTO {quoted} \
(version, description, migration_type, sql, checksum, no_tx) \
VALUES ({version}, N'{desc}', N'{mt}', N'{sql_text}', {chk}, {ntx})",
quoted = quoted,
version = version,
desc = escape_sql_string(&description),
mt = escape_sql_string(&migration_type),
sql_text = escape_sql_string(migration.sql.as_str()),
chk = format_hex(&checksum),
ntx = if no_tx { 1 } else { 0 },
);
Box::pin(async move {
self.apply_migration_blocking(&sql, &insert_sql, version, no_tx)
.map_err(|e| MigrateError::ExecuteMigration(e, version))
})
}
fn revert<'e>(
&'e mut self,
_table_name: &'e str,
migration: &'e Migration,
) -> BoxFuture<'e, Result<Duration, MigrateError>> {
let quoted = quoted_table_name(_table_name);
let sql = migration.sql.as_str().to_owned();
let version = migration.version;
let no_tx = migration.no_tx;
let delete_sql = format!(
"DELETE FROM {quoted} WHERE version = {version}",
quoted = quoted,
version = version,
);
Box::pin(async move {
self.revert_migration_blocking(&sql, &delete_sql, version, no_tx)
.map_err(|e| MigrateError::ExecuteMigration(e, version))
})
}
fn skip<'e>(
&'e mut self,
_table_name: &'e str,
_migration: &'e Migration,
) -> BoxFuture<'e, Result<(), MigrateError>> {
let quoted = quoted_table_name(_table_name);
let version = _migration.version;
let description = _migration.description.to_string();
let migration_type = format!("{:?}", _migration.migration_type);
let checksum = _migration.checksum.to_vec();
let no_tx = _migration.no_tx;
Box::pin(async move {
let insert_sql = format!(
"INSERT INTO {quoted} \
(version, description, migration_type, sql, checksum, no_tx) \
VALUES ({version}, N'{desc}', N'{mt}', N'', {chk}, {ntx})",
quoted = quoted,
version = version,
desc = escape_sql_string(&description),
mt = escape_sql_string(&migration_type),
chk = format_hex(&checksum),
ntx = if no_tx { 1 } else { 0 },
);
self.exec_sql_blocking(&insert_sql)
.map_err(|e| MigrateError::ExecuteMigration(e, version))
})
}
}