Skip to main content

citadeldb_sql/
parser.rs

1//! SQL parser: converts SQL strings into our internal AST.
2
3use sqlparser::ast as sp;
4use sqlparser::dialect::GenericDialect;
5use sqlparser::parser::Parser;
6
7use crate::error::{Result, SqlError};
8use crate::types::{DataType, Value};
9
10// ── Internal AST ────────────────────────────────────────────────────
11
12#[derive(Debug, Clone)]
13pub enum Statement {
14    CreateTable(CreateTableStmt),
15    DropTable(DropTableStmt),
16    CreateIndex(CreateIndexStmt),
17    DropIndex(DropIndexStmt),
18    Insert(InsertStmt),
19    Select(SelectStmt),
20    Update(UpdateStmt),
21    Delete(DeleteStmt),
22    Begin,
23    Commit,
24    Rollback,
25    Explain(Box<Statement>),
26}
27
28#[derive(Debug, Clone)]
29pub struct CreateTableStmt {
30    pub name: String,
31    pub columns: Vec<ColumnSpec>,
32    pub primary_key: Vec<String>,
33    pub if_not_exists: bool,
34}
35
36#[derive(Debug, Clone)]
37pub struct ColumnSpec {
38    pub name: String,
39    pub data_type: DataType,
40    pub nullable: bool,
41    pub is_primary_key: bool,
42}
43
44#[derive(Debug, Clone)]
45pub struct DropTableStmt {
46    pub name: String,
47    pub if_exists: bool,
48}
49
50#[derive(Debug, Clone)]
51pub struct CreateIndexStmt {
52    pub index_name: String,
53    pub table_name: String,
54    pub columns: Vec<String>,
55    pub unique: bool,
56    pub if_not_exists: bool,
57}
58
59#[derive(Debug, Clone)]
60pub struct DropIndexStmt {
61    pub index_name: String,
62    pub if_exists: bool,
63}
64
65#[derive(Debug, Clone)]
66pub struct InsertStmt {
67    pub table: String,
68    pub columns: Vec<String>,
69    pub values: Vec<Vec<Expr>>,
70}
71
72#[derive(Debug, Clone)]
73pub struct TableRef {
74    pub name: String,
75    pub alias: Option<String>,
76}
77
78#[derive(Debug, Clone, Copy, PartialEq)]
79pub enum JoinType {
80    Inner,
81    Cross,
82    Left,
83    Right,
84}
85
86#[derive(Debug, Clone)]
87pub struct JoinClause {
88    pub join_type: JoinType,
89    pub table: TableRef,
90    pub on_clause: Option<Expr>,
91}
92
93#[derive(Debug, Clone)]
94pub struct SelectStmt {
95    pub columns: Vec<SelectColumn>,
96    pub from: String,
97    pub from_alias: Option<String>,
98    pub joins: Vec<JoinClause>,
99    pub distinct: bool,
100    pub where_clause: Option<Expr>,
101    pub order_by: Vec<OrderByItem>,
102    pub limit: Option<Expr>,
103    pub offset: Option<Expr>,
104    pub group_by: Vec<Expr>,
105    pub having: Option<Expr>,
106}
107
108#[derive(Debug, Clone)]
109pub struct UpdateStmt {
110    pub table: String,
111    pub assignments: Vec<(String, Expr)>,
112    pub where_clause: Option<Expr>,
113}
114
115#[derive(Debug, Clone)]
116pub struct DeleteStmt {
117    pub table: String,
118    pub where_clause: Option<Expr>,
119}
120
121#[derive(Debug, Clone)]
122pub enum SelectColumn {
123    AllColumns,
124    Expr { expr: Expr, alias: Option<String> },
125}
126
127#[derive(Debug, Clone)]
128pub struct OrderByItem {
129    pub expr: Expr,
130    pub descending: bool,
131    pub nulls_first: Option<bool>,
132}
133
134#[derive(Debug, Clone)]
135pub enum Expr {
136    Literal(Value),
137    Column(String),
138    QualifiedColumn { table: String, column: String },
139    BinaryOp { left: Box<Expr>, op: BinOp, right: Box<Expr> },
140    UnaryOp { op: UnaryOp, expr: Box<Expr> },
141    IsNull(Box<Expr>),
142    IsNotNull(Box<Expr>),
143    Function { name: String, args: Vec<Expr> },
144    CountStar,
145    InSubquery { expr: Box<Expr>, subquery: Box<SelectStmt>, negated: bool },
146    InList { expr: Box<Expr>, list: Vec<Expr>, negated: bool },
147    Exists { subquery: Box<SelectStmt>, negated: bool },
148    ScalarSubquery(Box<SelectStmt>),
149    InSet { expr: Box<Expr>, values: std::collections::HashSet<Value>, has_null: bool, negated: bool },
150    Between { expr: Box<Expr>, low: Box<Expr>, high: Box<Expr>, negated: bool },
151    Like { expr: Box<Expr>, pattern: Box<Expr>, escape: Option<Box<Expr>>, negated: bool },
152    Case { operand: Option<Box<Expr>>, conditions: Vec<(Expr, Expr)>, else_result: Option<Box<Expr>> },
153    Coalesce(Vec<Expr>),
154    Cast { expr: Box<Expr>, data_type: DataType },
155    Parameter(usize),
156}
157
158#[derive(Debug, Clone, Copy, PartialEq, Eq)]
159pub enum BinOp {
160    Add, Sub, Mul, Div, Mod,
161    Eq, NotEq, Lt, Gt, LtEq, GtEq,
162    And, Or, Concat,
163}
164
165#[derive(Debug, Clone, Copy, PartialEq, Eq)]
166pub enum UnaryOp {
167    Neg,
168    Not,
169}
170
171// ── Parser entry point ──────────────────────────────────────────────
172
173pub fn parse_sql(sql: &str) -> Result<Statement> {
174    let dialect = GenericDialect {};
175    let stmts = Parser::parse_sql(&dialect, sql)
176        .map_err(|e| SqlError::Parse(e.to_string()))?;
177
178    if stmts.is_empty() {
179        return Err(SqlError::Parse("empty SQL".into()));
180    }
181    if stmts.len() > 1 {
182        return Err(SqlError::Unsupported("multiple statements".into()));
183    }
184
185    convert_statement(stmts.into_iter().next().unwrap())
186}
187
188// ── Parameter utilities ─────────────────────────────────────────────
189
190/// Returns the number of distinct parameters in a statement (max $N found).
191pub fn count_params(stmt: &Statement) -> usize {
192    let mut max_idx = 0usize;
193    visit_exprs_stmt(stmt, &mut |e| {
194        if let Expr::Parameter(n) = e {
195            max_idx = max_idx.max(*n);
196        }
197    });
198    max_idx
199}
200
201/// Replace all `Expr::Parameter(n)` with `Expr::Literal(params[n-1])`.
202pub fn bind_params(stmt: &Statement, params: &[crate::types::Value]) -> crate::error::Result<Statement> {
203    bind_stmt(stmt, params)
204}
205
206fn bind_stmt(stmt: &Statement, params: &[crate::types::Value]) -> crate::error::Result<Statement> {
207    match stmt {
208        Statement::Select(sel) => Ok(Statement::Select(bind_select(sel, params)?)),
209        Statement::Insert(ins) => {
210            let values = ins.values.iter()
211                .map(|row| row.iter().map(|e| bind_expr(e, params)).collect::<crate::error::Result<Vec<_>>>())
212                .collect::<crate::error::Result<Vec<_>>>()?;
213            Ok(Statement::Insert(InsertStmt {
214                table: ins.table.clone(),
215                columns: ins.columns.clone(),
216                values,
217            }))
218        }
219        Statement::Update(upd) => {
220            let assignments = upd.assignments.iter()
221                .map(|(col, e)| Ok((col.clone(), bind_expr(e, params)?)))
222                .collect::<crate::error::Result<Vec<_>>>()?;
223            let where_clause = upd.where_clause.as_ref()
224                .map(|e| bind_expr(e, params)).transpose()?;
225            Ok(Statement::Update(UpdateStmt {
226                table: upd.table.clone(),
227                assignments,
228                where_clause,
229            }))
230        }
231        Statement::Delete(del) => {
232            let where_clause = del.where_clause.as_ref()
233                .map(|e| bind_expr(e, params)).transpose()?;
234            Ok(Statement::Delete(DeleteStmt {
235                table: del.table.clone(),
236                where_clause,
237            }))
238        }
239        Statement::Explain(inner) => Ok(Statement::Explain(Box::new(bind_stmt(inner, params)?))),
240        other => Ok(other.clone()),
241    }
242}
243
244fn bind_select(sel: &SelectStmt, params: &[crate::types::Value]) -> crate::error::Result<SelectStmt> {
245    let columns = sel.columns.iter().map(|c| match c {
246        SelectColumn::AllColumns => Ok(SelectColumn::AllColumns),
247        SelectColumn::Expr { expr, alias } => Ok(SelectColumn::Expr {
248            expr: bind_expr(expr, params)?,
249            alias: alias.clone(),
250        }),
251    }).collect::<crate::error::Result<Vec<_>>>()?;
252    let joins = sel.joins.iter().map(|j| {
253        let on_clause = j.on_clause.as_ref()
254            .map(|e| bind_expr(e, params)).transpose()?;
255        Ok(JoinClause {
256            join_type: j.join_type,
257            table: j.table.clone(),
258            on_clause,
259        })
260    }).collect::<crate::error::Result<Vec<_>>>()?;
261    let where_clause = sel.where_clause.as_ref()
262        .map(|e| bind_expr(e, params)).transpose()?;
263    let order_by = sel.order_by.iter().map(|o| {
264        Ok(OrderByItem {
265            expr: bind_expr(&o.expr, params)?,
266            descending: o.descending,
267            nulls_first: o.nulls_first,
268        })
269    }).collect::<crate::error::Result<Vec<_>>>()?;
270    let limit = sel.limit.as_ref().map(|e| bind_expr(e, params)).transpose()?;
271    let offset = sel.offset.as_ref().map(|e| bind_expr(e, params)).transpose()?;
272    let group_by = sel.group_by.iter()
273        .map(|e| bind_expr(e, params))
274        .collect::<crate::error::Result<Vec<_>>>()?;
275    let having = sel.having.as_ref().map(|e| bind_expr(e, params)).transpose()?;
276
277    Ok(SelectStmt {
278        columns,
279        from: sel.from.clone(),
280        from_alias: sel.from_alias.clone(),
281        joins,
282        distinct: sel.distinct,
283        where_clause,
284        order_by,
285        limit,
286        offset,
287        group_by,
288        having,
289    })
290}
291
292fn bind_expr(expr: &Expr, params: &[crate::types::Value]) -> crate::error::Result<Expr> {
293    match expr {
294        Expr::Parameter(n) => {
295            if *n == 0 || *n > params.len() {
296                return Err(SqlError::ParameterCountMismatch {
297                    expected: *n,
298                    got: params.len(),
299                });
300            }
301            Ok(Expr::Literal(params[*n - 1].clone()))
302        }
303        Expr::Literal(_) | Expr::Column(_) | Expr::QualifiedColumn { .. }
304        | Expr::CountStar => Ok(expr.clone()),
305        Expr::BinaryOp { left, op, right } => Ok(Expr::BinaryOp {
306            left: Box::new(bind_expr(left, params)?),
307            op: *op,
308            right: Box::new(bind_expr(right, params)?),
309        }),
310        Expr::UnaryOp { op, expr: e } => Ok(Expr::UnaryOp {
311            op: *op,
312            expr: Box::new(bind_expr(e, params)?),
313        }),
314        Expr::IsNull(e) => Ok(Expr::IsNull(Box::new(bind_expr(e, params)?))),
315        Expr::IsNotNull(e) => Ok(Expr::IsNotNull(Box::new(bind_expr(e, params)?))),
316        Expr::Function { name, args } => {
317            let args = args.iter().map(|a| bind_expr(a, params))
318                .collect::<crate::error::Result<Vec<_>>>()?;
319            Ok(Expr::Function { name: name.clone(), args })
320        }
321        Expr::InSubquery { expr: e, subquery, negated } => Ok(Expr::InSubquery {
322            expr: Box::new(bind_expr(e, params)?),
323            subquery: Box::new(bind_select(subquery, params)?),
324            negated: *negated,
325        }),
326        Expr::InList { expr: e, list, negated } => {
327            let list = list.iter().map(|l| bind_expr(l, params))
328                .collect::<crate::error::Result<Vec<_>>>()?;
329            Ok(Expr::InList {
330                expr: Box::new(bind_expr(e, params)?),
331                list,
332                negated: *negated,
333            })
334        }
335        Expr::Exists { subquery, negated } => Ok(Expr::Exists {
336            subquery: Box::new(bind_select(subquery, params)?),
337            negated: *negated,
338        }),
339        Expr::ScalarSubquery(sq) => Ok(Expr::ScalarSubquery(
340            Box::new(bind_select(sq, params)?),
341        )),
342        Expr::InSet { expr: e, values, has_null, negated } => Ok(Expr::InSet {
343            expr: Box::new(bind_expr(e, params)?),
344            values: values.clone(),
345            has_null: *has_null,
346            negated: *negated,
347        }),
348        Expr::Between { expr: e, low, high, negated } => Ok(Expr::Between {
349            expr: Box::new(bind_expr(e, params)?),
350            low: Box::new(bind_expr(low, params)?),
351            high: Box::new(bind_expr(high, params)?),
352            negated: *negated,
353        }),
354        Expr::Like { expr: e, pattern, escape, negated } => Ok(Expr::Like {
355            expr: Box::new(bind_expr(e, params)?),
356            pattern: Box::new(bind_expr(pattern, params)?),
357            escape: escape.as_ref().map(|esc| bind_expr(esc, params).map(Box::new)).transpose()?,
358            negated: *negated,
359        }),
360        Expr::Case { operand, conditions, else_result } => {
361            let operand = operand.as_ref()
362                .map(|e| bind_expr(e, params).map(Box::new)).transpose()?;
363            let conditions = conditions.iter()
364                .map(|(cond, then)| Ok((bind_expr(cond, params)?, bind_expr(then, params)?)))
365                .collect::<crate::error::Result<Vec<_>>>()?;
366            let else_result = else_result.as_ref()
367                .map(|e| bind_expr(e, params).map(Box::new)).transpose()?;
368            Ok(Expr::Case { operand, conditions, else_result })
369        }
370        Expr::Coalesce(args) => {
371            let args = args.iter().map(|a| bind_expr(a, params))
372                .collect::<crate::error::Result<Vec<_>>>()?;
373            Ok(Expr::Coalesce(args))
374        }
375        Expr::Cast { expr: e, data_type } => Ok(Expr::Cast {
376            expr: Box::new(bind_expr(e, params)?),
377            data_type: data_type.clone(),
378        }),
379    }
380}
381
382fn visit_exprs_stmt(stmt: &Statement, visitor: &mut impl FnMut(&Expr)) {
383    match stmt {
384        Statement::Select(sel) => visit_exprs_select(sel, visitor),
385        Statement::Insert(ins) => {
386            for row in &ins.values {
387                for e in row { visit_expr(e, visitor); }
388            }
389        }
390        Statement::Update(upd) => {
391            for (_, e) in &upd.assignments { visit_expr(e, visitor); }
392            if let Some(w) = &upd.where_clause { visit_expr(w, visitor); }
393        }
394        Statement::Delete(del) => {
395            if let Some(w) = &del.where_clause { visit_expr(w, visitor); }
396        }
397        Statement::Explain(inner) => visit_exprs_stmt(inner, visitor),
398        _ => {}
399    }
400}
401
402fn visit_exprs_select(sel: &SelectStmt, visitor: &mut impl FnMut(&Expr)) {
403    for col in &sel.columns {
404        if let SelectColumn::Expr { expr, .. } = col { visit_expr(expr, visitor); }
405    }
406    for j in &sel.joins {
407        if let Some(on) = &j.on_clause { visit_expr(on, visitor); }
408    }
409    if let Some(w) = &sel.where_clause { visit_expr(w, visitor); }
410    for o in &sel.order_by { visit_expr(&o.expr, visitor); }
411    if let Some(l) = &sel.limit { visit_expr(l, visitor); }
412    if let Some(o) = &sel.offset { visit_expr(o, visitor); }
413    for g in &sel.group_by { visit_expr(g, visitor); }
414    if let Some(h) = &sel.having { visit_expr(h, visitor); }
415}
416
417fn visit_expr(expr: &Expr, visitor: &mut impl FnMut(&Expr)) {
418    visitor(expr);
419    match expr {
420        Expr::BinaryOp { left, right, .. } => {
421            visit_expr(left, visitor);
422            visit_expr(right, visitor);
423        }
424        Expr::UnaryOp { expr: e, .. } | Expr::IsNull(e) | Expr::IsNotNull(e) => {
425            visit_expr(e, visitor);
426        }
427        Expr::Function { args, .. } | Expr::Coalesce(args) => {
428            for a in args { visit_expr(a, visitor); }
429        }
430        Expr::InSubquery { expr: e, subquery, .. } => {
431            visit_expr(e, visitor);
432            visit_exprs_select(subquery, visitor);
433        }
434        Expr::InList { expr: e, list, .. } => {
435            visit_expr(e, visitor);
436            for l in list { visit_expr(l, visitor); }
437        }
438        Expr::Exists { subquery, .. } => visit_exprs_select(subquery, visitor),
439        Expr::ScalarSubquery(sq) => visit_exprs_select(sq, visitor),
440        Expr::InSet { expr: e, .. } => visit_expr(e, visitor),
441        Expr::Between { expr: e, low, high, .. } => {
442            visit_expr(e, visitor);
443            visit_expr(low, visitor);
444            visit_expr(high, visitor);
445        }
446        Expr::Like { expr: e, pattern, escape, .. } => {
447            visit_expr(e, visitor);
448            visit_expr(pattern, visitor);
449            if let Some(esc) = escape { visit_expr(esc, visitor); }
450        }
451        Expr::Case { operand, conditions, else_result } => {
452            if let Some(op) = operand { visit_expr(op, visitor); }
453            for (cond, then) in conditions {
454                visit_expr(cond, visitor);
455                visit_expr(then, visitor);
456            }
457            if let Some(el) = else_result { visit_expr(el, visitor); }
458        }
459        Expr::Cast { expr: e, .. } => visit_expr(e, visitor),
460        Expr::Literal(_) | Expr::Column(_) | Expr::QualifiedColumn { .. }
461        | Expr::CountStar | Expr::Parameter(_) => {}
462    }
463}
464
465// ── Statement conversion ────────────────────────────────────────────
466
467fn convert_statement(stmt: sp::Statement) -> Result<Statement> {
468    match stmt {
469        sp::Statement::CreateTable(ct) => convert_create_table(ct),
470        sp::Statement::CreateIndex(ci) => convert_create_index(ci),
471        sp::Statement::Drop {
472            object_type: sp::ObjectType::Table,
473            if_exists,
474            names,
475            ..
476        } => {
477            if names.len() != 1 {
478                return Err(SqlError::Unsupported("multi-table DROP".into()));
479            }
480            Ok(Statement::DropTable(DropTableStmt {
481                name: object_name_to_string(&names[0]),
482                if_exists,
483            }))
484        }
485        sp::Statement::Drop {
486            object_type: sp::ObjectType::Index,
487            if_exists,
488            names,
489            ..
490        } => {
491            if names.len() != 1 {
492                return Err(SqlError::Unsupported("multi-index DROP".into()));
493            }
494            Ok(Statement::DropIndex(DropIndexStmt {
495                index_name: object_name_to_string(&names[0]),
496                if_exists,
497            }))
498        }
499        sp::Statement::Insert(insert) => convert_insert(insert),
500        sp::Statement::Query(query) => convert_query(*query),
501        sp::Statement::Update(update) => convert_update(update),
502        sp::Statement::Delete(delete) => convert_delete(delete),
503        sp::Statement::StartTransaction { .. } => Ok(Statement::Begin),
504        sp::Statement::Commit { .. } => Ok(Statement::Commit),
505        sp::Statement::Rollback { .. } => Ok(Statement::Rollback),
506        sp::Statement::Explain { statement, analyze, .. } => {
507            if analyze {
508                return Err(SqlError::Unsupported("EXPLAIN ANALYZE".into()));
509            }
510            let inner = convert_statement(*statement)?;
511            Ok(Statement::Explain(Box::new(inner)))
512        }
513        _ => Err(SqlError::Unsupported(format!(
514            "statement type: {}",
515            stmt
516        ))),
517    }
518}
519
520fn convert_create_table(ct: sp::CreateTable) -> Result<Statement> {
521    let name = object_name_to_string(&ct.name);
522    let if_not_exists = ct.if_not_exists;
523
524    let mut columns = Vec::new();
525    let mut inline_pk: Vec<String> = Vec::new();
526
527    for col_def in &ct.columns {
528        let col_name = col_def.name.value.clone();
529        let data_type = convert_data_type(&col_def.data_type)?;
530        let mut nullable = true;
531        let mut is_primary_key = false;
532
533        for opt in &col_def.options {
534            match &opt.option {
535                sp::ColumnOption::NotNull => nullable = false,
536                sp::ColumnOption::Null => nullable = true,
537                sp::ColumnOption::PrimaryKey(_) => {
538                    is_primary_key = true;
539                    nullable = false;
540                    inline_pk.push(col_name.clone());
541                }
542                _ => {}
543            }
544        }
545
546        columns.push(ColumnSpec {
547            name: col_name,
548            data_type,
549            nullable,
550            is_primary_key,
551        });
552    }
553
554    // Check table-level constraints for PRIMARY KEY
555    for constraint in &ct.constraints {
556        if let sp::TableConstraint::PrimaryKey(pk_constraint) = constraint {
557            for idx_col in &pk_constraint.columns {
558                // IndexColumn has a `column: OrderByExpr` field; extract ident from the expr
559                let col_name = match &idx_col.column.expr {
560                    sp::Expr::Identifier(ident) => ident.value.clone(),
561                    _ => continue,
562                };
563                if !inline_pk.contains(&col_name) {
564                    inline_pk.push(col_name.clone());
565                }
566                if let Some(col) = columns.iter_mut().find(|c| c.name == col_name) {
567                    col.nullable = false;
568                    col.is_primary_key = true;
569                }
570            }
571        }
572    }
573
574    Ok(Statement::CreateTable(CreateTableStmt {
575        name,
576        columns,
577        primary_key: inline_pk,
578        if_not_exists,
579    }))
580}
581
582fn convert_create_index(ci: sp::CreateIndex) -> Result<Statement> {
583    let index_name = ci.name
584        .as_ref()
585        .map(object_name_to_string)
586        .ok_or_else(|| SqlError::Parse("index name required".into()))?;
587
588    let table_name = object_name_to_string(&ci.table_name);
589
590    let columns: Vec<String> = ci.columns.iter().map(|idx_col| {
591        match &idx_col.column.expr {
592            sp::Expr::Identifier(ident) => Ok(ident.value.clone()),
593            other => Err(SqlError::Unsupported(format!("expression index: {other}"))),
594        }
595    }).collect::<Result<_>>()?;
596
597    if columns.is_empty() {
598        return Err(SqlError::Parse("index must have at least one column".into()));
599    }
600
601    Ok(Statement::CreateIndex(CreateIndexStmt {
602        index_name,
603        table_name,
604        columns,
605        unique: ci.unique,
606        if_not_exists: ci.if_not_exists,
607    }))
608}
609
610fn convert_insert(insert: sp::Insert) -> Result<Statement> {
611    let table = match &insert.table {
612        sp::TableObject::TableName(name) => object_name_to_string(name),
613        _ => return Err(SqlError::Unsupported("INSERT into non-table object".into())),
614    };
615
616    let columns: Vec<String> = insert.columns.iter().map(|c| c.value.clone()).collect();
617
618    let source = insert.source.ok_or_else(|| {
619        SqlError::Parse("INSERT requires VALUES".into())
620    })?;
621
622    let values = match *source.body {
623        sp::SetExpr::Values(sp::Values { rows, .. }) => {
624            let mut result = Vec::new();
625            for row in rows {
626                let mut exprs = Vec::new();
627                for expr in row {
628                    exprs.push(convert_expr(&expr)?);
629                }
630                result.push(exprs);
631            }
632            result
633        }
634        _ => return Err(SqlError::Unsupported("INSERT ... SELECT".into())),
635    };
636
637    Ok(Statement::Insert(InsertStmt {
638        table,
639        columns,
640        values,
641    }))
642}
643
644fn convert_subquery(query: &sp::Query) -> Result<SelectStmt> {
645    match convert_query(query.clone())? {
646        Statement::Select(s) => Ok(s),
647        _ => Err(SqlError::Unsupported("non-SELECT subquery".into())),
648    }
649}
650
651fn convert_query(query: sp::Query) -> Result<Statement> {
652    let select = match *query.body {
653        sp::SetExpr::Select(sel) => *sel,
654        _ => return Err(SqlError::Unsupported("UNION/INTERSECT/EXCEPT".into())),
655    };
656
657    let distinct = match &select.distinct {
658        Some(sp::Distinct::Distinct) => true,
659        Some(sp::Distinct::On(_)) => {
660            return Err(SqlError::Unsupported("DISTINCT ON".into()));
661        }
662        _ => false,
663    };
664
665    // FROM clause
666    let (from, from_alias, joins) = if select.from.is_empty() {
667        (String::new(), None, vec![])
668    } else if select.from.len() == 1 {
669        let table_with_joins = &select.from[0];
670        let (name, alias) = match &table_with_joins.relation {
671            sp::TableFactor::Table { name, alias, .. } => {
672                let table_name = object_name_to_string(name);
673                let alias_str = alias.as_ref().map(|a| a.name.value.clone());
674                (table_name, alias_str)
675            }
676            _ => return Err(SqlError::Unsupported("non-table FROM source".into())),
677        };
678        let j = table_with_joins.joins.iter()
679            .map(|j| convert_join(j))
680            .collect::<Result<Vec<_>>>()?;
681        (name, alias, j)
682    } else {
683        return Err(SqlError::Unsupported("comma-separated FROM tables".into()));
684    };
685
686    // Projection
687    let columns: Vec<SelectColumn> = select.projection.iter()
688        .map(convert_select_item)
689        .collect::<Result<_>>()?;
690
691    // WHERE
692    let where_clause = select.selection.as_ref()
693        .map(convert_expr)
694        .transpose()?;
695
696    // ORDER BY
697    let order_by = if let Some(ref ob) = query.order_by {
698        match &ob.kind {
699            sp::OrderByKind::Expressions(exprs) => {
700                exprs.iter().map(convert_order_by_expr).collect::<Result<_>>()?
701            }
702            sp::OrderByKind::All { .. } => {
703                return Err(SqlError::Unsupported("ORDER BY ALL".into()));
704            }
705        }
706    } else {
707        vec![]
708    };
709
710    // LIMIT / OFFSET
711    let (limit, offset) = match &query.limit_clause {
712        Some(sp::LimitClause::LimitOffset { limit, offset, .. }) => {
713            let l = limit.as_ref().map(convert_expr).transpose()?;
714            let o = offset.as_ref().map(|o| convert_expr(&o.value)).transpose()?;
715            (l, o)
716        }
717        Some(sp::LimitClause::OffsetCommaLimit { limit, offset }) => {
718            let l = Some(convert_expr(limit)?);
719            let o = Some(convert_expr(offset)?);
720            (l, o)
721        }
722        None => (None, None),
723    };
724
725    // GROUP BY
726    let group_by = match &select.group_by {
727        sp::GroupByExpr::Expressions(exprs, _) => {
728            exprs.iter().map(convert_expr).collect::<Result<_>>()?
729        }
730        sp::GroupByExpr::All(_) => {
731            return Err(SqlError::Unsupported("GROUP BY ALL".into()));
732        }
733    };
734
735    // HAVING
736    let having = select.having.as_ref().map(convert_expr).transpose()?;
737
738    Ok(Statement::Select(SelectStmt {
739        columns,
740        from,
741        from_alias,
742        joins,
743        distinct,
744        where_clause,
745        order_by,
746        limit,
747        offset,
748        group_by,
749        having,
750    }))
751}
752
753fn convert_join(join: &sp::Join) -> Result<JoinClause> {
754    let (join_type, constraint) = match &join.join_operator {
755        sp::JoinOperator::Inner(c) => (JoinType::Inner, Some(c)),
756        sp::JoinOperator::Join(c) => (JoinType::Inner, Some(c)),
757        sp::JoinOperator::CrossJoin(c) => (JoinType::Cross, Some(c)),
758        sp::JoinOperator::Left(c) => (JoinType::Left, Some(c)),
759        sp::JoinOperator::LeftSemi(c) => (JoinType::Left, Some(c)),
760        sp::JoinOperator::LeftAnti(c) => (JoinType::Left, Some(c)),
761        sp::JoinOperator::Right(c) => (JoinType::Right, Some(c)),
762        sp::JoinOperator::RightSemi(c) => (JoinType::Right, Some(c)),
763        sp::JoinOperator::RightAnti(c) => (JoinType::Right, Some(c)),
764        other => return Err(SqlError::Unsupported(format!("join type: {other:?}"))),
765    };
766
767    let (name, alias) = match &join.relation {
768        sp::TableFactor::Table { name, alias, .. } => {
769            let table_name = object_name_to_string(name);
770            let alias_str = alias.as_ref().map(|a| a.name.value.clone());
771            (table_name, alias_str)
772        }
773        _ => return Err(SqlError::Unsupported("non-table JOIN source".into())),
774    };
775
776    let on_clause = match constraint {
777        Some(sp::JoinConstraint::On(expr)) => Some(convert_expr(expr)?),
778        Some(sp::JoinConstraint::None) | None => None,
779        Some(other) => return Err(SqlError::Unsupported(format!("join constraint: {other:?}"))),
780    };
781
782    Ok(JoinClause {
783        join_type,
784        table: TableRef { name, alias },
785        on_clause,
786    })
787}
788
789fn convert_update(update: sp::Update) -> Result<Statement> {
790    let table = match &update.table.relation {
791        sp::TableFactor::Table { name, .. } => object_name_to_string(name),
792        _ => return Err(SqlError::Unsupported("non-table UPDATE target".into())),
793    };
794
795    let assignments = update.assignments.iter()
796        .map(|a| {
797            let col = match &a.target {
798                sp::AssignmentTarget::ColumnName(name) => object_name_to_string(name),
799                _ => return Err(SqlError::Unsupported("tuple assignment".into())),
800            };
801            let expr = convert_expr(&a.value)?;
802            Ok((col, expr))
803        })
804        .collect::<Result<_>>()?;
805
806    let where_clause = update.selection.as_ref()
807        .map(convert_expr)
808        .transpose()?;
809
810    Ok(Statement::Update(UpdateStmt {
811        table,
812        assignments,
813        where_clause,
814    }))
815}
816
817fn convert_delete(delete: sp::Delete) -> Result<Statement> {
818    let table_name = match &delete.from {
819        sp::FromTable::WithFromKeyword(tables) => {
820            if tables.len() != 1 {
821                return Err(SqlError::Unsupported("multi-table DELETE".into()));
822            }
823            match &tables[0].relation {
824                sp::TableFactor::Table { name, .. } => object_name_to_string(name),
825                _ => return Err(SqlError::Unsupported("non-table DELETE target".into())),
826            }
827        }
828        sp::FromTable::WithoutKeyword(tables) => {
829            if tables.len() != 1 {
830                return Err(SqlError::Unsupported("multi-table DELETE".into()));
831            }
832            match &tables[0].relation {
833                sp::TableFactor::Table { name, .. } => object_name_to_string(name),
834                _ => return Err(SqlError::Unsupported("non-table DELETE target".into())),
835            }
836        }
837    };
838
839    let where_clause = delete.selection.as_ref()
840        .map(convert_expr)
841        .transpose()?;
842
843    Ok(Statement::Delete(DeleteStmt {
844        table: table_name,
845        where_clause,
846    }))
847}
848
849// ── Expression conversion ───────────────────────────────────────────
850
851fn convert_expr(expr: &sp::Expr) -> Result<Expr> {
852    match expr {
853        sp::Expr::Value(v) => convert_value(&v.value),
854        sp::Expr::Identifier(ident) => Ok(Expr::Column(ident.value.clone())),
855        sp::Expr::CompoundIdentifier(parts) => {
856            if parts.len() == 2 {
857                Ok(Expr::QualifiedColumn {
858                    table: parts[0].value.clone(),
859                    column: parts[1].value.clone(),
860                })
861            } else {
862                Ok(Expr::Column(parts.last().unwrap().value.clone()))
863            }
864        }
865        sp::Expr::BinaryOp { left, op, right } => {
866            let bin_op = convert_bin_op(op)?;
867            Ok(Expr::BinaryOp {
868                left: Box::new(convert_expr(left)?),
869                op: bin_op,
870                right: Box::new(convert_expr(right)?),
871            })
872        }
873        sp::Expr::UnaryOp { op, expr } => {
874            let unary_op = match op {
875                sp::UnaryOperator::Minus => UnaryOp::Neg,
876                sp::UnaryOperator::Not => UnaryOp::Not,
877                _ => return Err(SqlError::Unsupported(format!("unary op: {op}"))),
878            };
879            Ok(Expr::UnaryOp {
880                op: unary_op,
881                expr: Box::new(convert_expr(expr)?),
882            })
883        }
884        sp::Expr::IsNull(e) => Ok(Expr::IsNull(Box::new(convert_expr(e)?))),
885        sp::Expr::IsNotNull(e) => Ok(Expr::IsNotNull(Box::new(convert_expr(e)?))),
886        sp::Expr::Nested(e) => convert_expr(e),
887        sp::Expr::Function(func) => convert_function(func),
888        sp::Expr::InSubquery { expr: e, subquery, negated } => {
889            let inner_expr = convert_expr(e)?;
890            let stmt = convert_subquery(subquery)?;
891            Ok(Expr::InSubquery {
892                expr: Box::new(inner_expr),
893                subquery: Box::new(stmt),
894                negated: *negated,
895            })
896        }
897        sp::Expr::InList { expr: e, list, negated } => {
898            let inner_expr = convert_expr(e)?;
899            let items = list.iter().map(convert_expr).collect::<Result<Vec<_>>>()?;
900            Ok(Expr::InList {
901                expr: Box::new(inner_expr),
902                list: items,
903                negated: *negated,
904            })
905        }
906        sp::Expr::Exists { subquery, negated } => {
907            let stmt = convert_subquery(subquery)?;
908            Ok(Expr::Exists {
909                subquery: Box::new(stmt),
910                negated: *negated,
911            })
912        }
913        sp::Expr::Subquery(query) => {
914            let stmt = convert_subquery(query)?;
915            Ok(Expr::ScalarSubquery(Box::new(stmt)))
916        }
917        sp::Expr::Between { expr: e, negated, low, high } => {
918            Ok(Expr::Between {
919                expr: Box::new(convert_expr(e)?),
920                low: Box::new(convert_expr(low)?),
921                high: Box::new(convert_expr(high)?),
922                negated: *negated,
923            })
924        }
925        sp::Expr::Like { expr: e, negated, pattern, escape_char, .. } => {
926            let esc = escape_char.as_ref()
927                .map(|v| convert_escape_value(v))
928                .transpose()?
929                .map(Box::new);
930            Ok(Expr::Like {
931                expr: Box::new(convert_expr(e)?),
932                pattern: Box::new(convert_expr(pattern)?),
933                escape: esc,
934                negated: *negated,
935            })
936        }
937        sp::Expr::ILike { expr: e, negated, pattern, escape_char, .. } => {
938            let esc = escape_char.as_ref()
939                .map(|v| convert_escape_value(v))
940                .transpose()?
941                .map(Box::new);
942            Ok(Expr::Like {
943                expr: Box::new(convert_expr(e)?),
944                pattern: Box::new(convert_expr(pattern)?),
945                escape: esc,
946                negated: *negated,
947            })
948        }
949        sp::Expr::Case { operand, conditions, else_result, .. } => {
950            let op = operand.as_ref()
951                .map(|e| convert_expr(e))
952                .transpose()?
953                .map(Box::new);
954            let conds: Vec<(Expr, Expr)> = conditions.iter()
955                .map(|cw| Ok((convert_expr(&cw.condition)?, convert_expr(&cw.result)?)))
956                .collect::<Result<_>>()?;
957            let else_r = else_result.as_ref()
958                .map(|e| convert_expr(e))
959                .transpose()?
960                .map(Box::new);
961            Ok(Expr::Case { operand: op, conditions: conds, else_result: else_r })
962        }
963        sp::Expr::Cast { expr: e, data_type: dt, .. } => {
964            let target = convert_data_type(dt)?;
965            Ok(Expr::Cast {
966                expr: Box::new(convert_expr(e)?),
967                data_type: target,
968            })
969        }
970        sp::Expr::Substring { expr: e, substring_from, substring_for, .. } => {
971            let mut args = vec![convert_expr(e)?];
972            if let Some(from) = substring_from {
973                args.push(convert_expr(from)?);
974            }
975            if let Some(f) = substring_for {
976                args.push(convert_expr(f)?);
977            }
978            Ok(Expr::Function { name: "SUBSTR".into(), args })
979        }
980        sp::Expr::Trim { expr: e, trim_where, trim_what, trim_characters } => {
981            let fn_name = match trim_where {
982                Some(sp::TrimWhereField::Leading) => "LTRIM",
983                Some(sp::TrimWhereField::Trailing) => "RTRIM",
984                _ => "TRIM",
985            };
986            let mut args = vec![convert_expr(e)?];
987            if let Some(what) = trim_what {
988                args.push(convert_expr(what)?);
989            } else if let Some(chars) = trim_characters {
990                if let Some(first) = chars.first() {
991                    args.push(convert_expr(first)?);
992                }
993            }
994            Ok(Expr::Function { name: fn_name.into(), args })
995        }
996        sp::Expr::Ceil { expr: e, .. } => {
997            Ok(Expr::Function { name: "CEIL".into(), args: vec![convert_expr(e)?] })
998        }
999        sp::Expr::Floor { expr: e, .. } => {
1000            Ok(Expr::Function { name: "FLOOR".into(), args: vec![convert_expr(e)?] })
1001        }
1002        sp::Expr::Position { expr: e, r#in } => {
1003            Ok(Expr::Function { name: "INSTR".into(), args: vec![convert_expr(r#in)?, convert_expr(e)?] })
1004        }
1005        _ => Err(SqlError::Unsupported(format!("expression: {expr}"))),
1006    }
1007}
1008
1009fn convert_value(val: &sp::Value) -> Result<Expr> {
1010    match val {
1011        sp::Value::Number(n, _) => {
1012            if let Ok(i) = n.parse::<i64>() {
1013                Ok(Expr::Literal(Value::Integer(i)))
1014            } else if let Ok(f) = n.parse::<f64>() {
1015                Ok(Expr::Literal(Value::Real(f)))
1016            } else {
1017                Err(SqlError::InvalidValue(format!("cannot parse number: {n}")))
1018            }
1019        }
1020        sp::Value::SingleQuotedString(s) => Ok(Expr::Literal(Value::Text(s.clone()))),
1021        sp::Value::Boolean(b) => Ok(Expr::Literal(Value::Boolean(*b))),
1022        sp::Value::Null => Ok(Expr::Literal(Value::Null)),
1023        sp::Value::Placeholder(s) => {
1024            let idx_str = s.strip_prefix('$')
1025                .ok_or_else(|| SqlError::Parse(format!("invalid placeholder: {s}")))?;
1026            let idx: usize = idx_str.parse()
1027                .map_err(|_| SqlError::Parse(format!("invalid placeholder index: {s}")))?;
1028            if idx == 0 {
1029                return Err(SqlError::Parse("placeholder index must be >= 1".into()));
1030            }
1031            Ok(Expr::Parameter(idx))
1032        }
1033        _ => Err(SqlError::Unsupported(format!("value type: {val}"))),
1034    }
1035}
1036
1037fn convert_escape_value(val: &sp::Value) -> Result<Expr> {
1038    match val {
1039        sp::Value::SingleQuotedString(s) => Ok(Expr::Literal(Value::Text(s.clone()))),
1040        _ => Err(SqlError::Unsupported(format!("ESCAPE value: {val}"))),
1041    }
1042}
1043
1044fn convert_bin_op(op: &sp::BinaryOperator) -> Result<BinOp> {
1045    match op {
1046        sp::BinaryOperator::Plus => Ok(BinOp::Add),
1047        sp::BinaryOperator::Minus => Ok(BinOp::Sub),
1048        sp::BinaryOperator::Multiply => Ok(BinOp::Mul),
1049        sp::BinaryOperator::Divide => Ok(BinOp::Div),
1050        sp::BinaryOperator::Modulo => Ok(BinOp::Mod),
1051        sp::BinaryOperator::Eq => Ok(BinOp::Eq),
1052        sp::BinaryOperator::NotEq => Ok(BinOp::NotEq),
1053        sp::BinaryOperator::Lt => Ok(BinOp::Lt),
1054        sp::BinaryOperator::Gt => Ok(BinOp::Gt),
1055        sp::BinaryOperator::LtEq => Ok(BinOp::LtEq),
1056        sp::BinaryOperator::GtEq => Ok(BinOp::GtEq),
1057        sp::BinaryOperator::And => Ok(BinOp::And),
1058        sp::BinaryOperator::Or => Ok(BinOp::Or),
1059        sp::BinaryOperator::StringConcat => Ok(BinOp::Concat),
1060        _ => Err(SqlError::Unsupported(format!("binary op: {op}"))),
1061    }
1062}
1063
1064fn convert_function(func: &sp::Function) -> Result<Expr> {
1065    let name = object_name_to_string(&func.name).to_ascii_uppercase();
1066
1067    // COUNT(*)
1068    match &func.args {
1069        sp::FunctionArguments::List(list) => {
1070            if list.args.is_empty() && name == "COUNT" {
1071                return Ok(Expr::CountStar);
1072            }
1073            let args = list.args.iter()
1074                .map(|arg| match arg {
1075                    sp::FunctionArg::Unnamed(sp::FunctionArgExpr::Expr(e)) => convert_expr(e),
1076                    sp::FunctionArg::Unnamed(sp::FunctionArgExpr::Wildcard) => {
1077                        if name == "COUNT" {
1078                            Ok(Expr::CountStar)
1079                        } else {
1080                            Err(SqlError::Unsupported(format!("{name}(*)")))
1081                        }
1082                    }
1083                    _ => Err(SqlError::Unsupported(format!("function arg type in {name}"))),
1084                })
1085                .collect::<Result<Vec<_>>>()?;
1086
1087            if name == "COUNT" && args.len() == 1 && matches!(args[0], Expr::CountStar) {
1088                return Ok(Expr::CountStar);
1089            }
1090
1091            if name == "COALESCE" {
1092                if args.is_empty() {
1093                    return Err(SqlError::Parse("COALESCE requires at least one argument".into()));
1094                }
1095                return Ok(Expr::Coalesce(args));
1096            }
1097
1098            if name == "NULLIF" {
1099                if args.len() != 2 {
1100                    return Err(SqlError::Parse("NULLIF requires exactly two arguments".into()));
1101                }
1102                return Ok(Expr::Case {
1103                    operand: None,
1104                    conditions: vec![
1105                        (Expr::BinaryOp {
1106                            left: Box::new(args[0].clone()),
1107                            op: BinOp::Eq,
1108                            right: Box::new(args[1].clone()),
1109                        }, Expr::Literal(Value::Null)),
1110                    ],
1111                    else_result: Some(Box::new(args[0].clone())),
1112                });
1113            }
1114
1115            if name == "IIF" {
1116                if args.len() != 3 {
1117                    return Err(SqlError::Parse("IIF requires exactly three arguments".into()));
1118                }
1119                return Ok(Expr::Case {
1120                    operand: None,
1121                    conditions: vec![(args[0].clone(), args[1].clone())],
1122                    else_result: Some(Box::new(args[2].clone())),
1123                });
1124            }
1125
1126            Ok(Expr::Function { name, args })
1127        }
1128        sp::FunctionArguments::None => {
1129            if name == "COUNT" {
1130                Ok(Expr::CountStar)
1131            } else {
1132                Ok(Expr::Function { name, args: vec![] })
1133            }
1134        }
1135        sp::FunctionArguments::Subquery(_) => {
1136            Err(SqlError::Unsupported("subquery in function".into()))
1137        }
1138    }
1139}
1140
1141fn convert_select_item(item: &sp::SelectItem) -> Result<SelectColumn> {
1142    match item {
1143        sp::SelectItem::Wildcard(_) => Ok(SelectColumn::AllColumns),
1144        sp::SelectItem::UnnamedExpr(e) => {
1145            let expr = convert_expr(e)?;
1146            Ok(SelectColumn::Expr { expr, alias: None })
1147        }
1148        sp::SelectItem::ExprWithAlias { expr, alias } => {
1149            let expr = convert_expr(expr)?;
1150            Ok(SelectColumn::Expr {
1151                expr,
1152                alias: Some(alias.value.clone()),
1153            })
1154        }
1155        sp::SelectItem::QualifiedWildcard(_, _) => {
1156            Err(SqlError::Unsupported("qualified wildcard (table.*)".into()))
1157        }
1158    }
1159}
1160
1161fn convert_order_by_expr(expr: &sp::OrderByExpr) -> Result<OrderByItem> {
1162    let e = convert_expr(&expr.expr)?;
1163    let descending = expr.options.asc.map(|asc| !asc).unwrap_or(false);
1164    let nulls_first = expr.options.nulls_first;
1165
1166    Ok(OrderByItem {
1167        expr: e,
1168        descending,
1169        nulls_first,
1170    })
1171}
1172
1173// ── Data type conversion ────────────────────────────────────────────
1174
1175fn convert_data_type(dt: &sp::DataType) -> Result<DataType> {
1176    match dt {
1177        sp::DataType::Int(_)
1178        | sp::DataType::Integer(_)
1179        | sp::DataType::BigInt(_)
1180        | sp::DataType::SmallInt(_)
1181        | sp::DataType::TinyInt(_)
1182        | sp::DataType::Int2(_)
1183        | sp::DataType::Int4(_)
1184        | sp::DataType::Int8(_) => Ok(DataType::Integer),
1185
1186        sp::DataType::Real
1187        | sp::DataType::Double(..)
1188        | sp::DataType::DoublePrecision
1189        | sp::DataType::Float(_)
1190        | sp::DataType::Float4
1191        | sp::DataType::Float64 => Ok(DataType::Real),
1192
1193        sp::DataType::Varchar(_)
1194        | sp::DataType::Text
1195        | sp::DataType::Char(_)
1196        | sp::DataType::Character(_)
1197        | sp::DataType::String(_) => Ok(DataType::Text),
1198
1199        sp::DataType::Blob(_)
1200        | sp::DataType::Bytea => Ok(DataType::Blob),
1201
1202        sp::DataType::Boolean | sp::DataType::Bool => Ok(DataType::Boolean),
1203
1204        _ => Err(SqlError::Unsupported(format!("data type: {dt}"))),
1205    }
1206}
1207
1208// ── Helpers ─────────────────────────────────────────────────────────
1209
1210fn object_name_to_string(name: &sp::ObjectName) -> String {
1211    name.0
1212        .iter()
1213        .filter_map(|p| match p {
1214            sp::ObjectNamePart::Identifier(ident) => Some(ident.value.clone()),
1215            _ => None,
1216        })
1217        .collect::<Vec<_>>()
1218        .join(".")
1219}
1220
1221#[cfg(test)]
1222mod tests {
1223    use super::*;
1224
1225    #[test]
1226    fn parse_create_table() {
1227        let stmt = parse_sql(
1228            "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL, age INTEGER)"
1229        ).unwrap();
1230
1231        match stmt {
1232            Statement::CreateTable(ct) => {
1233                assert_eq!(ct.name, "users");
1234                assert_eq!(ct.columns.len(), 3);
1235                assert_eq!(ct.columns[0].name, "id");
1236                assert_eq!(ct.columns[0].data_type, DataType::Integer);
1237                assert!(ct.columns[0].is_primary_key);
1238                assert!(!ct.columns[0].nullable);
1239                assert_eq!(ct.columns[1].name, "name");
1240                assert_eq!(ct.columns[1].data_type, DataType::Text);
1241                assert!(!ct.columns[1].nullable);
1242                assert_eq!(ct.columns[2].name, "age");
1243                assert!(ct.columns[2].nullable);
1244                assert_eq!(ct.primary_key, vec!["id"]);
1245            }
1246            _ => panic!("expected CreateTable"),
1247        }
1248    }
1249
1250    #[test]
1251    fn parse_create_table_if_not_exists() {
1252        let stmt = parse_sql("CREATE TABLE IF NOT EXISTS t (id INT PRIMARY KEY)").unwrap();
1253        match stmt {
1254            Statement::CreateTable(ct) => assert!(ct.if_not_exists),
1255            _ => panic!("expected CreateTable"),
1256        }
1257    }
1258
1259    #[test]
1260    fn parse_drop_table() {
1261        let stmt = parse_sql("DROP TABLE users").unwrap();
1262        match stmt {
1263            Statement::DropTable(dt) => {
1264                assert_eq!(dt.name, "users");
1265                assert!(!dt.if_exists);
1266            }
1267            _ => panic!("expected DropTable"),
1268        }
1269    }
1270
1271    #[test]
1272    fn parse_drop_table_if_exists() {
1273        let stmt = parse_sql("DROP TABLE IF EXISTS users").unwrap();
1274        match stmt {
1275            Statement::DropTable(dt) => assert!(dt.if_exists),
1276            _ => panic!("expected DropTable"),
1277        }
1278    }
1279
1280    #[test]
1281    fn parse_insert() {
1282        let stmt = parse_sql(
1283            "INSERT INTO users (id, name) VALUES (1, 'Alice'), (2, 'Bob')"
1284        ).unwrap();
1285
1286        match stmt {
1287            Statement::Insert(ins) => {
1288                assert_eq!(ins.table, "users");
1289                assert_eq!(ins.columns, vec!["id", "name"]);
1290                assert_eq!(ins.values.len(), 2);
1291                assert!(matches!(ins.values[0][0], Expr::Literal(Value::Integer(1))));
1292                assert!(matches!(&ins.values[0][1], Expr::Literal(Value::Text(s)) if s == "Alice"));
1293            }
1294            _ => panic!("expected Insert"),
1295        }
1296    }
1297
1298    #[test]
1299    fn parse_select_all() {
1300        let stmt = parse_sql("SELECT * FROM users").unwrap();
1301        match stmt {
1302            Statement::Select(sel) => {
1303                assert_eq!(sel.from, "users");
1304                assert!(matches!(sel.columns[0], SelectColumn::AllColumns));
1305                assert!(sel.where_clause.is_none());
1306            }
1307            _ => panic!("expected Select"),
1308        }
1309    }
1310
1311    #[test]
1312    fn parse_select_where() {
1313        let stmt = parse_sql("SELECT id, name FROM users WHERE age > 18").unwrap();
1314        match stmt {
1315            Statement::Select(sel) => {
1316                assert_eq!(sel.columns.len(), 2);
1317                assert!(sel.where_clause.is_some());
1318            }
1319            _ => panic!("expected Select"),
1320        }
1321    }
1322
1323    #[test]
1324    fn parse_select_order_limit() {
1325        let stmt = parse_sql(
1326            "SELECT * FROM users ORDER BY name ASC LIMIT 10 OFFSET 5"
1327        ).unwrap();
1328        match stmt {
1329            Statement::Select(sel) => {
1330                assert_eq!(sel.order_by.len(), 1);
1331                assert!(!sel.order_by[0].descending);
1332                assert!(sel.limit.is_some());
1333                assert!(sel.offset.is_some());
1334            }
1335            _ => panic!("expected Select"),
1336        }
1337    }
1338
1339    #[test]
1340    fn parse_update() {
1341        let stmt = parse_sql("UPDATE users SET name = 'Bob' WHERE id = 1").unwrap();
1342        match stmt {
1343            Statement::Update(upd) => {
1344                assert_eq!(upd.table, "users");
1345                assert_eq!(upd.assignments.len(), 1);
1346                assert_eq!(upd.assignments[0].0, "name");
1347                assert!(upd.where_clause.is_some());
1348            }
1349            _ => panic!("expected Update"),
1350        }
1351    }
1352
1353    #[test]
1354    fn parse_delete() {
1355        let stmt = parse_sql("DELETE FROM users WHERE id = 1").unwrap();
1356        match stmt {
1357            Statement::Delete(del) => {
1358                assert_eq!(del.table, "users");
1359                assert!(del.where_clause.is_some());
1360            }
1361            _ => panic!("expected Delete"),
1362        }
1363    }
1364
1365    #[test]
1366    fn parse_aggregate() {
1367        let stmt = parse_sql("SELECT COUNT(*), SUM(age) FROM users").unwrap();
1368        match stmt {
1369            Statement::Select(sel) => {
1370                assert_eq!(sel.columns.len(), 2);
1371                match &sel.columns[0] {
1372                    SelectColumn::Expr { expr: Expr::CountStar, .. } => {}
1373                    other => panic!("expected CountStar, got {other:?}"),
1374                }
1375            }
1376            _ => panic!("expected Select"),
1377        }
1378    }
1379
1380    #[test]
1381    fn parse_group_by_having() {
1382        let stmt = parse_sql(
1383            "SELECT department, COUNT(*) FROM employees GROUP BY department HAVING COUNT(*) > 5"
1384        ).unwrap();
1385        match stmt {
1386            Statement::Select(sel) => {
1387                assert_eq!(sel.group_by.len(), 1);
1388                assert!(sel.having.is_some());
1389            }
1390            _ => panic!("expected Select"),
1391        }
1392    }
1393
1394    #[test]
1395    fn parse_expressions() {
1396        let stmt = parse_sql("SELECT id + 1, -price, NOT active FROM items").unwrap();
1397        match stmt {
1398            Statement::Select(sel) => {
1399                assert_eq!(sel.columns.len(), 3);
1400                // id + 1
1401                match &sel.columns[0] {
1402                    SelectColumn::Expr { expr: Expr::BinaryOp { op: BinOp::Add, .. }, .. } => {}
1403                    other => panic!("expected BinaryOp Add, got {other:?}"),
1404                }
1405                // -price
1406                match &sel.columns[1] {
1407                    SelectColumn::Expr { expr: Expr::UnaryOp { op: UnaryOp::Neg, .. }, .. } => {}
1408                    other => panic!("expected UnaryOp Neg, got {other:?}"),
1409                }
1410                // NOT active
1411                match &sel.columns[2] {
1412                    SelectColumn::Expr { expr: Expr::UnaryOp { op: UnaryOp::Not, .. }, .. } => {}
1413                    other => panic!("expected UnaryOp Not, got {other:?}"),
1414                }
1415            }
1416            _ => panic!("expected Select"),
1417        }
1418    }
1419
1420    #[test]
1421    fn parse_is_null() {
1422        let stmt = parse_sql("SELECT * FROM t WHERE x IS NULL").unwrap();
1423        match stmt {
1424            Statement::Select(sel) => {
1425                assert!(matches!(sel.where_clause, Some(Expr::IsNull(_))));
1426            }
1427            _ => panic!("expected Select"),
1428        }
1429    }
1430
1431    #[test]
1432    fn parse_inner_join() {
1433        let stmt = parse_sql("SELECT * FROM a JOIN b ON a.id = b.id").unwrap();
1434        match stmt {
1435            Statement::Select(sel) => {
1436                assert_eq!(sel.from, "a");
1437                assert_eq!(sel.joins.len(), 1);
1438                assert_eq!(sel.joins[0].join_type, JoinType::Inner);
1439                assert_eq!(sel.joins[0].table.name, "b");
1440                assert!(sel.joins[0].on_clause.is_some());
1441            }
1442            _ => panic!("expected Select"),
1443        }
1444    }
1445
1446    #[test]
1447    fn parse_inner_join_explicit() {
1448        let stmt = parse_sql("SELECT * FROM a INNER JOIN b ON a.id = b.a_id").unwrap();
1449        match stmt {
1450            Statement::Select(sel) => {
1451                assert_eq!(sel.joins.len(), 1);
1452                assert_eq!(sel.joins[0].join_type, JoinType::Inner);
1453            }
1454            _ => panic!("expected Select"),
1455        }
1456    }
1457
1458    #[test]
1459    fn parse_cross_join() {
1460        let stmt = parse_sql("SELECT * FROM a CROSS JOIN b").unwrap();
1461        match stmt {
1462            Statement::Select(sel) => {
1463                assert_eq!(sel.joins.len(), 1);
1464                assert_eq!(sel.joins[0].join_type, JoinType::Cross);
1465                assert!(sel.joins[0].on_clause.is_none());
1466            }
1467            _ => panic!("expected Select"),
1468        }
1469    }
1470
1471    #[test]
1472    fn parse_left_join() {
1473        let stmt = parse_sql("SELECT * FROM a LEFT JOIN b ON a.id = b.a_id").unwrap();
1474        match stmt {
1475            Statement::Select(sel) => {
1476                assert_eq!(sel.joins.len(), 1);
1477                assert_eq!(sel.joins[0].join_type, JoinType::Left);
1478            }
1479            _ => panic!("expected Select"),
1480        }
1481    }
1482
1483    #[test]
1484    fn parse_table_alias() {
1485        let stmt = parse_sql("SELECT u.id FROM users u JOIN orders o ON u.id = o.user_id").unwrap();
1486        match stmt {
1487            Statement::Select(sel) => {
1488                assert_eq!(sel.from, "users");
1489                assert_eq!(sel.from_alias.as_deref(), Some("u"));
1490                assert_eq!(sel.joins[0].table.name, "orders");
1491                assert_eq!(sel.joins[0].table.alias.as_deref(), Some("o"));
1492            }
1493            _ => panic!("expected Select"),
1494        }
1495    }
1496
1497    #[test]
1498    fn parse_multi_join() {
1499        let stmt = parse_sql(
1500            "SELECT * FROM a JOIN b ON a.id = b.a_id JOIN c ON b.id = c.b_id"
1501        ).unwrap();
1502        match stmt {
1503            Statement::Select(sel) => {
1504                assert_eq!(sel.joins.len(), 2);
1505            }
1506            _ => panic!("expected Select"),
1507        }
1508    }
1509
1510    #[test]
1511    fn parse_qualified_column() {
1512        let stmt = parse_sql("SELECT u.id, u.name FROM users u").unwrap();
1513        match stmt {
1514            Statement::Select(sel) => {
1515                match &sel.columns[0] {
1516                    SelectColumn::Expr { expr: Expr::QualifiedColumn { table, column }, .. } => {
1517                        assert_eq!(table, "u");
1518                        assert_eq!(column, "id");
1519                    }
1520                    other => panic!("expected QualifiedColumn, got {other:?}"),
1521                }
1522            }
1523            _ => panic!("expected Select"),
1524        }
1525    }
1526
1527    #[test]
1528    fn reject_subquery() {
1529        assert!(parse_sql("SELECT * FROM (SELECT 1)").is_err());
1530    }
1531
1532    #[test]
1533    fn parse_type_mapping() {
1534        let stmt = parse_sql(
1535            "CREATE TABLE t (a INT PRIMARY KEY, b BIGINT, c SMALLINT, d REAL, e DOUBLE PRECISION, f VARCHAR(255), g BOOLEAN, h BLOB, i BYTEA)"
1536        ).unwrap();
1537        match stmt {
1538            Statement::CreateTable(ct) => {
1539                assert_eq!(ct.columns[0].data_type, DataType::Integer); // INT
1540                assert_eq!(ct.columns[1].data_type, DataType::Integer); // BIGINT
1541                assert_eq!(ct.columns[2].data_type, DataType::Integer); // SMALLINT
1542                assert_eq!(ct.columns[3].data_type, DataType::Real);    // REAL
1543                assert_eq!(ct.columns[4].data_type, DataType::Real);    // DOUBLE
1544                assert_eq!(ct.columns[5].data_type, DataType::Text);    // VARCHAR
1545                assert_eq!(ct.columns[6].data_type, DataType::Boolean); // BOOLEAN
1546                assert_eq!(ct.columns[7].data_type, DataType::Blob);    // BLOB
1547                assert_eq!(ct.columns[8].data_type, DataType::Blob);    // BYTEA
1548            }
1549            _ => panic!("expected CreateTable"),
1550        }
1551    }
1552
1553    #[test]
1554    fn parse_boolean_literals() {
1555        let stmt = parse_sql("INSERT INTO t (a, b) VALUES (true, false)").unwrap();
1556        match stmt {
1557            Statement::Insert(ins) => {
1558                assert!(matches!(ins.values[0][0], Expr::Literal(Value::Boolean(true))));
1559                assert!(matches!(ins.values[0][1], Expr::Literal(Value::Boolean(false))));
1560            }
1561            _ => panic!("expected Insert"),
1562        }
1563    }
1564
1565    #[test]
1566    fn parse_null_literal() {
1567        let stmt = parse_sql("INSERT INTO t (a) VALUES (NULL)").unwrap();
1568        match stmt {
1569            Statement::Insert(ins) => {
1570                assert!(matches!(ins.values[0][0], Expr::Literal(Value::Null)));
1571            }
1572            _ => panic!("expected Insert"),
1573        }
1574    }
1575
1576    #[test]
1577    fn parse_alias() {
1578        let stmt = parse_sql("SELECT id AS user_id FROM users").unwrap();
1579        match stmt {
1580            Statement::Select(sel) => {
1581                match &sel.columns[0] {
1582                    SelectColumn::Expr { alias: Some(a), .. } => assert_eq!(a, "user_id"),
1583                    other => panic!("expected alias, got {other:?}"),
1584                }
1585            }
1586            _ => panic!("expected Select"),
1587        }
1588    }
1589
1590    #[test]
1591    fn parse_begin() {
1592        let stmt = parse_sql("BEGIN").unwrap();
1593        assert!(matches!(stmt, Statement::Begin));
1594    }
1595
1596    #[test]
1597    fn parse_begin_transaction() {
1598        let stmt = parse_sql("BEGIN TRANSACTION").unwrap();
1599        assert!(matches!(stmt, Statement::Begin));
1600    }
1601
1602    #[test]
1603    fn parse_commit() {
1604        let stmt = parse_sql("COMMIT").unwrap();
1605        assert!(matches!(stmt, Statement::Commit));
1606    }
1607
1608    #[test]
1609    fn parse_rollback() {
1610        let stmt = parse_sql("ROLLBACK").unwrap();
1611        assert!(matches!(stmt, Statement::Rollback));
1612    }
1613
1614    #[test]
1615    fn parse_select_distinct() {
1616        let stmt = parse_sql("SELECT DISTINCT name FROM users").unwrap();
1617        match stmt {
1618            Statement::Select(sel) => {
1619                assert!(sel.distinct);
1620                assert_eq!(sel.columns.len(), 1);
1621            }
1622            _ => panic!("expected Select"),
1623        }
1624    }
1625
1626    #[test]
1627    fn parse_select_without_distinct() {
1628        let stmt = parse_sql("SELECT name FROM users").unwrap();
1629        match stmt {
1630            Statement::Select(sel) => {
1631                assert!(!sel.distinct);
1632            }
1633            _ => panic!("expected Select"),
1634        }
1635    }
1636
1637    #[test]
1638    fn parse_select_distinct_all_columns() {
1639        let stmt = parse_sql("SELECT DISTINCT * FROM users").unwrap();
1640        match stmt {
1641            Statement::Select(sel) => {
1642                assert!(sel.distinct);
1643                assert!(matches!(sel.columns[0], SelectColumn::AllColumns));
1644            }
1645            _ => panic!("expected Select"),
1646        }
1647    }
1648
1649    #[test]
1650    fn reject_distinct_on() {
1651        assert!(parse_sql("SELECT DISTINCT ON (id) * FROM users").is_err());
1652    }
1653
1654    #[test]
1655    fn parse_create_index() {
1656        let stmt = parse_sql("CREATE INDEX idx_name ON users (name)").unwrap();
1657        match stmt {
1658            Statement::CreateIndex(ci) => {
1659                assert_eq!(ci.index_name, "idx_name");
1660                assert_eq!(ci.table_name, "users");
1661                assert_eq!(ci.columns, vec!["name"]);
1662                assert!(!ci.unique);
1663                assert!(!ci.if_not_exists);
1664            }
1665            _ => panic!("expected CreateIndex"),
1666        }
1667    }
1668
1669    #[test]
1670    fn parse_create_unique_index() {
1671        let stmt = parse_sql("CREATE UNIQUE INDEX idx_email ON users (email)").unwrap();
1672        match stmt {
1673            Statement::CreateIndex(ci) => {
1674                assert!(ci.unique);
1675                assert_eq!(ci.columns, vec!["email"]);
1676            }
1677            _ => panic!("expected CreateIndex"),
1678        }
1679    }
1680
1681    #[test]
1682    fn parse_create_index_if_not_exists() {
1683        let stmt = parse_sql("CREATE INDEX IF NOT EXISTS idx_x ON t (a)").unwrap();
1684        match stmt {
1685            Statement::CreateIndex(ci) => assert!(ci.if_not_exists),
1686            _ => panic!("expected CreateIndex"),
1687        }
1688    }
1689
1690    #[test]
1691    fn parse_create_index_multi_column() {
1692        let stmt = parse_sql("CREATE INDEX idx_multi ON t (a, b, c)").unwrap();
1693        match stmt {
1694            Statement::CreateIndex(ci) => {
1695                assert_eq!(ci.columns, vec!["a", "b", "c"]);
1696            }
1697            _ => panic!("expected CreateIndex"),
1698        }
1699    }
1700
1701    #[test]
1702    fn parse_drop_index() {
1703        let stmt = parse_sql("DROP INDEX idx_name").unwrap();
1704        match stmt {
1705            Statement::DropIndex(di) => {
1706                assert_eq!(di.index_name, "idx_name");
1707                assert!(!di.if_exists);
1708            }
1709            _ => panic!("expected DropIndex"),
1710        }
1711    }
1712
1713    #[test]
1714    fn parse_drop_index_if_exists() {
1715        let stmt = parse_sql("DROP INDEX IF EXISTS idx_name").unwrap();
1716        match stmt {
1717            Statement::DropIndex(di) => {
1718                assert!(di.if_exists);
1719            }
1720            _ => panic!("expected DropIndex"),
1721        }
1722    }
1723
1724    #[test]
1725    fn parse_explain_select() {
1726        let stmt = parse_sql("EXPLAIN SELECT * FROM users WHERE id = 1").unwrap();
1727        match stmt {
1728            Statement::Explain(inner) => {
1729                assert!(matches!(*inner, Statement::Select(_)));
1730            }
1731            _ => panic!("expected Explain"),
1732        }
1733    }
1734
1735    #[test]
1736    fn parse_explain_insert() {
1737        let stmt = parse_sql("EXPLAIN INSERT INTO t (a) VALUES (1)").unwrap();
1738        assert!(matches!(stmt, Statement::Explain(_)));
1739    }
1740
1741    #[test]
1742    fn reject_explain_analyze() {
1743        assert!(parse_sql("EXPLAIN ANALYZE SELECT * FROM t").is_err());
1744    }
1745
1746    #[test]
1747    fn parse_parameter_placeholder() {
1748        let stmt = parse_sql("SELECT * FROM t WHERE id = $1").unwrap();
1749        match stmt {
1750            Statement::Select(sel) => {
1751                match &sel.where_clause {
1752                    Some(Expr::BinaryOp { right, .. }) => {
1753                        assert!(matches!(right.as_ref(), Expr::Parameter(1)));
1754                    }
1755                    other => panic!("expected BinaryOp with Parameter, got {other:?}"),
1756                }
1757            }
1758            _ => panic!("expected Select"),
1759        }
1760    }
1761
1762    #[test]
1763    fn parse_multiple_parameters() {
1764        let stmt = parse_sql("INSERT INTO t (a, b) VALUES ($1, $2)").unwrap();
1765        match stmt {
1766            Statement::Insert(ins) => {
1767                assert!(matches!(ins.values[0][0], Expr::Parameter(1)));
1768                assert!(matches!(ins.values[0][1], Expr::Parameter(2)));
1769            }
1770            _ => panic!("expected Insert"),
1771        }
1772    }
1773
1774    #[test]
1775    fn reject_zero_parameter() {
1776        assert!(parse_sql("SELECT $0 FROM t").is_err());
1777    }
1778
1779    #[test]
1780    fn count_params_basic() {
1781        let stmt = parse_sql("SELECT * FROM t WHERE a = $1 AND b = $3").unwrap();
1782        assert_eq!(count_params(&stmt), 3);
1783    }
1784
1785    #[test]
1786    fn count_params_none() {
1787        let stmt = parse_sql("SELECT * FROM t WHERE a = 1").unwrap();
1788        assert_eq!(count_params(&stmt), 0);
1789    }
1790
1791    #[test]
1792    fn bind_params_basic() {
1793        let stmt = parse_sql("SELECT * FROM t WHERE id = $1").unwrap();
1794        let bound = bind_params(&stmt, &[Value::Integer(42)]).unwrap();
1795        match bound {
1796            Statement::Select(sel) => {
1797                match &sel.where_clause {
1798                    Some(Expr::BinaryOp { right, .. }) => {
1799                        assert!(matches!(right.as_ref(), Expr::Literal(Value::Integer(42))));
1800                    }
1801                    other => panic!("expected BinaryOp with Literal, got {other:?}"),
1802                }
1803            }
1804            _ => panic!("expected Select"),
1805        }
1806    }
1807
1808    #[test]
1809    fn bind_params_out_of_range() {
1810        let stmt = parse_sql("SELECT * FROM t WHERE id = $2").unwrap();
1811        let result = bind_params(&stmt, &[Value::Integer(1)]);
1812        assert!(result.is_err());
1813    }
1814
1815    #[test]
1816    fn parse_table_constraint_pk() {
1817        let stmt = parse_sql(
1818            "CREATE TABLE t (a INTEGER, b TEXT, PRIMARY KEY (a))"
1819        ).unwrap();
1820        match stmt {
1821            Statement::CreateTable(ct) => {
1822                assert_eq!(ct.primary_key, vec!["a"]);
1823                assert!(ct.columns[0].is_primary_key);
1824                assert!(!ct.columns[0].nullable);
1825            }
1826            _ => panic!("expected CreateTable"),
1827        }
1828    }
1829}