1use sqlparser::ast as sp;
4use sqlparser::dialect::GenericDialect;
5use sqlparser::parser::Parser;
6
7use crate::error::{Result, SqlError};
8use crate::types::{DataType, Value};
9
10#[derive(Debug, Clone)]
11pub enum Statement {
12 CreateTable(CreateTableStmt),
13 DropTable(DropTableStmt),
14 CreateIndex(CreateIndexStmt),
15 DropIndex(DropIndexStmt),
16 CreateView(CreateViewStmt),
17 DropView(DropViewStmt),
18 AlterTable(Box<AlterTableStmt>),
19 Insert(InsertStmt),
20 Select(Box<SelectQuery>),
21 Update(UpdateStmt),
22 Delete(DeleteStmt),
23 Begin,
24 Commit,
25 Rollback,
26 Savepoint(String),
27 ReleaseSavepoint(String),
28 RollbackTo(String),
29 SetTimezone(String),
30 Explain(Box<Statement>),
31}
32
33#[derive(Debug, Clone)]
34pub struct AlterTableStmt {
35 pub table: String,
36 pub op: AlterTableOp,
37}
38
39#[derive(Debug, Clone)]
40pub enum AlterTableOp {
41 AddColumn {
42 column: Box<ColumnSpec>,
43 foreign_key: Option<ForeignKeyDef>,
44 if_not_exists: bool,
45 },
46 DropColumn {
47 name: String,
48 if_exists: bool,
49 },
50 RenameColumn {
51 old_name: String,
52 new_name: String,
53 },
54 RenameTable {
55 new_name: String,
56 },
57}
58
59#[derive(Debug, Clone)]
60pub struct CreateTableStmt {
61 pub name: String,
62 pub columns: Vec<ColumnSpec>,
63 pub primary_key: Vec<String>,
64 pub if_not_exists: bool,
65 pub check_constraints: Vec<TableCheckConstraint>,
66 pub foreign_keys: Vec<ForeignKeyDef>,
67 pub unique_indices: Vec<UniqueIndexDef>,
68}
69
70#[derive(Debug, Clone)]
71pub struct UniqueIndexDef {
72 pub name: Option<String>,
73 pub columns: Vec<String>,
74}
75
76#[derive(Debug, Clone)]
77pub struct TableCheckConstraint {
78 pub name: Option<String>,
79 pub expr: Expr,
80 pub sql: String,
81}
82
83#[derive(Debug, Clone)]
84pub struct ForeignKeyDef {
85 pub name: Option<String>,
86 pub columns: Vec<String>,
87 pub foreign_table: String,
88 pub referred_columns: Vec<String>,
89}
90
91#[derive(Debug, Clone)]
92pub struct ColumnSpec {
93 pub name: String,
94 pub data_type: DataType,
95 pub nullable: bool,
96 pub is_primary_key: bool,
97 pub default_expr: Option<Expr>,
98 pub default_sql: Option<String>,
99 pub check_expr: Option<Expr>,
100 pub check_sql: Option<String>,
101 pub check_name: Option<String>,
102}
103
104#[derive(Debug, Clone)]
105pub struct DropTableStmt {
106 pub name: String,
107 pub if_exists: bool,
108}
109
110#[derive(Debug, Clone)]
111pub struct CreateIndexStmt {
112 pub index_name: String,
113 pub table_name: String,
114 pub columns: Vec<String>,
115 pub unique: bool,
116 pub if_not_exists: bool,
117}
118
119#[derive(Debug, Clone)]
120pub struct DropIndexStmt {
121 pub index_name: String,
122 pub if_exists: bool,
123}
124
125#[derive(Debug, Clone)]
126pub struct CreateViewStmt {
127 pub name: String,
128 pub sql: String,
129 pub column_aliases: Vec<String>,
130 pub or_replace: bool,
131 pub if_not_exists: bool,
132}
133
134#[derive(Debug, Clone)]
135pub struct DropViewStmt {
136 pub name: String,
137 pub if_exists: bool,
138}
139
140#[derive(Debug, Clone)]
141pub enum InsertSource {
142 Values(Vec<Vec<Expr>>),
143 Select(Box<SelectQuery>),
144}
145
146#[derive(Debug, Clone)]
147pub struct InsertStmt {
148 pub table: String,
149 pub columns: Vec<String>,
150 pub source: InsertSource,
151 pub on_conflict: Option<OnConflictClause>,
152}
153
154#[derive(Debug, Clone)]
155pub struct OnConflictClause {
156 pub target: Option<ConflictTarget>,
157 pub action: OnConflictAction,
158}
159
160#[derive(Debug, Clone)]
161pub enum ConflictTarget {
162 Columns(Vec<String>),
163 Constraint(String),
164}
165
166#[derive(Debug, Clone)]
167pub enum OnConflictAction {
168 DoNothing,
169 DoUpdate {
170 assignments: Vec<(String, Expr)>,
171 where_clause: Option<Expr>,
172 },
173}
174
175#[derive(Debug, Clone)]
176pub struct TableRef {
177 pub name: String,
178 pub alias: Option<String>,
179}
180
181#[derive(Debug, Clone, Copy, PartialEq)]
182pub enum JoinType {
183 Inner,
184 Cross,
185 Left,
186 Right,
187}
188
189#[derive(Debug, Clone)]
190pub struct JoinClause {
191 pub join_type: JoinType,
192 pub table: TableRef,
193 pub on_clause: Option<Expr>,
194}
195
196#[derive(Debug, Clone)]
197pub struct SelectStmt {
198 pub columns: Vec<SelectColumn>,
199 pub from: String,
200 pub from_alias: Option<String>,
201 pub joins: Vec<JoinClause>,
202 pub distinct: bool,
203 pub where_clause: Option<Expr>,
204 pub order_by: Vec<OrderByItem>,
205 pub limit: Option<Expr>,
206 pub offset: Option<Expr>,
207 pub group_by: Vec<Expr>,
208 pub having: Option<Expr>,
209}
210
211#[derive(Debug, Clone)]
212pub enum SetOp {
213 Union,
214 Intersect,
215 Except,
216}
217
218#[derive(Debug, Clone)]
219pub struct CompoundSelect {
220 pub op: SetOp,
221 pub all: bool,
222 pub left: Box<QueryBody>,
223 pub right: Box<QueryBody>,
224 pub order_by: Vec<OrderByItem>,
225 pub limit: Option<Expr>,
226 pub offset: Option<Expr>,
227}
228
229#[derive(Debug, Clone)]
230pub enum QueryBody {
231 Select(Box<SelectStmt>),
232 Compound(Box<CompoundSelect>),
233}
234
235#[derive(Debug, Clone)]
236pub struct CteDefinition {
237 pub name: String,
238 pub column_aliases: Vec<String>,
239 pub body: QueryBody,
240}
241
242#[derive(Debug, Clone)]
243pub struct SelectQuery {
244 pub ctes: Vec<CteDefinition>,
245 pub recursive: bool,
246 pub body: QueryBody,
247}
248
249#[derive(Debug, Clone)]
250pub struct UpdateStmt {
251 pub table: String,
252 pub assignments: Vec<(String, Expr)>,
253 pub where_clause: Option<Expr>,
254}
255
256#[derive(Debug, Clone)]
257pub struct DeleteStmt {
258 pub table: String,
259 pub where_clause: Option<Expr>,
260}
261
262#[derive(Debug, Clone)]
263pub enum SelectColumn {
264 AllColumns,
265 Expr { expr: Expr, alias: Option<String> },
266}
267
268#[derive(Debug, Clone)]
269pub struct OrderByItem {
270 pub expr: Expr,
271 pub descending: bool,
272 pub nulls_first: Option<bool>,
273}
274
275#[derive(Debug, Clone)]
276pub enum Expr {
277 Literal(Value),
278 Column(String),
279 QualifiedColumn {
280 table: String,
281 column: String,
282 },
283 BinaryOp {
284 left: Box<Expr>,
285 op: BinOp,
286 right: Box<Expr>,
287 },
288 UnaryOp {
289 op: UnaryOp,
290 expr: Box<Expr>,
291 },
292 IsNull(Box<Expr>),
293 IsNotNull(Box<Expr>),
294 Function {
295 name: String,
296 args: Vec<Expr>,
297 distinct: bool,
299 },
300 CountStar,
301 InSubquery {
302 expr: Box<Expr>,
303 subquery: Box<SelectStmt>,
304 negated: bool,
305 },
306 InList {
307 expr: Box<Expr>,
308 list: Vec<Expr>,
309 negated: bool,
310 },
311 Exists {
312 subquery: Box<SelectStmt>,
313 negated: bool,
314 },
315 ScalarSubquery(Box<SelectStmt>),
316 InSet {
317 expr: Box<Expr>,
318 values: std::collections::HashSet<Value>,
319 has_null: bool,
320 negated: bool,
321 },
322 Between {
323 expr: Box<Expr>,
324 low: Box<Expr>,
325 high: Box<Expr>,
326 negated: bool,
327 },
328 Like {
329 expr: Box<Expr>,
330 pattern: Box<Expr>,
331 escape: Option<Box<Expr>>,
332 negated: bool,
333 },
334 Case {
335 operand: Option<Box<Expr>>,
336 conditions: Vec<(Expr, Expr)>,
337 else_result: Option<Box<Expr>>,
338 },
339 Coalesce(Vec<Expr>),
340 Cast {
341 expr: Box<Expr>,
342 data_type: DataType,
343 },
344 Parameter(usize),
345 WindowFunction {
346 name: String,
347 args: Vec<Expr>,
348 spec: WindowSpec,
349 },
350}
351
352#[derive(Debug, Clone)]
353pub struct WindowSpec {
354 pub partition_by: Vec<Expr>,
355 pub order_by: Vec<OrderByItem>,
356 pub frame: Option<WindowFrame>,
357}
358
359#[derive(Debug, Clone)]
360pub struct WindowFrame {
361 pub units: WindowFrameUnits,
362 pub start: WindowFrameBound,
363 pub end: WindowFrameBound,
364}
365
366#[derive(Debug, Clone, Copy)]
367pub enum WindowFrameUnits {
368 Rows,
369 Range,
370 Groups,
371}
372
373#[derive(Debug, Clone)]
374pub enum WindowFrameBound {
375 UnboundedPreceding,
376 Preceding(Box<Expr>),
377 CurrentRow,
378 Following(Box<Expr>),
379 UnboundedFollowing,
380}
381
382#[derive(Debug, Clone, Copy, PartialEq, Eq)]
383pub enum BinOp {
384 Add,
385 Sub,
386 Mul,
387 Div,
388 Mod,
389 Eq,
390 NotEq,
391 Lt,
392 Gt,
393 LtEq,
394 GtEq,
395 And,
396 Or,
397 Concat,
398}
399
400#[derive(Debug, Clone, Copy, PartialEq, Eq)]
401pub enum UnaryOp {
402 Neg,
403 Not,
404}
405
406pub fn has_subquery(expr: &Expr) -> bool {
407 match expr {
408 Expr::InSubquery { .. } | Expr::Exists { .. } | Expr::ScalarSubquery(_) => true,
409 Expr::BinaryOp { left, right, .. } => has_subquery(left) || has_subquery(right),
410 Expr::UnaryOp { expr, .. } => has_subquery(expr),
411 Expr::IsNull(e) | Expr::IsNotNull(e) => has_subquery(e),
412 Expr::InList { expr, list, .. } => has_subquery(expr) || list.iter().any(has_subquery),
413 Expr::InSet { expr, .. } => has_subquery(expr),
414 Expr::Between {
415 expr, low, high, ..
416 } => has_subquery(expr) || has_subquery(low) || has_subquery(high),
417 Expr::Like {
418 expr,
419 pattern,
420 escape,
421 ..
422 } => {
423 has_subquery(expr)
424 || has_subquery(pattern)
425 || escape.as_ref().is_some_and(|e| has_subquery(e))
426 }
427 Expr::Case {
428 operand,
429 conditions,
430 else_result,
431 } => {
432 operand.as_ref().is_some_and(|e| has_subquery(e))
433 || conditions
434 .iter()
435 .any(|(c, r)| has_subquery(c) || has_subquery(r))
436 || else_result.as_ref().is_some_and(|e| has_subquery(e))
437 }
438 Expr::Coalesce(args) | Expr::Function { args, .. } => args.iter().any(has_subquery),
439 Expr::Cast { expr, .. } => has_subquery(expr),
440 _ => false,
441 }
442}
443
444pub fn parse_sql_expr(sql: &str) -> Result<Expr> {
447 let dialect = GenericDialect {};
448 let mut parser = Parser::new(&dialect)
449 .try_with_sql(sql)
450 .map_err(|e| SqlError::Parse(e.to_string()))?;
451 let sp_expr = parser
452 .parse_expr()
453 .map_err(|e| SqlError::Parse(e.to_string()))?;
454 convert_expr(&sp_expr)
455}
456
457pub fn parse_sql(sql: &str) -> Result<Statement> {
458 let dialect = GenericDialect {};
459 let stmts = Parser::parse_sql(&dialect, sql).map_err(|e| SqlError::Parse(e.to_string()))?;
460
461 if stmts.is_empty() {
462 return Err(SqlError::Parse("empty SQL".into()));
463 }
464 if stmts.len() > 1 {
465 return Err(SqlError::Unsupported("multiple statements".into()));
466 }
467
468 convert_statement(stmts.into_iter().next().unwrap())
469}
470
471pub fn parse_sql_multi(sql: &str) -> Result<Vec<Statement>> {
473 let dialect = GenericDialect {};
474 let stmts = Parser::parse_sql(&dialect, sql).map_err(|e| SqlError::Parse(e.to_string()))?;
475
476 if stmts.is_empty() {
477 return Err(SqlError::Parse("empty SQL".into()));
478 }
479
480 stmts.into_iter().map(convert_statement).collect()
481}
482
483pub fn count_params(stmt: &Statement) -> usize {
485 let mut max_idx = 0usize;
486 visit_exprs_stmt(stmt, &mut |e| {
487 if let Expr::Parameter(n) = e {
488 max_idx = max_idx.max(*n);
489 }
490 });
491 max_idx
492}
493
494fn visit_exprs_stmt(stmt: &Statement, visitor: &mut impl FnMut(&Expr)) {
495 match stmt {
496 Statement::Select(sq) => {
497 for cte in &sq.ctes {
498 visit_exprs_query_body(&cte.body, visitor);
499 }
500 visit_exprs_query_body(&sq.body, visitor);
501 }
502 Statement::Insert(ins) => match &ins.source {
503 InsertSource::Values(rows) => {
504 for row in rows {
505 for e in row {
506 visit_expr(e, visitor);
507 }
508 }
509 }
510 InsertSource::Select(sq) => {
511 for cte in &sq.ctes {
512 visit_exprs_query_body(&cte.body, visitor);
513 }
514 visit_exprs_query_body(&sq.body, visitor);
515 }
516 },
517 Statement::Update(upd) => {
518 for (_, e) in &upd.assignments {
519 visit_expr(e, visitor);
520 }
521 if let Some(w) = &upd.where_clause {
522 visit_expr(w, visitor);
523 }
524 }
525 Statement::Delete(del) => {
526 if let Some(w) = &del.where_clause {
527 visit_expr(w, visitor);
528 }
529 }
530 Statement::Explain(inner) => visit_exprs_stmt(inner, visitor),
531 _ => {}
532 }
533}
534
535fn visit_exprs_query_body(body: &QueryBody, visitor: &mut impl FnMut(&Expr)) {
536 match body {
537 QueryBody::Select(sel) => visit_exprs_select(sel, visitor),
538 QueryBody::Compound(comp) => {
539 visit_exprs_query_body(&comp.left, visitor);
540 visit_exprs_query_body(&comp.right, visitor);
541 for o in &comp.order_by {
542 visit_expr(&o.expr, visitor);
543 }
544 if let Some(l) = &comp.limit {
545 visit_expr(l, visitor);
546 }
547 if let Some(o) = &comp.offset {
548 visit_expr(o, visitor);
549 }
550 }
551 }
552}
553
554fn visit_exprs_select(sel: &SelectStmt, visitor: &mut impl FnMut(&Expr)) {
555 for col in &sel.columns {
556 if let SelectColumn::Expr { expr, .. } = col {
557 visit_expr(expr, visitor);
558 }
559 }
560 for j in &sel.joins {
561 if let Some(on) = &j.on_clause {
562 visit_expr(on, visitor);
563 }
564 }
565 if let Some(w) = &sel.where_clause {
566 visit_expr(w, visitor);
567 }
568 for o in &sel.order_by {
569 visit_expr(&o.expr, visitor);
570 }
571 if let Some(l) = &sel.limit {
572 visit_expr(l, visitor);
573 }
574 if let Some(o) = &sel.offset {
575 visit_expr(o, visitor);
576 }
577 for g in &sel.group_by {
578 visit_expr(g, visitor);
579 }
580 if let Some(h) = &sel.having {
581 visit_expr(h, visitor);
582 }
583}
584
585fn visit_expr(expr: &Expr, visitor: &mut impl FnMut(&Expr)) {
586 visitor(expr);
587 match expr {
588 Expr::BinaryOp { left, right, .. } => {
589 visit_expr(left, visitor);
590 visit_expr(right, visitor);
591 }
592 Expr::UnaryOp { expr: e, .. } | Expr::IsNull(e) | Expr::IsNotNull(e) => {
593 visit_expr(e, visitor);
594 }
595 Expr::Function { args, .. } | Expr::Coalesce(args) => {
596 for a in args {
597 visit_expr(a, visitor);
598 }
599 }
600 Expr::InSubquery {
601 expr: e, subquery, ..
602 } => {
603 visit_expr(e, visitor);
604 visit_exprs_select(subquery, visitor);
605 }
606 Expr::InList { expr: e, list, .. } => {
607 visit_expr(e, visitor);
608 for l in list {
609 visit_expr(l, visitor);
610 }
611 }
612 Expr::Exists { subquery, .. } => visit_exprs_select(subquery, visitor),
613 Expr::ScalarSubquery(sq) => visit_exprs_select(sq, visitor),
614 Expr::InSet { expr: e, .. } => visit_expr(e, visitor),
615 Expr::Between {
616 expr: e, low, high, ..
617 } => {
618 visit_expr(e, visitor);
619 visit_expr(low, visitor);
620 visit_expr(high, visitor);
621 }
622 Expr::Like {
623 expr: e,
624 pattern,
625 escape,
626 ..
627 } => {
628 visit_expr(e, visitor);
629 visit_expr(pattern, visitor);
630 if let Some(esc) = escape {
631 visit_expr(esc, visitor);
632 }
633 }
634 Expr::Case {
635 operand,
636 conditions,
637 else_result,
638 } => {
639 if let Some(op) = operand {
640 visit_expr(op, visitor);
641 }
642 for (cond, then) in conditions {
643 visit_expr(cond, visitor);
644 visit_expr(then, visitor);
645 }
646 if let Some(el) = else_result {
647 visit_expr(el, visitor);
648 }
649 }
650 Expr::Cast { expr: e, .. } => visit_expr(e, visitor),
651 Expr::WindowFunction { args, spec, .. } => {
652 for a in args {
653 visit_expr(a, visitor);
654 }
655 for p in &spec.partition_by {
656 visit_expr(p, visitor);
657 }
658 for o in &spec.order_by {
659 visit_expr(&o.expr, visitor);
660 }
661 if let Some(ref frame) = spec.frame {
662 if let WindowFrameBound::Preceding(e) | WindowFrameBound::Following(e) =
663 &frame.start
664 {
665 visit_expr(e, visitor);
666 }
667 if let WindowFrameBound::Preceding(e) | WindowFrameBound::Following(e) = &frame.end
668 {
669 visit_expr(e, visitor);
670 }
671 }
672 }
673 Expr::Literal(_)
674 | Expr::Column(_)
675 | Expr::QualifiedColumn { .. }
676 | Expr::CountStar
677 | Expr::Parameter(_) => {}
678 }
679}
680
681fn convert_statement(stmt: sp::Statement) -> Result<Statement> {
682 match stmt {
683 sp::Statement::CreateTable(ct) => convert_create_table(ct),
684 sp::Statement::CreateIndex(ci) => convert_create_index(ci),
685 sp::Statement::Drop {
686 object_type: sp::ObjectType::Table,
687 if_exists,
688 names,
689 ..
690 } => {
691 if names.len() != 1 {
692 return Err(SqlError::Unsupported("multi-table DROP".into()));
693 }
694 Ok(Statement::DropTable(DropTableStmt {
695 name: object_name_to_string(&names[0]),
696 if_exists,
697 }))
698 }
699 sp::Statement::Drop {
700 object_type: sp::ObjectType::Index,
701 if_exists,
702 names,
703 ..
704 } => {
705 if names.len() != 1 {
706 return Err(SqlError::Unsupported("multi-index DROP".into()));
707 }
708 Ok(Statement::DropIndex(DropIndexStmt {
709 index_name: object_name_to_string(&names[0]),
710 if_exists,
711 }))
712 }
713 sp::Statement::CreateView(cv) => convert_create_view(cv),
714 sp::Statement::Drop {
715 object_type: sp::ObjectType::View,
716 if_exists,
717 names,
718 ..
719 } => {
720 if names.len() != 1 {
721 return Err(SqlError::Unsupported("multi-view DROP".into()));
722 }
723 Ok(Statement::DropView(DropViewStmt {
724 name: object_name_to_string(&names[0]),
725 if_exists,
726 }))
727 }
728 sp::Statement::AlterTable(at) => convert_alter_table(at),
729 sp::Statement::Insert(insert) => convert_insert(insert),
730 sp::Statement::Query(query) => convert_query(*query),
731 sp::Statement::Update(update) => convert_update(update),
732 sp::Statement::Delete(delete) => convert_delete(delete),
733 sp::Statement::StartTransaction { .. } => Ok(Statement::Begin),
734 sp::Statement::Commit { chain: true, .. } => {
735 Err(SqlError::Unsupported("COMMIT AND CHAIN".into()))
736 }
737 sp::Statement::Commit { .. } => Ok(Statement::Commit),
738 sp::Statement::Rollback { chain: true, .. } => {
739 Err(SqlError::Unsupported("ROLLBACK AND CHAIN".into()))
740 }
741 sp::Statement::Rollback {
742 savepoint: Some(name),
743 ..
744 } => Ok(Statement::RollbackTo(name.value.to_ascii_lowercase())),
745 sp::Statement::Rollback { .. } => Ok(Statement::Rollback),
746 sp::Statement::Savepoint { name } => {
747 Ok(Statement::Savepoint(name.value.to_ascii_lowercase()))
748 }
749 sp::Statement::ReleaseSavepoint { name } => {
750 Ok(Statement::ReleaseSavepoint(name.value.to_ascii_lowercase()))
751 }
752 sp::Statement::Set(sp::Set::SetTimeZone { value, .. }) => {
753 let zone = match value {
755 sp::Expr::Value(v) => match &v.value {
756 sp::Value::SingleQuotedString(s) => s.clone(),
757 sp::Value::DoubleQuotedString(s) => s.clone(),
758 other => other.to_string(),
759 },
760 sp::Expr::Identifier(ident) => ident.value.clone(),
761 other => {
762 return Err(SqlError::Parse(format!(
763 "SET TIME ZONE expects a string literal or identifier, got: {other}"
764 )))
765 }
766 };
767 Ok(Statement::SetTimezone(zone))
768 }
769 sp::Statement::Explain {
770 statement, analyze, ..
771 } => {
772 if analyze {
773 return Err(SqlError::Unsupported("EXPLAIN ANALYZE".into()));
774 }
775 let inner = convert_statement(*statement)?;
776 Ok(Statement::Explain(Box::new(inner)))
777 }
778 _ => Err(SqlError::Unsupported(format!("statement type: {}", stmt))),
779 }
780}
781
782fn convert_column_def(
785 col_def: &sp::ColumnDef,
786) -> Result<(ColumnSpec, Option<ForeignKeyDef>, bool, bool)> {
787 let col_name = col_def.name.value.clone();
788 let data_type = convert_data_type(&col_def.data_type)?;
789 let mut nullable = true;
790 let mut is_primary_key = false;
791 let mut is_unique = false;
792 let mut default_expr = None;
793 let mut default_sql = None;
794 let mut check_expr = None;
795 let mut check_sql = None;
796 let mut check_name = None;
797 let mut fk_def = None;
798
799 for opt in &col_def.options {
800 match &opt.option {
801 sp::ColumnOption::NotNull => nullable = false,
802 sp::ColumnOption::Null => nullable = true,
803 sp::ColumnOption::PrimaryKey(_) => {
804 is_primary_key = true;
805 nullable = false;
806 }
807 sp::ColumnOption::Unique(_) => is_unique = true,
808 sp::ColumnOption::Default(expr) => {
809 default_sql = Some(expr.to_string());
810 default_expr = Some(convert_expr(expr)?);
811 }
812 sp::ColumnOption::Check(check) => {
813 check_sql = Some(check.expr.to_string());
814 let converted = convert_expr(&check.expr)?;
815 if has_subquery(&converted) {
816 return Err(SqlError::Unsupported("subquery in CHECK constraint".into()));
817 }
818 check_expr = Some(converted);
819 check_name = check.name.as_ref().map(|n| n.value.clone());
820 }
821 sp::ColumnOption::ForeignKey(fk) => {
822 convert_fk_actions(&fk.on_delete, &fk.on_update)?;
823 let ftable = object_name_to_string(&fk.foreign_table).to_ascii_lowercase();
824 let referred: Vec<String> = fk
825 .referred_columns
826 .iter()
827 .map(|i| i.value.to_ascii_lowercase())
828 .collect();
829 fk_def = Some(ForeignKeyDef {
830 name: fk.name.as_ref().map(|n| n.value.clone()),
831 columns: vec![col_name.to_ascii_lowercase()],
832 foreign_table: ftable,
833 referred_columns: referred,
834 });
835 }
836 _ => {}
837 }
838 }
839
840 let spec = ColumnSpec {
841 name: col_name,
842 data_type,
843 nullable,
844 is_primary_key,
845 default_expr,
846 default_sql,
847 check_expr,
848 check_sql,
849 check_name,
850 };
851 Ok((spec, fk_def, is_primary_key, is_unique))
852}
853
854fn convert_create_table(ct: sp::CreateTable) -> Result<Statement> {
855 let name = object_name_to_string(&ct.name);
856 let if_not_exists = ct.if_not_exists;
857
858 let mut columns = Vec::new();
859 let mut inline_pk: Vec<String> = Vec::new();
860 let mut foreign_keys: Vec<ForeignKeyDef> = Vec::new();
861 let mut unique_indices: Vec<UniqueIndexDef> = Vec::new();
862
863 for col_def in &ct.columns {
864 let (spec, fk_def, was_pk, was_unique) = convert_column_def(col_def)?;
865 if was_pk {
866 inline_pk.push(spec.name.clone());
867 }
868 if let Some(fk) = fk_def {
869 foreign_keys.push(fk);
870 }
871 if was_unique && !was_pk {
872 unique_indices.push(UniqueIndexDef {
873 name: None,
874 columns: vec![spec.name.to_ascii_lowercase()],
875 });
876 }
877 columns.push(spec);
878 }
879
880 let mut check_constraints: Vec<TableCheckConstraint> = Vec::new();
881
882 for constraint in &ct.constraints {
883 match constraint {
884 sp::TableConstraint::PrimaryKey(pk_constraint) => {
885 for idx_col in &pk_constraint.columns {
886 let col_name = match &idx_col.column.expr {
887 sp::Expr::Identifier(ident) => ident.value.clone(),
888 _ => continue,
889 };
890 if !inline_pk.contains(&col_name) {
891 inline_pk.push(col_name.clone());
892 }
893 if let Some(col) = columns.iter_mut().find(|c| c.name == col_name) {
894 col.nullable = false;
895 col.is_primary_key = true;
896 }
897 }
898 }
899 sp::TableConstraint::Check(check) => {
900 let sql = check.expr.to_string();
901 let converted = convert_expr(&check.expr)?;
902 if has_subquery(&converted) {
903 return Err(SqlError::Unsupported("subquery in CHECK constraint".into()));
904 }
905 check_constraints.push(TableCheckConstraint {
906 name: check.name.as_ref().map(|n| n.value.clone()),
907 expr: converted,
908 sql,
909 });
910 }
911 sp::TableConstraint::ForeignKey(fk) => {
912 convert_fk_actions(&fk.on_delete, &fk.on_update)?;
913 let cols: Vec<String> = fk
914 .columns
915 .iter()
916 .map(|i| i.value.to_ascii_lowercase())
917 .collect();
918 let ftable = object_name_to_string(&fk.foreign_table).to_ascii_lowercase();
919 let referred: Vec<String> = fk
920 .referred_columns
921 .iter()
922 .map(|i| i.value.to_ascii_lowercase())
923 .collect();
924 foreign_keys.push(ForeignKeyDef {
925 name: fk.name.as_ref().map(|n| n.value.clone()),
926 columns: cols,
927 foreign_table: ftable,
928 referred_columns: referred,
929 });
930 }
931 sp::TableConstraint::Unique(u) => {
932 let cols: Vec<String> = u
933 .columns
934 .iter()
935 .filter_map(|idx_col| match &idx_col.column.expr {
936 sp::Expr::Identifier(ident) => Some(ident.value.to_ascii_lowercase()),
937 _ => None,
938 })
939 .collect();
940 if !cols.is_empty() {
941 unique_indices.push(UniqueIndexDef {
942 name: u.name.as_ref().map(|n| n.value.clone()),
943 columns: cols,
944 });
945 }
946 }
947 _ => {}
948 }
949 }
950
951 Ok(Statement::CreateTable(CreateTableStmt {
952 name,
953 columns,
954 primary_key: inline_pk,
955 if_not_exists,
956 check_constraints,
957 foreign_keys,
958 unique_indices,
959 }))
960}
961
962fn convert_alter_table(at: sp::AlterTable) -> Result<Statement> {
963 let table = object_name_to_string(&at.name);
964 if at.operations.len() != 1 {
965 return Err(SqlError::Unsupported(
966 "ALTER TABLE with multiple operations".into(),
967 ));
968 }
969 let op = match at.operations.into_iter().next().unwrap() {
970 sp::AlterTableOperation::AddColumn {
971 column_def,
972 if_not_exists,
973 ..
974 } => {
975 let (spec, fk, _was_pk, _was_unique) = convert_column_def(&column_def)?;
976 AlterTableOp::AddColumn {
977 column: Box::new(spec),
978 foreign_key: fk,
979 if_not_exists,
980 }
981 }
982 sp::AlterTableOperation::DropColumn {
983 column_names,
984 if_exists,
985 ..
986 } => {
987 if column_names.len() != 1 {
988 return Err(SqlError::Unsupported(
989 "DROP COLUMN with multiple columns".into(),
990 ));
991 }
992 AlterTableOp::DropColumn {
993 name: column_names.into_iter().next().unwrap().value,
994 if_exists,
995 }
996 }
997 sp::AlterTableOperation::RenameColumn {
998 old_column_name,
999 new_column_name,
1000 } => AlterTableOp::RenameColumn {
1001 old_name: old_column_name.value,
1002 new_name: new_column_name.value,
1003 },
1004 sp::AlterTableOperation::RenameTable { table_name } => {
1005 let new_name = match table_name {
1006 sp::RenameTableNameKind::To(name) | sp::RenameTableNameKind::As(name) => {
1007 object_name_to_string(&name)
1008 }
1009 };
1010 AlterTableOp::RenameTable { new_name }
1011 }
1012 other => {
1013 return Err(SqlError::Unsupported(format!(
1014 "ALTER TABLE operation: {other}"
1015 )));
1016 }
1017 };
1018 Ok(Statement::AlterTable(Box::new(AlterTableStmt {
1019 table,
1020 op,
1021 })))
1022}
1023
1024fn convert_fk_actions(
1025 on_delete: &Option<sp::ReferentialAction>,
1026 on_update: &Option<sp::ReferentialAction>,
1027) -> Result<()> {
1028 for action in [on_delete, on_update] {
1029 match action {
1030 None
1031 | Some(sp::ReferentialAction::Restrict)
1032 | Some(sp::ReferentialAction::NoAction) => {}
1033 Some(other) => {
1034 return Err(SqlError::Unsupported(format!(
1035 "FOREIGN KEY action: {other}"
1036 )));
1037 }
1038 }
1039 }
1040 Ok(())
1041}
1042
1043fn convert_create_index(ci: sp::CreateIndex) -> Result<Statement> {
1044 let index_name = ci
1045 .name
1046 .as_ref()
1047 .map(object_name_to_string)
1048 .ok_or_else(|| SqlError::Parse("index name required".into()))?;
1049
1050 let table_name = object_name_to_string(&ci.table_name);
1051
1052 let columns: Vec<String> = ci
1053 .columns
1054 .iter()
1055 .map(|idx_col| match &idx_col.column.expr {
1056 sp::Expr::Identifier(ident) => Ok(ident.value.clone()),
1057 other => Err(SqlError::Unsupported(format!("expression index: {other}"))),
1058 })
1059 .collect::<Result<_>>()?;
1060
1061 if columns.is_empty() {
1062 return Err(SqlError::Parse(
1063 "index must have at least one column".into(),
1064 ));
1065 }
1066
1067 Ok(Statement::CreateIndex(CreateIndexStmt {
1068 index_name,
1069 table_name,
1070 columns,
1071 unique: ci.unique,
1072 if_not_exists: ci.if_not_exists,
1073 }))
1074}
1075
1076fn convert_create_view(cv: sp::CreateView) -> Result<Statement> {
1077 let name = object_name_to_string(&cv.name);
1078
1079 if cv.materialized {
1080 return Err(SqlError::Unsupported("MATERIALIZED VIEW".into()));
1081 }
1082
1083 let sql = cv.query.to_string();
1084
1085 let dialect = GenericDialect {};
1086 let test = Parser::parse_sql(&dialect, &sql).map_err(|e| SqlError::Parse(e.to_string()))?;
1087 if test.is_empty() {
1088 return Err(SqlError::Parse("empty view definition".into()));
1089 }
1090 match &test[0] {
1091 sp::Statement::Query(_) => {}
1092 _ => {
1093 return Err(SqlError::Parse(
1094 "view body must be a SELECT statement".into(),
1095 ))
1096 }
1097 }
1098
1099 let column_aliases: Vec<String> = cv
1100 .columns
1101 .iter()
1102 .map(|c| c.name.value.to_ascii_lowercase())
1103 .collect();
1104
1105 Ok(Statement::CreateView(CreateViewStmt {
1106 name,
1107 sql,
1108 column_aliases,
1109 or_replace: cv.or_replace,
1110 if_not_exists: cv.if_not_exists,
1111 }))
1112}
1113
1114fn convert_insert(insert: sp::Insert) -> Result<Statement> {
1115 let table = match &insert.table {
1116 sp::TableObject::TableName(name) => object_name_to_string(name).to_ascii_lowercase(),
1117 _ => return Err(SqlError::Unsupported("INSERT into non-table object".into())),
1118 };
1119
1120 let columns: Vec<String> = insert
1121 .columns
1122 .iter()
1123 .map(|c| c.value.to_ascii_lowercase())
1124 .collect();
1125
1126 let query = insert
1127 .source
1128 .ok_or_else(|| SqlError::Parse("INSERT requires VALUES or SELECT".into()))?;
1129
1130 let source = match *query.body {
1131 sp::SetExpr::Values(sp::Values { rows, .. }) => {
1132 let mut result = Vec::new();
1133 for row in rows {
1134 let mut exprs = Vec::new();
1135 for expr in row {
1136 exprs.push(convert_expr(&expr)?);
1137 }
1138 result.push(exprs);
1139 }
1140 InsertSource::Values(result)
1141 }
1142 _ => {
1143 let (ctes, recursive) = if let Some(ref with) = query.with {
1144 convert_with(with)?
1145 } else {
1146 (vec![], false)
1147 };
1148 let body = convert_query_body(&query)?;
1149 InsertSource::Select(Box::new(SelectQuery {
1150 ctes,
1151 recursive,
1152 body,
1153 }))
1154 }
1155 };
1156
1157 let on_conflict = insert.on.as_ref().map(convert_on_insert).transpose()?;
1158
1159 Ok(Statement::Insert(InsertStmt {
1160 table,
1161 columns,
1162 source,
1163 on_conflict,
1164 }))
1165}
1166
1167fn convert_on_insert(on: &sp::OnInsert) -> Result<OnConflictClause> {
1168 match on {
1169 sp::OnInsert::OnConflict(oc) => {
1170 let target = oc
1171 .conflict_target
1172 .as_ref()
1173 .map(convert_conflict_target)
1174 .transpose()?;
1175 let action = convert_on_conflict_action(&oc.action)?;
1176 Ok(OnConflictClause { target, action })
1177 }
1178 sp::OnInsert::DuplicateKeyUpdate(_) => Err(SqlError::Parse(
1179 "ON DUPLICATE KEY UPDATE is MySQL-specific; use ON CONFLICT".into(),
1180 )),
1181 _ => Err(SqlError::Parse("unsupported ON INSERT clause".into())),
1182 }
1183}
1184
1185fn convert_conflict_target(target: &sp::ConflictTarget) -> Result<ConflictTarget> {
1186 match target {
1187 sp::ConflictTarget::Columns(cols) => Ok(ConflictTarget::Columns(
1188 cols.iter().map(|c| c.value.to_ascii_lowercase()).collect(),
1189 )),
1190 sp::ConflictTarget::OnConstraint(name) => {
1191 if name.0.len() > 1 {
1192 return Err(SqlError::Parse(
1193 "qualified constraint names not supported".into(),
1194 ));
1195 }
1196 Ok(ConflictTarget::Constraint(
1197 object_name_to_string(name).to_ascii_lowercase(),
1198 ))
1199 }
1200 }
1201}
1202
1203fn convert_on_conflict_action(action: &sp::OnConflictAction) -> Result<OnConflictAction> {
1204 match action {
1205 sp::OnConflictAction::DoNothing => Ok(OnConflictAction::DoNothing),
1206 sp::OnConflictAction::DoUpdate(du) => {
1207 let assignments = du
1208 .assignments
1209 .iter()
1210 .map(|a| {
1211 let col = match &a.target {
1212 sp::AssignmentTarget::ColumnName(name) => {
1213 object_name_to_string(name).to_ascii_lowercase()
1214 }
1215 _ => {
1216 return Err(SqlError::Unsupported(
1217 "tuple assignment in ON CONFLICT".into(),
1218 ))
1219 }
1220 };
1221 let expr = convert_expr(&a.value)?;
1222 Ok((col, expr))
1223 })
1224 .collect::<Result<_>>()?;
1225 let where_clause = du.selection.as_ref().map(convert_expr).transpose()?;
1226 Ok(OnConflictAction::DoUpdate {
1227 assignments,
1228 where_clause,
1229 })
1230 }
1231 }
1232}
1233
1234fn convert_select_body(select: &sp::Select) -> Result<SelectStmt> {
1235 let distinct = match &select.distinct {
1236 Some(sp::Distinct::Distinct) => true,
1237 Some(sp::Distinct::On(_)) => {
1238 return Err(SqlError::Unsupported("DISTINCT ON".into()));
1239 }
1240 _ => false,
1241 };
1242
1243 let (from, from_alias, joins) = if select.from.is_empty() {
1244 (String::new(), None, vec![])
1245 } else if select.from.len() == 1 {
1246 let table_with_joins = &select.from[0];
1247 let (name, alias) = match &table_with_joins.relation {
1248 sp::TableFactor::Table { name, alias, .. } => {
1249 let table_name = object_name_to_string(name);
1250 let alias_str = alias.as_ref().map(|a| a.name.value.clone());
1251 (table_name, alias_str)
1252 }
1253 _ => return Err(SqlError::Unsupported("non-table FROM source".into())),
1254 };
1255 let j = table_with_joins
1256 .joins
1257 .iter()
1258 .map(convert_join)
1259 .collect::<Result<Vec<_>>>()?;
1260 (name, alias, j)
1261 } else {
1262 return Err(SqlError::Unsupported("comma-separated FROM tables".into()));
1263 };
1264
1265 let columns: Vec<SelectColumn> = select
1266 .projection
1267 .iter()
1268 .map(convert_select_item)
1269 .collect::<Result<_>>()?;
1270
1271 let where_clause = select.selection.as_ref().map(convert_expr).transpose()?;
1272
1273 let group_by = match &select.group_by {
1274 sp::GroupByExpr::Expressions(exprs, _) => {
1275 exprs.iter().map(convert_expr).collect::<Result<_>>()?
1276 }
1277 sp::GroupByExpr::All(_) => {
1278 return Err(SqlError::Unsupported("GROUP BY ALL".into()));
1279 }
1280 };
1281
1282 let having = select.having.as_ref().map(convert_expr).transpose()?;
1283
1284 Ok(SelectStmt {
1285 columns,
1286 from,
1287 from_alias,
1288 joins,
1289 distinct,
1290 where_clause,
1291 order_by: vec![],
1292 limit: None,
1293 offset: None,
1294 group_by,
1295 having,
1296 })
1297}
1298
1299fn convert_set_expr(set_expr: &sp::SetExpr) -> Result<QueryBody> {
1300 match set_expr {
1301 sp::SetExpr::Select(sel) => Ok(QueryBody::Select(Box::new(convert_select_body(sel)?))),
1302 sp::SetExpr::SetOperation {
1303 op,
1304 set_quantifier,
1305 left,
1306 right,
1307 } => {
1308 let set_op = match op {
1309 sp::SetOperator::Union => SetOp::Union,
1310 sp::SetOperator::Intersect => SetOp::Intersect,
1311 sp::SetOperator::Except | sp::SetOperator::Minus => SetOp::Except,
1312 };
1313 let all = match set_quantifier {
1314 sp::SetQuantifier::All => true,
1315 sp::SetQuantifier::None | sp::SetQuantifier::Distinct => false,
1316 _ => {
1317 return Err(SqlError::Unsupported("BY NAME set operations".into()));
1318 }
1319 };
1320 Ok(QueryBody::Compound(Box::new(CompoundSelect {
1321 op: set_op,
1322 all,
1323 left: Box::new(convert_set_expr(left)?),
1324 right: Box::new(convert_set_expr(right)?),
1325 order_by: vec![],
1326 limit: None,
1327 offset: None,
1328 })))
1329 }
1330 _ => Err(SqlError::Unsupported("unsupported set expression".into())),
1331 }
1332}
1333
1334fn convert_query_body(query: &sp::Query) -> Result<QueryBody> {
1335 let mut body = convert_set_expr(&query.body)?;
1336
1337 let order_by = if let Some(ref ob) = query.order_by {
1338 match &ob.kind {
1339 sp::OrderByKind::Expressions(exprs) => exprs
1340 .iter()
1341 .map(convert_order_by_expr)
1342 .collect::<Result<_>>()?,
1343 sp::OrderByKind::All { .. } => {
1344 return Err(SqlError::Unsupported("ORDER BY ALL".into()));
1345 }
1346 }
1347 } else {
1348 vec![]
1349 };
1350
1351 let (limit, offset) = match &query.limit_clause {
1352 Some(sp::LimitClause::LimitOffset { limit, offset, .. }) => {
1353 let l = limit.as_ref().map(convert_expr).transpose()?;
1354 let o = offset
1355 .as_ref()
1356 .map(|o| convert_expr(&o.value))
1357 .transpose()?;
1358 (l, o)
1359 }
1360 Some(sp::LimitClause::OffsetCommaLimit { limit, offset }) => {
1361 let l = Some(convert_expr(limit)?);
1362 let o = Some(convert_expr(offset)?);
1363 (l, o)
1364 }
1365 None => (None, None),
1366 };
1367
1368 match &mut body {
1369 QueryBody::Select(sel) => {
1370 sel.order_by = order_by;
1371 sel.limit = limit;
1372 sel.offset = offset;
1373 }
1374 QueryBody::Compound(comp) => {
1375 comp.order_by = order_by;
1376 comp.limit = limit;
1377 comp.offset = offset;
1378 }
1379 }
1380
1381 Ok(body)
1382}
1383
1384fn convert_subquery(query: &sp::Query) -> Result<SelectStmt> {
1385 if query.with.is_some() {
1386 return Err(SqlError::Unsupported("CTEs in subqueries".into()));
1387 }
1388 match convert_query_body(query)? {
1389 QueryBody::Select(s) => Ok(*s),
1390 QueryBody::Compound(_) => Err(SqlError::Unsupported(
1391 "UNION/INTERSECT/EXCEPT in subqueries".into(),
1392 )),
1393 }
1394}
1395
1396fn convert_with(with: &sp::With) -> Result<(Vec<CteDefinition>, bool)> {
1397 let mut names = std::collections::HashSet::new();
1398 let mut ctes = Vec::new();
1399 for cte in &with.cte_tables {
1400 let name = cte.alias.name.value.to_ascii_lowercase();
1401 if !names.insert(name.clone()) {
1402 return Err(SqlError::DuplicateCteName(name));
1403 }
1404 let column_aliases: Vec<String> = cte
1405 .alias
1406 .columns
1407 .iter()
1408 .map(|c| c.name.value.to_ascii_lowercase())
1409 .collect();
1410 let body = convert_query_body(&cte.query)?;
1411 ctes.push(CteDefinition {
1412 name,
1413 column_aliases,
1414 body,
1415 });
1416 }
1417 Ok((ctes, with.recursive))
1418}
1419
1420fn convert_query(query: sp::Query) -> Result<Statement> {
1421 let (ctes, recursive) = if let Some(ref with) = query.with {
1422 convert_with(with)?
1423 } else {
1424 (vec![], false)
1425 };
1426 let body = convert_query_body(&query)?;
1427 Ok(Statement::Select(Box::new(SelectQuery {
1428 ctes,
1429 recursive,
1430 body,
1431 })))
1432}
1433
1434fn convert_join(join: &sp::Join) -> Result<JoinClause> {
1435 let (join_type, constraint) = match &join.join_operator {
1436 sp::JoinOperator::Inner(c) => (JoinType::Inner, Some(c)),
1437 sp::JoinOperator::Join(c) => (JoinType::Inner, Some(c)),
1438 sp::JoinOperator::CrossJoin(c) => (JoinType::Cross, Some(c)),
1439 sp::JoinOperator::Left(c) => (JoinType::Left, Some(c)),
1440 sp::JoinOperator::LeftSemi(c) => (JoinType::Left, Some(c)),
1441 sp::JoinOperator::LeftAnti(c) => (JoinType::Left, Some(c)),
1442 sp::JoinOperator::Right(c) => (JoinType::Right, Some(c)),
1443 sp::JoinOperator::RightSemi(c) => (JoinType::Right, Some(c)),
1444 sp::JoinOperator::RightAnti(c) => (JoinType::Right, Some(c)),
1445 other => return Err(SqlError::Unsupported(format!("join type: {other:?}"))),
1446 };
1447
1448 let (name, alias) = match &join.relation {
1449 sp::TableFactor::Table { name, alias, .. } => {
1450 let table_name = object_name_to_string(name);
1451 let alias_str = alias.as_ref().map(|a| a.name.value.clone());
1452 (table_name, alias_str)
1453 }
1454 _ => return Err(SqlError::Unsupported("non-table JOIN source".into())),
1455 };
1456
1457 let on_clause = match constraint {
1458 Some(sp::JoinConstraint::On(expr)) => Some(convert_expr(expr)?),
1459 Some(sp::JoinConstraint::None) | None => None,
1460 Some(other) => return Err(SqlError::Unsupported(format!("join constraint: {other:?}"))),
1461 };
1462
1463 Ok(JoinClause {
1464 join_type,
1465 table: TableRef { name, alias },
1466 on_clause,
1467 })
1468}
1469
1470fn convert_update(update: sp::Update) -> Result<Statement> {
1471 let table = match &update.table.relation {
1472 sp::TableFactor::Table { name, .. } => object_name_to_string(name),
1473 _ => return Err(SqlError::Unsupported("non-table UPDATE target".into())),
1474 };
1475
1476 let assignments = update
1477 .assignments
1478 .iter()
1479 .map(|a| {
1480 let col = match &a.target {
1481 sp::AssignmentTarget::ColumnName(name) => object_name_to_string(name),
1482 _ => return Err(SqlError::Unsupported("tuple assignment".into())),
1483 };
1484 let expr = convert_expr(&a.value)?;
1485 Ok((col, expr))
1486 })
1487 .collect::<Result<_>>()?;
1488
1489 let where_clause = update.selection.as_ref().map(convert_expr).transpose()?;
1490
1491 Ok(Statement::Update(UpdateStmt {
1492 table,
1493 assignments,
1494 where_clause,
1495 }))
1496}
1497
1498fn convert_delete(delete: sp::Delete) -> Result<Statement> {
1499 let table_name = match &delete.from {
1500 sp::FromTable::WithFromKeyword(tables) => {
1501 if tables.len() != 1 {
1502 return Err(SqlError::Unsupported("multi-table DELETE".into()));
1503 }
1504 match &tables[0].relation {
1505 sp::TableFactor::Table { name, .. } => object_name_to_string(name),
1506 _ => return Err(SqlError::Unsupported("non-table DELETE target".into())),
1507 }
1508 }
1509 sp::FromTable::WithoutKeyword(tables) => {
1510 if tables.len() != 1 {
1511 return Err(SqlError::Unsupported("multi-table DELETE".into()));
1512 }
1513 match &tables[0].relation {
1514 sp::TableFactor::Table { name, .. } => object_name_to_string(name),
1515 _ => return Err(SqlError::Unsupported("non-table DELETE target".into())),
1516 }
1517 }
1518 };
1519
1520 let where_clause = delete.selection.as_ref().map(convert_expr).transpose()?;
1521
1522 Ok(Statement::Delete(DeleteStmt {
1523 table: table_name,
1524 where_clause,
1525 }))
1526}
1527
1528fn convert_expr(expr: &sp::Expr) -> Result<Expr> {
1529 match expr {
1530 sp::Expr::Value(v) => convert_value(&v.value),
1531 sp::Expr::Identifier(ident) => Ok(Expr::Column(ident.value.to_ascii_lowercase())),
1532 sp::Expr::CompoundIdentifier(parts) => {
1533 if parts.len() == 2 {
1534 Ok(Expr::QualifiedColumn {
1535 table: parts[0].value.to_ascii_lowercase(),
1536 column: parts[1].value.to_ascii_lowercase(),
1537 })
1538 } else {
1539 Ok(Expr::Column(
1540 parts.last().unwrap().value.to_ascii_lowercase(),
1541 ))
1542 }
1543 }
1544 sp::Expr::BinaryOp { left, op, right } => {
1545 let bin_op = convert_bin_op(op)?;
1546 Ok(Expr::BinaryOp {
1547 left: Box::new(convert_expr(left)?),
1548 op: bin_op,
1549 right: Box::new(convert_expr(right)?),
1550 })
1551 }
1552 sp::Expr::UnaryOp { op, expr } => {
1553 let unary_op = match op {
1554 sp::UnaryOperator::Minus => UnaryOp::Neg,
1555 sp::UnaryOperator::Not => UnaryOp::Not,
1556 _ => return Err(SqlError::Unsupported(format!("unary op: {op}"))),
1557 };
1558 Ok(Expr::UnaryOp {
1559 op: unary_op,
1560 expr: Box::new(convert_expr(expr)?),
1561 })
1562 }
1563 sp::Expr::IsNull(e) => Ok(Expr::IsNull(Box::new(convert_expr(e)?))),
1564 sp::Expr::IsNotNull(e) => Ok(Expr::IsNotNull(Box::new(convert_expr(e)?))),
1565 sp::Expr::Nested(e) => convert_expr(e),
1566 sp::Expr::Function(func) => convert_function(func),
1567 sp::Expr::InSubquery {
1568 expr: e,
1569 subquery,
1570 negated,
1571 } => {
1572 let inner_expr = convert_expr(e)?;
1573 let stmt = convert_subquery(subquery)?;
1574 Ok(Expr::InSubquery {
1575 expr: Box::new(inner_expr),
1576 subquery: Box::new(stmt),
1577 negated: *negated,
1578 })
1579 }
1580 sp::Expr::InList {
1581 expr: e,
1582 list,
1583 negated,
1584 } => {
1585 let inner_expr = convert_expr(e)?;
1586 let items = list.iter().map(convert_expr).collect::<Result<Vec<_>>>()?;
1587 Ok(Expr::InList {
1588 expr: Box::new(inner_expr),
1589 list: items,
1590 negated: *negated,
1591 })
1592 }
1593 sp::Expr::Exists { subquery, negated } => {
1594 let stmt = convert_subquery(subquery)?;
1595 Ok(Expr::Exists {
1596 subquery: Box::new(stmt),
1597 negated: *negated,
1598 })
1599 }
1600 sp::Expr::Subquery(query) => {
1601 let stmt = convert_subquery(query)?;
1602 Ok(Expr::ScalarSubquery(Box::new(stmt)))
1603 }
1604 sp::Expr::Between {
1605 expr: e,
1606 negated,
1607 low,
1608 high,
1609 } => Ok(Expr::Between {
1610 expr: Box::new(convert_expr(e)?),
1611 low: Box::new(convert_expr(low)?),
1612 high: Box::new(convert_expr(high)?),
1613 negated: *negated,
1614 }),
1615 sp::Expr::Like {
1616 expr: e,
1617 negated,
1618 pattern,
1619 escape_char,
1620 ..
1621 } => {
1622 let esc = escape_char
1623 .as_ref()
1624 .map(convert_escape_value)
1625 .transpose()?
1626 .map(Box::new);
1627 Ok(Expr::Like {
1628 expr: Box::new(convert_expr(e)?),
1629 pattern: Box::new(convert_expr(pattern)?),
1630 escape: esc,
1631 negated: *negated,
1632 })
1633 }
1634 sp::Expr::ILike {
1635 expr: e,
1636 negated,
1637 pattern,
1638 escape_char,
1639 ..
1640 } => {
1641 let esc = escape_char
1642 .as_ref()
1643 .map(convert_escape_value)
1644 .transpose()?
1645 .map(Box::new);
1646 Ok(Expr::Like {
1647 expr: Box::new(convert_expr(e)?),
1648 pattern: Box::new(convert_expr(pattern)?),
1649 escape: esc,
1650 negated: *negated,
1651 })
1652 }
1653 sp::Expr::Case {
1654 operand,
1655 conditions,
1656 else_result,
1657 ..
1658 } => {
1659 let op = operand
1660 .as_ref()
1661 .map(|e| convert_expr(e))
1662 .transpose()?
1663 .map(Box::new);
1664 let conds: Vec<(Expr, Expr)> = conditions
1665 .iter()
1666 .map(|cw| Ok((convert_expr(&cw.condition)?, convert_expr(&cw.result)?)))
1667 .collect::<Result<_>>()?;
1668 let else_r = else_result
1669 .as_ref()
1670 .map(|e| convert_expr(e))
1671 .transpose()?
1672 .map(Box::new);
1673 Ok(Expr::Case {
1674 operand: op,
1675 conditions: conds,
1676 else_result: else_r,
1677 })
1678 }
1679 sp::Expr::Cast {
1680 expr: e,
1681 data_type: dt,
1682 ..
1683 } => {
1684 let target = convert_data_type(dt)?;
1685 Ok(Expr::Cast {
1686 expr: Box::new(convert_expr(e)?),
1687 data_type: target,
1688 })
1689 }
1690 sp::Expr::Substring {
1691 expr: e,
1692 substring_from,
1693 substring_for,
1694 ..
1695 } => {
1696 let mut args = vec![convert_expr(e)?];
1697 if let Some(from) = substring_from {
1698 args.push(convert_expr(from)?);
1699 }
1700 if let Some(f) = substring_for {
1701 args.push(convert_expr(f)?);
1702 }
1703 Ok(Expr::Function {
1704 name: "SUBSTR".into(),
1705 args,
1706 distinct: false,
1707 })
1708 }
1709 sp::Expr::Trim {
1710 expr: e,
1711 trim_where,
1712 trim_what,
1713 trim_characters,
1714 } => {
1715 let fn_name = match trim_where {
1716 Some(sp::TrimWhereField::Leading) => "LTRIM",
1717 Some(sp::TrimWhereField::Trailing) => "RTRIM",
1718 _ => "TRIM",
1719 };
1720 let mut args = vec![convert_expr(e)?];
1721 if let Some(what) = trim_what {
1722 args.push(convert_expr(what)?);
1723 } else if let Some(chars) = trim_characters {
1724 if let Some(first) = chars.first() {
1725 args.push(convert_expr(first)?);
1726 }
1727 }
1728 Ok(Expr::Function {
1729 name: fn_name.into(),
1730 args,
1731 distinct: false,
1732 })
1733 }
1734 sp::Expr::Ceil { expr: e, .. } => Ok(Expr::Function {
1735 name: "CEIL".into(),
1736 args: vec![convert_expr(e)?],
1737 distinct: false,
1738 }),
1739 sp::Expr::Floor { expr: e, .. } => Ok(Expr::Function {
1740 name: "FLOOR".into(),
1741 args: vec![convert_expr(e)?],
1742 distinct: false,
1743 }),
1744 sp::Expr::Position { expr: e, r#in } => Ok(Expr::Function {
1745 name: "INSTR".into(),
1746 args: vec![convert_expr(r#in)?, convert_expr(e)?],
1747 distinct: false,
1748 }),
1749 sp::Expr::TypedString(ts) => {
1751 let raw = match &ts.value.value {
1752 sp::Value::SingleQuotedString(s) => s.clone(),
1753 sp::Value::DoubleQuotedString(s) => s.clone(),
1754 other => other.to_string(),
1755 };
1756 convert_typed_string(&ts.data_type, &raw)
1757 }
1758 sp::Expr::Interval(iv) => convert_interval_expr(iv),
1760 sp::Expr::Extract { field, expr: e, .. } => {
1762 let field_name = match field {
1763 sp::DateTimeField::Year => "year",
1764 sp::DateTimeField::Month => "month",
1765 sp::DateTimeField::Week(_) => "week",
1766 sp::DateTimeField::Day => "day",
1767 sp::DateTimeField::Date => "day",
1768 sp::DateTimeField::Hour => "hour",
1769 sp::DateTimeField::Minute => "minute",
1770 sp::DateTimeField::Second => "second",
1771 sp::DateTimeField::Millisecond => "milliseconds",
1772 sp::DateTimeField::Microsecond => "microseconds",
1773 sp::DateTimeField::Microseconds => "microseconds",
1774 sp::DateTimeField::Milliseconds => "milliseconds",
1775 sp::DateTimeField::Dow => "dow",
1776 sp::DateTimeField::Isodow => "isodow",
1777 sp::DateTimeField::Doy => "doy",
1778 sp::DateTimeField::Epoch => "epoch",
1779 sp::DateTimeField::Quarter => "quarter",
1780 sp::DateTimeField::Decade => "decade",
1781 sp::DateTimeField::Century => "century",
1782 sp::DateTimeField::Millennium => "millennium",
1783 sp::DateTimeField::Isoyear => "isoyear",
1784 sp::DateTimeField::Julian => "julian",
1785 other => {
1786 return Err(SqlError::InvalidExtractField(format!("{other:?}")));
1787 }
1788 };
1789 Ok(Expr::Function {
1790 name: "EXTRACT".into(),
1791 args: vec![
1792 Expr::Literal(Value::Text(field_name.into())),
1793 convert_expr(e)?,
1794 ],
1795 distinct: false,
1796 })
1797 }
1798 sp::Expr::AtTimeZone {
1800 timestamp,
1801 time_zone,
1802 } => Ok(Expr::Function {
1803 name: "AT_TIMEZONE".into(),
1804 args: vec![convert_expr(timestamp)?, convert_expr(time_zone)?],
1805 distinct: false,
1806 }),
1807 _ => Err(SqlError::Unsupported(format!("expression: {expr}"))),
1808 }
1809}
1810
1811fn convert_value(val: &sp::Value) -> Result<Expr> {
1812 match val {
1813 sp::Value::Number(n, _) => {
1814 if let Ok(i) = n.parse::<i64>() {
1815 Ok(Expr::Literal(Value::Integer(i)))
1816 } else if let Ok(f) = n.parse::<f64>() {
1817 Ok(Expr::Literal(Value::Real(f)))
1818 } else {
1819 Err(SqlError::InvalidValue(format!("cannot parse number: {n}")))
1820 }
1821 }
1822 sp::Value::SingleQuotedString(s) => Ok(Expr::Literal(Value::Text(s.as_str().into()))),
1823 sp::Value::Boolean(b) => Ok(Expr::Literal(Value::Boolean(*b))),
1824 sp::Value::Null => Ok(Expr::Literal(Value::Null)),
1825 sp::Value::Placeholder(s) => {
1826 let idx_str = s
1827 .strip_prefix('$')
1828 .ok_or_else(|| SqlError::Parse(format!("invalid placeholder: {s}")))?;
1829 let idx: usize = idx_str
1830 .parse()
1831 .map_err(|_| SqlError::Parse(format!("invalid placeholder index: {s}")))?;
1832 if idx == 0 {
1833 return Err(SqlError::Parse("placeholder index must be >= 1".into()));
1834 }
1835 Ok(Expr::Parameter(idx))
1836 }
1837 _ => Err(SqlError::Unsupported(format!("value type: {val}"))),
1838 }
1839}
1840
1841fn convert_typed_string(dt: &sp::DataType, value: &str) -> Result<Expr> {
1842 let s = value.trim_matches('\'');
1843 match dt {
1844 sp::DataType::Date => {
1845 let d = crate::datetime::parse_date(s)?;
1846 Ok(Expr::Literal(Value::Date(d)))
1847 }
1848 sp::DataType::Time(_, _) => {
1849 let t = crate::datetime::parse_time(s)?;
1850 Ok(Expr::Literal(Value::Time(t)))
1851 }
1852 sp::DataType::Timestamp(_, _) => {
1853 let t = crate::datetime::parse_timestamp(s)?;
1854 Ok(Expr::Literal(Value::Timestamp(t)))
1855 }
1856 sp::DataType::Interval { .. } => {
1857 let (months, days, micros) = crate::datetime::parse_interval(s)?;
1858 Ok(Expr::Literal(Value::Interval {
1859 months,
1860 days,
1861 micros,
1862 }))
1863 }
1864 _ => {
1865 let target = convert_data_type(dt)?;
1866 Ok(Expr::Cast {
1867 expr: Box::new(Expr::Literal(Value::Text(s.into()))),
1868 data_type: target,
1869 })
1870 }
1871 }
1872}
1873
1874fn convert_interval_expr(iv: &sp::Interval) -> Result<Expr> {
1875 let raw = match iv.value.as_ref() {
1876 sp::Expr::Value(v) => match &v.value {
1877 sp::Value::SingleQuotedString(s) => s.clone(),
1878 sp::Value::Number(n, _) => n.clone(),
1879 other => {
1880 return Err(SqlError::InvalidIntervalLiteral(format!(
1881 "unsupported inner value: {other}"
1882 )))
1883 }
1884 },
1885 other => {
1886 return Err(SqlError::InvalidIntervalLiteral(format!(
1887 "unsupported inner expr: {other}"
1888 )))
1889 }
1890 };
1891
1892 let with_unit = if let Some(field) = &iv.leading_field {
1894 let unit_name = match field {
1895 sp::DateTimeField::Year => "years",
1896 sp::DateTimeField::Month => "months",
1897 sp::DateTimeField::Week(_) => "weeks",
1898 sp::DateTimeField::Day => "days",
1899 sp::DateTimeField::Hour => "hours",
1900 sp::DateTimeField::Minute => "minutes",
1901 sp::DateTimeField::Second => "seconds",
1902 _ => {
1903 return Err(SqlError::InvalidIntervalLiteral(format!(
1904 "unsupported leading field: {field:?}"
1905 )))
1906 }
1907 };
1908 format!("{raw} {unit_name}")
1909 } else {
1910 raw
1911 };
1912
1913 let (months, days, micros) = crate::datetime::parse_interval(&with_unit)?;
1914 Ok(Expr::Literal(Value::Interval {
1915 months,
1916 days,
1917 micros,
1918 }))
1919}
1920
1921fn convert_escape_value(val: &sp::Value) -> Result<Expr> {
1922 match val {
1923 sp::Value::SingleQuotedString(s) => Ok(Expr::Literal(Value::Text(s.as_str().into()))),
1924 _ => Err(SqlError::Unsupported(format!("ESCAPE value: {val}"))),
1925 }
1926}
1927
1928fn convert_bin_op(op: &sp::BinaryOperator) -> Result<BinOp> {
1929 match op {
1930 sp::BinaryOperator::Plus => Ok(BinOp::Add),
1931 sp::BinaryOperator::Minus => Ok(BinOp::Sub),
1932 sp::BinaryOperator::Multiply => Ok(BinOp::Mul),
1933 sp::BinaryOperator::Divide => Ok(BinOp::Div),
1934 sp::BinaryOperator::Modulo => Ok(BinOp::Mod),
1935 sp::BinaryOperator::Eq => Ok(BinOp::Eq),
1936 sp::BinaryOperator::NotEq => Ok(BinOp::NotEq),
1937 sp::BinaryOperator::Lt => Ok(BinOp::Lt),
1938 sp::BinaryOperator::Gt => Ok(BinOp::Gt),
1939 sp::BinaryOperator::LtEq => Ok(BinOp::LtEq),
1940 sp::BinaryOperator::GtEq => Ok(BinOp::GtEq),
1941 sp::BinaryOperator::And => Ok(BinOp::And),
1942 sp::BinaryOperator::Or => Ok(BinOp::Or),
1943 sp::BinaryOperator::StringConcat => Ok(BinOp::Concat),
1944 _ => Err(SqlError::Unsupported(format!("binary op: {op}"))),
1945 }
1946}
1947
1948fn convert_function(func: &sp::Function) -> Result<Expr> {
1949 let name = object_name_to_string(&func.name).to_ascii_uppercase();
1950
1951 let (args, is_count_star, distinct) = match &func.args {
1952 sp::FunctionArguments::List(list) => {
1953 let distinct = matches!(
1954 list.duplicate_treatment,
1955 Some(sp::DuplicateTreatment::Distinct)
1956 );
1957 if list.args.is_empty() && name == "COUNT" {
1958 (vec![], true, distinct)
1959 } else {
1960 let mut count_star = false;
1961 let args = list
1962 .args
1963 .iter()
1964 .map(|arg| match arg {
1965 sp::FunctionArg::Unnamed(sp::FunctionArgExpr::Expr(e)) => convert_expr(e),
1966 sp::FunctionArg::Unnamed(sp::FunctionArgExpr::Wildcard) => {
1967 if name == "COUNT" {
1968 count_star = true;
1969 Ok(Expr::CountStar)
1970 } else {
1971 Err(SqlError::Unsupported(format!("{name}(*)")))
1972 }
1973 }
1974 _ => Err(SqlError::Unsupported(format!(
1975 "function arg type in {name}"
1976 ))),
1977 })
1978 .collect::<Result<Vec<_>>>()?;
1979 if name == "COUNT" && args.len() == 1 && count_star {
1980 (vec![], true, distinct)
1981 } else {
1982 (args, false, distinct)
1983 }
1984 }
1985 }
1986 sp::FunctionArguments::None => {
1987 if name == "COUNT" {
1988 (vec![], true, false)
1989 } else {
1990 (vec![], false, false)
1991 }
1992 }
1993 sp::FunctionArguments::Subquery(_) => {
1994 return Err(SqlError::Unsupported("subquery in function".into()));
1995 }
1996 };
1997
1998 if let Some(over) = &func.over {
2000 let spec = match over {
2001 sp::WindowType::WindowSpec(ws) => convert_window_spec(ws)?,
2002 sp::WindowType::NamedWindow(_) => {
2003 return Err(SqlError::Unsupported("named windows".into()));
2004 }
2005 };
2006 return Ok(Expr::WindowFunction { name, args, spec });
2007 }
2008
2009 if is_count_star {
2011 return Ok(Expr::CountStar);
2012 }
2013
2014 if name == "COALESCE" {
2015 if args.is_empty() {
2016 return Err(SqlError::Parse(
2017 "COALESCE requires at least one argument".into(),
2018 ));
2019 }
2020 return Ok(Expr::Coalesce(args));
2021 }
2022
2023 if name == "NULLIF" {
2024 if args.len() != 2 {
2025 return Err(SqlError::Parse(
2026 "NULLIF requires exactly two arguments".into(),
2027 ));
2028 }
2029 return Ok(Expr::Case {
2030 operand: None,
2031 conditions: vec![(
2032 Expr::BinaryOp {
2033 left: Box::new(args[0].clone()),
2034 op: BinOp::Eq,
2035 right: Box::new(args[1].clone()),
2036 },
2037 Expr::Literal(Value::Null),
2038 )],
2039 else_result: Some(Box::new(args[0].clone())),
2040 });
2041 }
2042
2043 if name == "IIF" {
2044 if args.len() != 3 {
2045 return Err(SqlError::Parse(
2046 "IIF requires exactly three arguments".into(),
2047 ));
2048 }
2049 return Ok(Expr::Case {
2050 operand: None,
2051 conditions: vec![(args[0].clone(), args[1].clone())],
2052 else_result: Some(Box::new(args[2].clone())),
2053 });
2054 }
2055
2056 Ok(Expr::Function {
2057 name,
2058 args,
2059 distinct,
2060 })
2061}
2062
2063fn convert_window_spec(ws: &sp::WindowSpec) -> Result<WindowSpec> {
2064 let partition_by = ws
2065 .partition_by
2066 .iter()
2067 .map(convert_expr)
2068 .collect::<Result<Vec<_>>>()?;
2069 let order_by = ws
2070 .order_by
2071 .iter()
2072 .map(convert_order_by_expr)
2073 .collect::<Result<Vec<_>>>()?;
2074 let frame = ws
2075 .window_frame
2076 .as_ref()
2077 .map(convert_window_frame)
2078 .transpose()?;
2079 Ok(WindowSpec {
2080 partition_by,
2081 order_by,
2082 frame,
2083 })
2084}
2085
2086fn convert_window_frame(wf: &sp::WindowFrame) -> Result<WindowFrame> {
2087 let units = match wf.units {
2088 sp::WindowFrameUnits::Rows => WindowFrameUnits::Rows,
2089 sp::WindowFrameUnits::Range => WindowFrameUnits::Range,
2090 sp::WindowFrameUnits::Groups => {
2091 return Err(SqlError::Unsupported("GROUPS window frame".into()));
2092 }
2093 };
2094 let start = convert_window_frame_bound(&wf.start_bound)?;
2095 let end = match &wf.end_bound {
2096 Some(b) => convert_window_frame_bound(b)?,
2097 None => WindowFrameBound::CurrentRow,
2098 };
2099 Ok(WindowFrame { units, start, end })
2100}
2101
2102fn convert_window_frame_bound(b: &sp::WindowFrameBound) -> Result<WindowFrameBound> {
2103 match b {
2104 sp::WindowFrameBound::CurrentRow => Ok(WindowFrameBound::CurrentRow),
2105 sp::WindowFrameBound::Preceding(None) => Ok(WindowFrameBound::UnboundedPreceding),
2106 sp::WindowFrameBound::Preceding(Some(e)) => {
2107 Ok(WindowFrameBound::Preceding(Box::new(convert_expr(e)?)))
2108 }
2109 sp::WindowFrameBound::Following(None) => Ok(WindowFrameBound::UnboundedFollowing),
2110 sp::WindowFrameBound::Following(Some(e)) => {
2111 Ok(WindowFrameBound::Following(Box::new(convert_expr(e)?)))
2112 }
2113 }
2114}
2115
2116fn convert_select_item(item: &sp::SelectItem) -> Result<SelectColumn> {
2117 match item {
2118 sp::SelectItem::Wildcard(_) => Ok(SelectColumn::AllColumns),
2119 sp::SelectItem::UnnamedExpr(e) => {
2120 let expr = convert_expr(e)?;
2121 Ok(SelectColumn::Expr { expr, alias: None })
2122 }
2123 sp::SelectItem::ExprWithAlias { expr, alias } => {
2124 let expr = convert_expr(expr)?;
2125 Ok(SelectColumn::Expr {
2126 expr,
2127 alias: Some(alias.value.clone()),
2128 })
2129 }
2130 sp::SelectItem::QualifiedWildcard(_, _) => {
2131 Err(SqlError::Unsupported("qualified wildcard (table.*)".into()))
2132 }
2133 }
2134}
2135
2136fn convert_order_by_expr(expr: &sp::OrderByExpr) -> Result<OrderByItem> {
2137 let e = convert_expr(&expr.expr)?;
2138 let descending = expr.options.asc.map(|asc| !asc).unwrap_or(false);
2139 let nulls_first = expr.options.nulls_first;
2140
2141 Ok(OrderByItem {
2142 expr: e,
2143 descending,
2144 nulls_first,
2145 })
2146}
2147
2148fn convert_data_type(dt: &sp::DataType) -> Result<DataType> {
2149 match dt {
2150 sp::DataType::Int(_)
2151 | sp::DataType::Integer(_)
2152 | sp::DataType::BigInt(_)
2153 | sp::DataType::SmallInt(_)
2154 | sp::DataType::TinyInt(_)
2155 | sp::DataType::Int2(_)
2156 | sp::DataType::Int4(_)
2157 | sp::DataType::Int8(_) => Ok(DataType::Integer),
2158
2159 sp::DataType::Real
2160 | sp::DataType::Double(..)
2161 | sp::DataType::DoublePrecision
2162 | sp::DataType::Float(_)
2163 | sp::DataType::Float4
2164 | sp::DataType::Float64 => Ok(DataType::Real),
2165
2166 sp::DataType::Varchar(_)
2167 | sp::DataType::Text
2168 | sp::DataType::Char(_)
2169 | sp::DataType::Character(_)
2170 | sp::DataType::String(_) => Ok(DataType::Text),
2171
2172 sp::DataType::Blob(_) | sp::DataType::Bytea => Ok(DataType::Blob),
2173
2174 sp::DataType::Boolean | sp::DataType::Bool => Ok(DataType::Boolean),
2175
2176 sp::DataType::Date => Ok(DataType::Date),
2177 sp::DataType::Time(_, _) => Ok(DataType::Time),
2178 sp::DataType::Timestamp(_, _) => Ok(DataType::Timestamp),
2179 sp::DataType::Interval { .. } => Ok(DataType::Interval),
2180
2181 _ => Err(SqlError::Unsupported(format!("data type: {dt}"))),
2182 }
2183}
2184
2185fn object_name_to_string(name: &sp::ObjectName) -> String {
2186 name.0
2187 .iter()
2188 .filter_map(|p| match p {
2189 sp::ObjectNamePart::Identifier(ident) => Some(ident.value.clone()),
2190 _ => None,
2191 })
2192 .collect::<Vec<_>>()
2193 .join(".")
2194}
2195
2196#[cfg(test)]
2197mod tests {
2198 use super::*;
2199
2200 #[test]
2201 fn parse_create_table() {
2202 let stmt = parse_sql(
2203 "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL, age INTEGER)",
2204 )
2205 .unwrap();
2206
2207 match stmt {
2208 Statement::CreateTable(ct) => {
2209 assert_eq!(ct.name, "users");
2210 assert_eq!(ct.columns.len(), 3);
2211 assert_eq!(ct.columns[0].name, "id");
2212 assert_eq!(ct.columns[0].data_type, DataType::Integer);
2213 assert!(ct.columns[0].is_primary_key);
2214 assert!(!ct.columns[0].nullable);
2215 assert_eq!(ct.columns[1].name, "name");
2216 assert_eq!(ct.columns[1].data_type, DataType::Text);
2217 assert!(!ct.columns[1].nullable);
2218 assert_eq!(ct.columns[2].name, "age");
2219 assert!(ct.columns[2].nullable);
2220 assert_eq!(ct.primary_key, vec!["id"]);
2221 }
2222 _ => panic!("expected CreateTable"),
2223 }
2224 }
2225
2226 #[test]
2227 fn parse_create_table_if_not_exists() {
2228 let stmt = parse_sql("CREATE TABLE IF NOT EXISTS t (id INT PRIMARY KEY)").unwrap();
2229 match stmt {
2230 Statement::CreateTable(ct) => assert!(ct.if_not_exists),
2231 _ => panic!("expected CreateTable"),
2232 }
2233 }
2234
2235 #[test]
2236 fn parse_drop_table() {
2237 let stmt = parse_sql("DROP TABLE users").unwrap();
2238 match stmt {
2239 Statement::DropTable(dt) => {
2240 assert_eq!(dt.name, "users");
2241 assert!(!dt.if_exists);
2242 }
2243 _ => panic!("expected DropTable"),
2244 }
2245 }
2246
2247 #[test]
2248 fn parse_drop_table_if_exists() {
2249 let stmt = parse_sql("DROP TABLE IF EXISTS users").unwrap();
2250 match stmt {
2251 Statement::DropTable(dt) => assert!(dt.if_exists),
2252 _ => panic!("expected DropTable"),
2253 }
2254 }
2255
2256 #[test]
2257 fn parse_insert() {
2258 let stmt =
2259 parse_sql("INSERT INTO users (id, name) VALUES (1, 'Alice'), (2, 'Bob')").unwrap();
2260
2261 match stmt {
2262 Statement::Insert(ins) => {
2263 assert_eq!(ins.table, "users");
2264 assert_eq!(ins.columns, vec!["id", "name"]);
2265 let values = match &ins.source {
2266 InsertSource::Values(v) => v,
2267 _ => panic!("expected Values"),
2268 };
2269 assert_eq!(values.len(), 2);
2270 assert!(matches!(values[0][0], Expr::Literal(Value::Integer(1))));
2271 assert!(matches!(&values[0][1], Expr::Literal(Value::Text(s)) if s == "Alice"));
2272 assert!(ins.on_conflict.is_none());
2273 }
2274 _ => panic!("expected Insert"),
2275 }
2276 }
2277
2278 #[test]
2279 fn parse_upsert_do_nothing() {
2280 let stmt =
2281 parse_sql("INSERT INTO t (id, v) VALUES (1, 'a') ON CONFLICT (id) DO NOTHING").unwrap();
2282 match stmt {
2283 Statement::Insert(ins) => {
2284 let oc = ins.on_conflict.expect("expected on_conflict");
2285 match oc.target.expect("target") {
2286 ConflictTarget::Columns(cols) => assert_eq!(cols, vec!["id"]),
2287 _ => panic!("expected Columns target"),
2288 }
2289 assert!(matches!(oc.action, OnConflictAction::DoNothing));
2290 }
2291 _ => panic!("expected Insert"),
2292 }
2293 }
2294
2295 #[test]
2296 fn parse_upsert_do_nothing_no_target() {
2297 let stmt = parse_sql("INSERT INTO t VALUES (1, 'a') ON CONFLICT DO NOTHING").unwrap();
2298 match stmt {
2299 Statement::Insert(ins) => {
2300 let oc = ins.on_conflict.expect("expected on_conflict");
2301 assert!(oc.target.is_none());
2302 assert!(matches!(oc.action, OnConflictAction::DoNothing));
2303 }
2304 _ => panic!("expected Insert"),
2305 }
2306 }
2307
2308 #[test]
2309 fn parse_upsert_do_update_simple() {
2310 let stmt = parse_sql(
2311 "INSERT INTO t (id, v) VALUES (1, 'a') ON CONFLICT (id) DO UPDATE SET v = 'b'",
2312 )
2313 .unwrap();
2314 match stmt {
2315 Statement::Insert(ins) => {
2316 let oc = ins.on_conflict.expect("expected on_conflict");
2317 match oc.action {
2318 OnConflictAction::DoUpdate {
2319 assignments,
2320 where_clause,
2321 } => {
2322 assert_eq!(assignments.len(), 1);
2323 assert_eq!(assignments[0].0, "v");
2324 assert!(where_clause.is_none());
2325 }
2326 _ => panic!("expected DoUpdate"),
2327 }
2328 }
2329 _ => panic!("expected Insert"),
2330 }
2331 }
2332
2333 #[test]
2334 fn parse_upsert_do_update_excluded() {
2335 let stmt = parse_sql(
2336 "INSERT INTO t (id, v) VALUES (1, 'a') \
2337 ON CONFLICT (id) DO UPDATE SET v = excluded.v",
2338 )
2339 .unwrap();
2340 match stmt {
2341 Statement::Insert(ins) => {
2342 let oc = ins.on_conflict.expect("expected on_conflict");
2343 let assignments = match oc.action {
2344 OnConflictAction::DoUpdate { assignments, .. } => assignments,
2345 _ => panic!("expected DoUpdate"),
2346 };
2347 match &assignments[0].1 {
2348 Expr::QualifiedColumn { table, column } => {
2349 assert_eq!(table, "excluded");
2350 assert_eq!(column, "v");
2351 }
2352 _ => panic!("expected QualifiedColumn"),
2353 }
2354 }
2355 _ => panic!("expected Insert"),
2356 }
2357 }
2358
2359 #[test]
2360 fn parse_upsert_do_update_where() {
2361 let stmt = parse_sql(
2362 "INSERT INTO t (id, v) VALUES (1, 'a') \
2363 ON CONFLICT (id) DO UPDATE SET v = excluded.v WHERE t.v < 'z'",
2364 )
2365 .unwrap();
2366 match stmt {
2367 Statement::Insert(ins) => {
2368 let oc = ins.on_conflict.expect("expected on_conflict");
2369 match oc.action {
2370 OnConflictAction::DoUpdate { where_clause, .. } => {
2371 assert!(where_clause.is_some());
2372 }
2373 _ => panic!("expected DoUpdate"),
2374 }
2375 }
2376 _ => panic!("expected Insert"),
2377 }
2378 }
2379
2380 #[test]
2381 fn parse_upsert_on_constraint_named() {
2382 let stmt = parse_sql(
2383 "INSERT INTO t (id, v) VALUES (1, 'a') \
2384 ON CONFLICT ON CONSTRAINT t_v_idx DO NOTHING",
2385 )
2386 .unwrap();
2387 match stmt {
2388 Statement::Insert(ins) => {
2389 let oc = ins.on_conflict.expect("expected on_conflict");
2390 match oc.target.expect("target") {
2391 ConflictTarget::Constraint(name) => assert_eq!(name, "t_v_idx"),
2392 _ => panic!("expected Constraint target"),
2393 }
2394 }
2395 _ => panic!("expected Insert"),
2396 }
2397 }
2398
2399 #[test]
2400 fn parse_upsert_rejects_duplicate_key_update() {
2401 let err = parse_sql("INSERT INTO t (id) VALUES (1) ON DUPLICATE KEY UPDATE id = 2")
2402 .expect_err("should reject MySQL syntax");
2403 let msg = format!("{err}");
2404 assert!(msg.contains("ON DUPLICATE KEY UPDATE") || msg.contains("MySQL"));
2405 }
2406
2407 #[test]
2408 fn parse_upsert_lowercases_conflict_target() {
2409 let stmt = parse_sql("INSERT INTO t (id) VALUES (1) ON CONFLICT (ID) DO NOTHING").unwrap();
2410 match stmt {
2411 Statement::Insert(ins) => {
2412 let oc = ins.on_conflict.expect("expected on_conflict");
2413 match oc.target.expect("target") {
2414 ConflictTarget::Columns(cols) => assert_eq!(cols, vec!["id"]),
2415 _ => panic!("expected Columns"),
2416 }
2417 }
2418 _ => panic!("expected Insert"),
2419 }
2420 }
2421
2422 #[test]
2423 fn parse_select_all() {
2424 let stmt = parse_sql("SELECT * FROM users").unwrap();
2425 match stmt {
2426 Statement::Select(sq) => match sq.body {
2427 QueryBody::Select(sel) => {
2428 assert_eq!(sel.from, "users");
2429 assert!(matches!(sel.columns[0], SelectColumn::AllColumns));
2430 assert!(sel.where_clause.is_none());
2431 }
2432 _ => panic!("expected QueryBody::Select"),
2433 },
2434 _ => panic!("expected Select"),
2435 }
2436 }
2437
2438 #[test]
2439 fn parse_select_where() {
2440 let stmt = parse_sql("SELECT id, name FROM users WHERE age > 18").unwrap();
2441 match stmt {
2442 Statement::Select(sq) => match sq.body {
2443 QueryBody::Select(sel) => {
2444 assert_eq!(sel.columns.len(), 2);
2445 assert!(sel.where_clause.is_some());
2446 }
2447 _ => panic!("expected QueryBody::Select"),
2448 },
2449 _ => panic!("expected Select"),
2450 }
2451 }
2452
2453 #[test]
2454 fn parse_select_order_limit() {
2455 let stmt = parse_sql("SELECT * FROM users ORDER BY name ASC LIMIT 10 OFFSET 5").unwrap();
2456 match stmt {
2457 Statement::Select(sq) => match sq.body {
2458 QueryBody::Select(sel) => {
2459 assert_eq!(sel.order_by.len(), 1);
2460 assert!(!sel.order_by[0].descending);
2461 assert!(sel.limit.is_some());
2462 assert!(sel.offset.is_some());
2463 }
2464 _ => panic!("expected QueryBody::Select"),
2465 },
2466 _ => panic!("expected Select"),
2467 }
2468 }
2469
2470 #[test]
2471 fn parse_update() {
2472 let stmt = parse_sql("UPDATE users SET name = 'Bob' WHERE id = 1").unwrap();
2473 match stmt {
2474 Statement::Update(upd) => {
2475 assert_eq!(upd.table, "users");
2476 assert_eq!(upd.assignments.len(), 1);
2477 assert_eq!(upd.assignments[0].0, "name");
2478 assert!(upd.where_clause.is_some());
2479 }
2480 _ => panic!("expected Update"),
2481 }
2482 }
2483
2484 #[test]
2485 fn parse_delete() {
2486 let stmt = parse_sql("DELETE FROM users WHERE id = 1").unwrap();
2487 match stmt {
2488 Statement::Delete(del) => {
2489 assert_eq!(del.table, "users");
2490 assert!(del.where_clause.is_some());
2491 }
2492 _ => panic!("expected Delete"),
2493 }
2494 }
2495
2496 #[test]
2497 fn parse_aggregate() {
2498 let stmt = parse_sql("SELECT COUNT(*), SUM(age) FROM users").unwrap();
2499 match stmt {
2500 Statement::Select(sq) => match sq.body {
2501 QueryBody::Select(sel) => {
2502 assert_eq!(sel.columns.len(), 2);
2503 match &sel.columns[0] {
2504 SelectColumn::Expr {
2505 expr: Expr::CountStar,
2506 ..
2507 } => {}
2508 other => panic!("expected CountStar, got {other:?}"),
2509 }
2510 }
2511 _ => panic!("expected QueryBody::Select"),
2512 },
2513 _ => panic!("expected Select"),
2514 }
2515 }
2516
2517 #[test]
2518 fn parse_group_by_having() {
2519 let stmt = parse_sql(
2520 "SELECT department, COUNT(*) FROM employees GROUP BY department HAVING COUNT(*) > 5",
2521 )
2522 .unwrap();
2523 match stmt {
2524 Statement::Select(sq) => match sq.body {
2525 QueryBody::Select(sel) => {
2526 assert_eq!(sel.group_by.len(), 1);
2527 assert!(sel.having.is_some());
2528 }
2529 _ => panic!("expected QueryBody::Select"),
2530 },
2531 _ => panic!("expected Select"),
2532 }
2533 }
2534
2535 #[test]
2536 fn parse_expressions() {
2537 let stmt = parse_sql("SELECT id + 1, -price, NOT active FROM items").unwrap();
2538 match stmt {
2539 Statement::Select(sq) => match sq.body {
2540 QueryBody::Select(sel) => {
2541 assert_eq!(sel.columns.len(), 3);
2542 match &sel.columns[0] {
2544 SelectColumn::Expr {
2545 expr: Expr::BinaryOp { op: BinOp::Add, .. },
2546 ..
2547 } => {}
2548 other => panic!("expected BinaryOp Add, got {other:?}"),
2549 }
2550 match &sel.columns[1] {
2552 SelectColumn::Expr {
2553 expr:
2554 Expr::UnaryOp {
2555 op: UnaryOp::Neg, ..
2556 },
2557 ..
2558 } => {}
2559 other => panic!("expected UnaryOp Neg, got {other:?}"),
2560 }
2561 match &sel.columns[2] {
2563 SelectColumn::Expr {
2564 expr:
2565 Expr::UnaryOp {
2566 op: UnaryOp::Not, ..
2567 },
2568 ..
2569 } => {}
2570 other => panic!("expected UnaryOp Not, got {other:?}"),
2571 }
2572 }
2573 _ => panic!("expected QueryBody::Select"),
2574 },
2575 _ => panic!("expected Select"),
2576 }
2577 }
2578
2579 #[test]
2580 fn parse_is_null() {
2581 let stmt = parse_sql("SELECT * FROM t WHERE x IS NULL").unwrap();
2582 match stmt {
2583 Statement::Select(sq) => match sq.body {
2584 QueryBody::Select(sel) => {
2585 assert!(matches!(sel.where_clause, Some(Expr::IsNull(_))));
2586 }
2587 _ => panic!("expected QueryBody::Select"),
2588 },
2589 _ => panic!("expected Select"),
2590 }
2591 }
2592
2593 #[test]
2594 fn parse_inner_join() {
2595 let stmt = parse_sql("SELECT * FROM a JOIN b ON a.id = b.id").unwrap();
2596 match stmt {
2597 Statement::Select(sq) => match sq.body {
2598 QueryBody::Select(sel) => {
2599 assert_eq!(sel.from, "a");
2600 assert_eq!(sel.joins.len(), 1);
2601 assert_eq!(sel.joins[0].join_type, JoinType::Inner);
2602 assert_eq!(sel.joins[0].table.name, "b");
2603 assert!(sel.joins[0].on_clause.is_some());
2604 }
2605 _ => panic!("expected QueryBody::Select"),
2606 },
2607 _ => panic!("expected Select"),
2608 }
2609 }
2610
2611 #[test]
2612 fn parse_inner_join_explicit() {
2613 let stmt = parse_sql("SELECT * FROM a INNER JOIN b ON a.id = b.a_id").unwrap();
2614 match stmt {
2615 Statement::Select(sq) => match sq.body {
2616 QueryBody::Select(sel) => {
2617 assert_eq!(sel.joins.len(), 1);
2618 assert_eq!(sel.joins[0].join_type, JoinType::Inner);
2619 }
2620 _ => panic!("expected QueryBody::Select"),
2621 },
2622 _ => panic!("expected Select"),
2623 }
2624 }
2625
2626 #[test]
2627 fn parse_cross_join() {
2628 let stmt = parse_sql("SELECT * FROM a CROSS JOIN b").unwrap();
2629 match stmt {
2630 Statement::Select(sq) => match sq.body {
2631 QueryBody::Select(sel) => {
2632 assert_eq!(sel.joins.len(), 1);
2633 assert_eq!(sel.joins[0].join_type, JoinType::Cross);
2634 assert!(sel.joins[0].on_clause.is_none());
2635 }
2636 _ => panic!("expected QueryBody::Select"),
2637 },
2638 _ => panic!("expected Select"),
2639 }
2640 }
2641
2642 #[test]
2643 fn parse_left_join() {
2644 let stmt = parse_sql("SELECT * FROM a LEFT JOIN b ON a.id = b.a_id").unwrap();
2645 match stmt {
2646 Statement::Select(sq) => match sq.body {
2647 QueryBody::Select(sel) => {
2648 assert_eq!(sel.joins.len(), 1);
2649 assert_eq!(sel.joins[0].join_type, JoinType::Left);
2650 }
2651 _ => panic!("expected QueryBody::Select"),
2652 },
2653 _ => panic!("expected Select"),
2654 }
2655 }
2656
2657 #[test]
2658 fn parse_table_alias() {
2659 let stmt = parse_sql("SELECT u.id FROM users u JOIN orders o ON u.id = o.user_id").unwrap();
2660 match stmt {
2661 Statement::Select(sq) => match sq.body {
2662 QueryBody::Select(sel) => {
2663 assert_eq!(sel.from, "users");
2664 assert_eq!(sel.from_alias.as_deref(), Some("u"));
2665 assert_eq!(sel.joins[0].table.name, "orders");
2666 assert_eq!(sel.joins[0].table.alias.as_deref(), Some("o"));
2667 }
2668 _ => panic!("expected QueryBody::Select"),
2669 },
2670 _ => panic!("expected Select"),
2671 }
2672 }
2673
2674 #[test]
2675 fn parse_multi_join() {
2676 let stmt =
2677 parse_sql("SELECT * FROM a JOIN b ON a.id = b.a_id JOIN c ON b.id = c.b_id").unwrap();
2678 match stmt {
2679 Statement::Select(sq) => match sq.body {
2680 QueryBody::Select(sel) => {
2681 assert_eq!(sel.joins.len(), 2);
2682 }
2683 _ => panic!("expected QueryBody::Select"),
2684 },
2685 _ => panic!("expected Select"),
2686 }
2687 }
2688
2689 #[test]
2690 fn parse_qualified_column() {
2691 let stmt = parse_sql("SELECT u.id, u.name FROM users u").unwrap();
2692 match stmt {
2693 Statement::Select(sq) => match sq.body {
2694 QueryBody::Select(sel) => match &sel.columns[0] {
2695 SelectColumn::Expr {
2696 expr: Expr::QualifiedColumn { table, column },
2697 ..
2698 } => {
2699 assert_eq!(table, "u");
2700 assert_eq!(column, "id");
2701 }
2702 other => panic!("expected QualifiedColumn, got {other:?}"),
2703 },
2704 _ => panic!("expected QueryBody::Select"),
2705 },
2706 _ => panic!("expected Select"),
2707 }
2708 }
2709
2710 #[test]
2711 fn reject_subquery() {
2712 assert!(parse_sql("SELECT * FROM (SELECT 1)").is_err());
2713 }
2714
2715 #[test]
2716 fn parse_type_mapping() {
2717 let stmt = parse_sql(
2718 "CREATE TABLE t (a INT PRIMARY KEY, b BIGINT, c SMALLINT, d REAL, e DOUBLE PRECISION, f VARCHAR(255), g BOOLEAN, h BLOB, i BYTEA)"
2719 ).unwrap();
2720 match stmt {
2721 Statement::CreateTable(ct) => {
2722 assert_eq!(ct.columns[0].data_type, DataType::Integer); assert_eq!(ct.columns[1].data_type, DataType::Integer); assert_eq!(ct.columns[2].data_type, DataType::Integer); assert_eq!(ct.columns[3].data_type, DataType::Real); assert_eq!(ct.columns[4].data_type, DataType::Real); assert_eq!(ct.columns[5].data_type, DataType::Text); assert_eq!(ct.columns[6].data_type, DataType::Boolean); assert_eq!(ct.columns[7].data_type, DataType::Blob); assert_eq!(ct.columns[8].data_type, DataType::Blob); }
2732 _ => panic!("expected CreateTable"),
2733 }
2734 }
2735
2736 #[test]
2737 fn parse_boolean_literals() {
2738 let stmt = parse_sql("INSERT INTO t (a, b) VALUES (true, false)").unwrap();
2739 match stmt {
2740 Statement::Insert(ins) => {
2741 let values = match &ins.source {
2742 InsertSource::Values(v) => v,
2743 _ => panic!("expected Values"),
2744 };
2745 assert!(matches!(values[0][0], Expr::Literal(Value::Boolean(true))));
2746 assert!(matches!(values[0][1], Expr::Literal(Value::Boolean(false))));
2747 }
2748 _ => panic!("expected Insert"),
2749 }
2750 }
2751
2752 #[test]
2753 fn parse_null_literal() {
2754 let stmt = parse_sql("INSERT INTO t (a) VALUES (NULL)").unwrap();
2755 match stmt {
2756 Statement::Insert(ins) => {
2757 let values = match &ins.source {
2758 InsertSource::Values(v) => v,
2759 _ => panic!("expected Values"),
2760 };
2761 assert!(matches!(values[0][0], Expr::Literal(Value::Null)));
2762 }
2763 _ => panic!("expected Insert"),
2764 }
2765 }
2766
2767 #[test]
2768 fn parse_alias() {
2769 let stmt = parse_sql("SELECT id AS user_id FROM users").unwrap();
2770 match stmt {
2771 Statement::Select(sq) => match sq.body {
2772 QueryBody::Select(sel) => match &sel.columns[0] {
2773 SelectColumn::Expr { alias: Some(a), .. } => assert_eq!(a, "user_id"),
2774 other => panic!("expected alias, got {other:?}"),
2775 },
2776 _ => panic!("expected QueryBody::Select"),
2777 },
2778 _ => panic!("expected Select"),
2779 }
2780 }
2781
2782 #[test]
2783 fn parse_begin() {
2784 let stmt = parse_sql("BEGIN").unwrap();
2785 assert!(matches!(stmt, Statement::Begin));
2786 }
2787
2788 #[test]
2789 fn parse_begin_transaction() {
2790 let stmt = parse_sql("BEGIN TRANSACTION").unwrap();
2791 assert!(matches!(stmt, Statement::Begin));
2792 }
2793
2794 #[test]
2795 fn parse_commit() {
2796 let stmt = parse_sql("COMMIT").unwrap();
2797 assert!(matches!(stmt, Statement::Commit));
2798 }
2799
2800 #[test]
2801 fn parse_rollback() {
2802 let stmt = parse_sql("ROLLBACK").unwrap();
2803 assert!(matches!(stmt, Statement::Rollback));
2804 }
2805
2806 #[test]
2807 fn parse_savepoint() {
2808 let stmt = parse_sql("SAVEPOINT sp1").unwrap();
2809 match stmt {
2810 Statement::Savepoint(name) => assert_eq!(name, "sp1"),
2811 other => panic!("expected Savepoint, got {other:?}"),
2812 }
2813 }
2814
2815 #[test]
2816 fn parse_savepoint_case_insensitive() {
2817 let stmt = parse_sql("SAVEPOINT My_SP").unwrap();
2818 match stmt {
2819 Statement::Savepoint(name) => assert_eq!(name, "my_sp"),
2820 other => panic!("expected Savepoint, got {other:?}"),
2821 }
2822 }
2823
2824 #[test]
2825 fn parse_release_savepoint() {
2826 let stmt = parse_sql("RELEASE SAVEPOINT sp1").unwrap();
2827 match stmt {
2828 Statement::ReleaseSavepoint(name) => assert_eq!(name, "sp1"),
2829 other => panic!("expected ReleaseSavepoint, got {other:?}"),
2830 }
2831 }
2832
2833 #[test]
2834 fn parse_release_without_savepoint_keyword() {
2835 let stmt = parse_sql("RELEASE sp1").unwrap();
2836 match stmt {
2837 Statement::ReleaseSavepoint(name) => assert_eq!(name, "sp1"),
2838 other => panic!("expected ReleaseSavepoint, got {other:?}"),
2839 }
2840 }
2841
2842 #[test]
2843 fn parse_rollback_to_savepoint() {
2844 let stmt = parse_sql("ROLLBACK TO SAVEPOINT sp1").unwrap();
2845 match stmt {
2846 Statement::RollbackTo(name) => assert_eq!(name, "sp1"),
2847 other => panic!("expected RollbackTo, got {other:?}"),
2848 }
2849 }
2850
2851 #[test]
2852 fn parse_rollback_to_without_savepoint_keyword() {
2853 let stmt = parse_sql("ROLLBACK TO sp1").unwrap();
2854 match stmt {
2855 Statement::RollbackTo(name) => assert_eq!(name, "sp1"),
2856 other => panic!("expected RollbackTo, got {other:?}"),
2857 }
2858 }
2859
2860 #[test]
2861 fn parse_rollback_to_case_insensitive() {
2862 let stmt = parse_sql("ROLLBACK TO My_SP").unwrap();
2863 match stmt {
2864 Statement::RollbackTo(name) => assert_eq!(name, "my_sp"),
2865 other => panic!("expected RollbackTo, got {other:?}"),
2866 }
2867 }
2868
2869 #[test]
2870 fn parse_commit_and_chain_rejected() {
2871 let err = parse_sql("COMMIT AND CHAIN").unwrap_err();
2872 assert!(matches!(err, SqlError::Unsupported(_)));
2873 }
2874
2875 #[test]
2876 fn parse_rollback_and_chain_rejected() {
2877 let err = parse_sql("ROLLBACK AND CHAIN").unwrap_err();
2878 assert!(matches!(err, SqlError::Unsupported(_)));
2879 }
2880
2881 #[test]
2882 fn parse_select_distinct() {
2883 let stmt = parse_sql("SELECT DISTINCT name FROM users").unwrap();
2884 match stmt {
2885 Statement::Select(sq) => match sq.body {
2886 QueryBody::Select(sel) => {
2887 assert!(sel.distinct);
2888 assert_eq!(sel.columns.len(), 1);
2889 }
2890 _ => panic!("expected QueryBody::Select"),
2891 },
2892 _ => panic!("expected Select"),
2893 }
2894 }
2895
2896 #[test]
2897 fn parse_select_without_distinct() {
2898 let stmt = parse_sql("SELECT name FROM users").unwrap();
2899 match stmt {
2900 Statement::Select(sq) => match sq.body {
2901 QueryBody::Select(sel) => {
2902 assert!(!sel.distinct);
2903 }
2904 _ => panic!("expected QueryBody::Select"),
2905 },
2906 _ => panic!("expected Select"),
2907 }
2908 }
2909
2910 #[test]
2911 fn parse_select_distinct_all_columns() {
2912 let stmt = parse_sql("SELECT DISTINCT * FROM users").unwrap();
2913 match stmt {
2914 Statement::Select(sq) => match sq.body {
2915 QueryBody::Select(sel) => {
2916 assert!(sel.distinct);
2917 assert!(matches!(sel.columns[0], SelectColumn::AllColumns));
2918 }
2919 _ => panic!("expected QueryBody::Select"),
2920 },
2921 _ => panic!("expected Select"),
2922 }
2923 }
2924
2925 #[test]
2926 fn reject_distinct_on() {
2927 assert!(parse_sql("SELECT DISTINCT ON (id) * FROM users").is_err());
2928 }
2929
2930 #[test]
2931 fn parse_create_index() {
2932 let stmt = parse_sql("CREATE INDEX idx_name ON users (name)").unwrap();
2933 match stmt {
2934 Statement::CreateIndex(ci) => {
2935 assert_eq!(ci.index_name, "idx_name");
2936 assert_eq!(ci.table_name, "users");
2937 assert_eq!(ci.columns, vec!["name"]);
2938 assert!(!ci.unique);
2939 assert!(!ci.if_not_exists);
2940 }
2941 _ => panic!("expected CreateIndex"),
2942 }
2943 }
2944
2945 #[test]
2946 fn parse_create_unique_index() {
2947 let stmt = parse_sql("CREATE UNIQUE INDEX idx_email ON users (email)").unwrap();
2948 match stmt {
2949 Statement::CreateIndex(ci) => {
2950 assert!(ci.unique);
2951 assert_eq!(ci.columns, vec!["email"]);
2952 }
2953 _ => panic!("expected CreateIndex"),
2954 }
2955 }
2956
2957 #[test]
2958 fn parse_create_index_if_not_exists() {
2959 let stmt = parse_sql("CREATE INDEX IF NOT EXISTS idx_x ON t (a)").unwrap();
2960 match stmt {
2961 Statement::CreateIndex(ci) => assert!(ci.if_not_exists),
2962 _ => panic!("expected CreateIndex"),
2963 }
2964 }
2965
2966 #[test]
2967 fn parse_create_index_multi_column() {
2968 let stmt = parse_sql("CREATE INDEX idx_multi ON t (a, b, c)").unwrap();
2969 match stmt {
2970 Statement::CreateIndex(ci) => {
2971 assert_eq!(ci.columns, vec!["a", "b", "c"]);
2972 }
2973 _ => panic!("expected CreateIndex"),
2974 }
2975 }
2976
2977 #[test]
2978 fn parse_drop_index() {
2979 let stmt = parse_sql("DROP INDEX idx_name").unwrap();
2980 match stmt {
2981 Statement::DropIndex(di) => {
2982 assert_eq!(di.index_name, "idx_name");
2983 assert!(!di.if_exists);
2984 }
2985 _ => panic!("expected DropIndex"),
2986 }
2987 }
2988
2989 #[test]
2990 fn parse_drop_index_if_exists() {
2991 let stmt = parse_sql("DROP INDEX IF EXISTS idx_name").unwrap();
2992 match stmt {
2993 Statement::DropIndex(di) => {
2994 assert!(di.if_exists);
2995 }
2996 _ => panic!("expected DropIndex"),
2997 }
2998 }
2999
3000 #[test]
3001 fn parse_explain_select() {
3002 let stmt = parse_sql("EXPLAIN SELECT * FROM users WHERE id = 1").unwrap();
3003 match stmt {
3004 Statement::Explain(inner) => {
3005 assert!(matches!(*inner, Statement::Select(_)));
3006 }
3007 _ => panic!("expected Explain"),
3008 }
3009 }
3010
3011 #[test]
3012 fn parse_explain_insert() {
3013 let stmt = parse_sql("EXPLAIN INSERT INTO t (a) VALUES (1)").unwrap();
3014 assert!(matches!(stmt, Statement::Explain(_)));
3015 }
3016
3017 #[test]
3018 fn reject_explain_analyze() {
3019 assert!(parse_sql("EXPLAIN ANALYZE SELECT * FROM t").is_err());
3020 }
3021
3022 #[test]
3023 fn parse_parameter_placeholder() {
3024 let stmt = parse_sql("SELECT * FROM t WHERE id = $1").unwrap();
3025 match stmt {
3026 Statement::Select(sq) => match sq.body {
3027 QueryBody::Select(sel) => match &sel.where_clause {
3028 Some(Expr::BinaryOp { right, .. }) => {
3029 assert!(matches!(right.as_ref(), Expr::Parameter(1)));
3030 }
3031 other => panic!("expected BinaryOp with Parameter, got {other:?}"),
3032 },
3033 _ => panic!("expected QueryBody::Select"),
3034 },
3035 _ => panic!("expected Select"),
3036 }
3037 }
3038
3039 #[test]
3040 fn parse_multiple_parameters() {
3041 let stmt = parse_sql("INSERT INTO t (a, b) VALUES ($1, $2)").unwrap();
3042 match stmt {
3043 Statement::Insert(ins) => {
3044 let values = match &ins.source {
3045 InsertSource::Values(v) => v,
3046 _ => panic!("expected Values"),
3047 };
3048 assert!(matches!(values[0][0], Expr::Parameter(1)));
3049 assert!(matches!(values[0][1], Expr::Parameter(2)));
3050 }
3051 _ => panic!("expected Insert"),
3052 }
3053 }
3054
3055 #[test]
3056 fn parse_insert_select() {
3057 let stmt =
3058 parse_sql("INSERT INTO t2 (id, name) SELECT id, name FROM t1 WHERE id > 5").unwrap();
3059 match stmt {
3060 Statement::Insert(ins) => {
3061 assert_eq!(ins.table, "t2");
3062 assert_eq!(ins.columns, vec!["id", "name"]);
3063 match &ins.source {
3064 InsertSource::Select(sq) => match &sq.body {
3065 QueryBody::Select(sel) => {
3066 assert_eq!(sel.from, "t1");
3067 assert_eq!(sel.columns.len(), 2);
3068 assert!(sel.where_clause.is_some());
3069 }
3070 _ => panic!("expected QueryBody::Select"),
3071 },
3072 _ => panic!("expected InsertSource::Select"),
3073 }
3074 }
3075 _ => panic!("expected Insert"),
3076 }
3077 }
3078
3079 #[test]
3080 fn parse_insert_select_no_columns() {
3081 let stmt = parse_sql("INSERT INTO t2 SELECT * FROM t1").unwrap();
3082 match stmt {
3083 Statement::Insert(ins) => {
3084 assert_eq!(ins.table, "t2");
3085 assert!(ins.columns.is_empty());
3086 assert!(matches!(&ins.source, InsertSource::Select(_)));
3087 }
3088 _ => panic!("expected Insert"),
3089 }
3090 }
3091
3092 #[test]
3093 fn reject_zero_parameter() {
3094 assert!(parse_sql("SELECT $0 FROM t").is_err());
3095 }
3096
3097 #[test]
3098 fn count_params_basic() {
3099 let stmt = parse_sql("SELECT * FROM t WHERE a = $1 AND b = $3").unwrap();
3100 assert_eq!(count_params(&stmt), 3);
3101 }
3102
3103 #[test]
3104 fn count_params_none() {
3105 let stmt = parse_sql("SELECT * FROM t WHERE a = 1").unwrap();
3106 assert_eq!(count_params(&stmt), 0);
3107 }
3108
3109 #[test]
3110 fn parse_table_constraint_pk() {
3111 let stmt = parse_sql("CREATE TABLE t (a INTEGER, b TEXT, PRIMARY KEY (a))").unwrap();
3112 match stmt {
3113 Statement::CreateTable(ct) => {
3114 assert_eq!(ct.primary_key, vec!["a"]);
3115 assert!(ct.columns[0].is_primary_key);
3116 assert!(!ct.columns[0].nullable);
3117 }
3118 _ => panic!("expected CreateTable"),
3119 }
3120 }
3121}