Skip to main content

premix_core/
schema.rs

1use std::collections::{BTreeMap, BTreeSet};
2
3#[cfg(feature = "postgres")]
4use sqlx::PgPool;
5#[cfg(feature = "sqlite")]
6use sqlx::SqlitePool;
7
8/// Metadata about a database column.
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub struct SchemaColumn {
11    /// The name of the column.
12    pub name: String,
13    /// The SQL type of the column (e.g., "INTEGER", "TEXT").
14    pub sql_type: String,
15    /// Whether the column can contain NULL values.
16    pub nullable: bool,
17    /// Whether the column is part of the Primary Key.
18    pub primary_key: bool,
19}
20
21impl SchemaColumn {
22    fn normalized_type(&self) -> String {
23        normalize_sql_type(&self.sql_type)
24    }
25}
26
27/// Metadata about a database index.
28#[derive(Debug, Clone, PartialEq, Eq)]
29pub struct SchemaIndex {
30    /// The name of the index.
31    pub name: String,
32    /// The columns included in the index.
33    pub columns: Vec<String>,
34    /// Whether the index is UNIQUE.
35    pub unique: bool,
36}
37
38/// Metadata about a foreign key relationship.
39#[derive(Debug, Clone, PartialEq, Eq)]
40pub struct SchemaForeignKey {
41    /// The column in the current table.
42    pub column: String,
43    /// The table being referenced.
44    pub ref_table: String,
45    /// The column being referenced in the target table.
46    pub ref_column: String,
47}
48
49/// Metadata about a database table.
50#[derive(Debug, Clone, PartialEq, Eq)]
51pub struct SchemaTable {
52    /// The name of the table.
53    pub name: String,
54    /// The columns in the table.
55    pub columns: Vec<SchemaColumn>,
56    /// The indexes on the table.
57    pub indexes: Vec<SchemaIndex>,
58    /// The foreign keys in the table.
59    pub foreign_keys: Vec<SchemaForeignKey>,
60    /// The original CREATE TABLE SQL (if available).
61    pub create_sql: Option<String>,
62}
63
64impl SchemaTable {
65    /// Returns a column by name if it exists in the table.
66    pub fn column(&self, name: &str) -> Option<&SchemaColumn> {
67        self.columns.iter().find(|c| c.name == name)
68    }
69
70    /// Generates a `CREATE TABLE` SQL statement for this table.
71    pub fn to_create_sql(&self) -> String {
72        if let Some(sql) = &self.create_sql {
73            return sql.clone();
74        }
75
76        let mut cols = Vec::new();
77        for col in &self.columns {
78            if col.primary_key {
79                cols.push(format!("{} INTEGER PRIMARY KEY", col.name));
80                continue;
81            }
82            let mut def = format!("{} {}", col.name, col.sql_type);
83            if !col.nullable {
84                def.push_str(" NOT NULL");
85            }
86            cols.push(def);
87        }
88
89        format!(
90            "CREATE TABLE IF NOT EXISTS {} ({})",
91            self.name,
92            cols.join(", ")
93        )
94    }
95}
96
97/// A trait for models that can provide their own schema metadata.
98pub trait ModelSchema {
99    /// Returns the schema metadata for this model.
100    fn schema() -> SchemaTable;
101}
102
103/// Helper macro to collect schema metadata from multiple models.
104#[macro_export]
105macro_rules! schema_models {
106    ($($model:ty),+ $(,)?) => {
107        vec![$(<$model as $crate::schema::ModelSchema>::schema()),+]
108    };
109}
110
111/// Represents a change in a column (addition or removal).
112#[derive(Debug, Clone, PartialEq, Eq)]
113pub struct ColumnDiff {
114    /// The table containing the column.
115    pub table: String,
116    /// The name of the column.
117    pub column: String,
118    /// The SQL type of the column.
119    pub sql_type: Option<String>,
120}
121
122/// Represents a mismatch in column types between models and database.
123#[derive(Debug, Clone, PartialEq, Eq)]
124pub struct ColumnTypeDiff {
125    /// The table containing the column.
126    pub table: String,
127    /// The name of the column.
128    pub column: String,
129    /// The SQL type expected by the model.
130    pub expected: String,
131    /// The actual SQL type found in the database.
132    pub actual: String,
133}
134
135/// Represents a mismatch in column nullability.
136#[derive(Debug, Clone, PartialEq, Eq)]
137pub struct ColumnNullabilityDiff {
138    /// The table containing the column.
139    pub table: String,
140    /// The name of the column.
141    pub column: String,
142    /// Whether the model expects the column to be nullable.
143    pub expected_nullable: bool,
144    /// Whether the column is actually nullable in the database.
145    pub actual_nullable: bool,
146}
147
148/// Represents a mismatch in Primary Key status.
149#[derive(Debug, Clone, PartialEq, Eq)]
150pub struct ColumnPrimaryKeyDiff {
151    /// The table containing the column.
152    pub table: String,
153    /// The name of the column.
154    pub column: String,
155    /// Whether the model expects the column to be a Primary Key.
156    pub expected_primary_key: bool,
157    /// Whether the column is actually a Primary Key in the database.
158    pub actual_primary_key: bool,
159}
160
161/// Represents the differences between two database schemas.
162#[derive(Debug, Clone, PartialEq, Eq, Default)]
163pub struct SchemaDiff {
164    /// Tables expected but missing in the actual database.
165    pub missing_tables: Vec<String>,
166    /// Tables present in the database but not in the models.
167    pub extra_tables: Vec<String>,
168    /// Columns missing in existing tables.
169    pub missing_columns: Vec<ColumnDiff>,
170    /// Columns present in the database but not in the models.
171    pub extra_columns: Vec<ColumnDiff>,
172    /// Columns with different types than expected.
173    pub type_mismatches: Vec<ColumnTypeDiff>,
174    /// Columns with different nullability than expected.
175    pub nullability_mismatches: Vec<ColumnNullabilityDiff>,
176    /// Columns with different Primary Key status than expected.
177    pub primary_key_mismatches: Vec<ColumnPrimaryKeyDiff>,
178    /// Indexes missing in the actual database.
179    pub missing_indexes: Vec<(String, SchemaIndex)>,
180    /// Indexes present in the database but not in the models.
181    pub extra_indexes: Vec<(String, SchemaIndex)>,
182    /// Foreign keys missing in the actual database.
183    pub missing_foreign_keys: Vec<(String, SchemaForeignKey)>,
184    /// Foreign keys present in the database but not in the models.
185    pub extra_foreign_keys: Vec<(String, SchemaForeignKey)>,
186}
187
188impl SchemaDiff {
189    /// Returns true if there are no differences.
190    pub fn is_empty(&self) -> bool {
191        self.missing_tables.is_empty()
192            && self.extra_tables.is_empty()
193            && self.missing_columns.is_empty()
194            && self.extra_columns.is_empty()
195            && self.type_mismatches.is_empty()
196            && self.nullability_mismatches.is_empty()
197            && self.primary_key_mismatches.is_empty()
198            && self.missing_indexes.is_empty()
199            && self.extra_indexes.is_empty()
200            && self.missing_foreign_keys.is_empty()
201            && self.extra_foreign_keys.is_empty()
202    }
203}
204
205/// Formats a [`SchemaDiff`] into a human-readable summary.
206pub fn format_schema_diff_summary(diff: &SchemaDiff) -> String {
207    if diff.is_empty() {
208        return "Schema diff: no changes".to_string();
209    }
210
211    let mut lines = Vec::new();
212    lines.push("Schema diff summary:".to_string());
213    lines.push(format!("  missing tables: {}", diff.missing_tables.len()));
214    lines.push(format!("  extra tables: {}", diff.extra_tables.len()));
215    lines.push(format!("  missing columns: {}", diff.missing_columns.len()));
216    lines.push(format!("  extra columns: {}", diff.extra_columns.len()));
217    lines.push(format!("  type mismatches: {}", diff.type_mismatches.len()));
218    lines.push(format!(
219        "  nullability mismatches: {}",
220        diff.nullability_mismatches.len()
221    ));
222    lines.push(format!(
223        "  primary key mismatches: {}",
224        diff.primary_key_mismatches.len()
225    ));
226    lines.push(format!("  missing indexes: {}", diff.missing_indexes.len()));
227    lines.push(format!("  extra indexes: {}", diff.extra_indexes.len()));
228    lines.push(format!(
229        "  missing foreign keys: {}",
230        diff.missing_foreign_keys.len()
231    ));
232    lines.push(format!(
233        "  extra foreign keys: {}",
234        diff.extra_foreign_keys.len()
235    ));
236
237    if !diff.missing_tables.is_empty() {
238        lines.push(format!(
239            "  missing tables list: {}",
240            diff.missing_tables.join(", ")
241        ));
242    }
243    if !diff.extra_tables.is_empty() {
244        lines.push(format!(
245            "  extra tables list: {}",
246            diff.extra_tables.join(", ")
247        ));
248    }
249
250    lines.join("\n")
251}
252
253/// Introspects the schema of a SQLite database.
254#[cfg(feature = "sqlite")]
255pub async fn introspect_sqlite_schema(pool: &SqlitePool) -> Result<Vec<SchemaTable>, sqlx::Error> {
256    let table_names: Vec<String> = sqlx::query_scalar(
257        "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' AND name != '_premix_migrations' ORDER BY name",
258    )
259    .fetch_all(pool)
260    .await?;
261
262    let mut tables = Vec::new();
263    for name in table_names {
264        let pragma_sql = format!("PRAGMA table_info({})", name);
265        let rows: Vec<(i64, String, String, i64, Option<String>, i64)> =
266            sqlx::query_as(&pragma_sql).fetch_all(pool).await?;
267
268        if rows.is_empty() {
269            continue;
270        }
271
272        let columns = rows
273            .into_iter()
274            .map(|(_cid, col_name, col_type, notnull, _default, pk)| {
275                let is_pk = pk > 0;
276                SchemaColumn {
277                    name: col_name,
278                    sql_type: col_type,
279                    nullable: !is_pk && notnull == 0,
280                    primary_key: is_pk,
281                }
282            })
283            .collect();
284
285        let indexes = introspect_sqlite_indexes(pool, &name).await?;
286        let foreign_keys = introspect_sqlite_foreign_keys(pool, &name).await?;
287
288        tables.push(SchemaTable {
289            name,
290            columns,
291            indexes,
292            foreign_keys,
293            create_sql: None,
294        });
295    }
296
297    Ok(tables)
298}
299
300#[cfg(feature = "postgres")]
301/// Introspects the schema of a PostgreSQL database.
302pub async fn introspect_postgres_schema(pool: &PgPool) -> Result<Vec<SchemaTable>, sqlx::Error> {
303    let table_names: Vec<String> = sqlx::query_scalar(
304        "SELECT table_name FROM information_schema.tables WHERE table_schema='public' AND table_type='BASE TABLE' AND table_name != '_premix_migrations' ORDER BY table_name",
305    )
306    .fetch_all(pool)
307    .await?;
308
309    let mut tables = Vec::new();
310    for name in table_names {
311        let pk_cols: Vec<String> = sqlx::query_scalar(
312            "SELECT a.attname FROM pg_index i JOIN pg_attribute a ON a.attrelid=i.indrelid AND a.attnum = ANY(i.indkey) WHERE i.indrelid=$1::regclass AND i.indisprimary",
313        )
314        .bind(&name)
315        .fetch_all(pool)
316        .await?;
317        let pk_set: BTreeSet<String> = pk_cols.into_iter().collect();
318
319        let rows: Vec<(String, String, String, String)> = sqlx::query_as(
320            "SELECT column_name, data_type, udt_name, is_nullable FROM information_schema.columns WHERE table_schema='public' AND table_name=$1 ORDER BY ordinal_position",
321        )
322        .bind(&name)
323        .fetch_all(pool)
324        .await?;
325
326        if rows.is_empty() {
327            continue;
328        }
329
330        let columns = rows
331            .into_iter()
332            .map(|(col_name, data_type, udt_name, is_nullable)| {
333                let is_pk = pk_set.contains(&col_name);
334                let sql_type = if data_type.eq_ignore_ascii_case("ARRAY") {
335                    let base = udt_name.trim_start_matches('_');
336                    format!("{}[]", base)
337                } else if data_type.eq_ignore_ascii_case("USER-DEFINED") {
338                    udt_name
339                } else {
340                    data_type
341                };
342                SchemaColumn {
343                    name: col_name,
344                    sql_type,
345                    nullable: !is_pk && is_nullable.eq_ignore_ascii_case("YES"),
346                    primary_key: is_pk,
347                }
348            })
349            .collect();
350
351        let indexes = introspect_postgres_indexes(pool, &name).await?;
352        let foreign_keys = introspect_postgres_foreign_keys(pool, &name).await?;
353
354        tables.push(SchemaTable {
355            name,
356            columns,
357            indexes,
358            foreign_keys,
359            create_sql: None,
360        });
361    }
362
363    Ok(tables)
364}
365
366/// Compares the actual SQLite schema with an expected list of tables.
367#[cfg(feature = "sqlite")]
368pub async fn diff_sqlite_schema(
369    pool: &SqlitePool,
370    expected: &[SchemaTable],
371) -> Result<SchemaDiff, sqlx::Error> {
372    let actual = introspect_sqlite_schema(pool).await?;
373    Ok(diff_schema(expected, &actual))
374}
375
376/// Compares the actual PostgreSQL schema with an expected list of tables.
377#[cfg(feature = "postgres")]
378pub async fn diff_postgres_schema(
379    pool: &PgPool,
380    expected: &[SchemaTable],
381) -> Result<SchemaDiff, sqlx::Error> {
382    let actual = introspect_postgres_schema(pool).await?;
383    Ok(diff_schema(expected, &actual))
384}
385
386/// Calculates the difference between two sets of table metadata.
387pub fn diff_schema(expected: &[SchemaTable], actual: &[SchemaTable]) -> SchemaDiff {
388    let mut diff = SchemaDiff::default();
389
390    let expected_map: BTreeMap<_, _> = expected.iter().map(|t| (&t.name, t)).collect();
391    let actual_map: BTreeMap<_, _> = actual.iter().map(|t| (&t.name, t)).collect();
392
393    for name in expected_map.keys() {
394        if !actual_map.contains_key(name) {
395            diff.missing_tables.push((*name).to_string());
396        }
397    }
398    for name in actual_map.keys() {
399        if !expected_map.contains_key(name) {
400            diff.extra_tables.push((*name).to_string());
401        }
402    }
403
404    for (name, expected_table) in &expected_map {
405        let Some(actual_table) = actual_map.get(name) else {
406            continue;
407        };
408
409        let expected_cols: BTreeMap<_, _> = expected_table
410            .columns
411            .iter()
412            .map(|c| (&c.name, c))
413            .collect();
414        let actual_cols: BTreeMap<_, _> =
415            actual_table.columns.iter().map(|c| (&c.name, c)).collect();
416
417        for col in expected_cols.keys() {
418            if !actual_cols.contains_key(col) {
419                let sql_type = expected_cols.get(col).map(|c| c.normalized_type());
420                diff.missing_columns.push(ColumnDiff {
421                    table: (*name).to_string(),
422                    column: (*col).to_string(),
423                    sql_type,
424                });
425            }
426        }
427        for col in actual_cols.keys() {
428            if !expected_cols.contains_key(col) {
429                let sql_type = actual_cols.get(col).map(|c| c.normalized_type());
430                diff.extra_columns.push(ColumnDiff {
431                    table: (*name).to_string(),
432                    column: (*col).to_string(),
433                    sql_type,
434                });
435            }
436        }
437
438        for (col_name, expected_col) in &expected_cols {
439            let Some(actual_col) = actual_cols.get(col_name) else {
440                continue;
441            };
442
443            let expected_type = expected_col.normalized_type();
444            let actual_type = actual_col.normalized_type();
445            if expected_type != actual_type {
446                diff.type_mismatches.push(ColumnTypeDiff {
447                    table: (*name).to_string(),
448                    column: (*col_name).to_string(),
449                    expected: expected_col.sql_type.clone(),
450                    actual: actual_col.sql_type.clone(),
451                });
452            }
453
454            if expected_col.nullable != actual_col.nullable {
455                diff.nullability_mismatches.push(ColumnNullabilityDiff {
456                    table: (*name).to_string(),
457                    column: (*col_name).to_string(),
458                    expected_nullable: expected_col.nullable,
459                    actual_nullable: actual_col.nullable,
460                });
461            }
462
463            if expected_col.primary_key != actual_col.primary_key {
464                diff.primary_key_mismatches.push(ColumnPrimaryKeyDiff {
465                    table: (*name).to_string(),
466                    column: (*col_name).to_string(),
467                    expected_primary_key: expected_col.primary_key,
468                    actual_primary_key: actual_col.primary_key,
469                });
470            }
471        }
472
473        let expected_indexes = index_map(&expected_table.indexes);
474        let actual_indexes = index_map(&actual_table.indexes);
475        for key in expected_indexes.keys() {
476            if !actual_indexes.contains_key(key) {
477                if let Some(index) = expected_indexes.get(key) {
478                    diff.missing_indexes
479                        .push(((*name).to_string(), (*index).clone()));
480                }
481            }
482        }
483        for key in actual_indexes.keys() {
484            if !expected_indexes.contains_key(key) {
485                if let Some(index) = actual_indexes.get(key) {
486                    diff.extra_indexes
487                        .push(((*name).to_string(), (*index).clone()));
488                }
489            }
490        }
491
492        let expected_fks = foreign_key_map(&expected_table.foreign_keys);
493        let actual_fks = foreign_key_map(&actual_table.foreign_keys);
494        for key in expected_fks.keys() {
495            if !actual_fks.contains_key(key) {
496                if let Some(fk) = expected_fks.get(key) {
497                    diff.missing_foreign_keys
498                        .push(((*name).to_string(), (*fk).clone()));
499                }
500            }
501        }
502        for key in actual_fks.keys() {
503            if !expected_fks.contains_key(key) {
504                if let Some(fk) = actual_fks.get(key) {
505                    diff.extra_foreign_keys
506                        .push(((*name).to_string(), (*fk).clone()));
507                }
508            }
509        }
510    }
511
512    diff.missing_tables.sort();
513    diff.extra_tables.sort();
514
515    diff
516}
517
518/// Generates SQLite migration SQL based on the provided schema differences.
519pub fn sqlite_migration_sql(expected: &[SchemaTable], diff: &SchemaDiff) -> Vec<String> {
520    let expected_map: BTreeMap<String, &SchemaTable> =
521        expected.iter().map(|t| (t.name.clone(), t)).collect();
522    let mut statements = Vec::new();
523
524    for table in &diff.missing_tables {
525        if let Some(schema) = expected_map.get(table) {
526            statements.push(schema.to_create_sql());
527            for index in &schema.indexes {
528                statements.push(sqlite_create_index_sql(&schema.name, index));
529            }
530        } else {
531            statements.push(format!("-- Missing schema for table {}", table));
532        }
533    }
534
535    let mut missing_by_table: BTreeMap<String, BTreeSet<String>> = BTreeMap::new();
536    for col in &diff.missing_columns {
537        missing_by_table
538            .entry(col.table.clone())
539            .or_default()
540            .insert(col.column.clone());
541    }
542
543    for (table, columns) in missing_by_table {
544        let Some(schema) = expected_map.get(&table) else {
545            continue;
546        };
547        for col_name in columns {
548            let Some(col) = schema.column(&col_name) else {
549                continue;
550            };
551
552            // Heuristic: Check for potential rename from extra columns
553            // Simple check: Look for an extra column in the same table with the same type
554            // that isn't already "claimed" (though here we just suggest)
555            let potential_rename = diff.extra_columns.iter().find(|e| {
556                e.table == table && e.sql_type.as_deref() == Some(&col.normalized_type())
557            });
558
559            if let Some(old) = potential_rename {
560                statements.push(format!(
561                    "-- SUGESTION: Potential rename from column '{}' to '{}'?",
562                    old.column, col.name
563                ));
564            }
565
566            if col.primary_key {
567                statements.push(format!(
568                    "-- TODO: add primary key column {}.{} manually",
569                    table, col_name
570                ));
571                continue;
572            }
573
574            if !col.nullable {
575                statements.push(format!(
576                    "-- WARNING: Adding NOT NULL column '{}.{}' without a default value will fail if table contains rows.",
577                    table, col.name
578                ));
579            }
580
581            let mut stmt = format!(
582                "ALTER TABLE {} ADD COLUMN {} {}",
583                table, col.name, col.sql_type
584            );
585            if !col.nullable {
586                stmt.push_str(" NOT NULL");
587            }
588            statements.push(stmt);
589        }
590    }
591
592    for mismatch in &diff.type_mismatches {
593        statements.push(format!(
594            "-- TODO: column type mismatch {}.{} (expected {}, actual {})",
595            mismatch.table, mismatch.column, mismatch.expected, mismatch.actual
596        ));
597    }
598    for mismatch in &diff.nullability_mismatches {
599        statements.push(format!(
600            "-- TODO: column nullability mismatch {}.{} (expected nullable {}, actual nullable {})",
601            mismatch.table, mismatch.column, mismatch.expected_nullable, mismatch.actual_nullable
602        ));
603    }
604    for mismatch in &diff.primary_key_mismatches {
605        statements.push(format!(
606            "-- TODO: column primary key mismatch {}.{} (expected pk {}, actual pk {})",
607            mismatch.table,
608            mismatch.column,
609            mismatch.expected_primary_key,
610            mismatch.actual_primary_key
611        ));
612    }
613    for (table, index) in &diff.missing_indexes {
614        statements.push(sqlite_create_index_sql(table, index));
615    }
616    for (table, index) in &diff.extra_indexes {
617        statements.push(format!(
618            "-- TODO: extra index {}.{} ({})",
619            table,
620            index.name,
621            index.columns.join(", ")
622        ));
623    }
624    for (table, fk) in &diff.missing_foreign_keys {
625        statements.push(format!(
626            "-- TODO: add foreign key {}.{} -> {}({}) (requires table rebuild)",
627            table, fk.column, fk.ref_table, fk.ref_column
628        ));
629    }
630    for (table, fk) in &diff.extra_foreign_keys {
631        statements.push(format!(
632            "-- TODO: extra foreign key {}.{} -> {}({})",
633            table, fk.column, fk.ref_table, fk.ref_column
634        ));
635    }
636    for extra in &diff.extra_columns {
637        statements.push(format!(
638            "-- TODO: extra column {}.{} not in models",
639            extra.table, extra.column
640        ));
641    }
642    for table in &diff.extra_tables {
643        statements.push(format!("-- TODO: extra table {} not in models", table));
644    }
645
646    statements
647}
648
649fn normalize_sql_type(sql_type: &str) -> String {
650    let t = sql_type.trim().to_lowercase();
651    if t.is_empty() {
652        return t;
653    }
654    if t.contains("int") || t.contains("serial") {
655        return "integer".to_string();
656    }
657    if t.contains("char") || t.contains("text") || t.contains("clob") {
658        return "text".to_string();
659    }
660    if t.contains("real")
661        || t.contains("floa")
662        || t.contains("doub")
663        || t.contains("numeric")
664        || t.contains("decimal")
665    {
666        return "real".to_string();
667    }
668    if t.contains("bool") {
669        return "boolean".to_string();
670    }
671    if t.contains("time") || t.contains("date") || t.contains("uuid") || t.contains("json") {
672        return "text".to_string();
673    }
674    t
675}
676
677#[cfg(feature = "postgres")]
678/// Generates PostgreSQL migration SQL for a given schema difference.
679pub fn postgres_migration_sql(expected: &[SchemaTable], diff: &SchemaDiff) -> Vec<String> {
680    let expected_map: BTreeMap<String, &SchemaTable> =
681        expected.iter().map(|t| (t.name.clone(), t)).collect();
682    let mut statements = Vec::new();
683
684    for table in &diff.missing_tables {
685        if let Some(schema) = expected_map.get(table) {
686            statements.push(schema.to_create_sql());
687            for index in &schema.indexes {
688                statements.push(postgres_create_index_sql(&schema.name, index));
689            }
690        } else {
691            statements.push(format!("-- Missing schema for table {}", table));
692        }
693    }
694
695    let mut missing_by_table: BTreeMap<String, BTreeSet<String>> = BTreeMap::new();
696    for col in &diff.missing_columns {
697        missing_by_table
698            .entry(col.table.clone())
699            .or_default()
700            .insert(col.column.clone());
701    }
702
703    for (table, columns) in missing_by_table {
704        let Some(schema) = expected_map.get(&table) else {
705            continue;
706        };
707        for col_name in columns {
708            let Some(col) = schema.column(&col_name) else {
709                continue;
710            };
711
712            // Heuristic: Check for potential rename
713            let potential_rename = diff.extra_columns.iter().find(|e| {
714                e.table == table && e.sql_type.as_deref() == Some(&col.normalized_type())
715            });
716
717            if let Some(old) = potential_rename {
718                statements.push(format!(
719                    "-- SUGESTION: Potential rename from column '{}' to '{}'?",
720                    old.column, col.name
721                ));
722            }
723
724            if col.primary_key {
725                statements.push(format!(
726                    "-- TODO: add primary key column {}.{} manually",
727                    table, col_name
728                ));
729                continue;
730            }
731
732            if !col.nullable {
733                statements.push(format!(
734                    "-- WARNING: Adding NOT NULL column '{}.{}' without a default value will fail if table contains rows.",
735                    table, col.name
736                ));
737            }
738
739            let mut stmt = format!(
740                "ALTER TABLE {} ADD COLUMN {} {}",
741                table, col.name, col.sql_type
742            );
743            if !col.nullable {
744                stmt.push_str(" NOT NULL");
745            }
746            statements.push(stmt);
747        }
748    }
749
750    for mismatch in &diff.type_mismatches {
751        statements.push(format!(
752            "-- TODO: column type mismatch {}.{} (expected {}, actual {})",
753            mismatch.table, mismatch.column, mismatch.expected, mismatch.actual
754        ));
755    }
756    for mismatch in &diff.nullability_mismatches {
757        statements.push(format!(
758            "-- TODO: column nullability mismatch {}.{} (expected nullable {}, actual nullable {})",
759            mismatch.table, mismatch.column, mismatch.expected_nullable, mismatch.actual_nullable
760        ));
761    }
762    for mismatch in &diff.primary_key_mismatches {
763        statements.push(format!(
764            "-- TODO: column primary key mismatch {}.{} (expected pk {}, actual pk {})",
765            mismatch.table,
766            mismatch.column,
767            mismatch.expected_primary_key,
768            mismatch.actual_primary_key
769        ));
770    }
771    for (table, index) in &diff.missing_indexes {
772        statements.push(postgres_create_index_sql(table, index));
773    }
774    for (table, index) in &diff.extra_indexes {
775        statements.push(format!(
776            "-- TODO: extra index {}.{} ({})",
777            table,
778            index.name,
779            index.columns.join(", ")
780        ));
781    }
782    for (table, fk) in &diff.missing_foreign_keys {
783        let fk_name = format!("fk_{}_{}", table, fk.column);
784        statements.push(format!(
785            "ALTER TABLE {} ADD CONSTRAINT {} FOREIGN KEY ({}) REFERENCES {}({})",
786            table, fk_name, fk.column, fk.ref_table, fk.ref_column
787        ));
788    }
789    for (table, fk) in &diff.extra_foreign_keys {
790        statements.push(format!(
791            "-- TODO: extra foreign key {}.{} -> {}({})",
792            table, fk.column, fk.ref_table, fk.ref_column
793        ));
794    }
795    for extra in &diff.extra_columns {
796        statements.push(format!(
797            "-- TODO: extra column {}.{} not in models",
798            extra.table, extra.column
799        ));
800    }
801    for table in &diff.extra_tables {
802        statements.push(format!("-- TODO: extra table {} not in models", table));
803    }
804
805    statements
806}
807
808fn index_key(index: &SchemaIndex) -> (String, String, bool) {
809    let name = index.name.clone();
810    let cols = index.columns.join(",");
811    (name, cols, index.unique)
812}
813
814fn index_map(indexes: &[SchemaIndex]) -> BTreeMap<(String, String, bool), &SchemaIndex> {
815    indexes.iter().map(|i| (index_key(i), i)).collect()
816}
817
818fn foreign_key_key(fk: &SchemaForeignKey) -> (String, String, String) {
819    (
820        fk.column.clone(),
821        fk.ref_table.clone(),
822        fk.ref_column.clone(),
823    )
824}
825
826fn foreign_key_map(
827    fks: &[SchemaForeignKey],
828) -> BTreeMap<(String, String, String), &SchemaForeignKey> {
829    fks.iter().map(|f| (foreign_key_key(f), f)).collect()
830}
831
832#[cfg(feature = "sqlite")]
833async fn introspect_sqlite_indexes(
834    pool: &SqlitePool,
835    table: &str,
836) -> Result<Vec<SchemaIndex>, sqlx::Error> {
837    let sql = format!("PRAGMA index_list({})", table);
838    let rows: Vec<(i64, String, i64, String, i64)> = sqlx::query_as(&sql).fetch_all(pool).await?;
839
840    let mut indexes = Vec::new();
841    for (_seq, name, unique, origin, _partial) in rows {
842        if origin == "pk" || name.starts_with("sqlite_autoindex") {
843            continue;
844        }
845        let info_sql = format!("PRAGMA index_info({})", name);
846        let info_rows: Vec<(i64, i64, String)> = sqlx::query_as(&info_sql).fetch_all(pool).await?;
847        let columns = info_rows.into_iter().map(|(_seq, _cid, col)| col).collect();
848        indexes.push(SchemaIndex {
849            name,
850            columns,
851            unique: unique != 0,
852        });
853    }
854    Ok(indexes)
855}
856
857#[cfg(feature = "sqlite")]
858async fn introspect_sqlite_foreign_keys(
859    pool: &SqlitePool,
860    table: &str,
861) -> Result<Vec<SchemaForeignKey>, sqlx::Error> {
862    let sql = format!("PRAGMA foreign_key_list({})", table);
863    #[allow(clippy::type_complexity)]
864    let rows: Vec<(i64, i64, String, String, String, String, String, String)> =
865        sqlx::query_as(&sql).fetch_all(pool).await?;
866
867    let mut fks = Vec::new();
868    for (_id, _seq, ref_table, from, to, _on_update, _on_delete, _match) in rows {
869        fks.push(SchemaForeignKey {
870            column: from,
871            ref_table,
872            ref_column: to,
873        });
874    }
875    Ok(fks)
876}
877
878fn sqlite_create_index_sql(table: &str, index: &SchemaIndex) -> String {
879    let unique = if index.unique { "UNIQUE " } else { "" };
880    let name = if index.name.is_empty() {
881        format!("idx_{}_{}", table, index.columns.join("_"))
882    } else {
883        index.name.clone()
884    };
885    format!(
886        "CREATE {}INDEX IF NOT EXISTS {} ON {} ({})",
887        unique,
888        name,
889        table,
890        index.columns.join(", ")
891    )
892}
893
894#[cfg(feature = "postgres")]
895fn postgres_create_index_sql(table: &str, index: &SchemaIndex) -> String {
896    let unique = if index.unique { "UNIQUE " } else { "" };
897    let name = if index.name.is_empty() {
898        format!("idx_{}_{}", table, index.columns.join("_"))
899    } else {
900        index.name.clone()
901    };
902    format!(
903        "CREATE {}INDEX IF NOT EXISTS {} ON {} ({})",
904        unique,
905        name,
906        table,
907        index.columns.join(", ")
908    )
909}
910
911#[cfg(feature = "postgres")]
912async fn introspect_postgres_indexes(
913    pool: &PgPool,
914    table: &str,
915) -> Result<Vec<SchemaIndex>, sqlx::Error> {
916    let rows: Vec<(String, bool, Vec<String>)> = sqlx::query_as(
917        "SELECT i.relname AS index_name, ix.indisunique, array_agg(a.attname ORDER BY x.n) AS columns
918         FROM pg_class t
919         JOIN pg_index ix ON t.oid = ix.indrelid
920         JOIN pg_class i ON i.oid = ix.indexrelid
921         JOIN LATERAL unnest(ix.indkey) WITH ORDINALITY AS x(attnum, n) ON true
922         JOIN pg_attribute a ON a.attrelid = t.oid AND a.attnum = x.attnum
923         WHERE t.relname = $1 AND t.relkind = 'r' AND NOT ix.indisprimary
924         GROUP BY i.relname, ix.indisunique
925         ORDER BY i.relname",
926    )
927    .bind(table)
928    .fetch_all(pool)
929    .await?;
930
931    let indexes = rows
932        .into_iter()
933        .map(|(name, unique, columns)| SchemaIndex {
934            name,
935            columns,
936            unique,
937        })
938        .collect();
939    Ok(indexes)
940}
941
942#[cfg(feature = "postgres")]
943async fn introspect_postgres_foreign_keys(
944    pool: &PgPool,
945    table: &str,
946) -> Result<Vec<SchemaForeignKey>, sqlx::Error> {
947    let rows: Vec<(String, String, String)> = sqlx::query_as(
948        "SELECT kcu.column_name, ccu.table_name, ccu.column_name
949         FROM information_schema.table_constraints tc
950         JOIN information_schema.key_column_usage kcu
951           ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema
952         JOIN information_schema.constraint_column_usage ccu
953           ON ccu.constraint_name = tc.constraint_name AND ccu.table_schema = tc.table_schema
954         WHERE tc.constraint_type = 'FOREIGN KEY'
955           AND tc.table_schema = 'public'
956           AND tc.table_name = $1
957         ORDER BY kcu.ordinal_position",
958    )
959    .bind(table)
960    .fetch_all(pool)
961    .await?;
962
963    let fks = rows
964        .into_iter()
965        .map(|(column, ref_table, ref_column)| SchemaForeignKey {
966            column,
967            ref_table,
968            ref_column,
969        })
970        .collect();
971
972    Ok(fks)
973}
974
975#[cfg(test)]
976mod tests {
977    use super::*;
978
979    #[cfg(feature = "sqlite")]
980    #[tokio::test]
981    async fn sqlite_introspect_and_diff_empty() {
982        let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
983        sqlx::query(
984            "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL, deleted_at TEXT)",
985        )
986        .execute(&pool)
987        .await
988        .unwrap();
989
990        let expected = vec![SchemaTable {
991            name: "users".to_string(),
992            columns: vec![
993                SchemaColumn {
994                    name: "id".to_string(),
995                    sql_type: "INTEGER".to_string(),
996                    nullable: false,
997                    primary_key: true,
998                },
999                SchemaColumn {
1000                    name: "name".to_string(),
1001                    sql_type: "TEXT".to_string(),
1002                    nullable: false,
1003                    primary_key: false,
1004                },
1005                SchemaColumn {
1006                    name: "deleted_at".to_string(),
1007                    sql_type: "TEXT".to_string(),
1008                    nullable: true,
1009                    primary_key: false,
1010                },
1011            ],
1012            indexes: Vec::new(),
1013            foreign_keys: Vec::new(),
1014            create_sql: None,
1015        }];
1016
1017        let actual = introspect_sqlite_schema(&pool).await.unwrap();
1018        let diff = diff_schema(&expected, &actual);
1019        assert!(diff.is_empty());
1020    }
1021
1022    #[cfg(feature = "sqlite")]
1023    #[tokio::test]
1024    async fn sqlite_diff_reports_missing_column() {
1025        let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
1026        sqlx::query("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)")
1027            .execute(&pool)
1028            .await
1029            .unwrap();
1030
1031        let expected = vec![SchemaTable {
1032            name: "users".to_string(),
1033            columns: vec![
1034                SchemaColumn {
1035                    name: "id".to_string(),
1036                    sql_type: "INTEGER".to_string(),
1037                    nullable: false,
1038                    primary_key: true,
1039                },
1040                SchemaColumn {
1041                    name: "name".to_string(),
1042                    sql_type: "TEXT".to_string(),
1043                    nullable: false,
1044                    primary_key: false,
1045                },
1046                SchemaColumn {
1047                    name: "status".to_string(),
1048                    sql_type: "TEXT".to_string(),
1049                    nullable: true,
1050                    primary_key: false,
1051                },
1052            ],
1053            indexes: Vec::new(),
1054            foreign_keys: Vec::new(),
1055            create_sql: None,
1056        }];
1057
1058        let actual = introspect_sqlite_schema(&pool).await.unwrap();
1059        let diff = diff_schema(&expected, &actual);
1060        assert_eq!(diff.missing_columns.len(), 1);
1061
1062        let summary = format_schema_diff_summary(&diff);
1063        assert!(summary.contains("missing columns: 1"));
1064
1065        let sql = sqlite_migration_sql(&expected, &diff);
1066        assert!(
1067            sql.iter()
1068                .any(|stmt| stmt.contains("ALTER TABLE users ADD COLUMN status"))
1069        );
1070    }
1071
1072    #[cfg(feature = "sqlite")]
1073    #[tokio::test]
1074    async fn sqlite_diff_reports_missing_index() {
1075        let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
1076        sqlx::query("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)")
1077            .execute(&pool)
1078            .await
1079            .unwrap();
1080
1081        let expected = vec![SchemaTable {
1082            name: "users".to_string(),
1083            columns: vec![
1084                SchemaColumn {
1085                    name: "id".to_string(),
1086                    sql_type: "INTEGER".to_string(),
1087                    nullable: false,
1088                    primary_key: true,
1089                },
1090                SchemaColumn {
1091                    name: "name".to_string(),
1092                    sql_type: "TEXT".to_string(),
1093                    nullable: false,
1094                    primary_key: false,
1095                },
1096            ],
1097            indexes: vec![SchemaIndex {
1098                name: "idx_users_name".to_string(),
1099                columns: vec!["name".to_string()],
1100                unique: false,
1101            }],
1102            foreign_keys: Vec::new(),
1103            create_sql: None,
1104        }];
1105
1106        let actual = introspect_sqlite_schema(&pool).await.unwrap();
1107        let diff = diff_schema(&expected, &actual);
1108        assert_eq!(diff.missing_indexes.len(), 1);
1109
1110        let sql = sqlite_migration_sql(&expected, &diff);
1111        assert!(
1112            sql.iter()
1113                .any(|stmt| stmt.contains("CREATE INDEX IF NOT EXISTS idx_users_name"))
1114        );
1115    }
1116
1117    #[cfg(feature = "postgres")]
1118    fn pg_url() -> String {
1119        std::env::var("DATABASE_URL").unwrap_or_else(|_| {
1120            "postgres://postgres:admin123@localhost:5432/premix_bench".to_string()
1121        })
1122    }
1123
1124    #[cfg(feature = "postgres")]
1125    #[tokio::test]
1126    async fn postgres_introspect_and_diff() {
1127        let url = pg_url();
1128        let pool = match PgPool::connect(&url).await {
1129            Ok(pool) => pool,
1130            Err(_) => return,
1131        };
1132
1133        sqlx::query("DROP TABLE IF EXISTS schema_posts")
1134            .execute(&pool)
1135            .await
1136            .unwrap();
1137        sqlx::query("DROP TABLE IF EXISTS schema_users")
1138            .execute(&pool)
1139            .await
1140            .unwrap();
1141
1142        sqlx::query("CREATE TABLE schema_users (id SERIAL PRIMARY KEY, name TEXT NOT NULL)")
1143            .execute(&pool)
1144            .await
1145            .unwrap();
1146        sqlx::query(
1147            "CREATE TABLE schema_posts (id SERIAL PRIMARY KEY, user_id INTEGER NOT NULL, title TEXT NOT NULL, CONSTRAINT fk_schema_posts_user_id FOREIGN KEY (user_id) REFERENCES schema_users(id))",
1148        )
1149        .execute(&pool)
1150        .await
1151        .unwrap();
1152        sqlx::query("CREATE INDEX idx_schema_posts_user_id ON schema_posts (user_id)")
1153            .execute(&pool)
1154            .await
1155            .unwrap();
1156
1157        let expected = vec![
1158            SchemaTable {
1159                name: "schema_posts".to_string(),
1160                columns: vec![
1161                    SchemaColumn {
1162                        name: "id".to_string(),
1163                        sql_type: "INTEGER".to_string(),
1164                        nullable: false,
1165                        primary_key: true,
1166                    },
1167                    SchemaColumn {
1168                        name: "user_id".to_string(),
1169                        sql_type: "INTEGER".to_string(),
1170                        nullable: false,
1171                        primary_key: false,
1172                    },
1173                    SchemaColumn {
1174                        name: "title".to_string(),
1175                        sql_type: "TEXT".to_string(),
1176                        nullable: false,
1177                        primary_key: false,
1178                    },
1179                ],
1180                indexes: vec![SchemaIndex {
1181                    name: "idx_schema_posts_user_id".to_string(),
1182                    columns: vec!["user_id".to_string()],
1183                    unique: false,
1184                }],
1185                foreign_keys: vec![SchemaForeignKey {
1186                    column: "user_id".to_string(),
1187                    ref_table: "schema_users".to_string(),
1188                    ref_column: "id".to_string(),
1189                }],
1190                create_sql: None,
1191            },
1192            SchemaTable {
1193                name: "schema_users".to_string(),
1194                columns: vec![
1195                    SchemaColumn {
1196                        name: "id".to_string(),
1197                        sql_type: "INTEGER".to_string(),
1198                        nullable: false,
1199                        primary_key: true,
1200                    },
1201                    SchemaColumn {
1202                        name: "name".to_string(),
1203                        sql_type: "TEXT".to_string(),
1204                        nullable: false,
1205                        primary_key: false,
1206                    },
1207                ],
1208                indexes: Vec::new(),
1209                foreign_keys: Vec::new(),
1210                create_sql: None,
1211            },
1212        ];
1213
1214        let actual = introspect_postgres_schema(&pool).await.unwrap();
1215        let expected_names: BTreeSet<String> =
1216            expected.iter().map(|table| table.name.clone()).collect();
1217        let actual = actual
1218            .into_iter()
1219            .filter(|table| expected_names.contains(&table.name))
1220            .collect::<Vec<_>>();
1221        let diff = diff_schema(&expected, &actual);
1222        assert!(diff.is_empty(), "postgres schema diff: {diff:?}");
1223
1224        sqlx::query("DROP TABLE IF EXISTS schema_posts")
1225            .execute(&pool)
1226            .await
1227            .unwrap();
1228        sqlx::query("DROP TABLE IF EXISTS schema_users")
1229            .execute(&pool)
1230            .await
1231            .unwrap();
1232    }
1233}