use std::{collections::HashMap, future::Future, pin::Pin};
#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mariadb"))]
use quex::{self, FromRow, Row};
#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mariadb"))]
use crate::context::execute_table_blueprint;
use crate::{
AlterTableBlueprint, BlueprintExecutor, ColumnType, IndexBlueprint, IntoSchemaColumns,
MigrationError, SchemaDialect, TableBlueprint,
};
pub type MigrationFuture<'a> = Pin<Box<dyn Future<Output = Result<(), MigrationError>> + 'a>>;
#[allow(dead_code)]
fn no_backend_error() -> MigrationError {
MigrationError::BackendNotEnabled("no backend")
}
#[derive(Clone, Copy)]
pub struct MigrationEntry {
pub name: &'static str,
pub version: u64,
pub up: for<'a> fn(&'a mut MigrationContext<'a>) -> MigrationFuture<'a>,
pub down: for<'a> fn(&'a mut MigrationContext<'a>) -> MigrationFuture<'a>,
}
impl MigrationEntry {
pub const fn new(
name: &'static str,
version: u64,
up: for<'a> fn(&'a mut MigrationContext<'a>) -> MigrationFuture<'a>,
down: for<'a> fn(&'a mut MigrationContext<'a>) -> MigrationFuture<'a>,
) -> Self {
Self {
name,
version,
up,
down,
}
}
}
inventory::collect!(MigrationEntry);
#[allow(async_fn_in_trait)]
pub trait Migration {
async fn up(ctx: &mut MigrationContext<'_>) -> Result<(), MigrationError>;
async fn down(ctx: &mut MigrationContext<'_>) -> Result<(), MigrationError>;
}
#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mariadb"))]
enum MigrationExecutor<'a> {
Pool(&'a quex::Pool),
Transaction(&'a mut quex::PoolTransaction),
}
#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mariadb"))]
type ColumnTypeCache = HashMap<(String, String), ColumnType>;
#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mariadb"))]
impl MigrationExecutor<'_> {
async fn execute_raw(&mut self, sql: &str) -> Result<u64, MigrationError> {
let result = match self {
Self::Pool(pool) => quex::query(sql).execute(*pool).await?,
Self::Transaction(tx) => quex::query(sql).execute(&mut **tx).await?,
};
Ok(result.rows_affected)
}
}
#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mariadb"))]
struct SqliteColumnTypeRow {
data_type: String,
}
#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mariadb"))]
impl FromRow for SqliteColumnTypeRow {
fn from_row(row: &Row) -> quex::Result<Self> {
Ok(Self {
data_type: row.get("type")?,
})
}
}
#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mariadb"))]
struct InformationSchemaColumnRow {
data_type: String,
udt_name: Option<String>,
character_maximum_length: Option<i64>,
numeric_precision: Option<i64>,
numeric_scale: Option<i64>,
}
#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mariadb"))]
impl FromRow for InformationSchemaColumnRow {
fn from_row(row: &Row) -> quex::Result<Self> {
Ok(Self {
data_type: row.get("data_type")?,
udt_name: row.get("udt_name")?,
character_maximum_length: row.get("character_maximum_length")?,
numeric_precision: row.get("numeric_precision")?,
numeric_scale: row.get("numeric_scale")?,
})
}
}
#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mariadb"))]
fn quote_string_literal(value: &str) -> String {
format!("'{}'", value.replace('\'', "''"))
}
#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mariadb"))]
fn parse_sqlite_column_type(data_type: &str) -> Result<ColumnType, MigrationError> {
parse_column_type_string(data_type)
}
#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mariadb"))]
fn parse_information_schema_column_type(
row: &InformationSchemaColumnRow,
) -> Result<ColumnType, MigrationError> {
if matches!(row.udt_name.as_deref(), Some("uuid")) {
return Ok(ColumnType::Uuid);
}
match row.data_type.as_str() {
"character varying" | "varchar" => Ok(ColumnType::Varchar(
row.character_maximum_length.unwrap_or(255) as u32,
)),
"character" | "char" => Ok(ColumnType::Char(
row.character_maximum_length.unwrap_or(1) as u32
)),
"numeric" | "decimal" => Ok(ColumnType::Decimal(
row.numeric_precision.unwrap_or(10) as u32,
row.numeric_scale.unwrap_or(0) as u32,
)),
other => parse_column_type_string(other),
}
}
#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mariadb"))]
fn parse_column_type_string(data_type: &str) -> Result<ColumnType, MigrationError> {
let normalized = data_type.trim().to_ascii_lowercase();
if normalized.starts_with("varchar(") && normalized.ends_with(')') {
let inner = &normalized["varchar(".len()..normalized.len() - 1];
let length = inner
.trim()
.parse::<u32>()
.map_err(|_| MigrationError::UnsupportedColumnType(data_type.to_owned()))?;
return Ok(ColumnType::Varchar(length));
}
if normalized.starts_with("char(") && normalized.ends_with(')') {
let inner = &normalized["char(".len()..normalized.len() - 1];
let length = inner
.trim()
.parse::<u32>()
.map_err(|_| MigrationError::UnsupportedColumnType(data_type.to_owned()))?;
return Ok(ColumnType::Char(length));
}
if normalized.starts_with("decimal(") && normalized.ends_with(')') {
let inner = &normalized["decimal(".len()..normalized.len() - 1];
let mut parts = inner.split(',').map(str::trim);
let precision = parts
.next()
.ok_or_else(|| MigrationError::UnsupportedColumnType(data_type.to_owned()))?
.parse::<u32>()
.map_err(|_| MigrationError::UnsupportedColumnType(data_type.to_owned()))?;
let scale = parts
.next()
.ok_or_else(|| MigrationError::UnsupportedColumnType(data_type.to_owned()))?
.parse::<u32>()
.map_err(|_| MigrationError::UnsupportedColumnType(data_type.to_owned()))?;
return Ok(ColumnType::Decimal(precision, scale));
}
match normalized.as_str() {
"integer" | "int" => Ok(ColumnType::Integer),
"bigint" => Ok(ColumnType::BigInt),
"boolean" | "bool" => Ok(ColumnType::Bool),
"text" => Ok(ColumnType::Text),
"date" => Ok(ColumnType::Date),
"time" | "time without time zone" => Ok(ColumnType::Time),
"timestamp" | "timestamp without time zone" | "timestamp with time zone" | "datetime" => {
Ok(ColumnType::Timestamp)
}
"json" | "jsonb" => Ok(ColumnType::Json),
"uuid" => Ok(ColumnType::Uuid),
"real" | "float" => Ok(ColumnType::Float),
"double precision" | "double" => Ok(ColumnType::Double),
other => Err(MigrationError::UnsupportedColumnType(other.to_owned())),
}
}
#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mariadb"))]
async fn resolve_column_type_from_pool(
dialect: SchemaDialect,
pool: &quex::Pool,
table: &str,
column: &str,
) -> Result<ColumnType, MigrationError> {
match dialect {
SchemaDialect::Sqlite => {
let sql = format!(
"select type from pragma_table_info({}) where name = ? limit 1",
quote_string_literal(table)
);
let row = quex::query(&sql)
.bind(column)
.one::<SqliteColumnTypeRow>(pool)
.await?;
parse_sqlite_column_type(&row.data_type)
}
SchemaDialect::Postgres => {
let row = quex::query(
"select data_type, udt_name, character_maximum_length, numeric_precision, numeric_scale \
from information_schema.columns \
where table_schema = current_schema() and table_name = ? and column_name = ? \
limit 1",
)
.bind(table)
.bind(column)
.one::<InformationSchemaColumnRow>(pool)
.await?;
parse_information_schema_column_type(&row)
}
SchemaDialect::MariaDb => {
let row = quex::query(
"select data_type, data_type as udt_name, character_maximum_length, numeric_precision, numeric_scale \
from information_schema.columns \
where table_schema = database() and table_name = ? and column_name = ? \
limit 1",
)
.bind(table)
.bind(column)
.one::<InformationSchemaColumnRow>(pool)
.await?;
parse_information_schema_column_type(&row)
}
}
}
#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mariadb"))]
async fn resolve_column_type_from_tx(
dialect: SchemaDialect,
tx: &mut quex::PoolTransaction,
table: &str,
column: &str,
) -> Result<ColumnType, MigrationError> {
match dialect {
SchemaDialect::Sqlite => {
let sql = format!(
"select type from pragma_table_info({}) where name = ? limit 1",
quote_string_literal(table)
);
let row = quex::query(&sql)
.bind(column)
.one::<SqliteColumnTypeRow>(&mut *tx)
.await?;
parse_sqlite_column_type(&row.data_type)
}
SchemaDialect::Postgres => {
let row = quex::query(
"select data_type, udt_name, character_maximum_length, numeric_precision, numeric_scale \
from information_schema.columns \
where table_schema = current_schema() and table_name = ? and column_name = ? \
limit 1",
)
.bind(table)
.bind(column)
.one::<InformationSchemaColumnRow>(&mut *tx)
.await?;
parse_information_schema_column_type(&row)
}
SchemaDialect::MariaDb => {
let row = quex::query(
"select data_type, data_type as udt_name, character_maximum_length, numeric_precision, numeric_scale \
from information_schema.columns \
where table_schema = database() and table_name = ? and column_name = ? \
limit 1",
)
.bind(table)
.bind(column)
.one::<InformationSchemaColumnRow>(&mut *tx)
.await?;
parse_information_schema_column_type(&row)
}
}
}
macro_rules! define_backend {
(
feature =
$feature:literal,context =
$context:ident,entry =
$entry:ident,entry_trait =
$entry_trait:ident,dialect =
$dialect:expr
) => {
#[cfg(feature = $feature)]
pub struct $context<'a> {
executor: MigrationExecutor<'a>,
column_type_cache: ColumnTypeCache,
}
#[cfg(feature = $feature)]
impl<'a> $context<'a> {
const SCHEMA_DIALECT: SchemaDialect = $dialect;
pub fn new(executor: &'a quex::Pool) -> Self {
Self {
executor: MigrationExecutor::Pool(executor),
column_type_cache: HashMap::new(),
}
}
pub fn from_transaction(executor: &'a mut quex::PoolTransaction) -> Self {
Self {
executor: MigrationExecutor::Transaction(executor),
column_type_cache: HashMap::new(),
}
}
pub async fn execute_raw(&mut self, sql: &str) -> Result<u64, MigrationError> {
self.executor.execute_raw(sql).await
}
pub async fn column_type(
&mut self,
table: &str,
column: &str,
) -> Result<ColumnType, MigrationError> {
let cache_key = (table.to_owned(), column.to_owned());
if let Some(cached) = self.column_type_cache.get(&cache_key) {
return Ok(cached.clone());
}
let resolved = match &mut self.executor {
MigrationExecutor::Pool(pool) => {
resolve_column_type_from_pool(Self::SCHEMA_DIALECT, pool, table, column)
.await
}
MigrationExecutor::Transaction(tx) => {
resolve_column_type_from_tx(Self::SCHEMA_DIALECT, tx, table, column).await
}
}?;
self.column_type_cache.insert(cache_key, resolved.clone());
Ok(resolved)
}
pub async fn create(
&mut self,
name: &str,
build: impl FnOnce(&mut TableBlueprint),
) -> Result<(), MigrationError> {
let mut table = TableBlueprint::new(name);
build(&mut table);
execute_table_blueprint(self, table).await
}
pub async fn alter_table(
&mut self,
name: &str,
build: impl FnOnce(&mut AlterTableBlueprint),
) -> Result<(), MigrationError> {
let mut table = AlterTableBlueprint::new(name);
build(&mut table);
for sql in table.sql_statements(Self::SCHEMA_DIALECT) {
self.execute_raw(&sql).await?;
}
Ok(())
}
pub async fn table(
&mut self,
name: &str,
build: impl FnOnce(&mut AlterTableBlueprint),
) -> Result<(), MigrationError> {
self.alter_table(name, build).await
}
pub async fn drop(&mut self, name: &str) -> Result<(), MigrationError> {
let table = TableBlueprint::new(name);
self.execute_raw(&table.drop_sql(Self::SCHEMA_DIALECT))
.await?;
Ok(())
}
pub async fn create_index(
&mut self,
name: &str,
table: &str,
columns: impl IntoSchemaColumns,
) -> Result<(), MigrationError> {
let index = IndexBlueprint::new(name, table, columns);
self.execute_raw(&index.create_sql(Self::SCHEMA_DIALECT))
.await?;
Ok(())
}
pub async fn create_unique_index(
&mut self,
name: &str,
table: &str,
columns: impl IntoSchemaColumns,
) -> Result<(), MigrationError> {
let index = IndexBlueprint::new_unique(name, table, columns);
self.execute_raw(&index.create_sql(Self::SCHEMA_DIALECT))
.await?;
Ok(())
}
pub async fn drop_index(&mut self, name: &str) -> Result<(), MigrationError> {
let index = IndexBlueprint::named(name);
self.execute_raw(&index.drop_sql(Self::SCHEMA_DIALECT))
.await?;
Ok(())
}
}
#[cfg(feature = $feature)]
impl<'a> BlueprintExecutor for $context<'a> {
fn dialect(&self) -> SchemaDialect {
Self::SCHEMA_DIALECT
}
async fn execute_raw_blueprint(&mut self, sql: &str) -> Result<u64, MigrationError> {
Self::execute_raw(self, sql).await
}
}
#[cfg(feature = $feature)]
#[derive(Clone, Copy)]
pub struct $entry {
pub name: &'static str,
pub version: u64,
pub up: for<'a> fn(&'a mut $context<'a>) -> MigrationFuture<'a>,
pub down: for<'a> fn(&'a mut $context<'a>) -> MigrationFuture<'a>,
}
#[cfg(feature = $feature)]
impl $entry {
pub const fn new(
name: &'static str,
version: u64,
up: for<'a> fn(&'a mut $context<'a>) -> MigrationFuture<'a>,
down: for<'a> fn(&'a mut $context<'a>) -> MigrationFuture<'a>,
) -> Self {
Self {
name,
version,
up,
down,
}
}
}
#[cfg(feature = $feature)]
inventory::collect!($entry);
#[cfg(feature = $feature)]
#[allow(async_fn_in_trait)]
pub trait $entry_trait {
async fn up(ctx: &mut $context<'_>) -> Result<(), MigrationError>;
async fn down(ctx: &mut $context<'_>) -> Result<(), MigrationError>;
}
};
}
define_backend!(
feature = "sqlite",
context = SqliteMigrationContext,
entry = SqliteMigrationEntry,
entry_trait = SqliteMigration,
dialect = SchemaDialect::Sqlite
);
define_backend!(
feature = "postgres",
context = PostgresMigrationContext,
entry = PostgresMigrationEntry,
entry_trait = PostgresMigration,
dialect = SchemaDialect::Postgres
);
define_backend!(
feature = "mariadb",
context = MariadbMigrationContext,
entry = MariadbMigrationEntry,
entry_trait = MariadbMigration,
dialect = SchemaDialect::MariaDb
);
pub enum MigrationContext<'a> {
#[cfg(feature = "sqlite")]
Sqlite(SqliteMigrationContext<'a>),
#[cfg(feature = "postgres")]
Postgres(PostgresMigrationContext<'a>),
#[cfg(feature = "mariadb")]
Mariadb(MariadbMigrationContext<'a>),
#[cfg(not(any(feature = "sqlite", feature = "postgres", feature = "mariadb")))]
Disabled(std::marker::PhantomData<&'a ()>),
}
impl<'a> MigrationContext<'a> {
pub fn dialect(&self) -> SchemaDialect {
match self {
#[cfg(feature = "sqlite")]
Self::Sqlite(_) => SchemaDialect::Sqlite,
#[cfg(feature = "postgres")]
Self::Postgres(_) => SchemaDialect::Postgres,
#[cfg(feature = "mariadb")]
Self::Mariadb(_) => SchemaDialect::MariaDb,
#[cfg(not(any(feature = "sqlite", feature = "postgres", feature = "mariadb")))]
Self::Disabled(_) => SchemaDialect::Sqlite,
}
}
pub async fn execute_raw(&mut self, sql: &str) -> Result<u64, MigrationError> {
match self {
#[cfg(feature = "sqlite")]
Self::Sqlite(ctx) => ctx.execute_raw(sql).await,
#[cfg(feature = "postgres")]
Self::Postgres(ctx) => ctx.execute_raw(sql).await,
#[cfg(feature = "mariadb")]
Self::Mariadb(ctx) => ctx.execute_raw(sql).await,
#[cfg(not(any(feature = "sqlite", feature = "postgres", feature = "mariadb")))]
Self::Disabled(_) => Err(no_backend_error()),
}
}
pub async fn column_type(
&mut self,
table: &str,
column: &str,
) -> Result<ColumnType, MigrationError> {
match self {
#[cfg(feature = "sqlite")]
Self::Sqlite(ctx) => ctx.column_type(table, column).await,
#[cfg(feature = "postgres")]
Self::Postgres(ctx) => ctx.column_type(table, column).await,
#[cfg(feature = "mariadb")]
Self::Mariadb(ctx) => ctx.column_type(table, column).await,
#[cfg(not(any(feature = "sqlite", feature = "postgres", feature = "mariadb")))]
Self::Disabled(_) => Err(no_backend_error()),
}
}
pub async fn create(
&mut self,
name: &str,
build: impl FnOnce(&mut TableBlueprint),
) -> Result<(), MigrationError> {
let mut build = Some(build);
match self {
#[cfg(feature = "sqlite")]
Self::Sqlite(ctx) => ctx.create(name, build.take().unwrap()).await,
#[cfg(feature = "postgres")]
Self::Postgres(ctx) => ctx.create(name, build.take().unwrap()).await,
#[cfg(feature = "mariadb")]
Self::Mariadb(ctx) => ctx.create(name, build.take().unwrap()).await,
#[cfg(not(any(feature = "sqlite", feature = "postgres", feature = "mariadb")))]
Self::Disabled(_) => Err(no_backend_error()),
}
}
pub async fn alter_table(
&mut self,
name: &str,
build: impl FnOnce(&mut AlterTableBlueprint),
) -> Result<(), MigrationError> {
let mut build = Some(build);
match self {
#[cfg(feature = "sqlite")]
Self::Sqlite(ctx) => ctx.alter_table(name, build.take().unwrap()).await,
#[cfg(feature = "postgres")]
Self::Postgres(ctx) => ctx.alter_table(name, build.take().unwrap()).await,
#[cfg(feature = "mariadb")]
Self::Mariadb(ctx) => ctx.alter_table(name, build.take().unwrap()).await,
#[cfg(not(any(feature = "sqlite", feature = "postgres", feature = "mariadb")))]
Self::Disabled(_) => Err(no_backend_error()),
}
}
pub async fn table(
&mut self,
name: &str,
build: impl FnOnce(&mut AlterTableBlueprint),
) -> Result<(), MigrationError> {
self.alter_table(name, build).await
}
pub async fn drop(&mut self, name: &str) -> Result<(), MigrationError> {
match self {
#[cfg(feature = "sqlite")]
Self::Sqlite(ctx) => ctx.drop(name).await,
#[cfg(feature = "postgres")]
Self::Postgres(ctx) => ctx.drop(name).await,
#[cfg(feature = "mariadb")]
Self::Mariadb(ctx) => ctx.drop(name).await,
#[cfg(not(any(feature = "sqlite", feature = "postgres", feature = "mariadb")))]
Self::Disabled(_) => Err(no_backend_error()),
}
}
pub async fn create_index(
&mut self,
name: &str,
table: &str,
columns: impl IntoSchemaColumns,
) -> Result<(), MigrationError> {
let index = IndexBlueprint::new(name, table, columns);
self.execute_raw(&index.create_sql(self.dialect())).await?;
Ok(())
}
pub async fn create_unique_index(
&mut self,
name: &str,
table: &str,
columns: impl IntoSchemaColumns,
) -> Result<(), MigrationError> {
let index = IndexBlueprint::new_unique(name, table, columns);
self.execute_raw(&index.create_sql(self.dialect())).await?;
Ok(())
}
pub async fn drop_index(&mut self, name: &str) -> Result<(), MigrationError> {
let index = IndexBlueprint::named(name);
self.execute_raw(&index.drop_sql(self.dialect())).await?;
Ok(())
}
}
impl<'a> BlueprintExecutor for MigrationContext<'a> {
fn dialect(&self) -> SchemaDialect {
Self::dialect(self)
}
async fn execute_raw_blueprint(&mut self, sql: &str) -> Result<u64, MigrationError> {
Self::execute_raw(self, sql).await
}
}