prax_migrate/
diff.rs

1//! Schema diffing for generating migrations.
2
3use std::collections::HashMap;
4
5use prax_schema::Schema;
6use prax_schema::ast::{Field, Model, View};
7
8use crate::error::MigrateResult;
9
10/// A diff between two schemas.
11#[derive(Debug, Clone, Default)]
12pub struct SchemaDiff {
13    /// Models to create.
14    pub create_models: Vec<ModelDiff>,
15    /// Models to drop.
16    pub drop_models: Vec<String>,
17    /// Models to alter.
18    pub alter_models: Vec<ModelAlterDiff>,
19    /// Enums to create.
20    pub create_enums: Vec<EnumDiff>,
21    /// Enums to drop.
22    pub drop_enums: Vec<String>,
23    /// Enums to alter.
24    pub alter_enums: Vec<EnumAlterDiff>,
25    /// Views to create.
26    pub create_views: Vec<ViewDiff>,
27    /// Views to drop.
28    pub drop_views: Vec<String>,
29    /// Views to alter (recreate with new definition).
30    pub alter_views: Vec<ViewDiff>,
31    /// Indexes to create.
32    pub create_indexes: Vec<IndexDiff>,
33    /// Indexes to drop.
34    pub drop_indexes: Vec<IndexDiff>,
35}
36
37impl SchemaDiff {
38    /// Check if there are any differences.
39    pub fn is_empty(&self) -> bool {
40        self.create_models.is_empty()
41            && self.drop_models.is_empty()
42            && self.alter_models.is_empty()
43            && self.create_enums.is_empty()
44            && self.drop_enums.is_empty()
45            && self.alter_enums.is_empty()
46            && self.create_views.is_empty()
47            && self.drop_views.is_empty()
48            && self.alter_views.is_empty()
49            && self.create_indexes.is_empty()
50            && self.drop_indexes.is_empty()
51    }
52
53    /// Get a human-readable summary of the diff.
54    pub fn summary(&self) -> String {
55        let mut parts = Vec::new();
56
57        if !self.create_models.is_empty() {
58            parts.push(format!("Create {} models", self.create_models.len()));
59        }
60        if !self.drop_models.is_empty() {
61            parts.push(format!("Drop {} models", self.drop_models.len()));
62        }
63        if !self.alter_models.is_empty() {
64            parts.push(format!("Alter {} models", self.alter_models.len()));
65        }
66        if !self.create_enums.is_empty() {
67            parts.push(format!("Create {} enums", self.create_enums.len()));
68        }
69        if !self.drop_enums.is_empty() {
70            parts.push(format!("Drop {} enums", self.drop_enums.len()));
71        }
72        if !self.create_views.is_empty() {
73            parts.push(format!("Create {} views", self.create_views.len()));
74        }
75        if !self.drop_views.is_empty() {
76            parts.push(format!("Drop {} views", self.drop_views.len()));
77        }
78        if !self.alter_views.is_empty() {
79            parts.push(format!("Alter {} views", self.alter_views.len()));
80        }
81        if !self.create_indexes.is_empty() {
82            parts.push(format!("Create {} indexes", self.create_indexes.len()));
83        }
84        if !self.drop_indexes.is_empty() {
85            parts.push(format!("Drop {} indexes", self.drop_indexes.len()));
86        }
87
88        if parts.is_empty() {
89            "No changes".to_string()
90        } else {
91            parts.join(", ")
92        }
93    }
94}
95
96/// Diff for creating a model.
97#[derive(Debug, Clone)]
98pub struct ModelDiff {
99    /// Model name.
100    pub name: String,
101    /// Table name.
102    pub table_name: String,
103    /// Fields to create.
104    pub fields: Vec<FieldDiff>,
105    /// Primary key columns.
106    pub primary_key: Vec<String>,
107    /// Indexes.
108    pub indexes: Vec<IndexDiff>,
109    /// Unique constraints.
110    pub unique_constraints: Vec<UniqueConstraint>,
111}
112
113/// Diff for altering a model.
114#[derive(Debug, Clone)]
115pub struct ModelAlterDiff {
116    /// Model name.
117    pub name: String,
118    /// Table name.
119    pub table_name: String,
120    /// Fields to add.
121    pub add_fields: Vec<FieldDiff>,
122    /// Fields to drop.
123    pub drop_fields: Vec<String>,
124    /// Fields to alter.
125    pub alter_fields: Vec<FieldAlterDiff>,
126    /// Indexes to add.
127    pub add_indexes: Vec<IndexDiff>,
128    /// Indexes to drop.
129    pub drop_indexes: Vec<String>,
130}
131
132/// Diff for a field.
133#[derive(Debug, Clone)]
134pub struct FieldDiff {
135    /// Field name.
136    pub name: String,
137    /// Column name.
138    pub column_name: String,
139    /// SQL type.
140    pub sql_type: String,
141    /// Whether the field is nullable.
142    pub nullable: bool,
143    /// Default value expression.
144    pub default: Option<String>,
145    /// Whether this is a primary key.
146    pub is_primary_key: bool,
147    /// Whether this has auto increment.
148    pub is_auto_increment: bool,
149    /// Whether this is unique.
150    pub is_unique: bool,
151}
152
153/// Diff for altering a field.
154#[derive(Debug, Clone)]
155pub struct FieldAlterDiff {
156    /// Field name.
157    pub name: String,
158    /// Column name.
159    pub column_name: String,
160    /// Old SQL type (if changed).
161    pub old_type: Option<String>,
162    /// New SQL type (if changed).
163    pub new_type: Option<String>,
164    /// Old nullable (if changed).
165    pub old_nullable: Option<bool>,
166    /// New nullable (if changed).
167    pub new_nullable: Option<bool>,
168    /// Old default (if changed).
169    pub old_default: Option<String>,
170    /// New default (if changed).
171    pub new_default: Option<String>,
172}
173
174/// Diff for an enum.
175#[derive(Debug, Clone)]
176pub struct EnumDiff {
177    /// Enum name.
178    pub name: String,
179    /// Values.
180    pub values: Vec<String>,
181}
182
183/// Diff for altering an enum.
184#[derive(Debug, Clone)]
185pub struct EnumAlterDiff {
186    /// Enum name.
187    pub name: String,
188    /// Values to add.
189    pub add_values: Vec<String>,
190    /// Values to remove.
191    pub remove_values: Vec<String>,
192}
193
194/// Index diff.
195#[derive(Debug, Clone)]
196pub struct IndexDiff {
197    /// Index name.
198    pub name: String,
199    /// Table name.
200    pub table_name: String,
201    /// Columns in the index.
202    pub columns: Vec<String>,
203    /// Whether this is a unique index.
204    pub unique: bool,
205}
206
207/// Unique constraint.
208#[derive(Debug, Clone)]
209pub struct UniqueConstraint {
210    /// Constraint name.
211    pub name: Option<String>,
212    /// Columns.
213    pub columns: Vec<String>,
214}
215
216/// Diff for creating or altering a view.
217#[derive(Debug, Clone)]
218pub struct ViewDiff {
219    /// View name.
220    pub name: String,
221    /// Database view name.
222    pub view_name: String,
223    /// SQL query that defines the view.
224    pub sql_query: String,
225    /// Whether the view is materialized.
226    pub is_materialized: bool,
227    /// Refresh interval for materialized views (if any).
228    pub refresh_interval: Option<String>,
229    /// Fields in the view (for documentation/validation).
230    pub fields: Vec<ViewFieldDiff>,
231}
232
233/// Field in a view diff (for documentation purposes).
234#[derive(Debug, Clone)]
235pub struct ViewFieldDiff {
236    /// Field name.
237    pub name: String,
238    /// Column name in the view.
239    pub column_name: String,
240    /// SQL type.
241    pub sql_type: String,
242    /// Whether the field is nullable.
243    pub nullable: bool,
244}
245
246/// Schema differ for comparing schemas.
247pub struct SchemaDiffer {
248    /// Source schema (current database state).
249    source: Option<Schema>,
250    /// Target schema (desired state).
251    target: Schema,
252}
253
254impl SchemaDiffer {
255    /// Create a new differ with only the target schema.
256    pub fn new(target: Schema) -> Self {
257        Self {
258            source: None,
259            target,
260        }
261    }
262
263    /// Set the source schema.
264    pub fn with_source(mut self, source: Schema) -> Self {
265        self.source = Some(source);
266        self
267    }
268
269    /// Compute the diff between schemas.
270    pub fn diff(&self) -> MigrateResult<SchemaDiff> {
271        let mut result = SchemaDiff::default();
272
273        let source_models: HashMap<&str, &Model> = self
274            .source
275            .as_ref()
276            .map(|s| s.models.values().map(|m| (m.name(), m)).collect())
277            .unwrap_or_default();
278
279        let target_models: HashMap<&str, &Model> =
280            self.target.models.values().map(|m| (m.name(), m)).collect();
281
282        // Find models to create
283        for (name, model) in &target_models {
284            if !source_models.contains_key(name) {
285                result.create_models.push(model_to_diff(model));
286            }
287        }
288
289        // Find models to drop
290        for name in source_models.keys() {
291            if !target_models.contains_key(name) {
292                result.drop_models.push((*name).to_string());
293            }
294        }
295
296        // Find models to alter
297        for (name, target_model) in &target_models {
298            if let Some(source_model) = source_models.get(name)
299                && let Some(alter) = diff_models(source_model, target_model)
300            {
301                result.alter_models.push(alter);
302            }
303        }
304
305        // Diff enums similarly
306        let source_enums: HashMap<&str, _> = self
307            .source
308            .as_ref()
309            .map(|s| s.enums.values().map(|e| (e.name(), e)).collect())
310            .unwrap_or_default();
311
312        let target_enums: HashMap<&str, _> =
313            self.target.enums.values().map(|e| (e.name(), e)).collect();
314
315        for (name, enum_def) in &target_enums {
316            if !source_enums.contains_key(name) {
317                result.create_enums.push(EnumDiff {
318                    name: (*name).to_string(),
319                    values: enum_def
320                        .variants
321                        .iter()
322                        .map(|v| v.name.to_string())
323                        .collect(),
324                });
325            }
326        }
327
328        for name in source_enums.keys() {
329            if !target_enums.contains_key(name) {
330                result.drop_enums.push((*name).to_string());
331            }
332        }
333
334        // Diff views
335        let source_views: HashMap<&str, &View> = self
336            .source
337            .as_ref()
338            .map(|s| s.views.values().map(|v| (v.name(), v)).collect())
339            .unwrap_or_default();
340
341        let target_views: HashMap<&str, &View> =
342            self.target.views.values().map(|v| (v.name(), v)).collect();
343
344        // Find views to create
345        for (name, view) in &target_views {
346            if !source_views.contains_key(name) {
347                if let Some(view_diff) = view_to_diff(view) {
348                    result.create_views.push(view_diff);
349                }
350            }
351        }
352
353        // Find views to drop
354        for name in source_views.keys() {
355            if !target_views.contains_key(name) {
356                result.drop_views.push((*name).to_string());
357            }
358        }
359
360        // Find views to alter (if SQL changed)
361        for (name, target_view) in &target_views {
362            if let Some(source_view) = source_views.get(name) {
363                // Views are altered by dropping and recreating
364                let source_sql = source_view.sql_query();
365                let target_sql = target_view.sql_query();
366
367                // Check if SQL or materialized status changed
368                let sql_changed = source_sql != target_sql;
369                let materialized_changed =
370                    source_view.is_materialized() != target_view.is_materialized();
371
372                if sql_changed || materialized_changed {
373                    if let Some(view_diff) = view_to_diff(target_view) {
374                        result.alter_views.push(view_diff);
375                    }
376                }
377            }
378        }
379
380        Ok(result)
381    }
382}
383
384/// Convert a model to a diff for creation.
385fn model_to_diff(model: &Model) -> ModelDiff {
386    let fields: Vec<FieldDiff> = model
387        .fields
388        .values()
389        .filter(|f| !f.is_relation())
390        .map(field_to_diff)
391        .collect();
392
393    let primary_key: Vec<String> = model
394        .fields
395        .values()
396        .filter(|f| f.has_attribute("id"))
397        .map(|f| f.name().to_string())
398        .collect();
399
400    ModelDiff {
401        name: model.name().to_string(),
402        table_name: model.table_name().to_string(),
403        fields,
404        primary_key,
405        indexes: Vec::new(),
406        unique_constraints: Vec::new(),
407    }
408}
409
410/// Convert a field to a diff.
411fn field_to_diff(field: &Field) -> FieldDiff {
412    let sql_type = field_type_to_sql(&field.field_type);
413    let nullable = field.is_optional();
414    let is_primary_key = field.has_attribute("id");
415    let is_auto_increment = field.has_attribute("auto");
416    let is_unique = field.has_attribute("unique");
417
418    let default = field
419        .get_attribute("default")
420        .and_then(|attr| attr.first_arg())
421        .map(|arg| format!("{:?}", arg));
422
423    // Get column name from @map attribute or use field name
424    let column_name = field
425        .get_attribute("map")
426        .and_then(|attr| attr.first_arg())
427        .and_then(|v| v.as_string())
428        .unwrap_or_else(|| field.name())
429        .to_string();
430
431    FieldDiff {
432        name: field.name().to_string(),
433        column_name,
434        sql_type,
435        nullable,
436        default,
437        is_primary_key,
438        is_auto_increment,
439        is_unique,
440    }
441}
442
443/// Convert a field type to SQL.
444fn field_type_to_sql(field_type: &prax_schema::ast::FieldType) -> String {
445    use prax_schema::ast::{FieldType, ScalarType};
446
447    match field_type {
448        FieldType::Scalar(scalar) => match scalar {
449            ScalarType::Int => "INTEGER".to_string(),
450            ScalarType::BigInt => "BIGINT".to_string(),
451            ScalarType::Float => "DOUBLE PRECISION".to_string(),
452            ScalarType::Decimal => "DECIMAL".to_string(),
453            ScalarType::String => "TEXT".to_string(),
454            ScalarType::Boolean => "BOOLEAN".to_string(),
455            ScalarType::DateTime => "TIMESTAMP WITH TIME ZONE".to_string(),
456            ScalarType::Date => "DATE".to_string(),
457            ScalarType::Time => "TIME".to_string(),
458            ScalarType::Json => "JSONB".to_string(),
459            ScalarType::Bytes => "BYTEA".to_string(),
460            ScalarType::Uuid => "UUID".to_string(),
461            // String-based ID types stored as TEXT
462            ScalarType::Cuid | ScalarType::Cuid2 | ScalarType::NanoId | ScalarType::Ulid => {
463                "TEXT".to_string()
464            }
465        },
466        FieldType::Model(name) => name.to_string(),
467        FieldType::Enum(name) => format!("\"{}\"", name),
468        FieldType::Composite(name) => name.to_string(),
469        FieldType::Unsupported(name) => name.to_string(),
470    }
471}
472
473/// Diff two models and return alterations if any.
474fn diff_models(source: &Model, target: &Model) -> Option<ModelAlterDiff> {
475    let source_fields: HashMap<&str, &Field> = source
476        .fields
477        .values()
478        .filter(|f| !f.is_relation())
479        .map(|f| (f.name(), f))
480        .collect();
481
482    let target_fields: HashMap<&str, &Field> = target
483        .fields
484        .values()
485        .filter(|f| !f.is_relation())
486        .map(|f| (f.name(), f))
487        .collect();
488
489    let mut add_fields = Vec::new();
490    let mut drop_fields = Vec::new();
491    let mut alter_fields = Vec::new();
492
493    // Find fields to add
494    for (name, field) in &target_fields {
495        if !source_fields.contains_key(name) {
496            add_fields.push(field_to_diff(field));
497        }
498    }
499
500    // Find fields to drop
501    for name in source_fields.keys() {
502        if !target_fields.contains_key(name) {
503            drop_fields.push((*name).to_string());
504        }
505    }
506
507    // Find fields to alter
508    for (name, target_field) in &target_fields {
509        if let Some(source_field) = source_fields.get(name)
510            && let Some(alter) = diff_fields(source_field, target_field)
511        {
512            alter_fields.push(alter);
513        }
514    }
515
516    if add_fields.is_empty() && drop_fields.is_empty() && alter_fields.is_empty() {
517        None
518    } else {
519        Some(ModelAlterDiff {
520            name: target.name().to_string(),
521            table_name: target.table_name().to_string(),
522            add_fields,
523            drop_fields,
524            alter_fields,
525            add_indexes: Vec::new(),
526            drop_indexes: Vec::new(),
527        })
528    }
529}
530
531/// Convert a view to a diff for creation.
532fn view_to_diff(view: &View) -> Option<ViewDiff> {
533    // Views require a @@sql attribute to be migrated
534    let sql_query = view.sql_query()?.to_string();
535
536    let fields: Vec<ViewFieldDiff> = view
537        .fields
538        .values()
539        .map(|field| {
540            let column_name = field
541                .get_attribute("map")
542                .and_then(|attr| attr.first_arg())
543                .and_then(|v| v.as_string())
544                .unwrap_or_else(|| field.name())
545                .to_string();
546
547            ViewFieldDiff {
548                name: field.name().to_string(),
549                column_name,
550                sql_type: field_type_to_sql(&field.field_type),
551                nullable: field.is_optional(),
552            }
553        })
554        .collect();
555
556    Some(ViewDiff {
557        name: view.name().to_string(),
558        view_name: view.view_name().to_string(),
559        sql_query,
560        is_materialized: view.is_materialized(),
561        refresh_interval: view.refresh_interval().map(|s| s.to_string()),
562        fields,
563    })
564}
565
566/// Diff two fields and return alterations if any.
567fn diff_fields(source: &Field, target: &Field) -> Option<FieldAlterDiff> {
568    let source_type = field_type_to_sql(&source.field_type);
569    let target_type = field_type_to_sql(&target.field_type);
570
571    let source_nullable = source.is_optional();
572    let target_nullable = target.is_optional();
573
574    let type_changed = source_type != target_type;
575    let nullable_changed = source_nullable != target_nullable;
576
577    if !type_changed && !nullable_changed {
578        return None;
579    }
580
581    // Get column name from @map attribute or use field name
582    let column_name = target
583        .get_attribute("map")
584        .and_then(|attr| attr.first_arg())
585        .and_then(|v| v.as_string())
586        .unwrap_or_else(|| target.name())
587        .to_string();
588
589    Some(FieldAlterDiff {
590        name: target.name().to_string(),
591        column_name,
592        old_type: if type_changed {
593            Some(source_type)
594        } else {
595            None
596        },
597        new_type: if type_changed {
598            Some(target_type)
599        } else {
600            None
601        },
602        old_nullable: if nullable_changed {
603            Some(source_nullable)
604        } else {
605            None
606        },
607        new_nullable: if nullable_changed {
608            Some(target_nullable)
609        } else {
610            None
611        },
612        old_default: None,
613        new_default: None,
614    })
615}
616
617#[cfg(test)]
618mod tests {
619    use super::*;
620
621    #[test]
622    fn test_schema_diff_empty() {
623        let diff = SchemaDiff::default();
624        assert!(diff.is_empty());
625    }
626
627    #[test]
628    fn test_schema_diff_summary() {
629        let mut diff = SchemaDiff::default();
630        diff.create_models.push(ModelDiff {
631            name: "User".to_string(),
632            table_name: "users".to_string(),
633            fields: Vec::new(),
634            primary_key: Vec::new(),
635            indexes: Vec::new(),
636            unique_constraints: Vec::new(),
637        });
638
639        let summary = diff.summary();
640        assert!(summary.contains("Create 1 models"));
641    }
642
643    #[test]
644    fn test_schema_diff_with_views() {
645        let mut diff = SchemaDiff::default();
646        diff.create_views.push(ViewDiff {
647            name: "UserStats".to_string(),
648            view_name: "user_stats".to_string(),
649            sql_query: "SELECT id, COUNT(*) FROM users GROUP BY id".to_string(),
650            is_materialized: false,
651            refresh_interval: None,
652            fields: vec![],
653        });
654
655        assert!(!diff.is_empty());
656        let summary = diff.summary();
657        assert!(summary.contains("Create 1 views"));
658    }
659
660    #[test]
661    fn test_schema_diff_summary_with_multiple() {
662        let mut diff = SchemaDiff::default();
663        diff.create_views.push(ViewDiff {
664            name: "View1".to_string(),
665            view_name: "view1".to_string(),
666            sql_query: "SELECT 1".to_string(),
667            is_materialized: false,
668            refresh_interval: None,
669            fields: vec![],
670        });
671        diff.drop_views.push("old_view".to_string());
672        diff.alter_views.push(ViewDiff {
673            name: "View2".to_string(),
674            view_name: "view2".to_string(),
675            sql_query: "SELECT 2".to_string(),
676            is_materialized: true,
677            refresh_interval: Some("1h".to_string()),
678            fields: vec![],
679        });
680
681        let summary = diff.summary();
682        assert!(summary.contains("Create 1 views"));
683        assert!(summary.contains("Drop 1 views"));
684        assert!(summary.contains("Alter 1 views"));
685    }
686
687    #[test]
688    fn test_view_diff_fields() {
689        let view_diff = ViewDiff {
690            name: "UserStats".to_string(),
691            view_name: "user_stats".to_string(),
692            sql_query: "SELECT id, name FROM users".to_string(),
693            is_materialized: false,
694            refresh_interval: None,
695            fields: vec![
696                ViewFieldDiff {
697                    name: "id".to_string(),
698                    column_name: "id".to_string(),
699                    sql_type: "INTEGER".to_string(),
700                    nullable: false,
701                },
702                ViewFieldDiff {
703                    name: "name".to_string(),
704                    column_name: "user_name".to_string(),
705                    sql_type: "TEXT".to_string(),
706                    nullable: true,
707                },
708            ],
709        };
710
711        assert_eq!(view_diff.fields.len(), 2);
712        assert_eq!(view_diff.fields[0].name, "id");
713        assert_eq!(view_diff.fields[1].column_name, "user_name");
714    }
715}