1use sqlparser::ast as sp;
4use sqlparser::dialect::GenericDialect;
5use sqlparser::parser::Parser;
6
7use crate::error::{Result, SqlError};
8use crate::types::{DataType, Value};
9
10#[derive(Debug, Clone)]
11pub enum Statement {
12 CreateTable(CreateTableStmt),
13 DropTable(DropTableStmt),
14 CreateIndex(CreateIndexStmt),
15 DropIndex(DropIndexStmt),
16 CreateView(CreateViewStmt),
17 DropView(DropViewStmt),
18 AlterTable(Box<AlterTableStmt>),
19 Insert(InsertStmt),
20 Select(Box<SelectQuery>),
21 Update(UpdateStmt),
22 Delete(DeleteStmt),
23 Begin,
24 Commit,
25 Rollback,
26 Savepoint(String),
27 ReleaseSavepoint(String),
28 RollbackTo(String),
29 SetTimezone(String),
30 Explain(Box<Statement>),
31}
32
33#[derive(Debug, Clone)]
34pub struct AlterTableStmt {
35 pub table: String,
36 pub op: AlterTableOp,
37}
38
39#[derive(Debug, Clone)]
40pub enum AlterTableOp {
41 AddColumn {
42 column: Box<ColumnSpec>,
43 foreign_key: Option<ForeignKeyDef>,
44 if_not_exists: bool,
45 },
46 DropColumn {
47 name: String,
48 if_exists: bool,
49 },
50 RenameColumn {
51 old_name: String,
52 new_name: String,
53 },
54 RenameTable {
55 new_name: String,
56 },
57}
58
59#[derive(Debug, Clone)]
60pub struct CreateTableStmt {
61 pub name: String,
62 pub columns: Vec<ColumnSpec>,
63 pub primary_key: Vec<String>,
64 pub if_not_exists: bool,
65 pub check_constraints: Vec<TableCheckConstraint>,
66 pub foreign_keys: Vec<ForeignKeyDef>,
67 pub unique_indices: Vec<UniqueIndexDef>,
68}
69
70#[derive(Debug, Clone)]
71pub struct UniqueIndexDef {
72 pub name: Option<String>,
73 pub columns: Vec<String>,
74}
75
76#[derive(Debug, Clone)]
77pub struct TableCheckConstraint {
78 pub name: Option<String>,
79 pub expr: Expr,
80 pub sql: String,
81}
82
83#[derive(Debug, Clone)]
84pub struct ForeignKeyDef {
85 pub name: Option<String>,
86 pub columns: Vec<String>,
87 pub foreign_table: String,
88 pub referred_columns: Vec<String>,
89}
90
91#[derive(Debug, Clone)]
92pub struct ColumnSpec {
93 pub name: String,
94 pub data_type: DataType,
95 pub nullable: bool,
96 pub is_primary_key: bool,
97 pub default_expr: Option<Expr>,
98 pub default_sql: Option<String>,
99 pub check_expr: Option<Expr>,
100 pub check_sql: Option<String>,
101 pub check_name: Option<String>,
102}
103
104#[derive(Debug, Clone)]
105pub struct DropTableStmt {
106 pub name: String,
107 pub if_exists: bool,
108}
109
110#[derive(Debug, Clone)]
111pub struct CreateIndexStmt {
112 pub index_name: String,
113 pub table_name: String,
114 pub columns: Vec<String>,
115 pub unique: bool,
116 pub if_not_exists: bool,
117}
118
119#[derive(Debug, Clone)]
120pub struct DropIndexStmt {
121 pub index_name: String,
122 pub if_exists: bool,
123}
124
125#[derive(Debug, Clone)]
126pub struct CreateViewStmt {
127 pub name: String,
128 pub sql: String,
129 pub column_aliases: Vec<String>,
130 pub or_replace: bool,
131 pub if_not_exists: bool,
132}
133
134#[derive(Debug, Clone)]
135pub struct DropViewStmt {
136 pub name: String,
137 pub if_exists: bool,
138}
139
140#[derive(Debug, Clone)]
141pub enum InsertSource {
142 Values(Vec<Vec<Expr>>),
143 Select(Box<SelectQuery>),
144}
145
146#[derive(Debug, Clone)]
147pub struct InsertStmt {
148 pub table: String,
149 pub columns: Vec<String>,
150 pub source: InsertSource,
151}
152
153#[derive(Debug, Clone)]
154pub struct TableRef {
155 pub name: String,
156 pub alias: Option<String>,
157}
158
159#[derive(Debug, Clone, Copy, PartialEq)]
160pub enum JoinType {
161 Inner,
162 Cross,
163 Left,
164 Right,
165}
166
167#[derive(Debug, Clone)]
168pub struct JoinClause {
169 pub join_type: JoinType,
170 pub table: TableRef,
171 pub on_clause: Option<Expr>,
172}
173
174#[derive(Debug, Clone)]
175pub struct SelectStmt {
176 pub columns: Vec<SelectColumn>,
177 pub from: String,
178 pub from_alias: Option<String>,
179 pub joins: Vec<JoinClause>,
180 pub distinct: bool,
181 pub where_clause: Option<Expr>,
182 pub order_by: Vec<OrderByItem>,
183 pub limit: Option<Expr>,
184 pub offset: Option<Expr>,
185 pub group_by: Vec<Expr>,
186 pub having: Option<Expr>,
187}
188
189#[derive(Debug, Clone)]
190pub enum SetOp {
191 Union,
192 Intersect,
193 Except,
194}
195
196#[derive(Debug, Clone)]
197pub struct CompoundSelect {
198 pub op: SetOp,
199 pub all: bool,
200 pub left: Box<QueryBody>,
201 pub right: Box<QueryBody>,
202 pub order_by: Vec<OrderByItem>,
203 pub limit: Option<Expr>,
204 pub offset: Option<Expr>,
205}
206
207#[derive(Debug, Clone)]
208pub enum QueryBody {
209 Select(Box<SelectStmt>),
210 Compound(Box<CompoundSelect>),
211}
212
213#[derive(Debug, Clone)]
214pub struct CteDefinition {
215 pub name: String,
216 pub column_aliases: Vec<String>,
217 pub body: QueryBody,
218}
219
220#[derive(Debug, Clone)]
221pub struct SelectQuery {
222 pub ctes: Vec<CteDefinition>,
223 pub recursive: bool,
224 pub body: QueryBody,
225}
226
227#[derive(Debug, Clone)]
228pub struct UpdateStmt {
229 pub table: String,
230 pub assignments: Vec<(String, Expr)>,
231 pub where_clause: Option<Expr>,
232}
233
234#[derive(Debug, Clone)]
235pub struct DeleteStmt {
236 pub table: String,
237 pub where_clause: Option<Expr>,
238}
239
240#[derive(Debug, Clone)]
241pub enum SelectColumn {
242 AllColumns,
243 Expr { expr: Expr, alias: Option<String> },
244}
245
246#[derive(Debug, Clone)]
247pub struct OrderByItem {
248 pub expr: Expr,
249 pub descending: bool,
250 pub nulls_first: Option<bool>,
251}
252
253#[derive(Debug, Clone)]
254pub enum Expr {
255 Literal(Value),
256 Column(String),
257 QualifiedColumn {
258 table: String,
259 column: String,
260 },
261 BinaryOp {
262 left: Box<Expr>,
263 op: BinOp,
264 right: Box<Expr>,
265 },
266 UnaryOp {
267 op: UnaryOp,
268 expr: Box<Expr>,
269 },
270 IsNull(Box<Expr>),
271 IsNotNull(Box<Expr>),
272 Function {
273 name: String,
274 args: Vec<Expr>,
275 distinct: bool,
277 },
278 CountStar,
279 InSubquery {
280 expr: Box<Expr>,
281 subquery: Box<SelectStmt>,
282 negated: bool,
283 },
284 InList {
285 expr: Box<Expr>,
286 list: Vec<Expr>,
287 negated: bool,
288 },
289 Exists {
290 subquery: Box<SelectStmt>,
291 negated: bool,
292 },
293 ScalarSubquery(Box<SelectStmt>),
294 InSet {
295 expr: Box<Expr>,
296 values: std::collections::HashSet<Value>,
297 has_null: bool,
298 negated: bool,
299 },
300 Between {
301 expr: Box<Expr>,
302 low: Box<Expr>,
303 high: Box<Expr>,
304 negated: bool,
305 },
306 Like {
307 expr: Box<Expr>,
308 pattern: Box<Expr>,
309 escape: Option<Box<Expr>>,
310 negated: bool,
311 },
312 Case {
313 operand: Option<Box<Expr>>,
314 conditions: Vec<(Expr, Expr)>,
315 else_result: Option<Box<Expr>>,
316 },
317 Coalesce(Vec<Expr>),
318 Cast {
319 expr: Box<Expr>,
320 data_type: DataType,
321 },
322 Parameter(usize),
323 WindowFunction {
324 name: String,
325 args: Vec<Expr>,
326 spec: WindowSpec,
327 },
328}
329
330#[derive(Debug, Clone)]
331pub struct WindowSpec {
332 pub partition_by: Vec<Expr>,
333 pub order_by: Vec<OrderByItem>,
334 pub frame: Option<WindowFrame>,
335}
336
337#[derive(Debug, Clone)]
338pub struct WindowFrame {
339 pub units: WindowFrameUnits,
340 pub start: WindowFrameBound,
341 pub end: WindowFrameBound,
342}
343
344#[derive(Debug, Clone, Copy)]
345pub enum WindowFrameUnits {
346 Rows,
347 Range,
348 Groups,
349}
350
351#[derive(Debug, Clone)]
352pub enum WindowFrameBound {
353 UnboundedPreceding,
354 Preceding(Box<Expr>),
355 CurrentRow,
356 Following(Box<Expr>),
357 UnboundedFollowing,
358}
359
360#[derive(Debug, Clone, Copy, PartialEq, Eq)]
361pub enum BinOp {
362 Add,
363 Sub,
364 Mul,
365 Div,
366 Mod,
367 Eq,
368 NotEq,
369 Lt,
370 Gt,
371 LtEq,
372 GtEq,
373 And,
374 Or,
375 Concat,
376}
377
378#[derive(Debug, Clone, Copy, PartialEq, Eq)]
379pub enum UnaryOp {
380 Neg,
381 Not,
382}
383
384pub fn has_subquery(expr: &Expr) -> bool {
385 match expr {
386 Expr::InSubquery { .. } | Expr::Exists { .. } | Expr::ScalarSubquery(_) => true,
387 Expr::BinaryOp { left, right, .. } => has_subquery(left) || has_subquery(right),
388 Expr::UnaryOp { expr, .. } => has_subquery(expr),
389 Expr::IsNull(e) | Expr::IsNotNull(e) => has_subquery(e),
390 Expr::InList { expr, list, .. } => has_subquery(expr) || list.iter().any(has_subquery),
391 Expr::InSet { expr, .. } => has_subquery(expr),
392 Expr::Between {
393 expr, low, high, ..
394 } => has_subquery(expr) || has_subquery(low) || has_subquery(high),
395 Expr::Like {
396 expr,
397 pattern,
398 escape,
399 ..
400 } => {
401 has_subquery(expr)
402 || has_subquery(pattern)
403 || escape.as_ref().is_some_and(|e| has_subquery(e))
404 }
405 Expr::Case {
406 operand,
407 conditions,
408 else_result,
409 } => {
410 operand.as_ref().is_some_and(|e| has_subquery(e))
411 || conditions
412 .iter()
413 .any(|(c, r)| has_subquery(c) || has_subquery(r))
414 || else_result.as_ref().is_some_and(|e| has_subquery(e))
415 }
416 Expr::Coalesce(args) | Expr::Function { args, .. } => args.iter().any(has_subquery),
417 Expr::Cast { expr, .. } => has_subquery(expr),
418 _ => false,
419 }
420}
421
422pub fn parse_sql_expr(sql: &str) -> Result<Expr> {
425 let dialect = GenericDialect {};
426 let mut parser = Parser::new(&dialect)
427 .try_with_sql(sql)
428 .map_err(|e| SqlError::Parse(e.to_string()))?;
429 let sp_expr = parser
430 .parse_expr()
431 .map_err(|e| SqlError::Parse(e.to_string()))?;
432 convert_expr(&sp_expr)
433}
434
435pub fn parse_sql(sql: &str) -> Result<Statement> {
436 let dialect = GenericDialect {};
437 let stmts = Parser::parse_sql(&dialect, sql).map_err(|e| SqlError::Parse(e.to_string()))?;
438
439 if stmts.is_empty() {
440 return Err(SqlError::Parse("empty SQL".into()));
441 }
442 if stmts.len() > 1 {
443 return Err(SqlError::Unsupported("multiple statements".into()));
444 }
445
446 convert_statement(stmts.into_iter().next().unwrap())
447}
448
449pub fn parse_sql_multi(sql: &str) -> Result<Vec<Statement>> {
451 let dialect = GenericDialect {};
452 let stmts = Parser::parse_sql(&dialect, sql).map_err(|e| SqlError::Parse(e.to_string()))?;
453
454 if stmts.is_empty() {
455 return Err(SqlError::Parse("empty SQL".into()));
456 }
457
458 stmts.into_iter().map(convert_statement).collect()
459}
460
461pub fn count_params(stmt: &Statement) -> usize {
463 let mut max_idx = 0usize;
464 visit_exprs_stmt(stmt, &mut |e| {
465 if let Expr::Parameter(n) = e {
466 max_idx = max_idx.max(*n);
467 }
468 });
469 max_idx
470}
471
472fn visit_exprs_stmt(stmt: &Statement, visitor: &mut impl FnMut(&Expr)) {
473 match stmt {
474 Statement::Select(sq) => {
475 for cte in &sq.ctes {
476 visit_exprs_query_body(&cte.body, visitor);
477 }
478 visit_exprs_query_body(&sq.body, visitor);
479 }
480 Statement::Insert(ins) => match &ins.source {
481 InsertSource::Values(rows) => {
482 for row in rows {
483 for e in row {
484 visit_expr(e, visitor);
485 }
486 }
487 }
488 InsertSource::Select(sq) => {
489 for cte in &sq.ctes {
490 visit_exprs_query_body(&cte.body, visitor);
491 }
492 visit_exprs_query_body(&sq.body, visitor);
493 }
494 },
495 Statement::Update(upd) => {
496 for (_, e) in &upd.assignments {
497 visit_expr(e, visitor);
498 }
499 if let Some(w) = &upd.where_clause {
500 visit_expr(w, visitor);
501 }
502 }
503 Statement::Delete(del) => {
504 if let Some(w) = &del.where_clause {
505 visit_expr(w, visitor);
506 }
507 }
508 Statement::Explain(inner) => visit_exprs_stmt(inner, visitor),
509 _ => {}
510 }
511}
512
513fn visit_exprs_query_body(body: &QueryBody, visitor: &mut impl FnMut(&Expr)) {
514 match body {
515 QueryBody::Select(sel) => visit_exprs_select(sel, visitor),
516 QueryBody::Compound(comp) => {
517 visit_exprs_query_body(&comp.left, visitor);
518 visit_exprs_query_body(&comp.right, visitor);
519 for o in &comp.order_by {
520 visit_expr(&o.expr, visitor);
521 }
522 if let Some(l) = &comp.limit {
523 visit_expr(l, visitor);
524 }
525 if let Some(o) = &comp.offset {
526 visit_expr(o, visitor);
527 }
528 }
529 }
530}
531
532fn visit_exprs_select(sel: &SelectStmt, visitor: &mut impl FnMut(&Expr)) {
533 for col in &sel.columns {
534 if let SelectColumn::Expr { expr, .. } = col {
535 visit_expr(expr, visitor);
536 }
537 }
538 for j in &sel.joins {
539 if let Some(on) = &j.on_clause {
540 visit_expr(on, visitor);
541 }
542 }
543 if let Some(w) = &sel.where_clause {
544 visit_expr(w, visitor);
545 }
546 for o in &sel.order_by {
547 visit_expr(&o.expr, visitor);
548 }
549 if let Some(l) = &sel.limit {
550 visit_expr(l, visitor);
551 }
552 if let Some(o) = &sel.offset {
553 visit_expr(o, visitor);
554 }
555 for g in &sel.group_by {
556 visit_expr(g, visitor);
557 }
558 if let Some(h) = &sel.having {
559 visit_expr(h, visitor);
560 }
561}
562
563fn visit_expr(expr: &Expr, visitor: &mut impl FnMut(&Expr)) {
564 visitor(expr);
565 match expr {
566 Expr::BinaryOp { left, right, .. } => {
567 visit_expr(left, visitor);
568 visit_expr(right, visitor);
569 }
570 Expr::UnaryOp { expr: e, .. } | Expr::IsNull(e) | Expr::IsNotNull(e) => {
571 visit_expr(e, visitor);
572 }
573 Expr::Function { args, .. } | Expr::Coalesce(args) => {
574 for a in args {
575 visit_expr(a, visitor);
576 }
577 }
578 Expr::InSubquery {
579 expr: e, subquery, ..
580 } => {
581 visit_expr(e, visitor);
582 visit_exprs_select(subquery, visitor);
583 }
584 Expr::InList { expr: e, list, .. } => {
585 visit_expr(e, visitor);
586 for l in list {
587 visit_expr(l, visitor);
588 }
589 }
590 Expr::Exists { subquery, .. } => visit_exprs_select(subquery, visitor),
591 Expr::ScalarSubquery(sq) => visit_exprs_select(sq, visitor),
592 Expr::InSet { expr: e, .. } => visit_expr(e, visitor),
593 Expr::Between {
594 expr: e, low, high, ..
595 } => {
596 visit_expr(e, visitor);
597 visit_expr(low, visitor);
598 visit_expr(high, visitor);
599 }
600 Expr::Like {
601 expr: e,
602 pattern,
603 escape,
604 ..
605 } => {
606 visit_expr(e, visitor);
607 visit_expr(pattern, visitor);
608 if let Some(esc) = escape {
609 visit_expr(esc, visitor);
610 }
611 }
612 Expr::Case {
613 operand,
614 conditions,
615 else_result,
616 } => {
617 if let Some(op) = operand {
618 visit_expr(op, visitor);
619 }
620 for (cond, then) in conditions {
621 visit_expr(cond, visitor);
622 visit_expr(then, visitor);
623 }
624 if let Some(el) = else_result {
625 visit_expr(el, visitor);
626 }
627 }
628 Expr::Cast { expr: e, .. } => visit_expr(e, visitor),
629 Expr::WindowFunction { args, spec, .. } => {
630 for a in args {
631 visit_expr(a, visitor);
632 }
633 for p in &spec.partition_by {
634 visit_expr(p, visitor);
635 }
636 for o in &spec.order_by {
637 visit_expr(&o.expr, visitor);
638 }
639 if let Some(ref frame) = spec.frame {
640 if let WindowFrameBound::Preceding(e) | WindowFrameBound::Following(e) =
641 &frame.start
642 {
643 visit_expr(e, visitor);
644 }
645 if let WindowFrameBound::Preceding(e) | WindowFrameBound::Following(e) = &frame.end
646 {
647 visit_expr(e, visitor);
648 }
649 }
650 }
651 Expr::Literal(_)
652 | Expr::Column(_)
653 | Expr::QualifiedColumn { .. }
654 | Expr::CountStar
655 | Expr::Parameter(_) => {}
656 }
657}
658
659fn convert_statement(stmt: sp::Statement) -> Result<Statement> {
660 match stmt {
661 sp::Statement::CreateTable(ct) => convert_create_table(ct),
662 sp::Statement::CreateIndex(ci) => convert_create_index(ci),
663 sp::Statement::Drop {
664 object_type: sp::ObjectType::Table,
665 if_exists,
666 names,
667 ..
668 } => {
669 if names.len() != 1 {
670 return Err(SqlError::Unsupported("multi-table DROP".into()));
671 }
672 Ok(Statement::DropTable(DropTableStmt {
673 name: object_name_to_string(&names[0]),
674 if_exists,
675 }))
676 }
677 sp::Statement::Drop {
678 object_type: sp::ObjectType::Index,
679 if_exists,
680 names,
681 ..
682 } => {
683 if names.len() != 1 {
684 return Err(SqlError::Unsupported("multi-index DROP".into()));
685 }
686 Ok(Statement::DropIndex(DropIndexStmt {
687 index_name: object_name_to_string(&names[0]),
688 if_exists,
689 }))
690 }
691 sp::Statement::CreateView(cv) => convert_create_view(cv),
692 sp::Statement::Drop {
693 object_type: sp::ObjectType::View,
694 if_exists,
695 names,
696 ..
697 } => {
698 if names.len() != 1 {
699 return Err(SqlError::Unsupported("multi-view DROP".into()));
700 }
701 Ok(Statement::DropView(DropViewStmt {
702 name: object_name_to_string(&names[0]),
703 if_exists,
704 }))
705 }
706 sp::Statement::AlterTable(at) => convert_alter_table(at),
707 sp::Statement::Insert(insert) => convert_insert(insert),
708 sp::Statement::Query(query) => convert_query(*query),
709 sp::Statement::Update(update) => convert_update(update),
710 sp::Statement::Delete(delete) => convert_delete(delete),
711 sp::Statement::StartTransaction { .. } => Ok(Statement::Begin),
712 sp::Statement::Commit { chain: true, .. } => {
713 Err(SqlError::Unsupported("COMMIT AND CHAIN".into()))
714 }
715 sp::Statement::Commit { .. } => Ok(Statement::Commit),
716 sp::Statement::Rollback { chain: true, .. } => {
717 Err(SqlError::Unsupported("ROLLBACK AND CHAIN".into()))
718 }
719 sp::Statement::Rollback {
720 savepoint: Some(name),
721 ..
722 } => Ok(Statement::RollbackTo(name.value.to_ascii_lowercase())),
723 sp::Statement::Rollback { .. } => Ok(Statement::Rollback),
724 sp::Statement::Savepoint { name } => {
725 Ok(Statement::Savepoint(name.value.to_ascii_lowercase()))
726 }
727 sp::Statement::ReleaseSavepoint { name } => {
728 Ok(Statement::ReleaseSavepoint(name.value.to_ascii_lowercase()))
729 }
730 sp::Statement::Set(sp::Set::SetTimeZone { value, .. }) => {
731 let zone = match value {
733 sp::Expr::Value(v) => match &v.value {
734 sp::Value::SingleQuotedString(s) => s.clone(),
735 sp::Value::DoubleQuotedString(s) => s.clone(),
736 other => other.to_string(),
737 },
738 sp::Expr::Identifier(ident) => ident.value.clone(),
739 other => {
740 return Err(SqlError::Parse(format!(
741 "SET TIME ZONE expects a string literal or identifier, got: {other}"
742 )))
743 }
744 };
745 Ok(Statement::SetTimezone(zone))
746 }
747 sp::Statement::Explain {
748 statement, analyze, ..
749 } => {
750 if analyze {
751 return Err(SqlError::Unsupported("EXPLAIN ANALYZE".into()));
752 }
753 let inner = convert_statement(*statement)?;
754 Ok(Statement::Explain(Box::new(inner)))
755 }
756 _ => Err(SqlError::Unsupported(format!("statement type: {}", stmt))),
757 }
758}
759
760fn convert_column_def(
763 col_def: &sp::ColumnDef,
764) -> Result<(ColumnSpec, Option<ForeignKeyDef>, bool, bool)> {
765 let col_name = col_def.name.value.clone();
766 let data_type = convert_data_type(&col_def.data_type)?;
767 let mut nullable = true;
768 let mut is_primary_key = false;
769 let mut is_unique = false;
770 let mut default_expr = None;
771 let mut default_sql = None;
772 let mut check_expr = None;
773 let mut check_sql = None;
774 let mut check_name = None;
775 let mut fk_def = None;
776
777 for opt in &col_def.options {
778 match &opt.option {
779 sp::ColumnOption::NotNull => nullable = false,
780 sp::ColumnOption::Null => nullable = true,
781 sp::ColumnOption::PrimaryKey(_) => {
782 is_primary_key = true;
783 nullable = false;
784 }
785 sp::ColumnOption::Unique(_) => is_unique = true,
786 sp::ColumnOption::Default(expr) => {
787 default_sql = Some(expr.to_string());
788 default_expr = Some(convert_expr(expr)?);
789 }
790 sp::ColumnOption::Check(check) => {
791 check_sql = Some(check.expr.to_string());
792 let converted = convert_expr(&check.expr)?;
793 if has_subquery(&converted) {
794 return Err(SqlError::Unsupported("subquery in CHECK constraint".into()));
795 }
796 check_expr = Some(converted);
797 check_name = check.name.as_ref().map(|n| n.value.clone());
798 }
799 sp::ColumnOption::ForeignKey(fk) => {
800 convert_fk_actions(&fk.on_delete, &fk.on_update)?;
801 let ftable = object_name_to_string(&fk.foreign_table).to_ascii_lowercase();
802 let referred: Vec<String> = fk
803 .referred_columns
804 .iter()
805 .map(|i| i.value.to_ascii_lowercase())
806 .collect();
807 fk_def = Some(ForeignKeyDef {
808 name: fk.name.as_ref().map(|n| n.value.clone()),
809 columns: vec![col_name.to_ascii_lowercase()],
810 foreign_table: ftable,
811 referred_columns: referred,
812 });
813 }
814 _ => {}
815 }
816 }
817
818 let spec = ColumnSpec {
819 name: col_name,
820 data_type,
821 nullable,
822 is_primary_key,
823 default_expr,
824 default_sql,
825 check_expr,
826 check_sql,
827 check_name,
828 };
829 Ok((spec, fk_def, is_primary_key, is_unique))
830}
831
832fn convert_create_table(ct: sp::CreateTable) -> Result<Statement> {
833 let name = object_name_to_string(&ct.name);
834 let if_not_exists = ct.if_not_exists;
835
836 let mut columns = Vec::new();
837 let mut inline_pk: Vec<String> = Vec::new();
838 let mut foreign_keys: Vec<ForeignKeyDef> = Vec::new();
839 let mut unique_indices: Vec<UniqueIndexDef> = Vec::new();
840
841 for col_def in &ct.columns {
842 let (spec, fk_def, was_pk, was_unique) = convert_column_def(col_def)?;
843 if was_pk {
844 inline_pk.push(spec.name.clone());
845 }
846 if let Some(fk) = fk_def {
847 foreign_keys.push(fk);
848 }
849 if was_unique && !was_pk {
850 unique_indices.push(UniqueIndexDef {
851 name: None,
852 columns: vec![spec.name.to_ascii_lowercase()],
853 });
854 }
855 columns.push(spec);
856 }
857
858 let mut check_constraints: Vec<TableCheckConstraint> = Vec::new();
859
860 for constraint in &ct.constraints {
861 match constraint {
862 sp::TableConstraint::PrimaryKey(pk_constraint) => {
863 for idx_col in &pk_constraint.columns {
864 let col_name = match &idx_col.column.expr {
865 sp::Expr::Identifier(ident) => ident.value.clone(),
866 _ => continue,
867 };
868 if !inline_pk.contains(&col_name) {
869 inline_pk.push(col_name.clone());
870 }
871 if let Some(col) = columns.iter_mut().find(|c| c.name == col_name) {
872 col.nullable = false;
873 col.is_primary_key = true;
874 }
875 }
876 }
877 sp::TableConstraint::Check(check) => {
878 let sql = check.expr.to_string();
879 let converted = convert_expr(&check.expr)?;
880 if has_subquery(&converted) {
881 return Err(SqlError::Unsupported("subquery in CHECK constraint".into()));
882 }
883 check_constraints.push(TableCheckConstraint {
884 name: check.name.as_ref().map(|n| n.value.clone()),
885 expr: converted,
886 sql,
887 });
888 }
889 sp::TableConstraint::ForeignKey(fk) => {
890 convert_fk_actions(&fk.on_delete, &fk.on_update)?;
891 let cols: Vec<String> = fk
892 .columns
893 .iter()
894 .map(|i| i.value.to_ascii_lowercase())
895 .collect();
896 let ftable = object_name_to_string(&fk.foreign_table).to_ascii_lowercase();
897 let referred: Vec<String> = fk
898 .referred_columns
899 .iter()
900 .map(|i| i.value.to_ascii_lowercase())
901 .collect();
902 foreign_keys.push(ForeignKeyDef {
903 name: fk.name.as_ref().map(|n| n.value.clone()),
904 columns: cols,
905 foreign_table: ftable,
906 referred_columns: referred,
907 });
908 }
909 sp::TableConstraint::Unique(u) => {
910 let cols: Vec<String> = u
911 .columns
912 .iter()
913 .filter_map(|idx_col| match &idx_col.column.expr {
914 sp::Expr::Identifier(ident) => Some(ident.value.to_ascii_lowercase()),
915 _ => None,
916 })
917 .collect();
918 if !cols.is_empty() {
919 unique_indices.push(UniqueIndexDef {
920 name: u.name.as_ref().map(|n| n.value.clone()),
921 columns: cols,
922 });
923 }
924 }
925 _ => {}
926 }
927 }
928
929 Ok(Statement::CreateTable(CreateTableStmt {
930 name,
931 columns,
932 primary_key: inline_pk,
933 if_not_exists,
934 check_constraints,
935 foreign_keys,
936 unique_indices,
937 }))
938}
939
940fn convert_alter_table(at: sp::AlterTable) -> Result<Statement> {
941 let table = object_name_to_string(&at.name);
942 if at.operations.len() != 1 {
943 return Err(SqlError::Unsupported(
944 "ALTER TABLE with multiple operations".into(),
945 ));
946 }
947 let op = match at.operations.into_iter().next().unwrap() {
948 sp::AlterTableOperation::AddColumn {
949 column_def,
950 if_not_exists,
951 ..
952 } => {
953 let (spec, fk, _was_pk, _was_unique) = convert_column_def(&column_def)?;
954 AlterTableOp::AddColumn {
955 column: Box::new(spec),
956 foreign_key: fk,
957 if_not_exists,
958 }
959 }
960 sp::AlterTableOperation::DropColumn {
961 column_names,
962 if_exists,
963 ..
964 } => {
965 if column_names.len() != 1 {
966 return Err(SqlError::Unsupported(
967 "DROP COLUMN with multiple columns".into(),
968 ));
969 }
970 AlterTableOp::DropColumn {
971 name: column_names.into_iter().next().unwrap().value,
972 if_exists,
973 }
974 }
975 sp::AlterTableOperation::RenameColumn {
976 old_column_name,
977 new_column_name,
978 } => AlterTableOp::RenameColumn {
979 old_name: old_column_name.value,
980 new_name: new_column_name.value,
981 },
982 sp::AlterTableOperation::RenameTable { table_name } => {
983 let new_name = match table_name {
984 sp::RenameTableNameKind::To(name) | sp::RenameTableNameKind::As(name) => {
985 object_name_to_string(&name)
986 }
987 };
988 AlterTableOp::RenameTable { new_name }
989 }
990 other => {
991 return Err(SqlError::Unsupported(format!(
992 "ALTER TABLE operation: {other}"
993 )));
994 }
995 };
996 Ok(Statement::AlterTable(Box::new(AlterTableStmt {
997 table,
998 op,
999 })))
1000}
1001
1002fn convert_fk_actions(
1003 on_delete: &Option<sp::ReferentialAction>,
1004 on_update: &Option<sp::ReferentialAction>,
1005) -> Result<()> {
1006 for action in [on_delete, on_update] {
1007 match action {
1008 None
1009 | Some(sp::ReferentialAction::Restrict)
1010 | Some(sp::ReferentialAction::NoAction) => {}
1011 Some(other) => {
1012 return Err(SqlError::Unsupported(format!(
1013 "FOREIGN KEY action: {other}"
1014 )));
1015 }
1016 }
1017 }
1018 Ok(())
1019}
1020
1021fn convert_create_index(ci: sp::CreateIndex) -> Result<Statement> {
1022 let index_name = ci
1023 .name
1024 .as_ref()
1025 .map(object_name_to_string)
1026 .ok_or_else(|| SqlError::Parse("index name required".into()))?;
1027
1028 let table_name = object_name_to_string(&ci.table_name);
1029
1030 let columns: Vec<String> = ci
1031 .columns
1032 .iter()
1033 .map(|idx_col| match &idx_col.column.expr {
1034 sp::Expr::Identifier(ident) => Ok(ident.value.clone()),
1035 other => Err(SqlError::Unsupported(format!("expression index: {other}"))),
1036 })
1037 .collect::<Result<_>>()?;
1038
1039 if columns.is_empty() {
1040 return Err(SqlError::Parse(
1041 "index must have at least one column".into(),
1042 ));
1043 }
1044
1045 Ok(Statement::CreateIndex(CreateIndexStmt {
1046 index_name,
1047 table_name,
1048 columns,
1049 unique: ci.unique,
1050 if_not_exists: ci.if_not_exists,
1051 }))
1052}
1053
1054fn convert_create_view(cv: sp::CreateView) -> Result<Statement> {
1055 let name = object_name_to_string(&cv.name);
1056
1057 if cv.materialized {
1058 return Err(SqlError::Unsupported("MATERIALIZED VIEW".into()));
1059 }
1060
1061 let sql = cv.query.to_string();
1062
1063 let dialect = GenericDialect {};
1064 let test = Parser::parse_sql(&dialect, &sql).map_err(|e| SqlError::Parse(e.to_string()))?;
1065 if test.is_empty() {
1066 return Err(SqlError::Parse("empty view definition".into()));
1067 }
1068 match &test[0] {
1069 sp::Statement::Query(_) => {}
1070 _ => {
1071 return Err(SqlError::Parse(
1072 "view body must be a SELECT statement".into(),
1073 ))
1074 }
1075 }
1076
1077 let column_aliases: Vec<String> = cv
1078 .columns
1079 .iter()
1080 .map(|c| c.name.value.to_ascii_lowercase())
1081 .collect();
1082
1083 Ok(Statement::CreateView(CreateViewStmt {
1084 name,
1085 sql,
1086 column_aliases,
1087 or_replace: cv.or_replace,
1088 if_not_exists: cv.if_not_exists,
1089 }))
1090}
1091
1092fn convert_insert(insert: sp::Insert) -> Result<Statement> {
1093 let table = match &insert.table {
1094 sp::TableObject::TableName(name) => object_name_to_string(name).to_ascii_lowercase(),
1095 _ => return Err(SqlError::Unsupported("INSERT into non-table object".into())),
1096 };
1097
1098 let columns: Vec<String> = insert
1099 .columns
1100 .iter()
1101 .map(|c| c.value.to_ascii_lowercase())
1102 .collect();
1103
1104 let query = insert
1105 .source
1106 .ok_or_else(|| SqlError::Parse("INSERT requires VALUES or SELECT".into()))?;
1107
1108 let source = match *query.body {
1109 sp::SetExpr::Values(sp::Values { rows, .. }) => {
1110 let mut result = Vec::new();
1111 for row in rows {
1112 let mut exprs = Vec::new();
1113 for expr in row {
1114 exprs.push(convert_expr(&expr)?);
1115 }
1116 result.push(exprs);
1117 }
1118 InsertSource::Values(result)
1119 }
1120 _ => {
1121 let (ctes, recursive) = if let Some(ref with) = query.with {
1122 convert_with(with)?
1123 } else {
1124 (vec![], false)
1125 };
1126 let body = convert_query_body(&query)?;
1127 InsertSource::Select(Box::new(SelectQuery {
1128 ctes,
1129 recursive,
1130 body,
1131 }))
1132 }
1133 };
1134
1135 Ok(Statement::Insert(InsertStmt {
1136 table,
1137 columns,
1138 source,
1139 }))
1140}
1141
1142fn convert_select_body(select: &sp::Select) -> Result<SelectStmt> {
1143 let distinct = match &select.distinct {
1144 Some(sp::Distinct::Distinct) => true,
1145 Some(sp::Distinct::On(_)) => {
1146 return Err(SqlError::Unsupported("DISTINCT ON".into()));
1147 }
1148 _ => false,
1149 };
1150
1151 let (from, from_alias, joins) = if select.from.is_empty() {
1152 (String::new(), None, vec![])
1153 } else if select.from.len() == 1 {
1154 let table_with_joins = &select.from[0];
1155 let (name, alias) = match &table_with_joins.relation {
1156 sp::TableFactor::Table { name, alias, .. } => {
1157 let table_name = object_name_to_string(name);
1158 let alias_str = alias.as_ref().map(|a| a.name.value.clone());
1159 (table_name, alias_str)
1160 }
1161 _ => return Err(SqlError::Unsupported("non-table FROM source".into())),
1162 };
1163 let j = table_with_joins
1164 .joins
1165 .iter()
1166 .map(convert_join)
1167 .collect::<Result<Vec<_>>>()?;
1168 (name, alias, j)
1169 } else {
1170 return Err(SqlError::Unsupported("comma-separated FROM tables".into()));
1171 };
1172
1173 let columns: Vec<SelectColumn> = select
1174 .projection
1175 .iter()
1176 .map(convert_select_item)
1177 .collect::<Result<_>>()?;
1178
1179 let where_clause = select.selection.as_ref().map(convert_expr).transpose()?;
1180
1181 let group_by = match &select.group_by {
1182 sp::GroupByExpr::Expressions(exprs, _) => {
1183 exprs.iter().map(convert_expr).collect::<Result<_>>()?
1184 }
1185 sp::GroupByExpr::All(_) => {
1186 return Err(SqlError::Unsupported("GROUP BY ALL".into()));
1187 }
1188 };
1189
1190 let having = select.having.as_ref().map(convert_expr).transpose()?;
1191
1192 Ok(SelectStmt {
1193 columns,
1194 from,
1195 from_alias,
1196 joins,
1197 distinct,
1198 where_clause,
1199 order_by: vec![],
1200 limit: None,
1201 offset: None,
1202 group_by,
1203 having,
1204 })
1205}
1206
1207fn convert_set_expr(set_expr: &sp::SetExpr) -> Result<QueryBody> {
1208 match set_expr {
1209 sp::SetExpr::Select(sel) => Ok(QueryBody::Select(Box::new(convert_select_body(sel)?))),
1210 sp::SetExpr::SetOperation {
1211 op,
1212 set_quantifier,
1213 left,
1214 right,
1215 } => {
1216 let set_op = match op {
1217 sp::SetOperator::Union => SetOp::Union,
1218 sp::SetOperator::Intersect => SetOp::Intersect,
1219 sp::SetOperator::Except | sp::SetOperator::Minus => SetOp::Except,
1220 };
1221 let all = match set_quantifier {
1222 sp::SetQuantifier::All => true,
1223 sp::SetQuantifier::None | sp::SetQuantifier::Distinct => false,
1224 _ => {
1225 return Err(SqlError::Unsupported("BY NAME set operations".into()));
1226 }
1227 };
1228 Ok(QueryBody::Compound(Box::new(CompoundSelect {
1229 op: set_op,
1230 all,
1231 left: Box::new(convert_set_expr(left)?),
1232 right: Box::new(convert_set_expr(right)?),
1233 order_by: vec![],
1234 limit: None,
1235 offset: None,
1236 })))
1237 }
1238 _ => Err(SqlError::Unsupported("unsupported set expression".into())),
1239 }
1240}
1241
1242fn convert_query_body(query: &sp::Query) -> Result<QueryBody> {
1243 let mut body = convert_set_expr(&query.body)?;
1244
1245 let order_by = if let Some(ref ob) = query.order_by {
1246 match &ob.kind {
1247 sp::OrderByKind::Expressions(exprs) => exprs
1248 .iter()
1249 .map(convert_order_by_expr)
1250 .collect::<Result<_>>()?,
1251 sp::OrderByKind::All { .. } => {
1252 return Err(SqlError::Unsupported("ORDER BY ALL".into()));
1253 }
1254 }
1255 } else {
1256 vec![]
1257 };
1258
1259 let (limit, offset) = match &query.limit_clause {
1260 Some(sp::LimitClause::LimitOffset { limit, offset, .. }) => {
1261 let l = limit.as_ref().map(convert_expr).transpose()?;
1262 let o = offset
1263 .as_ref()
1264 .map(|o| convert_expr(&o.value))
1265 .transpose()?;
1266 (l, o)
1267 }
1268 Some(sp::LimitClause::OffsetCommaLimit { limit, offset }) => {
1269 let l = Some(convert_expr(limit)?);
1270 let o = Some(convert_expr(offset)?);
1271 (l, o)
1272 }
1273 None => (None, None),
1274 };
1275
1276 match &mut body {
1277 QueryBody::Select(sel) => {
1278 sel.order_by = order_by;
1279 sel.limit = limit;
1280 sel.offset = offset;
1281 }
1282 QueryBody::Compound(comp) => {
1283 comp.order_by = order_by;
1284 comp.limit = limit;
1285 comp.offset = offset;
1286 }
1287 }
1288
1289 Ok(body)
1290}
1291
1292fn convert_subquery(query: &sp::Query) -> Result<SelectStmt> {
1293 if query.with.is_some() {
1294 return Err(SqlError::Unsupported("CTEs in subqueries".into()));
1295 }
1296 match convert_query_body(query)? {
1297 QueryBody::Select(s) => Ok(*s),
1298 QueryBody::Compound(_) => Err(SqlError::Unsupported(
1299 "UNION/INTERSECT/EXCEPT in subqueries".into(),
1300 )),
1301 }
1302}
1303
1304fn convert_with(with: &sp::With) -> Result<(Vec<CteDefinition>, bool)> {
1305 let mut names = std::collections::HashSet::new();
1306 let mut ctes = Vec::new();
1307 for cte in &with.cte_tables {
1308 let name = cte.alias.name.value.to_ascii_lowercase();
1309 if !names.insert(name.clone()) {
1310 return Err(SqlError::DuplicateCteName(name));
1311 }
1312 let column_aliases: Vec<String> = cte
1313 .alias
1314 .columns
1315 .iter()
1316 .map(|c| c.name.value.to_ascii_lowercase())
1317 .collect();
1318 let body = convert_query_body(&cte.query)?;
1319 ctes.push(CteDefinition {
1320 name,
1321 column_aliases,
1322 body,
1323 });
1324 }
1325 Ok((ctes, with.recursive))
1326}
1327
1328fn convert_query(query: sp::Query) -> Result<Statement> {
1329 let (ctes, recursive) = if let Some(ref with) = query.with {
1330 convert_with(with)?
1331 } else {
1332 (vec![], false)
1333 };
1334 let body = convert_query_body(&query)?;
1335 Ok(Statement::Select(Box::new(SelectQuery {
1336 ctes,
1337 recursive,
1338 body,
1339 })))
1340}
1341
1342fn convert_join(join: &sp::Join) -> Result<JoinClause> {
1343 let (join_type, constraint) = match &join.join_operator {
1344 sp::JoinOperator::Inner(c) => (JoinType::Inner, Some(c)),
1345 sp::JoinOperator::Join(c) => (JoinType::Inner, Some(c)),
1346 sp::JoinOperator::CrossJoin(c) => (JoinType::Cross, Some(c)),
1347 sp::JoinOperator::Left(c) => (JoinType::Left, Some(c)),
1348 sp::JoinOperator::LeftSemi(c) => (JoinType::Left, Some(c)),
1349 sp::JoinOperator::LeftAnti(c) => (JoinType::Left, Some(c)),
1350 sp::JoinOperator::Right(c) => (JoinType::Right, Some(c)),
1351 sp::JoinOperator::RightSemi(c) => (JoinType::Right, Some(c)),
1352 sp::JoinOperator::RightAnti(c) => (JoinType::Right, Some(c)),
1353 other => return Err(SqlError::Unsupported(format!("join type: {other:?}"))),
1354 };
1355
1356 let (name, alias) = match &join.relation {
1357 sp::TableFactor::Table { name, alias, .. } => {
1358 let table_name = object_name_to_string(name);
1359 let alias_str = alias.as_ref().map(|a| a.name.value.clone());
1360 (table_name, alias_str)
1361 }
1362 _ => return Err(SqlError::Unsupported("non-table JOIN source".into())),
1363 };
1364
1365 let on_clause = match constraint {
1366 Some(sp::JoinConstraint::On(expr)) => Some(convert_expr(expr)?),
1367 Some(sp::JoinConstraint::None) | None => None,
1368 Some(other) => return Err(SqlError::Unsupported(format!("join constraint: {other:?}"))),
1369 };
1370
1371 Ok(JoinClause {
1372 join_type,
1373 table: TableRef { name, alias },
1374 on_clause,
1375 })
1376}
1377
1378fn convert_update(update: sp::Update) -> Result<Statement> {
1379 let table = match &update.table.relation {
1380 sp::TableFactor::Table { name, .. } => object_name_to_string(name),
1381 _ => return Err(SqlError::Unsupported("non-table UPDATE target".into())),
1382 };
1383
1384 let assignments = update
1385 .assignments
1386 .iter()
1387 .map(|a| {
1388 let col = match &a.target {
1389 sp::AssignmentTarget::ColumnName(name) => object_name_to_string(name),
1390 _ => return Err(SqlError::Unsupported("tuple assignment".into())),
1391 };
1392 let expr = convert_expr(&a.value)?;
1393 Ok((col, expr))
1394 })
1395 .collect::<Result<_>>()?;
1396
1397 let where_clause = update.selection.as_ref().map(convert_expr).transpose()?;
1398
1399 Ok(Statement::Update(UpdateStmt {
1400 table,
1401 assignments,
1402 where_clause,
1403 }))
1404}
1405
1406fn convert_delete(delete: sp::Delete) -> Result<Statement> {
1407 let table_name = match &delete.from {
1408 sp::FromTable::WithFromKeyword(tables) => {
1409 if tables.len() != 1 {
1410 return Err(SqlError::Unsupported("multi-table DELETE".into()));
1411 }
1412 match &tables[0].relation {
1413 sp::TableFactor::Table { name, .. } => object_name_to_string(name),
1414 _ => return Err(SqlError::Unsupported("non-table DELETE target".into())),
1415 }
1416 }
1417 sp::FromTable::WithoutKeyword(tables) => {
1418 if tables.len() != 1 {
1419 return Err(SqlError::Unsupported("multi-table DELETE".into()));
1420 }
1421 match &tables[0].relation {
1422 sp::TableFactor::Table { name, .. } => object_name_to_string(name),
1423 _ => return Err(SqlError::Unsupported("non-table DELETE target".into())),
1424 }
1425 }
1426 };
1427
1428 let where_clause = delete.selection.as_ref().map(convert_expr).transpose()?;
1429
1430 Ok(Statement::Delete(DeleteStmt {
1431 table: table_name,
1432 where_clause,
1433 }))
1434}
1435
1436fn convert_expr(expr: &sp::Expr) -> Result<Expr> {
1437 match expr {
1438 sp::Expr::Value(v) => convert_value(&v.value),
1439 sp::Expr::Identifier(ident) => Ok(Expr::Column(ident.value.to_ascii_lowercase())),
1440 sp::Expr::CompoundIdentifier(parts) => {
1441 if parts.len() == 2 {
1442 Ok(Expr::QualifiedColumn {
1443 table: parts[0].value.to_ascii_lowercase(),
1444 column: parts[1].value.to_ascii_lowercase(),
1445 })
1446 } else {
1447 Ok(Expr::Column(
1448 parts.last().unwrap().value.to_ascii_lowercase(),
1449 ))
1450 }
1451 }
1452 sp::Expr::BinaryOp { left, op, right } => {
1453 let bin_op = convert_bin_op(op)?;
1454 Ok(Expr::BinaryOp {
1455 left: Box::new(convert_expr(left)?),
1456 op: bin_op,
1457 right: Box::new(convert_expr(right)?),
1458 })
1459 }
1460 sp::Expr::UnaryOp { op, expr } => {
1461 let unary_op = match op {
1462 sp::UnaryOperator::Minus => UnaryOp::Neg,
1463 sp::UnaryOperator::Not => UnaryOp::Not,
1464 _ => return Err(SqlError::Unsupported(format!("unary op: {op}"))),
1465 };
1466 Ok(Expr::UnaryOp {
1467 op: unary_op,
1468 expr: Box::new(convert_expr(expr)?),
1469 })
1470 }
1471 sp::Expr::IsNull(e) => Ok(Expr::IsNull(Box::new(convert_expr(e)?))),
1472 sp::Expr::IsNotNull(e) => Ok(Expr::IsNotNull(Box::new(convert_expr(e)?))),
1473 sp::Expr::Nested(e) => convert_expr(e),
1474 sp::Expr::Function(func) => convert_function(func),
1475 sp::Expr::InSubquery {
1476 expr: e,
1477 subquery,
1478 negated,
1479 } => {
1480 let inner_expr = convert_expr(e)?;
1481 let stmt = convert_subquery(subquery)?;
1482 Ok(Expr::InSubquery {
1483 expr: Box::new(inner_expr),
1484 subquery: Box::new(stmt),
1485 negated: *negated,
1486 })
1487 }
1488 sp::Expr::InList {
1489 expr: e,
1490 list,
1491 negated,
1492 } => {
1493 let inner_expr = convert_expr(e)?;
1494 let items = list.iter().map(convert_expr).collect::<Result<Vec<_>>>()?;
1495 Ok(Expr::InList {
1496 expr: Box::new(inner_expr),
1497 list: items,
1498 negated: *negated,
1499 })
1500 }
1501 sp::Expr::Exists { subquery, negated } => {
1502 let stmt = convert_subquery(subquery)?;
1503 Ok(Expr::Exists {
1504 subquery: Box::new(stmt),
1505 negated: *negated,
1506 })
1507 }
1508 sp::Expr::Subquery(query) => {
1509 let stmt = convert_subquery(query)?;
1510 Ok(Expr::ScalarSubquery(Box::new(stmt)))
1511 }
1512 sp::Expr::Between {
1513 expr: e,
1514 negated,
1515 low,
1516 high,
1517 } => Ok(Expr::Between {
1518 expr: Box::new(convert_expr(e)?),
1519 low: Box::new(convert_expr(low)?),
1520 high: Box::new(convert_expr(high)?),
1521 negated: *negated,
1522 }),
1523 sp::Expr::Like {
1524 expr: e,
1525 negated,
1526 pattern,
1527 escape_char,
1528 ..
1529 } => {
1530 let esc = escape_char
1531 .as_ref()
1532 .map(convert_escape_value)
1533 .transpose()?
1534 .map(Box::new);
1535 Ok(Expr::Like {
1536 expr: Box::new(convert_expr(e)?),
1537 pattern: Box::new(convert_expr(pattern)?),
1538 escape: esc,
1539 negated: *negated,
1540 })
1541 }
1542 sp::Expr::ILike {
1543 expr: e,
1544 negated,
1545 pattern,
1546 escape_char,
1547 ..
1548 } => {
1549 let esc = escape_char
1550 .as_ref()
1551 .map(convert_escape_value)
1552 .transpose()?
1553 .map(Box::new);
1554 Ok(Expr::Like {
1555 expr: Box::new(convert_expr(e)?),
1556 pattern: Box::new(convert_expr(pattern)?),
1557 escape: esc,
1558 negated: *negated,
1559 })
1560 }
1561 sp::Expr::Case {
1562 operand,
1563 conditions,
1564 else_result,
1565 ..
1566 } => {
1567 let op = operand
1568 .as_ref()
1569 .map(|e| convert_expr(e))
1570 .transpose()?
1571 .map(Box::new);
1572 let conds: Vec<(Expr, Expr)> = conditions
1573 .iter()
1574 .map(|cw| Ok((convert_expr(&cw.condition)?, convert_expr(&cw.result)?)))
1575 .collect::<Result<_>>()?;
1576 let else_r = else_result
1577 .as_ref()
1578 .map(|e| convert_expr(e))
1579 .transpose()?
1580 .map(Box::new);
1581 Ok(Expr::Case {
1582 operand: op,
1583 conditions: conds,
1584 else_result: else_r,
1585 })
1586 }
1587 sp::Expr::Cast {
1588 expr: e,
1589 data_type: dt,
1590 ..
1591 } => {
1592 let target = convert_data_type(dt)?;
1593 Ok(Expr::Cast {
1594 expr: Box::new(convert_expr(e)?),
1595 data_type: target,
1596 })
1597 }
1598 sp::Expr::Substring {
1599 expr: e,
1600 substring_from,
1601 substring_for,
1602 ..
1603 } => {
1604 let mut args = vec![convert_expr(e)?];
1605 if let Some(from) = substring_from {
1606 args.push(convert_expr(from)?);
1607 }
1608 if let Some(f) = substring_for {
1609 args.push(convert_expr(f)?);
1610 }
1611 Ok(Expr::Function {
1612 name: "SUBSTR".into(),
1613 args,
1614 distinct: false,
1615 })
1616 }
1617 sp::Expr::Trim {
1618 expr: e,
1619 trim_where,
1620 trim_what,
1621 trim_characters,
1622 } => {
1623 let fn_name = match trim_where {
1624 Some(sp::TrimWhereField::Leading) => "LTRIM",
1625 Some(sp::TrimWhereField::Trailing) => "RTRIM",
1626 _ => "TRIM",
1627 };
1628 let mut args = vec![convert_expr(e)?];
1629 if let Some(what) = trim_what {
1630 args.push(convert_expr(what)?);
1631 } else if let Some(chars) = trim_characters {
1632 if let Some(first) = chars.first() {
1633 args.push(convert_expr(first)?);
1634 }
1635 }
1636 Ok(Expr::Function {
1637 name: fn_name.into(),
1638 args,
1639 distinct: false,
1640 })
1641 }
1642 sp::Expr::Ceil { expr: e, .. } => Ok(Expr::Function {
1643 name: "CEIL".into(),
1644 args: vec![convert_expr(e)?],
1645 distinct: false,
1646 }),
1647 sp::Expr::Floor { expr: e, .. } => Ok(Expr::Function {
1648 name: "FLOOR".into(),
1649 args: vec![convert_expr(e)?],
1650 distinct: false,
1651 }),
1652 sp::Expr::Position { expr: e, r#in } => Ok(Expr::Function {
1653 name: "INSTR".into(),
1654 args: vec![convert_expr(r#in)?, convert_expr(e)?],
1655 distinct: false,
1656 }),
1657 sp::Expr::TypedString(ts) => {
1659 let raw = match &ts.value.value {
1660 sp::Value::SingleQuotedString(s) => s.clone(),
1661 sp::Value::DoubleQuotedString(s) => s.clone(),
1662 other => other.to_string(),
1663 };
1664 convert_typed_string(&ts.data_type, &raw)
1665 }
1666 sp::Expr::Interval(iv) => convert_interval_expr(iv),
1668 sp::Expr::Extract { field, expr: e, .. } => {
1670 let field_name = match field {
1671 sp::DateTimeField::Year => "year",
1672 sp::DateTimeField::Month => "month",
1673 sp::DateTimeField::Week(_) => "week",
1674 sp::DateTimeField::Day => "day",
1675 sp::DateTimeField::Date => "day",
1676 sp::DateTimeField::Hour => "hour",
1677 sp::DateTimeField::Minute => "minute",
1678 sp::DateTimeField::Second => "second",
1679 sp::DateTimeField::Millisecond => "milliseconds",
1680 sp::DateTimeField::Microsecond => "microseconds",
1681 sp::DateTimeField::Microseconds => "microseconds",
1682 sp::DateTimeField::Milliseconds => "milliseconds",
1683 sp::DateTimeField::Dow => "dow",
1684 sp::DateTimeField::Isodow => "isodow",
1685 sp::DateTimeField::Doy => "doy",
1686 sp::DateTimeField::Epoch => "epoch",
1687 sp::DateTimeField::Quarter => "quarter",
1688 sp::DateTimeField::Decade => "decade",
1689 sp::DateTimeField::Century => "century",
1690 sp::DateTimeField::Millennium => "millennium",
1691 sp::DateTimeField::Isoyear => "isoyear",
1692 sp::DateTimeField::Julian => "julian",
1693 other => {
1694 return Err(SqlError::InvalidExtractField(format!("{other:?}")));
1695 }
1696 };
1697 Ok(Expr::Function {
1698 name: "EXTRACT".into(),
1699 args: vec![
1700 Expr::Literal(Value::Text(field_name.into())),
1701 convert_expr(e)?,
1702 ],
1703 distinct: false,
1704 })
1705 }
1706 sp::Expr::AtTimeZone {
1708 timestamp,
1709 time_zone,
1710 } => Ok(Expr::Function {
1711 name: "AT_TIMEZONE".into(),
1712 args: vec![convert_expr(timestamp)?, convert_expr(time_zone)?],
1713 distinct: false,
1714 }),
1715 _ => Err(SqlError::Unsupported(format!("expression: {expr}"))),
1716 }
1717}
1718
1719fn convert_value(val: &sp::Value) -> Result<Expr> {
1720 match val {
1721 sp::Value::Number(n, _) => {
1722 if let Ok(i) = n.parse::<i64>() {
1723 Ok(Expr::Literal(Value::Integer(i)))
1724 } else if let Ok(f) = n.parse::<f64>() {
1725 Ok(Expr::Literal(Value::Real(f)))
1726 } else {
1727 Err(SqlError::InvalidValue(format!("cannot parse number: {n}")))
1728 }
1729 }
1730 sp::Value::SingleQuotedString(s) => Ok(Expr::Literal(Value::Text(s.as_str().into()))),
1731 sp::Value::Boolean(b) => Ok(Expr::Literal(Value::Boolean(*b))),
1732 sp::Value::Null => Ok(Expr::Literal(Value::Null)),
1733 sp::Value::Placeholder(s) => {
1734 let idx_str = s
1735 .strip_prefix('$')
1736 .ok_or_else(|| SqlError::Parse(format!("invalid placeholder: {s}")))?;
1737 let idx: usize = idx_str
1738 .parse()
1739 .map_err(|_| SqlError::Parse(format!("invalid placeholder index: {s}")))?;
1740 if idx == 0 {
1741 return Err(SqlError::Parse("placeholder index must be >= 1".into()));
1742 }
1743 Ok(Expr::Parameter(idx))
1744 }
1745 _ => Err(SqlError::Unsupported(format!("value type: {val}"))),
1746 }
1747}
1748
1749fn convert_typed_string(dt: &sp::DataType, value: &str) -> Result<Expr> {
1750 let s = value.trim_matches('\'');
1751 match dt {
1752 sp::DataType::Date => {
1753 let d = crate::datetime::parse_date(s)?;
1754 Ok(Expr::Literal(Value::Date(d)))
1755 }
1756 sp::DataType::Time(_, _) => {
1757 let t = crate::datetime::parse_time(s)?;
1758 Ok(Expr::Literal(Value::Time(t)))
1759 }
1760 sp::DataType::Timestamp(_, _) => {
1761 let t = crate::datetime::parse_timestamp(s)?;
1762 Ok(Expr::Literal(Value::Timestamp(t)))
1763 }
1764 sp::DataType::Interval { .. } => {
1765 let (months, days, micros) = crate::datetime::parse_interval(s)?;
1766 Ok(Expr::Literal(Value::Interval {
1767 months,
1768 days,
1769 micros,
1770 }))
1771 }
1772 _ => {
1773 let target = convert_data_type(dt)?;
1774 Ok(Expr::Cast {
1775 expr: Box::new(Expr::Literal(Value::Text(s.into()))),
1776 data_type: target,
1777 })
1778 }
1779 }
1780}
1781
1782fn convert_interval_expr(iv: &sp::Interval) -> Result<Expr> {
1783 let raw = match iv.value.as_ref() {
1784 sp::Expr::Value(v) => match &v.value {
1785 sp::Value::SingleQuotedString(s) => s.clone(),
1786 sp::Value::Number(n, _) => n.clone(),
1787 other => {
1788 return Err(SqlError::InvalidIntervalLiteral(format!(
1789 "unsupported inner value: {other}"
1790 )))
1791 }
1792 },
1793 other => {
1794 return Err(SqlError::InvalidIntervalLiteral(format!(
1795 "unsupported inner expr: {other}"
1796 )))
1797 }
1798 };
1799
1800 let with_unit = if let Some(field) = &iv.leading_field {
1802 let unit_name = match field {
1803 sp::DateTimeField::Year => "years",
1804 sp::DateTimeField::Month => "months",
1805 sp::DateTimeField::Week(_) => "weeks",
1806 sp::DateTimeField::Day => "days",
1807 sp::DateTimeField::Hour => "hours",
1808 sp::DateTimeField::Minute => "minutes",
1809 sp::DateTimeField::Second => "seconds",
1810 _ => {
1811 return Err(SqlError::InvalidIntervalLiteral(format!(
1812 "unsupported leading field: {field:?}"
1813 )))
1814 }
1815 };
1816 format!("{raw} {unit_name}")
1817 } else {
1818 raw
1819 };
1820
1821 let (months, days, micros) = crate::datetime::parse_interval(&with_unit)?;
1822 Ok(Expr::Literal(Value::Interval {
1823 months,
1824 days,
1825 micros,
1826 }))
1827}
1828
1829fn convert_escape_value(val: &sp::Value) -> Result<Expr> {
1830 match val {
1831 sp::Value::SingleQuotedString(s) => Ok(Expr::Literal(Value::Text(s.as_str().into()))),
1832 _ => Err(SqlError::Unsupported(format!("ESCAPE value: {val}"))),
1833 }
1834}
1835
1836fn convert_bin_op(op: &sp::BinaryOperator) -> Result<BinOp> {
1837 match op {
1838 sp::BinaryOperator::Plus => Ok(BinOp::Add),
1839 sp::BinaryOperator::Minus => Ok(BinOp::Sub),
1840 sp::BinaryOperator::Multiply => Ok(BinOp::Mul),
1841 sp::BinaryOperator::Divide => Ok(BinOp::Div),
1842 sp::BinaryOperator::Modulo => Ok(BinOp::Mod),
1843 sp::BinaryOperator::Eq => Ok(BinOp::Eq),
1844 sp::BinaryOperator::NotEq => Ok(BinOp::NotEq),
1845 sp::BinaryOperator::Lt => Ok(BinOp::Lt),
1846 sp::BinaryOperator::Gt => Ok(BinOp::Gt),
1847 sp::BinaryOperator::LtEq => Ok(BinOp::LtEq),
1848 sp::BinaryOperator::GtEq => Ok(BinOp::GtEq),
1849 sp::BinaryOperator::And => Ok(BinOp::And),
1850 sp::BinaryOperator::Or => Ok(BinOp::Or),
1851 sp::BinaryOperator::StringConcat => Ok(BinOp::Concat),
1852 _ => Err(SqlError::Unsupported(format!("binary op: {op}"))),
1853 }
1854}
1855
1856fn convert_function(func: &sp::Function) -> Result<Expr> {
1857 let name = object_name_to_string(&func.name).to_ascii_uppercase();
1858
1859 let (args, is_count_star, distinct) = match &func.args {
1860 sp::FunctionArguments::List(list) => {
1861 let distinct = matches!(
1862 list.duplicate_treatment,
1863 Some(sp::DuplicateTreatment::Distinct)
1864 );
1865 if list.args.is_empty() && name == "COUNT" {
1866 (vec![], true, distinct)
1867 } else {
1868 let mut count_star = false;
1869 let args = list
1870 .args
1871 .iter()
1872 .map(|arg| match arg {
1873 sp::FunctionArg::Unnamed(sp::FunctionArgExpr::Expr(e)) => convert_expr(e),
1874 sp::FunctionArg::Unnamed(sp::FunctionArgExpr::Wildcard) => {
1875 if name == "COUNT" {
1876 count_star = true;
1877 Ok(Expr::CountStar)
1878 } else {
1879 Err(SqlError::Unsupported(format!("{name}(*)")))
1880 }
1881 }
1882 _ => Err(SqlError::Unsupported(format!(
1883 "function arg type in {name}"
1884 ))),
1885 })
1886 .collect::<Result<Vec<_>>>()?;
1887 if name == "COUNT" && args.len() == 1 && count_star {
1888 (vec![], true, distinct)
1889 } else {
1890 (args, false, distinct)
1891 }
1892 }
1893 }
1894 sp::FunctionArguments::None => {
1895 if name == "COUNT" {
1896 (vec![], true, false)
1897 } else {
1898 (vec![], false, false)
1899 }
1900 }
1901 sp::FunctionArguments::Subquery(_) => {
1902 return Err(SqlError::Unsupported("subquery in function".into()));
1903 }
1904 };
1905
1906 if let Some(over) = &func.over {
1908 let spec = match over {
1909 sp::WindowType::WindowSpec(ws) => convert_window_spec(ws)?,
1910 sp::WindowType::NamedWindow(_) => {
1911 return Err(SqlError::Unsupported("named windows".into()));
1912 }
1913 };
1914 return Ok(Expr::WindowFunction { name, args, spec });
1915 }
1916
1917 if is_count_star {
1919 return Ok(Expr::CountStar);
1920 }
1921
1922 if name == "COALESCE" {
1923 if args.is_empty() {
1924 return Err(SqlError::Parse(
1925 "COALESCE requires at least one argument".into(),
1926 ));
1927 }
1928 return Ok(Expr::Coalesce(args));
1929 }
1930
1931 if name == "NULLIF" {
1932 if args.len() != 2 {
1933 return Err(SqlError::Parse(
1934 "NULLIF requires exactly two arguments".into(),
1935 ));
1936 }
1937 return Ok(Expr::Case {
1938 operand: None,
1939 conditions: vec![(
1940 Expr::BinaryOp {
1941 left: Box::new(args[0].clone()),
1942 op: BinOp::Eq,
1943 right: Box::new(args[1].clone()),
1944 },
1945 Expr::Literal(Value::Null),
1946 )],
1947 else_result: Some(Box::new(args[0].clone())),
1948 });
1949 }
1950
1951 if name == "IIF" {
1952 if args.len() != 3 {
1953 return Err(SqlError::Parse(
1954 "IIF requires exactly three arguments".into(),
1955 ));
1956 }
1957 return Ok(Expr::Case {
1958 operand: None,
1959 conditions: vec![(args[0].clone(), args[1].clone())],
1960 else_result: Some(Box::new(args[2].clone())),
1961 });
1962 }
1963
1964 Ok(Expr::Function {
1965 name,
1966 args,
1967 distinct,
1968 })
1969}
1970
1971fn convert_window_spec(ws: &sp::WindowSpec) -> Result<WindowSpec> {
1972 let partition_by = ws
1973 .partition_by
1974 .iter()
1975 .map(convert_expr)
1976 .collect::<Result<Vec<_>>>()?;
1977 let order_by = ws
1978 .order_by
1979 .iter()
1980 .map(convert_order_by_expr)
1981 .collect::<Result<Vec<_>>>()?;
1982 let frame = ws
1983 .window_frame
1984 .as_ref()
1985 .map(convert_window_frame)
1986 .transpose()?;
1987 Ok(WindowSpec {
1988 partition_by,
1989 order_by,
1990 frame,
1991 })
1992}
1993
1994fn convert_window_frame(wf: &sp::WindowFrame) -> Result<WindowFrame> {
1995 let units = match wf.units {
1996 sp::WindowFrameUnits::Rows => WindowFrameUnits::Rows,
1997 sp::WindowFrameUnits::Range => WindowFrameUnits::Range,
1998 sp::WindowFrameUnits::Groups => {
1999 return Err(SqlError::Unsupported("GROUPS window frame".into()));
2000 }
2001 };
2002 let start = convert_window_frame_bound(&wf.start_bound)?;
2003 let end = match &wf.end_bound {
2004 Some(b) => convert_window_frame_bound(b)?,
2005 None => WindowFrameBound::CurrentRow,
2006 };
2007 Ok(WindowFrame { units, start, end })
2008}
2009
2010fn convert_window_frame_bound(b: &sp::WindowFrameBound) -> Result<WindowFrameBound> {
2011 match b {
2012 sp::WindowFrameBound::CurrentRow => Ok(WindowFrameBound::CurrentRow),
2013 sp::WindowFrameBound::Preceding(None) => Ok(WindowFrameBound::UnboundedPreceding),
2014 sp::WindowFrameBound::Preceding(Some(e)) => {
2015 Ok(WindowFrameBound::Preceding(Box::new(convert_expr(e)?)))
2016 }
2017 sp::WindowFrameBound::Following(None) => Ok(WindowFrameBound::UnboundedFollowing),
2018 sp::WindowFrameBound::Following(Some(e)) => {
2019 Ok(WindowFrameBound::Following(Box::new(convert_expr(e)?)))
2020 }
2021 }
2022}
2023
2024fn convert_select_item(item: &sp::SelectItem) -> Result<SelectColumn> {
2025 match item {
2026 sp::SelectItem::Wildcard(_) => Ok(SelectColumn::AllColumns),
2027 sp::SelectItem::UnnamedExpr(e) => {
2028 let expr = convert_expr(e)?;
2029 Ok(SelectColumn::Expr { expr, alias: None })
2030 }
2031 sp::SelectItem::ExprWithAlias { expr, alias } => {
2032 let expr = convert_expr(expr)?;
2033 Ok(SelectColumn::Expr {
2034 expr,
2035 alias: Some(alias.value.clone()),
2036 })
2037 }
2038 sp::SelectItem::QualifiedWildcard(_, _) => {
2039 Err(SqlError::Unsupported("qualified wildcard (table.*)".into()))
2040 }
2041 }
2042}
2043
2044fn convert_order_by_expr(expr: &sp::OrderByExpr) -> Result<OrderByItem> {
2045 let e = convert_expr(&expr.expr)?;
2046 let descending = expr.options.asc.map(|asc| !asc).unwrap_or(false);
2047 let nulls_first = expr.options.nulls_first;
2048
2049 Ok(OrderByItem {
2050 expr: e,
2051 descending,
2052 nulls_first,
2053 })
2054}
2055
2056fn convert_data_type(dt: &sp::DataType) -> Result<DataType> {
2057 match dt {
2058 sp::DataType::Int(_)
2059 | sp::DataType::Integer(_)
2060 | sp::DataType::BigInt(_)
2061 | sp::DataType::SmallInt(_)
2062 | sp::DataType::TinyInt(_)
2063 | sp::DataType::Int2(_)
2064 | sp::DataType::Int4(_)
2065 | sp::DataType::Int8(_) => Ok(DataType::Integer),
2066
2067 sp::DataType::Real
2068 | sp::DataType::Double(..)
2069 | sp::DataType::DoublePrecision
2070 | sp::DataType::Float(_)
2071 | sp::DataType::Float4
2072 | sp::DataType::Float64 => Ok(DataType::Real),
2073
2074 sp::DataType::Varchar(_)
2075 | sp::DataType::Text
2076 | sp::DataType::Char(_)
2077 | sp::DataType::Character(_)
2078 | sp::DataType::String(_) => Ok(DataType::Text),
2079
2080 sp::DataType::Blob(_) | sp::DataType::Bytea => Ok(DataType::Blob),
2081
2082 sp::DataType::Boolean | sp::DataType::Bool => Ok(DataType::Boolean),
2083
2084 sp::DataType::Date => Ok(DataType::Date),
2085 sp::DataType::Time(_, _) => Ok(DataType::Time),
2086 sp::DataType::Timestamp(_, _) => Ok(DataType::Timestamp),
2087 sp::DataType::Interval { .. } => Ok(DataType::Interval),
2088
2089 _ => Err(SqlError::Unsupported(format!("data type: {dt}"))),
2090 }
2091}
2092
2093fn object_name_to_string(name: &sp::ObjectName) -> String {
2094 name.0
2095 .iter()
2096 .filter_map(|p| match p {
2097 sp::ObjectNamePart::Identifier(ident) => Some(ident.value.clone()),
2098 _ => None,
2099 })
2100 .collect::<Vec<_>>()
2101 .join(".")
2102}
2103
2104#[cfg(test)]
2105mod tests {
2106 use super::*;
2107
2108 #[test]
2109 fn parse_create_table() {
2110 let stmt = parse_sql(
2111 "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL, age INTEGER)",
2112 )
2113 .unwrap();
2114
2115 match stmt {
2116 Statement::CreateTable(ct) => {
2117 assert_eq!(ct.name, "users");
2118 assert_eq!(ct.columns.len(), 3);
2119 assert_eq!(ct.columns[0].name, "id");
2120 assert_eq!(ct.columns[0].data_type, DataType::Integer);
2121 assert!(ct.columns[0].is_primary_key);
2122 assert!(!ct.columns[0].nullable);
2123 assert_eq!(ct.columns[1].name, "name");
2124 assert_eq!(ct.columns[1].data_type, DataType::Text);
2125 assert!(!ct.columns[1].nullable);
2126 assert_eq!(ct.columns[2].name, "age");
2127 assert!(ct.columns[2].nullable);
2128 assert_eq!(ct.primary_key, vec!["id"]);
2129 }
2130 _ => panic!("expected CreateTable"),
2131 }
2132 }
2133
2134 #[test]
2135 fn parse_create_table_if_not_exists() {
2136 let stmt = parse_sql("CREATE TABLE IF NOT EXISTS t (id INT PRIMARY KEY)").unwrap();
2137 match stmt {
2138 Statement::CreateTable(ct) => assert!(ct.if_not_exists),
2139 _ => panic!("expected CreateTable"),
2140 }
2141 }
2142
2143 #[test]
2144 fn parse_drop_table() {
2145 let stmt = parse_sql("DROP TABLE users").unwrap();
2146 match stmt {
2147 Statement::DropTable(dt) => {
2148 assert_eq!(dt.name, "users");
2149 assert!(!dt.if_exists);
2150 }
2151 _ => panic!("expected DropTable"),
2152 }
2153 }
2154
2155 #[test]
2156 fn parse_drop_table_if_exists() {
2157 let stmt = parse_sql("DROP TABLE IF EXISTS users").unwrap();
2158 match stmt {
2159 Statement::DropTable(dt) => assert!(dt.if_exists),
2160 _ => panic!("expected DropTable"),
2161 }
2162 }
2163
2164 #[test]
2165 fn parse_insert() {
2166 let stmt =
2167 parse_sql("INSERT INTO users (id, name) VALUES (1, 'Alice'), (2, 'Bob')").unwrap();
2168
2169 match stmt {
2170 Statement::Insert(ins) => {
2171 assert_eq!(ins.table, "users");
2172 assert_eq!(ins.columns, vec!["id", "name"]);
2173 let values = match &ins.source {
2174 InsertSource::Values(v) => v,
2175 _ => panic!("expected Values"),
2176 };
2177 assert_eq!(values.len(), 2);
2178 assert!(matches!(values[0][0], Expr::Literal(Value::Integer(1))));
2179 assert!(matches!(&values[0][1], Expr::Literal(Value::Text(s)) if s == "Alice"));
2180 }
2181 _ => panic!("expected Insert"),
2182 }
2183 }
2184
2185 #[test]
2186 fn parse_select_all() {
2187 let stmt = parse_sql("SELECT * FROM users").unwrap();
2188 match stmt {
2189 Statement::Select(sq) => match sq.body {
2190 QueryBody::Select(sel) => {
2191 assert_eq!(sel.from, "users");
2192 assert!(matches!(sel.columns[0], SelectColumn::AllColumns));
2193 assert!(sel.where_clause.is_none());
2194 }
2195 _ => panic!("expected QueryBody::Select"),
2196 },
2197 _ => panic!("expected Select"),
2198 }
2199 }
2200
2201 #[test]
2202 fn parse_select_where() {
2203 let stmt = parse_sql("SELECT id, name FROM users WHERE age > 18").unwrap();
2204 match stmt {
2205 Statement::Select(sq) => match sq.body {
2206 QueryBody::Select(sel) => {
2207 assert_eq!(sel.columns.len(), 2);
2208 assert!(sel.where_clause.is_some());
2209 }
2210 _ => panic!("expected QueryBody::Select"),
2211 },
2212 _ => panic!("expected Select"),
2213 }
2214 }
2215
2216 #[test]
2217 fn parse_select_order_limit() {
2218 let stmt = parse_sql("SELECT * FROM users ORDER BY name ASC LIMIT 10 OFFSET 5").unwrap();
2219 match stmt {
2220 Statement::Select(sq) => match sq.body {
2221 QueryBody::Select(sel) => {
2222 assert_eq!(sel.order_by.len(), 1);
2223 assert!(!sel.order_by[0].descending);
2224 assert!(sel.limit.is_some());
2225 assert!(sel.offset.is_some());
2226 }
2227 _ => panic!("expected QueryBody::Select"),
2228 },
2229 _ => panic!("expected Select"),
2230 }
2231 }
2232
2233 #[test]
2234 fn parse_update() {
2235 let stmt = parse_sql("UPDATE users SET name = 'Bob' WHERE id = 1").unwrap();
2236 match stmt {
2237 Statement::Update(upd) => {
2238 assert_eq!(upd.table, "users");
2239 assert_eq!(upd.assignments.len(), 1);
2240 assert_eq!(upd.assignments[0].0, "name");
2241 assert!(upd.where_clause.is_some());
2242 }
2243 _ => panic!("expected Update"),
2244 }
2245 }
2246
2247 #[test]
2248 fn parse_delete() {
2249 let stmt = parse_sql("DELETE FROM users WHERE id = 1").unwrap();
2250 match stmt {
2251 Statement::Delete(del) => {
2252 assert_eq!(del.table, "users");
2253 assert!(del.where_clause.is_some());
2254 }
2255 _ => panic!("expected Delete"),
2256 }
2257 }
2258
2259 #[test]
2260 fn parse_aggregate() {
2261 let stmt = parse_sql("SELECT COUNT(*), SUM(age) FROM users").unwrap();
2262 match stmt {
2263 Statement::Select(sq) => match sq.body {
2264 QueryBody::Select(sel) => {
2265 assert_eq!(sel.columns.len(), 2);
2266 match &sel.columns[0] {
2267 SelectColumn::Expr {
2268 expr: Expr::CountStar,
2269 ..
2270 } => {}
2271 other => panic!("expected CountStar, got {other:?}"),
2272 }
2273 }
2274 _ => panic!("expected QueryBody::Select"),
2275 },
2276 _ => panic!("expected Select"),
2277 }
2278 }
2279
2280 #[test]
2281 fn parse_group_by_having() {
2282 let stmt = parse_sql(
2283 "SELECT department, COUNT(*) FROM employees GROUP BY department HAVING COUNT(*) > 5",
2284 )
2285 .unwrap();
2286 match stmt {
2287 Statement::Select(sq) => match sq.body {
2288 QueryBody::Select(sel) => {
2289 assert_eq!(sel.group_by.len(), 1);
2290 assert!(sel.having.is_some());
2291 }
2292 _ => panic!("expected QueryBody::Select"),
2293 },
2294 _ => panic!("expected Select"),
2295 }
2296 }
2297
2298 #[test]
2299 fn parse_expressions() {
2300 let stmt = parse_sql("SELECT id + 1, -price, NOT active FROM items").unwrap();
2301 match stmt {
2302 Statement::Select(sq) => match sq.body {
2303 QueryBody::Select(sel) => {
2304 assert_eq!(sel.columns.len(), 3);
2305 match &sel.columns[0] {
2307 SelectColumn::Expr {
2308 expr: Expr::BinaryOp { op: BinOp::Add, .. },
2309 ..
2310 } => {}
2311 other => panic!("expected BinaryOp Add, got {other:?}"),
2312 }
2313 match &sel.columns[1] {
2315 SelectColumn::Expr {
2316 expr:
2317 Expr::UnaryOp {
2318 op: UnaryOp::Neg, ..
2319 },
2320 ..
2321 } => {}
2322 other => panic!("expected UnaryOp Neg, got {other:?}"),
2323 }
2324 match &sel.columns[2] {
2326 SelectColumn::Expr {
2327 expr:
2328 Expr::UnaryOp {
2329 op: UnaryOp::Not, ..
2330 },
2331 ..
2332 } => {}
2333 other => panic!("expected UnaryOp Not, got {other:?}"),
2334 }
2335 }
2336 _ => panic!("expected QueryBody::Select"),
2337 },
2338 _ => panic!("expected Select"),
2339 }
2340 }
2341
2342 #[test]
2343 fn parse_is_null() {
2344 let stmt = parse_sql("SELECT * FROM t WHERE x IS NULL").unwrap();
2345 match stmt {
2346 Statement::Select(sq) => match sq.body {
2347 QueryBody::Select(sel) => {
2348 assert!(matches!(sel.where_clause, Some(Expr::IsNull(_))));
2349 }
2350 _ => panic!("expected QueryBody::Select"),
2351 },
2352 _ => panic!("expected Select"),
2353 }
2354 }
2355
2356 #[test]
2357 fn parse_inner_join() {
2358 let stmt = parse_sql("SELECT * FROM a JOIN b ON a.id = b.id").unwrap();
2359 match stmt {
2360 Statement::Select(sq) => match sq.body {
2361 QueryBody::Select(sel) => {
2362 assert_eq!(sel.from, "a");
2363 assert_eq!(sel.joins.len(), 1);
2364 assert_eq!(sel.joins[0].join_type, JoinType::Inner);
2365 assert_eq!(sel.joins[0].table.name, "b");
2366 assert!(sel.joins[0].on_clause.is_some());
2367 }
2368 _ => panic!("expected QueryBody::Select"),
2369 },
2370 _ => panic!("expected Select"),
2371 }
2372 }
2373
2374 #[test]
2375 fn parse_inner_join_explicit() {
2376 let stmt = parse_sql("SELECT * FROM a INNER JOIN b ON a.id = b.a_id").unwrap();
2377 match stmt {
2378 Statement::Select(sq) => match sq.body {
2379 QueryBody::Select(sel) => {
2380 assert_eq!(sel.joins.len(), 1);
2381 assert_eq!(sel.joins[0].join_type, JoinType::Inner);
2382 }
2383 _ => panic!("expected QueryBody::Select"),
2384 },
2385 _ => panic!("expected Select"),
2386 }
2387 }
2388
2389 #[test]
2390 fn parse_cross_join() {
2391 let stmt = parse_sql("SELECT * FROM a CROSS JOIN b").unwrap();
2392 match stmt {
2393 Statement::Select(sq) => match sq.body {
2394 QueryBody::Select(sel) => {
2395 assert_eq!(sel.joins.len(), 1);
2396 assert_eq!(sel.joins[0].join_type, JoinType::Cross);
2397 assert!(sel.joins[0].on_clause.is_none());
2398 }
2399 _ => panic!("expected QueryBody::Select"),
2400 },
2401 _ => panic!("expected Select"),
2402 }
2403 }
2404
2405 #[test]
2406 fn parse_left_join() {
2407 let stmt = parse_sql("SELECT * FROM a LEFT JOIN b ON a.id = b.a_id").unwrap();
2408 match stmt {
2409 Statement::Select(sq) => match sq.body {
2410 QueryBody::Select(sel) => {
2411 assert_eq!(sel.joins.len(), 1);
2412 assert_eq!(sel.joins[0].join_type, JoinType::Left);
2413 }
2414 _ => panic!("expected QueryBody::Select"),
2415 },
2416 _ => panic!("expected Select"),
2417 }
2418 }
2419
2420 #[test]
2421 fn parse_table_alias() {
2422 let stmt = parse_sql("SELECT u.id FROM users u JOIN orders o ON u.id = o.user_id").unwrap();
2423 match stmt {
2424 Statement::Select(sq) => match sq.body {
2425 QueryBody::Select(sel) => {
2426 assert_eq!(sel.from, "users");
2427 assert_eq!(sel.from_alias.as_deref(), Some("u"));
2428 assert_eq!(sel.joins[0].table.name, "orders");
2429 assert_eq!(sel.joins[0].table.alias.as_deref(), Some("o"));
2430 }
2431 _ => panic!("expected QueryBody::Select"),
2432 },
2433 _ => panic!("expected Select"),
2434 }
2435 }
2436
2437 #[test]
2438 fn parse_multi_join() {
2439 let stmt =
2440 parse_sql("SELECT * FROM a JOIN b ON a.id = b.a_id JOIN c ON b.id = c.b_id").unwrap();
2441 match stmt {
2442 Statement::Select(sq) => match sq.body {
2443 QueryBody::Select(sel) => {
2444 assert_eq!(sel.joins.len(), 2);
2445 }
2446 _ => panic!("expected QueryBody::Select"),
2447 },
2448 _ => panic!("expected Select"),
2449 }
2450 }
2451
2452 #[test]
2453 fn parse_qualified_column() {
2454 let stmt = parse_sql("SELECT u.id, u.name FROM users u").unwrap();
2455 match stmt {
2456 Statement::Select(sq) => match sq.body {
2457 QueryBody::Select(sel) => match &sel.columns[0] {
2458 SelectColumn::Expr {
2459 expr: Expr::QualifiedColumn { table, column },
2460 ..
2461 } => {
2462 assert_eq!(table, "u");
2463 assert_eq!(column, "id");
2464 }
2465 other => panic!("expected QualifiedColumn, got {other:?}"),
2466 },
2467 _ => panic!("expected QueryBody::Select"),
2468 },
2469 _ => panic!("expected Select"),
2470 }
2471 }
2472
2473 #[test]
2474 fn reject_subquery() {
2475 assert!(parse_sql("SELECT * FROM (SELECT 1)").is_err());
2476 }
2477
2478 #[test]
2479 fn parse_type_mapping() {
2480 let stmt = parse_sql(
2481 "CREATE TABLE t (a INT PRIMARY KEY, b BIGINT, c SMALLINT, d REAL, e DOUBLE PRECISION, f VARCHAR(255), g BOOLEAN, h BLOB, i BYTEA)"
2482 ).unwrap();
2483 match stmt {
2484 Statement::CreateTable(ct) => {
2485 assert_eq!(ct.columns[0].data_type, DataType::Integer); assert_eq!(ct.columns[1].data_type, DataType::Integer); assert_eq!(ct.columns[2].data_type, DataType::Integer); assert_eq!(ct.columns[3].data_type, DataType::Real); assert_eq!(ct.columns[4].data_type, DataType::Real); assert_eq!(ct.columns[5].data_type, DataType::Text); assert_eq!(ct.columns[6].data_type, DataType::Boolean); assert_eq!(ct.columns[7].data_type, DataType::Blob); assert_eq!(ct.columns[8].data_type, DataType::Blob); }
2495 _ => panic!("expected CreateTable"),
2496 }
2497 }
2498
2499 #[test]
2500 fn parse_boolean_literals() {
2501 let stmt = parse_sql("INSERT INTO t (a, b) VALUES (true, false)").unwrap();
2502 match stmt {
2503 Statement::Insert(ins) => {
2504 let values = match &ins.source {
2505 InsertSource::Values(v) => v,
2506 _ => panic!("expected Values"),
2507 };
2508 assert!(matches!(values[0][0], Expr::Literal(Value::Boolean(true))));
2509 assert!(matches!(values[0][1], Expr::Literal(Value::Boolean(false))));
2510 }
2511 _ => panic!("expected Insert"),
2512 }
2513 }
2514
2515 #[test]
2516 fn parse_null_literal() {
2517 let stmt = parse_sql("INSERT INTO t (a) VALUES (NULL)").unwrap();
2518 match stmt {
2519 Statement::Insert(ins) => {
2520 let values = match &ins.source {
2521 InsertSource::Values(v) => v,
2522 _ => panic!("expected Values"),
2523 };
2524 assert!(matches!(values[0][0], Expr::Literal(Value::Null)));
2525 }
2526 _ => panic!("expected Insert"),
2527 }
2528 }
2529
2530 #[test]
2531 fn parse_alias() {
2532 let stmt = parse_sql("SELECT id AS user_id FROM users").unwrap();
2533 match stmt {
2534 Statement::Select(sq) => match sq.body {
2535 QueryBody::Select(sel) => match &sel.columns[0] {
2536 SelectColumn::Expr { alias: Some(a), .. } => assert_eq!(a, "user_id"),
2537 other => panic!("expected alias, got {other:?}"),
2538 },
2539 _ => panic!("expected QueryBody::Select"),
2540 },
2541 _ => panic!("expected Select"),
2542 }
2543 }
2544
2545 #[test]
2546 fn parse_begin() {
2547 let stmt = parse_sql("BEGIN").unwrap();
2548 assert!(matches!(stmt, Statement::Begin));
2549 }
2550
2551 #[test]
2552 fn parse_begin_transaction() {
2553 let stmt = parse_sql("BEGIN TRANSACTION").unwrap();
2554 assert!(matches!(stmt, Statement::Begin));
2555 }
2556
2557 #[test]
2558 fn parse_commit() {
2559 let stmt = parse_sql("COMMIT").unwrap();
2560 assert!(matches!(stmt, Statement::Commit));
2561 }
2562
2563 #[test]
2564 fn parse_rollback() {
2565 let stmt = parse_sql("ROLLBACK").unwrap();
2566 assert!(matches!(stmt, Statement::Rollback));
2567 }
2568
2569 #[test]
2570 fn parse_savepoint() {
2571 let stmt = parse_sql("SAVEPOINT sp1").unwrap();
2572 match stmt {
2573 Statement::Savepoint(name) => assert_eq!(name, "sp1"),
2574 other => panic!("expected Savepoint, got {other:?}"),
2575 }
2576 }
2577
2578 #[test]
2579 fn parse_savepoint_case_insensitive() {
2580 let stmt = parse_sql("SAVEPOINT My_SP").unwrap();
2581 match stmt {
2582 Statement::Savepoint(name) => assert_eq!(name, "my_sp"),
2583 other => panic!("expected Savepoint, got {other:?}"),
2584 }
2585 }
2586
2587 #[test]
2588 fn parse_release_savepoint() {
2589 let stmt = parse_sql("RELEASE SAVEPOINT sp1").unwrap();
2590 match stmt {
2591 Statement::ReleaseSavepoint(name) => assert_eq!(name, "sp1"),
2592 other => panic!("expected ReleaseSavepoint, got {other:?}"),
2593 }
2594 }
2595
2596 #[test]
2597 fn parse_release_without_savepoint_keyword() {
2598 let stmt = parse_sql("RELEASE sp1").unwrap();
2599 match stmt {
2600 Statement::ReleaseSavepoint(name) => assert_eq!(name, "sp1"),
2601 other => panic!("expected ReleaseSavepoint, got {other:?}"),
2602 }
2603 }
2604
2605 #[test]
2606 fn parse_rollback_to_savepoint() {
2607 let stmt = parse_sql("ROLLBACK TO SAVEPOINT sp1").unwrap();
2608 match stmt {
2609 Statement::RollbackTo(name) => assert_eq!(name, "sp1"),
2610 other => panic!("expected RollbackTo, got {other:?}"),
2611 }
2612 }
2613
2614 #[test]
2615 fn parse_rollback_to_without_savepoint_keyword() {
2616 let stmt = parse_sql("ROLLBACK TO sp1").unwrap();
2617 match stmt {
2618 Statement::RollbackTo(name) => assert_eq!(name, "sp1"),
2619 other => panic!("expected RollbackTo, got {other:?}"),
2620 }
2621 }
2622
2623 #[test]
2624 fn parse_rollback_to_case_insensitive() {
2625 let stmt = parse_sql("ROLLBACK TO My_SP").unwrap();
2626 match stmt {
2627 Statement::RollbackTo(name) => assert_eq!(name, "my_sp"),
2628 other => panic!("expected RollbackTo, got {other:?}"),
2629 }
2630 }
2631
2632 #[test]
2633 fn parse_commit_and_chain_rejected() {
2634 let err = parse_sql("COMMIT AND CHAIN").unwrap_err();
2635 assert!(matches!(err, SqlError::Unsupported(_)));
2636 }
2637
2638 #[test]
2639 fn parse_rollback_and_chain_rejected() {
2640 let err = parse_sql("ROLLBACK AND CHAIN").unwrap_err();
2641 assert!(matches!(err, SqlError::Unsupported(_)));
2642 }
2643
2644 #[test]
2645 fn parse_select_distinct() {
2646 let stmt = parse_sql("SELECT DISTINCT name FROM users").unwrap();
2647 match stmt {
2648 Statement::Select(sq) => match sq.body {
2649 QueryBody::Select(sel) => {
2650 assert!(sel.distinct);
2651 assert_eq!(sel.columns.len(), 1);
2652 }
2653 _ => panic!("expected QueryBody::Select"),
2654 },
2655 _ => panic!("expected Select"),
2656 }
2657 }
2658
2659 #[test]
2660 fn parse_select_without_distinct() {
2661 let stmt = parse_sql("SELECT name FROM users").unwrap();
2662 match stmt {
2663 Statement::Select(sq) => match sq.body {
2664 QueryBody::Select(sel) => {
2665 assert!(!sel.distinct);
2666 }
2667 _ => panic!("expected QueryBody::Select"),
2668 },
2669 _ => panic!("expected Select"),
2670 }
2671 }
2672
2673 #[test]
2674 fn parse_select_distinct_all_columns() {
2675 let stmt = parse_sql("SELECT DISTINCT * FROM users").unwrap();
2676 match stmt {
2677 Statement::Select(sq) => match sq.body {
2678 QueryBody::Select(sel) => {
2679 assert!(sel.distinct);
2680 assert!(matches!(sel.columns[0], SelectColumn::AllColumns));
2681 }
2682 _ => panic!("expected QueryBody::Select"),
2683 },
2684 _ => panic!("expected Select"),
2685 }
2686 }
2687
2688 #[test]
2689 fn reject_distinct_on() {
2690 assert!(parse_sql("SELECT DISTINCT ON (id) * FROM users").is_err());
2691 }
2692
2693 #[test]
2694 fn parse_create_index() {
2695 let stmt = parse_sql("CREATE INDEX idx_name ON users (name)").unwrap();
2696 match stmt {
2697 Statement::CreateIndex(ci) => {
2698 assert_eq!(ci.index_name, "idx_name");
2699 assert_eq!(ci.table_name, "users");
2700 assert_eq!(ci.columns, vec!["name"]);
2701 assert!(!ci.unique);
2702 assert!(!ci.if_not_exists);
2703 }
2704 _ => panic!("expected CreateIndex"),
2705 }
2706 }
2707
2708 #[test]
2709 fn parse_create_unique_index() {
2710 let stmt = parse_sql("CREATE UNIQUE INDEX idx_email ON users (email)").unwrap();
2711 match stmt {
2712 Statement::CreateIndex(ci) => {
2713 assert!(ci.unique);
2714 assert_eq!(ci.columns, vec!["email"]);
2715 }
2716 _ => panic!("expected CreateIndex"),
2717 }
2718 }
2719
2720 #[test]
2721 fn parse_create_index_if_not_exists() {
2722 let stmt = parse_sql("CREATE INDEX IF NOT EXISTS idx_x ON t (a)").unwrap();
2723 match stmt {
2724 Statement::CreateIndex(ci) => assert!(ci.if_not_exists),
2725 _ => panic!("expected CreateIndex"),
2726 }
2727 }
2728
2729 #[test]
2730 fn parse_create_index_multi_column() {
2731 let stmt = parse_sql("CREATE INDEX idx_multi ON t (a, b, c)").unwrap();
2732 match stmt {
2733 Statement::CreateIndex(ci) => {
2734 assert_eq!(ci.columns, vec!["a", "b", "c"]);
2735 }
2736 _ => panic!("expected CreateIndex"),
2737 }
2738 }
2739
2740 #[test]
2741 fn parse_drop_index() {
2742 let stmt = parse_sql("DROP INDEX idx_name").unwrap();
2743 match stmt {
2744 Statement::DropIndex(di) => {
2745 assert_eq!(di.index_name, "idx_name");
2746 assert!(!di.if_exists);
2747 }
2748 _ => panic!("expected DropIndex"),
2749 }
2750 }
2751
2752 #[test]
2753 fn parse_drop_index_if_exists() {
2754 let stmt = parse_sql("DROP INDEX IF EXISTS idx_name").unwrap();
2755 match stmt {
2756 Statement::DropIndex(di) => {
2757 assert!(di.if_exists);
2758 }
2759 _ => panic!("expected DropIndex"),
2760 }
2761 }
2762
2763 #[test]
2764 fn parse_explain_select() {
2765 let stmt = parse_sql("EXPLAIN SELECT * FROM users WHERE id = 1").unwrap();
2766 match stmt {
2767 Statement::Explain(inner) => {
2768 assert!(matches!(*inner, Statement::Select(_)));
2769 }
2770 _ => panic!("expected Explain"),
2771 }
2772 }
2773
2774 #[test]
2775 fn parse_explain_insert() {
2776 let stmt = parse_sql("EXPLAIN INSERT INTO t (a) VALUES (1)").unwrap();
2777 assert!(matches!(stmt, Statement::Explain(_)));
2778 }
2779
2780 #[test]
2781 fn reject_explain_analyze() {
2782 assert!(parse_sql("EXPLAIN ANALYZE SELECT * FROM t").is_err());
2783 }
2784
2785 #[test]
2786 fn parse_parameter_placeholder() {
2787 let stmt = parse_sql("SELECT * FROM t WHERE id = $1").unwrap();
2788 match stmt {
2789 Statement::Select(sq) => match sq.body {
2790 QueryBody::Select(sel) => match &sel.where_clause {
2791 Some(Expr::BinaryOp { right, .. }) => {
2792 assert!(matches!(right.as_ref(), Expr::Parameter(1)));
2793 }
2794 other => panic!("expected BinaryOp with Parameter, got {other:?}"),
2795 },
2796 _ => panic!("expected QueryBody::Select"),
2797 },
2798 _ => panic!("expected Select"),
2799 }
2800 }
2801
2802 #[test]
2803 fn parse_multiple_parameters() {
2804 let stmt = parse_sql("INSERT INTO t (a, b) VALUES ($1, $2)").unwrap();
2805 match stmt {
2806 Statement::Insert(ins) => {
2807 let values = match &ins.source {
2808 InsertSource::Values(v) => v,
2809 _ => panic!("expected Values"),
2810 };
2811 assert!(matches!(values[0][0], Expr::Parameter(1)));
2812 assert!(matches!(values[0][1], Expr::Parameter(2)));
2813 }
2814 _ => panic!("expected Insert"),
2815 }
2816 }
2817
2818 #[test]
2819 fn parse_insert_select() {
2820 let stmt =
2821 parse_sql("INSERT INTO t2 (id, name) SELECT id, name FROM t1 WHERE id > 5").unwrap();
2822 match stmt {
2823 Statement::Insert(ins) => {
2824 assert_eq!(ins.table, "t2");
2825 assert_eq!(ins.columns, vec!["id", "name"]);
2826 match &ins.source {
2827 InsertSource::Select(sq) => match &sq.body {
2828 QueryBody::Select(sel) => {
2829 assert_eq!(sel.from, "t1");
2830 assert_eq!(sel.columns.len(), 2);
2831 assert!(sel.where_clause.is_some());
2832 }
2833 _ => panic!("expected QueryBody::Select"),
2834 },
2835 _ => panic!("expected InsertSource::Select"),
2836 }
2837 }
2838 _ => panic!("expected Insert"),
2839 }
2840 }
2841
2842 #[test]
2843 fn parse_insert_select_no_columns() {
2844 let stmt = parse_sql("INSERT INTO t2 SELECT * FROM t1").unwrap();
2845 match stmt {
2846 Statement::Insert(ins) => {
2847 assert_eq!(ins.table, "t2");
2848 assert!(ins.columns.is_empty());
2849 assert!(matches!(&ins.source, InsertSource::Select(_)));
2850 }
2851 _ => panic!("expected Insert"),
2852 }
2853 }
2854
2855 #[test]
2856 fn reject_zero_parameter() {
2857 assert!(parse_sql("SELECT $0 FROM t").is_err());
2858 }
2859
2860 #[test]
2861 fn count_params_basic() {
2862 let stmt = parse_sql("SELECT * FROM t WHERE a = $1 AND b = $3").unwrap();
2863 assert_eq!(count_params(&stmt), 3);
2864 }
2865
2866 #[test]
2867 fn count_params_none() {
2868 let stmt = parse_sql("SELECT * FROM t WHERE a = 1").unwrap();
2869 assert_eq!(count_params(&stmt), 0);
2870 }
2871
2872 #[test]
2873 fn parse_table_constraint_pk() {
2874 let stmt = parse_sql("CREATE TABLE t (a INTEGER, b TEXT, PRIMARY KEY (a))").unwrap();
2875 match stmt {
2876 Statement::CreateTable(ct) => {
2877 assert_eq!(ct.primary_key, vec!["a"]);
2878 assert!(ct.columns[0].is_primary_key);
2879 assert!(!ct.columns[0].nullable);
2880 }
2881 _ => panic!("expected CreateTable"),
2882 }
2883 }
2884}