1use sqlparser::ast as sp;
4use sqlparser::dialect::GenericDialect;
5use sqlparser::parser::Parser;
6
7use crate::error::{Result, SqlError};
8use crate::types::{DataType, Value};
9
10#[derive(Debug, Clone)]
13pub enum Statement {
14 CreateTable(CreateTableStmt),
15 DropTable(DropTableStmt),
16 CreateIndex(CreateIndexStmt),
17 DropIndex(DropIndexStmt),
18 Insert(InsertStmt),
19 Select(Box<SelectStmt>),
20 Update(UpdateStmt),
21 Delete(DeleteStmt),
22 Begin,
23 Commit,
24 Rollback,
25 Explain(Box<Statement>),
26}
27
28#[derive(Debug, Clone)]
29pub struct CreateTableStmt {
30 pub name: String,
31 pub columns: Vec<ColumnSpec>,
32 pub primary_key: Vec<String>,
33 pub if_not_exists: bool,
34}
35
36#[derive(Debug, Clone)]
37pub struct ColumnSpec {
38 pub name: String,
39 pub data_type: DataType,
40 pub nullable: bool,
41 pub is_primary_key: bool,
42}
43
44#[derive(Debug, Clone)]
45pub struct DropTableStmt {
46 pub name: String,
47 pub if_exists: bool,
48}
49
50#[derive(Debug, Clone)]
51pub struct CreateIndexStmt {
52 pub index_name: String,
53 pub table_name: String,
54 pub columns: Vec<String>,
55 pub unique: bool,
56 pub if_not_exists: bool,
57}
58
59#[derive(Debug, Clone)]
60pub struct DropIndexStmt {
61 pub index_name: String,
62 pub if_exists: bool,
63}
64
65#[derive(Debug, Clone)]
66pub struct InsertStmt {
67 pub table: String,
68 pub columns: Vec<String>,
69 pub values: Vec<Vec<Expr>>,
70}
71
72#[derive(Debug, Clone)]
73pub struct TableRef {
74 pub name: String,
75 pub alias: Option<String>,
76}
77
78#[derive(Debug, Clone, Copy, PartialEq)]
79pub enum JoinType {
80 Inner,
81 Cross,
82 Left,
83 Right,
84}
85
86#[derive(Debug, Clone)]
87pub struct JoinClause {
88 pub join_type: JoinType,
89 pub table: TableRef,
90 pub on_clause: Option<Expr>,
91}
92
93#[derive(Debug, Clone)]
94pub struct SelectStmt {
95 pub columns: Vec<SelectColumn>,
96 pub from: String,
97 pub from_alias: Option<String>,
98 pub joins: Vec<JoinClause>,
99 pub distinct: bool,
100 pub where_clause: Option<Expr>,
101 pub order_by: Vec<OrderByItem>,
102 pub limit: Option<Expr>,
103 pub offset: Option<Expr>,
104 pub group_by: Vec<Expr>,
105 pub having: Option<Expr>,
106}
107
108#[derive(Debug, Clone)]
109pub struct UpdateStmt {
110 pub table: String,
111 pub assignments: Vec<(String, Expr)>,
112 pub where_clause: Option<Expr>,
113}
114
115#[derive(Debug, Clone)]
116pub struct DeleteStmt {
117 pub table: String,
118 pub where_clause: Option<Expr>,
119}
120
121#[derive(Debug, Clone)]
122pub enum SelectColumn {
123 AllColumns,
124 Expr { expr: Expr, alias: Option<String> },
125}
126
127#[derive(Debug, Clone)]
128pub struct OrderByItem {
129 pub expr: Expr,
130 pub descending: bool,
131 pub nulls_first: Option<bool>,
132}
133
134#[derive(Debug, Clone)]
135pub enum Expr {
136 Literal(Value),
137 Column(String),
138 QualifiedColumn {
139 table: String,
140 column: String,
141 },
142 BinaryOp {
143 left: Box<Expr>,
144 op: BinOp,
145 right: Box<Expr>,
146 },
147 UnaryOp {
148 op: UnaryOp,
149 expr: Box<Expr>,
150 },
151 IsNull(Box<Expr>),
152 IsNotNull(Box<Expr>),
153 Function {
154 name: String,
155 args: Vec<Expr>,
156 },
157 CountStar,
158 InSubquery {
159 expr: Box<Expr>,
160 subquery: Box<SelectStmt>,
161 negated: bool,
162 },
163 InList {
164 expr: Box<Expr>,
165 list: Vec<Expr>,
166 negated: bool,
167 },
168 Exists {
169 subquery: Box<SelectStmt>,
170 negated: bool,
171 },
172 ScalarSubquery(Box<SelectStmt>),
173 InSet {
174 expr: Box<Expr>,
175 values: std::collections::HashSet<Value>,
176 has_null: bool,
177 negated: bool,
178 },
179 Between {
180 expr: Box<Expr>,
181 low: Box<Expr>,
182 high: Box<Expr>,
183 negated: bool,
184 },
185 Like {
186 expr: Box<Expr>,
187 pattern: Box<Expr>,
188 escape: Option<Box<Expr>>,
189 negated: bool,
190 },
191 Case {
192 operand: Option<Box<Expr>>,
193 conditions: Vec<(Expr, Expr)>,
194 else_result: Option<Box<Expr>>,
195 },
196 Coalesce(Vec<Expr>),
197 Cast {
198 expr: Box<Expr>,
199 data_type: DataType,
200 },
201 Parameter(usize),
202}
203
204#[derive(Debug, Clone, Copy, PartialEq, Eq)]
205pub enum BinOp {
206 Add,
207 Sub,
208 Mul,
209 Div,
210 Mod,
211 Eq,
212 NotEq,
213 Lt,
214 Gt,
215 LtEq,
216 GtEq,
217 And,
218 Or,
219 Concat,
220}
221
222#[derive(Debug, Clone, Copy, PartialEq, Eq)]
223pub enum UnaryOp {
224 Neg,
225 Not,
226}
227
228pub fn parse_sql(sql: &str) -> Result<Statement> {
231 let dialect = GenericDialect {};
232 let stmts = Parser::parse_sql(&dialect, sql).map_err(|e| SqlError::Parse(e.to_string()))?;
233
234 if stmts.is_empty() {
235 return Err(SqlError::Parse("empty SQL".into()));
236 }
237 if stmts.len() > 1 {
238 return Err(SqlError::Unsupported("multiple statements".into()));
239 }
240
241 convert_statement(stmts.into_iter().next().unwrap())
242}
243
244pub fn count_params(stmt: &Statement) -> usize {
248 let mut max_idx = 0usize;
249 visit_exprs_stmt(stmt, &mut |e| {
250 if let Expr::Parameter(n) = e {
251 max_idx = max_idx.max(*n);
252 }
253 });
254 max_idx
255}
256
257pub fn bind_params(
259 stmt: &Statement,
260 params: &[crate::types::Value],
261) -> crate::error::Result<Statement> {
262 bind_stmt(stmt, params)
263}
264
265fn bind_stmt(stmt: &Statement, params: &[crate::types::Value]) -> crate::error::Result<Statement> {
266 match stmt {
267 Statement::Select(sel) => Ok(Statement::Select(Box::new(bind_select(sel, params)?))),
268 Statement::Insert(ins) => {
269 let values = ins
270 .values
271 .iter()
272 .map(|row| {
273 row.iter()
274 .map(|e| bind_expr(e, params))
275 .collect::<crate::error::Result<Vec<_>>>()
276 })
277 .collect::<crate::error::Result<Vec<_>>>()?;
278 Ok(Statement::Insert(InsertStmt {
279 table: ins.table.clone(),
280 columns: ins.columns.clone(),
281 values,
282 }))
283 }
284 Statement::Update(upd) => {
285 let assignments = upd
286 .assignments
287 .iter()
288 .map(|(col, e)| Ok((col.clone(), bind_expr(e, params)?)))
289 .collect::<crate::error::Result<Vec<_>>>()?;
290 let where_clause = upd
291 .where_clause
292 .as_ref()
293 .map(|e| bind_expr(e, params))
294 .transpose()?;
295 Ok(Statement::Update(UpdateStmt {
296 table: upd.table.clone(),
297 assignments,
298 where_clause,
299 }))
300 }
301 Statement::Delete(del) => {
302 let where_clause = del
303 .where_clause
304 .as_ref()
305 .map(|e| bind_expr(e, params))
306 .transpose()?;
307 Ok(Statement::Delete(DeleteStmt {
308 table: del.table.clone(),
309 where_clause,
310 }))
311 }
312 Statement::Explain(inner) => Ok(Statement::Explain(Box::new(bind_stmt(inner, params)?))),
313 other => Ok(other.clone()),
314 }
315}
316
317fn bind_select(
318 sel: &SelectStmt,
319 params: &[crate::types::Value],
320) -> crate::error::Result<SelectStmt> {
321 let columns = sel
322 .columns
323 .iter()
324 .map(|c| match c {
325 SelectColumn::AllColumns => Ok(SelectColumn::AllColumns),
326 SelectColumn::Expr { expr, alias } => Ok(SelectColumn::Expr {
327 expr: bind_expr(expr, params)?,
328 alias: alias.clone(),
329 }),
330 })
331 .collect::<crate::error::Result<Vec<_>>>()?;
332 let joins = sel
333 .joins
334 .iter()
335 .map(|j| {
336 let on_clause = j
337 .on_clause
338 .as_ref()
339 .map(|e| bind_expr(e, params))
340 .transpose()?;
341 Ok(JoinClause {
342 join_type: j.join_type,
343 table: j.table.clone(),
344 on_clause,
345 })
346 })
347 .collect::<crate::error::Result<Vec<_>>>()?;
348 let where_clause = sel
349 .where_clause
350 .as_ref()
351 .map(|e| bind_expr(e, params))
352 .transpose()?;
353 let order_by = sel
354 .order_by
355 .iter()
356 .map(|o| {
357 Ok(OrderByItem {
358 expr: bind_expr(&o.expr, params)?,
359 descending: o.descending,
360 nulls_first: o.nulls_first,
361 })
362 })
363 .collect::<crate::error::Result<Vec<_>>>()?;
364 let limit = sel
365 .limit
366 .as_ref()
367 .map(|e| bind_expr(e, params))
368 .transpose()?;
369 let offset = sel
370 .offset
371 .as_ref()
372 .map(|e| bind_expr(e, params))
373 .transpose()?;
374 let group_by = sel
375 .group_by
376 .iter()
377 .map(|e| bind_expr(e, params))
378 .collect::<crate::error::Result<Vec<_>>>()?;
379 let having = sel
380 .having
381 .as_ref()
382 .map(|e| bind_expr(e, params))
383 .transpose()?;
384
385 Ok(SelectStmt {
386 columns,
387 from: sel.from.clone(),
388 from_alias: sel.from_alias.clone(),
389 joins,
390 distinct: sel.distinct,
391 where_clause,
392 order_by,
393 limit,
394 offset,
395 group_by,
396 having,
397 })
398}
399
400fn bind_expr(expr: &Expr, params: &[crate::types::Value]) -> crate::error::Result<Expr> {
401 match expr {
402 Expr::Parameter(n) => {
403 if *n == 0 || *n > params.len() {
404 return Err(SqlError::ParameterCountMismatch {
405 expected: *n,
406 got: params.len(),
407 });
408 }
409 Ok(Expr::Literal(params[*n - 1].clone()))
410 }
411 Expr::Literal(_) | Expr::Column(_) | Expr::QualifiedColumn { .. } | Expr::CountStar => {
412 Ok(expr.clone())
413 }
414 Expr::BinaryOp { left, op, right } => Ok(Expr::BinaryOp {
415 left: Box::new(bind_expr(left, params)?),
416 op: *op,
417 right: Box::new(bind_expr(right, params)?),
418 }),
419 Expr::UnaryOp { op, expr: e } => Ok(Expr::UnaryOp {
420 op: *op,
421 expr: Box::new(bind_expr(e, params)?),
422 }),
423 Expr::IsNull(e) => Ok(Expr::IsNull(Box::new(bind_expr(e, params)?))),
424 Expr::IsNotNull(e) => Ok(Expr::IsNotNull(Box::new(bind_expr(e, params)?))),
425 Expr::Function { name, args } => {
426 let args = args
427 .iter()
428 .map(|a| bind_expr(a, params))
429 .collect::<crate::error::Result<Vec<_>>>()?;
430 Ok(Expr::Function {
431 name: name.clone(),
432 args,
433 })
434 }
435 Expr::InSubquery {
436 expr: e,
437 subquery,
438 negated,
439 } => Ok(Expr::InSubquery {
440 expr: Box::new(bind_expr(e, params)?),
441 subquery: Box::new(bind_select(subquery, params)?),
442 negated: *negated,
443 }),
444 Expr::InList {
445 expr: e,
446 list,
447 negated,
448 } => {
449 let list = list
450 .iter()
451 .map(|l| bind_expr(l, params))
452 .collect::<crate::error::Result<Vec<_>>>()?;
453 Ok(Expr::InList {
454 expr: Box::new(bind_expr(e, params)?),
455 list,
456 negated: *negated,
457 })
458 }
459 Expr::Exists { subquery, negated } => Ok(Expr::Exists {
460 subquery: Box::new(bind_select(subquery, params)?),
461 negated: *negated,
462 }),
463 Expr::ScalarSubquery(sq) => Ok(Expr::ScalarSubquery(Box::new(bind_select(sq, params)?))),
464 Expr::InSet {
465 expr: e,
466 values,
467 has_null,
468 negated,
469 } => Ok(Expr::InSet {
470 expr: Box::new(bind_expr(e, params)?),
471 values: values.clone(),
472 has_null: *has_null,
473 negated: *negated,
474 }),
475 Expr::Between {
476 expr: e,
477 low,
478 high,
479 negated,
480 } => Ok(Expr::Between {
481 expr: Box::new(bind_expr(e, params)?),
482 low: Box::new(bind_expr(low, params)?),
483 high: Box::new(bind_expr(high, params)?),
484 negated: *negated,
485 }),
486 Expr::Like {
487 expr: e,
488 pattern,
489 escape,
490 negated,
491 } => Ok(Expr::Like {
492 expr: Box::new(bind_expr(e, params)?),
493 pattern: Box::new(bind_expr(pattern, params)?),
494 escape: escape
495 .as_ref()
496 .map(|esc| bind_expr(esc, params).map(Box::new))
497 .transpose()?,
498 negated: *negated,
499 }),
500 Expr::Case {
501 operand,
502 conditions,
503 else_result,
504 } => {
505 let operand = operand
506 .as_ref()
507 .map(|e| bind_expr(e, params).map(Box::new))
508 .transpose()?;
509 let conditions = conditions
510 .iter()
511 .map(|(cond, then)| Ok((bind_expr(cond, params)?, bind_expr(then, params)?)))
512 .collect::<crate::error::Result<Vec<_>>>()?;
513 let else_result = else_result
514 .as_ref()
515 .map(|e| bind_expr(e, params).map(Box::new))
516 .transpose()?;
517 Ok(Expr::Case {
518 operand,
519 conditions,
520 else_result,
521 })
522 }
523 Expr::Coalesce(args) => {
524 let args = args
525 .iter()
526 .map(|a| bind_expr(a, params))
527 .collect::<crate::error::Result<Vec<_>>>()?;
528 Ok(Expr::Coalesce(args))
529 }
530 Expr::Cast { expr: e, data_type } => Ok(Expr::Cast {
531 expr: Box::new(bind_expr(e, params)?),
532 data_type: *data_type,
533 }),
534 }
535}
536
537fn visit_exprs_stmt(stmt: &Statement, visitor: &mut impl FnMut(&Expr)) {
538 match stmt {
539 Statement::Select(sel) => visit_exprs_select(sel, visitor),
540 Statement::Insert(ins) => {
541 for row in &ins.values {
542 for e in row {
543 visit_expr(e, visitor);
544 }
545 }
546 }
547 Statement::Update(upd) => {
548 for (_, e) in &upd.assignments {
549 visit_expr(e, visitor);
550 }
551 if let Some(w) = &upd.where_clause {
552 visit_expr(w, visitor);
553 }
554 }
555 Statement::Delete(del) => {
556 if let Some(w) = &del.where_clause {
557 visit_expr(w, visitor);
558 }
559 }
560 Statement::Explain(inner) => visit_exprs_stmt(inner, visitor),
561 _ => {}
562 }
563}
564
565fn visit_exprs_select(sel: &SelectStmt, visitor: &mut impl FnMut(&Expr)) {
566 for col in &sel.columns {
567 if let SelectColumn::Expr { expr, .. } = col {
568 visit_expr(expr, visitor);
569 }
570 }
571 for j in &sel.joins {
572 if let Some(on) = &j.on_clause {
573 visit_expr(on, visitor);
574 }
575 }
576 if let Some(w) = &sel.where_clause {
577 visit_expr(w, visitor);
578 }
579 for o in &sel.order_by {
580 visit_expr(&o.expr, visitor);
581 }
582 if let Some(l) = &sel.limit {
583 visit_expr(l, visitor);
584 }
585 if let Some(o) = &sel.offset {
586 visit_expr(o, visitor);
587 }
588 for g in &sel.group_by {
589 visit_expr(g, visitor);
590 }
591 if let Some(h) = &sel.having {
592 visit_expr(h, visitor);
593 }
594}
595
596fn visit_expr(expr: &Expr, visitor: &mut impl FnMut(&Expr)) {
597 visitor(expr);
598 match expr {
599 Expr::BinaryOp { left, right, .. } => {
600 visit_expr(left, visitor);
601 visit_expr(right, visitor);
602 }
603 Expr::UnaryOp { expr: e, .. } | Expr::IsNull(e) | Expr::IsNotNull(e) => {
604 visit_expr(e, visitor);
605 }
606 Expr::Function { args, .. } | Expr::Coalesce(args) => {
607 for a in args {
608 visit_expr(a, visitor);
609 }
610 }
611 Expr::InSubquery {
612 expr: e, subquery, ..
613 } => {
614 visit_expr(e, visitor);
615 visit_exprs_select(subquery, visitor);
616 }
617 Expr::InList { expr: e, list, .. } => {
618 visit_expr(e, visitor);
619 for l in list {
620 visit_expr(l, visitor);
621 }
622 }
623 Expr::Exists { subquery, .. } => visit_exprs_select(subquery, visitor),
624 Expr::ScalarSubquery(sq) => visit_exprs_select(sq, visitor),
625 Expr::InSet { expr: e, .. } => visit_expr(e, visitor),
626 Expr::Between {
627 expr: e, low, high, ..
628 } => {
629 visit_expr(e, visitor);
630 visit_expr(low, visitor);
631 visit_expr(high, visitor);
632 }
633 Expr::Like {
634 expr: e,
635 pattern,
636 escape,
637 ..
638 } => {
639 visit_expr(e, visitor);
640 visit_expr(pattern, visitor);
641 if let Some(esc) = escape {
642 visit_expr(esc, visitor);
643 }
644 }
645 Expr::Case {
646 operand,
647 conditions,
648 else_result,
649 } => {
650 if let Some(op) = operand {
651 visit_expr(op, visitor);
652 }
653 for (cond, then) in conditions {
654 visit_expr(cond, visitor);
655 visit_expr(then, visitor);
656 }
657 if let Some(el) = else_result {
658 visit_expr(el, visitor);
659 }
660 }
661 Expr::Cast { expr: e, .. } => visit_expr(e, visitor),
662 Expr::Literal(_)
663 | Expr::Column(_)
664 | Expr::QualifiedColumn { .. }
665 | Expr::CountStar
666 | Expr::Parameter(_) => {}
667 }
668}
669
670fn convert_statement(stmt: sp::Statement) -> Result<Statement> {
673 match stmt {
674 sp::Statement::CreateTable(ct) => convert_create_table(ct),
675 sp::Statement::CreateIndex(ci) => convert_create_index(ci),
676 sp::Statement::Drop {
677 object_type: sp::ObjectType::Table,
678 if_exists,
679 names,
680 ..
681 } => {
682 if names.len() != 1 {
683 return Err(SqlError::Unsupported("multi-table DROP".into()));
684 }
685 Ok(Statement::DropTable(DropTableStmt {
686 name: object_name_to_string(&names[0]),
687 if_exists,
688 }))
689 }
690 sp::Statement::Drop {
691 object_type: sp::ObjectType::Index,
692 if_exists,
693 names,
694 ..
695 } => {
696 if names.len() != 1 {
697 return Err(SqlError::Unsupported("multi-index DROP".into()));
698 }
699 Ok(Statement::DropIndex(DropIndexStmt {
700 index_name: object_name_to_string(&names[0]),
701 if_exists,
702 }))
703 }
704 sp::Statement::Insert(insert) => convert_insert(insert),
705 sp::Statement::Query(query) => convert_query(*query),
706 sp::Statement::Update(update) => convert_update(update),
707 sp::Statement::Delete(delete) => convert_delete(delete),
708 sp::Statement::StartTransaction { .. } => Ok(Statement::Begin),
709 sp::Statement::Commit { .. } => Ok(Statement::Commit),
710 sp::Statement::Rollback { .. } => Ok(Statement::Rollback),
711 sp::Statement::Explain {
712 statement, analyze, ..
713 } => {
714 if analyze {
715 return Err(SqlError::Unsupported("EXPLAIN ANALYZE".into()));
716 }
717 let inner = convert_statement(*statement)?;
718 Ok(Statement::Explain(Box::new(inner)))
719 }
720 _ => Err(SqlError::Unsupported(format!("statement type: {}", stmt))),
721 }
722}
723
724fn convert_create_table(ct: sp::CreateTable) -> Result<Statement> {
725 let name = object_name_to_string(&ct.name);
726 let if_not_exists = ct.if_not_exists;
727
728 let mut columns = Vec::new();
729 let mut inline_pk: Vec<String> = Vec::new();
730
731 for col_def in &ct.columns {
732 let col_name = col_def.name.value.clone();
733 let data_type = convert_data_type(&col_def.data_type)?;
734 let mut nullable = true;
735 let mut is_primary_key = false;
736
737 for opt in &col_def.options {
738 match &opt.option {
739 sp::ColumnOption::NotNull => nullable = false,
740 sp::ColumnOption::Null => nullable = true,
741 sp::ColumnOption::PrimaryKey(_) => {
742 is_primary_key = true;
743 nullable = false;
744 inline_pk.push(col_name.clone());
745 }
746 _ => {}
747 }
748 }
749
750 columns.push(ColumnSpec {
751 name: col_name,
752 data_type,
753 nullable,
754 is_primary_key,
755 });
756 }
757
758 for constraint in &ct.constraints {
760 if let sp::TableConstraint::PrimaryKey(pk_constraint) = constraint {
761 for idx_col in &pk_constraint.columns {
762 let col_name = match &idx_col.column.expr {
764 sp::Expr::Identifier(ident) => ident.value.clone(),
765 _ => continue,
766 };
767 if !inline_pk.contains(&col_name) {
768 inline_pk.push(col_name.clone());
769 }
770 if let Some(col) = columns.iter_mut().find(|c| c.name == col_name) {
771 col.nullable = false;
772 col.is_primary_key = true;
773 }
774 }
775 }
776 }
777
778 Ok(Statement::CreateTable(CreateTableStmt {
779 name,
780 columns,
781 primary_key: inline_pk,
782 if_not_exists,
783 }))
784}
785
786fn convert_create_index(ci: sp::CreateIndex) -> Result<Statement> {
787 let index_name = ci
788 .name
789 .as_ref()
790 .map(object_name_to_string)
791 .ok_or_else(|| SqlError::Parse("index name required".into()))?;
792
793 let table_name = object_name_to_string(&ci.table_name);
794
795 let columns: Vec<String> = ci
796 .columns
797 .iter()
798 .map(|idx_col| match &idx_col.column.expr {
799 sp::Expr::Identifier(ident) => Ok(ident.value.clone()),
800 other => Err(SqlError::Unsupported(format!("expression index: {other}"))),
801 })
802 .collect::<Result<_>>()?;
803
804 if columns.is_empty() {
805 return Err(SqlError::Parse(
806 "index must have at least one column".into(),
807 ));
808 }
809
810 Ok(Statement::CreateIndex(CreateIndexStmt {
811 index_name,
812 table_name,
813 columns,
814 unique: ci.unique,
815 if_not_exists: ci.if_not_exists,
816 }))
817}
818
819fn convert_insert(insert: sp::Insert) -> Result<Statement> {
820 let table = match &insert.table {
821 sp::TableObject::TableName(name) => object_name_to_string(name),
822 _ => return Err(SqlError::Unsupported("INSERT into non-table object".into())),
823 };
824
825 let columns: Vec<String> = insert.columns.iter().map(|c| c.value.clone()).collect();
826
827 let source = insert
828 .source
829 .ok_or_else(|| SqlError::Parse("INSERT requires VALUES".into()))?;
830
831 let values = match *source.body {
832 sp::SetExpr::Values(sp::Values { rows, .. }) => {
833 let mut result = Vec::new();
834 for row in rows {
835 let mut exprs = Vec::new();
836 for expr in row {
837 exprs.push(convert_expr(&expr)?);
838 }
839 result.push(exprs);
840 }
841 result
842 }
843 _ => return Err(SqlError::Unsupported("INSERT ... SELECT".into())),
844 };
845
846 Ok(Statement::Insert(InsertStmt {
847 table,
848 columns,
849 values,
850 }))
851}
852
853fn convert_subquery(query: &sp::Query) -> Result<SelectStmt> {
854 match convert_query(query.clone())? {
855 Statement::Select(s) => Ok(*s),
856 _ => Err(SqlError::Unsupported("non-SELECT subquery".into())),
857 }
858}
859
860fn convert_query(query: sp::Query) -> Result<Statement> {
861 let select = match *query.body {
862 sp::SetExpr::Select(sel) => *sel,
863 _ => return Err(SqlError::Unsupported("UNION/INTERSECT/EXCEPT".into())),
864 };
865
866 let distinct = match &select.distinct {
867 Some(sp::Distinct::Distinct) => true,
868 Some(sp::Distinct::On(_)) => {
869 return Err(SqlError::Unsupported("DISTINCT ON".into()));
870 }
871 _ => false,
872 };
873
874 let (from, from_alias, joins) = if select.from.is_empty() {
876 (String::new(), None, vec![])
877 } else if select.from.len() == 1 {
878 let table_with_joins = &select.from[0];
879 let (name, alias) = match &table_with_joins.relation {
880 sp::TableFactor::Table { name, alias, .. } => {
881 let table_name = object_name_to_string(name);
882 let alias_str = alias.as_ref().map(|a| a.name.value.clone());
883 (table_name, alias_str)
884 }
885 _ => return Err(SqlError::Unsupported("non-table FROM source".into())),
886 };
887 let j = table_with_joins
888 .joins
889 .iter()
890 .map(convert_join)
891 .collect::<Result<Vec<_>>>()?;
892 (name, alias, j)
893 } else {
894 return Err(SqlError::Unsupported("comma-separated FROM tables".into()));
895 };
896
897 let columns: Vec<SelectColumn> = select
899 .projection
900 .iter()
901 .map(convert_select_item)
902 .collect::<Result<_>>()?;
903
904 let where_clause = select.selection.as_ref().map(convert_expr).transpose()?;
906
907 let order_by = if let Some(ref ob) = query.order_by {
909 match &ob.kind {
910 sp::OrderByKind::Expressions(exprs) => exprs
911 .iter()
912 .map(convert_order_by_expr)
913 .collect::<Result<_>>()?,
914 sp::OrderByKind::All { .. } => {
915 return Err(SqlError::Unsupported("ORDER BY ALL".into()));
916 }
917 }
918 } else {
919 vec![]
920 };
921
922 let (limit, offset) = match &query.limit_clause {
924 Some(sp::LimitClause::LimitOffset { limit, offset, .. }) => {
925 let l = limit.as_ref().map(convert_expr).transpose()?;
926 let o = offset
927 .as_ref()
928 .map(|o| convert_expr(&o.value))
929 .transpose()?;
930 (l, o)
931 }
932 Some(sp::LimitClause::OffsetCommaLimit { limit, offset }) => {
933 let l = Some(convert_expr(limit)?);
934 let o = Some(convert_expr(offset)?);
935 (l, o)
936 }
937 None => (None, None),
938 };
939
940 let group_by = match &select.group_by {
942 sp::GroupByExpr::Expressions(exprs, _) => {
943 exprs.iter().map(convert_expr).collect::<Result<_>>()?
944 }
945 sp::GroupByExpr::All(_) => {
946 return Err(SqlError::Unsupported("GROUP BY ALL".into()));
947 }
948 };
949
950 let having = select.having.as_ref().map(convert_expr).transpose()?;
952
953 Ok(Statement::Select(Box::new(SelectStmt {
954 columns,
955 from,
956 from_alias,
957 joins,
958 distinct,
959 where_clause,
960 order_by,
961 limit,
962 offset,
963 group_by,
964 having,
965 })))
966}
967
968fn convert_join(join: &sp::Join) -> Result<JoinClause> {
969 let (join_type, constraint) = match &join.join_operator {
970 sp::JoinOperator::Inner(c) => (JoinType::Inner, Some(c)),
971 sp::JoinOperator::Join(c) => (JoinType::Inner, Some(c)),
972 sp::JoinOperator::CrossJoin(c) => (JoinType::Cross, Some(c)),
973 sp::JoinOperator::Left(c) => (JoinType::Left, Some(c)),
974 sp::JoinOperator::LeftSemi(c) => (JoinType::Left, Some(c)),
975 sp::JoinOperator::LeftAnti(c) => (JoinType::Left, Some(c)),
976 sp::JoinOperator::Right(c) => (JoinType::Right, Some(c)),
977 sp::JoinOperator::RightSemi(c) => (JoinType::Right, Some(c)),
978 sp::JoinOperator::RightAnti(c) => (JoinType::Right, Some(c)),
979 other => return Err(SqlError::Unsupported(format!("join type: {other:?}"))),
980 };
981
982 let (name, alias) = match &join.relation {
983 sp::TableFactor::Table { name, alias, .. } => {
984 let table_name = object_name_to_string(name);
985 let alias_str = alias.as_ref().map(|a| a.name.value.clone());
986 (table_name, alias_str)
987 }
988 _ => return Err(SqlError::Unsupported("non-table JOIN source".into())),
989 };
990
991 let on_clause = match constraint {
992 Some(sp::JoinConstraint::On(expr)) => Some(convert_expr(expr)?),
993 Some(sp::JoinConstraint::None) | None => None,
994 Some(other) => return Err(SqlError::Unsupported(format!("join constraint: {other:?}"))),
995 };
996
997 Ok(JoinClause {
998 join_type,
999 table: TableRef { name, alias },
1000 on_clause,
1001 })
1002}
1003
1004fn convert_update(update: sp::Update) -> Result<Statement> {
1005 let table = match &update.table.relation {
1006 sp::TableFactor::Table { name, .. } => object_name_to_string(name),
1007 _ => return Err(SqlError::Unsupported("non-table UPDATE target".into())),
1008 };
1009
1010 let assignments = update
1011 .assignments
1012 .iter()
1013 .map(|a| {
1014 let col = match &a.target {
1015 sp::AssignmentTarget::ColumnName(name) => object_name_to_string(name),
1016 _ => return Err(SqlError::Unsupported("tuple assignment".into())),
1017 };
1018 let expr = convert_expr(&a.value)?;
1019 Ok((col, expr))
1020 })
1021 .collect::<Result<_>>()?;
1022
1023 let where_clause = update.selection.as_ref().map(convert_expr).transpose()?;
1024
1025 Ok(Statement::Update(UpdateStmt {
1026 table,
1027 assignments,
1028 where_clause,
1029 }))
1030}
1031
1032fn convert_delete(delete: sp::Delete) -> Result<Statement> {
1033 let table_name = match &delete.from {
1034 sp::FromTable::WithFromKeyword(tables) => {
1035 if tables.len() != 1 {
1036 return Err(SqlError::Unsupported("multi-table DELETE".into()));
1037 }
1038 match &tables[0].relation {
1039 sp::TableFactor::Table { name, .. } => object_name_to_string(name),
1040 _ => return Err(SqlError::Unsupported("non-table DELETE target".into())),
1041 }
1042 }
1043 sp::FromTable::WithoutKeyword(tables) => {
1044 if tables.len() != 1 {
1045 return Err(SqlError::Unsupported("multi-table DELETE".into()));
1046 }
1047 match &tables[0].relation {
1048 sp::TableFactor::Table { name, .. } => object_name_to_string(name),
1049 _ => return Err(SqlError::Unsupported("non-table DELETE target".into())),
1050 }
1051 }
1052 };
1053
1054 let where_clause = delete.selection.as_ref().map(convert_expr).transpose()?;
1055
1056 Ok(Statement::Delete(DeleteStmt {
1057 table: table_name,
1058 where_clause,
1059 }))
1060}
1061
1062fn convert_expr(expr: &sp::Expr) -> Result<Expr> {
1065 match expr {
1066 sp::Expr::Value(v) => convert_value(&v.value),
1067 sp::Expr::Identifier(ident) => Ok(Expr::Column(ident.value.clone())),
1068 sp::Expr::CompoundIdentifier(parts) => {
1069 if parts.len() == 2 {
1070 Ok(Expr::QualifiedColumn {
1071 table: parts[0].value.clone(),
1072 column: parts[1].value.clone(),
1073 })
1074 } else {
1075 Ok(Expr::Column(parts.last().unwrap().value.clone()))
1076 }
1077 }
1078 sp::Expr::BinaryOp { left, op, right } => {
1079 let bin_op = convert_bin_op(op)?;
1080 Ok(Expr::BinaryOp {
1081 left: Box::new(convert_expr(left)?),
1082 op: bin_op,
1083 right: Box::new(convert_expr(right)?),
1084 })
1085 }
1086 sp::Expr::UnaryOp { op, expr } => {
1087 let unary_op = match op {
1088 sp::UnaryOperator::Minus => UnaryOp::Neg,
1089 sp::UnaryOperator::Not => UnaryOp::Not,
1090 _ => return Err(SqlError::Unsupported(format!("unary op: {op}"))),
1091 };
1092 Ok(Expr::UnaryOp {
1093 op: unary_op,
1094 expr: Box::new(convert_expr(expr)?),
1095 })
1096 }
1097 sp::Expr::IsNull(e) => Ok(Expr::IsNull(Box::new(convert_expr(e)?))),
1098 sp::Expr::IsNotNull(e) => Ok(Expr::IsNotNull(Box::new(convert_expr(e)?))),
1099 sp::Expr::Nested(e) => convert_expr(e),
1100 sp::Expr::Function(func) => convert_function(func),
1101 sp::Expr::InSubquery {
1102 expr: e,
1103 subquery,
1104 negated,
1105 } => {
1106 let inner_expr = convert_expr(e)?;
1107 let stmt = convert_subquery(subquery)?;
1108 Ok(Expr::InSubquery {
1109 expr: Box::new(inner_expr),
1110 subquery: Box::new(stmt),
1111 negated: *negated,
1112 })
1113 }
1114 sp::Expr::InList {
1115 expr: e,
1116 list,
1117 negated,
1118 } => {
1119 let inner_expr = convert_expr(e)?;
1120 let items = list.iter().map(convert_expr).collect::<Result<Vec<_>>>()?;
1121 Ok(Expr::InList {
1122 expr: Box::new(inner_expr),
1123 list: items,
1124 negated: *negated,
1125 })
1126 }
1127 sp::Expr::Exists { subquery, negated } => {
1128 let stmt = convert_subquery(subquery)?;
1129 Ok(Expr::Exists {
1130 subquery: Box::new(stmt),
1131 negated: *negated,
1132 })
1133 }
1134 sp::Expr::Subquery(query) => {
1135 let stmt = convert_subquery(query)?;
1136 Ok(Expr::ScalarSubquery(Box::new(stmt)))
1137 }
1138 sp::Expr::Between {
1139 expr: e,
1140 negated,
1141 low,
1142 high,
1143 } => Ok(Expr::Between {
1144 expr: Box::new(convert_expr(e)?),
1145 low: Box::new(convert_expr(low)?),
1146 high: Box::new(convert_expr(high)?),
1147 negated: *negated,
1148 }),
1149 sp::Expr::Like {
1150 expr: e,
1151 negated,
1152 pattern,
1153 escape_char,
1154 ..
1155 } => {
1156 let esc = escape_char
1157 .as_ref()
1158 .map(convert_escape_value)
1159 .transpose()?
1160 .map(Box::new);
1161 Ok(Expr::Like {
1162 expr: Box::new(convert_expr(e)?),
1163 pattern: Box::new(convert_expr(pattern)?),
1164 escape: esc,
1165 negated: *negated,
1166 })
1167 }
1168 sp::Expr::ILike {
1169 expr: e,
1170 negated,
1171 pattern,
1172 escape_char,
1173 ..
1174 } => {
1175 let esc = escape_char
1176 .as_ref()
1177 .map(convert_escape_value)
1178 .transpose()?
1179 .map(Box::new);
1180 Ok(Expr::Like {
1181 expr: Box::new(convert_expr(e)?),
1182 pattern: Box::new(convert_expr(pattern)?),
1183 escape: esc,
1184 negated: *negated,
1185 })
1186 }
1187 sp::Expr::Case {
1188 operand,
1189 conditions,
1190 else_result,
1191 ..
1192 } => {
1193 let op = operand
1194 .as_ref()
1195 .map(|e| convert_expr(e))
1196 .transpose()?
1197 .map(Box::new);
1198 let conds: Vec<(Expr, Expr)> = conditions
1199 .iter()
1200 .map(|cw| Ok((convert_expr(&cw.condition)?, convert_expr(&cw.result)?)))
1201 .collect::<Result<_>>()?;
1202 let else_r = else_result
1203 .as_ref()
1204 .map(|e| convert_expr(e))
1205 .transpose()?
1206 .map(Box::new);
1207 Ok(Expr::Case {
1208 operand: op,
1209 conditions: conds,
1210 else_result: else_r,
1211 })
1212 }
1213 sp::Expr::Cast {
1214 expr: e,
1215 data_type: dt,
1216 ..
1217 } => {
1218 let target = convert_data_type(dt)?;
1219 Ok(Expr::Cast {
1220 expr: Box::new(convert_expr(e)?),
1221 data_type: target,
1222 })
1223 }
1224 sp::Expr::Substring {
1225 expr: e,
1226 substring_from,
1227 substring_for,
1228 ..
1229 } => {
1230 let mut args = vec![convert_expr(e)?];
1231 if let Some(from) = substring_from {
1232 args.push(convert_expr(from)?);
1233 }
1234 if let Some(f) = substring_for {
1235 args.push(convert_expr(f)?);
1236 }
1237 Ok(Expr::Function {
1238 name: "SUBSTR".into(),
1239 args,
1240 })
1241 }
1242 sp::Expr::Trim {
1243 expr: e,
1244 trim_where,
1245 trim_what,
1246 trim_characters,
1247 } => {
1248 let fn_name = match trim_where {
1249 Some(sp::TrimWhereField::Leading) => "LTRIM",
1250 Some(sp::TrimWhereField::Trailing) => "RTRIM",
1251 _ => "TRIM",
1252 };
1253 let mut args = vec![convert_expr(e)?];
1254 if let Some(what) = trim_what {
1255 args.push(convert_expr(what)?);
1256 } else if let Some(chars) = trim_characters {
1257 if let Some(first) = chars.first() {
1258 args.push(convert_expr(first)?);
1259 }
1260 }
1261 Ok(Expr::Function {
1262 name: fn_name.into(),
1263 args,
1264 })
1265 }
1266 sp::Expr::Ceil { expr: e, .. } => Ok(Expr::Function {
1267 name: "CEIL".into(),
1268 args: vec![convert_expr(e)?],
1269 }),
1270 sp::Expr::Floor { expr: e, .. } => Ok(Expr::Function {
1271 name: "FLOOR".into(),
1272 args: vec![convert_expr(e)?],
1273 }),
1274 sp::Expr::Position { expr: e, r#in } => Ok(Expr::Function {
1275 name: "INSTR".into(),
1276 args: vec![convert_expr(r#in)?, convert_expr(e)?],
1277 }),
1278 _ => Err(SqlError::Unsupported(format!("expression: {expr}"))),
1279 }
1280}
1281
1282fn convert_value(val: &sp::Value) -> Result<Expr> {
1283 match val {
1284 sp::Value::Number(n, _) => {
1285 if let Ok(i) = n.parse::<i64>() {
1286 Ok(Expr::Literal(Value::Integer(i)))
1287 } else if let Ok(f) = n.parse::<f64>() {
1288 Ok(Expr::Literal(Value::Real(f)))
1289 } else {
1290 Err(SqlError::InvalidValue(format!("cannot parse number: {n}")))
1291 }
1292 }
1293 sp::Value::SingleQuotedString(s) => Ok(Expr::Literal(Value::Text(s.clone()))),
1294 sp::Value::Boolean(b) => Ok(Expr::Literal(Value::Boolean(*b))),
1295 sp::Value::Null => Ok(Expr::Literal(Value::Null)),
1296 sp::Value::Placeholder(s) => {
1297 let idx_str = s
1298 .strip_prefix('$')
1299 .ok_or_else(|| SqlError::Parse(format!("invalid placeholder: {s}")))?;
1300 let idx: usize = idx_str
1301 .parse()
1302 .map_err(|_| SqlError::Parse(format!("invalid placeholder index: {s}")))?;
1303 if idx == 0 {
1304 return Err(SqlError::Parse("placeholder index must be >= 1".into()));
1305 }
1306 Ok(Expr::Parameter(idx))
1307 }
1308 _ => Err(SqlError::Unsupported(format!("value type: {val}"))),
1309 }
1310}
1311
1312fn convert_escape_value(val: &sp::Value) -> Result<Expr> {
1313 match val {
1314 sp::Value::SingleQuotedString(s) => Ok(Expr::Literal(Value::Text(s.clone()))),
1315 _ => Err(SqlError::Unsupported(format!("ESCAPE value: {val}"))),
1316 }
1317}
1318
1319fn convert_bin_op(op: &sp::BinaryOperator) -> Result<BinOp> {
1320 match op {
1321 sp::BinaryOperator::Plus => Ok(BinOp::Add),
1322 sp::BinaryOperator::Minus => Ok(BinOp::Sub),
1323 sp::BinaryOperator::Multiply => Ok(BinOp::Mul),
1324 sp::BinaryOperator::Divide => Ok(BinOp::Div),
1325 sp::BinaryOperator::Modulo => Ok(BinOp::Mod),
1326 sp::BinaryOperator::Eq => Ok(BinOp::Eq),
1327 sp::BinaryOperator::NotEq => Ok(BinOp::NotEq),
1328 sp::BinaryOperator::Lt => Ok(BinOp::Lt),
1329 sp::BinaryOperator::Gt => Ok(BinOp::Gt),
1330 sp::BinaryOperator::LtEq => Ok(BinOp::LtEq),
1331 sp::BinaryOperator::GtEq => Ok(BinOp::GtEq),
1332 sp::BinaryOperator::And => Ok(BinOp::And),
1333 sp::BinaryOperator::Or => Ok(BinOp::Or),
1334 sp::BinaryOperator::StringConcat => Ok(BinOp::Concat),
1335 _ => Err(SqlError::Unsupported(format!("binary op: {op}"))),
1336 }
1337}
1338
1339fn convert_function(func: &sp::Function) -> Result<Expr> {
1340 let name = object_name_to_string(&func.name).to_ascii_uppercase();
1341
1342 match &func.args {
1344 sp::FunctionArguments::List(list) => {
1345 if list.args.is_empty() && name == "COUNT" {
1346 return Ok(Expr::CountStar);
1347 }
1348 let args = list
1349 .args
1350 .iter()
1351 .map(|arg| match arg {
1352 sp::FunctionArg::Unnamed(sp::FunctionArgExpr::Expr(e)) => convert_expr(e),
1353 sp::FunctionArg::Unnamed(sp::FunctionArgExpr::Wildcard) => {
1354 if name == "COUNT" {
1355 Ok(Expr::CountStar)
1356 } else {
1357 Err(SqlError::Unsupported(format!("{name}(*)")))
1358 }
1359 }
1360 _ => Err(SqlError::Unsupported(format!(
1361 "function arg type in {name}"
1362 ))),
1363 })
1364 .collect::<Result<Vec<_>>>()?;
1365
1366 if name == "COUNT" && args.len() == 1 && matches!(args[0], Expr::CountStar) {
1367 return Ok(Expr::CountStar);
1368 }
1369
1370 if name == "COALESCE" {
1371 if args.is_empty() {
1372 return Err(SqlError::Parse(
1373 "COALESCE requires at least one argument".into(),
1374 ));
1375 }
1376 return Ok(Expr::Coalesce(args));
1377 }
1378
1379 if name == "NULLIF" {
1380 if args.len() != 2 {
1381 return Err(SqlError::Parse(
1382 "NULLIF requires exactly two arguments".into(),
1383 ));
1384 }
1385 return Ok(Expr::Case {
1386 operand: None,
1387 conditions: vec![(
1388 Expr::BinaryOp {
1389 left: Box::new(args[0].clone()),
1390 op: BinOp::Eq,
1391 right: Box::new(args[1].clone()),
1392 },
1393 Expr::Literal(Value::Null),
1394 )],
1395 else_result: Some(Box::new(args[0].clone())),
1396 });
1397 }
1398
1399 if name == "IIF" {
1400 if args.len() != 3 {
1401 return Err(SqlError::Parse(
1402 "IIF requires exactly three arguments".into(),
1403 ));
1404 }
1405 return Ok(Expr::Case {
1406 operand: None,
1407 conditions: vec![(args[0].clone(), args[1].clone())],
1408 else_result: Some(Box::new(args[2].clone())),
1409 });
1410 }
1411
1412 Ok(Expr::Function { name, args })
1413 }
1414 sp::FunctionArguments::None => {
1415 if name == "COUNT" {
1416 Ok(Expr::CountStar)
1417 } else {
1418 Ok(Expr::Function { name, args: vec![] })
1419 }
1420 }
1421 sp::FunctionArguments::Subquery(_) => {
1422 Err(SqlError::Unsupported("subquery in function".into()))
1423 }
1424 }
1425}
1426
1427fn convert_select_item(item: &sp::SelectItem) -> Result<SelectColumn> {
1428 match item {
1429 sp::SelectItem::Wildcard(_) => Ok(SelectColumn::AllColumns),
1430 sp::SelectItem::UnnamedExpr(e) => {
1431 let expr = convert_expr(e)?;
1432 Ok(SelectColumn::Expr { expr, alias: None })
1433 }
1434 sp::SelectItem::ExprWithAlias { expr, alias } => {
1435 let expr = convert_expr(expr)?;
1436 Ok(SelectColumn::Expr {
1437 expr,
1438 alias: Some(alias.value.clone()),
1439 })
1440 }
1441 sp::SelectItem::QualifiedWildcard(_, _) => {
1442 Err(SqlError::Unsupported("qualified wildcard (table.*)".into()))
1443 }
1444 }
1445}
1446
1447fn convert_order_by_expr(expr: &sp::OrderByExpr) -> Result<OrderByItem> {
1448 let e = convert_expr(&expr.expr)?;
1449 let descending = expr.options.asc.map(|asc| !asc).unwrap_or(false);
1450 let nulls_first = expr.options.nulls_first;
1451
1452 Ok(OrderByItem {
1453 expr: e,
1454 descending,
1455 nulls_first,
1456 })
1457}
1458
1459fn convert_data_type(dt: &sp::DataType) -> Result<DataType> {
1462 match dt {
1463 sp::DataType::Int(_)
1464 | sp::DataType::Integer(_)
1465 | sp::DataType::BigInt(_)
1466 | sp::DataType::SmallInt(_)
1467 | sp::DataType::TinyInt(_)
1468 | sp::DataType::Int2(_)
1469 | sp::DataType::Int4(_)
1470 | sp::DataType::Int8(_) => Ok(DataType::Integer),
1471
1472 sp::DataType::Real
1473 | sp::DataType::Double(..)
1474 | sp::DataType::DoublePrecision
1475 | sp::DataType::Float(_)
1476 | sp::DataType::Float4
1477 | sp::DataType::Float64 => Ok(DataType::Real),
1478
1479 sp::DataType::Varchar(_)
1480 | sp::DataType::Text
1481 | sp::DataType::Char(_)
1482 | sp::DataType::Character(_)
1483 | sp::DataType::String(_) => Ok(DataType::Text),
1484
1485 sp::DataType::Blob(_) | sp::DataType::Bytea => Ok(DataType::Blob),
1486
1487 sp::DataType::Boolean | sp::DataType::Bool => Ok(DataType::Boolean),
1488
1489 _ => Err(SqlError::Unsupported(format!("data type: {dt}"))),
1490 }
1491}
1492
1493fn object_name_to_string(name: &sp::ObjectName) -> String {
1496 name.0
1497 .iter()
1498 .filter_map(|p| match p {
1499 sp::ObjectNamePart::Identifier(ident) => Some(ident.value.clone()),
1500 _ => None,
1501 })
1502 .collect::<Vec<_>>()
1503 .join(".")
1504}
1505
1506#[cfg(test)]
1507mod tests {
1508 use super::*;
1509
1510 #[test]
1511 fn parse_create_table() {
1512 let stmt = parse_sql(
1513 "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL, age INTEGER)",
1514 )
1515 .unwrap();
1516
1517 match stmt {
1518 Statement::CreateTable(ct) => {
1519 assert_eq!(ct.name, "users");
1520 assert_eq!(ct.columns.len(), 3);
1521 assert_eq!(ct.columns[0].name, "id");
1522 assert_eq!(ct.columns[0].data_type, DataType::Integer);
1523 assert!(ct.columns[0].is_primary_key);
1524 assert!(!ct.columns[0].nullable);
1525 assert_eq!(ct.columns[1].name, "name");
1526 assert_eq!(ct.columns[1].data_type, DataType::Text);
1527 assert!(!ct.columns[1].nullable);
1528 assert_eq!(ct.columns[2].name, "age");
1529 assert!(ct.columns[2].nullable);
1530 assert_eq!(ct.primary_key, vec!["id"]);
1531 }
1532 _ => panic!("expected CreateTable"),
1533 }
1534 }
1535
1536 #[test]
1537 fn parse_create_table_if_not_exists() {
1538 let stmt = parse_sql("CREATE TABLE IF NOT EXISTS t (id INT PRIMARY KEY)").unwrap();
1539 match stmt {
1540 Statement::CreateTable(ct) => assert!(ct.if_not_exists),
1541 _ => panic!("expected CreateTable"),
1542 }
1543 }
1544
1545 #[test]
1546 fn parse_drop_table() {
1547 let stmt = parse_sql("DROP TABLE users").unwrap();
1548 match stmt {
1549 Statement::DropTable(dt) => {
1550 assert_eq!(dt.name, "users");
1551 assert!(!dt.if_exists);
1552 }
1553 _ => panic!("expected DropTable"),
1554 }
1555 }
1556
1557 #[test]
1558 fn parse_drop_table_if_exists() {
1559 let stmt = parse_sql("DROP TABLE IF EXISTS users").unwrap();
1560 match stmt {
1561 Statement::DropTable(dt) => assert!(dt.if_exists),
1562 _ => panic!("expected DropTable"),
1563 }
1564 }
1565
1566 #[test]
1567 fn parse_insert() {
1568 let stmt =
1569 parse_sql("INSERT INTO users (id, name) VALUES (1, 'Alice'), (2, 'Bob')").unwrap();
1570
1571 match stmt {
1572 Statement::Insert(ins) => {
1573 assert_eq!(ins.table, "users");
1574 assert_eq!(ins.columns, vec!["id", "name"]);
1575 assert_eq!(ins.values.len(), 2);
1576 assert!(matches!(ins.values[0][0], Expr::Literal(Value::Integer(1))));
1577 assert!(matches!(&ins.values[0][1], Expr::Literal(Value::Text(s)) if s == "Alice"));
1578 }
1579 _ => panic!("expected Insert"),
1580 }
1581 }
1582
1583 #[test]
1584 fn parse_select_all() {
1585 let stmt = parse_sql("SELECT * FROM users").unwrap();
1586 match stmt {
1587 Statement::Select(sel) => {
1588 assert_eq!(sel.from, "users");
1589 assert!(matches!(sel.columns[0], SelectColumn::AllColumns));
1590 assert!(sel.where_clause.is_none());
1591 }
1592 _ => panic!("expected Select"),
1593 }
1594 }
1595
1596 #[test]
1597 fn parse_select_where() {
1598 let stmt = parse_sql("SELECT id, name FROM users WHERE age > 18").unwrap();
1599 match stmt {
1600 Statement::Select(sel) => {
1601 assert_eq!(sel.columns.len(), 2);
1602 assert!(sel.where_clause.is_some());
1603 }
1604 _ => panic!("expected Select"),
1605 }
1606 }
1607
1608 #[test]
1609 fn parse_select_order_limit() {
1610 let stmt = parse_sql("SELECT * FROM users ORDER BY name ASC LIMIT 10 OFFSET 5").unwrap();
1611 match stmt {
1612 Statement::Select(sel) => {
1613 assert_eq!(sel.order_by.len(), 1);
1614 assert!(!sel.order_by[0].descending);
1615 assert!(sel.limit.is_some());
1616 assert!(sel.offset.is_some());
1617 }
1618 _ => panic!("expected Select"),
1619 }
1620 }
1621
1622 #[test]
1623 fn parse_update() {
1624 let stmt = parse_sql("UPDATE users SET name = 'Bob' WHERE id = 1").unwrap();
1625 match stmt {
1626 Statement::Update(upd) => {
1627 assert_eq!(upd.table, "users");
1628 assert_eq!(upd.assignments.len(), 1);
1629 assert_eq!(upd.assignments[0].0, "name");
1630 assert!(upd.where_clause.is_some());
1631 }
1632 _ => panic!("expected Update"),
1633 }
1634 }
1635
1636 #[test]
1637 fn parse_delete() {
1638 let stmt = parse_sql("DELETE FROM users WHERE id = 1").unwrap();
1639 match stmt {
1640 Statement::Delete(del) => {
1641 assert_eq!(del.table, "users");
1642 assert!(del.where_clause.is_some());
1643 }
1644 _ => panic!("expected Delete"),
1645 }
1646 }
1647
1648 #[test]
1649 fn parse_aggregate() {
1650 let stmt = parse_sql("SELECT COUNT(*), SUM(age) FROM users").unwrap();
1651 match stmt {
1652 Statement::Select(sel) => {
1653 assert_eq!(sel.columns.len(), 2);
1654 match &sel.columns[0] {
1655 SelectColumn::Expr {
1656 expr: Expr::CountStar,
1657 ..
1658 } => {}
1659 other => panic!("expected CountStar, got {other:?}"),
1660 }
1661 }
1662 _ => panic!("expected Select"),
1663 }
1664 }
1665
1666 #[test]
1667 fn parse_group_by_having() {
1668 let stmt = parse_sql(
1669 "SELECT department, COUNT(*) FROM employees GROUP BY department HAVING COUNT(*) > 5",
1670 )
1671 .unwrap();
1672 match stmt {
1673 Statement::Select(sel) => {
1674 assert_eq!(sel.group_by.len(), 1);
1675 assert!(sel.having.is_some());
1676 }
1677 _ => panic!("expected Select"),
1678 }
1679 }
1680
1681 #[test]
1682 fn parse_expressions() {
1683 let stmt = parse_sql("SELECT id + 1, -price, NOT active FROM items").unwrap();
1684 match stmt {
1685 Statement::Select(sel) => {
1686 assert_eq!(sel.columns.len(), 3);
1687 match &sel.columns[0] {
1689 SelectColumn::Expr {
1690 expr: Expr::BinaryOp { op: BinOp::Add, .. },
1691 ..
1692 } => {}
1693 other => panic!("expected BinaryOp Add, got {other:?}"),
1694 }
1695 match &sel.columns[1] {
1697 SelectColumn::Expr {
1698 expr:
1699 Expr::UnaryOp {
1700 op: UnaryOp::Neg, ..
1701 },
1702 ..
1703 } => {}
1704 other => panic!("expected UnaryOp Neg, got {other:?}"),
1705 }
1706 match &sel.columns[2] {
1708 SelectColumn::Expr {
1709 expr:
1710 Expr::UnaryOp {
1711 op: UnaryOp::Not, ..
1712 },
1713 ..
1714 } => {}
1715 other => panic!("expected UnaryOp Not, got {other:?}"),
1716 }
1717 }
1718 _ => panic!("expected Select"),
1719 }
1720 }
1721
1722 #[test]
1723 fn parse_is_null() {
1724 let stmt = parse_sql("SELECT * FROM t WHERE x IS NULL").unwrap();
1725 match stmt {
1726 Statement::Select(sel) => {
1727 assert!(matches!(sel.where_clause, Some(Expr::IsNull(_))));
1728 }
1729 _ => panic!("expected Select"),
1730 }
1731 }
1732
1733 #[test]
1734 fn parse_inner_join() {
1735 let stmt = parse_sql("SELECT * FROM a JOIN b ON a.id = b.id").unwrap();
1736 match stmt {
1737 Statement::Select(sel) => {
1738 assert_eq!(sel.from, "a");
1739 assert_eq!(sel.joins.len(), 1);
1740 assert_eq!(sel.joins[0].join_type, JoinType::Inner);
1741 assert_eq!(sel.joins[0].table.name, "b");
1742 assert!(sel.joins[0].on_clause.is_some());
1743 }
1744 _ => panic!("expected Select"),
1745 }
1746 }
1747
1748 #[test]
1749 fn parse_inner_join_explicit() {
1750 let stmt = parse_sql("SELECT * FROM a INNER JOIN b ON a.id = b.a_id").unwrap();
1751 match stmt {
1752 Statement::Select(sel) => {
1753 assert_eq!(sel.joins.len(), 1);
1754 assert_eq!(sel.joins[0].join_type, JoinType::Inner);
1755 }
1756 _ => panic!("expected Select"),
1757 }
1758 }
1759
1760 #[test]
1761 fn parse_cross_join() {
1762 let stmt = parse_sql("SELECT * FROM a CROSS JOIN b").unwrap();
1763 match stmt {
1764 Statement::Select(sel) => {
1765 assert_eq!(sel.joins.len(), 1);
1766 assert_eq!(sel.joins[0].join_type, JoinType::Cross);
1767 assert!(sel.joins[0].on_clause.is_none());
1768 }
1769 _ => panic!("expected Select"),
1770 }
1771 }
1772
1773 #[test]
1774 fn parse_left_join() {
1775 let stmt = parse_sql("SELECT * FROM a LEFT JOIN b ON a.id = b.a_id").unwrap();
1776 match stmt {
1777 Statement::Select(sel) => {
1778 assert_eq!(sel.joins.len(), 1);
1779 assert_eq!(sel.joins[0].join_type, JoinType::Left);
1780 }
1781 _ => panic!("expected Select"),
1782 }
1783 }
1784
1785 #[test]
1786 fn parse_table_alias() {
1787 let stmt = parse_sql("SELECT u.id FROM users u JOIN orders o ON u.id = o.user_id").unwrap();
1788 match stmt {
1789 Statement::Select(sel) => {
1790 assert_eq!(sel.from, "users");
1791 assert_eq!(sel.from_alias.as_deref(), Some("u"));
1792 assert_eq!(sel.joins[0].table.name, "orders");
1793 assert_eq!(sel.joins[0].table.alias.as_deref(), Some("o"));
1794 }
1795 _ => panic!("expected Select"),
1796 }
1797 }
1798
1799 #[test]
1800 fn parse_multi_join() {
1801 let stmt =
1802 parse_sql("SELECT * FROM a JOIN b ON a.id = b.a_id JOIN c ON b.id = c.b_id").unwrap();
1803 match stmt {
1804 Statement::Select(sel) => {
1805 assert_eq!(sel.joins.len(), 2);
1806 }
1807 _ => panic!("expected Select"),
1808 }
1809 }
1810
1811 #[test]
1812 fn parse_qualified_column() {
1813 let stmt = parse_sql("SELECT u.id, u.name FROM users u").unwrap();
1814 match stmt {
1815 Statement::Select(sel) => match &sel.columns[0] {
1816 SelectColumn::Expr {
1817 expr: Expr::QualifiedColumn { table, column },
1818 ..
1819 } => {
1820 assert_eq!(table, "u");
1821 assert_eq!(column, "id");
1822 }
1823 other => panic!("expected QualifiedColumn, got {other:?}"),
1824 },
1825 _ => panic!("expected Select"),
1826 }
1827 }
1828
1829 #[test]
1830 fn reject_subquery() {
1831 assert!(parse_sql("SELECT * FROM (SELECT 1)").is_err());
1832 }
1833
1834 #[test]
1835 fn parse_type_mapping() {
1836 let stmt = parse_sql(
1837 "CREATE TABLE t (a INT PRIMARY KEY, b BIGINT, c SMALLINT, d REAL, e DOUBLE PRECISION, f VARCHAR(255), g BOOLEAN, h BLOB, i BYTEA)"
1838 ).unwrap();
1839 match stmt {
1840 Statement::CreateTable(ct) => {
1841 assert_eq!(ct.columns[0].data_type, DataType::Integer); assert_eq!(ct.columns[1].data_type, DataType::Integer); assert_eq!(ct.columns[2].data_type, DataType::Integer); assert_eq!(ct.columns[3].data_type, DataType::Real); assert_eq!(ct.columns[4].data_type, DataType::Real); assert_eq!(ct.columns[5].data_type, DataType::Text); assert_eq!(ct.columns[6].data_type, DataType::Boolean); assert_eq!(ct.columns[7].data_type, DataType::Blob); assert_eq!(ct.columns[8].data_type, DataType::Blob); }
1851 _ => panic!("expected CreateTable"),
1852 }
1853 }
1854
1855 #[test]
1856 fn parse_boolean_literals() {
1857 let stmt = parse_sql("INSERT INTO t (a, b) VALUES (true, false)").unwrap();
1858 match stmt {
1859 Statement::Insert(ins) => {
1860 assert!(matches!(
1861 ins.values[0][0],
1862 Expr::Literal(Value::Boolean(true))
1863 ));
1864 assert!(matches!(
1865 ins.values[0][1],
1866 Expr::Literal(Value::Boolean(false))
1867 ));
1868 }
1869 _ => panic!("expected Insert"),
1870 }
1871 }
1872
1873 #[test]
1874 fn parse_null_literal() {
1875 let stmt = parse_sql("INSERT INTO t (a) VALUES (NULL)").unwrap();
1876 match stmt {
1877 Statement::Insert(ins) => {
1878 assert!(matches!(ins.values[0][0], Expr::Literal(Value::Null)));
1879 }
1880 _ => panic!("expected Insert"),
1881 }
1882 }
1883
1884 #[test]
1885 fn parse_alias() {
1886 let stmt = parse_sql("SELECT id AS user_id FROM users").unwrap();
1887 match stmt {
1888 Statement::Select(sel) => match &sel.columns[0] {
1889 SelectColumn::Expr { alias: Some(a), .. } => assert_eq!(a, "user_id"),
1890 other => panic!("expected alias, got {other:?}"),
1891 },
1892 _ => panic!("expected Select"),
1893 }
1894 }
1895
1896 #[test]
1897 fn parse_begin() {
1898 let stmt = parse_sql("BEGIN").unwrap();
1899 assert!(matches!(stmt, Statement::Begin));
1900 }
1901
1902 #[test]
1903 fn parse_begin_transaction() {
1904 let stmt = parse_sql("BEGIN TRANSACTION").unwrap();
1905 assert!(matches!(stmt, Statement::Begin));
1906 }
1907
1908 #[test]
1909 fn parse_commit() {
1910 let stmt = parse_sql("COMMIT").unwrap();
1911 assert!(matches!(stmt, Statement::Commit));
1912 }
1913
1914 #[test]
1915 fn parse_rollback() {
1916 let stmt = parse_sql("ROLLBACK").unwrap();
1917 assert!(matches!(stmt, Statement::Rollback));
1918 }
1919
1920 #[test]
1921 fn parse_select_distinct() {
1922 let stmt = parse_sql("SELECT DISTINCT name FROM users").unwrap();
1923 match stmt {
1924 Statement::Select(sel) => {
1925 assert!(sel.distinct);
1926 assert_eq!(sel.columns.len(), 1);
1927 }
1928 _ => panic!("expected Select"),
1929 }
1930 }
1931
1932 #[test]
1933 fn parse_select_without_distinct() {
1934 let stmt = parse_sql("SELECT name FROM users").unwrap();
1935 match stmt {
1936 Statement::Select(sel) => {
1937 assert!(!sel.distinct);
1938 }
1939 _ => panic!("expected Select"),
1940 }
1941 }
1942
1943 #[test]
1944 fn parse_select_distinct_all_columns() {
1945 let stmt = parse_sql("SELECT DISTINCT * FROM users").unwrap();
1946 match stmt {
1947 Statement::Select(sel) => {
1948 assert!(sel.distinct);
1949 assert!(matches!(sel.columns[0], SelectColumn::AllColumns));
1950 }
1951 _ => panic!("expected Select"),
1952 }
1953 }
1954
1955 #[test]
1956 fn reject_distinct_on() {
1957 assert!(parse_sql("SELECT DISTINCT ON (id) * FROM users").is_err());
1958 }
1959
1960 #[test]
1961 fn parse_create_index() {
1962 let stmt = parse_sql("CREATE INDEX idx_name ON users (name)").unwrap();
1963 match stmt {
1964 Statement::CreateIndex(ci) => {
1965 assert_eq!(ci.index_name, "idx_name");
1966 assert_eq!(ci.table_name, "users");
1967 assert_eq!(ci.columns, vec!["name"]);
1968 assert!(!ci.unique);
1969 assert!(!ci.if_not_exists);
1970 }
1971 _ => panic!("expected CreateIndex"),
1972 }
1973 }
1974
1975 #[test]
1976 fn parse_create_unique_index() {
1977 let stmt = parse_sql("CREATE UNIQUE INDEX idx_email ON users (email)").unwrap();
1978 match stmt {
1979 Statement::CreateIndex(ci) => {
1980 assert!(ci.unique);
1981 assert_eq!(ci.columns, vec!["email"]);
1982 }
1983 _ => panic!("expected CreateIndex"),
1984 }
1985 }
1986
1987 #[test]
1988 fn parse_create_index_if_not_exists() {
1989 let stmt = parse_sql("CREATE INDEX IF NOT EXISTS idx_x ON t (a)").unwrap();
1990 match stmt {
1991 Statement::CreateIndex(ci) => assert!(ci.if_not_exists),
1992 _ => panic!("expected CreateIndex"),
1993 }
1994 }
1995
1996 #[test]
1997 fn parse_create_index_multi_column() {
1998 let stmt = parse_sql("CREATE INDEX idx_multi ON t (a, b, c)").unwrap();
1999 match stmt {
2000 Statement::CreateIndex(ci) => {
2001 assert_eq!(ci.columns, vec!["a", "b", "c"]);
2002 }
2003 _ => panic!("expected CreateIndex"),
2004 }
2005 }
2006
2007 #[test]
2008 fn parse_drop_index() {
2009 let stmt = parse_sql("DROP INDEX idx_name").unwrap();
2010 match stmt {
2011 Statement::DropIndex(di) => {
2012 assert_eq!(di.index_name, "idx_name");
2013 assert!(!di.if_exists);
2014 }
2015 _ => panic!("expected DropIndex"),
2016 }
2017 }
2018
2019 #[test]
2020 fn parse_drop_index_if_exists() {
2021 let stmt = parse_sql("DROP INDEX IF EXISTS idx_name").unwrap();
2022 match stmt {
2023 Statement::DropIndex(di) => {
2024 assert!(di.if_exists);
2025 }
2026 _ => panic!("expected DropIndex"),
2027 }
2028 }
2029
2030 #[test]
2031 fn parse_explain_select() {
2032 let stmt = parse_sql("EXPLAIN SELECT * FROM users WHERE id = 1").unwrap();
2033 match stmt {
2034 Statement::Explain(inner) => {
2035 assert!(matches!(*inner, Statement::Select(_)));
2036 }
2037 _ => panic!("expected Explain"),
2038 }
2039 }
2040
2041 #[test]
2042 fn parse_explain_insert() {
2043 let stmt = parse_sql("EXPLAIN INSERT INTO t (a) VALUES (1)").unwrap();
2044 assert!(matches!(stmt, Statement::Explain(_)));
2045 }
2046
2047 #[test]
2048 fn reject_explain_analyze() {
2049 assert!(parse_sql("EXPLAIN ANALYZE SELECT * FROM t").is_err());
2050 }
2051
2052 #[test]
2053 fn parse_parameter_placeholder() {
2054 let stmt = parse_sql("SELECT * FROM t WHERE id = $1").unwrap();
2055 match stmt {
2056 Statement::Select(sel) => match &sel.where_clause {
2057 Some(Expr::BinaryOp { right, .. }) => {
2058 assert!(matches!(right.as_ref(), Expr::Parameter(1)));
2059 }
2060 other => panic!("expected BinaryOp with Parameter, got {other:?}"),
2061 },
2062 _ => panic!("expected Select"),
2063 }
2064 }
2065
2066 #[test]
2067 fn parse_multiple_parameters() {
2068 let stmt = parse_sql("INSERT INTO t (a, b) VALUES ($1, $2)").unwrap();
2069 match stmt {
2070 Statement::Insert(ins) => {
2071 assert!(matches!(ins.values[0][0], Expr::Parameter(1)));
2072 assert!(matches!(ins.values[0][1], Expr::Parameter(2)));
2073 }
2074 _ => panic!("expected Insert"),
2075 }
2076 }
2077
2078 #[test]
2079 fn reject_zero_parameter() {
2080 assert!(parse_sql("SELECT $0 FROM t").is_err());
2081 }
2082
2083 #[test]
2084 fn count_params_basic() {
2085 let stmt = parse_sql("SELECT * FROM t WHERE a = $1 AND b = $3").unwrap();
2086 assert_eq!(count_params(&stmt), 3);
2087 }
2088
2089 #[test]
2090 fn count_params_none() {
2091 let stmt = parse_sql("SELECT * FROM t WHERE a = 1").unwrap();
2092 assert_eq!(count_params(&stmt), 0);
2093 }
2094
2095 #[test]
2096 fn bind_params_basic() {
2097 let stmt = parse_sql("SELECT * FROM t WHERE id = $1").unwrap();
2098 let bound = bind_params(&stmt, &[Value::Integer(42)]).unwrap();
2099 match bound {
2100 Statement::Select(sel) => match &sel.where_clause {
2101 Some(Expr::BinaryOp { right, .. }) => {
2102 assert!(matches!(right.as_ref(), Expr::Literal(Value::Integer(42))));
2103 }
2104 other => panic!("expected BinaryOp with Literal, got {other:?}"),
2105 },
2106 _ => panic!("expected Select"),
2107 }
2108 }
2109
2110 #[test]
2111 fn bind_params_out_of_range() {
2112 let stmt = parse_sql("SELECT * FROM t WHERE id = $2").unwrap();
2113 let result = bind_params(&stmt, &[Value::Integer(1)]);
2114 assert!(result.is_err());
2115 }
2116
2117 #[test]
2118 fn parse_table_constraint_pk() {
2119 let stmt = parse_sql("CREATE TABLE t (a INTEGER, b TEXT, PRIMARY KEY (a))").unwrap();
2120 match stmt {
2121 Statement::CreateTable(ct) => {
2122 assert_eq!(ct.primary_key, vec!["a"]);
2123 assert!(ct.columns[0].is_primary_key);
2124 assert!(!ct.columns[0].nullable);
2125 }
2126 _ => panic!("expected CreateTable"),
2127 }
2128 }
2129}