Skip to main content

sqlmodel_schema/
introspect.rs

1//! Database introspection.
2//!
3//! This module provides comprehensive schema introspection for SQLite, PostgreSQL, and MySQL.
4//! It extracts metadata about tables, columns, constraints, and indexes.
5
6use asupersync::{Cx, Outcome};
7use sqlmodel_core::{Connection, Error, sanitize_identifier};
8use std::collections::HashMap;
9
10// ============================================================================
11// Schema Types
12// ============================================================================
13
14/// Complete representation of a database schema.
15#[derive(Debug, Clone, Default)]
16pub struct DatabaseSchema {
17    /// All tables in the schema, keyed by table name
18    pub tables: HashMap<String, TableInfo>,
19    /// Database dialect
20    pub dialect: Dialect,
21}
22
23impl DatabaseSchema {
24    /// Create a new empty schema for the given dialect.
25    pub fn new(dialect: Dialect) -> Self {
26        Self {
27            tables: HashMap::new(),
28            dialect,
29        }
30    }
31
32    /// Get a table by name.
33    pub fn table(&self, name: &str) -> Option<&TableInfo> {
34        self.tables.get(name)
35    }
36
37    /// Get all table names.
38    pub fn table_names(&self) -> Vec<&str> {
39        self.tables.keys().map(|s| s.as_str()).collect()
40    }
41}
42
43/// Parsed SQL type with extracted metadata.
44#[derive(Debug, Clone, Default, PartialEq, Eq)]
45pub struct ParsedSqlType {
46    /// Base type name (e.g., VARCHAR, INTEGER, DECIMAL)
47    pub base_type: String,
48    /// Length for character types (e.g., VARCHAR(255) -> 255)
49    pub length: Option<u32>,
50    /// Precision for numeric types (e.g., DECIMAL(10,2) -> 10)
51    pub precision: Option<u32>,
52    /// Scale for numeric types (e.g., DECIMAL(10,2) -> 2)
53    pub scale: Option<u32>,
54    /// Whether the type is unsigned (MySQL)
55    pub unsigned: bool,
56    /// Whether this is an array type (PostgreSQL)
57    pub array: bool,
58}
59
60impl ParsedSqlType {
61    /// Parse a SQL type string into structured metadata.
62    ///
63    /// # Examples
64    /// - `VARCHAR(255)` -> base_type: "VARCHAR", length: 255
65    /// - `DECIMAL(10,2)` -> base_type: "DECIMAL", precision: 10, scale: 2
66    /// - `INT UNSIGNED` -> base_type: "INT", unsigned: true
67    /// - `TEXT[]` -> base_type: "TEXT", array: true
68    pub fn parse(type_str: &str) -> Self {
69        let type_str = type_str.trim().to_uppercase();
70
71        // Check for array suffix (PostgreSQL)
72        let (type_str, array) = if type_str.ends_with("[]") {
73            (type_str.trim_end_matches("[]"), true)
74        } else {
75            (type_str.as_str(), false)
76        };
77
78        // Check for UNSIGNED suffix (MySQL)
79        let (type_str, unsigned) = if type_str.ends_with(" UNSIGNED") {
80            (type_str.trim_end_matches(" UNSIGNED"), true)
81        } else {
82            (type_str, false)
83        };
84
85        // Parse base type and parameters
86        if let Some(paren_start) = type_str.find('(') {
87            let base_type = type_str[..paren_start].trim().to_string();
88            let params = &type_str[paren_start + 1..type_str.len() - 1]; // Remove ()
89
90            // Check if it's precision,scale or just length
91            if params.contains(',') {
92                let parts: Vec<&str> = params.split(',').collect();
93                let precision = parts.first().and_then(|s| s.trim().parse().ok());
94                let scale = parts.get(1).and_then(|s| s.trim().parse().ok());
95                Self {
96                    base_type,
97                    length: None,
98                    precision,
99                    scale,
100                    unsigned,
101                    array,
102                }
103            } else {
104                let length = params.trim().parse().ok();
105                Self {
106                    base_type,
107                    length,
108                    precision: None,
109                    scale: None,
110                    unsigned,
111                    array,
112                }
113            }
114        } else {
115            Self {
116                base_type: type_str.to_string(),
117                length: None,
118                precision: None,
119                scale: None,
120                unsigned,
121                array,
122            }
123        }
124    }
125
126    /// Check if this is a text/string type.
127    pub fn is_text(&self) -> bool {
128        matches!(
129            self.base_type.as_str(),
130            "VARCHAR" | "CHAR" | "TEXT" | "CLOB" | "NVARCHAR" | "NCHAR" | "NTEXT"
131        )
132    }
133
134    /// Check if this is a numeric type.
135    pub fn is_numeric(&self) -> bool {
136        matches!(
137            self.base_type.as_str(),
138            "INT"
139                | "INTEGER"
140                | "BIGINT"
141                | "SMALLINT"
142                | "TINYINT"
143                | "MEDIUMINT"
144                | "DECIMAL"
145                | "NUMERIC"
146                | "FLOAT"
147                | "DOUBLE"
148                | "REAL"
149                | "DOUBLE PRECISION"
150        )
151    }
152
153    /// Check if this is a date/time type.
154    pub fn is_datetime(&self) -> bool {
155        matches!(
156            self.base_type.as_str(),
157            "DATE" | "TIME" | "DATETIME" | "TIMESTAMP" | "TIMESTAMPTZ" | "TIMETZ"
158        )
159    }
160}
161
162/// Unique constraint information.
163#[derive(Debug, Clone)]
164pub struct UniqueConstraintInfo {
165    /// Constraint name
166    pub name: Option<String>,
167    /// Columns in the constraint
168    pub columns: Vec<String>,
169}
170
171/// Check constraint information.
172#[derive(Debug, Clone)]
173pub struct CheckConstraintInfo {
174    /// Constraint name
175    pub name: Option<String>,
176    /// Check expression
177    pub expression: String,
178}
179
180/// Information about a database table.
181#[derive(Debug, Clone)]
182pub struct TableInfo {
183    /// Table name
184    pub name: String,
185    /// Columns in the table
186    pub columns: Vec<ColumnInfo>,
187    /// Primary key column names
188    pub primary_key: Vec<String>,
189    /// Foreign key constraints
190    pub foreign_keys: Vec<ForeignKeyInfo>,
191    /// Unique constraints
192    pub unique_constraints: Vec<UniqueConstraintInfo>,
193    /// Check constraints
194    pub check_constraints: Vec<CheckConstraintInfo>,
195    /// Indexes on the table
196    pub indexes: Vec<IndexInfo>,
197    /// Table comment (if any)
198    pub comment: Option<String>,
199}
200
201impl TableInfo {
202    /// Get a column by name.
203    pub fn column(&self, name: &str) -> Option<&ColumnInfo> {
204        self.columns.iter().find(|c| c.name == name)
205    }
206
207    /// Check if this table has a single-column auto-increment primary key.
208    pub fn has_auto_pk(&self) -> bool {
209        self.primary_key.len() == 1
210            && self
211                .column(&self.primary_key[0])
212                .is_some_and(|c| c.auto_increment)
213    }
214}
215
216/// Information about a table column.
217#[derive(Debug, Clone)]
218pub struct ColumnInfo {
219    /// Column name
220    pub name: String,
221    /// SQL type as raw string
222    pub sql_type: String,
223    /// Parsed SQL type with extracted metadata
224    pub parsed_type: ParsedSqlType,
225    /// Whether the column is nullable
226    pub nullable: bool,
227    /// Default value expression
228    pub default: Option<String>,
229    /// Whether this is part of the primary key
230    pub primary_key: bool,
231    /// Whether this column auto-increments
232    pub auto_increment: bool,
233    /// Column comment (if any)
234    pub comment: Option<String>,
235}
236
237/// Information about a foreign key constraint.
238#[derive(Debug, Clone)]
239pub struct ForeignKeyInfo {
240    /// Constraint name
241    pub name: Option<String>,
242    /// Local column name
243    pub column: String,
244    /// Referenced table
245    pub foreign_table: String,
246    /// Referenced column
247    pub foreign_column: String,
248    /// ON DELETE action
249    pub on_delete: Option<String>,
250    /// ON UPDATE action
251    pub on_update: Option<String>,
252}
253
254/// Information about an index.
255#[derive(Debug, Clone)]
256pub struct IndexInfo {
257    /// Index name
258    pub name: String,
259    /// Columns in the index
260    pub columns: Vec<String>,
261    /// Whether this is a unique index
262    pub unique: bool,
263    /// Index type (BTREE, HASH, GIN, GIST, etc.)
264    pub index_type: Option<String>,
265    /// Whether this is a primary key index
266    pub primary: bool,
267}
268
269/// Database introspector.
270pub struct Introspector {
271    /// Database type for dialect-specific queries
272    dialect: Dialect,
273}
274
275/// Supported database dialects.
276#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
277pub enum Dialect {
278    /// SQLite
279    #[default]
280    Sqlite,
281    /// PostgreSQL
282    Postgres,
283    /// MySQL/MariaDB
284    Mysql,
285}
286
287impl Introspector {
288    /// Create a new introspector for the given dialect.
289    pub fn new(dialect: Dialect) -> Self {
290        Self { dialect }
291    }
292
293    /// List all table names in the database.
294    pub async fn table_names<C: Connection>(
295        &self,
296        cx: &Cx,
297        conn: &C,
298    ) -> Outcome<Vec<String>, Error> {
299        let sql = match self.dialect {
300            Dialect::Sqlite => {
301                "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'"
302            }
303            Dialect::Postgres => {
304                "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'"
305            }
306            Dialect::Mysql => "SHOW TABLES",
307        };
308
309        let rows = match conn.query(cx, sql, &[]).await {
310            Outcome::Ok(rows) => rows,
311            Outcome::Err(e) => return Outcome::Err(e),
312            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
313            Outcome::Panicked(p) => return Outcome::Panicked(p),
314        };
315
316        let names: Vec<String> = rows
317            .iter()
318            .filter_map(|row| row.get(0).and_then(|v| v.as_str().map(String::from)))
319            .collect();
320
321        Outcome::Ok(names)
322    }
323
324    /// Get detailed information about a table.
325    pub async fn table_info<C: Connection>(
326        &self,
327        cx: &Cx,
328        conn: &C,
329        table_name: &str,
330    ) -> Outcome<TableInfo, Error> {
331        let columns = match self.columns(cx, conn, table_name).await {
332            Outcome::Ok(cols) => cols,
333            Outcome::Err(e) => return Outcome::Err(e),
334            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
335            Outcome::Panicked(p) => return Outcome::Panicked(p),
336        };
337
338        let primary_key: Vec<String> = columns
339            .iter()
340            .filter(|c| c.primary_key)
341            .map(|c| c.name.clone())
342            .collect();
343
344        let foreign_keys = match self.foreign_keys(cx, conn, table_name).await {
345            Outcome::Ok(fks) => fks,
346            Outcome::Err(e) => return Outcome::Err(e),
347            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
348            Outcome::Panicked(p) => return Outcome::Panicked(p),
349        };
350
351        let indexes = match self.indexes(cx, conn, table_name).await {
352            Outcome::Ok(idxs) => idxs,
353            Outcome::Err(e) => return Outcome::Err(e),
354            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
355            Outcome::Panicked(p) => return Outcome::Panicked(p),
356        };
357
358        Outcome::Ok(TableInfo {
359            name: table_name.to_string(),
360            columns,
361            primary_key,
362            foreign_keys,
363            unique_constraints: Vec::new(), // Extracted from indexes with unique=true
364            check_constraints: Vec::new(),  // Requires additional queries per dialect
365            indexes,
366            comment: None, // Requires additional queries per dialect
367        })
368    }
369
370    /// Introspect the entire database schema.
371    pub async fn introspect_all<C: Connection>(
372        &self,
373        cx: &Cx,
374        conn: &C,
375    ) -> Outcome<DatabaseSchema, Error> {
376        let table_names = match self.table_names(cx, conn).await {
377            Outcome::Ok(names) => names,
378            Outcome::Err(e) => return Outcome::Err(e),
379            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
380            Outcome::Panicked(p) => return Outcome::Panicked(p),
381        };
382
383        let mut schema = DatabaseSchema::new(self.dialect);
384
385        for name in table_names {
386            let info = match self.table_info(cx, conn, &name).await {
387                Outcome::Ok(info) => info,
388                Outcome::Err(e) => return Outcome::Err(e),
389                Outcome::Cancelled(r) => return Outcome::Cancelled(r),
390                Outcome::Panicked(p) => return Outcome::Panicked(p),
391            };
392            schema.tables.insert(name, info);
393        }
394
395        Outcome::Ok(schema)
396    }
397
398    /// Get column information for a table.
399    async fn columns<C: Connection>(
400        &self,
401        cx: &Cx,
402        conn: &C,
403        table_name: &str,
404    ) -> Outcome<Vec<ColumnInfo>, Error> {
405        match self.dialect {
406            Dialect::Sqlite => self.sqlite_columns(cx, conn, table_name).await,
407            Dialect::Postgres => self.postgres_columns(cx, conn, table_name).await,
408            Dialect::Mysql => self.mysql_columns(cx, conn, table_name).await,
409        }
410    }
411
412    async fn sqlite_columns<C: Connection>(
413        &self,
414        cx: &Cx,
415        conn: &C,
416        table_name: &str,
417    ) -> Outcome<Vec<ColumnInfo>, Error> {
418        let sql = format!("PRAGMA table_info({})", sanitize_identifier(table_name));
419        let rows = match conn.query(cx, &sql, &[]).await {
420            Outcome::Ok(rows) => rows,
421            Outcome::Err(e) => return Outcome::Err(e),
422            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
423            Outcome::Panicked(p) => return Outcome::Panicked(p),
424        };
425
426        let columns: Vec<ColumnInfo> = rows
427            .iter()
428            .filter_map(|row| {
429                let name = row.get_named::<String>("name").ok()?;
430                let sql_type = row.get_named::<String>("type").ok()?;
431                let notnull = row.get_named::<i64>("notnull").ok().unwrap_or(0);
432                let dflt_value = row.get_named::<String>("dflt_value").ok();
433                let pk = row.get_named::<i64>("pk").ok().unwrap_or(0);
434                let parsed_type = ParsedSqlType::parse(&sql_type);
435
436                Some(ColumnInfo {
437                    name,
438                    sql_type,
439                    parsed_type,
440                    nullable: notnull == 0,
441                    default: dflt_value,
442                    primary_key: pk > 0,
443                    auto_increment: false, // SQLite doesn't report this via PRAGMA
444                    comment: None,         // SQLite doesn't support column comments
445                })
446            })
447            .collect();
448
449        Outcome::Ok(columns)
450    }
451
452    async fn postgres_columns<C: Connection>(
453        &self,
454        cx: &Cx,
455        conn: &C,
456        table_name: &str,
457    ) -> Outcome<Vec<ColumnInfo>, Error> {
458        // Use a more comprehensive query to get full type info
459        let sql = "SELECT
460                       c.column_name,
461                       c.data_type,
462                       c.udt_name,
463                       c.character_maximum_length,
464                       c.numeric_precision,
465                       c.numeric_scale,
466                       c.is_nullable,
467                       c.column_default,
468                       COALESCE(d.description, '') as column_comment
469                   FROM information_schema.columns c
470                   LEFT JOIN pg_catalog.pg_statio_all_tables st
471                       ON c.table_schema = st.schemaname AND c.table_name = st.relname
472                   LEFT JOIN pg_catalog.pg_description d
473                       ON d.objoid = st.relid AND d.objsubid = c.ordinal_position
474                   WHERE c.table_name = $1 AND c.table_schema = 'public'
475                   ORDER BY c.ordinal_position";
476
477        let rows = match conn
478            .query(
479                cx,
480                sql,
481                &[sqlmodel_core::Value::Text(table_name.to_string())],
482            )
483            .await
484        {
485            Outcome::Ok(rows) => rows,
486            Outcome::Err(e) => return Outcome::Err(e),
487            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
488            Outcome::Panicked(p) => return Outcome::Panicked(p),
489        };
490
491        let columns: Vec<ColumnInfo> = rows
492            .iter()
493            .filter_map(|row| {
494                let name = row.get_named::<String>("column_name").ok()?;
495                let data_type = row.get_named::<String>("data_type").ok()?;
496                let udt_name = row.get_named::<String>("udt_name").ok().unwrap_or_default();
497                let char_len = row.get_named::<i64>("character_maximum_length").ok();
498                let precision = row.get_named::<i64>("numeric_precision").ok();
499                let scale = row.get_named::<i64>("numeric_scale").ok();
500                let nullable_str = row.get_named::<String>("is_nullable").ok()?;
501                let default = row.get_named::<String>("column_default").ok();
502                let comment = row.get_named::<String>("column_comment").ok();
503
504                // Build a complete SQL type string
505                let sql_type =
506                    build_postgres_type(&data_type, &udt_name, char_len, precision, scale);
507                let parsed_type = ParsedSqlType::parse(&sql_type);
508
509                // Check if auto-increment by looking at default (nextval)
510                let auto_increment = default.as_ref().is_some_and(|d| d.starts_with("nextval("));
511
512                Some(ColumnInfo {
513                    name,
514                    sql_type,
515                    parsed_type,
516                    nullable: nullable_str == "YES",
517                    default,
518                    primary_key: false, // Determined via separate index query
519                    auto_increment,
520                    comment: comment.filter(|s| !s.is_empty()),
521                })
522            })
523            .collect();
524
525        Outcome::Ok(columns)
526    }
527
528    async fn mysql_columns<C: Connection>(
529        &self,
530        cx: &Cx,
531        conn: &C,
532        table_name: &str,
533    ) -> Outcome<Vec<ColumnInfo>, Error> {
534        // Use SHOW FULL COLUMNS to get comments
535        let sql = format!(
536            "SHOW FULL COLUMNS FROM `{}`",
537            sanitize_identifier(table_name)
538        );
539        let rows = match conn.query(cx, &sql, &[]).await {
540            Outcome::Ok(rows) => rows,
541            Outcome::Err(e) => return Outcome::Err(e),
542            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
543            Outcome::Panicked(p) => return Outcome::Panicked(p),
544        };
545
546        let columns: Vec<ColumnInfo> = rows
547            .iter()
548            .filter_map(|row| {
549                let name = row.get_named::<String>("Field").ok()?;
550                let sql_type = row.get_named::<String>("Type").ok()?;
551                let null = row.get_named::<String>("Null").ok()?;
552                let key = row.get_named::<String>("Key").ok()?;
553                let default = row.get_named::<String>("Default").ok();
554                let extra = row.get_named::<String>("Extra").ok().unwrap_or_default();
555                let comment = row.get_named::<String>("Comment").ok();
556                let parsed_type = ParsedSqlType::parse(&sql_type);
557
558                Some(ColumnInfo {
559                    name,
560                    sql_type,
561                    parsed_type,
562                    nullable: null == "YES",
563                    default,
564                    primary_key: key == "PRI",
565                    auto_increment: extra.contains("auto_increment"),
566                    comment: comment.filter(|s| !s.is_empty()),
567                })
568            })
569            .collect();
570
571        Outcome::Ok(columns)
572    }
573
574    // ========================================================================
575    // Foreign Key Introspection
576    // ========================================================================
577
578    /// Get foreign key constraints for a table.
579    async fn foreign_keys<C: Connection>(
580        &self,
581        cx: &Cx,
582        conn: &C,
583        table_name: &str,
584    ) -> Outcome<Vec<ForeignKeyInfo>, Error> {
585        match self.dialect {
586            Dialect::Sqlite => self.sqlite_foreign_keys(cx, conn, table_name).await,
587            Dialect::Postgres => self.postgres_foreign_keys(cx, conn, table_name).await,
588            Dialect::Mysql => self.mysql_foreign_keys(cx, conn, table_name).await,
589        }
590    }
591
592    async fn sqlite_foreign_keys<C: Connection>(
593        &self,
594        cx: &Cx,
595        conn: &C,
596        table_name: &str,
597    ) -> Outcome<Vec<ForeignKeyInfo>, Error> {
598        let sql = format!(
599            "PRAGMA foreign_key_list({})",
600            sanitize_identifier(table_name)
601        );
602        let rows = match conn.query(cx, &sql, &[]).await {
603            Outcome::Ok(rows) => rows,
604            Outcome::Err(e) => return Outcome::Err(e),
605            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
606            Outcome::Panicked(p) => return Outcome::Panicked(p),
607        };
608
609        let fks: Vec<ForeignKeyInfo> = rows
610            .iter()
611            .filter_map(|row| {
612                let table = row.get_named::<String>("table").ok()?;
613                let from = row.get_named::<String>("from").ok()?;
614                let to = row.get_named::<String>("to").ok()?;
615                let on_update = row.get_named::<String>("on_update").ok();
616                let on_delete = row.get_named::<String>("on_delete").ok();
617
618                Some(ForeignKeyInfo {
619                    name: None, // SQLite doesn't name FK constraints in PRAGMA output
620                    column: from,
621                    foreign_table: table,
622                    foreign_column: to,
623                    on_delete: on_delete.filter(|s| s != "NO ACTION"),
624                    on_update: on_update.filter(|s| s != "NO ACTION"),
625                })
626            })
627            .collect();
628
629        Outcome::Ok(fks)
630    }
631
632    async fn postgres_foreign_keys<C: Connection>(
633        &self,
634        cx: &Cx,
635        conn: &C,
636        table_name: &str,
637    ) -> Outcome<Vec<ForeignKeyInfo>, Error> {
638        let sql = "SELECT
639                       tc.constraint_name,
640                       kcu.column_name,
641                       ccu.table_name AS foreign_table_name,
642                       ccu.column_name AS foreign_column_name,
643                       rc.delete_rule,
644                       rc.update_rule
645                   FROM information_schema.table_constraints AS tc
646                   JOIN information_schema.key_column_usage AS kcu
647                       ON tc.constraint_name = kcu.constraint_name
648                       AND tc.table_schema = kcu.table_schema
649                   JOIN information_schema.constraint_column_usage AS ccu
650                       ON ccu.constraint_name = tc.constraint_name
651                       AND ccu.table_schema = tc.table_schema
652                   JOIN information_schema.referential_constraints AS rc
653                       ON rc.constraint_name = tc.constraint_name
654                       AND rc.constraint_schema = tc.table_schema
655                   WHERE tc.constraint_type = 'FOREIGN KEY'
656                       AND tc.table_name = $1
657                       AND tc.table_schema = 'public'";
658
659        let rows = match conn
660            .query(
661                cx,
662                sql,
663                &[sqlmodel_core::Value::Text(table_name.to_string())],
664            )
665            .await
666        {
667            Outcome::Ok(rows) => rows,
668            Outcome::Err(e) => return Outcome::Err(e),
669            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
670            Outcome::Panicked(p) => return Outcome::Panicked(p),
671        };
672
673        let fks: Vec<ForeignKeyInfo> = rows
674            .iter()
675            .filter_map(|row| {
676                let name = row.get_named::<String>("constraint_name").ok();
677                let column = row.get_named::<String>("column_name").ok()?;
678                let foreign_table = row.get_named::<String>("foreign_table_name").ok()?;
679                let foreign_column = row.get_named::<String>("foreign_column_name").ok()?;
680                let on_delete = row.get_named::<String>("delete_rule").ok();
681                let on_update = row.get_named::<String>("update_rule").ok();
682
683                Some(ForeignKeyInfo {
684                    name,
685                    column,
686                    foreign_table,
687                    foreign_column,
688                    on_delete: on_delete.filter(|s| s != "NO ACTION"),
689                    on_update: on_update.filter(|s| s != "NO ACTION"),
690                })
691            })
692            .collect();
693
694        Outcome::Ok(fks)
695    }
696
697    async fn mysql_foreign_keys<C: Connection>(
698        &self,
699        cx: &Cx,
700        conn: &C,
701        table_name: &str,
702    ) -> Outcome<Vec<ForeignKeyInfo>, Error> {
703        let sql = "SELECT
704                       constraint_name,
705                       column_name,
706                       referenced_table_name,
707                       referenced_column_name
708                   FROM information_schema.key_column_usage
709                   WHERE table_name = ?
710                       AND referenced_table_name IS NOT NULL";
711
712        let rows = match conn
713            .query(
714                cx,
715                sql,
716                &[sqlmodel_core::Value::Text(table_name.to_string())],
717            )
718            .await
719        {
720            Outcome::Ok(rows) => rows,
721            Outcome::Err(e) => return Outcome::Err(e),
722            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
723            Outcome::Panicked(p) => return Outcome::Panicked(p),
724        };
725
726        let fks: Vec<ForeignKeyInfo> = rows
727            .iter()
728            .filter_map(|row| {
729                let name = row.get_named::<String>("constraint_name").ok();
730                let column = row.get_named::<String>("column_name").ok()?;
731                let foreign_table = row.get_named::<String>("referenced_table_name").ok()?;
732                let foreign_column = row.get_named::<String>("referenced_column_name").ok()?;
733
734                Some(ForeignKeyInfo {
735                    name,
736                    column,
737                    foreign_table,
738                    foreign_column,
739                    on_delete: None, // Would need additional query
740                    on_update: None,
741                })
742            })
743            .collect();
744
745        Outcome::Ok(fks)
746    }
747
748    // ========================================================================
749    // Index Introspection
750    // ========================================================================
751
752    /// Get indexes for a table.
753    async fn indexes<C: Connection>(
754        &self,
755        cx: &Cx,
756        conn: &C,
757        table_name: &str,
758    ) -> Outcome<Vec<IndexInfo>, Error> {
759        match self.dialect {
760            Dialect::Sqlite => self.sqlite_indexes(cx, conn, table_name).await,
761            Dialect::Postgres => self.postgres_indexes(cx, conn, table_name).await,
762            Dialect::Mysql => self.mysql_indexes(cx, conn, table_name).await,
763        }
764    }
765
766    async fn sqlite_indexes<C: Connection>(
767        &self,
768        cx: &Cx,
769        conn: &C,
770        table_name: &str,
771    ) -> Outcome<Vec<IndexInfo>, Error> {
772        let sql = format!("PRAGMA index_list({})", sanitize_identifier(table_name));
773        let rows = match conn.query(cx, &sql, &[]).await {
774            Outcome::Ok(rows) => rows,
775            Outcome::Err(e) => return Outcome::Err(e),
776            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
777            Outcome::Panicked(p) => return Outcome::Panicked(p),
778        };
779
780        let mut indexes = Vec::new();
781
782        for row in &rows {
783            let Ok(name) = row.get_named::<String>("name") else {
784                continue;
785            };
786            let unique = row.get_named::<i64>("unique").ok().unwrap_or(0) == 1;
787            let origin = row.get_named::<String>("origin").ok().unwrap_or_default();
788            let primary = origin == "pk";
789
790            // Get column info for this index
791            let info_sql = format!("PRAGMA index_info({})", sanitize_identifier(&name));
792            let info_rows = match conn.query(cx, &info_sql, &[]).await {
793                Outcome::Ok(r) => r,
794                Outcome::Err(_) => continue,
795                Outcome::Cancelled(r) => return Outcome::Cancelled(r),
796                Outcome::Panicked(p) => return Outcome::Panicked(p),
797            };
798
799            let columns: Vec<String> = info_rows
800                .iter()
801                .filter_map(|r| r.get_named::<String>("name").ok())
802                .collect();
803
804            indexes.push(IndexInfo {
805                name,
806                columns,
807                unique,
808                index_type: None, // SQLite doesn't expose index type
809                primary,
810            });
811        }
812
813        Outcome::Ok(indexes)
814    }
815
816    async fn postgres_indexes<C: Connection>(
817        &self,
818        cx: &Cx,
819        conn: &C,
820        table_name: &str,
821    ) -> Outcome<Vec<IndexInfo>, Error> {
822        let sql = "SELECT
823                       i.relname AS index_name,
824                       a.attname AS column_name,
825                       ix.indisunique AS is_unique,
826                       ix.indisprimary AS is_primary,
827                       am.amname AS index_type
828                   FROM pg_class t
829                   JOIN pg_index ix ON t.oid = ix.indrelid
830                   JOIN pg_class i ON i.oid = ix.indexrelid
831                   JOIN pg_am am ON i.relam = am.oid
832                   JOIN pg_attribute a ON a.attrelid = t.oid AND a.attnum = ANY(ix.indkey)
833                   WHERE t.relname = $1
834                       AND t.relkind = 'r'
835                   ORDER BY i.relname, a.attnum";
836
837        let rows = match conn
838            .query(
839                cx,
840                sql,
841                &[sqlmodel_core::Value::Text(table_name.to_string())],
842            )
843            .await
844        {
845            Outcome::Ok(rows) => rows,
846            Outcome::Err(e) => return Outcome::Err(e),
847            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
848            Outcome::Panicked(p) => return Outcome::Panicked(p),
849        };
850
851        // Group by index name
852        let mut index_map: HashMap<String, IndexInfo> = HashMap::new();
853
854        for row in &rows {
855            let Ok(name) = row.get_named::<String>("index_name") else {
856                continue;
857            };
858            let Ok(column) = row.get_named::<String>("column_name") else {
859                continue;
860            };
861            let unique = row.get_named::<bool>("is_unique").ok().unwrap_or(false);
862            let primary = row.get_named::<bool>("is_primary").ok().unwrap_or(false);
863            let index_type = row.get_named::<String>("index_type").ok();
864
865            index_map
866                .entry(name.clone())
867                .and_modify(|idx| idx.columns.push(column.clone()))
868                .or_insert_with(|| IndexInfo {
869                    name,
870                    columns: vec![column],
871                    unique,
872                    index_type,
873                    primary,
874                });
875        }
876
877        Outcome::Ok(index_map.into_values().collect())
878    }
879
880    async fn mysql_indexes<C: Connection>(
881        &self,
882        cx: &Cx,
883        conn: &C,
884        table_name: &str,
885    ) -> Outcome<Vec<IndexInfo>, Error> {
886        let sql = format!("SHOW INDEX FROM `{}`", sanitize_identifier(table_name));
887        let rows = match conn.query(cx, &sql, &[]).await {
888            Outcome::Ok(rows) => rows,
889            Outcome::Err(e) => return Outcome::Err(e),
890            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
891            Outcome::Panicked(p) => return Outcome::Panicked(p),
892        };
893
894        // Group by index name
895        let mut index_map: HashMap<String, IndexInfo> = HashMap::new();
896
897        for row in &rows {
898            let Ok(name) = row.get_named::<String>("Key_name") else {
899                continue;
900            };
901            let Ok(column) = row.get_named::<String>("Column_name") else {
902                continue;
903            };
904            let non_unique = row.get_named::<i64>("Non_unique").ok().unwrap_or(1);
905            let index_type = row.get_named::<String>("Index_type").ok();
906            let primary = name == "PRIMARY";
907
908            index_map
909                .entry(name.clone())
910                .and_modify(|idx| idx.columns.push(column.clone()))
911                .or_insert_with(|| IndexInfo {
912                    name,
913                    columns: vec![column],
914                    unique: non_unique == 0,
915                    index_type,
916                    primary,
917                });
918        }
919
920        Outcome::Ok(index_map.into_values().collect())
921    }
922}
923
924// ============================================================================
925// Helper Functions
926// ============================================================================
927
928/// Build a complete PostgreSQL type string from information_schema data.
929fn build_postgres_type(
930    data_type: &str,
931    udt_name: &str,
932    char_len: Option<i64>,
933    precision: Option<i64>,
934    scale: Option<i64>,
935) -> String {
936    // Handle array types
937    if data_type == "ARRAY" {
938        return format!("{}[]", udt_name.trim_start_matches('_'));
939    }
940
941    // For character types with length
942    if let Some(len) = char_len {
943        return format!("{}({})", data_type.to_uppercase(), len);
944    }
945
946    // For numeric types with precision/scale
947    if let (Some(p), Some(s)) = (precision, scale) {
948        if data_type == "numeric" {
949            return format!("NUMERIC({},{})", p, s);
950        }
951    }
952
953    // Default: just return the data type
954    data_type.to_uppercase()
955}
956
957// ============================================================================
958// Unit Tests
959// ============================================================================
960
961#[cfg(test)]
962mod tests {
963    use super::*;
964
965    #[test]
966    fn test_parsed_sql_type_varchar() {
967        let t = ParsedSqlType::parse("VARCHAR(255)");
968        assert_eq!(t.base_type, "VARCHAR");
969        assert_eq!(t.length, Some(255));
970        assert_eq!(t.precision, None);
971        assert_eq!(t.scale, None);
972        assert!(!t.unsigned);
973        assert!(!t.array);
974    }
975
976    #[test]
977    fn test_parsed_sql_type_decimal() {
978        let t = ParsedSqlType::parse("DECIMAL(10,2)");
979        assert_eq!(t.base_type, "DECIMAL");
980        assert_eq!(t.length, None);
981        assert_eq!(t.precision, Some(10));
982        assert_eq!(t.scale, Some(2));
983    }
984
985    #[test]
986    fn test_parsed_sql_type_unsigned() {
987        let t = ParsedSqlType::parse("INT UNSIGNED");
988        assert_eq!(t.base_type, "INT");
989        assert!(t.unsigned);
990    }
991
992    #[test]
993    fn test_parsed_sql_type_array() {
994        let t = ParsedSqlType::parse("TEXT[]");
995        assert_eq!(t.base_type, "TEXT");
996        assert!(t.array);
997    }
998
999    #[test]
1000    fn test_parsed_sql_type_simple() {
1001        let t = ParsedSqlType::parse("INTEGER");
1002        assert_eq!(t.base_type, "INTEGER");
1003        assert_eq!(t.length, None);
1004        assert!(!t.unsigned);
1005        assert!(!t.array);
1006    }
1007
1008    #[test]
1009    fn test_parsed_sql_type_is_text() {
1010        assert!(ParsedSqlType::parse("VARCHAR(100)").is_text());
1011        assert!(ParsedSqlType::parse("TEXT").is_text());
1012        assert!(ParsedSqlType::parse("CHAR(1)").is_text());
1013        assert!(!ParsedSqlType::parse("INTEGER").is_text());
1014    }
1015
1016    #[test]
1017    fn test_parsed_sql_type_is_numeric() {
1018        assert!(ParsedSqlType::parse("INTEGER").is_numeric());
1019        assert!(ParsedSqlType::parse("BIGINT").is_numeric());
1020        assert!(ParsedSqlType::parse("DECIMAL(10,2)").is_numeric());
1021        assert!(!ParsedSqlType::parse("TEXT").is_numeric());
1022    }
1023
1024    #[test]
1025    fn test_parsed_sql_type_is_datetime() {
1026        assert!(ParsedSqlType::parse("DATE").is_datetime());
1027        assert!(ParsedSqlType::parse("TIMESTAMP").is_datetime());
1028        assert!(ParsedSqlType::parse("TIMESTAMPTZ").is_datetime());
1029        assert!(!ParsedSqlType::parse("TEXT").is_datetime());
1030    }
1031
1032    #[test]
1033    fn test_database_schema_new() {
1034        let schema = DatabaseSchema::new(Dialect::Postgres);
1035        assert_eq!(schema.dialect, Dialect::Postgres);
1036        assert!(schema.tables.is_empty());
1037    }
1038
1039    #[test]
1040    fn test_table_info_column() {
1041        let table = TableInfo {
1042            name: "test".to_string(),
1043            columns: vec![ColumnInfo {
1044                name: "id".to_string(),
1045                sql_type: "INTEGER".to_string(),
1046                parsed_type: ParsedSqlType::parse("INTEGER"),
1047                nullable: false,
1048                default: None,
1049                primary_key: true,
1050                auto_increment: true,
1051                comment: None,
1052            }],
1053            primary_key: vec!["id".to_string()],
1054            foreign_keys: Vec::new(),
1055            unique_constraints: Vec::new(),
1056            check_constraints: Vec::new(),
1057            indexes: Vec::new(),
1058            comment: None,
1059        };
1060
1061        assert!(table.column("id").is_some());
1062        assert!(table.column("nonexistent").is_none());
1063        assert!(table.has_auto_pk());
1064    }
1065
1066    #[test]
1067    fn test_build_postgres_type_array() {
1068        let result = build_postgres_type("ARRAY", "_text", None, None, None);
1069        assert_eq!(result, "text[]");
1070    }
1071
1072    #[test]
1073    fn test_build_postgres_type_varchar() {
1074        let result = build_postgres_type("character varying", "", Some(100), None, None);
1075        assert_eq!(result, "CHARACTER VARYING(100)");
1076    }
1077
1078    #[test]
1079    fn test_build_postgres_type_numeric() {
1080        let result = build_postgres_type("numeric", "", None, Some(10), Some(2));
1081        assert_eq!(result, "NUMERIC(10,2)");
1082    }
1083
1084    #[test]
1085    fn test_sanitize_identifier_normal() {
1086        assert_eq!(sanitize_identifier("users"), "users");
1087        assert_eq!(sanitize_identifier("my_table"), "my_table");
1088        assert_eq!(sanitize_identifier("Table123"), "Table123");
1089    }
1090
1091    #[test]
1092    fn test_sanitize_identifier_sql_injection() {
1093        // SQL injection attempts should be sanitized
1094        assert_eq!(sanitize_identifier("users; DROP TABLE--"), "usersDROPTABLE");
1095        assert_eq!(sanitize_identifier("table`; malicious"), "tablemalicious");
1096        assert_eq!(sanitize_identifier("users'--"), "users");
1097        assert_eq!(
1098            sanitize_identifier("table\"); DROP TABLE users;--"),
1099            "tableDROPTABLEusers"
1100        );
1101    }
1102
1103    #[test]
1104    fn test_sanitize_identifier_special_chars() {
1105        // Various special characters should be stripped
1106        assert_eq!(sanitize_identifier("table-name"), "tablename");
1107        assert_eq!(sanitize_identifier("table.name"), "tablename");
1108        assert_eq!(sanitize_identifier("table name"), "tablename");
1109        assert_eq!(sanitize_identifier("table\nname"), "tablename");
1110    }
1111}