1use sqlparser::ast::{
5 self, AlterTableOperation, Array, ArrayElemTypeDef, BinaryOperator, ColumnDef, ColumnOption,
6 CreateIndex, DataType, Expr as SqlExpr, ObjectName, ReferentialAction, Statement,
7 TableConstraint as SqlConstraint, UserDefinedTypeRepresentation,
8};
9use sqlparser::dialect::PostgreSqlDialect;
10use sqlparser::parser::Parser;
11
12use crate::diagnostics::warning::{self, Severity, Warning};
13use crate::ir::{
14 AlterConstraint, Column, EnumDef, Expr, FkAction, ForeignKeyRef, Ident, Index, IndexColumn,
15 IndexMethod, PgType, QualifiedName, SchemaModel, Sequence, Table, TableConstraint,
16};
17
18pub fn parse(input: &str) -> (SchemaModel, Vec<Warning>) {
20 let dialect = PostgreSqlDialect {};
21 let mut model = SchemaModel::default();
22 let mut warnings = Vec::new();
23
24 let statements = match Parser::parse_sql(&dialect, input) {
25 Ok(stmts) => stmts,
26 Err(e) => {
27 warnings.push(Warning::new(
28 warning::PARSE_SKIPPED,
29 Severity::Error,
30 format!("Failed to parse DDL: {e}"),
31 ));
32 return (model, warnings);
33 }
34 };
35
36 for stmt in statements {
37 match stmt {
38 Statement::CreateTable(ct) => {
39 if let Some(table) = parse_create_table(&ct, &mut warnings) {
40 model.tables.push(table);
41 }
42 }
43 Statement::CreateIndex(ci) => {
44 if let Some(idx) = parse_create_index(&ci, &mut warnings) {
45 model.indexes.push(idx);
46 }
47 }
48 Statement::CreateSequence { name, .. } => {
49 model.sequences.push(Sequence {
50 name: convert_object_name(&name),
51 owned_by: None,
52 });
53 }
54 Statement::AlterTable {
55 name, operations, ..
56 } => {
57 let table_name = convert_object_name(&name);
58 for op in operations {
59 if let Some(constraint) = parse_alter_table_op(&table_name, &op, &mut warnings)
60 {
61 model.alter_constraints.push(constraint);
62 }
63 }
64 }
65 Statement::CreateType {
66 name,
67 representation: UserDefinedTypeRepresentation::Enum { labels },
68 ..
69 } => {
70 let values: Vec<String> = labels.into_iter().map(|v| v.to_string()).collect();
71 model.enums.push(EnumDef {
72 name: convert_object_name(&name),
73 values,
74 });
75 }
76 _ => {}
78 }
79 }
80
81 (model, warnings)
82}
83
84fn parse_create_table(ct: &ast::CreateTable, warnings: &mut [Warning]) -> Option<Table> {
85 let name = convert_object_name(&ct.name);
86 let mut columns = Vec::new();
87 let mut constraints = Vec::new();
88
89 for element in &ct.columns {
90 columns.push(parse_column(element));
91 }
92
93 for constraint in &ct.constraints {
94 if let Some(tc) = parse_table_constraint(constraint, warnings) {
95 constraints.push(tc);
96 }
97 }
98
99 Some(Table {
100 name,
101 columns,
102 constraints,
103 })
104}
105
106fn parse_column(col_def: &ColumnDef) -> Column {
107 let name = Ident::new(&col_def.name.value);
108 let pg_type = convert_data_type(&col_def.data_type);
109 let mut not_null = false;
110 let mut default = None;
111 let mut is_primary_key = false;
112 let mut is_unique = false;
113 let mut references = None;
114 let mut check = None;
115
116 for opt in &col_def.options {
117 match &opt.option {
118 ColumnOption::NotNull => not_null = true,
119 ColumnOption::Null => not_null = false,
120 ColumnOption::Default(expr) => {
121 default = Some(convert_sql_expr(expr));
122 }
123 ColumnOption::Unique { is_primary, .. } => {
124 if *is_primary {
125 is_primary_key = true;
126 } else {
127 is_unique = true;
128 }
129 }
130 ColumnOption::ForeignKey {
131 foreign_table,
132 referred_columns,
133 on_delete,
134 on_update,
135 ..
136 } => {
137 let ref_col = referred_columns.first().map(|c| Ident::new(&c.value));
138 references = Some(ForeignKeyRef {
139 table: convert_object_name(foreign_table),
140 column: ref_col,
141 on_delete: on_delete.as_ref().and_then(convert_referential_action),
142 on_update: on_update.as_ref().and_then(convert_referential_action),
143 });
144 }
145 ColumnOption::Check(expr) => {
146 check = Some(convert_sql_expr(expr));
147 }
148 _ => {}
149 }
150 }
151
152 Column {
153 name,
154 pg_type,
155 sqlite_type: None,
156 not_null,
157 default,
158 is_primary_key,
159 is_unique,
160 references,
161 check,
162 }
163}
164
165fn parse_table_constraint(
166 constraint: &SqlConstraint,
167 _warnings: &mut [Warning],
168) -> Option<TableConstraint> {
169 match constraint {
170 SqlConstraint::PrimaryKey { columns, name, .. } => {
171 let cols: Vec<Ident> = columns.iter().map(|c| Ident::new(&c.value)).collect();
172 Some(TableConstraint::PrimaryKey {
173 name: name.as_ref().map(|n| Ident::new(&n.value)),
174 columns: cols,
175 })
176 }
177 SqlConstraint::Unique { columns, name, .. } => {
178 let cols: Vec<Ident> = columns.iter().map(|c| Ident::new(&c.value)).collect();
179 Some(TableConstraint::Unique {
180 name: name.as_ref().map(|n| Ident::new(&n.value)),
181 columns: cols,
182 })
183 }
184 SqlConstraint::ForeignKey {
189 name,
190 columns,
191 foreign_table,
192 referred_columns,
193 on_delete,
194 on_update,
195 ..
196 } => Some(TableConstraint::ForeignKey {
197 name: name.as_ref().map(|n| Ident::new(&n.value)),
198 columns: columns.iter().map(|c| Ident::new(&c.value)).collect(),
199 ref_table: convert_object_name(foreign_table),
200 ref_columns: referred_columns
201 .iter()
202 .map(|c| Ident::new(&c.value))
203 .collect(),
204 on_delete: on_delete.as_ref().and_then(convert_referential_action),
205 on_update: on_update.as_ref().and_then(convert_referential_action),
206 deferrable: false,
207 }),
208 SqlConstraint::Check { name, expr } => Some(TableConstraint::Check {
209 name: name.as_ref().map(|n| Ident::new(&n.value)),
210 expr: convert_sql_expr(expr),
211 }),
212 _ => None,
213 }
214}
215
216fn parse_create_index(ci: &CreateIndex, _warnings: &mut [Warning]) -> Option<Index> {
217 let index_name = ci.name.as_ref()?;
218 let name = Ident::new(&index_name.to_string());
219 let table = convert_object_name(&ci.table_name);
220
221 let mut columns = Vec::new();
222 for col in &ci.columns {
223 let col_name = col.expr.to_string();
224 if col_name.contains('(') {
226 columns.push(IndexColumn::Expression(Expr::Raw(col_name)));
227 } else {
228 columns.push(IndexColumn::Column(Ident::new(&col_name)));
229 }
230 }
231
232 let method = ci
233 .using
234 .as_ref()
235 .and_then(|m| match m.value.to_lowercase().as_str() {
236 "btree" => Some(IndexMethod::Btree),
237 "hash" => Some(IndexMethod::Hash),
238 "gin" => Some(IndexMethod::Gin),
239 "gist" => Some(IndexMethod::Gist),
240 "spgist" => Some(IndexMethod::SpGist),
241 "brin" => Some(IndexMethod::Brin),
242 _ => None,
243 });
244
245 let where_clause = ci.predicate.as_ref().map(convert_sql_expr);
246
247 Some(Index {
248 name,
249 table,
250 columns,
251 unique: ci.unique,
252 method,
253 where_clause,
254 })
255}
256
257fn parse_alter_table_op(
258 table: &QualifiedName,
259 op: &AlterTableOperation,
260 warnings: &mut [Warning],
261) -> Option<AlterConstraint> {
262 match op {
263 AlterTableOperation::AddConstraint(constraint) => {
264 parse_table_constraint(constraint, warnings).map(|c| AlterConstraint {
265 table: table.clone(),
266 constraint: c,
267 })
268 }
269 _ => None,
270 }
271}
272
273fn convert_object_name(name: &ObjectName) -> QualifiedName {
275 let parts: Vec<&str> = name.0.iter().map(|ident| ident.value.as_str()).collect();
276 match parts.len() {
277 1 => QualifiedName::new(Ident::new(parts[0])),
278 2 => QualifiedName::with_schema(Ident::new(parts[0]), Ident::new(parts[1])),
279 _ => {
280 let len = parts.len();
282 QualifiedName::with_schema(Ident::new(parts[len - 2]), Ident::new(parts[len - 1]))
283 }
284 }
285}
286
287fn convert_data_type(dt: &DataType) -> PgType {
289 match dt {
290 DataType::SmallInt(_) | DataType::Int2(_) => PgType::SmallInt,
291 DataType::Integer(_) | DataType::Int(_) | DataType::Int4(_) => PgType::Integer,
292 DataType::BigInt(_) | DataType::Int8(_) => PgType::BigInt,
293 DataType::Real | DataType::Float4 => PgType::Real,
294 DataType::Double | DataType::DoublePrecision | DataType::Float8 => PgType::DoublePrecision,
295 DataType::Numeric(info) | DataType::Decimal(info) => {
296 let (precision, scale) = extract_numeric_info(info);
297 PgType::Numeric { precision, scale }
298 }
299 DataType::Boolean => PgType::Boolean,
300 DataType::Text => PgType::Text,
301 DataType::Varchar(len) | DataType::CharacterVarying(len) => PgType::Varchar {
302 length: extract_char_length(len),
303 },
304 DataType::Char(len) | DataType::Character(len) => PgType::Char {
305 length: extract_char_length(len),
306 },
307 DataType::Date => PgType::Date,
308 DataType::Time(_, tz) => PgType::Time {
309 with_tz: matches!(tz, ast::TimezoneInfo::WithTimeZone),
310 },
311 DataType::Timestamp(_, tz) => PgType::Timestamp {
312 with_tz: matches!(tz, ast::TimezoneInfo::WithTimeZone),
313 },
314 DataType::Interval => PgType::Interval,
315 DataType::Bytea => PgType::Bytea,
316 DataType::Uuid => PgType::Uuid,
317 DataType::JSON => PgType::Json,
318 DataType::JSONB => PgType::Jsonb,
319 DataType::Blob(_) => PgType::Bytea,
320 DataType::Array(
321 ArrayElemTypeDef::SquareBracket(inner, _) | ArrayElemTypeDef::AngleBracket(inner),
322 ) => PgType::Array {
323 element: Box::new(convert_data_type(inner)),
324 },
325 DataType::Array(_) => PgType::Other {
326 name: dt.to_string(),
327 },
328 DataType::Custom(name, _) => {
329 let type_name = name
331 .0
332 .last()
333 .map(|id| id.value.to_lowercase())
334 .unwrap_or_default();
335 match type_name.as_str() {
336 "serial" => PgType::Serial,
337 "bigserial" => PgType::BigSerial,
338 "smallserial" => PgType::SmallSerial,
339 "inet" => PgType::Inet,
340 "cidr" => PgType::Cidr,
341 "macaddr" | "macaddr8" => PgType::MacAddr,
342 "money" => PgType::Money,
343 "xml" => PgType::Xml,
344 "point" => PgType::Point,
345 "line" => PgType::Line,
346 "lseg" => PgType::Lseg,
347 "box" => PgType::Box,
348 "path" => PgType::Path,
349 "polygon" => PgType::Polygon,
350 "circle" => PgType::Circle,
351 "int4range" => PgType::Int4Range,
352 "int8range" => PgType::Int8Range,
353 "numrange" => PgType::NumRange,
354 "tsrange" => PgType::TsRange,
355 "tstzrange" => PgType::TsTzRange,
356 "daterange" => PgType::DateRange,
357 _ => PgType::Other { name: type_name },
358 }
359 }
360 _ => PgType::Other {
361 name: dt.to_string(),
362 },
363 }
364}
365
366fn convert_sql_expr(expr: &SqlExpr) -> Expr {
368 match expr {
369 SqlExpr::Value(val) => convert_value(val),
370 SqlExpr::Identifier(ident) => Expr::ColumnRef(ident.value.clone()),
371 SqlExpr::CompoundIdentifier(idents) => {
372 let name: Vec<&str> = idents.iter().map(|i| i.value.as_str()).collect();
373 Expr::ColumnRef(name.join("."))
374 }
375 SqlExpr::Function(func) => {
376 let func_name = func.name.to_string().to_lowercase();
377 let args: Vec<Expr> = match &func.args {
378 ast::FunctionArguments::List(arg_list) => arg_list
379 .args
380 .iter()
381 .filter_map(|arg| match arg {
382 ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)) => {
383 Some(convert_sql_expr(e))
384 }
385 _ => None,
386 })
387 .collect(),
388 _ => Vec::new(),
389 };
390
391 if func_name == "nextval"
393 && let Some(Expr::StringLiteral(seq)) = args.first()
394 {
395 return Expr::NextVal(seq.clone());
396 }
397
398 Expr::FunctionCall {
399 name: func_name,
400 args,
401 }
402 }
403 SqlExpr::Cast {
404 expr, data_type, ..
405 } => Expr::Cast {
406 expr: Box::new(convert_sql_expr(expr)),
407 type_name: data_type.to_string(),
408 },
409 SqlExpr::BinaryOp { left, op, right } => Expr::BinaryOp {
410 left: Box::new(convert_sql_expr(left)),
411 op: op.to_string(),
412 right: Box::new(convert_sql_expr(right)),
413 },
414 SqlExpr::UnaryOp { op, expr } => Expr::UnaryOp {
415 op: op.to_string(),
416 expr: Box::new(convert_sql_expr(expr)),
417 },
418 SqlExpr::IsNull(expr) => Expr::IsNull {
419 expr: Box::new(convert_sql_expr(expr)),
420 negated: false,
421 },
422 SqlExpr::IsNotNull(expr) => Expr::IsNull {
423 expr: Box::new(convert_sql_expr(expr)),
424 negated: true,
425 },
426 SqlExpr::InList {
427 expr,
428 list,
429 negated,
430 } => Expr::InList {
431 expr: Box::new(convert_sql_expr(expr)),
432 list: list.iter().map(convert_sql_expr).collect(),
433 negated: *negated,
434 },
435 SqlExpr::Between {
436 expr,
437 low,
438 high,
439 negated,
440 } => Expr::Between {
441 expr: Box::new(convert_sql_expr(expr)),
442 low: Box::new(convert_sql_expr(low)),
443 high: Box::new(convert_sql_expr(high)),
444 negated: *negated,
445 },
446 SqlExpr::Nested(inner) => Expr::Nested(Box::new(convert_sql_expr(inner))),
447 SqlExpr::AnyOp {
452 left,
453 compare_op: BinaryOperator::Eq,
454 right,
455 ..
456 } => match extract_array_elements(right) {
457 Some(list) => {
458 let left_expr = convert_sql_expr(left);
459 Expr::InList {
460 expr: Box::new(left_expr),
461 list,
462 negated: false,
463 }
464 }
465 None => Expr::Raw(expr.to_string()),
466 },
467 _ => Expr::Raw(expr.to_string()),
475 }
476}
477
478fn extract_array_elements(expr: &SqlExpr) -> Option<Vec<Expr>> {
482 match expr {
483 SqlExpr::Array(Array { elem, .. }) => Some(elem.iter().map(convert_sql_expr).collect()),
484 _ => None,
485 }
486}
487
488fn convert_value(val: &ast::Value) -> Expr {
489 match val {
490 ast::Value::Number(n, _) => {
491 if let Ok(i) = n.parse::<i64>() {
492 Expr::IntegerLiteral(i)
493 } else if let Ok(f) = n.parse::<f64>() {
494 Expr::FloatLiteral(f)
495 } else {
496 Expr::Raw(n.clone())
497 }
498 }
499 ast::Value::SingleQuotedString(s) => Expr::StringLiteral(s.clone()),
500 ast::Value::Boolean(b) => Expr::BooleanLiteral(*b),
501 ast::Value::Null => Expr::Null,
502 _ => Expr::Raw(val.to_string()),
503 }
504}
505
506fn convert_referential_action(action: &ReferentialAction) -> Option<FkAction> {
507 match action {
508 ReferentialAction::Cascade => Some(FkAction::Cascade),
509 ReferentialAction::SetNull => Some(FkAction::SetNull),
510 ReferentialAction::SetDefault => Some(FkAction::SetDefault),
511 ReferentialAction::Restrict => Some(FkAction::Restrict),
512 ReferentialAction::NoAction => Some(FkAction::NoAction),
513 }
514}
515
516fn extract_numeric_info(info: &ast::ExactNumberInfo) -> (Option<u32>, Option<u32>) {
517 match info {
518 ast::ExactNumberInfo::PrecisionAndScale(p, s) => (Some(*p as u32), Some(*s as u32)),
519 ast::ExactNumberInfo::Precision(p) => (Some(*p as u32), None),
520 ast::ExactNumberInfo::None => (None, None),
521 }
522}
523
524fn extract_char_length(len: &Option<ast::CharacterLength>) -> Option<u32> {
525 len.as_ref().map(|cl| match cl {
526 ast::CharacterLength::IntegerLength { length, .. } => *length as u32,
527 ast::CharacterLength::Max => u32::MAX,
528 })
529}
530
531#[cfg(test)]
532mod tests {
533 use super::*;
534
535 #[test]
536 fn test_parse_simple_table() {
537 let sql = "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL);";
538 let (model, warnings) = parse(sql);
539 assert!(warnings.is_empty());
540 assert_eq!(model.tables.len(), 1);
541 let table = &model.tables[0];
542 assert_eq!(table.name.name.normalized, "users");
543 assert_eq!(table.columns.len(), 2);
544 assert!(table.columns[0].is_primary_key);
545 assert!(table.columns[1].not_null);
546 }
547
548 #[test]
549 fn test_parse_schema_qualified_table() {
550 let sql = "CREATE TABLE public.users (id INTEGER);";
551 let (model, _) = parse(sql);
552 let table = &model.tables[0];
553 assert_eq!(table.name.schema.as_ref().unwrap().normalized, "public");
554 assert_eq!(table.name.name.normalized, "users");
555 }
556
557 #[test]
558 fn test_parse_create_index() {
559 let sql = "CREATE INDEX idx_name ON users (name);";
560 let (model, _) = parse(sql);
561 assert_eq!(model.indexes.len(), 1);
562 assert_eq!(model.indexes[0].name.normalized, "idx_name");
563 assert!(!model.indexes[0].unique);
564 }
565
566 #[test]
567 fn test_parse_unique_index() {
568 let sql = "CREATE UNIQUE INDEX idx_email ON users (email);";
569 let (model, _) = parse(sql);
570 assert!(model.indexes[0].unique);
571 }
572
573 #[test]
574 fn test_parse_alter_table_add_constraint() {
575 let sql = r#"
576 CREATE TABLE orders (id INTEGER, user_id INTEGER);
577 ALTER TABLE orders ADD CONSTRAINT fk_user FOREIGN KEY (user_id) REFERENCES users (id);
578 "#;
579 let (model, _) = parse(sql);
580 assert_eq!(model.tables.len(), 1);
581 assert_eq!(model.alter_constraints.len(), 1);
582 }
583
584 #[test]
585 fn test_parse_create_type_enum() {
586 let sql = "CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy');";
587 let (model, _) = parse(sql);
588 assert_eq!(model.enums.len(), 1);
589 assert_eq!(model.enums[0].values.len(), 3);
590 }
591
592 #[test]
593 fn test_parse_column_default() {
594 let sql = "CREATE TABLE t (created_at TIMESTAMP DEFAULT now());";
595 let (model, _) = parse(sql);
596 let col = &model.tables[0].columns[0];
597 assert!(col.default.is_some());
598 }
599
600 #[test]
601 fn test_non_ddl_ignored() {
602 let sql = "SELECT 1; CREATE TABLE t (id INTEGER);";
603 let (model, warnings) = parse(sql);
604 assert_eq!(model.tables.len(), 1);
605 assert!(warnings.is_empty());
606 }
607
608 #[test]
609 fn test_parse_foreign_key_with_actions() {
610 let sql = r#"
611 CREATE TABLE orders (
612 id INTEGER PRIMARY KEY,
613 user_id INTEGER REFERENCES users(id) ON DELETE CASCADE ON UPDATE SET NULL
614 );
615 "#;
616 let (model, _) = parse(sql);
617 let col = &model.tables[0].columns[1];
618 let fk = col.references.as_ref().unwrap();
619 assert_eq!(fk.on_delete, Some(FkAction::Cascade));
620 assert_eq!(fk.on_update, Some(FkAction::SetNull));
621 }
622
623 #[test]
624 fn test_parse_check_constraint() {
625 let sql = "CREATE TABLE t (age INTEGER CHECK (age >= 0));";
626 let (model, _) = parse(sql);
627 assert!(model.tables[0].columns[0].check.is_some());
628 }
629
630 #[test]
631 fn test_parse_any_array_to_in_list() {
632 let sql = r#"CREATE TABLE t (
633 status TEXT NOT NULL,
634 CONSTRAINT status_check CHECK ((status = ANY (ARRAY['active'::text, 'inactive'::text])))
635 );"#;
636 let (model, _) = parse(sql);
637 let table = &model.tables[0];
638 assert_eq!(table.constraints.len(), 1);
639 if let TableConstraint::Check { name, expr } = &table.constraints[0] {
640 assert_eq!(name.as_ref().unwrap().normalized, "status_check");
641 if let Expr::Nested(inner) = expr {
643 if let Expr::InList {
644 expr: col,
645 list,
646 negated,
647 } = inner.as_ref()
648 {
649 assert!(!negated);
650 assert!(matches!(col.as_ref(), Expr::ColumnRef(name) if name == "status"));
651 assert_eq!(list.len(), 2);
652 assert!(
654 matches!(&list[0], Expr::Cast { expr, .. } if matches!(expr.as_ref(), Expr::StringLiteral(s) if s == "active"))
655 );
656 } else {
657 panic!("Expected InList, got: {inner:?}");
658 }
659 } else {
660 panic!("Expected Nested, got: {expr:?}");
661 }
662 } else {
663 panic!("Expected Check constraint");
664 }
665 }
666}