Skip to main content

premix_core/
schema.rs

1use std::collections::{BTreeMap, BTreeSet};
2
3#[cfg(feature = "mysql")]
4use sqlx::MySqlPool;
5#[cfg(feature = "postgres")]
6use sqlx::PgPool;
7#[cfg(feature = "sqlite")]
8use sqlx::SqlitePool;
9
10/// Metadata about a database column.
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub struct SchemaColumn {
13    /// The name of the column.
14    pub name: String,
15    /// The SQL type of the column (e.g., "INTEGER", "TEXT").
16    pub sql_type: String,
17    /// Whether the column can contain NULL values.
18    pub nullable: bool,
19    /// Whether the column is part of the Primary Key.
20    pub primary_key: bool,
21}
22
23impl SchemaColumn {
24    fn normalized_type(&self) -> String {
25        normalize_sql_type(&self.sql_type)
26    }
27}
28
29/// Metadata about a database index.
30#[derive(Debug, Clone, PartialEq, Eq)]
31pub struct SchemaIndex {
32    /// The name of the index.
33    pub name: String,
34    /// The columns included in the index.
35    pub columns: Vec<String>,
36    /// Whether the index is UNIQUE.
37    pub unique: bool,
38}
39
40/// Metadata about a foreign key relationship.
41#[derive(Debug, Clone, PartialEq, Eq)]
42pub struct SchemaForeignKey {
43    /// The column in the current table.
44    pub column: String,
45    /// The table being referenced.
46    pub ref_table: String,
47    /// The column being referenced in the target table.
48    pub ref_column: String,
49}
50
51/// Metadata about a database table.
52#[derive(Debug, Clone, PartialEq, Eq)]
53pub struct SchemaTable {
54    /// The name of the table.
55    pub name: String,
56    /// The columns in the table.
57    pub columns: Vec<SchemaColumn>,
58    /// The indexes on the table.
59    pub indexes: Vec<SchemaIndex>,
60    /// The foreign keys in the table.
61    pub foreign_keys: Vec<SchemaForeignKey>,
62    /// The original CREATE TABLE SQL (if available).
63    pub create_sql: Option<String>,
64}
65
66impl SchemaTable {
67    /// Returns a column by name if it exists in the table.
68    pub fn column(&self, name: &str) -> Option<&SchemaColumn> {
69        self.columns.iter().find(|c| c.name == name)
70    }
71
72    /// Generates a `CREATE TABLE` SQL statement for this table.
73    pub fn to_create_sql(&self) -> String {
74        if let Some(sql) = &self.create_sql {
75            return sql.clone();
76        }
77
78        let mut cols = Vec::new();
79        for col in &self.columns {
80            if col.primary_key {
81                cols.push(format!("{} INTEGER PRIMARY KEY", col.name));
82                continue;
83            }
84            let mut def = format!("{} {}", col.name, col.sql_type);
85            if !col.nullable {
86                def.push_str(" NOT NULL");
87            }
88            cols.push(def);
89        }
90
91        format!(
92            "CREATE TABLE IF NOT EXISTS {} ({})",
93            self.name,
94            cols.join(", ")
95        )
96    }
97}
98
99/// A trait for models that can provide their own schema metadata.
100pub trait ModelSchema {
101    /// Returns the schema metadata for this model.
102    fn schema() -> SchemaTable;
103}
104
105/// Helper macro to collect schema metadata from multiple models.
106#[macro_export]
107macro_rules! schema_models {
108    ($($model:ty),+ $(,)?) => {
109        vec![$(<$model as $crate::schema::ModelSchema>::schema()),+]
110    };
111}
112
113/// Represents a change in a column (addition or removal).
114#[derive(Debug, Clone, PartialEq, Eq)]
115pub struct ColumnDiff {
116    /// The table containing the column.
117    pub table: String,
118    /// The name of the column.
119    pub column: String,
120    /// The SQL type of the column.
121    pub sql_type: Option<String>,
122}
123
124/// Represents a mismatch in column types between models and database.
125#[derive(Debug, Clone, PartialEq, Eq)]
126pub struct ColumnTypeDiff {
127    /// The table containing the column.
128    pub table: String,
129    /// The name of the column.
130    pub column: String,
131    /// The SQL type expected by the model.
132    pub expected: String,
133    /// The actual SQL type found in the database.
134    pub actual: String,
135}
136
137/// Represents a mismatch in column nullability.
138#[derive(Debug, Clone, PartialEq, Eq)]
139pub struct ColumnNullabilityDiff {
140    /// The table containing the column.
141    pub table: String,
142    /// The name of the column.
143    pub column: String,
144    /// Whether the model expects the column to be nullable.
145    pub expected_nullable: bool,
146    /// Whether the column is actually nullable in the database.
147    pub actual_nullable: bool,
148}
149
150/// Represents a mismatch in Primary Key status.
151#[derive(Debug, Clone, PartialEq, Eq)]
152pub struct ColumnPrimaryKeyDiff {
153    /// The table containing the column.
154    pub table: String,
155    /// The name of the column.
156    pub column: String,
157    /// Whether the model expects the column to be a Primary Key.
158    pub expected_primary_key: bool,
159    /// Whether the column is actually a Primary Key in the database.
160    pub actual_primary_key: bool,
161}
162
163/// Represents the differences between two database schemas.
164#[derive(Debug, Clone, PartialEq, Eq, Default)]
165pub struct SchemaDiff {
166    /// Tables expected but missing in the actual database.
167    pub missing_tables: Vec<String>,
168    /// Tables present in the database but not in the models.
169    pub extra_tables: Vec<String>,
170    /// Columns missing in existing tables.
171    pub missing_columns: Vec<ColumnDiff>,
172    /// Columns present in the database but not in the models.
173    pub extra_columns: Vec<ColumnDiff>,
174    /// Columns with different types than expected.
175    pub type_mismatches: Vec<ColumnTypeDiff>,
176    /// Columns with different nullability than expected.
177    pub nullability_mismatches: Vec<ColumnNullabilityDiff>,
178    /// Columns with different Primary Key status than expected.
179    pub primary_key_mismatches: Vec<ColumnPrimaryKeyDiff>,
180    /// Indexes missing in the actual database.
181    pub missing_indexes: Vec<(String, SchemaIndex)>,
182    /// Indexes present in the database but not in the models.
183    pub extra_indexes: Vec<(String, SchemaIndex)>,
184    /// Foreign keys missing in the actual database.
185    pub missing_foreign_keys: Vec<(String, SchemaForeignKey)>,
186    /// Foreign keys present in the database but not in the models.
187    pub extra_foreign_keys: Vec<(String, SchemaForeignKey)>,
188}
189
190impl SchemaDiff {
191    /// Returns true if there are no differences.
192    pub fn is_empty(&self) -> bool {
193        self.missing_tables.is_empty()
194            && self.extra_tables.is_empty()
195            && self.missing_columns.is_empty()
196            && self.extra_columns.is_empty()
197            && self.type_mismatches.is_empty()
198            && self.nullability_mismatches.is_empty()
199            && self.primary_key_mismatches.is_empty()
200            && self.missing_indexes.is_empty()
201            && self.extra_indexes.is_empty()
202            && self.missing_foreign_keys.is_empty()
203            && self.extra_foreign_keys.is_empty()
204    }
205}
206
207/// Formats a [`SchemaDiff`] into a human-readable summary.
208pub fn format_schema_diff_summary(diff: &SchemaDiff) -> String {
209    if diff.is_empty() {
210        return "Schema diff: no changes".to_string();
211    }
212
213    let mut lines = Vec::new();
214    lines.push("Schema diff summary:".to_string());
215    lines.push(format!("  missing tables: {}", diff.missing_tables.len()));
216    lines.push(format!("  extra tables: {}", diff.extra_tables.len()));
217    lines.push(format!("  missing columns: {}", diff.missing_columns.len()));
218    lines.push(format!("  extra columns: {}", diff.extra_columns.len()));
219    lines.push(format!("  type mismatches: {}", diff.type_mismatches.len()));
220    lines.push(format!(
221        "  nullability mismatches: {}",
222        diff.nullability_mismatches.len()
223    ));
224    lines.push(format!(
225        "  primary key mismatches: {}",
226        diff.primary_key_mismatches.len()
227    ));
228    lines.push(format!("  missing indexes: {}", diff.missing_indexes.len()));
229    lines.push(format!("  extra indexes: {}", diff.extra_indexes.len()));
230    lines.push(format!(
231        "  missing foreign keys: {}",
232        diff.missing_foreign_keys.len()
233    ));
234    lines.push(format!(
235        "  extra foreign keys: {}",
236        diff.extra_foreign_keys.len()
237    ));
238
239    if !diff.missing_tables.is_empty() {
240        lines.push(format!(
241            "  missing tables list: {}",
242            diff.missing_tables.join(", ")
243        ));
244    }
245    if !diff.extra_tables.is_empty() {
246        lines.push(format!(
247            "  extra tables list: {}",
248            diff.extra_tables.join(", ")
249        ));
250    }
251
252    lines.join("\n")
253}
254
255/// Introspects the schema of a SQLite database.
256#[cfg(feature = "sqlite")]
257pub async fn introspect_sqlite_schema(pool: &SqlitePool) -> Result<Vec<SchemaTable>, sqlx::Error> {
258    let table_names: Vec<String> = sqlx::query_scalar(
259        "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' AND name != '_premix_migrations' ORDER BY name",
260    )
261    .fetch_all(pool)
262    .await?;
263
264    let mut tables = Vec::new();
265    for name in table_names {
266        let pragma_sql = format!("PRAGMA table_info({})", name);
267        let rows: Vec<(i64, String, String, i64, Option<String>, i64)> =
268            sqlx::query_as(&pragma_sql).fetch_all(pool).await?;
269
270        if rows.is_empty() {
271            continue;
272        }
273
274        let columns = rows
275            .into_iter()
276            .map(|(_cid, col_name, col_type, notnull, _default, pk)| {
277                let is_pk = pk > 0;
278                SchemaColumn {
279                    name: col_name,
280                    sql_type: col_type,
281                    nullable: !is_pk && notnull == 0,
282                    primary_key: is_pk,
283                }
284            })
285            .collect();
286
287        let indexes = introspect_sqlite_indexes(pool, &name).await?;
288        let foreign_keys = introspect_sqlite_foreign_keys(pool, &name).await?;
289
290        tables.push(SchemaTable {
291            name,
292            columns,
293            indexes,
294            foreign_keys,
295            create_sql: None,
296        });
297    }
298
299    Ok(tables)
300}
301
302#[cfg(feature = "postgres")]
303/// Introspects the schema of a PostgreSQL database.
304pub async fn introspect_postgres_schema(pool: &PgPool) -> Result<Vec<SchemaTable>, sqlx::Error> {
305    let table_names: Vec<String> = sqlx::query_scalar(
306        "SELECT table_name FROM information_schema.tables WHERE table_schema='public' AND table_type='BASE TABLE' AND table_name != '_premix_migrations' ORDER BY table_name",
307    )
308    .fetch_all(pool)
309    .await?;
310
311    let mut tables = Vec::new();
312    for name in table_names {
313        let pk_cols: Vec<String> = sqlx::query_scalar(
314            "SELECT a.attname FROM pg_index i JOIN pg_attribute a ON a.attrelid=i.indrelid AND a.attnum = ANY(i.indkey) WHERE i.indrelid=$1::regclass AND i.indisprimary",
315        )
316        .bind(&name)
317        .fetch_all(pool)
318        .await?;
319        let pk_set: BTreeSet<String> = pk_cols.into_iter().collect();
320
321        let rows: Vec<(String, String, String, String)> = sqlx::query_as(
322            "SELECT column_name, data_type, udt_name, is_nullable FROM information_schema.columns WHERE table_schema='public' AND table_name=$1 ORDER BY ordinal_position",
323        )
324        .bind(&name)
325        .fetch_all(pool)
326        .await?;
327
328        if rows.is_empty() {
329            continue;
330        }
331
332        let columns = rows
333            .into_iter()
334            .map(|(col_name, data_type, udt_name, is_nullable)| {
335                let is_pk = pk_set.contains(&col_name);
336                let sql_type = if data_type.eq_ignore_ascii_case("ARRAY") {
337                    let base = udt_name.trim_start_matches('_');
338                    format!("{}[]", base)
339                } else if data_type.eq_ignore_ascii_case("USER-DEFINED") {
340                    udt_name
341                } else {
342                    data_type
343                };
344                SchemaColumn {
345                    name: col_name,
346                    sql_type,
347                    nullable: !is_pk && is_nullable.eq_ignore_ascii_case("YES"),
348                    primary_key: is_pk,
349                }
350            })
351            .collect();
352
353        let indexes = introspect_postgres_indexes(pool, &name).await?;
354        let foreign_keys = introspect_postgres_foreign_keys(pool, &name).await?;
355
356        tables.push(SchemaTable {
357            name,
358            columns,
359            indexes,
360            foreign_keys,
361            create_sql: None,
362        });
363    }
364
365    Ok(tables)
366}
367
368/// Introspects the schema of a MySQL database.
369#[cfg(feature = "mysql")]
370pub async fn introspect_mysql_schema(pool: &MySqlPool) -> Result<Vec<SchemaTable>, sqlx::Error> {
371    let table_names: Vec<String> = sqlx::query_scalar(
372        "SELECT table_name FROM information_schema.tables WHERE table_schema = DATABASE() AND table_type = 'BASE TABLE' AND table_name != '_premix_migrations' ORDER BY table_name",
373    )
374    .fetch_all(pool)
375    .await?;
376
377    let mut tables = Vec::new();
378    for name in table_names {
379        let pk_cols: Vec<String> = sqlx::query_scalar(
380            "SELECT k.column_name
381             FROM information_schema.table_constraints tc
382             JOIN information_schema.key_column_usage k
383               ON tc.constraint_name = k.constraint_name
384              AND tc.table_schema = k.table_schema
385              AND tc.table_name = k.table_name
386             WHERE tc.constraint_type = 'PRIMARY KEY'
387               AND tc.table_schema = DATABASE()
388               AND tc.table_name = ?
389             ORDER BY k.ordinal_position",
390        )
391        .bind(&name)
392        .fetch_all(pool)
393        .await?;
394        let pk_set: BTreeSet<String> = pk_cols.into_iter().collect();
395
396        let rows: Vec<(String, String, String)> = sqlx::query_as(
397            "SELECT column_name, column_type, is_nullable
398             FROM information_schema.columns
399             WHERE table_schema = DATABASE() AND table_name = ?
400             ORDER BY ordinal_position",
401        )
402        .bind(&name)
403        .fetch_all(pool)
404        .await?;
405
406        if rows.is_empty() {
407            continue;
408        }
409
410        let columns = rows
411            .into_iter()
412            .map(|(col_name, col_type, is_nullable)| {
413                let is_pk = pk_set.contains(&col_name);
414                SchemaColumn {
415                    name: col_name,
416                    sql_type: col_type,
417                    nullable: !is_pk && is_nullable.eq_ignore_ascii_case("YES"),
418                    primary_key: is_pk,
419                }
420            })
421            .collect();
422
423        let indexes = introspect_mysql_indexes(pool, &name).await?;
424        let foreign_keys = introspect_mysql_foreign_keys(pool, &name).await?;
425
426        tables.push(SchemaTable {
427            name,
428            columns,
429            indexes,
430            foreign_keys,
431            create_sql: None,
432        });
433    }
434
435    Ok(tables)
436}
437
438/// Compares the actual SQLite schema with an expected list of tables.
439#[cfg(feature = "sqlite")]
440pub async fn diff_sqlite_schema(
441    pool: &SqlitePool,
442    expected: &[SchemaTable],
443) -> Result<SchemaDiff, sqlx::Error> {
444    let actual = introspect_sqlite_schema(pool).await?;
445    Ok(diff_schema(expected, &actual))
446}
447
448/// Compares the actual PostgreSQL schema with an expected list of tables.
449#[cfg(feature = "postgres")]
450pub async fn diff_postgres_schema(
451    pool: &PgPool,
452    expected: &[SchemaTable],
453) -> Result<SchemaDiff, sqlx::Error> {
454    let actual = introspect_postgres_schema(pool).await?;
455    Ok(diff_schema(expected, &actual))
456}
457
458/// Compares the actual MySQL schema with an expected list of tables.
459#[cfg(feature = "mysql")]
460pub async fn diff_mysql_schema(
461    pool: &MySqlPool,
462    expected: &[SchemaTable],
463) -> Result<SchemaDiff, sqlx::Error> {
464    let actual = introspect_mysql_schema(pool).await?;
465    Ok(diff_schema(expected, &actual))
466}
467
468/// Calculates the difference between two sets of table metadata.
469pub fn diff_schema(expected: &[SchemaTable], actual: &[SchemaTable]) -> SchemaDiff {
470    let mut diff = SchemaDiff::default();
471
472    let expected_map: BTreeMap<_, _> = expected.iter().map(|t| (&t.name, t)).collect();
473    let actual_map: BTreeMap<_, _> = actual.iter().map(|t| (&t.name, t)).collect();
474
475    for name in expected_map.keys() {
476        if !actual_map.contains_key(name) {
477            diff.missing_tables.push((*name).to_string());
478        }
479    }
480    for name in actual_map.keys() {
481        if !expected_map.contains_key(name) {
482            diff.extra_tables.push((*name).to_string());
483        }
484    }
485
486    for (name, expected_table) in &expected_map {
487        let Some(actual_table) = actual_map.get(name) else {
488            continue;
489        };
490
491        let expected_cols: BTreeMap<_, _> = expected_table
492            .columns
493            .iter()
494            .map(|c| (&c.name, c))
495            .collect();
496        let actual_cols: BTreeMap<_, _> =
497            actual_table.columns.iter().map(|c| (&c.name, c)).collect();
498
499        for col in expected_cols.keys() {
500            if !actual_cols.contains_key(col) {
501                let sql_type = expected_cols.get(col).map(|c| c.normalized_type());
502                diff.missing_columns.push(ColumnDiff {
503                    table: (*name).to_string(),
504                    column: (*col).to_string(),
505                    sql_type,
506                });
507            }
508        }
509        for col in actual_cols.keys() {
510            if !expected_cols.contains_key(col) {
511                let sql_type = actual_cols.get(col).map(|c| c.normalized_type());
512                diff.extra_columns.push(ColumnDiff {
513                    table: (*name).to_string(),
514                    column: (*col).to_string(),
515                    sql_type,
516                });
517            }
518        }
519
520        for (col_name, expected_col) in &expected_cols {
521            let Some(actual_col) = actual_cols.get(col_name) else {
522                continue;
523            };
524
525            let expected_type = expected_col.normalized_type();
526            let actual_type = actual_col.normalized_type();
527            if expected_type != actual_type {
528                diff.type_mismatches.push(ColumnTypeDiff {
529                    table: (*name).to_string(),
530                    column: (*col_name).to_string(),
531                    expected: expected_col.sql_type.clone(),
532                    actual: actual_col.sql_type.clone(),
533                });
534            }
535
536            if expected_col.nullable != actual_col.nullable {
537                diff.nullability_mismatches.push(ColumnNullabilityDiff {
538                    table: (*name).to_string(),
539                    column: (*col_name).to_string(),
540                    expected_nullable: expected_col.nullable,
541                    actual_nullable: actual_col.nullable,
542                });
543            }
544
545            if expected_col.primary_key != actual_col.primary_key {
546                diff.primary_key_mismatches.push(ColumnPrimaryKeyDiff {
547                    table: (*name).to_string(),
548                    column: (*col_name).to_string(),
549                    expected_primary_key: expected_col.primary_key,
550                    actual_primary_key: actual_col.primary_key,
551                });
552            }
553        }
554
555        let expected_indexes = index_map(&expected_table.indexes);
556        let actual_indexes = index_map(&actual_table.indexes);
557        for key in expected_indexes.keys() {
558            if !actual_indexes.contains_key(key) {
559                if let Some(index) = expected_indexes.get(key) {
560                    diff.missing_indexes
561                        .push(((*name).to_string(), (*index).clone()));
562                }
563            }
564        }
565        for key in actual_indexes.keys() {
566            if !expected_indexes.contains_key(key) {
567                if let Some(index) = actual_indexes.get(key) {
568                    diff.extra_indexes
569                        .push(((*name).to_string(), (*index).clone()));
570                }
571            }
572        }
573
574        let expected_fks = foreign_key_map(&expected_table.foreign_keys);
575        let actual_fks = foreign_key_map(&actual_table.foreign_keys);
576        for key in expected_fks.keys() {
577            if !actual_fks.contains_key(key) {
578                if let Some(fk) = expected_fks.get(key) {
579                    diff.missing_foreign_keys
580                        .push(((*name).to_string(), (*fk).clone()));
581                }
582            }
583        }
584        for key in actual_fks.keys() {
585            if !expected_fks.contains_key(key) {
586                if let Some(fk) = actual_fks.get(key) {
587                    diff.extra_foreign_keys
588                        .push(((*name).to_string(), (*fk).clone()));
589                }
590            }
591        }
592    }
593
594    diff.missing_tables.sort();
595    diff.extra_tables.sort();
596
597    diff
598}
599
600/// Generates SQLite migration SQL based on the provided schema differences.
601pub fn sqlite_migration_sql(expected: &[SchemaTable], diff: &SchemaDiff) -> Vec<String> {
602    let expected_map: BTreeMap<String, &SchemaTable> =
603        expected.iter().map(|t| (t.name.clone(), t)).collect();
604    let mut statements = Vec::new();
605
606    for table in &diff.missing_tables {
607        if let Some(schema) = expected_map.get(table) {
608            statements.push(schema.to_create_sql());
609            for index in &schema.indexes {
610                statements.push(sqlite_create_index_sql(&schema.name, index));
611            }
612        } else {
613            statements.push(format!("-- Missing schema for table {}", table));
614        }
615    }
616
617    let mut missing_by_table: BTreeMap<String, BTreeSet<String>> = BTreeMap::new();
618    for col in &diff.missing_columns {
619        missing_by_table
620            .entry(col.table.clone())
621            .or_default()
622            .insert(col.column.clone());
623    }
624
625    for (table, columns) in missing_by_table {
626        let Some(schema) = expected_map.get(&table) else {
627            continue;
628        };
629        for col_name in columns {
630            let Some(col) = schema.column(&col_name) else {
631                continue;
632            };
633
634            // Heuristic: Check for potential rename from extra columns
635            // Simple check: Look for an extra column in the same table with the same type
636            // that isn't already "claimed" (though here we just suggest)
637            let potential_rename = diff.extra_columns.iter().find(|e| {
638                e.table == table && e.sql_type.as_deref() == Some(&col.normalized_type())
639            });
640
641            if let Some(old) = potential_rename {
642                statements.push(format!(
643                    "-- SUGESTION: Potential rename from column '{}' to '{}'?",
644                    old.column, col.name
645                ));
646            }
647
648            if col.primary_key {
649                statements.push(format!(
650                    "-- TODO: add primary key column {}.{} manually",
651                    table, col_name
652                ));
653                continue;
654            }
655
656            if !col.nullable {
657                statements.push(format!(
658                    "-- WARNING: Adding NOT NULL column '{}.{}' without a default value will fail if table contains rows.",
659                    table, col.name
660                ));
661            }
662
663            let mut stmt = format!(
664                "ALTER TABLE {} ADD COLUMN {} {}",
665                table, col.name, col.sql_type
666            );
667            if !col.nullable {
668                stmt.push_str(" NOT NULL");
669            }
670            statements.push(stmt);
671        }
672    }
673
674    for mismatch in &diff.type_mismatches {
675        statements.push(format!(
676            "-- TODO: column type mismatch {}.{} (expected {}, actual {})",
677            mismatch.table, mismatch.column, mismatch.expected, mismatch.actual
678        ));
679    }
680    for mismatch in &diff.nullability_mismatches {
681        statements.push(format!(
682            "-- TODO: column nullability mismatch {}.{} (expected nullable {}, actual nullable {})",
683            mismatch.table, mismatch.column, mismatch.expected_nullable, mismatch.actual_nullable
684        ));
685    }
686    for mismatch in &diff.primary_key_mismatches {
687        statements.push(format!(
688            "-- TODO: column primary key mismatch {}.{} (expected pk {}, actual pk {})",
689            mismatch.table,
690            mismatch.column,
691            mismatch.expected_primary_key,
692            mismatch.actual_primary_key
693        ));
694    }
695    for (table, index) in &diff.missing_indexes {
696        statements.push(sqlite_create_index_sql(table, index));
697    }
698    for (table, index) in &diff.extra_indexes {
699        statements.push(format!(
700            "-- TODO: extra index {}.{} ({})",
701            table,
702            index.name,
703            index.columns.join(", ")
704        ));
705    }
706    for (table, fk) in &diff.missing_foreign_keys {
707        statements.push(format!(
708            "-- TODO: add foreign key {}.{} -> {}({}) (requires table rebuild)",
709            table, fk.column, fk.ref_table, fk.ref_column
710        ));
711    }
712    for (table, fk) in &diff.extra_foreign_keys {
713        statements.push(format!(
714            "-- TODO: extra foreign key {}.{} -> {}({})",
715            table, fk.column, fk.ref_table, fk.ref_column
716        ));
717    }
718    for extra in &diff.extra_columns {
719        statements.push(format!(
720            "-- TODO: extra column {}.{} not in models",
721            extra.table, extra.column
722        ));
723    }
724    for table in &diff.extra_tables {
725        statements.push(format!("-- TODO: extra table {} not in models", table));
726    }
727
728    statements
729}
730
731fn normalize_sql_type(sql_type: &str) -> String {
732    let t = sql_type.trim().to_lowercase();
733    if t.is_empty() {
734        return t;
735    }
736    if t.contains("int") || t.contains("serial") {
737        return "integer".to_string();
738    }
739    if t.contains("char") || t.contains("text") || t.contains("clob") {
740        return "text".to_string();
741    }
742    if t.contains("real")
743        || t.contains("floa")
744        || t.contains("doub")
745        || t.contains("numeric")
746        || t.contains("decimal")
747    {
748        return "real".to_string();
749    }
750    if t.contains("bool") {
751        return "boolean".to_string();
752    }
753    if t.contains("time") || t.contains("date") || t.contains("uuid") || t.contains("json") {
754        return "text".to_string();
755    }
756    t
757}
758
759#[cfg(feature = "postgres")]
760/// Generates PostgreSQL migration SQL for a given schema difference.
761pub fn postgres_migration_sql(expected: &[SchemaTable], diff: &SchemaDiff) -> Vec<String> {
762    let expected_map: BTreeMap<String, &SchemaTable> =
763        expected.iter().map(|t| (t.name.clone(), t)).collect();
764    let mut statements = Vec::new();
765
766    for table in &diff.missing_tables {
767        if let Some(schema) = expected_map.get(table) {
768            statements.push(schema.to_create_sql());
769            for index in &schema.indexes {
770                statements.push(postgres_create_index_sql(&schema.name, index));
771            }
772        } else {
773            statements.push(format!("-- Missing schema for table {}", table));
774        }
775    }
776
777    let mut missing_by_table: BTreeMap<String, BTreeSet<String>> = BTreeMap::new();
778    for col in &diff.missing_columns {
779        missing_by_table
780            .entry(col.table.clone())
781            .or_default()
782            .insert(col.column.clone());
783    }
784
785    for (table, columns) in missing_by_table {
786        let Some(schema) = expected_map.get(&table) else {
787            continue;
788        };
789        for col_name in columns {
790            let Some(col) = schema.column(&col_name) else {
791                continue;
792            };
793
794            // Heuristic: Check for potential rename
795            let potential_rename = diff.extra_columns.iter().find(|e| {
796                e.table == table && e.sql_type.as_deref() == Some(&col.normalized_type())
797            });
798
799            if let Some(old) = potential_rename {
800                statements.push(format!(
801                    "-- SUGESTION: Potential rename from column '{}' to '{}'?",
802                    old.column, col.name
803                ));
804            }
805
806            if col.primary_key {
807                statements.push(format!(
808                    "-- TODO: add primary key column {}.{} manually",
809                    table, col_name
810                ));
811                continue;
812            }
813
814            if !col.nullable {
815                statements.push(format!(
816                    "-- WARNING: Adding NOT NULL column '{}.{}' without a default value will fail if table contains rows.",
817                    table, col.name
818                ));
819            }
820
821            let mut stmt = format!(
822                "ALTER TABLE {} ADD COLUMN {} {}",
823                table, col.name, col.sql_type
824            );
825            if !col.nullable {
826                stmt.push_str(" NOT NULL");
827            }
828            statements.push(stmt);
829        }
830    }
831
832    for mismatch in &diff.type_mismatches {
833        statements.push(format!(
834            "-- TODO: column type mismatch {}.{} (expected {}, actual {})",
835            mismatch.table, mismatch.column, mismatch.expected, mismatch.actual
836        ));
837    }
838    for mismatch in &diff.nullability_mismatches {
839        statements.push(format!(
840            "-- TODO: column nullability mismatch {}.{} (expected nullable {}, actual nullable {})",
841            mismatch.table, mismatch.column, mismatch.expected_nullable, mismatch.actual_nullable
842        ));
843    }
844    for mismatch in &diff.primary_key_mismatches {
845        statements.push(format!(
846            "-- TODO: column primary key mismatch {}.{} (expected pk {}, actual pk {})",
847            mismatch.table,
848            mismatch.column,
849            mismatch.expected_primary_key,
850            mismatch.actual_primary_key
851        ));
852    }
853    for (table, index) in &diff.missing_indexes {
854        statements.push(postgres_create_index_sql(table, index));
855    }
856    for (table, index) in &diff.extra_indexes {
857        statements.push(format!(
858            "-- TODO: extra index {}.{} ({})",
859            table,
860            index.name,
861            index.columns.join(", ")
862        ));
863    }
864    for (table, fk) in &diff.missing_foreign_keys {
865        let fk_name = format!("fk_{}_{}", table, fk.column);
866        statements.push(format!(
867            "ALTER TABLE {} ADD CONSTRAINT {} FOREIGN KEY ({}) REFERENCES {}({})",
868            table, fk_name, fk.column, fk.ref_table, fk.ref_column
869        ));
870    }
871    for (table, fk) in &diff.extra_foreign_keys {
872        statements.push(format!(
873            "-- TODO: extra foreign key {}.{} -> {}({})",
874            table, fk.column, fk.ref_table, fk.ref_column
875        ));
876    }
877    for extra in &diff.extra_columns {
878        statements.push(format!(
879            "-- TODO: extra column {}.{} not in models",
880            extra.table, extra.column
881        ));
882    }
883    for table in &diff.extra_tables {
884        statements.push(format!("-- TODO: extra table {} not in models", table));
885    }
886
887    statements
888}
889
890#[cfg(feature = "mysql")]
891/// Generates MySQL migration SQL for a given schema difference.
892pub fn mysql_migration_sql(expected: &[SchemaTable], diff: &SchemaDiff) -> Vec<String> {
893    let expected_map: BTreeMap<String, &SchemaTable> =
894        expected.iter().map(|t| (t.name.clone(), t)).collect();
895    let mut statements = Vec::new();
896
897    for table in &diff.missing_tables {
898        if let Some(schema) = expected_map.get(table) {
899            statements.push(schema.to_create_sql());
900            for index in &schema.indexes {
901                statements.push(mysql_create_index_sql(&schema.name, index));
902            }
903        } else {
904            statements.push(format!("-- Missing schema for table {}", table));
905        }
906    }
907
908    let mut missing_by_table: BTreeMap<String, BTreeSet<String>> = BTreeMap::new();
909    for col in &diff.missing_columns {
910        missing_by_table
911            .entry(col.table.clone())
912            .or_default()
913            .insert(col.column.clone());
914    }
915
916    for (table, columns) in missing_by_table {
917        let Some(schema) = expected_map.get(&table) else {
918            continue;
919        };
920        for col_name in columns {
921            let Some(col) = schema.column(&col_name) else {
922                continue;
923            };
924
925            let potential_rename = diff.extra_columns.iter().find(|e| {
926                e.table == table && e.sql_type.as_deref() == Some(&col.normalized_type())
927            });
928
929            if let Some(old) = potential_rename {
930                statements.push(format!(
931                    "-- SUGESTION: Potential rename from column '{}' to '{}'?",
932                    old.column, col.name
933                ));
934            }
935
936            if col.primary_key {
937                statements.push(format!(
938                    "-- TODO: add primary key column {}.{} manually",
939                    table, col_name
940                ));
941                continue;
942            }
943
944            if !col.nullable {
945                statements.push(format!(
946                    "-- WARNING: Adding NOT NULL column '{}.{}' without a default value will fail if table contains rows.",
947                    table, col.name
948                ));
949            }
950
951            let mut stmt = format!(
952                "ALTER TABLE {} ADD COLUMN {} {}",
953                table, col.name, col.sql_type
954            );
955            if !col.nullable {
956                stmt.push_str(" NOT NULL");
957            }
958            statements.push(stmt);
959        }
960    }
961
962    for mismatch in &diff.type_mismatches {
963        statements.push(format!(
964            "-- TODO: column type mismatch {}.{} (expected {}, actual {})",
965            mismatch.table, mismatch.column, mismatch.expected, mismatch.actual
966        ));
967    }
968    for mismatch in &diff.nullability_mismatches {
969        statements.push(format!(
970            "-- TODO: column nullability mismatch {}.{} (expected nullable {}, actual nullable {})",
971            mismatch.table, mismatch.column, mismatch.expected_nullable, mismatch.actual_nullable
972        ));
973    }
974    for mismatch in &diff.primary_key_mismatches {
975        statements.push(format!(
976            "-- TODO: column primary key mismatch {}.{} (expected pk {}, actual pk {})",
977            mismatch.table,
978            mismatch.column,
979            mismatch.expected_primary_key,
980            mismatch.actual_primary_key
981        ));
982    }
983    for (table, index) in &diff.missing_indexes {
984        statements.push(mysql_create_index_sql(table, index));
985    }
986    for (table, index) in &diff.extra_indexes {
987        statements.push(format!(
988            "-- TODO: extra index {}.{} ({})",
989            table,
990            index.name,
991            index.columns.join(", ")
992        ));
993    }
994    for (table, fk) in &diff.missing_foreign_keys {
995        let fk_name = format!("fk_{}_{}", table, fk.column);
996        statements.push(format!(
997            "ALTER TABLE {} ADD CONSTRAINT {} FOREIGN KEY ({}) REFERENCES {}({})",
998            table, fk_name, fk.column, fk.ref_table, fk.ref_column
999        ));
1000    }
1001    for (table, fk) in &diff.extra_foreign_keys {
1002        statements.push(format!(
1003            "-- TODO: extra foreign key {}.{} -> {}({})",
1004            table, fk.column, fk.ref_table, fk.ref_column
1005        ));
1006    }
1007    for extra in &diff.extra_columns {
1008        statements.push(format!(
1009            "-- TODO: extra column {}.{} not in models",
1010            extra.table, extra.column
1011        ));
1012    }
1013    for table in &diff.extra_tables {
1014        statements.push(format!("-- TODO: extra table {} not in models", table));
1015    }
1016
1017    statements
1018}
1019
1020fn index_key(index: &SchemaIndex) -> (String, String, bool) {
1021    let name = index.name.clone();
1022    let cols = index.columns.join(",");
1023    (name, cols, index.unique)
1024}
1025
1026fn index_map(indexes: &[SchemaIndex]) -> BTreeMap<(String, String, bool), &SchemaIndex> {
1027    indexes.iter().map(|i| (index_key(i), i)).collect()
1028}
1029
1030fn foreign_key_key(fk: &SchemaForeignKey) -> (String, String, String) {
1031    (
1032        fk.column.clone(),
1033        fk.ref_table.clone(),
1034        fk.ref_column.clone(),
1035    )
1036}
1037
1038fn foreign_key_map(
1039    fks: &[SchemaForeignKey],
1040) -> BTreeMap<(String, String, String), &SchemaForeignKey> {
1041    fks.iter().map(|f| (foreign_key_key(f), f)).collect()
1042}
1043
1044#[cfg(feature = "sqlite")]
1045async fn introspect_sqlite_indexes(
1046    pool: &SqlitePool,
1047    table: &str,
1048) -> Result<Vec<SchemaIndex>, sqlx::Error> {
1049    let sql = format!("PRAGMA index_list({})", table);
1050    let rows: Vec<(i64, String, i64, String, i64)> = sqlx::query_as(&sql).fetch_all(pool).await?;
1051
1052    let mut indexes = Vec::new();
1053    for (_seq, name, unique, origin, _partial) in rows {
1054        if origin == "pk" || name.starts_with("sqlite_autoindex") {
1055            continue;
1056        }
1057        let info_sql = format!("PRAGMA index_info({})", name);
1058        let info_rows: Vec<(i64, i64, String)> = sqlx::query_as(&info_sql).fetch_all(pool).await?;
1059        let columns = info_rows.into_iter().map(|(_seq, _cid, col)| col).collect();
1060        indexes.push(SchemaIndex {
1061            name,
1062            columns,
1063            unique: unique != 0,
1064        });
1065    }
1066    Ok(indexes)
1067}
1068
1069#[cfg(feature = "sqlite")]
1070async fn introspect_sqlite_foreign_keys(
1071    pool: &SqlitePool,
1072    table: &str,
1073) -> Result<Vec<SchemaForeignKey>, sqlx::Error> {
1074    let sql = format!("PRAGMA foreign_key_list({})", table);
1075    #[allow(clippy::type_complexity)]
1076    let rows: Vec<(i64, i64, String, String, String, String, String, String)> =
1077        sqlx::query_as(&sql).fetch_all(pool).await?;
1078
1079    let mut fks = Vec::new();
1080    for (_id, _seq, ref_table, from, to, _on_update, _on_delete, _match) in rows {
1081        fks.push(SchemaForeignKey {
1082            column: from,
1083            ref_table,
1084            ref_column: to,
1085        });
1086    }
1087    Ok(fks)
1088}
1089
1090fn sqlite_create_index_sql(table: &str, index: &SchemaIndex) -> String {
1091    let unique = if index.unique { "UNIQUE " } else { "" };
1092    let name = if index.name.is_empty() {
1093        format!("idx_{}_{}", table, index.columns.join("_"))
1094    } else {
1095        index.name.clone()
1096    };
1097    format!(
1098        "CREATE {}INDEX IF NOT EXISTS {} ON {} ({})",
1099        unique,
1100        name,
1101        table,
1102        index.columns.join(", ")
1103    )
1104}
1105
1106#[cfg(feature = "postgres")]
1107fn postgres_create_index_sql(table: &str, index: &SchemaIndex) -> String {
1108    let unique = if index.unique { "UNIQUE " } else { "" };
1109    let name = if index.name.is_empty() {
1110        format!("idx_{}_{}", table, index.columns.join("_"))
1111    } else {
1112        index.name.clone()
1113    };
1114    format!(
1115        "CREATE {}INDEX IF NOT EXISTS {} ON {} ({})",
1116        unique,
1117        name,
1118        table,
1119        index.columns.join(", ")
1120    )
1121}
1122
1123#[cfg(feature = "mysql")]
1124fn mysql_create_index_sql(table: &str, index: &SchemaIndex) -> String {
1125    let unique = if index.unique { "UNIQUE " } else { "" };
1126    let name = if index.name.is_empty() {
1127        format!("idx_{}_{}", table, index.columns.join("_"))
1128    } else {
1129        index.name.clone()
1130    };
1131    format!(
1132        "CREATE {}INDEX {} ON {} ({})",
1133        unique,
1134        name,
1135        table,
1136        index.columns.join(", ")
1137    )
1138}
1139
1140#[cfg(feature = "postgres")]
1141async fn introspect_postgres_indexes(
1142    pool: &PgPool,
1143    table: &str,
1144) -> Result<Vec<SchemaIndex>, sqlx::Error> {
1145    let rows: Vec<(String, bool, Vec<String>)> = sqlx::query_as(
1146        "SELECT i.relname AS index_name, ix.indisunique, array_agg(a.attname ORDER BY x.n) AS columns
1147         FROM pg_class t
1148         JOIN pg_index ix ON t.oid = ix.indrelid
1149         JOIN pg_class i ON i.oid = ix.indexrelid
1150         JOIN LATERAL unnest(ix.indkey) WITH ORDINALITY AS x(attnum, n) ON true
1151         JOIN pg_attribute a ON a.attrelid = t.oid AND a.attnum = x.attnum
1152         WHERE t.relname = $1 AND t.relkind = 'r' AND NOT ix.indisprimary
1153         GROUP BY i.relname, ix.indisunique
1154         ORDER BY i.relname",
1155    )
1156    .bind(table)
1157    .fetch_all(pool)
1158    .await?;
1159
1160    let indexes = rows
1161        .into_iter()
1162        .map(|(name, unique, columns)| SchemaIndex {
1163            name,
1164            columns,
1165            unique,
1166        })
1167        .collect();
1168    Ok(indexes)
1169}
1170
1171#[cfg(feature = "postgres")]
1172async fn introspect_postgres_foreign_keys(
1173    pool: &PgPool,
1174    table: &str,
1175) -> Result<Vec<SchemaForeignKey>, sqlx::Error> {
1176    let rows: Vec<(String, String, String)> = sqlx::query_as(
1177        "SELECT kcu.column_name, ccu.table_name, ccu.column_name
1178         FROM information_schema.table_constraints tc
1179         JOIN information_schema.key_column_usage kcu
1180           ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema
1181         JOIN information_schema.constraint_column_usage ccu
1182           ON ccu.constraint_name = tc.constraint_name AND ccu.table_schema = tc.table_schema
1183         WHERE tc.constraint_type = 'FOREIGN KEY'
1184           AND tc.table_schema = 'public'
1185           AND tc.table_name = $1
1186         ORDER BY kcu.ordinal_position",
1187    )
1188    .bind(table)
1189    .fetch_all(pool)
1190    .await?;
1191
1192    let fks = rows
1193        .into_iter()
1194        .map(|(column, ref_table, ref_column)| SchemaForeignKey {
1195            column,
1196            ref_table,
1197            ref_column,
1198        })
1199        .collect();
1200
1201    Ok(fks)
1202}
1203
1204#[cfg(feature = "mysql")]
1205async fn introspect_mysql_indexes(
1206    pool: &MySqlPool,
1207    table: &str,
1208) -> Result<Vec<SchemaIndex>, sqlx::Error> {
1209    let rows: Vec<(String, i64, Option<String>)> = sqlx::query_as(
1210        "SELECT index_name, non_unique, GROUP_CONCAT(column_name ORDER BY seq_in_index) AS columns
1211         FROM information_schema.statistics
1212         WHERE table_schema = DATABASE() AND table_name = ? AND index_name != 'PRIMARY'
1213         GROUP BY index_name, non_unique
1214         ORDER BY index_name",
1215    )
1216    .bind(table)
1217    .fetch_all(pool)
1218    .await?;
1219
1220    let mut indexes = Vec::new();
1221    for (name, non_unique, columns) in rows {
1222        let columns = columns
1223            .unwrap_or_default()
1224            .split(',')
1225            .filter(|col| !col.is_empty())
1226            .map(|col| col.to_string())
1227            .collect::<Vec<_>>();
1228        if columns.is_empty() {
1229            continue;
1230        }
1231        indexes.push(SchemaIndex {
1232            name,
1233            columns,
1234            unique: non_unique == 0,
1235        });
1236    }
1237
1238    Ok(indexes)
1239}
1240
1241#[cfg(feature = "mysql")]
1242async fn introspect_mysql_foreign_keys(
1243    pool: &MySqlPool,
1244    table: &str,
1245) -> Result<Vec<SchemaForeignKey>, sqlx::Error> {
1246    let rows: Vec<(String, String, String)> = sqlx::query_as(
1247        "SELECT column_name, referenced_table_name, referenced_column_name
1248         FROM information_schema.key_column_usage
1249         WHERE table_schema = DATABASE()
1250           AND table_name = ?
1251           AND referenced_table_name IS NOT NULL
1252         ORDER BY ordinal_position",
1253    )
1254    .bind(table)
1255    .fetch_all(pool)
1256    .await?;
1257
1258    let fks = rows
1259        .into_iter()
1260        .map(|(column, ref_table, ref_column)| SchemaForeignKey {
1261            column,
1262            ref_table,
1263            ref_column,
1264        })
1265        .collect();
1266
1267    Ok(fks)
1268}
1269
1270#[cfg(test)]
1271mod tests {
1272    use super::*;
1273
1274    #[cfg(feature = "sqlite")]
1275    #[tokio::test]
1276    async fn sqlite_introspect_and_diff_empty() {
1277        let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
1278        sqlx::query(
1279            "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL, deleted_at TEXT)",
1280        )
1281        .execute(&pool)
1282        .await
1283        .unwrap();
1284
1285        let expected = vec![SchemaTable {
1286            name: "users".to_string(),
1287            columns: vec![
1288                SchemaColumn {
1289                    name: "id".to_string(),
1290                    sql_type: "INTEGER".to_string(),
1291                    nullable: false,
1292                    primary_key: true,
1293                },
1294                SchemaColumn {
1295                    name: "name".to_string(),
1296                    sql_type: "TEXT".to_string(),
1297                    nullable: false,
1298                    primary_key: false,
1299                },
1300                SchemaColumn {
1301                    name: "deleted_at".to_string(),
1302                    sql_type: "TEXT".to_string(),
1303                    nullable: true,
1304                    primary_key: false,
1305                },
1306            ],
1307            indexes: Vec::new(),
1308            foreign_keys: Vec::new(),
1309            create_sql: None,
1310        }];
1311
1312        let actual = introspect_sqlite_schema(&pool).await.unwrap();
1313        let diff = diff_schema(&expected, &actual);
1314        assert!(diff.is_empty());
1315    }
1316
1317    #[cfg(feature = "sqlite")]
1318    #[tokio::test]
1319    async fn sqlite_diff_reports_missing_column() {
1320        let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
1321        sqlx::query("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)")
1322            .execute(&pool)
1323            .await
1324            .unwrap();
1325
1326        let expected = vec![SchemaTable {
1327            name: "users".to_string(),
1328            columns: vec![
1329                SchemaColumn {
1330                    name: "id".to_string(),
1331                    sql_type: "INTEGER".to_string(),
1332                    nullable: false,
1333                    primary_key: true,
1334                },
1335                SchemaColumn {
1336                    name: "name".to_string(),
1337                    sql_type: "TEXT".to_string(),
1338                    nullable: false,
1339                    primary_key: false,
1340                },
1341                SchemaColumn {
1342                    name: "status".to_string(),
1343                    sql_type: "TEXT".to_string(),
1344                    nullable: true,
1345                    primary_key: false,
1346                },
1347            ],
1348            indexes: Vec::new(),
1349            foreign_keys: Vec::new(),
1350            create_sql: None,
1351        }];
1352
1353        let actual = introspect_sqlite_schema(&pool).await.unwrap();
1354        let diff = diff_schema(&expected, &actual);
1355        assert_eq!(diff.missing_columns.len(), 1);
1356
1357        let summary = format_schema_diff_summary(&diff);
1358        assert!(summary.contains("missing columns: 1"));
1359
1360        let sql = sqlite_migration_sql(&expected, &diff);
1361        assert!(
1362            sql.iter()
1363                .any(|stmt| stmt.contains("ALTER TABLE users ADD COLUMN status"))
1364        );
1365    }
1366
1367    #[cfg(feature = "sqlite")]
1368    #[tokio::test]
1369    async fn sqlite_diff_reports_missing_index() {
1370        let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
1371        sqlx::query("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)")
1372            .execute(&pool)
1373            .await
1374            .unwrap();
1375
1376        let expected = vec![SchemaTable {
1377            name: "users".to_string(),
1378            columns: vec![
1379                SchemaColumn {
1380                    name: "id".to_string(),
1381                    sql_type: "INTEGER".to_string(),
1382                    nullable: false,
1383                    primary_key: true,
1384                },
1385                SchemaColumn {
1386                    name: "name".to_string(),
1387                    sql_type: "TEXT".to_string(),
1388                    nullable: false,
1389                    primary_key: false,
1390                },
1391            ],
1392            indexes: vec![SchemaIndex {
1393                name: "idx_users_name".to_string(),
1394                columns: vec!["name".to_string()],
1395                unique: false,
1396            }],
1397            foreign_keys: Vec::new(),
1398            create_sql: None,
1399        }];
1400
1401        let actual = introspect_sqlite_schema(&pool).await.unwrap();
1402        let diff = diff_schema(&expected, &actual);
1403        assert_eq!(diff.missing_indexes.len(), 1);
1404
1405        let sql = sqlite_migration_sql(&expected, &diff);
1406        assert!(
1407            sql.iter()
1408                .any(|stmt| stmt.contains("CREATE INDEX IF NOT EXISTS idx_users_name"))
1409        );
1410    }
1411
1412    #[cfg(feature = "postgres")]
1413    fn pg_url() -> String {
1414        std::env::var("DATABASE_URL").unwrap_or_else(|_| {
1415            "postgres://postgres:admin123@localhost:5432/premix_bench".to_string()
1416        })
1417    }
1418
1419    #[cfg(feature = "postgres")]
1420    #[tokio::test]
1421    async fn postgres_introspect_and_diff() {
1422        let url = pg_url();
1423        let pool = match PgPool::connect(&url).await {
1424            Ok(pool) => pool,
1425            Err(_) => return,
1426        };
1427
1428        sqlx::query("DROP TABLE IF EXISTS schema_posts")
1429            .execute(&pool)
1430            .await
1431            .unwrap();
1432        sqlx::query("DROP TABLE IF EXISTS schema_users")
1433            .execute(&pool)
1434            .await
1435            .unwrap();
1436
1437        sqlx::query("CREATE TABLE schema_users (id SERIAL PRIMARY KEY, name TEXT NOT NULL)")
1438            .execute(&pool)
1439            .await
1440            .unwrap();
1441        sqlx::query(
1442            "CREATE TABLE schema_posts (id SERIAL PRIMARY KEY, user_id INTEGER NOT NULL, title TEXT NOT NULL, CONSTRAINT fk_schema_posts_user_id FOREIGN KEY (user_id) REFERENCES schema_users(id))",
1443        )
1444        .execute(&pool)
1445        .await
1446        .unwrap();
1447        sqlx::query("CREATE INDEX idx_schema_posts_user_id ON schema_posts (user_id)")
1448            .execute(&pool)
1449            .await
1450            .unwrap();
1451
1452        let expected = vec![
1453            SchemaTable {
1454                name: "schema_posts".to_string(),
1455                columns: vec![
1456                    SchemaColumn {
1457                        name: "id".to_string(),
1458                        sql_type: "INTEGER".to_string(),
1459                        nullable: false,
1460                        primary_key: true,
1461                    },
1462                    SchemaColumn {
1463                        name: "user_id".to_string(),
1464                        sql_type: "INTEGER".to_string(),
1465                        nullable: false,
1466                        primary_key: false,
1467                    },
1468                    SchemaColumn {
1469                        name: "title".to_string(),
1470                        sql_type: "TEXT".to_string(),
1471                        nullable: false,
1472                        primary_key: false,
1473                    },
1474                ],
1475                indexes: vec![SchemaIndex {
1476                    name: "idx_schema_posts_user_id".to_string(),
1477                    columns: vec!["user_id".to_string()],
1478                    unique: false,
1479                }],
1480                foreign_keys: vec![SchemaForeignKey {
1481                    column: "user_id".to_string(),
1482                    ref_table: "schema_users".to_string(),
1483                    ref_column: "id".to_string(),
1484                }],
1485                create_sql: None,
1486            },
1487            SchemaTable {
1488                name: "schema_users".to_string(),
1489                columns: vec![
1490                    SchemaColumn {
1491                        name: "id".to_string(),
1492                        sql_type: "INTEGER".to_string(),
1493                        nullable: false,
1494                        primary_key: true,
1495                    },
1496                    SchemaColumn {
1497                        name: "name".to_string(),
1498                        sql_type: "TEXT".to_string(),
1499                        nullable: false,
1500                        primary_key: false,
1501                    },
1502                ],
1503                indexes: Vec::new(),
1504                foreign_keys: Vec::new(),
1505                create_sql: None,
1506            },
1507        ];
1508
1509        let actual = introspect_postgres_schema(&pool).await.unwrap();
1510        let expected_names: BTreeSet<String> =
1511            expected.iter().map(|table| table.name.clone()).collect();
1512        let actual = actual
1513            .into_iter()
1514            .filter(|table| expected_names.contains(&table.name))
1515            .collect::<Vec<_>>();
1516        let diff = diff_schema(&expected, &actual);
1517        assert!(diff.is_empty(), "postgres schema diff: {diff:?}");
1518
1519        sqlx::query("DROP TABLE IF EXISTS schema_posts")
1520            .execute(&pool)
1521            .await
1522            .unwrap();
1523        sqlx::query("DROP TABLE IF EXISTS schema_users")
1524            .execute(&pool)
1525            .await
1526            .unwrap();
1527    }
1528}