1use std::collections::HashMap;
4
5use prax_schema::Schema;
6use prax_schema::ast::{Field, Model, View};
7
8use crate::error::MigrateResult;
9
10#[derive(Debug, Clone, Default)]
12pub struct SchemaDiff {
13 pub create_models: Vec<ModelDiff>,
15 pub drop_models: Vec<String>,
17 pub alter_models: Vec<ModelAlterDiff>,
19 pub create_enums: Vec<EnumDiff>,
21 pub drop_enums: Vec<String>,
23 pub alter_enums: Vec<EnumAlterDiff>,
25 pub create_views: Vec<ViewDiff>,
27 pub drop_views: Vec<String>,
29 pub alter_views: Vec<ViewDiff>,
31 pub create_indexes: Vec<IndexDiff>,
33 pub drop_indexes: Vec<IndexDiff>,
35}
36
37impl SchemaDiff {
38 pub fn is_empty(&self) -> bool {
40 self.create_models.is_empty()
41 && self.drop_models.is_empty()
42 && self.alter_models.is_empty()
43 && self.create_enums.is_empty()
44 && self.drop_enums.is_empty()
45 && self.alter_enums.is_empty()
46 && self.create_views.is_empty()
47 && self.drop_views.is_empty()
48 && self.alter_views.is_empty()
49 && self.create_indexes.is_empty()
50 && self.drop_indexes.is_empty()
51 }
52
53 pub fn summary(&self) -> String {
55 let mut parts = Vec::new();
56
57 if !self.create_models.is_empty() {
58 parts.push(format!("Create {} models", self.create_models.len()));
59 }
60 if !self.drop_models.is_empty() {
61 parts.push(format!("Drop {} models", self.drop_models.len()));
62 }
63 if !self.alter_models.is_empty() {
64 parts.push(format!("Alter {} models", self.alter_models.len()));
65 }
66 if !self.create_enums.is_empty() {
67 parts.push(format!("Create {} enums", self.create_enums.len()));
68 }
69 if !self.drop_enums.is_empty() {
70 parts.push(format!("Drop {} enums", self.drop_enums.len()));
71 }
72 if !self.create_views.is_empty() {
73 parts.push(format!("Create {} views", self.create_views.len()));
74 }
75 if !self.drop_views.is_empty() {
76 parts.push(format!("Drop {} views", self.drop_views.len()));
77 }
78 if !self.alter_views.is_empty() {
79 parts.push(format!("Alter {} views", self.alter_views.len()));
80 }
81 if !self.create_indexes.is_empty() {
82 parts.push(format!("Create {} indexes", self.create_indexes.len()));
83 }
84 if !self.drop_indexes.is_empty() {
85 parts.push(format!("Drop {} indexes", self.drop_indexes.len()));
86 }
87
88 if parts.is_empty() {
89 "No changes".to_string()
90 } else {
91 parts.join(", ")
92 }
93 }
94}
95
96#[derive(Debug, Clone)]
98pub struct ModelDiff {
99 pub name: String,
101 pub table_name: String,
103 pub fields: Vec<FieldDiff>,
105 pub primary_key: Vec<String>,
107 pub indexes: Vec<IndexDiff>,
109 pub unique_constraints: Vec<UniqueConstraint>,
111}
112
113#[derive(Debug, Clone)]
115pub struct ModelAlterDiff {
116 pub name: String,
118 pub table_name: String,
120 pub add_fields: Vec<FieldDiff>,
122 pub drop_fields: Vec<String>,
124 pub alter_fields: Vec<FieldAlterDiff>,
126 pub add_indexes: Vec<IndexDiff>,
128 pub drop_indexes: Vec<String>,
130}
131
132#[derive(Debug, Clone)]
134pub struct FieldDiff {
135 pub name: String,
137 pub column_name: String,
139 pub sql_type: String,
141 pub nullable: bool,
143 pub default: Option<String>,
145 pub is_primary_key: bool,
147 pub is_auto_increment: bool,
149 pub is_unique: bool,
151}
152
153#[derive(Debug, Clone)]
155pub struct FieldAlterDiff {
156 pub name: String,
158 pub column_name: String,
160 pub old_type: Option<String>,
162 pub new_type: Option<String>,
164 pub old_nullable: Option<bool>,
166 pub new_nullable: Option<bool>,
168 pub old_default: Option<String>,
170 pub new_default: Option<String>,
172}
173
174#[derive(Debug, Clone)]
176pub struct EnumDiff {
177 pub name: String,
179 pub values: Vec<String>,
181}
182
183#[derive(Debug, Clone)]
185pub struct EnumAlterDiff {
186 pub name: String,
188 pub add_values: Vec<String>,
190 pub remove_values: Vec<String>,
192}
193
194#[derive(Debug, Clone)]
196pub struct IndexDiff {
197 pub name: String,
199 pub table_name: String,
201 pub columns: Vec<String>,
203 pub unique: bool,
205}
206
207#[derive(Debug, Clone)]
209pub struct UniqueConstraint {
210 pub name: Option<String>,
212 pub columns: Vec<String>,
214}
215
216#[derive(Debug, Clone)]
218pub struct ViewDiff {
219 pub name: String,
221 pub view_name: String,
223 pub sql_query: String,
225 pub is_materialized: bool,
227 pub refresh_interval: Option<String>,
229 pub fields: Vec<ViewFieldDiff>,
231}
232
233#[derive(Debug, Clone)]
235pub struct ViewFieldDiff {
236 pub name: String,
238 pub column_name: String,
240 pub sql_type: String,
242 pub nullable: bool,
244}
245
246pub struct SchemaDiffer {
248 source: Option<Schema>,
250 target: Schema,
252}
253
254impl SchemaDiffer {
255 pub fn new(target: Schema) -> Self {
257 Self {
258 source: None,
259 target,
260 }
261 }
262
263 pub fn with_source(mut self, source: Schema) -> Self {
265 self.source = Some(source);
266 self
267 }
268
269 pub fn diff(&self) -> MigrateResult<SchemaDiff> {
271 let mut result = SchemaDiff::default();
272
273 let source_models: HashMap<&str, &Model> = self
274 .source
275 .as_ref()
276 .map(|s| s.models.values().map(|m| (m.name(), m)).collect())
277 .unwrap_or_default();
278
279 let target_models: HashMap<&str, &Model> =
280 self.target.models.values().map(|m| (m.name(), m)).collect();
281
282 for (name, model) in &target_models {
284 if !source_models.contains_key(name) {
285 result.create_models.push(model_to_diff(model));
286 }
287 }
288
289 for name in source_models.keys() {
291 if !target_models.contains_key(name) {
292 result.drop_models.push((*name).to_string());
293 }
294 }
295
296 for (name, target_model) in &target_models {
298 if let Some(source_model) = source_models.get(name)
299 && let Some(alter) = diff_models(source_model, target_model)
300 {
301 result.alter_models.push(alter);
302 }
303 }
304
305 let source_enums: HashMap<&str, _> = self
307 .source
308 .as_ref()
309 .map(|s| s.enums.values().map(|e| (e.name(), e)).collect())
310 .unwrap_or_default();
311
312 let target_enums: HashMap<&str, _> =
313 self.target.enums.values().map(|e| (e.name(), e)).collect();
314
315 for (name, enum_def) in &target_enums {
316 if !source_enums.contains_key(name) {
317 result.create_enums.push(EnumDiff {
318 name: (*name).to_string(),
319 values: enum_def
320 .variants
321 .iter()
322 .map(|v| v.name.to_string())
323 .collect(),
324 });
325 }
326 }
327
328 for name in source_enums.keys() {
329 if !target_enums.contains_key(name) {
330 result.drop_enums.push((*name).to_string());
331 }
332 }
333
334 let source_views: HashMap<&str, &View> = self
336 .source
337 .as_ref()
338 .map(|s| s.views.values().map(|v| (v.name(), v)).collect())
339 .unwrap_or_default();
340
341 let target_views: HashMap<&str, &View> =
342 self.target.views.values().map(|v| (v.name(), v)).collect();
343
344 for (name, view) in &target_views {
346 if !source_views.contains_key(name) {
347 if let Some(view_diff) = view_to_diff(view) {
348 result.create_views.push(view_diff);
349 }
350 }
351 }
352
353 for name in source_views.keys() {
355 if !target_views.contains_key(name) {
356 result.drop_views.push((*name).to_string());
357 }
358 }
359
360 for (name, target_view) in &target_views {
362 if let Some(source_view) = source_views.get(name) {
363 let source_sql = source_view.sql_query();
365 let target_sql = target_view.sql_query();
366
367 let sql_changed = source_sql != target_sql;
369 let materialized_changed =
370 source_view.is_materialized() != target_view.is_materialized();
371
372 if sql_changed || materialized_changed {
373 if let Some(view_diff) = view_to_diff(target_view) {
374 result.alter_views.push(view_diff);
375 }
376 }
377 }
378 }
379
380 Ok(result)
381 }
382}
383
384fn model_to_diff(model: &Model) -> ModelDiff {
386 let fields: Vec<FieldDiff> = model
387 .fields
388 .values()
389 .filter(|f| !f.is_relation())
390 .map(field_to_diff)
391 .collect();
392
393 let primary_key: Vec<String> = model
394 .fields
395 .values()
396 .filter(|f| f.has_attribute("id"))
397 .map(|f| f.name().to_string())
398 .collect();
399
400 ModelDiff {
401 name: model.name().to_string(),
402 table_name: model.table_name().to_string(),
403 fields,
404 primary_key,
405 indexes: Vec::new(),
406 unique_constraints: Vec::new(),
407 }
408}
409
410fn field_to_diff(field: &Field) -> FieldDiff {
412 let sql_type = field_type_to_sql(&field.field_type);
413 let nullable = field.is_optional();
414 let is_primary_key = field.has_attribute("id");
415 let is_auto_increment = field.has_attribute("auto");
416 let is_unique = field.has_attribute("unique");
417
418 let default = field
419 .get_attribute("default")
420 .and_then(|attr| attr.first_arg())
421 .map(|arg| format!("{:?}", arg));
422
423 let column_name = field
425 .get_attribute("map")
426 .and_then(|attr| attr.first_arg())
427 .and_then(|v| v.as_string())
428 .unwrap_or_else(|| field.name())
429 .to_string();
430
431 FieldDiff {
432 name: field.name().to_string(),
433 column_name,
434 sql_type,
435 nullable,
436 default,
437 is_primary_key,
438 is_auto_increment,
439 is_unique,
440 }
441}
442
443fn field_type_to_sql(field_type: &prax_schema::ast::FieldType) -> String {
445 use prax_schema::ast::{FieldType, ScalarType};
446
447 match field_type {
448 FieldType::Scalar(scalar) => match scalar {
449 ScalarType::Int => "INTEGER".to_string(),
450 ScalarType::BigInt => "BIGINT".to_string(),
451 ScalarType::Float => "DOUBLE PRECISION".to_string(),
452 ScalarType::Decimal => "DECIMAL".to_string(),
453 ScalarType::String => "TEXT".to_string(),
454 ScalarType::Boolean => "BOOLEAN".to_string(),
455 ScalarType::DateTime => "TIMESTAMP WITH TIME ZONE".to_string(),
456 ScalarType::Date => "DATE".to_string(),
457 ScalarType::Time => "TIME".to_string(),
458 ScalarType::Json => "JSONB".to_string(),
459 ScalarType::Bytes => "BYTEA".to_string(),
460 ScalarType::Uuid => "UUID".to_string(),
461 ScalarType::Cuid | ScalarType::Cuid2 | ScalarType::NanoId | ScalarType::Ulid => {
463 "TEXT".to_string()
464 }
465 },
466 FieldType::Model(name) => name.to_string(),
467 FieldType::Enum(name) => format!("\"{}\"", name),
468 FieldType::Composite(name) => name.to_string(),
469 FieldType::Unsupported(name) => name.to_string(),
470 }
471}
472
473fn diff_models(source: &Model, target: &Model) -> Option<ModelAlterDiff> {
475 let source_fields: HashMap<&str, &Field> = source
476 .fields
477 .values()
478 .filter(|f| !f.is_relation())
479 .map(|f| (f.name(), f))
480 .collect();
481
482 let target_fields: HashMap<&str, &Field> = target
483 .fields
484 .values()
485 .filter(|f| !f.is_relation())
486 .map(|f| (f.name(), f))
487 .collect();
488
489 let mut add_fields = Vec::new();
490 let mut drop_fields = Vec::new();
491 let mut alter_fields = Vec::new();
492
493 for (name, field) in &target_fields {
495 if !source_fields.contains_key(name) {
496 add_fields.push(field_to_diff(field));
497 }
498 }
499
500 for name in source_fields.keys() {
502 if !target_fields.contains_key(name) {
503 drop_fields.push((*name).to_string());
504 }
505 }
506
507 for (name, target_field) in &target_fields {
509 if let Some(source_field) = source_fields.get(name)
510 && let Some(alter) = diff_fields(source_field, target_field)
511 {
512 alter_fields.push(alter);
513 }
514 }
515
516 if add_fields.is_empty() && drop_fields.is_empty() && alter_fields.is_empty() {
517 None
518 } else {
519 Some(ModelAlterDiff {
520 name: target.name().to_string(),
521 table_name: target.table_name().to_string(),
522 add_fields,
523 drop_fields,
524 alter_fields,
525 add_indexes: Vec::new(),
526 drop_indexes: Vec::new(),
527 })
528 }
529}
530
531fn view_to_diff(view: &View) -> Option<ViewDiff> {
533 let sql_query = view.sql_query()?.to_string();
535
536 let fields: Vec<ViewFieldDiff> = view
537 .fields
538 .values()
539 .map(|field| {
540 let column_name = field
541 .get_attribute("map")
542 .and_then(|attr| attr.first_arg())
543 .and_then(|v| v.as_string())
544 .unwrap_or_else(|| field.name())
545 .to_string();
546
547 ViewFieldDiff {
548 name: field.name().to_string(),
549 column_name,
550 sql_type: field_type_to_sql(&field.field_type),
551 nullable: field.is_optional(),
552 }
553 })
554 .collect();
555
556 Some(ViewDiff {
557 name: view.name().to_string(),
558 view_name: view.view_name().to_string(),
559 sql_query,
560 is_materialized: view.is_materialized(),
561 refresh_interval: view.refresh_interval().map(|s| s.to_string()),
562 fields,
563 })
564}
565
566fn diff_fields(source: &Field, target: &Field) -> Option<FieldAlterDiff> {
568 let source_type = field_type_to_sql(&source.field_type);
569 let target_type = field_type_to_sql(&target.field_type);
570
571 let source_nullable = source.is_optional();
572 let target_nullable = target.is_optional();
573
574 let type_changed = source_type != target_type;
575 let nullable_changed = source_nullable != target_nullable;
576
577 if !type_changed && !nullable_changed {
578 return None;
579 }
580
581 let column_name = target
583 .get_attribute("map")
584 .and_then(|attr| attr.first_arg())
585 .and_then(|v| v.as_string())
586 .unwrap_or_else(|| target.name())
587 .to_string();
588
589 Some(FieldAlterDiff {
590 name: target.name().to_string(),
591 column_name,
592 old_type: if type_changed {
593 Some(source_type)
594 } else {
595 None
596 },
597 new_type: if type_changed {
598 Some(target_type)
599 } else {
600 None
601 },
602 old_nullable: if nullable_changed {
603 Some(source_nullable)
604 } else {
605 None
606 },
607 new_nullable: if nullable_changed {
608 Some(target_nullable)
609 } else {
610 None
611 },
612 old_default: None,
613 new_default: None,
614 })
615}
616
617#[cfg(test)]
618mod tests {
619 use super::*;
620
621 #[test]
622 fn test_schema_diff_empty() {
623 let diff = SchemaDiff::default();
624 assert!(diff.is_empty());
625 }
626
627 #[test]
628 fn test_schema_diff_summary() {
629 let mut diff = SchemaDiff::default();
630 diff.create_models.push(ModelDiff {
631 name: "User".to_string(),
632 table_name: "users".to_string(),
633 fields: Vec::new(),
634 primary_key: Vec::new(),
635 indexes: Vec::new(),
636 unique_constraints: Vec::new(),
637 });
638
639 let summary = diff.summary();
640 assert!(summary.contains("Create 1 models"));
641 }
642
643 #[test]
644 fn test_schema_diff_with_views() {
645 let mut diff = SchemaDiff::default();
646 diff.create_views.push(ViewDiff {
647 name: "UserStats".to_string(),
648 view_name: "user_stats".to_string(),
649 sql_query: "SELECT id, COUNT(*) FROM users GROUP BY id".to_string(),
650 is_materialized: false,
651 refresh_interval: None,
652 fields: vec![],
653 });
654
655 assert!(!diff.is_empty());
656 let summary = diff.summary();
657 assert!(summary.contains("Create 1 views"));
658 }
659
660 #[test]
661 fn test_schema_diff_summary_with_multiple() {
662 let mut diff = SchemaDiff::default();
663 diff.create_views.push(ViewDiff {
664 name: "View1".to_string(),
665 view_name: "view1".to_string(),
666 sql_query: "SELECT 1".to_string(),
667 is_materialized: false,
668 refresh_interval: None,
669 fields: vec![],
670 });
671 diff.drop_views.push("old_view".to_string());
672 diff.alter_views.push(ViewDiff {
673 name: "View2".to_string(),
674 view_name: "view2".to_string(),
675 sql_query: "SELECT 2".to_string(),
676 is_materialized: true,
677 refresh_interval: Some("1h".to_string()),
678 fields: vec![],
679 });
680
681 let summary = diff.summary();
682 assert!(summary.contains("Create 1 views"));
683 assert!(summary.contains("Drop 1 views"));
684 assert!(summary.contains("Alter 1 views"));
685 }
686
687 #[test]
688 fn test_view_diff_fields() {
689 let view_diff = ViewDiff {
690 name: "UserStats".to_string(),
691 view_name: "user_stats".to_string(),
692 sql_query: "SELECT id, name FROM users".to_string(),
693 is_materialized: false,
694 refresh_interval: None,
695 fields: vec![
696 ViewFieldDiff {
697 name: "id".to_string(),
698 column_name: "id".to_string(),
699 sql_type: "INTEGER".to_string(),
700 nullable: false,
701 },
702 ViewFieldDiff {
703 name: "name".to_string(),
704 column_name: "user_name".to_string(),
705 sql_type: "TEXT".to_string(),
706 nullable: true,
707 },
708 ],
709 };
710
711 assert_eq!(view_diff.fields.len(), 2);
712 assert_eq!(view_diff.fields[0].name, "id");
713 assert_eq!(view_diff.fields[1].column_name, "user_name");
714 }
715}