1use super::schema::{MigrationHint, Schema};
7use crate::ast::{Action, Constraint, Expr, IndexDef, Qail};
8
9pub fn diff_schemas(old: &Schema, new: &Schema) -> Vec<Qail> {
13 let mut cmds = Vec::new();
14
15 for hint in &new.migrations {
17 match hint {
18 MigrationHint::Rename { from, to } => {
19 if let (Some((from_table, from_col)), Some((to_table, to_col))) =
20 (parse_table_col(from), parse_table_col(to))
21 && from_table == to_table
22 {
23 cmds.push(Qail {
25 action: Action::Mod,
26 table: from_table.to_string(),
27 columns: vec![Expr::Named(format!("{} -> {}", from_col, to_col))],
28 ..Default::default()
29 });
30 }
31 }
32 MigrationHint::Transform { expression, target } => {
33 if let Some((table, _col)) = parse_table_col(target) {
34 cmds.push(Qail {
35 action: Action::Set,
36 table: table.to_string(),
37 columns: vec![Expr::Named(format!("/* TRANSFORM: {} */", expression))],
38 ..Default::default()
39 });
40 }
41 }
42 MigrationHint::Drop {
43 target,
44 confirmed: true,
45 } => {
46 if target.contains('.') {
47 if let Some((table, col)) = parse_table_col(target) {
49 cmds.push(Qail {
50 action: Action::AlterDrop,
51 table: table.to_string(),
52 columns: vec![Expr::Named(col.to_string())],
53 ..Default::default()
54 });
55 }
56 } else {
57 cmds.push(Qail {
59 action: Action::Drop,
60 table: target.clone(),
61 ..Default::default()
62 });
63 }
64 }
65 _ => {}
66 }
67 }
68
69 let mut new_table_names: Vec<&String> = new
71 .tables
72 .keys()
73 .filter(|name| !old.tables.contains_key(*name))
74 .collect();
75
76 new_table_names.sort_by_key(|name| {
79 new.tables
80 .get(*name)
81 .map(|t| t.columns.iter().filter(|c| c.foreign_key.is_some()).count())
82 .unwrap_or(0)
83 });
84
85 for name in new_table_names {
87 let table = &new.tables[name];
88 let columns: Vec<Expr> = table
89 .columns
90 .iter()
91 .map(|col| {
92 let mut constraints = Vec::new();
93 if col.primary_key {
94 constraints.push(Constraint::PrimaryKey);
95 }
96 if col.nullable {
97 constraints.push(Constraint::Nullable);
98 }
99 if col.unique {
100 constraints.push(Constraint::Unique);
101 }
102 if let Some(def) = &col.default {
103 constraints.push(Constraint::Default(def.clone()));
104 }
105 if let Some(ref fk) = col.foreign_key {
106 constraints.push(Constraint::References(format!(
107 "{}({})",
108 fk.table, fk.column
109 )));
110 }
111
112 Expr::Def {
113 name: col.name.clone(),
114 data_type: col.data_type.to_pg_type(),
115 constraints,
116 }
117 })
118 .collect();
119
120 cmds.push(Qail {
121 action: Action::Make,
122 table: name.clone(),
123 columns,
124 ..Default::default()
125 });
126 }
127
128 let mut dropped_tables: Vec<&String> = old.tables.keys()
132 .filter(|name| {
133 !new.tables.contains_key(*name) && !new.migrations.iter().any(
134 |h| matches!(h, MigrationHint::Drop { target, confirmed: true } if target == *name),
135 )
136 })
137 .collect();
138
139 dropped_tables.sort_by_key(|name| {
141 std::cmp::Reverse(
142 old.tables
143 .get(*name)
144 .map(|t| t.columns.iter().filter(|c| c.foreign_key.is_some()).count())
145 .unwrap_or(0)
146 )
147 });
148
149 for name in dropped_tables {
150 cmds.push(Qail {
151 action: Action::Drop,
152 table: name.clone(),
153 ..Default::default()
154 });
155 }
156
157 for (name, new_table) in &new.tables {
159 if let Some(old_table) = old.tables.get(name) {
160 let old_cols: std::collections::HashSet<_> =
161 old_table.columns.iter().map(|c| &c.name).collect();
162 let new_cols: std::collections::HashSet<_> =
163 new_table.columns.iter().map(|c| &c.name).collect();
164
165 for col in &new_table.columns {
167 if !old_cols.contains(&col.name) {
168 let is_rename_target = new.migrations.iter().any(|h| {
169 matches!(h, MigrationHint::Rename { to, .. } if to.ends_with(&format!(".{}", col.name)))
170 });
171
172 if !is_rename_target {
173 let mut constraints = Vec::new();
174 if col.nullable {
175 constraints.push(Constraint::Nullable);
176 }
177 if col.unique {
178 constraints.push(Constraint::Unique);
179 }
180 if let Some(def) = &col.default {
181 constraints.push(Constraint::Default(def.clone()));
182 }
183 let data_type = match &col.data_type {
186 super::types::ColumnType::Serial => "INTEGER".to_string(),
187 super::types::ColumnType::BigSerial => "BIGINT".to_string(),
188 other => other.to_pg_type(),
189 };
190
191 cmds.push(Qail {
192 action: Action::Alter,
193 table: name.clone(),
194 columns: vec![Expr::Def {
195 name: col.name.clone(),
196 data_type,
197 constraints,
198 }],
199 ..Default::default()
200 });
201 }
202 }
203 }
204
205 for col in &old_table.columns {
207 if !new_cols.contains(&col.name) {
208 let is_rename_source = new.migrations.iter().any(|h| {
209 matches!(h, MigrationHint::Rename { from, .. } if from.ends_with(&format!(".{}", col.name)))
210 });
211
212 let is_drop_hinted = new.migrations.iter().any(|h| {
213 matches!(h, MigrationHint::Drop { target, confirmed: true } if target == &format!("{}.{}", name, col.name))
214 });
215
216 if !is_rename_source && !is_drop_hinted {
217 cmds.push(Qail {
218 action: Action::AlterDrop,
219 table: name.clone(),
220 columns: vec![Expr::Named(col.name.clone())],
221 ..Default::default()
222 });
223 }
224 }
225 }
226
227 for new_col in &new_table.columns {
229 if let Some(old_col) = old_table.columns.iter().find(|c| c.name == new_col.name) {
230 let old_type = old_col.data_type.to_pg_type();
231 let new_type = new_col.data_type.to_pg_type();
232
233 if old_type != new_type {
234 let safe_new_type = match &new_col.data_type {
237 super::types::ColumnType::Serial => "INTEGER".to_string(),
238 super::types::ColumnType::BigSerial => "BIGINT".to_string(),
239 _ => new_type,
240 };
241
242 cmds.push(Qail {
243 action: Action::AlterType,
244 table: name.clone(),
245 columns: vec![Expr::Def {
246 name: new_col.name.clone(),
247 data_type: safe_new_type,
248 constraints: vec![],
249 }],
250 ..Default::default()
251 });
252 }
253 }
254 }
255 }
256 }
257
258 for new_idx in &new.indexes {
260 let exists = old.indexes.iter().any(|i| i.name == new_idx.name);
261 if !exists {
262 cmds.push(Qail {
263 action: Action::Index,
264 table: String::new(),
265 index_def: Some(IndexDef {
266 name: new_idx.name.clone(),
267 table: new_idx.table.clone(),
268 columns: new_idx.columns.clone(),
269 unique: new_idx.unique,
270 index_type: None,
271 }),
272 ..Default::default()
273 });
274 }
275 }
276
277 for old_idx in &old.indexes {
279 let exists = new.indexes.iter().any(|i| i.name == old_idx.name);
280 if !exists {
281 cmds.push(Qail {
282 action: Action::DropIndex,
283 table: old_idx.name.clone(),
284 ..Default::default()
285 });
286 }
287 }
288
289 cmds
290}
291
292fn parse_table_col(s: &str) -> Option<(&str, &str)> {
294 let parts: Vec<&str> = s.splitn(2, '.').collect();
295 if parts.len() == 2 {
296 Some((parts[0], parts[1]))
297 } else {
298 None
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use super::super::schema::{Column, Table};
305 use super::*;
306
307 #[test]
308 fn test_diff_new_table() {
309 use super::super::types::ColumnType;
310 let old = Schema::default();
311 let mut new = Schema::default();
312 new.add_table(
313 Table::new("users")
314 .column(Column::new("id", ColumnType::Serial).primary_key())
315 .column(Column::new("name", ColumnType::Text).not_null()),
316 );
317
318 let cmds = diff_schemas(&old, &new);
319 assert_eq!(cmds.len(), 1);
320 assert!(matches!(cmds[0].action, Action::Make));
321 }
322
323 #[test]
324 fn test_diff_rename_with_hint() {
325 use super::super::types::ColumnType;
326 let mut old = Schema::default();
327 old.add_table(Table::new("users").column(Column::new("username", ColumnType::Text)));
328
329 let mut new = Schema::default();
330 new.add_table(Table::new("users").column(Column::new("name", ColumnType::Text)));
331 new.add_hint(MigrationHint::Rename {
332 from: "users.username".into(),
333 to: "users.name".into(),
334 });
335
336 let cmds = diff_schemas(&old, &new);
337 assert!(cmds.iter().any(|c| matches!(c.action, Action::Mod)));
339 assert!(!cmds.iter().any(|c| matches!(c.action, Action::AlterDrop)));
340 }
341
342 #[test]
344 fn test_fk_ordering_parent_before_child() {
345 use super::super::types::ColumnType;
346
347 let old = Schema::default();
348
349 let mut new = Schema::default();
350 new.add_table(
352 Table::new("child")
353 .column(Column::new("id", ColumnType::Serial).primary_key())
354 .column(Column::new("parent_id", ColumnType::Int).references("parent", "id")),
355 );
356 new.add_table(
358 Table::new("parent")
359 .column(Column::new("id", ColumnType::Serial).primary_key())
360 .column(Column::new("name", ColumnType::Text)),
361 );
362
363 let cmds = diff_schemas(&old, &new);
364
365 let make_cmds: Vec<_> = cmds.iter().filter(|c| matches!(c.action, Action::Make)).collect();
367 assert_eq!(make_cmds.len(), 2);
368
369 let parent_idx = make_cmds.iter().position(|c| c.table == "parent").unwrap();
371 let child_idx = make_cmds.iter().position(|c| c.table == "child").unwrap();
372 assert!(parent_idx < child_idx, "parent table should be created before child with FK");
373 }
374
375 #[test]
377 fn test_fk_ordering_multiple_dependencies() {
378 use super::super::types::ColumnType;
379
380 let old = Schema::default();
381
382 let mut new = Schema::default();
383 new.add_table(
385 Table::new("order_items")
386 .column(Column::new("id", ColumnType::Serial).primary_key())
387 .column(Column::new("order_id", ColumnType::Int).references("orders", "id"))
388 .column(Column::new("product_id", ColumnType::Int).references("products", "id")),
389 );
390 new.add_table(
392 Table::new("orders")
393 .column(Column::new("id", ColumnType::Serial).primary_key())
394 .column(Column::new("user_id", ColumnType::Int).references("users", "id")),
395 );
396 new.add_table(
398 Table::new("users")
399 .column(Column::new("id", ColumnType::Serial).primary_key()),
400 );
401 new.add_table(
402 Table::new("products")
403 .column(Column::new("id", ColumnType::Serial).primary_key()),
404 );
405
406 let cmds = diff_schemas(&old, &new);
407
408 let make_cmds: Vec<_> = cmds.iter().filter(|c| matches!(c.action, Action::Make)).collect();
409 assert_eq!(make_cmds.len(), 4);
410
411 let users_idx = make_cmds.iter().position(|c| c.table == "users").unwrap();
413 let products_idx = make_cmds.iter().position(|c| c.table == "products").unwrap();
414 let orders_idx = make_cmds.iter().position(|c| c.table == "orders").unwrap();
415 let items_idx = make_cmds.iter().position(|c| c.table == "order_items").unwrap();
416
417 assert!(users_idx < orders_idx, "users (0 FK) before orders (1 FK)");
419 assert!(products_idx < items_idx, "products (0 FK) before order_items (2 FK)");
420
421 assert!(orders_idx < items_idx, "orders (1 FK) before order_items (2 FK)");
423 }
424}
425