1use sqlparser::ast as sp;
4use sqlparser::dialect::GenericDialect;
5use sqlparser::parser::Parser;
6
7use crate::error::{Result, SqlError};
8use crate::types::{DataType, Value};
9
10#[derive(Debug, Clone)]
13pub enum Statement {
14 CreateTable(CreateTableStmt),
15 DropTable(DropTableStmt),
16 CreateIndex(CreateIndexStmt),
17 DropIndex(DropIndexStmt),
18 AlterTable(Box<AlterTableStmt>),
19 Insert(InsertStmt),
20 Select(Box<SelectQuery>),
21 Update(UpdateStmt),
22 Delete(DeleteStmt),
23 Begin,
24 Commit,
25 Rollback,
26 Explain(Box<Statement>),
27}
28
29#[derive(Debug, Clone)]
30pub struct AlterTableStmt {
31 pub table: String,
32 pub op: AlterTableOp,
33}
34
35#[derive(Debug, Clone)]
36pub enum AlterTableOp {
37 AddColumn {
38 column: Box<ColumnSpec>,
39 foreign_key: Option<ForeignKeyDef>,
40 if_not_exists: bool,
41 },
42 DropColumn {
43 name: String,
44 if_exists: bool,
45 },
46 RenameColumn {
47 old_name: String,
48 new_name: String,
49 },
50 RenameTable {
51 new_name: String,
52 },
53}
54
55#[derive(Debug, Clone)]
56pub struct CreateTableStmt {
57 pub name: String,
58 pub columns: Vec<ColumnSpec>,
59 pub primary_key: Vec<String>,
60 pub if_not_exists: bool,
61 pub check_constraints: Vec<TableCheckConstraint>,
62 pub foreign_keys: Vec<ForeignKeyDef>,
63}
64
65#[derive(Debug, Clone)]
66pub struct TableCheckConstraint {
67 pub name: Option<String>,
68 pub expr: Expr,
69 pub sql: String,
70}
71
72#[derive(Debug, Clone)]
73pub struct ForeignKeyDef {
74 pub name: Option<String>,
75 pub columns: Vec<String>,
76 pub foreign_table: String,
77 pub referred_columns: Vec<String>,
78}
79
80#[derive(Debug, Clone)]
81pub struct ColumnSpec {
82 pub name: String,
83 pub data_type: DataType,
84 pub nullable: bool,
85 pub is_primary_key: bool,
86 pub default_expr: Option<Expr>,
87 pub default_sql: Option<String>,
88 pub check_expr: Option<Expr>,
89 pub check_sql: Option<String>,
90 pub check_name: Option<String>,
91}
92
93#[derive(Debug, Clone)]
94pub struct DropTableStmt {
95 pub name: String,
96 pub if_exists: bool,
97}
98
99#[derive(Debug, Clone)]
100pub struct CreateIndexStmt {
101 pub index_name: String,
102 pub table_name: String,
103 pub columns: Vec<String>,
104 pub unique: bool,
105 pub if_not_exists: bool,
106}
107
108#[derive(Debug, Clone)]
109pub struct DropIndexStmt {
110 pub index_name: String,
111 pub if_exists: bool,
112}
113
114#[derive(Debug, Clone)]
115pub enum InsertSource {
116 Values(Vec<Vec<Expr>>),
117 Select(Box<SelectQuery>),
118}
119
120#[derive(Debug, Clone)]
121pub struct InsertStmt {
122 pub table: String,
123 pub columns: Vec<String>,
124 pub source: InsertSource,
125}
126
127#[derive(Debug, Clone)]
128pub struct TableRef {
129 pub name: String,
130 pub alias: Option<String>,
131}
132
133#[derive(Debug, Clone, Copy, PartialEq)]
134pub enum JoinType {
135 Inner,
136 Cross,
137 Left,
138 Right,
139}
140
141#[derive(Debug, Clone)]
142pub struct JoinClause {
143 pub join_type: JoinType,
144 pub table: TableRef,
145 pub on_clause: Option<Expr>,
146}
147
148#[derive(Debug, Clone)]
149pub struct SelectStmt {
150 pub columns: Vec<SelectColumn>,
151 pub from: String,
152 pub from_alias: Option<String>,
153 pub joins: Vec<JoinClause>,
154 pub distinct: bool,
155 pub where_clause: Option<Expr>,
156 pub order_by: Vec<OrderByItem>,
157 pub limit: Option<Expr>,
158 pub offset: Option<Expr>,
159 pub group_by: Vec<Expr>,
160 pub having: Option<Expr>,
161}
162
163#[derive(Debug, Clone)]
164pub enum SetOp {
165 Union,
166 Intersect,
167 Except,
168}
169
170#[derive(Debug, Clone)]
171pub struct CompoundSelect {
172 pub op: SetOp,
173 pub all: bool,
174 pub left: Box<QueryBody>,
175 pub right: Box<QueryBody>,
176 pub order_by: Vec<OrderByItem>,
177 pub limit: Option<Expr>,
178 pub offset: Option<Expr>,
179}
180
181#[derive(Debug, Clone)]
182pub enum QueryBody {
183 Select(Box<SelectStmt>),
184 Compound(CompoundSelect),
185}
186
187#[derive(Debug, Clone)]
188pub struct CteDefinition {
189 pub name: String,
190 pub column_aliases: Vec<String>,
191 pub body: QueryBody,
192}
193
194#[derive(Debug, Clone)]
195pub struct SelectQuery {
196 pub ctes: Vec<CteDefinition>,
197 pub recursive: bool,
198 pub body: QueryBody,
199}
200
201#[derive(Debug, Clone)]
202pub struct UpdateStmt {
203 pub table: String,
204 pub assignments: Vec<(String, Expr)>,
205 pub where_clause: Option<Expr>,
206}
207
208#[derive(Debug, Clone)]
209pub struct DeleteStmt {
210 pub table: String,
211 pub where_clause: Option<Expr>,
212}
213
214#[derive(Debug, Clone)]
215pub enum SelectColumn {
216 AllColumns,
217 Expr { expr: Expr, alias: Option<String> },
218}
219
220#[derive(Debug, Clone)]
221pub struct OrderByItem {
222 pub expr: Expr,
223 pub descending: bool,
224 pub nulls_first: Option<bool>,
225}
226
227#[derive(Debug, Clone)]
228pub enum Expr {
229 Literal(Value),
230 Column(String),
231 QualifiedColumn {
232 table: String,
233 column: String,
234 },
235 BinaryOp {
236 left: Box<Expr>,
237 op: BinOp,
238 right: Box<Expr>,
239 },
240 UnaryOp {
241 op: UnaryOp,
242 expr: Box<Expr>,
243 },
244 IsNull(Box<Expr>),
245 IsNotNull(Box<Expr>),
246 Function {
247 name: String,
248 args: Vec<Expr>,
249 },
250 CountStar,
251 InSubquery {
252 expr: Box<Expr>,
253 subquery: Box<SelectStmt>,
254 negated: bool,
255 },
256 InList {
257 expr: Box<Expr>,
258 list: Vec<Expr>,
259 negated: bool,
260 },
261 Exists {
262 subquery: Box<SelectStmt>,
263 negated: bool,
264 },
265 ScalarSubquery(Box<SelectStmt>),
266 InSet {
267 expr: Box<Expr>,
268 values: std::collections::HashSet<Value>,
269 has_null: bool,
270 negated: bool,
271 },
272 Between {
273 expr: Box<Expr>,
274 low: Box<Expr>,
275 high: Box<Expr>,
276 negated: bool,
277 },
278 Like {
279 expr: Box<Expr>,
280 pattern: Box<Expr>,
281 escape: Option<Box<Expr>>,
282 negated: bool,
283 },
284 Case {
285 operand: Option<Box<Expr>>,
286 conditions: Vec<(Expr, Expr)>,
287 else_result: Option<Box<Expr>>,
288 },
289 Coalesce(Vec<Expr>),
290 Cast {
291 expr: Box<Expr>,
292 data_type: DataType,
293 },
294 Parameter(usize),
295}
296
297#[derive(Debug, Clone, Copy, PartialEq, Eq)]
298pub enum BinOp {
299 Add,
300 Sub,
301 Mul,
302 Div,
303 Mod,
304 Eq,
305 NotEq,
306 Lt,
307 Gt,
308 LtEq,
309 GtEq,
310 And,
311 Or,
312 Concat,
313}
314
315#[derive(Debug, Clone, Copy, PartialEq, Eq)]
316pub enum UnaryOp {
317 Neg,
318 Not,
319}
320
321pub fn has_subquery(expr: &Expr) -> bool {
324 match expr {
325 Expr::InSubquery { .. } | Expr::Exists { .. } | Expr::ScalarSubquery(_) => true,
326 Expr::BinaryOp { left, right, .. } => has_subquery(left) || has_subquery(right),
327 Expr::UnaryOp { expr, .. } => has_subquery(expr),
328 Expr::IsNull(e) | Expr::IsNotNull(e) => has_subquery(e),
329 Expr::InList { expr, list, .. } => has_subquery(expr) || list.iter().any(has_subquery),
330 Expr::InSet { expr, .. } => has_subquery(expr),
331 Expr::Between {
332 expr, low, high, ..
333 } => has_subquery(expr) || has_subquery(low) || has_subquery(high),
334 Expr::Like {
335 expr,
336 pattern,
337 escape,
338 ..
339 } => {
340 has_subquery(expr)
341 || has_subquery(pattern)
342 || escape.as_ref().is_some_and(|e| has_subquery(e))
343 }
344 Expr::Case {
345 operand,
346 conditions,
347 else_result,
348 } => {
349 operand.as_ref().is_some_and(|e| has_subquery(e))
350 || conditions
351 .iter()
352 .any(|(c, r)| has_subquery(c) || has_subquery(r))
353 || else_result.as_ref().is_some_and(|e| has_subquery(e))
354 }
355 Expr::Coalesce(args) | Expr::Function { args, .. } => args.iter().any(has_subquery),
356 Expr::Cast { expr, .. } => has_subquery(expr),
357 _ => false,
358 }
359}
360
361pub fn parse_sql_expr(sql: &str) -> Result<Expr> {
364 let dialect = GenericDialect {};
365 let mut parser = Parser::new(&dialect)
366 .try_with_sql(sql)
367 .map_err(|e| SqlError::Parse(e.to_string()))?;
368 let sp_expr = parser
369 .parse_expr()
370 .map_err(|e| SqlError::Parse(e.to_string()))?;
371 convert_expr(&sp_expr)
372}
373
374pub fn parse_sql(sql: &str) -> Result<Statement> {
377 let dialect = GenericDialect {};
378 let stmts = Parser::parse_sql(&dialect, sql).map_err(|e| SqlError::Parse(e.to_string()))?;
379
380 if stmts.is_empty() {
381 return Err(SqlError::Parse("empty SQL".into()));
382 }
383 if stmts.len() > 1 {
384 return Err(SqlError::Unsupported("multiple statements".into()));
385 }
386
387 convert_statement(stmts.into_iter().next().unwrap())
388}
389
390pub fn count_params(stmt: &Statement) -> usize {
394 let mut max_idx = 0usize;
395 visit_exprs_stmt(stmt, &mut |e| {
396 if let Expr::Parameter(n) = e {
397 max_idx = max_idx.max(*n);
398 }
399 });
400 max_idx
401}
402
403pub fn bind_params(
405 stmt: &Statement,
406 params: &[crate::types::Value],
407) -> crate::error::Result<Statement> {
408 bind_stmt(stmt, params)
409}
410
411fn bind_stmt(stmt: &Statement, params: &[crate::types::Value]) -> crate::error::Result<Statement> {
412 match stmt {
413 Statement::Select(sq) => Ok(Statement::Select(Box::new(bind_select_query(sq, params)?))),
414 Statement::Insert(ins) => {
415 let source = match &ins.source {
416 InsertSource::Values(rows) => {
417 let bound = rows
418 .iter()
419 .map(|row| {
420 row.iter()
421 .map(|e| bind_expr(e, params))
422 .collect::<crate::error::Result<Vec<_>>>()
423 })
424 .collect::<crate::error::Result<Vec<_>>>()?;
425 InsertSource::Values(bound)
426 }
427 InsertSource::Select(sq) => {
428 InsertSource::Select(Box::new(bind_select_query(sq, params)?))
429 }
430 };
431 Ok(Statement::Insert(InsertStmt {
432 table: ins.table.clone(),
433 columns: ins.columns.clone(),
434 source,
435 }))
436 }
437 Statement::Update(upd) => {
438 let assignments = upd
439 .assignments
440 .iter()
441 .map(|(col, e)| Ok((col.clone(), bind_expr(e, params)?)))
442 .collect::<crate::error::Result<Vec<_>>>()?;
443 let where_clause = upd
444 .where_clause
445 .as_ref()
446 .map(|e| bind_expr(e, params))
447 .transpose()?;
448 Ok(Statement::Update(UpdateStmt {
449 table: upd.table.clone(),
450 assignments,
451 where_clause,
452 }))
453 }
454 Statement::Delete(del) => {
455 let where_clause = del
456 .where_clause
457 .as_ref()
458 .map(|e| bind_expr(e, params))
459 .transpose()?;
460 Ok(Statement::Delete(DeleteStmt {
461 table: del.table.clone(),
462 where_clause,
463 }))
464 }
465 Statement::Explain(inner) => Ok(Statement::Explain(Box::new(bind_stmt(inner, params)?))),
466 other => Ok(other.clone()),
467 }
468}
469
470fn bind_select(
471 sel: &SelectStmt,
472 params: &[crate::types::Value],
473) -> crate::error::Result<SelectStmt> {
474 let columns = sel
475 .columns
476 .iter()
477 .map(|c| match c {
478 SelectColumn::AllColumns => Ok(SelectColumn::AllColumns),
479 SelectColumn::Expr { expr, alias } => Ok(SelectColumn::Expr {
480 expr: bind_expr(expr, params)?,
481 alias: alias.clone(),
482 }),
483 })
484 .collect::<crate::error::Result<Vec<_>>>()?;
485 let joins = sel
486 .joins
487 .iter()
488 .map(|j| {
489 let on_clause = j
490 .on_clause
491 .as_ref()
492 .map(|e| bind_expr(e, params))
493 .transpose()?;
494 Ok(JoinClause {
495 join_type: j.join_type,
496 table: j.table.clone(),
497 on_clause,
498 })
499 })
500 .collect::<crate::error::Result<Vec<_>>>()?;
501 let where_clause = sel
502 .where_clause
503 .as_ref()
504 .map(|e| bind_expr(e, params))
505 .transpose()?;
506 let order_by = sel
507 .order_by
508 .iter()
509 .map(|o| {
510 Ok(OrderByItem {
511 expr: bind_expr(&o.expr, params)?,
512 descending: o.descending,
513 nulls_first: o.nulls_first,
514 })
515 })
516 .collect::<crate::error::Result<Vec<_>>>()?;
517 let limit = sel
518 .limit
519 .as_ref()
520 .map(|e| bind_expr(e, params))
521 .transpose()?;
522 let offset = sel
523 .offset
524 .as_ref()
525 .map(|e| bind_expr(e, params))
526 .transpose()?;
527 let group_by = sel
528 .group_by
529 .iter()
530 .map(|e| bind_expr(e, params))
531 .collect::<crate::error::Result<Vec<_>>>()?;
532 let having = sel
533 .having
534 .as_ref()
535 .map(|e| bind_expr(e, params))
536 .transpose()?;
537
538 Ok(SelectStmt {
539 columns,
540 from: sel.from.clone(),
541 from_alias: sel.from_alias.clone(),
542 joins,
543 distinct: sel.distinct,
544 where_clause,
545 order_by,
546 limit,
547 offset,
548 group_by,
549 having,
550 })
551}
552
553fn bind_query_body(
554 body: &QueryBody,
555 params: &[crate::types::Value],
556) -> crate::error::Result<QueryBody> {
557 match body {
558 QueryBody::Select(sel) => Ok(QueryBody::Select(Box::new(bind_select(sel, params)?))),
559 QueryBody::Compound(comp) => {
560 let order_by = comp
561 .order_by
562 .iter()
563 .map(|o| {
564 Ok(OrderByItem {
565 expr: bind_expr(&o.expr, params)?,
566 descending: o.descending,
567 nulls_first: o.nulls_first,
568 })
569 })
570 .collect::<crate::error::Result<Vec<_>>>()?;
571 let limit = comp
572 .limit
573 .as_ref()
574 .map(|e| bind_expr(e, params))
575 .transpose()?;
576 let offset = comp
577 .offset
578 .as_ref()
579 .map(|e| bind_expr(e, params))
580 .transpose()?;
581 Ok(QueryBody::Compound(CompoundSelect {
582 op: comp.op.clone(),
583 all: comp.all,
584 left: Box::new(bind_query_body(&comp.left, params)?),
585 right: Box::new(bind_query_body(&comp.right, params)?),
586 order_by,
587 limit,
588 offset,
589 }))
590 }
591 }
592}
593
594fn bind_select_query(
595 sq: &SelectQuery,
596 params: &[crate::types::Value],
597) -> crate::error::Result<SelectQuery> {
598 let ctes = sq
599 .ctes
600 .iter()
601 .map(|cte| {
602 Ok(CteDefinition {
603 name: cte.name.clone(),
604 column_aliases: cte.column_aliases.clone(),
605 body: bind_query_body(&cte.body, params)?,
606 })
607 })
608 .collect::<crate::error::Result<Vec<_>>>()?;
609 let body = bind_query_body(&sq.body, params)?;
610 Ok(SelectQuery {
611 ctes,
612 recursive: sq.recursive,
613 body,
614 })
615}
616
617fn bind_expr(expr: &Expr, params: &[crate::types::Value]) -> crate::error::Result<Expr> {
618 match expr {
619 Expr::Parameter(n) => {
620 if *n == 0 || *n > params.len() {
621 return Err(SqlError::ParameterCountMismatch {
622 expected: *n,
623 got: params.len(),
624 });
625 }
626 Ok(Expr::Literal(params[*n - 1].clone()))
627 }
628 Expr::Literal(_) | Expr::Column(_) | Expr::QualifiedColumn { .. } | Expr::CountStar => {
629 Ok(expr.clone())
630 }
631 Expr::BinaryOp { left, op, right } => Ok(Expr::BinaryOp {
632 left: Box::new(bind_expr(left, params)?),
633 op: *op,
634 right: Box::new(bind_expr(right, params)?),
635 }),
636 Expr::UnaryOp { op, expr: e } => Ok(Expr::UnaryOp {
637 op: *op,
638 expr: Box::new(bind_expr(e, params)?),
639 }),
640 Expr::IsNull(e) => Ok(Expr::IsNull(Box::new(bind_expr(e, params)?))),
641 Expr::IsNotNull(e) => Ok(Expr::IsNotNull(Box::new(bind_expr(e, params)?))),
642 Expr::Function { name, args } => {
643 let args = args
644 .iter()
645 .map(|a| bind_expr(a, params))
646 .collect::<crate::error::Result<Vec<_>>>()?;
647 Ok(Expr::Function {
648 name: name.clone(),
649 args,
650 })
651 }
652 Expr::InSubquery {
653 expr: e,
654 subquery,
655 negated,
656 } => Ok(Expr::InSubquery {
657 expr: Box::new(bind_expr(e, params)?),
658 subquery: Box::new(bind_select(subquery, params)?),
659 negated: *negated,
660 }),
661 Expr::InList {
662 expr: e,
663 list,
664 negated,
665 } => {
666 let list = list
667 .iter()
668 .map(|l| bind_expr(l, params))
669 .collect::<crate::error::Result<Vec<_>>>()?;
670 Ok(Expr::InList {
671 expr: Box::new(bind_expr(e, params)?),
672 list,
673 negated: *negated,
674 })
675 }
676 Expr::Exists { subquery, negated } => Ok(Expr::Exists {
677 subquery: Box::new(bind_select(subquery, params)?),
678 negated: *negated,
679 }),
680 Expr::ScalarSubquery(sq) => Ok(Expr::ScalarSubquery(Box::new(bind_select(sq, params)?))),
681 Expr::InSet {
682 expr: e,
683 values,
684 has_null,
685 negated,
686 } => Ok(Expr::InSet {
687 expr: Box::new(bind_expr(e, params)?),
688 values: values.clone(),
689 has_null: *has_null,
690 negated: *negated,
691 }),
692 Expr::Between {
693 expr: e,
694 low,
695 high,
696 negated,
697 } => Ok(Expr::Between {
698 expr: Box::new(bind_expr(e, params)?),
699 low: Box::new(bind_expr(low, params)?),
700 high: Box::new(bind_expr(high, params)?),
701 negated: *negated,
702 }),
703 Expr::Like {
704 expr: e,
705 pattern,
706 escape,
707 negated,
708 } => Ok(Expr::Like {
709 expr: Box::new(bind_expr(e, params)?),
710 pattern: Box::new(bind_expr(pattern, params)?),
711 escape: escape
712 .as_ref()
713 .map(|esc| bind_expr(esc, params).map(Box::new))
714 .transpose()?,
715 negated: *negated,
716 }),
717 Expr::Case {
718 operand,
719 conditions,
720 else_result,
721 } => {
722 let operand = operand
723 .as_ref()
724 .map(|e| bind_expr(e, params).map(Box::new))
725 .transpose()?;
726 let conditions = conditions
727 .iter()
728 .map(|(cond, then)| Ok((bind_expr(cond, params)?, bind_expr(then, params)?)))
729 .collect::<crate::error::Result<Vec<_>>>()?;
730 let else_result = else_result
731 .as_ref()
732 .map(|e| bind_expr(e, params).map(Box::new))
733 .transpose()?;
734 Ok(Expr::Case {
735 operand,
736 conditions,
737 else_result,
738 })
739 }
740 Expr::Coalesce(args) => {
741 let args = args
742 .iter()
743 .map(|a| bind_expr(a, params))
744 .collect::<crate::error::Result<Vec<_>>>()?;
745 Ok(Expr::Coalesce(args))
746 }
747 Expr::Cast { expr: e, data_type } => Ok(Expr::Cast {
748 expr: Box::new(bind_expr(e, params)?),
749 data_type: *data_type,
750 }),
751 }
752}
753
754fn visit_exprs_stmt(stmt: &Statement, visitor: &mut impl FnMut(&Expr)) {
755 match stmt {
756 Statement::Select(sq) => {
757 for cte in &sq.ctes {
758 visit_exprs_query_body(&cte.body, visitor);
759 }
760 visit_exprs_query_body(&sq.body, visitor);
761 }
762 Statement::Insert(ins) => match &ins.source {
763 InsertSource::Values(rows) => {
764 for row in rows {
765 for e in row {
766 visit_expr(e, visitor);
767 }
768 }
769 }
770 InsertSource::Select(sq) => {
771 for cte in &sq.ctes {
772 visit_exprs_query_body(&cte.body, visitor);
773 }
774 visit_exprs_query_body(&sq.body, visitor);
775 }
776 },
777 Statement::Update(upd) => {
778 for (_, e) in &upd.assignments {
779 visit_expr(e, visitor);
780 }
781 if let Some(w) = &upd.where_clause {
782 visit_expr(w, visitor);
783 }
784 }
785 Statement::Delete(del) => {
786 if let Some(w) = &del.where_clause {
787 visit_expr(w, visitor);
788 }
789 }
790 Statement::Explain(inner) => visit_exprs_stmt(inner, visitor),
791 _ => {}
792 }
793}
794
795fn visit_exprs_query_body(body: &QueryBody, visitor: &mut impl FnMut(&Expr)) {
796 match body {
797 QueryBody::Select(sel) => visit_exprs_select(sel, visitor),
798 QueryBody::Compound(comp) => {
799 visit_exprs_query_body(&comp.left, visitor);
800 visit_exprs_query_body(&comp.right, visitor);
801 for o in &comp.order_by {
802 visit_expr(&o.expr, visitor);
803 }
804 if let Some(l) = &comp.limit {
805 visit_expr(l, visitor);
806 }
807 if let Some(o) = &comp.offset {
808 visit_expr(o, visitor);
809 }
810 }
811 }
812}
813
814fn visit_exprs_select(sel: &SelectStmt, visitor: &mut impl FnMut(&Expr)) {
815 for col in &sel.columns {
816 if let SelectColumn::Expr { expr, .. } = col {
817 visit_expr(expr, visitor);
818 }
819 }
820 for j in &sel.joins {
821 if let Some(on) = &j.on_clause {
822 visit_expr(on, visitor);
823 }
824 }
825 if let Some(w) = &sel.where_clause {
826 visit_expr(w, visitor);
827 }
828 for o in &sel.order_by {
829 visit_expr(&o.expr, visitor);
830 }
831 if let Some(l) = &sel.limit {
832 visit_expr(l, visitor);
833 }
834 if let Some(o) = &sel.offset {
835 visit_expr(o, visitor);
836 }
837 for g in &sel.group_by {
838 visit_expr(g, visitor);
839 }
840 if let Some(h) = &sel.having {
841 visit_expr(h, visitor);
842 }
843}
844
845fn visit_expr(expr: &Expr, visitor: &mut impl FnMut(&Expr)) {
846 visitor(expr);
847 match expr {
848 Expr::BinaryOp { left, right, .. } => {
849 visit_expr(left, visitor);
850 visit_expr(right, visitor);
851 }
852 Expr::UnaryOp { expr: e, .. } | Expr::IsNull(e) | Expr::IsNotNull(e) => {
853 visit_expr(e, visitor);
854 }
855 Expr::Function { args, .. } | Expr::Coalesce(args) => {
856 for a in args {
857 visit_expr(a, visitor);
858 }
859 }
860 Expr::InSubquery {
861 expr: e, subquery, ..
862 } => {
863 visit_expr(e, visitor);
864 visit_exprs_select(subquery, visitor);
865 }
866 Expr::InList { expr: e, list, .. } => {
867 visit_expr(e, visitor);
868 for l in list {
869 visit_expr(l, visitor);
870 }
871 }
872 Expr::Exists { subquery, .. } => visit_exprs_select(subquery, visitor),
873 Expr::ScalarSubquery(sq) => visit_exprs_select(sq, visitor),
874 Expr::InSet { expr: e, .. } => visit_expr(e, visitor),
875 Expr::Between {
876 expr: e, low, high, ..
877 } => {
878 visit_expr(e, visitor);
879 visit_expr(low, visitor);
880 visit_expr(high, visitor);
881 }
882 Expr::Like {
883 expr: e,
884 pattern,
885 escape,
886 ..
887 } => {
888 visit_expr(e, visitor);
889 visit_expr(pattern, visitor);
890 if let Some(esc) = escape {
891 visit_expr(esc, visitor);
892 }
893 }
894 Expr::Case {
895 operand,
896 conditions,
897 else_result,
898 } => {
899 if let Some(op) = operand {
900 visit_expr(op, visitor);
901 }
902 for (cond, then) in conditions {
903 visit_expr(cond, visitor);
904 visit_expr(then, visitor);
905 }
906 if let Some(el) = else_result {
907 visit_expr(el, visitor);
908 }
909 }
910 Expr::Cast { expr: e, .. } => visit_expr(e, visitor),
911 Expr::Literal(_)
912 | Expr::Column(_)
913 | Expr::QualifiedColumn { .. }
914 | Expr::CountStar
915 | Expr::Parameter(_) => {}
916 }
917}
918
919fn convert_statement(stmt: sp::Statement) -> Result<Statement> {
922 match stmt {
923 sp::Statement::CreateTable(ct) => convert_create_table(ct),
924 sp::Statement::CreateIndex(ci) => convert_create_index(ci),
925 sp::Statement::Drop {
926 object_type: sp::ObjectType::Table,
927 if_exists,
928 names,
929 ..
930 } => {
931 if names.len() != 1 {
932 return Err(SqlError::Unsupported("multi-table DROP".into()));
933 }
934 Ok(Statement::DropTable(DropTableStmt {
935 name: object_name_to_string(&names[0]),
936 if_exists,
937 }))
938 }
939 sp::Statement::Drop {
940 object_type: sp::ObjectType::Index,
941 if_exists,
942 names,
943 ..
944 } => {
945 if names.len() != 1 {
946 return Err(SqlError::Unsupported("multi-index DROP".into()));
947 }
948 Ok(Statement::DropIndex(DropIndexStmt {
949 index_name: object_name_to_string(&names[0]),
950 if_exists,
951 }))
952 }
953 sp::Statement::AlterTable(at) => convert_alter_table(at),
954 sp::Statement::Insert(insert) => convert_insert(insert),
955 sp::Statement::Query(query) => convert_query(*query),
956 sp::Statement::Update(update) => convert_update(update),
957 sp::Statement::Delete(delete) => convert_delete(delete),
958 sp::Statement::StartTransaction { .. } => Ok(Statement::Begin),
959 sp::Statement::Commit { .. } => Ok(Statement::Commit),
960 sp::Statement::Rollback { .. } => Ok(Statement::Rollback),
961 sp::Statement::Explain {
962 statement, analyze, ..
963 } => {
964 if analyze {
965 return Err(SqlError::Unsupported("EXPLAIN ANALYZE".into()));
966 }
967 let inner = convert_statement(*statement)?;
968 Ok(Statement::Explain(Box::new(inner)))
969 }
970 _ => Err(SqlError::Unsupported(format!("statement type: {}", stmt))),
971 }
972}
973
974fn convert_column_def(
977 col_def: &sp::ColumnDef,
978) -> Result<(ColumnSpec, Option<ForeignKeyDef>, bool)> {
979 let col_name = col_def.name.value.clone();
980 let data_type = convert_data_type(&col_def.data_type)?;
981 let mut nullable = true;
982 let mut is_primary_key = false;
983 let mut default_expr = None;
984 let mut default_sql = None;
985 let mut check_expr = None;
986 let mut check_sql = None;
987 let mut check_name = None;
988 let mut fk_def = None;
989
990 for opt in &col_def.options {
991 match &opt.option {
992 sp::ColumnOption::NotNull => nullable = false,
993 sp::ColumnOption::Null => nullable = true,
994 sp::ColumnOption::PrimaryKey(_) => {
995 is_primary_key = true;
996 nullable = false;
997 }
998 sp::ColumnOption::Default(expr) => {
999 default_sql = Some(expr.to_string());
1000 default_expr = Some(convert_expr(expr)?);
1001 }
1002 sp::ColumnOption::Check(check) => {
1003 check_sql = Some(check.expr.to_string());
1004 let converted = convert_expr(&check.expr)?;
1005 if has_subquery(&converted) {
1006 return Err(SqlError::Unsupported("subquery in CHECK constraint".into()));
1007 }
1008 check_expr = Some(converted);
1009 check_name = check.name.as_ref().map(|n| n.value.clone());
1010 }
1011 sp::ColumnOption::ForeignKey(fk) => {
1012 convert_fk_actions(&fk.on_delete, &fk.on_update)?;
1013 let ftable = object_name_to_string(&fk.foreign_table).to_ascii_lowercase();
1014 let referred: Vec<String> = fk
1015 .referred_columns
1016 .iter()
1017 .map(|i| i.value.to_ascii_lowercase())
1018 .collect();
1019 fk_def = Some(ForeignKeyDef {
1020 name: fk.name.as_ref().map(|n| n.value.clone()),
1021 columns: vec![col_name.to_ascii_lowercase()],
1022 foreign_table: ftable,
1023 referred_columns: referred,
1024 });
1025 }
1026 _ => {}
1027 }
1028 }
1029
1030 let spec = ColumnSpec {
1031 name: col_name,
1032 data_type,
1033 nullable,
1034 is_primary_key,
1035 default_expr,
1036 default_sql,
1037 check_expr,
1038 check_sql,
1039 check_name,
1040 };
1041 Ok((spec, fk_def, is_primary_key))
1042}
1043
1044fn convert_create_table(ct: sp::CreateTable) -> Result<Statement> {
1045 let name = object_name_to_string(&ct.name);
1046 let if_not_exists = ct.if_not_exists;
1047
1048 let mut columns = Vec::new();
1049 let mut inline_pk: Vec<String> = Vec::new();
1050 let mut foreign_keys: Vec<ForeignKeyDef> = Vec::new();
1051
1052 for col_def in &ct.columns {
1053 let (spec, fk_def, was_pk) = convert_column_def(col_def)?;
1054 if was_pk {
1055 inline_pk.push(spec.name.clone());
1056 }
1057 if let Some(fk) = fk_def {
1058 foreign_keys.push(fk);
1059 }
1060 columns.push(spec);
1061 }
1062
1063 let mut check_constraints: Vec<TableCheckConstraint> = Vec::new();
1065
1066 for constraint in &ct.constraints {
1067 match constraint {
1068 sp::TableConstraint::PrimaryKey(pk_constraint) => {
1069 for idx_col in &pk_constraint.columns {
1070 let col_name = match &idx_col.column.expr {
1071 sp::Expr::Identifier(ident) => ident.value.clone(),
1072 _ => continue,
1073 };
1074 if !inline_pk.contains(&col_name) {
1075 inline_pk.push(col_name.clone());
1076 }
1077 if let Some(col) = columns.iter_mut().find(|c| c.name == col_name) {
1078 col.nullable = false;
1079 col.is_primary_key = true;
1080 }
1081 }
1082 }
1083 sp::TableConstraint::Check(check) => {
1084 let sql = check.expr.to_string();
1085 let converted = convert_expr(&check.expr)?;
1086 if has_subquery(&converted) {
1087 return Err(SqlError::Unsupported("subquery in CHECK constraint".into()));
1088 }
1089 check_constraints.push(TableCheckConstraint {
1090 name: check.name.as_ref().map(|n| n.value.clone()),
1091 expr: converted,
1092 sql,
1093 });
1094 }
1095 sp::TableConstraint::ForeignKey(fk) => {
1096 convert_fk_actions(&fk.on_delete, &fk.on_update)?;
1097 let cols: Vec<String> = fk
1098 .columns
1099 .iter()
1100 .map(|i| i.value.to_ascii_lowercase())
1101 .collect();
1102 let ftable = object_name_to_string(&fk.foreign_table).to_ascii_lowercase();
1103 let referred: Vec<String> = fk
1104 .referred_columns
1105 .iter()
1106 .map(|i| i.value.to_ascii_lowercase())
1107 .collect();
1108 foreign_keys.push(ForeignKeyDef {
1109 name: fk.name.as_ref().map(|n| n.value.clone()),
1110 columns: cols,
1111 foreign_table: ftable,
1112 referred_columns: referred,
1113 });
1114 }
1115 _ => {}
1116 }
1117 }
1118
1119 Ok(Statement::CreateTable(CreateTableStmt {
1120 name,
1121 columns,
1122 primary_key: inline_pk,
1123 if_not_exists,
1124 check_constraints,
1125 foreign_keys,
1126 }))
1127}
1128
1129fn convert_alter_table(at: sp::AlterTable) -> Result<Statement> {
1130 let table = object_name_to_string(&at.name);
1131 if at.operations.len() != 1 {
1132 return Err(SqlError::Unsupported(
1133 "ALTER TABLE with multiple operations".into(),
1134 ));
1135 }
1136 let op = match at.operations.into_iter().next().unwrap() {
1137 sp::AlterTableOperation::AddColumn {
1138 column_def,
1139 if_not_exists,
1140 ..
1141 } => {
1142 let (spec, fk, _was_pk) = convert_column_def(&column_def)?;
1143 AlterTableOp::AddColumn {
1144 column: Box::new(spec),
1145 foreign_key: fk,
1146 if_not_exists,
1147 }
1148 }
1149 sp::AlterTableOperation::DropColumn {
1150 column_names,
1151 if_exists,
1152 ..
1153 } => {
1154 if column_names.len() != 1 {
1155 return Err(SqlError::Unsupported(
1156 "DROP COLUMN with multiple columns".into(),
1157 ));
1158 }
1159 AlterTableOp::DropColumn {
1160 name: column_names.into_iter().next().unwrap().value,
1161 if_exists,
1162 }
1163 }
1164 sp::AlterTableOperation::RenameColumn {
1165 old_column_name,
1166 new_column_name,
1167 } => AlterTableOp::RenameColumn {
1168 old_name: old_column_name.value,
1169 new_name: new_column_name.value,
1170 },
1171 sp::AlterTableOperation::RenameTable { table_name } => {
1172 let new_name = match table_name {
1173 sp::RenameTableNameKind::To(name) | sp::RenameTableNameKind::As(name) => {
1174 object_name_to_string(&name)
1175 }
1176 };
1177 AlterTableOp::RenameTable { new_name }
1178 }
1179 other => {
1180 return Err(SqlError::Unsupported(format!(
1181 "ALTER TABLE operation: {other}"
1182 )));
1183 }
1184 };
1185 Ok(Statement::AlterTable(Box::new(AlterTableStmt {
1186 table,
1187 op,
1188 })))
1189}
1190
1191fn convert_fk_actions(
1192 on_delete: &Option<sp::ReferentialAction>,
1193 on_update: &Option<sp::ReferentialAction>,
1194) -> Result<()> {
1195 for action in [on_delete, on_update] {
1196 match action {
1197 None
1198 | Some(sp::ReferentialAction::Restrict)
1199 | Some(sp::ReferentialAction::NoAction) => {}
1200 Some(other) => {
1201 return Err(SqlError::Unsupported(format!(
1202 "FOREIGN KEY action: {other}"
1203 )));
1204 }
1205 }
1206 }
1207 Ok(())
1208}
1209
1210fn convert_create_index(ci: sp::CreateIndex) -> Result<Statement> {
1211 let index_name = ci
1212 .name
1213 .as_ref()
1214 .map(object_name_to_string)
1215 .ok_or_else(|| SqlError::Parse("index name required".into()))?;
1216
1217 let table_name = object_name_to_string(&ci.table_name);
1218
1219 let columns: Vec<String> = ci
1220 .columns
1221 .iter()
1222 .map(|idx_col| match &idx_col.column.expr {
1223 sp::Expr::Identifier(ident) => Ok(ident.value.clone()),
1224 other => Err(SqlError::Unsupported(format!("expression index: {other}"))),
1225 })
1226 .collect::<Result<_>>()?;
1227
1228 if columns.is_empty() {
1229 return Err(SqlError::Parse(
1230 "index must have at least one column".into(),
1231 ));
1232 }
1233
1234 Ok(Statement::CreateIndex(CreateIndexStmt {
1235 index_name,
1236 table_name,
1237 columns,
1238 unique: ci.unique,
1239 if_not_exists: ci.if_not_exists,
1240 }))
1241}
1242
1243fn convert_insert(insert: sp::Insert) -> Result<Statement> {
1244 let table = match &insert.table {
1245 sp::TableObject::TableName(name) => object_name_to_string(name).to_ascii_lowercase(),
1246 _ => return Err(SqlError::Unsupported("INSERT into non-table object".into())),
1247 };
1248
1249 let columns: Vec<String> = insert
1250 .columns
1251 .iter()
1252 .map(|c| c.value.to_ascii_lowercase())
1253 .collect();
1254
1255 let query = insert
1256 .source
1257 .ok_or_else(|| SqlError::Parse("INSERT requires VALUES or SELECT".into()))?;
1258
1259 let source = match *query.body {
1260 sp::SetExpr::Values(sp::Values { rows, .. }) => {
1261 let mut result = Vec::new();
1262 for row in rows {
1263 let mut exprs = Vec::new();
1264 for expr in row {
1265 exprs.push(convert_expr(&expr)?);
1266 }
1267 result.push(exprs);
1268 }
1269 InsertSource::Values(result)
1270 }
1271 _ => {
1272 let (ctes, recursive) = if let Some(ref with) = query.with {
1273 convert_with(with)?
1274 } else {
1275 (vec![], false)
1276 };
1277 let body = convert_query_body(&query)?;
1278 InsertSource::Select(Box::new(SelectQuery {
1279 ctes,
1280 recursive,
1281 body,
1282 }))
1283 }
1284 };
1285
1286 Ok(Statement::Insert(InsertStmt {
1287 table,
1288 columns,
1289 source,
1290 }))
1291}
1292
1293fn convert_select_body(select: &sp::Select) -> Result<SelectStmt> {
1294 let distinct = match &select.distinct {
1295 Some(sp::Distinct::Distinct) => true,
1296 Some(sp::Distinct::On(_)) => {
1297 return Err(SqlError::Unsupported("DISTINCT ON".into()));
1298 }
1299 _ => false,
1300 };
1301
1302 let (from, from_alias, joins) = if select.from.is_empty() {
1304 (String::new(), None, vec![])
1305 } else if select.from.len() == 1 {
1306 let table_with_joins = &select.from[0];
1307 let (name, alias) = match &table_with_joins.relation {
1308 sp::TableFactor::Table { name, alias, .. } => {
1309 let table_name = object_name_to_string(name);
1310 let alias_str = alias.as_ref().map(|a| a.name.value.clone());
1311 (table_name, alias_str)
1312 }
1313 _ => return Err(SqlError::Unsupported("non-table FROM source".into())),
1314 };
1315 let j = table_with_joins
1316 .joins
1317 .iter()
1318 .map(convert_join)
1319 .collect::<Result<Vec<_>>>()?;
1320 (name, alias, j)
1321 } else {
1322 return Err(SqlError::Unsupported("comma-separated FROM tables".into()));
1323 };
1324
1325 let columns: Vec<SelectColumn> = select
1327 .projection
1328 .iter()
1329 .map(convert_select_item)
1330 .collect::<Result<_>>()?;
1331
1332 let where_clause = select.selection.as_ref().map(convert_expr).transpose()?;
1334
1335 let group_by = match &select.group_by {
1337 sp::GroupByExpr::Expressions(exprs, _) => {
1338 exprs.iter().map(convert_expr).collect::<Result<_>>()?
1339 }
1340 sp::GroupByExpr::All(_) => {
1341 return Err(SqlError::Unsupported("GROUP BY ALL".into()));
1342 }
1343 };
1344
1345 let having = select.having.as_ref().map(convert_expr).transpose()?;
1347
1348 Ok(SelectStmt {
1349 columns,
1350 from,
1351 from_alias,
1352 joins,
1353 distinct,
1354 where_clause,
1355 order_by: vec![],
1356 limit: None,
1357 offset: None,
1358 group_by,
1359 having,
1360 })
1361}
1362
1363fn convert_set_expr(set_expr: &sp::SetExpr) -> Result<QueryBody> {
1364 match set_expr {
1365 sp::SetExpr::Select(sel) => Ok(QueryBody::Select(Box::new(convert_select_body(sel)?))),
1366 sp::SetExpr::SetOperation {
1367 op,
1368 set_quantifier,
1369 left,
1370 right,
1371 } => {
1372 let set_op = match op {
1373 sp::SetOperator::Union => SetOp::Union,
1374 sp::SetOperator::Intersect => SetOp::Intersect,
1375 sp::SetOperator::Except | sp::SetOperator::Minus => SetOp::Except,
1376 };
1377 let all = match set_quantifier {
1378 sp::SetQuantifier::All => true,
1379 sp::SetQuantifier::None | sp::SetQuantifier::Distinct => false,
1380 _ => {
1381 return Err(SqlError::Unsupported("BY NAME set operations".into()));
1382 }
1383 };
1384 Ok(QueryBody::Compound(CompoundSelect {
1385 op: set_op,
1386 all,
1387 left: Box::new(convert_set_expr(left)?),
1388 right: Box::new(convert_set_expr(right)?),
1389 order_by: vec![],
1390 limit: None,
1391 offset: None,
1392 }))
1393 }
1394 _ => Err(SqlError::Unsupported("unsupported set expression".into())),
1395 }
1396}
1397
1398fn convert_query_body(query: &sp::Query) -> Result<QueryBody> {
1399 let mut body = convert_set_expr(&query.body)?;
1400
1401 let order_by = if let Some(ref ob) = query.order_by {
1403 match &ob.kind {
1404 sp::OrderByKind::Expressions(exprs) => exprs
1405 .iter()
1406 .map(convert_order_by_expr)
1407 .collect::<Result<_>>()?,
1408 sp::OrderByKind::All { .. } => {
1409 return Err(SqlError::Unsupported("ORDER BY ALL".into()));
1410 }
1411 }
1412 } else {
1413 vec![]
1414 };
1415
1416 let (limit, offset) = match &query.limit_clause {
1418 Some(sp::LimitClause::LimitOffset { limit, offset, .. }) => {
1419 let l = limit.as_ref().map(convert_expr).transpose()?;
1420 let o = offset
1421 .as_ref()
1422 .map(|o| convert_expr(&o.value))
1423 .transpose()?;
1424 (l, o)
1425 }
1426 Some(sp::LimitClause::OffsetCommaLimit { limit, offset }) => {
1427 let l = Some(convert_expr(limit)?);
1428 let o = Some(convert_expr(offset)?);
1429 (l, o)
1430 }
1431 None => (None, None),
1432 };
1433
1434 match &mut body {
1435 QueryBody::Select(sel) => {
1436 sel.order_by = order_by;
1437 sel.limit = limit;
1438 sel.offset = offset;
1439 }
1440 QueryBody::Compound(comp) => {
1441 comp.order_by = order_by;
1442 comp.limit = limit;
1443 comp.offset = offset;
1444 }
1445 }
1446
1447 Ok(body)
1448}
1449
1450fn convert_subquery(query: &sp::Query) -> Result<SelectStmt> {
1451 if query.with.is_some() {
1452 return Err(SqlError::Unsupported("CTEs in subqueries".into()));
1453 }
1454 match convert_query_body(query)? {
1455 QueryBody::Select(s) => Ok(*s),
1456 QueryBody::Compound(_) => Err(SqlError::Unsupported(
1457 "UNION/INTERSECT/EXCEPT in subqueries".into(),
1458 )),
1459 }
1460}
1461
1462fn convert_with(with: &sp::With) -> Result<(Vec<CteDefinition>, bool)> {
1463 let mut names = std::collections::HashSet::new();
1464 let mut ctes = Vec::new();
1465 for cte in &with.cte_tables {
1466 let name = cte.alias.name.value.to_ascii_lowercase();
1467 if !names.insert(name.clone()) {
1468 return Err(SqlError::DuplicateCteName(name));
1469 }
1470 let column_aliases: Vec<String> = cte
1471 .alias
1472 .columns
1473 .iter()
1474 .map(|c| c.name.value.to_ascii_lowercase())
1475 .collect();
1476 let body = convert_query_body(&cte.query)?;
1477 ctes.push(CteDefinition {
1478 name,
1479 column_aliases,
1480 body,
1481 });
1482 }
1483 Ok((ctes, with.recursive))
1484}
1485
1486fn convert_query(query: sp::Query) -> Result<Statement> {
1487 let (ctes, recursive) = if let Some(ref with) = query.with {
1488 convert_with(with)?
1489 } else {
1490 (vec![], false)
1491 };
1492 let body = convert_query_body(&query)?;
1493 Ok(Statement::Select(Box::new(SelectQuery {
1494 ctes,
1495 recursive,
1496 body,
1497 })))
1498}
1499
1500fn convert_join(join: &sp::Join) -> Result<JoinClause> {
1501 let (join_type, constraint) = match &join.join_operator {
1502 sp::JoinOperator::Inner(c) => (JoinType::Inner, Some(c)),
1503 sp::JoinOperator::Join(c) => (JoinType::Inner, Some(c)),
1504 sp::JoinOperator::CrossJoin(c) => (JoinType::Cross, Some(c)),
1505 sp::JoinOperator::Left(c) => (JoinType::Left, Some(c)),
1506 sp::JoinOperator::LeftSemi(c) => (JoinType::Left, Some(c)),
1507 sp::JoinOperator::LeftAnti(c) => (JoinType::Left, Some(c)),
1508 sp::JoinOperator::Right(c) => (JoinType::Right, Some(c)),
1509 sp::JoinOperator::RightSemi(c) => (JoinType::Right, Some(c)),
1510 sp::JoinOperator::RightAnti(c) => (JoinType::Right, Some(c)),
1511 other => return Err(SqlError::Unsupported(format!("join type: {other:?}"))),
1512 };
1513
1514 let (name, alias) = match &join.relation {
1515 sp::TableFactor::Table { name, alias, .. } => {
1516 let table_name = object_name_to_string(name);
1517 let alias_str = alias.as_ref().map(|a| a.name.value.clone());
1518 (table_name, alias_str)
1519 }
1520 _ => return Err(SqlError::Unsupported("non-table JOIN source".into())),
1521 };
1522
1523 let on_clause = match constraint {
1524 Some(sp::JoinConstraint::On(expr)) => Some(convert_expr(expr)?),
1525 Some(sp::JoinConstraint::None) | None => None,
1526 Some(other) => return Err(SqlError::Unsupported(format!("join constraint: {other:?}"))),
1527 };
1528
1529 Ok(JoinClause {
1530 join_type,
1531 table: TableRef { name, alias },
1532 on_clause,
1533 })
1534}
1535
1536fn convert_update(update: sp::Update) -> Result<Statement> {
1537 let table = match &update.table.relation {
1538 sp::TableFactor::Table { name, .. } => object_name_to_string(name),
1539 _ => return Err(SqlError::Unsupported("non-table UPDATE target".into())),
1540 };
1541
1542 let assignments = update
1543 .assignments
1544 .iter()
1545 .map(|a| {
1546 let col = match &a.target {
1547 sp::AssignmentTarget::ColumnName(name) => object_name_to_string(name),
1548 _ => return Err(SqlError::Unsupported("tuple assignment".into())),
1549 };
1550 let expr = convert_expr(&a.value)?;
1551 Ok((col, expr))
1552 })
1553 .collect::<Result<_>>()?;
1554
1555 let where_clause = update.selection.as_ref().map(convert_expr).transpose()?;
1556
1557 Ok(Statement::Update(UpdateStmt {
1558 table,
1559 assignments,
1560 where_clause,
1561 }))
1562}
1563
1564fn convert_delete(delete: sp::Delete) -> Result<Statement> {
1565 let table_name = match &delete.from {
1566 sp::FromTable::WithFromKeyword(tables) => {
1567 if tables.len() != 1 {
1568 return Err(SqlError::Unsupported("multi-table DELETE".into()));
1569 }
1570 match &tables[0].relation {
1571 sp::TableFactor::Table { name, .. } => object_name_to_string(name),
1572 _ => return Err(SqlError::Unsupported("non-table DELETE target".into())),
1573 }
1574 }
1575 sp::FromTable::WithoutKeyword(tables) => {
1576 if tables.len() != 1 {
1577 return Err(SqlError::Unsupported("multi-table DELETE".into()));
1578 }
1579 match &tables[0].relation {
1580 sp::TableFactor::Table { name, .. } => object_name_to_string(name),
1581 _ => return Err(SqlError::Unsupported("non-table DELETE target".into())),
1582 }
1583 }
1584 };
1585
1586 let where_clause = delete.selection.as_ref().map(convert_expr).transpose()?;
1587
1588 Ok(Statement::Delete(DeleteStmt {
1589 table: table_name,
1590 where_clause,
1591 }))
1592}
1593
1594fn convert_expr(expr: &sp::Expr) -> Result<Expr> {
1597 match expr {
1598 sp::Expr::Value(v) => convert_value(&v.value),
1599 sp::Expr::Identifier(ident) => Ok(Expr::Column(ident.value.to_ascii_lowercase())),
1600 sp::Expr::CompoundIdentifier(parts) => {
1601 if parts.len() == 2 {
1602 Ok(Expr::QualifiedColumn {
1603 table: parts[0].value.to_ascii_lowercase(),
1604 column: parts[1].value.to_ascii_lowercase(),
1605 })
1606 } else {
1607 Ok(Expr::Column(
1608 parts.last().unwrap().value.to_ascii_lowercase(),
1609 ))
1610 }
1611 }
1612 sp::Expr::BinaryOp { left, op, right } => {
1613 let bin_op = convert_bin_op(op)?;
1614 Ok(Expr::BinaryOp {
1615 left: Box::new(convert_expr(left)?),
1616 op: bin_op,
1617 right: Box::new(convert_expr(right)?),
1618 })
1619 }
1620 sp::Expr::UnaryOp { op, expr } => {
1621 let unary_op = match op {
1622 sp::UnaryOperator::Minus => UnaryOp::Neg,
1623 sp::UnaryOperator::Not => UnaryOp::Not,
1624 _ => return Err(SqlError::Unsupported(format!("unary op: {op}"))),
1625 };
1626 Ok(Expr::UnaryOp {
1627 op: unary_op,
1628 expr: Box::new(convert_expr(expr)?),
1629 })
1630 }
1631 sp::Expr::IsNull(e) => Ok(Expr::IsNull(Box::new(convert_expr(e)?))),
1632 sp::Expr::IsNotNull(e) => Ok(Expr::IsNotNull(Box::new(convert_expr(e)?))),
1633 sp::Expr::Nested(e) => convert_expr(e),
1634 sp::Expr::Function(func) => convert_function(func),
1635 sp::Expr::InSubquery {
1636 expr: e,
1637 subquery,
1638 negated,
1639 } => {
1640 let inner_expr = convert_expr(e)?;
1641 let stmt = convert_subquery(subquery)?;
1642 Ok(Expr::InSubquery {
1643 expr: Box::new(inner_expr),
1644 subquery: Box::new(stmt),
1645 negated: *negated,
1646 })
1647 }
1648 sp::Expr::InList {
1649 expr: e,
1650 list,
1651 negated,
1652 } => {
1653 let inner_expr = convert_expr(e)?;
1654 let items = list.iter().map(convert_expr).collect::<Result<Vec<_>>>()?;
1655 Ok(Expr::InList {
1656 expr: Box::new(inner_expr),
1657 list: items,
1658 negated: *negated,
1659 })
1660 }
1661 sp::Expr::Exists { subquery, negated } => {
1662 let stmt = convert_subquery(subquery)?;
1663 Ok(Expr::Exists {
1664 subquery: Box::new(stmt),
1665 negated: *negated,
1666 })
1667 }
1668 sp::Expr::Subquery(query) => {
1669 let stmt = convert_subquery(query)?;
1670 Ok(Expr::ScalarSubquery(Box::new(stmt)))
1671 }
1672 sp::Expr::Between {
1673 expr: e,
1674 negated,
1675 low,
1676 high,
1677 } => Ok(Expr::Between {
1678 expr: Box::new(convert_expr(e)?),
1679 low: Box::new(convert_expr(low)?),
1680 high: Box::new(convert_expr(high)?),
1681 negated: *negated,
1682 }),
1683 sp::Expr::Like {
1684 expr: e,
1685 negated,
1686 pattern,
1687 escape_char,
1688 ..
1689 } => {
1690 let esc = escape_char
1691 .as_ref()
1692 .map(convert_escape_value)
1693 .transpose()?
1694 .map(Box::new);
1695 Ok(Expr::Like {
1696 expr: Box::new(convert_expr(e)?),
1697 pattern: Box::new(convert_expr(pattern)?),
1698 escape: esc,
1699 negated: *negated,
1700 })
1701 }
1702 sp::Expr::ILike {
1703 expr: e,
1704 negated,
1705 pattern,
1706 escape_char,
1707 ..
1708 } => {
1709 let esc = escape_char
1710 .as_ref()
1711 .map(convert_escape_value)
1712 .transpose()?
1713 .map(Box::new);
1714 Ok(Expr::Like {
1715 expr: Box::new(convert_expr(e)?),
1716 pattern: Box::new(convert_expr(pattern)?),
1717 escape: esc,
1718 negated: *negated,
1719 })
1720 }
1721 sp::Expr::Case {
1722 operand,
1723 conditions,
1724 else_result,
1725 ..
1726 } => {
1727 let op = operand
1728 .as_ref()
1729 .map(|e| convert_expr(e))
1730 .transpose()?
1731 .map(Box::new);
1732 let conds: Vec<(Expr, Expr)> = conditions
1733 .iter()
1734 .map(|cw| Ok((convert_expr(&cw.condition)?, convert_expr(&cw.result)?)))
1735 .collect::<Result<_>>()?;
1736 let else_r = else_result
1737 .as_ref()
1738 .map(|e| convert_expr(e))
1739 .transpose()?
1740 .map(Box::new);
1741 Ok(Expr::Case {
1742 operand: op,
1743 conditions: conds,
1744 else_result: else_r,
1745 })
1746 }
1747 sp::Expr::Cast {
1748 expr: e,
1749 data_type: dt,
1750 ..
1751 } => {
1752 let target = convert_data_type(dt)?;
1753 Ok(Expr::Cast {
1754 expr: Box::new(convert_expr(e)?),
1755 data_type: target,
1756 })
1757 }
1758 sp::Expr::Substring {
1759 expr: e,
1760 substring_from,
1761 substring_for,
1762 ..
1763 } => {
1764 let mut args = vec![convert_expr(e)?];
1765 if let Some(from) = substring_from {
1766 args.push(convert_expr(from)?);
1767 }
1768 if let Some(f) = substring_for {
1769 args.push(convert_expr(f)?);
1770 }
1771 Ok(Expr::Function {
1772 name: "SUBSTR".into(),
1773 args,
1774 })
1775 }
1776 sp::Expr::Trim {
1777 expr: e,
1778 trim_where,
1779 trim_what,
1780 trim_characters,
1781 } => {
1782 let fn_name = match trim_where {
1783 Some(sp::TrimWhereField::Leading) => "LTRIM",
1784 Some(sp::TrimWhereField::Trailing) => "RTRIM",
1785 _ => "TRIM",
1786 };
1787 let mut args = vec![convert_expr(e)?];
1788 if let Some(what) = trim_what {
1789 args.push(convert_expr(what)?);
1790 } else if let Some(chars) = trim_characters {
1791 if let Some(first) = chars.first() {
1792 args.push(convert_expr(first)?);
1793 }
1794 }
1795 Ok(Expr::Function {
1796 name: fn_name.into(),
1797 args,
1798 })
1799 }
1800 sp::Expr::Ceil { expr: e, .. } => Ok(Expr::Function {
1801 name: "CEIL".into(),
1802 args: vec![convert_expr(e)?],
1803 }),
1804 sp::Expr::Floor { expr: e, .. } => Ok(Expr::Function {
1805 name: "FLOOR".into(),
1806 args: vec![convert_expr(e)?],
1807 }),
1808 sp::Expr::Position { expr: e, r#in } => Ok(Expr::Function {
1809 name: "INSTR".into(),
1810 args: vec![convert_expr(r#in)?, convert_expr(e)?],
1811 }),
1812 _ => Err(SqlError::Unsupported(format!("expression: {expr}"))),
1813 }
1814}
1815
1816fn convert_value(val: &sp::Value) -> Result<Expr> {
1817 match val {
1818 sp::Value::Number(n, _) => {
1819 if let Ok(i) = n.parse::<i64>() {
1820 Ok(Expr::Literal(Value::Integer(i)))
1821 } else if let Ok(f) = n.parse::<f64>() {
1822 Ok(Expr::Literal(Value::Real(f)))
1823 } else {
1824 Err(SqlError::InvalidValue(format!("cannot parse number: {n}")))
1825 }
1826 }
1827 sp::Value::SingleQuotedString(s) => Ok(Expr::Literal(Value::Text(s.as_str().into()))),
1828 sp::Value::Boolean(b) => Ok(Expr::Literal(Value::Boolean(*b))),
1829 sp::Value::Null => Ok(Expr::Literal(Value::Null)),
1830 sp::Value::Placeholder(s) => {
1831 let idx_str = s
1832 .strip_prefix('$')
1833 .ok_or_else(|| SqlError::Parse(format!("invalid placeholder: {s}")))?;
1834 let idx: usize = idx_str
1835 .parse()
1836 .map_err(|_| SqlError::Parse(format!("invalid placeholder index: {s}")))?;
1837 if idx == 0 {
1838 return Err(SqlError::Parse("placeholder index must be >= 1".into()));
1839 }
1840 Ok(Expr::Parameter(idx))
1841 }
1842 _ => Err(SqlError::Unsupported(format!("value type: {val}"))),
1843 }
1844}
1845
1846fn convert_escape_value(val: &sp::Value) -> Result<Expr> {
1847 match val {
1848 sp::Value::SingleQuotedString(s) => Ok(Expr::Literal(Value::Text(s.as_str().into()))),
1849 _ => Err(SqlError::Unsupported(format!("ESCAPE value: {val}"))),
1850 }
1851}
1852
1853fn convert_bin_op(op: &sp::BinaryOperator) -> Result<BinOp> {
1854 match op {
1855 sp::BinaryOperator::Plus => Ok(BinOp::Add),
1856 sp::BinaryOperator::Minus => Ok(BinOp::Sub),
1857 sp::BinaryOperator::Multiply => Ok(BinOp::Mul),
1858 sp::BinaryOperator::Divide => Ok(BinOp::Div),
1859 sp::BinaryOperator::Modulo => Ok(BinOp::Mod),
1860 sp::BinaryOperator::Eq => Ok(BinOp::Eq),
1861 sp::BinaryOperator::NotEq => Ok(BinOp::NotEq),
1862 sp::BinaryOperator::Lt => Ok(BinOp::Lt),
1863 sp::BinaryOperator::Gt => Ok(BinOp::Gt),
1864 sp::BinaryOperator::LtEq => Ok(BinOp::LtEq),
1865 sp::BinaryOperator::GtEq => Ok(BinOp::GtEq),
1866 sp::BinaryOperator::And => Ok(BinOp::And),
1867 sp::BinaryOperator::Or => Ok(BinOp::Or),
1868 sp::BinaryOperator::StringConcat => Ok(BinOp::Concat),
1869 _ => Err(SqlError::Unsupported(format!("binary op: {op}"))),
1870 }
1871}
1872
1873fn convert_function(func: &sp::Function) -> Result<Expr> {
1874 let name = object_name_to_string(&func.name).to_ascii_uppercase();
1875
1876 match &func.args {
1878 sp::FunctionArguments::List(list) => {
1879 if list.args.is_empty() && name == "COUNT" {
1880 return Ok(Expr::CountStar);
1881 }
1882 let args = list
1883 .args
1884 .iter()
1885 .map(|arg| match arg {
1886 sp::FunctionArg::Unnamed(sp::FunctionArgExpr::Expr(e)) => convert_expr(e),
1887 sp::FunctionArg::Unnamed(sp::FunctionArgExpr::Wildcard) => {
1888 if name == "COUNT" {
1889 Ok(Expr::CountStar)
1890 } else {
1891 Err(SqlError::Unsupported(format!("{name}(*)")))
1892 }
1893 }
1894 _ => Err(SqlError::Unsupported(format!(
1895 "function arg type in {name}"
1896 ))),
1897 })
1898 .collect::<Result<Vec<_>>>()?;
1899
1900 if name == "COUNT" && args.len() == 1 && matches!(args[0], Expr::CountStar) {
1901 return Ok(Expr::CountStar);
1902 }
1903
1904 if name == "COALESCE" {
1905 if args.is_empty() {
1906 return Err(SqlError::Parse(
1907 "COALESCE requires at least one argument".into(),
1908 ));
1909 }
1910 return Ok(Expr::Coalesce(args));
1911 }
1912
1913 if name == "NULLIF" {
1914 if args.len() != 2 {
1915 return Err(SqlError::Parse(
1916 "NULLIF requires exactly two arguments".into(),
1917 ));
1918 }
1919 return Ok(Expr::Case {
1920 operand: None,
1921 conditions: vec![(
1922 Expr::BinaryOp {
1923 left: Box::new(args[0].clone()),
1924 op: BinOp::Eq,
1925 right: Box::new(args[1].clone()),
1926 },
1927 Expr::Literal(Value::Null),
1928 )],
1929 else_result: Some(Box::new(args[0].clone())),
1930 });
1931 }
1932
1933 if name == "IIF" {
1934 if args.len() != 3 {
1935 return Err(SqlError::Parse(
1936 "IIF requires exactly three arguments".into(),
1937 ));
1938 }
1939 return Ok(Expr::Case {
1940 operand: None,
1941 conditions: vec![(args[0].clone(), args[1].clone())],
1942 else_result: Some(Box::new(args[2].clone())),
1943 });
1944 }
1945
1946 Ok(Expr::Function { name, args })
1947 }
1948 sp::FunctionArguments::None => {
1949 if name == "COUNT" {
1950 Ok(Expr::CountStar)
1951 } else {
1952 Ok(Expr::Function { name, args: vec![] })
1953 }
1954 }
1955 sp::FunctionArguments::Subquery(_) => {
1956 Err(SqlError::Unsupported("subquery in function".into()))
1957 }
1958 }
1959}
1960
1961fn convert_select_item(item: &sp::SelectItem) -> Result<SelectColumn> {
1962 match item {
1963 sp::SelectItem::Wildcard(_) => Ok(SelectColumn::AllColumns),
1964 sp::SelectItem::UnnamedExpr(e) => {
1965 let expr = convert_expr(e)?;
1966 Ok(SelectColumn::Expr { expr, alias: None })
1967 }
1968 sp::SelectItem::ExprWithAlias { expr, alias } => {
1969 let expr = convert_expr(expr)?;
1970 Ok(SelectColumn::Expr {
1971 expr,
1972 alias: Some(alias.value.clone()),
1973 })
1974 }
1975 sp::SelectItem::QualifiedWildcard(_, _) => {
1976 Err(SqlError::Unsupported("qualified wildcard (table.*)".into()))
1977 }
1978 }
1979}
1980
1981fn convert_order_by_expr(expr: &sp::OrderByExpr) -> Result<OrderByItem> {
1982 let e = convert_expr(&expr.expr)?;
1983 let descending = expr.options.asc.map(|asc| !asc).unwrap_or(false);
1984 let nulls_first = expr.options.nulls_first;
1985
1986 Ok(OrderByItem {
1987 expr: e,
1988 descending,
1989 nulls_first,
1990 })
1991}
1992
1993fn convert_data_type(dt: &sp::DataType) -> Result<DataType> {
1996 match dt {
1997 sp::DataType::Int(_)
1998 | sp::DataType::Integer(_)
1999 | sp::DataType::BigInt(_)
2000 | sp::DataType::SmallInt(_)
2001 | sp::DataType::TinyInt(_)
2002 | sp::DataType::Int2(_)
2003 | sp::DataType::Int4(_)
2004 | sp::DataType::Int8(_) => Ok(DataType::Integer),
2005
2006 sp::DataType::Real
2007 | sp::DataType::Double(..)
2008 | sp::DataType::DoublePrecision
2009 | sp::DataType::Float(_)
2010 | sp::DataType::Float4
2011 | sp::DataType::Float64 => Ok(DataType::Real),
2012
2013 sp::DataType::Varchar(_)
2014 | sp::DataType::Text
2015 | sp::DataType::Char(_)
2016 | sp::DataType::Character(_)
2017 | sp::DataType::String(_) => Ok(DataType::Text),
2018
2019 sp::DataType::Blob(_) | sp::DataType::Bytea => Ok(DataType::Blob),
2020
2021 sp::DataType::Boolean | sp::DataType::Bool => Ok(DataType::Boolean),
2022
2023 _ => Err(SqlError::Unsupported(format!("data type: {dt}"))),
2024 }
2025}
2026
2027fn object_name_to_string(name: &sp::ObjectName) -> String {
2030 name.0
2031 .iter()
2032 .filter_map(|p| match p {
2033 sp::ObjectNamePart::Identifier(ident) => Some(ident.value.clone()),
2034 _ => None,
2035 })
2036 .collect::<Vec<_>>()
2037 .join(".")
2038}
2039
2040#[cfg(test)]
2041mod tests {
2042 use super::*;
2043
2044 #[test]
2045 fn parse_create_table() {
2046 let stmt = parse_sql(
2047 "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL, age INTEGER)",
2048 )
2049 .unwrap();
2050
2051 match stmt {
2052 Statement::CreateTable(ct) => {
2053 assert_eq!(ct.name, "users");
2054 assert_eq!(ct.columns.len(), 3);
2055 assert_eq!(ct.columns[0].name, "id");
2056 assert_eq!(ct.columns[0].data_type, DataType::Integer);
2057 assert!(ct.columns[0].is_primary_key);
2058 assert!(!ct.columns[0].nullable);
2059 assert_eq!(ct.columns[1].name, "name");
2060 assert_eq!(ct.columns[1].data_type, DataType::Text);
2061 assert!(!ct.columns[1].nullable);
2062 assert_eq!(ct.columns[2].name, "age");
2063 assert!(ct.columns[2].nullable);
2064 assert_eq!(ct.primary_key, vec!["id"]);
2065 }
2066 _ => panic!("expected CreateTable"),
2067 }
2068 }
2069
2070 #[test]
2071 fn parse_create_table_if_not_exists() {
2072 let stmt = parse_sql("CREATE TABLE IF NOT EXISTS t (id INT PRIMARY KEY)").unwrap();
2073 match stmt {
2074 Statement::CreateTable(ct) => assert!(ct.if_not_exists),
2075 _ => panic!("expected CreateTable"),
2076 }
2077 }
2078
2079 #[test]
2080 fn parse_drop_table() {
2081 let stmt = parse_sql("DROP TABLE users").unwrap();
2082 match stmt {
2083 Statement::DropTable(dt) => {
2084 assert_eq!(dt.name, "users");
2085 assert!(!dt.if_exists);
2086 }
2087 _ => panic!("expected DropTable"),
2088 }
2089 }
2090
2091 #[test]
2092 fn parse_drop_table_if_exists() {
2093 let stmt = parse_sql("DROP TABLE IF EXISTS users").unwrap();
2094 match stmt {
2095 Statement::DropTable(dt) => assert!(dt.if_exists),
2096 _ => panic!("expected DropTable"),
2097 }
2098 }
2099
2100 #[test]
2101 fn parse_insert() {
2102 let stmt =
2103 parse_sql("INSERT INTO users (id, name) VALUES (1, 'Alice'), (2, 'Bob')").unwrap();
2104
2105 match stmt {
2106 Statement::Insert(ins) => {
2107 assert_eq!(ins.table, "users");
2108 assert_eq!(ins.columns, vec!["id", "name"]);
2109 let values = match &ins.source {
2110 InsertSource::Values(v) => v,
2111 _ => panic!("expected Values"),
2112 };
2113 assert_eq!(values.len(), 2);
2114 assert!(matches!(values[0][0], Expr::Literal(Value::Integer(1))));
2115 assert!(matches!(&values[0][1], Expr::Literal(Value::Text(s)) if s == "Alice"));
2116 }
2117 _ => panic!("expected Insert"),
2118 }
2119 }
2120
2121 #[test]
2122 fn parse_select_all() {
2123 let stmt = parse_sql("SELECT * FROM users").unwrap();
2124 match stmt {
2125 Statement::Select(sq) => match sq.body {
2126 QueryBody::Select(sel) => {
2127 assert_eq!(sel.from, "users");
2128 assert!(matches!(sel.columns[0], SelectColumn::AllColumns));
2129 assert!(sel.where_clause.is_none());
2130 }
2131 _ => panic!("expected QueryBody::Select"),
2132 },
2133 _ => panic!("expected Select"),
2134 }
2135 }
2136
2137 #[test]
2138 fn parse_select_where() {
2139 let stmt = parse_sql("SELECT id, name FROM users WHERE age > 18").unwrap();
2140 match stmt {
2141 Statement::Select(sq) => match sq.body {
2142 QueryBody::Select(sel) => {
2143 assert_eq!(sel.columns.len(), 2);
2144 assert!(sel.where_clause.is_some());
2145 }
2146 _ => panic!("expected QueryBody::Select"),
2147 },
2148 _ => panic!("expected Select"),
2149 }
2150 }
2151
2152 #[test]
2153 fn parse_select_order_limit() {
2154 let stmt = parse_sql("SELECT * FROM users ORDER BY name ASC LIMIT 10 OFFSET 5").unwrap();
2155 match stmt {
2156 Statement::Select(sq) => match sq.body {
2157 QueryBody::Select(sel) => {
2158 assert_eq!(sel.order_by.len(), 1);
2159 assert!(!sel.order_by[0].descending);
2160 assert!(sel.limit.is_some());
2161 assert!(sel.offset.is_some());
2162 }
2163 _ => panic!("expected QueryBody::Select"),
2164 },
2165 _ => panic!("expected Select"),
2166 }
2167 }
2168
2169 #[test]
2170 fn parse_update() {
2171 let stmt = parse_sql("UPDATE users SET name = 'Bob' WHERE id = 1").unwrap();
2172 match stmt {
2173 Statement::Update(upd) => {
2174 assert_eq!(upd.table, "users");
2175 assert_eq!(upd.assignments.len(), 1);
2176 assert_eq!(upd.assignments[0].0, "name");
2177 assert!(upd.where_clause.is_some());
2178 }
2179 _ => panic!("expected Update"),
2180 }
2181 }
2182
2183 #[test]
2184 fn parse_delete() {
2185 let stmt = parse_sql("DELETE FROM users WHERE id = 1").unwrap();
2186 match stmt {
2187 Statement::Delete(del) => {
2188 assert_eq!(del.table, "users");
2189 assert!(del.where_clause.is_some());
2190 }
2191 _ => panic!("expected Delete"),
2192 }
2193 }
2194
2195 #[test]
2196 fn parse_aggregate() {
2197 let stmt = parse_sql("SELECT COUNT(*), SUM(age) FROM users").unwrap();
2198 match stmt {
2199 Statement::Select(sq) => match sq.body {
2200 QueryBody::Select(sel) => {
2201 assert_eq!(sel.columns.len(), 2);
2202 match &sel.columns[0] {
2203 SelectColumn::Expr {
2204 expr: Expr::CountStar,
2205 ..
2206 } => {}
2207 other => panic!("expected CountStar, got {other:?}"),
2208 }
2209 }
2210 _ => panic!("expected QueryBody::Select"),
2211 },
2212 _ => panic!("expected Select"),
2213 }
2214 }
2215
2216 #[test]
2217 fn parse_group_by_having() {
2218 let stmt = parse_sql(
2219 "SELECT department, COUNT(*) FROM employees GROUP BY department HAVING COUNT(*) > 5",
2220 )
2221 .unwrap();
2222 match stmt {
2223 Statement::Select(sq) => match sq.body {
2224 QueryBody::Select(sel) => {
2225 assert_eq!(sel.group_by.len(), 1);
2226 assert!(sel.having.is_some());
2227 }
2228 _ => panic!("expected QueryBody::Select"),
2229 },
2230 _ => panic!("expected Select"),
2231 }
2232 }
2233
2234 #[test]
2235 fn parse_expressions() {
2236 let stmt = parse_sql("SELECT id + 1, -price, NOT active FROM items").unwrap();
2237 match stmt {
2238 Statement::Select(sq) => match sq.body {
2239 QueryBody::Select(sel) => {
2240 assert_eq!(sel.columns.len(), 3);
2241 match &sel.columns[0] {
2243 SelectColumn::Expr {
2244 expr: Expr::BinaryOp { op: BinOp::Add, .. },
2245 ..
2246 } => {}
2247 other => panic!("expected BinaryOp Add, got {other:?}"),
2248 }
2249 match &sel.columns[1] {
2251 SelectColumn::Expr {
2252 expr:
2253 Expr::UnaryOp {
2254 op: UnaryOp::Neg, ..
2255 },
2256 ..
2257 } => {}
2258 other => panic!("expected UnaryOp Neg, got {other:?}"),
2259 }
2260 match &sel.columns[2] {
2262 SelectColumn::Expr {
2263 expr:
2264 Expr::UnaryOp {
2265 op: UnaryOp::Not, ..
2266 },
2267 ..
2268 } => {}
2269 other => panic!("expected UnaryOp Not, got {other:?}"),
2270 }
2271 }
2272 _ => panic!("expected QueryBody::Select"),
2273 },
2274 _ => panic!("expected Select"),
2275 }
2276 }
2277
2278 #[test]
2279 fn parse_is_null() {
2280 let stmt = parse_sql("SELECT * FROM t WHERE x IS NULL").unwrap();
2281 match stmt {
2282 Statement::Select(sq) => match sq.body {
2283 QueryBody::Select(sel) => {
2284 assert!(matches!(sel.where_clause, Some(Expr::IsNull(_))));
2285 }
2286 _ => panic!("expected QueryBody::Select"),
2287 },
2288 _ => panic!("expected Select"),
2289 }
2290 }
2291
2292 #[test]
2293 fn parse_inner_join() {
2294 let stmt = parse_sql("SELECT * FROM a JOIN b ON a.id = b.id").unwrap();
2295 match stmt {
2296 Statement::Select(sq) => match sq.body {
2297 QueryBody::Select(sel) => {
2298 assert_eq!(sel.from, "a");
2299 assert_eq!(sel.joins.len(), 1);
2300 assert_eq!(sel.joins[0].join_type, JoinType::Inner);
2301 assert_eq!(sel.joins[0].table.name, "b");
2302 assert!(sel.joins[0].on_clause.is_some());
2303 }
2304 _ => panic!("expected QueryBody::Select"),
2305 },
2306 _ => panic!("expected Select"),
2307 }
2308 }
2309
2310 #[test]
2311 fn parse_inner_join_explicit() {
2312 let stmt = parse_sql("SELECT * FROM a INNER JOIN b ON a.id = b.a_id").unwrap();
2313 match stmt {
2314 Statement::Select(sq) => match sq.body {
2315 QueryBody::Select(sel) => {
2316 assert_eq!(sel.joins.len(), 1);
2317 assert_eq!(sel.joins[0].join_type, JoinType::Inner);
2318 }
2319 _ => panic!("expected QueryBody::Select"),
2320 },
2321 _ => panic!("expected Select"),
2322 }
2323 }
2324
2325 #[test]
2326 fn parse_cross_join() {
2327 let stmt = parse_sql("SELECT * FROM a CROSS JOIN b").unwrap();
2328 match stmt {
2329 Statement::Select(sq) => match sq.body {
2330 QueryBody::Select(sel) => {
2331 assert_eq!(sel.joins.len(), 1);
2332 assert_eq!(sel.joins[0].join_type, JoinType::Cross);
2333 assert!(sel.joins[0].on_clause.is_none());
2334 }
2335 _ => panic!("expected QueryBody::Select"),
2336 },
2337 _ => panic!("expected Select"),
2338 }
2339 }
2340
2341 #[test]
2342 fn parse_left_join() {
2343 let stmt = parse_sql("SELECT * FROM a LEFT JOIN b ON a.id = b.a_id").unwrap();
2344 match stmt {
2345 Statement::Select(sq) => match sq.body {
2346 QueryBody::Select(sel) => {
2347 assert_eq!(sel.joins.len(), 1);
2348 assert_eq!(sel.joins[0].join_type, JoinType::Left);
2349 }
2350 _ => panic!("expected QueryBody::Select"),
2351 },
2352 _ => panic!("expected Select"),
2353 }
2354 }
2355
2356 #[test]
2357 fn parse_table_alias() {
2358 let stmt = parse_sql("SELECT u.id FROM users u JOIN orders o ON u.id = o.user_id").unwrap();
2359 match stmt {
2360 Statement::Select(sq) => match sq.body {
2361 QueryBody::Select(sel) => {
2362 assert_eq!(sel.from, "users");
2363 assert_eq!(sel.from_alias.as_deref(), Some("u"));
2364 assert_eq!(sel.joins[0].table.name, "orders");
2365 assert_eq!(sel.joins[0].table.alias.as_deref(), Some("o"));
2366 }
2367 _ => panic!("expected QueryBody::Select"),
2368 },
2369 _ => panic!("expected Select"),
2370 }
2371 }
2372
2373 #[test]
2374 fn parse_multi_join() {
2375 let stmt =
2376 parse_sql("SELECT * FROM a JOIN b ON a.id = b.a_id JOIN c ON b.id = c.b_id").unwrap();
2377 match stmt {
2378 Statement::Select(sq) => match sq.body {
2379 QueryBody::Select(sel) => {
2380 assert_eq!(sel.joins.len(), 2);
2381 }
2382 _ => panic!("expected QueryBody::Select"),
2383 },
2384 _ => panic!("expected Select"),
2385 }
2386 }
2387
2388 #[test]
2389 fn parse_qualified_column() {
2390 let stmt = parse_sql("SELECT u.id, u.name FROM users u").unwrap();
2391 match stmt {
2392 Statement::Select(sq) => match sq.body {
2393 QueryBody::Select(sel) => match &sel.columns[0] {
2394 SelectColumn::Expr {
2395 expr: Expr::QualifiedColumn { table, column },
2396 ..
2397 } => {
2398 assert_eq!(table, "u");
2399 assert_eq!(column, "id");
2400 }
2401 other => panic!("expected QualifiedColumn, got {other:?}"),
2402 },
2403 _ => panic!("expected QueryBody::Select"),
2404 },
2405 _ => panic!("expected Select"),
2406 }
2407 }
2408
2409 #[test]
2410 fn reject_subquery() {
2411 assert!(parse_sql("SELECT * FROM (SELECT 1)").is_err());
2412 }
2413
2414 #[test]
2415 fn parse_type_mapping() {
2416 let stmt = parse_sql(
2417 "CREATE TABLE t (a INT PRIMARY KEY, b BIGINT, c SMALLINT, d REAL, e DOUBLE PRECISION, f VARCHAR(255), g BOOLEAN, h BLOB, i BYTEA)"
2418 ).unwrap();
2419 match stmt {
2420 Statement::CreateTable(ct) => {
2421 assert_eq!(ct.columns[0].data_type, DataType::Integer); assert_eq!(ct.columns[1].data_type, DataType::Integer); assert_eq!(ct.columns[2].data_type, DataType::Integer); assert_eq!(ct.columns[3].data_type, DataType::Real); assert_eq!(ct.columns[4].data_type, DataType::Real); assert_eq!(ct.columns[5].data_type, DataType::Text); assert_eq!(ct.columns[6].data_type, DataType::Boolean); assert_eq!(ct.columns[7].data_type, DataType::Blob); assert_eq!(ct.columns[8].data_type, DataType::Blob); }
2431 _ => panic!("expected CreateTable"),
2432 }
2433 }
2434
2435 #[test]
2436 fn parse_boolean_literals() {
2437 let stmt = parse_sql("INSERT INTO t (a, b) VALUES (true, false)").unwrap();
2438 match stmt {
2439 Statement::Insert(ins) => {
2440 let values = match &ins.source {
2441 InsertSource::Values(v) => v,
2442 _ => panic!("expected Values"),
2443 };
2444 assert!(matches!(values[0][0], Expr::Literal(Value::Boolean(true))));
2445 assert!(matches!(values[0][1], Expr::Literal(Value::Boolean(false))));
2446 }
2447 _ => panic!("expected Insert"),
2448 }
2449 }
2450
2451 #[test]
2452 fn parse_null_literal() {
2453 let stmt = parse_sql("INSERT INTO t (a) VALUES (NULL)").unwrap();
2454 match stmt {
2455 Statement::Insert(ins) => {
2456 let values = match &ins.source {
2457 InsertSource::Values(v) => v,
2458 _ => panic!("expected Values"),
2459 };
2460 assert!(matches!(values[0][0], Expr::Literal(Value::Null)));
2461 }
2462 _ => panic!("expected Insert"),
2463 }
2464 }
2465
2466 #[test]
2467 fn parse_alias() {
2468 let stmt = parse_sql("SELECT id AS user_id FROM users").unwrap();
2469 match stmt {
2470 Statement::Select(sq) => match sq.body {
2471 QueryBody::Select(sel) => match &sel.columns[0] {
2472 SelectColumn::Expr { alias: Some(a), .. } => assert_eq!(a, "user_id"),
2473 other => panic!("expected alias, got {other:?}"),
2474 },
2475 _ => panic!("expected QueryBody::Select"),
2476 },
2477 _ => panic!("expected Select"),
2478 }
2479 }
2480
2481 #[test]
2482 fn parse_begin() {
2483 let stmt = parse_sql("BEGIN").unwrap();
2484 assert!(matches!(stmt, Statement::Begin));
2485 }
2486
2487 #[test]
2488 fn parse_begin_transaction() {
2489 let stmt = parse_sql("BEGIN TRANSACTION").unwrap();
2490 assert!(matches!(stmt, Statement::Begin));
2491 }
2492
2493 #[test]
2494 fn parse_commit() {
2495 let stmt = parse_sql("COMMIT").unwrap();
2496 assert!(matches!(stmt, Statement::Commit));
2497 }
2498
2499 #[test]
2500 fn parse_rollback() {
2501 let stmt = parse_sql("ROLLBACK").unwrap();
2502 assert!(matches!(stmt, Statement::Rollback));
2503 }
2504
2505 #[test]
2506 fn parse_select_distinct() {
2507 let stmt = parse_sql("SELECT DISTINCT name FROM users").unwrap();
2508 match stmt {
2509 Statement::Select(sq) => match sq.body {
2510 QueryBody::Select(sel) => {
2511 assert!(sel.distinct);
2512 assert_eq!(sel.columns.len(), 1);
2513 }
2514 _ => panic!("expected QueryBody::Select"),
2515 },
2516 _ => panic!("expected Select"),
2517 }
2518 }
2519
2520 #[test]
2521 fn parse_select_without_distinct() {
2522 let stmt = parse_sql("SELECT name FROM users").unwrap();
2523 match stmt {
2524 Statement::Select(sq) => match sq.body {
2525 QueryBody::Select(sel) => {
2526 assert!(!sel.distinct);
2527 }
2528 _ => panic!("expected QueryBody::Select"),
2529 },
2530 _ => panic!("expected Select"),
2531 }
2532 }
2533
2534 #[test]
2535 fn parse_select_distinct_all_columns() {
2536 let stmt = parse_sql("SELECT DISTINCT * FROM users").unwrap();
2537 match stmt {
2538 Statement::Select(sq) => match sq.body {
2539 QueryBody::Select(sel) => {
2540 assert!(sel.distinct);
2541 assert!(matches!(sel.columns[0], SelectColumn::AllColumns));
2542 }
2543 _ => panic!("expected QueryBody::Select"),
2544 },
2545 _ => panic!("expected Select"),
2546 }
2547 }
2548
2549 #[test]
2550 fn reject_distinct_on() {
2551 assert!(parse_sql("SELECT DISTINCT ON (id) * FROM users").is_err());
2552 }
2553
2554 #[test]
2555 fn parse_create_index() {
2556 let stmt = parse_sql("CREATE INDEX idx_name ON users (name)").unwrap();
2557 match stmt {
2558 Statement::CreateIndex(ci) => {
2559 assert_eq!(ci.index_name, "idx_name");
2560 assert_eq!(ci.table_name, "users");
2561 assert_eq!(ci.columns, vec!["name"]);
2562 assert!(!ci.unique);
2563 assert!(!ci.if_not_exists);
2564 }
2565 _ => panic!("expected CreateIndex"),
2566 }
2567 }
2568
2569 #[test]
2570 fn parse_create_unique_index() {
2571 let stmt = parse_sql("CREATE UNIQUE INDEX idx_email ON users (email)").unwrap();
2572 match stmt {
2573 Statement::CreateIndex(ci) => {
2574 assert!(ci.unique);
2575 assert_eq!(ci.columns, vec!["email"]);
2576 }
2577 _ => panic!("expected CreateIndex"),
2578 }
2579 }
2580
2581 #[test]
2582 fn parse_create_index_if_not_exists() {
2583 let stmt = parse_sql("CREATE INDEX IF NOT EXISTS idx_x ON t (a)").unwrap();
2584 match stmt {
2585 Statement::CreateIndex(ci) => assert!(ci.if_not_exists),
2586 _ => panic!("expected CreateIndex"),
2587 }
2588 }
2589
2590 #[test]
2591 fn parse_create_index_multi_column() {
2592 let stmt = parse_sql("CREATE INDEX idx_multi ON t (a, b, c)").unwrap();
2593 match stmt {
2594 Statement::CreateIndex(ci) => {
2595 assert_eq!(ci.columns, vec!["a", "b", "c"]);
2596 }
2597 _ => panic!("expected CreateIndex"),
2598 }
2599 }
2600
2601 #[test]
2602 fn parse_drop_index() {
2603 let stmt = parse_sql("DROP INDEX idx_name").unwrap();
2604 match stmt {
2605 Statement::DropIndex(di) => {
2606 assert_eq!(di.index_name, "idx_name");
2607 assert!(!di.if_exists);
2608 }
2609 _ => panic!("expected DropIndex"),
2610 }
2611 }
2612
2613 #[test]
2614 fn parse_drop_index_if_exists() {
2615 let stmt = parse_sql("DROP INDEX IF EXISTS idx_name").unwrap();
2616 match stmt {
2617 Statement::DropIndex(di) => {
2618 assert!(di.if_exists);
2619 }
2620 _ => panic!("expected DropIndex"),
2621 }
2622 }
2623
2624 #[test]
2625 fn parse_explain_select() {
2626 let stmt = parse_sql("EXPLAIN SELECT * FROM users WHERE id = 1").unwrap();
2627 match stmt {
2628 Statement::Explain(inner) => {
2629 assert!(matches!(*inner, Statement::Select(_)));
2630 }
2631 _ => panic!("expected Explain"),
2632 }
2633 }
2634
2635 #[test]
2636 fn parse_explain_insert() {
2637 let stmt = parse_sql("EXPLAIN INSERT INTO t (a) VALUES (1)").unwrap();
2638 assert!(matches!(stmt, Statement::Explain(_)));
2639 }
2640
2641 #[test]
2642 fn reject_explain_analyze() {
2643 assert!(parse_sql("EXPLAIN ANALYZE SELECT * FROM t").is_err());
2644 }
2645
2646 #[test]
2647 fn parse_parameter_placeholder() {
2648 let stmt = parse_sql("SELECT * FROM t WHERE id = $1").unwrap();
2649 match stmt {
2650 Statement::Select(sq) => match sq.body {
2651 QueryBody::Select(sel) => match &sel.where_clause {
2652 Some(Expr::BinaryOp { right, .. }) => {
2653 assert!(matches!(right.as_ref(), Expr::Parameter(1)));
2654 }
2655 other => panic!("expected BinaryOp with Parameter, got {other:?}"),
2656 },
2657 _ => panic!("expected QueryBody::Select"),
2658 },
2659 _ => panic!("expected Select"),
2660 }
2661 }
2662
2663 #[test]
2664 fn parse_multiple_parameters() {
2665 let stmt = parse_sql("INSERT INTO t (a, b) VALUES ($1, $2)").unwrap();
2666 match stmt {
2667 Statement::Insert(ins) => {
2668 let values = match &ins.source {
2669 InsertSource::Values(v) => v,
2670 _ => panic!("expected Values"),
2671 };
2672 assert!(matches!(values[0][0], Expr::Parameter(1)));
2673 assert!(matches!(values[0][1], Expr::Parameter(2)));
2674 }
2675 _ => panic!("expected Insert"),
2676 }
2677 }
2678
2679 #[test]
2680 fn parse_insert_select() {
2681 let stmt =
2682 parse_sql("INSERT INTO t2 (id, name) SELECT id, name FROM t1 WHERE id > 5").unwrap();
2683 match stmt {
2684 Statement::Insert(ins) => {
2685 assert_eq!(ins.table, "t2");
2686 assert_eq!(ins.columns, vec!["id", "name"]);
2687 match &ins.source {
2688 InsertSource::Select(sq) => match &sq.body {
2689 QueryBody::Select(sel) => {
2690 assert_eq!(sel.from, "t1");
2691 assert_eq!(sel.columns.len(), 2);
2692 assert!(sel.where_clause.is_some());
2693 }
2694 _ => panic!("expected QueryBody::Select"),
2695 },
2696 _ => panic!("expected InsertSource::Select"),
2697 }
2698 }
2699 _ => panic!("expected Insert"),
2700 }
2701 }
2702
2703 #[test]
2704 fn parse_insert_select_no_columns() {
2705 let stmt = parse_sql("INSERT INTO t2 SELECT * FROM t1").unwrap();
2706 match stmt {
2707 Statement::Insert(ins) => {
2708 assert_eq!(ins.table, "t2");
2709 assert!(ins.columns.is_empty());
2710 assert!(matches!(&ins.source, InsertSource::Select(_)));
2711 }
2712 _ => panic!("expected Insert"),
2713 }
2714 }
2715
2716 #[test]
2717 fn reject_zero_parameter() {
2718 assert!(parse_sql("SELECT $0 FROM t").is_err());
2719 }
2720
2721 #[test]
2722 fn count_params_basic() {
2723 let stmt = parse_sql("SELECT * FROM t WHERE a = $1 AND b = $3").unwrap();
2724 assert_eq!(count_params(&stmt), 3);
2725 }
2726
2727 #[test]
2728 fn count_params_none() {
2729 let stmt = parse_sql("SELECT * FROM t WHERE a = 1").unwrap();
2730 assert_eq!(count_params(&stmt), 0);
2731 }
2732
2733 #[test]
2734 fn bind_params_basic() {
2735 let stmt = parse_sql("SELECT * FROM t WHERE id = $1").unwrap();
2736 let bound = bind_params(&stmt, &[Value::Integer(42)]).unwrap();
2737 match bound {
2738 Statement::Select(sq) => match sq.body {
2739 QueryBody::Select(sel) => match &sel.where_clause {
2740 Some(Expr::BinaryOp { right, .. }) => {
2741 assert!(matches!(right.as_ref(), Expr::Literal(Value::Integer(42))));
2742 }
2743 other => panic!("expected BinaryOp with Literal, got {other:?}"),
2744 },
2745 _ => panic!("expected QueryBody::Select"),
2746 },
2747 _ => panic!("expected Select"),
2748 }
2749 }
2750
2751 #[test]
2752 fn bind_params_out_of_range() {
2753 let stmt = parse_sql("SELECT * FROM t WHERE id = $2").unwrap();
2754 let result = bind_params(&stmt, &[Value::Integer(1)]);
2755 assert!(result.is_err());
2756 }
2757
2758 #[test]
2759 fn parse_table_constraint_pk() {
2760 let stmt = parse_sql("CREATE TABLE t (a INTEGER, b TEXT, PRIMARY KEY (a))").unwrap();
2761 match stmt {
2762 Statement::CreateTable(ct) => {
2763 assert_eq!(ct.primary_key, vec!["a"]);
2764 assert!(ct.columns[0].is_primary_key);
2765 assert!(!ct.columns[0].nullable);
2766 }
2767 _ => panic!("expected CreateTable"),
2768 }
2769 }
2770}