1use std::collections::HashMap;
4
5use prax_schema::Schema;
6use prax_schema::ast::{Field, IndexType, Model, VectorOps, View};
7
8use crate::error::MigrateResult;
9
10#[derive(Debug, Clone, Default)]
12pub struct SchemaDiff {
13 pub create_extensions: Vec<ExtensionDiff>,
15 pub drop_extensions: Vec<String>,
17 pub create_models: Vec<ModelDiff>,
19 pub drop_models: Vec<String>,
21 pub alter_models: Vec<ModelAlterDiff>,
23 pub create_enums: Vec<EnumDiff>,
25 pub drop_enums: Vec<String>,
27 pub alter_enums: Vec<EnumAlterDiff>,
29 pub create_views: Vec<ViewDiff>,
31 pub drop_views: Vec<String>,
33 pub alter_views: Vec<ViewDiff>,
35 pub create_indexes: Vec<IndexDiff>,
37 pub drop_indexes: Vec<IndexDiff>,
39}
40
41#[derive(Debug, Clone)]
43pub struct ExtensionDiff {
44 pub name: String,
46 pub schema: Option<String>,
48 pub version: Option<String>,
50}
51
52impl SchemaDiff {
53 pub fn is_empty(&self) -> bool {
55 self.create_extensions.is_empty()
56 && self.drop_extensions.is_empty()
57 && self.create_models.is_empty()
58 && self.drop_models.is_empty()
59 && self.alter_models.is_empty()
60 && self.create_enums.is_empty()
61 && self.drop_enums.is_empty()
62 && self.alter_enums.is_empty()
63 && self.create_views.is_empty()
64 && self.drop_views.is_empty()
65 && self.alter_views.is_empty()
66 && self.create_indexes.is_empty()
67 && self.drop_indexes.is_empty()
68 }
69
70 pub fn summary(&self) -> String {
72 let mut parts = Vec::new();
73
74 if !self.create_extensions.is_empty() {
75 parts.push(format!(
76 "Create {} extensions",
77 self.create_extensions.len()
78 ));
79 }
80 if !self.drop_extensions.is_empty() {
81 parts.push(format!("Drop {} extensions", self.drop_extensions.len()));
82 }
83 if !self.create_models.is_empty() {
84 parts.push(format!("Create {} models", self.create_models.len()));
85 }
86 if !self.drop_models.is_empty() {
87 parts.push(format!("Drop {} models", self.drop_models.len()));
88 }
89 if !self.alter_models.is_empty() {
90 parts.push(format!("Alter {} models", self.alter_models.len()));
91 }
92 if !self.create_enums.is_empty() {
93 parts.push(format!("Create {} enums", self.create_enums.len()));
94 }
95 if !self.drop_enums.is_empty() {
96 parts.push(format!("Drop {} enums", self.drop_enums.len()));
97 }
98 if !self.create_views.is_empty() {
99 parts.push(format!("Create {} views", self.create_views.len()));
100 }
101 if !self.drop_views.is_empty() {
102 parts.push(format!("Drop {} views", self.drop_views.len()));
103 }
104 if !self.alter_views.is_empty() {
105 parts.push(format!("Alter {} views", self.alter_views.len()));
106 }
107 if !self.create_indexes.is_empty() {
108 parts.push(format!("Create {} indexes", self.create_indexes.len()));
109 }
110 if !self.drop_indexes.is_empty() {
111 parts.push(format!("Drop {} indexes", self.drop_indexes.len()));
112 }
113
114 if parts.is_empty() {
115 "No changes".to_string()
116 } else {
117 parts.join(", ")
118 }
119 }
120}
121
122#[derive(Debug, Clone)]
124pub struct ModelDiff {
125 pub name: String,
127 pub table_name: String,
129 pub fields: Vec<FieldDiff>,
131 pub primary_key: Vec<String>,
133 pub indexes: Vec<IndexDiff>,
135 pub unique_constraints: Vec<UniqueConstraint>,
137}
138
139#[derive(Debug, Clone)]
141pub struct ModelAlterDiff {
142 pub name: String,
144 pub table_name: String,
146 pub add_fields: Vec<FieldDiff>,
148 pub drop_fields: Vec<String>,
150 pub alter_fields: Vec<FieldAlterDiff>,
152 pub add_indexes: Vec<IndexDiff>,
154 pub drop_indexes: Vec<String>,
156}
157
158#[derive(Debug, Clone)]
160pub struct FieldDiff {
161 pub name: String,
163 pub column_name: String,
165 pub sql_type: String,
167 pub nullable: bool,
169 pub default: Option<String>,
171 pub is_primary_key: bool,
173 pub is_auto_increment: bool,
175 pub is_unique: bool,
177}
178
179#[derive(Debug, Clone)]
181pub struct FieldAlterDiff {
182 pub name: String,
184 pub column_name: String,
186 pub old_type: Option<String>,
188 pub new_type: Option<String>,
190 pub old_nullable: Option<bool>,
192 pub new_nullable: Option<bool>,
194 pub old_default: Option<String>,
196 pub new_default: Option<String>,
198}
199
200#[derive(Debug, Clone)]
202pub struct EnumDiff {
203 pub name: String,
205 pub values: Vec<String>,
207}
208
209#[derive(Debug, Clone)]
211pub struct EnumAlterDiff {
212 pub name: String,
214 pub add_values: Vec<String>,
216 pub remove_values: Vec<String>,
218}
219
220#[derive(Debug, Clone)]
222pub struct IndexDiff {
223 pub name: String,
225 pub table_name: String,
227 pub columns: Vec<String>,
229 pub unique: bool,
231 pub index_type: Option<IndexType>,
233 pub vector_ops: Option<VectorOps>,
235 pub hnsw_m: Option<u32>,
237 pub hnsw_ef_construction: Option<u32>,
239 pub ivfflat_lists: Option<u32>,
241}
242
243impl IndexDiff {
244 pub fn new(
246 name: impl Into<String>,
247 table_name: impl Into<String>,
248 columns: Vec<String>,
249 ) -> Self {
250 Self {
251 name: name.into(),
252 table_name: table_name.into(),
253 columns,
254 unique: false,
255 index_type: None,
256 vector_ops: None,
257 hnsw_m: None,
258 hnsw_ef_construction: None,
259 ivfflat_lists: None,
260 }
261 }
262
263 pub fn unique(mut self) -> Self {
265 self.unique = true;
266 self
267 }
268
269 pub fn with_type(mut self, index_type: IndexType) -> Self {
271 self.index_type = Some(index_type);
272 self
273 }
274
275 pub fn with_vector_ops(mut self, ops: VectorOps) -> Self {
277 self.vector_ops = Some(ops);
278 self
279 }
280
281 pub fn with_hnsw_m(mut self, m: u32) -> Self {
283 self.hnsw_m = Some(m);
284 self
285 }
286
287 pub fn with_hnsw_ef_construction(mut self, ef: u32) -> Self {
289 self.hnsw_ef_construction = Some(ef);
290 self
291 }
292
293 pub fn with_ivfflat_lists(mut self, lists: u32) -> Self {
295 self.ivfflat_lists = Some(lists);
296 self
297 }
298
299 pub fn is_vector_index(&self) -> bool {
301 self.index_type
302 .as_ref()
303 .is_some_and(|t| t.is_vector_index())
304 }
305}
306
307#[derive(Debug, Clone)]
309pub struct UniqueConstraint {
310 pub name: Option<String>,
312 pub columns: Vec<String>,
314}
315
316#[derive(Debug, Clone)]
318pub struct ViewDiff {
319 pub name: String,
321 pub view_name: String,
323 pub sql_query: String,
325 pub is_materialized: bool,
327 pub refresh_interval: Option<String>,
329 pub fields: Vec<ViewFieldDiff>,
331}
332
333#[derive(Debug, Clone)]
335pub struct ViewFieldDiff {
336 pub name: String,
338 pub column_name: String,
340 pub sql_type: String,
342 pub nullable: bool,
344}
345
346pub struct SchemaDiffer {
348 source: Option<Schema>,
350 target: Schema,
352}
353
354impl SchemaDiffer {
355 pub fn new(target: Schema) -> Self {
357 Self {
358 source: None,
359 target,
360 }
361 }
362
363 pub fn with_source(mut self, source: Schema) -> Self {
365 self.source = Some(source);
366 self
367 }
368
369 pub fn diff(&self) -> MigrateResult<SchemaDiff> {
371 let mut result = SchemaDiff::default();
372
373 let source_models: HashMap<&str, &Model> = self
374 .source
375 .as_ref()
376 .map(|s| s.models.values().map(|m| (m.name(), m)).collect())
377 .unwrap_or_default();
378
379 let target_models: HashMap<&str, &Model> =
380 self.target.models.values().map(|m| (m.name(), m)).collect();
381
382 for (name, model) in &target_models {
384 if !source_models.contains_key(name) {
385 result.create_models.push(model_to_diff(model));
386 }
387 }
388
389 for name in source_models.keys() {
391 if !target_models.contains_key(name) {
392 result.drop_models.push((*name).to_string());
393 }
394 }
395
396 for (name, target_model) in &target_models {
398 if let Some(source_model) = source_models.get(name)
399 && let Some(alter) = diff_models(source_model, target_model)
400 {
401 result.alter_models.push(alter);
402 }
403 }
404
405 let source_enums: HashMap<&str, _> = self
407 .source
408 .as_ref()
409 .map(|s| s.enums.values().map(|e| (e.name(), e)).collect())
410 .unwrap_or_default();
411
412 let target_enums: HashMap<&str, _> =
413 self.target.enums.values().map(|e| (e.name(), e)).collect();
414
415 for (name, enum_def) in &target_enums {
416 if !source_enums.contains_key(name) {
417 result.create_enums.push(EnumDiff {
418 name: (*name).to_string(),
419 values: enum_def
420 .variants
421 .iter()
422 .map(|v| v.name.to_string())
423 .collect(),
424 });
425 }
426 }
427
428 for name in source_enums.keys() {
429 if !target_enums.contains_key(name) {
430 result.drop_enums.push((*name).to_string());
431 }
432 }
433
434 let source_views: HashMap<&str, &View> = self
436 .source
437 .as_ref()
438 .map(|s| s.views.values().map(|v| (v.name(), v)).collect())
439 .unwrap_or_default();
440
441 let target_views: HashMap<&str, &View> =
442 self.target.views.values().map(|v| (v.name(), v)).collect();
443
444 for (name, view) in &target_views {
446 if !source_views.contains_key(name)
447 && let Some(view_diff) = view_to_diff(view)
448 {
449 result.create_views.push(view_diff);
450 }
451 }
452
453 for name in source_views.keys() {
455 if !target_views.contains_key(name) {
456 result.drop_views.push((*name).to_string());
457 }
458 }
459
460 for (name, target_view) in &target_views {
462 if let Some(source_view) = source_views.get(name) {
463 let source_sql = source_view.sql_query();
465 let target_sql = target_view.sql_query();
466
467 let sql_changed = source_sql != target_sql;
469 let materialized_changed =
470 source_view.is_materialized() != target_view.is_materialized();
471
472 if (sql_changed || materialized_changed)
473 && let Some(view_diff) = view_to_diff(target_view)
474 {
475 result.alter_views.push(view_diff);
476 }
477 }
478 }
479
480 Ok(result)
481 }
482}
483
484fn model_to_diff(model: &Model) -> ModelDiff {
486 let fields: Vec<FieldDiff> = model
487 .fields
488 .values()
489 .filter(|f| !f.is_relation())
490 .map(field_to_diff)
491 .collect();
492
493 let primary_key: Vec<String> = model
494 .fields
495 .values()
496 .filter(|f| f.has_attribute("id"))
497 .map(|f| f.name().to_string())
498 .collect();
499
500 ModelDiff {
501 name: model.name().to_string(),
502 table_name: model.table_name().to_string(),
503 fields,
504 primary_key,
505 indexes: Vec::new(),
506 unique_constraints: Vec::new(),
507 }
508}
509
510fn field_to_diff(field: &Field) -> FieldDiff {
512 let sql_type = field_type_to_sql(&field.field_type);
513 let nullable = field.is_optional();
514 let is_primary_key = field.has_attribute("id");
515 let is_auto_increment = field.has_attribute("auto");
516 let is_unique = field.has_attribute("unique");
517
518 let default = field
519 .get_attribute("default")
520 .and_then(|attr| attr.first_arg())
521 .map(|arg| format!("{:?}", arg));
522
523 let column_name = field
525 .get_attribute("map")
526 .and_then(|attr| attr.first_arg())
527 .and_then(|v| v.as_string())
528 .unwrap_or_else(|| field.name())
529 .to_string();
530
531 FieldDiff {
532 name: field.name().to_string(),
533 column_name,
534 sql_type,
535 nullable,
536 default,
537 is_primary_key,
538 is_auto_increment,
539 is_unique,
540 }
541}
542
543fn field_type_to_sql(field_type: &prax_schema::ast::FieldType) -> String {
545 use prax_schema::ast::{FieldType, ScalarType};
546
547 match field_type {
548 FieldType::Scalar(scalar) => match scalar {
549 ScalarType::Int => "INTEGER".to_string(),
550 ScalarType::BigInt => "BIGINT".to_string(),
551 ScalarType::Float => "DOUBLE PRECISION".to_string(),
552 ScalarType::Decimal => "DECIMAL".to_string(),
553 ScalarType::String => "TEXT".to_string(),
554 ScalarType::Boolean => "BOOLEAN".to_string(),
555 ScalarType::DateTime => "TIMESTAMP WITH TIME ZONE".to_string(),
556 ScalarType::Date => "DATE".to_string(),
557 ScalarType::Time => "TIME".to_string(),
558 ScalarType::Json => "JSONB".to_string(),
559 ScalarType::Bytes => "BYTEA".to_string(),
560 ScalarType::Uuid => "UUID".to_string(),
561 ScalarType::Cuid | ScalarType::Cuid2 | ScalarType::NanoId | ScalarType::Ulid => {
563 "TEXT".to_string()
564 }
565 ScalarType::Vector(dim) => match dim {
567 Some(d) => format!("vector({})", d),
568 None => "vector".to_string(),
569 },
570 ScalarType::HalfVector(dim) => match dim {
571 Some(d) => format!("halfvec({})", d),
572 None => "halfvec".to_string(),
573 },
574 ScalarType::SparseVector(dim) => match dim {
575 Some(d) => format!("sparsevec({})", d),
576 None => "sparsevec".to_string(),
577 },
578 ScalarType::Bit(dim) => match dim {
579 Some(d) => format!("bit({})", d),
580 None => "bit".to_string(),
581 },
582 },
583 FieldType::Model(name) => name.to_string(),
584 FieldType::Enum(name) => format!("\"{}\"", name),
585 FieldType::Composite(name) => name.to_string(),
586 FieldType::Unsupported(name) => name.to_string(),
587 }
588}
589
590fn diff_models(source: &Model, target: &Model) -> Option<ModelAlterDiff> {
592 let source_fields: HashMap<&str, &Field> = source
593 .fields
594 .values()
595 .filter(|f| !f.is_relation())
596 .map(|f| (f.name(), f))
597 .collect();
598
599 let target_fields: HashMap<&str, &Field> = target
600 .fields
601 .values()
602 .filter(|f| !f.is_relation())
603 .map(|f| (f.name(), f))
604 .collect();
605
606 let mut add_fields = Vec::new();
607 let mut drop_fields = Vec::new();
608 let mut alter_fields = Vec::new();
609
610 for (name, field) in &target_fields {
612 if !source_fields.contains_key(name) {
613 add_fields.push(field_to_diff(field));
614 }
615 }
616
617 for name in source_fields.keys() {
619 if !target_fields.contains_key(name) {
620 drop_fields.push((*name).to_string());
621 }
622 }
623
624 for (name, target_field) in &target_fields {
626 if let Some(source_field) = source_fields.get(name)
627 && let Some(alter) = diff_fields(source_field, target_field)
628 {
629 alter_fields.push(alter);
630 }
631 }
632
633 if add_fields.is_empty() && drop_fields.is_empty() && alter_fields.is_empty() {
634 None
635 } else {
636 Some(ModelAlterDiff {
637 name: target.name().to_string(),
638 table_name: target.table_name().to_string(),
639 add_fields,
640 drop_fields,
641 alter_fields,
642 add_indexes: Vec::new(),
643 drop_indexes: Vec::new(),
644 })
645 }
646}
647
648fn view_to_diff(view: &View) -> Option<ViewDiff> {
650 let sql_query = view.sql_query()?.to_string();
652
653 let fields: Vec<ViewFieldDiff> = view
654 .fields
655 .values()
656 .map(|field| {
657 let column_name = field
658 .get_attribute("map")
659 .and_then(|attr| attr.first_arg())
660 .and_then(|v| v.as_string())
661 .unwrap_or_else(|| field.name())
662 .to_string();
663
664 ViewFieldDiff {
665 name: field.name().to_string(),
666 column_name,
667 sql_type: field_type_to_sql(&field.field_type),
668 nullable: field.is_optional(),
669 }
670 })
671 .collect();
672
673 Some(ViewDiff {
674 name: view.name().to_string(),
675 view_name: view.view_name().to_string(),
676 sql_query,
677 is_materialized: view.is_materialized(),
678 refresh_interval: view.refresh_interval().map(|s| s.to_string()),
679 fields,
680 })
681}
682
683fn diff_fields(source: &Field, target: &Field) -> Option<FieldAlterDiff> {
685 let source_type = field_type_to_sql(&source.field_type);
686 let target_type = field_type_to_sql(&target.field_type);
687
688 let source_nullable = source.is_optional();
689 let target_nullable = target.is_optional();
690
691 let type_changed = source_type != target_type;
692 let nullable_changed = source_nullable != target_nullable;
693
694 if !type_changed && !nullable_changed {
695 return None;
696 }
697
698 let column_name = target
700 .get_attribute("map")
701 .and_then(|attr| attr.first_arg())
702 .and_then(|v| v.as_string())
703 .unwrap_or_else(|| target.name())
704 .to_string();
705
706 Some(FieldAlterDiff {
707 name: target.name().to_string(),
708 column_name,
709 old_type: if type_changed {
710 Some(source_type)
711 } else {
712 None
713 },
714 new_type: if type_changed {
715 Some(target_type)
716 } else {
717 None
718 },
719 old_nullable: if nullable_changed {
720 Some(source_nullable)
721 } else {
722 None
723 },
724 new_nullable: if nullable_changed {
725 Some(target_nullable)
726 } else {
727 None
728 },
729 old_default: None,
730 new_default: None,
731 })
732}
733
734#[cfg(test)]
735mod tests {
736 use super::*;
737
738 #[test]
739 fn test_schema_diff_empty() {
740 let diff = SchemaDiff::default();
741 assert!(diff.is_empty());
742 }
743
744 #[test]
745 fn test_schema_diff_summary() {
746 let mut diff = SchemaDiff::default();
747 diff.create_models.push(ModelDiff {
748 name: "User".to_string(),
749 table_name: "users".to_string(),
750 fields: Vec::new(),
751 primary_key: Vec::new(),
752 indexes: Vec::new(),
753 unique_constraints: Vec::new(),
754 });
755
756 let summary = diff.summary();
757 assert!(summary.contains("Create 1 models"));
758 }
759
760 #[test]
761 fn test_schema_diff_with_views() {
762 let mut diff = SchemaDiff::default();
763 diff.create_views.push(ViewDiff {
764 name: "UserStats".to_string(),
765 view_name: "user_stats".to_string(),
766 sql_query: "SELECT id, COUNT(*) FROM users GROUP BY id".to_string(),
767 is_materialized: false,
768 refresh_interval: None,
769 fields: vec![],
770 });
771
772 assert!(!diff.is_empty());
773 let summary = diff.summary();
774 assert!(summary.contains("Create 1 views"));
775 }
776
777 #[test]
778 fn test_schema_diff_summary_with_multiple() {
779 let mut diff = SchemaDiff::default();
780 diff.create_views.push(ViewDiff {
781 name: "View1".to_string(),
782 view_name: "view1".to_string(),
783 sql_query: "SELECT 1".to_string(),
784 is_materialized: false,
785 refresh_interval: None,
786 fields: vec![],
787 });
788 diff.drop_views.push("old_view".to_string());
789 diff.alter_views.push(ViewDiff {
790 name: "View2".to_string(),
791 view_name: "view2".to_string(),
792 sql_query: "SELECT 2".to_string(),
793 is_materialized: true,
794 refresh_interval: Some("1h".to_string()),
795 fields: vec![],
796 });
797
798 let summary = diff.summary();
799 assert!(summary.contains("Create 1 views"));
800 assert!(summary.contains("Drop 1 views"));
801 assert!(summary.contains("Alter 1 views"));
802 }
803
804 #[test]
805 fn test_view_diff_fields() {
806 let view_diff = ViewDiff {
807 name: "UserStats".to_string(),
808 view_name: "user_stats".to_string(),
809 sql_query: "SELECT id, name FROM users".to_string(),
810 is_materialized: false,
811 refresh_interval: None,
812 fields: vec![
813 ViewFieldDiff {
814 name: "id".to_string(),
815 column_name: "id".to_string(),
816 sql_type: "INTEGER".to_string(),
817 nullable: false,
818 },
819 ViewFieldDiff {
820 name: "name".to_string(),
821 column_name: "user_name".to_string(),
822 sql_type: "TEXT".to_string(),
823 nullable: true,
824 },
825 ],
826 };
827
828 assert_eq!(view_diff.fields.len(), 2);
829 assert_eq!(view_diff.fields[0].name, "id");
830 assert_eq!(view_diff.fields[1].column_name, "user_name");
831 }
832}