Skip to main content

pg2sqlite_core/pg/
parser.rs

1/// PostgreSQL DDL parser using sqlparser-rs.
2///
3/// Converts sqlparser AST into our internal representation (IR).
4use sqlparser::ast::{
5    self, AlterColumnOperation, AlterTableOperation, Array, ArrayElemTypeDef, BinaryOperator,
6    ColumnDef, ColumnOption, CreateIndex, DataType, Expr as SqlExpr, ObjectName, ObjectNamePart,
7    ReferentialAction, Statement, TableConstraint as SqlConstraint, UserDefinedTypeRepresentation,
8    ValueWithSpan,
9};
10use sqlparser::dialect::PostgreSqlDialect;
11use sqlparser::parser::Parser;
12
13use crate::diagnostics::warning::{self, Severity, Warning};
14use crate::ir::{
15    AlterConstraint, AlterIdentity, Column, EnumDef, Expr, FkAction, ForeignKeyRef, Ident, Index,
16    IndexColumn, IndexMethod, PgType, QualifiedName, SchemaModel, Sequence, Table, TableConstraint,
17};
18
19/// Strip the parenthesized sequence-options block from `AS IDENTITY (...)` statements.
20///
21/// pg_dump emits sequence options (SEQUENCE NAME, START WITH, INCREMENT BY, etc.)
22/// inside the identity block in an order that sqlparser cannot parse.
23/// Since we only need to know the column has identity (not the sequence details),
24/// we strip the entire parenthesized block.
25fn strip_identity_options(input: &str) -> String {
26    let mut result = String::with_capacity(input.len());
27    let mut pos = 0;
28    let bytes = input.as_bytes();
29    let needle = b"AS IDENTITY";
30
31    while pos < bytes.len() {
32        if let Some(idx) = bytes[pos..]
33            .windows(needle.len())
34            .position(|w| w.eq_ignore_ascii_case(needle))
35        {
36            let identity_end = pos + idx + needle.len();
37            result.push_str(&input[pos..identity_end]);
38            pos = identity_end;
39
40            // Skip whitespace then look for '('
41            let rest = &input[pos..];
42            let trimmed = rest.trim_start();
43            if trimmed.starts_with('(') {
44                let ws_len = rest.len() - trimmed.len();
45                let paren_start = pos + ws_len;
46                // Find matching closing paren
47                if let Some(close) = find_matching_paren(input, paren_start) {
48                    pos = close + 1; // skip past ')'
49                    continue;
50                }
51            }
52        } else {
53            result.push_str(&input[pos..]);
54            break;
55        }
56    }
57
58    result
59}
60
61/// Find the position of the closing ')' matching the '(' at `start`.
62fn find_matching_paren(input: &str, start: usize) -> Option<usize> {
63    let bytes = input.as_bytes();
64    if bytes[start] != b'(' {
65        return None;
66    }
67    let mut depth = 1;
68    let mut i = start + 1;
69    while i < bytes.len() && depth > 0 {
70        match bytes[i] {
71            b'(' => depth += 1,
72            b')' => {
73                depth -= 1;
74                if depth == 0 {
75                    return Some(i);
76                }
77            }
78            _ => {}
79        }
80        i += 1;
81    }
82    None
83}
84
85/// Parse PostgreSQL DDL text into an IR SchemaModel.
86pub fn parse(input: &str) -> (SchemaModel, Vec<Warning>) {
87    let dialect = PostgreSqlDialect {};
88    let mut model = SchemaModel::default();
89    let mut warnings = Vec::new();
90
91    let cleaned = strip_identity_options(input);
92    let statements = match Parser::parse_sql(&dialect, &cleaned) {
93        Ok(stmts) => stmts,
94        Err(e) => {
95            warnings.push(Warning::new(
96                warning::PARSE_SKIPPED,
97                Severity::Error,
98                format!("Failed to parse DDL: {e}"),
99            ));
100            return (model, warnings);
101        }
102    };
103
104    for stmt in statements {
105        match stmt {
106            Statement::CreateTable(ct) => {
107                if let Some(table) = parse_create_table(&ct, &mut warnings) {
108                    model.tables.push(table);
109                }
110            }
111            Statement::CreateIndex(ci) => {
112                if let Some(idx) = parse_create_index(&ci, &mut warnings) {
113                    model.indexes.push(idx);
114                }
115            }
116            Statement::CreateSequence { name, .. } => {
117                model.sequences.push(Sequence {
118                    name: convert_object_name(&name),
119                    owned_by: None,
120                });
121            }
122            Statement::AlterTable(alter_table) => {
123                let table_name = convert_object_name(&alter_table.name);
124                for op in &alter_table.operations {
125                    match parse_alter_table_op(&table_name, op, &mut warnings) {
126                        AlterResult::Constraint(c) => model.alter_constraints.push(c),
127                        AlterResult::Identity(id) => model.identity_columns.push(id),
128                        AlterResult::None => {}
129                    }
130                }
131            }
132            Statement::CreateType {
133                name,
134                representation: Some(UserDefinedTypeRepresentation::Enum { labels }),
135                ..
136            } => {
137                let values: Vec<String> = labels.into_iter().map(|v| v.to_string()).collect();
138                model.enums.push(EnumDef {
139                    name: convert_object_name(&name),
140                    values,
141                });
142            }
143            // Skip non-DDL statements silently
144            _ => {}
145        }
146    }
147
148    (model, warnings)
149}
150
151fn parse_create_table(ct: &ast::CreateTable, warnings: &mut [Warning]) -> Option<Table> {
152    let name = convert_object_name(&ct.name);
153    let mut columns = Vec::new();
154    let mut constraints = Vec::new();
155
156    for element in &ct.columns {
157        columns.push(parse_column(element));
158    }
159
160    for constraint in &ct.constraints {
161        if let Some(tc) = parse_table_constraint(constraint, warnings) {
162            constraints.push(tc);
163        }
164    }
165
166    Some(Table {
167        name,
168        columns,
169        constraints,
170    })
171}
172
173fn parse_column(col_def: &ColumnDef) -> Column {
174    let name = Ident::new(&col_def.name.value);
175    let pg_type = convert_data_type(&col_def.data_type);
176    let mut not_null = false;
177    let mut default = None;
178    let mut is_primary_key = false;
179    let mut is_unique = false;
180    let mut references = None;
181    let mut check = None;
182
183    for opt in &col_def.options {
184        match &opt.option {
185            ColumnOption::NotNull => not_null = true,
186            ColumnOption::Null => not_null = false,
187            ColumnOption::Default(expr) => {
188                default = Some(convert_sql_expr(expr));
189            }
190            ColumnOption::PrimaryKey(_) => {
191                is_primary_key = true;
192            }
193            ColumnOption::Unique(_) => {
194                is_unique = true;
195            }
196            ColumnOption::ForeignKey(fk) => {
197                let ref_col = fk.referred_columns.first().map(|c| Ident::new(&c.value));
198                references = Some(ForeignKeyRef {
199                    table: convert_object_name(&fk.foreign_table),
200                    column: ref_col,
201                    on_delete: fk.on_delete.as_ref().and_then(convert_referential_action),
202                    on_update: fk.on_update.as_ref().and_then(convert_referential_action),
203                });
204            }
205            ColumnOption::Check(ck) => {
206                check = Some(convert_sql_expr(&ck.expr));
207            }
208            _ => {}
209        }
210    }
211
212    Column {
213        name,
214        pg_type,
215        sqlite_type: None,
216        not_null,
217        default,
218        is_primary_key,
219        is_unique,
220        autoincrement: false,
221        references,
222        check,
223    }
224}
225
226fn parse_table_constraint(
227    constraint: &SqlConstraint,
228    _warnings: &mut [Warning],
229) -> Option<TableConstraint> {
230    match constraint {
231        SqlConstraint::PrimaryKey(pk) => {
232            let cols: Vec<Ident> = pk
233                .columns
234                .iter()
235                .map(|c| Ident::new(&c.column.expr.to_string()))
236                .collect();
237            Some(TableConstraint::PrimaryKey {
238                name: pk.name.as_ref().map(|n| Ident::new(&n.value)),
239                columns: cols,
240            })
241        }
242        SqlConstraint::Unique(uq) => {
243            let cols: Vec<Ident> = uq
244                .columns
245                .iter()
246                .map(|c| Ident::new(&c.column.expr.to_string()))
247                .collect();
248            Some(TableConstraint::Unique {
249                name: uq.name.as_ref().map(|n| Ident::new(&n.value)),
250                columns: cols,
251            })
252        }
253        SqlConstraint::ForeignKey(fk) => Some(TableConstraint::ForeignKey {
254            name: fk.name.as_ref().map(|n| Ident::new(&n.value)),
255            columns: fk.columns.iter().map(|c| Ident::new(&c.value)).collect(),
256            ref_table: convert_object_name(&fk.foreign_table),
257            ref_columns: fk
258                .referred_columns
259                .iter()
260                .map(|c| Ident::new(&c.value))
261                .collect(),
262            on_delete: fk.on_delete.as_ref().and_then(convert_referential_action),
263            on_update: fk.on_update.as_ref().and_then(convert_referential_action),
264            deferrable: false,
265        }),
266        SqlConstraint::Check(ck) => Some(TableConstraint::Check {
267            name: ck.name.as_ref().map(|n| Ident::new(&n.value)),
268            expr: convert_sql_expr(&ck.expr),
269        }),
270        _ => None,
271    }
272}
273
274fn parse_create_index(ci: &CreateIndex, _warnings: &mut [Warning]) -> Option<Index> {
275    let index_name = ci.name.as_ref()?;
276    let name = Ident::new(&index_name.to_string());
277    let table = convert_object_name(&ci.table_name);
278
279    let mut columns = Vec::new();
280    for col in &ci.columns {
281        let col_name = col.column.expr.to_string();
282        // Check if this looks like a function call / expression
283        if col_name.contains('(') {
284            columns.push(IndexColumn::Expression(Expr::Raw(col_name)));
285        } else {
286            columns.push(IndexColumn::Column(Ident::new(&col_name)));
287        }
288    }
289
290    let method = ci.using.as_ref().and_then(|m| match m {
291        ast::IndexType::BTree => Some(IndexMethod::Btree),
292        ast::IndexType::Hash => Some(IndexMethod::Hash),
293        ast::IndexType::GIN => Some(IndexMethod::Gin),
294        ast::IndexType::GiST => Some(IndexMethod::Gist),
295        ast::IndexType::SPGiST => Some(IndexMethod::SpGist),
296        ast::IndexType::BRIN => Some(IndexMethod::Brin),
297        _ => None,
298    });
299
300    let where_clause = ci.predicate.as_ref().map(convert_sql_expr);
301
302    Some(Index {
303        name,
304        table,
305        columns,
306        unique: ci.unique,
307        method,
308        where_clause,
309    })
310}
311
312enum AlterResult {
313    Constraint(AlterConstraint),
314    Identity(AlterIdentity),
315    None,
316}
317
318fn parse_alter_table_op(
319    table: &QualifiedName,
320    op: &AlterTableOperation,
321    warnings: &mut [Warning],
322) -> AlterResult {
323    match op {
324        AlterTableOperation::AddConstraint { constraint, .. } => {
325            match parse_table_constraint(constraint, warnings) {
326                Some(c) => AlterResult::Constraint(AlterConstraint {
327                    table: table.clone(),
328                    constraint: c,
329                }),
330                None => AlterResult::None,
331            }
332        }
333        AlterTableOperation::AlterColumn {
334            column_name,
335            op: AlterColumnOperation::AddGenerated { .. },
336            ..
337        } => AlterResult::Identity(AlterIdentity {
338            table: table.clone(),
339            column: Ident::new(&column_name.value),
340        }),
341        _ => AlterResult::None,
342    }
343}
344
345/// Convert sqlparser ObjectName to our QualifiedName.
346fn convert_object_name(name: &ObjectName) -> QualifiedName {
347    let parts: Vec<&str> = name
348        .0
349        .iter()
350        .filter_map(|part| match part {
351            ObjectNamePart::Identifier(ident) => Some(ident.value.as_str()),
352            _ => None,
353        })
354        .collect();
355    match parts.len() {
356        1 => QualifiedName::new(Ident::new(parts[0])),
357        2 => QualifiedName::with_schema(Ident::new(parts[0]), Ident::new(parts[1])),
358        _ => {
359            // Take the last two parts as schema.table
360            let len = parts.len();
361            QualifiedName::with_schema(Ident::new(parts[len - 2]), Ident::new(parts[len - 1]))
362        }
363    }
364}
365
366/// Convert sqlparser DataType to our PgType.
367fn convert_data_type(dt: &DataType) -> PgType {
368    match dt {
369        DataType::SmallInt(_) | DataType::Int2(_) => PgType::SmallInt,
370        DataType::Integer(_) | DataType::Int(_) | DataType::Int4(_) => PgType::Integer,
371        DataType::BigInt(_) | DataType::Int8(_) => PgType::BigInt,
372        DataType::Real | DataType::Float4 => PgType::Real,
373        DataType::Double(_) | DataType::DoublePrecision | DataType::Float8 => {
374            PgType::DoublePrecision
375        }
376        DataType::Numeric(info) | DataType::Decimal(info) => {
377            let (precision, scale) = extract_numeric_info(info);
378            PgType::Numeric { precision, scale }
379        }
380        DataType::Boolean => PgType::Boolean,
381        DataType::Text => PgType::Text,
382        DataType::Varchar(len) | DataType::CharacterVarying(len) => PgType::Varchar {
383            length: extract_char_length(len),
384        },
385        DataType::Char(len) | DataType::Character(len) => PgType::Char {
386            length: extract_char_length(len),
387        },
388        DataType::Date => PgType::Date,
389        DataType::Time(_, tz) => PgType::Time {
390            with_tz: matches!(tz, ast::TimezoneInfo::WithTimeZone),
391        },
392        DataType::Timestamp(_, tz) => PgType::Timestamp {
393            with_tz: matches!(tz, ast::TimezoneInfo::WithTimeZone),
394        },
395        DataType::Interval { .. } => PgType::Interval,
396        DataType::Bytea => PgType::Bytea,
397        DataType::Uuid => PgType::Uuid,
398        DataType::JSON => PgType::Json,
399        DataType::JSONB => PgType::Jsonb,
400        DataType::Blob(_) => PgType::Bytea,
401        DataType::Array(
402            ArrayElemTypeDef::SquareBracket(inner, _) | ArrayElemTypeDef::AngleBracket(inner),
403        ) => PgType::Array {
404            element: Box::new(convert_data_type(inner)),
405        },
406        DataType::Array(_) => PgType::Other {
407            name: dt.to_string(),
408        },
409        DataType::Custom(name, _) => {
410            // Use the last part of the name to handle schema-qualified types (e.g., pg_catalog.serial)
411            let type_name = name
412                .0
413                .iter()
414                .filter_map(|part| match part {
415                    ObjectNamePart::Identifier(ident) => Some(ident.value.to_lowercase()),
416                    _ => None,
417                })
418                .next_back()
419                .unwrap_or_default();
420            match type_name.as_str() {
421                "serial" => PgType::Serial,
422                "bigserial" => PgType::BigSerial,
423                "smallserial" => PgType::SmallSerial,
424                "inet" => PgType::Inet,
425                "cidr" => PgType::Cidr,
426                "macaddr" | "macaddr8" => PgType::MacAddr,
427                "money" => PgType::Money,
428                "xml" => PgType::Xml,
429                "point" => PgType::Point,
430                "line" => PgType::Line,
431                "lseg" => PgType::Lseg,
432                "box" => PgType::Box,
433                "path" => PgType::Path,
434                "polygon" => PgType::Polygon,
435                "circle" => PgType::Circle,
436                "int4range" => PgType::Int4Range,
437                "int8range" => PgType::Int8Range,
438                "numrange" => PgType::NumRange,
439                "tsrange" => PgType::TsRange,
440                "tstzrange" => PgType::TsTzRange,
441                "daterange" => PgType::DateRange,
442                _ => PgType::Other { name: type_name },
443            }
444        }
445        _ => PgType::Other {
446            name: dt.to_string(),
447        },
448    }
449}
450
451/// Convert sqlparser Expr to our Expr.
452fn convert_sql_expr(expr: &SqlExpr) -> Expr {
453    match expr {
454        SqlExpr::Value(val) => convert_value_with_span(val),
455        SqlExpr::Identifier(ident) => Expr::ColumnRef(ident.value.clone()),
456        SqlExpr::CompoundIdentifier(idents) => {
457            let name: Vec<&str> = idents.iter().map(|i| i.value.as_str()).collect();
458            Expr::ColumnRef(name.join("."))
459        }
460        SqlExpr::Function(func) => {
461            let func_name = func.name.to_string().to_lowercase();
462            let args: Vec<Expr> = match &func.args {
463                ast::FunctionArguments::List(arg_list) => arg_list
464                    .args
465                    .iter()
466                    .filter_map(|arg| match arg {
467                        ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)) => {
468                            Some(convert_sql_expr(e))
469                        }
470                        _ => None,
471                    })
472                    .collect(),
473                _ => Vec::new(),
474            };
475
476            // Detect nextval('sequence_name')
477            if func_name == "nextval"
478                && let Some(Expr::StringLiteral(seq)) = args.first()
479            {
480                return Expr::NextVal(seq.clone());
481            }
482
483            Expr::FunctionCall {
484                name: func_name,
485                args,
486            }
487        }
488        SqlExpr::Cast {
489            expr, data_type, ..
490        } => Expr::Cast {
491            expr: Box::new(convert_sql_expr(expr)),
492            type_name: data_type.to_string(),
493        },
494        SqlExpr::BinaryOp { left, op, right } => Expr::BinaryOp {
495            left: Box::new(convert_sql_expr(left)),
496            op: op.to_string(),
497            right: Box::new(convert_sql_expr(right)),
498        },
499        SqlExpr::UnaryOp { op, expr } => Expr::UnaryOp {
500            op: op.to_string(),
501            expr: Box::new(convert_sql_expr(expr)),
502        },
503        SqlExpr::IsNull(expr) => Expr::IsNull {
504            expr: Box::new(convert_sql_expr(expr)),
505            negated: false,
506        },
507        SqlExpr::IsNotNull(expr) => Expr::IsNull {
508            expr: Box::new(convert_sql_expr(expr)),
509            negated: true,
510        },
511        SqlExpr::InList {
512            expr,
513            list,
514            negated,
515        } => Expr::InList {
516            expr: Box::new(convert_sql_expr(expr)),
517            list: list.iter().map(convert_sql_expr).collect(),
518            negated: *negated,
519        },
520        SqlExpr::Between {
521            expr,
522            low,
523            high,
524            negated,
525        } => Expr::Between {
526            expr: Box::new(convert_sql_expr(expr)),
527            low: Box::new(convert_sql_expr(low)),
528            high: Box::new(convert_sql_expr(high)),
529            negated: *negated,
530        },
531        SqlExpr::Nested(inner) => Expr::Nested(Box::new(convert_sql_expr(inner))),
532        // col = ANY(ARRAY['a', 'b']) → col IN ('a', 'b')
533        // Only convert when the right-hand side is an ARRAY literal.
534        // Non-array forms (e.g., subqueries) fall through to Raw to avoid
535        // producing a semantically incorrect single-element InList.
536        SqlExpr::AnyOp {
537            left,
538            compare_op: BinaryOperator::Eq,
539            right,
540            ..
541        } => match extract_array_elements(right) {
542            Some(list) => {
543                let left_expr = convert_sql_expr(left);
544                Expr::InList {
545                    expr: Box::new(left_expr),
546                    list,
547                    negated: false,
548                }
549            }
550            None => Expr::Raw(expr.to_string()),
551        },
552        // Note: `expr != ANY(ARRAY[...])` is NOT equivalent to `NOT IN (...)`.
553        // `!= ANY` is true if expr differs from *at least one* element,
554        // whereas `NOT IN` requires expr to differ from *all* elements.
555        // The correct equivalent of `NOT IN` is `!= ALL(...)`, not `!= ANY(...)`.
556        // We let `!= ANY` fall through to the Raw fallback below.
557        //
558        // Fallback: render back to SQL string
559        _ => Expr::Raw(expr.to_string()),
560    }
561}
562
563/// Extract elements from an ARRAY literal expression.
564/// Returns `None` for non-array expressions (e.g., subqueries) so callers
565/// can fall back to `Expr::Raw` instead of producing incorrect results.
566fn extract_array_elements(expr: &SqlExpr) -> Option<Vec<Expr>> {
567    match expr {
568        SqlExpr::Array(Array { elem, .. }) => Some(elem.iter().map(convert_sql_expr).collect()),
569        _ => None,
570    }
571}
572
573fn convert_value_with_span(val: &ValueWithSpan) -> Expr {
574    convert_value(&val.value)
575}
576
577fn convert_value(val: &ast::Value) -> Expr {
578    match val {
579        ast::Value::Number(n, _) => {
580            if let Ok(i) = n.parse::<i64>() {
581                Expr::IntegerLiteral(i)
582            } else if let Ok(f) = n.parse::<f64>() {
583                Expr::FloatLiteral(f)
584            } else {
585                Expr::Raw(n.clone())
586            }
587        }
588        ast::Value::SingleQuotedString(s) => Expr::StringLiteral(s.clone()),
589        ast::Value::Boolean(b) => Expr::BooleanLiteral(*b),
590        ast::Value::Null => Expr::Null,
591        _ => Expr::Raw(val.to_string()),
592    }
593}
594
595fn convert_referential_action(action: &ReferentialAction) -> Option<FkAction> {
596    match action {
597        ReferentialAction::Cascade => Some(FkAction::Cascade),
598        ReferentialAction::SetNull => Some(FkAction::SetNull),
599        ReferentialAction::SetDefault => Some(FkAction::SetDefault),
600        ReferentialAction::Restrict => Some(FkAction::Restrict),
601        ReferentialAction::NoAction => Some(FkAction::NoAction),
602    }
603}
604
605fn extract_numeric_info(info: &ast::ExactNumberInfo) -> (Option<u32>, Option<u32>) {
606    match info {
607        ast::ExactNumberInfo::PrecisionAndScale(p, s) => (Some(*p as u32), Some(*s as u32)),
608        ast::ExactNumberInfo::Precision(p) => (Some(*p as u32), None),
609        ast::ExactNumberInfo::None => (None, None),
610    }
611}
612
613fn extract_char_length(len: &Option<ast::CharacterLength>) -> Option<u32> {
614    len.as_ref().map(|cl| match cl {
615        ast::CharacterLength::IntegerLength { length, .. } => *length as u32,
616        ast::CharacterLength::Max => u32::MAX,
617    })
618}
619
620#[cfg(test)]
621mod tests {
622    use super::*;
623
624    #[test]
625    fn test_parse_simple_table() {
626        let sql = "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL);";
627        let (model, warnings) = parse(sql);
628        assert!(warnings.is_empty());
629        assert_eq!(model.tables.len(), 1);
630        let table = &model.tables[0];
631        assert_eq!(table.name.name.normalized, "users");
632        assert_eq!(table.columns.len(), 2);
633        assert!(table.columns[0].is_primary_key);
634        assert!(table.columns[1].not_null);
635    }
636
637    #[test]
638    fn test_parse_schema_qualified_table() {
639        let sql = "CREATE TABLE public.users (id INTEGER);";
640        let (model, _) = parse(sql);
641        let table = &model.tables[0];
642        assert_eq!(table.name.schema.as_ref().unwrap().normalized, "public");
643        assert_eq!(table.name.name.normalized, "users");
644    }
645
646    #[test]
647    fn test_parse_create_index() {
648        let sql = "CREATE INDEX idx_name ON users (name);";
649        let (model, _) = parse(sql);
650        assert_eq!(model.indexes.len(), 1);
651        assert_eq!(model.indexes[0].name.normalized, "idx_name");
652        assert!(!model.indexes[0].unique);
653    }
654
655    #[test]
656    fn test_parse_unique_index() {
657        let sql = "CREATE UNIQUE INDEX idx_email ON users (email);";
658        let (model, _) = parse(sql);
659        assert!(model.indexes[0].unique);
660    }
661
662    #[test]
663    fn test_parse_alter_table_add_constraint() {
664        let sql = r#"
665            CREATE TABLE orders (id INTEGER, user_id INTEGER);
666            ALTER TABLE orders ADD CONSTRAINT fk_user FOREIGN KEY (user_id) REFERENCES users (id);
667        "#;
668        let (model, _) = parse(sql);
669        assert_eq!(model.tables.len(), 1);
670        assert_eq!(model.alter_constraints.len(), 1);
671    }
672
673    #[test]
674    fn test_parse_create_type_enum() {
675        let sql = "CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy');";
676        let (model, _) = parse(sql);
677        assert_eq!(model.enums.len(), 1);
678        assert_eq!(model.enums[0].values.len(), 3);
679    }
680
681    #[test]
682    fn test_parse_column_default() {
683        let sql = "CREATE TABLE t (created_at TIMESTAMP DEFAULT now());";
684        let (model, _) = parse(sql);
685        let col = &model.tables[0].columns[0];
686        assert!(col.default.is_some());
687    }
688
689    #[test]
690    fn test_non_ddl_ignored() {
691        let sql = "SELECT 1; CREATE TABLE t (id INTEGER);";
692        let (model, warnings) = parse(sql);
693        assert_eq!(model.tables.len(), 1);
694        assert!(warnings.is_empty());
695    }
696
697    #[test]
698    fn test_parse_foreign_key_with_actions() {
699        let sql = r#"
700            CREATE TABLE orders (
701                id INTEGER PRIMARY KEY,
702                user_id INTEGER REFERENCES users(id) ON DELETE CASCADE ON UPDATE SET NULL
703            );
704        "#;
705        let (model, _) = parse(sql);
706        let col = &model.tables[0].columns[1];
707        let fk = col.references.as_ref().unwrap();
708        assert_eq!(fk.on_delete, Some(FkAction::Cascade));
709        assert_eq!(fk.on_update, Some(FkAction::SetNull));
710    }
711
712    #[test]
713    fn test_parse_check_constraint() {
714        let sql = "CREATE TABLE t (age INTEGER CHECK (age >= 0));";
715        let (model, _) = parse(sql);
716        assert!(model.tables[0].columns[0].check.is_some());
717    }
718
719    #[test]
720    fn test_parse_any_array_to_in_list() {
721        let sql = r#"CREATE TABLE t (
722            status TEXT NOT NULL,
723            CONSTRAINT status_check CHECK ((status = ANY (ARRAY['active'::text, 'inactive'::text])))
724        );"#;
725        let (model, _) = parse(sql);
726        let table = &model.tables[0];
727        assert_eq!(table.constraints.len(), 1);
728        if let TableConstraint::Check { name, expr } = &table.constraints[0] {
729            assert_eq!(name.as_ref().unwrap().normalized, "status_check");
730            // Should be Nested(InList { ... })
731            if let Expr::Nested(inner) = expr {
732                if let Expr::InList {
733                    expr: col,
734                    list,
735                    negated,
736                } = inner.as_ref()
737                {
738                    assert!(!negated);
739                    assert!(matches!(col.as_ref(), Expr::ColumnRef(name) if name == "status"));
740                    assert_eq!(list.len(), 2);
741                    // Casts should be preserved at parse level (stripped during transform)
742                    assert!(
743                        matches!(&list[0], Expr::Cast { expr, .. } if matches!(expr.as_ref(), Expr::StringLiteral(s) if s == "active"))
744                    );
745                } else {
746                    panic!("Expected InList, got: {inner:?}");
747                }
748            } else {
749                panic!("Expected Nested, got: {expr:?}");
750            }
751        } else {
752            panic!("Expected Check constraint");
753        }
754    }
755
756    #[test]
757    fn test_parse_identity_native() {
758        let sql = r#"
759            ALTER TABLE address ALTER COLUMN id ADD GENERATED BY DEFAULT AS IDENTITY (
760                SEQUENCE NAME address_id_seq
761                START WITH 1
762                INCREMENT BY 1
763                NO MINVALUE
764                NO MAXVALUE
765                CACHE 1
766            );
767        "#;
768        let (model, warnings) = parse(sql);
769        assert!(warnings.is_empty(), "warnings: {warnings:?}");
770        assert_eq!(model.identity_columns.len(), 1);
771        assert_eq!(model.identity_columns[0].table.name.normalized, "address");
772        assert_eq!(model.identity_columns[0].column.normalized, "id");
773    }
774
775    #[test]
776    fn test_parse_identity_with_schema() {
777        let sql = r#"
778            ALTER TABLE public.seed ALTER COLUMN id ADD GENERATED BY DEFAULT AS IDENTITY (
779                SEQUENCE NAME public.seed_id_seq
780                START WITH 1
781                INCREMENT BY 1
782                NO MINVALUE
783                NO MAXVALUE
784                CACHE 1
785            );
786        "#;
787        let (model, warnings) = parse(sql);
788        assert!(warnings.is_empty(), "warnings: {warnings:?}");
789        assert_eq!(model.identity_columns.len(), 1);
790        assert_eq!(
791            model.identity_columns[0]
792                .table
793                .schema
794                .as_ref()
795                .unwrap()
796                .normalized,
797            "public"
798        );
799        assert_eq!(model.identity_columns[0].table.name.normalized, "seed");
800        assert_eq!(model.identity_columns[0].column.normalized, "id");
801    }
802}