Skip to main content

citadel_sql/
parser.rs

1//! SQL parser: converts SQL strings into our internal AST.
2
3use sqlparser::ast as sp;
4use sqlparser::dialect::GenericDialect;
5use sqlparser::parser::Parser;
6
7use crate::error::{Result, SqlError};
8use crate::types::{DataType, Value};
9
10#[derive(Debug, Clone)]
11pub enum Statement {
12    CreateTable(CreateTableStmt),
13    DropTable(DropTableStmt),
14    CreateIndex(CreateIndexStmt),
15    DropIndex(DropIndexStmt),
16    CreateView(CreateViewStmt),
17    DropView(DropViewStmt),
18    AlterTable(Box<AlterTableStmt>),
19    Insert(InsertStmt),
20    Select(Box<SelectQuery>),
21    Update(UpdateStmt),
22    Delete(DeleteStmt),
23    Begin,
24    Commit,
25    Rollback,
26    Savepoint(String),
27    ReleaseSavepoint(String),
28    RollbackTo(String),
29    SetTimezone(String),
30    Explain(Box<Statement>),
31}
32
33#[derive(Debug, Clone)]
34pub struct AlterTableStmt {
35    pub table: String,
36    pub op: AlterTableOp,
37}
38
39#[derive(Debug, Clone)]
40pub enum AlterTableOp {
41    AddColumn {
42        column: Box<ColumnSpec>,
43        foreign_key: Option<ForeignKeyDef>,
44        if_not_exists: bool,
45    },
46    DropColumn {
47        name: String,
48        if_exists: bool,
49    },
50    RenameColumn {
51        old_name: String,
52        new_name: String,
53    },
54    RenameTable {
55        new_name: String,
56    },
57}
58
59#[derive(Debug, Clone)]
60pub struct CreateTableStmt {
61    pub name: String,
62    pub columns: Vec<ColumnSpec>,
63    pub primary_key: Vec<String>,
64    pub if_not_exists: bool,
65    pub check_constraints: Vec<TableCheckConstraint>,
66    pub foreign_keys: Vec<ForeignKeyDef>,
67    pub unique_indices: Vec<UniqueIndexDef>,
68}
69
70#[derive(Debug, Clone)]
71pub struct UniqueIndexDef {
72    pub name: Option<String>,
73    pub columns: Vec<String>,
74}
75
76#[derive(Debug, Clone)]
77pub struct TableCheckConstraint {
78    pub name: Option<String>,
79    pub expr: Expr,
80    pub sql: String,
81}
82
83#[derive(Debug, Clone)]
84pub struct ForeignKeyDef {
85    pub name: Option<String>,
86    pub columns: Vec<String>,
87    pub foreign_table: String,
88    pub referred_columns: Vec<String>,
89}
90
91#[derive(Debug, Clone)]
92pub struct ColumnSpec {
93    pub name: String,
94    pub data_type: DataType,
95    pub nullable: bool,
96    pub is_primary_key: bool,
97    pub default_expr: Option<Expr>,
98    pub default_sql: Option<String>,
99    pub check_expr: Option<Expr>,
100    pub check_sql: Option<String>,
101    pub check_name: Option<String>,
102}
103
104#[derive(Debug, Clone)]
105pub struct DropTableStmt {
106    pub name: String,
107    pub if_exists: bool,
108}
109
110#[derive(Debug, Clone)]
111pub struct CreateIndexStmt {
112    pub index_name: String,
113    pub table_name: String,
114    pub columns: Vec<String>,
115    pub unique: bool,
116    pub if_not_exists: bool,
117}
118
119#[derive(Debug, Clone)]
120pub struct DropIndexStmt {
121    pub index_name: String,
122    pub if_exists: bool,
123}
124
125#[derive(Debug, Clone)]
126pub struct CreateViewStmt {
127    pub name: String,
128    pub sql: String,
129    pub column_aliases: Vec<String>,
130    pub or_replace: bool,
131    pub if_not_exists: bool,
132}
133
134#[derive(Debug, Clone)]
135pub struct DropViewStmt {
136    pub name: String,
137    pub if_exists: bool,
138}
139
140#[derive(Debug, Clone)]
141pub enum InsertSource {
142    Values(Vec<Vec<Expr>>),
143    Select(Box<SelectQuery>),
144}
145
146#[derive(Debug, Clone)]
147pub struct InsertStmt {
148    pub table: String,
149    pub columns: Vec<String>,
150    pub source: InsertSource,
151}
152
153#[derive(Debug, Clone)]
154pub struct TableRef {
155    pub name: String,
156    pub alias: Option<String>,
157}
158
159#[derive(Debug, Clone, Copy, PartialEq)]
160pub enum JoinType {
161    Inner,
162    Cross,
163    Left,
164    Right,
165}
166
167#[derive(Debug, Clone)]
168pub struct JoinClause {
169    pub join_type: JoinType,
170    pub table: TableRef,
171    pub on_clause: Option<Expr>,
172}
173
174#[derive(Debug, Clone)]
175pub struct SelectStmt {
176    pub columns: Vec<SelectColumn>,
177    pub from: String,
178    pub from_alias: Option<String>,
179    pub joins: Vec<JoinClause>,
180    pub distinct: bool,
181    pub where_clause: Option<Expr>,
182    pub order_by: Vec<OrderByItem>,
183    pub limit: Option<Expr>,
184    pub offset: Option<Expr>,
185    pub group_by: Vec<Expr>,
186    pub having: Option<Expr>,
187}
188
189#[derive(Debug, Clone)]
190pub enum SetOp {
191    Union,
192    Intersect,
193    Except,
194}
195
196#[derive(Debug, Clone)]
197pub struct CompoundSelect {
198    pub op: SetOp,
199    pub all: bool,
200    pub left: Box<QueryBody>,
201    pub right: Box<QueryBody>,
202    pub order_by: Vec<OrderByItem>,
203    pub limit: Option<Expr>,
204    pub offset: Option<Expr>,
205}
206
207#[derive(Debug, Clone)]
208pub enum QueryBody {
209    Select(Box<SelectStmt>),
210    Compound(Box<CompoundSelect>),
211}
212
213#[derive(Debug, Clone)]
214pub struct CteDefinition {
215    pub name: String,
216    pub column_aliases: Vec<String>,
217    pub body: QueryBody,
218}
219
220#[derive(Debug, Clone)]
221pub struct SelectQuery {
222    pub ctes: Vec<CteDefinition>,
223    pub recursive: bool,
224    pub body: QueryBody,
225}
226
227#[derive(Debug, Clone)]
228pub struct UpdateStmt {
229    pub table: String,
230    pub assignments: Vec<(String, Expr)>,
231    pub where_clause: Option<Expr>,
232}
233
234#[derive(Debug, Clone)]
235pub struct DeleteStmt {
236    pub table: String,
237    pub where_clause: Option<Expr>,
238}
239
240#[derive(Debug, Clone)]
241pub enum SelectColumn {
242    AllColumns,
243    Expr { expr: Expr, alias: Option<String> },
244}
245
246#[derive(Debug, Clone)]
247pub struct OrderByItem {
248    pub expr: Expr,
249    pub descending: bool,
250    pub nulls_first: Option<bool>,
251}
252
253#[derive(Debug, Clone)]
254pub enum Expr {
255    Literal(Value),
256    Column(String),
257    QualifiedColumn {
258        table: String,
259        column: String,
260    },
261    BinaryOp {
262        left: Box<Expr>,
263        op: BinOp,
264        right: Box<Expr>,
265    },
266    UnaryOp {
267        op: UnaryOp,
268        expr: Box<Expr>,
269    },
270    IsNull(Box<Expr>),
271    IsNotNull(Box<Expr>),
272    Function {
273        name: String,
274        args: Vec<Expr>,
275        /// True for aggregate forms like `COUNT(DISTINCT x)`.
276        distinct: bool,
277    },
278    CountStar,
279    InSubquery {
280        expr: Box<Expr>,
281        subquery: Box<SelectStmt>,
282        negated: bool,
283    },
284    InList {
285        expr: Box<Expr>,
286        list: Vec<Expr>,
287        negated: bool,
288    },
289    Exists {
290        subquery: Box<SelectStmt>,
291        negated: bool,
292    },
293    ScalarSubquery(Box<SelectStmt>),
294    InSet {
295        expr: Box<Expr>,
296        values: std::collections::HashSet<Value>,
297        has_null: bool,
298        negated: bool,
299    },
300    Between {
301        expr: Box<Expr>,
302        low: Box<Expr>,
303        high: Box<Expr>,
304        negated: bool,
305    },
306    Like {
307        expr: Box<Expr>,
308        pattern: Box<Expr>,
309        escape: Option<Box<Expr>>,
310        negated: bool,
311    },
312    Case {
313        operand: Option<Box<Expr>>,
314        conditions: Vec<(Expr, Expr)>,
315        else_result: Option<Box<Expr>>,
316    },
317    Coalesce(Vec<Expr>),
318    Cast {
319        expr: Box<Expr>,
320        data_type: DataType,
321    },
322    Parameter(usize),
323    WindowFunction {
324        name: String,
325        args: Vec<Expr>,
326        spec: WindowSpec,
327    },
328}
329
330#[derive(Debug, Clone)]
331pub struct WindowSpec {
332    pub partition_by: Vec<Expr>,
333    pub order_by: Vec<OrderByItem>,
334    pub frame: Option<WindowFrame>,
335}
336
337#[derive(Debug, Clone)]
338pub struct WindowFrame {
339    pub units: WindowFrameUnits,
340    pub start: WindowFrameBound,
341    pub end: WindowFrameBound,
342}
343
344#[derive(Debug, Clone, Copy)]
345pub enum WindowFrameUnits {
346    Rows,
347    Range,
348    Groups,
349}
350
351#[derive(Debug, Clone)]
352pub enum WindowFrameBound {
353    UnboundedPreceding,
354    Preceding(Box<Expr>),
355    CurrentRow,
356    Following(Box<Expr>),
357    UnboundedFollowing,
358}
359
360#[derive(Debug, Clone, Copy, PartialEq, Eq)]
361pub enum BinOp {
362    Add,
363    Sub,
364    Mul,
365    Div,
366    Mod,
367    Eq,
368    NotEq,
369    Lt,
370    Gt,
371    LtEq,
372    GtEq,
373    And,
374    Or,
375    Concat,
376}
377
378#[derive(Debug, Clone, Copy, PartialEq, Eq)]
379pub enum UnaryOp {
380    Neg,
381    Not,
382}
383
384pub fn has_subquery(expr: &Expr) -> bool {
385    match expr {
386        Expr::InSubquery { .. } | Expr::Exists { .. } | Expr::ScalarSubquery(_) => true,
387        Expr::BinaryOp { left, right, .. } => has_subquery(left) || has_subquery(right),
388        Expr::UnaryOp { expr, .. } => has_subquery(expr),
389        Expr::IsNull(e) | Expr::IsNotNull(e) => has_subquery(e),
390        Expr::InList { expr, list, .. } => has_subquery(expr) || list.iter().any(has_subquery),
391        Expr::InSet { expr, .. } => has_subquery(expr),
392        Expr::Between {
393            expr, low, high, ..
394        } => has_subquery(expr) || has_subquery(low) || has_subquery(high),
395        Expr::Like {
396            expr,
397            pattern,
398            escape,
399            ..
400        } => {
401            has_subquery(expr)
402                || has_subquery(pattern)
403                || escape.as_ref().is_some_and(|e| has_subquery(e))
404        }
405        Expr::Case {
406            operand,
407            conditions,
408            else_result,
409        } => {
410            operand.as_ref().is_some_and(|e| has_subquery(e))
411                || conditions
412                    .iter()
413                    .any(|(c, r)| has_subquery(c) || has_subquery(r))
414                || else_result.as_ref().is_some_and(|e| has_subquery(e))
415        }
416        Expr::Coalesce(args) | Expr::Function { args, .. } => args.iter().any(has_subquery),
417        Expr::Cast { expr, .. } => has_subquery(expr),
418        _ => false,
419    }
420}
421
422/// Parse a SQL expression string back into an internal Expr.
423/// Used for deserializing stored DEFAULT/CHECK expressions from schema.
424pub fn parse_sql_expr(sql: &str) -> Result<Expr> {
425    let dialect = GenericDialect {};
426    let mut parser = Parser::new(&dialect)
427        .try_with_sql(sql)
428        .map_err(|e| SqlError::Parse(e.to_string()))?;
429    let sp_expr = parser
430        .parse_expr()
431        .map_err(|e| SqlError::Parse(e.to_string()))?;
432    convert_expr(&sp_expr)
433}
434
435pub fn parse_sql(sql: &str) -> Result<Statement> {
436    let dialect = GenericDialect {};
437    let stmts = Parser::parse_sql(&dialect, sql).map_err(|e| SqlError::Parse(e.to_string()))?;
438
439    if stmts.is_empty() {
440        return Err(SqlError::Parse("empty SQL".into()));
441    }
442    if stmts.len() > 1 {
443        return Err(SqlError::Unsupported("multiple statements".into()));
444    }
445
446    convert_statement(stmts.into_iter().next().unwrap())
447}
448
449/// Returns the number of distinct parameters in a statement (max $N found).
450pub fn count_params(stmt: &Statement) -> usize {
451    let mut max_idx = 0usize;
452    visit_exprs_stmt(stmt, &mut |e| {
453        if let Expr::Parameter(n) = e {
454            max_idx = max_idx.max(*n);
455        }
456    });
457    max_idx
458}
459
460fn visit_exprs_stmt(stmt: &Statement, visitor: &mut impl FnMut(&Expr)) {
461    match stmt {
462        Statement::Select(sq) => {
463            for cte in &sq.ctes {
464                visit_exprs_query_body(&cte.body, visitor);
465            }
466            visit_exprs_query_body(&sq.body, visitor);
467        }
468        Statement::Insert(ins) => match &ins.source {
469            InsertSource::Values(rows) => {
470                for row in rows {
471                    for e in row {
472                        visit_expr(e, visitor);
473                    }
474                }
475            }
476            InsertSource::Select(sq) => {
477                for cte in &sq.ctes {
478                    visit_exprs_query_body(&cte.body, visitor);
479                }
480                visit_exprs_query_body(&sq.body, visitor);
481            }
482        },
483        Statement::Update(upd) => {
484            for (_, e) in &upd.assignments {
485                visit_expr(e, visitor);
486            }
487            if let Some(w) = &upd.where_clause {
488                visit_expr(w, visitor);
489            }
490        }
491        Statement::Delete(del) => {
492            if let Some(w) = &del.where_clause {
493                visit_expr(w, visitor);
494            }
495        }
496        Statement::Explain(inner) => visit_exprs_stmt(inner, visitor),
497        _ => {}
498    }
499}
500
501fn visit_exprs_query_body(body: &QueryBody, visitor: &mut impl FnMut(&Expr)) {
502    match body {
503        QueryBody::Select(sel) => visit_exprs_select(sel, visitor),
504        QueryBody::Compound(comp) => {
505            visit_exprs_query_body(&comp.left, visitor);
506            visit_exprs_query_body(&comp.right, visitor);
507            for o in &comp.order_by {
508                visit_expr(&o.expr, visitor);
509            }
510            if let Some(l) = &comp.limit {
511                visit_expr(l, visitor);
512            }
513            if let Some(o) = &comp.offset {
514                visit_expr(o, visitor);
515            }
516        }
517    }
518}
519
520fn visit_exprs_select(sel: &SelectStmt, visitor: &mut impl FnMut(&Expr)) {
521    for col in &sel.columns {
522        if let SelectColumn::Expr { expr, .. } = col {
523            visit_expr(expr, visitor);
524        }
525    }
526    for j in &sel.joins {
527        if let Some(on) = &j.on_clause {
528            visit_expr(on, visitor);
529        }
530    }
531    if let Some(w) = &sel.where_clause {
532        visit_expr(w, visitor);
533    }
534    for o in &sel.order_by {
535        visit_expr(&o.expr, visitor);
536    }
537    if let Some(l) = &sel.limit {
538        visit_expr(l, visitor);
539    }
540    if let Some(o) = &sel.offset {
541        visit_expr(o, visitor);
542    }
543    for g in &sel.group_by {
544        visit_expr(g, visitor);
545    }
546    if let Some(h) = &sel.having {
547        visit_expr(h, visitor);
548    }
549}
550
551fn visit_expr(expr: &Expr, visitor: &mut impl FnMut(&Expr)) {
552    visitor(expr);
553    match expr {
554        Expr::BinaryOp { left, right, .. } => {
555            visit_expr(left, visitor);
556            visit_expr(right, visitor);
557        }
558        Expr::UnaryOp { expr: e, .. } | Expr::IsNull(e) | Expr::IsNotNull(e) => {
559            visit_expr(e, visitor);
560        }
561        Expr::Function { args, .. } | Expr::Coalesce(args) => {
562            for a in args {
563                visit_expr(a, visitor);
564            }
565        }
566        Expr::InSubquery {
567            expr: e, subquery, ..
568        } => {
569            visit_expr(e, visitor);
570            visit_exprs_select(subquery, visitor);
571        }
572        Expr::InList { expr: e, list, .. } => {
573            visit_expr(e, visitor);
574            for l in list {
575                visit_expr(l, visitor);
576            }
577        }
578        Expr::Exists { subquery, .. } => visit_exprs_select(subquery, visitor),
579        Expr::ScalarSubquery(sq) => visit_exprs_select(sq, visitor),
580        Expr::InSet { expr: e, .. } => visit_expr(e, visitor),
581        Expr::Between {
582            expr: e, low, high, ..
583        } => {
584            visit_expr(e, visitor);
585            visit_expr(low, visitor);
586            visit_expr(high, visitor);
587        }
588        Expr::Like {
589            expr: e,
590            pattern,
591            escape,
592            ..
593        } => {
594            visit_expr(e, visitor);
595            visit_expr(pattern, visitor);
596            if let Some(esc) = escape {
597                visit_expr(esc, visitor);
598            }
599        }
600        Expr::Case {
601            operand,
602            conditions,
603            else_result,
604        } => {
605            if let Some(op) = operand {
606                visit_expr(op, visitor);
607            }
608            for (cond, then) in conditions {
609                visit_expr(cond, visitor);
610                visit_expr(then, visitor);
611            }
612            if let Some(el) = else_result {
613                visit_expr(el, visitor);
614            }
615        }
616        Expr::Cast { expr: e, .. } => visit_expr(e, visitor),
617        Expr::WindowFunction { args, spec, .. } => {
618            for a in args {
619                visit_expr(a, visitor);
620            }
621            for p in &spec.partition_by {
622                visit_expr(p, visitor);
623            }
624            for o in &spec.order_by {
625                visit_expr(&o.expr, visitor);
626            }
627            if let Some(ref frame) = spec.frame {
628                if let WindowFrameBound::Preceding(e) | WindowFrameBound::Following(e) =
629                    &frame.start
630                {
631                    visit_expr(e, visitor);
632                }
633                if let WindowFrameBound::Preceding(e) | WindowFrameBound::Following(e) = &frame.end
634                {
635                    visit_expr(e, visitor);
636                }
637            }
638        }
639        Expr::Literal(_)
640        | Expr::Column(_)
641        | Expr::QualifiedColumn { .. }
642        | Expr::CountStar
643        | Expr::Parameter(_) => {}
644    }
645}
646
647fn convert_statement(stmt: sp::Statement) -> Result<Statement> {
648    match stmt {
649        sp::Statement::CreateTable(ct) => convert_create_table(ct),
650        sp::Statement::CreateIndex(ci) => convert_create_index(ci),
651        sp::Statement::Drop {
652            object_type: sp::ObjectType::Table,
653            if_exists,
654            names,
655            ..
656        } => {
657            if names.len() != 1 {
658                return Err(SqlError::Unsupported("multi-table DROP".into()));
659            }
660            Ok(Statement::DropTable(DropTableStmt {
661                name: object_name_to_string(&names[0]),
662                if_exists,
663            }))
664        }
665        sp::Statement::Drop {
666            object_type: sp::ObjectType::Index,
667            if_exists,
668            names,
669            ..
670        } => {
671            if names.len() != 1 {
672                return Err(SqlError::Unsupported("multi-index DROP".into()));
673            }
674            Ok(Statement::DropIndex(DropIndexStmt {
675                index_name: object_name_to_string(&names[0]),
676                if_exists,
677            }))
678        }
679        sp::Statement::CreateView(cv) => convert_create_view(cv),
680        sp::Statement::Drop {
681            object_type: sp::ObjectType::View,
682            if_exists,
683            names,
684            ..
685        } => {
686            if names.len() != 1 {
687                return Err(SqlError::Unsupported("multi-view DROP".into()));
688            }
689            Ok(Statement::DropView(DropViewStmt {
690                name: object_name_to_string(&names[0]),
691                if_exists,
692            }))
693        }
694        sp::Statement::AlterTable(at) => convert_alter_table(at),
695        sp::Statement::Insert(insert) => convert_insert(insert),
696        sp::Statement::Query(query) => convert_query(*query),
697        sp::Statement::Update(update) => convert_update(update),
698        sp::Statement::Delete(delete) => convert_delete(delete),
699        sp::Statement::StartTransaction { .. } => Ok(Statement::Begin),
700        sp::Statement::Commit { chain: true, .. } => {
701            Err(SqlError::Unsupported("COMMIT AND CHAIN".into()))
702        }
703        sp::Statement::Commit { .. } => Ok(Statement::Commit),
704        sp::Statement::Rollback { chain: true, .. } => {
705            Err(SqlError::Unsupported("ROLLBACK AND CHAIN".into()))
706        }
707        sp::Statement::Rollback {
708            savepoint: Some(name),
709            ..
710        } => Ok(Statement::RollbackTo(name.value.to_ascii_lowercase())),
711        sp::Statement::Rollback { .. } => Ok(Statement::Rollback),
712        sp::Statement::Savepoint { name } => {
713            Ok(Statement::Savepoint(name.value.to_ascii_lowercase()))
714        }
715        sp::Statement::ReleaseSavepoint { name } => {
716            Ok(Statement::ReleaseSavepoint(name.value.to_ascii_lowercase()))
717        }
718        sp::Statement::Set(sp::Set::SetTimeZone { value, .. }) => {
719            // Accept a string literal or bare identifier (PG allows `SET TIME ZONE UTC`).
720            let zone = match value {
721                sp::Expr::Value(v) => match &v.value {
722                    sp::Value::SingleQuotedString(s) => s.clone(),
723                    sp::Value::DoubleQuotedString(s) => s.clone(),
724                    other => other.to_string(),
725                },
726                sp::Expr::Identifier(ident) => ident.value.clone(),
727                other => {
728                    return Err(SqlError::Parse(format!(
729                        "SET TIME ZONE expects a string literal or identifier, got: {other}"
730                    )))
731                }
732            };
733            Ok(Statement::SetTimezone(zone))
734        }
735        sp::Statement::Explain {
736            statement, analyze, ..
737        } => {
738            if analyze {
739                return Err(SqlError::Unsupported("EXPLAIN ANALYZE".into()));
740            }
741            let inner = convert_statement(*statement)?;
742            Ok(Statement::Explain(Box::new(inner)))
743        }
744        _ => Err(SqlError::Unsupported(format!("statement type: {}", stmt))),
745    }
746}
747
748/// Parse column options (NOT NULL, DEFAULT, CHECK, FK, UNIQUE) from a sqlparser ColumnDef.
749/// Returns (ColumnSpec, Option<ForeignKeyDef>, was_inline_pk, was_unique).
750fn convert_column_def(
751    col_def: &sp::ColumnDef,
752) -> Result<(ColumnSpec, Option<ForeignKeyDef>, bool, bool)> {
753    let col_name = col_def.name.value.clone();
754    let data_type = convert_data_type(&col_def.data_type)?;
755    let mut nullable = true;
756    let mut is_primary_key = false;
757    let mut is_unique = false;
758    let mut default_expr = None;
759    let mut default_sql = None;
760    let mut check_expr = None;
761    let mut check_sql = None;
762    let mut check_name = None;
763    let mut fk_def = None;
764
765    for opt in &col_def.options {
766        match &opt.option {
767            sp::ColumnOption::NotNull => nullable = false,
768            sp::ColumnOption::Null => nullable = true,
769            sp::ColumnOption::PrimaryKey(_) => {
770                is_primary_key = true;
771                nullable = false;
772            }
773            sp::ColumnOption::Unique(_) => is_unique = true,
774            sp::ColumnOption::Default(expr) => {
775                default_sql = Some(expr.to_string());
776                default_expr = Some(convert_expr(expr)?);
777            }
778            sp::ColumnOption::Check(check) => {
779                check_sql = Some(check.expr.to_string());
780                let converted = convert_expr(&check.expr)?;
781                if has_subquery(&converted) {
782                    return Err(SqlError::Unsupported("subquery in CHECK constraint".into()));
783                }
784                check_expr = Some(converted);
785                check_name = check.name.as_ref().map(|n| n.value.clone());
786            }
787            sp::ColumnOption::ForeignKey(fk) => {
788                convert_fk_actions(&fk.on_delete, &fk.on_update)?;
789                let ftable = object_name_to_string(&fk.foreign_table).to_ascii_lowercase();
790                let referred: Vec<String> = fk
791                    .referred_columns
792                    .iter()
793                    .map(|i| i.value.to_ascii_lowercase())
794                    .collect();
795                fk_def = Some(ForeignKeyDef {
796                    name: fk.name.as_ref().map(|n| n.value.clone()),
797                    columns: vec![col_name.to_ascii_lowercase()],
798                    foreign_table: ftable,
799                    referred_columns: referred,
800                });
801            }
802            _ => {}
803        }
804    }
805
806    let spec = ColumnSpec {
807        name: col_name,
808        data_type,
809        nullable,
810        is_primary_key,
811        default_expr,
812        default_sql,
813        check_expr,
814        check_sql,
815        check_name,
816    };
817    Ok((spec, fk_def, is_primary_key, is_unique))
818}
819
820fn convert_create_table(ct: sp::CreateTable) -> Result<Statement> {
821    let name = object_name_to_string(&ct.name);
822    let if_not_exists = ct.if_not_exists;
823
824    let mut columns = Vec::new();
825    let mut inline_pk: Vec<String> = Vec::new();
826    let mut foreign_keys: Vec<ForeignKeyDef> = Vec::new();
827    let mut unique_indices: Vec<UniqueIndexDef> = Vec::new();
828
829    for col_def in &ct.columns {
830        let (spec, fk_def, was_pk, was_unique) = convert_column_def(col_def)?;
831        if was_pk {
832            inline_pk.push(spec.name.clone());
833        }
834        if let Some(fk) = fk_def {
835            foreign_keys.push(fk);
836        }
837        if was_unique && !was_pk {
838            unique_indices.push(UniqueIndexDef {
839                name: None,
840                columns: vec![spec.name.to_ascii_lowercase()],
841            });
842        }
843        columns.push(spec);
844    }
845
846    let mut check_constraints: Vec<TableCheckConstraint> = Vec::new();
847
848    for constraint in &ct.constraints {
849        match constraint {
850            sp::TableConstraint::PrimaryKey(pk_constraint) => {
851                for idx_col in &pk_constraint.columns {
852                    let col_name = match &idx_col.column.expr {
853                        sp::Expr::Identifier(ident) => ident.value.clone(),
854                        _ => continue,
855                    };
856                    if !inline_pk.contains(&col_name) {
857                        inline_pk.push(col_name.clone());
858                    }
859                    if let Some(col) = columns.iter_mut().find(|c| c.name == col_name) {
860                        col.nullable = false;
861                        col.is_primary_key = true;
862                    }
863                }
864            }
865            sp::TableConstraint::Check(check) => {
866                let sql = check.expr.to_string();
867                let converted = convert_expr(&check.expr)?;
868                if has_subquery(&converted) {
869                    return Err(SqlError::Unsupported("subquery in CHECK constraint".into()));
870                }
871                check_constraints.push(TableCheckConstraint {
872                    name: check.name.as_ref().map(|n| n.value.clone()),
873                    expr: converted,
874                    sql,
875                });
876            }
877            sp::TableConstraint::ForeignKey(fk) => {
878                convert_fk_actions(&fk.on_delete, &fk.on_update)?;
879                let cols: Vec<String> = fk
880                    .columns
881                    .iter()
882                    .map(|i| i.value.to_ascii_lowercase())
883                    .collect();
884                let ftable = object_name_to_string(&fk.foreign_table).to_ascii_lowercase();
885                let referred: Vec<String> = fk
886                    .referred_columns
887                    .iter()
888                    .map(|i| i.value.to_ascii_lowercase())
889                    .collect();
890                foreign_keys.push(ForeignKeyDef {
891                    name: fk.name.as_ref().map(|n| n.value.clone()),
892                    columns: cols,
893                    foreign_table: ftable,
894                    referred_columns: referred,
895                });
896            }
897            sp::TableConstraint::Unique(u) => {
898                let cols: Vec<String> = u
899                    .columns
900                    .iter()
901                    .filter_map(|idx_col| match &idx_col.column.expr {
902                        sp::Expr::Identifier(ident) => Some(ident.value.to_ascii_lowercase()),
903                        _ => None,
904                    })
905                    .collect();
906                if !cols.is_empty() {
907                    unique_indices.push(UniqueIndexDef {
908                        name: u.name.as_ref().map(|n| n.value.clone()),
909                        columns: cols,
910                    });
911                }
912            }
913            _ => {}
914        }
915    }
916
917    Ok(Statement::CreateTable(CreateTableStmt {
918        name,
919        columns,
920        primary_key: inline_pk,
921        if_not_exists,
922        check_constraints,
923        foreign_keys,
924        unique_indices,
925    }))
926}
927
928fn convert_alter_table(at: sp::AlterTable) -> Result<Statement> {
929    let table = object_name_to_string(&at.name);
930    if at.operations.len() != 1 {
931        return Err(SqlError::Unsupported(
932            "ALTER TABLE with multiple operations".into(),
933        ));
934    }
935    let op = match at.operations.into_iter().next().unwrap() {
936        sp::AlterTableOperation::AddColumn {
937            column_def,
938            if_not_exists,
939            ..
940        } => {
941            let (spec, fk, _was_pk, _was_unique) = convert_column_def(&column_def)?;
942            AlterTableOp::AddColumn {
943                column: Box::new(spec),
944                foreign_key: fk,
945                if_not_exists,
946            }
947        }
948        sp::AlterTableOperation::DropColumn {
949            column_names,
950            if_exists,
951            ..
952        } => {
953            if column_names.len() != 1 {
954                return Err(SqlError::Unsupported(
955                    "DROP COLUMN with multiple columns".into(),
956                ));
957            }
958            AlterTableOp::DropColumn {
959                name: column_names.into_iter().next().unwrap().value,
960                if_exists,
961            }
962        }
963        sp::AlterTableOperation::RenameColumn {
964            old_column_name,
965            new_column_name,
966        } => AlterTableOp::RenameColumn {
967            old_name: old_column_name.value,
968            new_name: new_column_name.value,
969        },
970        sp::AlterTableOperation::RenameTable { table_name } => {
971            let new_name = match table_name {
972                sp::RenameTableNameKind::To(name) | sp::RenameTableNameKind::As(name) => {
973                    object_name_to_string(&name)
974                }
975            };
976            AlterTableOp::RenameTable { new_name }
977        }
978        other => {
979            return Err(SqlError::Unsupported(format!(
980                "ALTER TABLE operation: {other}"
981            )));
982        }
983    };
984    Ok(Statement::AlterTable(Box::new(AlterTableStmt {
985        table,
986        op,
987    })))
988}
989
990fn convert_fk_actions(
991    on_delete: &Option<sp::ReferentialAction>,
992    on_update: &Option<sp::ReferentialAction>,
993) -> Result<()> {
994    for action in [on_delete, on_update] {
995        match action {
996            None
997            | Some(sp::ReferentialAction::Restrict)
998            | Some(sp::ReferentialAction::NoAction) => {}
999            Some(other) => {
1000                return Err(SqlError::Unsupported(format!(
1001                    "FOREIGN KEY action: {other}"
1002                )));
1003            }
1004        }
1005    }
1006    Ok(())
1007}
1008
1009fn convert_create_index(ci: sp::CreateIndex) -> Result<Statement> {
1010    let index_name = ci
1011        .name
1012        .as_ref()
1013        .map(object_name_to_string)
1014        .ok_or_else(|| SqlError::Parse("index name required".into()))?;
1015
1016    let table_name = object_name_to_string(&ci.table_name);
1017
1018    let columns: Vec<String> = ci
1019        .columns
1020        .iter()
1021        .map(|idx_col| match &idx_col.column.expr {
1022            sp::Expr::Identifier(ident) => Ok(ident.value.clone()),
1023            other => Err(SqlError::Unsupported(format!("expression index: {other}"))),
1024        })
1025        .collect::<Result<_>>()?;
1026
1027    if columns.is_empty() {
1028        return Err(SqlError::Parse(
1029            "index must have at least one column".into(),
1030        ));
1031    }
1032
1033    Ok(Statement::CreateIndex(CreateIndexStmt {
1034        index_name,
1035        table_name,
1036        columns,
1037        unique: ci.unique,
1038        if_not_exists: ci.if_not_exists,
1039    }))
1040}
1041
1042fn convert_create_view(cv: sp::CreateView) -> Result<Statement> {
1043    let name = object_name_to_string(&cv.name);
1044
1045    if cv.materialized {
1046        return Err(SqlError::Unsupported("MATERIALIZED VIEW".into()));
1047    }
1048
1049    let sql = cv.query.to_string();
1050
1051    let dialect = GenericDialect {};
1052    let test = Parser::parse_sql(&dialect, &sql).map_err(|e| SqlError::Parse(e.to_string()))?;
1053    if test.is_empty() {
1054        return Err(SqlError::Parse("empty view definition".into()));
1055    }
1056    match &test[0] {
1057        sp::Statement::Query(_) => {}
1058        _ => {
1059            return Err(SqlError::Parse(
1060                "view body must be a SELECT statement".into(),
1061            ))
1062        }
1063    }
1064
1065    let column_aliases: Vec<String> = cv
1066        .columns
1067        .iter()
1068        .map(|c| c.name.value.to_ascii_lowercase())
1069        .collect();
1070
1071    Ok(Statement::CreateView(CreateViewStmt {
1072        name,
1073        sql,
1074        column_aliases,
1075        or_replace: cv.or_replace,
1076        if_not_exists: cv.if_not_exists,
1077    }))
1078}
1079
1080fn convert_insert(insert: sp::Insert) -> Result<Statement> {
1081    let table = match &insert.table {
1082        sp::TableObject::TableName(name) => object_name_to_string(name).to_ascii_lowercase(),
1083        _ => return Err(SqlError::Unsupported("INSERT into non-table object".into())),
1084    };
1085
1086    let columns: Vec<String> = insert
1087        .columns
1088        .iter()
1089        .map(|c| c.value.to_ascii_lowercase())
1090        .collect();
1091
1092    let query = insert
1093        .source
1094        .ok_or_else(|| SqlError::Parse("INSERT requires VALUES or SELECT".into()))?;
1095
1096    let source = match *query.body {
1097        sp::SetExpr::Values(sp::Values { rows, .. }) => {
1098            let mut result = Vec::new();
1099            for row in rows {
1100                let mut exprs = Vec::new();
1101                for expr in row {
1102                    exprs.push(convert_expr(&expr)?);
1103                }
1104                result.push(exprs);
1105            }
1106            InsertSource::Values(result)
1107        }
1108        _ => {
1109            let (ctes, recursive) = if let Some(ref with) = query.with {
1110                convert_with(with)?
1111            } else {
1112                (vec![], false)
1113            };
1114            let body = convert_query_body(&query)?;
1115            InsertSource::Select(Box::new(SelectQuery {
1116                ctes,
1117                recursive,
1118                body,
1119            }))
1120        }
1121    };
1122
1123    Ok(Statement::Insert(InsertStmt {
1124        table,
1125        columns,
1126        source,
1127    }))
1128}
1129
1130fn convert_select_body(select: &sp::Select) -> Result<SelectStmt> {
1131    let distinct = match &select.distinct {
1132        Some(sp::Distinct::Distinct) => true,
1133        Some(sp::Distinct::On(_)) => {
1134            return Err(SqlError::Unsupported("DISTINCT ON".into()));
1135        }
1136        _ => false,
1137    };
1138
1139    let (from, from_alias, joins) = if select.from.is_empty() {
1140        (String::new(), None, vec![])
1141    } else if select.from.len() == 1 {
1142        let table_with_joins = &select.from[0];
1143        let (name, alias) = match &table_with_joins.relation {
1144            sp::TableFactor::Table { name, alias, .. } => {
1145                let table_name = object_name_to_string(name);
1146                let alias_str = alias.as_ref().map(|a| a.name.value.clone());
1147                (table_name, alias_str)
1148            }
1149            _ => return Err(SqlError::Unsupported("non-table FROM source".into())),
1150        };
1151        let j = table_with_joins
1152            .joins
1153            .iter()
1154            .map(convert_join)
1155            .collect::<Result<Vec<_>>>()?;
1156        (name, alias, j)
1157    } else {
1158        return Err(SqlError::Unsupported("comma-separated FROM tables".into()));
1159    };
1160
1161    let columns: Vec<SelectColumn> = select
1162        .projection
1163        .iter()
1164        .map(convert_select_item)
1165        .collect::<Result<_>>()?;
1166
1167    let where_clause = select.selection.as_ref().map(convert_expr).transpose()?;
1168
1169    let group_by = match &select.group_by {
1170        sp::GroupByExpr::Expressions(exprs, _) => {
1171            exprs.iter().map(convert_expr).collect::<Result<_>>()?
1172        }
1173        sp::GroupByExpr::All(_) => {
1174            return Err(SqlError::Unsupported("GROUP BY ALL".into()));
1175        }
1176    };
1177
1178    let having = select.having.as_ref().map(convert_expr).transpose()?;
1179
1180    Ok(SelectStmt {
1181        columns,
1182        from,
1183        from_alias,
1184        joins,
1185        distinct,
1186        where_clause,
1187        order_by: vec![],
1188        limit: None,
1189        offset: None,
1190        group_by,
1191        having,
1192    })
1193}
1194
1195fn convert_set_expr(set_expr: &sp::SetExpr) -> Result<QueryBody> {
1196    match set_expr {
1197        sp::SetExpr::Select(sel) => Ok(QueryBody::Select(Box::new(convert_select_body(sel)?))),
1198        sp::SetExpr::SetOperation {
1199            op,
1200            set_quantifier,
1201            left,
1202            right,
1203        } => {
1204            let set_op = match op {
1205                sp::SetOperator::Union => SetOp::Union,
1206                sp::SetOperator::Intersect => SetOp::Intersect,
1207                sp::SetOperator::Except | sp::SetOperator::Minus => SetOp::Except,
1208            };
1209            let all = match set_quantifier {
1210                sp::SetQuantifier::All => true,
1211                sp::SetQuantifier::None | sp::SetQuantifier::Distinct => false,
1212                _ => {
1213                    return Err(SqlError::Unsupported("BY NAME set operations".into()));
1214                }
1215            };
1216            Ok(QueryBody::Compound(Box::new(CompoundSelect {
1217                op: set_op,
1218                all,
1219                left: Box::new(convert_set_expr(left)?),
1220                right: Box::new(convert_set_expr(right)?),
1221                order_by: vec![],
1222                limit: None,
1223                offset: None,
1224            })))
1225        }
1226        _ => Err(SqlError::Unsupported("unsupported set expression".into())),
1227    }
1228}
1229
1230fn convert_query_body(query: &sp::Query) -> Result<QueryBody> {
1231    let mut body = convert_set_expr(&query.body)?;
1232
1233    let order_by = if let Some(ref ob) = query.order_by {
1234        match &ob.kind {
1235            sp::OrderByKind::Expressions(exprs) => exprs
1236                .iter()
1237                .map(convert_order_by_expr)
1238                .collect::<Result<_>>()?,
1239            sp::OrderByKind::All { .. } => {
1240                return Err(SqlError::Unsupported("ORDER BY ALL".into()));
1241            }
1242        }
1243    } else {
1244        vec![]
1245    };
1246
1247    let (limit, offset) = match &query.limit_clause {
1248        Some(sp::LimitClause::LimitOffset { limit, offset, .. }) => {
1249            let l = limit.as_ref().map(convert_expr).transpose()?;
1250            let o = offset
1251                .as_ref()
1252                .map(|o| convert_expr(&o.value))
1253                .transpose()?;
1254            (l, o)
1255        }
1256        Some(sp::LimitClause::OffsetCommaLimit { limit, offset }) => {
1257            let l = Some(convert_expr(limit)?);
1258            let o = Some(convert_expr(offset)?);
1259            (l, o)
1260        }
1261        None => (None, None),
1262    };
1263
1264    match &mut body {
1265        QueryBody::Select(sel) => {
1266            sel.order_by = order_by;
1267            sel.limit = limit;
1268            sel.offset = offset;
1269        }
1270        QueryBody::Compound(comp) => {
1271            comp.order_by = order_by;
1272            comp.limit = limit;
1273            comp.offset = offset;
1274        }
1275    }
1276
1277    Ok(body)
1278}
1279
1280fn convert_subquery(query: &sp::Query) -> Result<SelectStmt> {
1281    if query.with.is_some() {
1282        return Err(SqlError::Unsupported("CTEs in subqueries".into()));
1283    }
1284    match convert_query_body(query)? {
1285        QueryBody::Select(s) => Ok(*s),
1286        QueryBody::Compound(_) => Err(SqlError::Unsupported(
1287            "UNION/INTERSECT/EXCEPT in subqueries".into(),
1288        )),
1289    }
1290}
1291
1292fn convert_with(with: &sp::With) -> Result<(Vec<CteDefinition>, bool)> {
1293    let mut names = std::collections::HashSet::new();
1294    let mut ctes = Vec::new();
1295    for cte in &with.cte_tables {
1296        let name = cte.alias.name.value.to_ascii_lowercase();
1297        if !names.insert(name.clone()) {
1298            return Err(SqlError::DuplicateCteName(name));
1299        }
1300        let column_aliases: Vec<String> = cte
1301            .alias
1302            .columns
1303            .iter()
1304            .map(|c| c.name.value.to_ascii_lowercase())
1305            .collect();
1306        let body = convert_query_body(&cte.query)?;
1307        ctes.push(CteDefinition {
1308            name,
1309            column_aliases,
1310            body,
1311        });
1312    }
1313    Ok((ctes, with.recursive))
1314}
1315
1316fn convert_query(query: sp::Query) -> Result<Statement> {
1317    let (ctes, recursive) = if let Some(ref with) = query.with {
1318        convert_with(with)?
1319    } else {
1320        (vec![], false)
1321    };
1322    let body = convert_query_body(&query)?;
1323    Ok(Statement::Select(Box::new(SelectQuery {
1324        ctes,
1325        recursive,
1326        body,
1327    })))
1328}
1329
1330fn convert_join(join: &sp::Join) -> Result<JoinClause> {
1331    let (join_type, constraint) = match &join.join_operator {
1332        sp::JoinOperator::Inner(c) => (JoinType::Inner, Some(c)),
1333        sp::JoinOperator::Join(c) => (JoinType::Inner, Some(c)),
1334        sp::JoinOperator::CrossJoin(c) => (JoinType::Cross, Some(c)),
1335        sp::JoinOperator::Left(c) => (JoinType::Left, Some(c)),
1336        sp::JoinOperator::LeftSemi(c) => (JoinType::Left, Some(c)),
1337        sp::JoinOperator::LeftAnti(c) => (JoinType::Left, Some(c)),
1338        sp::JoinOperator::Right(c) => (JoinType::Right, Some(c)),
1339        sp::JoinOperator::RightSemi(c) => (JoinType::Right, Some(c)),
1340        sp::JoinOperator::RightAnti(c) => (JoinType::Right, Some(c)),
1341        other => return Err(SqlError::Unsupported(format!("join type: {other:?}"))),
1342    };
1343
1344    let (name, alias) = match &join.relation {
1345        sp::TableFactor::Table { name, alias, .. } => {
1346            let table_name = object_name_to_string(name);
1347            let alias_str = alias.as_ref().map(|a| a.name.value.clone());
1348            (table_name, alias_str)
1349        }
1350        _ => return Err(SqlError::Unsupported("non-table JOIN source".into())),
1351    };
1352
1353    let on_clause = match constraint {
1354        Some(sp::JoinConstraint::On(expr)) => Some(convert_expr(expr)?),
1355        Some(sp::JoinConstraint::None) | None => None,
1356        Some(other) => return Err(SqlError::Unsupported(format!("join constraint: {other:?}"))),
1357    };
1358
1359    Ok(JoinClause {
1360        join_type,
1361        table: TableRef { name, alias },
1362        on_clause,
1363    })
1364}
1365
1366fn convert_update(update: sp::Update) -> Result<Statement> {
1367    let table = match &update.table.relation {
1368        sp::TableFactor::Table { name, .. } => object_name_to_string(name),
1369        _ => return Err(SqlError::Unsupported("non-table UPDATE target".into())),
1370    };
1371
1372    let assignments = update
1373        .assignments
1374        .iter()
1375        .map(|a| {
1376            let col = match &a.target {
1377                sp::AssignmentTarget::ColumnName(name) => object_name_to_string(name),
1378                _ => return Err(SqlError::Unsupported("tuple assignment".into())),
1379            };
1380            let expr = convert_expr(&a.value)?;
1381            Ok((col, expr))
1382        })
1383        .collect::<Result<_>>()?;
1384
1385    let where_clause = update.selection.as_ref().map(convert_expr).transpose()?;
1386
1387    Ok(Statement::Update(UpdateStmt {
1388        table,
1389        assignments,
1390        where_clause,
1391    }))
1392}
1393
1394fn convert_delete(delete: sp::Delete) -> Result<Statement> {
1395    let table_name = match &delete.from {
1396        sp::FromTable::WithFromKeyword(tables) => {
1397            if tables.len() != 1 {
1398                return Err(SqlError::Unsupported("multi-table DELETE".into()));
1399            }
1400            match &tables[0].relation {
1401                sp::TableFactor::Table { name, .. } => object_name_to_string(name),
1402                _ => return Err(SqlError::Unsupported("non-table DELETE target".into())),
1403            }
1404        }
1405        sp::FromTable::WithoutKeyword(tables) => {
1406            if tables.len() != 1 {
1407                return Err(SqlError::Unsupported("multi-table DELETE".into()));
1408            }
1409            match &tables[0].relation {
1410                sp::TableFactor::Table { name, .. } => object_name_to_string(name),
1411                _ => return Err(SqlError::Unsupported("non-table DELETE target".into())),
1412            }
1413        }
1414    };
1415
1416    let where_clause = delete.selection.as_ref().map(convert_expr).transpose()?;
1417
1418    Ok(Statement::Delete(DeleteStmt {
1419        table: table_name,
1420        where_clause,
1421    }))
1422}
1423
1424fn convert_expr(expr: &sp::Expr) -> Result<Expr> {
1425    match expr {
1426        sp::Expr::Value(v) => convert_value(&v.value),
1427        sp::Expr::Identifier(ident) => Ok(Expr::Column(ident.value.to_ascii_lowercase())),
1428        sp::Expr::CompoundIdentifier(parts) => {
1429            if parts.len() == 2 {
1430                Ok(Expr::QualifiedColumn {
1431                    table: parts[0].value.to_ascii_lowercase(),
1432                    column: parts[1].value.to_ascii_lowercase(),
1433                })
1434            } else {
1435                Ok(Expr::Column(
1436                    parts.last().unwrap().value.to_ascii_lowercase(),
1437                ))
1438            }
1439        }
1440        sp::Expr::BinaryOp { left, op, right } => {
1441            let bin_op = convert_bin_op(op)?;
1442            Ok(Expr::BinaryOp {
1443                left: Box::new(convert_expr(left)?),
1444                op: bin_op,
1445                right: Box::new(convert_expr(right)?),
1446            })
1447        }
1448        sp::Expr::UnaryOp { op, expr } => {
1449            let unary_op = match op {
1450                sp::UnaryOperator::Minus => UnaryOp::Neg,
1451                sp::UnaryOperator::Not => UnaryOp::Not,
1452                _ => return Err(SqlError::Unsupported(format!("unary op: {op}"))),
1453            };
1454            Ok(Expr::UnaryOp {
1455                op: unary_op,
1456                expr: Box::new(convert_expr(expr)?),
1457            })
1458        }
1459        sp::Expr::IsNull(e) => Ok(Expr::IsNull(Box::new(convert_expr(e)?))),
1460        sp::Expr::IsNotNull(e) => Ok(Expr::IsNotNull(Box::new(convert_expr(e)?))),
1461        sp::Expr::Nested(e) => convert_expr(e),
1462        sp::Expr::Function(func) => convert_function(func),
1463        sp::Expr::InSubquery {
1464            expr: e,
1465            subquery,
1466            negated,
1467        } => {
1468            let inner_expr = convert_expr(e)?;
1469            let stmt = convert_subquery(subquery)?;
1470            Ok(Expr::InSubquery {
1471                expr: Box::new(inner_expr),
1472                subquery: Box::new(stmt),
1473                negated: *negated,
1474            })
1475        }
1476        sp::Expr::InList {
1477            expr: e,
1478            list,
1479            negated,
1480        } => {
1481            let inner_expr = convert_expr(e)?;
1482            let items = list.iter().map(convert_expr).collect::<Result<Vec<_>>>()?;
1483            Ok(Expr::InList {
1484                expr: Box::new(inner_expr),
1485                list: items,
1486                negated: *negated,
1487            })
1488        }
1489        sp::Expr::Exists { subquery, negated } => {
1490            let stmt = convert_subquery(subquery)?;
1491            Ok(Expr::Exists {
1492                subquery: Box::new(stmt),
1493                negated: *negated,
1494            })
1495        }
1496        sp::Expr::Subquery(query) => {
1497            let stmt = convert_subquery(query)?;
1498            Ok(Expr::ScalarSubquery(Box::new(stmt)))
1499        }
1500        sp::Expr::Between {
1501            expr: e,
1502            negated,
1503            low,
1504            high,
1505        } => Ok(Expr::Between {
1506            expr: Box::new(convert_expr(e)?),
1507            low: Box::new(convert_expr(low)?),
1508            high: Box::new(convert_expr(high)?),
1509            negated: *negated,
1510        }),
1511        sp::Expr::Like {
1512            expr: e,
1513            negated,
1514            pattern,
1515            escape_char,
1516            ..
1517        } => {
1518            let esc = escape_char
1519                .as_ref()
1520                .map(convert_escape_value)
1521                .transpose()?
1522                .map(Box::new);
1523            Ok(Expr::Like {
1524                expr: Box::new(convert_expr(e)?),
1525                pattern: Box::new(convert_expr(pattern)?),
1526                escape: esc,
1527                negated: *negated,
1528            })
1529        }
1530        sp::Expr::ILike {
1531            expr: e,
1532            negated,
1533            pattern,
1534            escape_char,
1535            ..
1536        } => {
1537            let esc = escape_char
1538                .as_ref()
1539                .map(convert_escape_value)
1540                .transpose()?
1541                .map(Box::new);
1542            Ok(Expr::Like {
1543                expr: Box::new(convert_expr(e)?),
1544                pattern: Box::new(convert_expr(pattern)?),
1545                escape: esc,
1546                negated: *negated,
1547            })
1548        }
1549        sp::Expr::Case {
1550            operand,
1551            conditions,
1552            else_result,
1553            ..
1554        } => {
1555            let op = operand
1556                .as_ref()
1557                .map(|e| convert_expr(e))
1558                .transpose()?
1559                .map(Box::new);
1560            let conds: Vec<(Expr, Expr)> = conditions
1561                .iter()
1562                .map(|cw| Ok((convert_expr(&cw.condition)?, convert_expr(&cw.result)?)))
1563                .collect::<Result<_>>()?;
1564            let else_r = else_result
1565                .as_ref()
1566                .map(|e| convert_expr(e))
1567                .transpose()?
1568                .map(Box::new);
1569            Ok(Expr::Case {
1570                operand: op,
1571                conditions: conds,
1572                else_result: else_r,
1573            })
1574        }
1575        sp::Expr::Cast {
1576            expr: e,
1577            data_type: dt,
1578            ..
1579        } => {
1580            let target = convert_data_type(dt)?;
1581            Ok(Expr::Cast {
1582                expr: Box::new(convert_expr(e)?),
1583                data_type: target,
1584            })
1585        }
1586        sp::Expr::Substring {
1587            expr: e,
1588            substring_from,
1589            substring_for,
1590            ..
1591        } => {
1592            let mut args = vec![convert_expr(e)?];
1593            if let Some(from) = substring_from {
1594                args.push(convert_expr(from)?);
1595            }
1596            if let Some(f) = substring_for {
1597                args.push(convert_expr(f)?);
1598            }
1599            Ok(Expr::Function {
1600                name: "SUBSTR".into(),
1601                args,
1602                distinct: false,
1603            })
1604        }
1605        sp::Expr::Trim {
1606            expr: e,
1607            trim_where,
1608            trim_what,
1609            trim_characters,
1610        } => {
1611            let fn_name = match trim_where {
1612                Some(sp::TrimWhereField::Leading) => "LTRIM",
1613                Some(sp::TrimWhereField::Trailing) => "RTRIM",
1614                _ => "TRIM",
1615            };
1616            let mut args = vec![convert_expr(e)?];
1617            if let Some(what) = trim_what {
1618                args.push(convert_expr(what)?);
1619            } else if let Some(chars) = trim_characters {
1620                if let Some(first) = chars.first() {
1621                    args.push(convert_expr(first)?);
1622                }
1623            }
1624            Ok(Expr::Function {
1625                name: fn_name.into(),
1626                args,
1627                distinct: false,
1628            })
1629        }
1630        sp::Expr::Ceil { expr: e, .. } => Ok(Expr::Function {
1631            name: "CEIL".into(),
1632            args: vec![convert_expr(e)?],
1633            distinct: false,
1634        }),
1635        sp::Expr::Floor { expr: e, .. } => Ok(Expr::Function {
1636            name: "FLOOR".into(),
1637            args: vec![convert_expr(e)?],
1638            distinct: false,
1639        }),
1640        sp::Expr::Position { expr: e, r#in } => Ok(Expr::Function {
1641            name: "INSTR".into(),
1642            args: vec![convert_expr(r#in)?, convert_expr(e)?],
1643            distinct: false,
1644        }),
1645        // Typed literal: `DATE '2024-01-15'`, `TIMESTAMP '...'`, etc.
1646        sp::Expr::TypedString(ts) => {
1647            let raw = match &ts.value.value {
1648                sp::Value::SingleQuotedString(s) => s.clone(),
1649                sp::Value::DoubleQuotedString(s) => s.clone(),
1650                other => other.to_string(),
1651            };
1652            convert_typed_string(&ts.data_type, &raw)
1653        }
1654        // INTERVAL '...' — sqlparser emits Expr::Interval with a boxed Expr value + field qualifiers
1655        sp::Expr::Interval(iv) => convert_interval_expr(iv),
1656        // EXTRACT(field FROM src)
1657        sp::Expr::Extract { field, expr: e, .. } => {
1658            let field_name = match field {
1659                sp::DateTimeField::Year => "year",
1660                sp::DateTimeField::Month => "month",
1661                sp::DateTimeField::Week(_) => "week",
1662                sp::DateTimeField::Day => "day",
1663                sp::DateTimeField::Date => "day",
1664                sp::DateTimeField::Hour => "hour",
1665                sp::DateTimeField::Minute => "minute",
1666                sp::DateTimeField::Second => "second",
1667                sp::DateTimeField::Millisecond => "milliseconds",
1668                sp::DateTimeField::Microsecond => "microseconds",
1669                sp::DateTimeField::Microseconds => "microseconds",
1670                sp::DateTimeField::Milliseconds => "milliseconds",
1671                sp::DateTimeField::Dow => "dow",
1672                sp::DateTimeField::Isodow => "isodow",
1673                sp::DateTimeField::Doy => "doy",
1674                sp::DateTimeField::Epoch => "epoch",
1675                sp::DateTimeField::Quarter => "quarter",
1676                sp::DateTimeField::Decade => "decade",
1677                sp::DateTimeField::Century => "century",
1678                sp::DateTimeField::Millennium => "millennium",
1679                sp::DateTimeField::Isoyear => "isoyear",
1680                sp::DateTimeField::Julian => "julian",
1681                other => {
1682                    return Err(SqlError::InvalidExtractField(format!("{other:?}")));
1683                }
1684            };
1685            Ok(Expr::Function {
1686                name: "EXTRACT".into(),
1687                args: vec![
1688                    Expr::Literal(Value::Text(field_name.into())),
1689                    convert_expr(e)?,
1690                ],
1691                distinct: false,
1692            })
1693        }
1694        // `AT TIME ZONE 'zone'` operator — desugars to AT_TIMEZONE(ts, zone) scalar function.
1695        sp::Expr::AtTimeZone {
1696            timestamp,
1697            time_zone,
1698        } => Ok(Expr::Function {
1699            name: "AT_TIMEZONE".into(),
1700            args: vec![convert_expr(timestamp)?, convert_expr(time_zone)?],
1701            distinct: false,
1702        }),
1703        _ => Err(SqlError::Unsupported(format!("expression: {expr}"))),
1704    }
1705}
1706
1707fn convert_value(val: &sp::Value) -> Result<Expr> {
1708    match val {
1709        sp::Value::Number(n, _) => {
1710            if let Ok(i) = n.parse::<i64>() {
1711                Ok(Expr::Literal(Value::Integer(i)))
1712            } else if let Ok(f) = n.parse::<f64>() {
1713                Ok(Expr::Literal(Value::Real(f)))
1714            } else {
1715                Err(SqlError::InvalidValue(format!("cannot parse number: {n}")))
1716            }
1717        }
1718        sp::Value::SingleQuotedString(s) => Ok(Expr::Literal(Value::Text(s.as_str().into()))),
1719        sp::Value::Boolean(b) => Ok(Expr::Literal(Value::Boolean(*b))),
1720        sp::Value::Null => Ok(Expr::Literal(Value::Null)),
1721        sp::Value::Placeholder(s) => {
1722            let idx_str = s
1723                .strip_prefix('$')
1724                .ok_or_else(|| SqlError::Parse(format!("invalid placeholder: {s}")))?;
1725            let idx: usize = idx_str
1726                .parse()
1727                .map_err(|_| SqlError::Parse(format!("invalid placeholder index: {s}")))?;
1728            if idx == 0 {
1729                return Err(SqlError::Parse("placeholder index must be >= 1".into()));
1730            }
1731            Ok(Expr::Parameter(idx))
1732        }
1733        _ => Err(SqlError::Unsupported(format!("value type: {val}"))),
1734    }
1735}
1736
1737fn convert_typed_string(dt: &sp::DataType, value: &str) -> Result<Expr> {
1738    let s = value.trim_matches('\'');
1739    match dt {
1740        sp::DataType::Date => {
1741            let d = crate::datetime::parse_date(s)?;
1742            Ok(Expr::Literal(Value::Date(d)))
1743        }
1744        sp::DataType::Time(_, _) => {
1745            let t = crate::datetime::parse_time(s)?;
1746            Ok(Expr::Literal(Value::Time(t)))
1747        }
1748        sp::DataType::Timestamp(_, _) => {
1749            let t = crate::datetime::parse_timestamp(s)?;
1750            Ok(Expr::Literal(Value::Timestamp(t)))
1751        }
1752        sp::DataType::Interval { .. } => {
1753            let (months, days, micros) = crate::datetime::parse_interval(s)?;
1754            Ok(Expr::Literal(Value::Interval {
1755                months,
1756                days,
1757                micros,
1758            }))
1759        }
1760        _ => {
1761            let target = convert_data_type(dt)?;
1762            Ok(Expr::Cast {
1763                expr: Box::new(Expr::Literal(Value::Text(s.into()))),
1764                data_type: target,
1765            })
1766        }
1767    }
1768}
1769
1770fn convert_interval_expr(iv: &sp::Interval) -> Result<Expr> {
1771    let raw = match iv.value.as_ref() {
1772        sp::Expr::Value(v) => match &v.value {
1773            sp::Value::SingleQuotedString(s) => s.clone(),
1774            sp::Value::Number(n, _) => n.clone(),
1775            other => {
1776                return Err(SqlError::InvalidIntervalLiteral(format!(
1777                    "unsupported inner value: {other}"
1778                )))
1779            }
1780        },
1781        other => {
1782            return Err(SqlError::InvalidIntervalLiteral(format!(
1783                "unsupported inner expr: {other}"
1784            )))
1785        }
1786    };
1787
1788    // SQL-standard form `INTERVAL '5' DAY` — append the unit to the literal.
1789    let with_unit = if let Some(field) = &iv.leading_field {
1790        let unit_name = match field {
1791            sp::DateTimeField::Year => "years",
1792            sp::DateTimeField::Month => "months",
1793            sp::DateTimeField::Week(_) => "weeks",
1794            sp::DateTimeField::Day => "days",
1795            sp::DateTimeField::Hour => "hours",
1796            sp::DateTimeField::Minute => "minutes",
1797            sp::DateTimeField::Second => "seconds",
1798            _ => {
1799                return Err(SqlError::InvalidIntervalLiteral(format!(
1800                    "unsupported leading field: {field:?}"
1801                )))
1802            }
1803        };
1804        format!("{raw} {unit_name}")
1805    } else {
1806        raw
1807    };
1808
1809    let (months, days, micros) = crate::datetime::parse_interval(&with_unit)?;
1810    Ok(Expr::Literal(Value::Interval {
1811        months,
1812        days,
1813        micros,
1814    }))
1815}
1816
1817fn convert_escape_value(val: &sp::Value) -> Result<Expr> {
1818    match val {
1819        sp::Value::SingleQuotedString(s) => Ok(Expr::Literal(Value::Text(s.as_str().into()))),
1820        _ => Err(SqlError::Unsupported(format!("ESCAPE value: {val}"))),
1821    }
1822}
1823
1824fn convert_bin_op(op: &sp::BinaryOperator) -> Result<BinOp> {
1825    match op {
1826        sp::BinaryOperator::Plus => Ok(BinOp::Add),
1827        sp::BinaryOperator::Minus => Ok(BinOp::Sub),
1828        sp::BinaryOperator::Multiply => Ok(BinOp::Mul),
1829        sp::BinaryOperator::Divide => Ok(BinOp::Div),
1830        sp::BinaryOperator::Modulo => Ok(BinOp::Mod),
1831        sp::BinaryOperator::Eq => Ok(BinOp::Eq),
1832        sp::BinaryOperator::NotEq => Ok(BinOp::NotEq),
1833        sp::BinaryOperator::Lt => Ok(BinOp::Lt),
1834        sp::BinaryOperator::Gt => Ok(BinOp::Gt),
1835        sp::BinaryOperator::LtEq => Ok(BinOp::LtEq),
1836        sp::BinaryOperator::GtEq => Ok(BinOp::GtEq),
1837        sp::BinaryOperator::And => Ok(BinOp::And),
1838        sp::BinaryOperator::Or => Ok(BinOp::Or),
1839        sp::BinaryOperator::StringConcat => Ok(BinOp::Concat),
1840        _ => Err(SqlError::Unsupported(format!("binary op: {op}"))),
1841    }
1842}
1843
1844fn convert_function(func: &sp::Function) -> Result<Expr> {
1845    let name = object_name_to_string(&func.name).to_ascii_uppercase();
1846
1847    let (args, is_count_star, distinct) = match &func.args {
1848        sp::FunctionArguments::List(list) => {
1849            let distinct = matches!(
1850                list.duplicate_treatment,
1851                Some(sp::DuplicateTreatment::Distinct)
1852            );
1853            if list.args.is_empty() && name == "COUNT" {
1854                (vec![], true, distinct)
1855            } else {
1856                let mut count_star = false;
1857                let args = list
1858                    .args
1859                    .iter()
1860                    .map(|arg| match arg {
1861                        sp::FunctionArg::Unnamed(sp::FunctionArgExpr::Expr(e)) => convert_expr(e),
1862                        sp::FunctionArg::Unnamed(sp::FunctionArgExpr::Wildcard) => {
1863                            if name == "COUNT" {
1864                                count_star = true;
1865                                Ok(Expr::CountStar)
1866                            } else {
1867                                Err(SqlError::Unsupported(format!("{name}(*)")))
1868                            }
1869                        }
1870                        _ => Err(SqlError::Unsupported(format!(
1871                            "function arg type in {name}"
1872                        ))),
1873                    })
1874                    .collect::<Result<Vec<_>>>()?;
1875                if name == "COUNT" && args.len() == 1 && count_star {
1876                    (vec![], true, distinct)
1877                } else {
1878                    (args, false, distinct)
1879                }
1880            }
1881        }
1882        sp::FunctionArguments::None => {
1883            if name == "COUNT" {
1884                (vec![], true, false)
1885            } else {
1886                (vec![], false, false)
1887            }
1888        }
1889        sp::FunctionArguments::Subquery(_) => {
1890            return Err(SqlError::Unsupported("subquery in function".into()));
1891        }
1892    };
1893
1894    // Window function: check OVER before any other special handling
1895    if let Some(over) = &func.over {
1896        let spec = match over {
1897            sp::WindowType::WindowSpec(ws) => convert_window_spec(ws)?,
1898            sp::WindowType::NamedWindow(_) => {
1899                return Err(SqlError::Unsupported("named windows".into()));
1900            }
1901        };
1902        return Ok(Expr::WindowFunction { name, args, spec });
1903    }
1904
1905    // Non-window special forms
1906    if is_count_star {
1907        return Ok(Expr::CountStar);
1908    }
1909
1910    if name == "COALESCE" {
1911        if args.is_empty() {
1912            return Err(SqlError::Parse(
1913                "COALESCE requires at least one argument".into(),
1914            ));
1915        }
1916        return Ok(Expr::Coalesce(args));
1917    }
1918
1919    if name == "NULLIF" {
1920        if args.len() != 2 {
1921            return Err(SqlError::Parse(
1922                "NULLIF requires exactly two arguments".into(),
1923            ));
1924        }
1925        return Ok(Expr::Case {
1926            operand: None,
1927            conditions: vec![(
1928                Expr::BinaryOp {
1929                    left: Box::new(args[0].clone()),
1930                    op: BinOp::Eq,
1931                    right: Box::new(args[1].clone()),
1932                },
1933                Expr::Literal(Value::Null),
1934            )],
1935            else_result: Some(Box::new(args[0].clone())),
1936        });
1937    }
1938
1939    if name == "IIF" {
1940        if args.len() != 3 {
1941            return Err(SqlError::Parse(
1942                "IIF requires exactly three arguments".into(),
1943            ));
1944        }
1945        return Ok(Expr::Case {
1946            operand: None,
1947            conditions: vec![(args[0].clone(), args[1].clone())],
1948            else_result: Some(Box::new(args[2].clone())),
1949        });
1950    }
1951
1952    Ok(Expr::Function {
1953        name,
1954        args,
1955        distinct,
1956    })
1957}
1958
1959fn convert_window_spec(ws: &sp::WindowSpec) -> Result<WindowSpec> {
1960    let partition_by = ws
1961        .partition_by
1962        .iter()
1963        .map(convert_expr)
1964        .collect::<Result<Vec<_>>>()?;
1965    let order_by = ws
1966        .order_by
1967        .iter()
1968        .map(convert_order_by_expr)
1969        .collect::<Result<Vec<_>>>()?;
1970    let frame = ws
1971        .window_frame
1972        .as_ref()
1973        .map(convert_window_frame)
1974        .transpose()?;
1975    Ok(WindowSpec {
1976        partition_by,
1977        order_by,
1978        frame,
1979    })
1980}
1981
1982fn convert_window_frame(wf: &sp::WindowFrame) -> Result<WindowFrame> {
1983    let units = match wf.units {
1984        sp::WindowFrameUnits::Rows => WindowFrameUnits::Rows,
1985        sp::WindowFrameUnits::Range => WindowFrameUnits::Range,
1986        sp::WindowFrameUnits::Groups => {
1987            return Err(SqlError::Unsupported("GROUPS window frame".into()));
1988        }
1989    };
1990    let start = convert_window_frame_bound(&wf.start_bound)?;
1991    let end = match &wf.end_bound {
1992        Some(b) => convert_window_frame_bound(b)?,
1993        None => WindowFrameBound::CurrentRow,
1994    };
1995    Ok(WindowFrame { units, start, end })
1996}
1997
1998fn convert_window_frame_bound(b: &sp::WindowFrameBound) -> Result<WindowFrameBound> {
1999    match b {
2000        sp::WindowFrameBound::CurrentRow => Ok(WindowFrameBound::CurrentRow),
2001        sp::WindowFrameBound::Preceding(None) => Ok(WindowFrameBound::UnboundedPreceding),
2002        sp::WindowFrameBound::Preceding(Some(e)) => {
2003            Ok(WindowFrameBound::Preceding(Box::new(convert_expr(e)?)))
2004        }
2005        sp::WindowFrameBound::Following(None) => Ok(WindowFrameBound::UnboundedFollowing),
2006        sp::WindowFrameBound::Following(Some(e)) => {
2007            Ok(WindowFrameBound::Following(Box::new(convert_expr(e)?)))
2008        }
2009    }
2010}
2011
2012fn convert_select_item(item: &sp::SelectItem) -> Result<SelectColumn> {
2013    match item {
2014        sp::SelectItem::Wildcard(_) => Ok(SelectColumn::AllColumns),
2015        sp::SelectItem::UnnamedExpr(e) => {
2016            let expr = convert_expr(e)?;
2017            Ok(SelectColumn::Expr { expr, alias: None })
2018        }
2019        sp::SelectItem::ExprWithAlias { expr, alias } => {
2020            let expr = convert_expr(expr)?;
2021            Ok(SelectColumn::Expr {
2022                expr,
2023                alias: Some(alias.value.clone()),
2024            })
2025        }
2026        sp::SelectItem::QualifiedWildcard(_, _) => {
2027            Err(SqlError::Unsupported("qualified wildcard (table.*)".into()))
2028        }
2029    }
2030}
2031
2032fn convert_order_by_expr(expr: &sp::OrderByExpr) -> Result<OrderByItem> {
2033    let e = convert_expr(&expr.expr)?;
2034    let descending = expr.options.asc.map(|asc| !asc).unwrap_or(false);
2035    let nulls_first = expr.options.nulls_first;
2036
2037    Ok(OrderByItem {
2038        expr: e,
2039        descending,
2040        nulls_first,
2041    })
2042}
2043
2044fn convert_data_type(dt: &sp::DataType) -> Result<DataType> {
2045    match dt {
2046        sp::DataType::Int(_)
2047        | sp::DataType::Integer(_)
2048        | sp::DataType::BigInt(_)
2049        | sp::DataType::SmallInt(_)
2050        | sp::DataType::TinyInt(_)
2051        | sp::DataType::Int2(_)
2052        | sp::DataType::Int4(_)
2053        | sp::DataType::Int8(_) => Ok(DataType::Integer),
2054
2055        sp::DataType::Real
2056        | sp::DataType::Double(..)
2057        | sp::DataType::DoublePrecision
2058        | sp::DataType::Float(_)
2059        | sp::DataType::Float4
2060        | sp::DataType::Float64 => Ok(DataType::Real),
2061
2062        sp::DataType::Varchar(_)
2063        | sp::DataType::Text
2064        | sp::DataType::Char(_)
2065        | sp::DataType::Character(_)
2066        | sp::DataType::String(_) => Ok(DataType::Text),
2067
2068        sp::DataType::Blob(_) | sp::DataType::Bytea => Ok(DataType::Blob),
2069
2070        sp::DataType::Boolean | sp::DataType::Bool => Ok(DataType::Boolean),
2071
2072        sp::DataType::Date => Ok(DataType::Date),
2073        sp::DataType::Time(_, _) => Ok(DataType::Time),
2074        sp::DataType::Timestamp(_, _) => Ok(DataType::Timestamp),
2075        sp::DataType::Interval { .. } => Ok(DataType::Interval),
2076
2077        _ => Err(SqlError::Unsupported(format!("data type: {dt}"))),
2078    }
2079}
2080
2081fn object_name_to_string(name: &sp::ObjectName) -> String {
2082    name.0
2083        .iter()
2084        .filter_map(|p| match p {
2085            sp::ObjectNamePart::Identifier(ident) => Some(ident.value.clone()),
2086            _ => None,
2087        })
2088        .collect::<Vec<_>>()
2089        .join(".")
2090}
2091
2092#[cfg(test)]
2093mod tests {
2094    use super::*;
2095
2096    #[test]
2097    fn parse_create_table() {
2098        let stmt = parse_sql(
2099            "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL, age INTEGER)",
2100        )
2101        .unwrap();
2102
2103        match stmt {
2104            Statement::CreateTable(ct) => {
2105                assert_eq!(ct.name, "users");
2106                assert_eq!(ct.columns.len(), 3);
2107                assert_eq!(ct.columns[0].name, "id");
2108                assert_eq!(ct.columns[0].data_type, DataType::Integer);
2109                assert!(ct.columns[0].is_primary_key);
2110                assert!(!ct.columns[0].nullable);
2111                assert_eq!(ct.columns[1].name, "name");
2112                assert_eq!(ct.columns[1].data_type, DataType::Text);
2113                assert!(!ct.columns[1].nullable);
2114                assert_eq!(ct.columns[2].name, "age");
2115                assert!(ct.columns[2].nullable);
2116                assert_eq!(ct.primary_key, vec!["id"]);
2117            }
2118            _ => panic!("expected CreateTable"),
2119        }
2120    }
2121
2122    #[test]
2123    fn parse_create_table_if_not_exists() {
2124        let stmt = parse_sql("CREATE TABLE IF NOT EXISTS t (id INT PRIMARY KEY)").unwrap();
2125        match stmt {
2126            Statement::CreateTable(ct) => assert!(ct.if_not_exists),
2127            _ => panic!("expected CreateTable"),
2128        }
2129    }
2130
2131    #[test]
2132    fn parse_drop_table() {
2133        let stmt = parse_sql("DROP TABLE users").unwrap();
2134        match stmt {
2135            Statement::DropTable(dt) => {
2136                assert_eq!(dt.name, "users");
2137                assert!(!dt.if_exists);
2138            }
2139            _ => panic!("expected DropTable"),
2140        }
2141    }
2142
2143    #[test]
2144    fn parse_drop_table_if_exists() {
2145        let stmt = parse_sql("DROP TABLE IF EXISTS users").unwrap();
2146        match stmt {
2147            Statement::DropTable(dt) => assert!(dt.if_exists),
2148            _ => panic!("expected DropTable"),
2149        }
2150    }
2151
2152    #[test]
2153    fn parse_insert() {
2154        let stmt =
2155            parse_sql("INSERT INTO users (id, name) VALUES (1, 'Alice'), (2, 'Bob')").unwrap();
2156
2157        match stmt {
2158            Statement::Insert(ins) => {
2159                assert_eq!(ins.table, "users");
2160                assert_eq!(ins.columns, vec!["id", "name"]);
2161                let values = match &ins.source {
2162                    InsertSource::Values(v) => v,
2163                    _ => panic!("expected Values"),
2164                };
2165                assert_eq!(values.len(), 2);
2166                assert!(matches!(values[0][0], Expr::Literal(Value::Integer(1))));
2167                assert!(matches!(&values[0][1], Expr::Literal(Value::Text(s)) if s == "Alice"));
2168            }
2169            _ => panic!("expected Insert"),
2170        }
2171    }
2172
2173    #[test]
2174    fn parse_select_all() {
2175        let stmt = parse_sql("SELECT * FROM users").unwrap();
2176        match stmt {
2177            Statement::Select(sq) => match sq.body {
2178                QueryBody::Select(sel) => {
2179                    assert_eq!(sel.from, "users");
2180                    assert!(matches!(sel.columns[0], SelectColumn::AllColumns));
2181                    assert!(sel.where_clause.is_none());
2182                }
2183                _ => panic!("expected QueryBody::Select"),
2184            },
2185            _ => panic!("expected Select"),
2186        }
2187    }
2188
2189    #[test]
2190    fn parse_select_where() {
2191        let stmt = parse_sql("SELECT id, name FROM users WHERE age > 18").unwrap();
2192        match stmt {
2193            Statement::Select(sq) => match sq.body {
2194                QueryBody::Select(sel) => {
2195                    assert_eq!(sel.columns.len(), 2);
2196                    assert!(sel.where_clause.is_some());
2197                }
2198                _ => panic!("expected QueryBody::Select"),
2199            },
2200            _ => panic!("expected Select"),
2201        }
2202    }
2203
2204    #[test]
2205    fn parse_select_order_limit() {
2206        let stmt = parse_sql("SELECT * FROM users ORDER BY name ASC LIMIT 10 OFFSET 5").unwrap();
2207        match stmt {
2208            Statement::Select(sq) => match sq.body {
2209                QueryBody::Select(sel) => {
2210                    assert_eq!(sel.order_by.len(), 1);
2211                    assert!(!sel.order_by[0].descending);
2212                    assert!(sel.limit.is_some());
2213                    assert!(sel.offset.is_some());
2214                }
2215                _ => panic!("expected QueryBody::Select"),
2216            },
2217            _ => panic!("expected Select"),
2218        }
2219    }
2220
2221    #[test]
2222    fn parse_update() {
2223        let stmt = parse_sql("UPDATE users SET name = 'Bob' WHERE id = 1").unwrap();
2224        match stmt {
2225            Statement::Update(upd) => {
2226                assert_eq!(upd.table, "users");
2227                assert_eq!(upd.assignments.len(), 1);
2228                assert_eq!(upd.assignments[0].0, "name");
2229                assert!(upd.where_clause.is_some());
2230            }
2231            _ => panic!("expected Update"),
2232        }
2233    }
2234
2235    #[test]
2236    fn parse_delete() {
2237        let stmt = parse_sql("DELETE FROM users WHERE id = 1").unwrap();
2238        match stmt {
2239            Statement::Delete(del) => {
2240                assert_eq!(del.table, "users");
2241                assert!(del.where_clause.is_some());
2242            }
2243            _ => panic!("expected Delete"),
2244        }
2245    }
2246
2247    #[test]
2248    fn parse_aggregate() {
2249        let stmt = parse_sql("SELECT COUNT(*), SUM(age) FROM users").unwrap();
2250        match stmt {
2251            Statement::Select(sq) => match sq.body {
2252                QueryBody::Select(sel) => {
2253                    assert_eq!(sel.columns.len(), 2);
2254                    match &sel.columns[0] {
2255                        SelectColumn::Expr {
2256                            expr: Expr::CountStar,
2257                            ..
2258                        } => {}
2259                        other => panic!("expected CountStar, got {other:?}"),
2260                    }
2261                }
2262                _ => panic!("expected QueryBody::Select"),
2263            },
2264            _ => panic!("expected Select"),
2265        }
2266    }
2267
2268    #[test]
2269    fn parse_group_by_having() {
2270        let stmt = parse_sql(
2271            "SELECT department, COUNT(*) FROM employees GROUP BY department HAVING COUNT(*) > 5",
2272        )
2273        .unwrap();
2274        match stmt {
2275            Statement::Select(sq) => match sq.body {
2276                QueryBody::Select(sel) => {
2277                    assert_eq!(sel.group_by.len(), 1);
2278                    assert!(sel.having.is_some());
2279                }
2280                _ => panic!("expected QueryBody::Select"),
2281            },
2282            _ => panic!("expected Select"),
2283        }
2284    }
2285
2286    #[test]
2287    fn parse_expressions() {
2288        let stmt = parse_sql("SELECT id + 1, -price, NOT active FROM items").unwrap();
2289        match stmt {
2290            Statement::Select(sq) => match sq.body {
2291                QueryBody::Select(sel) => {
2292                    assert_eq!(sel.columns.len(), 3);
2293                    // id + 1
2294                    match &sel.columns[0] {
2295                        SelectColumn::Expr {
2296                            expr: Expr::BinaryOp { op: BinOp::Add, .. },
2297                            ..
2298                        } => {}
2299                        other => panic!("expected BinaryOp Add, got {other:?}"),
2300                    }
2301                    // -price
2302                    match &sel.columns[1] {
2303                        SelectColumn::Expr {
2304                            expr:
2305                                Expr::UnaryOp {
2306                                    op: UnaryOp::Neg, ..
2307                                },
2308                            ..
2309                        } => {}
2310                        other => panic!("expected UnaryOp Neg, got {other:?}"),
2311                    }
2312                    // NOT active
2313                    match &sel.columns[2] {
2314                        SelectColumn::Expr {
2315                            expr:
2316                                Expr::UnaryOp {
2317                                    op: UnaryOp::Not, ..
2318                                },
2319                            ..
2320                        } => {}
2321                        other => panic!("expected UnaryOp Not, got {other:?}"),
2322                    }
2323                }
2324                _ => panic!("expected QueryBody::Select"),
2325            },
2326            _ => panic!("expected Select"),
2327        }
2328    }
2329
2330    #[test]
2331    fn parse_is_null() {
2332        let stmt = parse_sql("SELECT * FROM t WHERE x IS NULL").unwrap();
2333        match stmt {
2334            Statement::Select(sq) => match sq.body {
2335                QueryBody::Select(sel) => {
2336                    assert!(matches!(sel.where_clause, Some(Expr::IsNull(_))));
2337                }
2338                _ => panic!("expected QueryBody::Select"),
2339            },
2340            _ => panic!("expected Select"),
2341        }
2342    }
2343
2344    #[test]
2345    fn parse_inner_join() {
2346        let stmt = parse_sql("SELECT * FROM a JOIN b ON a.id = b.id").unwrap();
2347        match stmt {
2348            Statement::Select(sq) => match sq.body {
2349                QueryBody::Select(sel) => {
2350                    assert_eq!(sel.from, "a");
2351                    assert_eq!(sel.joins.len(), 1);
2352                    assert_eq!(sel.joins[0].join_type, JoinType::Inner);
2353                    assert_eq!(sel.joins[0].table.name, "b");
2354                    assert!(sel.joins[0].on_clause.is_some());
2355                }
2356                _ => panic!("expected QueryBody::Select"),
2357            },
2358            _ => panic!("expected Select"),
2359        }
2360    }
2361
2362    #[test]
2363    fn parse_inner_join_explicit() {
2364        let stmt = parse_sql("SELECT * FROM a INNER JOIN b ON a.id = b.a_id").unwrap();
2365        match stmt {
2366            Statement::Select(sq) => match sq.body {
2367                QueryBody::Select(sel) => {
2368                    assert_eq!(sel.joins.len(), 1);
2369                    assert_eq!(sel.joins[0].join_type, JoinType::Inner);
2370                }
2371                _ => panic!("expected QueryBody::Select"),
2372            },
2373            _ => panic!("expected Select"),
2374        }
2375    }
2376
2377    #[test]
2378    fn parse_cross_join() {
2379        let stmt = parse_sql("SELECT * FROM a CROSS JOIN b").unwrap();
2380        match stmt {
2381            Statement::Select(sq) => match sq.body {
2382                QueryBody::Select(sel) => {
2383                    assert_eq!(sel.joins.len(), 1);
2384                    assert_eq!(sel.joins[0].join_type, JoinType::Cross);
2385                    assert!(sel.joins[0].on_clause.is_none());
2386                }
2387                _ => panic!("expected QueryBody::Select"),
2388            },
2389            _ => panic!("expected Select"),
2390        }
2391    }
2392
2393    #[test]
2394    fn parse_left_join() {
2395        let stmt = parse_sql("SELECT * FROM a LEFT JOIN b ON a.id = b.a_id").unwrap();
2396        match stmt {
2397            Statement::Select(sq) => match sq.body {
2398                QueryBody::Select(sel) => {
2399                    assert_eq!(sel.joins.len(), 1);
2400                    assert_eq!(sel.joins[0].join_type, JoinType::Left);
2401                }
2402                _ => panic!("expected QueryBody::Select"),
2403            },
2404            _ => panic!("expected Select"),
2405        }
2406    }
2407
2408    #[test]
2409    fn parse_table_alias() {
2410        let stmt = parse_sql("SELECT u.id FROM users u JOIN orders o ON u.id = o.user_id").unwrap();
2411        match stmt {
2412            Statement::Select(sq) => match sq.body {
2413                QueryBody::Select(sel) => {
2414                    assert_eq!(sel.from, "users");
2415                    assert_eq!(sel.from_alias.as_deref(), Some("u"));
2416                    assert_eq!(sel.joins[0].table.name, "orders");
2417                    assert_eq!(sel.joins[0].table.alias.as_deref(), Some("o"));
2418                }
2419                _ => panic!("expected QueryBody::Select"),
2420            },
2421            _ => panic!("expected Select"),
2422        }
2423    }
2424
2425    #[test]
2426    fn parse_multi_join() {
2427        let stmt =
2428            parse_sql("SELECT * FROM a JOIN b ON a.id = b.a_id JOIN c ON b.id = c.b_id").unwrap();
2429        match stmt {
2430            Statement::Select(sq) => match sq.body {
2431                QueryBody::Select(sel) => {
2432                    assert_eq!(sel.joins.len(), 2);
2433                }
2434                _ => panic!("expected QueryBody::Select"),
2435            },
2436            _ => panic!("expected Select"),
2437        }
2438    }
2439
2440    #[test]
2441    fn parse_qualified_column() {
2442        let stmt = parse_sql("SELECT u.id, u.name FROM users u").unwrap();
2443        match stmt {
2444            Statement::Select(sq) => match sq.body {
2445                QueryBody::Select(sel) => match &sel.columns[0] {
2446                    SelectColumn::Expr {
2447                        expr: Expr::QualifiedColumn { table, column },
2448                        ..
2449                    } => {
2450                        assert_eq!(table, "u");
2451                        assert_eq!(column, "id");
2452                    }
2453                    other => panic!("expected QualifiedColumn, got {other:?}"),
2454                },
2455                _ => panic!("expected QueryBody::Select"),
2456            },
2457            _ => panic!("expected Select"),
2458        }
2459    }
2460
2461    #[test]
2462    fn reject_subquery() {
2463        assert!(parse_sql("SELECT * FROM (SELECT 1)").is_err());
2464    }
2465
2466    #[test]
2467    fn parse_type_mapping() {
2468        let stmt = parse_sql(
2469            "CREATE TABLE t (a INT PRIMARY KEY, b BIGINT, c SMALLINT, d REAL, e DOUBLE PRECISION, f VARCHAR(255), g BOOLEAN, h BLOB, i BYTEA)"
2470        ).unwrap();
2471        match stmt {
2472            Statement::CreateTable(ct) => {
2473                assert_eq!(ct.columns[0].data_type, DataType::Integer); // INT
2474                assert_eq!(ct.columns[1].data_type, DataType::Integer); // BIGINT
2475                assert_eq!(ct.columns[2].data_type, DataType::Integer); // SMALLINT
2476                assert_eq!(ct.columns[3].data_type, DataType::Real); // REAL
2477                assert_eq!(ct.columns[4].data_type, DataType::Real); // DOUBLE
2478                assert_eq!(ct.columns[5].data_type, DataType::Text); // VARCHAR
2479                assert_eq!(ct.columns[6].data_type, DataType::Boolean); // BOOLEAN
2480                assert_eq!(ct.columns[7].data_type, DataType::Blob); // BLOB
2481                assert_eq!(ct.columns[8].data_type, DataType::Blob); // BYTEA
2482            }
2483            _ => panic!("expected CreateTable"),
2484        }
2485    }
2486
2487    #[test]
2488    fn parse_boolean_literals() {
2489        let stmt = parse_sql("INSERT INTO t (a, b) VALUES (true, false)").unwrap();
2490        match stmt {
2491            Statement::Insert(ins) => {
2492                let values = match &ins.source {
2493                    InsertSource::Values(v) => v,
2494                    _ => panic!("expected Values"),
2495                };
2496                assert!(matches!(values[0][0], Expr::Literal(Value::Boolean(true))));
2497                assert!(matches!(values[0][1], Expr::Literal(Value::Boolean(false))));
2498            }
2499            _ => panic!("expected Insert"),
2500        }
2501    }
2502
2503    #[test]
2504    fn parse_null_literal() {
2505        let stmt = parse_sql("INSERT INTO t (a) VALUES (NULL)").unwrap();
2506        match stmt {
2507            Statement::Insert(ins) => {
2508                let values = match &ins.source {
2509                    InsertSource::Values(v) => v,
2510                    _ => panic!("expected Values"),
2511                };
2512                assert!(matches!(values[0][0], Expr::Literal(Value::Null)));
2513            }
2514            _ => panic!("expected Insert"),
2515        }
2516    }
2517
2518    #[test]
2519    fn parse_alias() {
2520        let stmt = parse_sql("SELECT id AS user_id FROM users").unwrap();
2521        match stmt {
2522            Statement::Select(sq) => match sq.body {
2523                QueryBody::Select(sel) => match &sel.columns[0] {
2524                    SelectColumn::Expr { alias: Some(a), .. } => assert_eq!(a, "user_id"),
2525                    other => panic!("expected alias, got {other:?}"),
2526                },
2527                _ => panic!("expected QueryBody::Select"),
2528            },
2529            _ => panic!("expected Select"),
2530        }
2531    }
2532
2533    #[test]
2534    fn parse_begin() {
2535        let stmt = parse_sql("BEGIN").unwrap();
2536        assert!(matches!(stmt, Statement::Begin));
2537    }
2538
2539    #[test]
2540    fn parse_begin_transaction() {
2541        let stmt = parse_sql("BEGIN TRANSACTION").unwrap();
2542        assert!(matches!(stmt, Statement::Begin));
2543    }
2544
2545    #[test]
2546    fn parse_commit() {
2547        let stmt = parse_sql("COMMIT").unwrap();
2548        assert!(matches!(stmt, Statement::Commit));
2549    }
2550
2551    #[test]
2552    fn parse_rollback() {
2553        let stmt = parse_sql("ROLLBACK").unwrap();
2554        assert!(matches!(stmt, Statement::Rollback));
2555    }
2556
2557    #[test]
2558    fn parse_savepoint() {
2559        let stmt = parse_sql("SAVEPOINT sp1").unwrap();
2560        match stmt {
2561            Statement::Savepoint(name) => assert_eq!(name, "sp1"),
2562            other => panic!("expected Savepoint, got {other:?}"),
2563        }
2564    }
2565
2566    #[test]
2567    fn parse_savepoint_case_insensitive() {
2568        let stmt = parse_sql("SAVEPOINT My_SP").unwrap();
2569        match stmt {
2570            Statement::Savepoint(name) => assert_eq!(name, "my_sp"),
2571            other => panic!("expected Savepoint, got {other:?}"),
2572        }
2573    }
2574
2575    #[test]
2576    fn parse_release_savepoint() {
2577        let stmt = parse_sql("RELEASE SAVEPOINT sp1").unwrap();
2578        match stmt {
2579            Statement::ReleaseSavepoint(name) => assert_eq!(name, "sp1"),
2580            other => panic!("expected ReleaseSavepoint, got {other:?}"),
2581        }
2582    }
2583
2584    #[test]
2585    fn parse_release_without_savepoint_keyword() {
2586        let stmt = parse_sql("RELEASE sp1").unwrap();
2587        match stmt {
2588            Statement::ReleaseSavepoint(name) => assert_eq!(name, "sp1"),
2589            other => panic!("expected ReleaseSavepoint, got {other:?}"),
2590        }
2591    }
2592
2593    #[test]
2594    fn parse_rollback_to_savepoint() {
2595        let stmt = parse_sql("ROLLBACK TO SAVEPOINT sp1").unwrap();
2596        match stmt {
2597            Statement::RollbackTo(name) => assert_eq!(name, "sp1"),
2598            other => panic!("expected RollbackTo, got {other:?}"),
2599        }
2600    }
2601
2602    #[test]
2603    fn parse_rollback_to_without_savepoint_keyword() {
2604        let stmt = parse_sql("ROLLBACK TO sp1").unwrap();
2605        match stmt {
2606            Statement::RollbackTo(name) => assert_eq!(name, "sp1"),
2607            other => panic!("expected RollbackTo, got {other:?}"),
2608        }
2609    }
2610
2611    #[test]
2612    fn parse_rollback_to_case_insensitive() {
2613        let stmt = parse_sql("ROLLBACK TO My_SP").unwrap();
2614        match stmt {
2615            Statement::RollbackTo(name) => assert_eq!(name, "my_sp"),
2616            other => panic!("expected RollbackTo, got {other:?}"),
2617        }
2618    }
2619
2620    #[test]
2621    fn parse_commit_and_chain_rejected() {
2622        let err = parse_sql("COMMIT AND CHAIN").unwrap_err();
2623        assert!(matches!(err, SqlError::Unsupported(_)));
2624    }
2625
2626    #[test]
2627    fn parse_rollback_and_chain_rejected() {
2628        let err = parse_sql("ROLLBACK AND CHAIN").unwrap_err();
2629        assert!(matches!(err, SqlError::Unsupported(_)));
2630    }
2631
2632    #[test]
2633    fn parse_select_distinct() {
2634        let stmt = parse_sql("SELECT DISTINCT name FROM users").unwrap();
2635        match stmt {
2636            Statement::Select(sq) => match sq.body {
2637                QueryBody::Select(sel) => {
2638                    assert!(sel.distinct);
2639                    assert_eq!(sel.columns.len(), 1);
2640                }
2641                _ => panic!("expected QueryBody::Select"),
2642            },
2643            _ => panic!("expected Select"),
2644        }
2645    }
2646
2647    #[test]
2648    fn parse_select_without_distinct() {
2649        let stmt = parse_sql("SELECT name FROM users").unwrap();
2650        match stmt {
2651            Statement::Select(sq) => match sq.body {
2652                QueryBody::Select(sel) => {
2653                    assert!(!sel.distinct);
2654                }
2655                _ => panic!("expected QueryBody::Select"),
2656            },
2657            _ => panic!("expected Select"),
2658        }
2659    }
2660
2661    #[test]
2662    fn parse_select_distinct_all_columns() {
2663        let stmt = parse_sql("SELECT DISTINCT * FROM users").unwrap();
2664        match stmt {
2665            Statement::Select(sq) => match sq.body {
2666                QueryBody::Select(sel) => {
2667                    assert!(sel.distinct);
2668                    assert!(matches!(sel.columns[0], SelectColumn::AllColumns));
2669                }
2670                _ => panic!("expected QueryBody::Select"),
2671            },
2672            _ => panic!("expected Select"),
2673        }
2674    }
2675
2676    #[test]
2677    fn reject_distinct_on() {
2678        assert!(parse_sql("SELECT DISTINCT ON (id) * FROM users").is_err());
2679    }
2680
2681    #[test]
2682    fn parse_create_index() {
2683        let stmt = parse_sql("CREATE INDEX idx_name ON users (name)").unwrap();
2684        match stmt {
2685            Statement::CreateIndex(ci) => {
2686                assert_eq!(ci.index_name, "idx_name");
2687                assert_eq!(ci.table_name, "users");
2688                assert_eq!(ci.columns, vec!["name"]);
2689                assert!(!ci.unique);
2690                assert!(!ci.if_not_exists);
2691            }
2692            _ => panic!("expected CreateIndex"),
2693        }
2694    }
2695
2696    #[test]
2697    fn parse_create_unique_index() {
2698        let stmt = parse_sql("CREATE UNIQUE INDEX idx_email ON users (email)").unwrap();
2699        match stmt {
2700            Statement::CreateIndex(ci) => {
2701                assert!(ci.unique);
2702                assert_eq!(ci.columns, vec!["email"]);
2703            }
2704            _ => panic!("expected CreateIndex"),
2705        }
2706    }
2707
2708    #[test]
2709    fn parse_create_index_if_not_exists() {
2710        let stmt = parse_sql("CREATE INDEX IF NOT EXISTS idx_x ON t (a)").unwrap();
2711        match stmt {
2712            Statement::CreateIndex(ci) => assert!(ci.if_not_exists),
2713            _ => panic!("expected CreateIndex"),
2714        }
2715    }
2716
2717    #[test]
2718    fn parse_create_index_multi_column() {
2719        let stmt = parse_sql("CREATE INDEX idx_multi ON t (a, b, c)").unwrap();
2720        match stmt {
2721            Statement::CreateIndex(ci) => {
2722                assert_eq!(ci.columns, vec!["a", "b", "c"]);
2723            }
2724            _ => panic!("expected CreateIndex"),
2725        }
2726    }
2727
2728    #[test]
2729    fn parse_drop_index() {
2730        let stmt = parse_sql("DROP INDEX idx_name").unwrap();
2731        match stmt {
2732            Statement::DropIndex(di) => {
2733                assert_eq!(di.index_name, "idx_name");
2734                assert!(!di.if_exists);
2735            }
2736            _ => panic!("expected DropIndex"),
2737        }
2738    }
2739
2740    #[test]
2741    fn parse_drop_index_if_exists() {
2742        let stmt = parse_sql("DROP INDEX IF EXISTS idx_name").unwrap();
2743        match stmt {
2744            Statement::DropIndex(di) => {
2745                assert!(di.if_exists);
2746            }
2747            _ => panic!("expected DropIndex"),
2748        }
2749    }
2750
2751    #[test]
2752    fn parse_explain_select() {
2753        let stmt = parse_sql("EXPLAIN SELECT * FROM users WHERE id = 1").unwrap();
2754        match stmt {
2755            Statement::Explain(inner) => {
2756                assert!(matches!(*inner, Statement::Select(_)));
2757            }
2758            _ => panic!("expected Explain"),
2759        }
2760    }
2761
2762    #[test]
2763    fn parse_explain_insert() {
2764        let stmt = parse_sql("EXPLAIN INSERT INTO t (a) VALUES (1)").unwrap();
2765        assert!(matches!(stmt, Statement::Explain(_)));
2766    }
2767
2768    #[test]
2769    fn reject_explain_analyze() {
2770        assert!(parse_sql("EXPLAIN ANALYZE SELECT * FROM t").is_err());
2771    }
2772
2773    #[test]
2774    fn parse_parameter_placeholder() {
2775        let stmt = parse_sql("SELECT * FROM t WHERE id = $1").unwrap();
2776        match stmt {
2777            Statement::Select(sq) => match sq.body {
2778                QueryBody::Select(sel) => match &sel.where_clause {
2779                    Some(Expr::BinaryOp { right, .. }) => {
2780                        assert!(matches!(right.as_ref(), Expr::Parameter(1)));
2781                    }
2782                    other => panic!("expected BinaryOp with Parameter, got {other:?}"),
2783                },
2784                _ => panic!("expected QueryBody::Select"),
2785            },
2786            _ => panic!("expected Select"),
2787        }
2788    }
2789
2790    #[test]
2791    fn parse_multiple_parameters() {
2792        let stmt = parse_sql("INSERT INTO t (a, b) VALUES ($1, $2)").unwrap();
2793        match stmt {
2794            Statement::Insert(ins) => {
2795                let values = match &ins.source {
2796                    InsertSource::Values(v) => v,
2797                    _ => panic!("expected Values"),
2798                };
2799                assert!(matches!(values[0][0], Expr::Parameter(1)));
2800                assert!(matches!(values[0][1], Expr::Parameter(2)));
2801            }
2802            _ => panic!("expected Insert"),
2803        }
2804    }
2805
2806    #[test]
2807    fn parse_insert_select() {
2808        let stmt =
2809            parse_sql("INSERT INTO t2 (id, name) SELECT id, name FROM t1 WHERE id > 5").unwrap();
2810        match stmt {
2811            Statement::Insert(ins) => {
2812                assert_eq!(ins.table, "t2");
2813                assert_eq!(ins.columns, vec!["id", "name"]);
2814                match &ins.source {
2815                    InsertSource::Select(sq) => match &sq.body {
2816                        QueryBody::Select(sel) => {
2817                            assert_eq!(sel.from, "t1");
2818                            assert_eq!(sel.columns.len(), 2);
2819                            assert!(sel.where_clause.is_some());
2820                        }
2821                        _ => panic!("expected QueryBody::Select"),
2822                    },
2823                    _ => panic!("expected InsertSource::Select"),
2824                }
2825            }
2826            _ => panic!("expected Insert"),
2827        }
2828    }
2829
2830    #[test]
2831    fn parse_insert_select_no_columns() {
2832        let stmt = parse_sql("INSERT INTO t2 SELECT * FROM t1").unwrap();
2833        match stmt {
2834            Statement::Insert(ins) => {
2835                assert_eq!(ins.table, "t2");
2836                assert!(ins.columns.is_empty());
2837                assert!(matches!(&ins.source, InsertSource::Select(_)));
2838            }
2839            _ => panic!("expected Insert"),
2840        }
2841    }
2842
2843    #[test]
2844    fn reject_zero_parameter() {
2845        assert!(parse_sql("SELECT $0 FROM t").is_err());
2846    }
2847
2848    #[test]
2849    fn count_params_basic() {
2850        let stmt = parse_sql("SELECT * FROM t WHERE a = $1 AND b = $3").unwrap();
2851        assert_eq!(count_params(&stmt), 3);
2852    }
2853
2854    #[test]
2855    fn count_params_none() {
2856        let stmt = parse_sql("SELECT * FROM t WHERE a = 1").unwrap();
2857        assert_eq!(count_params(&stmt), 0);
2858    }
2859
2860    #[test]
2861    fn parse_table_constraint_pk() {
2862        let stmt = parse_sql("CREATE TABLE t (a INTEGER, b TEXT, PRIMARY KEY (a))").unwrap();
2863        match stmt {
2864            Statement::CreateTable(ct) => {
2865                assert_eq!(ct.primary_key, vec!["a"]);
2866                assert!(ct.columns[0].is_primary_key);
2867                assert!(!ct.columns[0].nullable);
2868            }
2869            _ => panic!("expected CreateTable"),
2870        }
2871    }
2872}