Skip to main content

citadel_sql/
parser.rs

1//! SQL parser: converts SQL strings into our internal AST.
2
3use sqlparser::ast as sp;
4use sqlparser::dialect::GenericDialect;
5use sqlparser::parser::Parser;
6
7use crate::error::{Result, SqlError};
8use crate::types::{DataType, Value};
9
10#[derive(Debug, Clone)]
11pub enum Statement {
12    CreateTable(CreateTableStmt),
13    DropTable(DropTableStmt),
14    CreateIndex(CreateIndexStmt),
15    DropIndex(DropIndexStmt),
16    CreateView(CreateViewStmt),
17    DropView(DropViewStmt),
18    AlterTable(Box<AlterTableStmt>),
19    Insert(InsertStmt),
20    Select(Box<SelectQuery>),
21    Update(UpdateStmt),
22    Delete(DeleteStmt),
23    Begin,
24    Commit,
25    Rollback,
26    Savepoint(String),
27    ReleaseSavepoint(String),
28    RollbackTo(String),
29    SetTimezone(String),
30    Explain(Box<Statement>),
31}
32
33#[derive(Debug, Clone)]
34pub struct AlterTableStmt {
35    pub table: String,
36    pub op: AlterTableOp,
37}
38
39#[derive(Debug, Clone)]
40pub enum AlterTableOp {
41    AddColumn {
42        column: Box<ColumnSpec>,
43        foreign_key: Option<ForeignKeyDef>,
44        if_not_exists: bool,
45    },
46    DropColumn {
47        name: String,
48        if_exists: bool,
49    },
50    RenameColumn {
51        old_name: String,
52        new_name: String,
53    },
54    RenameTable {
55        new_name: String,
56    },
57}
58
59#[derive(Debug, Clone)]
60pub struct CreateTableStmt {
61    pub name: String,
62    pub columns: Vec<ColumnSpec>,
63    pub primary_key: Vec<String>,
64    pub if_not_exists: bool,
65    pub check_constraints: Vec<TableCheckConstraint>,
66    pub foreign_keys: Vec<ForeignKeyDef>,
67    pub unique_indices: Vec<UniqueIndexDef>,
68}
69
70#[derive(Debug, Clone)]
71pub struct UniqueIndexDef {
72    pub name: Option<String>,
73    pub columns: Vec<String>,
74}
75
76#[derive(Debug, Clone)]
77pub struct TableCheckConstraint {
78    pub name: Option<String>,
79    pub expr: Expr,
80    pub sql: String,
81}
82
83#[derive(Debug, Clone)]
84pub struct ForeignKeyDef {
85    pub name: Option<String>,
86    pub columns: Vec<String>,
87    pub foreign_table: String,
88    pub referred_columns: Vec<String>,
89}
90
91#[derive(Debug, Clone)]
92pub struct ColumnSpec {
93    pub name: String,
94    pub data_type: DataType,
95    pub nullable: bool,
96    pub is_primary_key: bool,
97    pub default_expr: Option<Expr>,
98    pub default_sql: Option<String>,
99    pub check_expr: Option<Expr>,
100    pub check_sql: Option<String>,
101    pub check_name: Option<String>,
102}
103
104#[derive(Debug, Clone)]
105pub struct DropTableStmt {
106    pub name: String,
107    pub if_exists: bool,
108}
109
110#[derive(Debug, Clone)]
111pub struct CreateIndexStmt {
112    pub index_name: String,
113    pub table_name: String,
114    pub columns: Vec<String>,
115    pub unique: bool,
116    pub if_not_exists: bool,
117}
118
119#[derive(Debug, Clone)]
120pub struct DropIndexStmt {
121    pub index_name: String,
122    pub if_exists: bool,
123}
124
125#[derive(Debug, Clone)]
126pub struct CreateViewStmt {
127    pub name: String,
128    pub sql: String,
129    pub column_aliases: Vec<String>,
130    pub or_replace: bool,
131    pub if_not_exists: bool,
132}
133
134#[derive(Debug, Clone)]
135pub struct DropViewStmt {
136    pub name: String,
137    pub if_exists: bool,
138}
139
140#[derive(Debug, Clone)]
141pub enum InsertSource {
142    Values(Vec<Vec<Expr>>),
143    Select(Box<SelectQuery>),
144}
145
146#[derive(Debug, Clone)]
147pub struct InsertStmt {
148    pub table: String,
149    pub columns: Vec<String>,
150    pub source: InsertSource,
151}
152
153#[derive(Debug, Clone)]
154pub struct TableRef {
155    pub name: String,
156    pub alias: Option<String>,
157}
158
159#[derive(Debug, Clone, Copy, PartialEq)]
160pub enum JoinType {
161    Inner,
162    Cross,
163    Left,
164    Right,
165}
166
167#[derive(Debug, Clone)]
168pub struct JoinClause {
169    pub join_type: JoinType,
170    pub table: TableRef,
171    pub on_clause: Option<Expr>,
172}
173
174#[derive(Debug, Clone)]
175pub struct SelectStmt {
176    pub columns: Vec<SelectColumn>,
177    pub from: String,
178    pub from_alias: Option<String>,
179    pub joins: Vec<JoinClause>,
180    pub distinct: bool,
181    pub where_clause: Option<Expr>,
182    pub order_by: Vec<OrderByItem>,
183    pub limit: Option<Expr>,
184    pub offset: Option<Expr>,
185    pub group_by: Vec<Expr>,
186    pub having: Option<Expr>,
187}
188
189#[derive(Debug, Clone)]
190pub enum SetOp {
191    Union,
192    Intersect,
193    Except,
194}
195
196#[derive(Debug, Clone)]
197pub struct CompoundSelect {
198    pub op: SetOp,
199    pub all: bool,
200    pub left: Box<QueryBody>,
201    pub right: Box<QueryBody>,
202    pub order_by: Vec<OrderByItem>,
203    pub limit: Option<Expr>,
204    pub offset: Option<Expr>,
205}
206
207#[derive(Debug, Clone)]
208pub enum QueryBody {
209    Select(Box<SelectStmt>),
210    Compound(Box<CompoundSelect>),
211}
212
213#[derive(Debug, Clone)]
214pub struct CteDefinition {
215    pub name: String,
216    pub column_aliases: Vec<String>,
217    pub body: QueryBody,
218}
219
220#[derive(Debug, Clone)]
221pub struct SelectQuery {
222    pub ctes: Vec<CteDefinition>,
223    pub recursive: bool,
224    pub body: QueryBody,
225}
226
227#[derive(Debug, Clone)]
228pub struct UpdateStmt {
229    pub table: String,
230    pub assignments: Vec<(String, Expr)>,
231    pub where_clause: Option<Expr>,
232}
233
234#[derive(Debug, Clone)]
235pub struct DeleteStmt {
236    pub table: String,
237    pub where_clause: Option<Expr>,
238}
239
240#[derive(Debug, Clone)]
241pub enum SelectColumn {
242    AllColumns,
243    Expr { expr: Expr, alias: Option<String> },
244}
245
246#[derive(Debug, Clone)]
247pub struct OrderByItem {
248    pub expr: Expr,
249    pub descending: bool,
250    pub nulls_first: Option<bool>,
251}
252
253#[derive(Debug, Clone)]
254pub enum Expr {
255    Literal(Value),
256    Column(String),
257    QualifiedColumn {
258        table: String,
259        column: String,
260    },
261    BinaryOp {
262        left: Box<Expr>,
263        op: BinOp,
264        right: Box<Expr>,
265    },
266    UnaryOp {
267        op: UnaryOp,
268        expr: Box<Expr>,
269    },
270    IsNull(Box<Expr>),
271    IsNotNull(Box<Expr>),
272    Function {
273        name: String,
274        args: Vec<Expr>,
275        /// True for aggregate forms like `COUNT(DISTINCT x)`.
276        distinct: bool,
277    },
278    CountStar,
279    InSubquery {
280        expr: Box<Expr>,
281        subquery: Box<SelectStmt>,
282        negated: bool,
283    },
284    InList {
285        expr: Box<Expr>,
286        list: Vec<Expr>,
287        negated: bool,
288    },
289    Exists {
290        subquery: Box<SelectStmt>,
291        negated: bool,
292    },
293    ScalarSubquery(Box<SelectStmt>),
294    InSet {
295        expr: Box<Expr>,
296        values: std::collections::HashSet<Value>,
297        has_null: bool,
298        negated: bool,
299    },
300    Between {
301        expr: Box<Expr>,
302        low: Box<Expr>,
303        high: Box<Expr>,
304        negated: bool,
305    },
306    Like {
307        expr: Box<Expr>,
308        pattern: Box<Expr>,
309        escape: Option<Box<Expr>>,
310        negated: bool,
311    },
312    Case {
313        operand: Option<Box<Expr>>,
314        conditions: Vec<(Expr, Expr)>,
315        else_result: Option<Box<Expr>>,
316    },
317    Coalesce(Vec<Expr>),
318    Cast {
319        expr: Box<Expr>,
320        data_type: DataType,
321    },
322    Parameter(usize),
323    WindowFunction {
324        name: String,
325        args: Vec<Expr>,
326        spec: WindowSpec,
327    },
328}
329
330#[derive(Debug, Clone)]
331pub struct WindowSpec {
332    pub partition_by: Vec<Expr>,
333    pub order_by: Vec<OrderByItem>,
334    pub frame: Option<WindowFrame>,
335}
336
337#[derive(Debug, Clone)]
338pub struct WindowFrame {
339    pub units: WindowFrameUnits,
340    pub start: WindowFrameBound,
341    pub end: WindowFrameBound,
342}
343
344#[derive(Debug, Clone, Copy)]
345pub enum WindowFrameUnits {
346    Rows,
347    Range,
348    Groups,
349}
350
351#[derive(Debug, Clone)]
352pub enum WindowFrameBound {
353    UnboundedPreceding,
354    Preceding(Box<Expr>),
355    CurrentRow,
356    Following(Box<Expr>),
357    UnboundedFollowing,
358}
359
360#[derive(Debug, Clone, Copy, PartialEq, Eq)]
361pub enum BinOp {
362    Add,
363    Sub,
364    Mul,
365    Div,
366    Mod,
367    Eq,
368    NotEq,
369    Lt,
370    Gt,
371    LtEq,
372    GtEq,
373    And,
374    Or,
375    Concat,
376}
377
378#[derive(Debug, Clone, Copy, PartialEq, Eq)]
379pub enum UnaryOp {
380    Neg,
381    Not,
382}
383
384pub fn has_subquery(expr: &Expr) -> bool {
385    match expr {
386        Expr::InSubquery { .. } | Expr::Exists { .. } | Expr::ScalarSubquery(_) => true,
387        Expr::BinaryOp { left, right, .. } => has_subquery(left) || has_subquery(right),
388        Expr::UnaryOp { expr, .. } => has_subquery(expr),
389        Expr::IsNull(e) | Expr::IsNotNull(e) => has_subquery(e),
390        Expr::InList { expr, list, .. } => has_subquery(expr) || list.iter().any(has_subquery),
391        Expr::InSet { expr, .. } => has_subquery(expr),
392        Expr::Between {
393            expr, low, high, ..
394        } => has_subquery(expr) || has_subquery(low) || has_subquery(high),
395        Expr::Like {
396            expr,
397            pattern,
398            escape,
399            ..
400        } => {
401            has_subquery(expr)
402                || has_subquery(pattern)
403                || escape.as_ref().is_some_and(|e| has_subquery(e))
404        }
405        Expr::Case {
406            operand,
407            conditions,
408            else_result,
409        } => {
410            operand.as_ref().is_some_and(|e| has_subquery(e))
411                || conditions
412                    .iter()
413                    .any(|(c, r)| has_subquery(c) || has_subquery(r))
414                || else_result.as_ref().is_some_and(|e| has_subquery(e))
415        }
416        Expr::Coalesce(args) | Expr::Function { args, .. } => args.iter().any(has_subquery),
417        Expr::Cast { expr, .. } => has_subquery(expr),
418        _ => false,
419    }
420}
421
422/// Parse a SQL expression string back into an internal Expr.
423/// Used for deserializing stored DEFAULT/CHECK expressions from schema.
424pub fn parse_sql_expr(sql: &str) -> Result<Expr> {
425    let dialect = GenericDialect {};
426    let mut parser = Parser::new(&dialect)
427        .try_with_sql(sql)
428        .map_err(|e| SqlError::Parse(e.to_string()))?;
429    let sp_expr = parser
430        .parse_expr()
431        .map_err(|e| SqlError::Parse(e.to_string()))?;
432    convert_expr(&sp_expr)
433}
434
435pub fn parse_sql(sql: &str) -> Result<Statement> {
436    let dialect = GenericDialect {};
437    let stmts = Parser::parse_sql(&dialect, sql).map_err(|e| SqlError::Parse(e.to_string()))?;
438
439    if stmts.is_empty() {
440        return Err(SqlError::Parse("empty SQL".into()));
441    }
442    if stmts.len() > 1 {
443        return Err(SqlError::Unsupported("multiple statements".into()));
444    }
445
446    convert_statement(stmts.into_iter().next().unwrap())
447}
448
449/// Parse one or more `;`-separated SQL statements.
450pub fn parse_sql_multi(sql: &str) -> Result<Vec<Statement>> {
451    let dialect = GenericDialect {};
452    let stmts = Parser::parse_sql(&dialect, sql).map_err(|e| SqlError::Parse(e.to_string()))?;
453
454    if stmts.is_empty() {
455        return Err(SqlError::Parse("empty SQL".into()));
456    }
457
458    stmts.into_iter().map(convert_statement).collect()
459}
460
461/// Returns the number of distinct parameters in a statement (max $N found).
462pub fn count_params(stmt: &Statement) -> usize {
463    let mut max_idx = 0usize;
464    visit_exprs_stmt(stmt, &mut |e| {
465        if let Expr::Parameter(n) = e {
466            max_idx = max_idx.max(*n);
467        }
468    });
469    max_idx
470}
471
472fn visit_exprs_stmt(stmt: &Statement, visitor: &mut impl FnMut(&Expr)) {
473    match stmt {
474        Statement::Select(sq) => {
475            for cte in &sq.ctes {
476                visit_exprs_query_body(&cte.body, visitor);
477            }
478            visit_exprs_query_body(&sq.body, visitor);
479        }
480        Statement::Insert(ins) => match &ins.source {
481            InsertSource::Values(rows) => {
482                for row in rows {
483                    for e in row {
484                        visit_expr(e, visitor);
485                    }
486                }
487            }
488            InsertSource::Select(sq) => {
489                for cte in &sq.ctes {
490                    visit_exprs_query_body(&cte.body, visitor);
491                }
492                visit_exprs_query_body(&sq.body, visitor);
493            }
494        },
495        Statement::Update(upd) => {
496            for (_, e) in &upd.assignments {
497                visit_expr(e, visitor);
498            }
499            if let Some(w) = &upd.where_clause {
500                visit_expr(w, visitor);
501            }
502        }
503        Statement::Delete(del) => {
504            if let Some(w) = &del.where_clause {
505                visit_expr(w, visitor);
506            }
507        }
508        Statement::Explain(inner) => visit_exprs_stmt(inner, visitor),
509        _ => {}
510    }
511}
512
513fn visit_exprs_query_body(body: &QueryBody, visitor: &mut impl FnMut(&Expr)) {
514    match body {
515        QueryBody::Select(sel) => visit_exprs_select(sel, visitor),
516        QueryBody::Compound(comp) => {
517            visit_exprs_query_body(&comp.left, visitor);
518            visit_exprs_query_body(&comp.right, visitor);
519            for o in &comp.order_by {
520                visit_expr(&o.expr, visitor);
521            }
522            if let Some(l) = &comp.limit {
523                visit_expr(l, visitor);
524            }
525            if let Some(o) = &comp.offset {
526                visit_expr(o, visitor);
527            }
528        }
529    }
530}
531
532fn visit_exprs_select(sel: &SelectStmt, visitor: &mut impl FnMut(&Expr)) {
533    for col in &sel.columns {
534        if let SelectColumn::Expr { expr, .. } = col {
535            visit_expr(expr, visitor);
536        }
537    }
538    for j in &sel.joins {
539        if let Some(on) = &j.on_clause {
540            visit_expr(on, visitor);
541        }
542    }
543    if let Some(w) = &sel.where_clause {
544        visit_expr(w, visitor);
545    }
546    for o in &sel.order_by {
547        visit_expr(&o.expr, visitor);
548    }
549    if let Some(l) = &sel.limit {
550        visit_expr(l, visitor);
551    }
552    if let Some(o) = &sel.offset {
553        visit_expr(o, visitor);
554    }
555    for g in &sel.group_by {
556        visit_expr(g, visitor);
557    }
558    if let Some(h) = &sel.having {
559        visit_expr(h, visitor);
560    }
561}
562
563fn visit_expr(expr: &Expr, visitor: &mut impl FnMut(&Expr)) {
564    visitor(expr);
565    match expr {
566        Expr::BinaryOp { left, right, .. } => {
567            visit_expr(left, visitor);
568            visit_expr(right, visitor);
569        }
570        Expr::UnaryOp { expr: e, .. } | Expr::IsNull(e) | Expr::IsNotNull(e) => {
571            visit_expr(e, visitor);
572        }
573        Expr::Function { args, .. } | Expr::Coalesce(args) => {
574            for a in args {
575                visit_expr(a, visitor);
576            }
577        }
578        Expr::InSubquery {
579            expr: e, subquery, ..
580        } => {
581            visit_expr(e, visitor);
582            visit_exprs_select(subquery, visitor);
583        }
584        Expr::InList { expr: e, list, .. } => {
585            visit_expr(e, visitor);
586            for l in list {
587                visit_expr(l, visitor);
588            }
589        }
590        Expr::Exists { subquery, .. } => visit_exprs_select(subquery, visitor),
591        Expr::ScalarSubquery(sq) => visit_exprs_select(sq, visitor),
592        Expr::InSet { expr: e, .. } => visit_expr(e, visitor),
593        Expr::Between {
594            expr: e, low, high, ..
595        } => {
596            visit_expr(e, visitor);
597            visit_expr(low, visitor);
598            visit_expr(high, visitor);
599        }
600        Expr::Like {
601            expr: e,
602            pattern,
603            escape,
604            ..
605        } => {
606            visit_expr(e, visitor);
607            visit_expr(pattern, visitor);
608            if let Some(esc) = escape {
609                visit_expr(esc, visitor);
610            }
611        }
612        Expr::Case {
613            operand,
614            conditions,
615            else_result,
616        } => {
617            if let Some(op) = operand {
618                visit_expr(op, visitor);
619            }
620            for (cond, then) in conditions {
621                visit_expr(cond, visitor);
622                visit_expr(then, visitor);
623            }
624            if let Some(el) = else_result {
625                visit_expr(el, visitor);
626            }
627        }
628        Expr::Cast { expr: e, .. } => visit_expr(e, visitor),
629        Expr::WindowFunction { args, spec, .. } => {
630            for a in args {
631                visit_expr(a, visitor);
632            }
633            for p in &spec.partition_by {
634                visit_expr(p, visitor);
635            }
636            for o in &spec.order_by {
637                visit_expr(&o.expr, visitor);
638            }
639            if let Some(ref frame) = spec.frame {
640                if let WindowFrameBound::Preceding(e) | WindowFrameBound::Following(e) =
641                    &frame.start
642                {
643                    visit_expr(e, visitor);
644                }
645                if let WindowFrameBound::Preceding(e) | WindowFrameBound::Following(e) = &frame.end
646                {
647                    visit_expr(e, visitor);
648                }
649            }
650        }
651        Expr::Literal(_)
652        | Expr::Column(_)
653        | Expr::QualifiedColumn { .. }
654        | Expr::CountStar
655        | Expr::Parameter(_) => {}
656    }
657}
658
659fn convert_statement(stmt: sp::Statement) -> Result<Statement> {
660    match stmt {
661        sp::Statement::CreateTable(ct) => convert_create_table(ct),
662        sp::Statement::CreateIndex(ci) => convert_create_index(ci),
663        sp::Statement::Drop {
664            object_type: sp::ObjectType::Table,
665            if_exists,
666            names,
667            ..
668        } => {
669            if names.len() != 1 {
670                return Err(SqlError::Unsupported("multi-table DROP".into()));
671            }
672            Ok(Statement::DropTable(DropTableStmt {
673                name: object_name_to_string(&names[0]),
674                if_exists,
675            }))
676        }
677        sp::Statement::Drop {
678            object_type: sp::ObjectType::Index,
679            if_exists,
680            names,
681            ..
682        } => {
683            if names.len() != 1 {
684                return Err(SqlError::Unsupported("multi-index DROP".into()));
685            }
686            Ok(Statement::DropIndex(DropIndexStmt {
687                index_name: object_name_to_string(&names[0]),
688                if_exists,
689            }))
690        }
691        sp::Statement::CreateView(cv) => convert_create_view(cv),
692        sp::Statement::Drop {
693            object_type: sp::ObjectType::View,
694            if_exists,
695            names,
696            ..
697        } => {
698            if names.len() != 1 {
699                return Err(SqlError::Unsupported("multi-view DROP".into()));
700            }
701            Ok(Statement::DropView(DropViewStmt {
702                name: object_name_to_string(&names[0]),
703                if_exists,
704            }))
705        }
706        sp::Statement::AlterTable(at) => convert_alter_table(at),
707        sp::Statement::Insert(insert) => convert_insert(insert),
708        sp::Statement::Query(query) => convert_query(*query),
709        sp::Statement::Update(update) => convert_update(update),
710        sp::Statement::Delete(delete) => convert_delete(delete),
711        sp::Statement::StartTransaction { .. } => Ok(Statement::Begin),
712        sp::Statement::Commit { chain: true, .. } => {
713            Err(SqlError::Unsupported("COMMIT AND CHAIN".into()))
714        }
715        sp::Statement::Commit { .. } => Ok(Statement::Commit),
716        sp::Statement::Rollback { chain: true, .. } => {
717            Err(SqlError::Unsupported("ROLLBACK AND CHAIN".into()))
718        }
719        sp::Statement::Rollback {
720            savepoint: Some(name),
721            ..
722        } => Ok(Statement::RollbackTo(name.value.to_ascii_lowercase())),
723        sp::Statement::Rollback { .. } => Ok(Statement::Rollback),
724        sp::Statement::Savepoint { name } => {
725            Ok(Statement::Savepoint(name.value.to_ascii_lowercase()))
726        }
727        sp::Statement::ReleaseSavepoint { name } => {
728            Ok(Statement::ReleaseSavepoint(name.value.to_ascii_lowercase()))
729        }
730        sp::Statement::Set(sp::Set::SetTimeZone { value, .. }) => {
731            // Accept a string literal or bare identifier (PG allows `SET TIME ZONE UTC`).
732            let zone = match value {
733                sp::Expr::Value(v) => match &v.value {
734                    sp::Value::SingleQuotedString(s) => s.clone(),
735                    sp::Value::DoubleQuotedString(s) => s.clone(),
736                    other => other.to_string(),
737                },
738                sp::Expr::Identifier(ident) => ident.value.clone(),
739                other => {
740                    return Err(SqlError::Parse(format!(
741                        "SET TIME ZONE expects a string literal or identifier, got: {other}"
742                    )))
743                }
744            };
745            Ok(Statement::SetTimezone(zone))
746        }
747        sp::Statement::Explain {
748            statement, analyze, ..
749        } => {
750            if analyze {
751                return Err(SqlError::Unsupported("EXPLAIN ANALYZE".into()));
752            }
753            let inner = convert_statement(*statement)?;
754            Ok(Statement::Explain(Box::new(inner)))
755        }
756        _ => Err(SqlError::Unsupported(format!("statement type: {}", stmt))),
757    }
758}
759
760/// Parse column options (NOT NULL, DEFAULT, CHECK, FK, UNIQUE) from a sqlparser ColumnDef.
761/// Returns (ColumnSpec, Option<ForeignKeyDef>, was_inline_pk, was_unique).
762fn convert_column_def(
763    col_def: &sp::ColumnDef,
764) -> Result<(ColumnSpec, Option<ForeignKeyDef>, bool, bool)> {
765    let col_name = col_def.name.value.clone();
766    let data_type = convert_data_type(&col_def.data_type)?;
767    let mut nullable = true;
768    let mut is_primary_key = false;
769    let mut is_unique = false;
770    let mut default_expr = None;
771    let mut default_sql = None;
772    let mut check_expr = None;
773    let mut check_sql = None;
774    let mut check_name = None;
775    let mut fk_def = None;
776
777    for opt in &col_def.options {
778        match &opt.option {
779            sp::ColumnOption::NotNull => nullable = false,
780            sp::ColumnOption::Null => nullable = true,
781            sp::ColumnOption::PrimaryKey(_) => {
782                is_primary_key = true;
783                nullable = false;
784            }
785            sp::ColumnOption::Unique(_) => is_unique = true,
786            sp::ColumnOption::Default(expr) => {
787                default_sql = Some(expr.to_string());
788                default_expr = Some(convert_expr(expr)?);
789            }
790            sp::ColumnOption::Check(check) => {
791                check_sql = Some(check.expr.to_string());
792                let converted = convert_expr(&check.expr)?;
793                if has_subquery(&converted) {
794                    return Err(SqlError::Unsupported("subquery in CHECK constraint".into()));
795                }
796                check_expr = Some(converted);
797                check_name = check.name.as_ref().map(|n| n.value.clone());
798            }
799            sp::ColumnOption::ForeignKey(fk) => {
800                convert_fk_actions(&fk.on_delete, &fk.on_update)?;
801                let ftable = object_name_to_string(&fk.foreign_table).to_ascii_lowercase();
802                let referred: Vec<String> = fk
803                    .referred_columns
804                    .iter()
805                    .map(|i| i.value.to_ascii_lowercase())
806                    .collect();
807                fk_def = Some(ForeignKeyDef {
808                    name: fk.name.as_ref().map(|n| n.value.clone()),
809                    columns: vec![col_name.to_ascii_lowercase()],
810                    foreign_table: ftable,
811                    referred_columns: referred,
812                });
813            }
814            _ => {}
815        }
816    }
817
818    let spec = ColumnSpec {
819        name: col_name,
820        data_type,
821        nullable,
822        is_primary_key,
823        default_expr,
824        default_sql,
825        check_expr,
826        check_sql,
827        check_name,
828    };
829    Ok((spec, fk_def, is_primary_key, is_unique))
830}
831
832fn convert_create_table(ct: sp::CreateTable) -> Result<Statement> {
833    let name = object_name_to_string(&ct.name);
834    let if_not_exists = ct.if_not_exists;
835
836    let mut columns = Vec::new();
837    let mut inline_pk: Vec<String> = Vec::new();
838    let mut foreign_keys: Vec<ForeignKeyDef> = Vec::new();
839    let mut unique_indices: Vec<UniqueIndexDef> = Vec::new();
840
841    for col_def in &ct.columns {
842        let (spec, fk_def, was_pk, was_unique) = convert_column_def(col_def)?;
843        if was_pk {
844            inline_pk.push(spec.name.clone());
845        }
846        if let Some(fk) = fk_def {
847            foreign_keys.push(fk);
848        }
849        if was_unique && !was_pk {
850            unique_indices.push(UniqueIndexDef {
851                name: None,
852                columns: vec![spec.name.to_ascii_lowercase()],
853            });
854        }
855        columns.push(spec);
856    }
857
858    let mut check_constraints: Vec<TableCheckConstraint> = Vec::new();
859
860    for constraint in &ct.constraints {
861        match constraint {
862            sp::TableConstraint::PrimaryKey(pk_constraint) => {
863                for idx_col in &pk_constraint.columns {
864                    let col_name = match &idx_col.column.expr {
865                        sp::Expr::Identifier(ident) => ident.value.clone(),
866                        _ => continue,
867                    };
868                    if !inline_pk.contains(&col_name) {
869                        inline_pk.push(col_name.clone());
870                    }
871                    if let Some(col) = columns.iter_mut().find(|c| c.name == col_name) {
872                        col.nullable = false;
873                        col.is_primary_key = true;
874                    }
875                }
876            }
877            sp::TableConstraint::Check(check) => {
878                let sql = check.expr.to_string();
879                let converted = convert_expr(&check.expr)?;
880                if has_subquery(&converted) {
881                    return Err(SqlError::Unsupported("subquery in CHECK constraint".into()));
882                }
883                check_constraints.push(TableCheckConstraint {
884                    name: check.name.as_ref().map(|n| n.value.clone()),
885                    expr: converted,
886                    sql,
887                });
888            }
889            sp::TableConstraint::ForeignKey(fk) => {
890                convert_fk_actions(&fk.on_delete, &fk.on_update)?;
891                let cols: Vec<String> = fk
892                    .columns
893                    .iter()
894                    .map(|i| i.value.to_ascii_lowercase())
895                    .collect();
896                let ftable = object_name_to_string(&fk.foreign_table).to_ascii_lowercase();
897                let referred: Vec<String> = fk
898                    .referred_columns
899                    .iter()
900                    .map(|i| i.value.to_ascii_lowercase())
901                    .collect();
902                foreign_keys.push(ForeignKeyDef {
903                    name: fk.name.as_ref().map(|n| n.value.clone()),
904                    columns: cols,
905                    foreign_table: ftable,
906                    referred_columns: referred,
907                });
908            }
909            sp::TableConstraint::Unique(u) => {
910                let cols: Vec<String> = u
911                    .columns
912                    .iter()
913                    .filter_map(|idx_col| match &idx_col.column.expr {
914                        sp::Expr::Identifier(ident) => Some(ident.value.to_ascii_lowercase()),
915                        _ => None,
916                    })
917                    .collect();
918                if !cols.is_empty() {
919                    unique_indices.push(UniqueIndexDef {
920                        name: u.name.as_ref().map(|n| n.value.clone()),
921                        columns: cols,
922                    });
923                }
924            }
925            _ => {}
926        }
927    }
928
929    Ok(Statement::CreateTable(CreateTableStmt {
930        name,
931        columns,
932        primary_key: inline_pk,
933        if_not_exists,
934        check_constraints,
935        foreign_keys,
936        unique_indices,
937    }))
938}
939
940fn convert_alter_table(at: sp::AlterTable) -> Result<Statement> {
941    let table = object_name_to_string(&at.name);
942    if at.operations.len() != 1 {
943        return Err(SqlError::Unsupported(
944            "ALTER TABLE with multiple operations".into(),
945        ));
946    }
947    let op = match at.operations.into_iter().next().unwrap() {
948        sp::AlterTableOperation::AddColumn {
949            column_def,
950            if_not_exists,
951            ..
952        } => {
953            let (spec, fk, _was_pk, _was_unique) = convert_column_def(&column_def)?;
954            AlterTableOp::AddColumn {
955                column: Box::new(spec),
956                foreign_key: fk,
957                if_not_exists,
958            }
959        }
960        sp::AlterTableOperation::DropColumn {
961            column_names,
962            if_exists,
963            ..
964        } => {
965            if column_names.len() != 1 {
966                return Err(SqlError::Unsupported(
967                    "DROP COLUMN with multiple columns".into(),
968                ));
969            }
970            AlterTableOp::DropColumn {
971                name: column_names.into_iter().next().unwrap().value,
972                if_exists,
973            }
974        }
975        sp::AlterTableOperation::RenameColumn {
976            old_column_name,
977            new_column_name,
978        } => AlterTableOp::RenameColumn {
979            old_name: old_column_name.value,
980            new_name: new_column_name.value,
981        },
982        sp::AlterTableOperation::RenameTable { table_name } => {
983            let new_name = match table_name {
984                sp::RenameTableNameKind::To(name) | sp::RenameTableNameKind::As(name) => {
985                    object_name_to_string(&name)
986                }
987            };
988            AlterTableOp::RenameTable { new_name }
989        }
990        other => {
991            return Err(SqlError::Unsupported(format!(
992                "ALTER TABLE operation: {other}"
993            )));
994        }
995    };
996    Ok(Statement::AlterTable(Box::new(AlterTableStmt {
997        table,
998        op,
999    })))
1000}
1001
1002fn convert_fk_actions(
1003    on_delete: &Option<sp::ReferentialAction>,
1004    on_update: &Option<sp::ReferentialAction>,
1005) -> Result<()> {
1006    for action in [on_delete, on_update] {
1007        match action {
1008            None
1009            | Some(sp::ReferentialAction::Restrict)
1010            | Some(sp::ReferentialAction::NoAction) => {}
1011            Some(other) => {
1012                return Err(SqlError::Unsupported(format!(
1013                    "FOREIGN KEY action: {other}"
1014                )));
1015            }
1016        }
1017    }
1018    Ok(())
1019}
1020
1021fn convert_create_index(ci: sp::CreateIndex) -> Result<Statement> {
1022    let index_name = ci
1023        .name
1024        .as_ref()
1025        .map(object_name_to_string)
1026        .ok_or_else(|| SqlError::Parse("index name required".into()))?;
1027
1028    let table_name = object_name_to_string(&ci.table_name);
1029
1030    let columns: Vec<String> = ci
1031        .columns
1032        .iter()
1033        .map(|idx_col| match &idx_col.column.expr {
1034            sp::Expr::Identifier(ident) => Ok(ident.value.clone()),
1035            other => Err(SqlError::Unsupported(format!("expression index: {other}"))),
1036        })
1037        .collect::<Result<_>>()?;
1038
1039    if columns.is_empty() {
1040        return Err(SqlError::Parse(
1041            "index must have at least one column".into(),
1042        ));
1043    }
1044
1045    Ok(Statement::CreateIndex(CreateIndexStmt {
1046        index_name,
1047        table_name,
1048        columns,
1049        unique: ci.unique,
1050        if_not_exists: ci.if_not_exists,
1051    }))
1052}
1053
1054fn convert_create_view(cv: sp::CreateView) -> Result<Statement> {
1055    let name = object_name_to_string(&cv.name);
1056
1057    if cv.materialized {
1058        return Err(SqlError::Unsupported("MATERIALIZED VIEW".into()));
1059    }
1060
1061    let sql = cv.query.to_string();
1062
1063    let dialect = GenericDialect {};
1064    let test = Parser::parse_sql(&dialect, &sql).map_err(|e| SqlError::Parse(e.to_string()))?;
1065    if test.is_empty() {
1066        return Err(SqlError::Parse("empty view definition".into()));
1067    }
1068    match &test[0] {
1069        sp::Statement::Query(_) => {}
1070        _ => {
1071            return Err(SqlError::Parse(
1072                "view body must be a SELECT statement".into(),
1073            ))
1074        }
1075    }
1076
1077    let column_aliases: Vec<String> = cv
1078        .columns
1079        .iter()
1080        .map(|c| c.name.value.to_ascii_lowercase())
1081        .collect();
1082
1083    Ok(Statement::CreateView(CreateViewStmt {
1084        name,
1085        sql,
1086        column_aliases,
1087        or_replace: cv.or_replace,
1088        if_not_exists: cv.if_not_exists,
1089    }))
1090}
1091
1092fn convert_insert(insert: sp::Insert) -> Result<Statement> {
1093    let table = match &insert.table {
1094        sp::TableObject::TableName(name) => object_name_to_string(name).to_ascii_lowercase(),
1095        _ => return Err(SqlError::Unsupported("INSERT into non-table object".into())),
1096    };
1097
1098    let columns: Vec<String> = insert
1099        .columns
1100        .iter()
1101        .map(|c| c.value.to_ascii_lowercase())
1102        .collect();
1103
1104    let query = insert
1105        .source
1106        .ok_or_else(|| SqlError::Parse("INSERT requires VALUES or SELECT".into()))?;
1107
1108    let source = match *query.body {
1109        sp::SetExpr::Values(sp::Values { rows, .. }) => {
1110            let mut result = Vec::new();
1111            for row in rows {
1112                let mut exprs = Vec::new();
1113                for expr in row {
1114                    exprs.push(convert_expr(&expr)?);
1115                }
1116                result.push(exprs);
1117            }
1118            InsertSource::Values(result)
1119        }
1120        _ => {
1121            let (ctes, recursive) = if let Some(ref with) = query.with {
1122                convert_with(with)?
1123            } else {
1124                (vec![], false)
1125            };
1126            let body = convert_query_body(&query)?;
1127            InsertSource::Select(Box::new(SelectQuery {
1128                ctes,
1129                recursive,
1130                body,
1131            }))
1132        }
1133    };
1134
1135    Ok(Statement::Insert(InsertStmt {
1136        table,
1137        columns,
1138        source,
1139    }))
1140}
1141
1142fn convert_select_body(select: &sp::Select) -> Result<SelectStmt> {
1143    let distinct = match &select.distinct {
1144        Some(sp::Distinct::Distinct) => true,
1145        Some(sp::Distinct::On(_)) => {
1146            return Err(SqlError::Unsupported("DISTINCT ON".into()));
1147        }
1148        _ => false,
1149    };
1150
1151    let (from, from_alias, joins) = if select.from.is_empty() {
1152        (String::new(), None, vec![])
1153    } else if select.from.len() == 1 {
1154        let table_with_joins = &select.from[0];
1155        let (name, alias) = match &table_with_joins.relation {
1156            sp::TableFactor::Table { name, alias, .. } => {
1157                let table_name = object_name_to_string(name);
1158                let alias_str = alias.as_ref().map(|a| a.name.value.clone());
1159                (table_name, alias_str)
1160            }
1161            _ => return Err(SqlError::Unsupported("non-table FROM source".into())),
1162        };
1163        let j = table_with_joins
1164            .joins
1165            .iter()
1166            .map(convert_join)
1167            .collect::<Result<Vec<_>>>()?;
1168        (name, alias, j)
1169    } else {
1170        return Err(SqlError::Unsupported("comma-separated FROM tables".into()));
1171    };
1172
1173    let columns: Vec<SelectColumn> = select
1174        .projection
1175        .iter()
1176        .map(convert_select_item)
1177        .collect::<Result<_>>()?;
1178
1179    let where_clause = select.selection.as_ref().map(convert_expr).transpose()?;
1180
1181    let group_by = match &select.group_by {
1182        sp::GroupByExpr::Expressions(exprs, _) => {
1183            exprs.iter().map(convert_expr).collect::<Result<_>>()?
1184        }
1185        sp::GroupByExpr::All(_) => {
1186            return Err(SqlError::Unsupported("GROUP BY ALL".into()));
1187        }
1188    };
1189
1190    let having = select.having.as_ref().map(convert_expr).transpose()?;
1191
1192    Ok(SelectStmt {
1193        columns,
1194        from,
1195        from_alias,
1196        joins,
1197        distinct,
1198        where_clause,
1199        order_by: vec![],
1200        limit: None,
1201        offset: None,
1202        group_by,
1203        having,
1204    })
1205}
1206
1207fn convert_set_expr(set_expr: &sp::SetExpr) -> Result<QueryBody> {
1208    match set_expr {
1209        sp::SetExpr::Select(sel) => Ok(QueryBody::Select(Box::new(convert_select_body(sel)?))),
1210        sp::SetExpr::SetOperation {
1211            op,
1212            set_quantifier,
1213            left,
1214            right,
1215        } => {
1216            let set_op = match op {
1217                sp::SetOperator::Union => SetOp::Union,
1218                sp::SetOperator::Intersect => SetOp::Intersect,
1219                sp::SetOperator::Except | sp::SetOperator::Minus => SetOp::Except,
1220            };
1221            let all = match set_quantifier {
1222                sp::SetQuantifier::All => true,
1223                sp::SetQuantifier::None | sp::SetQuantifier::Distinct => false,
1224                _ => {
1225                    return Err(SqlError::Unsupported("BY NAME set operations".into()));
1226                }
1227            };
1228            Ok(QueryBody::Compound(Box::new(CompoundSelect {
1229                op: set_op,
1230                all,
1231                left: Box::new(convert_set_expr(left)?),
1232                right: Box::new(convert_set_expr(right)?),
1233                order_by: vec![],
1234                limit: None,
1235                offset: None,
1236            })))
1237        }
1238        _ => Err(SqlError::Unsupported("unsupported set expression".into())),
1239    }
1240}
1241
1242fn convert_query_body(query: &sp::Query) -> Result<QueryBody> {
1243    let mut body = convert_set_expr(&query.body)?;
1244
1245    let order_by = if let Some(ref ob) = query.order_by {
1246        match &ob.kind {
1247            sp::OrderByKind::Expressions(exprs) => exprs
1248                .iter()
1249                .map(convert_order_by_expr)
1250                .collect::<Result<_>>()?,
1251            sp::OrderByKind::All { .. } => {
1252                return Err(SqlError::Unsupported("ORDER BY ALL".into()));
1253            }
1254        }
1255    } else {
1256        vec![]
1257    };
1258
1259    let (limit, offset) = match &query.limit_clause {
1260        Some(sp::LimitClause::LimitOffset { limit, offset, .. }) => {
1261            let l = limit.as_ref().map(convert_expr).transpose()?;
1262            let o = offset
1263                .as_ref()
1264                .map(|o| convert_expr(&o.value))
1265                .transpose()?;
1266            (l, o)
1267        }
1268        Some(sp::LimitClause::OffsetCommaLimit { limit, offset }) => {
1269            let l = Some(convert_expr(limit)?);
1270            let o = Some(convert_expr(offset)?);
1271            (l, o)
1272        }
1273        None => (None, None),
1274    };
1275
1276    match &mut body {
1277        QueryBody::Select(sel) => {
1278            sel.order_by = order_by;
1279            sel.limit = limit;
1280            sel.offset = offset;
1281        }
1282        QueryBody::Compound(comp) => {
1283            comp.order_by = order_by;
1284            comp.limit = limit;
1285            comp.offset = offset;
1286        }
1287    }
1288
1289    Ok(body)
1290}
1291
1292fn convert_subquery(query: &sp::Query) -> Result<SelectStmt> {
1293    if query.with.is_some() {
1294        return Err(SqlError::Unsupported("CTEs in subqueries".into()));
1295    }
1296    match convert_query_body(query)? {
1297        QueryBody::Select(s) => Ok(*s),
1298        QueryBody::Compound(_) => Err(SqlError::Unsupported(
1299            "UNION/INTERSECT/EXCEPT in subqueries".into(),
1300        )),
1301    }
1302}
1303
1304fn convert_with(with: &sp::With) -> Result<(Vec<CteDefinition>, bool)> {
1305    let mut names = std::collections::HashSet::new();
1306    let mut ctes = Vec::new();
1307    for cte in &with.cte_tables {
1308        let name = cte.alias.name.value.to_ascii_lowercase();
1309        if !names.insert(name.clone()) {
1310            return Err(SqlError::DuplicateCteName(name));
1311        }
1312        let column_aliases: Vec<String> = cte
1313            .alias
1314            .columns
1315            .iter()
1316            .map(|c| c.name.value.to_ascii_lowercase())
1317            .collect();
1318        let body = convert_query_body(&cte.query)?;
1319        ctes.push(CteDefinition {
1320            name,
1321            column_aliases,
1322            body,
1323        });
1324    }
1325    Ok((ctes, with.recursive))
1326}
1327
1328fn convert_query(query: sp::Query) -> Result<Statement> {
1329    let (ctes, recursive) = if let Some(ref with) = query.with {
1330        convert_with(with)?
1331    } else {
1332        (vec![], false)
1333    };
1334    let body = convert_query_body(&query)?;
1335    Ok(Statement::Select(Box::new(SelectQuery {
1336        ctes,
1337        recursive,
1338        body,
1339    })))
1340}
1341
1342fn convert_join(join: &sp::Join) -> Result<JoinClause> {
1343    let (join_type, constraint) = match &join.join_operator {
1344        sp::JoinOperator::Inner(c) => (JoinType::Inner, Some(c)),
1345        sp::JoinOperator::Join(c) => (JoinType::Inner, Some(c)),
1346        sp::JoinOperator::CrossJoin(c) => (JoinType::Cross, Some(c)),
1347        sp::JoinOperator::Left(c) => (JoinType::Left, Some(c)),
1348        sp::JoinOperator::LeftSemi(c) => (JoinType::Left, Some(c)),
1349        sp::JoinOperator::LeftAnti(c) => (JoinType::Left, Some(c)),
1350        sp::JoinOperator::Right(c) => (JoinType::Right, Some(c)),
1351        sp::JoinOperator::RightSemi(c) => (JoinType::Right, Some(c)),
1352        sp::JoinOperator::RightAnti(c) => (JoinType::Right, Some(c)),
1353        other => return Err(SqlError::Unsupported(format!("join type: {other:?}"))),
1354    };
1355
1356    let (name, alias) = match &join.relation {
1357        sp::TableFactor::Table { name, alias, .. } => {
1358            let table_name = object_name_to_string(name);
1359            let alias_str = alias.as_ref().map(|a| a.name.value.clone());
1360            (table_name, alias_str)
1361        }
1362        _ => return Err(SqlError::Unsupported("non-table JOIN source".into())),
1363    };
1364
1365    let on_clause = match constraint {
1366        Some(sp::JoinConstraint::On(expr)) => Some(convert_expr(expr)?),
1367        Some(sp::JoinConstraint::None) | None => None,
1368        Some(other) => return Err(SqlError::Unsupported(format!("join constraint: {other:?}"))),
1369    };
1370
1371    Ok(JoinClause {
1372        join_type,
1373        table: TableRef { name, alias },
1374        on_clause,
1375    })
1376}
1377
1378fn convert_update(update: sp::Update) -> Result<Statement> {
1379    let table = match &update.table.relation {
1380        sp::TableFactor::Table { name, .. } => object_name_to_string(name),
1381        _ => return Err(SqlError::Unsupported("non-table UPDATE target".into())),
1382    };
1383
1384    let assignments = update
1385        .assignments
1386        .iter()
1387        .map(|a| {
1388            let col = match &a.target {
1389                sp::AssignmentTarget::ColumnName(name) => object_name_to_string(name),
1390                _ => return Err(SqlError::Unsupported("tuple assignment".into())),
1391            };
1392            let expr = convert_expr(&a.value)?;
1393            Ok((col, expr))
1394        })
1395        .collect::<Result<_>>()?;
1396
1397    let where_clause = update.selection.as_ref().map(convert_expr).transpose()?;
1398
1399    Ok(Statement::Update(UpdateStmt {
1400        table,
1401        assignments,
1402        where_clause,
1403    }))
1404}
1405
1406fn convert_delete(delete: sp::Delete) -> Result<Statement> {
1407    let table_name = match &delete.from {
1408        sp::FromTable::WithFromKeyword(tables) => {
1409            if tables.len() != 1 {
1410                return Err(SqlError::Unsupported("multi-table DELETE".into()));
1411            }
1412            match &tables[0].relation {
1413                sp::TableFactor::Table { name, .. } => object_name_to_string(name),
1414                _ => return Err(SqlError::Unsupported("non-table DELETE target".into())),
1415            }
1416        }
1417        sp::FromTable::WithoutKeyword(tables) => {
1418            if tables.len() != 1 {
1419                return Err(SqlError::Unsupported("multi-table DELETE".into()));
1420            }
1421            match &tables[0].relation {
1422                sp::TableFactor::Table { name, .. } => object_name_to_string(name),
1423                _ => return Err(SqlError::Unsupported("non-table DELETE target".into())),
1424            }
1425        }
1426    };
1427
1428    let where_clause = delete.selection.as_ref().map(convert_expr).transpose()?;
1429
1430    Ok(Statement::Delete(DeleteStmt {
1431        table: table_name,
1432        where_clause,
1433    }))
1434}
1435
1436fn convert_expr(expr: &sp::Expr) -> Result<Expr> {
1437    match expr {
1438        sp::Expr::Value(v) => convert_value(&v.value),
1439        sp::Expr::Identifier(ident) => Ok(Expr::Column(ident.value.to_ascii_lowercase())),
1440        sp::Expr::CompoundIdentifier(parts) => {
1441            if parts.len() == 2 {
1442                Ok(Expr::QualifiedColumn {
1443                    table: parts[0].value.to_ascii_lowercase(),
1444                    column: parts[1].value.to_ascii_lowercase(),
1445                })
1446            } else {
1447                Ok(Expr::Column(
1448                    parts.last().unwrap().value.to_ascii_lowercase(),
1449                ))
1450            }
1451        }
1452        sp::Expr::BinaryOp { left, op, right } => {
1453            let bin_op = convert_bin_op(op)?;
1454            Ok(Expr::BinaryOp {
1455                left: Box::new(convert_expr(left)?),
1456                op: bin_op,
1457                right: Box::new(convert_expr(right)?),
1458            })
1459        }
1460        sp::Expr::UnaryOp { op, expr } => {
1461            let unary_op = match op {
1462                sp::UnaryOperator::Minus => UnaryOp::Neg,
1463                sp::UnaryOperator::Not => UnaryOp::Not,
1464                _ => return Err(SqlError::Unsupported(format!("unary op: {op}"))),
1465            };
1466            Ok(Expr::UnaryOp {
1467                op: unary_op,
1468                expr: Box::new(convert_expr(expr)?),
1469            })
1470        }
1471        sp::Expr::IsNull(e) => Ok(Expr::IsNull(Box::new(convert_expr(e)?))),
1472        sp::Expr::IsNotNull(e) => Ok(Expr::IsNotNull(Box::new(convert_expr(e)?))),
1473        sp::Expr::Nested(e) => convert_expr(e),
1474        sp::Expr::Function(func) => convert_function(func),
1475        sp::Expr::InSubquery {
1476            expr: e,
1477            subquery,
1478            negated,
1479        } => {
1480            let inner_expr = convert_expr(e)?;
1481            let stmt = convert_subquery(subquery)?;
1482            Ok(Expr::InSubquery {
1483                expr: Box::new(inner_expr),
1484                subquery: Box::new(stmt),
1485                negated: *negated,
1486            })
1487        }
1488        sp::Expr::InList {
1489            expr: e,
1490            list,
1491            negated,
1492        } => {
1493            let inner_expr = convert_expr(e)?;
1494            let items = list.iter().map(convert_expr).collect::<Result<Vec<_>>>()?;
1495            Ok(Expr::InList {
1496                expr: Box::new(inner_expr),
1497                list: items,
1498                negated: *negated,
1499            })
1500        }
1501        sp::Expr::Exists { subquery, negated } => {
1502            let stmt = convert_subquery(subquery)?;
1503            Ok(Expr::Exists {
1504                subquery: Box::new(stmt),
1505                negated: *negated,
1506            })
1507        }
1508        sp::Expr::Subquery(query) => {
1509            let stmt = convert_subquery(query)?;
1510            Ok(Expr::ScalarSubquery(Box::new(stmt)))
1511        }
1512        sp::Expr::Between {
1513            expr: e,
1514            negated,
1515            low,
1516            high,
1517        } => Ok(Expr::Between {
1518            expr: Box::new(convert_expr(e)?),
1519            low: Box::new(convert_expr(low)?),
1520            high: Box::new(convert_expr(high)?),
1521            negated: *negated,
1522        }),
1523        sp::Expr::Like {
1524            expr: e,
1525            negated,
1526            pattern,
1527            escape_char,
1528            ..
1529        } => {
1530            let esc = escape_char
1531                .as_ref()
1532                .map(convert_escape_value)
1533                .transpose()?
1534                .map(Box::new);
1535            Ok(Expr::Like {
1536                expr: Box::new(convert_expr(e)?),
1537                pattern: Box::new(convert_expr(pattern)?),
1538                escape: esc,
1539                negated: *negated,
1540            })
1541        }
1542        sp::Expr::ILike {
1543            expr: e,
1544            negated,
1545            pattern,
1546            escape_char,
1547            ..
1548        } => {
1549            let esc = escape_char
1550                .as_ref()
1551                .map(convert_escape_value)
1552                .transpose()?
1553                .map(Box::new);
1554            Ok(Expr::Like {
1555                expr: Box::new(convert_expr(e)?),
1556                pattern: Box::new(convert_expr(pattern)?),
1557                escape: esc,
1558                negated: *negated,
1559            })
1560        }
1561        sp::Expr::Case {
1562            operand,
1563            conditions,
1564            else_result,
1565            ..
1566        } => {
1567            let op = operand
1568                .as_ref()
1569                .map(|e| convert_expr(e))
1570                .transpose()?
1571                .map(Box::new);
1572            let conds: Vec<(Expr, Expr)> = conditions
1573                .iter()
1574                .map(|cw| Ok((convert_expr(&cw.condition)?, convert_expr(&cw.result)?)))
1575                .collect::<Result<_>>()?;
1576            let else_r = else_result
1577                .as_ref()
1578                .map(|e| convert_expr(e))
1579                .transpose()?
1580                .map(Box::new);
1581            Ok(Expr::Case {
1582                operand: op,
1583                conditions: conds,
1584                else_result: else_r,
1585            })
1586        }
1587        sp::Expr::Cast {
1588            expr: e,
1589            data_type: dt,
1590            ..
1591        } => {
1592            let target = convert_data_type(dt)?;
1593            Ok(Expr::Cast {
1594                expr: Box::new(convert_expr(e)?),
1595                data_type: target,
1596            })
1597        }
1598        sp::Expr::Substring {
1599            expr: e,
1600            substring_from,
1601            substring_for,
1602            ..
1603        } => {
1604            let mut args = vec![convert_expr(e)?];
1605            if let Some(from) = substring_from {
1606                args.push(convert_expr(from)?);
1607            }
1608            if let Some(f) = substring_for {
1609                args.push(convert_expr(f)?);
1610            }
1611            Ok(Expr::Function {
1612                name: "SUBSTR".into(),
1613                args,
1614                distinct: false,
1615            })
1616        }
1617        sp::Expr::Trim {
1618            expr: e,
1619            trim_where,
1620            trim_what,
1621            trim_characters,
1622        } => {
1623            let fn_name = match trim_where {
1624                Some(sp::TrimWhereField::Leading) => "LTRIM",
1625                Some(sp::TrimWhereField::Trailing) => "RTRIM",
1626                _ => "TRIM",
1627            };
1628            let mut args = vec![convert_expr(e)?];
1629            if let Some(what) = trim_what {
1630                args.push(convert_expr(what)?);
1631            } else if let Some(chars) = trim_characters {
1632                if let Some(first) = chars.first() {
1633                    args.push(convert_expr(first)?);
1634                }
1635            }
1636            Ok(Expr::Function {
1637                name: fn_name.into(),
1638                args,
1639                distinct: false,
1640            })
1641        }
1642        sp::Expr::Ceil { expr: e, .. } => Ok(Expr::Function {
1643            name: "CEIL".into(),
1644            args: vec![convert_expr(e)?],
1645            distinct: false,
1646        }),
1647        sp::Expr::Floor { expr: e, .. } => Ok(Expr::Function {
1648            name: "FLOOR".into(),
1649            args: vec![convert_expr(e)?],
1650            distinct: false,
1651        }),
1652        sp::Expr::Position { expr: e, r#in } => Ok(Expr::Function {
1653            name: "INSTR".into(),
1654            args: vec![convert_expr(r#in)?, convert_expr(e)?],
1655            distinct: false,
1656        }),
1657        // Typed literal: `DATE '2024-01-15'`, `TIMESTAMP '...'`, etc.
1658        sp::Expr::TypedString(ts) => {
1659            let raw = match &ts.value.value {
1660                sp::Value::SingleQuotedString(s) => s.clone(),
1661                sp::Value::DoubleQuotedString(s) => s.clone(),
1662                other => other.to_string(),
1663            };
1664            convert_typed_string(&ts.data_type, &raw)
1665        }
1666        // INTERVAL '...' — sqlparser emits Expr::Interval with a boxed Expr value + field qualifiers
1667        sp::Expr::Interval(iv) => convert_interval_expr(iv),
1668        // EXTRACT(field FROM src)
1669        sp::Expr::Extract { field, expr: e, .. } => {
1670            let field_name = match field {
1671                sp::DateTimeField::Year => "year",
1672                sp::DateTimeField::Month => "month",
1673                sp::DateTimeField::Week(_) => "week",
1674                sp::DateTimeField::Day => "day",
1675                sp::DateTimeField::Date => "day",
1676                sp::DateTimeField::Hour => "hour",
1677                sp::DateTimeField::Minute => "minute",
1678                sp::DateTimeField::Second => "second",
1679                sp::DateTimeField::Millisecond => "milliseconds",
1680                sp::DateTimeField::Microsecond => "microseconds",
1681                sp::DateTimeField::Microseconds => "microseconds",
1682                sp::DateTimeField::Milliseconds => "milliseconds",
1683                sp::DateTimeField::Dow => "dow",
1684                sp::DateTimeField::Isodow => "isodow",
1685                sp::DateTimeField::Doy => "doy",
1686                sp::DateTimeField::Epoch => "epoch",
1687                sp::DateTimeField::Quarter => "quarter",
1688                sp::DateTimeField::Decade => "decade",
1689                sp::DateTimeField::Century => "century",
1690                sp::DateTimeField::Millennium => "millennium",
1691                sp::DateTimeField::Isoyear => "isoyear",
1692                sp::DateTimeField::Julian => "julian",
1693                other => {
1694                    return Err(SqlError::InvalidExtractField(format!("{other:?}")));
1695                }
1696            };
1697            Ok(Expr::Function {
1698                name: "EXTRACT".into(),
1699                args: vec![
1700                    Expr::Literal(Value::Text(field_name.into())),
1701                    convert_expr(e)?,
1702                ],
1703                distinct: false,
1704            })
1705        }
1706        // `AT TIME ZONE 'zone'` operator — desugars to AT_TIMEZONE(ts, zone) scalar function.
1707        sp::Expr::AtTimeZone {
1708            timestamp,
1709            time_zone,
1710        } => Ok(Expr::Function {
1711            name: "AT_TIMEZONE".into(),
1712            args: vec![convert_expr(timestamp)?, convert_expr(time_zone)?],
1713            distinct: false,
1714        }),
1715        _ => Err(SqlError::Unsupported(format!("expression: {expr}"))),
1716    }
1717}
1718
1719fn convert_value(val: &sp::Value) -> Result<Expr> {
1720    match val {
1721        sp::Value::Number(n, _) => {
1722            if let Ok(i) = n.parse::<i64>() {
1723                Ok(Expr::Literal(Value::Integer(i)))
1724            } else if let Ok(f) = n.parse::<f64>() {
1725                Ok(Expr::Literal(Value::Real(f)))
1726            } else {
1727                Err(SqlError::InvalidValue(format!("cannot parse number: {n}")))
1728            }
1729        }
1730        sp::Value::SingleQuotedString(s) => Ok(Expr::Literal(Value::Text(s.as_str().into()))),
1731        sp::Value::Boolean(b) => Ok(Expr::Literal(Value::Boolean(*b))),
1732        sp::Value::Null => Ok(Expr::Literal(Value::Null)),
1733        sp::Value::Placeholder(s) => {
1734            let idx_str = s
1735                .strip_prefix('$')
1736                .ok_or_else(|| SqlError::Parse(format!("invalid placeholder: {s}")))?;
1737            let idx: usize = idx_str
1738                .parse()
1739                .map_err(|_| SqlError::Parse(format!("invalid placeholder index: {s}")))?;
1740            if idx == 0 {
1741                return Err(SqlError::Parse("placeholder index must be >= 1".into()));
1742            }
1743            Ok(Expr::Parameter(idx))
1744        }
1745        _ => Err(SqlError::Unsupported(format!("value type: {val}"))),
1746    }
1747}
1748
1749fn convert_typed_string(dt: &sp::DataType, value: &str) -> Result<Expr> {
1750    let s = value.trim_matches('\'');
1751    match dt {
1752        sp::DataType::Date => {
1753            let d = crate::datetime::parse_date(s)?;
1754            Ok(Expr::Literal(Value::Date(d)))
1755        }
1756        sp::DataType::Time(_, _) => {
1757            let t = crate::datetime::parse_time(s)?;
1758            Ok(Expr::Literal(Value::Time(t)))
1759        }
1760        sp::DataType::Timestamp(_, _) => {
1761            let t = crate::datetime::parse_timestamp(s)?;
1762            Ok(Expr::Literal(Value::Timestamp(t)))
1763        }
1764        sp::DataType::Interval { .. } => {
1765            let (months, days, micros) = crate::datetime::parse_interval(s)?;
1766            Ok(Expr::Literal(Value::Interval {
1767                months,
1768                days,
1769                micros,
1770            }))
1771        }
1772        _ => {
1773            let target = convert_data_type(dt)?;
1774            Ok(Expr::Cast {
1775                expr: Box::new(Expr::Literal(Value::Text(s.into()))),
1776                data_type: target,
1777            })
1778        }
1779    }
1780}
1781
1782fn convert_interval_expr(iv: &sp::Interval) -> Result<Expr> {
1783    let raw = match iv.value.as_ref() {
1784        sp::Expr::Value(v) => match &v.value {
1785            sp::Value::SingleQuotedString(s) => s.clone(),
1786            sp::Value::Number(n, _) => n.clone(),
1787            other => {
1788                return Err(SqlError::InvalidIntervalLiteral(format!(
1789                    "unsupported inner value: {other}"
1790                )))
1791            }
1792        },
1793        other => {
1794            return Err(SqlError::InvalidIntervalLiteral(format!(
1795                "unsupported inner expr: {other}"
1796            )))
1797        }
1798    };
1799
1800    // SQL-standard form `INTERVAL '5' DAY` — append the unit to the literal.
1801    let with_unit = if let Some(field) = &iv.leading_field {
1802        let unit_name = match field {
1803            sp::DateTimeField::Year => "years",
1804            sp::DateTimeField::Month => "months",
1805            sp::DateTimeField::Week(_) => "weeks",
1806            sp::DateTimeField::Day => "days",
1807            sp::DateTimeField::Hour => "hours",
1808            sp::DateTimeField::Minute => "minutes",
1809            sp::DateTimeField::Second => "seconds",
1810            _ => {
1811                return Err(SqlError::InvalidIntervalLiteral(format!(
1812                    "unsupported leading field: {field:?}"
1813                )))
1814            }
1815        };
1816        format!("{raw} {unit_name}")
1817    } else {
1818        raw
1819    };
1820
1821    let (months, days, micros) = crate::datetime::parse_interval(&with_unit)?;
1822    Ok(Expr::Literal(Value::Interval {
1823        months,
1824        days,
1825        micros,
1826    }))
1827}
1828
1829fn convert_escape_value(val: &sp::Value) -> Result<Expr> {
1830    match val {
1831        sp::Value::SingleQuotedString(s) => Ok(Expr::Literal(Value::Text(s.as_str().into()))),
1832        _ => Err(SqlError::Unsupported(format!("ESCAPE value: {val}"))),
1833    }
1834}
1835
1836fn convert_bin_op(op: &sp::BinaryOperator) -> Result<BinOp> {
1837    match op {
1838        sp::BinaryOperator::Plus => Ok(BinOp::Add),
1839        sp::BinaryOperator::Minus => Ok(BinOp::Sub),
1840        sp::BinaryOperator::Multiply => Ok(BinOp::Mul),
1841        sp::BinaryOperator::Divide => Ok(BinOp::Div),
1842        sp::BinaryOperator::Modulo => Ok(BinOp::Mod),
1843        sp::BinaryOperator::Eq => Ok(BinOp::Eq),
1844        sp::BinaryOperator::NotEq => Ok(BinOp::NotEq),
1845        sp::BinaryOperator::Lt => Ok(BinOp::Lt),
1846        sp::BinaryOperator::Gt => Ok(BinOp::Gt),
1847        sp::BinaryOperator::LtEq => Ok(BinOp::LtEq),
1848        sp::BinaryOperator::GtEq => Ok(BinOp::GtEq),
1849        sp::BinaryOperator::And => Ok(BinOp::And),
1850        sp::BinaryOperator::Or => Ok(BinOp::Or),
1851        sp::BinaryOperator::StringConcat => Ok(BinOp::Concat),
1852        _ => Err(SqlError::Unsupported(format!("binary op: {op}"))),
1853    }
1854}
1855
1856fn convert_function(func: &sp::Function) -> Result<Expr> {
1857    let name = object_name_to_string(&func.name).to_ascii_uppercase();
1858
1859    let (args, is_count_star, distinct) = match &func.args {
1860        sp::FunctionArguments::List(list) => {
1861            let distinct = matches!(
1862                list.duplicate_treatment,
1863                Some(sp::DuplicateTreatment::Distinct)
1864            );
1865            if list.args.is_empty() && name == "COUNT" {
1866                (vec![], true, distinct)
1867            } else {
1868                let mut count_star = false;
1869                let args = list
1870                    .args
1871                    .iter()
1872                    .map(|arg| match arg {
1873                        sp::FunctionArg::Unnamed(sp::FunctionArgExpr::Expr(e)) => convert_expr(e),
1874                        sp::FunctionArg::Unnamed(sp::FunctionArgExpr::Wildcard) => {
1875                            if name == "COUNT" {
1876                                count_star = true;
1877                                Ok(Expr::CountStar)
1878                            } else {
1879                                Err(SqlError::Unsupported(format!("{name}(*)")))
1880                            }
1881                        }
1882                        _ => Err(SqlError::Unsupported(format!(
1883                            "function arg type in {name}"
1884                        ))),
1885                    })
1886                    .collect::<Result<Vec<_>>>()?;
1887                if name == "COUNT" && args.len() == 1 && count_star {
1888                    (vec![], true, distinct)
1889                } else {
1890                    (args, false, distinct)
1891                }
1892            }
1893        }
1894        sp::FunctionArguments::None => {
1895            if name == "COUNT" {
1896                (vec![], true, false)
1897            } else {
1898                (vec![], false, false)
1899            }
1900        }
1901        sp::FunctionArguments::Subquery(_) => {
1902            return Err(SqlError::Unsupported("subquery in function".into()));
1903        }
1904    };
1905
1906    // Window function: check OVER before any other special handling
1907    if let Some(over) = &func.over {
1908        let spec = match over {
1909            sp::WindowType::WindowSpec(ws) => convert_window_spec(ws)?,
1910            sp::WindowType::NamedWindow(_) => {
1911                return Err(SqlError::Unsupported("named windows".into()));
1912            }
1913        };
1914        return Ok(Expr::WindowFunction { name, args, spec });
1915    }
1916
1917    // Non-window special forms
1918    if is_count_star {
1919        return Ok(Expr::CountStar);
1920    }
1921
1922    if name == "COALESCE" {
1923        if args.is_empty() {
1924            return Err(SqlError::Parse(
1925                "COALESCE requires at least one argument".into(),
1926            ));
1927        }
1928        return Ok(Expr::Coalesce(args));
1929    }
1930
1931    if name == "NULLIF" {
1932        if args.len() != 2 {
1933            return Err(SqlError::Parse(
1934                "NULLIF requires exactly two arguments".into(),
1935            ));
1936        }
1937        return Ok(Expr::Case {
1938            operand: None,
1939            conditions: vec![(
1940                Expr::BinaryOp {
1941                    left: Box::new(args[0].clone()),
1942                    op: BinOp::Eq,
1943                    right: Box::new(args[1].clone()),
1944                },
1945                Expr::Literal(Value::Null),
1946            )],
1947            else_result: Some(Box::new(args[0].clone())),
1948        });
1949    }
1950
1951    if name == "IIF" {
1952        if args.len() != 3 {
1953            return Err(SqlError::Parse(
1954                "IIF requires exactly three arguments".into(),
1955            ));
1956        }
1957        return Ok(Expr::Case {
1958            operand: None,
1959            conditions: vec![(args[0].clone(), args[1].clone())],
1960            else_result: Some(Box::new(args[2].clone())),
1961        });
1962    }
1963
1964    Ok(Expr::Function {
1965        name,
1966        args,
1967        distinct,
1968    })
1969}
1970
1971fn convert_window_spec(ws: &sp::WindowSpec) -> Result<WindowSpec> {
1972    let partition_by = ws
1973        .partition_by
1974        .iter()
1975        .map(convert_expr)
1976        .collect::<Result<Vec<_>>>()?;
1977    let order_by = ws
1978        .order_by
1979        .iter()
1980        .map(convert_order_by_expr)
1981        .collect::<Result<Vec<_>>>()?;
1982    let frame = ws
1983        .window_frame
1984        .as_ref()
1985        .map(convert_window_frame)
1986        .transpose()?;
1987    Ok(WindowSpec {
1988        partition_by,
1989        order_by,
1990        frame,
1991    })
1992}
1993
1994fn convert_window_frame(wf: &sp::WindowFrame) -> Result<WindowFrame> {
1995    let units = match wf.units {
1996        sp::WindowFrameUnits::Rows => WindowFrameUnits::Rows,
1997        sp::WindowFrameUnits::Range => WindowFrameUnits::Range,
1998        sp::WindowFrameUnits::Groups => {
1999            return Err(SqlError::Unsupported("GROUPS window frame".into()));
2000        }
2001    };
2002    let start = convert_window_frame_bound(&wf.start_bound)?;
2003    let end = match &wf.end_bound {
2004        Some(b) => convert_window_frame_bound(b)?,
2005        None => WindowFrameBound::CurrentRow,
2006    };
2007    Ok(WindowFrame { units, start, end })
2008}
2009
2010fn convert_window_frame_bound(b: &sp::WindowFrameBound) -> Result<WindowFrameBound> {
2011    match b {
2012        sp::WindowFrameBound::CurrentRow => Ok(WindowFrameBound::CurrentRow),
2013        sp::WindowFrameBound::Preceding(None) => Ok(WindowFrameBound::UnboundedPreceding),
2014        sp::WindowFrameBound::Preceding(Some(e)) => {
2015            Ok(WindowFrameBound::Preceding(Box::new(convert_expr(e)?)))
2016        }
2017        sp::WindowFrameBound::Following(None) => Ok(WindowFrameBound::UnboundedFollowing),
2018        sp::WindowFrameBound::Following(Some(e)) => {
2019            Ok(WindowFrameBound::Following(Box::new(convert_expr(e)?)))
2020        }
2021    }
2022}
2023
2024fn convert_select_item(item: &sp::SelectItem) -> Result<SelectColumn> {
2025    match item {
2026        sp::SelectItem::Wildcard(_) => Ok(SelectColumn::AllColumns),
2027        sp::SelectItem::UnnamedExpr(e) => {
2028            let expr = convert_expr(e)?;
2029            Ok(SelectColumn::Expr { expr, alias: None })
2030        }
2031        sp::SelectItem::ExprWithAlias { expr, alias } => {
2032            let expr = convert_expr(expr)?;
2033            Ok(SelectColumn::Expr {
2034                expr,
2035                alias: Some(alias.value.clone()),
2036            })
2037        }
2038        sp::SelectItem::QualifiedWildcard(_, _) => {
2039            Err(SqlError::Unsupported("qualified wildcard (table.*)".into()))
2040        }
2041    }
2042}
2043
2044fn convert_order_by_expr(expr: &sp::OrderByExpr) -> Result<OrderByItem> {
2045    let e = convert_expr(&expr.expr)?;
2046    let descending = expr.options.asc.map(|asc| !asc).unwrap_or(false);
2047    let nulls_first = expr.options.nulls_first;
2048
2049    Ok(OrderByItem {
2050        expr: e,
2051        descending,
2052        nulls_first,
2053    })
2054}
2055
2056fn convert_data_type(dt: &sp::DataType) -> Result<DataType> {
2057    match dt {
2058        sp::DataType::Int(_)
2059        | sp::DataType::Integer(_)
2060        | sp::DataType::BigInt(_)
2061        | sp::DataType::SmallInt(_)
2062        | sp::DataType::TinyInt(_)
2063        | sp::DataType::Int2(_)
2064        | sp::DataType::Int4(_)
2065        | sp::DataType::Int8(_) => Ok(DataType::Integer),
2066
2067        sp::DataType::Real
2068        | sp::DataType::Double(..)
2069        | sp::DataType::DoublePrecision
2070        | sp::DataType::Float(_)
2071        | sp::DataType::Float4
2072        | sp::DataType::Float64 => Ok(DataType::Real),
2073
2074        sp::DataType::Varchar(_)
2075        | sp::DataType::Text
2076        | sp::DataType::Char(_)
2077        | sp::DataType::Character(_)
2078        | sp::DataType::String(_) => Ok(DataType::Text),
2079
2080        sp::DataType::Blob(_) | sp::DataType::Bytea => Ok(DataType::Blob),
2081
2082        sp::DataType::Boolean | sp::DataType::Bool => Ok(DataType::Boolean),
2083
2084        sp::DataType::Date => Ok(DataType::Date),
2085        sp::DataType::Time(_, _) => Ok(DataType::Time),
2086        sp::DataType::Timestamp(_, _) => Ok(DataType::Timestamp),
2087        sp::DataType::Interval { .. } => Ok(DataType::Interval),
2088
2089        _ => Err(SqlError::Unsupported(format!("data type: {dt}"))),
2090    }
2091}
2092
2093fn object_name_to_string(name: &sp::ObjectName) -> String {
2094    name.0
2095        .iter()
2096        .filter_map(|p| match p {
2097            sp::ObjectNamePart::Identifier(ident) => Some(ident.value.clone()),
2098            _ => None,
2099        })
2100        .collect::<Vec<_>>()
2101        .join(".")
2102}
2103
2104#[cfg(test)]
2105mod tests {
2106    use super::*;
2107
2108    #[test]
2109    fn parse_create_table() {
2110        let stmt = parse_sql(
2111            "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL, age INTEGER)",
2112        )
2113        .unwrap();
2114
2115        match stmt {
2116            Statement::CreateTable(ct) => {
2117                assert_eq!(ct.name, "users");
2118                assert_eq!(ct.columns.len(), 3);
2119                assert_eq!(ct.columns[0].name, "id");
2120                assert_eq!(ct.columns[0].data_type, DataType::Integer);
2121                assert!(ct.columns[0].is_primary_key);
2122                assert!(!ct.columns[0].nullable);
2123                assert_eq!(ct.columns[1].name, "name");
2124                assert_eq!(ct.columns[1].data_type, DataType::Text);
2125                assert!(!ct.columns[1].nullable);
2126                assert_eq!(ct.columns[2].name, "age");
2127                assert!(ct.columns[2].nullable);
2128                assert_eq!(ct.primary_key, vec!["id"]);
2129            }
2130            _ => panic!("expected CreateTable"),
2131        }
2132    }
2133
2134    #[test]
2135    fn parse_create_table_if_not_exists() {
2136        let stmt = parse_sql("CREATE TABLE IF NOT EXISTS t (id INT PRIMARY KEY)").unwrap();
2137        match stmt {
2138            Statement::CreateTable(ct) => assert!(ct.if_not_exists),
2139            _ => panic!("expected CreateTable"),
2140        }
2141    }
2142
2143    #[test]
2144    fn parse_drop_table() {
2145        let stmt = parse_sql("DROP TABLE users").unwrap();
2146        match stmt {
2147            Statement::DropTable(dt) => {
2148                assert_eq!(dt.name, "users");
2149                assert!(!dt.if_exists);
2150            }
2151            _ => panic!("expected DropTable"),
2152        }
2153    }
2154
2155    #[test]
2156    fn parse_drop_table_if_exists() {
2157        let stmt = parse_sql("DROP TABLE IF EXISTS users").unwrap();
2158        match stmt {
2159            Statement::DropTable(dt) => assert!(dt.if_exists),
2160            _ => panic!("expected DropTable"),
2161        }
2162    }
2163
2164    #[test]
2165    fn parse_insert() {
2166        let stmt =
2167            parse_sql("INSERT INTO users (id, name) VALUES (1, 'Alice'), (2, 'Bob')").unwrap();
2168
2169        match stmt {
2170            Statement::Insert(ins) => {
2171                assert_eq!(ins.table, "users");
2172                assert_eq!(ins.columns, vec!["id", "name"]);
2173                let values = match &ins.source {
2174                    InsertSource::Values(v) => v,
2175                    _ => panic!("expected Values"),
2176                };
2177                assert_eq!(values.len(), 2);
2178                assert!(matches!(values[0][0], Expr::Literal(Value::Integer(1))));
2179                assert!(matches!(&values[0][1], Expr::Literal(Value::Text(s)) if s == "Alice"));
2180            }
2181            _ => panic!("expected Insert"),
2182        }
2183    }
2184
2185    #[test]
2186    fn parse_select_all() {
2187        let stmt = parse_sql("SELECT * FROM users").unwrap();
2188        match stmt {
2189            Statement::Select(sq) => match sq.body {
2190                QueryBody::Select(sel) => {
2191                    assert_eq!(sel.from, "users");
2192                    assert!(matches!(sel.columns[0], SelectColumn::AllColumns));
2193                    assert!(sel.where_clause.is_none());
2194                }
2195                _ => panic!("expected QueryBody::Select"),
2196            },
2197            _ => panic!("expected Select"),
2198        }
2199    }
2200
2201    #[test]
2202    fn parse_select_where() {
2203        let stmt = parse_sql("SELECT id, name FROM users WHERE age > 18").unwrap();
2204        match stmt {
2205            Statement::Select(sq) => match sq.body {
2206                QueryBody::Select(sel) => {
2207                    assert_eq!(sel.columns.len(), 2);
2208                    assert!(sel.where_clause.is_some());
2209                }
2210                _ => panic!("expected QueryBody::Select"),
2211            },
2212            _ => panic!("expected Select"),
2213        }
2214    }
2215
2216    #[test]
2217    fn parse_select_order_limit() {
2218        let stmt = parse_sql("SELECT * FROM users ORDER BY name ASC LIMIT 10 OFFSET 5").unwrap();
2219        match stmt {
2220            Statement::Select(sq) => match sq.body {
2221                QueryBody::Select(sel) => {
2222                    assert_eq!(sel.order_by.len(), 1);
2223                    assert!(!sel.order_by[0].descending);
2224                    assert!(sel.limit.is_some());
2225                    assert!(sel.offset.is_some());
2226                }
2227                _ => panic!("expected QueryBody::Select"),
2228            },
2229            _ => panic!("expected Select"),
2230        }
2231    }
2232
2233    #[test]
2234    fn parse_update() {
2235        let stmt = parse_sql("UPDATE users SET name = 'Bob' WHERE id = 1").unwrap();
2236        match stmt {
2237            Statement::Update(upd) => {
2238                assert_eq!(upd.table, "users");
2239                assert_eq!(upd.assignments.len(), 1);
2240                assert_eq!(upd.assignments[0].0, "name");
2241                assert!(upd.where_clause.is_some());
2242            }
2243            _ => panic!("expected Update"),
2244        }
2245    }
2246
2247    #[test]
2248    fn parse_delete() {
2249        let stmt = parse_sql("DELETE FROM users WHERE id = 1").unwrap();
2250        match stmt {
2251            Statement::Delete(del) => {
2252                assert_eq!(del.table, "users");
2253                assert!(del.where_clause.is_some());
2254            }
2255            _ => panic!("expected Delete"),
2256        }
2257    }
2258
2259    #[test]
2260    fn parse_aggregate() {
2261        let stmt = parse_sql("SELECT COUNT(*), SUM(age) FROM users").unwrap();
2262        match stmt {
2263            Statement::Select(sq) => match sq.body {
2264                QueryBody::Select(sel) => {
2265                    assert_eq!(sel.columns.len(), 2);
2266                    match &sel.columns[0] {
2267                        SelectColumn::Expr {
2268                            expr: Expr::CountStar,
2269                            ..
2270                        } => {}
2271                        other => panic!("expected CountStar, got {other:?}"),
2272                    }
2273                }
2274                _ => panic!("expected QueryBody::Select"),
2275            },
2276            _ => panic!("expected Select"),
2277        }
2278    }
2279
2280    #[test]
2281    fn parse_group_by_having() {
2282        let stmt = parse_sql(
2283            "SELECT department, COUNT(*) FROM employees GROUP BY department HAVING COUNT(*) > 5",
2284        )
2285        .unwrap();
2286        match stmt {
2287            Statement::Select(sq) => match sq.body {
2288                QueryBody::Select(sel) => {
2289                    assert_eq!(sel.group_by.len(), 1);
2290                    assert!(sel.having.is_some());
2291                }
2292                _ => panic!("expected QueryBody::Select"),
2293            },
2294            _ => panic!("expected Select"),
2295        }
2296    }
2297
2298    #[test]
2299    fn parse_expressions() {
2300        let stmt = parse_sql("SELECT id + 1, -price, NOT active FROM items").unwrap();
2301        match stmt {
2302            Statement::Select(sq) => match sq.body {
2303                QueryBody::Select(sel) => {
2304                    assert_eq!(sel.columns.len(), 3);
2305                    // id + 1
2306                    match &sel.columns[0] {
2307                        SelectColumn::Expr {
2308                            expr: Expr::BinaryOp { op: BinOp::Add, .. },
2309                            ..
2310                        } => {}
2311                        other => panic!("expected BinaryOp Add, got {other:?}"),
2312                    }
2313                    // -price
2314                    match &sel.columns[1] {
2315                        SelectColumn::Expr {
2316                            expr:
2317                                Expr::UnaryOp {
2318                                    op: UnaryOp::Neg, ..
2319                                },
2320                            ..
2321                        } => {}
2322                        other => panic!("expected UnaryOp Neg, got {other:?}"),
2323                    }
2324                    // NOT active
2325                    match &sel.columns[2] {
2326                        SelectColumn::Expr {
2327                            expr:
2328                                Expr::UnaryOp {
2329                                    op: UnaryOp::Not, ..
2330                                },
2331                            ..
2332                        } => {}
2333                        other => panic!("expected UnaryOp Not, got {other:?}"),
2334                    }
2335                }
2336                _ => panic!("expected QueryBody::Select"),
2337            },
2338            _ => panic!("expected Select"),
2339        }
2340    }
2341
2342    #[test]
2343    fn parse_is_null() {
2344        let stmt = parse_sql("SELECT * FROM t WHERE x IS NULL").unwrap();
2345        match stmt {
2346            Statement::Select(sq) => match sq.body {
2347                QueryBody::Select(sel) => {
2348                    assert!(matches!(sel.where_clause, Some(Expr::IsNull(_))));
2349                }
2350                _ => panic!("expected QueryBody::Select"),
2351            },
2352            _ => panic!("expected Select"),
2353        }
2354    }
2355
2356    #[test]
2357    fn parse_inner_join() {
2358        let stmt = parse_sql("SELECT * FROM a JOIN b ON a.id = b.id").unwrap();
2359        match stmt {
2360            Statement::Select(sq) => match sq.body {
2361                QueryBody::Select(sel) => {
2362                    assert_eq!(sel.from, "a");
2363                    assert_eq!(sel.joins.len(), 1);
2364                    assert_eq!(sel.joins[0].join_type, JoinType::Inner);
2365                    assert_eq!(sel.joins[0].table.name, "b");
2366                    assert!(sel.joins[0].on_clause.is_some());
2367                }
2368                _ => panic!("expected QueryBody::Select"),
2369            },
2370            _ => panic!("expected Select"),
2371        }
2372    }
2373
2374    #[test]
2375    fn parse_inner_join_explicit() {
2376        let stmt = parse_sql("SELECT * FROM a INNER JOIN b ON a.id = b.a_id").unwrap();
2377        match stmt {
2378            Statement::Select(sq) => match sq.body {
2379                QueryBody::Select(sel) => {
2380                    assert_eq!(sel.joins.len(), 1);
2381                    assert_eq!(sel.joins[0].join_type, JoinType::Inner);
2382                }
2383                _ => panic!("expected QueryBody::Select"),
2384            },
2385            _ => panic!("expected Select"),
2386        }
2387    }
2388
2389    #[test]
2390    fn parse_cross_join() {
2391        let stmt = parse_sql("SELECT * FROM a CROSS JOIN b").unwrap();
2392        match stmt {
2393            Statement::Select(sq) => match sq.body {
2394                QueryBody::Select(sel) => {
2395                    assert_eq!(sel.joins.len(), 1);
2396                    assert_eq!(sel.joins[0].join_type, JoinType::Cross);
2397                    assert!(sel.joins[0].on_clause.is_none());
2398                }
2399                _ => panic!("expected QueryBody::Select"),
2400            },
2401            _ => panic!("expected Select"),
2402        }
2403    }
2404
2405    #[test]
2406    fn parse_left_join() {
2407        let stmt = parse_sql("SELECT * FROM a LEFT JOIN b ON a.id = b.a_id").unwrap();
2408        match stmt {
2409            Statement::Select(sq) => match sq.body {
2410                QueryBody::Select(sel) => {
2411                    assert_eq!(sel.joins.len(), 1);
2412                    assert_eq!(sel.joins[0].join_type, JoinType::Left);
2413                }
2414                _ => panic!("expected QueryBody::Select"),
2415            },
2416            _ => panic!("expected Select"),
2417        }
2418    }
2419
2420    #[test]
2421    fn parse_table_alias() {
2422        let stmt = parse_sql("SELECT u.id FROM users u JOIN orders o ON u.id = o.user_id").unwrap();
2423        match stmt {
2424            Statement::Select(sq) => match sq.body {
2425                QueryBody::Select(sel) => {
2426                    assert_eq!(sel.from, "users");
2427                    assert_eq!(sel.from_alias.as_deref(), Some("u"));
2428                    assert_eq!(sel.joins[0].table.name, "orders");
2429                    assert_eq!(sel.joins[0].table.alias.as_deref(), Some("o"));
2430                }
2431                _ => panic!("expected QueryBody::Select"),
2432            },
2433            _ => panic!("expected Select"),
2434        }
2435    }
2436
2437    #[test]
2438    fn parse_multi_join() {
2439        let stmt =
2440            parse_sql("SELECT * FROM a JOIN b ON a.id = b.a_id JOIN c ON b.id = c.b_id").unwrap();
2441        match stmt {
2442            Statement::Select(sq) => match sq.body {
2443                QueryBody::Select(sel) => {
2444                    assert_eq!(sel.joins.len(), 2);
2445                }
2446                _ => panic!("expected QueryBody::Select"),
2447            },
2448            _ => panic!("expected Select"),
2449        }
2450    }
2451
2452    #[test]
2453    fn parse_qualified_column() {
2454        let stmt = parse_sql("SELECT u.id, u.name FROM users u").unwrap();
2455        match stmt {
2456            Statement::Select(sq) => match sq.body {
2457                QueryBody::Select(sel) => match &sel.columns[0] {
2458                    SelectColumn::Expr {
2459                        expr: Expr::QualifiedColumn { table, column },
2460                        ..
2461                    } => {
2462                        assert_eq!(table, "u");
2463                        assert_eq!(column, "id");
2464                    }
2465                    other => panic!("expected QualifiedColumn, got {other:?}"),
2466                },
2467                _ => panic!("expected QueryBody::Select"),
2468            },
2469            _ => panic!("expected Select"),
2470        }
2471    }
2472
2473    #[test]
2474    fn reject_subquery() {
2475        assert!(parse_sql("SELECT * FROM (SELECT 1)").is_err());
2476    }
2477
2478    #[test]
2479    fn parse_type_mapping() {
2480        let stmt = parse_sql(
2481            "CREATE TABLE t (a INT PRIMARY KEY, b BIGINT, c SMALLINT, d REAL, e DOUBLE PRECISION, f VARCHAR(255), g BOOLEAN, h BLOB, i BYTEA)"
2482        ).unwrap();
2483        match stmt {
2484            Statement::CreateTable(ct) => {
2485                assert_eq!(ct.columns[0].data_type, DataType::Integer); // INT
2486                assert_eq!(ct.columns[1].data_type, DataType::Integer); // BIGINT
2487                assert_eq!(ct.columns[2].data_type, DataType::Integer); // SMALLINT
2488                assert_eq!(ct.columns[3].data_type, DataType::Real); // REAL
2489                assert_eq!(ct.columns[4].data_type, DataType::Real); // DOUBLE
2490                assert_eq!(ct.columns[5].data_type, DataType::Text); // VARCHAR
2491                assert_eq!(ct.columns[6].data_type, DataType::Boolean); // BOOLEAN
2492                assert_eq!(ct.columns[7].data_type, DataType::Blob); // BLOB
2493                assert_eq!(ct.columns[8].data_type, DataType::Blob); // BYTEA
2494            }
2495            _ => panic!("expected CreateTable"),
2496        }
2497    }
2498
2499    #[test]
2500    fn parse_boolean_literals() {
2501        let stmt = parse_sql("INSERT INTO t (a, b) VALUES (true, false)").unwrap();
2502        match stmt {
2503            Statement::Insert(ins) => {
2504                let values = match &ins.source {
2505                    InsertSource::Values(v) => v,
2506                    _ => panic!("expected Values"),
2507                };
2508                assert!(matches!(values[0][0], Expr::Literal(Value::Boolean(true))));
2509                assert!(matches!(values[0][1], Expr::Literal(Value::Boolean(false))));
2510            }
2511            _ => panic!("expected Insert"),
2512        }
2513    }
2514
2515    #[test]
2516    fn parse_null_literal() {
2517        let stmt = parse_sql("INSERT INTO t (a) VALUES (NULL)").unwrap();
2518        match stmt {
2519            Statement::Insert(ins) => {
2520                let values = match &ins.source {
2521                    InsertSource::Values(v) => v,
2522                    _ => panic!("expected Values"),
2523                };
2524                assert!(matches!(values[0][0], Expr::Literal(Value::Null)));
2525            }
2526            _ => panic!("expected Insert"),
2527        }
2528    }
2529
2530    #[test]
2531    fn parse_alias() {
2532        let stmt = parse_sql("SELECT id AS user_id FROM users").unwrap();
2533        match stmt {
2534            Statement::Select(sq) => match sq.body {
2535                QueryBody::Select(sel) => match &sel.columns[0] {
2536                    SelectColumn::Expr { alias: Some(a), .. } => assert_eq!(a, "user_id"),
2537                    other => panic!("expected alias, got {other:?}"),
2538                },
2539                _ => panic!("expected QueryBody::Select"),
2540            },
2541            _ => panic!("expected Select"),
2542        }
2543    }
2544
2545    #[test]
2546    fn parse_begin() {
2547        let stmt = parse_sql("BEGIN").unwrap();
2548        assert!(matches!(stmt, Statement::Begin));
2549    }
2550
2551    #[test]
2552    fn parse_begin_transaction() {
2553        let stmt = parse_sql("BEGIN TRANSACTION").unwrap();
2554        assert!(matches!(stmt, Statement::Begin));
2555    }
2556
2557    #[test]
2558    fn parse_commit() {
2559        let stmt = parse_sql("COMMIT").unwrap();
2560        assert!(matches!(stmt, Statement::Commit));
2561    }
2562
2563    #[test]
2564    fn parse_rollback() {
2565        let stmt = parse_sql("ROLLBACK").unwrap();
2566        assert!(matches!(stmt, Statement::Rollback));
2567    }
2568
2569    #[test]
2570    fn parse_savepoint() {
2571        let stmt = parse_sql("SAVEPOINT sp1").unwrap();
2572        match stmt {
2573            Statement::Savepoint(name) => assert_eq!(name, "sp1"),
2574            other => panic!("expected Savepoint, got {other:?}"),
2575        }
2576    }
2577
2578    #[test]
2579    fn parse_savepoint_case_insensitive() {
2580        let stmt = parse_sql("SAVEPOINT My_SP").unwrap();
2581        match stmt {
2582            Statement::Savepoint(name) => assert_eq!(name, "my_sp"),
2583            other => panic!("expected Savepoint, got {other:?}"),
2584        }
2585    }
2586
2587    #[test]
2588    fn parse_release_savepoint() {
2589        let stmt = parse_sql("RELEASE SAVEPOINT sp1").unwrap();
2590        match stmt {
2591            Statement::ReleaseSavepoint(name) => assert_eq!(name, "sp1"),
2592            other => panic!("expected ReleaseSavepoint, got {other:?}"),
2593        }
2594    }
2595
2596    #[test]
2597    fn parse_release_without_savepoint_keyword() {
2598        let stmt = parse_sql("RELEASE sp1").unwrap();
2599        match stmt {
2600            Statement::ReleaseSavepoint(name) => assert_eq!(name, "sp1"),
2601            other => panic!("expected ReleaseSavepoint, got {other:?}"),
2602        }
2603    }
2604
2605    #[test]
2606    fn parse_rollback_to_savepoint() {
2607        let stmt = parse_sql("ROLLBACK TO SAVEPOINT sp1").unwrap();
2608        match stmt {
2609            Statement::RollbackTo(name) => assert_eq!(name, "sp1"),
2610            other => panic!("expected RollbackTo, got {other:?}"),
2611        }
2612    }
2613
2614    #[test]
2615    fn parse_rollback_to_without_savepoint_keyword() {
2616        let stmt = parse_sql("ROLLBACK TO sp1").unwrap();
2617        match stmt {
2618            Statement::RollbackTo(name) => assert_eq!(name, "sp1"),
2619            other => panic!("expected RollbackTo, got {other:?}"),
2620        }
2621    }
2622
2623    #[test]
2624    fn parse_rollback_to_case_insensitive() {
2625        let stmt = parse_sql("ROLLBACK TO My_SP").unwrap();
2626        match stmt {
2627            Statement::RollbackTo(name) => assert_eq!(name, "my_sp"),
2628            other => panic!("expected RollbackTo, got {other:?}"),
2629        }
2630    }
2631
2632    #[test]
2633    fn parse_commit_and_chain_rejected() {
2634        let err = parse_sql("COMMIT AND CHAIN").unwrap_err();
2635        assert!(matches!(err, SqlError::Unsupported(_)));
2636    }
2637
2638    #[test]
2639    fn parse_rollback_and_chain_rejected() {
2640        let err = parse_sql("ROLLBACK AND CHAIN").unwrap_err();
2641        assert!(matches!(err, SqlError::Unsupported(_)));
2642    }
2643
2644    #[test]
2645    fn parse_select_distinct() {
2646        let stmt = parse_sql("SELECT DISTINCT name FROM users").unwrap();
2647        match stmt {
2648            Statement::Select(sq) => match sq.body {
2649                QueryBody::Select(sel) => {
2650                    assert!(sel.distinct);
2651                    assert_eq!(sel.columns.len(), 1);
2652                }
2653                _ => panic!("expected QueryBody::Select"),
2654            },
2655            _ => panic!("expected Select"),
2656        }
2657    }
2658
2659    #[test]
2660    fn parse_select_without_distinct() {
2661        let stmt = parse_sql("SELECT name FROM users").unwrap();
2662        match stmt {
2663            Statement::Select(sq) => match sq.body {
2664                QueryBody::Select(sel) => {
2665                    assert!(!sel.distinct);
2666                }
2667                _ => panic!("expected QueryBody::Select"),
2668            },
2669            _ => panic!("expected Select"),
2670        }
2671    }
2672
2673    #[test]
2674    fn parse_select_distinct_all_columns() {
2675        let stmt = parse_sql("SELECT DISTINCT * FROM users").unwrap();
2676        match stmt {
2677            Statement::Select(sq) => match sq.body {
2678                QueryBody::Select(sel) => {
2679                    assert!(sel.distinct);
2680                    assert!(matches!(sel.columns[0], SelectColumn::AllColumns));
2681                }
2682                _ => panic!("expected QueryBody::Select"),
2683            },
2684            _ => panic!("expected Select"),
2685        }
2686    }
2687
2688    #[test]
2689    fn reject_distinct_on() {
2690        assert!(parse_sql("SELECT DISTINCT ON (id) * FROM users").is_err());
2691    }
2692
2693    #[test]
2694    fn parse_create_index() {
2695        let stmt = parse_sql("CREATE INDEX idx_name ON users (name)").unwrap();
2696        match stmt {
2697            Statement::CreateIndex(ci) => {
2698                assert_eq!(ci.index_name, "idx_name");
2699                assert_eq!(ci.table_name, "users");
2700                assert_eq!(ci.columns, vec!["name"]);
2701                assert!(!ci.unique);
2702                assert!(!ci.if_not_exists);
2703            }
2704            _ => panic!("expected CreateIndex"),
2705        }
2706    }
2707
2708    #[test]
2709    fn parse_create_unique_index() {
2710        let stmt = parse_sql("CREATE UNIQUE INDEX idx_email ON users (email)").unwrap();
2711        match stmt {
2712            Statement::CreateIndex(ci) => {
2713                assert!(ci.unique);
2714                assert_eq!(ci.columns, vec!["email"]);
2715            }
2716            _ => panic!("expected CreateIndex"),
2717        }
2718    }
2719
2720    #[test]
2721    fn parse_create_index_if_not_exists() {
2722        let stmt = parse_sql("CREATE INDEX IF NOT EXISTS idx_x ON t (a)").unwrap();
2723        match stmt {
2724            Statement::CreateIndex(ci) => assert!(ci.if_not_exists),
2725            _ => panic!("expected CreateIndex"),
2726        }
2727    }
2728
2729    #[test]
2730    fn parse_create_index_multi_column() {
2731        let stmt = parse_sql("CREATE INDEX idx_multi ON t (a, b, c)").unwrap();
2732        match stmt {
2733            Statement::CreateIndex(ci) => {
2734                assert_eq!(ci.columns, vec!["a", "b", "c"]);
2735            }
2736            _ => panic!("expected CreateIndex"),
2737        }
2738    }
2739
2740    #[test]
2741    fn parse_drop_index() {
2742        let stmt = parse_sql("DROP INDEX idx_name").unwrap();
2743        match stmt {
2744            Statement::DropIndex(di) => {
2745                assert_eq!(di.index_name, "idx_name");
2746                assert!(!di.if_exists);
2747            }
2748            _ => panic!("expected DropIndex"),
2749        }
2750    }
2751
2752    #[test]
2753    fn parse_drop_index_if_exists() {
2754        let stmt = parse_sql("DROP INDEX IF EXISTS idx_name").unwrap();
2755        match stmt {
2756            Statement::DropIndex(di) => {
2757                assert!(di.if_exists);
2758            }
2759            _ => panic!("expected DropIndex"),
2760        }
2761    }
2762
2763    #[test]
2764    fn parse_explain_select() {
2765        let stmt = parse_sql("EXPLAIN SELECT * FROM users WHERE id = 1").unwrap();
2766        match stmt {
2767            Statement::Explain(inner) => {
2768                assert!(matches!(*inner, Statement::Select(_)));
2769            }
2770            _ => panic!("expected Explain"),
2771        }
2772    }
2773
2774    #[test]
2775    fn parse_explain_insert() {
2776        let stmt = parse_sql("EXPLAIN INSERT INTO t (a) VALUES (1)").unwrap();
2777        assert!(matches!(stmt, Statement::Explain(_)));
2778    }
2779
2780    #[test]
2781    fn reject_explain_analyze() {
2782        assert!(parse_sql("EXPLAIN ANALYZE SELECT * FROM t").is_err());
2783    }
2784
2785    #[test]
2786    fn parse_parameter_placeholder() {
2787        let stmt = parse_sql("SELECT * FROM t WHERE id = $1").unwrap();
2788        match stmt {
2789            Statement::Select(sq) => match sq.body {
2790                QueryBody::Select(sel) => match &sel.where_clause {
2791                    Some(Expr::BinaryOp { right, .. }) => {
2792                        assert!(matches!(right.as_ref(), Expr::Parameter(1)));
2793                    }
2794                    other => panic!("expected BinaryOp with Parameter, got {other:?}"),
2795                },
2796                _ => panic!("expected QueryBody::Select"),
2797            },
2798            _ => panic!("expected Select"),
2799        }
2800    }
2801
2802    #[test]
2803    fn parse_multiple_parameters() {
2804        let stmt = parse_sql("INSERT INTO t (a, b) VALUES ($1, $2)").unwrap();
2805        match stmt {
2806            Statement::Insert(ins) => {
2807                let values = match &ins.source {
2808                    InsertSource::Values(v) => v,
2809                    _ => panic!("expected Values"),
2810                };
2811                assert!(matches!(values[0][0], Expr::Parameter(1)));
2812                assert!(matches!(values[0][1], Expr::Parameter(2)));
2813            }
2814            _ => panic!("expected Insert"),
2815        }
2816    }
2817
2818    #[test]
2819    fn parse_insert_select() {
2820        let stmt =
2821            parse_sql("INSERT INTO t2 (id, name) SELECT id, name FROM t1 WHERE id > 5").unwrap();
2822        match stmt {
2823            Statement::Insert(ins) => {
2824                assert_eq!(ins.table, "t2");
2825                assert_eq!(ins.columns, vec!["id", "name"]);
2826                match &ins.source {
2827                    InsertSource::Select(sq) => match &sq.body {
2828                        QueryBody::Select(sel) => {
2829                            assert_eq!(sel.from, "t1");
2830                            assert_eq!(sel.columns.len(), 2);
2831                            assert!(sel.where_clause.is_some());
2832                        }
2833                        _ => panic!("expected QueryBody::Select"),
2834                    },
2835                    _ => panic!("expected InsertSource::Select"),
2836                }
2837            }
2838            _ => panic!("expected Insert"),
2839        }
2840    }
2841
2842    #[test]
2843    fn parse_insert_select_no_columns() {
2844        let stmt = parse_sql("INSERT INTO t2 SELECT * FROM t1").unwrap();
2845        match stmt {
2846            Statement::Insert(ins) => {
2847                assert_eq!(ins.table, "t2");
2848                assert!(ins.columns.is_empty());
2849                assert!(matches!(&ins.source, InsertSource::Select(_)));
2850            }
2851            _ => panic!("expected Insert"),
2852        }
2853    }
2854
2855    #[test]
2856    fn reject_zero_parameter() {
2857        assert!(parse_sql("SELECT $0 FROM t").is_err());
2858    }
2859
2860    #[test]
2861    fn count_params_basic() {
2862        let stmt = parse_sql("SELECT * FROM t WHERE a = $1 AND b = $3").unwrap();
2863        assert_eq!(count_params(&stmt), 3);
2864    }
2865
2866    #[test]
2867    fn count_params_none() {
2868        let stmt = parse_sql("SELECT * FROM t WHERE a = 1").unwrap();
2869        assert_eq!(count_params(&stmt), 0);
2870    }
2871
2872    #[test]
2873    fn parse_table_constraint_pk() {
2874        let stmt = parse_sql("CREATE TABLE t (a INTEGER, b TEXT, PRIMARY KEY (a))").unwrap();
2875        match stmt {
2876            Statement::CreateTable(ct) => {
2877                assert_eq!(ct.primary_key, vec!["a"]);
2878                assert!(ct.columns[0].is_primary_key);
2879                assert!(!ct.columns[0].nullable);
2880            }
2881            _ => panic!("expected CreateTable"),
2882        }
2883    }
2884}