1use std::collections::{BTreeMap, BTreeSet};
2
3#[cfg(feature = "postgres")]
4use sqlx::PgPool;
5#[cfg(feature = "sqlite")]
6use sqlx::SqlitePool;
7
8#[derive(Debug, Clone, PartialEq, Eq)]
10pub struct SchemaColumn {
11 pub name: String,
13 pub sql_type: String,
15 pub nullable: bool,
17 pub primary_key: bool,
19}
20
21impl SchemaColumn {
22 fn normalized_type(&self) -> String {
23 normalize_sql_type(&self.sql_type)
24 }
25}
26
27#[derive(Debug, Clone, PartialEq, Eq)]
29pub struct SchemaIndex {
30 pub name: String,
32 pub columns: Vec<String>,
34 pub unique: bool,
36}
37
38#[derive(Debug, Clone, PartialEq, Eq)]
40pub struct SchemaForeignKey {
41 pub column: String,
43 pub ref_table: String,
45 pub ref_column: String,
47}
48
49#[derive(Debug, Clone, PartialEq, Eq)]
51pub struct SchemaTable {
52 pub name: String,
54 pub columns: Vec<SchemaColumn>,
56 pub indexes: Vec<SchemaIndex>,
58 pub foreign_keys: Vec<SchemaForeignKey>,
60 pub create_sql: Option<String>,
62}
63
64impl SchemaTable {
65 pub fn column(&self, name: &str) -> Option<&SchemaColumn> {
67 self.columns.iter().find(|c| c.name == name)
68 }
69
70 pub fn to_create_sql(&self) -> String {
72 if let Some(sql) = &self.create_sql {
73 return sql.clone();
74 }
75
76 let mut cols = Vec::new();
77 for col in &self.columns {
78 if col.primary_key {
79 cols.push(format!("{} INTEGER PRIMARY KEY", col.name));
80 continue;
81 }
82 let mut def = format!("{} {}", col.name, col.sql_type);
83 if !col.nullable {
84 def.push_str(" NOT NULL");
85 }
86 cols.push(def);
87 }
88
89 format!(
90 "CREATE TABLE IF NOT EXISTS {} ({})",
91 self.name,
92 cols.join(", ")
93 )
94 }
95}
96
97pub trait ModelSchema {
99 fn schema() -> SchemaTable;
101}
102
103#[macro_export]
105macro_rules! schema_models {
106 ($($model:ty),+ $(,)?) => {
107 vec![$(<$model as $crate::schema::ModelSchema>::schema()),+]
108 };
109}
110
111#[derive(Debug, Clone, PartialEq, Eq)]
113pub struct ColumnDiff {
114 pub table: String,
116 pub column: String,
118 pub sql_type: Option<String>,
120}
121
122#[derive(Debug, Clone, PartialEq, Eq)]
124pub struct ColumnTypeDiff {
125 pub table: String,
127 pub column: String,
129 pub expected: String,
131 pub actual: String,
133}
134
135#[derive(Debug, Clone, PartialEq, Eq)]
137pub struct ColumnNullabilityDiff {
138 pub table: String,
140 pub column: String,
142 pub expected_nullable: bool,
144 pub actual_nullable: bool,
146}
147
148#[derive(Debug, Clone, PartialEq, Eq)]
150pub struct ColumnPrimaryKeyDiff {
151 pub table: String,
153 pub column: String,
155 pub expected_primary_key: bool,
157 pub actual_primary_key: bool,
159}
160
161#[derive(Debug, Clone, PartialEq, Eq, Default)]
163pub struct SchemaDiff {
164 pub missing_tables: Vec<String>,
166 pub extra_tables: Vec<String>,
168 pub missing_columns: Vec<ColumnDiff>,
170 pub extra_columns: Vec<ColumnDiff>,
172 pub type_mismatches: Vec<ColumnTypeDiff>,
174 pub nullability_mismatches: Vec<ColumnNullabilityDiff>,
176 pub primary_key_mismatches: Vec<ColumnPrimaryKeyDiff>,
178 pub missing_indexes: Vec<(String, SchemaIndex)>,
180 pub extra_indexes: Vec<(String, SchemaIndex)>,
182 pub missing_foreign_keys: Vec<(String, SchemaForeignKey)>,
184 pub extra_foreign_keys: Vec<(String, SchemaForeignKey)>,
186}
187
188impl SchemaDiff {
189 pub fn is_empty(&self) -> bool {
191 self.missing_tables.is_empty()
192 && self.extra_tables.is_empty()
193 && self.missing_columns.is_empty()
194 && self.extra_columns.is_empty()
195 && self.type_mismatches.is_empty()
196 && self.nullability_mismatches.is_empty()
197 && self.primary_key_mismatches.is_empty()
198 && self.missing_indexes.is_empty()
199 && self.extra_indexes.is_empty()
200 && self.missing_foreign_keys.is_empty()
201 && self.extra_foreign_keys.is_empty()
202 }
203}
204
205pub fn format_schema_diff_summary(diff: &SchemaDiff) -> String {
207 if diff.is_empty() {
208 return "Schema diff: no changes".to_string();
209 }
210
211 let mut lines = Vec::new();
212 lines.push("Schema diff summary:".to_string());
213 lines.push(format!(" missing tables: {}", diff.missing_tables.len()));
214 lines.push(format!(" extra tables: {}", diff.extra_tables.len()));
215 lines.push(format!(" missing columns: {}", diff.missing_columns.len()));
216 lines.push(format!(" extra columns: {}", diff.extra_columns.len()));
217 lines.push(format!(" type mismatches: {}", diff.type_mismatches.len()));
218 lines.push(format!(
219 " nullability mismatches: {}",
220 diff.nullability_mismatches.len()
221 ));
222 lines.push(format!(
223 " primary key mismatches: {}",
224 diff.primary_key_mismatches.len()
225 ));
226 lines.push(format!(" missing indexes: {}", diff.missing_indexes.len()));
227 lines.push(format!(" extra indexes: {}", diff.extra_indexes.len()));
228 lines.push(format!(
229 " missing foreign keys: {}",
230 diff.missing_foreign_keys.len()
231 ));
232 lines.push(format!(
233 " extra foreign keys: {}",
234 diff.extra_foreign_keys.len()
235 ));
236
237 if !diff.missing_tables.is_empty() {
238 lines.push(format!(
239 " missing tables list: {}",
240 diff.missing_tables.join(", ")
241 ));
242 }
243 if !diff.extra_tables.is_empty() {
244 lines.push(format!(
245 " extra tables list: {}",
246 diff.extra_tables.join(", ")
247 ));
248 }
249
250 lines.join("\n")
251}
252
253#[cfg(feature = "sqlite")]
255pub async fn introspect_sqlite_schema(pool: &SqlitePool) -> Result<Vec<SchemaTable>, sqlx::Error> {
256 let table_names: Vec<String> = sqlx::query_scalar(
257 "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' AND name != '_premix_migrations' ORDER BY name",
258 )
259 .fetch_all(pool)
260 .await?;
261
262 let mut tables = Vec::new();
263 for name in table_names {
264 let pragma_sql = format!("PRAGMA table_info({})", name);
265 let rows: Vec<(i64, String, String, i64, Option<String>, i64)> =
266 sqlx::query_as(&pragma_sql).fetch_all(pool).await?;
267
268 if rows.is_empty() {
269 continue;
270 }
271
272 let columns = rows
273 .into_iter()
274 .map(|(_cid, col_name, col_type, notnull, _default, pk)| {
275 let is_pk = pk > 0;
276 SchemaColumn {
277 name: col_name,
278 sql_type: col_type,
279 nullable: !is_pk && notnull == 0,
280 primary_key: is_pk,
281 }
282 })
283 .collect();
284
285 let indexes = introspect_sqlite_indexes(pool, &name).await?;
286 let foreign_keys = introspect_sqlite_foreign_keys(pool, &name).await?;
287
288 tables.push(SchemaTable {
289 name,
290 columns,
291 indexes,
292 foreign_keys,
293 create_sql: None,
294 });
295 }
296
297 Ok(tables)
298}
299
300#[cfg(feature = "postgres")]
301pub async fn introspect_postgres_schema(pool: &PgPool) -> Result<Vec<SchemaTable>, sqlx::Error> {
303 let table_names: Vec<String> = sqlx::query_scalar(
304 "SELECT table_name FROM information_schema.tables WHERE table_schema='public' AND table_type='BASE TABLE' AND table_name != '_premix_migrations' ORDER BY table_name",
305 )
306 .fetch_all(pool)
307 .await?;
308
309 let mut tables = Vec::new();
310 for name in table_names {
311 let pk_cols: Vec<String> = sqlx::query_scalar(
312 "SELECT a.attname FROM pg_index i JOIN pg_attribute a ON a.attrelid=i.indrelid AND a.attnum = ANY(i.indkey) WHERE i.indrelid=$1::regclass AND i.indisprimary",
313 )
314 .bind(&name)
315 .fetch_all(pool)
316 .await?;
317 let pk_set: BTreeSet<String> = pk_cols.into_iter().collect();
318
319 let rows: Vec<(String, String, String, String)> = sqlx::query_as(
320 "SELECT column_name, data_type, udt_name, is_nullable FROM information_schema.columns WHERE table_schema='public' AND table_name=$1 ORDER BY ordinal_position",
321 )
322 .bind(&name)
323 .fetch_all(pool)
324 .await?;
325
326 if rows.is_empty() {
327 continue;
328 }
329
330 let columns = rows
331 .into_iter()
332 .map(|(col_name, data_type, udt_name, is_nullable)| {
333 let is_pk = pk_set.contains(&col_name);
334 let sql_type = if data_type.eq_ignore_ascii_case("ARRAY") {
335 let base = udt_name.trim_start_matches('_');
336 format!("{}[]", base)
337 } else if data_type.eq_ignore_ascii_case("USER-DEFINED") {
338 udt_name
339 } else {
340 data_type
341 };
342 SchemaColumn {
343 name: col_name,
344 sql_type,
345 nullable: !is_pk && is_nullable.eq_ignore_ascii_case("YES"),
346 primary_key: is_pk,
347 }
348 })
349 .collect();
350
351 let indexes = introspect_postgres_indexes(pool, &name).await?;
352 let foreign_keys = introspect_postgres_foreign_keys(pool, &name).await?;
353
354 tables.push(SchemaTable {
355 name,
356 columns,
357 indexes,
358 foreign_keys,
359 create_sql: None,
360 });
361 }
362
363 Ok(tables)
364}
365
366#[cfg(feature = "sqlite")]
368pub async fn diff_sqlite_schema(
369 pool: &SqlitePool,
370 expected: &[SchemaTable],
371) -> Result<SchemaDiff, sqlx::Error> {
372 let actual = introspect_sqlite_schema(pool).await?;
373 Ok(diff_schema(expected, &actual))
374}
375
376#[cfg(feature = "postgres")]
378pub async fn diff_postgres_schema(
379 pool: &PgPool,
380 expected: &[SchemaTable],
381) -> Result<SchemaDiff, sqlx::Error> {
382 let actual = introspect_postgres_schema(pool).await?;
383 Ok(diff_schema(expected, &actual))
384}
385
386pub fn diff_schema(expected: &[SchemaTable], actual: &[SchemaTable]) -> SchemaDiff {
388 let mut diff = SchemaDiff::default();
389
390 let expected_map: BTreeMap<_, _> = expected.iter().map(|t| (&t.name, t)).collect();
391 let actual_map: BTreeMap<_, _> = actual.iter().map(|t| (&t.name, t)).collect();
392
393 for name in expected_map.keys() {
394 if !actual_map.contains_key(name) {
395 diff.missing_tables.push((*name).to_string());
396 }
397 }
398 for name in actual_map.keys() {
399 if !expected_map.contains_key(name) {
400 diff.extra_tables.push((*name).to_string());
401 }
402 }
403
404 for (name, expected_table) in &expected_map {
405 let Some(actual_table) = actual_map.get(name) else {
406 continue;
407 };
408
409 let expected_cols: BTreeMap<_, _> = expected_table
410 .columns
411 .iter()
412 .map(|c| (&c.name, c))
413 .collect();
414 let actual_cols: BTreeMap<_, _> =
415 actual_table.columns.iter().map(|c| (&c.name, c)).collect();
416
417 for col in expected_cols.keys() {
418 if !actual_cols.contains_key(col) {
419 let sql_type = expected_cols.get(col).map(|c| c.normalized_type());
420 diff.missing_columns.push(ColumnDiff {
421 table: (*name).to_string(),
422 column: (*col).to_string(),
423 sql_type,
424 });
425 }
426 }
427 for col in actual_cols.keys() {
428 if !expected_cols.contains_key(col) {
429 let sql_type = actual_cols.get(col).map(|c| c.normalized_type());
430 diff.extra_columns.push(ColumnDiff {
431 table: (*name).to_string(),
432 column: (*col).to_string(),
433 sql_type,
434 });
435 }
436 }
437
438 for (col_name, expected_col) in &expected_cols {
439 let Some(actual_col) = actual_cols.get(col_name) else {
440 continue;
441 };
442
443 let expected_type = expected_col.normalized_type();
444 let actual_type = actual_col.normalized_type();
445 if expected_type != actual_type {
446 diff.type_mismatches.push(ColumnTypeDiff {
447 table: (*name).to_string(),
448 column: (*col_name).to_string(),
449 expected: expected_col.sql_type.clone(),
450 actual: actual_col.sql_type.clone(),
451 });
452 }
453
454 if expected_col.nullable != actual_col.nullable {
455 diff.nullability_mismatches.push(ColumnNullabilityDiff {
456 table: (*name).to_string(),
457 column: (*col_name).to_string(),
458 expected_nullable: expected_col.nullable,
459 actual_nullable: actual_col.nullable,
460 });
461 }
462
463 if expected_col.primary_key != actual_col.primary_key {
464 diff.primary_key_mismatches.push(ColumnPrimaryKeyDiff {
465 table: (*name).to_string(),
466 column: (*col_name).to_string(),
467 expected_primary_key: expected_col.primary_key,
468 actual_primary_key: actual_col.primary_key,
469 });
470 }
471 }
472
473 let expected_indexes = index_map(&expected_table.indexes);
474 let actual_indexes = index_map(&actual_table.indexes);
475 for key in expected_indexes.keys() {
476 if !actual_indexes.contains_key(key) {
477 if let Some(index) = expected_indexes.get(key) {
478 diff.missing_indexes
479 .push(((*name).to_string(), (*index).clone()));
480 }
481 }
482 }
483 for key in actual_indexes.keys() {
484 if !expected_indexes.contains_key(key) {
485 if let Some(index) = actual_indexes.get(key) {
486 diff.extra_indexes
487 .push(((*name).to_string(), (*index).clone()));
488 }
489 }
490 }
491
492 let expected_fks = foreign_key_map(&expected_table.foreign_keys);
493 let actual_fks = foreign_key_map(&actual_table.foreign_keys);
494 for key in expected_fks.keys() {
495 if !actual_fks.contains_key(key) {
496 if let Some(fk) = expected_fks.get(key) {
497 diff.missing_foreign_keys
498 .push(((*name).to_string(), (*fk).clone()));
499 }
500 }
501 }
502 for key in actual_fks.keys() {
503 if !expected_fks.contains_key(key) {
504 if let Some(fk) = actual_fks.get(key) {
505 diff.extra_foreign_keys
506 .push(((*name).to_string(), (*fk).clone()));
507 }
508 }
509 }
510 }
511
512 diff.missing_tables.sort();
513 diff.extra_tables.sort();
514
515 diff
516}
517
518pub fn sqlite_migration_sql(expected: &[SchemaTable], diff: &SchemaDiff) -> Vec<String> {
520 let expected_map: BTreeMap<String, &SchemaTable> =
521 expected.iter().map(|t| (t.name.clone(), t)).collect();
522 let mut statements = Vec::new();
523
524 for table in &diff.missing_tables {
525 if let Some(schema) = expected_map.get(table) {
526 statements.push(schema.to_create_sql());
527 for index in &schema.indexes {
528 statements.push(sqlite_create_index_sql(&schema.name, index));
529 }
530 } else {
531 statements.push(format!("-- Missing schema for table {}", table));
532 }
533 }
534
535 let mut missing_by_table: BTreeMap<String, BTreeSet<String>> = BTreeMap::new();
536 for col in &diff.missing_columns {
537 missing_by_table
538 .entry(col.table.clone())
539 .or_default()
540 .insert(col.column.clone());
541 }
542
543 for (table, columns) in missing_by_table {
544 let Some(schema) = expected_map.get(&table) else {
545 continue;
546 };
547 for col_name in columns {
548 let Some(col) = schema.column(&col_name) else {
549 continue;
550 };
551
552 let potential_rename = diff.extra_columns.iter().find(|e| {
556 e.table == table && e.sql_type.as_deref() == Some(&col.normalized_type())
557 });
558
559 if let Some(old) = potential_rename {
560 statements.push(format!(
561 "-- SUGESTION: Potential rename from column '{}' to '{}'?",
562 old.column, col.name
563 ));
564 }
565
566 if col.primary_key {
567 statements.push(format!(
568 "-- TODO: add primary key column {}.{} manually",
569 table, col_name
570 ));
571 continue;
572 }
573
574 if !col.nullable {
575 statements.push(format!(
576 "-- WARNING: Adding NOT NULL column '{}.{}' without a default value will fail if table contains rows.",
577 table, col.name
578 ));
579 }
580
581 let mut stmt = format!(
582 "ALTER TABLE {} ADD COLUMN {} {}",
583 table, col.name, col.sql_type
584 );
585 if !col.nullable {
586 stmt.push_str(" NOT NULL");
587 }
588 statements.push(stmt);
589 }
590 }
591
592 for mismatch in &diff.type_mismatches {
593 statements.push(format!(
594 "-- TODO: column type mismatch {}.{} (expected {}, actual {})",
595 mismatch.table, mismatch.column, mismatch.expected, mismatch.actual
596 ));
597 }
598 for mismatch in &diff.nullability_mismatches {
599 statements.push(format!(
600 "-- TODO: column nullability mismatch {}.{} (expected nullable {}, actual nullable {})",
601 mismatch.table, mismatch.column, mismatch.expected_nullable, mismatch.actual_nullable
602 ));
603 }
604 for mismatch in &diff.primary_key_mismatches {
605 statements.push(format!(
606 "-- TODO: column primary key mismatch {}.{} (expected pk {}, actual pk {})",
607 mismatch.table,
608 mismatch.column,
609 mismatch.expected_primary_key,
610 mismatch.actual_primary_key
611 ));
612 }
613 for (table, index) in &diff.missing_indexes {
614 statements.push(sqlite_create_index_sql(table, index));
615 }
616 for (table, index) in &diff.extra_indexes {
617 statements.push(format!(
618 "-- TODO: extra index {}.{} ({})",
619 table,
620 index.name,
621 index.columns.join(", ")
622 ));
623 }
624 for (table, fk) in &diff.missing_foreign_keys {
625 statements.push(format!(
626 "-- TODO: add foreign key {}.{} -> {}({}) (requires table rebuild)",
627 table, fk.column, fk.ref_table, fk.ref_column
628 ));
629 }
630 for (table, fk) in &diff.extra_foreign_keys {
631 statements.push(format!(
632 "-- TODO: extra foreign key {}.{} -> {}({})",
633 table, fk.column, fk.ref_table, fk.ref_column
634 ));
635 }
636 for extra in &diff.extra_columns {
637 statements.push(format!(
638 "-- TODO: extra column {}.{} not in models",
639 extra.table, extra.column
640 ));
641 }
642 for table in &diff.extra_tables {
643 statements.push(format!("-- TODO: extra table {} not in models", table));
644 }
645
646 statements
647}
648
649fn normalize_sql_type(sql_type: &str) -> String {
650 let t = sql_type.trim().to_lowercase();
651 if t.is_empty() {
652 return t;
653 }
654 if t.contains("int") || t.contains("serial") {
655 return "integer".to_string();
656 }
657 if t.contains("char") || t.contains("text") || t.contains("clob") {
658 return "text".to_string();
659 }
660 if t.contains("real")
661 || t.contains("floa")
662 || t.contains("doub")
663 || t.contains("numeric")
664 || t.contains("decimal")
665 {
666 return "real".to_string();
667 }
668 if t.contains("bool") {
669 return "boolean".to_string();
670 }
671 if t.contains("time") || t.contains("date") || t.contains("uuid") || t.contains("json") {
672 return "text".to_string();
673 }
674 t
675}
676
677#[cfg(feature = "postgres")]
678pub fn postgres_migration_sql(expected: &[SchemaTable], diff: &SchemaDiff) -> Vec<String> {
680 let expected_map: BTreeMap<String, &SchemaTable> =
681 expected.iter().map(|t| (t.name.clone(), t)).collect();
682 let mut statements = Vec::new();
683
684 for table in &diff.missing_tables {
685 if let Some(schema) = expected_map.get(table) {
686 statements.push(schema.to_create_sql());
687 for index in &schema.indexes {
688 statements.push(postgres_create_index_sql(&schema.name, index));
689 }
690 } else {
691 statements.push(format!("-- Missing schema for table {}", table));
692 }
693 }
694
695 let mut missing_by_table: BTreeMap<String, BTreeSet<String>> = BTreeMap::new();
696 for col in &diff.missing_columns {
697 missing_by_table
698 .entry(col.table.clone())
699 .or_default()
700 .insert(col.column.clone());
701 }
702
703 for (table, columns) in missing_by_table {
704 let Some(schema) = expected_map.get(&table) else {
705 continue;
706 };
707 for col_name in columns {
708 let Some(col) = schema.column(&col_name) else {
709 continue;
710 };
711
712 let potential_rename = diff.extra_columns.iter().find(|e| {
714 e.table == table && e.sql_type.as_deref() == Some(&col.normalized_type())
715 });
716
717 if let Some(old) = potential_rename {
718 statements.push(format!(
719 "-- SUGESTION: Potential rename from column '{}' to '{}'?",
720 old.column, col.name
721 ));
722 }
723
724 if col.primary_key {
725 statements.push(format!(
726 "-- TODO: add primary key column {}.{} manually",
727 table, col_name
728 ));
729 continue;
730 }
731
732 if !col.nullable {
733 statements.push(format!(
734 "-- WARNING: Adding NOT NULL column '{}.{}' without a default value will fail if table contains rows.",
735 table, col.name
736 ));
737 }
738
739 let mut stmt = format!(
740 "ALTER TABLE {} ADD COLUMN {} {}",
741 table, col.name, col.sql_type
742 );
743 if !col.nullable {
744 stmt.push_str(" NOT NULL");
745 }
746 statements.push(stmt);
747 }
748 }
749
750 for mismatch in &diff.type_mismatches {
751 statements.push(format!(
752 "-- TODO: column type mismatch {}.{} (expected {}, actual {})",
753 mismatch.table, mismatch.column, mismatch.expected, mismatch.actual
754 ));
755 }
756 for mismatch in &diff.nullability_mismatches {
757 statements.push(format!(
758 "-- TODO: column nullability mismatch {}.{} (expected nullable {}, actual nullable {})",
759 mismatch.table, mismatch.column, mismatch.expected_nullable, mismatch.actual_nullable
760 ));
761 }
762 for mismatch in &diff.primary_key_mismatches {
763 statements.push(format!(
764 "-- TODO: column primary key mismatch {}.{} (expected pk {}, actual pk {})",
765 mismatch.table,
766 mismatch.column,
767 mismatch.expected_primary_key,
768 mismatch.actual_primary_key
769 ));
770 }
771 for (table, index) in &diff.missing_indexes {
772 statements.push(postgres_create_index_sql(table, index));
773 }
774 for (table, index) in &diff.extra_indexes {
775 statements.push(format!(
776 "-- TODO: extra index {}.{} ({})",
777 table,
778 index.name,
779 index.columns.join(", ")
780 ));
781 }
782 for (table, fk) in &diff.missing_foreign_keys {
783 let fk_name = format!("fk_{}_{}", table, fk.column);
784 statements.push(format!(
785 "ALTER TABLE {} ADD CONSTRAINT {} FOREIGN KEY ({}) REFERENCES {}({})",
786 table, fk_name, fk.column, fk.ref_table, fk.ref_column
787 ));
788 }
789 for (table, fk) in &diff.extra_foreign_keys {
790 statements.push(format!(
791 "-- TODO: extra foreign key {}.{} -> {}({})",
792 table, fk.column, fk.ref_table, fk.ref_column
793 ));
794 }
795 for extra in &diff.extra_columns {
796 statements.push(format!(
797 "-- TODO: extra column {}.{} not in models",
798 extra.table, extra.column
799 ));
800 }
801 for table in &diff.extra_tables {
802 statements.push(format!("-- TODO: extra table {} not in models", table));
803 }
804
805 statements
806}
807
808fn index_key(index: &SchemaIndex) -> (String, String, bool) {
809 let name = index.name.clone();
810 let cols = index.columns.join(",");
811 (name, cols, index.unique)
812}
813
814fn index_map(indexes: &[SchemaIndex]) -> BTreeMap<(String, String, bool), &SchemaIndex> {
815 indexes.iter().map(|i| (index_key(i), i)).collect()
816}
817
818fn foreign_key_key(fk: &SchemaForeignKey) -> (String, String, String) {
819 (
820 fk.column.clone(),
821 fk.ref_table.clone(),
822 fk.ref_column.clone(),
823 )
824}
825
826fn foreign_key_map(
827 fks: &[SchemaForeignKey],
828) -> BTreeMap<(String, String, String), &SchemaForeignKey> {
829 fks.iter().map(|f| (foreign_key_key(f), f)).collect()
830}
831
832#[cfg(feature = "sqlite")]
833async fn introspect_sqlite_indexes(
834 pool: &SqlitePool,
835 table: &str,
836) -> Result<Vec<SchemaIndex>, sqlx::Error> {
837 let sql = format!("PRAGMA index_list({})", table);
838 let rows: Vec<(i64, String, i64, String, i64)> = sqlx::query_as(&sql).fetch_all(pool).await?;
839
840 let mut indexes = Vec::new();
841 for (_seq, name, unique, origin, _partial) in rows {
842 if origin == "pk" || name.starts_with("sqlite_autoindex") {
843 continue;
844 }
845 let info_sql = format!("PRAGMA index_info({})", name);
846 let info_rows: Vec<(i64, i64, String)> = sqlx::query_as(&info_sql).fetch_all(pool).await?;
847 let columns = info_rows.into_iter().map(|(_seq, _cid, col)| col).collect();
848 indexes.push(SchemaIndex {
849 name,
850 columns,
851 unique: unique != 0,
852 });
853 }
854 Ok(indexes)
855}
856
857#[cfg(feature = "sqlite")]
858async fn introspect_sqlite_foreign_keys(
859 pool: &SqlitePool,
860 table: &str,
861) -> Result<Vec<SchemaForeignKey>, sqlx::Error> {
862 let sql = format!("PRAGMA foreign_key_list({})", table);
863 #[allow(clippy::type_complexity)]
864 let rows: Vec<(i64, i64, String, String, String, String, String, String)> =
865 sqlx::query_as(&sql).fetch_all(pool).await?;
866
867 let mut fks = Vec::new();
868 for (_id, _seq, ref_table, from, to, _on_update, _on_delete, _match) in rows {
869 fks.push(SchemaForeignKey {
870 column: from,
871 ref_table,
872 ref_column: to,
873 });
874 }
875 Ok(fks)
876}
877
878fn sqlite_create_index_sql(table: &str, index: &SchemaIndex) -> String {
879 let unique = if index.unique { "UNIQUE " } else { "" };
880 let name = if index.name.is_empty() {
881 format!("idx_{}_{}", table, index.columns.join("_"))
882 } else {
883 index.name.clone()
884 };
885 format!(
886 "CREATE {}INDEX IF NOT EXISTS {} ON {} ({})",
887 unique,
888 name,
889 table,
890 index.columns.join(", ")
891 )
892}
893
894#[cfg(feature = "postgres")]
895fn postgres_create_index_sql(table: &str, index: &SchemaIndex) -> String {
896 let unique = if index.unique { "UNIQUE " } else { "" };
897 let name = if index.name.is_empty() {
898 format!("idx_{}_{}", table, index.columns.join("_"))
899 } else {
900 index.name.clone()
901 };
902 format!(
903 "CREATE {}INDEX IF NOT EXISTS {} ON {} ({})",
904 unique,
905 name,
906 table,
907 index.columns.join(", ")
908 )
909}
910
911#[cfg(feature = "postgres")]
912async fn introspect_postgres_indexes(
913 pool: &PgPool,
914 table: &str,
915) -> Result<Vec<SchemaIndex>, sqlx::Error> {
916 let rows: Vec<(String, bool, Vec<String>)> = sqlx::query_as(
917 "SELECT i.relname AS index_name, ix.indisunique, array_agg(a.attname ORDER BY x.n) AS columns
918 FROM pg_class t
919 JOIN pg_index ix ON t.oid = ix.indrelid
920 JOIN pg_class i ON i.oid = ix.indexrelid
921 JOIN LATERAL unnest(ix.indkey) WITH ORDINALITY AS x(attnum, n) ON true
922 JOIN pg_attribute a ON a.attrelid = t.oid AND a.attnum = x.attnum
923 WHERE t.relname = $1 AND t.relkind = 'r' AND NOT ix.indisprimary
924 GROUP BY i.relname, ix.indisunique
925 ORDER BY i.relname",
926 )
927 .bind(table)
928 .fetch_all(pool)
929 .await?;
930
931 let indexes = rows
932 .into_iter()
933 .map(|(name, unique, columns)| SchemaIndex {
934 name,
935 columns,
936 unique,
937 })
938 .collect();
939 Ok(indexes)
940}
941
942#[cfg(feature = "postgres")]
943async fn introspect_postgres_foreign_keys(
944 pool: &PgPool,
945 table: &str,
946) -> Result<Vec<SchemaForeignKey>, sqlx::Error> {
947 let rows: Vec<(String, String, String)> = sqlx::query_as(
948 "SELECT kcu.column_name, ccu.table_name, ccu.column_name
949 FROM information_schema.table_constraints tc
950 JOIN information_schema.key_column_usage kcu
951 ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema
952 JOIN information_schema.constraint_column_usage ccu
953 ON ccu.constraint_name = tc.constraint_name AND ccu.table_schema = tc.table_schema
954 WHERE tc.constraint_type = 'FOREIGN KEY'
955 AND tc.table_schema = 'public'
956 AND tc.table_name = $1
957 ORDER BY kcu.ordinal_position",
958 )
959 .bind(table)
960 .fetch_all(pool)
961 .await?;
962
963 let fks = rows
964 .into_iter()
965 .map(|(column, ref_table, ref_column)| SchemaForeignKey {
966 column,
967 ref_table,
968 ref_column,
969 })
970 .collect();
971
972 Ok(fks)
973}
974
975#[cfg(test)]
976mod tests {
977 use super::*;
978
979 #[cfg(feature = "sqlite")]
980 #[tokio::test]
981 async fn sqlite_introspect_and_diff_empty() {
982 let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
983 sqlx::query(
984 "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL, deleted_at TEXT)",
985 )
986 .execute(&pool)
987 .await
988 .unwrap();
989
990 let expected = vec![SchemaTable {
991 name: "users".to_string(),
992 columns: vec![
993 SchemaColumn {
994 name: "id".to_string(),
995 sql_type: "INTEGER".to_string(),
996 nullable: false,
997 primary_key: true,
998 },
999 SchemaColumn {
1000 name: "name".to_string(),
1001 sql_type: "TEXT".to_string(),
1002 nullable: false,
1003 primary_key: false,
1004 },
1005 SchemaColumn {
1006 name: "deleted_at".to_string(),
1007 sql_type: "TEXT".to_string(),
1008 nullable: true,
1009 primary_key: false,
1010 },
1011 ],
1012 indexes: Vec::new(),
1013 foreign_keys: Vec::new(),
1014 create_sql: None,
1015 }];
1016
1017 let actual = introspect_sqlite_schema(&pool).await.unwrap();
1018 let diff = diff_schema(&expected, &actual);
1019 assert!(diff.is_empty());
1020 }
1021
1022 #[cfg(feature = "sqlite")]
1023 #[tokio::test]
1024 async fn sqlite_diff_reports_missing_column() {
1025 let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
1026 sqlx::query("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)")
1027 .execute(&pool)
1028 .await
1029 .unwrap();
1030
1031 let expected = vec![SchemaTable {
1032 name: "users".to_string(),
1033 columns: vec![
1034 SchemaColumn {
1035 name: "id".to_string(),
1036 sql_type: "INTEGER".to_string(),
1037 nullable: false,
1038 primary_key: true,
1039 },
1040 SchemaColumn {
1041 name: "name".to_string(),
1042 sql_type: "TEXT".to_string(),
1043 nullable: false,
1044 primary_key: false,
1045 },
1046 SchemaColumn {
1047 name: "status".to_string(),
1048 sql_type: "TEXT".to_string(),
1049 nullable: true,
1050 primary_key: false,
1051 },
1052 ],
1053 indexes: Vec::new(),
1054 foreign_keys: Vec::new(),
1055 create_sql: None,
1056 }];
1057
1058 let actual = introspect_sqlite_schema(&pool).await.unwrap();
1059 let diff = diff_schema(&expected, &actual);
1060 assert_eq!(diff.missing_columns.len(), 1);
1061
1062 let summary = format_schema_diff_summary(&diff);
1063 assert!(summary.contains("missing columns: 1"));
1064
1065 let sql = sqlite_migration_sql(&expected, &diff);
1066 assert!(
1067 sql.iter()
1068 .any(|stmt| stmt.contains("ALTER TABLE users ADD COLUMN status"))
1069 );
1070 }
1071
1072 #[cfg(feature = "sqlite")]
1073 #[tokio::test]
1074 async fn sqlite_diff_reports_missing_index() {
1075 let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
1076 sqlx::query("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)")
1077 .execute(&pool)
1078 .await
1079 .unwrap();
1080
1081 let expected = vec![SchemaTable {
1082 name: "users".to_string(),
1083 columns: vec![
1084 SchemaColumn {
1085 name: "id".to_string(),
1086 sql_type: "INTEGER".to_string(),
1087 nullable: false,
1088 primary_key: true,
1089 },
1090 SchemaColumn {
1091 name: "name".to_string(),
1092 sql_type: "TEXT".to_string(),
1093 nullable: false,
1094 primary_key: false,
1095 },
1096 ],
1097 indexes: vec![SchemaIndex {
1098 name: "idx_users_name".to_string(),
1099 columns: vec!["name".to_string()],
1100 unique: false,
1101 }],
1102 foreign_keys: Vec::new(),
1103 create_sql: None,
1104 }];
1105
1106 let actual = introspect_sqlite_schema(&pool).await.unwrap();
1107 let diff = diff_schema(&expected, &actual);
1108 assert_eq!(diff.missing_indexes.len(), 1);
1109
1110 let sql = sqlite_migration_sql(&expected, &diff);
1111 assert!(
1112 sql.iter()
1113 .any(|stmt| stmt.contains("CREATE INDEX IF NOT EXISTS idx_users_name"))
1114 );
1115 }
1116
1117 #[cfg(feature = "postgres")]
1118 fn pg_url() -> String {
1119 std::env::var("DATABASE_URL").unwrap_or_else(|_| {
1120 "postgres://postgres:admin123@localhost:5432/premix_bench".to_string()
1121 })
1122 }
1123
1124 #[cfg(feature = "postgres")]
1125 #[tokio::test]
1126 async fn postgres_introspect_and_diff() {
1127 let url = pg_url();
1128 let pool = match PgPool::connect(&url).await {
1129 Ok(pool) => pool,
1130 Err(_) => return,
1131 };
1132
1133 sqlx::query("DROP TABLE IF EXISTS schema_posts")
1134 .execute(&pool)
1135 .await
1136 .unwrap();
1137 sqlx::query("DROP TABLE IF EXISTS schema_users")
1138 .execute(&pool)
1139 .await
1140 .unwrap();
1141
1142 sqlx::query("CREATE TABLE schema_users (id SERIAL PRIMARY KEY, name TEXT NOT NULL)")
1143 .execute(&pool)
1144 .await
1145 .unwrap();
1146 sqlx::query(
1147 "CREATE TABLE schema_posts (id SERIAL PRIMARY KEY, user_id INTEGER NOT NULL, title TEXT NOT NULL, CONSTRAINT fk_schema_posts_user_id FOREIGN KEY (user_id) REFERENCES schema_users(id))",
1148 )
1149 .execute(&pool)
1150 .await
1151 .unwrap();
1152 sqlx::query("CREATE INDEX idx_schema_posts_user_id ON schema_posts (user_id)")
1153 .execute(&pool)
1154 .await
1155 .unwrap();
1156
1157 let expected = vec![
1158 SchemaTable {
1159 name: "schema_posts".to_string(),
1160 columns: vec![
1161 SchemaColumn {
1162 name: "id".to_string(),
1163 sql_type: "INTEGER".to_string(),
1164 nullable: false,
1165 primary_key: true,
1166 },
1167 SchemaColumn {
1168 name: "user_id".to_string(),
1169 sql_type: "INTEGER".to_string(),
1170 nullable: false,
1171 primary_key: false,
1172 },
1173 SchemaColumn {
1174 name: "title".to_string(),
1175 sql_type: "TEXT".to_string(),
1176 nullable: false,
1177 primary_key: false,
1178 },
1179 ],
1180 indexes: vec![SchemaIndex {
1181 name: "idx_schema_posts_user_id".to_string(),
1182 columns: vec!["user_id".to_string()],
1183 unique: false,
1184 }],
1185 foreign_keys: vec![SchemaForeignKey {
1186 column: "user_id".to_string(),
1187 ref_table: "schema_users".to_string(),
1188 ref_column: "id".to_string(),
1189 }],
1190 create_sql: None,
1191 },
1192 SchemaTable {
1193 name: "schema_users".to_string(),
1194 columns: vec![
1195 SchemaColumn {
1196 name: "id".to_string(),
1197 sql_type: "INTEGER".to_string(),
1198 nullable: false,
1199 primary_key: true,
1200 },
1201 SchemaColumn {
1202 name: "name".to_string(),
1203 sql_type: "TEXT".to_string(),
1204 nullable: false,
1205 primary_key: false,
1206 },
1207 ],
1208 indexes: Vec::new(),
1209 foreign_keys: Vec::new(),
1210 create_sql: None,
1211 },
1212 ];
1213
1214 let actual = introspect_postgres_schema(&pool).await.unwrap();
1215 let expected_names: BTreeSet<String> =
1216 expected.iter().map(|table| table.name.clone()).collect();
1217 let actual = actual
1218 .into_iter()
1219 .filter(|table| expected_names.contains(&table.name))
1220 .collect::<Vec<_>>();
1221 let diff = diff_schema(&expected, &actual);
1222 assert!(diff.is_empty(), "postgres schema diff: {diff:?}");
1223
1224 sqlx::query("DROP TABLE IF EXISTS schema_posts")
1225 .execute(&pool)
1226 .await
1227 .unwrap();
1228 sqlx::query("DROP TABLE IF EXISTS schema_users")
1229 .execute(&pool)
1230 .await
1231 .unwrap();
1232 }
1233}