1use sqlparser::ast as sp;
4use sqlparser::dialect::GenericDialect;
5use sqlparser::parser::Parser;
6
7use crate::error::{Result, SqlError};
8use crate::types::{DataType, Value};
9
10#[derive(Debug, Clone)]
13pub enum Statement {
14 CreateTable(CreateTableStmt),
15 DropTable(DropTableStmt),
16 CreateIndex(CreateIndexStmt),
17 DropIndex(DropIndexStmt),
18 Insert(InsertStmt),
19 Select(Box<SelectStmt>),
20 Update(UpdateStmt),
21 Delete(DeleteStmt),
22 Begin,
23 Commit,
24 Rollback,
25 Explain(Box<Statement>),
26}
27
28#[derive(Debug, Clone)]
29pub struct CreateTableStmt {
30 pub name: String,
31 pub columns: Vec<ColumnSpec>,
32 pub primary_key: Vec<String>,
33 pub if_not_exists: bool,
34}
35
36#[derive(Debug, Clone)]
37pub struct ColumnSpec {
38 pub name: String,
39 pub data_type: DataType,
40 pub nullable: bool,
41 pub is_primary_key: bool,
42}
43
44#[derive(Debug, Clone)]
45pub struct DropTableStmt {
46 pub name: String,
47 pub if_exists: bool,
48}
49
50#[derive(Debug, Clone)]
51pub struct CreateIndexStmt {
52 pub index_name: String,
53 pub table_name: String,
54 pub columns: Vec<String>,
55 pub unique: bool,
56 pub if_not_exists: bool,
57}
58
59#[derive(Debug, Clone)]
60pub struct DropIndexStmt {
61 pub index_name: String,
62 pub if_exists: bool,
63}
64
65#[derive(Debug, Clone)]
66pub struct InsertStmt {
67 pub table: String,
68 pub columns: Vec<String>,
69 pub values: Vec<Vec<Expr>>,
70}
71
72#[derive(Debug, Clone)]
73pub struct TableRef {
74 pub name: String,
75 pub alias: Option<String>,
76}
77
78#[derive(Debug, Clone, Copy, PartialEq)]
79pub enum JoinType {
80 Inner,
81 Cross,
82 Left,
83 Right,
84}
85
86#[derive(Debug, Clone)]
87pub struct JoinClause {
88 pub join_type: JoinType,
89 pub table: TableRef,
90 pub on_clause: Option<Expr>,
91}
92
93#[derive(Debug, Clone)]
94pub struct SelectStmt {
95 pub columns: Vec<SelectColumn>,
96 pub from: String,
97 pub from_alias: Option<String>,
98 pub joins: Vec<JoinClause>,
99 pub distinct: bool,
100 pub where_clause: Option<Expr>,
101 pub order_by: Vec<OrderByItem>,
102 pub limit: Option<Expr>,
103 pub offset: Option<Expr>,
104 pub group_by: Vec<Expr>,
105 pub having: Option<Expr>,
106}
107
108#[derive(Debug, Clone)]
109pub struct UpdateStmt {
110 pub table: String,
111 pub assignments: Vec<(String, Expr)>,
112 pub where_clause: Option<Expr>,
113}
114
115#[derive(Debug, Clone)]
116pub struct DeleteStmt {
117 pub table: String,
118 pub where_clause: Option<Expr>,
119}
120
121#[derive(Debug, Clone)]
122pub enum SelectColumn {
123 AllColumns,
124 Expr { expr: Expr, alias: Option<String> },
125}
126
127#[derive(Debug, Clone)]
128pub struct OrderByItem {
129 pub expr: Expr,
130 pub descending: bool,
131 pub nulls_first: Option<bool>,
132}
133
134#[derive(Debug, Clone)]
135pub enum Expr {
136 Literal(Value),
137 Column(String),
138 QualifiedColumn {
139 table: String,
140 column: String,
141 },
142 BinaryOp {
143 left: Box<Expr>,
144 op: BinOp,
145 right: Box<Expr>,
146 },
147 UnaryOp {
148 op: UnaryOp,
149 expr: Box<Expr>,
150 },
151 IsNull(Box<Expr>),
152 IsNotNull(Box<Expr>),
153 Function {
154 name: String,
155 args: Vec<Expr>,
156 },
157 CountStar,
158 InSubquery {
159 expr: Box<Expr>,
160 subquery: Box<SelectStmt>,
161 negated: bool,
162 },
163 InList {
164 expr: Box<Expr>,
165 list: Vec<Expr>,
166 negated: bool,
167 },
168 Exists {
169 subquery: Box<SelectStmt>,
170 negated: bool,
171 },
172 ScalarSubquery(Box<SelectStmt>),
173 InSet {
174 expr: Box<Expr>,
175 values: std::collections::HashSet<Value>,
176 has_null: bool,
177 negated: bool,
178 },
179 Between {
180 expr: Box<Expr>,
181 low: Box<Expr>,
182 high: Box<Expr>,
183 negated: bool,
184 },
185 Like {
186 expr: Box<Expr>,
187 pattern: Box<Expr>,
188 escape: Option<Box<Expr>>,
189 negated: bool,
190 },
191 Case {
192 operand: Option<Box<Expr>>,
193 conditions: Vec<(Expr, Expr)>,
194 else_result: Option<Box<Expr>>,
195 },
196 Coalesce(Vec<Expr>),
197 Cast {
198 expr: Box<Expr>,
199 data_type: DataType,
200 },
201 Parameter(usize),
202}
203
204#[derive(Debug, Clone, Copy, PartialEq, Eq)]
205pub enum BinOp {
206 Add,
207 Sub,
208 Mul,
209 Div,
210 Mod,
211 Eq,
212 NotEq,
213 Lt,
214 Gt,
215 LtEq,
216 GtEq,
217 And,
218 Or,
219 Concat,
220}
221
222#[derive(Debug, Clone, Copy, PartialEq, Eq)]
223pub enum UnaryOp {
224 Neg,
225 Not,
226}
227
228pub fn parse_sql(sql: &str) -> Result<Statement> {
231 let dialect = GenericDialect {};
232 let stmts = Parser::parse_sql(&dialect, sql).map_err(|e| SqlError::Parse(e.to_string()))?;
233
234 if stmts.is_empty() {
235 return Err(SqlError::Parse("empty SQL".into()));
236 }
237 if stmts.len() > 1 {
238 return Err(SqlError::Unsupported("multiple statements".into()));
239 }
240
241 convert_statement(stmts.into_iter().next().unwrap())
242}
243
244pub fn count_params(stmt: &Statement) -> usize {
248 let mut max_idx = 0usize;
249 visit_exprs_stmt(stmt, &mut |e| {
250 if let Expr::Parameter(n) = e {
251 max_idx = max_idx.max(*n);
252 }
253 });
254 max_idx
255}
256
257pub fn bind_params(
259 stmt: &Statement,
260 params: &[crate::types::Value],
261) -> crate::error::Result<Statement> {
262 bind_stmt(stmt, params)
263}
264
265fn bind_stmt(stmt: &Statement, params: &[crate::types::Value]) -> crate::error::Result<Statement> {
266 match stmt {
267 Statement::Select(sel) => Ok(Statement::Select(Box::new(bind_select(sel, params)?))),
268 Statement::Insert(ins) => {
269 let values = ins
270 .values
271 .iter()
272 .map(|row| {
273 row.iter()
274 .map(|e| bind_expr(e, params))
275 .collect::<crate::error::Result<Vec<_>>>()
276 })
277 .collect::<crate::error::Result<Vec<_>>>()?;
278 Ok(Statement::Insert(InsertStmt {
279 table: ins.table.clone(),
280 columns: ins.columns.clone(),
281 values,
282 }))
283 }
284 Statement::Update(upd) => {
285 let assignments = upd
286 .assignments
287 .iter()
288 .map(|(col, e)| Ok((col.clone(), bind_expr(e, params)?)))
289 .collect::<crate::error::Result<Vec<_>>>()?;
290 let where_clause = upd
291 .where_clause
292 .as_ref()
293 .map(|e| bind_expr(e, params))
294 .transpose()?;
295 Ok(Statement::Update(UpdateStmt {
296 table: upd.table.clone(),
297 assignments,
298 where_clause,
299 }))
300 }
301 Statement::Delete(del) => {
302 let where_clause = del
303 .where_clause
304 .as_ref()
305 .map(|e| bind_expr(e, params))
306 .transpose()?;
307 Ok(Statement::Delete(DeleteStmt {
308 table: del.table.clone(),
309 where_clause,
310 }))
311 }
312 Statement::Explain(inner) => Ok(Statement::Explain(Box::new(bind_stmt(inner, params)?))),
313 other => Ok(other.clone()),
314 }
315}
316
317fn bind_select(
318 sel: &SelectStmt,
319 params: &[crate::types::Value],
320) -> crate::error::Result<SelectStmt> {
321 let columns = sel
322 .columns
323 .iter()
324 .map(|c| match c {
325 SelectColumn::AllColumns => Ok(SelectColumn::AllColumns),
326 SelectColumn::Expr { expr, alias } => Ok(SelectColumn::Expr {
327 expr: bind_expr(expr, params)?,
328 alias: alias.clone(),
329 }),
330 })
331 .collect::<crate::error::Result<Vec<_>>>()?;
332 let joins = sel
333 .joins
334 .iter()
335 .map(|j| {
336 let on_clause = j
337 .on_clause
338 .as_ref()
339 .map(|e| bind_expr(e, params))
340 .transpose()?;
341 Ok(JoinClause {
342 join_type: j.join_type,
343 table: j.table.clone(),
344 on_clause,
345 })
346 })
347 .collect::<crate::error::Result<Vec<_>>>()?;
348 let where_clause = sel
349 .where_clause
350 .as_ref()
351 .map(|e| bind_expr(e, params))
352 .transpose()?;
353 let order_by = sel
354 .order_by
355 .iter()
356 .map(|o| {
357 Ok(OrderByItem {
358 expr: bind_expr(&o.expr, params)?,
359 descending: o.descending,
360 nulls_first: o.nulls_first,
361 })
362 })
363 .collect::<crate::error::Result<Vec<_>>>()?;
364 let limit = sel
365 .limit
366 .as_ref()
367 .map(|e| bind_expr(e, params))
368 .transpose()?;
369 let offset = sel
370 .offset
371 .as_ref()
372 .map(|e| bind_expr(e, params))
373 .transpose()?;
374 let group_by = sel
375 .group_by
376 .iter()
377 .map(|e| bind_expr(e, params))
378 .collect::<crate::error::Result<Vec<_>>>()?;
379 let having = sel
380 .having
381 .as_ref()
382 .map(|e| bind_expr(e, params))
383 .transpose()?;
384
385 Ok(SelectStmt {
386 columns,
387 from: sel.from.clone(),
388 from_alias: sel.from_alias.clone(),
389 joins,
390 distinct: sel.distinct,
391 where_clause,
392 order_by,
393 limit,
394 offset,
395 group_by,
396 having,
397 })
398}
399
400fn bind_expr(expr: &Expr, params: &[crate::types::Value]) -> crate::error::Result<Expr> {
401 match expr {
402 Expr::Parameter(n) => {
403 if *n == 0 || *n > params.len() {
404 return Err(SqlError::ParameterCountMismatch {
405 expected: *n,
406 got: params.len(),
407 });
408 }
409 Ok(Expr::Literal(params[*n - 1].clone()))
410 }
411 Expr::Literal(_) | Expr::Column(_) | Expr::QualifiedColumn { .. } | Expr::CountStar => {
412 Ok(expr.clone())
413 }
414 Expr::BinaryOp { left, op, right } => Ok(Expr::BinaryOp {
415 left: Box::new(bind_expr(left, params)?),
416 op: *op,
417 right: Box::new(bind_expr(right, params)?),
418 }),
419 Expr::UnaryOp { op, expr: e } => Ok(Expr::UnaryOp {
420 op: *op,
421 expr: Box::new(bind_expr(e, params)?),
422 }),
423 Expr::IsNull(e) => Ok(Expr::IsNull(Box::new(bind_expr(e, params)?))),
424 Expr::IsNotNull(e) => Ok(Expr::IsNotNull(Box::new(bind_expr(e, params)?))),
425 Expr::Function { name, args } => {
426 let args = args
427 .iter()
428 .map(|a| bind_expr(a, params))
429 .collect::<crate::error::Result<Vec<_>>>()?;
430 Ok(Expr::Function {
431 name: name.clone(),
432 args,
433 })
434 }
435 Expr::InSubquery {
436 expr: e,
437 subquery,
438 negated,
439 } => Ok(Expr::InSubquery {
440 expr: Box::new(bind_expr(e, params)?),
441 subquery: Box::new(bind_select(subquery, params)?),
442 negated: *negated,
443 }),
444 Expr::InList {
445 expr: e,
446 list,
447 negated,
448 } => {
449 let list = list
450 .iter()
451 .map(|l| bind_expr(l, params))
452 .collect::<crate::error::Result<Vec<_>>>()?;
453 Ok(Expr::InList {
454 expr: Box::new(bind_expr(e, params)?),
455 list,
456 negated: *negated,
457 })
458 }
459 Expr::Exists { subquery, negated } => Ok(Expr::Exists {
460 subquery: Box::new(bind_select(subquery, params)?),
461 negated: *negated,
462 }),
463 Expr::ScalarSubquery(sq) => Ok(Expr::ScalarSubquery(Box::new(bind_select(sq, params)?))),
464 Expr::InSet {
465 expr: e,
466 values,
467 has_null,
468 negated,
469 } => Ok(Expr::InSet {
470 expr: Box::new(bind_expr(e, params)?),
471 values: values.clone(),
472 has_null: *has_null,
473 negated: *negated,
474 }),
475 Expr::Between {
476 expr: e,
477 low,
478 high,
479 negated,
480 } => Ok(Expr::Between {
481 expr: Box::new(bind_expr(e, params)?),
482 low: Box::new(bind_expr(low, params)?),
483 high: Box::new(bind_expr(high, params)?),
484 negated: *negated,
485 }),
486 Expr::Like {
487 expr: e,
488 pattern,
489 escape,
490 negated,
491 } => Ok(Expr::Like {
492 expr: Box::new(bind_expr(e, params)?),
493 pattern: Box::new(bind_expr(pattern, params)?),
494 escape: escape
495 .as_ref()
496 .map(|esc| bind_expr(esc, params).map(Box::new))
497 .transpose()?,
498 negated: *negated,
499 }),
500 Expr::Case {
501 operand,
502 conditions,
503 else_result,
504 } => {
505 let operand = operand
506 .as_ref()
507 .map(|e| bind_expr(e, params).map(Box::new))
508 .transpose()?;
509 let conditions = conditions
510 .iter()
511 .map(|(cond, then)| Ok((bind_expr(cond, params)?, bind_expr(then, params)?)))
512 .collect::<crate::error::Result<Vec<_>>>()?;
513 let else_result = else_result
514 .as_ref()
515 .map(|e| bind_expr(e, params).map(Box::new))
516 .transpose()?;
517 Ok(Expr::Case {
518 operand,
519 conditions,
520 else_result,
521 })
522 }
523 Expr::Coalesce(args) => {
524 let args = args
525 .iter()
526 .map(|a| bind_expr(a, params))
527 .collect::<crate::error::Result<Vec<_>>>()?;
528 Ok(Expr::Coalesce(args))
529 }
530 Expr::Cast { expr: e, data_type } => Ok(Expr::Cast {
531 expr: Box::new(bind_expr(e, params)?),
532 data_type: *data_type,
533 }),
534 }
535}
536
537fn visit_exprs_stmt(stmt: &Statement, visitor: &mut impl FnMut(&Expr)) {
538 match stmt {
539 Statement::Select(sel) => visit_exprs_select(sel, visitor),
540 Statement::Insert(ins) => {
541 for row in &ins.values {
542 for e in row {
543 visit_expr(e, visitor);
544 }
545 }
546 }
547 Statement::Update(upd) => {
548 for (_, e) in &upd.assignments {
549 visit_expr(e, visitor);
550 }
551 if let Some(w) = &upd.where_clause {
552 visit_expr(w, visitor);
553 }
554 }
555 Statement::Delete(del) => {
556 if let Some(w) = &del.where_clause {
557 visit_expr(w, visitor);
558 }
559 }
560 Statement::Explain(inner) => visit_exprs_stmt(inner, visitor),
561 _ => {}
562 }
563}
564
565fn visit_exprs_select(sel: &SelectStmt, visitor: &mut impl FnMut(&Expr)) {
566 for col in &sel.columns {
567 if let SelectColumn::Expr { expr, .. } = col {
568 visit_expr(expr, visitor);
569 }
570 }
571 for j in &sel.joins {
572 if let Some(on) = &j.on_clause {
573 visit_expr(on, visitor);
574 }
575 }
576 if let Some(w) = &sel.where_clause {
577 visit_expr(w, visitor);
578 }
579 for o in &sel.order_by {
580 visit_expr(&o.expr, visitor);
581 }
582 if let Some(l) = &sel.limit {
583 visit_expr(l, visitor);
584 }
585 if let Some(o) = &sel.offset {
586 visit_expr(o, visitor);
587 }
588 for g in &sel.group_by {
589 visit_expr(g, visitor);
590 }
591 if let Some(h) = &sel.having {
592 visit_expr(h, visitor);
593 }
594}
595
596fn visit_expr(expr: &Expr, visitor: &mut impl FnMut(&Expr)) {
597 visitor(expr);
598 match expr {
599 Expr::BinaryOp { left, right, .. } => {
600 visit_expr(left, visitor);
601 visit_expr(right, visitor);
602 }
603 Expr::UnaryOp { expr: e, .. } | Expr::IsNull(e) | Expr::IsNotNull(e) => {
604 visit_expr(e, visitor);
605 }
606 Expr::Function { args, .. } | Expr::Coalesce(args) => {
607 for a in args {
608 visit_expr(a, visitor);
609 }
610 }
611 Expr::InSubquery {
612 expr: e, subquery, ..
613 } => {
614 visit_expr(e, visitor);
615 visit_exprs_select(subquery, visitor);
616 }
617 Expr::InList { expr: e, list, .. } => {
618 visit_expr(e, visitor);
619 for l in list {
620 visit_expr(l, visitor);
621 }
622 }
623 Expr::Exists { subquery, .. } => visit_exprs_select(subquery, visitor),
624 Expr::ScalarSubquery(sq) => visit_exprs_select(sq, visitor),
625 Expr::InSet { expr: e, .. } => visit_expr(e, visitor),
626 Expr::Between {
627 expr: e, low, high, ..
628 } => {
629 visit_expr(e, visitor);
630 visit_expr(low, visitor);
631 visit_expr(high, visitor);
632 }
633 Expr::Like {
634 expr: e,
635 pattern,
636 escape,
637 ..
638 } => {
639 visit_expr(e, visitor);
640 visit_expr(pattern, visitor);
641 if let Some(esc) = escape {
642 visit_expr(esc, visitor);
643 }
644 }
645 Expr::Case {
646 operand,
647 conditions,
648 else_result,
649 } => {
650 if let Some(op) = operand {
651 visit_expr(op, visitor);
652 }
653 for (cond, then) in conditions {
654 visit_expr(cond, visitor);
655 visit_expr(then, visitor);
656 }
657 if let Some(el) = else_result {
658 visit_expr(el, visitor);
659 }
660 }
661 Expr::Cast { expr: e, .. } => visit_expr(e, visitor),
662 Expr::Literal(_)
663 | Expr::Column(_)
664 | Expr::QualifiedColumn { .. }
665 | Expr::CountStar
666 | Expr::Parameter(_) => {}
667 }
668}
669
670fn convert_statement(stmt: sp::Statement) -> Result<Statement> {
673 match stmt {
674 sp::Statement::CreateTable(ct) => convert_create_table(ct),
675 sp::Statement::CreateIndex(ci) => convert_create_index(ci),
676 sp::Statement::Drop {
677 object_type: sp::ObjectType::Table,
678 if_exists,
679 names,
680 ..
681 } => {
682 if names.len() != 1 {
683 return Err(SqlError::Unsupported("multi-table DROP".into()));
684 }
685 Ok(Statement::DropTable(DropTableStmt {
686 name: object_name_to_string(&names[0]),
687 if_exists,
688 }))
689 }
690 sp::Statement::Drop {
691 object_type: sp::ObjectType::Index,
692 if_exists,
693 names,
694 ..
695 } => {
696 if names.len() != 1 {
697 return Err(SqlError::Unsupported("multi-index DROP".into()));
698 }
699 Ok(Statement::DropIndex(DropIndexStmt {
700 index_name: object_name_to_string(&names[0]),
701 if_exists,
702 }))
703 }
704 sp::Statement::Insert(insert) => convert_insert(insert),
705 sp::Statement::Query(query) => convert_query(*query),
706 sp::Statement::Update(update) => convert_update(update),
707 sp::Statement::Delete(delete) => convert_delete(delete),
708 sp::Statement::StartTransaction { .. } => Ok(Statement::Begin),
709 sp::Statement::Commit { .. } => Ok(Statement::Commit),
710 sp::Statement::Rollback { .. } => Ok(Statement::Rollback),
711 sp::Statement::Explain {
712 statement, analyze, ..
713 } => {
714 if analyze {
715 return Err(SqlError::Unsupported("EXPLAIN ANALYZE".into()));
716 }
717 let inner = convert_statement(*statement)?;
718 Ok(Statement::Explain(Box::new(inner)))
719 }
720 _ => Err(SqlError::Unsupported(format!("statement type: {}", stmt))),
721 }
722}
723
724fn convert_create_table(ct: sp::CreateTable) -> Result<Statement> {
725 let name = object_name_to_string(&ct.name);
726 let if_not_exists = ct.if_not_exists;
727
728 let mut columns = Vec::new();
729 let mut inline_pk: Vec<String> = Vec::new();
730
731 for col_def in &ct.columns {
732 let col_name = col_def.name.value.clone();
733 let data_type = convert_data_type(&col_def.data_type)?;
734 let mut nullable = true;
735 let mut is_primary_key = false;
736
737 for opt in &col_def.options {
738 match &opt.option {
739 sp::ColumnOption::NotNull => nullable = false,
740 sp::ColumnOption::Null => nullable = true,
741 sp::ColumnOption::PrimaryKey(_) => {
742 is_primary_key = true;
743 nullable = false;
744 inline_pk.push(col_name.clone());
745 }
746 _ => {}
747 }
748 }
749
750 columns.push(ColumnSpec {
751 name: col_name,
752 data_type,
753 nullable,
754 is_primary_key,
755 });
756 }
757
758 for constraint in &ct.constraints {
760 if let sp::TableConstraint::PrimaryKey(pk_constraint) = constraint {
761 for idx_col in &pk_constraint.columns {
762 let col_name = match &idx_col.column.expr {
764 sp::Expr::Identifier(ident) => ident.value.clone(),
765 _ => continue,
766 };
767 if !inline_pk.contains(&col_name) {
768 inline_pk.push(col_name.clone());
769 }
770 if let Some(col) = columns.iter_mut().find(|c| c.name == col_name) {
771 col.nullable = false;
772 col.is_primary_key = true;
773 }
774 }
775 }
776 }
777
778 Ok(Statement::CreateTable(CreateTableStmt {
779 name,
780 columns,
781 primary_key: inline_pk,
782 if_not_exists,
783 }))
784}
785
786fn convert_create_index(ci: sp::CreateIndex) -> Result<Statement> {
787 let index_name = ci
788 .name
789 .as_ref()
790 .map(object_name_to_string)
791 .ok_or_else(|| SqlError::Parse("index name required".into()))?;
792
793 let table_name = object_name_to_string(&ci.table_name);
794
795 let columns: Vec<String> = ci
796 .columns
797 .iter()
798 .map(|idx_col| match &idx_col.column.expr {
799 sp::Expr::Identifier(ident) => Ok(ident.value.clone()),
800 other => Err(SqlError::Unsupported(format!("expression index: {other}"))),
801 })
802 .collect::<Result<_>>()?;
803
804 if columns.is_empty() {
805 return Err(SqlError::Parse(
806 "index must have at least one column".into(),
807 ));
808 }
809
810 Ok(Statement::CreateIndex(CreateIndexStmt {
811 index_name,
812 table_name,
813 columns,
814 unique: ci.unique,
815 if_not_exists: ci.if_not_exists,
816 }))
817}
818
819fn convert_insert(insert: sp::Insert) -> Result<Statement> {
820 let table = match &insert.table {
821 sp::TableObject::TableName(name) => object_name_to_string(name).to_ascii_lowercase(),
822 _ => return Err(SqlError::Unsupported("INSERT into non-table object".into())),
823 };
824
825 let columns: Vec<String> = insert
826 .columns
827 .iter()
828 .map(|c| c.value.to_ascii_lowercase())
829 .collect();
830
831 let source = insert
832 .source
833 .ok_or_else(|| SqlError::Parse("INSERT requires VALUES".into()))?;
834
835 let values = match *source.body {
836 sp::SetExpr::Values(sp::Values { rows, .. }) => {
837 let mut result = Vec::new();
838 for row in rows {
839 let mut exprs = Vec::new();
840 for expr in row {
841 exprs.push(convert_expr(&expr)?);
842 }
843 result.push(exprs);
844 }
845 result
846 }
847 _ => return Err(SqlError::Unsupported("INSERT ... SELECT".into())),
848 };
849
850 Ok(Statement::Insert(InsertStmt {
851 table,
852 columns,
853 values,
854 }))
855}
856
857fn convert_subquery(query: &sp::Query) -> Result<SelectStmt> {
858 match convert_query(query.clone())? {
859 Statement::Select(s) => Ok(*s),
860 _ => Err(SqlError::Unsupported("non-SELECT subquery".into())),
861 }
862}
863
864fn convert_query(query: sp::Query) -> Result<Statement> {
865 let select = match *query.body {
866 sp::SetExpr::Select(sel) => *sel,
867 _ => return Err(SqlError::Unsupported("UNION/INTERSECT/EXCEPT".into())),
868 };
869
870 let distinct = match &select.distinct {
871 Some(sp::Distinct::Distinct) => true,
872 Some(sp::Distinct::On(_)) => {
873 return Err(SqlError::Unsupported("DISTINCT ON".into()));
874 }
875 _ => false,
876 };
877
878 let (from, from_alias, joins) = if select.from.is_empty() {
880 (String::new(), None, vec![])
881 } else if select.from.len() == 1 {
882 let table_with_joins = &select.from[0];
883 let (name, alias) = match &table_with_joins.relation {
884 sp::TableFactor::Table { name, alias, .. } => {
885 let table_name = object_name_to_string(name);
886 let alias_str = alias.as_ref().map(|a| a.name.value.clone());
887 (table_name, alias_str)
888 }
889 _ => return Err(SqlError::Unsupported("non-table FROM source".into())),
890 };
891 let j = table_with_joins
892 .joins
893 .iter()
894 .map(convert_join)
895 .collect::<Result<Vec<_>>>()?;
896 (name, alias, j)
897 } else {
898 return Err(SqlError::Unsupported("comma-separated FROM tables".into()));
899 };
900
901 let columns: Vec<SelectColumn> = select
903 .projection
904 .iter()
905 .map(convert_select_item)
906 .collect::<Result<_>>()?;
907
908 let where_clause = select.selection.as_ref().map(convert_expr).transpose()?;
910
911 let order_by = if let Some(ref ob) = query.order_by {
913 match &ob.kind {
914 sp::OrderByKind::Expressions(exprs) => exprs
915 .iter()
916 .map(convert_order_by_expr)
917 .collect::<Result<_>>()?,
918 sp::OrderByKind::All { .. } => {
919 return Err(SqlError::Unsupported("ORDER BY ALL".into()));
920 }
921 }
922 } else {
923 vec![]
924 };
925
926 let (limit, offset) = match &query.limit_clause {
928 Some(sp::LimitClause::LimitOffset { limit, offset, .. }) => {
929 let l = limit.as_ref().map(convert_expr).transpose()?;
930 let o = offset
931 .as_ref()
932 .map(|o| convert_expr(&o.value))
933 .transpose()?;
934 (l, o)
935 }
936 Some(sp::LimitClause::OffsetCommaLimit { limit, offset }) => {
937 let l = Some(convert_expr(limit)?);
938 let o = Some(convert_expr(offset)?);
939 (l, o)
940 }
941 None => (None, None),
942 };
943
944 let group_by = match &select.group_by {
946 sp::GroupByExpr::Expressions(exprs, _) => {
947 exprs.iter().map(convert_expr).collect::<Result<_>>()?
948 }
949 sp::GroupByExpr::All(_) => {
950 return Err(SqlError::Unsupported("GROUP BY ALL".into()));
951 }
952 };
953
954 let having = select.having.as_ref().map(convert_expr).transpose()?;
956
957 Ok(Statement::Select(Box::new(SelectStmt {
958 columns,
959 from,
960 from_alias,
961 joins,
962 distinct,
963 where_clause,
964 order_by,
965 limit,
966 offset,
967 group_by,
968 having,
969 })))
970}
971
972fn convert_join(join: &sp::Join) -> Result<JoinClause> {
973 let (join_type, constraint) = match &join.join_operator {
974 sp::JoinOperator::Inner(c) => (JoinType::Inner, Some(c)),
975 sp::JoinOperator::Join(c) => (JoinType::Inner, Some(c)),
976 sp::JoinOperator::CrossJoin(c) => (JoinType::Cross, Some(c)),
977 sp::JoinOperator::Left(c) => (JoinType::Left, Some(c)),
978 sp::JoinOperator::LeftSemi(c) => (JoinType::Left, Some(c)),
979 sp::JoinOperator::LeftAnti(c) => (JoinType::Left, Some(c)),
980 sp::JoinOperator::Right(c) => (JoinType::Right, Some(c)),
981 sp::JoinOperator::RightSemi(c) => (JoinType::Right, Some(c)),
982 sp::JoinOperator::RightAnti(c) => (JoinType::Right, Some(c)),
983 other => return Err(SqlError::Unsupported(format!("join type: {other:?}"))),
984 };
985
986 let (name, alias) = match &join.relation {
987 sp::TableFactor::Table { name, alias, .. } => {
988 let table_name = object_name_to_string(name);
989 let alias_str = alias.as_ref().map(|a| a.name.value.clone());
990 (table_name, alias_str)
991 }
992 _ => return Err(SqlError::Unsupported("non-table JOIN source".into())),
993 };
994
995 let on_clause = match constraint {
996 Some(sp::JoinConstraint::On(expr)) => Some(convert_expr(expr)?),
997 Some(sp::JoinConstraint::None) | None => None,
998 Some(other) => return Err(SqlError::Unsupported(format!("join constraint: {other:?}"))),
999 };
1000
1001 Ok(JoinClause {
1002 join_type,
1003 table: TableRef { name, alias },
1004 on_clause,
1005 })
1006}
1007
1008fn convert_update(update: sp::Update) -> Result<Statement> {
1009 let table = match &update.table.relation {
1010 sp::TableFactor::Table { name, .. } => object_name_to_string(name),
1011 _ => return Err(SqlError::Unsupported("non-table UPDATE target".into())),
1012 };
1013
1014 let assignments = update
1015 .assignments
1016 .iter()
1017 .map(|a| {
1018 let col = match &a.target {
1019 sp::AssignmentTarget::ColumnName(name) => object_name_to_string(name),
1020 _ => return Err(SqlError::Unsupported("tuple assignment".into())),
1021 };
1022 let expr = convert_expr(&a.value)?;
1023 Ok((col, expr))
1024 })
1025 .collect::<Result<_>>()?;
1026
1027 let where_clause = update.selection.as_ref().map(convert_expr).transpose()?;
1028
1029 Ok(Statement::Update(UpdateStmt {
1030 table,
1031 assignments,
1032 where_clause,
1033 }))
1034}
1035
1036fn convert_delete(delete: sp::Delete) -> Result<Statement> {
1037 let table_name = match &delete.from {
1038 sp::FromTable::WithFromKeyword(tables) => {
1039 if tables.len() != 1 {
1040 return Err(SqlError::Unsupported("multi-table DELETE".into()));
1041 }
1042 match &tables[0].relation {
1043 sp::TableFactor::Table { name, .. } => object_name_to_string(name),
1044 _ => return Err(SqlError::Unsupported("non-table DELETE target".into())),
1045 }
1046 }
1047 sp::FromTable::WithoutKeyword(tables) => {
1048 if tables.len() != 1 {
1049 return Err(SqlError::Unsupported("multi-table DELETE".into()));
1050 }
1051 match &tables[0].relation {
1052 sp::TableFactor::Table { name, .. } => object_name_to_string(name),
1053 _ => return Err(SqlError::Unsupported("non-table DELETE target".into())),
1054 }
1055 }
1056 };
1057
1058 let where_clause = delete.selection.as_ref().map(convert_expr).transpose()?;
1059
1060 Ok(Statement::Delete(DeleteStmt {
1061 table: table_name,
1062 where_clause,
1063 }))
1064}
1065
1066fn convert_expr(expr: &sp::Expr) -> Result<Expr> {
1069 match expr {
1070 sp::Expr::Value(v) => convert_value(&v.value),
1071 sp::Expr::Identifier(ident) => Ok(Expr::Column(ident.value.to_ascii_lowercase())),
1072 sp::Expr::CompoundIdentifier(parts) => {
1073 if parts.len() == 2 {
1074 Ok(Expr::QualifiedColumn {
1075 table: parts[0].value.to_ascii_lowercase(),
1076 column: parts[1].value.to_ascii_lowercase(),
1077 })
1078 } else {
1079 Ok(Expr::Column(
1080 parts.last().unwrap().value.to_ascii_lowercase(),
1081 ))
1082 }
1083 }
1084 sp::Expr::BinaryOp { left, op, right } => {
1085 let bin_op = convert_bin_op(op)?;
1086 Ok(Expr::BinaryOp {
1087 left: Box::new(convert_expr(left)?),
1088 op: bin_op,
1089 right: Box::new(convert_expr(right)?),
1090 })
1091 }
1092 sp::Expr::UnaryOp { op, expr } => {
1093 let unary_op = match op {
1094 sp::UnaryOperator::Minus => UnaryOp::Neg,
1095 sp::UnaryOperator::Not => UnaryOp::Not,
1096 _ => return Err(SqlError::Unsupported(format!("unary op: {op}"))),
1097 };
1098 Ok(Expr::UnaryOp {
1099 op: unary_op,
1100 expr: Box::new(convert_expr(expr)?),
1101 })
1102 }
1103 sp::Expr::IsNull(e) => Ok(Expr::IsNull(Box::new(convert_expr(e)?))),
1104 sp::Expr::IsNotNull(e) => Ok(Expr::IsNotNull(Box::new(convert_expr(e)?))),
1105 sp::Expr::Nested(e) => convert_expr(e),
1106 sp::Expr::Function(func) => convert_function(func),
1107 sp::Expr::InSubquery {
1108 expr: e,
1109 subquery,
1110 negated,
1111 } => {
1112 let inner_expr = convert_expr(e)?;
1113 let stmt = convert_subquery(subquery)?;
1114 Ok(Expr::InSubquery {
1115 expr: Box::new(inner_expr),
1116 subquery: Box::new(stmt),
1117 negated: *negated,
1118 })
1119 }
1120 sp::Expr::InList {
1121 expr: e,
1122 list,
1123 negated,
1124 } => {
1125 let inner_expr = convert_expr(e)?;
1126 let items = list.iter().map(convert_expr).collect::<Result<Vec<_>>>()?;
1127 Ok(Expr::InList {
1128 expr: Box::new(inner_expr),
1129 list: items,
1130 negated: *negated,
1131 })
1132 }
1133 sp::Expr::Exists { subquery, negated } => {
1134 let stmt = convert_subquery(subquery)?;
1135 Ok(Expr::Exists {
1136 subquery: Box::new(stmt),
1137 negated: *negated,
1138 })
1139 }
1140 sp::Expr::Subquery(query) => {
1141 let stmt = convert_subquery(query)?;
1142 Ok(Expr::ScalarSubquery(Box::new(stmt)))
1143 }
1144 sp::Expr::Between {
1145 expr: e,
1146 negated,
1147 low,
1148 high,
1149 } => Ok(Expr::Between {
1150 expr: Box::new(convert_expr(e)?),
1151 low: Box::new(convert_expr(low)?),
1152 high: Box::new(convert_expr(high)?),
1153 negated: *negated,
1154 }),
1155 sp::Expr::Like {
1156 expr: e,
1157 negated,
1158 pattern,
1159 escape_char,
1160 ..
1161 } => {
1162 let esc = escape_char
1163 .as_ref()
1164 .map(convert_escape_value)
1165 .transpose()?
1166 .map(Box::new);
1167 Ok(Expr::Like {
1168 expr: Box::new(convert_expr(e)?),
1169 pattern: Box::new(convert_expr(pattern)?),
1170 escape: esc,
1171 negated: *negated,
1172 })
1173 }
1174 sp::Expr::ILike {
1175 expr: e,
1176 negated,
1177 pattern,
1178 escape_char,
1179 ..
1180 } => {
1181 let esc = escape_char
1182 .as_ref()
1183 .map(convert_escape_value)
1184 .transpose()?
1185 .map(Box::new);
1186 Ok(Expr::Like {
1187 expr: Box::new(convert_expr(e)?),
1188 pattern: Box::new(convert_expr(pattern)?),
1189 escape: esc,
1190 negated: *negated,
1191 })
1192 }
1193 sp::Expr::Case {
1194 operand,
1195 conditions,
1196 else_result,
1197 ..
1198 } => {
1199 let op = operand
1200 .as_ref()
1201 .map(|e| convert_expr(e))
1202 .transpose()?
1203 .map(Box::new);
1204 let conds: Vec<(Expr, Expr)> = conditions
1205 .iter()
1206 .map(|cw| Ok((convert_expr(&cw.condition)?, convert_expr(&cw.result)?)))
1207 .collect::<Result<_>>()?;
1208 let else_r = else_result
1209 .as_ref()
1210 .map(|e| convert_expr(e))
1211 .transpose()?
1212 .map(Box::new);
1213 Ok(Expr::Case {
1214 operand: op,
1215 conditions: conds,
1216 else_result: else_r,
1217 })
1218 }
1219 sp::Expr::Cast {
1220 expr: e,
1221 data_type: dt,
1222 ..
1223 } => {
1224 let target = convert_data_type(dt)?;
1225 Ok(Expr::Cast {
1226 expr: Box::new(convert_expr(e)?),
1227 data_type: target,
1228 })
1229 }
1230 sp::Expr::Substring {
1231 expr: e,
1232 substring_from,
1233 substring_for,
1234 ..
1235 } => {
1236 let mut args = vec![convert_expr(e)?];
1237 if let Some(from) = substring_from {
1238 args.push(convert_expr(from)?);
1239 }
1240 if let Some(f) = substring_for {
1241 args.push(convert_expr(f)?);
1242 }
1243 Ok(Expr::Function {
1244 name: "SUBSTR".into(),
1245 args,
1246 })
1247 }
1248 sp::Expr::Trim {
1249 expr: e,
1250 trim_where,
1251 trim_what,
1252 trim_characters,
1253 } => {
1254 let fn_name = match trim_where {
1255 Some(sp::TrimWhereField::Leading) => "LTRIM",
1256 Some(sp::TrimWhereField::Trailing) => "RTRIM",
1257 _ => "TRIM",
1258 };
1259 let mut args = vec![convert_expr(e)?];
1260 if let Some(what) = trim_what {
1261 args.push(convert_expr(what)?);
1262 } else if let Some(chars) = trim_characters {
1263 if let Some(first) = chars.first() {
1264 args.push(convert_expr(first)?);
1265 }
1266 }
1267 Ok(Expr::Function {
1268 name: fn_name.into(),
1269 args,
1270 })
1271 }
1272 sp::Expr::Ceil { expr: e, .. } => Ok(Expr::Function {
1273 name: "CEIL".into(),
1274 args: vec![convert_expr(e)?],
1275 }),
1276 sp::Expr::Floor { expr: e, .. } => Ok(Expr::Function {
1277 name: "FLOOR".into(),
1278 args: vec![convert_expr(e)?],
1279 }),
1280 sp::Expr::Position { expr: e, r#in } => Ok(Expr::Function {
1281 name: "INSTR".into(),
1282 args: vec![convert_expr(r#in)?, convert_expr(e)?],
1283 }),
1284 _ => Err(SqlError::Unsupported(format!("expression: {expr}"))),
1285 }
1286}
1287
1288fn convert_value(val: &sp::Value) -> Result<Expr> {
1289 match val {
1290 sp::Value::Number(n, _) => {
1291 if let Ok(i) = n.parse::<i64>() {
1292 Ok(Expr::Literal(Value::Integer(i)))
1293 } else if let Ok(f) = n.parse::<f64>() {
1294 Ok(Expr::Literal(Value::Real(f)))
1295 } else {
1296 Err(SqlError::InvalidValue(format!("cannot parse number: {n}")))
1297 }
1298 }
1299 sp::Value::SingleQuotedString(s) => Ok(Expr::Literal(Value::Text(s.as_str().into()))),
1300 sp::Value::Boolean(b) => Ok(Expr::Literal(Value::Boolean(*b))),
1301 sp::Value::Null => Ok(Expr::Literal(Value::Null)),
1302 sp::Value::Placeholder(s) => {
1303 let idx_str = s
1304 .strip_prefix('$')
1305 .ok_or_else(|| SqlError::Parse(format!("invalid placeholder: {s}")))?;
1306 let idx: usize = idx_str
1307 .parse()
1308 .map_err(|_| SqlError::Parse(format!("invalid placeholder index: {s}")))?;
1309 if idx == 0 {
1310 return Err(SqlError::Parse("placeholder index must be >= 1".into()));
1311 }
1312 Ok(Expr::Parameter(idx))
1313 }
1314 _ => Err(SqlError::Unsupported(format!("value type: {val}"))),
1315 }
1316}
1317
1318fn convert_escape_value(val: &sp::Value) -> Result<Expr> {
1319 match val {
1320 sp::Value::SingleQuotedString(s) => Ok(Expr::Literal(Value::Text(s.as_str().into()))),
1321 _ => Err(SqlError::Unsupported(format!("ESCAPE value: {val}"))),
1322 }
1323}
1324
1325fn convert_bin_op(op: &sp::BinaryOperator) -> Result<BinOp> {
1326 match op {
1327 sp::BinaryOperator::Plus => Ok(BinOp::Add),
1328 sp::BinaryOperator::Minus => Ok(BinOp::Sub),
1329 sp::BinaryOperator::Multiply => Ok(BinOp::Mul),
1330 sp::BinaryOperator::Divide => Ok(BinOp::Div),
1331 sp::BinaryOperator::Modulo => Ok(BinOp::Mod),
1332 sp::BinaryOperator::Eq => Ok(BinOp::Eq),
1333 sp::BinaryOperator::NotEq => Ok(BinOp::NotEq),
1334 sp::BinaryOperator::Lt => Ok(BinOp::Lt),
1335 sp::BinaryOperator::Gt => Ok(BinOp::Gt),
1336 sp::BinaryOperator::LtEq => Ok(BinOp::LtEq),
1337 sp::BinaryOperator::GtEq => Ok(BinOp::GtEq),
1338 sp::BinaryOperator::And => Ok(BinOp::And),
1339 sp::BinaryOperator::Or => Ok(BinOp::Or),
1340 sp::BinaryOperator::StringConcat => Ok(BinOp::Concat),
1341 _ => Err(SqlError::Unsupported(format!("binary op: {op}"))),
1342 }
1343}
1344
1345fn convert_function(func: &sp::Function) -> Result<Expr> {
1346 let name = object_name_to_string(&func.name).to_ascii_uppercase();
1347
1348 match &func.args {
1350 sp::FunctionArguments::List(list) => {
1351 if list.args.is_empty() && name == "COUNT" {
1352 return Ok(Expr::CountStar);
1353 }
1354 let args = list
1355 .args
1356 .iter()
1357 .map(|arg| match arg {
1358 sp::FunctionArg::Unnamed(sp::FunctionArgExpr::Expr(e)) => convert_expr(e),
1359 sp::FunctionArg::Unnamed(sp::FunctionArgExpr::Wildcard) => {
1360 if name == "COUNT" {
1361 Ok(Expr::CountStar)
1362 } else {
1363 Err(SqlError::Unsupported(format!("{name}(*)")))
1364 }
1365 }
1366 _ => Err(SqlError::Unsupported(format!(
1367 "function arg type in {name}"
1368 ))),
1369 })
1370 .collect::<Result<Vec<_>>>()?;
1371
1372 if name == "COUNT" && args.len() == 1 && matches!(args[0], Expr::CountStar) {
1373 return Ok(Expr::CountStar);
1374 }
1375
1376 if name == "COALESCE" {
1377 if args.is_empty() {
1378 return Err(SqlError::Parse(
1379 "COALESCE requires at least one argument".into(),
1380 ));
1381 }
1382 return Ok(Expr::Coalesce(args));
1383 }
1384
1385 if name == "NULLIF" {
1386 if args.len() != 2 {
1387 return Err(SqlError::Parse(
1388 "NULLIF requires exactly two arguments".into(),
1389 ));
1390 }
1391 return Ok(Expr::Case {
1392 operand: None,
1393 conditions: vec![(
1394 Expr::BinaryOp {
1395 left: Box::new(args[0].clone()),
1396 op: BinOp::Eq,
1397 right: Box::new(args[1].clone()),
1398 },
1399 Expr::Literal(Value::Null),
1400 )],
1401 else_result: Some(Box::new(args[0].clone())),
1402 });
1403 }
1404
1405 if name == "IIF" {
1406 if args.len() != 3 {
1407 return Err(SqlError::Parse(
1408 "IIF requires exactly three arguments".into(),
1409 ));
1410 }
1411 return Ok(Expr::Case {
1412 operand: None,
1413 conditions: vec![(args[0].clone(), args[1].clone())],
1414 else_result: Some(Box::new(args[2].clone())),
1415 });
1416 }
1417
1418 Ok(Expr::Function { name, args })
1419 }
1420 sp::FunctionArguments::None => {
1421 if name == "COUNT" {
1422 Ok(Expr::CountStar)
1423 } else {
1424 Ok(Expr::Function { name, args: vec![] })
1425 }
1426 }
1427 sp::FunctionArguments::Subquery(_) => {
1428 Err(SqlError::Unsupported("subquery in function".into()))
1429 }
1430 }
1431}
1432
1433fn convert_select_item(item: &sp::SelectItem) -> Result<SelectColumn> {
1434 match item {
1435 sp::SelectItem::Wildcard(_) => Ok(SelectColumn::AllColumns),
1436 sp::SelectItem::UnnamedExpr(e) => {
1437 let expr = convert_expr(e)?;
1438 Ok(SelectColumn::Expr { expr, alias: None })
1439 }
1440 sp::SelectItem::ExprWithAlias { expr, alias } => {
1441 let expr = convert_expr(expr)?;
1442 Ok(SelectColumn::Expr {
1443 expr,
1444 alias: Some(alias.value.clone()),
1445 })
1446 }
1447 sp::SelectItem::QualifiedWildcard(_, _) => {
1448 Err(SqlError::Unsupported("qualified wildcard (table.*)".into()))
1449 }
1450 }
1451}
1452
1453fn convert_order_by_expr(expr: &sp::OrderByExpr) -> Result<OrderByItem> {
1454 let e = convert_expr(&expr.expr)?;
1455 let descending = expr.options.asc.map(|asc| !asc).unwrap_or(false);
1456 let nulls_first = expr.options.nulls_first;
1457
1458 Ok(OrderByItem {
1459 expr: e,
1460 descending,
1461 nulls_first,
1462 })
1463}
1464
1465fn convert_data_type(dt: &sp::DataType) -> Result<DataType> {
1468 match dt {
1469 sp::DataType::Int(_)
1470 | sp::DataType::Integer(_)
1471 | sp::DataType::BigInt(_)
1472 | sp::DataType::SmallInt(_)
1473 | sp::DataType::TinyInt(_)
1474 | sp::DataType::Int2(_)
1475 | sp::DataType::Int4(_)
1476 | sp::DataType::Int8(_) => Ok(DataType::Integer),
1477
1478 sp::DataType::Real
1479 | sp::DataType::Double(..)
1480 | sp::DataType::DoublePrecision
1481 | sp::DataType::Float(_)
1482 | sp::DataType::Float4
1483 | sp::DataType::Float64 => Ok(DataType::Real),
1484
1485 sp::DataType::Varchar(_)
1486 | sp::DataType::Text
1487 | sp::DataType::Char(_)
1488 | sp::DataType::Character(_)
1489 | sp::DataType::String(_) => Ok(DataType::Text),
1490
1491 sp::DataType::Blob(_) | sp::DataType::Bytea => Ok(DataType::Blob),
1492
1493 sp::DataType::Boolean | sp::DataType::Bool => Ok(DataType::Boolean),
1494
1495 _ => Err(SqlError::Unsupported(format!("data type: {dt}"))),
1496 }
1497}
1498
1499fn object_name_to_string(name: &sp::ObjectName) -> String {
1502 name.0
1503 .iter()
1504 .filter_map(|p| match p {
1505 sp::ObjectNamePart::Identifier(ident) => Some(ident.value.clone()),
1506 _ => None,
1507 })
1508 .collect::<Vec<_>>()
1509 .join(".")
1510}
1511
1512#[cfg(test)]
1513mod tests {
1514 use super::*;
1515
1516 #[test]
1517 fn parse_create_table() {
1518 let stmt = parse_sql(
1519 "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL, age INTEGER)",
1520 )
1521 .unwrap();
1522
1523 match stmt {
1524 Statement::CreateTable(ct) => {
1525 assert_eq!(ct.name, "users");
1526 assert_eq!(ct.columns.len(), 3);
1527 assert_eq!(ct.columns[0].name, "id");
1528 assert_eq!(ct.columns[0].data_type, DataType::Integer);
1529 assert!(ct.columns[0].is_primary_key);
1530 assert!(!ct.columns[0].nullable);
1531 assert_eq!(ct.columns[1].name, "name");
1532 assert_eq!(ct.columns[1].data_type, DataType::Text);
1533 assert!(!ct.columns[1].nullable);
1534 assert_eq!(ct.columns[2].name, "age");
1535 assert!(ct.columns[2].nullable);
1536 assert_eq!(ct.primary_key, vec!["id"]);
1537 }
1538 _ => panic!("expected CreateTable"),
1539 }
1540 }
1541
1542 #[test]
1543 fn parse_create_table_if_not_exists() {
1544 let stmt = parse_sql("CREATE TABLE IF NOT EXISTS t (id INT PRIMARY KEY)").unwrap();
1545 match stmt {
1546 Statement::CreateTable(ct) => assert!(ct.if_not_exists),
1547 _ => panic!("expected CreateTable"),
1548 }
1549 }
1550
1551 #[test]
1552 fn parse_drop_table() {
1553 let stmt = parse_sql("DROP TABLE users").unwrap();
1554 match stmt {
1555 Statement::DropTable(dt) => {
1556 assert_eq!(dt.name, "users");
1557 assert!(!dt.if_exists);
1558 }
1559 _ => panic!("expected DropTable"),
1560 }
1561 }
1562
1563 #[test]
1564 fn parse_drop_table_if_exists() {
1565 let stmt = parse_sql("DROP TABLE IF EXISTS users").unwrap();
1566 match stmt {
1567 Statement::DropTable(dt) => assert!(dt.if_exists),
1568 _ => panic!("expected DropTable"),
1569 }
1570 }
1571
1572 #[test]
1573 fn parse_insert() {
1574 let stmt =
1575 parse_sql("INSERT INTO users (id, name) VALUES (1, 'Alice'), (2, 'Bob')").unwrap();
1576
1577 match stmt {
1578 Statement::Insert(ins) => {
1579 assert_eq!(ins.table, "users");
1580 assert_eq!(ins.columns, vec!["id", "name"]);
1581 assert_eq!(ins.values.len(), 2);
1582 assert!(matches!(ins.values[0][0], Expr::Literal(Value::Integer(1))));
1583 assert!(matches!(&ins.values[0][1], Expr::Literal(Value::Text(s)) if s == "Alice"));
1584 }
1585 _ => panic!("expected Insert"),
1586 }
1587 }
1588
1589 #[test]
1590 fn parse_select_all() {
1591 let stmt = parse_sql("SELECT * FROM users").unwrap();
1592 match stmt {
1593 Statement::Select(sel) => {
1594 assert_eq!(sel.from, "users");
1595 assert!(matches!(sel.columns[0], SelectColumn::AllColumns));
1596 assert!(sel.where_clause.is_none());
1597 }
1598 _ => panic!("expected Select"),
1599 }
1600 }
1601
1602 #[test]
1603 fn parse_select_where() {
1604 let stmt = parse_sql("SELECT id, name FROM users WHERE age > 18").unwrap();
1605 match stmt {
1606 Statement::Select(sel) => {
1607 assert_eq!(sel.columns.len(), 2);
1608 assert!(sel.where_clause.is_some());
1609 }
1610 _ => panic!("expected Select"),
1611 }
1612 }
1613
1614 #[test]
1615 fn parse_select_order_limit() {
1616 let stmt = parse_sql("SELECT * FROM users ORDER BY name ASC LIMIT 10 OFFSET 5").unwrap();
1617 match stmt {
1618 Statement::Select(sel) => {
1619 assert_eq!(sel.order_by.len(), 1);
1620 assert!(!sel.order_by[0].descending);
1621 assert!(sel.limit.is_some());
1622 assert!(sel.offset.is_some());
1623 }
1624 _ => panic!("expected Select"),
1625 }
1626 }
1627
1628 #[test]
1629 fn parse_update() {
1630 let stmt = parse_sql("UPDATE users SET name = 'Bob' WHERE id = 1").unwrap();
1631 match stmt {
1632 Statement::Update(upd) => {
1633 assert_eq!(upd.table, "users");
1634 assert_eq!(upd.assignments.len(), 1);
1635 assert_eq!(upd.assignments[0].0, "name");
1636 assert!(upd.where_clause.is_some());
1637 }
1638 _ => panic!("expected Update"),
1639 }
1640 }
1641
1642 #[test]
1643 fn parse_delete() {
1644 let stmt = parse_sql("DELETE FROM users WHERE id = 1").unwrap();
1645 match stmt {
1646 Statement::Delete(del) => {
1647 assert_eq!(del.table, "users");
1648 assert!(del.where_clause.is_some());
1649 }
1650 _ => panic!("expected Delete"),
1651 }
1652 }
1653
1654 #[test]
1655 fn parse_aggregate() {
1656 let stmt = parse_sql("SELECT COUNT(*), SUM(age) FROM users").unwrap();
1657 match stmt {
1658 Statement::Select(sel) => {
1659 assert_eq!(sel.columns.len(), 2);
1660 match &sel.columns[0] {
1661 SelectColumn::Expr {
1662 expr: Expr::CountStar,
1663 ..
1664 } => {}
1665 other => panic!("expected CountStar, got {other:?}"),
1666 }
1667 }
1668 _ => panic!("expected Select"),
1669 }
1670 }
1671
1672 #[test]
1673 fn parse_group_by_having() {
1674 let stmt = parse_sql(
1675 "SELECT department, COUNT(*) FROM employees GROUP BY department HAVING COUNT(*) > 5",
1676 )
1677 .unwrap();
1678 match stmt {
1679 Statement::Select(sel) => {
1680 assert_eq!(sel.group_by.len(), 1);
1681 assert!(sel.having.is_some());
1682 }
1683 _ => panic!("expected Select"),
1684 }
1685 }
1686
1687 #[test]
1688 fn parse_expressions() {
1689 let stmt = parse_sql("SELECT id + 1, -price, NOT active FROM items").unwrap();
1690 match stmt {
1691 Statement::Select(sel) => {
1692 assert_eq!(sel.columns.len(), 3);
1693 match &sel.columns[0] {
1695 SelectColumn::Expr {
1696 expr: Expr::BinaryOp { op: BinOp::Add, .. },
1697 ..
1698 } => {}
1699 other => panic!("expected BinaryOp Add, got {other:?}"),
1700 }
1701 match &sel.columns[1] {
1703 SelectColumn::Expr {
1704 expr:
1705 Expr::UnaryOp {
1706 op: UnaryOp::Neg, ..
1707 },
1708 ..
1709 } => {}
1710 other => panic!("expected UnaryOp Neg, got {other:?}"),
1711 }
1712 match &sel.columns[2] {
1714 SelectColumn::Expr {
1715 expr:
1716 Expr::UnaryOp {
1717 op: UnaryOp::Not, ..
1718 },
1719 ..
1720 } => {}
1721 other => panic!("expected UnaryOp Not, got {other:?}"),
1722 }
1723 }
1724 _ => panic!("expected Select"),
1725 }
1726 }
1727
1728 #[test]
1729 fn parse_is_null() {
1730 let stmt = parse_sql("SELECT * FROM t WHERE x IS NULL").unwrap();
1731 match stmt {
1732 Statement::Select(sel) => {
1733 assert!(matches!(sel.where_clause, Some(Expr::IsNull(_))));
1734 }
1735 _ => panic!("expected Select"),
1736 }
1737 }
1738
1739 #[test]
1740 fn parse_inner_join() {
1741 let stmt = parse_sql("SELECT * FROM a JOIN b ON a.id = b.id").unwrap();
1742 match stmt {
1743 Statement::Select(sel) => {
1744 assert_eq!(sel.from, "a");
1745 assert_eq!(sel.joins.len(), 1);
1746 assert_eq!(sel.joins[0].join_type, JoinType::Inner);
1747 assert_eq!(sel.joins[0].table.name, "b");
1748 assert!(sel.joins[0].on_clause.is_some());
1749 }
1750 _ => panic!("expected Select"),
1751 }
1752 }
1753
1754 #[test]
1755 fn parse_inner_join_explicit() {
1756 let stmt = parse_sql("SELECT * FROM a INNER JOIN b ON a.id = b.a_id").unwrap();
1757 match stmt {
1758 Statement::Select(sel) => {
1759 assert_eq!(sel.joins.len(), 1);
1760 assert_eq!(sel.joins[0].join_type, JoinType::Inner);
1761 }
1762 _ => panic!("expected Select"),
1763 }
1764 }
1765
1766 #[test]
1767 fn parse_cross_join() {
1768 let stmt = parse_sql("SELECT * FROM a CROSS JOIN b").unwrap();
1769 match stmt {
1770 Statement::Select(sel) => {
1771 assert_eq!(sel.joins.len(), 1);
1772 assert_eq!(sel.joins[0].join_type, JoinType::Cross);
1773 assert!(sel.joins[0].on_clause.is_none());
1774 }
1775 _ => panic!("expected Select"),
1776 }
1777 }
1778
1779 #[test]
1780 fn parse_left_join() {
1781 let stmt = parse_sql("SELECT * FROM a LEFT JOIN b ON a.id = b.a_id").unwrap();
1782 match stmt {
1783 Statement::Select(sel) => {
1784 assert_eq!(sel.joins.len(), 1);
1785 assert_eq!(sel.joins[0].join_type, JoinType::Left);
1786 }
1787 _ => panic!("expected Select"),
1788 }
1789 }
1790
1791 #[test]
1792 fn parse_table_alias() {
1793 let stmt = parse_sql("SELECT u.id FROM users u JOIN orders o ON u.id = o.user_id").unwrap();
1794 match stmt {
1795 Statement::Select(sel) => {
1796 assert_eq!(sel.from, "users");
1797 assert_eq!(sel.from_alias.as_deref(), Some("u"));
1798 assert_eq!(sel.joins[0].table.name, "orders");
1799 assert_eq!(sel.joins[0].table.alias.as_deref(), Some("o"));
1800 }
1801 _ => panic!("expected Select"),
1802 }
1803 }
1804
1805 #[test]
1806 fn parse_multi_join() {
1807 let stmt =
1808 parse_sql("SELECT * FROM a JOIN b ON a.id = b.a_id JOIN c ON b.id = c.b_id").unwrap();
1809 match stmt {
1810 Statement::Select(sel) => {
1811 assert_eq!(sel.joins.len(), 2);
1812 }
1813 _ => panic!("expected Select"),
1814 }
1815 }
1816
1817 #[test]
1818 fn parse_qualified_column() {
1819 let stmt = parse_sql("SELECT u.id, u.name FROM users u").unwrap();
1820 match stmt {
1821 Statement::Select(sel) => match &sel.columns[0] {
1822 SelectColumn::Expr {
1823 expr: Expr::QualifiedColumn { table, column },
1824 ..
1825 } => {
1826 assert_eq!(table, "u");
1827 assert_eq!(column, "id");
1828 }
1829 other => panic!("expected QualifiedColumn, got {other:?}"),
1830 },
1831 _ => panic!("expected Select"),
1832 }
1833 }
1834
1835 #[test]
1836 fn reject_subquery() {
1837 assert!(parse_sql("SELECT * FROM (SELECT 1)").is_err());
1838 }
1839
1840 #[test]
1841 fn parse_type_mapping() {
1842 let stmt = parse_sql(
1843 "CREATE TABLE t (a INT PRIMARY KEY, b BIGINT, c SMALLINT, d REAL, e DOUBLE PRECISION, f VARCHAR(255), g BOOLEAN, h BLOB, i BYTEA)"
1844 ).unwrap();
1845 match stmt {
1846 Statement::CreateTable(ct) => {
1847 assert_eq!(ct.columns[0].data_type, DataType::Integer); assert_eq!(ct.columns[1].data_type, DataType::Integer); assert_eq!(ct.columns[2].data_type, DataType::Integer); assert_eq!(ct.columns[3].data_type, DataType::Real); assert_eq!(ct.columns[4].data_type, DataType::Real); assert_eq!(ct.columns[5].data_type, DataType::Text); assert_eq!(ct.columns[6].data_type, DataType::Boolean); assert_eq!(ct.columns[7].data_type, DataType::Blob); assert_eq!(ct.columns[8].data_type, DataType::Blob); }
1857 _ => panic!("expected CreateTable"),
1858 }
1859 }
1860
1861 #[test]
1862 fn parse_boolean_literals() {
1863 let stmt = parse_sql("INSERT INTO t (a, b) VALUES (true, false)").unwrap();
1864 match stmt {
1865 Statement::Insert(ins) => {
1866 assert!(matches!(
1867 ins.values[0][0],
1868 Expr::Literal(Value::Boolean(true))
1869 ));
1870 assert!(matches!(
1871 ins.values[0][1],
1872 Expr::Literal(Value::Boolean(false))
1873 ));
1874 }
1875 _ => panic!("expected Insert"),
1876 }
1877 }
1878
1879 #[test]
1880 fn parse_null_literal() {
1881 let stmt = parse_sql("INSERT INTO t (a) VALUES (NULL)").unwrap();
1882 match stmt {
1883 Statement::Insert(ins) => {
1884 assert!(matches!(ins.values[0][0], Expr::Literal(Value::Null)));
1885 }
1886 _ => panic!("expected Insert"),
1887 }
1888 }
1889
1890 #[test]
1891 fn parse_alias() {
1892 let stmt = parse_sql("SELECT id AS user_id FROM users").unwrap();
1893 match stmt {
1894 Statement::Select(sel) => match &sel.columns[0] {
1895 SelectColumn::Expr { alias: Some(a), .. } => assert_eq!(a, "user_id"),
1896 other => panic!("expected alias, got {other:?}"),
1897 },
1898 _ => panic!("expected Select"),
1899 }
1900 }
1901
1902 #[test]
1903 fn parse_begin() {
1904 let stmt = parse_sql("BEGIN").unwrap();
1905 assert!(matches!(stmt, Statement::Begin));
1906 }
1907
1908 #[test]
1909 fn parse_begin_transaction() {
1910 let stmt = parse_sql("BEGIN TRANSACTION").unwrap();
1911 assert!(matches!(stmt, Statement::Begin));
1912 }
1913
1914 #[test]
1915 fn parse_commit() {
1916 let stmt = parse_sql("COMMIT").unwrap();
1917 assert!(matches!(stmt, Statement::Commit));
1918 }
1919
1920 #[test]
1921 fn parse_rollback() {
1922 let stmt = parse_sql("ROLLBACK").unwrap();
1923 assert!(matches!(stmt, Statement::Rollback));
1924 }
1925
1926 #[test]
1927 fn parse_select_distinct() {
1928 let stmt = parse_sql("SELECT DISTINCT name FROM users").unwrap();
1929 match stmt {
1930 Statement::Select(sel) => {
1931 assert!(sel.distinct);
1932 assert_eq!(sel.columns.len(), 1);
1933 }
1934 _ => panic!("expected Select"),
1935 }
1936 }
1937
1938 #[test]
1939 fn parse_select_without_distinct() {
1940 let stmt = parse_sql("SELECT name FROM users").unwrap();
1941 match stmt {
1942 Statement::Select(sel) => {
1943 assert!(!sel.distinct);
1944 }
1945 _ => panic!("expected Select"),
1946 }
1947 }
1948
1949 #[test]
1950 fn parse_select_distinct_all_columns() {
1951 let stmt = parse_sql("SELECT DISTINCT * FROM users").unwrap();
1952 match stmt {
1953 Statement::Select(sel) => {
1954 assert!(sel.distinct);
1955 assert!(matches!(sel.columns[0], SelectColumn::AllColumns));
1956 }
1957 _ => panic!("expected Select"),
1958 }
1959 }
1960
1961 #[test]
1962 fn reject_distinct_on() {
1963 assert!(parse_sql("SELECT DISTINCT ON (id) * FROM users").is_err());
1964 }
1965
1966 #[test]
1967 fn parse_create_index() {
1968 let stmt = parse_sql("CREATE INDEX idx_name ON users (name)").unwrap();
1969 match stmt {
1970 Statement::CreateIndex(ci) => {
1971 assert_eq!(ci.index_name, "idx_name");
1972 assert_eq!(ci.table_name, "users");
1973 assert_eq!(ci.columns, vec!["name"]);
1974 assert!(!ci.unique);
1975 assert!(!ci.if_not_exists);
1976 }
1977 _ => panic!("expected CreateIndex"),
1978 }
1979 }
1980
1981 #[test]
1982 fn parse_create_unique_index() {
1983 let stmt = parse_sql("CREATE UNIQUE INDEX idx_email ON users (email)").unwrap();
1984 match stmt {
1985 Statement::CreateIndex(ci) => {
1986 assert!(ci.unique);
1987 assert_eq!(ci.columns, vec!["email"]);
1988 }
1989 _ => panic!("expected CreateIndex"),
1990 }
1991 }
1992
1993 #[test]
1994 fn parse_create_index_if_not_exists() {
1995 let stmt = parse_sql("CREATE INDEX IF NOT EXISTS idx_x ON t (a)").unwrap();
1996 match stmt {
1997 Statement::CreateIndex(ci) => assert!(ci.if_not_exists),
1998 _ => panic!("expected CreateIndex"),
1999 }
2000 }
2001
2002 #[test]
2003 fn parse_create_index_multi_column() {
2004 let stmt = parse_sql("CREATE INDEX idx_multi ON t (a, b, c)").unwrap();
2005 match stmt {
2006 Statement::CreateIndex(ci) => {
2007 assert_eq!(ci.columns, vec!["a", "b", "c"]);
2008 }
2009 _ => panic!("expected CreateIndex"),
2010 }
2011 }
2012
2013 #[test]
2014 fn parse_drop_index() {
2015 let stmt = parse_sql("DROP INDEX idx_name").unwrap();
2016 match stmt {
2017 Statement::DropIndex(di) => {
2018 assert_eq!(di.index_name, "idx_name");
2019 assert!(!di.if_exists);
2020 }
2021 _ => panic!("expected DropIndex"),
2022 }
2023 }
2024
2025 #[test]
2026 fn parse_drop_index_if_exists() {
2027 let stmt = parse_sql("DROP INDEX IF EXISTS idx_name").unwrap();
2028 match stmt {
2029 Statement::DropIndex(di) => {
2030 assert!(di.if_exists);
2031 }
2032 _ => panic!("expected DropIndex"),
2033 }
2034 }
2035
2036 #[test]
2037 fn parse_explain_select() {
2038 let stmt = parse_sql("EXPLAIN SELECT * FROM users WHERE id = 1").unwrap();
2039 match stmt {
2040 Statement::Explain(inner) => {
2041 assert!(matches!(*inner, Statement::Select(_)));
2042 }
2043 _ => panic!("expected Explain"),
2044 }
2045 }
2046
2047 #[test]
2048 fn parse_explain_insert() {
2049 let stmt = parse_sql("EXPLAIN INSERT INTO t (a) VALUES (1)").unwrap();
2050 assert!(matches!(stmt, Statement::Explain(_)));
2051 }
2052
2053 #[test]
2054 fn reject_explain_analyze() {
2055 assert!(parse_sql("EXPLAIN ANALYZE SELECT * FROM t").is_err());
2056 }
2057
2058 #[test]
2059 fn parse_parameter_placeholder() {
2060 let stmt = parse_sql("SELECT * FROM t WHERE id = $1").unwrap();
2061 match stmt {
2062 Statement::Select(sel) => match &sel.where_clause {
2063 Some(Expr::BinaryOp { right, .. }) => {
2064 assert!(matches!(right.as_ref(), Expr::Parameter(1)));
2065 }
2066 other => panic!("expected BinaryOp with Parameter, got {other:?}"),
2067 },
2068 _ => panic!("expected Select"),
2069 }
2070 }
2071
2072 #[test]
2073 fn parse_multiple_parameters() {
2074 let stmt = parse_sql("INSERT INTO t (a, b) VALUES ($1, $2)").unwrap();
2075 match stmt {
2076 Statement::Insert(ins) => {
2077 assert!(matches!(ins.values[0][0], Expr::Parameter(1)));
2078 assert!(matches!(ins.values[0][1], Expr::Parameter(2)));
2079 }
2080 _ => panic!("expected Insert"),
2081 }
2082 }
2083
2084 #[test]
2085 fn reject_zero_parameter() {
2086 assert!(parse_sql("SELECT $0 FROM t").is_err());
2087 }
2088
2089 #[test]
2090 fn count_params_basic() {
2091 let stmt = parse_sql("SELECT * FROM t WHERE a = $1 AND b = $3").unwrap();
2092 assert_eq!(count_params(&stmt), 3);
2093 }
2094
2095 #[test]
2096 fn count_params_none() {
2097 let stmt = parse_sql("SELECT * FROM t WHERE a = 1").unwrap();
2098 assert_eq!(count_params(&stmt), 0);
2099 }
2100
2101 #[test]
2102 fn bind_params_basic() {
2103 let stmt = parse_sql("SELECT * FROM t WHERE id = $1").unwrap();
2104 let bound = bind_params(&stmt, &[Value::Integer(42)]).unwrap();
2105 match bound {
2106 Statement::Select(sel) => match &sel.where_clause {
2107 Some(Expr::BinaryOp { right, .. }) => {
2108 assert!(matches!(right.as_ref(), Expr::Literal(Value::Integer(42))));
2109 }
2110 other => panic!("expected BinaryOp with Literal, got {other:?}"),
2111 },
2112 _ => panic!("expected Select"),
2113 }
2114 }
2115
2116 #[test]
2117 fn bind_params_out_of_range() {
2118 let stmt = parse_sql("SELECT * FROM t WHERE id = $2").unwrap();
2119 let result = bind_params(&stmt, &[Value::Integer(1)]);
2120 assert!(result.is_err());
2121 }
2122
2123 #[test]
2124 fn parse_table_constraint_pk() {
2125 let stmt = parse_sql("CREATE TABLE t (a INTEGER, b TEXT, PRIMARY KEY (a))").unwrap();
2126 match stmt {
2127 Statement::CreateTable(ct) => {
2128 assert_eq!(ct.primary_key, vec!["a"]);
2129 assert!(ct.columns[0].is_primary_key);
2130 assert!(!ct.columns[0].nullable);
2131 }
2132 _ => panic!("expected CreateTable"),
2133 }
2134 }
2135}