Skip to main content

citadel_sql/
parser.rs

1//! SQL parser: converts SQL strings into our internal AST.
2
3use sqlparser::ast as sp;
4use sqlparser::dialect::GenericDialect;
5use sqlparser::parser::Parser;
6
7use crate::error::{Result, SqlError};
8use crate::types::{DataType, Value};
9
10#[derive(Debug, Clone)]
11pub enum Statement {
12    CreateTable(CreateTableStmt),
13    DropTable(DropTableStmt),
14    CreateIndex(CreateIndexStmt),
15    DropIndex(DropIndexStmt),
16    CreateView(CreateViewStmt),
17    DropView(DropViewStmt),
18    AlterTable(Box<AlterTableStmt>),
19    Insert(InsertStmt),
20    Select(Box<SelectQuery>),
21    Update(UpdateStmt),
22    Delete(DeleteStmt),
23    Begin,
24    Commit,
25    Rollback,
26    Savepoint(String),
27    ReleaseSavepoint(String),
28    RollbackTo(String),
29    SetTimezone(String),
30    Explain(Box<Statement>),
31}
32
33#[derive(Debug, Clone)]
34pub struct AlterTableStmt {
35    pub table: String,
36    pub op: AlterTableOp,
37}
38
39#[derive(Debug, Clone)]
40pub enum AlterTableOp {
41    AddColumn {
42        column: Box<ColumnSpec>,
43        foreign_key: Option<ForeignKeyDef>,
44        if_not_exists: bool,
45    },
46    DropColumn {
47        name: String,
48        if_exists: bool,
49    },
50    RenameColumn {
51        old_name: String,
52        new_name: String,
53    },
54    RenameTable {
55        new_name: String,
56    },
57}
58
59#[derive(Debug, Clone)]
60pub struct CreateTableStmt {
61    pub name: String,
62    pub columns: Vec<ColumnSpec>,
63    pub primary_key: Vec<String>,
64    pub if_not_exists: bool,
65    pub check_constraints: Vec<TableCheckConstraint>,
66    pub foreign_keys: Vec<ForeignKeyDef>,
67    pub unique_indices: Vec<UniqueIndexDef>,
68}
69
70#[derive(Debug, Clone)]
71pub struct UniqueIndexDef {
72    pub name: Option<String>,
73    pub columns: Vec<String>,
74}
75
76#[derive(Debug, Clone)]
77pub struct TableCheckConstraint {
78    pub name: Option<String>,
79    pub expr: Expr,
80    pub sql: String,
81}
82
83#[derive(Debug, Clone)]
84pub struct ForeignKeyDef {
85    pub name: Option<String>,
86    pub columns: Vec<String>,
87    pub foreign_table: String,
88    pub referred_columns: Vec<String>,
89}
90
91#[derive(Debug, Clone)]
92pub struct ColumnSpec {
93    pub name: String,
94    pub data_type: DataType,
95    pub nullable: bool,
96    pub is_primary_key: bool,
97    pub default_expr: Option<Expr>,
98    pub default_sql: Option<String>,
99    pub check_expr: Option<Expr>,
100    pub check_sql: Option<String>,
101    pub check_name: Option<String>,
102}
103
104#[derive(Debug, Clone)]
105pub struct DropTableStmt {
106    pub name: String,
107    pub if_exists: bool,
108}
109
110#[derive(Debug, Clone)]
111pub struct CreateIndexStmt {
112    pub index_name: String,
113    pub table_name: String,
114    pub columns: Vec<String>,
115    pub unique: bool,
116    pub if_not_exists: bool,
117}
118
119#[derive(Debug, Clone)]
120pub struct DropIndexStmt {
121    pub index_name: String,
122    pub if_exists: bool,
123}
124
125#[derive(Debug, Clone)]
126pub struct CreateViewStmt {
127    pub name: String,
128    pub sql: String,
129    pub column_aliases: Vec<String>,
130    pub or_replace: bool,
131    pub if_not_exists: bool,
132}
133
134#[derive(Debug, Clone)]
135pub struct DropViewStmt {
136    pub name: String,
137    pub if_exists: bool,
138}
139
140#[derive(Debug, Clone)]
141pub enum InsertSource {
142    Values(Vec<Vec<Expr>>),
143    Select(Box<SelectQuery>),
144}
145
146#[derive(Debug, Clone)]
147pub struct InsertStmt {
148    pub table: String,
149    pub columns: Vec<String>,
150    pub source: InsertSource,
151    pub on_conflict: Option<OnConflictClause>,
152}
153
154#[derive(Debug, Clone)]
155pub struct OnConflictClause {
156    pub target: Option<ConflictTarget>,
157    pub action: OnConflictAction,
158}
159
160#[derive(Debug, Clone)]
161pub enum ConflictTarget {
162    Columns(Vec<String>),
163    Constraint(String),
164}
165
166#[derive(Debug, Clone)]
167pub enum OnConflictAction {
168    DoNothing,
169    DoUpdate {
170        assignments: Vec<(String, Expr)>,
171        where_clause: Option<Expr>,
172    },
173}
174
175#[derive(Debug, Clone)]
176pub struct TableRef {
177    pub name: String,
178    pub alias: Option<String>,
179}
180
181#[derive(Debug, Clone, Copy, PartialEq)]
182pub enum JoinType {
183    Inner,
184    Cross,
185    Left,
186    Right,
187}
188
189#[derive(Debug, Clone)]
190pub struct JoinClause {
191    pub join_type: JoinType,
192    pub table: TableRef,
193    pub on_clause: Option<Expr>,
194}
195
196#[derive(Debug, Clone)]
197pub struct SelectStmt {
198    pub columns: Vec<SelectColumn>,
199    pub from: String,
200    pub from_alias: Option<String>,
201    pub joins: Vec<JoinClause>,
202    pub distinct: bool,
203    pub where_clause: Option<Expr>,
204    pub order_by: Vec<OrderByItem>,
205    pub limit: Option<Expr>,
206    pub offset: Option<Expr>,
207    pub group_by: Vec<Expr>,
208    pub having: Option<Expr>,
209}
210
211#[derive(Debug, Clone)]
212pub enum SetOp {
213    Union,
214    Intersect,
215    Except,
216}
217
218#[derive(Debug, Clone)]
219pub struct CompoundSelect {
220    pub op: SetOp,
221    pub all: bool,
222    pub left: Box<QueryBody>,
223    pub right: Box<QueryBody>,
224    pub order_by: Vec<OrderByItem>,
225    pub limit: Option<Expr>,
226    pub offset: Option<Expr>,
227}
228
229#[derive(Debug, Clone)]
230pub enum QueryBody {
231    Select(Box<SelectStmt>),
232    Compound(Box<CompoundSelect>),
233}
234
235#[derive(Debug, Clone)]
236pub struct CteDefinition {
237    pub name: String,
238    pub column_aliases: Vec<String>,
239    pub body: QueryBody,
240}
241
242#[derive(Debug, Clone)]
243pub struct SelectQuery {
244    pub ctes: Vec<CteDefinition>,
245    pub recursive: bool,
246    pub body: QueryBody,
247}
248
249#[derive(Debug, Clone)]
250pub struct UpdateStmt {
251    pub table: String,
252    pub assignments: Vec<(String, Expr)>,
253    pub where_clause: Option<Expr>,
254}
255
256#[derive(Debug, Clone)]
257pub struct DeleteStmt {
258    pub table: String,
259    pub where_clause: Option<Expr>,
260}
261
262#[derive(Debug, Clone)]
263pub enum SelectColumn {
264    AllColumns,
265    Expr { expr: Expr, alias: Option<String> },
266}
267
268#[derive(Debug, Clone)]
269pub struct OrderByItem {
270    pub expr: Expr,
271    pub descending: bool,
272    pub nulls_first: Option<bool>,
273}
274
275#[derive(Debug, Clone)]
276pub enum Expr {
277    Literal(Value),
278    Column(String),
279    QualifiedColumn {
280        table: String,
281        column: String,
282    },
283    BinaryOp {
284        left: Box<Expr>,
285        op: BinOp,
286        right: Box<Expr>,
287    },
288    UnaryOp {
289        op: UnaryOp,
290        expr: Box<Expr>,
291    },
292    IsNull(Box<Expr>),
293    IsNotNull(Box<Expr>),
294    Function {
295        name: String,
296        args: Vec<Expr>,
297        /// True for aggregate forms like `COUNT(DISTINCT x)`.
298        distinct: bool,
299    },
300    CountStar,
301    InSubquery {
302        expr: Box<Expr>,
303        subquery: Box<SelectStmt>,
304        negated: bool,
305    },
306    InList {
307        expr: Box<Expr>,
308        list: Vec<Expr>,
309        negated: bool,
310    },
311    Exists {
312        subquery: Box<SelectStmt>,
313        negated: bool,
314    },
315    ScalarSubquery(Box<SelectStmt>),
316    InSet {
317        expr: Box<Expr>,
318        values: std::collections::HashSet<Value>,
319        has_null: bool,
320        negated: bool,
321    },
322    Between {
323        expr: Box<Expr>,
324        low: Box<Expr>,
325        high: Box<Expr>,
326        negated: bool,
327    },
328    Like {
329        expr: Box<Expr>,
330        pattern: Box<Expr>,
331        escape: Option<Box<Expr>>,
332        negated: bool,
333    },
334    Case {
335        operand: Option<Box<Expr>>,
336        conditions: Vec<(Expr, Expr)>,
337        else_result: Option<Box<Expr>>,
338    },
339    Coalesce(Vec<Expr>),
340    Cast {
341        expr: Box<Expr>,
342        data_type: DataType,
343    },
344    Parameter(usize),
345    WindowFunction {
346        name: String,
347        args: Vec<Expr>,
348        spec: WindowSpec,
349    },
350}
351
352#[derive(Debug, Clone)]
353pub struct WindowSpec {
354    pub partition_by: Vec<Expr>,
355    pub order_by: Vec<OrderByItem>,
356    pub frame: Option<WindowFrame>,
357}
358
359#[derive(Debug, Clone)]
360pub struct WindowFrame {
361    pub units: WindowFrameUnits,
362    pub start: WindowFrameBound,
363    pub end: WindowFrameBound,
364}
365
366#[derive(Debug, Clone, Copy)]
367pub enum WindowFrameUnits {
368    Rows,
369    Range,
370    Groups,
371}
372
373#[derive(Debug, Clone)]
374pub enum WindowFrameBound {
375    UnboundedPreceding,
376    Preceding(Box<Expr>),
377    CurrentRow,
378    Following(Box<Expr>),
379    UnboundedFollowing,
380}
381
382#[derive(Debug, Clone, Copy, PartialEq, Eq)]
383pub enum BinOp {
384    Add,
385    Sub,
386    Mul,
387    Div,
388    Mod,
389    Eq,
390    NotEq,
391    Lt,
392    Gt,
393    LtEq,
394    GtEq,
395    And,
396    Or,
397    Concat,
398}
399
400#[derive(Debug, Clone, Copy, PartialEq, Eq)]
401pub enum UnaryOp {
402    Neg,
403    Not,
404}
405
406pub fn has_subquery(expr: &Expr) -> bool {
407    match expr {
408        Expr::InSubquery { .. } | Expr::Exists { .. } | Expr::ScalarSubquery(_) => true,
409        Expr::BinaryOp { left, right, .. } => has_subquery(left) || has_subquery(right),
410        Expr::UnaryOp { expr, .. } => has_subquery(expr),
411        Expr::IsNull(e) | Expr::IsNotNull(e) => has_subquery(e),
412        Expr::InList { expr, list, .. } => has_subquery(expr) || list.iter().any(has_subquery),
413        Expr::InSet { expr, .. } => has_subquery(expr),
414        Expr::Between {
415            expr, low, high, ..
416        } => has_subquery(expr) || has_subquery(low) || has_subquery(high),
417        Expr::Like {
418            expr,
419            pattern,
420            escape,
421            ..
422        } => {
423            has_subquery(expr)
424                || has_subquery(pattern)
425                || escape.as_ref().is_some_and(|e| has_subquery(e))
426        }
427        Expr::Case {
428            operand,
429            conditions,
430            else_result,
431        } => {
432            operand.as_ref().is_some_and(|e| has_subquery(e))
433                || conditions
434                    .iter()
435                    .any(|(c, r)| has_subquery(c) || has_subquery(r))
436                || else_result.as_ref().is_some_and(|e| has_subquery(e))
437        }
438        Expr::Coalesce(args) | Expr::Function { args, .. } => args.iter().any(has_subquery),
439        Expr::Cast { expr, .. } => has_subquery(expr),
440        _ => false,
441    }
442}
443
444/// Parse a SQL expression string back into an internal Expr.
445/// Used for deserializing stored DEFAULT/CHECK expressions from schema.
446pub fn parse_sql_expr(sql: &str) -> Result<Expr> {
447    let dialect = GenericDialect {};
448    let mut parser = Parser::new(&dialect)
449        .try_with_sql(sql)
450        .map_err(|e| SqlError::Parse(e.to_string()))?;
451    let sp_expr = parser
452        .parse_expr()
453        .map_err(|e| SqlError::Parse(e.to_string()))?;
454    convert_expr(&sp_expr)
455}
456
457pub fn parse_sql(sql: &str) -> Result<Statement> {
458    let dialect = GenericDialect {};
459    let stmts = Parser::parse_sql(&dialect, sql).map_err(|e| SqlError::Parse(e.to_string()))?;
460
461    if stmts.is_empty() {
462        return Err(SqlError::Parse("empty SQL".into()));
463    }
464    if stmts.len() > 1 {
465        return Err(SqlError::Unsupported("multiple statements".into()));
466    }
467
468    convert_statement(stmts.into_iter().next().unwrap())
469}
470
471/// Parse one or more `;`-separated SQL statements.
472pub fn parse_sql_multi(sql: &str) -> Result<Vec<Statement>> {
473    let dialect = GenericDialect {};
474    let stmts = Parser::parse_sql(&dialect, sql).map_err(|e| SqlError::Parse(e.to_string()))?;
475
476    if stmts.is_empty() {
477        return Err(SqlError::Parse("empty SQL".into()));
478    }
479
480    stmts.into_iter().map(convert_statement).collect()
481}
482
483/// Returns the number of distinct parameters in a statement (max $N found).
484pub fn count_params(stmt: &Statement) -> usize {
485    let mut max_idx = 0usize;
486    visit_exprs_stmt(stmt, &mut |e| {
487        if let Expr::Parameter(n) = e {
488            max_idx = max_idx.max(*n);
489        }
490    });
491    max_idx
492}
493
494fn visit_exprs_stmt(stmt: &Statement, visitor: &mut impl FnMut(&Expr)) {
495    match stmt {
496        Statement::Select(sq) => {
497            for cte in &sq.ctes {
498                visit_exprs_query_body(&cte.body, visitor);
499            }
500            visit_exprs_query_body(&sq.body, visitor);
501        }
502        Statement::Insert(ins) => match &ins.source {
503            InsertSource::Values(rows) => {
504                for row in rows {
505                    for e in row {
506                        visit_expr(e, visitor);
507                    }
508                }
509            }
510            InsertSource::Select(sq) => {
511                for cte in &sq.ctes {
512                    visit_exprs_query_body(&cte.body, visitor);
513                }
514                visit_exprs_query_body(&sq.body, visitor);
515            }
516        },
517        Statement::Update(upd) => {
518            for (_, e) in &upd.assignments {
519                visit_expr(e, visitor);
520            }
521            if let Some(w) = &upd.where_clause {
522                visit_expr(w, visitor);
523            }
524        }
525        Statement::Delete(del) => {
526            if let Some(w) = &del.where_clause {
527                visit_expr(w, visitor);
528            }
529        }
530        Statement::Explain(inner) => visit_exprs_stmt(inner, visitor),
531        _ => {}
532    }
533}
534
535fn visit_exprs_query_body(body: &QueryBody, visitor: &mut impl FnMut(&Expr)) {
536    match body {
537        QueryBody::Select(sel) => visit_exprs_select(sel, visitor),
538        QueryBody::Compound(comp) => {
539            visit_exprs_query_body(&comp.left, visitor);
540            visit_exprs_query_body(&comp.right, visitor);
541            for o in &comp.order_by {
542                visit_expr(&o.expr, visitor);
543            }
544            if let Some(l) = &comp.limit {
545                visit_expr(l, visitor);
546            }
547            if let Some(o) = &comp.offset {
548                visit_expr(o, visitor);
549            }
550        }
551    }
552}
553
554fn visit_exprs_select(sel: &SelectStmt, visitor: &mut impl FnMut(&Expr)) {
555    for col in &sel.columns {
556        if let SelectColumn::Expr { expr, .. } = col {
557            visit_expr(expr, visitor);
558        }
559    }
560    for j in &sel.joins {
561        if let Some(on) = &j.on_clause {
562            visit_expr(on, visitor);
563        }
564    }
565    if let Some(w) = &sel.where_clause {
566        visit_expr(w, visitor);
567    }
568    for o in &sel.order_by {
569        visit_expr(&o.expr, visitor);
570    }
571    if let Some(l) = &sel.limit {
572        visit_expr(l, visitor);
573    }
574    if let Some(o) = &sel.offset {
575        visit_expr(o, visitor);
576    }
577    for g in &sel.group_by {
578        visit_expr(g, visitor);
579    }
580    if let Some(h) = &sel.having {
581        visit_expr(h, visitor);
582    }
583}
584
585fn visit_expr(expr: &Expr, visitor: &mut impl FnMut(&Expr)) {
586    visitor(expr);
587    match expr {
588        Expr::BinaryOp { left, right, .. } => {
589            visit_expr(left, visitor);
590            visit_expr(right, visitor);
591        }
592        Expr::UnaryOp { expr: e, .. } | Expr::IsNull(e) | Expr::IsNotNull(e) => {
593            visit_expr(e, visitor);
594        }
595        Expr::Function { args, .. } | Expr::Coalesce(args) => {
596            for a in args {
597                visit_expr(a, visitor);
598            }
599        }
600        Expr::InSubquery {
601            expr: e, subquery, ..
602        } => {
603            visit_expr(e, visitor);
604            visit_exprs_select(subquery, visitor);
605        }
606        Expr::InList { expr: e, list, .. } => {
607            visit_expr(e, visitor);
608            for l in list {
609                visit_expr(l, visitor);
610            }
611        }
612        Expr::Exists { subquery, .. } => visit_exprs_select(subquery, visitor),
613        Expr::ScalarSubquery(sq) => visit_exprs_select(sq, visitor),
614        Expr::InSet { expr: e, .. } => visit_expr(e, visitor),
615        Expr::Between {
616            expr: e, low, high, ..
617        } => {
618            visit_expr(e, visitor);
619            visit_expr(low, visitor);
620            visit_expr(high, visitor);
621        }
622        Expr::Like {
623            expr: e,
624            pattern,
625            escape,
626            ..
627        } => {
628            visit_expr(e, visitor);
629            visit_expr(pattern, visitor);
630            if let Some(esc) = escape {
631                visit_expr(esc, visitor);
632            }
633        }
634        Expr::Case {
635            operand,
636            conditions,
637            else_result,
638        } => {
639            if let Some(op) = operand {
640                visit_expr(op, visitor);
641            }
642            for (cond, then) in conditions {
643                visit_expr(cond, visitor);
644                visit_expr(then, visitor);
645            }
646            if let Some(el) = else_result {
647                visit_expr(el, visitor);
648            }
649        }
650        Expr::Cast { expr: e, .. } => visit_expr(e, visitor),
651        Expr::WindowFunction { args, spec, .. } => {
652            for a in args {
653                visit_expr(a, visitor);
654            }
655            for p in &spec.partition_by {
656                visit_expr(p, visitor);
657            }
658            for o in &spec.order_by {
659                visit_expr(&o.expr, visitor);
660            }
661            if let Some(ref frame) = spec.frame {
662                if let WindowFrameBound::Preceding(e) | WindowFrameBound::Following(e) =
663                    &frame.start
664                {
665                    visit_expr(e, visitor);
666                }
667                if let WindowFrameBound::Preceding(e) | WindowFrameBound::Following(e) = &frame.end
668                {
669                    visit_expr(e, visitor);
670                }
671            }
672        }
673        Expr::Literal(_)
674        | Expr::Column(_)
675        | Expr::QualifiedColumn { .. }
676        | Expr::CountStar
677        | Expr::Parameter(_) => {}
678    }
679}
680
681fn convert_statement(stmt: sp::Statement) -> Result<Statement> {
682    match stmt {
683        sp::Statement::CreateTable(ct) => convert_create_table(ct),
684        sp::Statement::CreateIndex(ci) => convert_create_index(ci),
685        sp::Statement::Drop {
686            object_type: sp::ObjectType::Table,
687            if_exists,
688            names,
689            ..
690        } => {
691            if names.len() != 1 {
692                return Err(SqlError::Unsupported("multi-table DROP".into()));
693            }
694            Ok(Statement::DropTable(DropTableStmt {
695                name: object_name_to_string(&names[0]),
696                if_exists,
697            }))
698        }
699        sp::Statement::Drop {
700            object_type: sp::ObjectType::Index,
701            if_exists,
702            names,
703            ..
704        } => {
705            if names.len() != 1 {
706                return Err(SqlError::Unsupported("multi-index DROP".into()));
707            }
708            Ok(Statement::DropIndex(DropIndexStmt {
709                index_name: object_name_to_string(&names[0]),
710                if_exists,
711            }))
712        }
713        sp::Statement::CreateView(cv) => convert_create_view(cv),
714        sp::Statement::Drop {
715            object_type: sp::ObjectType::View,
716            if_exists,
717            names,
718            ..
719        } => {
720            if names.len() != 1 {
721                return Err(SqlError::Unsupported("multi-view DROP".into()));
722            }
723            Ok(Statement::DropView(DropViewStmt {
724                name: object_name_to_string(&names[0]),
725                if_exists,
726            }))
727        }
728        sp::Statement::AlterTable(at) => convert_alter_table(at),
729        sp::Statement::Insert(insert) => convert_insert(insert),
730        sp::Statement::Query(query) => convert_query(*query),
731        sp::Statement::Update(update) => convert_update(update),
732        sp::Statement::Delete(delete) => convert_delete(delete),
733        sp::Statement::StartTransaction { .. } => Ok(Statement::Begin),
734        sp::Statement::Commit { chain: true, .. } => {
735            Err(SqlError::Unsupported("COMMIT AND CHAIN".into()))
736        }
737        sp::Statement::Commit { .. } => Ok(Statement::Commit),
738        sp::Statement::Rollback { chain: true, .. } => {
739            Err(SqlError::Unsupported("ROLLBACK AND CHAIN".into()))
740        }
741        sp::Statement::Rollback {
742            savepoint: Some(name),
743            ..
744        } => Ok(Statement::RollbackTo(name.value.to_ascii_lowercase())),
745        sp::Statement::Rollback { .. } => Ok(Statement::Rollback),
746        sp::Statement::Savepoint { name } => {
747            Ok(Statement::Savepoint(name.value.to_ascii_lowercase()))
748        }
749        sp::Statement::ReleaseSavepoint { name } => {
750            Ok(Statement::ReleaseSavepoint(name.value.to_ascii_lowercase()))
751        }
752        sp::Statement::Set(sp::Set::SetTimeZone { value, .. }) => {
753            // Accept a string literal or bare identifier (PG allows `SET TIME ZONE UTC`).
754            let zone = match value {
755                sp::Expr::Value(v) => match &v.value {
756                    sp::Value::SingleQuotedString(s) => s.clone(),
757                    sp::Value::DoubleQuotedString(s) => s.clone(),
758                    other => other.to_string(),
759                },
760                sp::Expr::Identifier(ident) => ident.value.clone(),
761                other => {
762                    return Err(SqlError::Parse(format!(
763                        "SET TIME ZONE expects a string literal or identifier, got: {other}"
764                    )))
765                }
766            };
767            Ok(Statement::SetTimezone(zone))
768        }
769        sp::Statement::Explain {
770            statement, analyze, ..
771        } => {
772            if analyze {
773                return Err(SqlError::Unsupported("EXPLAIN ANALYZE".into()));
774            }
775            let inner = convert_statement(*statement)?;
776            Ok(Statement::Explain(Box::new(inner)))
777        }
778        _ => Err(SqlError::Unsupported(format!("statement type: {}", stmt))),
779    }
780}
781
782/// Parse column options (NOT NULL, DEFAULT, CHECK, FK, UNIQUE) from a sqlparser ColumnDef.
783/// Returns (ColumnSpec, Option<ForeignKeyDef>, was_inline_pk, was_unique).
784fn convert_column_def(
785    col_def: &sp::ColumnDef,
786) -> Result<(ColumnSpec, Option<ForeignKeyDef>, bool, bool)> {
787    let col_name = col_def.name.value.clone();
788    let data_type = convert_data_type(&col_def.data_type)?;
789    let mut nullable = true;
790    let mut is_primary_key = false;
791    let mut is_unique = false;
792    let mut default_expr = None;
793    let mut default_sql = None;
794    let mut check_expr = None;
795    let mut check_sql = None;
796    let mut check_name = None;
797    let mut fk_def = None;
798
799    for opt in &col_def.options {
800        match &opt.option {
801            sp::ColumnOption::NotNull => nullable = false,
802            sp::ColumnOption::Null => nullable = true,
803            sp::ColumnOption::PrimaryKey(_) => {
804                is_primary_key = true;
805                nullable = false;
806            }
807            sp::ColumnOption::Unique(_) => is_unique = true,
808            sp::ColumnOption::Default(expr) => {
809                default_sql = Some(expr.to_string());
810                default_expr = Some(convert_expr(expr)?);
811            }
812            sp::ColumnOption::Check(check) => {
813                check_sql = Some(check.expr.to_string());
814                let converted = convert_expr(&check.expr)?;
815                if has_subquery(&converted) {
816                    return Err(SqlError::Unsupported("subquery in CHECK constraint".into()));
817                }
818                check_expr = Some(converted);
819                check_name = check.name.as_ref().map(|n| n.value.clone());
820            }
821            sp::ColumnOption::ForeignKey(fk) => {
822                convert_fk_actions(&fk.on_delete, &fk.on_update)?;
823                let ftable = object_name_to_string(&fk.foreign_table).to_ascii_lowercase();
824                let referred: Vec<String> = fk
825                    .referred_columns
826                    .iter()
827                    .map(|i| i.value.to_ascii_lowercase())
828                    .collect();
829                fk_def = Some(ForeignKeyDef {
830                    name: fk.name.as_ref().map(|n| n.value.clone()),
831                    columns: vec![col_name.to_ascii_lowercase()],
832                    foreign_table: ftable,
833                    referred_columns: referred,
834                });
835            }
836            _ => {}
837        }
838    }
839
840    let spec = ColumnSpec {
841        name: col_name,
842        data_type,
843        nullable,
844        is_primary_key,
845        default_expr,
846        default_sql,
847        check_expr,
848        check_sql,
849        check_name,
850    };
851    Ok((spec, fk_def, is_primary_key, is_unique))
852}
853
854fn convert_create_table(ct: sp::CreateTable) -> Result<Statement> {
855    let name = object_name_to_string(&ct.name);
856    let if_not_exists = ct.if_not_exists;
857
858    let mut columns = Vec::new();
859    let mut inline_pk: Vec<String> = Vec::new();
860    let mut foreign_keys: Vec<ForeignKeyDef> = Vec::new();
861    let mut unique_indices: Vec<UniqueIndexDef> = Vec::new();
862
863    for col_def in &ct.columns {
864        let (spec, fk_def, was_pk, was_unique) = convert_column_def(col_def)?;
865        if was_pk {
866            inline_pk.push(spec.name.clone());
867        }
868        if let Some(fk) = fk_def {
869            foreign_keys.push(fk);
870        }
871        if was_unique && !was_pk {
872            unique_indices.push(UniqueIndexDef {
873                name: None,
874                columns: vec![spec.name.to_ascii_lowercase()],
875            });
876        }
877        columns.push(spec);
878    }
879
880    let mut check_constraints: Vec<TableCheckConstraint> = Vec::new();
881
882    for constraint in &ct.constraints {
883        match constraint {
884            sp::TableConstraint::PrimaryKey(pk_constraint) => {
885                for idx_col in &pk_constraint.columns {
886                    let col_name = match &idx_col.column.expr {
887                        sp::Expr::Identifier(ident) => ident.value.clone(),
888                        _ => continue,
889                    };
890                    if !inline_pk.contains(&col_name) {
891                        inline_pk.push(col_name.clone());
892                    }
893                    if let Some(col) = columns.iter_mut().find(|c| c.name == col_name) {
894                        col.nullable = false;
895                        col.is_primary_key = true;
896                    }
897                }
898            }
899            sp::TableConstraint::Check(check) => {
900                let sql = check.expr.to_string();
901                let converted = convert_expr(&check.expr)?;
902                if has_subquery(&converted) {
903                    return Err(SqlError::Unsupported("subquery in CHECK constraint".into()));
904                }
905                check_constraints.push(TableCheckConstraint {
906                    name: check.name.as_ref().map(|n| n.value.clone()),
907                    expr: converted,
908                    sql,
909                });
910            }
911            sp::TableConstraint::ForeignKey(fk) => {
912                convert_fk_actions(&fk.on_delete, &fk.on_update)?;
913                let cols: Vec<String> = fk
914                    .columns
915                    .iter()
916                    .map(|i| i.value.to_ascii_lowercase())
917                    .collect();
918                let ftable = object_name_to_string(&fk.foreign_table).to_ascii_lowercase();
919                let referred: Vec<String> = fk
920                    .referred_columns
921                    .iter()
922                    .map(|i| i.value.to_ascii_lowercase())
923                    .collect();
924                foreign_keys.push(ForeignKeyDef {
925                    name: fk.name.as_ref().map(|n| n.value.clone()),
926                    columns: cols,
927                    foreign_table: ftable,
928                    referred_columns: referred,
929                });
930            }
931            sp::TableConstraint::Unique(u) => {
932                let cols: Vec<String> = u
933                    .columns
934                    .iter()
935                    .filter_map(|idx_col| match &idx_col.column.expr {
936                        sp::Expr::Identifier(ident) => Some(ident.value.to_ascii_lowercase()),
937                        _ => None,
938                    })
939                    .collect();
940                if !cols.is_empty() {
941                    unique_indices.push(UniqueIndexDef {
942                        name: u.name.as_ref().map(|n| n.value.clone()),
943                        columns: cols,
944                    });
945                }
946            }
947            _ => {}
948        }
949    }
950
951    Ok(Statement::CreateTable(CreateTableStmt {
952        name,
953        columns,
954        primary_key: inline_pk,
955        if_not_exists,
956        check_constraints,
957        foreign_keys,
958        unique_indices,
959    }))
960}
961
962fn convert_alter_table(at: sp::AlterTable) -> Result<Statement> {
963    let table = object_name_to_string(&at.name);
964    if at.operations.len() != 1 {
965        return Err(SqlError::Unsupported(
966            "ALTER TABLE with multiple operations".into(),
967        ));
968    }
969    let op = match at.operations.into_iter().next().unwrap() {
970        sp::AlterTableOperation::AddColumn {
971            column_def,
972            if_not_exists,
973            ..
974        } => {
975            let (spec, fk, _was_pk, _was_unique) = convert_column_def(&column_def)?;
976            AlterTableOp::AddColumn {
977                column: Box::new(spec),
978                foreign_key: fk,
979                if_not_exists,
980            }
981        }
982        sp::AlterTableOperation::DropColumn {
983            column_names,
984            if_exists,
985            ..
986        } => {
987            if column_names.len() != 1 {
988                return Err(SqlError::Unsupported(
989                    "DROP COLUMN with multiple columns".into(),
990                ));
991            }
992            AlterTableOp::DropColumn {
993                name: column_names.into_iter().next().unwrap().value,
994                if_exists,
995            }
996        }
997        sp::AlterTableOperation::RenameColumn {
998            old_column_name,
999            new_column_name,
1000        } => AlterTableOp::RenameColumn {
1001            old_name: old_column_name.value,
1002            new_name: new_column_name.value,
1003        },
1004        sp::AlterTableOperation::RenameTable { table_name } => {
1005            let new_name = match table_name {
1006                sp::RenameTableNameKind::To(name) | sp::RenameTableNameKind::As(name) => {
1007                    object_name_to_string(&name)
1008                }
1009            };
1010            AlterTableOp::RenameTable { new_name }
1011        }
1012        other => {
1013            return Err(SqlError::Unsupported(format!(
1014                "ALTER TABLE operation: {other}"
1015            )));
1016        }
1017    };
1018    Ok(Statement::AlterTable(Box::new(AlterTableStmt {
1019        table,
1020        op,
1021    })))
1022}
1023
1024fn convert_fk_actions(
1025    on_delete: &Option<sp::ReferentialAction>,
1026    on_update: &Option<sp::ReferentialAction>,
1027) -> Result<()> {
1028    for action in [on_delete, on_update] {
1029        match action {
1030            None
1031            | Some(sp::ReferentialAction::Restrict)
1032            | Some(sp::ReferentialAction::NoAction) => {}
1033            Some(other) => {
1034                return Err(SqlError::Unsupported(format!(
1035                    "FOREIGN KEY action: {other}"
1036                )));
1037            }
1038        }
1039    }
1040    Ok(())
1041}
1042
1043fn convert_create_index(ci: sp::CreateIndex) -> Result<Statement> {
1044    let index_name = ci
1045        .name
1046        .as_ref()
1047        .map(object_name_to_string)
1048        .ok_or_else(|| SqlError::Parse("index name required".into()))?;
1049
1050    let table_name = object_name_to_string(&ci.table_name);
1051
1052    let columns: Vec<String> = ci
1053        .columns
1054        .iter()
1055        .map(|idx_col| match &idx_col.column.expr {
1056            sp::Expr::Identifier(ident) => Ok(ident.value.clone()),
1057            other => Err(SqlError::Unsupported(format!("expression index: {other}"))),
1058        })
1059        .collect::<Result<_>>()?;
1060
1061    if columns.is_empty() {
1062        return Err(SqlError::Parse(
1063            "index must have at least one column".into(),
1064        ));
1065    }
1066
1067    Ok(Statement::CreateIndex(CreateIndexStmt {
1068        index_name,
1069        table_name,
1070        columns,
1071        unique: ci.unique,
1072        if_not_exists: ci.if_not_exists,
1073    }))
1074}
1075
1076fn convert_create_view(cv: sp::CreateView) -> Result<Statement> {
1077    let name = object_name_to_string(&cv.name);
1078
1079    if cv.materialized {
1080        return Err(SqlError::Unsupported("MATERIALIZED VIEW".into()));
1081    }
1082
1083    let sql = cv.query.to_string();
1084
1085    let dialect = GenericDialect {};
1086    let test = Parser::parse_sql(&dialect, &sql).map_err(|e| SqlError::Parse(e.to_string()))?;
1087    if test.is_empty() {
1088        return Err(SqlError::Parse("empty view definition".into()));
1089    }
1090    match &test[0] {
1091        sp::Statement::Query(_) => {}
1092        _ => {
1093            return Err(SqlError::Parse(
1094                "view body must be a SELECT statement".into(),
1095            ))
1096        }
1097    }
1098
1099    let column_aliases: Vec<String> = cv
1100        .columns
1101        .iter()
1102        .map(|c| c.name.value.to_ascii_lowercase())
1103        .collect();
1104
1105    Ok(Statement::CreateView(CreateViewStmt {
1106        name,
1107        sql,
1108        column_aliases,
1109        or_replace: cv.or_replace,
1110        if_not_exists: cv.if_not_exists,
1111    }))
1112}
1113
1114fn convert_insert(insert: sp::Insert) -> Result<Statement> {
1115    let table = match &insert.table {
1116        sp::TableObject::TableName(name) => object_name_to_string(name).to_ascii_lowercase(),
1117        _ => return Err(SqlError::Unsupported("INSERT into non-table object".into())),
1118    };
1119
1120    let columns: Vec<String> = insert
1121        .columns
1122        .iter()
1123        .map(|c| c.value.to_ascii_lowercase())
1124        .collect();
1125
1126    let query = insert
1127        .source
1128        .ok_or_else(|| SqlError::Parse("INSERT requires VALUES or SELECT".into()))?;
1129
1130    let source = match *query.body {
1131        sp::SetExpr::Values(sp::Values { rows, .. }) => {
1132            let mut result = Vec::new();
1133            for row in rows {
1134                let mut exprs = Vec::new();
1135                for expr in row {
1136                    exprs.push(convert_expr(&expr)?);
1137                }
1138                result.push(exprs);
1139            }
1140            InsertSource::Values(result)
1141        }
1142        _ => {
1143            let (ctes, recursive) = if let Some(ref with) = query.with {
1144                convert_with(with)?
1145            } else {
1146                (vec![], false)
1147            };
1148            let body = convert_query_body(&query)?;
1149            InsertSource::Select(Box::new(SelectQuery {
1150                ctes,
1151                recursive,
1152                body,
1153            }))
1154        }
1155    };
1156
1157    let on_conflict = insert.on.as_ref().map(convert_on_insert).transpose()?;
1158
1159    Ok(Statement::Insert(InsertStmt {
1160        table,
1161        columns,
1162        source,
1163        on_conflict,
1164    }))
1165}
1166
1167fn convert_on_insert(on: &sp::OnInsert) -> Result<OnConflictClause> {
1168    match on {
1169        sp::OnInsert::OnConflict(oc) => {
1170            let target = oc
1171                .conflict_target
1172                .as_ref()
1173                .map(convert_conflict_target)
1174                .transpose()?;
1175            let action = convert_on_conflict_action(&oc.action)?;
1176            Ok(OnConflictClause { target, action })
1177        }
1178        sp::OnInsert::DuplicateKeyUpdate(_) => Err(SqlError::Parse(
1179            "ON DUPLICATE KEY UPDATE is MySQL-specific; use ON CONFLICT".into(),
1180        )),
1181        _ => Err(SqlError::Parse("unsupported ON INSERT clause".into())),
1182    }
1183}
1184
1185fn convert_conflict_target(target: &sp::ConflictTarget) -> Result<ConflictTarget> {
1186    match target {
1187        sp::ConflictTarget::Columns(cols) => Ok(ConflictTarget::Columns(
1188            cols.iter().map(|c| c.value.to_ascii_lowercase()).collect(),
1189        )),
1190        sp::ConflictTarget::OnConstraint(name) => {
1191            if name.0.len() > 1 {
1192                return Err(SqlError::Parse(
1193                    "qualified constraint names not supported".into(),
1194                ));
1195            }
1196            Ok(ConflictTarget::Constraint(
1197                object_name_to_string(name).to_ascii_lowercase(),
1198            ))
1199        }
1200    }
1201}
1202
1203fn convert_on_conflict_action(action: &sp::OnConflictAction) -> Result<OnConflictAction> {
1204    match action {
1205        sp::OnConflictAction::DoNothing => Ok(OnConflictAction::DoNothing),
1206        sp::OnConflictAction::DoUpdate(du) => {
1207            let assignments = du
1208                .assignments
1209                .iter()
1210                .map(|a| {
1211                    let col = match &a.target {
1212                        sp::AssignmentTarget::ColumnName(name) => {
1213                            object_name_to_string(name).to_ascii_lowercase()
1214                        }
1215                        _ => {
1216                            return Err(SqlError::Unsupported(
1217                                "tuple assignment in ON CONFLICT".into(),
1218                            ))
1219                        }
1220                    };
1221                    let expr = convert_expr(&a.value)?;
1222                    Ok((col, expr))
1223                })
1224                .collect::<Result<_>>()?;
1225            let where_clause = du.selection.as_ref().map(convert_expr).transpose()?;
1226            Ok(OnConflictAction::DoUpdate {
1227                assignments,
1228                where_clause,
1229            })
1230        }
1231    }
1232}
1233
1234fn convert_select_body(select: &sp::Select) -> Result<SelectStmt> {
1235    let distinct = match &select.distinct {
1236        Some(sp::Distinct::Distinct) => true,
1237        Some(sp::Distinct::On(_)) => {
1238            return Err(SqlError::Unsupported("DISTINCT ON".into()));
1239        }
1240        _ => false,
1241    };
1242
1243    let (from, from_alias, joins) = if select.from.is_empty() {
1244        (String::new(), None, vec![])
1245    } else if select.from.len() == 1 {
1246        let table_with_joins = &select.from[0];
1247        let (name, alias) = match &table_with_joins.relation {
1248            sp::TableFactor::Table { name, alias, .. } => {
1249                let table_name = object_name_to_string(name);
1250                let alias_str = alias.as_ref().map(|a| a.name.value.clone());
1251                (table_name, alias_str)
1252            }
1253            _ => return Err(SqlError::Unsupported("non-table FROM source".into())),
1254        };
1255        let j = table_with_joins
1256            .joins
1257            .iter()
1258            .map(convert_join)
1259            .collect::<Result<Vec<_>>>()?;
1260        (name, alias, j)
1261    } else {
1262        return Err(SqlError::Unsupported("comma-separated FROM tables".into()));
1263    };
1264
1265    let columns: Vec<SelectColumn> = select
1266        .projection
1267        .iter()
1268        .map(convert_select_item)
1269        .collect::<Result<_>>()?;
1270
1271    let where_clause = select.selection.as_ref().map(convert_expr).transpose()?;
1272
1273    let group_by = match &select.group_by {
1274        sp::GroupByExpr::Expressions(exprs, _) => {
1275            exprs.iter().map(convert_expr).collect::<Result<_>>()?
1276        }
1277        sp::GroupByExpr::All(_) => {
1278            return Err(SqlError::Unsupported("GROUP BY ALL".into()));
1279        }
1280    };
1281
1282    let having = select.having.as_ref().map(convert_expr).transpose()?;
1283
1284    Ok(SelectStmt {
1285        columns,
1286        from,
1287        from_alias,
1288        joins,
1289        distinct,
1290        where_clause,
1291        order_by: vec![],
1292        limit: None,
1293        offset: None,
1294        group_by,
1295        having,
1296    })
1297}
1298
1299fn convert_set_expr(set_expr: &sp::SetExpr) -> Result<QueryBody> {
1300    match set_expr {
1301        sp::SetExpr::Select(sel) => Ok(QueryBody::Select(Box::new(convert_select_body(sel)?))),
1302        sp::SetExpr::SetOperation {
1303            op,
1304            set_quantifier,
1305            left,
1306            right,
1307        } => {
1308            let set_op = match op {
1309                sp::SetOperator::Union => SetOp::Union,
1310                sp::SetOperator::Intersect => SetOp::Intersect,
1311                sp::SetOperator::Except | sp::SetOperator::Minus => SetOp::Except,
1312            };
1313            let all = match set_quantifier {
1314                sp::SetQuantifier::All => true,
1315                sp::SetQuantifier::None | sp::SetQuantifier::Distinct => false,
1316                _ => {
1317                    return Err(SqlError::Unsupported("BY NAME set operations".into()));
1318                }
1319            };
1320            Ok(QueryBody::Compound(Box::new(CompoundSelect {
1321                op: set_op,
1322                all,
1323                left: Box::new(convert_set_expr(left)?),
1324                right: Box::new(convert_set_expr(right)?),
1325                order_by: vec![],
1326                limit: None,
1327                offset: None,
1328            })))
1329        }
1330        _ => Err(SqlError::Unsupported("unsupported set expression".into())),
1331    }
1332}
1333
1334fn convert_query_body(query: &sp::Query) -> Result<QueryBody> {
1335    let mut body = convert_set_expr(&query.body)?;
1336
1337    let order_by = if let Some(ref ob) = query.order_by {
1338        match &ob.kind {
1339            sp::OrderByKind::Expressions(exprs) => exprs
1340                .iter()
1341                .map(convert_order_by_expr)
1342                .collect::<Result<_>>()?,
1343            sp::OrderByKind::All { .. } => {
1344                return Err(SqlError::Unsupported("ORDER BY ALL".into()));
1345            }
1346        }
1347    } else {
1348        vec![]
1349    };
1350
1351    let (limit, offset) = match &query.limit_clause {
1352        Some(sp::LimitClause::LimitOffset { limit, offset, .. }) => {
1353            let l = limit.as_ref().map(convert_expr).transpose()?;
1354            let o = offset
1355                .as_ref()
1356                .map(|o| convert_expr(&o.value))
1357                .transpose()?;
1358            (l, o)
1359        }
1360        Some(sp::LimitClause::OffsetCommaLimit { limit, offset }) => {
1361            let l = Some(convert_expr(limit)?);
1362            let o = Some(convert_expr(offset)?);
1363            (l, o)
1364        }
1365        None => (None, None),
1366    };
1367
1368    match &mut body {
1369        QueryBody::Select(sel) => {
1370            sel.order_by = order_by;
1371            sel.limit = limit;
1372            sel.offset = offset;
1373        }
1374        QueryBody::Compound(comp) => {
1375            comp.order_by = order_by;
1376            comp.limit = limit;
1377            comp.offset = offset;
1378        }
1379    }
1380
1381    Ok(body)
1382}
1383
1384fn convert_subquery(query: &sp::Query) -> Result<SelectStmt> {
1385    if query.with.is_some() {
1386        return Err(SqlError::Unsupported("CTEs in subqueries".into()));
1387    }
1388    match convert_query_body(query)? {
1389        QueryBody::Select(s) => Ok(*s),
1390        QueryBody::Compound(_) => Err(SqlError::Unsupported(
1391            "UNION/INTERSECT/EXCEPT in subqueries".into(),
1392        )),
1393    }
1394}
1395
1396fn convert_with(with: &sp::With) -> Result<(Vec<CteDefinition>, bool)> {
1397    let mut names = std::collections::HashSet::new();
1398    let mut ctes = Vec::new();
1399    for cte in &with.cte_tables {
1400        let name = cte.alias.name.value.to_ascii_lowercase();
1401        if !names.insert(name.clone()) {
1402            return Err(SqlError::DuplicateCteName(name));
1403        }
1404        let column_aliases: Vec<String> = cte
1405            .alias
1406            .columns
1407            .iter()
1408            .map(|c| c.name.value.to_ascii_lowercase())
1409            .collect();
1410        let body = convert_query_body(&cte.query)?;
1411        ctes.push(CteDefinition {
1412            name,
1413            column_aliases,
1414            body,
1415        });
1416    }
1417    Ok((ctes, with.recursive))
1418}
1419
1420fn convert_query(query: sp::Query) -> Result<Statement> {
1421    let (ctes, recursive) = if let Some(ref with) = query.with {
1422        convert_with(with)?
1423    } else {
1424        (vec![], false)
1425    };
1426    let body = convert_query_body(&query)?;
1427    Ok(Statement::Select(Box::new(SelectQuery {
1428        ctes,
1429        recursive,
1430        body,
1431    })))
1432}
1433
1434fn convert_join(join: &sp::Join) -> Result<JoinClause> {
1435    let (join_type, constraint) = match &join.join_operator {
1436        sp::JoinOperator::Inner(c) => (JoinType::Inner, Some(c)),
1437        sp::JoinOperator::Join(c) => (JoinType::Inner, Some(c)),
1438        sp::JoinOperator::CrossJoin(c) => (JoinType::Cross, Some(c)),
1439        sp::JoinOperator::Left(c) => (JoinType::Left, Some(c)),
1440        sp::JoinOperator::LeftSemi(c) => (JoinType::Left, Some(c)),
1441        sp::JoinOperator::LeftAnti(c) => (JoinType::Left, Some(c)),
1442        sp::JoinOperator::Right(c) => (JoinType::Right, Some(c)),
1443        sp::JoinOperator::RightSemi(c) => (JoinType::Right, Some(c)),
1444        sp::JoinOperator::RightAnti(c) => (JoinType::Right, Some(c)),
1445        other => return Err(SqlError::Unsupported(format!("join type: {other:?}"))),
1446    };
1447
1448    let (name, alias) = match &join.relation {
1449        sp::TableFactor::Table { name, alias, .. } => {
1450            let table_name = object_name_to_string(name);
1451            let alias_str = alias.as_ref().map(|a| a.name.value.clone());
1452            (table_name, alias_str)
1453        }
1454        _ => return Err(SqlError::Unsupported("non-table JOIN source".into())),
1455    };
1456
1457    let on_clause = match constraint {
1458        Some(sp::JoinConstraint::On(expr)) => Some(convert_expr(expr)?),
1459        Some(sp::JoinConstraint::None) | None => None,
1460        Some(other) => return Err(SqlError::Unsupported(format!("join constraint: {other:?}"))),
1461    };
1462
1463    Ok(JoinClause {
1464        join_type,
1465        table: TableRef { name, alias },
1466        on_clause,
1467    })
1468}
1469
1470fn convert_update(update: sp::Update) -> Result<Statement> {
1471    let table = match &update.table.relation {
1472        sp::TableFactor::Table { name, .. } => object_name_to_string(name),
1473        _ => return Err(SqlError::Unsupported("non-table UPDATE target".into())),
1474    };
1475
1476    let assignments = update
1477        .assignments
1478        .iter()
1479        .map(|a| {
1480            let col = match &a.target {
1481                sp::AssignmentTarget::ColumnName(name) => object_name_to_string(name),
1482                _ => return Err(SqlError::Unsupported("tuple assignment".into())),
1483            };
1484            let expr = convert_expr(&a.value)?;
1485            Ok((col, expr))
1486        })
1487        .collect::<Result<_>>()?;
1488
1489    let where_clause = update.selection.as_ref().map(convert_expr).transpose()?;
1490
1491    Ok(Statement::Update(UpdateStmt {
1492        table,
1493        assignments,
1494        where_clause,
1495    }))
1496}
1497
1498fn convert_delete(delete: sp::Delete) -> Result<Statement> {
1499    let table_name = match &delete.from {
1500        sp::FromTable::WithFromKeyword(tables) => {
1501            if tables.len() != 1 {
1502                return Err(SqlError::Unsupported("multi-table DELETE".into()));
1503            }
1504            match &tables[0].relation {
1505                sp::TableFactor::Table { name, .. } => object_name_to_string(name),
1506                _ => return Err(SqlError::Unsupported("non-table DELETE target".into())),
1507            }
1508        }
1509        sp::FromTable::WithoutKeyword(tables) => {
1510            if tables.len() != 1 {
1511                return Err(SqlError::Unsupported("multi-table DELETE".into()));
1512            }
1513            match &tables[0].relation {
1514                sp::TableFactor::Table { name, .. } => object_name_to_string(name),
1515                _ => return Err(SqlError::Unsupported("non-table DELETE target".into())),
1516            }
1517        }
1518    };
1519
1520    let where_clause = delete.selection.as_ref().map(convert_expr).transpose()?;
1521
1522    Ok(Statement::Delete(DeleteStmt {
1523        table: table_name,
1524        where_clause,
1525    }))
1526}
1527
1528fn convert_expr(expr: &sp::Expr) -> Result<Expr> {
1529    match expr {
1530        sp::Expr::Value(v) => convert_value(&v.value),
1531        sp::Expr::Identifier(ident) => Ok(Expr::Column(ident.value.to_ascii_lowercase())),
1532        sp::Expr::CompoundIdentifier(parts) => {
1533            if parts.len() == 2 {
1534                Ok(Expr::QualifiedColumn {
1535                    table: parts[0].value.to_ascii_lowercase(),
1536                    column: parts[1].value.to_ascii_lowercase(),
1537                })
1538            } else {
1539                Ok(Expr::Column(
1540                    parts.last().unwrap().value.to_ascii_lowercase(),
1541                ))
1542            }
1543        }
1544        sp::Expr::BinaryOp { left, op, right } => {
1545            let bin_op = convert_bin_op(op)?;
1546            Ok(Expr::BinaryOp {
1547                left: Box::new(convert_expr(left)?),
1548                op: bin_op,
1549                right: Box::new(convert_expr(right)?),
1550            })
1551        }
1552        sp::Expr::UnaryOp { op, expr } => {
1553            let unary_op = match op {
1554                sp::UnaryOperator::Minus => UnaryOp::Neg,
1555                sp::UnaryOperator::Not => UnaryOp::Not,
1556                _ => return Err(SqlError::Unsupported(format!("unary op: {op}"))),
1557            };
1558            Ok(Expr::UnaryOp {
1559                op: unary_op,
1560                expr: Box::new(convert_expr(expr)?),
1561            })
1562        }
1563        sp::Expr::IsNull(e) => Ok(Expr::IsNull(Box::new(convert_expr(e)?))),
1564        sp::Expr::IsNotNull(e) => Ok(Expr::IsNotNull(Box::new(convert_expr(e)?))),
1565        sp::Expr::Nested(e) => convert_expr(e),
1566        sp::Expr::Function(func) => convert_function(func),
1567        sp::Expr::InSubquery {
1568            expr: e,
1569            subquery,
1570            negated,
1571        } => {
1572            let inner_expr = convert_expr(e)?;
1573            let stmt = convert_subquery(subquery)?;
1574            Ok(Expr::InSubquery {
1575                expr: Box::new(inner_expr),
1576                subquery: Box::new(stmt),
1577                negated: *negated,
1578            })
1579        }
1580        sp::Expr::InList {
1581            expr: e,
1582            list,
1583            negated,
1584        } => {
1585            let inner_expr = convert_expr(e)?;
1586            let items = list.iter().map(convert_expr).collect::<Result<Vec<_>>>()?;
1587            Ok(Expr::InList {
1588                expr: Box::new(inner_expr),
1589                list: items,
1590                negated: *negated,
1591            })
1592        }
1593        sp::Expr::Exists { subquery, negated } => {
1594            let stmt = convert_subquery(subquery)?;
1595            Ok(Expr::Exists {
1596                subquery: Box::new(stmt),
1597                negated: *negated,
1598            })
1599        }
1600        sp::Expr::Subquery(query) => {
1601            let stmt = convert_subquery(query)?;
1602            Ok(Expr::ScalarSubquery(Box::new(stmt)))
1603        }
1604        sp::Expr::Between {
1605            expr: e,
1606            negated,
1607            low,
1608            high,
1609        } => Ok(Expr::Between {
1610            expr: Box::new(convert_expr(e)?),
1611            low: Box::new(convert_expr(low)?),
1612            high: Box::new(convert_expr(high)?),
1613            negated: *negated,
1614        }),
1615        sp::Expr::Like {
1616            expr: e,
1617            negated,
1618            pattern,
1619            escape_char,
1620            ..
1621        } => {
1622            let esc = escape_char
1623                .as_ref()
1624                .map(convert_escape_value)
1625                .transpose()?
1626                .map(Box::new);
1627            Ok(Expr::Like {
1628                expr: Box::new(convert_expr(e)?),
1629                pattern: Box::new(convert_expr(pattern)?),
1630                escape: esc,
1631                negated: *negated,
1632            })
1633        }
1634        sp::Expr::ILike {
1635            expr: e,
1636            negated,
1637            pattern,
1638            escape_char,
1639            ..
1640        } => {
1641            let esc = escape_char
1642                .as_ref()
1643                .map(convert_escape_value)
1644                .transpose()?
1645                .map(Box::new);
1646            Ok(Expr::Like {
1647                expr: Box::new(convert_expr(e)?),
1648                pattern: Box::new(convert_expr(pattern)?),
1649                escape: esc,
1650                negated: *negated,
1651            })
1652        }
1653        sp::Expr::Case {
1654            operand,
1655            conditions,
1656            else_result,
1657            ..
1658        } => {
1659            let op = operand
1660                .as_ref()
1661                .map(|e| convert_expr(e))
1662                .transpose()?
1663                .map(Box::new);
1664            let conds: Vec<(Expr, Expr)> = conditions
1665                .iter()
1666                .map(|cw| Ok((convert_expr(&cw.condition)?, convert_expr(&cw.result)?)))
1667                .collect::<Result<_>>()?;
1668            let else_r = else_result
1669                .as_ref()
1670                .map(|e| convert_expr(e))
1671                .transpose()?
1672                .map(Box::new);
1673            Ok(Expr::Case {
1674                operand: op,
1675                conditions: conds,
1676                else_result: else_r,
1677            })
1678        }
1679        sp::Expr::Cast {
1680            expr: e,
1681            data_type: dt,
1682            ..
1683        } => {
1684            let target = convert_data_type(dt)?;
1685            Ok(Expr::Cast {
1686                expr: Box::new(convert_expr(e)?),
1687                data_type: target,
1688            })
1689        }
1690        sp::Expr::Substring {
1691            expr: e,
1692            substring_from,
1693            substring_for,
1694            ..
1695        } => {
1696            let mut args = vec![convert_expr(e)?];
1697            if let Some(from) = substring_from {
1698                args.push(convert_expr(from)?);
1699            }
1700            if let Some(f) = substring_for {
1701                args.push(convert_expr(f)?);
1702            }
1703            Ok(Expr::Function {
1704                name: "SUBSTR".into(),
1705                args,
1706                distinct: false,
1707            })
1708        }
1709        sp::Expr::Trim {
1710            expr: e,
1711            trim_where,
1712            trim_what,
1713            trim_characters,
1714        } => {
1715            let fn_name = match trim_where {
1716                Some(sp::TrimWhereField::Leading) => "LTRIM",
1717                Some(sp::TrimWhereField::Trailing) => "RTRIM",
1718                _ => "TRIM",
1719            };
1720            let mut args = vec![convert_expr(e)?];
1721            if let Some(what) = trim_what {
1722                args.push(convert_expr(what)?);
1723            } else if let Some(chars) = trim_characters {
1724                if let Some(first) = chars.first() {
1725                    args.push(convert_expr(first)?);
1726                }
1727            }
1728            Ok(Expr::Function {
1729                name: fn_name.into(),
1730                args,
1731                distinct: false,
1732            })
1733        }
1734        sp::Expr::Ceil { expr: e, .. } => Ok(Expr::Function {
1735            name: "CEIL".into(),
1736            args: vec![convert_expr(e)?],
1737            distinct: false,
1738        }),
1739        sp::Expr::Floor { expr: e, .. } => Ok(Expr::Function {
1740            name: "FLOOR".into(),
1741            args: vec![convert_expr(e)?],
1742            distinct: false,
1743        }),
1744        sp::Expr::Position { expr: e, r#in } => Ok(Expr::Function {
1745            name: "INSTR".into(),
1746            args: vec![convert_expr(r#in)?, convert_expr(e)?],
1747            distinct: false,
1748        }),
1749        // Typed literal: `DATE '2024-01-15'`, `TIMESTAMP '...'`, etc.
1750        sp::Expr::TypedString(ts) => {
1751            let raw = match &ts.value.value {
1752                sp::Value::SingleQuotedString(s) => s.clone(),
1753                sp::Value::DoubleQuotedString(s) => s.clone(),
1754                other => other.to_string(),
1755            };
1756            convert_typed_string(&ts.data_type, &raw)
1757        }
1758        // INTERVAL '...' — sqlparser emits Expr::Interval with a boxed Expr value + field qualifiers
1759        sp::Expr::Interval(iv) => convert_interval_expr(iv),
1760        // EXTRACT(field FROM src)
1761        sp::Expr::Extract { field, expr: e, .. } => {
1762            let field_name = match field {
1763                sp::DateTimeField::Year => "year",
1764                sp::DateTimeField::Month => "month",
1765                sp::DateTimeField::Week(_) => "week",
1766                sp::DateTimeField::Day => "day",
1767                sp::DateTimeField::Date => "day",
1768                sp::DateTimeField::Hour => "hour",
1769                sp::DateTimeField::Minute => "minute",
1770                sp::DateTimeField::Second => "second",
1771                sp::DateTimeField::Millisecond => "milliseconds",
1772                sp::DateTimeField::Microsecond => "microseconds",
1773                sp::DateTimeField::Microseconds => "microseconds",
1774                sp::DateTimeField::Milliseconds => "milliseconds",
1775                sp::DateTimeField::Dow => "dow",
1776                sp::DateTimeField::Isodow => "isodow",
1777                sp::DateTimeField::Doy => "doy",
1778                sp::DateTimeField::Epoch => "epoch",
1779                sp::DateTimeField::Quarter => "quarter",
1780                sp::DateTimeField::Decade => "decade",
1781                sp::DateTimeField::Century => "century",
1782                sp::DateTimeField::Millennium => "millennium",
1783                sp::DateTimeField::Isoyear => "isoyear",
1784                sp::DateTimeField::Julian => "julian",
1785                other => {
1786                    return Err(SqlError::InvalidExtractField(format!("{other:?}")));
1787                }
1788            };
1789            Ok(Expr::Function {
1790                name: "EXTRACT".into(),
1791                args: vec![
1792                    Expr::Literal(Value::Text(field_name.into())),
1793                    convert_expr(e)?,
1794                ],
1795                distinct: false,
1796            })
1797        }
1798        // `AT TIME ZONE 'zone'` operator — desugars to AT_TIMEZONE(ts, zone) scalar function.
1799        sp::Expr::AtTimeZone {
1800            timestamp,
1801            time_zone,
1802        } => Ok(Expr::Function {
1803            name: "AT_TIMEZONE".into(),
1804            args: vec![convert_expr(timestamp)?, convert_expr(time_zone)?],
1805            distinct: false,
1806        }),
1807        _ => Err(SqlError::Unsupported(format!("expression: {expr}"))),
1808    }
1809}
1810
1811fn convert_value(val: &sp::Value) -> Result<Expr> {
1812    match val {
1813        sp::Value::Number(n, _) => {
1814            if let Ok(i) = n.parse::<i64>() {
1815                Ok(Expr::Literal(Value::Integer(i)))
1816            } else if let Ok(f) = n.parse::<f64>() {
1817                Ok(Expr::Literal(Value::Real(f)))
1818            } else {
1819                Err(SqlError::InvalidValue(format!("cannot parse number: {n}")))
1820            }
1821        }
1822        sp::Value::SingleQuotedString(s) => Ok(Expr::Literal(Value::Text(s.as_str().into()))),
1823        sp::Value::Boolean(b) => Ok(Expr::Literal(Value::Boolean(*b))),
1824        sp::Value::Null => Ok(Expr::Literal(Value::Null)),
1825        sp::Value::Placeholder(s) => {
1826            let idx_str = s
1827                .strip_prefix('$')
1828                .ok_or_else(|| SqlError::Parse(format!("invalid placeholder: {s}")))?;
1829            let idx: usize = idx_str
1830                .parse()
1831                .map_err(|_| SqlError::Parse(format!("invalid placeholder index: {s}")))?;
1832            if idx == 0 {
1833                return Err(SqlError::Parse("placeholder index must be >= 1".into()));
1834            }
1835            Ok(Expr::Parameter(idx))
1836        }
1837        _ => Err(SqlError::Unsupported(format!("value type: {val}"))),
1838    }
1839}
1840
1841fn convert_typed_string(dt: &sp::DataType, value: &str) -> Result<Expr> {
1842    let s = value.trim_matches('\'');
1843    match dt {
1844        sp::DataType::Date => {
1845            let d = crate::datetime::parse_date(s)?;
1846            Ok(Expr::Literal(Value::Date(d)))
1847        }
1848        sp::DataType::Time(_, _) => {
1849            let t = crate::datetime::parse_time(s)?;
1850            Ok(Expr::Literal(Value::Time(t)))
1851        }
1852        sp::DataType::Timestamp(_, _) => {
1853            let t = crate::datetime::parse_timestamp(s)?;
1854            Ok(Expr::Literal(Value::Timestamp(t)))
1855        }
1856        sp::DataType::Interval { .. } => {
1857            let (months, days, micros) = crate::datetime::parse_interval(s)?;
1858            Ok(Expr::Literal(Value::Interval {
1859                months,
1860                days,
1861                micros,
1862            }))
1863        }
1864        _ => {
1865            let target = convert_data_type(dt)?;
1866            Ok(Expr::Cast {
1867                expr: Box::new(Expr::Literal(Value::Text(s.into()))),
1868                data_type: target,
1869            })
1870        }
1871    }
1872}
1873
1874fn convert_interval_expr(iv: &sp::Interval) -> Result<Expr> {
1875    let raw = match iv.value.as_ref() {
1876        sp::Expr::Value(v) => match &v.value {
1877            sp::Value::SingleQuotedString(s) => s.clone(),
1878            sp::Value::Number(n, _) => n.clone(),
1879            other => {
1880                return Err(SqlError::InvalidIntervalLiteral(format!(
1881                    "unsupported inner value: {other}"
1882                )))
1883            }
1884        },
1885        other => {
1886            return Err(SqlError::InvalidIntervalLiteral(format!(
1887                "unsupported inner expr: {other}"
1888            )))
1889        }
1890    };
1891
1892    // SQL-standard form `INTERVAL '5' DAY` — append the unit to the literal.
1893    let with_unit = if let Some(field) = &iv.leading_field {
1894        let unit_name = match field {
1895            sp::DateTimeField::Year => "years",
1896            sp::DateTimeField::Month => "months",
1897            sp::DateTimeField::Week(_) => "weeks",
1898            sp::DateTimeField::Day => "days",
1899            sp::DateTimeField::Hour => "hours",
1900            sp::DateTimeField::Minute => "minutes",
1901            sp::DateTimeField::Second => "seconds",
1902            _ => {
1903                return Err(SqlError::InvalidIntervalLiteral(format!(
1904                    "unsupported leading field: {field:?}"
1905                )))
1906            }
1907        };
1908        format!("{raw} {unit_name}")
1909    } else {
1910        raw
1911    };
1912
1913    let (months, days, micros) = crate::datetime::parse_interval(&with_unit)?;
1914    Ok(Expr::Literal(Value::Interval {
1915        months,
1916        days,
1917        micros,
1918    }))
1919}
1920
1921fn convert_escape_value(val: &sp::Value) -> Result<Expr> {
1922    match val {
1923        sp::Value::SingleQuotedString(s) => Ok(Expr::Literal(Value::Text(s.as_str().into()))),
1924        _ => Err(SqlError::Unsupported(format!("ESCAPE value: {val}"))),
1925    }
1926}
1927
1928fn convert_bin_op(op: &sp::BinaryOperator) -> Result<BinOp> {
1929    match op {
1930        sp::BinaryOperator::Plus => Ok(BinOp::Add),
1931        sp::BinaryOperator::Minus => Ok(BinOp::Sub),
1932        sp::BinaryOperator::Multiply => Ok(BinOp::Mul),
1933        sp::BinaryOperator::Divide => Ok(BinOp::Div),
1934        sp::BinaryOperator::Modulo => Ok(BinOp::Mod),
1935        sp::BinaryOperator::Eq => Ok(BinOp::Eq),
1936        sp::BinaryOperator::NotEq => Ok(BinOp::NotEq),
1937        sp::BinaryOperator::Lt => Ok(BinOp::Lt),
1938        sp::BinaryOperator::Gt => Ok(BinOp::Gt),
1939        sp::BinaryOperator::LtEq => Ok(BinOp::LtEq),
1940        sp::BinaryOperator::GtEq => Ok(BinOp::GtEq),
1941        sp::BinaryOperator::And => Ok(BinOp::And),
1942        sp::BinaryOperator::Or => Ok(BinOp::Or),
1943        sp::BinaryOperator::StringConcat => Ok(BinOp::Concat),
1944        _ => Err(SqlError::Unsupported(format!("binary op: {op}"))),
1945    }
1946}
1947
1948fn convert_function(func: &sp::Function) -> Result<Expr> {
1949    let name = object_name_to_string(&func.name).to_ascii_uppercase();
1950
1951    let (args, is_count_star, distinct) = match &func.args {
1952        sp::FunctionArguments::List(list) => {
1953            let distinct = matches!(
1954                list.duplicate_treatment,
1955                Some(sp::DuplicateTreatment::Distinct)
1956            );
1957            if list.args.is_empty() && name == "COUNT" {
1958                (vec![], true, distinct)
1959            } else {
1960                let mut count_star = false;
1961                let args = list
1962                    .args
1963                    .iter()
1964                    .map(|arg| match arg {
1965                        sp::FunctionArg::Unnamed(sp::FunctionArgExpr::Expr(e)) => convert_expr(e),
1966                        sp::FunctionArg::Unnamed(sp::FunctionArgExpr::Wildcard) => {
1967                            if name == "COUNT" {
1968                                count_star = true;
1969                                Ok(Expr::CountStar)
1970                            } else {
1971                                Err(SqlError::Unsupported(format!("{name}(*)")))
1972                            }
1973                        }
1974                        _ => Err(SqlError::Unsupported(format!(
1975                            "function arg type in {name}"
1976                        ))),
1977                    })
1978                    .collect::<Result<Vec<_>>>()?;
1979                if name == "COUNT" && args.len() == 1 && count_star {
1980                    (vec![], true, distinct)
1981                } else {
1982                    (args, false, distinct)
1983                }
1984            }
1985        }
1986        sp::FunctionArguments::None => {
1987            if name == "COUNT" {
1988                (vec![], true, false)
1989            } else {
1990                (vec![], false, false)
1991            }
1992        }
1993        sp::FunctionArguments::Subquery(_) => {
1994            return Err(SqlError::Unsupported("subquery in function".into()));
1995        }
1996    };
1997
1998    // Window function: check OVER before any other special handling
1999    if let Some(over) = &func.over {
2000        let spec = match over {
2001            sp::WindowType::WindowSpec(ws) => convert_window_spec(ws)?,
2002            sp::WindowType::NamedWindow(_) => {
2003                return Err(SqlError::Unsupported("named windows".into()));
2004            }
2005        };
2006        return Ok(Expr::WindowFunction { name, args, spec });
2007    }
2008
2009    // Non-window special forms
2010    if is_count_star {
2011        return Ok(Expr::CountStar);
2012    }
2013
2014    if name == "COALESCE" {
2015        if args.is_empty() {
2016            return Err(SqlError::Parse(
2017                "COALESCE requires at least one argument".into(),
2018            ));
2019        }
2020        return Ok(Expr::Coalesce(args));
2021    }
2022
2023    if name == "NULLIF" {
2024        if args.len() != 2 {
2025            return Err(SqlError::Parse(
2026                "NULLIF requires exactly two arguments".into(),
2027            ));
2028        }
2029        return Ok(Expr::Case {
2030            operand: None,
2031            conditions: vec![(
2032                Expr::BinaryOp {
2033                    left: Box::new(args[0].clone()),
2034                    op: BinOp::Eq,
2035                    right: Box::new(args[1].clone()),
2036                },
2037                Expr::Literal(Value::Null),
2038            )],
2039            else_result: Some(Box::new(args[0].clone())),
2040        });
2041    }
2042
2043    if name == "IIF" {
2044        if args.len() != 3 {
2045            return Err(SqlError::Parse(
2046                "IIF requires exactly three arguments".into(),
2047            ));
2048        }
2049        return Ok(Expr::Case {
2050            operand: None,
2051            conditions: vec![(args[0].clone(), args[1].clone())],
2052            else_result: Some(Box::new(args[2].clone())),
2053        });
2054    }
2055
2056    Ok(Expr::Function {
2057        name,
2058        args,
2059        distinct,
2060    })
2061}
2062
2063fn convert_window_spec(ws: &sp::WindowSpec) -> Result<WindowSpec> {
2064    let partition_by = ws
2065        .partition_by
2066        .iter()
2067        .map(convert_expr)
2068        .collect::<Result<Vec<_>>>()?;
2069    let order_by = ws
2070        .order_by
2071        .iter()
2072        .map(convert_order_by_expr)
2073        .collect::<Result<Vec<_>>>()?;
2074    let frame = ws
2075        .window_frame
2076        .as_ref()
2077        .map(convert_window_frame)
2078        .transpose()?;
2079    Ok(WindowSpec {
2080        partition_by,
2081        order_by,
2082        frame,
2083    })
2084}
2085
2086fn convert_window_frame(wf: &sp::WindowFrame) -> Result<WindowFrame> {
2087    let units = match wf.units {
2088        sp::WindowFrameUnits::Rows => WindowFrameUnits::Rows,
2089        sp::WindowFrameUnits::Range => WindowFrameUnits::Range,
2090        sp::WindowFrameUnits::Groups => {
2091            return Err(SqlError::Unsupported("GROUPS window frame".into()));
2092        }
2093    };
2094    let start = convert_window_frame_bound(&wf.start_bound)?;
2095    let end = match &wf.end_bound {
2096        Some(b) => convert_window_frame_bound(b)?,
2097        None => WindowFrameBound::CurrentRow,
2098    };
2099    Ok(WindowFrame { units, start, end })
2100}
2101
2102fn convert_window_frame_bound(b: &sp::WindowFrameBound) -> Result<WindowFrameBound> {
2103    match b {
2104        sp::WindowFrameBound::CurrentRow => Ok(WindowFrameBound::CurrentRow),
2105        sp::WindowFrameBound::Preceding(None) => Ok(WindowFrameBound::UnboundedPreceding),
2106        sp::WindowFrameBound::Preceding(Some(e)) => {
2107            Ok(WindowFrameBound::Preceding(Box::new(convert_expr(e)?)))
2108        }
2109        sp::WindowFrameBound::Following(None) => Ok(WindowFrameBound::UnboundedFollowing),
2110        sp::WindowFrameBound::Following(Some(e)) => {
2111            Ok(WindowFrameBound::Following(Box::new(convert_expr(e)?)))
2112        }
2113    }
2114}
2115
2116fn convert_select_item(item: &sp::SelectItem) -> Result<SelectColumn> {
2117    match item {
2118        sp::SelectItem::Wildcard(_) => Ok(SelectColumn::AllColumns),
2119        sp::SelectItem::UnnamedExpr(e) => {
2120            let expr = convert_expr(e)?;
2121            Ok(SelectColumn::Expr { expr, alias: None })
2122        }
2123        sp::SelectItem::ExprWithAlias { expr, alias } => {
2124            let expr = convert_expr(expr)?;
2125            Ok(SelectColumn::Expr {
2126                expr,
2127                alias: Some(alias.value.clone()),
2128            })
2129        }
2130        sp::SelectItem::QualifiedWildcard(_, _) => {
2131            Err(SqlError::Unsupported("qualified wildcard (table.*)".into()))
2132        }
2133    }
2134}
2135
2136fn convert_order_by_expr(expr: &sp::OrderByExpr) -> Result<OrderByItem> {
2137    let e = convert_expr(&expr.expr)?;
2138    let descending = expr.options.asc.map(|asc| !asc).unwrap_or(false);
2139    let nulls_first = expr.options.nulls_first;
2140
2141    Ok(OrderByItem {
2142        expr: e,
2143        descending,
2144        nulls_first,
2145    })
2146}
2147
2148fn convert_data_type(dt: &sp::DataType) -> Result<DataType> {
2149    match dt {
2150        sp::DataType::Int(_)
2151        | sp::DataType::Integer(_)
2152        | sp::DataType::BigInt(_)
2153        | sp::DataType::SmallInt(_)
2154        | sp::DataType::TinyInt(_)
2155        | sp::DataType::Int2(_)
2156        | sp::DataType::Int4(_)
2157        | sp::DataType::Int8(_) => Ok(DataType::Integer),
2158
2159        sp::DataType::Real
2160        | sp::DataType::Double(..)
2161        | sp::DataType::DoublePrecision
2162        | sp::DataType::Float(_)
2163        | sp::DataType::Float4
2164        | sp::DataType::Float64 => Ok(DataType::Real),
2165
2166        sp::DataType::Varchar(_)
2167        | sp::DataType::Text
2168        | sp::DataType::Char(_)
2169        | sp::DataType::Character(_)
2170        | sp::DataType::String(_) => Ok(DataType::Text),
2171
2172        sp::DataType::Blob(_) | sp::DataType::Bytea => Ok(DataType::Blob),
2173
2174        sp::DataType::Boolean | sp::DataType::Bool => Ok(DataType::Boolean),
2175
2176        sp::DataType::Date => Ok(DataType::Date),
2177        sp::DataType::Time(_, _) => Ok(DataType::Time),
2178        sp::DataType::Timestamp(_, _) => Ok(DataType::Timestamp),
2179        sp::DataType::Interval { .. } => Ok(DataType::Interval),
2180
2181        _ => Err(SqlError::Unsupported(format!("data type: {dt}"))),
2182    }
2183}
2184
2185fn object_name_to_string(name: &sp::ObjectName) -> String {
2186    name.0
2187        .iter()
2188        .filter_map(|p| match p {
2189            sp::ObjectNamePart::Identifier(ident) => Some(ident.value.clone()),
2190            _ => None,
2191        })
2192        .collect::<Vec<_>>()
2193        .join(".")
2194}
2195
2196#[cfg(test)]
2197mod tests {
2198    use super::*;
2199
2200    #[test]
2201    fn parse_create_table() {
2202        let stmt = parse_sql(
2203            "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL, age INTEGER)",
2204        )
2205        .unwrap();
2206
2207        match stmt {
2208            Statement::CreateTable(ct) => {
2209                assert_eq!(ct.name, "users");
2210                assert_eq!(ct.columns.len(), 3);
2211                assert_eq!(ct.columns[0].name, "id");
2212                assert_eq!(ct.columns[0].data_type, DataType::Integer);
2213                assert!(ct.columns[0].is_primary_key);
2214                assert!(!ct.columns[0].nullable);
2215                assert_eq!(ct.columns[1].name, "name");
2216                assert_eq!(ct.columns[1].data_type, DataType::Text);
2217                assert!(!ct.columns[1].nullable);
2218                assert_eq!(ct.columns[2].name, "age");
2219                assert!(ct.columns[2].nullable);
2220                assert_eq!(ct.primary_key, vec!["id"]);
2221            }
2222            _ => panic!("expected CreateTable"),
2223        }
2224    }
2225
2226    #[test]
2227    fn parse_create_table_if_not_exists() {
2228        let stmt = parse_sql("CREATE TABLE IF NOT EXISTS t (id INT PRIMARY KEY)").unwrap();
2229        match stmt {
2230            Statement::CreateTable(ct) => assert!(ct.if_not_exists),
2231            _ => panic!("expected CreateTable"),
2232        }
2233    }
2234
2235    #[test]
2236    fn parse_drop_table() {
2237        let stmt = parse_sql("DROP TABLE users").unwrap();
2238        match stmt {
2239            Statement::DropTable(dt) => {
2240                assert_eq!(dt.name, "users");
2241                assert!(!dt.if_exists);
2242            }
2243            _ => panic!("expected DropTable"),
2244        }
2245    }
2246
2247    #[test]
2248    fn parse_drop_table_if_exists() {
2249        let stmt = parse_sql("DROP TABLE IF EXISTS users").unwrap();
2250        match stmt {
2251            Statement::DropTable(dt) => assert!(dt.if_exists),
2252            _ => panic!("expected DropTable"),
2253        }
2254    }
2255
2256    #[test]
2257    fn parse_insert() {
2258        let stmt =
2259            parse_sql("INSERT INTO users (id, name) VALUES (1, 'Alice'), (2, 'Bob')").unwrap();
2260
2261        match stmt {
2262            Statement::Insert(ins) => {
2263                assert_eq!(ins.table, "users");
2264                assert_eq!(ins.columns, vec!["id", "name"]);
2265                let values = match &ins.source {
2266                    InsertSource::Values(v) => v,
2267                    _ => panic!("expected Values"),
2268                };
2269                assert_eq!(values.len(), 2);
2270                assert!(matches!(values[0][0], Expr::Literal(Value::Integer(1))));
2271                assert!(matches!(&values[0][1], Expr::Literal(Value::Text(s)) if s == "Alice"));
2272                assert!(ins.on_conflict.is_none());
2273            }
2274            _ => panic!("expected Insert"),
2275        }
2276    }
2277
2278    #[test]
2279    fn parse_upsert_do_nothing() {
2280        let stmt =
2281            parse_sql("INSERT INTO t (id, v) VALUES (1, 'a') ON CONFLICT (id) DO NOTHING").unwrap();
2282        match stmt {
2283            Statement::Insert(ins) => {
2284                let oc = ins.on_conflict.expect("expected on_conflict");
2285                match oc.target.expect("target") {
2286                    ConflictTarget::Columns(cols) => assert_eq!(cols, vec!["id"]),
2287                    _ => panic!("expected Columns target"),
2288                }
2289                assert!(matches!(oc.action, OnConflictAction::DoNothing));
2290            }
2291            _ => panic!("expected Insert"),
2292        }
2293    }
2294
2295    #[test]
2296    fn parse_upsert_do_nothing_no_target() {
2297        let stmt = parse_sql("INSERT INTO t VALUES (1, 'a') ON CONFLICT DO NOTHING").unwrap();
2298        match stmt {
2299            Statement::Insert(ins) => {
2300                let oc = ins.on_conflict.expect("expected on_conflict");
2301                assert!(oc.target.is_none());
2302                assert!(matches!(oc.action, OnConflictAction::DoNothing));
2303            }
2304            _ => panic!("expected Insert"),
2305        }
2306    }
2307
2308    #[test]
2309    fn parse_upsert_do_update_simple() {
2310        let stmt = parse_sql(
2311            "INSERT INTO t (id, v) VALUES (1, 'a') ON CONFLICT (id) DO UPDATE SET v = 'b'",
2312        )
2313        .unwrap();
2314        match stmt {
2315            Statement::Insert(ins) => {
2316                let oc = ins.on_conflict.expect("expected on_conflict");
2317                match oc.action {
2318                    OnConflictAction::DoUpdate {
2319                        assignments,
2320                        where_clause,
2321                    } => {
2322                        assert_eq!(assignments.len(), 1);
2323                        assert_eq!(assignments[0].0, "v");
2324                        assert!(where_clause.is_none());
2325                    }
2326                    _ => panic!("expected DoUpdate"),
2327                }
2328            }
2329            _ => panic!("expected Insert"),
2330        }
2331    }
2332
2333    #[test]
2334    fn parse_upsert_do_update_excluded() {
2335        let stmt = parse_sql(
2336            "INSERT INTO t (id, v) VALUES (1, 'a') \
2337             ON CONFLICT (id) DO UPDATE SET v = excluded.v",
2338        )
2339        .unwrap();
2340        match stmt {
2341            Statement::Insert(ins) => {
2342                let oc = ins.on_conflict.expect("expected on_conflict");
2343                let assignments = match oc.action {
2344                    OnConflictAction::DoUpdate { assignments, .. } => assignments,
2345                    _ => panic!("expected DoUpdate"),
2346                };
2347                match &assignments[0].1 {
2348                    Expr::QualifiedColumn { table, column } => {
2349                        assert_eq!(table, "excluded");
2350                        assert_eq!(column, "v");
2351                    }
2352                    _ => panic!("expected QualifiedColumn"),
2353                }
2354            }
2355            _ => panic!("expected Insert"),
2356        }
2357    }
2358
2359    #[test]
2360    fn parse_upsert_do_update_where() {
2361        let stmt = parse_sql(
2362            "INSERT INTO t (id, v) VALUES (1, 'a') \
2363             ON CONFLICT (id) DO UPDATE SET v = excluded.v WHERE t.v < 'z'",
2364        )
2365        .unwrap();
2366        match stmt {
2367            Statement::Insert(ins) => {
2368                let oc = ins.on_conflict.expect("expected on_conflict");
2369                match oc.action {
2370                    OnConflictAction::DoUpdate { where_clause, .. } => {
2371                        assert!(where_clause.is_some());
2372                    }
2373                    _ => panic!("expected DoUpdate"),
2374                }
2375            }
2376            _ => panic!("expected Insert"),
2377        }
2378    }
2379
2380    #[test]
2381    fn parse_upsert_on_constraint_named() {
2382        let stmt = parse_sql(
2383            "INSERT INTO t (id, v) VALUES (1, 'a') \
2384             ON CONFLICT ON CONSTRAINT t_v_idx DO NOTHING",
2385        )
2386        .unwrap();
2387        match stmt {
2388            Statement::Insert(ins) => {
2389                let oc = ins.on_conflict.expect("expected on_conflict");
2390                match oc.target.expect("target") {
2391                    ConflictTarget::Constraint(name) => assert_eq!(name, "t_v_idx"),
2392                    _ => panic!("expected Constraint target"),
2393                }
2394            }
2395            _ => panic!("expected Insert"),
2396        }
2397    }
2398
2399    #[test]
2400    fn parse_upsert_rejects_duplicate_key_update() {
2401        let err = parse_sql("INSERT INTO t (id) VALUES (1) ON DUPLICATE KEY UPDATE id = 2")
2402            .expect_err("should reject MySQL syntax");
2403        let msg = format!("{err}");
2404        assert!(msg.contains("ON DUPLICATE KEY UPDATE") || msg.contains("MySQL"));
2405    }
2406
2407    #[test]
2408    fn parse_upsert_lowercases_conflict_target() {
2409        let stmt = parse_sql("INSERT INTO t (id) VALUES (1) ON CONFLICT (ID) DO NOTHING").unwrap();
2410        match stmt {
2411            Statement::Insert(ins) => {
2412                let oc = ins.on_conflict.expect("expected on_conflict");
2413                match oc.target.expect("target") {
2414                    ConflictTarget::Columns(cols) => assert_eq!(cols, vec!["id"]),
2415                    _ => panic!("expected Columns"),
2416                }
2417            }
2418            _ => panic!("expected Insert"),
2419        }
2420    }
2421
2422    #[test]
2423    fn parse_select_all() {
2424        let stmt = parse_sql("SELECT * FROM users").unwrap();
2425        match stmt {
2426            Statement::Select(sq) => match sq.body {
2427                QueryBody::Select(sel) => {
2428                    assert_eq!(sel.from, "users");
2429                    assert!(matches!(sel.columns[0], SelectColumn::AllColumns));
2430                    assert!(sel.where_clause.is_none());
2431                }
2432                _ => panic!("expected QueryBody::Select"),
2433            },
2434            _ => panic!("expected Select"),
2435        }
2436    }
2437
2438    #[test]
2439    fn parse_select_where() {
2440        let stmt = parse_sql("SELECT id, name FROM users WHERE age > 18").unwrap();
2441        match stmt {
2442            Statement::Select(sq) => match sq.body {
2443                QueryBody::Select(sel) => {
2444                    assert_eq!(sel.columns.len(), 2);
2445                    assert!(sel.where_clause.is_some());
2446                }
2447                _ => panic!("expected QueryBody::Select"),
2448            },
2449            _ => panic!("expected Select"),
2450        }
2451    }
2452
2453    #[test]
2454    fn parse_select_order_limit() {
2455        let stmt = parse_sql("SELECT * FROM users ORDER BY name ASC LIMIT 10 OFFSET 5").unwrap();
2456        match stmt {
2457            Statement::Select(sq) => match sq.body {
2458                QueryBody::Select(sel) => {
2459                    assert_eq!(sel.order_by.len(), 1);
2460                    assert!(!sel.order_by[0].descending);
2461                    assert!(sel.limit.is_some());
2462                    assert!(sel.offset.is_some());
2463                }
2464                _ => panic!("expected QueryBody::Select"),
2465            },
2466            _ => panic!("expected Select"),
2467        }
2468    }
2469
2470    #[test]
2471    fn parse_update() {
2472        let stmt = parse_sql("UPDATE users SET name = 'Bob' WHERE id = 1").unwrap();
2473        match stmt {
2474            Statement::Update(upd) => {
2475                assert_eq!(upd.table, "users");
2476                assert_eq!(upd.assignments.len(), 1);
2477                assert_eq!(upd.assignments[0].0, "name");
2478                assert!(upd.where_clause.is_some());
2479            }
2480            _ => panic!("expected Update"),
2481        }
2482    }
2483
2484    #[test]
2485    fn parse_delete() {
2486        let stmt = parse_sql("DELETE FROM users WHERE id = 1").unwrap();
2487        match stmt {
2488            Statement::Delete(del) => {
2489                assert_eq!(del.table, "users");
2490                assert!(del.where_clause.is_some());
2491            }
2492            _ => panic!("expected Delete"),
2493        }
2494    }
2495
2496    #[test]
2497    fn parse_aggregate() {
2498        let stmt = parse_sql("SELECT COUNT(*), SUM(age) FROM users").unwrap();
2499        match stmt {
2500            Statement::Select(sq) => match sq.body {
2501                QueryBody::Select(sel) => {
2502                    assert_eq!(sel.columns.len(), 2);
2503                    match &sel.columns[0] {
2504                        SelectColumn::Expr {
2505                            expr: Expr::CountStar,
2506                            ..
2507                        } => {}
2508                        other => panic!("expected CountStar, got {other:?}"),
2509                    }
2510                }
2511                _ => panic!("expected QueryBody::Select"),
2512            },
2513            _ => panic!("expected Select"),
2514        }
2515    }
2516
2517    #[test]
2518    fn parse_group_by_having() {
2519        let stmt = parse_sql(
2520            "SELECT department, COUNT(*) FROM employees GROUP BY department HAVING COUNT(*) > 5",
2521        )
2522        .unwrap();
2523        match stmt {
2524            Statement::Select(sq) => match sq.body {
2525                QueryBody::Select(sel) => {
2526                    assert_eq!(sel.group_by.len(), 1);
2527                    assert!(sel.having.is_some());
2528                }
2529                _ => panic!("expected QueryBody::Select"),
2530            },
2531            _ => panic!("expected Select"),
2532        }
2533    }
2534
2535    #[test]
2536    fn parse_expressions() {
2537        let stmt = parse_sql("SELECT id + 1, -price, NOT active FROM items").unwrap();
2538        match stmt {
2539            Statement::Select(sq) => match sq.body {
2540                QueryBody::Select(sel) => {
2541                    assert_eq!(sel.columns.len(), 3);
2542                    // id + 1
2543                    match &sel.columns[0] {
2544                        SelectColumn::Expr {
2545                            expr: Expr::BinaryOp { op: BinOp::Add, .. },
2546                            ..
2547                        } => {}
2548                        other => panic!("expected BinaryOp Add, got {other:?}"),
2549                    }
2550                    // -price
2551                    match &sel.columns[1] {
2552                        SelectColumn::Expr {
2553                            expr:
2554                                Expr::UnaryOp {
2555                                    op: UnaryOp::Neg, ..
2556                                },
2557                            ..
2558                        } => {}
2559                        other => panic!("expected UnaryOp Neg, got {other:?}"),
2560                    }
2561                    // NOT active
2562                    match &sel.columns[2] {
2563                        SelectColumn::Expr {
2564                            expr:
2565                                Expr::UnaryOp {
2566                                    op: UnaryOp::Not, ..
2567                                },
2568                            ..
2569                        } => {}
2570                        other => panic!("expected UnaryOp Not, got {other:?}"),
2571                    }
2572                }
2573                _ => panic!("expected QueryBody::Select"),
2574            },
2575            _ => panic!("expected Select"),
2576        }
2577    }
2578
2579    #[test]
2580    fn parse_is_null() {
2581        let stmt = parse_sql("SELECT * FROM t WHERE x IS NULL").unwrap();
2582        match stmt {
2583            Statement::Select(sq) => match sq.body {
2584                QueryBody::Select(sel) => {
2585                    assert!(matches!(sel.where_clause, Some(Expr::IsNull(_))));
2586                }
2587                _ => panic!("expected QueryBody::Select"),
2588            },
2589            _ => panic!("expected Select"),
2590        }
2591    }
2592
2593    #[test]
2594    fn parse_inner_join() {
2595        let stmt = parse_sql("SELECT * FROM a JOIN b ON a.id = b.id").unwrap();
2596        match stmt {
2597            Statement::Select(sq) => match sq.body {
2598                QueryBody::Select(sel) => {
2599                    assert_eq!(sel.from, "a");
2600                    assert_eq!(sel.joins.len(), 1);
2601                    assert_eq!(sel.joins[0].join_type, JoinType::Inner);
2602                    assert_eq!(sel.joins[0].table.name, "b");
2603                    assert!(sel.joins[0].on_clause.is_some());
2604                }
2605                _ => panic!("expected QueryBody::Select"),
2606            },
2607            _ => panic!("expected Select"),
2608        }
2609    }
2610
2611    #[test]
2612    fn parse_inner_join_explicit() {
2613        let stmt = parse_sql("SELECT * FROM a INNER JOIN b ON a.id = b.a_id").unwrap();
2614        match stmt {
2615            Statement::Select(sq) => match sq.body {
2616                QueryBody::Select(sel) => {
2617                    assert_eq!(sel.joins.len(), 1);
2618                    assert_eq!(sel.joins[0].join_type, JoinType::Inner);
2619                }
2620                _ => panic!("expected QueryBody::Select"),
2621            },
2622            _ => panic!("expected Select"),
2623        }
2624    }
2625
2626    #[test]
2627    fn parse_cross_join() {
2628        let stmt = parse_sql("SELECT * FROM a CROSS JOIN b").unwrap();
2629        match stmt {
2630            Statement::Select(sq) => match sq.body {
2631                QueryBody::Select(sel) => {
2632                    assert_eq!(sel.joins.len(), 1);
2633                    assert_eq!(sel.joins[0].join_type, JoinType::Cross);
2634                    assert!(sel.joins[0].on_clause.is_none());
2635                }
2636                _ => panic!("expected QueryBody::Select"),
2637            },
2638            _ => panic!("expected Select"),
2639        }
2640    }
2641
2642    #[test]
2643    fn parse_left_join() {
2644        let stmt = parse_sql("SELECT * FROM a LEFT JOIN b ON a.id = b.a_id").unwrap();
2645        match stmt {
2646            Statement::Select(sq) => match sq.body {
2647                QueryBody::Select(sel) => {
2648                    assert_eq!(sel.joins.len(), 1);
2649                    assert_eq!(sel.joins[0].join_type, JoinType::Left);
2650                }
2651                _ => panic!("expected QueryBody::Select"),
2652            },
2653            _ => panic!("expected Select"),
2654        }
2655    }
2656
2657    #[test]
2658    fn parse_table_alias() {
2659        let stmt = parse_sql("SELECT u.id FROM users u JOIN orders o ON u.id = o.user_id").unwrap();
2660        match stmt {
2661            Statement::Select(sq) => match sq.body {
2662                QueryBody::Select(sel) => {
2663                    assert_eq!(sel.from, "users");
2664                    assert_eq!(sel.from_alias.as_deref(), Some("u"));
2665                    assert_eq!(sel.joins[0].table.name, "orders");
2666                    assert_eq!(sel.joins[0].table.alias.as_deref(), Some("o"));
2667                }
2668                _ => panic!("expected QueryBody::Select"),
2669            },
2670            _ => panic!("expected Select"),
2671        }
2672    }
2673
2674    #[test]
2675    fn parse_multi_join() {
2676        let stmt =
2677            parse_sql("SELECT * FROM a JOIN b ON a.id = b.a_id JOIN c ON b.id = c.b_id").unwrap();
2678        match stmt {
2679            Statement::Select(sq) => match sq.body {
2680                QueryBody::Select(sel) => {
2681                    assert_eq!(sel.joins.len(), 2);
2682                }
2683                _ => panic!("expected QueryBody::Select"),
2684            },
2685            _ => panic!("expected Select"),
2686        }
2687    }
2688
2689    #[test]
2690    fn parse_qualified_column() {
2691        let stmt = parse_sql("SELECT u.id, u.name FROM users u").unwrap();
2692        match stmt {
2693            Statement::Select(sq) => match sq.body {
2694                QueryBody::Select(sel) => match &sel.columns[0] {
2695                    SelectColumn::Expr {
2696                        expr: Expr::QualifiedColumn { table, column },
2697                        ..
2698                    } => {
2699                        assert_eq!(table, "u");
2700                        assert_eq!(column, "id");
2701                    }
2702                    other => panic!("expected QualifiedColumn, got {other:?}"),
2703                },
2704                _ => panic!("expected QueryBody::Select"),
2705            },
2706            _ => panic!("expected Select"),
2707        }
2708    }
2709
2710    #[test]
2711    fn reject_subquery() {
2712        assert!(parse_sql("SELECT * FROM (SELECT 1)").is_err());
2713    }
2714
2715    #[test]
2716    fn parse_type_mapping() {
2717        let stmt = parse_sql(
2718            "CREATE TABLE t (a INT PRIMARY KEY, b BIGINT, c SMALLINT, d REAL, e DOUBLE PRECISION, f VARCHAR(255), g BOOLEAN, h BLOB, i BYTEA)"
2719        ).unwrap();
2720        match stmt {
2721            Statement::CreateTable(ct) => {
2722                assert_eq!(ct.columns[0].data_type, DataType::Integer); // INT
2723                assert_eq!(ct.columns[1].data_type, DataType::Integer); // BIGINT
2724                assert_eq!(ct.columns[2].data_type, DataType::Integer); // SMALLINT
2725                assert_eq!(ct.columns[3].data_type, DataType::Real); // REAL
2726                assert_eq!(ct.columns[4].data_type, DataType::Real); // DOUBLE
2727                assert_eq!(ct.columns[5].data_type, DataType::Text); // VARCHAR
2728                assert_eq!(ct.columns[6].data_type, DataType::Boolean); // BOOLEAN
2729                assert_eq!(ct.columns[7].data_type, DataType::Blob); // BLOB
2730                assert_eq!(ct.columns[8].data_type, DataType::Blob); // BYTEA
2731            }
2732            _ => panic!("expected CreateTable"),
2733        }
2734    }
2735
2736    #[test]
2737    fn parse_boolean_literals() {
2738        let stmt = parse_sql("INSERT INTO t (a, b) VALUES (true, false)").unwrap();
2739        match stmt {
2740            Statement::Insert(ins) => {
2741                let values = match &ins.source {
2742                    InsertSource::Values(v) => v,
2743                    _ => panic!("expected Values"),
2744                };
2745                assert!(matches!(values[0][0], Expr::Literal(Value::Boolean(true))));
2746                assert!(matches!(values[0][1], Expr::Literal(Value::Boolean(false))));
2747            }
2748            _ => panic!("expected Insert"),
2749        }
2750    }
2751
2752    #[test]
2753    fn parse_null_literal() {
2754        let stmt = parse_sql("INSERT INTO t (a) VALUES (NULL)").unwrap();
2755        match stmt {
2756            Statement::Insert(ins) => {
2757                let values = match &ins.source {
2758                    InsertSource::Values(v) => v,
2759                    _ => panic!("expected Values"),
2760                };
2761                assert!(matches!(values[0][0], Expr::Literal(Value::Null)));
2762            }
2763            _ => panic!("expected Insert"),
2764        }
2765    }
2766
2767    #[test]
2768    fn parse_alias() {
2769        let stmt = parse_sql("SELECT id AS user_id FROM users").unwrap();
2770        match stmt {
2771            Statement::Select(sq) => match sq.body {
2772                QueryBody::Select(sel) => match &sel.columns[0] {
2773                    SelectColumn::Expr { alias: Some(a), .. } => assert_eq!(a, "user_id"),
2774                    other => panic!("expected alias, got {other:?}"),
2775                },
2776                _ => panic!("expected QueryBody::Select"),
2777            },
2778            _ => panic!("expected Select"),
2779        }
2780    }
2781
2782    #[test]
2783    fn parse_begin() {
2784        let stmt = parse_sql("BEGIN").unwrap();
2785        assert!(matches!(stmt, Statement::Begin));
2786    }
2787
2788    #[test]
2789    fn parse_begin_transaction() {
2790        let stmt = parse_sql("BEGIN TRANSACTION").unwrap();
2791        assert!(matches!(stmt, Statement::Begin));
2792    }
2793
2794    #[test]
2795    fn parse_commit() {
2796        let stmt = parse_sql("COMMIT").unwrap();
2797        assert!(matches!(stmt, Statement::Commit));
2798    }
2799
2800    #[test]
2801    fn parse_rollback() {
2802        let stmt = parse_sql("ROLLBACK").unwrap();
2803        assert!(matches!(stmt, Statement::Rollback));
2804    }
2805
2806    #[test]
2807    fn parse_savepoint() {
2808        let stmt = parse_sql("SAVEPOINT sp1").unwrap();
2809        match stmt {
2810            Statement::Savepoint(name) => assert_eq!(name, "sp1"),
2811            other => panic!("expected Savepoint, got {other:?}"),
2812        }
2813    }
2814
2815    #[test]
2816    fn parse_savepoint_case_insensitive() {
2817        let stmt = parse_sql("SAVEPOINT My_SP").unwrap();
2818        match stmt {
2819            Statement::Savepoint(name) => assert_eq!(name, "my_sp"),
2820            other => panic!("expected Savepoint, got {other:?}"),
2821        }
2822    }
2823
2824    #[test]
2825    fn parse_release_savepoint() {
2826        let stmt = parse_sql("RELEASE SAVEPOINT sp1").unwrap();
2827        match stmt {
2828            Statement::ReleaseSavepoint(name) => assert_eq!(name, "sp1"),
2829            other => panic!("expected ReleaseSavepoint, got {other:?}"),
2830        }
2831    }
2832
2833    #[test]
2834    fn parse_release_without_savepoint_keyword() {
2835        let stmt = parse_sql("RELEASE sp1").unwrap();
2836        match stmt {
2837            Statement::ReleaseSavepoint(name) => assert_eq!(name, "sp1"),
2838            other => panic!("expected ReleaseSavepoint, got {other:?}"),
2839        }
2840    }
2841
2842    #[test]
2843    fn parse_rollback_to_savepoint() {
2844        let stmt = parse_sql("ROLLBACK TO SAVEPOINT sp1").unwrap();
2845        match stmt {
2846            Statement::RollbackTo(name) => assert_eq!(name, "sp1"),
2847            other => panic!("expected RollbackTo, got {other:?}"),
2848        }
2849    }
2850
2851    #[test]
2852    fn parse_rollback_to_without_savepoint_keyword() {
2853        let stmt = parse_sql("ROLLBACK TO sp1").unwrap();
2854        match stmt {
2855            Statement::RollbackTo(name) => assert_eq!(name, "sp1"),
2856            other => panic!("expected RollbackTo, got {other:?}"),
2857        }
2858    }
2859
2860    #[test]
2861    fn parse_rollback_to_case_insensitive() {
2862        let stmt = parse_sql("ROLLBACK TO My_SP").unwrap();
2863        match stmt {
2864            Statement::RollbackTo(name) => assert_eq!(name, "my_sp"),
2865            other => panic!("expected RollbackTo, got {other:?}"),
2866        }
2867    }
2868
2869    #[test]
2870    fn parse_commit_and_chain_rejected() {
2871        let err = parse_sql("COMMIT AND CHAIN").unwrap_err();
2872        assert!(matches!(err, SqlError::Unsupported(_)));
2873    }
2874
2875    #[test]
2876    fn parse_rollback_and_chain_rejected() {
2877        let err = parse_sql("ROLLBACK AND CHAIN").unwrap_err();
2878        assert!(matches!(err, SqlError::Unsupported(_)));
2879    }
2880
2881    #[test]
2882    fn parse_select_distinct() {
2883        let stmt = parse_sql("SELECT DISTINCT name FROM users").unwrap();
2884        match stmt {
2885            Statement::Select(sq) => match sq.body {
2886                QueryBody::Select(sel) => {
2887                    assert!(sel.distinct);
2888                    assert_eq!(sel.columns.len(), 1);
2889                }
2890                _ => panic!("expected QueryBody::Select"),
2891            },
2892            _ => panic!("expected Select"),
2893        }
2894    }
2895
2896    #[test]
2897    fn parse_select_without_distinct() {
2898        let stmt = parse_sql("SELECT name FROM users").unwrap();
2899        match stmt {
2900            Statement::Select(sq) => match sq.body {
2901                QueryBody::Select(sel) => {
2902                    assert!(!sel.distinct);
2903                }
2904                _ => panic!("expected QueryBody::Select"),
2905            },
2906            _ => panic!("expected Select"),
2907        }
2908    }
2909
2910    #[test]
2911    fn parse_select_distinct_all_columns() {
2912        let stmt = parse_sql("SELECT DISTINCT * FROM users").unwrap();
2913        match stmt {
2914            Statement::Select(sq) => match sq.body {
2915                QueryBody::Select(sel) => {
2916                    assert!(sel.distinct);
2917                    assert!(matches!(sel.columns[0], SelectColumn::AllColumns));
2918                }
2919                _ => panic!("expected QueryBody::Select"),
2920            },
2921            _ => panic!("expected Select"),
2922        }
2923    }
2924
2925    #[test]
2926    fn reject_distinct_on() {
2927        assert!(parse_sql("SELECT DISTINCT ON (id) * FROM users").is_err());
2928    }
2929
2930    #[test]
2931    fn parse_create_index() {
2932        let stmt = parse_sql("CREATE INDEX idx_name ON users (name)").unwrap();
2933        match stmt {
2934            Statement::CreateIndex(ci) => {
2935                assert_eq!(ci.index_name, "idx_name");
2936                assert_eq!(ci.table_name, "users");
2937                assert_eq!(ci.columns, vec!["name"]);
2938                assert!(!ci.unique);
2939                assert!(!ci.if_not_exists);
2940            }
2941            _ => panic!("expected CreateIndex"),
2942        }
2943    }
2944
2945    #[test]
2946    fn parse_create_unique_index() {
2947        let stmt = parse_sql("CREATE UNIQUE INDEX idx_email ON users (email)").unwrap();
2948        match stmt {
2949            Statement::CreateIndex(ci) => {
2950                assert!(ci.unique);
2951                assert_eq!(ci.columns, vec!["email"]);
2952            }
2953            _ => panic!("expected CreateIndex"),
2954        }
2955    }
2956
2957    #[test]
2958    fn parse_create_index_if_not_exists() {
2959        let stmt = parse_sql("CREATE INDEX IF NOT EXISTS idx_x ON t (a)").unwrap();
2960        match stmt {
2961            Statement::CreateIndex(ci) => assert!(ci.if_not_exists),
2962            _ => panic!("expected CreateIndex"),
2963        }
2964    }
2965
2966    #[test]
2967    fn parse_create_index_multi_column() {
2968        let stmt = parse_sql("CREATE INDEX idx_multi ON t (a, b, c)").unwrap();
2969        match stmt {
2970            Statement::CreateIndex(ci) => {
2971                assert_eq!(ci.columns, vec!["a", "b", "c"]);
2972            }
2973            _ => panic!("expected CreateIndex"),
2974        }
2975    }
2976
2977    #[test]
2978    fn parse_drop_index() {
2979        let stmt = parse_sql("DROP INDEX idx_name").unwrap();
2980        match stmt {
2981            Statement::DropIndex(di) => {
2982                assert_eq!(di.index_name, "idx_name");
2983                assert!(!di.if_exists);
2984            }
2985            _ => panic!("expected DropIndex"),
2986        }
2987    }
2988
2989    #[test]
2990    fn parse_drop_index_if_exists() {
2991        let stmt = parse_sql("DROP INDEX IF EXISTS idx_name").unwrap();
2992        match stmt {
2993            Statement::DropIndex(di) => {
2994                assert!(di.if_exists);
2995            }
2996            _ => panic!("expected DropIndex"),
2997        }
2998    }
2999
3000    #[test]
3001    fn parse_explain_select() {
3002        let stmt = parse_sql("EXPLAIN SELECT * FROM users WHERE id = 1").unwrap();
3003        match stmt {
3004            Statement::Explain(inner) => {
3005                assert!(matches!(*inner, Statement::Select(_)));
3006            }
3007            _ => panic!("expected Explain"),
3008        }
3009    }
3010
3011    #[test]
3012    fn parse_explain_insert() {
3013        let stmt = parse_sql("EXPLAIN INSERT INTO t (a) VALUES (1)").unwrap();
3014        assert!(matches!(stmt, Statement::Explain(_)));
3015    }
3016
3017    #[test]
3018    fn reject_explain_analyze() {
3019        assert!(parse_sql("EXPLAIN ANALYZE SELECT * FROM t").is_err());
3020    }
3021
3022    #[test]
3023    fn parse_parameter_placeholder() {
3024        let stmt = parse_sql("SELECT * FROM t WHERE id = $1").unwrap();
3025        match stmt {
3026            Statement::Select(sq) => match sq.body {
3027                QueryBody::Select(sel) => match &sel.where_clause {
3028                    Some(Expr::BinaryOp { right, .. }) => {
3029                        assert!(matches!(right.as_ref(), Expr::Parameter(1)));
3030                    }
3031                    other => panic!("expected BinaryOp with Parameter, got {other:?}"),
3032                },
3033                _ => panic!("expected QueryBody::Select"),
3034            },
3035            _ => panic!("expected Select"),
3036        }
3037    }
3038
3039    #[test]
3040    fn parse_multiple_parameters() {
3041        let stmt = parse_sql("INSERT INTO t (a, b) VALUES ($1, $2)").unwrap();
3042        match stmt {
3043            Statement::Insert(ins) => {
3044                let values = match &ins.source {
3045                    InsertSource::Values(v) => v,
3046                    _ => panic!("expected Values"),
3047                };
3048                assert!(matches!(values[0][0], Expr::Parameter(1)));
3049                assert!(matches!(values[0][1], Expr::Parameter(2)));
3050            }
3051            _ => panic!("expected Insert"),
3052        }
3053    }
3054
3055    #[test]
3056    fn parse_insert_select() {
3057        let stmt =
3058            parse_sql("INSERT INTO t2 (id, name) SELECT id, name FROM t1 WHERE id > 5").unwrap();
3059        match stmt {
3060            Statement::Insert(ins) => {
3061                assert_eq!(ins.table, "t2");
3062                assert_eq!(ins.columns, vec!["id", "name"]);
3063                match &ins.source {
3064                    InsertSource::Select(sq) => match &sq.body {
3065                        QueryBody::Select(sel) => {
3066                            assert_eq!(sel.from, "t1");
3067                            assert_eq!(sel.columns.len(), 2);
3068                            assert!(sel.where_clause.is_some());
3069                        }
3070                        _ => panic!("expected QueryBody::Select"),
3071                    },
3072                    _ => panic!("expected InsertSource::Select"),
3073                }
3074            }
3075            _ => panic!("expected Insert"),
3076        }
3077    }
3078
3079    #[test]
3080    fn parse_insert_select_no_columns() {
3081        let stmt = parse_sql("INSERT INTO t2 SELECT * FROM t1").unwrap();
3082        match stmt {
3083            Statement::Insert(ins) => {
3084                assert_eq!(ins.table, "t2");
3085                assert!(ins.columns.is_empty());
3086                assert!(matches!(&ins.source, InsertSource::Select(_)));
3087            }
3088            _ => panic!("expected Insert"),
3089        }
3090    }
3091
3092    #[test]
3093    fn reject_zero_parameter() {
3094        assert!(parse_sql("SELECT $0 FROM t").is_err());
3095    }
3096
3097    #[test]
3098    fn count_params_basic() {
3099        let stmt = parse_sql("SELECT * FROM t WHERE a = $1 AND b = $3").unwrap();
3100        assert_eq!(count_params(&stmt), 3);
3101    }
3102
3103    #[test]
3104    fn count_params_none() {
3105        let stmt = parse_sql("SELECT * FROM t WHERE a = 1").unwrap();
3106        assert_eq!(count_params(&stmt), 0);
3107    }
3108
3109    #[test]
3110    fn parse_table_constraint_pk() {
3111        let stmt = parse_sql("CREATE TABLE t (a INTEGER, b TEXT, PRIMARY KEY (a))").unwrap();
3112        match stmt {
3113            Statement::CreateTable(ct) => {
3114                assert_eq!(ct.primary_key, vec!["a"]);
3115                assert!(ct.columns[0].is_primary_key);
3116                assert!(!ct.columns[0].nullable);
3117            }
3118            _ => panic!("expected CreateTable"),
3119        }
3120    }
3121}