Skip to main content

prax_migrate/
introspect.rs

1//! Database introspection for reverse-engineering schemas.
2//!
3//! This module provides functionality to introspect an existing database
4//! and generate a Prax schema from its structure.
5
6use std::collections::HashMap;
7
8use prax_schema::Schema;
9use prax_schema::ast::{
10    Attribute, AttributeArg, AttributeValue, Enum, EnumVariant, Field, FieldType, Ident, Model,
11    ScalarType, Span, TypeModifier,
12};
13
14use crate::error::{MigrateResult, MigrationError};
15
16/// Result of introspecting a database.
17#[derive(Debug, Clone)]
18pub struct IntrospectionResult {
19    /// The generated schema.
20    pub schema: Schema,
21    /// Tables that were skipped.
22    pub skipped_tables: Vec<SkippedTable>,
23    /// Warnings generated during introspection.
24    pub warnings: Vec<String>,
25}
26
27/// A table that was skipped during introspection.
28#[derive(Debug, Clone)]
29pub struct SkippedTable {
30    /// Table name.
31    pub name: String,
32    /// Reason it was skipped.
33    pub reason: String,
34}
35
36/// Configuration for introspection.
37#[derive(Debug, Clone)]
38pub struct IntrospectionConfig {
39    /// Schema to introspect (default: "public").
40    pub database_schema: String,
41    /// Tables to include (empty = all).
42    pub include_tables: Vec<String>,
43    /// Tables to exclude.
44    pub exclude_tables: Vec<String>,
45    /// Whether to include views.
46    pub include_views: bool,
47    /// Whether to include enums.
48    pub include_enums: bool,
49}
50
51impl Default for IntrospectionConfig {
52    fn default() -> Self {
53        Self {
54            database_schema: "public".to_string(),
55            include_tables: Vec::new(),
56            exclude_tables: vec![
57                "_prax_migrations".to_string(),
58                "_prisma_migrations".to_string(),
59                "schema_migrations".to_string(),
60            ],
61            include_views: true,
62            include_enums: true,
63        }
64    }
65}
66
67impl IntrospectionConfig {
68    /// Create a new introspection config.
69    pub fn new() -> Self {
70        Self::default()
71    }
72
73    /// Set the database schema to introspect.
74    pub fn database_schema(mut self, schema: impl Into<String>) -> Self {
75        self.database_schema = schema.into();
76        self
77    }
78
79    /// Include only these tables.
80    pub fn include_tables(mut self, tables: Vec<String>) -> Self {
81        self.include_tables = tables;
82        self
83    }
84
85    /// Exclude these tables.
86    pub fn exclude_tables(mut self, tables: Vec<String>) -> Self {
87        self.exclude_tables = tables;
88        self
89    }
90
91    /// Whether to include views.
92    pub fn include_views(mut self, include: bool) -> Self {
93        self.include_views = include;
94        self
95    }
96
97    /// Whether to include enums.
98    pub fn include_enums(mut self, include: bool) -> Self {
99        self.include_enums = include;
100        self
101    }
102
103    /// Check if a table should be included.
104    pub fn should_include_table(&self, name: &str) -> bool {
105        if self.exclude_tables.contains(&name.to_string()) {
106            return false;
107        }
108        if self.include_tables.is_empty() {
109            return true;
110        }
111        self.include_tables.contains(&name.to_string())
112    }
113}
114
115/// Raw table information from the database.
116#[derive(Debug, Clone)]
117pub struct TableInfo {
118    /// Table name.
119    pub name: String,
120    /// Table schema (e.g., "public").
121    pub schema: String,
122    /// Table type ("BASE TABLE" or "VIEW").
123    pub table_type: String,
124    /// Table comment.
125    pub comment: Option<String>,
126}
127
128/// Raw column information from the database.
129#[derive(Debug, Clone)]
130pub struct ColumnInfo {
131    /// Column name.
132    pub name: String,
133    /// Data type (e.g., "integer", "character varying").
134    pub data_type: String,
135    /// Full UDT name (e.g., "int4", "varchar").
136    pub udt_name: String,
137    /// Character maximum length (for varchar, etc.).
138    pub character_maximum_length: Option<i32>,
139    /// Numeric precision.
140    pub numeric_precision: Option<i32>,
141    /// Whether the column is nullable.
142    pub is_nullable: bool,
143    /// Default value expression.
144    pub column_default: Option<String>,
145    /// Ordinal position.
146    pub ordinal_position: i32,
147    /// Column comment.
148    pub comment: Option<String>,
149}
150
151/// Raw constraint information from the database.
152#[derive(Debug, Clone)]
153pub struct ConstraintInfo {
154    /// Constraint name.
155    pub name: String,
156    /// Constraint type (PRIMARY KEY, UNIQUE, FOREIGN KEY, CHECK).
157    pub constraint_type: String,
158    /// Table name.
159    pub table_name: String,
160    /// Columns in the constraint.
161    pub columns: Vec<String>,
162    /// Referenced table (for foreign keys).
163    pub referenced_table: Option<String>,
164    /// Referenced columns (for foreign keys).
165    pub referenced_columns: Option<Vec<String>>,
166    /// On delete action (for foreign keys).
167    pub on_delete: Option<String>,
168    /// On update action (for foreign keys).
169    pub on_update: Option<String>,
170}
171
172/// Raw enum information from the database.
173#[derive(Debug, Clone)]
174pub struct EnumInfo {
175    /// Enum name.
176    pub name: String,
177    /// Enum values.
178    pub values: Vec<String>,
179    /// Schema the enum belongs to.
180    pub schema: String,
181}
182
183/// Raw index information from the database.
184#[derive(Debug, Clone)]
185pub struct IndexInfo {
186    /// Index name.
187    pub name: String,
188    /// Table name.
189    pub table_name: String,
190    /// Columns in the index.
191    pub columns: Vec<String>,
192    /// Whether the index is unique.
193    pub is_unique: bool,
194    /// Whether this is a primary key index.
195    pub is_primary: bool,
196    /// Index method (btree, hash, etc.).
197    pub index_method: String,
198}
199
200/// Trait for database introspection.
201#[async_trait::async_trait]
202pub trait Introspector: Send + Sync {
203    /// Get all tables in the database.
204    async fn get_tables(&self, config: &IntrospectionConfig) -> MigrateResult<Vec<TableInfo>>;
205
206    /// Get columns for a table.
207    async fn get_columns(&self, table: &str, schema: &str) -> MigrateResult<Vec<ColumnInfo>>;
208
209    /// Get constraints for a table.
210    async fn get_constraints(
211        &self,
212        table: &str,
213        schema: &str,
214    ) -> MigrateResult<Vec<ConstraintInfo>>;
215
216    /// Get indexes for a table.
217    async fn get_indexes(&self, table: &str, schema: &str) -> MigrateResult<Vec<IndexInfo>>;
218
219    /// Get all enums in the database.
220    async fn get_enums(&self, schema: &str) -> MigrateResult<Vec<EnumInfo>>;
221}
222
223/// Build a Prax schema from introspection data.
224pub struct SchemaBuilder {
225    config: IntrospectionConfig,
226    tables: Vec<TableInfo>,
227    columns: HashMap<String, Vec<ColumnInfo>>,
228    constraints: HashMap<String, Vec<ConstraintInfo>>,
229    indexes: HashMap<String, Vec<IndexInfo>>,
230    enums: Vec<EnumInfo>,
231}
232
233impl SchemaBuilder {
234    /// Create a new schema builder.
235    pub fn new(config: IntrospectionConfig) -> Self {
236        Self {
237            config,
238            tables: Vec::new(),
239            columns: HashMap::new(),
240            constraints: HashMap::new(),
241            indexes: HashMap::new(),
242            enums: Vec::new(),
243        }
244    }
245
246    /// Add table information.
247    pub fn with_tables(mut self, tables: Vec<TableInfo>) -> Self {
248        self.tables = tables;
249        self
250    }
251
252    /// Add column information for a table.
253    pub fn with_columns(mut self, table: &str, columns: Vec<ColumnInfo>) -> Self {
254        self.columns.insert(table.to_string(), columns);
255        self
256    }
257
258    /// Add constraint information for a table.
259    pub fn with_constraints(mut self, table: &str, constraints: Vec<ConstraintInfo>) -> Self {
260        self.constraints.insert(table.to_string(), constraints);
261        self
262    }
263
264    /// Add index information for a table.
265    pub fn with_indexes(mut self, table: &str, indexes: Vec<IndexInfo>) -> Self {
266        self.indexes.insert(table.to_string(), indexes);
267        self
268    }
269
270    /// Add enum information.
271    pub fn with_enums(mut self, enums: Vec<EnumInfo>) -> Self {
272        self.enums = enums;
273        self
274    }
275
276    /// Build the schema from the collected information.
277    pub fn build(self) -> MigrateResult<IntrospectionResult> {
278        let mut schema = Schema::new();
279        let mut skipped_tables = Vec::new();
280        let mut warnings = Vec::new();
281
282        // Add enums first (they may be referenced by columns)
283        if self.config.include_enums {
284            for enum_info in &self.enums {
285                let prax_enum = self.build_enum(enum_info);
286                schema.add_enum(prax_enum);
287            }
288        }
289
290        // Build models from tables
291        for table in &self.tables {
292            if !self.config.should_include_table(&table.name) {
293                skipped_tables.push(SkippedTable {
294                    name: table.name.clone(),
295                    reason: "Excluded by configuration".to_string(),
296                });
297                continue;
298            }
299
300            // Skip views if not configured
301            if table.table_type == "VIEW" && !self.config.include_views {
302                continue;
303            }
304
305            match self.build_model(table) {
306                Ok(model) => {
307                    schema.add_model(model);
308                }
309                Err(e) => {
310                    warnings.push(format!("Failed to build model for '{}': {}", table.name, e));
311                    skipped_tables.push(SkippedTable {
312                        name: table.name.clone(),
313                        reason: e.to_string(),
314                    });
315                }
316            }
317        }
318
319        Ok(IntrospectionResult {
320            schema,
321            skipped_tables,
322            warnings,
323        })
324    }
325
326    /// Build an enum from database enum info.
327    fn build_enum(&self, info: &EnumInfo) -> Enum {
328        let span = Span::new(0, 0);
329        let name = Ident::new(to_pascal_case(&info.name), span);
330        let mut prax_enum = Enum::new(name, span);
331
332        for value in &info.values {
333            prax_enum.add_variant(EnumVariant::new(Ident::new(value.clone(), span), span));
334        }
335
336        prax_enum
337    }
338
339    /// Build a model from table info.
340    fn build_model(&self, table: &TableInfo) -> MigrateResult<Model> {
341        let span = Span::new(0, 0);
342        let name = Ident::new(to_pascal_case(&table.name), span);
343        let mut model = Model::new(name, span);
344
345        // Add @@map attribute if table name differs from model name
346        let model_name = to_pascal_case(&table.name);
347        if table.name != model_name && table.name != to_snake_case(&model_name) {
348            model.attributes.push(Attribute::new(
349                Ident::new("map", span),
350                vec![AttributeArg::positional(
351                    AttributeValue::String(table.name.clone()),
352                    span,
353                )],
354                span,
355            ));
356        }
357
358        // Get columns for this table
359        let columns = self.columns.get(&table.name).cloned().unwrap_or_default();
360
361        // Get constraints for this table
362        let constraints = self
363            .constraints
364            .get(&table.name)
365            .cloned()
366            .unwrap_or_default();
367
368        // Find primary key columns
369        let pk_columns: Vec<&str> = constraints
370            .iter()
371            .filter(|c| c.constraint_type == "PRIMARY KEY")
372            .flat_map(|c| c.columns.iter().map(|s| s.as_str()))
373            .collect();
374
375        // Find unique columns
376        let unique_columns: Vec<&str> = constraints
377            .iter()
378            .filter(|c| c.constraint_type == "UNIQUE")
379            .filter(|c| c.columns.len() == 1)
380            .flat_map(|c| c.columns.iter().map(|s| s.as_str()))
381            .collect();
382
383        // Build fields from columns
384        for column in columns {
385            let field = self.build_field(&column, &pk_columns, &unique_columns)?;
386            model.add_field(field);
387        }
388
389        Ok(model)
390    }
391
392    /// Build a field from column info.
393    fn build_field(
394        &self,
395        column: &ColumnInfo,
396        pk_columns: &[&str],
397        unique_columns: &[&str],
398    ) -> MigrateResult<Field> {
399        let span = Span::new(0, 0);
400        let name = Ident::new(&column.name, span);
401
402        // Map SQL type to Prax type
403        let (field_type, needs_map) = self.sql_type_to_prax(&column.udt_name, &column.data_type)?;
404
405        // Determine modifier
406        let modifier = if column.is_nullable {
407            TypeModifier::Optional
408        } else {
409            TypeModifier::Required
410        };
411
412        let mut attributes = Vec::new();
413
414        // Add @id if this is a primary key
415        if pk_columns.contains(&column.name.as_str()) {
416            attributes.push(Attribute::simple(Ident::new("id", span), span));
417
418            // Check for auto-increment
419            if let Some(default) = &column.column_default
420                && (default.contains("nextval") || default.contains("SERIAL"))
421            {
422                attributes.push(Attribute::simple(Ident::new("auto", span), span));
423            }
424        }
425
426        // Add @unique if this is a unique column
427        if unique_columns.contains(&column.name.as_str()) {
428            attributes.push(Attribute::simple(Ident::new("unique", span), span));
429        }
430
431        // Add @default if there's a default value (skip auto-increment defaults)
432        if let Some(default) = &column.column_default
433            && !default.contains("nextval")
434            && let Some(value) = parse_default_value(default)
435        {
436            attributes.push(Attribute::new(
437                Ident::new("default", span),
438                vec![AttributeArg::positional(value, span)],
439                span,
440            ));
441        }
442
443        // Add @map if column name differs from field name
444        if needs_map {
445            attributes.push(Attribute::new(
446                Ident::new("map", span),
447                vec![AttributeArg::positional(
448                    AttributeValue::String(column.name.clone()),
449                    span,
450                )],
451                span,
452            ));
453        }
454
455        Ok(Field::new(name, field_type, modifier, attributes, span))
456    }
457
458    /// Convert SQL type to Prax field type.
459    fn sql_type_to_prax(
460        &self,
461        udt_name: &str,
462        data_type: &str,
463    ) -> MigrateResult<(FieldType, bool)> {
464        // Check if this is a known enum
465        let enum_names: Vec<&str> = self.enums.iter().map(|e| e.name.as_str()).collect();
466        if enum_names.contains(&udt_name) {
467            return Ok((FieldType::Enum(to_pascal_case(udt_name).into()), false));
468        }
469
470        let scalar = match udt_name {
471            "int2" | "int4" | "integer" | "smallint" => ScalarType::Int,
472            "int8" | "bigint" => ScalarType::BigInt,
473            "float4" | "float8" | "real" | "double precision" => ScalarType::Float,
474            "numeric" | "decimal" | "money" => ScalarType::Decimal,
475            "text" | "varchar" | "char" | "character varying" | "character" | "bpchar" => {
476                ScalarType::String
477            }
478            "bool" | "boolean" => ScalarType::Boolean,
479            "timestamp"
480            | "timestamptz"
481            | "timestamp with time zone"
482            | "timestamp without time zone" => ScalarType::DateTime,
483            "date" => ScalarType::Date,
484            "time" | "timetz" | "time with time zone" | "time without time zone" => {
485                ScalarType::Time
486            }
487            "json" | "jsonb" => ScalarType::Json,
488            "bytea" => ScalarType::Bytes,
489            "uuid" => ScalarType::Uuid,
490            _ => {
491                // Try to match by data_type as fallback
492                match data_type {
493                    "integer" | "smallint" => ScalarType::Int,
494                    "bigint" => ScalarType::BigInt,
495                    "real" | "double precision" => ScalarType::Float,
496                    "numeric" => ScalarType::Decimal,
497                    "character varying" | "character" | "text" => ScalarType::String,
498                    "boolean" => ScalarType::Boolean,
499                    "timestamp with time zone" | "timestamp without time zone" => {
500                        ScalarType::DateTime
501                    }
502                    "date" => ScalarType::Date,
503                    "time with time zone" | "time without time zone" => ScalarType::Time,
504                    "json" | "jsonb" => ScalarType::Json,
505                    "bytea" => ScalarType::Bytes,
506                    "uuid" => ScalarType::Uuid,
507                    "ARRAY" => {
508                        // Arrays are complex - for now, treat as Json
509                        ScalarType::Json
510                    }
511                    "USER-DEFINED" => {
512                        // This might be an enum we haven't seen
513                        return Err(MigrationError::InvalidMigration(format!(
514                            "Unknown user-defined type: {}",
515                            udt_name
516                        )));
517                    }
518                    _ => {
519                        return Err(MigrationError::InvalidMigration(format!(
520                            "Unknown SQL type: {} ({})",
521                            udt_name, data_type
522                        )));
523                    }
524                }
525            }
526        };
527
528        Ok((FieldType::Scalar(scalar), false))
529    }
530}
531
532/// Parse a default value expression into an AttributeValue.
533fn parse_default_value(default: &str) -> Option<AttributeValue> {
534    let trimmed = default.trim();
535
536    // Handle booleans
537    if trimmed == "true" || trimmed == "TRUE" {
538        return Some(AttributeValue::Boolean(true));
539    }
540    if trimmed == "false" || trimmed == "FALSE" {
541        return Some(AttributeValue::Boolean(false));
542    }
543
544    // Handle NULL
545    if trimmed.to_uppercase() == "NULL" {
546        return None;
547    }
548
549    // Handle integers
550    if let Ok(int) = trimmed.parse::<i64>() {
551        return Some(AttributeValue::Int(int));
552    }
553
554    // Handle floats
555    if let Ok(float) = trimmed.parse::<f64>() {
556        return Some(AttributeValue::Float(float));
557    }
558
559    // Handle strings (enclosed in quotes)
560    if (trimmed.starts_with('\'') && trimmed.ends_with('\''))
561        || (trimmed.starts_with('"') && trimmed.ends_with('"'))
562    {
563        let inner = &trimmed[1..trimmed.len() - 1];
564        return Some(AttributeValue::String(inner.to_string()));
565    }
566
567    // Handle PostgreSQL type casts (e.g., 'value'::type)
568    if let Some(pos) = trimmed.find("::") {
569        return parse_default_value(&trimmed[..pos]);
570    }
571
572    // Handle function calls (e.g., now(), uuid_generate_v4())
573    if trimmed.ends_with("()") || trimmed.contains('(') {
574        let func_name = if let Some(paren_pos) = trimmed.find('(') {
575            &trimmed[..paren_pos]
576        } else {
577            &trimmed[..trimmed.len() - 2]
578        };
579        return Some(AttributeValue::Function(
580            func_name.to_string().into(),
581            vec![],
582        ));
583    }
584
585    // Unknown default - return as string
586    Some(AttributeValue::String(trimmed.to_string()))
587}
588
589/// Convert snake_case to PascalCase.
590fn to_pascal_case(s: &str) -> String {
591    s.split('_')
592        .filter(|part| !part.is_empty())
593        .map(|part| {
594            let mut chars = part.chars();
595            match chars.next() {
596                None => String::new(),
597                Some(first) => first.to_uppercase().chain(chars).collect(),
598            }
599        })
600        .collect()
601}
602
603/// Convert PascalCase to snake_case.
604fn to_snake_case(s: &str) -> String {
605    let mut result = String::new();
606    for (i, ch) in s.chars().enumerate() {
607        if ch.is_uppercase() {
608            if i > 0 {
609                result.push('_');
610            }
611            result.push(ch.to_lowercase().next().unwrap_or(ch));
612        } else {
613            result.push(ch);
614        }
615    }
616    result
617}
618
619/// SQL queries for PostgreSQL introspection.
620pub mod postgres_queries {
621    /// Query to get all tables and views.
622    pub const TABLES: &str = r#"
623        SELECT
624            table_name,
625            table_schema,
626            table_type
627        FROM information_schema.tables
628        WHERE table_schema = $1
629        ORDER BY table_name
630    "#;
631
632    /// Query to get columns for a table.
633    pub const COLUMNS: &str = r#"
634        SELECT
635            column_name,
636            data_type,
637            udt_name,
638            character_maximum_length,
639            numeric_precision,
640            is_nullable = 'YES' as is_nullable,
641            column_default,
642            ordinal_position
643        FROM information_schema.columns
644        WHERE table_schema = $1 AND table_name = $2
645        ORDER BY ordinal_position
646    "#;
647
648    /// Query to get constraints.
649    pub const CONSTRAINTS: &str = r#"
650        SELECT
651            tc.constraint_name,
652            tc.constraint_type,
653            tc.table_name,
654            kcu.column_name,
655            ccu.table_name AS referenced_table,
656            ccu.column_name AS referenced_column,
657            rc.delete_rule,
658            rc.update_rule
659        FROM information_schema.table_constraints tc
660        LEFT JOIN information_schema.key_column_usage kcu
661            ON tc.constraint_name = kcu.constraint_name
662            AND tc.table_schema = kcu.table_schema
663        LEFT JOIN information_schema.constraint_column_usage ccu
664            ON tc.constraint_name = ccu.constraint_name
665            AND tc.table_schema = ccu.table_schema
666            AND tc.constraint_type = 'FOREIGN KEY'
667        LEFT JOIN information_schema.referential_constraints rc
668            ON tc.constraint_name = rc.constraint_name
669            AND tc.table_schema = rc.constraint_schema
670        WHERE tc.table_schema = $1 AND tc.table_name = $2
671        ORDER BY tc.constraint_name, kcu.ordinal_position
672    "#;
673
674    /// Query to get indexes.
675    pub const INDEXES: &str = r#"
676        SELECT
677            i.relname AS index_name,
678            t.relname AS table_name,
679            array_agg(a.attname ORDER BY array_position(ix.indkey, a.attnum)) AS columns,
680            ix.indisunique AS is_unique,
681            ix.indisprimary AS is_primary,
682            am.amname AS index_method
683        FROM pg_index ix
684        JOIN pg_class i ON ix.indexrelid = i.oid
685        JOIN pg_class t ON ix.indrelid = t.oid
686        JOIN pg_namespace n ON t.relnamespace = n.oid
687        JOIN pg_am am ON i.relam = am.oid
688        JOIN pg_attribute a ON a.attrelid = t.oid AND a.attnum = ANY(ix.indkey)
689        WHERE n.nspname = $1 AND t.relname = $2
690        GROUP BY i.relname, t.relname, ix.indisunique, ix.indisprimary, am.amname
691    "#;
692
693    /// Query to get enums.
694    pub const ENUMS: &str = r#"
695        SELECT
696            t.typname AS enum_name,
697            n.nspname AS schema_name,
698            array_agg(e.enumlabel ORDER BY e.enumsortorder) AS enum_values
699        FROM pg_type t
700        JOIN pg_namespace n ON t.typnamespace = n.oid
701        JOIN pg_enum e ON t.oid = e.enumtypid
702        WHERE n.nspname = $1
703        GROUP BY t.typname, n.nspname
704    "#;
705}
706
707#[cfg(test)]
708mod tests {
709    use super::*;
710
711    #[test]
712    fn test_to_pascal_case() {
713        assert_eq!(to_pascal_case("user"), "User");
714        assert_eq!(to_pascal_case("user_profile"), "UserProfile");
715        assert_eq!(
716            to_pascal_case("user_profile_settings"),
717            "UserProfileSettings"
718        );
719        assert_eq!(to_pascal_case("_user_"), "User");
720    }
721
722    #[test]
723    fn test_to_snake_case() {
724        assert_eq!(to_snake_case("User"), "user");
725        assert_eq!(to_snake_case("UserProfile"), "user_profile");
726        assert_eq!(to_snake_case("HTTPResponse"), "h_t_t_p_response");
727    }
728
729    #[test]
730    fn test_parse_default_value_boolean() {
731        assert!(matches!(
732            parse_default_value("true"),
733            Some(AttributeValue::Boolean(true))
734        ));
735        assert!(matches!(
736            parse_default_value("false"),
737            Some(AttributeValue::Boolean(false))
738        ));
739    }
740
741    #[test]
742    fn test_parse_default_value_int() {
743        assert!(matches!(
744            parse_default_value("42"),
745            Some(AttributeValue::Int(42))
746        ));
747        assert!(matches!(
748            parse_default_value("-5"),
749            Some(AttributeValue::Int(-5))
750        ));
751    }
752
753    #[test]
754    #[allow(clippy::approx_constant)]
755    fn test_parse_default_value_float() {
756        if let Some(AttributeValue::Float(f)) = parse_default_value("3.14") {
757            assert!((f - 3.14).abs() < 0.001);
758        } else {
759            panic!("Expected Float");
760        }
761    }
762
763    #[test]
764    fn test_parse_default_value_string() {
765        if let Some(AttributeValue::String(s)) = parse_default_value("'hello'") {
766            assert_eq!(s.as_str(), "hello");
767        } else {
768            panic!("Expected String");
769        }
770    }
771
772    #[test]
773    fn test_parse_default_value_function() {
774        if let Some(AttributeValue::Function(name, args)) = parse_default_value("now()") {
775            assert_eq!(name.as_str(), "now");
776            assert!(args.is_empty());
777        } else {
778            panic!("Expected Function");
779        }
780    }
781
782    #[test]
783    fn test_parse_default_value_with_cast() {
784        if let Some(AttributeValue::String(s)) = parse_default_value("'active'::status_type") {
785            assert_eq!(s.as_str(), "active");
786        } else {
787            panic!("Expected String");
788        }
789    }
790
791    #[test]
792    fn test_config_should_include_table() {
793        let config = IntrospectionConfig::default();
794        assert!(config.should_include_table("users"));
795        assert!(!config.should_include_table("_prax_migrations"));
796    }
797
798    #[test]
799    fn test_config_include_specific_tables() {
800        let config = IntrospectionConfig::new().include_tables(vec!["users".to_string()]);
801        assert!(config.should_include_table("users"));
802        assert!(!config.should_include_table("posts"));
803    }
804
805    #[test]
806    fn test_sql_type_mapping() {
807        let builder = SchemaBuilder::new(IntrospectionConfig::default());
808
809        let (ft, _) = builder.sql_type_to_prax("int4", "integer").unwrap();
810        assert!(matches!(ft, FieldType::Scalar(ScalarType::Int)));
811
812        let (ft, _) = builder.sql_type_to_prax("text", "text").unwrap();
813        assert!(matches!(ft, FieldType::Scalar(ScalarType::String)));
814
815        let (ft, _) = builder.sql_type_to_prax("bool", "boolean").unwrap();
816        assert!(matches!(ft, FieldType::Scalar(ScalarType::Boolean)));
817
818        let (ft, _) = builder
819            .sql_type_to_prax("timestamptz", "timestamp with time zone")
820            .unwrap();
821        assert!(matches!(ft, FieldType::Scalar(ScalarType::DateTime)));
822
823        let (ft, _) = builder.sql_type_to_prax("uuid", "uuid").unwrap();
824        assert!(matches!(ft, FieldType::Scalar(ScalarType::Uuid)));
825    }
826}