1use super::schema::{MigrationHint, Schema};
7use crate::ast::{Action, Constraint, Expr, IndexDef, Qail};
8
9pub fn diff_schemas(old: &Schema, new: &Schema) -> Vec<Qail> {
13 let mut cmds = Vec::new();
14
15 for hint in &new.migrations {
17 match hint {
18 MigrationHint::Rename { from, to } => {
19 if let (Some((from_table, from_col)), Some((to_table, to_col))) =
20 (parse_table_col(from), parse_table_col(to))
21 && from_table == to_table
22 {
23 cmds.push(Qail {
25 action: Action::Mod,
26 table: from_table.to_string(),
27 columns: vec![Expr::Named(format!("{} -> {}", from_col, to_col))],
28 ..Default::default()
29 });
30 }
31 }
32 MigrationHint::Transform { expression, target } => {
33 if let Some((table, _col)) = parse_table_col(target) {
34 cmds.push(Qail {
35 action: Action::Set,
36 table: table.to_string(),
37 columns: vec![Expr::Named(format!("/* TRANSFORM: {} */", expression))],
38 ..Default::default()
39 });
40 }
41 }
42 MigrationHint::Drop {
43 target,
44 confirmed: true,
45 } => {
46 if target.contains('.') {
47 if let Some((table, col)) = parse_table_col(target) {
49 cmds.push(Qail {
50 action: Action::AlterDrop,
51 table: table.to_string(),
52 columns: vec![Expr::Named(col.to_string())],
53 ..Default::default()
54 });
55 }
56 } else {
57 cmds.push(Qail {
59 action: Action::Drop,
60 table: target.clone(),
61 ..Default::default()
62 });
63 }
64 }
65 _ => {}
66 }
67 }
68
69 let new_table_names: Vec<&String> = new
71 .tables
72 .keys()
73 .filter(|name| !old.tables.contains_key(*name))
74 .collect();
75
76 let new_set: std::collections::HashSet<&str> =
81 new_table_names.iter().map(|n| n.as_str()).collect();
82 let mut emitted: std::collections::HashSet<&str> = std::collections::HashSet::new();
83 let mut sorted: Vec<&String> = Vec::with_capacity(new_table_names.len());
84 let mut remaining = new_table_names;
85
86 loop {
87 let before = sorted.len();
88 remaining.retain(|name| {
89 let deps_satisfied = new.tables.get(*name).is_none_or(|t| {
90 t.columns.iter().all(|c| {
91 c.foreign_key.as_ref().is_none_or(|fk| {
92 !new_set.contains(fk.table.as_str()) || emitted.contains(fk.table.as_str())
93 })
94 })
95 });
96 if deps_satisfied {
97 emitted.insert(name.as_str());
98 sorted.push(name);
99 false } else {
101 true }
103 });
104 if remaining.is_empty() || sorted.len() == before {
105 sorted.extend(remaining);
107 break;
108 }
109 }
110
111 let new_table_names = sorted;
112
113 for name in new_table_names {
115 let table = &new.tables[name];
116 let columns: Vec<Expr> = table
117 .columns
118 .iter()
119 .map(|col| {
120 let mut constraints = Vec::new();
121 if col.primary_key {
122 constraints.push(Constraint::PrimaryKey);
123 }
124 if col.nullable {
125 constraints.push(Constraint::Nullable);
126 }
127 if col.unique {
128 constraints.push(Constraint::Unique);
129 }
130 if let Some(def) = &col.default {
131 constraints.push(Constraint::Default(def.clone()));
132 }
133 if let Some(ref fk) = col.foreign_key {
134 constraints.push(Constraint::References(format!(
135 "{}({})",
136 fk.table, fk.column
137 )));
138 }
139
140 Expr::Def {
141 name: col.name.clone(),
142 data_type: col.data_type.to_pg_type(),
143 constraints,
144 }
145 })
146 .collect();
147
148 cmds.push(Qail {
149 action: Action::Make,
150 table: name.clone(),
151 columns,
152 ..Default::default()
153 });
154 }
155
156 let mut dropped_tables: Vec<&String> = old
160 .tables
161 .keys()
162 .filter(|name| {
163 !new.tables.contains_key(*name) && !new.migrations.iter().any(
164 |h| matches!(h, MigrationHint::Drop { target, confirmed: true } if target == *name),
165 )
166 })
167 .collect();
168
169 dropped_tables.sort_by_key(|name| {
171 std::cmp::Reverse(
172 old.tables
173 .get(*name)
174 .map(|t| t.columns.iter().filter(|c| c.foreign_key.is_some()).count())
175 .unwrap_or(0),
176 )
177 });
178
179 for name in dropped_tables {
180 cmds.push(Qail {
181 action: Action::Drop,
182 table: name.clone(),
183 ..Default::default()
184 });
185 }
186
187 for (name, new_table) in &new.tables {
189 if let Some(old_table) = old.tables.get(name) {
190 let old_cols: std::collections::HashSet<_> =
191 old_table.columns.iter().map(|c| &c.name).collect();
192 let new_cols: std::collections::HashSet<_> =
193 new_table.columns.iter().map(|c| &c.name).collect();
194
195 for col in &new_table.columns {
197 if !old_cols.contains(&col.name) {
198 let is_rename_target = new.migrations.iter().any(|h| {
199 matches!(h, MigrationHint::Rename { to, .. } if to.ends_with(&format!(".{}", col.name)))
200 });
201
202 if !is_rename_target {
203 let mut constraints = Vec::new();
204 if col.nullable {
205 constraints.push(Constraint::Nullable);
206 }
207 if col.unique {
208 constraints.push(Constraint::Unique);
209 }
210 if let Some(def) = &col.default {
211 constraints.push(Constraint::Default(def.clone()));
212 }
213 let data_type = match &col.data_type {
216 super::types::ColumnType::Serial => "INTEGER".to_string(),
217 super::types::ColumnType::BigSerial => "BIGINT".to_string(),
218 other => other.to_pg_type(),
219 };
220
221 cmds.push(Qail {
222 action: Action::Alter,
223 table: name.clone(),
224 columns: vec![Expr::Def {
225 name: col.name.clone(),
226 data_type,
227 constraints,
228 }],
229 ..Default::default()
230 });
231 }
232 }
233 }
234
235 for col in &old_table.columns {
237 if !new_cols.contains(&col.name) {
238 let is_rename_source = new.migrations.iter().any(|h| {
239 matches!(h, MigrationHint::Rename { from, .. } if from.ends_with(&format!(".{}", col.name)))
240 });
241
242 let is_drop_hinted = new.migrations.iter().any(|h| {
243 matches!(h, MigrationHint::Drop { target, confirmed: true } if target == &format!("{}.{}", name, col.name))
244 });
245
246 if !is_rename_source && !is_drop_hinted {
247 cmds.push(Qail {
248 action: Action::AlterDrop,
249 table: name.clone(),
250 columns: vec![Expr::Named(col.name.clone())],
251 ..Default::default()
252 });
253 }
254 }
255 }
256
257 for new_col in &new_table.columns {
259 if let Some(old_col) = old_table.columns.iter().find(|c| c.name == new_col.name) {
260 let old_type = old_col.data_type.to_pg_type();
261 let new_type = new_col.data_type.to_pg_type();
262
263 if old_type != new_type {
264 let safe_new_type = match &new_col.data_type {
267 super::types::ColumnType::Serial => "INTEGER".to_string(),
268 super::types::ColumnType::BigSerial => "BIGINT".to_string(),
269 _ => new_type,
270 };
271
272 cmds.push(Qail {
273 action: Action::AlterType,
274 table: name.clone(),
275 columns: vec![Expr::Def {
276 name: new_col.name.clone(),
277 data_type: safe_new_type,
278 constraints: vec![],
279 }],
280 ..Default::default()
281 });
282 }
283
284 if old_col.nullable && !new_col.nullable && !new_col.primary_key {
286 cmds.push(Qail {
288 action: Action::AlterSetNotNull,
289 table: name.clone(),
290 columns: vec![Expr::Named(new_col.name.clone())],
291 ..Default::default()
292 });
293 } else if !old_col.nullable && new_col.nullable && !old_col.primary_key {
294 cmds.push(Qail {
296 action: Action::AlterDropNotNull,
297 table: name.clone(),
298 columns: vec![Expr::Named(new_col.name.clone())],
299 ..Default::default()
300 });
301 }
302
303 match (&old_col.default, &new_col.default) {
305 (None, Some(new_default)) => {
306 cmds.push(Qail {
308 action: Action::AlterSetDefault,
309 table: name.clone(),
310 columns: vec![Expr::Named(new_col.name.clone())],
311 payload: Some(new_default.clone()),
312 ..Default::default()
313 });
314 }
315 (Some(_), None) => {
316 cmds.push(Qail {
318 action: Action::AlterDropDefault,
319 table: name.clone(),
320 columns: vec![Expr::Named(new_col.name.clone())],
321 ..Default::default()
322 });
323 }
324 (Some(old_default), Some(new_default)) if old_default != new_default => {
325 cmds.push(Qail {
327 action: Action::AlterSetDefault,
328 table: name.clone(),
329 columns: vec![Expr::Named(new_col.name.clone())],
330 payload: Some(new_default.clone()),
331 ..Default::default()
332 });
333 }
334 _ => {} }
336 }
337 }
338
339 if !old_table.enable_rls && new_table.enable_rls {
341 cmds.push(Qail {
342 action: Action::AlterEnableRls,
343 table: name.clone(),
344 ..Default::default()
345 });
346 } else if old_table.enable_rls && !new_table.enable_rls {
347 cmds.push(Qail {
348 action: Action::AlterDisableRls,
349 table: name.clone(),
350 ..Default::default()
351 });
352 }
353
354 if !old_table.force_rls && new_table.force_rls {
355 cmds.push(Qail {
356 action: Action::AlterForceRls,
357 table: name.clone(),
358 ..Default::default()
359 });
360 } else if old_table.force_rls && !new_table.force_rls {
361 cmds.push(Qail {
362 action: Action::AlterNoForceRls,
363 table: name.clone(),
364 ..Default::default()
365 });
366 }
367 }
368 }
369
370 for new_idx in &new.indexes {
372 let exists = old.indexes.iter().any(|i| i.name == new_idx.name);
373 if !exists {
374 cmds.push(Qail {
375 action: Action::Index,
376 table: String::new(),
377 index_def: Some(IndexDef {
378 name: new_idx.name.clone(),
379 table: new_idx.table.clone(),
380 columns: new_idx.columns.clone(),
381 unique: new_idx.unique,
382 index_type: None,
383 }),
384 ..Default::default()
385 });
386 }
387 }
388
389 for old_idx in &old.indexes {
391 let exists = new.indexes.iter().any(|i| i.name == old_idx.name);
392 if !exists {
393 cmds.push(Qail {
394 action: Action::DropIndex,
395 table: old_idx.name.clone(),
396 ..Default::default()
397 });
398 }
399 }
400
401 cmds
402}
403
404fn parse_table_col(s: &str) -> Option<(&str, &str)> {
406 let parts: Vec<&str> = s.splitn(2, '.').collect();
407 if parts.len() == 2 {
408 Some((parts[0], parts[1]))
409 } else {
410 None
411 }
412}
413
414#[cfg(test)]
415mod tests {
416 use super::super::schema::{Column, Table};
417 use super::*;
418
419 #[test]
420 fn test_diff_new_table() {
421 use super::super::types::ColumnType;
422 let old = Schema::default();
423 let mut new = Schema::default();
424 new.add_table(
425 Table::new("users")
426 .column(Column::new("id", ColumnType::Serial).primary_key())
427 .column(Column::new("name", ColumnType::Text).not_null()),
428 );
429
430 let cmds = diff_schemas(&old, &new);
431 assert_eq!(cmds.len(), 1);
432 assert!(matches!(cmds[0].action, Action::Make));
433 }
434
435 #[test]
436 fn test_diff_rename_with_hint() {
437 use super::super::types::ColumnType;
438 let mut old = Schema::default();
439 old.add_table(Table::new("users").column(Column::new("username", ColumnType::Text)));
440
441 let mut new = Schema::default();
442 new.add_table(Table::new("users").column(Column::new("name", ColumnType::Text)));
443 new.add_hint(MigrationHint::Rename {
444 from: "users.username".into(),
445 to: "users.name".into(),
446 });
447
448 let cmds = diff_schemas(&old, &new);
449 assert!(cmds.iter().any(|c| matches!(c.action, Action::Mod)));
451 assert!(!cmds.iter().any(|c| matches!(c.action, Action::AlterDrop)));
452 }
453
454 #[test]
456 fn test_fk_ordering_parent_before_child() {
457 use super::super::types::ColumnType;
458
459 let old = Schema::default();
460
461 let mut new = Schema::default();
462 new.add_table(
464 Table::new("child")
465 .column(Column::new("id", ColumnType::Serial).primary_key())
466 .column(Column::new("parent_id", ColumnType::Int).references("parent", "id")),
467 );
468 new.add_table(
470 Table::new("parent")
471 .column(Column::new("id", ColumnType::Serial).primary_key())
472 .column(Column::new("name", ColumnType::Text)),
473 );
474
475 let cmds = diff_schemas(&old, &new);
476
477 let make_cmds: Vec<_> = cmds
479 .iter()
480 .filter(|c| matches!(c.action, Action::Make))
481 .collect();
482 assert_eq!(make_cmds.len(), 2);
483
484 let parent_idx = make_cmds.iter().position(|c| c.table == "parent").unwrap();
486 let child_idx = make_cmds.iter().position(|c| c.table == "child").unwrap();
487 assert!(
488 parent_idx < child_idx,
489 "parent table should be created before child with FK"
490 );
491 }
492
493 #[test]
495 fn test_fk_ordering_multiple_dependencies() {
496 use super::super::types::ColumnType;
497
498 let old = Schema::default();
499
500 let mut new = Schema::default();
501 new.add_table(
503 Table::new("order_items")
504 .column(Column::new("id", ColumnType::Serial).primary_key())
505 .column(Column::new("order_id", ColumnType::Int).references("orders", "id"))
506 .column(Column::new("product_id", ColumnType::Int).references("products", "id")),
507 );
508 new.add_table(
510 Table::new("orders")
511 .column(Column::new("id", ColumnType::Serial).primary_key())
512 .column(Column::new("user_id", ColumnType::Int).references("users", "id")),
513 );
514 new.add_table(
516 Table::new("users").column(Column::new("id", ColumnType::Serial).primary_key()),
517 );
518 new.add_table(
519 Table::new("products").column(Column::new("id", ColumnType::Serial).primary_key()),
520 );
521
522 let cmds = diff_schemas(&old, &new);
523
524 let make_cmds: Vec<_> = cmds
525 .iter()
526 .filter(|c| matches!(c.action, Action::Make))
527 .collect();
528 assert_eq!(make_cmds.len(), 4);
529
530 let users_idx = make_cmds.iter().position(|c| c.table == "users").unwrap();
532 let products_idx = make_cmds
533 .iter()
534 .position(|c| c.table == "products")
535 .unwrap();
536 let orders_idx = make_cmds.iter().position(|c| c.table == "orders").unwrap();
537 let items_idx = make_cmds
538 .iter()
539 .position(|c| c.table == "order_items")
540 .unwrap();
541
542 assert!(users_idx < orders_idx, "users (0 FK) before orders (1 FK)");
544 assert!(
545 products_idx < items_idx,
546 "products (0 FK) before order_items (2 FK)"
547 );
548
549 assert!(
551 orders_idx < items_idx,
552 "orders (1 FK) before order_items (2 FK)"
553 );
554 }
555}