Skip to main content

citadel_sql/
parser.rs

1//! SQL parser: converts SQL strings into the internal AST.
2
3use sqlparser::ast as sp;
4use sqlparser::dialect::GenericDialect;
5use sqlparser::parser::Parser;
6
7use crate::error::{Result, SqlError};
8use crate::types::{DataType, Value};
9
10#[derive(Debug, Clone)]
11pub enum Statement {
12    CreateTable(CreateTableStmt),
13    DropTable(DropTableStmt),
14    CreateIndex(CreateIndexStmt),
15    DropIndex(DropIndexStmt),
16    CreateView(CreateViewStmt),
17    DropView(DropViewStmt),
18    AlterTable(Box<AlterTableStmt>),
19    Insert(InsertStmt),
20    Select(Box<SelectQuery>),
21    Update(UpdateStmt),
22    Delete(DeleteStmt),
23    Truncate(TruncateStmt),
24    Begin,
25    Commit,
26    Rollback,
27    Savepoint(String),
28    ReleaseSavepoint(String),
29    RollbackTo(String),
30    SetTimezone(String),
31    Explain(Box<Statement>),
32}
33
34#[derive(Debug, Clone)]
35pub struct AlterTableStmt {
36    pub table: String,
37    pub op: AlterTableOp,
38}
39
40#[derive(Debug, Clone)]
41pub enum AlterTableOp {
42    AddColumn {
43        column: Box<ColumnSpec>,
44        foreign_key: Option<ForeignKeyDef>,
45        if_not_exists: bool,
46    },
47    DropColumn {
48        name: String,
49        if_exists: bool,
50    },
51    RenameColumn {
52        old_name: String,
53        new_name: String,
54    },
55    RenameTable {
56        new_name: String,
57    },
58}
59
60#[derive(Debug, Clone)]
61pub struct CreateTableStmt {
62    pub name: String,
63    pub columns: Vec<ColumnSpec>,
64    pub primary_key: Vec<String>,
65    pub if_not_exists: bool,
66    pub check_constraints: Vec<TableCheckConstraint>,
67    pub foreign_keys: Vec<ForeignKeyDef>,
68    pub unique_indices: Vec<UniqueIndexDef>,
69}
70
71#[derive(Debug, Clone)]
72pub struct UniqueIndexDef {
73    pub name: Option<String>,
74    pub columns: Vec<String>,
75}
76
77#[derive(Debug, Clone)]
78pub struct TableCheckConstraint {
79    pub name: Option<String>,
80    pub expr: Expr,
81    pub sql: String,
82}
83
84#[derive(Debug, Clone)]
85pub struct ForeignKeyDef {
86    pub name: Option<String>,
87    pub columns: Vec<String>,
88    pub foreign_table: String,
89    pub referred_columns: Vec<String>,
90    pub on_delete: ReferentialAction,
91    pub on_update: ReferentialAction,
92}
93
94#[derive(Debug, Clone, Copy, PartialEq, Eq)]
95#[repr(u8)]
96pub enum ReferentialAction {
97    NoAction = 0,
98    Restrict = 1,
99    Cascade = 2,
100    SetNull = 3,
101    SetDefault = 4,
102}
103
104impl ReferentialAction {
105    pub fn from_tag(tag: u8) -> Option<Self> {
106        match tag {
107            0 => Some(Self::NoAction),
108            1 => Some(Self::Restrict),
109            2 => Some(Self::Cascade),
110            3 => Some(Self::SetNull),
111            4 => Some(Self::SetDefault),
112            _ => None,
113        }
114    }
115}
116
117#[derive(Debug, Clone, Copy, PartialEq, Eq)]
118pub enum GeneratedKind {
119    Stored,
120    Virtual,
121}
122
123#[derive(Debug, Clone)]
124pub struct ColumnSpec {
125    pub name: String,
126    pub data_type: DataType,
127    pub nullable: bool,
128    pub is_primary_key: bool,
129    pub default_expr: Option<Expr>,
130    pub default_sql: Option<String>,
131    pub check_expr: Option<Expr>,
132    pub check_sql: Option<String>,
133    pub check_name: Option<String>,
134    pub generated_expr: Option<Expr>,
135    pub generated_sql: Option<String>,
136    pub generated_kind: Option<GeneratedKind>,
137}
138
139#[derive(Debug, Clone)]
140pub struct DropTableStmt {
141    pub name: String,
142    pub if_exists: bool,
143}
144
145#[derive(Debug, Clone)]
146pub struct TruncateStmt {
147    pub tables: Vec<String>,
148}
149
150#[derive(Debug, Clone)]
151pub struct CreateIndexStmt {
152    pub index_name: String,
153    pub table_name: String,
154    pub columns: Vec<String>,
155    pub unique: bool,
156    pub if_not_exists: bool,
157    pub predicate_sql: Option<String>,
158    pub predicate_expr: Option<Expr>,
159}
160
161#[derive(Debug, Clone)]
162pub struct DropIndexStmt {
163    pub index_name: String,
164    pub if_exists: bool,
165}
166
167#[derive(Debug, Clone)]
168pub struct CreateViewStmt {
169    pub name: String,
170    pub sql: String,
171    pub column_aliases: Vec<String>,
172    pub or_replace: bool,
173    pub if_not_exists: bool,
174}
175
176#[derive(Debug, Clone)]
177pub struct DropViewStmt {
178    pub name: String,
179    pub if_exists: bool,
180}
181
182#[derive(Debug, Clone)]
183pub enum InsertSource {
184    Values(Vec<Vec<Expr>>),
185    Select(Box<SelectQuery>),
186}
187
188#[derive(Debug, Clone)]
189pub struct InsertStmt {
190    pub table: String,
191    pub columns: Vec<String>,
192    pub source: InsertSource,
193    pub on_conflict: Option<OnConflictClause>,
194    pub returning: Option<Vec<SelectColumn>>,
195}
196
197#[derive(Debug, Clone)]
198pub struct OnConflictClause {
199    pub target: Option<ConflictTarget>,
200    pub action: OnConflictAction,
201}
202
203#[derive(Debug, Clone)]
204pub enum ConflictTarget {
205    Columns(Vec<String>),
206    Constraint(String),
207}
208
209#[derive(Debug, Clone)]
210pub enum OnConflictAction {
211    DoNothing,
212    DoUpdate {
213        assignments: Vec<(String, Expr)>,
214        where_clause: Option<Expr>,
215    },
216}
217
218#[derive(Debug, Clone)]
219pub struct TableRef {
220    pub name: String,
221    pub alias: Option<String>,
222}
223
224#[derive(Debug, Clone, Copy, PartialEq)]
225pub enum JoinType {
226    Inner,
227    Cross,
228    Left,
229    Right,
230}
231
232#[derive(Debug, Clone)]
233pub struct JoinClause {
234    pub join_type: JoinType,
235    pub table: TableRef,
236    pub on_clause: Option<Expr>,
237}
238
239#[derive(Debug, Clone)]
240pub struct SelectStmt {
241    pub columns: Vec<SelectColumn>,
242    pub from: String,
243    pub from_alias: Option<String>,
244    pub joins: Vec<JoinClause>,
245    pub distinct: bool,
246    pub where_clause: Option<Expr>,
247    pub order_by: Vec<OrderByItem>,
248    pub limit: Option<Expr>,
249    pub offset: Option<Expr>,
250    pub group_by: Vec<Expr>,
251    pub having: Option<Expr>,
252}
253
254#[derive(Debug, Clone)]
255pub enum SetOp {
256    Union,
257    Intersect,
258    Except,
259}
260
261#[derive(Debug, Clone)]
262pub struct CompoundSelect {
263    pub op: SetOp,
264    pub all: bool,
265    pub left: Box<QueryBody>,
266    pub right: Box<QueryBody>,
267    pub order_by: Vec<OrderByItem>,
268    pub limit: Option<Expr>,
269    pub offset: Option<Expr>,
270}
271
272#[derive(Debug, Clone)]
273pub enum QueryBody {
274    Select(Box<SelectStmt>),
275    Compound(Box<CompoundSelect>),
276    Insert(Box<InsertStmt>),
277    Update(Box<UpdateStmt>),
278    Delete(Box<DeleteStmt>),
279}
280
281#[derive(Debug, Clone)]
282pub struct CteDefinition {
283    pub name: String,
284    pub column_aliases: Vec<String>,
285    pub body: QueryBody,
286}
287
288#[derive(Debug, Clone)]
289pub struct SelectQuery {
290    pub ctes: Vec<CteDefinition>,
291    pub recursive: bool,
292    pub body: QueryBody,
293}
294
295#[derive(Debug, Clone)]
296pub struct UpdateStmt {
297    pub table: String,
298    pub assignments: Vec<(String, Expr)>,
299    pub where_clause: Option<Expr>,
300    pub returning: Option<Vec<SelectColumn>>,
301}
302
303#[derive(Debug, Clone)]
304pub struct DeleteStmt {
305    pub table: String,
306    pub where_clause: Option<Expr>,
307    pub returning: Option<Vec<SelectColumn>>,
308}
309
310#[derive(Debug, Clone)]
311pub enum SelectColumn {
312    AllColumns,
313    AllFromOld,
314    AllFromNew,
315    Expr { expr: Expr, alias: Option<String> },
316}
317
318#[derive(Debug, Clone)]
319pub struct OrderByItem {
320    pub expr: Expr,
321    pub descending: bool,
322    pub nulls_first: Option<bool>,
323}
324
325#[derive(Debug, Clone)]
326pub enum Expr {
327    Literal(Value),
328    Column(String),
329    QualifiedColumn {
330        table: String,
331        column: String,
332    },
333    BinaryOp {
334        left: Box<Expr>,
335        op: BinOp,
336        right: Box<Expr>,
337    },
338    UnaryOp {
339        op: UnaryOp,
340        expr: Box<Expr>,
341    },
342    IsNull(Box<Expr>),
343    IsNotNull(Box<Expr>),
344    Function {
345        name: String,
346        args: Vec<Expr>,
347        /// True for aggregate forms like `COUNT(DISTINCT x)`.
348        distinct: bool,
349    },
350    CountStar,
351    InSubquery {
352        expr: Box<Expr>,
353        subquery: Box<SelectStmt>,
354        negated: bool,
355    },
356    InList {
357        expr: Box<Expr>,
358        list: Vec<Expr>,
359        negated: bool,
360    },
361    Exists {
362        subquery: Box<SelectStmt>,
363        negated: bool,
364    },
365    ScalarSubquery(Box<SelectStmt>),
366    InSet {
367        expr: Box<Expr>,
368        values: rustc_hash::FxHashSet<Value>,
369        has_null: bool,
370        negated: bool,
371    },
372    Between {
373        expr: Box<Expr>,
374        low: Box<Expr>,
375        high: Box<Expr>,
376        negated: bool,
377    },
378    Like {
379        expr: Box<Expr>,
380        pattern: Box<Expr>,
381        escape: Option<Box<Expr>>,
382        negated: bool,
383    },
384    Case {
385        operand: Option<Box<Expr>>,
386        conditions: Vec<(Expr, Expr)>,
387        else_result: Option<Box<Expr>>,
388    },
389    Coalesce(Vec<Expr>),
390    Cast {
391        expr: Box<Expr>,
392        data_type: DataType,
393    },
394    Parameter(usize),
395    WindowFunction {
396        name: String,
397        args: Vec<Expr>,
398        spec: WindowSpec,
399    },
400}
401
402#[derive(Debug, Clone)]
403pub struct WindowSpec {
404    pub partition_by: Vec<Expr>,
405    pub order_by: Vec<OrderByItem>,
406    pub frame: Option<WindowFrame>,
407}
408
409#[derive(Debug, Clone)]
410pub struct WindowFrame {
411    pub units: WindowFrameUnits,
412    pub start: WindowFrameBound,
413    pub end: WindowFrameBound,
414}
415
416#[derive(Debug, Clone, Copy)]
417pub enum WindowFrameUnits {
418    Rows,
419    Range,
420    Groups,
421}
422
423#[derive(Debug, Clone)]
424pub enum WindowFrameBound {
425    UnboundedPreceding,
426    Preceding(Box<Expr>),
427    CurrentRow,
428    Following(Box<Expr>),
429    UnboundedFollowing,
430}
431
432#[derive(Debug, Clone, Copy, PartialEq, Eq)]
433pub enum BinOp {
434    Add,
435    Sub,
436    Mul,
437    Div,
438    Mod,
439    Eq,
440    NotEq,
441    Lt,
442    Gt,
443    LtEq,
444    GtEq,
445    And,
446    Or,
447    Concat,
448}
449
450#[derive(Debug, Clone, Copy, PartialEq, Eq)]
451pub enum UnaryOp {
452    Neg,
453    Not,
454}
455
456pub fn has_subquery(expr: &Expr) -> bool {
457    match expr {
458        Expr::InSubquery { .. } | Expr::Exists { .. } | Expr::ScalarSubquery(_) => true,
459        Expr::BinaryOp { left, right, .. } => has_subquery(left) || has_subquery(right),
460        Expr::UnaryOp { expr, .. } => has_subquery(expr),
461        Expr::IsNull(e) | Expr::IsNotNull(e) => has_subquery(e),
462        Expr::InList { expr, list, .. } => has_subquery(expr) || list.iter().any(has_subquery),
463        Expr::InSet { expr, .. } => has_subquery(expr),
464        Expr::Between {
465            expr, low, high, ..
466        } => has_subquery(expr) || has_subquery(low) || has_subquery(high),
467        Expr::Like {
468            expr,
469            pattern,
470            escape,
471            ..
472        } => {
473            has_subquery(expr)
474                || has_subquery(pattern)
475                || escape.as_ref().is_some_and(|e| has_subquery(e))
476        }
477        Expr::Case {
478            operand,
479            conditions,
480            else_result,
481        } => {
482            operand.as_ref().is_some_and(|e| has_subquery(e))
483                || conditions
484                    .iter()
485                    .any(|(c, r)| has_subquery(c) || has_subquery(r))
486                || else_result.as_ref().is_some_and(|e| has_subquery(e))
487        }
488        Expr::Coalesce(args) | Expr::Function { args, .. } => args.iter().any(has_subquery),
489        Expr::Cast { expr, .. } => has_subquery(expr),
490        _ => false,
491    }
492}
493
494pub fn parse_sql_expr(sql: &str) -> Result<Expr> {
495    let dialect = GenericDialect {};
496    let mut parser = Parser::new(&dialect)
497        .try_with_sql(sql)
498        .map_err(|e| SqlError::Parse(e.to_string()))?;
499    let sp_expr = parser
500        .parse_expr()
501        .map_err(|e| SqlError::Parse(e.to_string()))?;
502    convert_expr(&sp_expr)
503}
504
505pub fn parse_sql(sql: &str) -> Result<Statement> {
506    let dialect = GenericDialect {};
507    let stmts = Parser::parse_sql(&dialect, sql).map_err(|e| SqlError::Parse(e.to_string()))?;
508
509    if stmts.is_empty() {
510        return Err(SqlError::Parse("empty SQL".into()));
511    }
512    if stmts.len() > 1 {
513        return Err(SqlError::Unsupported("multiple statements".into()));
514    }
515
516    convert_statement(stmts.into_iter().next().unwrap())
517}
518
519/// Parse one or more `;`-separated SQL statements.
520pub fn parse_sql_multi(sql: &str) -> Result<Vec<Statement>> {
521    let dialect = GenericDialect {};
522    let stmts = Parser::parse_sql(&dialect, sql).map_err(|e| SqlError::Parse(e.to_string()))?;
523
524    if stmts.is_empty() {
525        return Err(SqlError::Parse("empty SQL".into()));
526    }
527
528    stmts.into_iter().map(convert_statement).collect()
529}
530
531/// Returns the number of distinct parameters in a statement (max $N found).
532pub fn count_params(stmt: &Statement) -> usize {
533    let mut max_idx = 0usize;
534    visit_exprs_stmt(stmt, &mut |e| {
535        if let Expr::Parameter(n) = e {
536            max_idx = max_idx.max(*n);
537        }
538    });
539    max_idx
540}
541
542fn visit_exprs_stmt(stmt: &Statement, visitor: &mut impl FnMut(&Expr)) {
543    match stmt {
544        Statement::Select(sq) => {
545            for cte in &sq.ctes {
546                visit_exprs_query_body(&cte.body, visitor);
547            }
548            visit_exprs_query_body(&sq.body, visitor);
549        }
550        Statement::Insert(ins) => match &ins.source {
551            InsertSource::Values(rows) => {
552                for row in rows {
553                    for e in row {
554                        visit_expr(e, visitor);
555                    }
556                }
557            }
558            InsertSource::Select(sq) => {
559                for cte in &sq.ctes {
560                    visit_exprs_query_body(&cte.body, visitor);
561                }
562                visit_exprs_query_body(&sq.body, visitor);
563            }
564        },
565        Statement::Update(upd) => {
566            for (_, e) in &upd.assignments {
567                visit_expr(e, visitor);
568            }
569            if let Some(w) = &upd.where_clause {
570                visit_expr(w, visitor);
571            }
572        }
573        Statement::Delete(del) => {
574            if let Some(w) = &del.where_clause {
575                visit_expr(w, visitor);
576            }
577        }
578        Statement::Explain(inner) => visit_exprs_stmt(inner, visitor),
579        _ => {}
580    }
581}
582
583fn visit_exprs_query_body(body: &QueryBody, visitor: &mut impl FnMut(&Expr)) {
584    match body {
585        QueryBody::Select(sel) => visit_exprs_select(sel, visitor),
586        QueryBody::Compound(comp) => {
587            visit_exprs_query_body(&comp.left, visitor);
588            visit_exprs_query_body(&comp.right, visitor);
589            for o in &comp.order_by {
590                visit_expr(&o.expr, visitor);
591            }
592            if let Some(l) = &comp.limit {
593                visit_expr(l, visitor);
594            }
595            if let Some(o) = &comp.offset {
596                visit_expr(o, visitor);
597            }
598        }
599        QueryBody::Insert(ins) => visit_exprs_stmt(&Statement::Insert((**ins).clone()), visitor),
600        QueryBody::Update(upd) => visit_exprs_stmt(&Statement::Update((**upd).clone()), visitor),
601        QueryBody::Delete(del) => visit_exprs_stmt(&Statement::Delete((**del).clone()), visitor),
602    }
603}
604
605fn visit_exprs_select(sel: &SelectStmt, visitor: &mut impl FnMut(&Expr)) {
606    for col in &sel.columns {
607        if let SelectColumn::Expr { expr, .. } = col {
608            visit_expr(expr, visitor);
609        }
610    }
611    for j in &sel.joins {
612        if let Some(on) = &j.on_clause {
613            visit_expr(on, visitor);
614        }
615    }
616    if let Some(w) = &sel.where_clause {
617        visit_expr(w, visitor);
618    }
619    for o in &sel.order_by {
620        visit_expr(&o.expr, visitor);
621    }
622    if let Some(l) = &sel.limit {
623        visit_expr(l, visitor);
624    }
625    if let Some(o) = &sel.offset {
626        visit_expr(o, visitor);
627    }
628    for g in &sel.group_by {
629        visit_expr(g, visitor);
630    }
631    if let Some(h) = &sel.having {
632        visit_expr(h, visitor);
633    }
634}
635
636fn visit_expr(expr: &Expr, visitor: &mut impl FnMut(&Expr)) {
637    visitor(expr);
638    match expr {
639        Expr::BinaryOp { left, right, .. } => {
640            visit_expr(left, visitor);
641            visit_expr(right, visitor);
642        }
643        Expr::UnaryOp { expr: e, .. } | Expr::IsNull(e) | Expr::IsNotNull(e) => {
644            visit_expr(e, visitor);
645        }
646        Expr::Function { args, .. } | Expr::Coalesce(args) => {
647            for a in args {
648                visit_expr(a, visitor);
649            }
650        }
651        Expr::InSubquery {
652            expr: e, subquery, ..
653        } => {
654            visit_expr(e, visitor);
655            visit_exprs_select(subquery, visitor);
656        }
657        Expr::InList { expr: e, list, .. } => {
658            visit_expr(e, visitor);
659            for l in list {
660                visit_expr(l, visitor);
661            }
662        }
663        Expr::Exists { subquery, .. } => visit_exprs_select(subquery, visitor),
664        Expr::ScalarSubquery(sq) => visit_exprs_select(sq, visitor),
665        Expr::InSet { expr: e, .. } => visit_expr(e, visitor),
666        Expr::Between {
667            expr: e, low, high, ..
668        } => {
669            visit_expr(e, visitor);
670            visit_expr(low, visitor);
671            visit_expr(high, visitor);
672        }
673        Expr::Like {
674            expr: e,
675            pattern,
676            escape,
677            ..
678        } => {
679            visit_expr(e, visitor);
680            visit_expr(pattern, visitor);
681            if let Some(esc) = escape {
682                visit_expr(esc, visitor);
683            }
684        }
685        Expr::Case {
686            operand,
687            conditions,
688            else_result,
689        } => {
690            if let Some(op) = operand {
691                visit_expr(op, visitor);
692            }
693            for (cond, then) in conditions {
694                visit_expr(cond, visitor);
695                visit_expr(then, visitor);
696            }
697            if let Some(el) = else_result {
698                visit_expr(el, visitor);
699            }
700        }
701        Expr::Cast { expr: e, .. } => visit_expr(e, visitor),
702        Expr::WindowFunction { args, spec, .. } => {
703            for a in args {
704                visit_expr(a, visitor);
705            }
706            for p in &spec.partition_by {
707                visit_expr(p, visitor);
708            }
709            for o in &spec.order_by {
710                visit_expr(&o.expr, visitor);
711            }
712            if let Some(ref frame) = spec.frame {
713                if let WindowFrameBound::Preceding(e) | WindowFrameBound::Following(e) =
714                    &frame.start
715                {
716                    visit_expr(e, visitor);
717                }
718                if let WindowFrameBound::Preceding(e) | WindowFrameBound::Following(e) = &frame.end
719                {
720                    visit_expr(e, visitor);
721                }
722            }
723        }
724        Expr::Literal(_)
725        | Expr::Column(_)
726        | Expr::QualifiedColumn { .. }
727        | Expr::CountStar
728        | Expr::Parameter(_) => {}
729    }
730}
731
732fn convert_statement(stmt: sp::Statement) -> Result<Statement> {
733    match stmt {
734        sp::Statement::CreateTable(ct) => convert_create_table(ct),
735        sp::Statement::CreateIndex(ci) => convert_create_index(ci),
736        sp::Statement::Drop {
737            object_type: sp::ObjectType::Table,
738            if_exists,
739            names,
740            ..
741        } => {
742            if names.len() != 1 {
743                return Err(SqlError::Unsupported("multi-table DROP".into()));
744            }
745            Ok(Statement::DropTable(DropTableStmt {
746                name: object_name_to_string(&names[0]),
747                if_exists,
748            }))
749        }
750        sp::Statement::Drop {
751            object_type: sp::ObjectType::Index,
752            if_exists,
753            names,
754            ..
755        } => {
756            if names.len() != 1 {
757                return Err(SqlError::Unsupported("multi-index DROP".into()));
758            }
759            Ok(Statement::DropIndex(DropIndexStmt {
760                index_name: object_name_to_string(&names[0]),
761                if_exists,
762            }))
763        }
764        sp::Statement::CreateView(cv) => convert_create_view(cv),
765        sp::Statement::Drop {
766            object_type: sp::ObjectType::View,
767            if_exists,
768            names,
769            ..
770        } => {
771            if names.len() != 1 {
772                return Err(SqlError::Unsupported("multi-view DROP".into()));
773            }
774            Ok(Statement::DropView(DropViewStmt {
775                name: object_name_to_string(&names[0]),
776                if_exists,
777            }))
778        }
779        sp::Statement::AlterTable(at) => convert_alter_table(at),
780        sp::Statement::Insert(insert) => convert_insert(insert),
781        sp::Statement::Query(query) => convert_query(*query),
782        sp::Statement::Update(update) => convert_update(update),
783        sp::Statement::Delete(delete) => convert_delete(delete),
784        sp::Statement::Truncate(t) => convert_truncate(t),
785        sp::Statement::StartTransaction { .. } => Ok(Statement::Begin),
786        sp::Statement::Commit { chain: true, .. } => {
787            Err(SqlError::Unsupported("COMMIT AND CHAIN".into()))
788        }
789        sp::Statement::Commit { .. } => Ok(Statement::Commit),
790        sp::Statement::Rollback { chain: true, .. } => {
791            Err(SqlError::Unsupported("ROLLBACK AND CHAIN".into()))
792        }
793        sp::Statement::Rollback {
794            savepoint: Some(name),
795            ..
796        } => Ok(Statement::RollbackTo(name.value.to_ascii_lowercase())),
797        sp::Statement::Rollback { .. } => Ok(Statement::Rollback),
798        sp::Statement::Savepoint { name } => {
799            Ok(Statement::Savepoint(name.value.to_ascii_lowercase()))
800        }
801        sp::Statement::ReleaseSavepoint { name } => {
802            Ok(Statement::ReleaseSavepoint(name.value.to_ascii_lowercase()))
803        }
804        sp::Statement::Set(sp::Set::SetTimeZone { value, .. }) => {
805            // Accept a string literal or bare identifier (PG allows `SET TIME ZONE UTC`).
806            let zone = match value {
807                sp::Expr::Value(v) => match &v.value {
808                    sp::Value::SingleQuotedString(s) => s.clone(),
809                    sp::Value::DoubleQuotedString(s) => s.clone(),
810                    other => other.to_string(),
811                },
812                sp::Expr::Identifier(ident) => ident.value.clone(),
813                other => {
814                    return Err(SqlError::Parse(format!(
815                        "SET TIME ZONE expects a string literal or identifier, got: {other}"
816                    )))
817                }
818            };
819            Ok(Statement::SetTimezone(zone))
820        }
821        sp::Statement::Explain {
822            statement, analyze, ..
823        } => {
824            if analyze {
825                return Err(SqlError::Unsupported("EXPLAIN ANALYZE".into()));
826            }
827            let inner = convert_statement(*statement)?;
828            Ok(Statement::Explain(Box::new(inner)))
829        }
830        _ => Err(SqlError::Unsupported(format!("statement type: {}", stmt))),
831    }
832}
833
834/// Parse column options (NOT NULL, DEFAULT, CHECK, FK, UNIQUE) from a sqlparser ColumnDef.
835/// Returns (ColumnSpec, Option<ForeignKeyDef>, was_inline_pk, was_unique).
836fn convert_column_def(
837    col_def: &sp::ColumnDef,
838) -> Result<(ColumnSpec, Option<ForeignKeyDef>, bool, bool)> {
839    let col_name = col_def.name.value.clone();
840    let data_type = convert_data_type(&col_def.data_type)?;
841    let mut nullable = true;
842    let mut is_primary_key = false;
843    let mut is_unique = false;
844    let mut default_expr = None;
845    let mut default_sql = None;
846    let mut check_expr = None;
847    let mut check_sql = None;
848    let mut check_name = None;
849    let mut generated_expr = None;
850    let mut generated_sql = None;
851    let mut generated_kind = None;
852    let mut fk_def = None;
853
854    for opt in &col_def.options {
855        match &opt.option {
856            sp::ColumnOption::NotNull => nullable = false,
857            sp::ColumnOption::Null => nullable = true,
858            sp::ColumnOption::PrimaryKey(_) => {
859                is_primary_key = true;
860                nullable = false;
861            }
862            sp::ColumnOption::Unique(_) => is_unique = true,
863            sp::ColumnOption::Default(expr) => {
864                default_sql = Some(expr.to_string());
865                default_expr = Some(convert_expr(expr)?);
866            }
867            sp::ColumnOption::Check(check) => {
868                check_sql = Some(check.expr.to_string());
869                let converted = convert_expr(&check.expr)?;
870                if has_subquery(&converted) {
871                    return Err(SqlError::Unsupported("subquery in CHECK constraint".into()));
872                }
873                check_expr = Some(converted);
874                check_name = check.name.as_ref().map(|n| n.value.clone());
875            }
876            sp::ColumnOption::ForeignKey(fk) => {
877                let (on_delete, on_update) = convert_fk_actions(&fk.on_delete, &fk.on_update)?;
878                let ftable = object_name_to_string(&fk.foreign_table).to_ascii_lowercase();
879                let referred: Vec<String> = fk
880                    .referred_columns
881                    .iter()
882                    .map(|i| i.value.to_ascii_lowercase())
883                    .collect();
884                fk_def = Some(ForeignKeyDef {
885                    name: fk.name.as_ref().map(|n| n.value.clone()),
886                    columns: vec![col_name.to_ascii_lowercase()],
887                    foreign_table: ftable,
888                    referred_columns: referred,
889                    on_delete,
890                    on_update,
891                });
892            }
893            sp::ColumnOption::Generated {
894                generation_expr,
895                generation_expr_mode,
896                sequence_options: _,
897                ..
898            } => {
899                let Some(expr) = generation_expr else {
900                    return Err(SqlError::Unsupported(
901                        "identity columns not yet supported; use INTEGER PRIMARY KEY for autoincrement".into(),
902                    ));
903                };
904                let mode = generation_expr_mode.unwrap_or(sp::GeneratedExpressionMode::Virtual);
905                let converted = convert_expr(expr)?;
906                reject_aggregate_or_window(expr, "GENERATED")?;
907                if has_subquery(&converted) {
908                    return Err(SqlError::Unsupported(
909                        "subquery in GENERATED expression".into(),
910                    ));
911                }
912                reject_volatile_in_generated(&converted)?;
913                generated_sql = Some(expr.to_string());
914                generated_expr = Some(converted);
915                generated_kind = Some(match mode {
916                    sp::GeneratedExpressionMode::Stored => GeneratedKind::Stored,
917                    sp::GeneratedExpressionMode::Virtual => GeneratedKind::Virtual,
918                });
919            }
920            _ => {}
921        }
922    }
923
924    if generated_kind.is_some() {
925        if default_expr.is_some() {
926            return Err(SqlError::Unsupported(
927                "DEFAULT and GENERATED cannot be combined".into(),
928            ));
929        }
930        if is_primary_key {
931            return Err(SqlError::Unsupported(
932                "GENERATED column cannot be PRIMARY KEY".into(),
933            ));
934        }
935    }
936
937    let spec = ColumnSpec {
938        name: col_name,
939        data_type,
940        nullable,
941        is_primary_key,
942        default_expr,
943        default_sql,
944        check_expr,
945        check_sql,
946        check_name,
947        generated_expr,
948        generated_sql,
949        generated_kind,
950    };
951    Ok((spec, fk_def, is_primary_key, is_unique))
952}
953
954fn reject_volatile_in_generated(expr: &Expr) -> Result<()> {
955    fn walk(e: &Expr) -> Result<()> {
956        match e {
957            Expr::Function { name, args, .. } => {
958                let upper = name.to_ascii_uppercase();
959                if matches!(
960                    upper.as_str(),
961                    "RANDOM"
962                        | "NOW"
963                        | "CURRENT_TIMESTAMP"
964                        | "CURRENT_DATE"
965                        | "CURRENT_TIME"
966                        | "CLOCK_TIMESTAMP"
967                        | "STATEMENT_TIMESTAMP"
968                        | "TRANSACTION_TIMESTAMP"
969                        | "LOCALTIMESTAMP"
970                        | "LOCALTIME"
971                ) {
972                    return Err(SqlError::Unsupported(format!(
973                        "volatile function {name}() not allowed in GENERATED expression"
974                    )));
975                }
976                for a in args {
977                    walk(a)?;
978                }
979                Ok(())
980            }
981            Expr::BinaryOp { left, right, .. } => {
982                walk(left)?;
983                walk(right)
984            }
985            Expr::UnaryOp { expr, .. } => walk(expr),
986            Expr::Cast { expr, .. } => walk(expr),
987            Expr::Case {
988                operand,
989                conditions,
990                else_result,
991            } => {
992                if let Some(o) = operand {
993                    walk(o)?;
994                }
995                for (cond, res) in conditions {
996                    walk(cond)?;
997                    walk(res)?;
998                }
999                if let Some(e) = else_result {
1000                    walk(e)?;
1001                }
1002                Ok(())
1003            }
1004            Expr::Coalesce(items) => items.iter().try_for_each(walk),
1005            _ => Ok(()),
1006        }
1007    }
1008    walk(expr)
1009}
1010
1011fn convert_create_table(ct: sp::CreateTable) -> Result<Statement> {
1012    let name = object_name_to_string(&ct.name);
1013    let if_not_exists = ct.if_not_exists;
1014
1015    let mut columns = Vec::new();
1016    let mut inline_pk: Vec<String> = Vec::new();
1017    let mut foreign_keys: Vec<ForeignKeyDef> = Vec::new();
1018    let mut unique_indices: Vec<UniqueIndexDef> = Vec::new();
1019
1020    for col_def in &ct.columns {
1021        let (spec, fk_def, was_pk, was_unique) = convert_column_def(col_def)?;
1022        if was_pk {
1023            inline_pk.push(spec.name.clone());
1024        }
1025        if let Some(fk) = fk_def {
1026            foreign_keys.push(fk);
1027        }
1028        if was_unique && !was_pk {
1029            unique_indices.push(UniqueIndexDef {
1030                name: None,
1031                columns: vec![spec.name.to_ascii_lowercase()],
1032            });
1033        }
1034        columns.push(spec);
1035    }
1036
1037    let mut check_constraints: Vec<TableCheckConstraint> = Vec::new();
1038
1039    for constraint in &ct.constraints {
1040        match constraint {
1041            sp::TableConstraint::PrimaryKey(pk_constraint) => {
1042                for idx_col in &pk_constraint.columns {
1043                    let col_name = match &idx_col.column.expr {
1044                        sp::Expr::Identifier(ident) => ident.value.clone(),
1045                        _ => continue,
1046                    };
1047                    if !inline_pk.contains(&col_name) {
1048                        inline_pk.push(col_name.clone());
1049                    }
1050                    if let Some(col) = columns.iter_mut().find(|c| c.name == col_name) {
1051                        col.nullable = false;
1052                        col.is_primary_key = true;
1053                    }
1054                }
1055            }
1056            sp::TableConstraint::Check(check) => {
1057                let sql = check.expr.to_string();
1058                let converted = convert_expr(&check.expr)?;
1059                if has_subquery(&converted) {
1060                    return Err(SqlError::Unsupported("subquery in CHECK constraint".into()));
1061                }
1062                check_constraints.push(TableCheckConstraint {
1063                    name: check.name.as_ref().map(|n| n.value.clone()),
1064                    expr: converted,
1065                    sql,
1066                });
1067            }
1068            sp::TableConstraint::ForeignKey(fk) => {
1069                let (on_delete, on_update) = convert_fk_actions(&fk.on_delete, &fk.on_update)?;
1070                let cols: Vec<String> = fk
1071                    .columns
1072                    .iter()
1073                    .map(|i| i.value.to_ascii_lowercase())
1074                    .collect();
1075                let ftable = object_name_to_string(&fk.foreign_table).to_ascii_lowercase();
1076                let referred: Vec<String> = fk
1077                    .referred_columns
1078                    .iter()
1079                    .map(|i| i.value.to_ascii_lowercase())
1080                    .collect();
1081                foreign_keys.push(ForeignKeyDef {
1082                    name: fk.name.as_ref().map(|n| n.value.clone()),
1083                    columns: cols,
1084                    foreign_table: ftable,
1085                    referred_columns: referred,
1086                    on_delete,
1087                    on_update,
1088                });
1089            }
1090            sp::TableConstraint::Unique(u) => {
1091                let cols: Vec<String> = u
1092                    .columns
1093                    .iter()
1094                    .filter_map(|idx_col| match &idx_col.column.expr {
1095                        sp::Expr::Identifier(ident) => Some(ident.value.to_ascii_lowercase()),
1096                        _ => None,
1097                    })
1098                    .collect();
1099                if !cols.is_empty() {
1100                    unique_indices.push(UniqueIndexDef {
1101                        name: u.name.as_ref().map(|n| n.value.clone()),
1102                        columns: cols,
1103                    });
1104                }
1105            }
1106            _ => {}
1107        }
1108    }
1109
1110    Ok(Statement::CreateTable(CreateTableStmt {
1111        name,
1112        columns,
1113        primary_key: inline_pk,
1114        if_not_exists,
1115        check_constraints,
1116        foreign_keys,
1117        unique_indices,
1118    }))
1119}
1120
1121fn convert_alter_table(at: sp::AlterTable) -> Result<Statement> {
1122    let table = object_name_to_string(&at.name);
1123    if at.operations.len() != 1 {
1124        return Err(SqlError::Unsupported(
1125            "ALTER TABLE with multiple operations".into(),
1126        ));
1127    }
1128    let op = match at.operations.into_iter().next().unwrap() {
1129        sp::AlterTableOperation::AddColumn {
1130            column_def,
1131            if_not_exists,
1132            ..
1133        } => {
1134            let (spec, fk, _was_pk, _was_unique) = convert_column_def(&column_def)?;
1135            AlterTableOp::AddColumn {
1136                column: Box::new(spec),
1137                foreign_key: fk,
1138                if_not_exists,
1139            }
1140        }
1141        sp::AlterTableOperation::DropColumn {
1142            column_names,
1143            if_exists,
1144            ..
1145        } => {
1146            if column_names.len() != 1 {
1147                return Err(SqlError::Unsupported(
1148                    "DROP COLUMN with multiple columns".into(),
1149                ));
1150            }
1151            AlterTableOp::DropColumn {
1152                name: column_names.into_iter().next().unwrap().value,
1153                if_exists,
1154            }
1155        }
1156        sp::AlterTableOperation::RenameColumn {
1157            old_column_name,
1158            new_column_name,
1159        } => AlterTableOp::RenameColumn {
1160            old_name: old_column_name.value,
1161            new_name: new_column_name.value,
1162        },
1163        sp::AlterTableOperation::RenameTable { table_name } => {
1164            let new_name = match table_name {
1165                sp::RenameTableNameKind::To(name) | sp::RenameTableNameKind::As(name) => {
1166                    object_name_to_string(&name)
1167                }
1168            };
1169            AlterTableOp::RenameTable { new_name }
1170        }
1171        other => {
1172            return Err(SqlError::Unsupported(format!(
1173                "ALTER TABLE operation: {other}"
1174            )));
1175        }
1176    };
1177    Ok(Statement::AlterTable(Box::new(AlterTableStmt {
1178        table,
1179        op,
1180    })))
1181}
1182
1183fn convert_fk_actions(
1184    on_delete: &Option<sp::ReferentialAction>,
1185    on_update: &Option<sp::ReferentialAction>,
1186) -> Result<(ReferentialAction, ReferentialAction)> {
1187    Ok((convert_fk_action(on_delete)?, convert_fk_action(on_update)?))
1188}
1189
1190fn convert_fk_action(action: &Option<sp::ReferentialAction>) -> Result<ReferentialAction> {
1191    match action {
1192        None | Some(sp::ReferentialAction::NoAction) => Ok(ReferentialAction::NoAction),
1193        Some(sp::ReferentialAction::Restrict) => Ok(ReferentialAction::Restrict),
1194        Some(sp::ReferentialAction::Cascade) => Ok(ReferentialAction::Cascade),
1195        Some(sp::ReferentialAction::SetNull) => Ok(ReferentialAction::SetNull),
1196        Some(sp::ReferentialAction::SetDefault) => Ok(ReferentialAction::SetDefault),
1197    }
1198}
1199
1200fn convert_create_index(ci: sp::CreateIndex) -> Result<Statement> {
1201    let index_name = ci
1202        .name
1203        .as_ref()
1204        .map(object_name_to_string)
1205        .ok_or_else(|| SqlError::Parse("index name required".into()))?;
1206
1207    let table_name = object_name_to_string(&ci.table_name);
1208
1209    let columns: Vec<String> = ci
1210        .columns
1211        .iter()
1212        .map(|idx_col| match &idx_col.column.expr {
1213            sp::Expr::Identifier(ident) => Ok(ident.value.clone()),
1214            other => Err(SqlError::Unsupported(format!("expression index: {other}"))),
1215        })
1216        .collect::<Result<_>>()?;
1217
1218    if columns.is_empty() {
1219        return Err(SqlError::Parse(
1220            "index must have at least one column".into(),
1221        ));
1222    }
1223
1224    let (predicate_sql, predicate_expr) = match &ci.predicate {
1225        Some(sp_expr) => {
1226            let expr = convert_expr(sp_expr)?;
1227            validate_partial_index_predicate(&expr)?;
1228            (Some(sp_expr.to_string()), Some(expr))
1229        }
1230        None => (None, None),
1231    };
1232
1233    Ok(Statement::CreateIndex(CreateIndexStmt {
1234        index_name,
1235        table_name,
1236        columns,
1237        unique: ci.unique,
1238        if_not_exists: ci.if_not_exists,
1239        predicate_sql,
1240        predicate_expr,
1241    }))
1242}
1243
1244fn validate_partial_index_predicate(expr: &Expr) -> Result<()> {
1245    let mut bad: Option<&'static str> = None;
1246    visit_expr(expr, &mut |e| {
1247        if bad.is_some() {
1248            return;
1249        }
1250        match e {
1251            Expr::ScalarSubquery(_) | Expr::Exists { .. } | Expr::InSubquery { .. } => {
1252                bad = Some("subqueries");
1253            }
1254            Expr::CountStar => bad = Some("aggregates"),
1255            Expr::WindowFunction { .. } => bad = Some("window functions"),
1256            Expr::Parameter(_) => bad = Some("bound parameters"),
1257            Expr::QualifiedColumn { .. } => bad = Some("cross-table references"),
1258            Expr::Function { name, .. } => {
1259                if is_aggregate_function(name) {
1260                    bad = Some("aggregates");
1261                } else if !is_immutable_function(name) {
1262                    bad = Some("non-deterministic functions");
1263                }
1264            }
1265            _ => {}
1266        }
1267    });
1268    if let Some(reason) = bad {
1269        return Err(SqlError::Unsupported(format!(
1270            "partial index predicate cannot contain {reason}"
1271        )));
1272    }
1273    Ok(())
1274}
1275
1276fn is_aggregate_function(name: &str) -> bool {
1277    matches!(
1278        name.to_ascii_lowercase().as_str(),
1279        "count" | "sum" | "avg" | "min" | "max" | "total" | "group_concat" | "string_agg"
1280    )
1281}
1282
1283fn is_immutable_function(name: &str) -> bool {
1284    !matches!(
1285        name.to_ascii_lowercase().as_str(),
1286        "now"
1287            | "current_timestamp"
1288            | "current_date"
1289            | "current_time"
1290            | "localtimestamp"
1291            | "localtime"
1292            | "random"
1293            | "rand"
1294    )
1295}
1296
1297fn convert_create_view(cv: sp::CreateView) -> Result<Statement> {
1298    let name = object_name_to_string(&cv.name);
1299
1300    if cv.materialized {
1301        return Err(SqlError::Unsupported("MATERIALIZED VIEW".into()));
1302    }
1303
1304    let sql = cv.query.to_string();
1305
1306    let dialect = GenericDialect {};
1307    let test = Parser::parse_sql(&dialect, &sql).map_err(|e| SqlError::Parse(e.to_string()))?;
1308    if test.is_empty() {
1309        return Err(SqlError::Parse("empty view definition".into()));
1310    }
1311    match &test[0] {
1312        sp::Statement::Query(_) => {}
1313        _ => {
1314            return Err(SqlError::Parse(
1315                "view body must be a SELECT statement".into(),
1316            ))
1317        }
1318    }
1319
1320    let column_aliases: Vec<String> = cv
1321        .columns
1322        .iter()
1323        .map(|c| c.name.value.to_ascii_lowercase())
1324        .collect();
1325
1326    Ok(Statement::CreateView(CreateViewStmt {
1327        name,
1328        sql,
1329        column_aliases,
1330        or_replace: cv.or_replace,
1331        if_not_exists: cv.if_not_exists,
1332    }))
1333}
1334
1335fn convert_insert(insert: sp::Insert) -> Result<Statement> {
1336    let table = match &insert.table {
1337        sp::TableObject::TableName(name) => object_name_to_string(name).to_ascii_lowercase(),
1338        _ => return Err(SqlError::Unsupported("INSERT into non-table object".into())),
1339    };
1340
1341    let columns: Vec<String> = insert
1342        .columns
1343        .iter()
1344        .map(|c| c.value.to_ascii_lowercase())
1345        .collect();
1346
1347    let query = insert
1348        .source
1349        .ok_or_else(|| SqlError::Parse("INSERT requires VALUES or SELECT".into()))?;
1350
1351    let source = match *query.body {
1352        sp::SetExpr::Values(sp::Values { rows, .. }) => {
1353            let mut result = Vec::new();
1354            for row in rows {
1355                let mut exprs = Vec::new();
1356                for expr in row {
1357                    exprs.push(convert_expr(&expr)?);
1358                }
1359                result.push(exprs);
1360            }
1361            InsertSource::Values(result)
1362        }
1363        _ => {
1364            let (ctes, recursive) = if let Some(ref with) = query.with {
1365                convert_with(with)?
1366            } else {
1367                (vec![], false)
1368            };
1369            let body = convert_query_body(&query)?;
1370            InsertSource::Select(Box::new(SelectQuery {
1371                ctes,
1372                recursive,
1373                body,
1374            }))
1375        }
1376    };
1377
1378    let on_conflict = insert.on.as_ref().map(convert_on_insert).transpose()?;
1379    let returning = convert_returning(insert.returning.as_deref())?;
1380
1381    Ok(Statement::Insert(InsertStmt {
1382        table,
1383        columns,
1384        source,
1385        on_conflict,
1386        returning,
1387    }))
1388}
1389
1390fn convert_on_insert(on: &sp::OnInsert) -> Result<OnConflictClause> {
1391    match on {
1392        sp::OnInsert::OnConflict(oc) => {
1393            let target = oc
1394                .conflict_target
1395                .as_ref()
1396                .map(convert_conflict_target)
1397                .transpose()?;
1398            let action = convert_on_conflict_action(&oc.action)?;
1399            Ok(OnConflictClause { target, action })
1400        }
1401        sp::OnInsert::DuplicateKeyUpdate(_) => Err(SqlError::Parse(
1402            "ON DUPLICATE KEY UPDATE is MySQL-specific; use ON CONFLICT".into(),
1403        )),
1404        _ => Err(SqlError::Parse("unsupported ON INSERT clause".into())),
1405    }
1406}
1407
1408fn convert_conflict_target(target: &sp::ConflictTarget) -> Result<ConflictTarget> {
1409    match target {
1410        sp::ConflictTarget::Columns(cols) => Ok(ConflictTarget::Columns(
1411            cols.iter().map(|c| c.value.to_ascii_lowercase()).collect(),
1412        )),
1413        sp::ConflictTarget::OnConstraint(name) => {
1414            if name.0.len() > 1 {
1415                return Err(SqlError::Parse(
1416                    "qualified constraint names not supported".into(),
1417                ));
1418            }
1419            Ok(ConflictTarget::Constraint(
1420                object_name_to_string(name).to_ascii_lowercase(),
1421            ))
1422        }
1423    }
1424}
1425
1426fn convert_on_conflict_action(action: &sp::OnConflictAction) -> Result<OnConflictAction> {
1427    match action {
1428        sp::OnConflictAction::DoNothing => Ok(OnConflictAction::DoNothing),
1429        sp::OnConflictAction::DoUpdate(du) => {
1430            let assignments = du
1431                .assignments
1432                .iter()
1433                .map(|a| {
1434                    let col = match &a.target {
1435                        sp::AssignmentTarget::ColumnName(name) => {
1436                            object_name_to_string(name).to_ascii_lowercase()
1437                        }
1438                        _ => {
1439                            return Err(SqlError::Unsupported(
1440                                "tuple assignment in ON CONFLICT".into(),
1441                            ))
1442                        }
1443                    };
1444                    let expr = convert_expr(&a.value)?;
1445                    Ok((col, expr))
1446                })
1447                .collect::<Result<_>>()?;
1448            let where_clause = du.selection.as_ref().map(convert_expr).transpose()?;
1449            Ok(OnConflictAction::DoUpdate {
1450                assignments,
1451                where_clause,
1452            })
1453        }
1454    }
1455}
1456
1457fn convert_select_body(select: &sp::Select) -> Result<SelectStmt> {
1458    let distinct = match &select.distinct {
1459        Some(sp::Distinct::Distinct) => true,
1460        Some(sp::Distinct::On(_)) => {
1461            return Err(SqlError::Unsupported("DISTINCT ON".into()));
1462        }
1463        _ => false,
1464    };
1465
1466    let (from, from_alias, joins) = if select.from.is_empty() {
1467        (String::new(), None, vec![])
1468    } else if select.from.len() == 1 {
1469        let table_with_joins = &select.from[0];
1470        let (name, alias) = match &table_with_joins.relation {
1471            sp::TableFactor::Table { name, alias, .. } => {
1472                let table_name = object_name_to_string(name);
1473                let alias_str = alias.as_ref().map(|a| a.name.value.clone());
1474                (table_name, alias_str)
1475            }
1476            _ => return Err(SqlError::Unsupported("non-table FROM source".into())),
1477        };
1478        let j = table_with_joins
1479            .joins
1480            .iter()
1481            .map(convert_join)
1482            .collect::<Result<Vec<_>>>()?;
1483        (name, alias, j)
1484    } else {
1485        return Err(SqlError::Unsupported("comma-separated FROM tables".into()));
1486    };
1487
1488    let columns: Vec<SelectColumn> = select
1489        .projection
1490        .iter()
1491        .map(convert_select_item)
1492        .collect::<Result<_>>()?;
1493
1494    let where_clause = select.selection.as_ref().map(convert_expr).transpose()?;
1495
1496    let group_by = match &select.group_by {
1497        sp::GroupByExpr::Expressions(exprs, _) => {
1498            exprs.iter().map(convert_expr).collect::<Result<_>>()?
1499        }
1500        sp::GroupByExpr::All(_) => {
1501            return Err(SqlError::Unsupported("GROUP BY ALL".into()));
1502        }
1503    };
1504
1505    let having = select.having.as_ref().map(convert_expr).transpose()?;
1506
1507    Ok(SelectStmt {
1508        columns,
1509        from,
1510        from_alias,
1511        joins,
1512        distinct,
1513        where_clause,
1514        order_by: vec![],
1515        limit: None,
1516        offset: None,
1517        group_by,
1518        having,
1519    })
1520}
1521
1522fn convert_set_expr(set_expr: &sp::SetExpr) -> Result<QueryBody> {
1523    match set_expr {
1524        sp::SetExpr::Select(sel) => Ok(QueryBody::Select(Box::new(convert_select_body(sel)?))),
1525        sp::SetExpr::Insert(stmt) => match convert_statement(stmt.clone())? {
1526            Statement::Insert(ins) => Ok(QueryBody::Insert(Box::new(ins))),
1527            _ => Err(SqlError::Parse("expected INSERT in WITH-DML body".into())),
1528        },
1529        sp::SetExpr::Update(stmt) => match convert_statement(stmt.clone())? {
1530            Statement::Update(upd) => Ok(QueryBody::Update(Box::new(upd))),
1531            _ => Err(SqlError::Parse("expected UPDATE in WITH-DML body".into())),
1532        },
1533        sp::SetExpr::Delete(stmt) => match convert_statement(stmt.clone())? {
1534            Statement::Delete(del) => Ok(QueryBody::Delete(Box::new(del))),
1535            _ => Err(SqlError::Parse("expected DELETE in WITH-DML body".into())),
1536        },
1537        sp::SetExpr::SetOperation {
1538            op,
1539            set_quantifier,
1540            left,
1541            right,
1542        } => {
1543            let set_op = match op {
1544                sp::SetOperator::Union => SetOp::Union,
1545                sp::SetOperator::Intersect => SetOp::Intersect,
1546                sp::SetOperator::Except | sp::SetOperator::Minus => SetOp::Except,
1547            };
1548            let all = match set_quantifier {
1549                sp::SetQuantifier::All => true,
1550                sp::SetQuantifier::None | sp::SetQuantifier::Distinct => false,
1551                _ => {
1552                    return Err(SqlError::Unsupported("BY NAME set operations".into()));
1553                }
1554            };
1555            Ok(QueryBody::Compound(Box::new(CompoundSelect {
1556                op: set_op,
1557                all,
1558                left: Box::new(convert_set_expr(left)?),
1559                right: Box::new(convert_set_expr(right)?),
1560                order_by: vec![],
1561                limit: None,
1562                offset: None,
1563            })))
1564        }
1565        _ => Err(SqlError::Unsupported("unsupported set expression".into())),
1566    }
1567}
1568
1569fn convert_query_body(query: &sp::Query) -> Result<QueryBody> {
1570    let mut body = convert_set_expr(&query.body)?;
1571
1572    let order_by = if let Some(ref ob) = query.order_by {
1573        match &ob.kind {
1574            sp::OrderByKind::Expressions(exprs) => exprs
1575                .iter()
1576                .map(convert_order_by_expr)
1577                .collect::<Result<_>>()?,
1578            sp::OrderByKind::All { .. } => {
1579                return Err(SqlError::Unsupported("ORDER BY ALL".into()));
1580            }
1581        }
1582    } else {
1583        vec![]
1584    };
1585
1586    let (limit, offset) = match &query.limit_clause {
1587        Some(sp::LimitClause::LimitOffset { limit, offset, .. }) => {
1588            let l = limit.as_ref().map(convert_expr).transpose()?;
1589            let o = offset
1590                .as_ref()
1591                .map(|o| convert_expr(&o.value))
1592                .transpose()?;
1593            (l, o)
1594        }
1595        Some(sp::LimitClause::OffsetCommaLimit { limit, offset }) => {
1596            let l = Some(convert_expr(limit)?);
1597            let o = Some(convert_expr(offset)?);
1598            (l, o)
1599        }
1600        None => (None, None),
1601    };
1602
1603    match &mut body {
1604        QueryBody::Select(sel) => {
1605            sel.order_by = order_by;
1606            sel.limit = limit;
1607            sel.offset = offset;
1608        }
1609        QueryBody::Compound(comp) => {
1610            comp.order_by = order_by;
1611            comp.limit = limit;
1612            comp.offset = offset;
1613        }
1614        QueryBody::Insert(_) | QueryBody::Update(_) | QueryBody::Delete(_) => {
1615            if !order_by.is_empty() || limit.is_some() || offset.is_some() {
1616                return Err(SqlError::Parse(
1617                    "ORDER BY / LIMIT / OFFSET not allowed on DML CTE body".into(),
1618                ));
1619            }
1620        }
1621    }
1622
1623    Ok(body)
1624}
1625
1626fn convert_subquery(query: &sp::Query) -> Result<SelectStmt> {
1627    if query.with.is_some() {
1628        return Err(SqlError::Unsupported("CTEs in subqueries".into()));
1629    }
1630    match convert_query_body(query)? {
1631        QueryBody::Select(s) => Ok(*s),
1632        QueryBody::Compound(_) => Err(SqlError::Unsupported(
1633            "UNION/INTERSECT/EXCEPT in subqueries".into(),
1634        )),
1635        QueryBody::Insert(_) | QueryBody::Update(_) | QueryBody::Delete(_) => Err(
1636            SqlError::Unsupported("WITH-DML inside subqueries (PG forbids)".into()),
1637        ),
1638    }
1639}
1640
1641fn convert_with(with: &sp::With) -> Result<(Vec<CteDefinition>, bool)> {
1642    let mut names = rustc_hash::FxHashSet::default();
1643    let mut ctes = Vec::new();
1644    for cte in &with.cte_tables {
1645        let name = cte.alias.name.value.to_ascii_lowercase();
1646        if !names.insert(name.clone()) {
1647            return Err(SqlError::DuplicateCteName(name));
1648        }
1649        let column_aliases: Vec<String> = cte
1650            .alias
1651            .columns
1652            .iter()
1653            .map(|c| c.name.value.to_ascii_lowercase())
1654            .collect();
1655        let body = convert_query_body(&cte.query)?;
1656        ctes.push(CteDefinition {
1657            name,
1658            column_aliases,
1659            body,
1660        });
1661    }
1662    Ok((ctes, with.recursive))
1663}
1664
1665fn convert_query(query: sp::Query) -> Result<Statement> {
1666    let (ctes, recursive) = if let Some(ref with) = query.with {
1667        convert_with(with)?
1668    } else {
1669        (vec![], false)
1670    };
1671    let body = convert_query_body(&query)?;
1672    Ok(Statement::Select(Box::new(SelectQuery {
1673        ctes,
1674        recursive,
1675        body,
1676    })))
1677}
1678
1679fn convert_join(join: &sp::Join) -> Result<JoinClause> {
1680    let (join_type, constraint) = match &join.join_operator {
1681        sp::JoinOperator::Inner(c) => (JoinType::Inner, Some(c)),
1682        sp::JoinOperator::Join(c) => (JoinType::Inner, Some(c)),
1683        sp::JoinOperator::CrossJoin(c) => (JoinType::Cross, Some(c)),
1684        sp::JoinOperator::Left(c) => (JoinType::Left, Some(c)),
1685        sp::JoinOperator::LeftSemi(c) => (JoinType::Left, Some(c)),
1686        sp::JoinOperator::LeftAnti(c) => (JoinType::Left, Some(c)),
1687        sp::JoinOperator::Right(c) => (JoinType::Right, Some(c)),
1688        sp::JoinOperator::RightSemi(c) => (JoinType::Right, Some(c)),
1689        sp::JoinOperator::RightAnti(c) => (JoinType::Right, Some(c)),
1690        other => return Err(SqlError::Unsupported(format!("join type: {other:?}"))),
1691    };
1692
1693    let (name, alias) = match &join.relation {
1694        sp::TableFactor::Table { name, alias, .. } => {
1695            let table_name = object_name_to_string(name);
1696            let alias_str = alias.as_ref().map(|a| a.name.value.clone());
1697            (table_name, alias_str)
1698        }
1699        _ => return Err(SqlError::Unsupported("non-table JOIN source".into())),
1700    };
1701
1702    let on_clause = match constraint {
1703        Some(sp::JoinConstraint::On(expr)) => Some(convert_expr(expr)?),
1704        Some(sp::JoinConstraint::None) | None => None,
1705        Some(other) => return Err(SqlError::Unsupported(format!("join constraint: {other:?}"))),
1706    };
1707
1708    Ok(JoinClause {
1709        join_type,
1710        table: TableRef { name, alias },
1711        on_clause,
1712    })
1713}
1714
1715fn convert_update(update: sp::Update) -> Result<Statement> {
1716    let table = match &update.table.relation {
1717        sp::TableFactor::Table { name, .. } => object_name_to_string(name),
1718        _ => return Err(SqlError::Unsupported("non-table UPDATE target".into())),
1719    };
1720
1721    let assignments = update
1722        .assignments
1723        .iter()
1724        .map(|a| {
1725            let col = match &a.target {
1726                sp::AssignmentTarget::ColumnName(name) => object_name_to_string(name),
1727                _ => return Err(SqlError::Unsupported("tuple assignment".into())),
1728            };
1729            let expr = convert_expr(&a.value)?;
1730            Ok((col, expr))
1731        })
1732        .collect::<Result<_>>()?;
1733
1734    let where_clause = update.selection.as_ref().map(convert_expr).transpose()?;
1735    let returning = convert_returning(update.returning.as_deref())?;
1736
1737    Ok(Statement::Update(UpdateStmt {
1738        table,
1739        assignments,
1740        where_clause,
1741        returning,
1742    }))
1743}
1744
1745fn convert_truncate(t: sp::Truncate) -> Result<Statement> {
1746    if matches!(t.cascade, Some(sp::CascadeOption::Cascade)) {
1747        return Err(SqlError::Unsupported(
1748            "TRUNCATE CASCADE is planned for v0.13".into(),
1749        ));
1750    }
1751    if t.if_exists {
1752        return Err(SqlError::Unsupported("TRUNCATE IF EXISTS".into()));
1753    }
1754    if t.partitions.is_some() {
1755        return Err(SqlError::Unsupported("TRUNCATE PARTITION".into()));
1756    }
1757    if t.on_cluster.is_some() {
1758        return Err(SqlError::Unsupported("TRUNCATE ON CLUSTER".into()));
1759    }
1760    if t.table_names.is_empty() {
1761        return Err(SqlError::Parse(
1762            "TRUNCATE requires at least one table".into(),
1763        ));
1764    }
1765
1766    let tables: Vec<String> = t
1767        .table_names
1768        .iter()
1769        .map(|tt| object_name_to_string(&tt.name))
1770        .collect();
1771
1772    Ok(Statement::Truncate(TruncateStmt { tables }))
1773}
1774
1775fn convert_delete(delete: sp::Delete) -> Result<Statement> {
1776    let table_name = match &delete.from {
1777        sp::FromTable::WithFromKeyword(tables) => {
1778            if tables.len() != 1 {
1779                return Err(SqlError::Unsupported("multi-table DELETE".into()));
1780            }
1781            match &tables[0].relation {
1782                sp::TableFactor::Table { name, .. } => object_name_to_string(name),
1783                _ => return Err(SqlError::Unsupported("non-table DELETE target".into())),
1784            }
1785        }
1786        sp::FromTable::WithoutKeyword(tables) => {
1787            if tables.len() != 1 {
1788                return Err(SqlError::Unsupported("multi-table DELETE".into()));
1789            }
1790            match &tables[0].relation {
1791                sp::TableFactor::Table { name, .. } => object_name_to_string(name),
1792                _ => return Err(SqlError::Unsupported("non-table DELETE target".into())),
1793            }
1794        }
1795    };
1796
1797    let where_clause = delete.selection.as_ref().map(convert_expr).transpose()?;
1798    let returning = convert_returning(delete.returning.as_deref())?;
1799
1800    Ok(Statement::Delete(DeleteStmt {
1801        table: table_name,
1802        where_clause,
1803        returning,
1804    }))
1805}
1806
1807fn convert_expr(expr: &sp::Expr) -> Result<Expr> {
1808    match expr {
1809        sp::Expr::Value(v) => convert_value(&v.value),
1810        sp::Expr::Identifier(ident) => Ok(Expr::Column(ident.value.to_ascii_lowercase())),
1811        sp::Expr::CompoundIdentifier(parts) => {
1812            if parts.len() == 2 {
1813                Ok(Expr::QualifiedColumn {
1814                    table: parts[0].value.to_ascii_lowercase(),
1815                    column: parts[1].value.to_ascii_lowercase(),
1816                })
1817            } else {
1818                Ok(Expr::Column(
1819                    parts.last().unwrap().value.to_ascii_lowercase(),
1820                ))
1821            }
1822        }
1823        sp::Expr::BinaryOp { left, op, right } => {
1824            let bin_op = convert_bin_op(op)?;
1825            Ok(Expr::BinaryOp {
1826                left: Box::new(convert_expr(left)?),
1827                op: bin_op,
1828                right: Box::new(convert_expr(right)?),
1829            })
1830        }
1831        sp::Expr::UnaryOp { op, expr } => {
1832            let unary_op = match op {
1833                sp::UnaryOperator::Minus => UnaryOp::Neg,
1834                sp::UnaryOperator::Not => UnaryOp::Not,
1835                _ => return Err(SqlError::Unsupported(format!("unary op: {op}"))),
1836            };
1837            Ok(Expr::UnaryOp {
1838                op: unary_op,
1839                expr: Box::new(convert_expr(expr)?),
1840            })
1841        }
1842        sp::Expr::IsNull(e) => Ok(Expr::IsNull(Box::new(convert_expr(e)?))),
1843        sp::Expr::IsNotNull(e) => Ok(Expr::IsNotNull(Box::new(convert_expr(e)?))),
1844        sp::Expr::Nested(e) => convert_expr(e),
1845        sp::Expr::Function(func) => convert_function(func),
1846        sp::Expr::InSubquery {
1847            expr: e,
1848            subquery,
1849            negated,
1850        } => {
1851            let inner_expr = convert_expr(e)?;
1852            let stmt = convert_subquery(subquery)?;
1853            Ok(Expr::InSubquery {
1854                expr: Box::new(inner_expr),
1855                subquery: Box::new(stmt),
1856                negated: *negated,
1857            })
1858        }
1859        sp::Expr::InList {
1860            expr: e,
1861            list,
1862            negated,
1863        } => {
1864            let inner_expr = convert_expr(e)?;
1865            let items = list.iter().map(convert_expr).collect::<Result<Vec<_>>>()?;
1866            Ok(Expr::InList {
1867                expr: Box::new(inner_expr),
1868                list: items,
1869                negated: *negated,
1870            })
1871        }
1872        sp::Expr::Exists { subquery, negated } => {
1873            let stmt = convert_subquery(subquery)?;
1874            Ok(Expr::Exists {
1875                subquery: Box::new(stmt),
1876                negated: *negated,
1877            })
1878        }
1879        sp::Expr::Subquery(query) => {
1880            let stmt = convert_subquery(query)?;
1881            Ok(Expr::ScalarSubquery(Box::new(stmt)))
1882        }
1883        sp::Expr::Between {
1884            expr: e,
1885            negated,
1886            low,
1887            high,
1888        } => Ok(Expr::Between {
1889            expr: Box::new(convert_expr(e)?),
1890            low: Box::new(convert_expr(low)?),
1891            high: Box::new(convert_expr(high)?),
1892            negated: *negated,
1893        }),
1894        sp::Expr::Like {
1895            expr: e,
1896            negated,
1897            pattern,
1898            escape_char,
1899            ..
1900        } => {
1901            let esc = escape_char
1902                .as_ref()
1903                .map(convert_escape_value)
1904                .transpose()?
1905                .map(Box::new);
1906            Ok(Expr::Like {
1907                expr: Box::new(convert_expr(e)?),
1908                pattern: Box::new(convert_expr(pattern)?),
1909                escape: esc,
1910                negated: *negated,
1911            })
1912        }
1913        sp::Expr::ILike {
1914            expr: e,
1915            negated,
1916            pattern,
1917            escape_char,
1918            ..
1919        } => {
1920            let esc = escape_char
1921                .as_ref()
1922                .map(convert_escape_value)
1923                .transpose()?
1924                .map(Box::new);
1925            Ok(Expr::Like {
1926                expr: Box::new(convert_expr(e)?),
1927                pattern: Box::new(convert_expr(pattern)?),
1928                escape: esc,
1929                negated: *negated,
1930            })
1931        }
1932        sp::Expr::Case {
1933            operand,
1934            conditions,
1935            else_result,
1936            ..
1937        } => {
1938            let op = operand
1939                .as_ref()
1940                .map(|e| convert_expr(e))
1941                .transpose()?
1942                .map(Box::new);
1943            let conds: Vec<(Expr, Expr)> = conditions
1944                .iter()
1945                .map(|cw| Ok((convert_expr(&cw.condition)?, convert_expr(&cw.result)?)))
1946                .collect::<Result<_>>()?;
1947            let else_r = else_result
1948                .as_ref()
1949                .map(|e| convert_expr(e))
1950                .transpose()?
1951                .map(Box::new);
1952            Ok(Expr::Case {
1953                operand: op,
1954                conditions: conds,
1955                else_result: else_r,
1956            })
1957        }
1958        sp::Expr::Cast {
1959            expr: e,
1960            data_type: dt,
1961            ..
1962        } => {
1963            let target = convert_data_type(dt)?;
1964            Ok(Expr::Cast {
1965                expr: Box::new(convert_expr(e)?),
1966                data_type: target,
1967            })
1968        }
1969        sp::Expr::Substring {
1970            expr: e,
1971            substring_from,
1972            substring_for,
1973            ..
1974        } => {
1975            let mut args = vec![convert_expr(e)?];
1976            if let Some(from) = substring_from {
1977                args.push(convert_expr(from)?);
1978            }
1979            if let Some(f) = substring_for {
1980                args.push(convert_expr(f)?);
1981            }
1982            Ok(Expr::Function {
1983                name: "SUBSTR".into(),
1984                args,
1985                distinct: false,
1986            })
1987        }
1988        sp::Expr::Trim {
1989            expr: e,
1990            trim_where,
1991            trim_what,
1992            trim_characters,
1993        } => {
1994            let fn_name = match trim_where {
1995                Some(sp::TrimWhereField::Leading) => "LTRIM",
1996                Some(sp::TrimWhereField::Trailing) => "RTRIM",
1997                _ => "TRIM",
1998            };
1999            let mut args = vec![convert_expr(e)?];
2000            if let Some(what) = trim_what {
2001                args.push(convert_expr(what)?);
2002            } else if let Some(chars) = trim_characters {
2003                if let Some(first) = chars.first() {
2004                    args.push(convert_expr(first)?);
2005                }
2006            }
2007            Ok(Expr::Function {
2008                name: fn_name.into(),
2009                args,
2010                distinct: false,
2011            })
2012        }
2013        sp::Expr::Ceil { expr: e, .. } => Ok(Expr::Function {
2014            name: "CEIL".into(),
2015            args: vec![convert_expr(e)?],
2016            distinct: false,
2017        }),
2018        sp::Expr::Floor { expr: e, .. } => Ok(Expr::Function {
2019            name: "FLOOR".into(),
2020            args: vec![convert_expr(e)?],
2021            distinct: false,
2022        }),
2023        sp::Expr::Position { expr: e, r#in } => Ok(Expr::Function {
2024            name: "INSTR".into(),
2025            args: vec![convert_expr(r#in)?, convert_expr(e)?],
2026            distinct: false,
2027        }),
2028        // Typed literal: `DATE '2024-01-15'`, `TIMESTAMP '...'`, etc.
2029        sp::Expr::TypedString(ts) => {
2030            let raw = match &ts.value.value {
2031                sp::Value::SingleQuotedString(s) => s.clone(),
2032                sp::Value::DoubleQuotedString(s) => s.clone(),
2033                other => other.to_string(),
2034            };
2035            convert_typed_string(&ts.data_type, &raw)
2036        }
2037        // INTERVAL '...' — sqlparser emits Expr::Interval with a boxed Expr value + field qualifiers
2038        sp::Expr::Interval(iv) => convert_interval_expr(iv),
2039        // EXTRACT(field FROM src)
2040        sp::Expr::Extract { field, expr: e, .. } => {
2041            let field_name = match field {
2042                sp::DateTimeField::Year => "year",
2043                sp::DateTimeField::Month => "month",
2044                sp::DateTimeField::Week(_) => "week",
2045                sp::DateTimeField::Day => "day",
2046                sp::DateTimeField::Date => "day",
2047                sp::DateTimeField::Hour => "hour",
2048                sp::DateTimeField::Minute => "minute",
2049                sp::DateTimeField::Second => "second",
2050                sp::DateTimeField::Millisecond => "milliseconds",
2051                sp::DateTimeField::Microsecond => "microseconds",
2052                sp::DateTimeField::Microseconds => "microseconds",
2053                sp::DateTimeField::Milliseconds => "milliseconds",
2054                sp::DateTimeField::Dow => "dow",
2055                sp::DateTimeField::Isodow => "isodow",
2056                sp::DateTimeField::Doy => "doy",
2057                sp::DateTimeField::Epoch => "epoch",
2058                sp::DateTimeField::Quarter => "quarter",
2059                sp::DateTimeField::Decade => "decade",
2060                sp::DateTimeField::Century => "century",
2061                sp::DateTimeField::Millennium => "millennium",
2062                sp::DateTimeField::Isoyear => "isoyear",
2063                sp::DateTimeField::Julian => "julian",
2064                other => {
2065                    return Err(SqlError::InvalidExtractField(format!("{other:?}")));
2066                }
2067            };
2068            Ok(Expr::Function {
2069                name: "EXTRACT".into(),
2070                args: vec![
2071                    Expr::Literal(Value::Text(field_name.into())),
2072                    convert_expr(e)?,
2073                ],
2074                distinct: false,
2075            })
2076        }
2077        // `AT TIME ZONE 'zone'` operator — desugars to AT_TIMEZONE(ts, zone) scalar function.
2078        sp::Expr::AtTimeZone {
2079            timestamp,
2080            time_zone,
2081        } => Ok(Expr::Function {
2082            name: "AT_TIMEZONE".into(),
2083            args: vec![convert_expr(timestamp)?, convert_expr(time_zone)?],
2084            distinct: false,
2085        }),
2086        _ => Err(SqlError::Unsupported(format!("expression: {expr}"))),
2087    }
2088}
2089
2090fn convert_value(val: &sp::Value) -> Result<Expr> {
2091    match val {
2092        sp::Value::Number(n, _) => {
2093            if let Ok(i) = n.parse::<i64>() {
2094                Ok(Expr::Literal(Value::Integer(i)))
2095            } else if let Ok(f) = n.parse::<f64>() {
2096                Ok(Expr::Literal(Value::Real(f)))
2097            } else {
2098                Err(SqlError::InvalidValue(format!("cannot parse number: {n}")))
2099            }
2100        }
2101        sp::Value::SingleQuotedString(s) => Ok(Expr::Literal(Value::Text(s.as_str().into()))),
2102        sp::Value::Boolean(b) => Ok(Expr::Literal(Value::Boolean(*b))),
2103        sp::Value::Null => Ok(Expr::Literal(Value::Null)),
2104        sp::Value::Placeholder(s) => {
2105            let idx_str = s
2106                .strip_prefix('$')
2107                .ok_or_else(|| SqlError::Parse(format!("invalid placeholder: {s}")))?;
2108            let idx: usize = idx_str
2109                .parse()
2110                .map_err(|_| SqlError::Parse(format!("invalid placeholder index: {s}")))?;
2111            if idx == 0 {
2112                return Err(SqlError::Parse("placeholder index must be >= 1".into()));
2113            }
2114            Ok(Expr::Parameter(idx))
2115        }
2116        _ => Err(SqlError::Unsupported(format!("value type: {val}"))),
2117    }
2118}
2119
2120fn convert_typed_string(dt: &sp::DataType, value: &str) -> Result<Expr> {
2121    let s = value.trim_matches('\'');
2122    match dt {
2123        sp::DataType::Date => {
2124            let d = crate::datetime::parse_date(s)?;
2125            Ok(Expr::Literal(Value::Date(d)))
2126        }
2127        sp::DataType::Time(_, _) => {
2128            let t = crate::datetime::parse_time(s)?;
2129            Ok(Expr::Literal(Value::Time(t)))
2130        }
2131        sp::DataType::Timestamp(_, _) => {
2132            let t = crate::datetime::parse_timestamp(s)?;
2133            Ok(Expr::Literal(Value::Timestamp(t)))
2134        }
2135        sp::DataType::Interval { .. } => {
2136            let (months, days, micros) = crate::datetime::parse_interval(s)?;
2137            Ok(Expr::Literal(Value::Interval {
2138                months,
2139                days,
2140                micros,
2141            }))
2142        }
2143        _ => {
2144            let target = convert_data_type(dt)?;
2145            Ok(Expr::Cast {
2146                expr: Box::new(Expr::Literal(Value::Text(s.into()))),
2147                data_type: target,
2148            })
2149        }
2150    }
2151}
2152
2153fn convert_interval_expr(iv: &sp::Interval) -> Result<Expr> {
2154    let raw = match iv.value.as_ref() {
2155        sp::Expr::Value(v) => match &v.value {
2156            sp::Value::SingleQuotedString(s) => s.clone(),
2157            sp::Value::Number(n, _) => n.clone(),
2158            other => {
2159                return Err(SqlError::InvalidIntervalLiteral(format!(
2160                    "unsupported inner value: {other}"
2161                )))
2162            }
2163        },
2164        other => {
2165            return Err(SqlError::InvalidIntervalLiteral(format!(
2166                "unsupported inner expr: {other}"
2167            )))
2168        }
2169    };
2170
2171    // SQL-standard form `INTERVAL '5' DAY` — append the unit to the literal.
2172    let with_unit = if let Some(field) = &iv.leading_field {
2173        let unit_name = match field {
2174            sp::DateTimeField::Year => "years",
2175            sp::DateTimeField::Month => "months",
2176            sp::DateTimeField::Week(_) => "weeks",
2177            sp::DateTimeField::Day => "days",
2178            sp::DateTimeField::Hour => "hours",
2179            sp::DateTimeField::Minute => "minutes",
2180            sp::DateTimeField::Second => "seconds",
2181            _ => {
2182                return Err(SqlError::InvalidIntervalLiteral(format!(
2183                    "unsupported leading field: {field:?}"
2184                )))
2185            }
2186        };
2187        format!("{raw} {unit_name}")
2188    } else {
2189        raw
2190    };
2191
2192    let (months, days, micros) = crate::datetime::parse_interval(&with_unit)?;
2193    Ok(Expr::Literal(Value::Interval {
2194        months,
2195        days,
2196        micros,
2197    }))
2198}
2199
2200fn convert_escape_value(val: &sp::Value) -> Result<Expr> {
2201    match val {
2202        sp::Value::SingleQuotedString(s) => Ok(Expr::Literal(Value::Text(s.as_str().into()))),
2203        _ => Err(SqlError::Unsupported(format!("ESCAPE value: {val}"))),
2204    }
2205}
2206
2207fn convert_bin_op(op: &sp::BinaryOperator) -> Result<BinOp> {
2208    match op {
2209        sp::BinaryOperator::Plus => Ok(BinOp::Add),
2210        sp::BinaryOperator::Minus => Ok(BinOp::Sub),
2211        sp::BinaryOperator::Multiply => Ok(BinOp::Mul),
2212        sp::BinaryOperator::Divide => Ok(BinOp::Div),
2213        sp::BinaryOperator::Modulo => Ok(BinOp::Mod),
2214        sp::BinaryOperator::Eq => Ok(BinOp::Eq),
2215        sp::BinaryOperator::NotEq => Ok(BinOp::NotEq),
2216        sp::BinaryOperator::Lt => Ok(BinOp::Lt),
2217        sp::BinaryOperator::Gt => Ok(BinOp::Gt),
2218        sp::BinaryOperator::LtEq => Ok(BinOp::LtEq),
2219        sp::BinaryOperator::GtEq => Ok(BinOp::GtEq),
2220        sp::BinaryOperator::And => Ok(BinOp::And),
2221        sp::BinaryOperator::Or => Ok(BinOp::Or),
2222        sp::BinaryOperator::StringConcat => Ok(BinOp::Concat),
2223        _ => Err(SqlError::Unsupported(format!("binary op: {op}"))),
2224    }
2225}
2226
2227fn convert_function(func: &sp::Function) -> Result<Expr> {
2228    let name = object_name_to_string(&func.name).to_ascii_uppercase();
2229
2230    let (args, is_count_star, distinct) = match &func.args {
2231        sp::FunctionArguments::List(list) => {
2232            let distinct = matches!(
2233                list.duplicate_treatment,
2234                Some(sp::DuplicateTreatment::Distinct)
2235            );
2236            if list.args.is_empty() && name == "COUNT" {
2237                (vec![], true, distinct)
2238            } else {
2239                let mut count_star = false;
2240                let args = list
2241                    .args
2242                    .iter()
2243                    .map(|arg| match arg {
2244                        sp::FunctionArg::Unnamed(sp::FunctionArgExpr::Expr(e)) => convert_expr(e),
2245                        sp::FunctionArg::Unnamed(sp::FunctionArgExpr::Wildcard) => {
2246                            if name == "COUNT" {
2247                                count_star = true;
2248                                Ok(Expr::CountStar)
2249                            } else {
2250                                Err(SqlError::Unsupported(format!("{name}(*)")))
2251                            }
2252                        }
2253                        _ => Err(SqlError::Unsupported(format!(
2254                            "function arg type in {name}"
2255                        ))),
2256                    })
2257                    .collect::<Result<Vec<_>>>()?;
2258                if name == "COUNT" && args.len() == 1 && count_star {
2259                    (vec![], true, distinct)
2260                } else {
2261                    (args, false, distinct)
2262                }
2263            }
2264        }
2265        sp::FunctionArguments::None => {
2266            if name == "COUNT" {
2267                (vec![], true, false)
2268            } else {
2269                (vec![], false, false)
2270            }
2271        }
2272        sp::FunctionArguments::Subquery(_) => {
2273            return Err(SqlError::Unsupported("subquery in function".into()));
2274        }
2275    };
2276
2277    // Window function: check OVER before any other special handling
2278    if let Some(over) = &func.over {
2279        let spec = match over {
2280            sp::WindowType::WindowSpec(ws) => convert_window_spec(ws)?,
2281            sp::WindowType::NamedWindow(_) => {
2282                return Err(SqlError::Unsupported("named windows".into()));
2283            }
2284        };
2285        return Ok(Expr::WindowFunction { name, args, spec });
2286    }
2287
2288    // Non-window special forms
2289    if is_count_star {
2290        return Ok(Expr::CountStar);
2291    }
2292
2293    if name == "COALESCE" {
2294        if args.is_empty() {
2295            return Err(SqlError::Parse(
2296                "COALESCE requires at least one argument".into(),
2297            ));
2298        }
2299        return Ok(Expr::Coalesce(args));
2300    }
2301
2302    if name == "NULLIF" {
2303        if args.len() != 2 {
2304            return Err(SqlError::Parse(
2305                "NULLIF requires exactly two arguments".into(),
2306            ));
2307        }
2308        return Ok(Expr::Case {
2309            operand: None,
2310            conditions: vec![(
2311                Expr::BinaryOp {
2312                    left: Box::new(args[0].clone()),
2313                    op: BinOp::Eq,
2314                    right: Box::new(args[1].clone()),
2315                },
2316                Expr::Literal(Value::Null),
2317            )],
2318            else_result: Some(Box::new(args[0].clone())),
2319        });
2320    }
2321
2322    if name == "IIF" {
2323        if args.len() != 3 {
2324            return Err(SqlError::Parse(
2325                "IIF requires exactly three arguments".into(),
2326            ));
2327        }
2328        return Ok(Expr::Case {
2329            operand: None,
2330            conditions: vec![(args[0].clone(), args[1].clone())],
2331            else_result: Some(Box::new(args[2].clone())),
2332        });
2333    }
2334
2335    Ok(Expr::Function {
2336        name,
2337        args,
2338        distinct,
2339    })
2340}
2341
2342fn convert_window_spec(ws: &sp::WindowSpec) -> Result<WindowSpec> {
2343    let partition_by = ws
2344        .partition_by
2345        .iter()
2346        .map(convert_expr)
2347        .collect::<Result<Vec<_>>>()?;
2348    let order_by = ws
2349        .order_by
2350        .iter()
2351        .map(convert_order_by_expr)
2352        .collect::<Result<Vec<_>>>()?;
2353    let frame = ws
2354        .window_frame
2355        .as_ref()
2356        .map(convert_window_frame)
2357        .transpose()?;
2358    Ok(WindowSpec {
2359        partition_by,
2360        order_by,
2361        frame,
2362    })
2363}
2364
2365fn convert_window_frame(wf: &sp::WindowFrame) -> Result<WindowFrame> {
2366    let units = match wf.units {
2367        sp::WindowFrameUnits::Rows => WindowFrameUnits::Rows,
2368        sp::WindowFrameUnits::Range => WindowFrameUnits::Range,
2369        sp::WindowFrameUnits::Groups => {
2370            return Err(SqlError::Unsupported("GROUPS window frame".into()));
2371        }
2372    };
2373    let start = convert_window_frame_bound(&wf.start_bound)?;
2374    let end = match &wf.end_bound {
2375        Some(b) => convert_window_frame_bound(b)?,
2376        None => WindowFrameBound::CurrentRow,
2377    };
2378    Ok(WindowFrame { units, start, end })
2379}
2380
2381fn convert_window_frame_bound(b: &sp::WindowFrameBound) -> Result<WindowFrameBound> {
2382    match b {
2383        sp::WindowFrameBound::CurrentRow => Ok(WindowFrameBound::CurrentRow),
2384        sp::WindowFrameBound::Preceding(None) => Ok(WindowFrameBound::UnboundedPreceding),
2385        sp::WindowFrameBound::Preceding(Some(e)) => {
2386            Ok(WindowFrameBound::Preceding(Box::new(convert_expr(e)?)))
2387        }
2388        sp::WindowFrameBound::Following(None) => Ok(WindowFrameBound::UnboundedFollowing),
2389        sp::WindowFrameBound::Following(Some(e)) => {
2390            Ok(WindowFrameBound::Following(Box::new(convert_expr(e)?)))
2391        }
2392    }
2393}
2394
2395fn convert_returning(items: Option<&[sp::SelectItem]>) -> Result<Option<Vec<SelectColumn>>> {
2396    match items {
2397        None => Ok(None),
2398        Some(items) => {
2399            let cols = items
2400                .iter()
2401                .map(convert_returning_item)
2402                .collect::<Result<Vec<_>>>()?;
2403            Ok(Some(cols))
2404        }
2405    }
2406}
2407
2408fn convert_returning_item(item: &sp::SelectItem) -> Result<SelectColumn> {
2409    match item {
2410        sp::SelectItem::Wildcard(_) => Ok(SelectColumn::AllColumns),
2411        sp::SelectItem::UnnamedExpr(e) => {
2412            reject_aggregate_or_window(e, "RETURNING")?;
2413            Ok(SelectColumn::Expr {
2414                expr: convert_expr(e)?,
2415                alias: None,
2416            })
2417        }
2418        sp::SelectItem::ExprWithAlias { expr, alias } => {
2419            reject_aggregate_or_window(expr, "RETURNING")?;
2420            Ok(SelectColumn::Expr {
2421                expr: convert_expr(expr)?,
2422                alias: Some(alias.value.clone()),
2423            })
2424        }
2425        sp::SelectItem::QualifiedWildcard(kind, _) => match kind {
2426            sp::SelectItemQualifiedWildcardKind::ObjectName(name) => {
2427                let s = object_name_to_string(name);
2428                if s.eq_ignore_ascii_case("old") {
2429                    Ok(SelectColumn::AllFromOld)
2430                } else if s.eq_ignore_ascii_case("new") {
2431                    Ok(SelectColumn::AllFromNew)
2432                } else {
2433                    Err(SqlError::Unsupported(format!(
2434                        "RETURNING {s}.* — only old.* and new.* qualified wildcards allowed"
2435                    )))
2436                }
2437            }
2438            sp::SelectItemQualifiedWildcardKind::Expr(_) => {
2439                Err(SqlError::Unsupported("expression.* in RETURNING".into()))
2440            }
2441        },
2442    }
2443}
2444
2445fn reject_aggregate_or_window(expr: &sp::Expr, ctx: &str) -> Result<()> {
2446    use sp::Expr as E;
2447    match expr {
2448        E::Function(f) => {
2449            if f.over.is_some() {
2450                return Err(SqlError::Unsupported(format!(
2451                    "window functions are not allowed in {ctx}"
2452                )));
2453            }
2454            let name = f
2455                .name
2456                .0
2457                .last()
2458                .map(|p| match p {
2459                    sp::ObjectNamePart::Identifier(id) => id.value.to_ascii_uppercase(),
2460                    _ => String::new(),
2461                })
2462                .unwrap_or_default();
2463            if matches!(
2464                name.as_str(),
2465                "COUNT"
2466                    | "SUM"
2467                    | "AVG"
2468                    | "MIN"
2469                    | "MAX"
2470                    | "GROUP_CONCAT"
2471                    | "STRING_AGG"
2472                    | "ARRAY_AGG"
2473                    | "BIT_AND"
2474                    | "BIT_OR"
2475                    | "BOOL_AND"
2476                    | "BOOL_OR"
2477                    | "EVERY"
2478                    | "STDDEV"
2479                    | "STDDEV_POP"
2480                    | "STDDEV_SAMP"
2481                    | "VARIANCE"
2482                    | "VAR_POP"
2483                    | "VAR_SAMP"
2484            ) {
2485                return Err(SqlError::Unsupported(format!(
2486                    "aggregate functions are not allowed in {ctx}"
2487                )));
2488            }
2489            for arg in walk_function_args(f) {
2490                reject_aggregate_or_window(arg, ctx)?;
2491            }
2492            Ok(())
2493        }
2494        E::BinaryOp { left, right, .. } => {
2495            reject_aggregate_or_window(left, ctx)?;
2496            reject_aggregate_or_window(right, ctx)
2497        }
2498        E::UnaryOp { expr, .. } => reject_aggregate_or_window(expr, ctx),
2499        E::Cast { expr, .. } => reject_aggregate_or_window(expr, ctx),
2500        E::Nested(e) => reject_aggregate_or_window(e, ctx),
2501        E::Case {
2502            conditions,
2503            else_result,
2504            ..
2505        } => {
2506            for cwt in conditions {
2507                reject_aggregate_or_window(&cwt.condition, ctx)?;
2508                reject_aggregate_or_window(&cwt.result, ctx)?;
2509            }
2510            if let Some(e) = else_result {
2511                reject_aggregate_or_window(e, ctx)?;
2512            }
2513            Ok(())
2514        }
2515        _ => Ok(()),
2516    }
2517}
2518
2519fn walk_function_args(f: &sp::Function) -> Vec<&sp::Expr> {
2520    use sp::FunctionArguments as FA;
2521    let mut out = Vec::new();
2522    if let FA::List(args) = &f.args {
2523        for a in &args.args {
2524            if let sp::FunctionArg::Unnamed(sp::FunctionArgExpr::Expr(e)) = a {
2525                out.push(e);
2526            }
2527        }
2528    }
2529    out
2530}
2531
2532fn convert_select_item(item: &sp::SelectItem) -> Result<SelectColumn> {
2533    match item {
2534        sp::SelectItem::Wildcard(_) => Ok(SelectColumn::AllColumns),
2535        sp::SelectItem::UnnamedExpr(e) => {
2536            let expr = convert_expr(e)?;
2537            Ok(SelectColumn::Expr { expr, alias: None })
2538        }
2539        sp::SelectItem::ExprWithAlias { expr, alias } => {
2540            let expr = convert_expr(expr)?;
2541            Ok(SelectColumn::Expr {
2542                expr,
2543                alias: Some(alias.value.clone()),
2544            })
2545        }
2546        sp::SelectItem::QualifiedWildcard(_, _) => {
2547            Err(SqlError::Unsupported("qualified wildcard (table.*)".into()))
2548        }
2549    }
2550}
2551
2552fn convert_order_by_expr(expr: &sp::OrderByExpr) -> Result<OrderByItem> {
2553    let e = convert_expr(&expr.expr)?;
2554    let descending = expr.options.asc.map(|asc| !asc).unwrap_or(false);
2555    let nulls_first = expr.options.nulls_first;
2556
2557    Ok(OrderByItem {
2558        expr: e,
2559        descending,
2560        nulls_first,
2561    })
2562}
2563
2564fn convert_data_type(dt: &sp::DataType) -> Result<DataType> {
2565    match dt {
2566        sp::DataType::Int(_)
2567        | sp::DataType::Integer(_)
2568        | sp::DataType::BigInt(_)
2569        | sp::DataType::SmallInt(_)
2570        | sp::DataType::TinyInt(_)
2571        | sp::DataType::Int2(_)
2572        | sp::DataType::Int4(_)
2573        | sp::DataType::Int8(_) => Ok(DataType::Integer),
2574
2575        sp::DataType::Real
2576        | sp::DataType::Double(..)
2577        | sp::DataType::DoublePrecision
2578        | sp::DataType::Float(_)
2579        | sp::DataType::Float4
2580        | sp::DataType::Float64 => Ok(DataType::Real),
2581
2582        sp::DataType::Varchar(_)
2583        | sp::DataType::Text
2584        | sp::DataType::Char(_)
2585        | sp::DataType::Character(_)
2586        | sp::DataType::String(_) => Ok(DataType::Text),
2587
2588        sp::DataType::Blob(_) | sp::DataType::Bytea => Ok(DataType::Blob),
2589
2590        sp::DataType::Boolean | sp::DataType::Bool => Ok(DataType::Boolean),
2591
2592        sp::DataType::Date => Ok(DataType::Date),
2593        sp::DataType::Time(_, _) => Ok(DataType::Time),
2594        sp::DataType::Timestamp(_, _) => Ok(DataType::Timestamp),
2595        sp::DataType::Interval { .. } => Ok(DataType::Interval),
2596
2597        _ => Err(SqlError::Unsupported(format!("data type: {dt}"))),
2598    }
2599}
2600
2601fn object_name_to_string(name: &sp::ObjectName) -> String {
2602    name.0
2603        .iter()
2604        .filter_map(|p| match p {
2605            sp::ObjectNamePart::Identifier(ident) => Some(ident.value.clone()),
2606            _ => None,
2607        })
2608        .collect::<Vec<_>>()
2609        .join(".")
2610}
2611
2612#[cfg(test)]
2613#[path = "parser_tests.rs"]
2614mod tests;