pgmold/diff/
mod.rs

1pub mod planner;
2
3use crate::model::{
4    Column, EnumType, ForeignKey, Function, Index, PgType, Policy, PrimaryKey, Table,
5};
6
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub enum MigrationOp {
9    CreateEnum(EnumType),
10    DropEnum(String),
11    CreateTable(Table),
12    DropTable(String),
13    AddColumn {
14        table: String,
15        column: Column,
16    },
17    DropColumn {
18        table: String,
19        column: String,
20    },
21    AlterColumn {
22        table: String,
23        column: String,
24        changes: ColumnChanges,
25    },
26    AddPrimaryKey {
27        table: String,
28        primary_key: PrimaryKey,
29    },
30    DropPrimaryKey {
31        table: String,
32    },
33    AddIndex {
34        table: String,
35        index: Index,
36    },
37    DropIndex {
38        table: String,
39        index_name: String,
40    },
41    AddForeignKey {
42        table: String,
43        foreign_key: ForeignKey,
44    },
45    DropForeignKey {
46        table: String,
47        foreign_key_name: String,
48    },
49    EnableRls {
50        table: String,
51    },
52    DisableRls {
53        table: String,
54    },
55    CreatePolicy(Policy),
56    DropPolicy {
57        table: String,
58        name: String,
59    },
60    AlterPolicy {
61        table: String,
62        name: String,
63        changes: PolicyChanges,
64    },
65    CreateFunction(Function),
66    DropFunction {
67        name: String,
68        args: String,
69    },
70    AlterFunction {
71        name: String,
72        args: String,
73        new_function: Function,
74    },
75}
76
77#[derive(Debug, Clone, PartialEq, Eq)]
78pub struct PolicyChanges {
79    pub roles: Option<Vec<String>>,
80    pub using_expr: Option<Option<String>>,
81    pub check_expr: Option<Option<String>>,
82}
83
84#[derive(Debug, Clone, PartialEq, Eq)]
85pub struct ColumnChanges {
86    pub data_type: Option<PgType>,
87    pub nullable: Option<bool>,
88    pub default: Option<Option<String>>,
89}
90
91use crate::model::Schema;
92
93pub fn compute_diff(from: &Schema, to: &Schema) -> Vec<MigrationOp> {
94    let mut ops = Vec::new();
95
96    ops.extend(diff_enums(from, to));
97    ops.extend(diff_tables(from, to));
98    ops.extend(diff_functions(from, to));
99
100    for (name, to_table) in &to.tables {
101        if let Some(from_table) = from.tables.get(name) {
102            ops.extend(diff_columns(from_table, to_table));
103            ops.extend(diff_primary_keys(from_table, to_table));
104            ops.extend(diff_indexes(from_table, to_table));
105            ops.extend(diff_foreign_keys(from_table, to_table));
106            ops.extend(diff_rls(from_table, to_table));
107            ops.extend(diff_policies(from_table, to_table));
108        }
109    }
110
111    ops
112}
113
114fn diff_enums(from: &Schema, to: &Schema) -> Vec<MigrationOp> {
115    let mut ops = Vec::new();
116
117    for (name, enum_type) in &to.enums {
118        if !from.enums.contains_key(name) {
119            ops.push(MigrationOp::CreateEnum(enum_type.clone()));
120        }
121    }
122
123    for name in from.enums.keys() {
124        if !to.enums.contains_key(name) {
125            ops.push(MigrationOp::DropEnum(name.clone()));
126        }
127    }
128
129    ops
130}
131
132fn diff_tables(from: &Schema, to: &Schema) -> Vec<MigrationOp> {
133    let mut ops = Vec::new();
134
135    for (name, table) in &to.tables {
136        if !from.tables.contains_key(name) {
137            ops.push(MigrationOp::CreateTable(table.clone()));
138        }
139    }
140
141    for name in from.tables.keys() {
142        if !to.tables.contains_key(name) {
143            ops.push(MigrationOp::DropTable(name.clone()));
144        }
145    }
146
147    ops
148}
149
150fn diff_functions(from: &Schema, to: &Schema) -> Vec<MigrationOp> {
151    let mut ops = Vec::new();
152
153    for (sig, func) in &to.functions {
154        if let Some(from_func) = from.functions.get(sig) {
155            if from_func != func {
156                ops.push(MigrationOp::AlterFunction {
157                    name: func.name.clone(),
158                    args: func
159                        .arguments
160                        .iter()
161                        .map(|a| a.data_type.clone())
162                        .collect::<Vec<_>>()
163                        .join(", "),
164                    new_function: func.clone(),
165                });
166            }
167        } else {
168            ops.push(MigrationOp::CreateFunction(func.clone()));
169        }
170    }
171
172    for (sig, func) in &from.functions {
173        if !to.functions.contains_key(sig) {
174            ops.push(MigrationOp::DropFunction {
175                name: func.name.clone(),
176                args: func
177                    .arguments
178                    .iter()
179                    .map(|a| a.data_type.clone())
180                    .collect::<Vec<_>>()
181                    .join(", "),
182            });
183        }
184    }
185
186    ops
187}
188
189fn diff_columns(from_table: &Table, to_table: &Table) -> Vec<MigrationOp> {
190    let mut ops = Vec::new();
191
192    for (name, column) in &to_table.columns {
193        if let Some(from_column) = from_table.columns.get(name) {
194            let changes = compute_column_changes(from_column, column);
195            if changes.data_type.is_some()
196                || changes.nullable.is_some()
197                || changes.default.is_some()
198            {
199                ops.push(MigrationOp::AlterColumn {
200                    table: to_table.name.clone(),
201                    column: name.clone(),
202                    changes,
203                });
204            }
205        } else {
206            ops.push(MigrationOp::AddColumn {
207                table: to_table.name.clone(),
208                column: column.clone(),
209            });
210        }
211    }
212
213    for name in from_table.columns.keys() {
214        if !to_table.columns.contains_key(name) {
215            ops.push(MigrationOp::DropColumn {
216                table: from_table.name.clone(),
217                column: name.clone(),
218            });
219        }
220    }
221
222    ops
223}
224
225fn compute_column_changes(from: &Column, to: &Column) -> ColumnChanges {
226    ColumnChanges {
227        data_type: if from.data_type != to.data_type {
228            Some(to.data_type.clone())
229        } else {
230            None
231        },
232        nullable: if from.nullable != to.nullable {
233            Some(to.nullable)
234        } else {
235            None
236        },
237        default: if from.default != to.default {
238            Some(to.default.clone())
239        } else {
240            None
241        },
242    }
243}
244
245fn diff_primary_keys(from_table: &Table, to_table: &Table) -> Vec<MigrationOp> {
246    let mut ops = Vec::new();
247
248    match (&from_table.primary_key, &to_table.primary_key) {
249        (None, Some(pk)) => {
250            ops.push(MigrationOp::AddPrimaryKey {
251                table: to_table.name.clone(),
252                primary_key: pk.clone(),
253            });
254        }
255        (Some(_), None) => {
256            ops.push(MigrationOp::DropPrimaryKey {
257                table: from_table.name.clone(),
258            });
259        }
260        (Some(from_pk), Some(to_pk)) if from_pk != to_pk => {
261            ops.push(MigrationOp::DropPrimaryKey {
262                table: from_table.name.clone(),
263            });
264            ops.push(MigrationOp::AddPrimaryKey {
265                table: to_table.name.clone(),
266                primary_key: to_pk.clone(),
267            });
268        }
269        _ => {}
270    }
271
272    ops
273}
274
275fn diff_indexes(from_table: &Table, to_table: &Table) -> Vec<MigrationOp> {
276    let mut ops = Vec::new();
277
278    for index in &to_table.indexes {
279        if !from_table.indexes.iter().any(|i| i.name == index.name) {
280            ops.push(MigrationOp::AddIndex {
281                table: to_table.name.clone(),
282                index: index.clone(),
283            });
284        }
285    }
286
287    for index in &from_table.indexes {
288        if !to_table.indexes.iter().any(|i| i.name == index.name) {
289            ops.push(MigrationOp::DropIndex {
290                table: from_table.name.clone(),
291                index_name: index.name.clone(),
292            });
293        }
294    }
295
296    ops
297}
298
299fn diff_foreign_keys(from_table: &Table, to_table: &Table) -> Vec<MigrationOp> {
300    let mut ops = Vec::new();
301
302    for foreign_key in &to_table.foreign_keys {
303        if !from_table
304            .foreign_keys
305            .iter()
306            .any(|fk| fk.name == foreign_key.name)
307        {
308            ops.push(MigrationOp::AddForeignKey {
309                table: to_table.name.clone(),
310                foreign_key: foreign_key.clone(),
311            });
312        }
313    }
314
315    for foreign_key in &from_table.foreign_keys {
316        if !to_table
317            .foreign_keys
318            .iter()
319            .any(|fk| fk.name == foreign_key.name)
320        {
321            ops.push(MigrationOp::DropForeignKey {
322                table: from_table.name.clone(),
323                foreign_key_name: foreign_key.name.clone(),
324            });
325        }
326    }
327
328    ops
329}
330
331fn diff_rls(from_table: &Table, to_table: &Table) -> Vec<MigrationOp> {
332    let mut ops = Vec::new();
333
334    if !from_table.row_level_security && to_table.row_level_security {
335        ops.push(MigrationOp::EnableRls {
336            table: to_table.name.clone(),
337        });
338    } else if from_table.row_level_security && !to_table.row_level_security {
339        ops.push(MigrationOp::DisableRls {
340            table: to_table.name.clone(),
341        });
342    }
343
344    ops
345}
346
347fn diff_policies(from_table: &Table, to_table: &Table) -> Vec<MigrationOp> {
348    let mut ops = Vec::new();
349
350    for policy in &to_table.policies {
351        if let Some(from_policy) = from_table.policies.iter().find(|p| p.name == policy.name) {
352            let changes = compute_policy_changes(from_policy, policy);
353            if changes.roles.is_some()
354                || changes.using_expr.is_some()
355                || changes.check_expr.is_some()
356            {
357                ops.push(MigrationOp::AlterPolicy {
358                    table: to_table.name.clone(),
359                    name: policy.name.clone(),
360                    changes,
361                });
362            }
363        } else {
364            ops.push(MigrationOp::CreatePolicy(policy.clone()));
365        }
366    }
367
368    for policy in &from_table.policies {
369        if !to_table.policies.iter().any(|p| p.name == policy.name) {
370            ops.push(MigrationOp::DropPolicy {
371                table: from_table.name.clone(),
372                name: policy.name.clone(),
373            });
374        }
375    }
376
377    ops
378}
379
380fn compute_policy_changes(from: &Policy, to: &Policy) -> PolicyChanges {
381    PolicyChanges {
382        roles: if from.roles != to.roles {
383            Some(to.roles.clone())
384        } else {
385            None
386        },
387        using_expr: if from.using_expr != to.using_expr {
388            Some(to.using_expr.clone())
389        } else {
390            None
391        },
392        check_expr: if from.check_expr != to.check_expr {
393            Some(to.check_expr.clone())
394        } else {
395            None
396        },
397    }
398}
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403    use crate::model::{IndexType, ReferentialAction, SecurityType, Volatility};
404    use std::collections::BTreeMap;
405
406    fn empty_schema() -> Schema {
407        Schema::new()
408    }
409
410    fn simple_table(name: &str) -> Table {
411        Table {
412            name: name.to_string(),
413            columns: BTreeMap::new(),
414            indexes: Vec::new(),
415            primary_key: None,
416            foreign_keys: Vec::new(),
417            comment: None,
418            row_level_security: false,
419            policies: Vec::new(),
420        }
421    }
422
423    fn simple_column(name: &str, data_type: PgType) -> Column {
424        Column {
425            name: name.to_string(),
426            data_type,
427            nullable: true,
428            default: None,
429            comment: None,
430        }
431    }
432
433    #[test]
434    fn detects_added_enum() {
435        let from = empty_schema();
436        let mut to = empty_schema();
437        to.enums.insert(
438            "status".to_string(),
439            EnumType {
440                name: "status".to_string(),
441                values: vec!["active".to_string(), "inactive".to_string()],
442            },
443        );
444
445        let ops = compute_diff(&from, &to);
446        assert_eq!(ops.len(), 1);
447        assert!(matches!(&ops[0], MigrationOp::CreateEnum(e) if e.name == "status"));
448    }
449
450    #[test]
451    fn detects_removed_enum() {
452        let mut from = empty_schema();
453        from.enums.insert(
454            "status".to_string(),
455            EnumType {
456                name: "status".to_string(),
457                values: vec!["active".to_string()],
458            },
459        );
460        let to = empty_schema();
461
462        let ops = compute_diff(&from, &to);
463        assert_eq!(ops.len(), 1);
464        assert!(matches!(&ops[0], MigrationOp::DropEnum(name) if name == "status"));
465    }
466
467    #[test]
468    fn detects_added_table() {
469        let from = empty_schema();
470        let mut to = empty_schema();
471        to.tables.insert("users".to_string(), simple_table("users"));
472
473        let ops = compute_diff(&from, &to);
474        assert_eq!(ops.len(), 1);
475        assert!(matches!(&ops[0], MigrationOp::CreateTable(t) if t.name == "users"));
476    }
477
478    #[test]
479    fn detects_removed_table() {
480        let mut from = empty_schema();
481        from.tables
482            .insert("users".to_string(), simple_table("users"));
483        let to = empty_schema();
484
485        let ops = compute_diff(&from, &to);
486        assert_eq!(ops.len(), 1);
487        assert!(matches!(&ops[0], MigrationOp::DropTable(name) if name == "users"));
488    }
489
490    #[test]
491    fn detects_added_column() {
492        let mut from = empty_schema();
493        from.tables
494            .insert("users".to_string(), simple_table("users"));
495
496        let mut to = empty_schema();
497        let mut table = simple_table("users");
498        table
499            .columns
500            .insert("email".to_string(), simple_column("email", PgType::Text));
501        to.tables.insert("users".to_string(), table);
502
503        let ops = compute_diff(&from, &to);
504        assert_eq!(ops.len(), 1);
505        assert!(
506            matches!(&ops[0], MigrationOp::AddColumn { table, column } if table == "users" && column.name == "email")
507        );
508    }
509
510    #[test]
511    fn detects_removed_column() {
512        let mut from = empty_schema();
513        let mut table = simple_table("users");
514        table
515            .columns
516            .insert("email".to_string(), simple_column("email", PgType::Text));
517        from.tables.insert("users".to_string(), table);
518
519        let mut to = empty_schema();
520        to.tables.insert("users".to_string(), simple_table("users"));
521
522        let ops = compute_diff(&from, &to);
523        assert_eq!(ops.len(), 1);
524        assert!(
525            matches!(&ops[0], MigrationOp::DropColumn { table, column } if table == "users" && column == "email")
526        );
527    }
528
529    #[test]
530    fn detects_altered_column_type() {
531        let mut from = empty_schema();
532        let mut from_table = simple_table("users");
533        from_table
534            .columns
535            .insert("age".to_string(), simple_column("age", PgType::Integer));
536        from.tables.insert("users".to_string(), from_table);
537
538        let mut to = empty_schema();
539        let mut to_table = simple_table("users");
540        to_table
541            .columns
542            .insert("age".to_string(), simple_column("age", PgType::BigInt));
543        to.tables.insert("users".to_string(), to_table);
544
545        let ops = compute_diff(&from, &to);
546        assert_eq!(ops.len(), 1);
547        assert!(matches!(
548            &ops[0],
549            MigrationOp::AlterColumn { table, column, changes }
550            if table == "users" && column == "age" && changes.data_type == Some(PgType::BigInt)
551        ));
552    }
553
554    #[test]
555    fn detects_added_index() {
556        let mut from = empty_schema();
557        from.tables
558            .insert("users".to_string(), simple_table("users"));
559
560        let mut to = empty_schema();
561        let mut table = simple_table("users");
562        table.indexes.push(Index {
563            name: "users_email_idx".to_string(),
564            columns: vec!["email".to_string()],
565            unique: true,
566            index_type: IndexType::BTree,
567        });
568        to.tables.insert("users".to_string(), table);
569
570        let ops = compute_diff(&from, &to);
571        assert_eq!(ops.len(), 1);
572        assert!(
573            matches!(&ops[0], MigrationOp::AddIndex { table, index } if table == "users" && index.name == "users_email_idx")
574        );
575    }
576
577    #[test]
578    fn detects_removed_index() {
579        let mut from = empty_schema();
580        let mut from_table = simple_table("users");
581        from_table.indexes.push(Index {
582            name: "users_email_idx".to_string(),
583            columns: vec!["email".to_string()],
584            unique: true,
585            index_type: IndexType::BTree,
586        });
587        from.tables.insert("users".to_string(), from_table);
588
589        let mut to = empty_schema();
590        to.tables.insert("users".to_string(), simple_table("users"));
591
592        let ops = compute_diff(&from, &to);
593        assert_eq!(ops.len(), 1);
594        assert!(
595            matches!(&ops[0], MigrationOp::DropIndex { table, index_name } if table == "users" && index_name == "users_email_idx")
596        );
597    }
598
599    #[test]
600    fn detects_added_foreign_key() {
601        let mut from = empty_schema();
602        from.tables
603            .insert("posts".to_string(), simple_table("posts"));
604
605        let mut to = empty_schema();
606        let mut table = simple_table("posts");
607        table.foreign_keys.push(ForeignKey {
608            name: "posts_user_id_fkey".to_string(),
609            columns: vec!["user_id".to_string()],
610            referenced_table: "users".to_string(),
611            referenced_columns: vec!["id".to_string()],
612            on_delete: ReferentialAction::Cascade,
613            on_update: ReferentialAction::NoAction,
614        });
615        to.tables.insert("posts".to_string(), table);
616
617        let ops = compute_diff(&from, &to);
618        assert_eq!(ops.len(), 1);
619        assert!(
620            matches!(&ops[0], MigrationOp::AddForeignKey { table, foreign_key } if table == "posts" && foreign_key.name == "posts_user_id_fkey")
621        );
622    }
623
624    #[test]
625    fn detects_removed_foreign_key() {
626        let mut from = empty_schema();
627        let mut from_table = simple_table("posts");
628        from_table.foreign_keys.push(ForeignKey {
629            name: "posts_user_id_fkey".to_string(),
630            columns: vec!["user_id".to_string()],
631            referenced_table: "users".to_string(),
632            referenced_columns: vec!["id".to_string()],
633            on_delete: ReferentialAction::Cascade,
634            on_update: ReferentialAction::NoAction,
635        });
636        from.tables.insert("posts".to_string(), from_table);
637
638        let mut to = empty_schema();
639        to.tables.insert("posts".to_string(), simple_table("posts"));
640
641        let ops = compute_diff(&from, &to);
642        assert_eq!(ops.len(), 1);
643        assert!(
644            matches!(&ops[0], MigrationOp::DropForeignKey { table, foreign_key_name } if table == "posts" && foreign_key_name == "posts_user_id_fkey")
645        );
646    }
647
648    #[test]
649    fn detects_added_function() {
650        let from = empty_schema();
651        let mut to = empty_schema();
652        let func = Function {
653            name: "add_numbers".to_string(),
654            schema: "public".to_string(),
655            arguments: vec![],
656            return_type: "integer".to_string(),
657            language: "sql".to_string(),
658            body: "SELECT 1 + 1".to_string(),
659            volatility: Volatility::Immutable,
660            security: SecurityType::Invoker,
661        };
662        to.functions.insert(func.signature(), func);
663
664        let ops = compute_diff(&from, &to);
665        assert_eq!(ops.len(), 1);
666        assert!(matches!(&ops[0], MigrationOp::CreateFunction(f) if f.name == "add_numbers"));
667    }
668
669    #[test]
670    fn detects_removed_function() {
671        let mut from = empty_schema();
672        let func = Function {
673            name: "add_numbers".to_string(),
674            schema: "public".to_string(),
675            arguments: vec![],
676            return_type: "integer".to_string(),
677            language: "sql".to_string(),
678            body: "SELECT 1 + 1".to_string(),
679            volatility: Volatility::Immutable,
680            security: SecurityType::Invoker,
681        };
682        from.functions.insert(func.signature(), func);
683        let to = empty_schema();
684
685        let ops = compute_diff(&from, &to);
686        assert_eq!(ops.len(), 1);
687        assert!(matches!(&ops[0], MigrationOp::DropFunction { name, .. } if name == "add_numbers"));
688    }
689}