use std::collections::{HashMap, HashSet};
use std::path::{Path, PathBuf};
use crate::error::OrmError;
use crate::executor::Executor;
use crate::registry::{registered_models, TableSchema};
use super::ddl::{
AlterAction, AlterTable, ColumnSpec, DefaultValue, ForeignKeySpec, TableDef,
};
use super::introspect::ExistingColumn;
use super::{files, introspect, render};
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct SchemaChange {
pub up: Vec<String>,
pub down: Vec<String>,
}
impl SchemaChange {
pub fn is_empty(&self) -> bool {
!self
.up
.iter()
.any(|s| !s.trim_start().starts_with("--"))
}
}
pub async fn generate<E: Executor + Sync>(
executor: &E,
models: &[TableSchema],
) -> crate::Result<SchemaChange> {
let dialect = executor.dialect();
let mut up: Vec<String> = Vec::new();
let mut down: Vec<String> = Vec::new();
for schema in models {
if !introspect::table_exists(executor, schema.table).await? {
let def = table_def_from_schema(schema);
up.extend(render::create_table(dialect, &def)?);
down.push(render::drop_table(dialect, schema.table, true));
continue;
}
let existing_indexes = introspect::existing_indexes(executor, schema.table).await?;
let existing_index_names: HashSet<&str> =
existing_indexes.iter().map(|i| i.name.as_str()).collect();
let model_index_names: HashSet<&str> =
schema.indexes.iter().map(|i| i.name.as_str()).collect();
let existing_cols = introspect::existing_columns(executor, schema.table).await?;
let existing_col_map: HashMap<&str, &ExistingColumn> =
existing_cols.iter().map(|c| (c.name.as_str(), c)).collect();
let model_col_names: HashSet<&str> = schema.columns.iter().map(|c| c.name).collect();
for index in &existing_indexes {
if !model_index_names.contains(index.name.as_str()) {
up.push(render::drop_index(dialect, &index.name, true));
down.push(format!(
"-- cannot recreate dropped index \"{}\" (its definition is unknown)",
index.name
));
}
}
for col in &existing_cols {
if !model_col_names.contains(col.name.as_str()) {
if col.is_pk {
up.push(format!(
"-- NOTE: column \"{}\" was removed from the model but cannot be \
dropped automatically (primary key); rebuild the table manually",
col.name
));
} else {
let alter = AlterTable {
table: schema.table.to_string(),
actions: vec![AlterAction::DropColumn(col.name.clone())],
};
for stmt in render::alter_table(dialect, &alter) {
up.push(stmt);
}
down.push(restore_column_sql(dialect, schema.table, col));
}
}
}
for model_col in &schema.columns {
if let Some(existing_col) = existing_col_map.get(model_col.name) {
let model_type =
render::column_type_str(dialect, model_col.sql_type);
let db_type = existing_col.declared_type.trim().to_uppercase();
if existing_col.is_pk {
continue;
}
let type_changed = model_type.to_uppercase() != db_type
&& !db_type.is_empty();
let nullability_changed = model_col.nullable == existing_col.not_null;
if type_changed || nullability_changed {
up.push(format!(
"-- NOTE: column \"{}\" definition changed \
(model: {} {}, database: {} {}); \
rebuild the table to apply the change",
model_col.name,
model_type,
if model_col.nullable { "nullable" } else { "not null" },
existing_col.declared_type,
if existing_col.not_null { "not null" } else { "nullable" },
));
}
}
}
for model_col in &schema.columns {
if !existing_col_map.contains_key(model_col.name) {
let mut spec = ColumnSpec::new(model_col.name, model_col.sql_type);
spec.primary_key = false;
spec.auto_increment = false;
spec.default = column_default_ddl(model_col.default);
if !model_col.nullable {
spec.nullable = true;
up.push(format!(
"-- NOTE: column \"{}\" added as nullable; NOT NULL \
requires a default value for existing rows",
model_col.name
));
} else {
spec.nullable = true;
}
let alter = AlterTable {
table: schema.table.to_string(),
actions: vec![AlterAction::AddColumn(spec)],
};
for stmt in render::alter_table(dialect, &alter) {
up.push(stmt);
}
let drop_alter = AlterTable {
table: schema.table.to_string(),
actions: vec![AlterAction::DropColumn(model_col.name.to_string())],
};
for stmt in render::alter_table(dialect, &drop_alter) {
down.push(stmt);
}
}
}
for index in &schema.indexes {
if !existing_index_names.contains(index.name.as_str()) {
up.push(render::create_index(dialect, schema.table, index, false)?);
down.push(render::drop_index(dialect, &index.name, true));
}
}
}
down.reverse();
Ok(SchemaChange { up, down })
}
pub async fn generate_from_registry<E: Executor + Sync>(
executor: &E,
) -> crate::Result<SchemaChange> {
generate(executor, ®istered_models()).await
}
pub async fn generate_and_write<E: Executor + Sync>(
executor: &E,
dir: &Path,
name: &str,
) -> crate::Result<Option<PathBuf>> {
let change = generate_from_registry(executor).await?;
write_migration(dir, name, &change)
}
pub fn write_migration(
dir: &Path,
name: &str,
change: &SchemaChange,
) -> crate::Result<Option<PathBuf>> {
if change.is_empty() {
return Ok(None);
}
let down_revision = files::head_revision(dir)?;
let revision = new_revision();
let snake = snake_case(name);
let contents = render_migration_file(&revision, down_revision.as_deref(), &snake, change);
let path = dir.join(format!("{revision}_{snake}.sql"));
std::fs::write(&path, contents)
.map_err(|error| OrmError::configuration(format!("could not write migration file: {error}")))?;
Ok(Some(path))
}
pub fn render_migration_file(
revision: &str,
down_revision: Option<&str>,
name: &str,
change: &SchemaChange,
) -> String {
let mut out = String::new();
out.push_str(&format!("-- revision: {revision}\n"));
out.push_str(&format!(
"-- down_revision: {}\n",
down_revision.unwrap_or("")
));
out.push_str(&format!("-- name: {name}\n\n"));
out.push_str("-- migrate:up\n");
for statement in &change.up {
push_statement(&mut out, statement);
}
out.push_str("\n-- migrate:down\n");
for statement in &change.down {
push_statement(&mut out, statement);
}
out
}
fn push_statement(out: &mut String, statement: &str) {
out.push_str(statement);
if !statement.trim_start().starts_with("--") {
out.push(';');
}
out.push('\n');
}
fn column_default_ddl(default: Option<crate::ColumnDefault>) -> Option<DefaultValue> {
match default? {
crate::ColumnDefault::CurrentTimestamp => Some(DefaultValue::CurrentTimestamp),
crate::ColumnDefault::Uuid => Some(DefaultValue::Uuid),
crate::ColumnDefault::Raw(sql) => Some(DefaultValue::Raw(sql.to_string())),
}
}
fn table_def_from_schema(schema: &TableSchema) -> TableDef {
let mut def = TableDef::new(schema.table);
for column in &schema.columns {
let mut spec = ColumnSpec::new(column.name, column.sql_type);
spec.nullable = column.nullable;
spec.primary_key = column.primary_key;
spec.auto_increment = column.auto;
spec.default = column_default_ddl(column.default);
def.columns.push(spec);
if let Some(foreign_key) = column.foreign_key {
def.foreign_keys.push(ForeignKeySpec {
columns: vec![column.name.to_string()],
ref_table: foreign_key.table.to_string(),
ref_columns: vec![foreign_key.column.to_string()],
on_delete: foreign_key.on_delete,
on_update: foreign_key.on_update,
});
}
}
def.indexes = schema.indexes.clone();
def.checks = schema.checks.iter().map(|check| check.to_string()).collect();
def
}
fn restore_column_sql(
dialect: &dyn crate::dialect::Dialect,
table: &str,
col: &ExistingColumn,
) -> String {
let mut sql = String::from("ALTER TABLE ");
dialect.quote_identifier(table, &mut sql);
sql.push_str(" ADD COLUMN ");
dialect.quote_identifier(&col.name, &mut sql);
sql.push(' ');
if col.declared_type.is_empty() {
sql.push_str("TEXT");
} else {
sql.push_str(&col.declared_type);
}
sql
}
fn new_revision() -> String {
uuid::Uuid::new_v4().simple().to_string()[..12].to_string()
}
fn snake_case(name: &str) -> String {
let mut out = String::new();
for ch in name.chars() {
if ch.is_ascii_alphanumeric() {
out.push(ch.to_ascii_lowercase());
} else if !out.ends_with('_') && !out.is_empty() {
out.push('_');
}
}
out.trim_matches('_').to_string()
}