use super::schema::{MigrationHint, Schema};
use crate::ast::{Action, Constraint, Expr, IndexDef, Qail};
use std::collections::BTreeSet;
fn unsupported_state_diff_features(schema: &Schema) -> BTreeSet<&'static str> {
let mut out = BTreeSet::new();
if !schema.extensions.is_empty() {
out.insert("extensions");
}
if !schema.comments.is_empty() {
out.insert("comments");
}
if !schema.sequences.is_empty() {
out.insert("sequences");
}
if !schema.enums.is_empty() {
out.insert("enums");
}
if !schema.views.is_empty() {
out.insert("views");
}
if !schema.functions.is_empty() {
out.insert("functions");
}
if !schema.triggers.is_empty() {
out.insert("triggers");
}
if !schema.grants.is_empty() {
out.insert("grants");
}
if !schema.policies.is_empty() {
out.insert("policies");
}
if !schema.resources.is_empty() {
out.insert("resources");
}
out
}
pub fn validate_state_diff_support(old: &Schema, new: &Schema) -> Result<(), String> {
let mut unsupported = unsupported_state_diff_features(old);
unsupported.extend(unsupported_state_diff_features(new));
if unsupported.is_empty() {
return Ok(());
}
let detail = unsupported.into_iter().collect::<Vec<_>>().join(", ");
Err(format!(
"State-based diff currently supports tables, columns, indexes, and migration hints only. \
Unsupported schema object families present: {}. \
Use folder-based strict migrations for these objects.",
detail
))
}
pub fn diff_schemas_checked(old: &Schema, new: &Schema) -> Result<Vec<Qail>, String> {
validate_state_diff_support(old, new)?;
Ok(diff_schemas(old, new))
}
pub fn diff_schemas(old: &Schema, new: &Schema) -> Vec<Qail> {
let mut cmds = Vec::new();
for hint in &new.migrations {
match hint {
MigrationHint::Rename { from, to } => {
if let (Some((from_table, from_col)), Some((to_table, to_col))) =
(parse_table_col(from), parse_table_col(to))
&& from_table == to_table
{
cmds.push(Qail {
action: Action::Mod,
table: from_table.to_string(),
columns: vec![Expr::Named(format!("{} -> {}", from_col, to_col))],
..Default::default()
});
}
}
MigrationHint::Transform { expression, target } => {
if let Some((table, _col)) = parse_table_col(target) {
cmds.push(Qail {
action: Action::Set,
table: table.to_string(),
columns: vec![Expr::Named(format!("/* TRANSFORM: {} */", expression))],
..Default::default()
});
}
}
MigrationHint::Drop {
target,
confirmed: true,
} => {
if target.contains('.') {
if let Some((table, col)) = parse_table_col(target) {
cmds.push(Qail {
action: Action::AlterDrop,
table: table.to_string(),
columns: vec![Expr::Named(col.to_string())],
..Default::default()
});
}
} else {
cmds.push(Qail {
action: Action::Drop,
table: target.clone(),
..Default::default()
});
}
}
_ => {}
}
}
let new_table_names: Vec<&String> = new
.tables
.keys()
.filter(|name| !old.tables.contains_key(*name))
.collect();
let new_set: std::collections::HashSet<&str> =
new_table_names.iter().map(|n| n.as_str()).collect();
let mut emitted: std::collections::HashSet<&str> = std::collections::HashSet::new();
let mut sorted: Vec<&String> = Vec::with_capacity(new_table_names.len());
let mut remaining = new_table_names;
loop {
let before = sorted.len();
remaining.retain(|name| {
let deps_satisfied = new.tables.get(*name).is_none_or(|t| {
t.columns.iter().all(|c| {
c.foreign_key.as_ref().is_none_or(|fk| {
!new_set.contains(fk.table.as_str()) || emitted.contains(fk.table.as_str())
})
})
});
if deps_satisfied {
emitted.insert(name.as_str());
sorted.push(name);
false } else {
true }
});
if remaining.is_empty() || sorted.len() == before {
sorted.extend(remaining);
break;
}
}
let new_table_names = sorted;
for name in new_table_names {
let table = &new.tables[name];
let columns: Vec<Expr> = table
.columns
.iter()
.map(|col| {
let mut constraints = Vec::new();
if col.primary_key {
constraints.push(Constraint::PrimaryKey);
}
if col.nullable {
constraints.push(Constraint::Nullable);
}
if col.unique {
constraints.push(Constraint::Unique);
}
if let Some(def) = &col.default {
constraints.push(Constraint::Default(def.clone()));
}
if let Some(ref fk) = col.foreign_key {
constraints.push(Constraint::References(format!(
"{}({})",
fk.table, fk.column
)));
}
Expr::Def {
name: col.name.clone(),
data_type: col.data_type.to_pg_type(),
constraints,
}
})
.collect();
cmds.push(Qail {
action: Action::Make,
table: name.clone(),
columns,
..Default::default()
});
}
let mut dropped_tables: Vec<&String> = old
.tables
.keys()
.filter(|name| {
!new.tables.contains_key(*name) && !new.migrations.iter().any(
|h| matches!(h, MigrationHint::Drop { target, confirmed: true } if target == *name),
)
})
.collect();
dropped_tables.sort_by_key(|name| {
std::cmp::Reverse(
old.tables
.get(*name)
.map(|t| t.columns.iter().filter(|c| c.foreign_key.is_some()).count())
.unwrap_or(0),
)
});
for name in dropped_tables {
cmds.push(Qail {
action: Action::Drop,
table: name.clone(),
..Default::default()
});
}
for (name, new_table) in &new.tables {
if let Some(old_table) = old.tables.get(name) {
let old_cols: std::collections::HashSet<_> =
old_table.columns.iter().map(|c| &c.name).collect();
let new_cols: std::collections::HashSet<_> =
new_table.columns.iter().map(|c| &c.name).collect();
for col in &new_table.columns {
if !old_cols.contains(&col.name) {
let is_rename_target = new.migrations.iter().any(|h| {
matches!(h, MigrationHint::Rename { to, .. } if to.ends_with(&format!(".{}", col.name)))
});
if !is_rename_target {
let mut constraints = Vec::new();
if col.nullable {
constraints.push(Constraint::Nullable);
}
if col.unique {
constraints.push(Constraint::Unique);
}
if let Some(def) = &col.default {
constraints.push(Constraint::Default(def.clone()));
}
let data_type = match &col.data_type {
super::types::ColumnType::Serial => "INTEGER".to_string(),
super::types::ColumnType::BigSerial => "BIGINT".to_string(),
other => other.to_pg_type(),
};
cmds.push(Qail {
action: Action::Alter,
table: name.clone(),
columns: vec![Expr::Def {
name: col.name.clone(),
data_type,
constraints,
}],
..Default::default()
});
}
}
}
for col in &old_table.columns {
if !new_cols.contains(&col.name) {
let is_rename_source = new.migrations.iter().any(|h| {
matches!(h, MigrationHint::Rename { from, .. } if from.ends_with(&format!(".{}", col.name)))
});
let is_drop_hinted = new.migrations.iter().any(|h| {
matches!(h, MigrationHint::Drop { target, confirmed: true } if target == &format!("{}.{}", name, col.name))
});
if !is_rename_source && !is_drop_hinted {
cmds.push(Qail {
action: Action::AlterDrop,
table: name.clone(),
columns: vec![Expr::Named(col.name.clone())],
..Default::default()
});
}
}
}
for new_col in &new_table.columns {
if let Some(old_col) = old_table.columns.iter().find(|c| c.name == new_col.name) {
let old_type = old_col.data_type.to_pg_type();
let new_type = new_col.data_type.to_pg_type();
if old_type != new_type {
let safe_new_type = match &new_col.data_type {
super::types::ColumnType::Serial => "INTEGER".to_string(),
super::types::ColumnType::BigSerial => "BIGINT".to_string(),
_ => new_type,
};
cmds.push(Qail {
action: Action::AlterType,
table: name.clone(),
columns: vec![Expr::Def {
name: new_col.name.clone(),
data_type: safe_new_type,
constraints: vec![],
}],
..Default::default()
});
}
if old_col.nullable && !new_col.nullable && !new_col.primary_key {
cmds.push(Qail {
action: Action::AlterSetNotNull,
table: name.clone(),
columns: vec![Expr::Named(new_col.name.clone())],
..Default::default()
});
} else if !old_col.nullable && new_col.nullable && !old_col.primary_key {
cmds.push(Qail {
action: Action::AlterDropNotNull,
table: name.clone(),
columns: vec![Expr::Named(new_col.name.clone())],
..Default::default()
});
}
match (&old_col.default, &new_col.default) {
(None, Some(new_default)) => {
cmds.push(Qail {
action: Action::AlterSetDefault,
table: name.clone(),
columns: vec![Expr::Named(new_col.name.clone())],
payload: Some(new_default.clone()),
..Default::default()
});
}
(Some(_), None) => {
cmds.push(Qail {
action: Action::AlterDropDefault,
table: name.clone(),
columns: vec![Expr::Named(new_col.name.clone())],
..Default::default()
});
}
(Some(old_default), Some(new_default)) if old_default != new_default => {
cmds.push(Qail {
action: Action::AlterSetDefault,
table: name.clone(),
columns: vec![Expr::Named(new_col.name.clone())],
payload: Some(new_default.clone()),
..Default::default()
});
}
_ => {} }
}
}
if !old_table.enable_rls && new_table.enable_rls {
cmds.push(Qail {
action: Action::AlterEnableRls,
table: name.clone(),
..Default::default()
});
} else if old_table.enable_rls && !new_table.enable_rls {
cmds.push(Qail {
action: Action::AlterDisableRls,
table: name.clone(),
..Default::default()
});
}
if !old_table.force_rls && new_table.force_rls {
cmds.push(Qail {
action: Action::AlterForceRls,
table: name.clone(),
..Default::default()
});
} else if old_table.force_rls && !new_table.force_rls {
cmds.push(Qail {
action: Action::AlterNoForceRls,
table: name.clone(),
..Default::default()
});
}
}
}
for new_idx in &new.indexes {
let exists = old.indexes.iter().any(|i| i.name == new_idx.name);
if !exists {
cmds.push(Qail {
action: Action::Index,
table: String::new(),
index_def: Some(IndexDef {
name: new_idx.name.clone(),
table: new_idx.table.clone(),
columns: new_idx.columns.clone(),
unique: new_idx.unique,
index_type: None,
where_clause: None,
}),
..Default::default()
});
}
}
for old_idx in &old.indexes {
let exists = new.indexes.iter().any(|i| i.name == old_idx.name);
if !exists {
cmds.push(Qail {
action: Action::DropIndex,
table: old_idx.name.clone(),
..Default::default()
});
}
}
cmds
}
fn parse_table_col(s: &str) -> Option<(&str, &str)> {
let parts: Vec<&str> = s.splitn(2, '.').collect();
if parts.len() == 2 {
Some((parts[0], parts[1]))
} else {
None
}
}
#[cfg(test)]
mod tests {
use super::super::schema::{Column, Table, ViewDef};
use super::*;
#[test]
fn test_diff_new_table() {
use super::super::types::ColumnType;
let old = Schema::default();
let mut new = Schema::default();
new.add_table(
Table::new("users")
.column(Column::new("id", ColumnType::Serial).primary_key())
.column(Column::new("name", ColumnType::Text).not_null()),
);
let cmds = diff_schemas(&old, &new);
assert_eq!(cmds.len(), 1);
assert!(matches!(cmds[0].action, Action::Make));
}
#[test]
fn state_diff_support_rejects_non_table_object_families() {
let old = Schema::default();
let mut new = Schema::default();
new.add_view(ViewDef::new("active_users", "SELECT 1"));
let err = validate_state_diff_support(&old, &new)
.expect_err("state-based diff should reject unsupported view objects");
assert!(
err.contains("views"),
"error should include unsupported family name"
);
}
#[test]
fn state_diff_checked_passes_for_table_index_only_schema() {
use super::super::types::ColumnType;
let old = Schema::default();
let mut new = Schema::default();
new.add_table(Table::new("users").column(Column::new("id", ColumnType::Serial)));
let cmds = diff_schemas_checked(&old, &new).expect("table/index-only schema should pass");
assert!(
cmds.iter().any(|c| matches!(c.action, Action::Make)),
"checked diff should still produce normal table commands"
);
}
#[test]
fn test_diff_rename_with_hint() {
use super::super::types::ColumnType;
let mut old = Schema::default();
old.add_table(Table::new("users").column(Column::new("username", ColumnType::Text)));
let mut new = Schema::default();
new.add_table(Table::new("users").column(Column::new("name", ColumnType::Text)));
new.add_hint(MigrationHint::Rename {
from: "users.username".into(),
to: "users.name".into(),
});
let cmds = diff_schemas(&old, &new);
assert!(cmds.iter().any(|c| matches!(c.action, Action::Mod)));
assert!(!cmds.iter().any(|c| matches!(c.action, Action::AlterDrop)));
}
#[test]
fn test_fk_ordering_parent_before_child() {
use super::super::types::ColumnType;
let old = Schema::default();
let mut new = Schema::default();
new.add_table(
Table::new("child")
.column(Column::new("id", ColumnType::Serial).primary_key())
.column(Column::new("parent_id", ColumnType::Int).references("parent", "id")),
);
new.add_table(
Table::new("parent")
.column(Column::new("id", ColumnType::Serial).primary_key())
.column(Column::new("name", ColumnType::Text)),
);
let cmds = diff_schemas(&old, &new);
let make_cmds: Vec<_> = cmds
.iter()
.filter(|c| matches!(c.action, Action::Make))
.collect();
assert_eq!(make_cmds.len(), 2);
let parent_idx = make_cmds.iter().position(|c| c.table == "parent").unwrap();
let child_idx = make_cmds.iter().position(|c| c.table == "child").unwrap();
assert!(
parent_idx < child_idx,
"parent table should be created before child with FK"
);
}
#[test]
fn test_fk_ordering_multiple_dependencies() {
use super::super::types::ColumnType;
let old = Schema::default();
let mut new = Schema::default();
new.add_table(
Table::new("order_items")
.column(Column::new("id", ColumnType::Serial).primary_key())
.column(Column::new("order_id", ColumnType::Int).references("orders", "id"))
.column(Column::new("product_id", ColumnType::Int).references("products", "id")),
);
new.add_table(
Table::new("orders")
.column(Column::new("id", ColumnType::Serial).primary_key())
.column(Column::new("user_id", ColumnType::Int).references("users", "id")),
);
new.add_table(
Table::new("users").column(Column::new("id", ColumnType::Serial).primary_key()),
);
new.add_table(
Table::new("products").column(Column::new("id", ColumnType::Serial).primary_key()),
);
let cmds = diff_schemas(&old, &new);
let make_cmds: Vec<_> = cmds
.iter()
.filter(|c| matches!(c.action, Action::Make))
.collect();
assert_eq!(make_cmds.len(), 4);
let users_idx = make_cmds.iter().position(|c| c.table == "users").unwrap();
let products_idx = make_cmds
.iter()
.position(|c| c.table == "products")
.unwrap();
let orders_idx = make_cmds.iter().position(|c| c.table == "orders").unwrap();
let items_idx = make_cmds
.iter()
.position(|c| c.table == "order_items")
.unwrap();
assert!(users_idx < orders_idx, "users (0 FK) before orders (1 FK)");
assert!(
products_idx < items_idx,
"products (0 FK) before order_items (2 FK)"
);
assert!(
orders_idx < items_idx,
"orders (1 FK) before order_items (2 FK)"
);
}
}