Skip to main content

pg2sqlite_core/pg/
parser.rs

1/// PostgreSQL DDL parser using sqlparser-rs.
2///
3/// Converts sqlparser AST into our internal representation (IR).
4use sqlparser::ast::{
5    self, AlterTableOperation, Array, ArrayElemTypeDef, BinaryOperator, ColumnDef, ColumnOption,
6    CreateIndex, DataType, Expr as SqlExpr, ObjectName, ReferentialAction, Statement,
7    TableConstraint as SqlConstraint, UserDefinedTypeRepresentation,
8};
9use sqlparser::dialect::PostgreSqlDialect;
10use sqlparser::parser::Parser;
11
12use crate::diagnostics::warning::{self, Severity, Warning};
13use crate::ir::{
14    AlterConstraint, Column, EnumDef, Expr, FkAction, ForeignKeyRef, Ident, Index, IndexColumn,
15    IndexMethod, PgType, QualifiedName, SchemaModel, Sequence, Table, TableConstraint,
16};
17
18/// Parse PostgreSQL DDL text into an IR SchemaModel.
19pub fn parse(input: &str) -> (SchemaModel, Vec<Warning>) {
20    let dialect = PostgreSqlDialect {};
21    let mut model = SchemaModel::default();
22    let mut warnings = Vec::new();
23
24    let statements = match Parser::parse_sql(&dialect, input) {
25        Ok(stmts) => stmts,
26        Err(e) => {
27            warnings.push(Warning::new(
28                warning::PARSE_SKIPPED,
29                Severity::Error,
30                format!("Failed to parse DDL: {e}"),
31            ));
32            return (model, warnings);
33        }
34    };
35
36    for stmt in statements {
37        match stmt {
38            Statement::CreateTable(ct) => {
39                if let Some(table) = parse_create_table(&ct, &mut warnings) {
40                    model.tables.push(table);
41                }
42            }
43            Statement::CreateIndex(ci) => {
44                if let Some(idx) = parse_create_index(&ci, &mut warnings) {
45                    model.indexes.push(idx);
46                }
47            }
48            Statement::CreateSequence { name, .. } => {
49                model.sequences.push(Sequence {
50                    name: convert_object_name(&name),
51                    owned_by: None,
52                });
53            }
54            Statement::AlterTable {
55                name, operations, ..
56            } => {
57                let table_name = convert_object_name(&name);
58                for op in operations {
59                    if let Some(constraint) = parse_alter_table_op(&table_name, &op, &mut warnings)
60                    {
61                        model.alter_constraints.push(constraint);
62                    }
63                }
64            }
65            Statement::CreateType {
66                name,
67                representation: UserDefinedTypeRepresentation::Enum { labels },
68                ..
69            } => {
70                let values: Vec<String> = labels.into_iter().map(|v| v.to_string()).collect();
71                model.enums.push(EnumDef {
72                    name: convert_object_name(&name),
73                    values,
74                });
75            }
76            // Skip non-DDL statements silently
77            _ => {}
78        }
79    }
80
81    (model, warnings)
82}
83
84fn parse_create_table(ct: &ast::CreateTable, warnings: &mut [Warning]) -> Option<Table> {
85    let name = convert_object_name(&ct.name);
86    let mut columns = Vec::new();
87    let mut constraints = Vec::new();
88
89    for element in &ct.columns {
90        columns.push(parse_column(element));
91    }
92
93    for constraint in &ct.constraints {
94        if let Some(tc) = parse_table_constraint(constraint, warnings) {
95            constraints.push(tc);
96        }
97    }
98
99    Some(Table {
100        name,
101        columns,
102        constraints,
103    })
104}
105
106fn parse_column(col_def: &ColumnDef) -> Column {
107    let name = Ident::new(&col_def.name.value);
108    let pg_type = convert_data_type(&col_def.data_type);
109    let mut not_null = false;
110    let mut default = None;
111    let mut is_primary_key = false;
112    let mut is_unique = false;
113    let mut references = None;
114    let mut check = None;
115
116    for opt in &col_def.options {
117        match &opt.option {
118            ColumnOption::NotNull => not_null = true,
119            ColumnOption::Null => not_null = false,
120            ColumnOption::Default(expr) => {
121                default = Some(convert_sql_expr(expr));
122            }
123            ColumnOption::Unique { is_primary, .. } => {
124                if *is_primary {
125                    is_primary_key = true;
126                } else {
127                    is_unique = true;
128                }
129            }
130            ColumnOption::ForeignKey {
131                foreign_table,
132                referred_columns,
133                on_delete,
134                on_update,
135                ..
136            } => {
137                let ref_col = referred_columns.first().map(|c| Ident::new(&c.value));
138                references = Some(ForeignKeyRef {
139                    table: convert_object_name(foreign_table),
140                    column: ref_col,
141                    on_delete: on_delete.as_ref().and_then(convert_referential_action),
142                    on_update: on_update.as_ref().and_then(convert_referential_action),
143                });
144            }
145            ColumnOption::Check(expr) => {
146                check = Some(convert_sql_expr(expr));
147            }
148            _ => {}
149        }
150    }
151
152    Column {
153        name,
154        pg_type,
155        sqlite_type: None,
156        not_null,
157        default,
158        is_primary_key,
159        is_unique,
160        references,
161        check,
162    }
163}
164
165fn parse_table_constraint(
166    constraint: &SqlConstraint,
167    _warnings: &mut [Warning],
168) -> Option<TableConstraint> {
169    match constraint {
170        SqlConstraint::PrimaryKey { columns, name, .. } => {
171            let cols: Vec<Ident> = columns.iter().map(|c| Ident::new(&c.value)).collect();
172            Some(TableConstraint::PrimaryKey {
173                name: name.as_ref().map(|n| Ident::new(&n.value)),
174                columns: cols,
175            })
176        }
177        SqlConstraint::Unique { columns, name, .. } => {
178            let cols: Vec<Ident> = columns.iter().map(|c| Ident::new(&c.value)).collect();
179            Some(TableConstraint::Unique {
180                name: name.as_ref().map(|n| Ident::new(&n.value)),
181                columns: cols,
182            })
183        }
184        // Note:
185        // The parser hardcodes deferrable: false for foreign key constraints.
186        // PostgreSQL supports DEFERRABLE characteristics which should be extracted from the sqlparser AST
187        // to ensure accurate transformation and warning emission later in the pipeline.
188        SqlConstraint::ForeignKey {
189            name,
190            columns,
191            foreign_table,
192            referred_columns,
193            on_delete,
194            on_update,
195            ..
196        } => Some(TableConstraint::ForeignKey {
197            name: name.as_ref().map(|n| Ident::new(&n.value)),
198            columns: columns.iter().map(|c| Ident::new(&c.value)).collect(),
199            ref_table: convert_object_name(foreign_table),
200            ref_columns: referred_columns
201                .iter()
202                .map(|c| Ident::new(&c.value))
203                .collect(),
204            on_delete: on_delete.as_ref().and_then(convert_referential_action),
205            on_update: on_update.as_ref().and_then(convert_referential_action),
206            deferrable: false,
207        }),
208        SqlConstraint::Check { name, expr } => Some(TableConstraint::Check {
209            name: name.as_ref().map(|n| Ident::new(&n.value)),
210            expr: convert_sql_expr(expr),
211        }),
212        _ => None,
213    }
214}
215
216fn parse_create_index(ci: &CreateIndex, _warnings: &mut [Warning]) -> Option<Index> {
217    let index_name = ci.name.as_ref()?;
218    let name = Ident::new(&index_name.to_string());
219    let table = convert_object_name(&ci.table_name);
220
221    let mut columns = Vec::new();
222    for col in &ci.columns {
223        let col_name = col.expr.to_string();
224        // Check if this looks like a function call / expression
225        if col_name.contains('(') {
226            columns.push(IndexColumn::Expression(Expr::Raw(col_name)));
227        } else {
228            columns.push(IndexColumn::Column(Ident::new(&col_name)));
229        }
230    }
231
232    let method = ci
233        .using
234        .as_ref()
235        .and_then(|m| match m.value.to_lowercase().as_str() {
236            "btree" => Some(IndexMethod::Btree),
237            "hash" => Some(IndexMethod::Hash),
238            "gin" => Some(IndexMethod::Gin),
239            "gist" => Some(IndexMethod::Gist),
240            "spgist" => Some(IndexMethod::SpGist),
241            "brin" => Some(IndexMethod::Brin),
242            _ => None,
243        });
244
245    let where_clause = ci.predicate.as_ref().map(convert_sql_expr);
246
247    Some(Index {
248        name,
249        table,
250        columns,
251        unique: ci.unique,
252        method,
253        where_clause,
254    })
255}
256
257fn parse_alter_table_op(
258    table: &QualifiedName,
259    op: &AlterTableOperation,
260    warnings: &mut [Warning],
261) -> Option<AlterConstraint> {
262    match op {
263        AlterTableOperation::AddConstraint(constraint) => {
264            parse_table_constraint(constraint, warnings).map(|c| AlterConstraint {
265                table: table.clone(),
266                constraint: c,
267            })
268        }
269        _ => None,
270    }
271}
272
273/// Convert sqlparser ObjectName to our QualifiedName.
274fn convert_object_name(name: &ObjectName) -> QualifiedName {
275    let parts: Vec<&str> = name.0.iter().map(|ident| ident.value.as_str()).collect();
276    match parts.len() {
277        1 => QualifiedName::new(Ident::new(parts[0])),
278        2 => QualifiedName::with_schema(Ident::new(parts[0]), Ident::new(parts[1])),
279        _ => {
280            // Take the last two parts as schema.table
281            let len = parts.len();
282            QualifiedName::with_schema(Ident::new(parts[len - 2]), Ident::new(parts[len - 1]))
283        }
284    }
285}
286
287/// Convert sqlparser DataType to our PgType.
288fn convert_data_type(dt: &DataType) -> PgType {
289    match dt {
290        DataType::SmallInt(_) | DataType::Int2(_) => PgType::SmallInt,
291        DataType::Integer(_) | DataType::Int(_) | DataType::Int4(_) => PgType::Integer,
292        DataType::BigInt(_) | DataType::Int8(_) => PgType::BigInt,
293        DataType::Real | DataType::Float4 => PgType::Real,
294        DataType::Double | DataType::DoublePrecision | DataType::Float8 => PgType::DoublePrecision,
295        DataType::Numeric(info) | DataType::Decimal(info) => {
296            let (precision, scale) = extract_numeric_info(info);
297            PgType::Numeric { precision, scale }
298        }
299        DataType::Boolean => PgType::Boolean,
300        DataType::Text => PgType::Text,
301        DataType::Varchar(len) | DataType::CharacterVarying(len) => PgType::Varchar {
302            length: extract_char_length(len),
303        },
304        DataType::Char(len) | DataType::Character(len) => PgType::Char {
305            length: extract_char_length(len),
306        },
307        DataType::Date => PgType::Date,
308        DataType::Time(_, tz) => PgType::Time {
309            with_tz: matches!(tz, ast::TimezoneInfo::WithTimeZone),
310        },
311        DataType::Timestamp(_, tz) => PgType::Timestamp {
312            with_tz: matches!(tz, ast::TimezoneInfo::WithTimeZone),
313        },
314        DataType::Interval => PgType::Interval,
315        DataType::Bytea => PgType::Bytea,
316        DataType::Uuid => PgType::Uuid,
317        DataType::JSON => PgType::Json,
318        DataType::JSONB => PgType::Jsonb,
319        DataType::Blob(_) => PgType::Bytea,
320        DataType::Array(
321            ArrayElemTypeDef::SquareBracket(inner, _) | ArrayElemTypeDef::AngleBracket(inner),
322        ) => PgType::Array {
323            element: Box::new(convert_data_type(inner)),
324        },
325        DataType::Array(_) => PgType::Other {
326            name: dt.to_string(),
327        },
328        DataType::Custom(name, _) => {
329            // Use the last part of the name to handle schema-qualified types (e.g., pg_catalog.serial)
330            let type_name = name
331                .0
332                .last()
333                .map(|id| id.value.to_lowercase())
334                .unwrap_or_default();
335            match type_name.as_str() {
336                "serial" => PgType::Serial,
337                "bigserial" => PgType::BigSerial,
338                "smallserial" => PgType::SmallSerial,
339                "inet" => PgType::Inet,
340                "cidr" => PgType::Cidr,
341                "macaddr" | "macaddr8" => PgType::MacAddr,
342                "money" => PgType::Money,
343                "xml" => PgType::Xml,
344                "point" => PgType::Point,
345                "line" => PgType::Line,
346                "lseg" => PgType::Lseg,
347                "box" => PgType::Box,
348                "path" => PgType::Path,
349                "polygon" => PgType::Polygon,
350                "circle" => PgType::Circle,
351                "int4range" => PgType::Int4Range,
352                "int8range" => PgType::Int8Range,
353                "numrange" => PgType::NumRange,
354                "tsrange" => PgType::TsRange,
355                "tstzrange" => PgType::TsTzRange,
356                "daterange" => PgType::DateRange,
357                _ => PgType::Other { name: type_name },
358            }
359        }
360        _ => PgType::Other {
361            name: dt.to_string(),
362        },
363    }
364}
365
366/// Convert sqlparser Expr to our Expr.
367fn convert_sql_expr(expr: &SqlExpr) -> Expr {
368    match expr {
369        SqlExpr::Value(val) => convert_value(val),
370        SqlExpr::Identifier(ident) => Expr::ColumnRef(ident.value.clone()),
371        SqlExpr::CompoundIdentifier(idents) => {
372            let name: Vec<&str> = idents.iter().map(|i| i.value.as_str()).collect();
373            Expr::ColumnRef(name.join("."))
374        }
375        SqlExpr::Function(func) => {
376            let func_name = func.name.to_string().to_lowercase();
377            let args: Vec<Expr> = match &func.args {
378                ast::FunctionArguments::List(arg_list) => arg_list
379                    .args
380                    .iter()
381                    .filter_map(|arg| match arg {
382                        ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)) => {
383                            Some(convert_sql_expr(e))
384                        }
385                        _ => None,
386                    })
387                    .collect(),
388                _ => Vec::new(),
389            };
390
391            // Detect nextval('sequence_name')
392            if func_name == "nextval"
393                && let Some(Expr::StringLiteral(seq)) = args.first()
394            {
395                return Expr::NextVal(seq.clone());
396            }
397
398            Expr::FunctionCall {
399                name: func_name,
400                args,
401            }
402        }
403        SqlExpr::Cast {
404            expr, data_type, ..
405        } => Expr::Cast {
406            expr: Box::new(convert_sql_expr(expr)),
407            type_name: data_type.to_string(),
408        },
409        SqlExpr::BinaryOp { left, op, right } => Expr::BinaryOp {
410            left: Box::new(convert_sql_expr(left)),
411            op: op.to_string(),
412            right: Box::new(convert_sql_expr(right)),
413        },
414        SqlExpr::UnaryOp { op, expr } => Expr::UnaryOp {
415            op: op.to_string(),
416            expr: Box::new(convert_sql_expr(expr)),
417        },
418        SqlExpr::IsNull(expr) => Expr::IsNull {
419            expr: Box::new(convert_sql_expr(expr)),
420            negated: false,
421        },
422        SqlExpr::IsNotNull(expr) => Expr::IsNull {
423            expr: Box::new(convert_sql_expr(expr)),
424            negated: true,
425        },
426        SqlExpr::InList {
427            expr,
428            list,
429            negated,
430        } => Expr::InList {
431            expr: Box::new(convert_sql_expr(expr)),
432            list: list.iter().map(convert_sql_expr).collect(),
433            negated: *negated,
434        },
435        SqlExpr::Between {
436            expr,
437            low,
438            high,
439            negated,
440        } => Expr::Between {
441            expr: Box::new(convert_sql_expr(expr)),
442            low: Box::new(convert_sql_expr(low)),
443            high: Box::new(convert_sql_expr(high)),
444            negated: *negated,
445        },
446        SqlExpr::Nested(inner) => Expr::Nested(Box::new(convert_sql_expr(inner))),
447        // col = ANY(ARRAY['a', 'b']) → col IN ('a', 'b')
448        // Only convert when the right-hand side is an ARRAY literal.
449        // Non-array forms (e.g., subqueries) fall through to Raw to avoid
450        // producing a semantically incorrect single-element InList.
451        SqlExpr::AnyOp {
452            left,
453            compare_op: BinaryOperator::Eq,
454            right,
455            ..
456        } => match extract_array_elements(right) {
457            Some(list) => {
458                let left_expr = convert_sql_expr(left);
459                Expr::InList {
460                    expr: Box::new(left_expr),
461                    list,
462                    negated: false,
463                }
464            }
465            None => Expr::Raw(expr.to_string()),
466        },
467        // Note: `expr != ANY(ARRAY[...])` is NOT equivalent to `NOT IN (...)`.
468        // `!= ANY` is true if expr differs from *at least one* element,
469        // whereas `NOT IN` requires expr to differ from *all* elements.
470        // The correct equivalent of `NOT IN` is `!= ALL(...)`, not `!= ANY(...)`.
471        // We let `!= ANY` fall through to the Raw fallback below.
472        //
473        // Fallback: render back to SQL string
474        _ => Expr::Raw(expr.to_string()),
475    }
476}
477
478/// Extract elements from an ARRAY literal expression.
479/// Returns `None` for non-array expressions (e.g., subqueries) so callers
480/// can fall back to `Expr::Raw` instead of producing incorrect results.
481fn extract_array_elements(expr: &SqlExpr) -> Option<Vec<Expr>> {
482    match expr {
483        SqlExpr::Array(Array { elem, .. }) => Some(elem.iter().map(convert_sql_expr).collect()),
484        _ => None,
485    }
486}
487
488fn convert_value(val: &ast::Value) -> Expr {
489    match val {
490        ast::Value::Number(n, _) => {
491            if let Ok(i) = n.parse::<i64>() {
492                Expr::IntegerLiteral(i)
493            } else if let Ok(f) = n.parse::<f64>() {
494                Expr::FloatLiteral(f)
495            } else {
496                Expr::Raw(n.clone())
497            }
498        }
499        ast::Value::SingleQuotedString(s) => Expr::StringLiteral(s.clone()),
500        ast::Value::Boolean(b) => Expr::BooleanLiteral(*b),
501        ast::Value::Null => Expr::Null,
502        _ => Expr::Raw(val.to_string()),
503    }
504}
505
506fn convert_referential_action(action: &ReferentialAction) -> Option<FkAction> {
507    match action {
508        ReferentialAction::Cascade => Some(FkAction::Cascade),
509        ReferentialAction::SetNull => Some(FkAction::SetNull),
510        ReferentialAction::SetDefault => Some(FkAction::SetDefault),
511        ReferentialAction::Restrict => Some(FkAction::Restrict),
512        ReferentialAction::NoAction => Some(FkAction::NoAction),
513    }
514}
515
516fn extract_numeric_info(info: &ast::ExactNumberInfo) -> (Option<u32>, Option<u32>) {
517    match info {
518        ast::ExactNumberInfo::PrecisionAndScale(p, s) => (Some(*p as u32), Some(*s as u32)),
519        ast::ExactNumberInfo::Precision(p) => (Some(*p as u32), None),
520        ast::ExactNumberInfo::None => (None, None),
521    }
522}
523
524fn extract_char_length(len: &Option<ast::CharacterLength>) -> Option<u32> {
525    len.as_ref().map(|cl| match cl {
526        ast::CharacterLength::IntegerLength { length, .. } => *length as u32,
527        ast::CharacterLength::Max => u32::MAX,
528    })
529}
530
531#[cfg(test)]
532mod tests {
533    use super::*;
534
535    #[test]
536    fn test_parse_simple_table() {
537        let sql = "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL);";
538        let (model, warnings) = parse(sql);
539        assert!(warnings.is_empty());
540        assert_eq!(model.tables.len(), 1);
541        let table = &model.tables[0];
542        assert_eq!(table.name.name.normalized, "users");
543        assert_eq!(table.columns.len(), 2);
544        assert!(table.columns[0].is_primary_key);
545        assert!(table.columns[1].not_null);
546    }
547
548    #[test]
549    fn test_parse_schema_qualified_table() {
550        let sql = "CREATE TABLE public.users (id INTEGER);";
551        let (model, _) = parse(sql);
552        let table = &model.tables[0];
553        assert_eq!(table.name.schema.as_ref().unwrap().normalized, "public");
554        assert_eq!(table.name.name.normalized, "users");
555    }
556
557    #[test]
558    fn test_parse_create_index() {
559        let sql = "CREATE INDEX idx_name ON users (name);";
560        let (model, _) = parse(sql);
561        assert_eq!(model.indexes.len(), 1);
562        assert_eq!(model.indexes[0].name.normalized, "idx_name");
563        assert!(!model.indexes[0].unique);
564    }
565
566    #[test]
567    fn test_parse_unique_index() {
568        let sql = "CREATE UNIQUE INDEX idx_email ON users (email);";
569        let (model, _) = parse(sql);
570        assert!(model.indexes[0].unique);
571    }
572
573    #[test]
574    fn test_parse_alter_table_add_constraint() {
575        let sql = r#"
576            CREATE TABLE orders (id INTEGER, user_id INTEGER);
577            ALTER TABLE orders ADD CONSTRAINT fk_user FOREIGN KEY (user_id) REFERENCES users (id);
578        "#;
579        let (model, _) = parse(sql);
580        assert_eq!(model.tables.len(), 1);
581        assert_eq!(model.alter_constraints.len(), 1);
582    }
583
584    #[test]
585    fn test_parse_create_type_enum() {
586        let sql = "CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy');";
587        let (model, _) = parse(sql);
588        assert_eq!(model.enums.len(), 1);
589        assert_eq!(model.enums[0].values.len(), 3);
590    }
591
592    #[test]
593    fn test_parse_column_default() {
594        let sql = "CREATE TABLE t (created_at TIMESTAMP DEFAULT now());";
595        let (model, _) = parse(sql);
596        let col = &model.tables[0].columns[0];
597        assert!(col.default.is_some());
598    }
599
600    #[test]
601    fn test_non_ddl_ignored() {
602        let sql = "SELECT 1; CREATE TABLE t (id INTEGER);";
603        let (model, warnings) = parse(sql);
604        assert_eq!(model.tables.len(), 1);
605        assert!(warnings.is_empty());
606    }
607
608    #[test]
609    fn test_parse_foreign_key_with_actions() {
610        let sql = r#"
611            CREATE TABLE orders (
612                id INTEGER PRIMARY KEY,
613                user_id INTEGER REFERENCES users(id) ON DELETE CASCADE ON UPDATE SET NULL
614            );
615        "#;
616        let (model, _) = parse(sql);
617        let col = &model.tables[0].columns[1];
618        let fk = col.references.as_ref().unwrap();
619        assert_eq!(fk.on_delete, Some(FkAction::Cascade));
620        assert_eq!(fk.on_update, Some(FkAction::SetNull));
621    }
622
623    #[test]
624    fn test_parse_check_constraint() {
625        let sql = "CREATE TABLE t (age INTEGER CHECK (age >= 0));";
626        let (model, _) = parse(sql);
627        assert!(model.tables[0].columns[0].check.is_some());
628    }
629
630    #[test]
631    fn test_parse_any_array_to_in_list() {
632        let sql = r#"CREATE TABLE t (
633            status TEXT NOT NULL,
634            CONSTRAINT status_check CHECK ((status = ANY (ARRAY['active'::text, 'inactive'::text])))
635        );"#;
636        let (model, _) = parse(sql);
637        let table = &model.tables[0];
638        assert_eq!(table.constraints.len(), 1);
639        if let TableConstraint::Check { name, expr } = &table.constraints[0] {
640            assert_eq!(name.as_ref().unwrap().normalized, "status_check");
641            // Should be Nested(InList { ... })
642            if let Expr::Nested(inner) = expr {
643                if let Expr::InList {
644                    expr: col,
645                    list,
646                    negated,
647                } = inner.as_ref()
648                {
649                    assert!(!negated);
650                    assert!(matches!(col.as_ref(), Expr::ColumnRef(name) if name == "status"));
651                    assert_eq!(list.len(), 2);
652                    // Casts should be preserved at parse level (stripped during transform)
653                    assert!(
654                        matches!(&list[0], Expr::Cast { expr, .. } if matches!(expr.as_ref(), Expr::StringLiteral(s) if s == "active"))
655                    );
656                } else {
657                    panic!("Expected InList, got: {inner:?}");
658                }
659            } else {
660                panic!("Expected Nested, got: {expr:?}");
661            }
662        } else {
663            panic!("Expected Check constraint");
664        }
665    }
666}