1use super::schema::{MigrationHint, Schema};
7use crate::ast::{Action, Constraint, Expr, IndexDef, Qail};
8
9pub fn diff_schemas(old: &Schema, new: &Schema) -> Vec<Qail> {
13 let mut cmds = Vec::new();
14
15 for hint in &new.migrations {
17 match hint {
18 MigrationHint::Rename { from, to } => {
19 if let (Some((from_table, from_col)), Some((to_table, to_col))) =
20 (parse_table_col(from), parse_table_col(to))
21 && from_table == to_table
22 {
23 cmds.push(Qail {
25 action: Action::Mod,
26 table: from_table.to_string(),
27 columns: vec![Expr::Named(format!("{} -> {}", from_col, to_col))],
28 ..Default::default()
29 });
30 }
31 }
32 MigrationHint::Transform { expression, target } => {
33 if let Some((table, _col)) = parse_table_col(target) {
34 cmds.push(Qail {
35 action: Action::Set,
36 table: table.to_string(),
37 columns: vec![Expr::Named(format!("/* TRANSFORM: {} */", expression))],
38 ..Default::default()
39 });
40 }
41 }
42 MigrationHint::Drop {
43 target,
44 confirmed: true,
45 } => {
46 if target.contains('.') {
47 if let Some((table, col)) = parse_table_col(target) {
49 cmds.push(Qail {
50 action: Action::AlterDrop,
51 table: table.to_string(),
52 columns: vec![Expr::Named(col.to_string())],
53 ..Default::default()
54 });
55 }
56 } else {
57 cmds.push(Qail {
59 action: Action::Drop,
60 table: target.clone(),
61 ..Default::default()
62 });
63 }
64 }
65 _ => {}
66 }
67 }
68
69 let mut new_table_names: Vec<&String> = new
71 .tables
72 .keys()
73 .filter(|name| !old.tables.contains_key(*name))
74 .collect();
75
76 new_table_names.sort_by_key(|name| {
79 new.tables
80 .get(*name)
81 .map(|t| t.columns.iter().filter(|c| c.foreign_key.is_some()).count())
82 .unwrap_or(0)
83 });
84
85 for name in new_table_names {
87 let table = &new.tables[name];
88 let columns: Vec<Expr> = table
89 .columns
90 .iter()
91 .map(|col| {
92 let mut constraints = Vec::new();
93 if col.primary_key {
94 constraints.push(Constraint::PrimaryKey);
95 }
96 if col.nullable {
97 constraints.push(Constraint::Nullable);
98 }
99 if col.unique {
100 constraints.push(Constraint::Unique);
101 }
102 if let Some(def) = &col.default {
103 constraints.push(Constraint::Default(def.clone()));
104 }
105 if let Some(ref fk) = col.foreign_key {
106 constraints.push(Constraint::References(format!(
107 "{}({})",
108 fk.table, fk.column
109 )));
110 }
111
112 Expr::Def {
113 name: col.name.clone(),
114 data_type: col.data_type.to_pg_type(),
115 constraints,
116 }
117 })
118 .collect();
119
120 cmds.push(Qail {
121 action: Action::Make,
122 table: name.clone(),
123 columns,
124 ..Default::default()
125 });
126 }
127
128 let mut dropped_tables: Vec<&String> = old.tables.keys()
132 .filter(|name| {
133 !new.tables.contains_key(*name) && !new.migrations.iter().any(
134 |h| matches!(h, MigrationHint::Drop { target, confirmed: true } if target == *name),
135 )
136 })
137 .collect();
138
139 dropped_tables.sort_by_key(|name| {
141 std::cmp::Reverse(
142 old.tables
143 .get(*name)
144 .map(|t| t.columns.iter().filter(|c| c.foreign_key.is_some()).count())
145 .unwrap_or(0)
146 )
147 });
148
149 for name in dropped_tables {
150 cmds.push(Qail {
151 action: Action::Drop,
152 table: name.clone(),
153 ..Default::default()
154 });
155 }
156
157 for (name, new_table) in &new.tables {
159 if let Some(old_table) = old.tables.get(name) {
160 let old_cols: std::collections::HashSet<_> =
161 old_table.columns.iter().map(|c| &c.name).collect();
162 let new_cols: std::collections::HashSet<_> =
163 new_table.columns.iter().map(|c| &c.name).collect();
164
165 for col in &new_table.columns {
167 if !old_cols.contains(&col.name) {
168 let is_rename_target = new.migrations.iter().any(|h| {
169 matches!(h, MigrationHint::Rename { to, .. } if to.ends_with(&format!(".{}", col.name)))
170 });
171
172 if !is_rename_target {
173 let mut constraints = Vec::new();
174 if col.nullable {
175 constraints.push(Constraint::Nullable);
176 }
177 if col.unique {
178 constraints.push(Constraint::Unique);
179 }
180 if let Some(def) = &col.default {
181 constraints.push(Constraint::Default(def.clone()));
182 }
183 let data_type = match &col.data_type {
186 super::types::ColumnType::Serial => "INTEGER".to_string(),
187 super::types::ColumnType::BigSerial => "BIGINT".to_string(),
188 other => other.to_pg_type(),
189 };
190
191 cmds.push(Qail {
192 action: Action::Alter,
193 table: name.clone(),
194 columns: vec![Expr::Def {
195 name: col.name.clone(),
196 data_type,
197 constraints,
198 }],
199 ..Default::default()
200 });
201 }
202 }
203 }
204
205 for col in &old_table.columns {
207 if !new_cols.contains(&col.name) {
208 let is_rename_source = new.migrations.iter().any(|h| {
209 matches!(h, MigrationHint::Rename { from, .. } if from.ends_with(&format!(".{}", col.name)))
210 });
211
212 let is_drop_hinted = new.migrations.iter().any(|h| {
213 matches!(h, MigrationHint::Drop { target, confirmed: true } if target == &format!("{}.{}", name, col.name))
214 });
215
216 if !is_rename_source && !is_drop_hinted {
217 cmds.push(Qail {
218 action: Action::AlterDrop,
219 table: name.clone(),
220 columns: vec![Expr::Named(col.name.clone())],
221 ..Default::default()
222 });
223 }
224 }
225 }
226
227 for new_col in &new_table.columns {
229 if let Some(old_col) = old_table.columns.iter().find(|c| c.name == new_col.name) {
230 let old_type = old_col.data_type.to_pg_type();
231 let new_type = new_col.data_type.to_pg_type();
232
233 if old_type != new_type {
234 let safe_new_type = match &new_col.data_type {
237 super::types::ColumnType::Serial => "INTEGER".to_string(),
238 super::types::ColumnType::BigSerial => "BIGINT".to_string(),
239 _ => new_type,
240 };
241
242 cmds.push(Qail {
243 action: Action::AlterType,
244 table: name.clone(),
245 columns: vec![Expr::Def {
246 name: new_col.name.clone(),
247 data_type: safe_new_type,
248 constraints: vec![],
249 }],
250 ..Default::default()
251 });
252 }
253
254 if old_col.nullable && !new_col.nullable && !new_col.primary_key {
256 cmds.push(Qail {
258 action: Action::AlterSetNotNull,
259 table: name.clone(),
260 columns: vec![Expr::Named(new_col.name.clone())],
261 ..Default::default()
262 });
263 } else if !old_col.nullable && new_col.nullable && !old_col.primary_key {
264 cmds.push(Qail {
266 action: Action::AlterDropNotNull,
267 table: name.clone(),
268 columns: vec![Expr::Named(new_col.name.clone())],
269 ..Default::default()
270 });
271 }
272
273 match (&old_col.default, &new_col.default) {
275 (None, Some(new_default)) => {
276 cmds.push(Qail {
278 action: Action::AlterSetDefault,
279 table: name.clone(),
280 columns: vec![Expr::Named(new_col.name.clone())],
281 payload: Some(new_default.clone()),
282 ..Default::default()
283 });
284 }
285 (Some(_), None) => {
286 cmds.push(Qail {
288 action: Action::AlterDropDefault,
289 table: name.clone(),
290 columns: vec![Expr::Named(new_col.name.clone())],
291 ..Default::default()
292 });
293 }
294 (Some(old_default), Some(new_default)) if old_default != new_default => {
295 cmds.push(Qail {
297 action: Action::AlterSetDefault,
298 table: name.clone(),
299 columns: vec![Expr::Named(new_col.name.clone())],
300 payload: Some(new_default.clone()),
301 ..Default::default()
302 });
303 }
304 _ => {} }
306 }
307 }
308
309 if !old_table.enable_rls && new_table.enable_rls {
311 cmds.push(Qail {
312 action: Action::AlterEnableRls,
313 table: name.clone(),
314 ..Default::default()
315 });
316 } else if old_table.enable_rls && !new_table.enable_rls {
317 cmds.push(Qail {
318 action: Action::AlterDisableRls,
319 table: name.clone(),
320 ..Default::default()
321 });
322 }
323
324 if !old_table.force_rls && new_table.force_rls {
325 cmds.push(Qail {
326 action: Action::AlterForceRls,
327 table: name.clone(),
328 ..Default::default()
329 });
330 } else if old_table.force_rls && !new_table.force_rls {
331 cmds.push(Qail {
332 action: Action::AlterNoForceRls,
333 table: name.clone(),
334 ..Default::default()
335 });
336 }
337 }
338 }
339
340 for new_idx in &new.indexes {
342 let exists = old.indexes.iter().any(|i| i.name == new_idx.name);
343 if !exists {
344 cmds.push(Qail {
345 action: Action::Index,
346 table: String::new(),
347 index_def: Some(IndexDef {
348 name: new_idx.name.clone(),
349 table: new_idx.table.clone(),
350 columns: new_idx.columns.clone(),
351 unique: new_idx.unique,
352 index_type: None,
353 }),
354 ..Default::default()
355 });
356 }
357 }
358
359 for old_idx in &old.indexes {
361 let exists = new.indexes.iter().any(|i| i.name == old_idx.name);
362 if !exists {
363 cmds.push(Qail {
364 action: Action::DropIndex,
365 table: old_idx.name.clone(),
366 ..Default::default()
367 });
368 }
369 }
370
371 cmds
372}
373
374fn parse_table_col(s: &str) -> Option<(&str, &str)> {
376 let parts: Vec<&str> = s.splitn(2, '.').collect();
377 if parts.len() == 2 {
378 Some((parts[0], parts[1]))
379 } else {
380 None
381 }
382}
383
384#[cfg(test)]
385mod tests {
386 use super::super::schema::{Column, Table};
387 use super::*;
388
389 #[test]
390 fn test_diff_new_table() {
391 use super::super::types::ColumnType;
392 let old = Schema::default();
393 let mut new = Schema::default();
394 new.add_table(
395 Table::new("users")
396 .column(Column::new("id", ColumnType::Serial).primary_key())
397 .column(Column::new("name", ColumnType::Text).not_null()),
398 );
399
400 let cmds = diff_schemas(&old, &new);
401 assert_eq!(cmds.len(), 1);
402 assert!(matches!(cmds[0].action, Action::Make));
403 }
404
405 #[test]
406 fn test_diff_rename_with_hint() {
407 use super::super::types::ColumnType;
408 let mut old = Schema::default();
409 old.add_table(Table::new("users").column(Column::new("username", ColumnType::Text)));
410
411 let mut new = Schema::default();
412 new.add_table(Table::new("users").column(Column::new("name", ColumnType::Text)));
413 new.add_hint(MigrationHint::Rename {
414 from: "users.username".into(),
415 to: "users.name".into(),
416 });
417
418 let cmds = diff_schemas(&old, &new);
419 assert!(cmds.iter().any(|c| matches!(c.action, Action::Mod)));
421 assert!(!cmds.iter().any(|c| matches!(c.action, Action::AlterDrop)));
422 }
423
424 #[test]
426 fn test_fk_ordering_parent_before_child() {
427 use super::super::types::ColumnType;
428
429 let old = Schema::default();
430
431 let mut new = Schema::default();
432 new.add_table(
434 Table::new("child")
435 .column(Column::new("id", ColumnType::Serial).primary_key())
436 .column(Column::new("parent_id", ColumnType::Int).references("parent", "id")),
437 );
438 new.add_table(
440 Table::new("parent")
441 .column(Column::new("id", ColumnType::Serial).primary_key())
442 .column(Column::new("name", ColumnType::Text)),
443 );
444
445 let cmds = diff_schemas(&old, &new);
446
447 let make_cmds: Vec<_> = cmds.iter().filter(|c| matches!(c.action, Action::Make)).collect();
449 assert_eq!(make_cmds.len(), 2);
450
451 let parent_idx = make_cmds.iter().position(|c| c.table == "parent").unwrap();
453 let child_idx = make_cmds.iter().position(|c| c.table == "child").unwrap();
454 assert!(parent_idx < child_idx, "parent table should be created before child with FK");
455 }
456
457 #[test]
459 fn test_fk_ordering_multiple_dependencies() {
460 use super::super::types::ColumnType;
461
462 let old = Schema::default();
463
464 let mut new = Schema::default();
465 new.add_table(
467 Table::new("order_items")
468 .column(Column::new("id", ColumnType::Serial).primary_key())
469 .column(Column::new("order_id", ColumnType::Int).references("orders", "id"))
470 .column(Column::new("product_id", ColumnType::Int).references("products", "id")),
471 );
472 new.add_table(
474 Table::new("orders")
475 .column(Column::new("id", ColumnType::Serial).primary_key())
476 .column(Column::new("user_id", ColumnType::Int).references("users", "id")),
477 );
478 new.add_table(
480 Table::new("users")
481 .column(Column::new("id", ColumnType::Serial).primary_key()),
482 );
483 new.add_table(
484 Table::new("products")
485 .column(Column::new("id", ColumnType::Serial).primary_key()),
486 );
487
488 let cmds = diff_schemas(&old, &new);
489
490 let make_cmds: Vec<_> = cmds.iter().filter(|c| matches!(c.action, Action::Make)).collect();
491 assert_eq!(make_cmds.len(), 4);
492
493 let users_idx = make_cmds.iter().position(|c| c.table == "users").unwrap();
495 let products_idx = make_cmds.iter().position(|c| c.table == "products").unwrap();
496 let orders_idx = make_cmds.iter().position(|c| c.table == "orders").unwrap();
497 let items_idx = make_cmds.iter().position(|c| c.table == "order_items").unwrap();
498
499 assert!(users_idx < orders_idx, "users (0 FK) before orders (1 FK)");
501 assert!(products_idx < items_idx, "products (0 FK) before order_items (2 FK)");
502
503 assert!(orders_idx < items_idx, "orders (1 FK) before order_items (2 FK)");
505 }
506}
507