use std::fmt::Write as _;
use serde::{Deserialize, Serialize};
use super::snapshot::{FieldSnapshot, SchemaSnapshot, TableSnapshot};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum SchemaChange {
CreateTable(String ),
DropTable(String ),
AddColumn {
table: String,
column: String,
},
DropColumn {
table: String,
column: String,
},
AlterColumnType {
table: String,
column: String,
from: String,
to: String,
},
AlterColumnNullable {
table: String,
column: String,
nullable: bool,
},
AlterColumnDefault {
table: String,
column: String,
from: Option<String>,
to: Option<String>,
},
AlterColumnMaxLength {
table: String,
column: String,
from: Option<u32>,
to: Option<u32>,
},
RenameTable {
old_name: String,
new_name: String,
},
RenameColumn {
table: String,
old_column: String,
new_column: String,
},
}
#[must_use]
pub fn detect_changes(prev: &SchemaSnapshot, current: &SchemaSnapshot) -> Vec<SchemaChange> {
let mut changes = Vec::new();
for t in ¤t.tables {
if prev.table(&t.name).is_none() {
changes.push(SchemaChange::CreateTable(t.name.clone()));
}
}
for t in ¤t.tables {
let Some(pt) = prev.table(&t.name) else {
continue;
};
for f in &t.fields {
if pt.field(&f.column).is_none() {
changes.push(SchemaChange::AddColumn {
table: t.name.clone(),
column: f.column.clone(),
});
}
}
}
for ct in ¤t.tables {
let Some(pt) = prev.table(&ct.name) else {
continue;
};
for cf in &ct.fields {
let Some(pf) = pt.field(&cf.column) else {
continue;
};
push_alter_changes(&ct.name, pf, cf, &mut changes);
}
}
for pt in &prev.tables {
let Some(t) = current.table(&pt.name) else {
continue;
};
for f in &pt.fields {
if t.field(&f.column).is_none() {
changes.push(SchemaChange::DropColumn {
table: pt.name.clone(),
column: f.column.clone(),
});
}
}
}
for pt in &prev.tables {
if current.table(&pt.name).is_none() {
changes.push(SchemaChange::DropTable(pt.name.clone()));
}
}
changes
}
fn push_alter_changes(
table: &str,
pf: &FieldSnapshot,
cf: &FieldSnapshot,
out: &mut Vec<SchemaChange>,
) {
if pf.ty != cf.ty {
out.push(SchemaChange::AlterColumnType {
table: table.to_owned(),
column: cf.column.clone(),
from: pf.ty.clone(),
to: cf.ty.clone(),
});
}
if pf.nullable != cf.nullable {
out.push(SchemaChange::AlterColumnNullable {
table: table.to_owned(),
column: cf.column.clone(),
nullable: cf.nullable,
});
}
if pf.default != cf.default {
out.push(SchemaChange::AlterColumnDefault {
table: table.to_owned(),
column: cf.column.clone(),
from: pf.default.clone(),
to: cf.default.clone(),
});
}
if pf.max_length != cf.max_length {
out.push(SchemaChange::AlterColumnMaxLength {
table: table.to_owned(),
column: cf.column.clone(),
from: pf.max_length,
to: cf.max_length,
});
}
}
#[must_use]
pub fn detect_unsupported_field_changes(
prev: &SchemaSnapshot,
current: &SchemaSnapshot,
) -> Vec<String> {
let mut out = Vec::new();
for ct in ¤t.tables {
let Some(pt) = prev.table(&ct.name) else {
continue;
};
for cf in &ct.fields {
let Some(pf) = pt.field(&cf.column) else {
continue;
};
push_field_diffs(&ct.name, pf, cf, &mut out);
}
}
out
}
fn push_field_diffs(table: &str, pf: &FieldSnapshot, cf: &FieldSnapshot, out: &mut Vec<String>) {
let col = &cf.column;
if pf.primary_key != cf.primary_key {
out.push(format!(
"`{table}.{col}` primary_key changed: {} → {}",
pf.primary_key, cf.primary_key
));
}
if pf.min != cf.min {
out.push(format!(
"`{table}.{col}` min changed: {:?} → {:?}",
pf.min, cf.min
));
}
if pf.max != cf.max {
out.push(format!(
"`{table}.{col}` max changed: {:?} → {:?}",
pf.max, cf.max
));
}
if pf.fk != cf.fk {
out.push(format!(
"`{table}.{col}` fk changed: {:?} → {:?}",
pf.fk, cf.fk
));
}
if pf.auto != cf.auto {
out.push(format!(
"`{table}.{col}` auto changed: {} → {}",
pf.auto, cf.auto
));
}
}
pub fn render_changes(
changes: &[SchemaChange],
current: &SchemaSnapshot,
) -> Result<Vec<String>, String> {
let RenderedBatch {
mut immediate,
deferred_fks,
} = render_changes_split(changes, current)?;
immediate.extend(deferred_fks);
Ok(immediate)
}
#[derive(Debug, Default)]
pub struct RenderedBatch {
pub immediate: Vec<String>,
pub deferred_fks: Vec<String>,
}
pub fn render_changes_split(
changes: &[SchemaChange],
current: &SchemaSnapshot,
) -> Result<RenderedBatch, String> {
let mut out = RenderedBatch::default();
for change in changes {
match change {
SchemaChange::CreateTable(name) => {
let table = current.table(name).ok_or_else(|| {
format!("CreateTable for `{name}` but no snapshot entry for it")
})?;
out.immediate.push(create_table_sql_from_snapshot(table));
out.deferred_fks
.extend(constraints_sql_from_snapshot(table));
}
SchemaChange::DropColumn { table, column } => {
out.immediate
.push(format!(r#"ALTER TABLE "{table}" DROP COLUMN "{column}""#,));
}
SchemaChange::AddColumn { table, column } => {
let t = current.table(table).ok_or_else(|| {
format!("AddColumn for `{table}.{column}` but table missing in snapshot")
})?;
let f = t.field(column).ok_or_else(|| {
format!("AddColumn for `{table}.{column}` but field missing in snapshot")
})?;
if !f.nullable && f.default.is_none() {
return Err(format!(
"AddColumn `{table}.{column}` is NOT NULL with no `default` — Postgres can't backfill existing rows. Make the field `Option<…>` or set `#[rustango(default = \"…\")]`.",
));
}
out.immediate.push(add_column_sql(table, f));
}
SchemaChange::DropTable(name) => {
out.immediate
.push(format!(r#"DROP TABLE "{name}" CASCADE"#));
}
SchemaChange::AlterColumnType {
table,
column,
from: _,
to,
} => {
let pg_to = pg_type_for_ty_name(to);
out.immediate.push(format!(
r#"ALTER TABLE "{table}" ALTER COLUMN "{column}" TYPE {pg_to} USING "{column}"::{pg_to}"#,
));
}
SchemaChange::AlterColumnNullable {
table,
column,
nullable,
} => {
let action = if *nullable { "DROP NOT NULL" } else { "SET NOT NULL" };
out.immediate.push(format!(
r#"ALTER TABLE "{table}" ALTER COLUMN "{column}" {action}"#,
));
}
SchemaChange::AlterColumnDefault {
table,
column,
from: _,
to,
} => match to {
Some(expr) => out.immediate.push(format!(
r#"ALTER TABLE "{table}" ALTER COLUMN "{column}" SET DEFAULT {expr}"#,
)),
None => out.immediate.push(format!(
r#"ALTER TABLE "{table}" ALTER COLUMN "{column}" DROP DEFAULT"#,
)),
},
SchemaChange::AlterColumnMaxLength {
table,
column,
from: _,
to,
} => {
let pg_to = match to {
Some(n) => format!("VARCHAR({n})"),
None => "TEXT".into(),
};
out.immediate.push(format!(
r#"ALTER TABLE "{table}" ALTER COLUMN "{column}" TYPE {pg_to} USING "{column}"::{pg_to}"#,
));
}
SchemaChange::RenameTable { old_name, new_name } => {
out.immediate.push(format!(
r#"ALTER TABLE "{old_name}" RENAME TO "{new_name}""#,
));
}
SchemaChange::RenameColumn {
table,
old_column,
new_column,
} => {
out.immediate.push(format!(
r#"ALTER TABLE "{table}" RENAME COLUMN "{old_column}" TO "{new_column}""#,
));
}
}
}
Ok(out)
}
fn pg_type_for_ty_name(ty: &str) -> String {
match ty {
"i32" => "INTEGER".into(),
"i64" => "BIGINT".into(),
"f32" => "REAL".into(),
"f64" => "DOUBLE PRECISION".into(),
"bool" => "BOOLEAN".into(),
"string" => "TEXT".into(),
"datetime" => "TIMESTAMPTZ".into(),
"date" => "DATE".into(),
"uuid" => "UUID".into(),
"json" => "JSONB".into(),
other => other.to_uppercase(),
}
}
fn create_table_sql_from_snapshot(t: &TableSnapshot) -> String {
let mut sql = format!(r#"CREATE TABLE "{}" ("#, t.name);
let mut first = true;
for f in &t.fields {
if !first {
sql.push_str(", ");
}
first = false;
let _ = write!(sql, r#""{}" {}"#, f.column, sql_type(f));
if let Some(expr) = &f.default {
let _ = write!(sql, " DEFAULT {expr}");
}
if !f.nullable {
sql.push_str(" NOT NULL");
}
if f.primary_key {
sql.push_str(" PRIMARY KEY");
}
if f.min.is_some() || f.max.is_some() {
sql.push_str(" CHECK (");
let mut wrote = false;
if let Some(min) = f.min {
let _ = write!(sql, r#""{}" >= {}"#, f.column, min);
wrote = true;
}
if let Some(max) = f.max {
if wrote {
sql.push_str(" AND ");
}
let _ = write!(sql, r#""{}" <= {}"#, f.column, max);
}
sql.push(')');
}
}
sql.push(')');
sql
}
fn constraints_sql_from_snapshot(t: &TableSnapshot) -> Vec<String> {
t.fields
.iter()
.filter_map(|f| {
f.fk.as_ref().map(|rel| {
format!(
r#"ALTER TABLE "{}" ADD CONSTRAINT "{}_{}_fkey" FOREIGN KEY ("{}") REFERENCES "{}" ("{}")"#,
t.name, t.name, f.column, f.column, rel.to, rel.on,
)
})
})
.collect()
}
fn add_column_sql(table: &str, f: &FieldSnapshot) -> String {
let mut sql = format!(
r#"ALTER TABLE "{}" ADD COLUMN "{}" {}"#,
table,
f.column,
sql_type(f)
);
if let Some(expr) = &f.default {
let _ = write!(sql, " DEFAULT {expr}");
}
if !f.nullable {
sql.push_str(" NOT NULL");
}
if f.min.is_some() || f.max.is_some() {
sql.push_str(" CHECK (");
let mut wrote = false;
if let Some(min) = f.min {
let _ = write!(sql, r#""{}" >= {}"#, f.column, min);
wrote = true;
}
if let Some(max) = f.max {
if wrote {
sql.push_str(" AND ");
}
let _ = write!(sql, r#""{}" <= {}"#, f.column, max);
}
sql.push(')');
}
sql
}
fn sql_type(f: &FieldSnapshot) -> String {
if f.auto {
return match f.ty.as_str() {
"i32" => "SERIAL".into(),
"i64" => "BIGSERIAL".into(),
other => other.to_uppercase(),
};
}
match f.ty.as_str() {
"i32" => "INTEGER".into(),
"i64" => "BIGINT".into(),
"f32" => "REAL".into(),
"f64" => "DOUBLE PRECISION".into(),
"bool" => "BOOLEAN".into(),
"string" => match f.max_length {
Some(n) => format!("VARCHAR({n})"),
None => "TEXT".into(),
},
"datetime" => "TIMESTAMPTZ".into(),
"date" => "DATE".into(),
"uuid" => "UUID".into(),
"json" => "JSONB".into(),
other => other.to_uppercase(),
}
}