1use super::schema::{MigrationHint, Schema};
7use crate::ast::{Action, Constraint, Expr, IndexDef, Qail};
8use std::collections::BTreeSet;
9
10fn unsupported_state_diff_features(schema: &Schema) -> BTreeSet<&'static str> {
14 let mut out = BTreeSet::new();
15 if !schema.extensions.is_empty() {
16 out.insert("extensions");
17 }
18 if !schema.comments.is_empty() {
19 out.insert("comments");
20 }
21 if !schema.sequences.is_empty() {
22 out.insert("sequences");
23 }
24 if !schema.enums.is_empty() {
25 out.insert("enums");
26 }
27 if !schema.views.is_empty() {
28 out.insert("views");
29 }
30 if !schema.functions.is_empty() {
31 out.insert("functions");
32 }
33 if !schema.triggers.is_empty() {
34 out.insert("triggers");
35 }
36 if !schema.grants.is_empty() {
37 out.insert("grants");
38 }
39 if !schema.policies.is_empty() {
40 out.insert("policies");
41 }
42 if !schema.resources.is_empty() {
43 out.insert("resources");
44 }
45 out
46}
47
48pub fn validate_state_diff_support(old: &Schema, new: &Schema) -> Result<(), String> {
52 let mut unsupported = unsupported_state_diff_features(old);
53 unsupported.extend(unsupported_state_diff_features(new));
54
55 if unsupported.is_empty() {
56 return Ok(());
57 }
58
59 let detail = unsupported.into_iter().collect::<Vec<_>>().join(", ");
60 Err(format!(
61 "State-based diff currently supports tables, columns, indexes, and migration hints only. \
62 Unsupported schema object families present: {}. \
63 Use folder-based strict migrations for these objects.",
64 detail
65 ))
66}
67
68pub fn diff_schemas_checked(old: &Schema, new: &Schema) -> Result<Vec<Qail>, String> {
70 validate_state_diff_support(old, new)?;
71 Ok(diff_schemas(old, new))
72}
73
74pub fn diff_schemas(old: &Schema, new: &Schema) -> Vec<Qail> {
78 let mut cmds = Vec::new();
79
80 for hint in &new.migrations {
82 match hint {
83 MigrationHint::Rename { from, to } => {
84 if let (Some((from_table, from_col)), Some((to_table, to_col))) =
85 (parse_table_col(from), parse_table_col(to))
86 && from_table == to_table
87 {
88 cmds.push(Qail {
90 action: Action::Mod,
91 table: from_table.to_string(),
92 columns: vec![Expr::Named(format!("{} -> {}", from_col, to_col))],
93 ..Default::default()
94 });
95 }
96 }
97 MigrationHint::Transform { expression, target } => {
98 if let Some((table, _col)) = parse_table_col(target) {
99 cmds.push(Qail {
100 action: Action::Set,
101 table: table.to_string(),
102 columns: vec![Expr::Named(format!("/* TRANSFORM: {} */", expression))],
103 ..Default::default()
104 });
105 }
106 }
107 MigrationHint::Drop {
108 target,
109 confirmed: true,
110 } => {
111 if target.contains('.') {
112 if let Some((table, col)) = parse_table_col(target) {
114 cmds.push(Qail {
115 action: Action::AlterDrop,
116 table: table.to_string(),
117 columns: vec![Expr::Named(col.to_string())],
118 ..Default::default()
119 });
120 }
121 } else {
122 cmds.push(Qail {
124 action: Action::Drop,
125 table: target.clone(),
126 ..Default::default()
127 });
128 }
129 }
130 _ => {}
131 }
132 }
133
134 let new_table_names: Vec<&String> = new
136 .tables
137 .keys()
138 .filter(|name| !old.tables.contains_key(*name))
139 .collect();
140
141 let new_set: std::collections::HashSet<&str> =
146 new_table_names.iter().map(|n| n.as_str()).collect();
147 let mut emitted: std::collections::HashSet<&str> = std::collections::HashSet::new();
148 let mut sorted: Vec<&String> = Vec::with_capacity(new_table_names.len());
149 let mut remaining = new_table_names;
150
151 loop {
152 let before = sorted.len();
153 remaining.retain(|name| {
154 let deps_satisfied = new.tables.get(*name).is_none_or(|t| {
155 t.columns.iter().all(|c| {
156 c.foreign_key.as_ref().is_none_or(|fk| {
157 !new_set.contains(fk.table.as_str()) || emitted.contains(fk.table.as_str())
158 })
159 })
160 });
161 if deps_satisfied {
162 emitted.insert(name.as_str());
163 sorted.push(name);
164 false } else {
166 true }
168 });
169 if remaining.is_empty() || sorted.len() == before {
170 sorted.extend(remaining);
172 break;
173 }
174 }
175
176 let new_table_names = sorted;
177
178 for name in new_table_names {
180 let table = &new.tables[name];
181 let columns: Vec<Expr> = table
182 .columns
183 .iter()
184 .map(|col| {
185 let mut constraints = Vec::new();
186 if col.primary_key {
187 constraints.push(Constraint::PrimaryKey);
188 }
189 if col.nullable {
190 constraints.push(Constraint::Nullable);
191 }
192 if col.unique {
193 constraints.push(Constraint::Unique);
194 }
195 if let Some(def) = &col.default {
196 constraints.push(Constraint::Default(def.clone()));
197 }
198 if let Some(ref fk) = col.foreign_key {
199 constraints.push(Constraint::References(format!(
200 "{}({})",
201 fk.table, fk.column
202 )));
203 }
204
205 Expr::Def {
206 name: col.name.clone(),
207 data_type: col.data_type.to_pg_type(),
208 constraints,
209 }
210 })
211 .collect();
212
213 cmds.push(Qail {
214 action: Action::Make,
215 table: name.clone(),
216 columns,
217 ..Default::default()
218 });
219 }
220
221 let mut dropped_tables: Vec<&String> = old
225 .tables
226 .keys()
227 .filter(|name| {
228 !new.tables.contains_key(*name) && !new.migrations.iter().any(
229 |h| matches!(h, MigrationHint::Drop { target, confirmed: true } if target == *name),
230 )
231 })
232 .collect();
233
234 dropped_tables.sort_by_key(|name| {
236 std::cmp::Reverse(
237 old.tables
238 .get(*name)
239 .map(|t| t.columns.iter().filter(|c| c.foreign_key.is_some()).count())
240 .unwrap_or(0),
241 )
242 });
243
244 for name in dropped_tables {
245 cmds.push(Qail {
246 action: Action::Drop,
247 table: name.clone(),
248 ..Default::default()
249 });
250 }
251
252 for (name, new_table) in &new.tables {
254 if let Some(old_table) = old.tables.get(name) {
255 let old_cols: std::collections::HashSet<_> =
256 old_table.columns.iter().map(|c| &c.name).collect();
257 let new_cols: std::collections::HashSet<_> =
258 new_table.columns.iter().map(|c| &c.name).collect();
259
260 for col in &new_table.columns {
262 if !old_cols.contains(&col.name) {
263 let is_rename_target = new.migrations.iter().any(|h| {
264 matches!(h, MigrationHint::Rename { to, .. } if to.ends_with(&format!(".{}", col.name)))
265 });
266
267 if !is_rename_target {
268 let mut constraints = Vec::new();
269 if col.nullable {
270 constraints.push(Constraint::Nullable);
271 }
272 if col.unique {
273 constraints.push(Constraint::Unique);
274 }
275 if let Some(def) = &col.default {
276 constraints.push(Constraint::Default(def.clone()));
277 }
278 let data_type = match &col.data_type {
281 super::types::ColumnType::Serial => "INTEGER".to_string(),
282 super::types::ColumnType::BigSerial => "BIGINT".to_string(),
283 other => other.to_pg_type(),
284 };
285
286 cmds.push(Qail {
287 action: Action::Alter,
288 table: name.clone(),
289 columns: vec![Expr::Def {
290 name: col.name.clone(),
291 data_type,
292 constraints,
293 }],
294 ..Default::default()
295 });
296 }
297 }
298 }
299
300 for col in &old_table.columns {
302 if !new_cols.contains(&col.name) {
303 let is_rename_source = new.migrations.iter().any(|h| {
304 matches!(h, MigrationHint::Rename { from, .. } if from.ends_with(&format!(".{}", col.name)))
305 });
306
307 let is_drop_hinted = new.migrations.iter().any(|h| {
308 matches!(h, MigrationHint::Drop { target, confirmed: true } if target == &format!("{}.{}", name, col.name))
309 });
310
311 if !is_rename_source && !is_drop_hinted {
312 cmds.push(Qail {
313 action: Action::AlterDrop,
314 table: name.clone(),
315 columns: vec![Expr::Named(col.name.clone())],
316 ..Default::default()
317 });
318 }
319 }
320 }
321
322 for new_col in &new_table.columns {
324 if let Some(old_col) = old_table.columns.iter().find(|c| c.name == new_col.name) {
325 let old_type = old_col.data_type.to_pg_type();
326 let new_type = new_col.data_type.to_pg_type();
327
328 if old_type != new_type {
329 let safe_new_type = match &new_col.data_type {
332 super::types::ColumnType::Serial => "INTEGER".to_string(),
333 super::types::ColumnType::BigSerial => "BIGINT".to_string(),
334 _ => new_type,
335 };
336
337 cmds.push(Qail {
338 action: Action::AlterType,
339 table: name.clone(),
340 columns: vec![Expr::Def {
341 name: new_col.name.clone(),
342 data_type: safe_new_type,
343 constraints: vec![],
344 }],
345 ..Default::default()
346 });
347 }
348
349 if old_col.nullable && !new_col.nullable && !new_col.primary_key {
351 cmds.push(Qail {
353 action: Action::AlterSetNotNull,
354 table: name.clone(),
355 columns: vec![Expr::Named(new_col.name.clone())],
356 ..Default::default()
357 });
358 } else if !old_col.nullable && new_col.nullable && !old_col.primary_key {
359 cmds.push(Qail {
361 action: Action::AlterDropNotNull,
362 table: name.clone(),
363 columns: vec![Expr::Named(new_col.name.clone())],
364 ..Default::default()
365 });
366 }
367
368 match (&old_col.default, &new_col.default) {
370 (None, Some(new_default)) => {
371 cmds.push(Qail {
373 action: Action::AlterSetDefault,
374 table: name.clone(),
375 columns: vec![Expr::Named(new_col.name.clone())],
376 payload: Some(new_default.clone()),
377 ..Default::default()
378 });
379 }
380 (Some(_), None) => {
381 cmds.push(Qail {
383 action: Action::AlterDropDefault,
384 table: name.clone(),
385 columns: vec![Expr::Named(new_col.name.clone())],
386 ..Default::default()
387 });
388 }
389 (Some(old_default), Some(new_default)) if old_default != new_default => {
390 cmds.push(Qail {
392 action: Action::AlterSetDefault,
393 table: name.clone(),
394 columns: vec![Expr::Named(new_col.name.clone())],
395 payload: Some(new_default.clone()),
396 ..Default::default()
397 });
398 }
399 _ => {} }
401 }
402 }
403
404 if !old_table.enable_rls && new_table.enable_rls {
406 cmds.push(Qail {
407 action: Action::AlterEnableRls,
408 table: name.clone(),
409 ..Default::default()
410 });
411 } else if old_table.enable_rls && !new_table.enable_rls {
412 cmds.push(Qail {
413 action: Action::AlterDisableRls,
414 table: name.clone(),
415 ..Default::default()
416 });
417 }
418
419 if !old_table.force_rls && new_table.force_rls {
420 cmds.push(Qail {
421 action: Action::AlterForceRls,
422 table: name.clone(),
423 ..Default::default()
424 });
425 } else if old_table.force_rls && !new_table.force_rls {
426 cmds.push(Qail {
427 action: Action::AlterNoForceRls,
428 table: name.clone(),
429 ..Default::default()
430 });
431 }
432 }
433 }
434
435 for new_idx in &new.indexes {
437 let exists = old.indexes.iter().any(|i| i.name == new_idx.name);
438 if !exists {
439 cmds.push(Qail {
440 action: Action::Index,
441 table: String::new(),
442 index_def: Some(IndexDef {
443 name: new_idx.name.clone(),
444 table: new_idx.table.clone(),
445 columns: new_idx.columns.clone(),
446 unique: new_idx.unique,
447 index_type: None,
448 where_clause: None,
449 }),
450 ..Default::default()
451 });
452 }
453 }
454
455 for old_idx in &old.indexes {
457 let exists = new.indexes.iter().any(|i| i.name == old_idx.name);
458 if !exists {
459 cmds.push(Qail {
460 action: Action::DropIndex,
461 table: old_idx.name.clone(),
462 ..Default::default()
463 });
464 }
465 }
466
467 cmds
468}
469
470fn parse_table_col(s: &str) -> Option<(&str, &str)> {
472 let parts: Vec<&str> = s.splitn(2, '.').collect();
473 if parts.len() == 2 {
474 Some((parts[0], parts[1]))
475 } else {
476 None
477 }
478}
479
480#[cfg(test)]
481mod tests {
482 use super::super::schema::{Column, Table, ViewDef};
483 use super::*;
484
485 #[test]
486 fn test_diff_new_table() {
487 use super::super::types::ColumnType;
488 let old = Schema::default();
489 let mut new = Schema::default();
490 new.add_table(
491 Table::new("users")
492 .column(Column::new("id", ColumnType::Serial).primary_key())
493 .column(Column::new("name", ColumnType::Text).not_null()),
494 );
495
496 let cmds = diff_schemas(&old, &new);
497 assert_eq!(cmds.len(), 1);
498 assert!(matches!(cmds[0].action, Action::Make));
499 }
500
501 #[test]
502 fn state_diff_support_rejects_non_table_object_families() {
503 let old = Schema::default();
504 let mut new = Schema::default();
505 new.add_view(ViewDef::new("active_users", "SELECT 1"));
506
507 let err = validate_state_diff_support(&old, &new)
508 .expect_err("state-based diff should reject unsupported view objects");
509 assert!(
510 err.contains("views"),
511 "error should include unsupported family name"
512 );
513 }
514
515 #[test]
516 fn state_diff_checked_passes_for_table_index_only_schema() {
517 use super::super::types::ColumnType;
518 let old = Schema::default();
519 let mut new = Schema::default();
520 new.add_table(Table::new("users").column(Column::new("id", ColumnType::Serial)));
521 let cmds = diff_schemas_checked(&old, &new).expect("table/index-only schema should pass");
522 assert!(
523 cmds.iter().any(|c| matches!(c.action, Action::Make)),
524 "checked diff should still produce normal table commands"
525 );
526 }
527
528 #[test]
529 fn test_diff_rename_with_hint() {
530 use super::super::types::ColumnType;
531 let mut old = Schema::default();
532 old.add_table(Table::new("users").column(Column::new("username", ColumnType::Text)));
533
534 let mut new = Schema::default();
535 new.add_table(Table::new("users").column(Column::new("name", ColumnType::Text)));
536 new.add_hint(MigrationHint::Rename {
537 from: "users.username".into(),
538 to: "users.name".into(),
539 });
540
541 let cmds = diff_schemas(&old, &new);
542 assert!(cmds.iter().any(|c| matches!(c.action, Action::Mod)));
544 assert!(!cmds.iter().any(|c| matches!(c.action, Action::AlterDrop)));
545 }
546
547 #[test]
549 fn test_fk_ordering_parent_before_child() {
550 use super::super::types::ColumnType;
551
552 let old = Schema::default();
553
554 let mut new = Schema::default();
555 new.add_table(
557 Table::new("child")
558 .column(Column::new("id", ColumnType::Serial).primary_key())
559 .column(Column::new("parent_id", ColumnType::Int).references("parent", "id")),
560 );
561 new.add_table(
563 Table::new("parent")
564 .column(Column::new("id", ColumnType::Serial).primary_key())
565 .column(Column::new("name", ColumnType::Text)),
566 );
567
568 let cmds = diff_schemas(&old, &new);
569
570 let make_cmds: Vec<_> = cmds
572 .iter()
573 .filter(|c| matches!(c.action, Action::Make))
574 .collect();
575 assert_eq!(make_cmds.len(), 2);
576
577 let parent_idx = make_cmds.iter().position(|c| c.table == "parent").unwrap();
579 let child_idx = make_cmds.iter().position(|c| c.table == "child").unwrap();
580 assert!(
581 parent_idx < child_idx,
582 "parent table should be created before child with FK"
583 );
584 }
585
586 #[test]
588 fn test_fk_ordering_multiple_dependencies() {
589 use super::super::types::ColumnType;
590
591 let old = Schema::default();
592
593 let mut new = Schema::default();
594 new.add_table(
596 Table::new("order_items")
597 .column(Column::new("id", ColumnType::Serial).primary_key())
598 .column(Column::new("order_id", ColumnType::Int).references("orders", "id"))
599 .column(Column::new("product_id", ColumnType::Int).references("products", "id")),
600 );
601 new.add_table(
603 Table::new("orders")
604 .column(Column::new("id", ColumnType::Serial).primary_key())
605 .column(Column::new("user_id", ColumnType::Int).references("users", "id")),
606 );
607 new.add_table(
609 Table::new("users").column(Column::new("id", ColumnType::Serial).primary_key()),
610 );
611 new.add_table(
612 Table::new("products").column(Column::new("id", ColumnType::Serial).primary_key()),
613 );
614
615 let cmds = diff_schemas(&old, &new);
616
617 let make_cmds: Vec<_> = cmds
618 .iter()
619 .filter(|c| matches!(c.action, Action::Make))
620 .collect();
621 assert_eq!(make_cmds.len(), 4);
622
623 let users_idx = make_cmds.iter().position(|c| c.table == "users").unwrap();
625 let products_idx = make_cmds
626 .iter()
627 .position(|c| c.table == "products")
628 .unwrap();
629 let orders_idx = make_cmds.iter().position(|c| c.table == "orders").unwrap();
630 let items_idx = make_cmds
631 .iter()
632 .position(|c| c.table == "order_items")
633 .unwrap();
634
635 assert!(users_idx < orders_idx, "users (0 FK) before orders (1 FK)");
637 assert!(
638 products_idx < items_idx,
639 "products (0 FK) before order_items (2 FK)"
640 );
641
642 assert!(
644 orders_idx < items_idx,
645 "orders (1 FK) before order_items (2 FK)"
646 );
647 }
648}