Skip to main content

citadel_sql/
parser.rs

1//! SQL parser: converts SQL strings into our internal AST.
2
3use sqlparser::ast as sp;
4use sqlparser::dialect::GenericDialect;
5use sqlparser::parser::Parser;
6
7use crate::error::{Result, SqlError};
8use crate::types::{DataType, Value};
9
10// ── Internal AST ────────────────────────────────────────────────────
11
12#[derive(Debug, Clone)]
13pub enum Statement {
14    CreateTable(CreateTableStmt),
15    DropTable(DropTableStmt),
16    CreateIndex(CreateIndexStmt),
17    DropIndex(DropIndexStmt),
18    AlterTable(Box<AlterTableStmt>),
19    Insert(InsertStmt),
20    Select(Box<SelectQuery>),
21    Update(UpdateStmt),
22    Delete(DeleteStmt),
23    Begin,
24    Commit,
25    Rollback,
26    Explain(Box<Statement>),
27}
28
29#[derive(Debug, Clone)]
30pub struct AlterTableStmt {
31    pub table: String,
32    pub op: AlterTableOp,
33}
34
35#[derive(Debug, Clone)]
36pub enum AlterTableOp {
37    AddColumn {
38        column: Box<ColumnSpec>,
39        foreign_key: Option<ForeignKeyDef>,
40        if_not_exists: bool,
41    },
42    DropColumn {
43        name: String,
44        if_exists: bool,
45    },
46    RenameColumn {
47        old_name: String,
48        new_name: String,
49    },
50    RenameTable {
51        new_name: String,
52    },
53}
54
55#[derive(Debug, Clone)]
56pub struct CreateTableStmt {
57    pub name: String,
58    pub columns: Vec<ColumnSpec>,
59    pub primary_key: Vec<String>,
60    pub if_not_exists: bool,
61    pub check_constraints: Vec<TableCheckConstraint>,
62    pub foreign_keys: Vec<ForeignKeyDef>,
63}
64
65#[derive(Debug, Clone)]
66pub struct TableCheckConstraint {
67    pub name: Option<String>,
68    pub expr: Expr,
69    pub sql: String,
70}
71
72#[derive(Debug, Clone)]
73pub struct ForeignKeyDef {
74    pub name: Option<String>,
75    pub columns: Vec<String>,
76    pub foreign_table: String,
77    pub referred_columns: Vec<String>,
78}
79
80#[derive(Debug, Clone)]
81pub struct ColumnSpec {
82    pub name: String,
83    pub data_type: DataType,
84    pub nullable: bool,
85    pub is_primary_key: bool,
86    pub default_expr: Option<Expr>,
87    pub default_sql: Option<String>,
88    pub check_expr: Option<Expr>,
89    pub check_sql: Option<String>,
90    pub check_name: Option<String>,
91}
92
93#[derive(Debug, Clone)]
94pub struct DropTableStmt {
95    pub name: String,
96    pub if_exists: bool,
97}
98
99#[derive(Debug, Clone)]
100pub struct CreateIndexStmt {
101    pub index_name: String,
102    pub table_name: String,
103    pub columns: Vec<String>,
104    pub unique: bool,
105    pub if_not_exists: bool,
106}
107
108#[derive(Debug, Clone)]
109pub struct DropIndexStmt {
110    pub index_name: String,
111    pub if_exists: bool,
112}
113
114#[derive(Debug, Clone)]
115pub enum InsertSource {
116    Values(Vec<Vec<Expr>>),
117    Select(Box<SelectQuery>),
118}
119
120#[derive(Debug, Clone)]
121pub struct InsertStmt {
122    pub table: String,
123    pub columns: Vec<String>,
124    pub source: InsertSource,
125}
126
127#[derive(Debug, Clone)]
128pub struct TableRef {
129    pub name: String,
130    pub alias: Option<String>,
131}
132
133#[derive(Debug, Clone, Copy, PartialEq)]
134pub enum JoinType {
135    Inner,
136    Cross,
137    Left,
138    Right,
139}
140
141#[derive(Debug, Clone)]
142pub struct JoinClause {
143    pub join_type: JoinType,
144    pub table: TableRef,
145    pub on_clause: Option<Expr>,
146}
147
148#[derive(Debug, Clone)]
149pub struct SelectStmt {
150    pub columns: Vec<SelectColumn>,
151    pub from: String,
152    pub from_alias: Option<String>,
153    pub joins: Vec<JoinClause>,
154    pub distinct: bool,
155    pub where_clause: Option<Expr>,
156    pub order_by: Vec<OrderByItem>,
157    pub limit: Option<Expr>,
158    pub offset: Option<Expr>,
159    pub group_by: Vec<Expr>,
160    pub having: Option<Expr>,
161}
162
163#[derive(Debug, Clone)]
164pub enum SetOp {
165    Union,
166    Intersect,
167    Except,
168}
169
170#[derive(Debug, Clone)]
171pub struct CompoundSelect {
172    pub op: SetOp,
173    pub all: bool,
174    pub left: Box<QueryBody>,
175    pub right: Box<QueryBody>,
176    pub order_by: Vec<OrderByItem>,
177    pub limit: Option<Expr>,
178    pub offset: Option<Expr>,
179}
180
181#[derive(Debug, Clone)]
182pub enum QueryBody {
183    Select(Box<SelectStmt>),
184    Compound(CompoundSelect),
185}
186
187#[derive(Debug, Clone)]
188pub struct CteDefinition {
189    pub name: String,
190    pub column_aliases: Vec<String>,
191    pub body: QueryBody,
192}
193
194#[derive(Debug, Clone)]
195pub struct SelectQuery {
196    pub ctes: Vec<CteDefinition>,
197    pub recursive: bool,
198    pub body: QueryBody,
199}
200
201#[derive(Debug, Clone)]
202pub struct UpdateStmt {
203    pub table: String,
204    pub assignments: Vec<(String, Expr)>,
205    pub where_clause: Option<Expr>,
206}
207
208#[derive(Debug, Clone)]
209pub struct DeleteStmt {
210    pub table: String,
211    pub where_clause: Option<Expr>,
212}
213
214#[derive(Debug, Clone)]
215pub enum SelectColumn {
216    AllColumns,
217    Expr { expr: Expr, alias: Option<String> },
218}
219
220#[derive(Debug, Clone)]
221pub struct OrderByItem {
222    pub expr: Expr,
223    pub descending: bool,
224    pub nulls_first: Option<bool>,
225}
226
227#[derive(Debug, Clone)]
228pub enum Expr {
229    Literal(Value),
230    Column(String),
231    QualifiedColumn {
232        table: String,
233        column: String,
234    },
235    BinaryOp {
236        left: Box<Expr>,
237        op: BinOp,
238        right: Box<Expr>,
239    },
240    UnaryOp {
241        op: UnaryOp,
242        expr: Box<Expr>,
243    },
244    IsNull(Box<Expr>),
245    IsNotNull(Box<Expr>),
246    Function {
247        name: String,
248        args: Vec<Expr>,
249    },
250    CountStar,
251    InSubquery {
252        expr: Box<Expr>,
253        subquery: Box<SelectStmt>,
254        negated: bool,
255    },
256    InList {
257        expr: Box<Expr>,
258        list: Vec<Expr>,
259        negated: bool,
260    },
261    Exists {
262        subquery: Box<SelectStmt>,
263        negated: bool,
264    },
265    ScalarSubquery(Box<SelectStmt>),
266    InSet {
267        expr: Box<Expr>,
268        values: std::collections::HashSet<Value>,
269        has_null: bool,
270        negated: bool,
271    },
272    Between {
273        expr: Box<Expr>,
274        low: Box<Expr>,
275        high: Box<Expr>,
276        negated: bool,
277    },
278    Like {
279        expr: Box<Expr>,
280        pattern: Box<Expr>,
281        escape: Option<Box<Expr>>,
282        negated: bool,
283    },
284    Case {
285        operand: Option<Box<Expr>>,
286        conditions: Vec<(Expr, Expr)>,
287        else_result: Option<Box<Expr>>,
288    },
289    Coalesce(Vec<Expr>),
290    Cast {
291        expr: Box<Expr>,
292        data_type: DataType,
293    },
294    Parameter(usize),
295}
296
297#[derive(Debug, Clone, Copy, PartialEq, Eq)]
298pub enum BinOp {
299    Add,
300    Sub,
301    Mul,
302    Div,
303    Mod,
304    Eq,
305    NotEq,
306    Lt,
307    Gt,
308    LtEq,
309    GtEq,
310    And,
311    Or,
312    Concat,
313}
314
315#[derive(Debug, Clone, Copy, PartialEq, Eq)]
316pub enum UnaryOp {
317    Neg,
318    Not,
319}
320
321// ── Expression utilities ────────────────────────────────────────────
322
323pub fn has_subquery(expr: &Expr) -> bool {
324    match expr {
325        Expr::InSubquery { .. } | Expr::Exists { .. } | Expr::ScalarSubquery(_) => true,
326        Expr::BinaryOp { left, right, .. } => has_subquery(left) || has_subquery(right),
327        Expr::UnaryOp { expr, .. } => has_subquery(expr),
328        Expr::IsNull(e) | Expr::IsNotNull(e) => has_subquery(e),
329        Expr::InList { expr, list, .. } => has_subquery(expr) || list.iter().any(has_subquery),
330        Expr::InSet { expr, .. } => has_subquery(expr),
331        Expr::Between {
332            expr, low, high, ..
333        } => has_subquery(expr) || has_subquery(low) || has_subquery(high),
334        Expr::Like {
335            expr,
336            pattern,
337            escape,
338            ..
339        } => {
340            has_subquery(expr)
341                || has_subquery(pattern)
342                || escape.as_ref().is_some_and(|e| has_subquery(e))
343        }
344        Expr::Case {
345            operand,
346            conditions,
347            else_result,
348        } => {
349            operand.as_ref().is_some_and(|e| has_subquery(e))
350                || conditions
351                    .iter()
352                    .any(|(c, r)| has_subquery(c) || has_subquery(r))
353                || else_result.as_ref().is_some_and(|e| has_subquery(e))
354        }
355        Expr::Coalesce(args) | Expr::Function { args, .. } => args.iter().any(has_subquery),
356        Expr::Cast { expr, .. } => has_subquery(expr),
357        _ => false,
358    }
359}
360
361/// Parse a SQL expression string back into an internal Expr.
362/// Used for deserializing stored DEFAULT/CHECK expressions from schema.
363pub fn parse_sql_expr(sql: &str) -> Result<Expr> {
364    let dialect = GenericDialect {};
365    let mut parser = Parser::new(&dialect)
366        .try_with_sql(sql)
367        .map_err(|e| SqlError::Parse(e.to_string()))?;
368    let sp_expr = parser
369        .parse_expr()
370        .map_err(|e| SqlError::Parse(e.to_string()))?;
371    convert_expr(&sp_expr)
372}
373
374// ── Parser entry point ──────────────────────────────────────────────
375
376pub fn parse_sql(sql: &str) -> Result<Statement> {
377    let dialect = GenericDialect {};
378    let stmts = Parser::parse_sql(&dialect, sql).map_err(|e| SqlError::Parse(e.to_string()))?;
379
380    if stmts.is_empty() {
381        return Err(SqlError::Parse("empty SQL".into()));
382    }
383    if stmts.len() > 1 {
384        return Err(SqlError::Unsupported("multiple statements".into()));
385    }
386
387    convert_statement(stmts.into_iter().next().unwrap())
388}
389
390// ── Parameter utilities ─────────────────────────────────────────────
391
392/// Returns the number of distinct parameters in a statement (max $N found).
393pub fn count_params(stmt: &Statement) -> usize {
394    let mut max_idx = 0usize;
395    visit_exprs_stmt(stmt, &mut |e| {
396        if let Expr::Parameter(n) = e {
397            max_idx = max_idx.max(*n);
398        }
399    });
400    max_idx
401}
402
403/// Replace all `Expr::Parameter(n)` with `Expr::Literal(params[n-1])`.
404pub fn bind_params(
405    stmt: &Statement,
406    params: &[crate::types::Value],
407) -> crate::error::Result<Statement> {
408    bind_stmt(stmt, params)
409}
410
411fn bind_stmt(stmt: &Statement, params: &[crate::types::Value]) -> crate::error::Result<Statement> {
412    match stmt {
413        Statement::Select(sq) => Ok(Statement::Select(Box::new(bind_select_query(sq, params)?))),
414        Statement::Insert(ins) => {
415            let source = match &ins.source {
416                InsertSource::Values(rows) => {
417                    let bound = rows
418                        .iter()
419                        .map(|row| {
420                            row.iter()
421                                .map(|e| bind_expr(e, params))
422                                .collect::<crate::error::Result<Vec<_>>>()
423                        })
424                        .collect::<crate::error::Result<Vec<_>>>()?;
425                    InsertSource::Values(bound)
426                }
427                InsertSource::Select(sq) => {
428                    InsertSource::Select(Box::new(bind_select_query(sq, params)?))
429                }
430            };
431            Ok(Statement::Insert(InsertStmt {
432                table: ins.table.clone(),
433                columns: ins.columns.clone(),
434                source,
435            }))
436        }
437        Statement::Update(upd) => {
438            let assignments = upd
439                .assignments
440                .iter()
441                .map(|(col, e)| Ok((col.clone(), bind_expr(e, params)?)))
442                .collect::<crate::error::Result<Vec<_>>>()?;
443            let where_clause = upd
444                .where_clause
445                .as_ref()
446                .map(|e| bind_expr(e, params))
447                .transpose()?;
448            Ok(Statement::Update(UpdateStmt {
449                table: upd.table.clone(),
450                assignments,
451                where_clause,
452            }))
453        }
454        Statement::Delete(del) => {
455            let where_clause = del
456                .where_clause
457                .as_ref()
458                .map(|e| bind_expr(e, params))
459                .transpose()?;
460            Ok(Statement::Delete(DeleteStmt {
461                table: del.table.clone(),
462                where_clause,
463            }))
464        }
465        Statement::Explain(inner) => Ok(Statement::Explain(Box::new(bind_stmt(inner, params)?))),
466        other => Ok(other.clone()),
467    }
468}
469
470fn bind_select(
471    sel: &SelectStmt,
472    params: &[crate::types::Value],
473) -> crate::error::Result<SelectStmt> {
474    let columns = sel
475        .columns
476        .iter()
477        .map(|c| match c {
478            SelectColumn::AllColumns => Ok(SelectColumn::AllColumns),
479            SelectColumn::Expr { expr, alias } => Ok(SelectColumn::Expr {
480                expr: bind_expr(expr, params)?,
481                alias: alias.clone(),
482            }),
483        })
484        .collect::<crate::error::Result<Vec<_>>>()?;
485    let joins = sel
486        .joins
487        .iter()
488        .map(|j| {
489            let on_clause = j
490                .on_clause
491                .as_ref()
492                .map(|e| bind_expr(e, params))
493                .transpose()?;
494            Ok(JoinClause {
495                join_type: j.join_type,
496                table: j.table.clone(),
497                on_clause,
498            })
499        })
500        .collect::<crate::error::Result<Vec<_>>>()?;
501    let where_clause = sel
502        .where_clause
503        .as_ref()
504        .map(|e| bind_expr(e, params))
505        .transpose()?;
506    let order_by = sel
507        .order_by
508        .iter()
509        .map(|o| {
510            Ok(OrderByItem {
511                expr: bind_expr(&o.expr, params)?,
512                descending: o.descending,
513                nulls_first: o.nulls_first,
514            })
515        })
516        .collect::<crate::error::Result<Vec<_>>>()?;
517    let limit = sel
518        .limit
519        .as_ref()
520        .map(|e| bind_expr(e, params))
521        .transpose()?;
522    let offset = sel
523        .offset
524        .as_ref()
525        .map(|e| bind_expr(e, params))
526        .transpose()?;
527    let group_by = sel
528        .group_by
529        .iter()
530        .map(|e| bind_expr(e, params))
531        .collect::<crate::error::Result<Vec<_>>>()?;
532    let having = sel
533        .having
534        .as_ref()
535        .map(|e| bind_expr(e, params))
536        .transpose()?;
537
538    Ok(SelectStmt {
539        columns,
540        from: sel.from.clone(),
541        from_alias: sel.from_alias.clone(),
542        joins,
543        distinct: sel.distinct,
544        where_clause,
545        order_by,
546        limit,
547        offset,
548        group_by,
549        having,
550    })
551}
552
553fn bind_query_body(
554    body: &QueryBody,
555    params: &[crate::types::Value],
556) -> crate::error::Result<QueryBody> {
557    match body {
558        QueryBody::Select(sel) => Ok(QueryBody::Select(Box::new(bind_select(sel, params)?))),
559        QueryBody::Compound(comp) => {
560            let order_by = comp
561                .order_by
562                .iter()
563                .map(|o| {
564                    Ok(OrderByItem {
565                        expr: bind_expr(&o.expr, params)?,
566                        descending: o.descending,
567                        nulls_first: o.nulls_first,
568                    })
569                })
570                .collect::<crate::error::Result<Vec<_>>>()?;
571            let limit = comp
572                .limit
573                .as_ref()
574                .map(|e| bind_expr(e, params))
575                .transpose()?;
576            let offset = comp
577                .offset
578                .as_ref()
579                .map(|e| bind_expr(e, params))
580                .transpose()?;
581            Ok(QueryBody::Compound(CompoundSelect {
582                op: comp.op.clone(),
583                all: comp.all,
584                left: Box::new(bind_query_body(&comp.left, params)?),
585                right: Box::new(bind_query_body(&comp.right, params)?),
586                order_by,
587                limit,
588                offset,
589            }))
590        }
591    }
592}
593
594fn bind_select_query(
595    sq: &SelectQuery,
596    params: &[crate::types::Value],
597) -> crate::error::Result<SelectQuery> {
598    let ctes = sq
599        .ctes
600        .iter()
601        .map(|cte| {
602            Ok(CteDefinition {
603                name: cte.name.clone(),
604                column_aliases: cte.column_aliases.clone(),
605                body: bind_query_body(&cte.body, params)?,
606            })
607        })
608        .collect::<crate::error::Result<Vec<_>>>()?;
609    let body = bind_query_body(&sq.body, params)?;
610    Ok(SelectQuery {
611        ctes,
612        recursive: sq.recursive,
613        body,
614    })
615}
616
617fn bind_expr(expr: &Expr, params: &[crate::types::Value]) -> crate::error::Result<Expr> {
618    match expr {
619        Expr::Parameter(n) => {
620            if *n == 0 || *n > params.len() {
621                return Err(SqlError::ParameterCountMismatch {
622                    expected: *n,
623                    got: params.len(),
624                });
625            }
626            Ok(Expr::Literal(params[*n - 1].clone()))
627        }
628        Expr::Literal(_) | Expr::Column(_) | Expr::QualifiedColumn { .. } | Expr::CountStar => {
629            Ok(expr.clone())
630        }
631        Expr::BinaryOp { left, op, right } => Ok(Expr::BinaryOp {
632            left: Box::new(bind_expr(left, params)?),
633            op: *op,
634            right: Box::new(bind_expr(right, params)?),
635        }),
636        Expr::UnaryOp { op, expr: e } => Ok(Expr::UnaryOp {
637            op: *op,
638            expr: Box::new(bind_expr(e, params)?),
639        }),
640        Expr::IsNull(e) => Ok(Expr::IsNull(Box::new(bind_expr(e, params)?))),
641        Expr::IsNotNull(e) => Ok(Expr::IsNotNull(Box::new(bind_expr(e, params)?))),
642        Expr::Function { name, args } => {
643            let args = args
644                .iter()
645                .map(|a| bind_expr(a, params))
646                .collect::<crate::error::Result<Vec<_>>>()?;
647            Ok(Expr::Function {
648                name: name.clone(),
649                args,
650            })
651        }
652        Expr::InSubquery {
653            expr: e,
654            subquery,
655            negated,
656        } => Ok(Expr::InSubquery {
657            expr: Box::new(bind_expr(e, params)?),
658            subquery: Box::new(bind_select(subquery, params)?),
659            negated: *negated,
660        }),
661        Expr::InList {
662            expr: e,
663            list,
664            negated,
665        } => {
666            let list = list
667                .iter()
668                .map(|l| bind_expr(l, params))
669                .collect::<crate::error::Result<Vec<_>>>()?;
670            Ok(Expr::InList {
671                expr: Box::new(bind_expr(e, params)?),
672                list,
673                negated: *negated,
674            })
675        }
676        Expr::Exists { subquery, negated } => Ok(Expr::Exists {
677            subquery: Box::new(bind_select(subquery, params)?),
678            negated: *negated,
679        }),
680        Expr::ScalarSubquery(sq) => Ok(Expr::ScalarSubquery(Box::new(bind_select(sq, params)?))),
681        Expr::InSet {
682            expr: e,
683            values,
684            has_null,
685            negated,
686        } => Ok(Expr::InSet {
687            expr: Box::new(bind_expr(e, params)?),
688            values: values.clone(),
689            has_null: *has_null,
690            negated: *negated,
691        }),
692        Expr::Between {
693            expr: e,
694            low,
695            high,
696            negated,
697        } => Ok(Expr::Between {
698            expr: Box::new(bind_expr(e, params)?),
699            low: Box::new(bind_expr(low, params)?),
700            high: Box::new(bind_expr(high, params)?),
701            negated: *negated,
702        }),
703        Expr::Like {
704            expr: e,
705            pattern,
706            escape,
707            negated,
708        } => Ok(Expr::Like {
709            expr: Box::new(bind_expr(e, params)?),
710            pattern: Box::new(bind_expr(pattern, params)?),
711            escape: escape
712                .as_ref()
713                .map(|esc| bind_expr(esc, params).map(Box::new))
714                .transpose()?,
715            negated: *negated,
716        }),
717        Expr::Case {
718            operand,
719            conditions,
720            else_result,
721        } => {
722            let operand = operand
723                .as_ref()
724                .map(|e| bind_expr(e, params).map(Box::new))
725                .transpose()?;
726            let conditions = conditions
727                .iter()
728                .map(|(cond, then)| Ok((bind_expr(cond, params)?, bind_expr(then, params)?)))
729                .collect::<crate::error::Result<Vec<_>>>()?;
730            let else_result = else_result
731                .as_ref()
732                .map(|e| bind_expr(e, params).map(Box::new))
733                .transpose()?;
734            Ok(Expr::Case {
735                operand,
736                conditions,
737                else_result,
738            })
739        }
740        Expr::Coalesce(args) => {
741            let args = args
742                .iter()
743                .map(|a| bind_expr(a, params))
744                .collect::<crate::error::Result<Vec<_>>>()?;
745            Ok(Expr::Coalesce(args))
746        }
747        Expr::Cast { expr: e, data_type } => Ok(Expr::Cast {
748            expr: Box::new(bind_expr(e, params)?),
749            data_type: *data_type,
750        }),
751    }
752}
753
754fn visit_exprs_stmt(stmt: &Statement, visitor: &mut impl FnMut(&Expr)) {
755    match stmt {
756        Statement::Select(sq) => {
757            for cte in &sq.ctes {
758                visit_exprs_query_body(&cte.body, visitor);
759            }
760            visit_exprs_query_body(&sq.body, visitor);
761        }
762        Statement::Insert(ins) => match &ins.source {
763            InsertSource::Values(rows) => {
764                for row in rows {
765                    for e in row {
766                        visit_expr(e, visitor);
767                    }
768                }
769            }
770            InsertSource::Select(sq) => {
771                for cte in &sq.ctes {
772                    visit_exprs_query_body(&cte.body, visitor);
773                }
774                visit_exprs_query_body(&sq.body, visitor);
775            }
776        },
777        Statement::Update(upd) => {
778            for (_, e) in &upd.assignments {
779                visit_expr(e, visitor);
780            }
781            if let Some(w) = &upd.where_clause {
782                visit_expr(w, visitor);
783            }
784        }
785        Statement::Delete(del) => {
786            if let Some(w) = &del.where_clause {
787                visit_expr(w, visitor);
788            }
789        }
790        Statement::Explain(inner) => visit_exprs_stmt(inner, visitor),
791        _ => {}
792    }
793}
794
795fn visit_exprs_query_body(body: &QueryBody, visitor: &mut impl FnMut(&Expr)) {
796    match body {
797        QueryBody::Select(sel) => visit_exprs_select(sel, visitor),
798        QueryBody::Compound(comp) => {
799            visit_exprs_query_body(&comp.left, visitor);
800            visit_exprs_query_body(&comp.right, visitor);
801            for o in &comp.order_by {
802                visit_expr(&o.expr, visitor);
803            }
804            if let Some(l) = &comp.limit {
805                visit_expr(l, visitor);
806            }
807            if let Some(o) = &comp.offset {
808                visit_expr(o, visitor);
809            }
810        }
811    }
812}
813
814fn visit_exprs_select(sel: &SelectStmt, visitor: &mut impl FnMut(&Expr)) {
815    for col in &sel.columns {
816        if let SelectColumn::Expr { expr, .. } = col {
817            visit_expr(expr, visitor);
818        }
819    }
820    for j in &sel.joins {
821        if let Some(on) = &j.on_clause {
822            visit_expr(on, visitor);
823        }
824    }
825    if let Some(w) = &sel.where_clause {
826        visit_expr(w, visitor);
827    }
828    for o in &sel.order_by {
829        visit_expr(&o.expr, visitor);
830    }
831    if let Some(l) = &sel.limit {
832        visit_expr(l, visitor);
833    }
834    if let Some(o) = &sel.offset {
835        visit_expr(o, visitor);
836    }
837    for g in &sel.group_by {
838        visit_expr(g, visitor);
839    }
840    if let Some(h) = &sel.having {
841        visit_expr(h, visitor);
842    }
843}
844
845fn visit_expr(expr: &Expr, visitor: &mut impl FnMut(&Expr)) {
846    visitor(expr);
847    match expr {
848        Expr::BinaryOp { left, right, .. } => {
849            visit_expr(left, visitor);
850            visit_expr(right, visitor);
851        }
852        Expr::UnaryOp { expr: e, .. } | Expr::IsNull(e) | Expr::IsNotNull(e) => {
853            visit_expr(e, visitor);
854        }
855        Expr::Function { args, .. } | Expr::Coalesce(args) => {
856            for a in args {
857                visit_expr(a, visitor);
858            }
859        }
860        Expr::InSubquery {
861            expr: e, subquery, ..
862        } => {
863            visit_expr(e, visitor);
864            visit_exprs_select(subquery, visitor);
865        }
866        Expr::InList { expr: e, list, .. } => {
867            visit_expr(e, visitor);
868            for l in list {
869                visit_expr(l, visitor);
870            }
871        }
872        Expr::Exists { subquery, .. } => visit_exprs_select(subquery, visitor),
873        Expr::ScalarSubquery(sq) => visit_exprs_select(sq, visitor),
874        Expr::InSet { expr: e, .. } => visit_expr(e, visitor),
875        Expr::Between {
876            expr: e, low, high, ..
877        } => {
878            visit_expr(e, visitor);
879            visit_expr(low, visitor);
880            visit_expr(high, visitor);
881        }
882        Expr::Like {
883            expr: e,
884            pattern,
885            escape,
886            ..
887        } => {
888            visit_expr(e, visitor);
889            visit_expr(pattern, visitor);
890            if let Some(esc) = escape {
891                visit_expr(esc, visitor);
892            }
893        }
894        Expr::Case {
895            operand,
896            conditions,
897            else_result,
898        } => {
899            if let Some(op) = operand {
900                visit_expr(op, visitor);
901            }
902            for (cond, then) in conditions {
903                visit_expr(cond, visitor);
904                visit_expr(then, visitor);
905            }
906            if let Some(el) = else_result {
907                visit_expr(el, visitor);
908            }
909        }
910        Expr::Cast { expr: e, .. } => visit_expr(e, visitor),
911        Expr::Literal(_)
912        | Expr::Column(_)
913        | Expr::QualifiedColumn { .. }
914        | Expr::CountStar
915        | Expr::Parameter(_) => {}
916    }
917}
918
919// ── Statement conversion ────────────────────────────────────────────
920
921fn convert_statement(stmt: sp::Statement) -> Result<Statement> {
922    match stmt {
923        sp::Statement::CreateTable(ct) => convert_create_table(ct),
924        sp::Statement::CreateIndex(ci) => convert_create_index(ci),
925        sp::Statement::Drop {
926            object_type: sp::ObjectType::Table,
927            if_exists,
928            names,
929            ..
930        } => {
931            if names.len() != 1 {
932                return Err(SqlError::Unsupported("multi-table DROP".into()));
933            }
934            Ok(Statement::DropTable(DropTableStmt {
935                name: object_name_to_string(&names[0]),
936                if_exists,
937            }))
938        }
939        sp::Statement::Drop {
940            object_type: sp::ObjectType::Index,
941            if_exists,
942            names,
943            ..
944        } => {
945            if names.len() != 1 {
946                return Err(SqlError::Unsupported("multi-index DROP".into()));
947            }
948            Ok(Statement::DropIndex(DropIndexStmt {
949                index_name: object_name_to_string(&names[0]),
950                if_exists,
951            }))
952        }
953        sp::Statement::AlterTable(at) => convert_alter_table(at),
954        sp::Statement::Insert(insert) => convert_insert(insert),
955        sp::Statement::Query(query) => convert_query(*query),
956        sp::Statement::Update(update) => convert_update(update),
957        sp::Statement::Delete(delete) => convert_delete(delete),
958        sp::Statement::StartTransaction { .. } => Ok(Statement::Begin),
959        sp::Statement::Commit { .. } => Ok(Statement::Commit),
960        sp::Statement::Rollback { .. } => Ok(Statement::Rollback),
961        sp::Statement::Explain {
962            statement, analyze, ..
963        } => {
964            if analyze {
965                return Err(SqlError::Unsupported("EXPLAIN ANALYZE".into()));
966            }
967            let inner = convert_statement(*statement)?;
968            Ok(Statement::Explain(Box::new(inner)))
969        }
970        _ => Err(SqlError::Unsupported(format!("statement type: {}", stmt))),
971    }
972}
973
974/// Parse column options (NOT NULL, DEFAULT, CHECK, FK) from a sqlparser ColumnDef.
975/// Returns (ColumnSpec, Option<ForeignKeyDef>, was_inline_pk).
976fn convert_column_def(
977    col_def: &sp::ColumnDef,
978) -> Result<(ColumnSpec, Option<ForeignKeyDef>, bool)> {
979    let col_name = col_def.name.value.clone();
980    let data_type = convert_data_type(&col_def.data_type)?;
981    let mut nullable = true;
982    let mut is_primary_key = false;
983    let mut default_expr = None;
984    let mut default_sql = None;
985    let mut check_expr = None;
986    let mut check_sql = None;
987    let mut check_name = None;
988    let mut fk_def = None;
989
990    for opt in &col_def.options {
991        match &opt.option {
992            sp::ColumnOption::NotNull => nullable = false,
993            sp::ColumnOption::Null => nullable = true,
994            sp::ColumnOption::PrimaryKey(_) => {
995                is_primary_key = true;
996                nullable = false;
997            }
998            sp::ColumnOption::Default(expr) => {
999                default_sql = Some(expr.to_string());
1000                default_expr = Some(convert_expr(expr)?);
1001            }
1002            sp::ColumnOption::Check(check) => {
1003                check_sql = Some(check.expr.to_string());
1004                let converted = convert_expr(&check.expr)?;
1005                if has_subquery(&converted) {
1006                    return Err(SqlError::Unsupported("subquery in CHECK constraint".into()));
1007                }
1008                check_expr = Some(converted);
1009                check_name = check.name.as_ref().map(|n| n.value.clone());
1010            }
1011            sp::ColumnOption::ForeignKey(fk) => {
1012                convert_fk_actions(&fk.on_delete, &fk.on_update)?;
1013                let ftable = object_name_to_string(&fk.foreign_table).to_ascii_lowercase();
1014                let referred: Vec<String> = fk
1015                    .referred_columns
1016                    .iter()
1017                    .map(|i| i.value.to_ascii_lowercase())
1018                    .collect();
1019                fk_def = Some(ForeignKeyDef {
1020                    name: fk.name.as_ref().map(|n| n.value.clone()),
1021                    columns: vec![col_name.to_ascii_lowercase()],
1022                    foreign_table: ftable,
1023                    referred_columns: referred,
1024                });
1025            }
1026            _ => {}
1027        }
1028    }
1029
1030    let spec = ColumnSpec {
1031        name: col_name,
1032        data_type,
1033        nullable,
1034        is_primary_key,
1035        default_expr,
1036        default_sql,
1037        check_expr,
1038        check_sql,
1039        check_name,
1040    };
1041    Ok((spec, fk_def, is_primary_key))
1042}
1043
1044fn convert_create_table(ct: sp::CreateTable) -> Result<Statement> {
1045    let name = object_name_to_string(&ct.name);
1046    let if_not_exists = ct.if_not_exists;
1047
1048    let mut columns = Vec::new();
1049    let mut inline_pk: Vec<String> = Vec::new();
1050    let mut foreign_keys: Vec<ForeignKeyDef> = Vec::new();
1051
1052    for col_def in &ct.columns {
1053        let (spec, fk_def, was_pk) = convert_column_def(col_def)?;
1054        if was_pk {
1055            inline_pk.push(spec.name.clone());
1056        }
1057        if let Some(fk) = fk_def {
1058            foreign_keys.push(fk);
1059        }
1060        columns.push(spec);
1061    }
1062
1063    // Check table-level constraints
1064    let mut check_constraints: Vec<TableCheckConstraint> = Vec::new();
1065
1066    for constraint in &ct.constraints {
1067        match constraint {
1068            sp::TableConstraint::PrimaryKey(pk_constraint) => {
1069                for idx_col in &pk_constraint.columns {
1070                    let col_name = match &idx_col.column.expr {
1071                        sp::Expr::Identifier(ident) => ident.value.clone(),
1072                        _ => continue,
1073                    };
1074                    if !inline_pk.contains(&col_name) {
1075                        inline_pk.push(col_name.clone());
1076                    }
1077                    if let Some(col) = columns.iter_mut().find(|c| c.name == col_name) {
1078                        col.nullable = false;
1079                        col.is_primary_key = true;
1080                    }
1081                }
1082            }
1083            sp::TableConstraint::Check(check) => {
1084                let sql = check.expr.to_string();
1085                let converted = convert_expr(&check.expr)?;
1086                if has_subquery(&converted) {
1087                    return Err(SqlError::Unsupported("subquery in CHECK constraint".into()));
1088                }
1089                check_constraints.push(TableCheckConstraint {
1090                    name: check.name.as_ref().map(|n| n.value.clone()),
1091                    expr: converted,
1092                    sql,
1093                });
1094            }
1095            sp::TableConstraint::ForeignKey(fk) => {
1096                convert_fk_actions(&fk.on_delete, &fk.on_update)?;
1097                let cols: Vec<String> = fk
1098                    .columns
1099                    .iter()
1100                    .map(|i| i.value.to_ascii_lowercase())
1101                    .collect();
1102                let ftable = object_name_to_string(&fk.foreign_table).to_ascii_lowercase();
1103                let referred: Vec<String> = fk
1104                    .referred_columns
1105                    .iter()
1106                    .map(|i| i.value.to_ascii_lowercase())
1107                    .collect();
1108                foreign_keys.push(ForeignKeyDef {
1109                    name: fk.name.as_ref().map(|n| n.value.clone()),
1110                    columns: cols,
1111                    foreign_table: ftable,
1112                    referred_columns: referred,
1113                });
1114            }
1115            _ => {}
1116        }
1117    }
1118
1119    Ok(Statement::CreateTable(CreateTableStmt {
1120        name,
1121        columns,
1122        primary_key: inline_pk,
1123        if_not_exists,
1124        check_constraints,
1125        foreign_keys,
1126    }))
1127}
1128
1129fn convert_alter_table(at: sp::AlterTable) -> Result<Statement> {
1130    let table = object_name_to_string(&at.name);
1131    if at.operations.len() != 1 {
1132        return Err(SqlError::Unsupported(
1133            "ALTER TABLE with multiple operations".into(),
1134        ));
1135    }
1136    let op = match at.operations.into_iter().next().unwrap() {
1137        sp::AlterTableOperation::AddColumn {
1138            column_def,
1139            if_not_exists,
1140            ..
1141        } => {
1142            let (spec, fk, _was_pk) = convert_column_def(&column_def)?;
1143            AlterTableOp::AddColumn {
1144                column: Box::new(spec),
1145                foreign_key: fk,
1146                if_not_exists,
1147            }
1148        }
1149        sp::AlterTableOperation::DropColumn {
1150            column_names,
1151            if_exists,
1152            ..
1153        } => {
1154            if column_names.len() != 1 {
1155                return Err(SqlError::Unsupported(
1156                    "DROP COLUMN with multiple columns".into(),
1157                ));
1158            }
1159            AlterTableOp::DropColumn {
1160                name: column_names.into_iter().next().unwrap().value,
1161                if_exists,
1162            }
1163        }
1164        sp::AlterTableOperation::RenameColumn {
1165            old_column_name,
1166            new_column_name,
1167        } => AlterTableOp::RenameColumn {
1168            old_name: old_column_name.value,
1169            new_name: new_column_name.value,
1170        },
1171        sp::AlterTableOperation::RenameTable { table_name } => {
1172            let new_name = match table_name {
1173                sp::RenameTableNameKind::To(name) | sp::RenameTableNameKind::As(name) => {
1174                    object_name_to_string(&name)
1175                }
1176            };
1177            AlterTableOp::RenameTable { new_name }
1178        }
1179        other => {
1180            return Err(SqlError::Unsupported(format!(
1181                "ALTER TABLE operation: {other}"
1182            )));
1183        }
1184    };
1185    Ok(Statement::AlterTable(Box::new(AlterTableStmt {
1186        table,
1187        op,
1188    })))
1189}
1190
1191fn convert_fk_actions(
1192    on_delete: &Option<sp::ReferentialAction>,
1193    on_update: &Option<sp::ReferentialAction>,
1194) -> Result<()> {
1195    for action in [on_delete, on_update] {
1196        match action {
1197            None
1198            | Some(sp::ReferentialAction::Restrict)
1199            | Some(sp::ReferentialAction::NoAction) => {}
1200            Some(other) => {
1201                return Err(SqlError::Unsupported(format!(
1202                    "FOREIGN KEY action: {other}"
1203                )));
1204            }
1205        }
1206    }
1207    Ok(())
1208}
1209
1210fn convert_create_index(ci: sp::CreateIndex) -> Result<Statement> {
1211    let index_name = ci
1212        .name
1213        .as_ref()
1214        .map(object_name_to_string)
1215        .ok_or_else(|| SqlError::Parse("index name required".into()))?;
1216
1217    let table_name = object_name_to_string(&ci.table_name);
1218
1219    let columns: Vec<String> = ci
1220        .columns
1221        .iter()
1222        .map(|idx_col| match &idx_col.column.expr {
1223            sp::Expr::Identifier(ident) => Ok(ident.value.clone()),
1224            other => Err(SqlError::Unsupported(format!("expression index: {other}"))),
1225        })
1226        .collect::<Result<_>>()?;
1227
1228    if columns.is_empty() {
1229        return Err(SqlError::Parse(
1230            "index must have at least one column".into(),
1231        ));
1232    }
1233
1234    Ok(Statement::CreateIndex(CreateIndexStmt {
1235        index_name,
1236        table_name,
1237        columns,
1238        unique: ci.unique,
1239        if_not_exists: ci.if_not_exists,
1240    }))
1241}
1242
1243fn convert_insert(insert: sp::Insert) -> Result<Statement> {
1244    let table = match &insert.table {
1245        sp::TableObject::TableName(name) => object_name_to_string(name).to_ascii_lowercase(),
1246        _ => return Err(SqlError::Unsupported("INSERT into non-table object".into())),
1247    };
1248
1249    let columns: Vec<String> = insert
1250        .columns
1251        .iter()
1252        .map(|c| c.value.to_ascii_lowercase())
1253        .collect();
1254
1255    let query = insert
1256        .source
1257        .ok_or_else(|| SqlError::Parse("INSERT requires VALUES or SELECT".into()))?;
1258
1259    let source = match *query.body {
1260        sp::SetExpr::Values(sp::Values { rows, .. }) => {
1261            let mut result = Vec::new();
1262            for row in rows {
1263                let mut exprs = Vec::new();
1264                for expr in row {
1265                    exprs.push(convert_expr(&expr)?);
1266                }
1267                result.push(exprs);
1268            }
1269            InsertSource::Values(result)
1270        }
1271        _ => {
1272            let (ctes, recursive) = if let Some(ref with) = query.with {
1273                convert_with(with)?
1274            } else {
1275                (vec![], false)
1276            };
1277            let body = convert_query_body(&query)?;
1278            InsertSource::Select(Box::new(SelectQuery {
1279                ctes,
1280                recursive,
1281                body,
1282            }))
1283        }
1284    };
1285
1286    Ok(Statement::Insert(InsertStmt {
1287        table,
1288        columns,
1289        source,
1290    }))
1291}
1292
1293fn convert_select_body(select: &sp::Select) -> Result<SelectStmt> {
1294    let distinct = match &select.distinct {
1295        Some(sp::Distinct::Distinct) => true,
1296        Some(sp::Distinct::On(_)) => {
1297            return Err(SqlError::Unsupported("DISTINCT ON".into()));
1298        }
1299        _ => false,
1300    };
1301
1302    // FROM clause
1303    let (from, from_alias, joins) = if select.from.is_empty() {
1304        (String::new(), None, vec![])
1305    } else if select.from.len() == 1 {
1306        let table_with_joins = &select.from[0];
1307        let (name, alias) = match &table_with_joins.relation {
1308            sp::TableFactor::Table { name, alias, .. } => {
1309                let table_name = object_name_to_string(name);
1310                let alias_str = alias.as_ref().map(|a| a.name.value.clone());
1311                (table_name, alias_str)
1312            }
1313            _ => return Err(SqlError::Unsupported("non-table FROM source".into())),
1314        };
1315        let j = table_with_joins
1316            .joins
1317            .iter()
1318            .map(convert_join)
1319            .collect::<Result<Vec<_>>>()?;
1320        (name, alias, j)
1321    } else {
1322        return Err(SqlError::Unsupported("comma-separated FROM tables".into()));
1323    };
1324
1325    // Projection
1326    let columns: Vec<SelectColumn> = select
1327        .projection
1328        .iter()
1329        .map(convert_select_item)
1330        .collect::<Result<_>>()?;
1331
1332    // WHERE
1333    let where_clause = select.selection.as_ref().map(convert_expr).transpose()?;
1334
1335    // GROUP BY
1336    let group_by = match &select.group_by {
1337        sp::GroupByExpr::Expressions(exprs, _) => {
1338            exprs.iter().map(convert_expr).collect::<Result<_>>()?
1339        }
1340        sp::GroupByExpr::All(_) => {
1341            return Err(SqlError::Unsupported("GROUP BY ALL".into()));
1342        }
1343    };
1344
1345    // HAVING
1346    let having = select.having.as_ref().map(convert_expr).transpose()?;
1347
1348    Ok(SelectStmt {
1349        columns,
1350        from,
1351        from_alias,
1352        joins,
1353        distinct,
1354        where_clause,
1355        order_by: vec![],
1356        limit: None,
1357        offset: None,
1358        group_by,
1359        having,
1360    })
1361}
1362
1363fn convert_set_expr(set_expr: &sp::SetExpr) -> Result<QueryBody> {
1364    match set_expr {
1365        sp::SetExpr::Select(sel) => Ok(QueryBody::Select(Box::new(convert_select_body(sel)?))),
1366        sp::SetExpr::SetOperation {
1367            op,
1368            set_quantifier,
1369            left,
1370            right,
1371        } => {
1372            let set_op = match op {
1373                sp::SetOperator::Union => SetOp::Union,
1374                sp::SetOperator::Intersect => SetOp::Intersect,
1375                sp::SetOperator::Except | sp::SetOperator::Minus => SetOp::Except,
1376            };
1377            let all = match set_quantifier {
1378                sp::SetQuantifier::All => true,
1379                sp::SetQuantifier::None | sp::SetQuantifier::Distinct => false,
1380                _ => {
1381                    return Err(SqlError::Unsupported("BY NAME set operations".into()));
1382                }
1383            };
1384            Ok(QueryBody::Compound(CompoundSelect {
1385                op: set_op,
1386                all,
1387                left: Box::new(convert_set_expr(left)?),
1388                right: Box::new(convert_set_expr(right)?),
1389                order_by: vec![],
1390                limit: None,
1391                offset: None,
1392            }))
1393        }
1394        _ => Err(SqlError::Unsupported("unsupported set expression".into())),
1395    }
1396}
1397
1398fn convert_query_body(query: &sp::Query) -> Result<QueryBody> {
1399    let mut body = convert_set_expr(&query.body)?;
1400
1401    // ORDER BY
1402    let order_by = if let Some(ref ob) = query.order_by {
1403        match &ob.kind {
1404            sp::OrderByKind::Expressions(exprs) => exprs
1405                .iter()
1406                .map(convert_order_by_expr)
1407                .collect::<Result<_>>()?,
1408            sp::OrderByKind::All { .. } => {
1409                return Err(SqlError::Unsupported("ORDER BY ALL".into()));
1410            }
1411        }
1412    } else {
1413        vec![]
1414    };
1415
1416    // LIMIT / OFFSET
1417    let (limit, offset) = match &query.limit_clause {
1418        Some(sp::LimitClause::LimitOffset { limit, offset, .. }) => {
1419            let l = limit.as_ref().map(convert_expr).transpose()?;
1420            let o = offset
1421                .as_ref()
1422                .map(|o| convert_expr(&o.value))
1423                .transpose()?;
1424            (l, o)
1425        }
1426        Some(sp::LimitClause::OffsetCommaLimit { limit, offset }) => {
1427            let l = Some(convert_expr(limit)?);
1428            let o = Some(convert_expr(offset)?);
1429            (l, o)
1430        }
1431        None => (None, None),
1432    };
1433
1434    match &mut body {
1435        QueryBody::Select(sel) => {
1436            sel.order_by = order_by;
1437            sel.limit = limit;
1438            sel.offset = offset;
1439        }
1440        QueryBody::Compound(comp) => {
1441            comp.order_by = order_by;
1442            comp.limit = limit;
1443            comp.offset = offset;
1444        }
1445    }
1446
1447    Ok(body)
1448}
1449
1450fn convert_subquery(query: &sp::Query) -> Result<SelectStmt> {
1451    if query.with.is_some() {
1452        return Err(SqlError::Unsupported("CTEs in subqueries".into()));
1453    }
1454    match convert_query_body(query)? {
1455        QueryBody::Select(s) => Ok(*s),
1456        QueryBody::Compound(_) => Err(SqlError::Unsupported(
1457            "UNION/INTERSECT/EXCEPT in subqueries".into(),
1458        )),
1459    }
1460}
1461
1462fn convert_with(with: &sp::With) -> Result<(Vec<CteDefinition>, bool)> {
1463    let mut names = std::collections::HashSet::new();
1464    let mut ctes = Vec::new();
1465    for cte in &with.cte_tables {
1466        let name = cte.alias.name.value.to_ascii_lowercase();
1467        if !names.insert(name.clone()) {
1468            return Err(SqlError::DuplicateCteName(name));
1469        }
1470        let column_aliases: Vec<String> = cte
1471            .alias
1472            .columns
1473            .iter()
1474            .map(|c| c.name.value.to_ascii_lowercase())
1475            .collect();
1476        let body = convert_query_body(&cte.query)?;
1477        ctes.push(CteDefinition {
1478            name,
1479            column_aliases,
1480            body,
1481        });
1482    }
1483    Ok((ctes, with.recursive))
1484}
1485
1486fn convert_query(query: sp::Query) -> Result<Statement> {
1487    let (ctes, recursive) = if let Some(ref with) = query.with {
1488        convert_with(with)?
1489    } else {
1490        (vec![], false)
1491    };
1492    let body = convert_query_body(&query)?;
1493    Ok(Statement::Select(Box::new(SelectQuery {
1494        ctes,
1495        recursive,
1496        body,
1497    })))
1498}
1499
1500fn convert_join(join: &sp::Join) -> Result<JoinClause> {
1501    let (join_type, constraint) = match &join.join_operator {
1502        sp::JoinOperator::Inner(c) => (JoinType::Inner, Some(c)),
1503        sp::JoinOperator::Join(c) => (JoinType::Inner, Some(c)),
1504        sp::JoinOperator::CrossJoin(c) => (JoinType::Cross, Some(c)),
1505        sp::JoinOperator::Left(c) => (JoinType::Left, Some(c)),
1506        sp::JoinOperator::LeftSemi(c) => (JoinType::Left, Some(c)),
1507        sp::JoinOperator::LeftAnti(c) => (JoinType::Left, Some(c)),
1508        sp::JoinOperator::Right(c) => (JoinType::Right, Some(c)),
1509        sp::JoinOperator::RightSemi(c) => (JoinType::Right, Some(c)),
1510        sp::JoinOperator::RightAnti(c) => (JoinType::Right, Some(c)),
1511        other => return Err(SqlError::Unsupported(format!("join type: {other:?}"))),
1512    };
1513
1514    let (name, alias) = match &join.relation {
1515        sp::TableFactor::Table { name, alias, .. } => {
1516            let table_name = object_name_to_string(name);
1517            let alias_str = alias.as_ref().map(|a| a.name.value.clone());
1518            (table_name, alias_str)
1519        }
1520        _ => return Err(SqlError::Unsupported("non-table JOIN source".into())),
1521    };
1522
1523    let on_clause = match constraint {
1524        Some(sp::JoinConstraint::On(expr)) => Some(convert_expr(expr)?),
1525        Some(sp::JoinConstraint::None) | None => None,
1526        Some(other) => return Err(SqlError::Unsupported(format!("join constraint: {other:?}"))),
1527    };
1528
1529    Ok(JoinClause {
1530        join_type,
1531        table: TableRef { name, alias },
1532        on_clause,
1533    })
1534}
1535
1536fn convert_update(update: sp::Update) -> Result<Statement> {
1537    let table = match &update.table.relation {
1538        sp::TableFactor::Table { name, .. } => object_name_to_string(name),
1539        _ => return Err(SqlError::Unsupported("non-table UPDATE target".into())),
1540    };
1541
1542    let assignments = update
1543        .assignments
1544        .iter()
1545        .map(|a| {
1546            let col = match &a.target {
1547                sp::AssignmentTarget::ColumnName(name) => object_name_to_string(name),
1548                _ => return Err(SqlError::Unsupported("tuple assignment".into())),
1549            };
1550            let expr = convert_expr(&a.value)?;
1551            Ok((col, expr))
1552        })
1553        .collect::<Result<_>>()?;
1554
1555    let where_clause = update.selection.as_ref().map(convert_expr).transpose()?;
1556
1557    Ok(Statement::Update(UpdateStmt {
1558        table,
1559        assignments,
1560        where_clause,
1561    }))
1562}
1563
1564fn convert_delete(delete: sp::Delete) -> Result<Statement> {
1565    let table_name = match &delete.from {
1566        sp::FromTable::WithFromKeyword(tables) => {
1567            if tables.len() != 1 {
1568                return Err(SqlError::Unsupported("multi-table DELETE".into()));
1569            }
1570            match &tables[0].relation {
1571                sp::TableFactor::Table { name, .. } => object_name_to_string(name),
1572                _ => return Err(SqlError::Unsupported("non-table DELETE target".into())),
1573            }
1574        }
1575        sp::FromTable::WithoutKeyword(tables) => {
1576            if tables.len() != 1 {
1577                return Err(SqlError::Unsupported("multi-table DELETE".into()));
1578            }
1579            match &tables[0].relation {
1580                sp::TableFactor::Table { name, .. } => object_name_to_string(name),
1581                _ => return Err(SqlError::Unsupported("non-table DELETE target".into())),
1582            }
1583        }
1584    };
1585
1586    let where_clause = delete.selection.as_ref().map(convert_expr).transpose()?;
1587
1588    Ok(Statement::Delete(DeleteStmt {
1589        table: table_name,
1590        where_clause,
1591    }))
1592}
1593
1594// ── Expression conversion ───────────────────────────────────────────
1595
1596fn convert_expr(expr: &sp::Expr) -> Result<Expr> {
1597    match expr {
1598        sp::Expr::Value(v) => convert_value(&v.value),
1599        sp::Expr::Identifier(ident) => Ok(Expr::Column(ident.value.to_ascii_lowercase())),
1600        sp::Expr::CompoundIdentifier(parts) => {
1601            if parts.len() == 2 {
1602                Ok(Expr::QualifiedColumn {
1603                    table: parts[0].value.to_ascii_lowercase(),
1604                    column: parts[1].value.to_ascii_lowercase(),
1605                })
1606            } else {
1607                Ok(Expr::Column(
1608                    parts.last().unwrap().value.to_ascii_lowercase(),
1609                ))
1610            }
1611        }
1612        sp::Expr::BinaryOp { left, op, right } => {
1613            let bin_op = convert_bin_op(op)?;
1614            Ok(Expr::BinaryOp {
1615                left: Box::new(convert_expr(left)?),
1616                op: bin_op,
1617                right: Box::new(convert_expr(right)?),
1618            })
1619        }
1620        sp::Expr::UnaryOp { op, expr } => {
1621            let unary_op = match op {
1622                sp::UnaryOperator::Minus => UnaryOp::Neg,
1623                sp::UnaryOperator::Not => UnaryOp::Not,
1624                _ => return Err(SqlError::Unsupported(format!("unary op: {op}"))),
1625            };
1626            Ok(Expr::UnaryOp {
1627                op: unary_op,
1628                expr: Box::new(convert_expr(expr)?),
1629            })
1630        }
1631        sp::Expr::IsNull(e) => Ok(Expr::IsNull(Box::new(convert_expr(e)?))),
1632        sp::Expr::IsNotNull(e) => Ok(Expr::IsNotNull(Box::new(convert_expr(e)?))),
1633        sp::Expr::Nested(e) => convert_expr(e),
1634        sp::Expr::Function(func) => convert_function(func),
1635        sp::Expr::InSubquery {
1636            expr: e,
1637            subquery,
1638            negated,
1639        } => {
1640            let inner_expr = convert_expr(e)?;
1641            let stmt = convert_subquery(subquery)?;
1642            Ok(Expr::InSubquery {
1643                expr: Box::new(inner_expr),
1644                subquery: Box::new(stmt),
1645                negated: *negated,
1646            })
1647        }
1648        sp::Expr::InList {
1649            expr: e,
1650            list,
1651            negated,
1652        } => {
1653            let inner_expr = convert_expr(e)?;
1654            let items = list.iter().map(convert_expr).collect::<Result<Vec<_>>>()?;
1655            Ok(Expr::InList {
1656                expr: Box::new(inner_expr),
1657                list: items,
1658                negated: *negated,
1659            })
1660        }
1661        sp::Expr::Exists { subquery, negated } => {
1662            let stmt = convert_subquery(subquery)?;
1663            Ok(Expr::Exists {
1664                subquery: Box::new(stmt),
1665                negated: *negated,
1666            })
1667        }
1668        sp::Expr::Subquery(query) => {
1669            let stmt = convert_subquery(query)?;
1670            Ok(Expr::ScalarSubquery(Box::new(stmt)))
1671        }
1672        sp::Expr::Between {
1673            expr: e,
1674            negated,
1675            low,
1676            high,
1677        } => Ok(Expr::Between {
1678            expr: Box::new(convert_expr(e)?),
1679            low: Box::new(convert_expr(low)?),
1680            high: Box::new(convert_expr(high)?),
1681            negated: *negated,
1682        }),
1683        sp::Expr::Like {
1684            expr: e,
1685            negated,
1686            pattern,
1687            escape_char,
1688            ..
1689        } => {
1690            let esc = escape_char
1691                .as_ref()
1692                .map(convert_escape_value)
1693                .transpose()?
1694                .map(Box::new);
1695            Ok(Expr::Like {
1696                expr: Box::new(convert_expr(e)?),
1697                pattern: Box::new(convert_expr(pattern)?),
1698                escape: esc,
1699                negated: *negated,
1700            })
1701        }
1702        sp::Expr::ILike {
1703            expr: e,
1704            negated,
1705            pattern,
1706            escape_char,
1707            ..
1708        } => {
1709            let esc = escape_char
1710                .as_ref()
1711                .map(convert_escape_value)
1712                .transpose()?
1713                .map(Box::new);
1714            Ok(Expr::Like {
1715                expr: Box::new(convert_expr(e)?),
1716                pattern: Box::new(convert_expr(pattern)?),
1717                escape: esc,
1718                negated: *negated,
1719            })
1720        }
1721        sp::Expr::Case {
1722            operand,
1723            conditions,
1724            else_result,
1725            ..
1726        } => {
1727            let op = operand
1728                .as_ref()
1729                .map(|e| convert_expr(e))
1730                .transpose()?
1731                .map(Box::new);
1732            let conds: Vec<(Expr, Expr)> = conditions
1733                .iter()
1734                .map(|cw| Ok((convert_expr(&cw.condition)?, convert_expr(&cw.result)?)))
1735                .collect::<Result<_>>()?;
1736            let else_r = else_result
1737                .as_ref()
1738                .map(|e| convert_expr(e))
1739                .transpose()?
1740                .map(Box::new);
1741            Ok(Expr::Case {
1742                operand: op,
1743                conditions: conds,
1744                else_result: else_r,
1745            })
1746        }
1747        sp::Expr::Cast {
1748            expr: e,
1749            data_type: dt,
1750            ..
1751        } => {
1752            let target = convert_data_type(dt)?;
1753            Ok(Expr::Cast {
1754                expr: Box::new(convert_expr(e)?),
1755                data_type: target,
1756            })
1757        }
1758        sp::Expr::Substring {
1759            expr: e,
1760            substring_from,
1761            substring_for,
1762            ..
1763        } => {
1764            let mut args = vec![convert_expr(e)?];
1765            if let Some(from) = substring_from {
1766                args.push(convert_expr(from)?);
1767            }
1768            if let Some(f) = substring_for {
1769                args.push(convert_expr(f)?);
1770            }
1771            Ok(Expr::Function {
1772                name: "SUBSTR".into(),
1773                args,
1774            })
1775        }
1776        sp::Expr::Trim {
1777            expr: e,
1778            trim_where,
1779            trim_what,
1780            trim_characters,
1781        } => {
1782            let fn_name = match trim_where {
1783                Some(sp::TrimWhereField::Leading) => "LTRIM",
1784                Some(sp::TrimWhereField::Trailing) => "RTRIM",
1785                _ => "TRIM",
1786            };
1787            let mut args = vec![convert_expr(e)?];
1788            if let Some(what) = trim_what {
1789                args.push(convert_expr(what)?);
1790            } else if let Some(chars) = trim_characters {
1791                if let Some(first) = chars.first() {
1792                    args.push(convert_expr(first)?);
1793                }
1794            }
1795            Ok(Expr::Function {
1796                name: fn_name.into(),
1797                args,
1798            })
1799        }
1800        sp::Expr::Ceil { expr: e, .. } => Ok(Expr::Function {
1801            name: "CEIL".into(),
1802            args: vec![convert_expr(e)?],
1803        }),
1804        sp::Expr::Floor { expr: e, .. } => Ok(Expr::Function {
1805            name: "FLOOR".into(),
1806            args: vec![convert_expr(e)?],
1807        }),
1808        sp::Expr::Position { expr: e, r#in } => Ok(Expr::Function {
1809            name: "INSTR".into(),
1810            args: vec![convert_expr(r#in)?, convert_expr(e)?],
1811        }),
1812        _ => Err(SqlError::Unsupported(format!("expression: {expr}"))),
1813    }
1814}
1815
1816fn convert_value(val: &sp::Value) -> Result<Expr> {
1817    match val {
1818        sp::Value::Number(n, _) => {
1819            if let Ok(i) = n.parse::<i64>() {
1820                Ok(Expr::Literal(Value::Integer(i)))
1821            } else if let Ok(f) = n.parse::<f64>() {
1822                Ok(Expr::Literal(Value::Real(f)))
1823            } else {
1824                Err(SqlError::InvalidValue(format!("cannot parse number: {n}")))
1825            }
1826        }
1827        sp::Value::SingleQuotedString(s) => Ok(Expr::Literal(Value::Text(s.as_str().into()))),
1828        sp::Value::Boolean(b) => Ok(Expr::Literal(Value::Boolean(*b))),
1829        sp::Value::Null => Ok(Expr::Literal(Value::Null)),
1830        sp::Value::Placeholder(s) => {
1831            let idx_str = s
1832                .strip_prefix('$')
1833                .ok_or_else(|| SqlError::Parse(format!("invalid placeholder: {s}")))?;
1834            let idx: usize = idx_str
1835                .parse()
1836                .map_err(|_| SqlError::Parse(format!("invalid placeholder index: {s}")))?;
1837            if idx == 0 {
1838                return Err(SqlError::Parse("placeholder index must be >= 1".into()));
1839            }
1840            Ok(Expr::Parameter(idx))
1841        }
1842        _ => Err(SqlError::Unsupported(format!("value type: {val}"))),
1843    }
1844}
1845
1846fn convert_escape_value(val: &sp::Value) -> Result<Expr> {
1847    match val {
1848        sp::Value::SingleQuotedString(s) => Ok(Expr::Literal(Value::Text(s.as_str().into()))),
1849        _ => Err(SqlError::Unsupported(format!("ESCAPE value: {val}"))),
1850    }
1851}
1852
1853fn convert_bin_op(op: &sp::BinaryOperator) -> Result<BinOp> {
1854    match op {
1855        sp::BinaryOperator::Plus => Ok(BinOp::Add),
1856        sp::BinaryOperator::Minus => Ok(BinOp::Sub),
1857        sp::BinaryOperator::Multiply => Ok(BinOp::Mul),
1858        sp::BinaryOperator::Divide => Ok(BinOp::Div),
1859        sp::BinaryOperator::Modulo => Ok(BinOp::Mod),
1860        sp::BinaryOperator::Eq => Ok(BinOp::Eq),
1861        sp::BinaryOperator::NotEq => Ok(BinOp::NotEq),
1862        sp::BinaryOperator::Lt => Ok(BinOp::Lt),
1863        sp::BinaryOperator::Gt => Ok(BinOp::Gt),
1864        sp::BinaryOperator::LtEq => Ok(BinOp::LtEq),
1865        sp::BinaryOperator::GtEq => Ok(BinOp::GtEq),
1866        sp::BinaryOperator::And => Ok(BinOp::And),
1867        sp::BinaryOperator::Or => Ok(BinOp::Or),
1868        sp::BinaryOperator::StringConcat => Ok(BinOp::Concat),
1869        _ => Err(SqlError::Unsupported(format!("binary op: {op}"))),
1870    }
1871}
1872
1873fn convert_function(func: &sp::Function) -> Result<Expr> {
1874    let name = object_name_to_string(&func.name).to_ascii_uppercase();
1875
1876    // COUNT(*)
1877    match &func.args {
1878        sp::FunctionArguments::List(list) => {
1879            if list.args.is_empty() && name == "COUNT" {
1880                return Ok(Expr::CountStar);
1881            }
1882            let args = list
1883                .args
1884                .iter()
1885                .map(|arg| match arg {
1886                    sp::FunctionArg::Unnamed(sp::FunctionArgExpr::Expr(e)) => convert_expr(e),
1887                    sp::FunctionArg::Unnamed(sp::FunctionArgExpr::Wildcard) => {
1888                        if name == "COUNT" {
1889                            Ok(Expr::CountStar)
1890                        } else {
1891                            Err(SqlError::Unsupported(format!("{name}(*)")))
1892                        }
1893                    }
1894                    _ => Err(SqlError::Unsupported(format!(
1895                        "function arg type in {name}"
1896                    ))),
1897                })
1898                .collect::<Result<Vec<_>>>()?;
1899
1900            if name == "COUNT" && args.len() == 1 && matches!(args[0], Expr::CountStar) {
1901                return Ok(Expr::CountStar);
1902            }
1903
1904            if name == "COALESCE" {
1905                if args.is_empty() {
1906                    return Err(SqlError::Parse(
1907                        "COALESCE requires at least one argument".into(),
1908                    ));
1909                }
1910                return Ok(Expr::Coalesce(args));
1911            }
1912
1913            if name == "NULLIF" {
1914                if args.len() != 2 {
1915                    return Err(SqlError::Parse(
1916                        "NULLIF requires exactly two arguments".into(),
1917                    ));
1918                }
1919                return Ok(Expr::Case {
1920                    operand: None,
1921                    conditions: vec![(
1922                        Expr::BinaryOp {
1923                            left: Box::new(args[0].clone()),
1924                            op: BinOp::Eq,
1925                            right: Box::new(args[1].clone()),
1926                        },
1927                        Expr::Literal(Value::Null),
1928                    )],
1929                    else_result: Some(Box::new(args[0].clone())),
1930                });
1931            }
1932
1933            if name == "IIF" {
1934                if args.len() != 3 {
1935                    return Err(SqlError::Parse(
1936                        "IIF requires exactly three arguments".into(),
1937                    ));
1938                }
1939                return Ok(Expr::Case {
1940                    operand: None,
1941                    conditions: vec![(args[0].clone(), args[1].clone())],
1942                    else_result: Some(Box::new(args[2].clone())),
1943                });
1944            }
1945
1946            Ok(Expr::Function { name, args })
1947        }
1948        sp::FunctionArguments::None => {
1949            if name == "COUNT" {
1950                Ok(Expr::CountStar)
1951            } else {
1952                Ok(Expr::Function { name, args: vec![] })
1953            }
1954        }
1955        sp::FunctionArguments::Subquery(_) => {
1956            Err(SqlError::Unsupported("subquery in function".into()))
1957        }
1958    }
1959}
1960
1961fn convert_select_item(item: &sp::SelectItem) -> Result<SelectColumn> {
1962    match item {
1963        sp::SelectItem::Wildcard(_) => Ok(SelectColumn::AllColumns),
1964        sp::SelectItem::UnnamedExpr(e) => {
1965            let expr = convert_expr(e)?;
1966            Ok(SelectColumn::Expr { expr, alias: None })
1967        }
1968        sp::SelectItem::ExprWithAlias { expr, alias } => {
1969            let expr = convert_expr(expr)?;
1970            Ok(SelectColumn::Expr {
1971                expr,
1972                alias: Some(alias.value.clone()),
1973            })
1974        }
1975        sp::SelectItem::QualifiedWildcard(_, _) => {
1976            Err(SqlError::Unsupported("qualified wildcard (table.*)".into()))
1977        }
1978    }
1979}
1980
1981fn convert_order_by_expr(expr: &sp::OrderByExpr) -> Result<OrderByItem> {
1982    let e = convert_expr(&expr.expr)?;
1983    let descending = expr.options.asc.map(|asc| !asc).unwrap_or(false);
1984    let nulls_first = expr.options.nulls_first;
1985
1986    Ok(OrderByItem {
1987        expr: e,
1988        descending,
1989        nulls_first,
1990    })
1991}
1992
1993// ── Data type conversion ────────────────────────────────────────────
1994
1995fn convert_data_type(dt: &sp::DataType) -> Result<DataType> {
1996    match dt {
1997        sp::DataType::Int(_)
1998        | sp::DataType::Integer(_)
1999        | sp::DataType::BigInt(_)
2000        | sp::DataType::SmallInt(_)
2001        | sp::DataType::TinyInt(_)
2002        | sp::DataType::Int2(_)
2003        | sp::DataType::Int4(_)
2004        | sp::DataType::Int8(_) => Ok(DataType::Integer),
2005
2006        sp::DataType::Real
2007        | sp::DataType::Double(..)
2008        | sp::DataType::DoublePrecision
2009        | sp::DataType::Float(_)
2010        | sp::DataType::Float4
2011        | sp::DataType::Float64 => Ok(DataType::Real),
2012
2013        sp::DataType::Varchar(_)
2014        | sp::DataType::Text
2015        | sp::DataType::Char(_)
2016        | sp::DataType::Character(_)
2017        | sp::DataType::String(_) => Ok(DataType::Text),
2018
2019        sp::DataType::Blob(_) | sp::DataType::Bytea => Ok(DataType::Blob),
2020
2021        sp::DataType::Boolean | sp::DataType::Bool => Ok(DataType::Boolean),
2022
2023        _ => Err(SqlError::Unsupported(format!("data type: {dt}"))),
2024    }
2025}
2026
2027// ── Helpers ─────────────────────────────────────────────────────────
2028
2029fn object_name_to_string(name: &sp::ObjectName) -> String {
2030    name.0
2031        .iter()
2032        .filter_map(|p| match p {
2033            sp::ObjectNamePart::Identifier(ident) => Some(ident.value.clone()),
2034            _ => None,
2035        })
2036        .collect::<Vec<_>>()
2037        .join(".")
2038}
2039
2040#[cfg(test)]
2041mod tests {
2042    use super::*;
2043
2044    #[test]
2045    fn parse_create_table() {
2046        let stmt = parse_sql(
2047            "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL, age INTEGER)",
2048        )
2049        .unwrap();
2050
2051        match stmt {
2052            Statement::CreateTable(ct) => {
2053                assert_eq!(ct.name, "users");
2054                assert_eq!(ct.columns.len(), 3);
2055                assert_eq!(ct.columns[0].name, "id");
2056                assert_eq!(ct.columns[0].data_type, DataType::Integer);
2057                assert!(ct.columns[0].is_primary_key);
2058                assert!(!ct.columns[0].nullable);
2059                assert_eq!(ct.columns[1].name, "name");
2060                assert_eq!(ct.columns[1].data_type, DataType::Text);
2061                assert!(!ct.columns[1].nullable);
2062                assert_eq!(ct.columns[2].name, "age");
2063                assert!(ct.columns[2].nullable);
2064                assert_eq!(ct.primary_key, vec!["id"]);
2065            }
2066            _ => panic!("expected CreateTable"),
2067        }
2068    }
2069
2070    #[test]
2071    fn parse_create_table_if_not_exists() {
2072        let stmt = parse_sql("CREATE TABLE IF NOT EXISTS t (id INT PRIMARY KEY)").unwrap();
2073        match stmt {
2074            Statement::CreateTable(ct) => assert!(ct.if_not_exists),
2075            _ => panic!("expected CreateTable"),
2076        }
2077    }
2078
2079    #[test]
2080    fn parse_drop_table() {
2081        let stmt = parse_sql("DROP TABLE users").unwrap();
2082        match stmt {
2083            Statement::DropTable(dt) => {
2084                assert_eq!(dt.name, "users");
2085                assert!(!dt.if_exists);
2086            }
2087            _ => panic!("expected DropTable"),
2088        }
2089    }
2090
2091    #[test]
2092    fn parse_drop_table_if_exists() {
2093        let stmt = parse_sql("DROP TABLE IF EXISTS users").unwrap();
2094        match stmt {
2095            Statement::DropTable(dt) => assert!(dt.if_exists),
2096            _ => panic!("expected DropTable"),
2097        }
2098    }
2099
2100    #[test]
2101    fn parse_insert() {
2102        let stmt =
2103            parse_sql("INSERT INTO users (id, name) VALUES (1, 'Alice'), (2, 'Bob')").unwrap();
2104
2105        match stmt {
2106            Statement::Insert(ins) => {
2107                assert_eq!(ins.table, "users");
2108                assert_eq!(ins.columns, vec!["id", "name"]);
2109                let values = match &ins.source {
2110                    InsertSource::Values(v) => v,
2111                    _ => panic!("expected Values"),
2112                };
2113                assert_eq!(values.len(), 2);
2114                assert!(matches!(values[0][0], Expr::Literal(Value::Integer(1))));
2115                assert!(matches!(&values[0][1], Expr::Literal(Value::Text(s)) if s == "Alice"));
2116            }
2117            _ => panic!("expected Insert"),
2118        }
2119    }
2120
2121    #[test]
2122    fn parse_select_all() {
2123        let stmt = parse_sql("SELECT * FROM users").unwrap();
2124        match stmt {
2125            Statement::Select(sq) => match sq.body {
2126                QueryBody::Select(sel) => {
2127                    assert_eq!(sel.from, "users");
2128                    assert!(matches!(sel.columns[0], SelectColumn::AllColumns));
2129                    assert!(sel.where_clause.is_none());
2130                }
2131                _ => panic!("expected QueryBody::Select"),
2132            },
2133            _ => panic!("expected Select"),
2134        }
2135    }
2136
2137    #[test]
2138    fn parse_select_where() {
2139        let stmt = parse_sql("SELECT id, name FROM users WHERE age > 18").unwrap();
2140        match stmt {
2141            Statement::Select(sq) => match sq.body {
2142                QueryBody::Select(sel) => {
2143                    assert_eq!(sel.columns.len(), 2);
2144                    assert!(sel.where_clause.is_some());
2145                }
2146                _ => panic!("expected QueryBody::Select"),
2147            },
2148            _ => panic!("expected Select"),
2149        }
2150    }
2151
2152    #[test]
2153    fn parse_select_order_limit() {
2154        let stmt = parse_sql("SELECT * FROM users ORDER BY name ASC LIMIT 10 OFFSET 5").unwrap();
2155        match stmt {
2156            Statement::Select(sq) => match sq.body {
2157                QueryBody::Select(sel) => {
2158                    assert_eq!(sel.order_by.len(), 1);
2159                    assert!(!sel.order_by[0].descending);
2160                    assert!(sel.limit.is_some());
2161                    assert!(sel.offset.is_some());
2162                }
2163                _ => panic!("expected QueryBody::Select"),
2164            },
2165            _ => panic!("expected Select"),
2166        }
2167    }
2168
2169    #[test]
2170    fn parse_update() {
2171        let stmt = parse_sql("UPDATE users SET name = 'Bob' WHERE id = 1").unwrap();
2172        match stmt {
2173            Statement::Update(upd) => {
2174                assert_eq!(upd.table, "users");
2175                assert_eq!(upd.assignments.len(), 1);
2176                assert_eq!(upd.assignments[0].0, "name");
2177                assert!(upd.where_clause.is_some());
2178            }
2179            _ => panic!("expected Update"),
2180        }
2181    }
2182
2183    #[test]
2184    fn parse_delete() {
2185        let stmt = parse_sql("DELETE FROM users WHERE id = 1").unwrap();
2186        match stmt {
2187            Statement::Delete(del) => {
2188                assert_eq!(del.table, "users");
2189                assert!(del.where_clause.is_some());
2190            }
2191            _ => panic!("expected Delete"),
2192        }
2193    }
2194
2195    #[test]
2196    fn parse_aggregate() {
2197        let stmt = parse_sql("SELECT COUNT(*), SUM(age) FROM users").unwrap();
2198        match stmt {
2199            Statement::Select(sq) => match sq.body {
2200                QueryBody::Select(sel) => {
2201                    assert_eq!(sel.columns.len(), 2);
2202                    match &sel.columns[0] {
2203                        SelectColumn::Expr {
2204                            expr: Expr::CountStar,
2205                            ..
2206                        } => {}
2207                        other => panic!("expected CountStar, got {other:?}"),
2208                    }
2209                }
2210                _ => panic!("expected QueryBody::Select"),
2211            },
2212            _ => panic!("expected Select"),
2213        }
2214    }
2215
2216    #[test]
2217    fn parse_group_by_having() {
2218        let stmt = parse_sql(
2219            "SELECT department, COUNT(*) FROM employees GROUP BY department HAVING COUNT(*) > 5",
2220        )
2221        .unwrap();
2222        match stmt {
2223            Statement::Select(sq) => match sq.body {
2224                QueryBody::Select(sel) => {
2225                    assert_eq!(sel.group_by.len(), 1);
2226                    assert!(sel.having.is_some());
2227                }
2228                _ => panic!("expected QueryBody::Select"),
2229            },
2230            _ => panic!("expected Select"),
2231        }
2232    }
2233
2234    #[test]
2235    fn parse_expressions() {
2236        let stmt = parse_sql("SELECT id + 1, -price, NOT active FROM items").unwrap();
2237        match stmt {
2238            Statement::Select(sq) => match sq.body {
2239                QueryBody::Select(sel) => {
2240                    assert_eq!(sel.columns.len(), 3);
2241                    // id + 1
2242                    match &sel.columns[0] {
2243                        SelectColumn::Expr {
2244                            expr: Expr::BinaryOp { op: BinOp::Add, .. },
2245                            ..
2246                        } => {}
2247                        other => panic!("expected BinaryOp Add, got {other:?}"),
2248                    }
2249                    // -price
2250                    match &sel.columns[1] {
2251                        SelectColumn::Expr {
2252                            expr:
2253                                Expr::UnaryOp {
2254                                    op: UnaryOp::Neg, ..
2255                                },
2256                            ..
2257                        } => {}
2258                        other => panic!("expected UnaryOp Neg, got {other:?}"),
2259                    }
2260                    // NOT active
2261                    match &sel.columns[2] {
2262                        SelectColumn::Expr {
2263                            expr:
2264                                Expr::UnaryOp {
2265                                    op: UnaryOp::Not, ..
2266                                },
2267                            ..
2268                        } => {}
2269                        other => panic!("expected UnaryOp Not, got {other:?}"),
2270                    }
2271                }
2272                _ => panic!("expected QueryBody::Select"),
2273            },
2274            _ => panic!("expected Select"),
2275        }
2276    }
2277
2278    #[test]
2279    fn parse_is_null() {
2280        let stmt = parse_sql("SELECT * FROM t WHERE x IS NULL").unwrap();
2281        match stmt {
2282            Statement::Select(sq) => match sq.body {
2283                QueryBody::Select(sel) => {
2284                    assert!(matches!(sel.where_clause, Some(Expr::IsNull(_))));
2285                }
2286                _ => panic!("expected QueryBody::Select"),
2287            },
2288            _ => panic!("expected Select"),
2289        }
2290    }
2291
2292    #[test]
2293    fn parse_inner_join() {
2294        let stmt = parse_sql("SELECT * FROM a JOIN b ON a.id = b.id").unwrap();
2295        match stmt {
2296            Statement::Select(sq) => match sq.body {
2297                QueryBody::Select(sel) => {
2298                    assert_eq!(sel.from, "a");
2299                    assert_eq!(sel.joins.len(), 1);
2300                    assert_eq!(sel.joins[0].join_type, JoinType::Inner);
2301                    assert_eq!(sel.joins[0].table.name, "b");
2302                    assert!(sel.joins[0].on_clause.is_some());
2303                }
2304                _ => panic!("expected QueryBody::Select"),
2305            },
2306            _ => panic!("expected Select"),
2307        }
2308    }
2309
2310    #[test]
2311    fn parse_inner_join_explicit() {
2312        let stmt = parse_sql("SELECT * FROM a INNER JOIN b ON a.id = b.a_id").unwrap();
2313        match stmt {
2314            Statement::Select(sq) => match sq.body {
2315                QueryBody::Select(sel) => {
2316                    assert_eq!(sel.joins.len(), 1);
2317                    assert_eq!(sel.joins[0].join_type, JoinType::Inner);
2318                }
2319                _ => panic!("expected QueryBody::Select"),
2320            },
2321            _ => panic!("expected Select"),
2322        }
2323    }
2324
2325    #[test]
2326    fn parse_cross_join() {
2327        let stmt = parse_sql("SELECT * FROM a CROSS JOIN b").unwrap();
2328        match stmt {
2329            Statement::Select(sq) => match sq.body {
2330                QueryBody::Select(sel) => {
2331                    assert_eq!(sel.joins.len(), 1);
2332                    assert_eq!(sel.joins[0].join_type, JoinType::Cross);
2333                    assert!(sel.joins[0].on_clause.is_none());
2334                }
2335                _ => panic!("expected QueryBody::Select"),
2336            },
2337            _ => panic!("expected Select"),
2338        }
2339    }
2340
2341    #[test]
2342    fn parse_left_join() {
2343        let stmt = parse_sql("SELECT * FROM a LEFT JOIN b ON a.id = b.a_id").unwrap();
2344        match stmt {
2345            Statement::Select(sq) => match sq.body {
2346                QueryBody::Select(sel) => {
2347                    assert_eq!(sel.joins.len(), 1);
2348                    assert_eq!(sel.joins[0].join_type, JoinType::Left);
2349                }
2350                _ => panic!("expected QueryBody::Select"),
2351            },
2352            _ => panic!("expected Select"),
2353        }
2354    }
2355
2356    #[test]
2357    fn parse_table_alias() {
2358        let stmt = parse_sql("SELECT u.id FROM users u JOIN orders o ON u.id = o.user_id").unwrap();
2359        match stmt {
2360            Statement::Select(sq) => match sq.body {
2361                QueryBody::Select(sel) => {
2362                    assert_eq!(sel.from, "users");
2363                    assert_eq!(sel.from_alias.as_deref(), Some("u"));
2364                    assert_eq!(sel.joins[0].table.name, "orders");
2365                    assert_eq!(sel.joins[0].table.alias.as_deref(), Some("o"));
2366                }
2367                _ => panic!("expected QueryBody::Select"),
2368            },
2369            _ => panic!("expected Select"),
2370        }
2371    }
2372
2373    #[test]
2374    fn parse_multi_join() {
2375        let stmt =
2376            parse_sql("SELECT * FROM a JOIN b ON a.id = b.a_id JOIN c ON b.id = c.b_id").unwrap();
2377        match stmt {
2378            Statement::Select(sq) => match sq.body {
2379                QueryBody::Select(sel) => {
2380                    assert_eq!(sel.joins.len(), 2);
2381                }
2382                _ => panic!("expected QueryBody::Select"),
2383            },
2384            _ => panic!("expected Select"),
2385        }
2386    }
2387
2388    #[test]
2389    fn parse_qualified_column() {
2390        let stmt = parse_sql("SELECT u.id, u.name FROM users u").unwrap();
2391        match stmt {
2392            Statement::Select(sq) => match sq.body {
2393                QueryBody::Select(sel) => match &sel.columns[0] {
2394                    SelectColumn::Expr {
2395                        expr: Expr::QualifiedColumn { table, column },
2396                        ..
2397                    } => {
2398                        assert_eq!(table, "u");
2399                        assert_eq!(column, "id");
2400                    }
2401                    other => panic!("expected QualifiedColumn, got {other:?}"),
2402                },
2403                _ => panic!("expected QueryBody::Select"),
2404            },
2405            _ => panic!("expected Select"),
2406        }
2407    }
2408
2409    #[test]
2410    fn reject_subquery() {
2411        assert!(parse_sql("SELECT * FROM (SELECT 1)").is_err());
2412    }
2413
2414    #[test]
2415    fn parse_type_mapping() {
2416        let stmt = parse_sql(
2417            "CREATE TABLE t (a INT PRIMARY KEY, b BIGINT, c SMALLINT, d REAL, e DOUBLE PRECISION, f VARCHAR(255), g BOOLEAN, h BLOB, i BYTEA)"
2418        ).unwrap();
2419        match stmt {
2420            Statement::CreateTable(ct) => {
2421                assert_eq!(ct.columns[0].data_type, DataType::Integer); // INT
2422                assert_eq!(ct.columns[1].data_type, DataType::Integer); // BIGINT
2423                assert_eq!(ct.columns[2].data_type, DataType::Integer); // SMALLINT
2424                assert_eq!(ct.columns[3].data_type, DataType::Real); // REAL
2425                assert_eq!(ct.columns[4].data_type, DataType::Real); // DOUBLE
2426                assert_eq!(ct.columns[5].data_type, DataType::Text); // VARCHAR
2427                assert_eq!(ct.columns[6].data_type, DataType::Boolean); // BOOLEAN
2428                assert_eq!(ct.columns[7].data_type, DataType::Blob); // BLOB
2429                assert_eq!(ct.columns[8].data_type, DataType::Blob); // BYTEA
2430            }
2431            _ => panic!("expected CreateTable"),
2432        }
2433    }
2434
2435    #[test]
2436    fn parse_boolean_literals() {
2437        let stmt = parse_sql("INSERT INTO t (a, b) VALUES (true, false)").unwrap();
2438        match stmt {
2439            Statement::Insert(ins) => {
2440                let values = match &ins.source {
2441                    InsertSource::Values(v) => v,
2442                    _ => panic!("expected Values"),
2443                };
2444                assert!(matches!(values[0][0], Expr::Literal(Value::Boolean(true))));
2445                assert!(matches!(values[0][1], Expr::Literal(Value::Boolean(false))));
2446            }
2447            _ => panic!("expected Insert"),
2448        }
2449    }
2450
2451    #[test]
2452    fn parse_null_literal() {
2453        let stmt = parse_sql("INSERT INTO t (a) VALUES (NULL)").unwrap();
2454        match stmt {
2455            Statement::Insert(ins) => {
2456                let values = match &ins.source {
2457                    InsertSource::Values(v) => v,
2458                    _ => panic!("expected Values"),
2459                };
2460                assert!(matches!(values[0][0], Expr::Literal(Value::Null)));
2461            }
2462            _ => panic!("expected Insert"),
2463        }
2464    }
2465
2466    #[test]
2467    fn parse_alias() {
2468        let stmt = parse_sql("SELECT id AS user_id FROM users").unwrap();
2469        match stmt {
2470            Statement::Select(sq) => match sq.body {
2471                QueryBody::Select(sel) => match &sel.columns[0] {
2472                    SelectColumn::Expr { alias: Some(a), .. } => assert_eq!(a, "user_id"),
2473                    other => panic!("expected alias, got {other:?}"),
2474                },
2475                _ => panic!("expected QueryBody::Select"),
2476            },
2477            _ => panic!("expected Select"),
2478        }
2479    }
2480
2481    #[test]
2482    fn parse_begin() {
2483        let stmt = parse_sql("BEGIN").unwrap();
2484        assert!(matches!(stmt, Statement::Begin));
2485    }
2486
2487    #[test]
2488    fn parse_begin_transaction() {
2489        let stmt = parse_sql("BEGIN TRANSACTION").unwrap();
2490        assert!(matches!(stmt, Statement::Begin));
2491    }
2492
2493    #[test]
2494    fn parse_commit() {
2495        let stmt = parse_sql("COMMIT").unwrap();
2496        assert!(matches!(stmt, Statement::Commit));
2497    }
2498
2499    #[test]
2500    fn parse_rollback() {
2501        let stmt = parse_sql("ROLLBACK").unwrap();
2502        assert!(matches!(stmt, Statement::Rollback));
2503    }
2504
2505    #[test]
2506    fn parse_select_distinct() {
2507        let stmt = parse_sql("SELECT DISTINCT name FROM users").unwrap();
2508        match stmt {
2509            Statement::Select(sq) => match sq.body {
2510                QueryBody::Select(sel) => {
2511                    assert!(sel.distinct);
2512                    assert_eq!(sel.columns.len(), 1);
2513                }
2514                _ => panic!("expected QueryBody::Select"),
2515            },
2516            _ => panic!("expected Select"),
2517        }
2518    }
2519
2520    #[test]
2521    fn parse_select_without_distinct() {
2522        let stmt = parse_sql("SELECT name FROM users").unwrap();
2523        match stmt {
2524            Statement::Select(sq) => match sq.body {
2525                QueryBody::Select(sel) => {
2526                    assert!(!sel.distinct);
2527                }
2528                _ => panic!("expected QueryBody::Select"),
2529            },
2530            _ => panic!("expected Select"),
2531        }
2532    }
2533
2534    #[test]
2535    fn parse_select_distinct_all_columns() {
2536        let stmt = parse_sql("SELECT DISTINCT * FROM users").unwrap();
2537        match stmt {
2538            Statement::Select(sq) => match sq.body {
2539                QueryBody::Select(sel) => {
2540                    assert!(sel.distinct);
2541                    assert!(matches!(sel.columns[0], SelectColumn::AllColumns));
2542                }
2543                _ => panic!("expected QueryBody::Select"),
2544            },
2545            _ => panic!("expected Select"),
2546        }
2547    }
2548
2549    #[test]
2550    fn reject_distinct_on() {
2551        assert!(parse_sql("SELECT DISTINCT ON (id) * FROM users").is_err());
2552    }
2553
2554    #[test]
2555    fn parse_create_index() {
2556        let stmt = parse_sql("CREATE INDEX idx_name ON users (name)").unwrap();
2557        match stmt {
2558            Statement::CreateIndex(ci) => {
2559                assert_eq!(ci.index_name, "idx_name");
2560                assert_eq!(ci.table_name, "users");
2561                assert_eq!(ci.columns, vec!["name"]);
2562                assert!(!ci.unique);
2563                assert!(!ci.if_not_exists);
2564            }
2565            _ => panic!("expected CreateIndex"),
2566        }
2567    }
2568
2569    #[test]
2570    fn parse_create_unique_index() {
2571        let stmt = parse_sql("CREATE UNIQUE INDEX idx_email ON users (email)").unwrap();
2572        match stmt {
2573            Statement::CreateIndex(ci) => {
2574                assert!(ci.unique);
2575                assert_eq!(ci.columns, vec!["email"]);
2576            }
2577            _ => panic!("expected CreateIndex"),
2578        }
2579    }
2580
2581    #[test]
2582    fn parse_create_index_if_not_exists() {
2583        let stmt = parse_sql("CREATE INDEX IF NOT EXISTS idx_x ON t (a)").unwrap();
2584        match stmt {
2585            Statement::CreateIndex(ci) => assert!(ci.if_not_exists),
2586            _ => panic!("expected CreateIndex"),
2587        }
2588    }
2589
2590    #[test]
2591    fn parse_create_index_multi_column() {
2592        let stmt = parse_sql("CREATE INDEX idx_multi ON t (a, b, c)").unwrap();
2593        match stmt {
2594            Statement::CreateIndex(ci) => {
2595                assert_eq!(ci.columns, vec!["a", "b", "c"]);
2596            }
2597            _ => panic!("expected CreateIndex"),
2598        }
2599    }
2600
2601    #[test]
2602    fn parse_drop_index() {
2603        let stmt = parse_sql("DROP INDEX idx_name").unwrap();
2604        match stmt {
2605            Statement::DropIndex(di) => {
2606                assert_eq!(di.index_name, "idx_name");
2607                assert!(!di.if_exists);
2608            }
2609            _ => panic!("expected DropIndex"),
2610        }
2611    }
2612
2613    #[test]
2614    fn parse_drop_index_if_exists() {
2615        let stmt = parse_sql("DROP INDEX IF EXISTS idx_name").unwrap();
2616        match stmt {
2617            Statement::DropIndex(di) => {
2618                assert!(di.if_exists);
2619            }
2620            _ => panic!("expected DropIndex"),
2621        }
2622    }
2623
2624    #[test]
2625    fn parse_explain_select() {
2626        let stmt = parse_sql("EXPLAIN SELECT * FROM users WHERE id = 1").unwrap();
2627        match stmt {
2628            Statement::Explain(inner) => {
2629                assert!(matches!(*inner, Statement::Select(_)));
2630            }
2631            _ => panic!("expected Explain"),
2632        }
2633    }
2634
2635    #[test]
2636    fn parse_explain_insert() {
2637        let stmt = parse_sql("EXPLAIN INSERT INTO t (a) VALUES (1)").unwrap();
2638        assert!(matches!(stmt, Statement::Explain(_)));
2639    }
2640
2641    #[test]
2642    fn reject_explain_analyze() {
2643        assert!(parse_sql("EXPLAIN ANALYZE SELECT * FROM t").is_err());
2644    }
2645
2646    #[test]
2647    fn parse_parameter_placeholder() {
2648        let stmt = parse_sql("SELECT * FROM t WHERE id = $1").unwrap();
2649        match stmt {
2650            Statement::Select(sq) => match sq.body {
2651                QueryBody::Select(sel) => match &sel.where_clause {
2652                    Some(Expr::BinaryOp { right, .. }) => {
2653                        assert!(matches!(right.as_ref(), Expr::Parameter(1)));
2654                    }
2655                    other => panic!("expected BinaryOp with Parameter, got {other:?}"),
2656                },
2657                _ => panic!("expected QueryBody::Select"),
2658            },
2659            _ => panic!("expected Select"),
2660        }
2661    }
2662
2663    #[test]
2664    fn parse_multiple_parameters() {
2665        let stmt = parse_sql("INSERT INTO t (a, b) VALUES ($1, $2)").unwrap();
2666        match stmt {
2667            Statement::Insert(ins) => {
2668                let values = match &ins.source {
2669                    InsertSource::Values(v) => v,
2670                    _ => panic!("expected Values"),
2671                };
2672                assert!(matches!(values[0][0], Expr::Parameter(1)));
2673                assert!(matches!(values[0][1], Expr::Parameter(2)));
2674            }
2675            _ => panic!("expected Insert"),
2676        }
2677    }
2678
2679    #[test]
2680    fn parse_insert_select() {
2681        let stmt =
2682            parse_sql("INSERT INTO t2 (id, name) SELECT id, name FROM t1 WHERE id > 5").unwrap();
2683        match stmt {
2684            Statement::Insert(ins) => {
2685                assert_eq!(ins.table, "t2");
2686                assert_eq!(ins.columns, vec!["id", "name"]);
2687                match &ins.source {
2688                    InsertSource::Select(sq) => match &sq.body {
2689                        QueryBody::Select(sel) => {
2690                            assert_eq!(sel.from, "t1");
2691                            assert_eq!(sel.columns.len(), 2);
2692                            assert!(sel.where_clause.is_some());
2693                        }
2694                        _ => panic!("expected QueryBody::Select"),
2695                    },
2696                    _ => panic!("expected InsertSource::Select"),
2697                }
2698            }
2699            _ => panic!("expected Insert"),
2700        }
2701    }
2702
2703    #[test]
2704    fn parse_insert_select_no_columns() {
2705        let stmt = parse_sql("INSERT INTO t2 SELECT * FROM t1").unwrap();
2706        match stmt {
2707            Statement::Insert(ins) => {
2708                assert_eq!(ins.table, "t2");
2709                assert!(ins.columns.is_empty());
2710                assert!(matches!(&ins.source, InsertSource::Select(_)));
2711            }
2712            _ => panic!("expected Insert"),
2713        }
2714    }
2715
2716    #[test]
2717    fn reject_zero_parameter() {
2718        assert!(parse_sql("SELECT $0 FROM t").is_err());
2719    }
2720
2721    #[test]
2722    fn count_params_basic() {
2723        let stmt = parse_sql("SELECT * FROM t WHERE a = $1 AND b = $3").unwrap();
2724        assert_eq!(count_params(&stmt), 3);
2725    }
2726
2727    #[test]
2728    fn count_params_none() {
2729        let stmt = parse_sql("SELECT * FROM t WHERE a = 1").unwrap();
2730        assert_eq!(count_params(&stmt), 0);
2731    }
2732
2733    #[test]
2734    fn bind_params_basic() {
2735        let stmt = parse_sql("SELECT * FROM t WHERE id = $1").unwrap();
2736        let bound = bind_params(&stmt, &[Value::Integer(42)]).unwrap();
2737        match bound {
2738            Statement::Select(sq) => match sq.body {
2739                QueryBody::Select(sel) => match &sel.where_clause {
2740                    Some(Expr::BinaryOp { right, .. }) => {
2741                        assert!(matches!(right.as_ref(), Expr::Literal(Value::Integer(42))));
2742                    }
2743                    other => panic!("expected BinaryOp with Literal, got {other:?}"),
2744                },
2745                _ => panic!("expected QueryBody::Select"),
2746            },
2747            _ => panic!("expected Select"),
2748        }
2749    }
2750
2751    #[test]
2752    fn bind_params_out_of_range() {
2753        let stmt = parse_sql("SELECT * FROM t WHERE id = $2").unwrap();
2754        let result = bind_params(&stmt, &[Value::Integer(1)]);
2755        assert!(result.is_err());
2756    }
2757
2758    #[test]
2759    fn parse_table_constraint_pk() {
2760        let stmt = parse_sql("CREATE TABLE t (a INTEGER, b TEXT, PRIMARY KEY (a))").unwrap();
2761        match stmt {
2762            Statement::CreateTable(ct) => {
2763                assert_eq!(ct.primary_key, vec!["a"]);
2764                assert!(ct.columns[0].is_primary_key);
2765                assert!(!ct.columns[0].nullable);
2766            }
2767            _ => panic!("expected CreateTable"),
2768        }
2769    }
2770}