use std::{collections::HashMap, path::PathBuf};
use quex::{self, Driver, FromRow, Pool, Row};
use super::pool::MigrationPool;
use crate::{
AlterColumnBuilder, BlueprintExecutor, ColumnBuilder, MigrationContext, MigrationError,
MigrationReport, ResetReport,
files::{
AlterAction, AlterEnumAction, Field, FieldType, MigrationFile, MigrationOp, Nullability,
Reference, ReferenceAction, TableItem, default_migration_dir, load_migrations,
},
};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CliCommand {
Run,
Refresh,
RollbackLastBatch,
RollbackSteps(usize),
Reset,
Validate,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CliOptions {
pub database_url: String,
pub command: CliCommand,
pub migration_dir: Option<PathBuf>,
}
pub async fn run_cli_command(options: &CliOptions) -> Result<MigrationReport, MigrationError> {
let migrator = Migrator::connect(&options.database_url).await?.path(
options
.migration_dir
.clone()
.unwrap_or_else(default_migration_dir),
);
match options.command {
CliCommand::Run => migrator.run().await,
CliCommand::Refresh => migrator.refresh().await,
CliCommand::RollbackLastBatch => migrator.rollback_last_batch().await,
CliCommand::RollbackSteps(steps) => migrator.rollback_steps(steps).await,
CliCommand::Reset => {
let report = migrator.reset().await?;
Ok(MigrationReport {
batch: None,
applied: Vec::new(),
rolled_back: report.rolled_back,
})
}
CliCommand::Validate => {
migrator.validate()?;
Ok(MigrationReport::default())
}
}
}
struct BatchValue {
batch: Option<i64>,
}
impl FromRow for BatchValue {
fn from_row(row: &Row) -> quex::Result<Self> {
Ok(Self {
batch: row.get("batch")?,
})
}
}
struct ExistsValue {
value: i64,
}
impl FromRow for ExistsValue {
fn from_row(row: &Row) -> quex::Result<Self> {
Ok(Self {
value: row.get("value")?,
})
}
}
struct AppliedMigrationRow {
id: i64,
checksum: String,
status: String,
}
impl FromRow for AppliedMigrationRow {
fn from_row(row: &Row) -> quex::Result<Self> {
Ok(Self {
id: row.get("id")?,
checksum: row.get("checksum")?,
status: row.get("status")?,
})
}
}
struct AppliedMigrationId {
id: i64,
}
impl FromRow for AppliedMigrationId {
fn from_row(row: &Row) -> quex::Result<Self> {
Ok(Self { id: row.get("id")? })
}
}
pub struct Migrator {
pub(crate) pool: MigrationPool,
migration_dir: PathBuf,
}
impl Migrator {
pub fn new(pool: impl Into<Self>) -> Self {
pool.into()
}
pub fn path(mut self, path: impl Into<PathBuf>) -> Self {
self.migration_dir = path.into();
self
}
pub async fn connect(database_url: &str) -> Result<Self, MigrationError> {
let pool = Pool::connect(database_url)?.max_size(1).build().await?;
let migrator = match pool.driver() {
Driver::Sqlite => {
#[cfg(feature = "sqlite")]
{
Ok::<Self, MigrationError>(Self {
pool: MigrationPool::Sqlite(pool),
migration_dir: default_migration_dir(),
})
}
#[cfg(not(feature = "sqlite"))]
{
Err(MigrationError::BackendNotEnabled("sqlite"))
}
}
Driver::Pgsql => {
#[cfg(feature = "postgres")]
{
Ok::<Self, MigrationError>(Self {
pool: MigrationPool::Postgres(pool),
migration_dir: default_migration_dir(),
})
}
#[cfg(not(feature = "postgres"))]
{
Err(MigrationError::BackendNotEnabled("postgres"))
}
}
Driver::Mysql => {
#[cfg(feature = "mariadb")]
{
Ok::<Self, MigrationError>(Self {
pool: MigrationPool::Mariadb(pool),
migration_dir: default_migration_dir(),
})
}
#[cfg(not(feature = "mariadb"))]
{
Err(MigrationError::BackendNotEnabled("mariadb"))
}
}
}?;
migration_trace!(
backend = migrator.pool.backend_name(),
"connecting migrator"
);
Ok(migrator)
}
pub fn validate(&self) -> Result<(), MigrationError> {
let _ = self.load_migrations()?;
Ok(())
}
pub async fn reset(&self) -> Result<ResetReport, MigrationError> {
let mut rolled_back = Vec::new();
let mut batches = Vec::new();
loop {
let batch_report = self.rollback_last_batch().await?;
if batch_report.rolled_back.is_empty() {
break;
}
rolled_back.extend(batch_report.rolled_back);
if let Some(batch) = batch_report.batch {
batches.push(batch);
}
}
Ok(ResetReport {
rolled_back,
batches,
})
}
pub async fn refresh(&self) -> Result<MigrationReport, MigrationError> {
let reset_report = self.reset().await?;
let up_report = self.run().await?;
Ok(MigrationReport {
batch: up_report.batch,
applied: up_report.applied,
rolled_back: reset_report.rolled_back,
})
}
pub async fn run(&self) -> Result<MigrationReport, MigrationError> {
migration_trace!(backend = self.pool.backend_name(), "starting migration run");
self.ensure_table().await?;
let migrations = self.load_migrations()?;
let mut applied = self.applied_migrations().await?;
let mut retryable_failed_ids = Vec::new();
for migration in &migrations {
if let Some(row) = applied.get(&migration.id) {
if row.status == "failed" {
if self.pool.supports_clean_failed_retry() {
retryable_failed_ids.push(migration.id);
continue;
}
return Err(MigrationError::FailedMigration { id: migration.id });
}
if row.checksum != migration.checksum {
return Err(MigrationError::ChecksumMismatch { id: migration.id });
}
}
}
for id in retryable_failed_ids {
self.delete_applied(id).await?;
applied.remove(&id);
}
let pending: Vec<_> = migrations
.into_iter()
.filter(|migration| !applied.contains_key(&migration.id))
.collect();
if pending.is_empty() {
return Ok(MigrationReport::default());
}
let batch = self.next_batch().await?;
let mut report = MigrationReport {
batch: Some(batch),
applied: Vec::with_capacity(pending.len()),
rolled_back: Vec::new(),
};
for migration in pending {
self.mark_running(&migration, batch).await?;
match self.apply(&migration, &migration.up).await {
Ok(()) => {
self.mark_applied(&migration.id).await?;
report.applied.push(migration.id);
}
Err(error) => {
self.mark_failed(migration.id).await?;
return Err(error);
}
}
}
Ok(report)
}
pub async fn rollback_last_batch(&self) -> Result<MigrationReport, MigrationError> {
self.ensure_table().await?;
let Some(batch) = self.last_batch().await? else {
return Ok(MigrationReport::default());
};
let ids = self.ids_for_batch(batch).await?;
self.rollback_ids(ids, Some(batch)).await
}
pub async fn rollback_steps(&self, steps: usize) -> Result<MigrationReport, MigrationError> {
self.ensure_table().await?;
if steps == 0 {
return Ok(MigrationReport::default());
}
let ids = self.latest_ids(steps).await?;
self.rollback_ids(ids, None).await
}
fn load_migrations(&self) -> Result<Vec<MigrationFile>, MigrationError> {
load_migrations(&self.migration_dir)
}
async fn rollback_ids(
&self,
ids: Vec<u64>,
batch: Option<u64>,
) -> Result<MigrationReport, MigrationError> {
if ids.is_empty() {
return Ok(MigrationReport::default());
}
let migrations = self
.load_migrations()?
.into_iter()
.map(|migration| (migration.id, migration))
.collect::<HashMap<_, _>>();
let mut report = MigrationReport {
batch,
applied: Vec::new(),
rolled_back: Vec::with_capacity(ids.len()),
};
for id in ids {
let migration = migrations
.get(&id)
.ok_or(MigrationError::MissingMigration(id))?;
self.apply(migration, &migration.down).await?;
self.delete_applied(id).await?;
report.rolled_back.push(id);
}
Ok(report)
}
async fn apply(
&self,
migration: &MigrationFile,
ops: &[MigrationOp],
) -> Result<(), MigrationError> {
let mut tx = self.pool.pool().begin().await?;
let result = {
let mut ctx = self.new_transaction_context(&mut tx);
apply_operations(&mut ctx, &migration.path, ops).await
};
match result {
Ok(()) => tx.commit().await.map_err(Into::into),
Err(error) => {
let _ = tx.rollback().await;
Err(error)
}
}
}
fn new_transaction_context<'a>(
&self,
tx: &'a mut quex::PoolTransaction,
) -> MigrationContext<'a> {
match &self.pool {
#[cfg(feature = "sqlite")]
MigrationPool::Sqlite(_) => {
MigrationContext::Sqlite(crate::SqliteMigrationContext::from_transaction(tx))
}
#[cfg(feature = "postgres")]
MigrationPool::Postgres(_) => {
MigrationContext::Postgres(crate::PostgresMigrationContext::from_transaction(tx))
}
#[cfg(feature = "mariadb")]
MigrationPool::Mariadb(_) => {
MigrationContext::Mariadb(crate::MariadbMigrationContext::from_transaction(tx))
}
#[cfg(not(any(feature = "sqlite", feature = "postgres", feature = "mariadb")))]
MigrationPool::Disabled => MigrationContext::Disabled(std::marker::PhantomData),
}
}
async fn ensure_table(&self) -> Result<(), MigrationError> {
let exists_sql = match &self.pool {
#[cfg(feature = "sqlite")]
MigrationPool::Sqlite(_) => {
"select 1 as value from sqlite_master where type = 'table' and name = 'migrations' limit 1"
}
#[cfg(feature = "postgres")]
MigrationPool::Postgres(_) => {
"select 1 as value from pg_catalog.pg_class c join pg_catalog.pg_namespace n on n.oid = c.relnamespace where c.relkind = 'r' and c.relname = 'migrations' and n.nspname = current_schema() limit 1"
}
#[cfg(feature = "mariadb")]
MigrationPool::Mariadb(_) => {
"select 1 as value from information_schema.tables where table_schema = database() and table_name = 'migrations' limit 1"
}
#[cfg(not(any(feature = "sqlite", feature = "postgres", feature = "mariadb")))]
MigrationPool::Disabled => return Err(MigrationError::BackendNotEnabled("no backend")),
};
let exists = quex::query(exists_sql)
.optional::<ExistsValue>(self.pool.pool())
.await?
.map(|row| row.value != 0)
.unwrap_or(false);
if exists {
return Ok(());
}
let sql = match &self.pool {
#[cfg(feature = "sqlite")]
MigrationPool::Sqlite(_) => {
"create table if not exists migrations (id integer primary key not null, name text not null, checksum text not null, batch integer not null, status text not null, started_at text not null default current_timestamp, finished_at text null)"
}
#[cfg(feature = "postgres")]
MigrationPool::Postgres(_) => {
"create table if not exists migrations (id bigint primary key not null, name text not null, checksum text not null, batch bigint not null, status text not null, started_at timestamptz not null default current_timestamp, finished_at timestamptz null)"
}
#[cfg(feature = "mariadb")]
MigrationPool::Mariadb(_) => {
"create table if not exists migrations (id bigint primary key not null, name varchar(255) not null, checksum varchar(255) not null, batch bigint not null, status varchar(32) not null, started_at timestamp not null default current_timestamp, finished_at timestamp null)"
}
#[cfg(not(any(feature = "sqlite", feature = "postgres", feature = "mariadb")))]
MigrationPool::Disabled => return Err(MigrationError::BackendNotEnabled("no backend")),
};
quex::query(sql).execute(self.pool.pool()).await?;
Ok(())
}
async fn applied_migrations(
&self,
) -> Result<HashMap<u64, AppliedMigrationRow>, MigrationError> {
let rows = quex::query("select id, checksum, status from migrations order by id")
.all::<AppliedMigrationRow>(self.pool.pool())
.await?;
Ok(rows
.into_iter()
.map(|row| (row.id as u64, row))
.collect::<HashMap<_, _>>())
}
async fn next_batch(&self) -> Result<u64, MigrationError> {
Ok(self.max_batch().await?.unwrap_or(0) + 1)
}
async fn last_batch(&self) -> Result<Option<u64>, MigrationError> {
self.max_batch().await
}
async fn max_batch(&self) -> Result<Option<u64>, MigrationError> {
let row =
quex::query("select max(batch) as batch from migrations where status = 'applied'")
.one::<BatchValue>(self.pool.pool())
.await?;
Ok(row.batch.map(|value| value as u64))
}
async fn ids_for_batch(&self, batch: u64) -> Result<Vec<u64>, MigrationError> {
let ids = quex::query(
"select id from migrations where batch = ? and status = 'applied' order by id desc",
)
.bind(batch as i64)
.all::<AppliedMigrationId>(self.pool.pool())
.await?;
Ok(ids.into_iter().map(|value| value.id as u64).collect())
}
async fn latest_ids(&self, steps: usize) -> Result<Vec<u64>, MigrationError> {
let ids = quex::query("select id from migrations where status = 'applied' order by batch desc, id desc limit ?")
.bind(steps as i64)
.all::<AppliedMigrationId>(self.pool.pool())
.await?;
Ok(ids.into_iter().map(|value| value.id as u64).collect())
}
async fn mark_running(
&self,
migration: &MigrationFile,
batch: u64,
) -> Result<(), MigrationError> {
quex::query("delete from migrations where id = ?")
.bind(migration.id as i64)
.execute(self.pool.pool())
.await?;
quex::query(
"insert into migrations(id, name, checksum, batch, status) values(?, ?, ?, ?, ?)",
)
.bind(migration.id as i64)
.bind(&migration.name)
.bind(&migration.checksum)
.bind(batch as i64)
.bind("running")
.execute(self.pool.pool())
.await?;
Ok(())
}
async fn mark_applied(&self, id: &u64) -> Result<(), MigrationError> {
quex::query(
"update migrations set status = ?, finished_at = current_timestamp where id = ?",
)
.bind("applied")
.bind(*id as i64)
.execute(self.pool.pool())
.await?;
Ok(())
}
async fn mark_failed(&self, id: u64) -> Result<(), MigrationError> {
quex::query(
"update migrations set status = ?, finished_at = current_timestamp where id = ?",
)
.bind("failed")
.bind(id as i64)
.execute(self.pool.pool())
.await?;
Ok(())
}
async fn delete_applied(&self, id: u64) -> Result<(), MigrationError> {
quex::query("delete from migrations where id = ?")
.bind(id as i64)
.execute(self.pool.pool())
.await?;
Ok(())
}
}
trait ColumnAttrs: Sized {
fn nullable(self) -> Self;
fn unique(self) -> Self;
fn default_raw(self, value: &str) -> Self;
}
impl<'a> ColumnAttrs for ColumnBuilder<'a> {
fn nullable(self) -> Self {
Self::nullable(self)
}
fn unique(self) -> Self {
Self::unique(self)
}
fn default_raw(self, value: &str) -> Self {
Self::default_raw(self, value)
}
}
impl<'a> ColumnAttrs for AlterColumnBuilder<'a> {
fn nullable(self) -> Self {
Self::nullable(self)
}
fn unique(self) -> Self {
Self::unique(self)
}
fn default_raw(self, value: &str) -> Self {
Self::default_raw(self, value)
}
}
impl<'a> ColumnAttrs for crate::ForeignKeyBuilder<'a> {
fn nullable(self) -> Self {
Self::nullable(self)
}
fn unique(self) -> Self {
Self::unique(self)
}
fn default_raw(self, value: &str) -> Self {
Self::default_raw(self, value)
}
}
fn apply_field_attrs<B>(builder: B, field: &Field) -> B
where
B: ColumnAttrs,
{
let mut builder = builder;
if matches!(field.nullable, Nullability::Nullable) {
builder = builder.nullable();
}
if field.unique {
builder = builder.unique();
}
if let Some(default) = &field.default {
builder = builder.default_raw(default);
}
builder
}
async fn apply_operations(
ctx: &mut MigrationContext<'_>,
path: &std::path::Path,
ops: &[MigrationOp],
) -> Result<(), MigrationError> {
for op in ops {
match op {
MigrationOp::Sql { sql } => {
ctx.execute_raw(sql).await?;
}
MigrationOp::Backfill { sql } => {
ctx.execute_raw(sql).await?;
}
MigrationOp::CreateTable { name, items } => {
let items = resolve_table_items(ctx, path, items).await?;
ctx.create(name, |table| {
for item in &items {
match item {
TableItem::Column(field) => apply_create_field(table, field),
TableItem::Index {
name: index_name,
columns,
} => table.index(
index_name
.as_deref()
.unwrap_or(&default_index_name(name, columns)),
columns.clone(),
),
TableItem::Unique(columns) => table.unique(columns.clone()),
TableItem::ConstraintUnique { name, columns } => {
table.unique_named(name, columns.clone())
}
TableItem::Primary(columns) => table.primary(columns.clone()),
TableItem::Timestamps => table.timestamps(),
TableItem::SoftDeletes => {
table.timestamp("deleted_at").nullable();
}
}
}
})
.await?;
}
MigrationOp::AlterTable { name, actions } => {
let actions = resolve_alter_actions(ctx, path, actions).await?;
for action in &actions {
if let AlterAction::DropIndex(index_name) = action {
let index = crate::IndexBlueprint::named(index_name);
ctx.execute_raw_blueprint(&index.drop_sql(ctx.dialect()))
.await?;
}
}
ctx.alter_table(name, |table| {
for action in &actions {
match action {
AlterAction::AddColumn(field) if field.reference.is_none() => {
apply_alter_field(table, field)
}
AlterAction::DropColumn(name) => table.drop_column(name),
AlterAction::RenameColumn { .. } => {}
AlterAction::AddIndex { .. } => {}
AlterAction::DropIndex(_) | AlterAction::AddColumn(_) => {}
}
}
})
.await?;
for action in &actions {
match action {
AlterAction::RenameColumn { from, to } => {
let sql = format!(
"alter table {} rename column {} to {};",
ctx.dialect().quote_ident(name),
ctx.dialect().quote_ident(from),
ctx.dialect().quote_ident(to)
);
ctx.execute_raw_blueprint(&sql).await?;
}
AlterAction::AddIndex {
name: index_name,
columns,
} => {
let index = crate::IndexBlueprint::new(
index_name
.as_deref()
.unwrap_or(&default_index_name(name, columns)),
name,
columns.clone(),
);
ctx.execute_raw_blueprint(&index.create_sql(ctx.dialect()))
.await?;
}
AlterAction::AddColumn(field) if field.reference.is_some() => {
add_relation_column(ctx, name, field).await?;
}
AlterAction::DropIndex(_) => {}
AlterAction::AddColumn(_) | AlterAction::DropColumn(_) => {}
}
}
}
MigrationOp::DropTable { name } => {
ctx.drop(name).await?;
}
MigrationOp::RenameTable { from, to } => {
let sql = format!(
"alter table {} rename to {};",
ctx.dialect().quote_ident(from),
ctx.dialect().quote_ident(to)
);
ctx.execute_raw_blueprint(&sql).await?;
}
MigrationOp::CreateEnum { name, values } => {
apply_create_enum(ctx, name, values).await?;
}
MigrationOp::AlterEnum { name, actions } => {
apply_alter_enum(ctx, name, actions).await?;
}
MigrationOp::DropEnum { name } => {
apply_drop_enum(ctx, name).await?;
}
MigrationOp::RenameEnum { from, to } => {
apply_rename_enum(ctx, from, to).await?;
}
}
}
Ok(())
}
fn default_index_name(table: &str, columns: &[String]) -> String {
format!("{table}_{}_idx", columns.join("_"))
}
async fn resolve_table_items(
ctx: &mut MigrationContext<'_>,
path: &std::path::Path,
items: &[TableItem],
) -> Result<Vec<TableItem>, MigrationError> {
let mut resolved = Vec::with_capacity(items.len());
for item in items {
match item {
TableItem::Column(field) if field.reference.is_some() => {
resolved.push(TableItem::Column(
resolve_reference_field(ctx, path, field).await?,
));
}
_ => resolved.push(item.clone()),
}
}
Ok(resolved)
}
async fn resolve_alter_actions(
ctx: &mut MigrationContext<'_>,
path: &std::path::Path,
actions: &[AlterAction],
) -> Result<Vec<AlterAction>, MigrationError> {
let mut resolved = Vec::with_capacity(actions.len());
for action in actions {
if let AlterAction::AddColumn(field) = action
&& field.reference.is_some()
{
resolved.push(AlterAction::AddColumn(
resolve_reference_field(ctx, path, field).await?,
));
continue;
}
resolved.push(action.clone());
}
Ok(resolved)
}
async fn resolve_reference_field(
ctx: &mut MigrationContext<'_>,
path: &std::path::Path,
field: &Field,
) -> Result<Field, MigrationError> {
let mut resolved = field.clone();
let reference =
resolved
.reference
.as_ref()
.ok_or_else(|| MigrationError::InvalidMigrationFile {
path: path.to_path_buf(),
message: format!("missing reference metadata for `{}`", field.name),
})?;
if matches!(resolved.ty, FieldType::Implicit) {
let target_column = reference.column.as_deref().unwrap_or("id");
resolved.ty = map_schema_type_to_field_type(
ctx.column_type(&reference.table, target_column).await?,
path,
field,
)?;
}
Ok(resolved)
}
fn map_schema_type_to_field_type(
ty: crate::ColumnType,
_path: &std::path::Path,
_field: &Field,
) -> Result<FieldType, MigrationError> {
Ok(match ty {
crate::ColumnType::Integer => FieldType::Integer,
crate::ColumnType::BigInt => FieldType::BigInt,
crate::ColumnType::Bool => FieldType::Boolean,
crate::ColumnType::Char(len) => FieldType::Varchar(len),
crate::ColumnType::Varchar(len) => FieldType::Varchar(len),
crate::ColumnType::Text => FieldType::Text,
crate::ColumnType::Date => FieldType::Date,
crate::ColumnType::Time => FieldType::Time,
crate::ColumnType::DateTime => FieldType::DateTime,
crate::ColumnType::Timestamp => FieldType::Timestamp,
crate::ColumnType::Decimal(precision, scale) => FieldType::Decimal(precision, scale),
crate::ColumnType::Float => FieldType::Float,
crate::ColumnType::Double => FieldType::Double,
crate::ColumnType::Json => FieldType::Json,
crate::ColumnType::Uuid => FieldType::Uuid,
crate::ColumnType::Custom(name) => FieldType::Custom(name),
})
}
fn apply_create_field(table: &mut crate::TableBlueprint, field: &Field) {
if let Some(reference) = &field.reference {
apply_reference_field(table, field, reference);
return;
}
match &field.ty {
FieldType::Implicit if field.primary && field.name == "id" => table.id(),
FieldType::Id => table.id(),
FieldType::String => {
let _ = apply_field_attrs(table.string(&field.name), field);
}
FieldType::Varchar(length) => {
let _ = apply_field_attrs(table.varchar(&field.name, *length), field);
}
FieldType::Text => {
let _ = apply_field_attrs(table.text(&field.name), field);
}
FieldType::Integer => {
let _ = apply_field_attrs(table.integer(&field.name), field);
}
FieldType::BigInt => {
let _ = apply_field_attrs(table.bigint(&field.name), field);
}
FieldType::Boolean => {
let _ = apply_field_attrs(table.boolean(&field.name), field);
}
FieldType::Date => {
let _ = apply_field_attrs(table.date(&field.name), field);
}
FieldType::Time => {
let _ = apply_field_attrs(table.time(&field.name), field);
}
FieldType::DateTime => {
let _ = apply_field_attrs(table.datetime(&field.name), field);
}
FieldType::Timestamp | FieldType::TimestampTz => {
let _ = apply_field_attrs(table.timestamp(&field.name), field);
}
FieldType::Decimal(precision, scale) => {
let _ = apply_field_attrs(table.decimal(&field.name, *precision, *scale), field);
}
FieldType::Float => {
let _ = apply_field_attrs(table.float(&field.name), field);
}
FieldType::Double => {
let _ = apply_field_attrs(table.double(&field.name), field);
}
FieldType::Json => {
let _ = apply_field_attrs(table.json(&field.name), field);
}
FieldType::Uuid => {
let _ = apply_field_attrs(table.uuid(&field.name), field);
}
FieldType::RememberToken => {
if field.name == "remember_token" {
let _ = apply_field_attrs(table.remember_token(), field);
} else {
let _ = apply_field_attrs(table.string(&field.name), field);
}
}
FieldType::Custom(name) => {
let _ = apply_field_attrs(
table.custom(&field.name, crate::ColumnType::Custom(name.clone())),
field,
);
}
FieldType::Implicit => {}
}
}
fn apply_alter_field(table: &mut crate::AlterTableBlueprint, field: &Field) {
match &field.ty {
FieldType::Implicit | FieldType::Id => {}
FieldType::String => {
let _ = apply_field_attrs(table.string(&field.name), field);
}
FieldType::Varchar(length) => {
let _ = apply_field_attrs(table.varchar(&field.name, *length), field);
}
FieldType::Text => {
let _ = apply_field_attrs(table.text(&field.name), field);
}
FieldType::Integer => {
let _ = apply_field_attrs(table.integer(&field.name), field);
}
FieldType::BigInt => {
let _ = apply_field_attrs(table.bigint(&field.name), field);
}
FieldType::Boolean => {
let _ = apply_field_attrs(table.boolean(&field.name), field);
}
FieldType::Date => {
let _ = apply_field_attrs(table.date(&field.name), field);
}
FieldType::Time => {
let _ = apply_field_attrs(table.time(&field.name), field);
}
FieldType::DateTime => {
let _ = apply_field_attrs(table.datetime(&field.name), field);
}
FieldType::Timestamp | FieldType::TimestampTz => {
let _ = apply_field_attrs(table.timestamp(&field.name), field);
}
FieldType::Decimal(precision, scale) => {
let _ = apply_field_attrs(table.decimal(&field.name, *precision, *scale), field);
}
FieldType::Float => {
let _ = apply_field_attrs(table.float(&field.name), field);
}
FieldType::Double => {
let _ = apply_field_attrs(table.double(&field.name), field);
}
FieldType::Json => {
let _ = apply_field_attrs(table.json(&field.name), field);
}
FieldType::Uuid => {
let _ = apply_field_attrs(table.uuid(&field.name), field);
}
FieldType::RememberToken => {
let _ = apply_field_attrs(table.string(&field.name), field);
}
FieldType::Custom(name) => {
let _ = apply_field_attrs(
table.custom(&field.name, crate::ColumnType::Custom(name.clone())),
field,
);
}
}
}
fn apply_reference_field(table: &mut crate::TableBlueprint, field: &Field, reference: &Reference) {
let mut builder = apply_field_attrs(
table.foreign(&field.name, relation_column_type(field)),
field,
);
if let Some(column) = &reference.column {
builder = builder.references_column(&reference.table, column);
} else {
builder = builder.references(&reference.table);
}
if field.index {
builder = builder.index();
}
if let Some(action) = reference.on_delete {
builder = match action {
ReferenceAction::Cascade => builder.cascade_on_delete(),
ReferenceAction::Restrict => builder.restrict_on_delete(),
ReferenceAction::SetNull => builder.null_on_delete(),
ReferenceAction::NoAction => builder.no_action_on_delete(),
};
} else {
builder = builder.restrict_on_delete();
}
if let Some(action) = reference.on_update {
builder = match action {
ReferenceAction::Cascade => builder.cascade_on_update(),
ReferenceAction::Restrict => builder.restrict_on_update(),
ReferenceAction::SetNull => builder.null_on_update(),
ReferenceAction::NoAction => builder.no_action_on_update(),
};
}
let _ = builder;
}
fn relation_column_type(field: &Field) -> crate::ColumnType {
match &field.ty {
FieldType::String => crate::ColumnType::Varchar(255),
FieldType::Text => crate::ColumnType::Text,
FieldType::Integer => crate::ColumnType::Integer,
FieldType::BigInt => crate::ColumnType::BigInt,
FieldType::Boolean => crate::ColumnType::Bool,
FieldType::Date => crate::ColumnType::Date,
FieldType::Time => crate::ColumnType::Time,
FieldType::DateTime => crate::ColumnType::DateTime,
FieldType::Timestamp | FieldType::TimestampTz => crate::ColumnType::Timestamp,
FieldType::Decimal(precision, scale) => crate::ColumnType::Decimal(*precision, *scale),
FieldType::Float => crate::ColumnType::Float,
FieldType::Double => crate::ColumnType::Double,
FieldType::Json => crate::ColumnType::Json,
FieldType::Uuid => crate::ColumnType::Uuid,
FieldType::Varchar(length) => crate::ColumnType::Varchar(*length),
FieldType::Id => crate::ColumnType::BigInt,
FieldType::RememberToken => crate::ColumnType::Varchar(100),
FieldType::Custom(name) => crate::ColumnType::Custom(name.clone()),
FieldType::Implicit => crate::ColumnType::BigInt,
}
}
async fn add_relation_column(
ctx: &mut MigrationContext<'_>,
table_name: &str,
field: &Field,
) -> Result<(), MigrationError> {
let Some(reference) = &field.reference else {
return Ok(());
};
let sql = format!(
"alter table {} add column {};",
ctx.dialect().quote_ident(table_name),
render_inline_relation_column(ctx.dialect(), field, reference)
);
ctx.execute_raw_blueprint(&sql).await?;
if field.index {
let index = crate::IndexBlueprint::new(
&default_index_name(table_name, std::slice::from_ref(&field.name)),
table_name,
[field.name.as_str()],
);
ctx.execute_raw_blueprint(&index.create_sql(ctx.dialect()))
.await?;
}
Ok(())
}
fn render_inline_relation_column(
dialect: crate::SchemaDialect,
field: &Field,
reference: &Reference,
) -> String {
let column = crate::schema::blueprint::ColumnDef {
name: field.name.clone(),
ty: relation_column_type(field),
nullable: matches!(field.nullable, Nullability::Nullable),
primary_key: false,
auto_increment: false,
unique: field.unique,
default_raw: field.default.clone(),
};
let mut out = crate::schema::render::render_column(dialect, &column);
out.push_str(" references ");
out.push_str(&dialect.quote_ident(&reference.table));
out.push('(');
out.push_str(&dialect.quote_ident(reference.column.as_deref().unwrap_or("id")));
out.push(')');
match reference.on_delete.unwrap_or(ReferenceAction::Restrict) {
ReferenceAction::Cascade => out.push_str(" on delete cascade"),
ReferenceAction::Restrict => out.push_str(" on delete restrict"),
ReferenceAction::SetNull => out.push_str(" on delete set null"),
ReferenceAction::NoAction => out.push_str(" on delete no action"),
}
if let Some(action) = reference.on_update {
match action {
ReferenceAction::Cascade => out.push_str(" on update cascade"),
ReferenceAction::Restrict => out.push_str(" on update restrict"),
ReferenceAction::SetNull => out.push_str(" on update set null"),
ReferenceAction::NoAction => out.push_str(" on update no action"),
}
}
out
}
async fn apply_create_enum(
ctx: &mut MigrationContext<'_>,
name: &str,
values: &[String],
) -> Result<(), MigrationError> {
if matches!(ctx.dialect(), crate::SchemaDialect::Postgres) {
let values = values
.iter()
.map(|value| format!("'{}'", value.replace('\'', "''")))
.collect::<Vec<_>>()
.join(", ");
let sql = format!(
"create type {} as enum ({values});",
ctx.dialect().quote_ident(name)
);
ctx.execute_raw_blueprint(&sql).await?;
}
Ok(())
}
async fn apply_alter_enum(
ctx: &mut MigrationContext<'_>,
name: &str,
actions: &[AlterEnumAction],
) -> Result<(), MigrationError> {
if matches!(ctx.dialect(), crate::SchemaDialect::Postgres) {
for action in actions {
match action {
AlterEnumAction::AddValue(value) => {
let sql = format!(
"alter type {} add value if not exists '{}';",
ctx.dialect().quote_ident(name),
value.replace('\'', "''")
);
ctx.execute_raw_blueprint(&sql).await?;
}
}
}
}
Ok(())
}
async fn apply_drop_enum(ctx: &mut MigrationContext<'_>, name: &str) -> Result<(), MigrationError> {
if matches!(ctx.dialect(), crate::SchemaDialect::Postgres) {
let sql = format!("drop type if exists {};", ctx.dialect().quote_ident(name));
ctx.execute_raw_blueprint(&sql).await?;
}
Ok(())
}
async fn apply_rename_enum(
ctx: &mut MigrationContext<'_>,
from: &str,
to: &str,
) -> Result<(), MigrationError> {
if matches!(ctx.dialect(), crate::SchemaDialect::Postgres) {
let sql = format!(
"alter type {} rename to {};",
ctx.dialect().quote_ident(from),
ctx.dialect().quote_ident(to)
);
ctx.execute_raw_blueprint(&sql).await?;
}
Ok(())
}
impl From<quex::Pool> for Migrator {
fn from(pool: quex::Pool) -> Self {
#[allow(unreachable_patterns)]
let pool = match pool.driver() {
#[cfg(feature = "sqlite")]
quex::Driver::Sqlite => MigrationPool::Sqlite(pool),
#[cfg(feature = "postgres")]
quex::Driver::Pgsql => MigrationPool::Postgres(pool),
#[cfg(feature = "mariadb")]
quex::Driver::Mysql => MigrationPool::Mariadb(pool),
other => panic!("unsupported pool driver for migrations: {other:?}"),
};
Self {
pool,
migration_dir: default_migration_dir(),
}
}
}
impl From<&quex::Pool> for Migrator {
fn from(pool: &quex::Pool) -> Self {
pool.clone().into()
}
}