1use sqlparser::ast as sp;
4use sqlparser::dialect::GenericDialect;
5use sqlparser::parser::Parser;
6
7use crate::error::{Result, SqlError};
8use crate::types::{DataType, Value};
9
10#[derive(Debug, Clone)]
11pub enum Statement {
12 CreateTable(CreateTableStmt),
13 DropTable(DropTableStmt),
14 CreateIndex(CreateIndexStmt),
15 DropIndex(DropIndexStmt),
16 CreateView(CreateViewStmt),
17 DropView(DropViewStmt),
18 AlterTable(Box<AlterTableStmt>),
19 Insert(InsertStmt),
20 Select(Box<SelectQuery>),
21 Update(UpdateStmt),
22 Delete(DeleteStmt),
23 Begin,
24 Commit,
25 Rollback,
26 Savepoint(String),
27 ReleaseSavepoint(String),
28 RollbackTo(String),
29 SetTimezone(String),
30 Explain(Box<Statement>),
31}
32
33#[derive(Debug, Clone)]
34pub struct AlterTableStmt {
35 pub table: String,
36 pub op: AlterTableOp,
37}
38
39#[derive(Debug, Clone)]
40pub enum AlterTableOp {
41 AddColumn {
42 column: Box<ColumnSpec>,
43 foreign_key: Option<ForeignKeyDef>,
44 if_not_exists: bool,
45 },
46 DropColumn {
47 name: String,
48 if_exists: bool,
49 },
50 RenameColumn {
51 old_name: String,
52 new_name: String,
53 },
54 RenameTable {
55 new_name: String,
56 },
57}
58
59#[derive(Debug, Clone)]
60pub struct CreateTableStmt {
61 pub name: String,
62 pub columns: Vec<ColumnSpec>,
63 pub primary_key: Vec<String>,
64 pub if_not_exists: bool,
65 pub check_constraints: Vec<TableCheckConstraint>,
66 pub foreign_keys: Vec<ForeignKeyDef>,
67 pub unique_indices: Vec<UniqueIndexDef>,
68}
69
70#[derive(Debug, Clone)]
71pub struct UniqueIndexDef {
72 pub name: Option<String>,
73 pub columns: Vec<String>,
74}
75
76#[derive(Debug, Clone)]
77pub struct TableCheckConstraint {
78 pub name: Option<String>,
79 pub expr: Expr,
80 pub sql: String,
81}
82
83#[derive(Debug, Clone)]
84pub struct ForeignKeyDef {
85 pub name: Option<String>,
86 pub columns: Vec<String>,
87 pub foreign_table: String,
88 pub referred_columns: Vec<String>,
89}
90
91#[derive(Debug, Clone)]
92pub struct ColumnSpec {
93 pub name: String,
94 pub data_type: DataType,
95 pub nullable: bool,
96 pub is_primary_key: bool,
97 pub default_expr: Option<Expr>,
98 pub default_sql: Option<String>,
99 pub check_expr: Option<Expr>,
100 pub check_sql: Option<String>,
101 pub check_name: Option<String>,
102}
103
104#[derive(Debug, Clone)]
105pub struct DropTableStmt {
106 pub name: String,
107 pub if_exists: bool,
108}
109
110#[derive(Debug, Clone)]
111pub struct CreateIndexStmt {
112 pub index_name: String,
113 pub table_name: String,
114 pub columns: Vec<String>,
115 pub unique: bool,
116 pub if_not_exists: bool,
117}
118
119#[derive(Debug, Clone)]
120pub struct DropIndexStmt {
121 pub index_name: String,
122 pub if_exists: bool,
123}
124
125#[derive(Debug, Clone)]
126pub struct CreateViewStmt {
127 pub name: String,
128 pub sql: String,
129 pub column_aliases: Vec<String>,
130 pub or_replace: bool,
131 pub if_not_exists: bool,
132}
133
134#[derive(Debug, Clone)]
135pub struct DropViewStmt {
136 pub name: String,
137 pub if_exists: bool,
138}
139
140#[derive(Debug, Clone)]
141pub enum InsertSource {
142 Values(Vec<Vec<Expr>>),
143 Select(Box<SelectQuery>),
144}
145
146#[derive(Debug, Clone)]
147pub struct InsertStmt {
148 pub table: String,
149 pub columns: Vec<String>,
150 pub source: InsertSource,
151}
152
153#[derive(Debug, Clone)]
154pub struct TableRef {
155 pub name: String,
156 pub alias: Option<String>,
157}
158
159#[derive(Debug, Clone, Copy, PartialEq)]
160pub enum JoinType {
161 Inner,
162 Cross,
163 Left,
164 Right,
165}
166
167#[derive(Debug, Clone)]
168pub struct JoinClause {
169 pub join_type: JoinType,
170 pub table: TableRef,
171 pub on_clause: Option<Expr>,
172}
173
174#[derive(Debug, Clone)]
175pub struct SelectStmt {
176 pub columns: Vec<SelectColumn>,
177 pub from: String,
178 pub from_alias: Option<String>,
179 pub joins: Vec<JoinClause>,
180 pub distinct: bool,
181 pub where_clause: Option<Expr>,
182 pub order_by: Vec<OrderByItem>,
183 pub limit: Option<Expr>,
184 pub offset: Option<Expr>,
185 pub group_by: Vec<Expr>,
186 pub having: Option<Expr>,
187}
188
189#[derive(Debug, Clone)]
190pub enum SetOp {
191 Union,
192 Intersect,
193 Except,
194}
195
196#[derive(Debug, Clone)]
197pub struct CompoundSelect {
198 pub op: SetOp,
199 pub all: bool,
200 pub left: Box<QueryBody>,
201 pub right: Box<QueryBody>,
202 pub order_by: Vec<OrderByItem>,
203 pub limit: Option<Expr>,
204 pub offset: Option<Expr>,
205}
206
207#[derive(Debug, Clone)]
208pub enum QueryBody {
209 Select(Box<SelectStmt>),
210 Compound(Box<CompoundSelect>),
211}
212
213#[derive(Debug, Clone)]
214pub struct CteDefinition {
215 pub name: String,
216 pub column_aliases: Vec<String>,
217 pub body: QueryBody,
218}
219
220#[derive(Debug, Clone)]
221pub struct SelectQuery {
222 pub ctes: Vec<CteDefinition>,
223 pub recursive: bool,
224 pub body: QueryBody,
225}
226
227#[derive(Debug, Clone)]
228pub struct UpdateStmt {
229 pub table: String,
230 pub assignments: Vec<(String, Expr)>,
231 pub where_clause: Option<Expr>,
232}
233
234#[derive(Debug, Clone)]
235pub struct DeleteStmt {
236 pub table: String,
237 pub where_clause: Option<Expr>,
238}
239
240#[derive(Debug, Clone)]
241pub enum SelectColumn {
242 AllColumns,
243 Expr { expr: Expr, alias: Option<String> },
244}
245
246#[derive(Debug, Clone)]
247pub struct OrderByItem {
248 pub expr: Expr,
249 pub descending: bool,
250 pub nulls_first: Option<bool>,
251}
252
253#[derive(Debug, Clone)]
254pub enum Expr {
255 Literal(Value),
256 Column(String),
257 QualifiedColumn {
258 table: String,
259 column: String,
260 },
261 BinaryOp {
262 left: Box<Expr>,
263 op: BinOp,
264 right: Box<Expr>,
265 },
266 UnaryOp {
267 op: UnaryOp,
268 expr: Box<Expr>,
269 },
270 IsNull(Box<Expr>),
271 IsNotNull(Box<Expr>),
272 Function {
273 name: String,
274 args: Vec<Expr>,
275 distinct: bool,
277 },
278 CountStar,
279 InSubquery {
280 expr: Box<Expr>,
281 subquery: Box<SelectStmt>,
282 negated: bool,
283 },
284 InList {
285 expr: Box<Expr>,
286 list: Vec<Expr>,
287 negated: bool,
288 },
289 Exists {
290 subquery: Box<SelectStmt>,
291 negated: bool,
292 },
293 ScalarSubquery(Box<SelectStmt>),
294 InSet {
295 expr: Box<Expr>,
296 values: std::collections::HashSet<Value>,
297 has_null: bool,
298 negated: bool,
299 },
300 Between {
301 expr: Box<Expr>,
302 low: Box<Expr>,
303 high: Box<Expr>,
304 negated: bool,
305 },
306 Like {
307 expr: Box<Expr>,
308 pattern: Box<Expr>,
309 escape: Option<Box<Expr>>,
310 negated: bool,
311 },
312 Case {
313 operand: Option<Box<Expr>>,
314 conditions: Vec<(Expr, Expr)>,
315 else_result: Option<Box<Expr>>,
316 },
317 Coalesce(Vec<Expr>),
318 Cast {
319 expr: Box<Expr>,
320 data_type: DataType,
321 },
322 Parameter(usize),
323 WindowFunction {
324 name: String,
325 args: Vec<Expr>,
326 spec: WindowSpec,
327 },
328}
329
330#[derive(Debug, Clone)]
331pub struct WindowSpec {
332 pub partition_by: Vec<Expr>,
333 pub order_by: Vec<OrderByItem>,
334 pub frame: Option<WindowFrame>,
335}
336
337#[derive(Debug, Clone)]
338pub struct WindowFrame {
339 pub units: WindowFrameUnits,
340 pub start: WindowFrameBound,
341 pub end: WindowFrameBound,
342}
343
344#[derive(Debug, Clone, Copy)]
345pub enum WindowFrameUnits {
346 Rows,
347 Range,
348 Groups,
349}
350
351#[derive(Debug, Clone)]
352pub enum WindowFrameBound {
353 UnboundedPreceding,
354 Preceding(Box<Expr>),
355 CurrentRow,
356 Following(Box<Expr>),
357 UnboundedFollowing,
358}
359
360#[derive(Debug, Clone, Copy, PartialEq, Eq)]
361pub enum BinOp {
362 Add,
363 Sub,
364 Mul,
365 Div,
366 Mod,
367 Eq,
368 NotEq,
369 Lt,
370 Gt,
371 LtEq,
372 GtEq,
373 And,
374 Or,
375 Concat,
376}
377
378#[derive(Debug, Clone, Copy, PartialEq, Eq)]
379pub enum UnaryOp {
380 Neg,
381 Not,
382}
383
384pub fn has_subquery(expr: &Expr) -> bool {
385 match expr {
386 Expr::InSubquery { .. } | Expr::Exists { .. } | Expr::ScalarSubquery(_) => true,
387 Expr::BinaryOp { left, right, .. } => has_subquery(left) || has_subquery(right),
388 Expr::UnaryOp { expr, .. } => has_subquery(expr),
389 Expr::IsNull(e) | Expr::IsNotNull(e) => has_subquery(e),
390 Expr::InList { expr, list, .. } => has_subquery(expr) || list.iter().any(has_subquery),
391 Expr::InSet { expr, .. } => has_subquery(expr),
392 Expr::Between {
393 expr, low, high, ..
394 } => has_subquery(expr) || has_subquery(low) || has_subquery(high),
395 Expr::Like {
396 expr,
397 pattern,
398 escape,
399 ..
400 } => {
401 has_subquery(expr)
402 || has_subquery(pattern)
403 || escape.as_ref().is_some_and(|e| has_subquery(e))
404 }
405 Expr::Case {
406 operand,
407 conditions,
408 else_result,
409 } => {
410 operand.as_ref().is_some_and(|e| has_subquery(e))
411 || conditions
412 .iter()
413 .any(|(c, r)| has_subquery(c) || has_subquery(r))
414 || else_result.as_ref().is_some_and(|e| has_subquery(e))
415 }
416 Expr::Coalesce(args) | Expr::Function { args, .. } => args.iter().any(has_subquery),
417 Expr::Cast { expr, .. } => has_subquery(expr),
418 _ => false,
419 }
420}
421
422pub fn parse_sql_expr(sql: &str) -> Result<Expr> {
425 let dialect = GenericDialect {};
426 let mut parser = Parser::new(&dialect)
427 .try_with_sql(sql)
428 .map_err(|e| SqlError::Parse(e.to_string()))?;
429 let sp_expr = parser
430 .parse_expr()
431 .map_err(|e| SqlError::Parse(e.to_string()))?;
432 convert_expr(&sp_expr)
433}
434
435pub fn parse_sql(sql: &str) -> Result<Statement> {
436 let dialect = GenericDialect {};
437 let stmts = Parser::parse_sql(&dialect, sql).map_err(|e| SqlError::Parse(e.to_string()))?;
438
439 if stmts.is_empty() {
440 return Err(SqlError::Parse("empty SQL".into()));
441 }
442 if stmts.len() > 1 {
443 return Err(SqlError::Unsupported("multiple statements".into()));
444 }
445
446 convert_statement(stmts.into_iter().next().unwrap())
447}
448
449pub fn count_params(stmt: &Statement) -> usize {
451 let mut max_idx = 0usize;
452 visit_exprs_stmt(stmt, &mut |e| {
453 if let Expr::Parameter(n) = e {
454 max_idx = max_idx.max(*n);
455 }
456 });
457 max_idx
458}
459
460fn visit_exprs_stmt(stmt: &Statement, visitor: &mut impl FnMut(&Expr)) {
461 match stmt {
462 Statement::Select(sq) => {
463 for cte in &sq.ctes {
464 visit_exprs_query_body(&cte.body, visitor);
465 }
466 visit_exprs_query_body(&sq.body, visitor);
467 }
468 Statement::Insert(ins) => match &ins.source {
469 InsertSource::Values(rows) => {
470 for row in rows {
471 for e in row {
472 visit_expr(e, visitor);
473 }
474 }
475 }
476 InsertSource::Select(sq) => {
477 for cte in &sq.ctes {
478 visit_exprs_query_body(&cte.body, visitor);
479 }
480 visit_exprs_query_body(&sq.body, visitor);
481 }
482 },
483 Statement::Update(upd) => {
484 for (_, e) in &upd.assignments {
485 visit_expr(e, visitor);
486 }
487 if let Some(w) = &upd.where_clause {
488 visit_expr(w, visitor);
489 }
490 }
491 Statement::Delete(del) => {
492 if let Some(w) = &del.where_clause {
493 visit_expr(w, visitor);
494 }
495 }
496 Statement::Explain(inner) => visit_exprs_stmt(inner, visitor),
497 _ => {}
498 }
499}
500
501fn visit_exprs_query_body(body: &QueryBody, visitor: &mut impl FnMut(&Expr)) {
502 match body {
503 QueryBody::Select(sel) => visit_exprs_select(sel, visitor),
504 QueryBody::Compound(comp) => {
505 visit_exprs_query_body(&comp.left, visitor);
506 visit_exprs_query_body(&comp.right, visitor);
507 for o in &comp.order_by {
508 visit_expr(&o.expr, visitor);
509 }
510 if let Some(l) = &comp.limit {
511 visit_expr(l, visitor);
512 }
513 if let Some(o) = &comp.offset {
514 visit_expr(o, visitor);
515 }
516 }
517 }
518}
519
520fn visit_exprs_select(sel: &SelectStmt, visitor: &mut impl FnMut(&Expr)) {
521 for col in &sel.columns {
522 if let SelectColumn::Expr { expr, .. } = col {
523 visit_expr(expr, visitor);
524 }
525 }
526 for j in &sel.joins {
527 if let Some(on) = &j.on_clause {
528 visit_expr(on, visitor);
529 }
530 }
531 if let Some(w) = &sel.where_clause {
532 visit_expr(w, visitor);
533 }
534 for o in &sel.order_by {
535 visit_expr(&o.expr, visitor);
536 }
537 if let Some(l) = &sel.limit {
538 visit_expr(l, visitor);
539 }
540 if let Some(o) = &sel.offset {
541 visit_expr(o, visitor);
542 }
543 for g in &sel.group_by {
544 visit_expr(g, visitor);
545 }
546 if let Some(h) = &sel.having {
547 visit_expr(h, visitor);
548 }
549}
550
551fn visit_expr(expr: &Expr, visitor: &mut impl FnMut(&Expr)) {
552 visitor(expr);
553 match expr {
554 Expr::BinaryOp { left, right, .. } => {
555 visit_expr(left, visitor);
556 visit_expr(right, visitor);
557 }
558 Expr::UnaryOp { expr: e, .. } | Expr::IsNull(e) | Expr::IsNotNull(e) => {
559 visit_expr(e, visitor);
560 }
561 Expr::Function { args, .. } | Expr::Coalesce(args) => {
562 for a in args {
563 visit_expr(a, visitor);
564 }
565 }
566 Expr::InSubquery {
567 expr: e, subquery, ..
568 } => {
569 visit_expr(e, visitor);
570 visit_exprs_select(subquery, visitor);
571 }
572 Expr::InList { expr: e, list, .. } => {
573 visit_expr(e, visitor);
574 for l in list {
575 visit_expr(l, visitor);
576 }
577 }
578 Expr::Exists { subquery, .. } => visit_exprs_select(subquery, visitor),
579 Expr::ScalarSubquery(sq) => visit_exprs_select(sq, visitor),
580 Expr::InSet { expr: e, .. } => visit_expr(e, visitor),
581 Expr::Between {
582 expr: e, low, high, ..
583 } => {
584 visit_expr(e, visitor);
585 visit_expr(low, visitor);
586 visit_expr(high, visitor);
587 }
588 Expr::Like {
589 expr: e,
590 pattern,
591 escape,
592 ..
593 } => {
594 visit_expr(e, visitor);
595 visit_expr(pattern, visitor);
596 if let Some(esc) = escape {
597 visit_expr(esc, visitor);
598 }
599 }
600 Expr::Case {
601 operand,
602 conditions,
603 else_result,
604 } => {
605 if let Some(op) = operand {
606 visit_expr(op, visitor);
607 }
608 for (cond, then) in conditions {
609 visit_expr(cond, visitor);
610 visit_expr(then, visitor);
611 }
612 if let Some(el) = else_result {
613 visit_expr(el, visitor);
614 }
615 }
616 Expr::Cast { expr: e, .. } => visit_expr(e, visitor),
617 Expr::WindowFunction { args, spec, .. } => {
618 for a in args {
619 visit_expr(a, visitor);
620 }
621 for p in &spec.partition_by {
622 visit_expr(p, visitor);
623 }
624 for o in &spec.order_by {
625 visit_expr(&o.expr, visitor);
626 }
627 if let Some(ref frame) = spec.frame {
628 if let WindowFrameBound::Preceding(e) | WindowFrameBound::Following(e) =
629 &frame.start
630 {
631 visit_expr(e, visitor);
632 }
633 if let WindowFrameBound::Preceding(e) | WindowFrameBound::Following(e) = &frame.end
634 {
635 visit_expr(e, visitor);
636 }
637 }
638 }
639 Expr::Literal(_)
640 | Expr::Column(_)
641 | Expr::QualifiedColumn { .. }
642 | Expr::CountStar
643 | Expr::Parameter(_) => {}
644 }
645}
646
647fn convert_statement(stmt: sp::Statement) -> Result<Statement> {
648 match stmt {
649 sp::Statement::CreateTable(ct) => convert_create_table(ct),
650 sp::Statement::CreateIndex(ci) => convert_create_index(ci),
651 sp::Statement::Drop {
652 object_type: sp::ObjectType::Table,
653 if_exists,
654 names,
655 ..
656 } => {
657 if names.len() != 1 {
658 return Err(SqlError::Unsupported("multi-table DROP".into()));
659 }
660 Ok(Statement::DropTable(DropTableStmt {
661 name: object_name_to_string(&names[0]),
662 if_exists,
663 }))
664 }
665 sp::Statement::Drop {
666 object_type: sp::ObjectType::Index,
667 if_exists,
668 names,
669 ..
670 } => {
671 if names.len() != 1 {
672 return Err(SqlError::Unsupported("multi-index DROP".into()));
673 }
674 Ok(Statement::DropIndex(DropIndexStmt {
675 index_name: object_name_to_string(&names[0]),
676 if_exists,
677 }))
678 }
679 sp::Statement::CreateView(cv) => convert_create_view(cv),
680 sp::Statement::Drop {
681 object_type: sp::ObjectType::View,
682 if_exists,
683 names,
684 ..
685 } => {
686 if names.len() != 1 {
687 return Err(SqlError::Unsupported("multi-view DROP".into()));
688 }
689 Ok(Statement::DropView(DropViewStmt {
690 name: object_name_to_string(&names[0]),
691 if_exists,
692 }))
693 }
694 sp::Statement::AlterTable(at) => convert_alter_table(at),
695 sp::Statement::Insert(insert) => convert_insert(insert),
696 sp::Statement::Query(query) => convert_query(*query),
697 sp::Statement::Update(update) => convert_update(update),
698 sp::Statement::Delete(delete) => convert_delete(delete),
699 sp::Statement::StartTransaction { .. } => Ok(Statement::Begin),
700 sp::Statement::Commit { chain: true, .. } => {
701 Err(SqlError::Unsupported("COMMIT AND CHAIN".into()))
702 }
703 sp::Statement::Commit { .. } => Ok(Statement::Commit),
704 sp::Statement::Rollback { chain: true, .. } => {
705 Err(SqlError::Unsupported("ROLLBACK AND CHAIN".into()))
706 }
707 sp::Statement::Rollback {
708 savepoint: Some(name),
709 ..
710 } => Ok(Statement::RollbackTo(name.value.to_ascii_lowercase())),
711 sp::Statement::Rollback { .. } => Ok(Statement::Rollback),
712 sp::Statement::Savepoint { name } => {
713 Ok(Statement::Savepoint(name.value.to_ascii_lowercase()))
714 }
715 sp::Statement::ReleaseSavepoint { name } => {
716 Ok(Statement::ReleaseSavepoint(name.value.to_ascii_lowercase()))
717 }
718 sp::Statement::Set(sp::Set::SetTimeZone { value, .. }) => {
719 let zone = match value {
721 sp::Expr::Value(v) => match &v.value {
722 sp::Value::SingleQuotedString(s) => s.clone(),
723 sp::Value::DoubleQuotedString(s) => s.clone(),
724 other => other.to_string(),
725 },
726 sp::Expr::Identifier(ident) => ident.value.clone(),
727 other => {
728 return Err(SqlError::Parse(format!(
729 "SET TIME ZONE expects a string literal or identifier, got: {other}"
730 )))
731 }
732 };
733 Ok(Statement::SetTimezone(zone))
734 }
735 sp::Statement::Explain {
736 statement, analyze, ..
737 } => {
738 if analyze {
739 return Err(SqlError::Unsupported("EXPLAIN ANALYZE".into()));
740 }
741 let inner = convert_statement(*statement)?;
742 Ok(Statement::Explain(Box::new(inner)))
743 }
744 _ => Err(SqlError::Unsupported(format!("statement type: {}", stmt))),
745 }
746}
747
748fn convert_column_def(
751 col_def: &sp::ColumnDef,
752) -> Result<(ColumnSpec, Option<ForeignKeyDef>, bool, bool)> {
753 let col_name = col_def.name.value.clone();
754 let data_type = convert_data_type(&col_def.data_type)?;
755 let mut nullable = true;
756 let mut is_primary_key = false;
757 let mut is_unique = false;
758 let mut default_expr = None;
759 let mut default_sql = None;
760 let mut check_expr = None;
761 let mut check_sql = None;
762 let mut check_name = None;
763 let mut fk_def = None;
764
765 for opt in &col_def.options {
766 match &opt.option {
767 sp::ColumnOption::NotNull => nullable = false,
768 sp::ColumnOption::Null => nullable = true,
769 sp::ColumnOption::PrimaryKey(_) => {
770 is_primary_key = true;
771 nullable = false;
772 }
773 sp::ColumnOption::Unique(_) => is_unique = true,
774 sp::ColumnOption::Default(expr) => {
775 default_sql = Some(expr.to_string());
776 default_expr = Some(convert_expr(expr)?);
777 }
778 sp::ColumnOption::Check(check) => {
779 check_sql = Some(check.expr.to_string());
780 let converted = convert_expr(&check.expr)?;
781 if has_subquery(&converted) {
782 return Err(SqlError::Unsupported("subquery in CHECK constraint".into()));
783 }
784 check_expr = Some(converted);
785 check_name = check.name.as_ref().map(|n| n.value.clone());
786 }
787 sp::ColumnOption::ForeignKey(fk) => {
788 convert_fk_actions(&fk.on_delete, &fk.on_update)?;
789 let ftable = object_name_to_string(&fk.foreign_table).to_ascii_lowercase();
790 let referred: Vec<String> = fk
791 .referred_columns
792 .iter()
793 .map(|i| i.value.to_ascii_lowercase())
794 .collect();
795 fk_def = Some(ForeignKeyDef {
796 name: fk.name.as_ref().map(|n| n.value.clone()),
797 columns: vec![col_name.to_ascii_lowercase()],
798 foreign_table: ftable,
799 referred_columns: referred,
800 });
801 }
802 _ => {}
803 }
804 }
805
806 let spec = ColumnSpec {
807 name: col_name,
808 data_type,
809 nullable,
810 is_primary_key,
811 default_expr,
812 default_sql,
813 check_expr,
814 check_sql,
815 check_name,
816 };
817 Ok((spec, fk_def, is_primary_key, is_unique))
818}
819
820fn convert_create_table(ct: sp::CreateTable) -> Result<Statement> {
821 let name = object_name_to_string(&ct.name);
822 let if_not_exists = ct.if_not_exists;
823
824 let mut columns = Vec::new();
825 let mut inline_pk: Vec<String> = Vec::new();
826 let mut foreign_keys: Vec<ForeignKeyDef> = Vec::new();
827 let mut unique_indices: Vec<UniqueIndexDef> = Vec::new();
828
829 for col_def in &ct.columns {
830 let (spec, fk_def, was_pk, was_unique) = convert_column_def(col_def)?;
831 if was_pk {
832 inline_pk.push(spec.name.clone());
833 }
834 if let Some(fk) = fk_def {
835 foreign_keys.push(fk);
836 }
837 if was_unique && !was_pk {
838 unique_indices.push(UniqueIndexDef {
839 name: None,
840 columns: vec![spec.name.to_ascii_lowercase()],
841 });
842 }
843 columns.push(spec);
844 }
845
846 let mut check_constraints: Vec<TableCheckConstraint> = Vec::new();
847
848 for constraint in &ct.constraints {
849 match constraint {
850 sp::TableConstraint::PrimaryKey(pk_constraint) => {
851 for idx_col in &pk_constraint.columns {
852 let col_name = match &idx_col.column.expr {
853 sp::Expr::Identifier(ident) => ident.value.clone(),
854 _ => continue,
855 };
856 if !inline_pk.contains(&col_name) {
857 inline_pk.push(col_name.clone());
858 }
859 if let Some(col) = columns.iter_mut().find(|c| c.name == col_name) {
860 col.nullable = false;
861 col.is_primary_key = true;
862 }
863 }
864 }
865 sp::TableConstraint::Check(check) => {
866 let sql = check.expr.to_string();
867 let converted = convert_expr(&check.expr)?;
868 if has_subquery(&converted) {
869 return Err(SqlError::Unsupported("subquery in CHECK constraint".into()));
870 }
871 check_constraints.push(TableCheckConstraint {
872 name: check.name.as_ref().map(|n| n.value.clone()),
873 expr: converted,
874 sql,
875 });
876 }
877 sp::TableConstraint::ForeignKey(fk) => {
878 convert_fk_actions(&fk.on_delete, &fk.on_update)?;
879 let cols: Vec<String> = fk
880 .columns
881 .iter()
882 .map(|i| i.value.to_ascii_lowercase())
883 .collect();
884 let ftable = object_name_to_string(&fk.foreign_table).to_ascii_lowercase();
885 let referred: Vec<String> = fk
886 .referred_columns
887 .iter()
888 .map(|i| i.value.to_ascii_lowercase())
889 .collect();
890 foreign_keys.push(ForeignKeyDef {
891 name: fk.name.as_ref().map(|n| n.value.clone()),
892 columns: cols,
893 foreign_table: ftable,
894 referred_columns: referred,
895 });
896 }
897 sp::TableConstraint::Unique(u) => {
898 let cols: Vec<String> = u
899 .columns
900 .iter()
901 .filter_map(|idx_col| match &idx_col.column.expr {
902 sp::Expr::Identifier(ident) => Some(ident.value.to_ascii_lowercase()),
903 _ => None,
904 })
905 .collect();
906 if !cols.is_empty() {
907 unique_indices.push(UniqueIndexDef {
908 name: u.name.as_ref().map(|n| n.value.clone()),
909 columns: cols,
910 });
911 }
912 }
913 _ => {}
914 }
915 }
916
917 Ok(Statement::CreateTable(CreateTableStmt {
918 name,
919 columns,
920 primary_key: inline_pk,
921 if_not_exists,
922 check_constraints,
923 foreign_keys,
924 unique_indices,
925 }))
926}
927
928fn convert_alter_table(at: sp::AlterTable) -> Result<Statement> {
929 let table = object_name_to_string(&at.name);
930 if at.operations.len() != 1 {
931 return Err(SqlError::Unsupported(
932 "ALTER TABLE with multiple operations".into(),
933 ));
934 }
935 let op = match at.operations.into_iter().next().unwrap() {
936 sp::AlterTableOperation::AddColumn {
937 column_def,
938 if_not_exists,
939 ..
940 } => {
941 let (spec, fk, _was_pk, _was_unique) = convert_column_def(&column_def)?;
942 AlterTableOp::AddColumn {
943 column: Box::new(spec),
944 foreign_key: fk,
945 if_not_exists,
946 }
947 }
948 sp::AlterTableOperation::DropColumn {
949 column_names,
950 if_exists,
951 ..
952 } => {
953 if column_names.len() != 1 {
954 return Err(SqlError::Unsupported(
955 "DROP COLUMN with multiple columns".into(),
956 ));
957 }
958 AlterTableOp::DropColumn {
959 name: column_names.into_iter().next().unwrap().value,
960 if_exists,
961 }
962 }
963 sp::AlterTableOperation::RenameColumn {
964 old_column_name,
965 new_column_name,
966 } => AlterTableOp::RenameColumn {
967 old_name: old_column_name.value,
968 new_name: new_column_name.value,
969 },
970 sp::AlterTableOperation::RenameTable { table_name } => {
971 let new_name = match table_name {
972 sp::RenameTableNameKind::To(name) | sp::RenameTableNameKind::As(name) => {
973 object_name_to_string(&name)
974 }
975 };
976 AlterTableOp::RenameTable { new_name }
977 }
978 other => {
979 return Err(SqlError::Unsupported(format!(
980 "ALTER TABLE operation: {other}"
981 )));
982 }
983 };
984 Ok(Statement::AlterTable(Box::new(AlterTableStmt {
985 table,
986 op,
987 })))
988}
989
990fn convert_fk_actions(
991 on_delete: &Option<sp::ReferentialAction>,
992 on_update: &Option<sp::ReferentialAction>,
993) -> Result<()> {
994 for action in [on_delete, on_update] {
995 match action {
996 None
997 | Some(sp::ReferentialAction::Restrict)
998 | Some(sp::ReferentialAction::NoAction) => {}
999 Some(other) => {
1000 return Err(SqlError::Unsupported(format!(
1001 "FOREIGN KEY action: {other}"
1002 )));
1003 }
1004 }
1005 }
1006 Ok(())
1007}
1008
1009fn convert_create_index(ci: sp::CreateIndex) -> Result<Statement> {
1010 let index_name = ci
1011 .name
1012 .as_ref()
1013 .map(object_name_to_string)
1014 .ok_or_else(|| SqlError::Parse("index name required".into()))?;
1015
1016 let table_name = object_name_to_string(&ci.table_name);
1017
1018 let columns: Vec<String> = ci
1019 .columns
1020 .iter()
1021 .map(|idx_col| match &idx_col.column.expr {
1022 sp::Expr::Identifier(ident) => Ok(ident.value.clone()),
1023 other => Err(SqlError::Unsupported(format!("expression index: {other}"))),
1024 })
1025 .collect::<Result<_>>()?;
1026
1027 if columns.is_empty() {
1028 return Err(SqlError::Parse(
1029 "index must have at least one column".into(),
1030 ));
1031 }
1032
1033 Ok(Statement::CreateIndex(CreateIndexStmt {
1034 index_name,
1035 table_name,
1036 columns,
1037 unique: ci.unique,
1038 if_not_exists: ci.if_not_exists,
1039 }))
1040}
1041
1042fn convert_create_view(cv: sp::CreateView) -> Result<Statement> {
1043 let name = object_name_to_string(&cv.name);
1044
1045 if cv.materialized {
1046 return Err(SqlError::Unsupported("MATERIALIZED VIEW".into()));
1047 }
1048
1049 let sql = cv.query.to_string();
1050
1051 let dialect = GenericDialect {};
1052 let test = Parser::parse_sql(&dialect, &sql).map_err(|e| SqlError::Parse(e.to_string()))?;
1053 if test.is_empty() {
1054 return Err(SqlError::Parse("empty view definition".into()));
1055 }
1056 match &test[0] {
1057 sp::Statement::Query(_) => {}
1058 _ => {
1059 return Err(SqlError::Parse(
1060 "view body must be a SELECT statement".into(),
1061 ))
1062 }
1063 }
1064
1065 let column_aliases: Vec<String> = cv
1066 .columns
1067 .iter()
1068 .map(|c| c.name.value.to_ascii_lowercase())
1069 .collect();
1070
1071 Ok(Statement::CreateView(CreateViewStmt {
1072 name,
1073 sql,
1074 column_aliases,
1075 or_replace: cv.or_replace,
1076 if_not_exists: cv.if_not_exists,
1077 }))
1078}
1079
1080fn convert_insert(insert: sp::Insert) -> Result<Statement> {
1081 let table = match &insert.table {
1082 sp::TableObject::TableName(name) => object_name_to_string(name).to_ascii_lowercase(),
1083 _ => return Err(SqlError::Unsupported("INSERT into non-table object".into())),
1084 };
1085
1086 let columns: Vec<String> = insert
1087 .columns
1088 .iter()
1089 .map(|c| c.value.to_ascii_lowercase())
1090 .collect();
1091
1092 let query = insert
1093 .source
1094 .ok_or_else(|| SqlError::Parse("INSERT requires VALUES or SELECT".into()))?;
1095
1096 let source = match *query.body {
1097 sp::SetExpr::Values(sp::Values { rows, .. }) => {
1098 let mut result = Vec::new();
1099 for row in rows {
1100 let mut exprs = Vec::new();
1101 for expr in row {
1102 exprs.push(convert_expr(&expr)?);
1103 }
1104 result.push(exprs);
1105 }
1106 InsertSource::Values(result)
1107 }
1108 _ => {
1109 let (ctes, recursive) = if let Some(ref with) = query.with {
1110 convert_with(with)?
1111 } else {
1112 (vec![], false)
1113 };
1114 let body = convert_query_body(&query)?;
1115 InsertSource::Select(Box::new(SelectQuery {
1116 ctes,
1117 recursive,
1118 body,
1119 }))
1120 }
1121 };
1122
1123 Ok(Statement::Insert(InsertStmt {
1124 table,
1125 columns,
1126 source,
1127 }))
1128}
1129
1130fn convert_select_body(select: &sp::Select) -> Result<SelectStmt> {
1131 let distinct = match &select.distinct {
1132 Some(sp::Distinct::Distinct) => true,
1133 Some(sp::Distinct::On(_)) => {
1134 return Err(SqlError::Unsupported("DISTINCT ON".into()));
1135 }
1136 _ => false,
1137 };
1138
1139 let (from, from_alias, joins) = if select.from.is_empty() {
1140 (String::new(), None, vec![])
1141 } else if select.from.len() == 1 {
1142 let table_with_joins = &select.from[0];
1143 let (name, alias) = match &table_with_joins.relation {
1144 sp::TableFactor::Table { name, alias, .. } => {
1145 let table_name = object_name_to_string(name);
1146 let alias_str = alias.as_ref().map(|a| a.name.value.clone());
1147 (table_name, alias_str)
1148 }
1149 _ => return Err(SqlError::Unsupported("non-table FROM source".into())),
1150 };
1151 let j = table_with_joins
1152 .joins
1153 .iter()
1154 .map(convert_join)
1155 .collect::<Result<Vec<_>>>()?;
1156 (name, alias, j)
1157 } else {
1158 return Err(SqlError::Unsupported("comma-separated FROM tables".into()));
1159 };
1160
1161 let columns: Vec<SelectColumn> = select
1162 .projection
1163 .iter()
1164 .map(convert_select_item)
1165 .collect::<Result<_>>()?;
1166
1167 let where_clause = select.selection.as_ref().map(convert_expr).transpose()?;
1168
1169 let group_by = match &select.group_by {
1170 sp::GroupByExpr::Expressions(exprs, _) => {
1171 exprs.iter().map(convert_expr).collect::<Result<_>>()?
1172 }
1173 sp::GroupByExpr::All(_) => {
1174 return Err(SqlError::Unsupported("GROUP BY ALL".into()));
1175 }
1176 };
1177
1178 let having = select.having.as_ref().map(convert_expr).transpose()?;
1179
1180 Ok(SelectStmt {
1181 columns,
1182 from,
1183 from_alias,
1184 joins,
1185 distinct,
1186 where_clause,
1187 order_by: vec![],
1188 limit: None,
1189 offset: None,
1190 group_by,
1191 having,
1192 })
1193}
1194
1195fn convert_set_expr(set_expr: &sp::SetExpr) -> Result<QueryBody> {
1196 match set_expr {
1197 sp::SetExpr::Select(sel) => Ok(QueryBody::Select(Box::new(convert_select_body(sel)?))),
1198 sp::SetExpr::SetOperation {
1199 op,
1200 set_quantifier,
1201 left,
1202 right,
1203 } => {
1204 let set_op = match op {
1205 sp::SetOperator::Union => SetOp::Union,
1206 sp::SetOperator::Intersect => SetOp::Intersect,
1207 sp::SetOperator::Except | sp::SetOperator::Minus => SetOp::Except,
1208 };
1209 let all = match set_quantifier {
1210 sp::SetQuantifier::All => true,
1211 sp::SetQuantifier::None | sp::SetQuantifier::Distinct => false,
1212 _ => {
1213 return Err(SqlError::Unsupported("BY NAME set operations".into()));
1214 }
1215 };
1216 Ok(QueryBody::Compound(Box::new(CompoundSelect {
1217 op: set_op,
1218 all,
1219 left: Box::new(convert_set_expr(left)?),
1220 right: Box::new(convert_set_expr(right)?),
1221 order_by: vec![],
1222 limit: None,
1223 offset: None,
1224 })))
1225 }
1226 _ => Err(SqlError::Unsupported("unsupported set expression".into())),
1227 }
1228}
1229
1230fn convert_query_body(query: &sp::Query) -> Result<QueryBody> {
1231 let mut body = convert_set_expr(&query.body)?;
1232
1233 let order_by = if let Some(ref ob) = query.order_by {
1234 match &ob.kind {
1235 sp::OrderByKind::Expressions(exprs) => exprs
1236 .iter()
1237 .map(convert_order_by_expr)
1238 .collect::<Result<_>>()?,
1239 sp::OrderByKind::All { .. } => {
1240 return Err(SqlError::Unsupported("ORDER BY ALL".into()));
1241 }
1242 }
1243 } else {
1244 vec![]
1245 };
1246
1247 let (limit, offset) = match &query.limit_clause {
1248 Some(sp::LimitClause::LimitOffset { limit, offset, .. }) => {
1249 let l = limit.as_ref().map(convert_expr).transpose()?;
1250 let o = offset
1251 .as_ref()
1252 .map(|o| convert_expr(&o.value))
1253 .transpose()?;
1254 (l, o)
1255 }
1256 Some(sp::LimitClause::OffsetCommaLimit { limit, offset }) => {
1257 let l = Some(convert_expr(limit)?);
1258 let o = Some(convert_expr(offset)?);
1259 (l, o)
1260 }
1261 None => (None, None),
1262 };
1263
1264 match &mut body {
1265 QueryBody::Select(sel) => {
1266 sel.order_by = order_by;
1267 sel.limit = limit;
1268 sel.offset = offset;
1269 }
1270 QueryBody::Compound(comp) => {
1271 comp.order_by = order_by;
1272 comp.limit = limit;
1273 comp.offset = offset;
1274 }
1275 }
1276
1277 Ok(body)
1278}
1279
1280fn convert_subquery(query: &sp::Query) -> Result<SelectStmt> {
1281 if query.with.is_some() {
1282 return Err(SqlError::Unsupported("CTEs in subqueries".into()));
1283 }
1284 match convert_query_body(query)? {
1285 QueryBody::Select(s) => Ok(*s),
1286 QueryBody::Compound(_) => Err(SqlError::Unsupported(
1287 "UNION/INTERSECT/EXCEPT in subqueries".into(),
1288 )),
1289 }
1290}
1291
1292fn convert_with(with: &sp::With) -> Result<(Vec<CteDefinition>, bool)> {
1293 let mut names = std::collections::HashSet::new();
1294 let mut ctes = Vec::new();
1295 for cte in &with.cte_tables {
1296 let name = cte.alias.name.value.to_ascii_lowercase();
1297 if !names.insert(name.clone()) {
1298 return Err(SqlError::DuplicateCteName(name));
1299 }
1300 let column_aliases: Vec<String> = cte
1301 .alias
1302 .columns
1303 .iter()
1304 .map(|c| c.name.value.to_ascii_lowercase())
1305 .collect();
1306 let body = convert_query_body(&cte.query)?;
1307 ctes.push(CteDefinition {
1308 name,
1309 column_aliases,
1310 body,
1311 });
1312 }
1313 Ok((ctes, with.recursive))
1314}
1315
1316fn convert_query(query: sp::Query) -> Result<Statement> {
1317 let (ctes, recursive) = if let Some(ref with) = query.with {
1318 convert_with(with)?
1319 } else {
1320 (vec![], false)
1321 };
1322 let body = convert_query_body(&query)?;
1323 Ok(Statement::Select(Box::new(SelectQuery {
1324 ctes,
1325 recursive,
1326 body,
1327 })))
1328}
1329
1330fn convert_join(join: &sp::Join) -> Result<JoinClause> {
1331 let (join_type, constraint) = match &join.join_operator {
1332 sp::JoinOperator::Inner(c) => (JoinType::Inner, Some(c)),
1333 sp::JoinOperator::Join(c) => (JoinType::Inner, Some(c)),
1334 sp::JoinOperator::CrossJoin(c) => (JoinType::Cross, Some(c)),
1335 sp::JoinOperator::Left(c) => (JoinType::Left, Some(c)),
1336 sp::JoinOperator::LeftSemi(c) => (JoinType::Left, Some(c)),
1337 sp::JoinOperator::LeftAnti(c) => (JoinType::Left, Some(c)),
1338 sp::JoinOperator::Right(c) => (JoinType::Right, Some(c)),
1339 sp::JoinOperator::RightSemi(c) => (JoinType::Right, Some(c)),
1340 sp::JoinOperator::RightAnti(c) => (JoinType::Right, Some(c)),
1341 other => return Err(SqlError::Unsupported(format!("join type: {other:?}"))),
1342 };
1343
1344 let (name, alias) = match &join.relation {
1345 sp::TableFactor::Table { name, alias, .. } => {
1346 let table_name = object_name_to_string(name);
1347 let alias_str = alias.as_ref().map(|a| a.name.value.clone());
1348 (table_name, alias_str)
1349 }
1350 _ => return Err(SqlError::Unsupported("non-table JOIN source".into())),
1351 };
1352
1353 let on_clause = match constraint {
1354 Some(sp::JoinConstraint::On(expr)) => Some(convert_expr(expr)?),
1355 Some(sp::JoinConstraint::None) | None => None,
1356 Some(other) => return Err(SqlError::Unsupported(format!("join constraint: {other:?}"))),
1357 };
1358
1359 Ok(JoinClause {
1360 join_type,
1361 table: TableRef { name, alias },
1362 on_clause,
1363 })
1364}
1365
1366fn convert_update(update: sp::Update) -> Result<Statement> {
1367 let table = match &update.table.relation {
1368 sp::TableFactor::Table { name, .. } => object_name_to_string(name),
1369 _ => return Err(SqlError::Unsupported("non-table UPDATE target".into())),
1370 };
1371
1372 let assignments = update
1373 .assignments
1374 .iter()
1375 .map(|a| {
1376 let col = match &a.target {
1377 sp::AssignmentTarget::ColumnName(name) => object_name_to_string(name),
1378 _ => return Err(SqlError::Unsupported("tuple assignment".into())),
1379 };
1380 let expr = convert_expr(&a.value)?;
1381 Ok((col, expr))
1382 })
1383 .collect::<Result<_>>()?;
1384
1385 let where_clause = update.selection.as_ref().map(convert_expr).transpose()?;
1386
1387 Ok(Statement::Update(UpdateStmt {
1388 table,
1389 assignments,
1390 where_clause,
1391 }))
1392}
1393
1394fn convert_delete(delete: sp::Delete) -> Result<Statement> {
1395 let table_name = match &delete.from {
1396 sp::FromTable::WithFromKeyword(tables) => {
1397 if tables.len() != 1 {
1398 return Err(SqlError::Unsupported("multi-table DELETE".into()));
1399 }
1400 match &tables[0].relation {
1401 sp::TableFactor::Table { name, .. } => object_name_to_string(name),
1402 _ => return Err(SqlError::Unsupported("non-table DELETE target".into())),
1403 }
1404 }
1405 sp::FromTable::WithoutKeyword(tables) => {
1406 if tables.len() != 1 {
1407 return Err(SqlError::Unsupported("multi-table DELETE".into()));
1408 }
1409 match &tables[0].relation {
1410 sp::TableFactor::Table { name, .. } => object_name_to_string(name),
1411 _ => return Err(SqlError::Unsupported("non-table DELETE target".into())),
1412 }
1413 }
1414 };
1415
1416 let where_clause = delete.selection.as_ref().map(convert_expr).transpose()?;
1417
1418 Ok(Statement::Delete(DeleteStmt {
1419 table: table_name,
1420 where_clause,
1421 }))
1422}
1423
1424fn convert_expr(expr: &sp::Expr) -> Result<Expr> {
1425 match expr {
1426 sp::Expr::Value(v) => convert_value(&v.value),
1427 sp::Expr::Identifier(ident) => Ok(Expr::Column(ident.value.to_ascii_lowercase())),
1428 sp::Expr::CompoundIdentifier(parts) => {
1429 if parts.len() == 2 {
1430 Ok(Expr::QualifiedColumn {
1431 table: parts[0].value.to_ascii_lowercase(),
1432 column: parts[1].value.to_ascii_lowercase(),
1433 })
1434 } else {
1435 Ok(Expr::Column(
1436 parts.last().unwrap().value.to_ascii_lowercase(),
1437 ))
1438 }
1439 }
1440 sp::Expr::BinaryOp { left, op, right } => {
1441 let bin_op = convert_bin_op(op)?;
1442 Ok(Expr::BinaryOp {
1443 left: Box::new(convert_expr(left)?),
1444 op: bin_op,
1445 right: Box::new(convert_expr(right)?),
1446 })
1447 }
1448 sp::Expr::UnaryOp { op, expr } => {
1449 let unary_op = match op {
1450 sp::UnaryOperator::Minus => UnaryOp::Neg,
1451 sp::UnaryOperator::Not => UnaryOp::Not,
1452 _ => return Err(SqlError::Unsupported(format!("unary op: {op}"))),
1453 };
1454 Ok(Expr::UnaryOp {
1455 op: unary_op,
1456 expr: Box::new(convert_expr(expr)?),
1457 })
1458 }
1459 sp::Expr::IsNull(e) => Ok(Expr::IsNull(Box::new(convert_expr(e)?))),
1460 sp::Expr::IsNotNull(e) => Ok(Expr::IsNotNull(Box::new(convert_expr(e)?))),
1461 sp::Expr::Nested(e) => convert_expr(e),
1462 sp::Expr::Function(func) => convert_function(func),
1463 sp::Expr::InSubquery {
1464 expr: e,
1465 subquery,
1466 negated,
1467 } => {
1468 let inner_expr = convert_expr(e)?;
1469 let stmt = convert_subquery(subquery)?;
1470 Ok(Expr::InSubquery {
1471 expr: Box::new(inner_expr),
1472 subquery: Box::new(stmt),
1473 negated: *negated,
1474 })
1475 }
1476 sp::Expr::InList {
1477 expr: e,
1478 list,
1479 negated,
1480 } => {
1481 let inner_expr = convert_expr(e)?;
1482 let items = list.iter().map(convert_expr).collect::<Result<Vec<_>>>()?;
1483 Ok(Expr::InList {
1484 expr: Box::new(inner_expr),
1485 list: items,
1486 negated: *negated,
1487 })
1488 }
1489 sp::Expr::Exists { subquery, negated } => {
1490 let stmt = convert_subquery(subquery)?;
1491 Ok(Expr::Exists {
1492 subquery: Box::new(stmt),
1493 negated: *negated,
1494 })
1495 }
1496 sp::Expr::Subquery(query) => {
1497 let stmt = convert_subquery(query)?;
1498 Ok(Expr::ScalarSubquery(Box::new(stmt)))
1499 }
1500 sp::Expr::Between {
1501 expr: e,
1502 negated,
1503 low,
1504 high,
1505 } => Ok(Expr::Between {
1506 expr: Box::new(convert_expr(e)?),
1507 low: Box::new(convert_expr(low)?),
1508 high: Box::new(convert_expr(high)?),
1509 negated: *negated,
1510 }),
1511 sp::Expr::Like {
1512 expr: e,
1513 negated,
1514 pattern,
1515 escape_char,
1516 ..
1517 } => {
1518 let esc = escape_char
1519 .as_ref()
1520 .map(convert_escape_value)
1521 .transpose()?
1522 .map(Box::new);
1523 Ok(Expr::Like {
1524 expr: Box::new(convert_expr(e)?),
1525 pattern: Box::new(convert_expr(pattern)?),
1526 escape: esc,
1527 negated: *negated,
1528 })
1529 }
1530 sp::Expr::ILike {
1531 expr: e,
1532 negated,
1533 pattern,
1534 escape_char,
1535 ..
1536 } => {
1537 let esc = escape_char
1538 .as_ref()
1539 .map(convert_escape_value)
1540 .transpose()?
1541 .map(Box::new);
1542 Ok(Expr::Like {
1543 expr: Box::new(convert_expr(e)?),
1544 pattern: Box::new(convert_expr(pattern)?),
1545 escape: esc,
1546 negated: *negated,
1547 })
1548 }
1549 sp::Expr::Case {
1550 operand,
1551 conditions,
1552 else_result,
1553 ..
1554 } => {
1555 let op = operand
1556 .as_ref()
1557 .map(|e| convert_expr(e))
1558 .transpose()?
1559 .map(Box::new);
1560 let conds: Vec<(Expr, Expr)> = conditions
1561 .iter()
1562 .map(|cw| Ok((convert_expr(&cw.condition)?, convert_expr(&cw.result)?)))
1563 .collect::<Result<_>>()?;
1564 let else_r = else_result
1565 .as_ref()
1566 .map(|e| convert_expr(e))
1567 .transpose()?
1568 .map(Box::new);
1569 Ok(Expr::Case {
1570 operand: op,
1571 conditions: conds,
1572 else_result: else_r,
1573 })
1574 }
1575 sp::Expr::Cast {
1576 expr: e,
1577 data_type: dt,
1578 ..
1579 } => {
1580 let target = convert_data_type(dt)?;
1581 Ok(Expr::Cast {
1582 expr: Box::new(convert_expr(e)?),
1583 data_type: target,
1584 })
1585 }
1586 sp::Expr::Substring {
1587 expr: e,
1588 substring_from,
1589 substring_for,
1590 ..
1591 } => {
1592 let mut args = vec![convert_expr(e)?];
1593 if let Some(from) = substring_from {
1594 args.push(convert_expr(from)?);
1595 }
1596 if let Some(f) = substring_for {
1597 args.push(convert_expr(f)?);
1598 }
1599 Ok(Expr::Function {
1600 name: "SUBSTR".into(),
1601 args,
1602 distinct: false,
1603 })
1604 }
1605 sp::Expr::Trim {
1606 expr: e,
1607 trim_where,
1608 trim_what,
1609 trim_characters,
1610 } => {
1611 let fn_name = match trim_where {
1612 Some(sp::TrimWhereField::Leading) => "LTRIM",
1613 Some(sp::TrimWhereField::Trailing) => "RTRIM",
1614 _ => "TRIM",
1615 };
1616 let mut args = vec![convert_expr(e)?];
1617 if let Some(what) = trim_what {
1618 args.push(convert_expr(what)?);
1619 } else if let Some(chars) = trim_characters {
1620 if let Some(first) = chars.first() {
1621 args.push(convert_expr(first)?);
1622 }
1623 }
1624 Ok(Expr::Function {
1625 name: fn_name.into(),
1626 args,
1627 distinct: false,
1628 })
1629 }
1630 sp::Expr::Ceil { expr: e, .. } => Ok(Expr::Function {
1631 name: "CEIL".into(),
1632 args: vec![convert_expr(e)?],
1633 distinct: false,
1634 }),
1635 sp::Expr::Floor { expr: e, .. } => Ok(Expr::Function {
1636 name: "FLOOR".into(),
1637 args: vec![convert_expr(e)?],
1638 distinct: false,
1639 }),
1640 sp::Expr::Position { expr: e, r#in } => Ok(Expr::Function {
1641 name: "INSTR".into(),
1642 args: vec![convert_expr(r#in)?, convert_expr(e)?],
1643 distinct: false,
1644 }),
1645 sp::Expr::TypedString(ts) => {
1647 let raw = match &ts.value.value {
1648 sp::Value::SingleQuotedString(s) => s.clone(),
1649 sp::Value::DoubleQuotedString(s) => s.clone(),
1650 other => other.to_string(),
1651 };
1652 convert_typed_string(&ts.data_type, &raw)
1653 }
1654 sp::Expr::Interval(iv) => convert_interval_expr(iv),
1656 sp::Expr::Extract { field, expr: e, .. } => {
1658 let field_name = match field {
1659 sp::DateTimeField::Year => "year",
1660 sp::DateTimeField::Month => "month",
1661 sp::DateTimeField::Week(_) => "week",
1662 sp::DateTimeField::Day => "day",
1663 sp::DateTimeField::Date => "day",
1664 sp::DateTimeField::Hour => "hour",
1665 sp::DateTimeField::Minute => "minute",
1666 sp::DateTimeField::Second => "second",
1667 sp::DateTimeField::Millisecond => "milliseconds",
1668 sp::DateTimeField::Microsecond => "microseconds",
1669 sp::DateTimeField::Microseconds => "microseconds",
1670 sp::DateTimeField::Milliseconds => "milliseconds",
1671 sp::DateTimeField::Dow => "dow",
1672 sp::DateTimeField::Isodow => "isodow",
1673 sp::DateTimeField::Doy => "doy",
1674 sp::DateTimeField::Epoch => "epoch",
1675 sp::DateTimeField::Quarter => "quarter",
1676 sp::DateTimeField::Decade => "decade",
1677 sp::DateTimeField::Century => "century",
1678 sp::DateTimeField::Millennium => "millennium",
1679 sp::DateTimeField::Isoyear => "isoyear",
1680 sp::DateTimeField::Julian => "julian",
1681 other => {
1682 return Err(SqlError::InvalidExtractField(format!("{other:?}")));
1683 }
1684 };
1685 Ok(Expr::Function {
1686 name: "EXTRACT".into(),
1687 args: vec![
1688 Expr::Literal(Value::Text(field_name.into())),
1689 convert_expr(e)?,
1690 ],
1691 distinct: false,
1692 })
1693 }
1694 sp::Expr::AtTimeZone {
1696 timestamp,
1697 time_zone,
1698 } => Ok(Expr::Function {
1699 name: "AT_TIMEZONE".into(),
1700 args: vec![convert_expr(timestamp)?, convert_expr(time_zone)?],
1701 distinct: false,
1702 }),
1703 _ => Err(SqlError::Unsupported(format!("expression: {expr}"))),
1704 }
1705}
1706
1707fn convert_value(val: &sp::Value) -> Result<Expr> {
1708 match val {
1709 sp::Value::Number(n, _) => {
1710 if let Ok(i) = n.parse::<i64>() {
1711 Ok(Expr::Literal(Value::Integer(i)))
1712 } else if let Ok(f) = n.parse::<f64>() {
1713 Ok(Expr::Literal(Value::Real(f)))
1714 } else {
1715 Err(SqlError::InvalidValue(format!("cannot parse number: {n}")))
1716 }
1717 }
1718 sp::Value::SingleQuotedString(s) => Ok(Expr::Literal(Value::Text(s.as_str().into()))),
1719 sp::Value::Boolean(b) => Ok(Expr::Literal(Value::Boolean(*b))),
1720 sp::Value::Null => Ok(Expr::Literal(Value::Null)),
1721 sp::Value::Placeholder(s) => {
1722 let idx_str = s
1723 .strip_prefix('$')
1724 .ok_or_else(|| SqlError::Parse(format!("invalid placeholder: {s}")))?;
1725 let idx: usize = idx_str
1726 .parse()
1727 .map_err(|_| SqlError::Parse(format!("invalid placeholder index: {s}")))?;
1728 if idx == 0 {
1729 return Err(SqlError::Parse("placeholder index must be >= 1".into()));
1730 }
1731 Ok(Expr::Parameter(idx))
1732 }
1733 _ => Err(SqlError::Unsupported(format!("value type: {val}"))),
1734 }
1735}
1736
1737fn convert_typed_string(dt: &sp::DataType, value: &str) -> Result<Expr> {
1738 let s = value.trim_matches('\'');
1739 match dt {
1740 sp::DataType::Date => {
1741 let d = crate::datetime::parse_date(s)?;
1742 Ok(Expr::Literal(Value::Date(d)))
1743 }
1744 sp::DataType::Time(_, _) => {
1745 let t = crate::datetime::parse_time(s)?;
1746 Ok(Expr::Literal(Value::Time(t)))
1747 }
1748 sp::DataType::Timestamp(_, _) => {
1749 let t = crate::datetime::parse_timestamp(s)?;
1750 Ok(Expr::Literal(Value::Timestamp(t)))
1751 }
1752 sp::DataType::Interval { .. } => {
1753 let (months, days, micros) = crate::datetime::parse_interval(s)?;
1754 Ok(Expr::Literal(Value::Interval {
1755 months,
1756 days,
1757 micros,
1758 }))
1759 }
1760 _ => {
1761 let target = convert_data_type(dt)?;
1762 Ok(Expr::Cast {
1763 expr: Box::new(Expr::Literal(Value::Text(s.into()))),
1764 data_type: target,
1765 })
1766 }
1767 }
1768}
1769
1770fn convert_interval_expr(iv: &sp::Interval) -> Result<Expr> {
1771 let raw = match iv.value.as_ref() {
1772 sp::Expr::Value(v) => match &v.value {
1773 sp::Value::SingleQuotedString(s) => s.clone(),
1774 sp::Value::Number(n, _) => n.clone(),
1775 other => {
1776 return Err(SqlError::InvalidIntervalLiteral(format!(
1777 "unsupported inner value: {other}"
1778 )))
1779 }
1780 },
1781 other => {
1782 return Err(SqlError::InvalidIntervalLiteral(format!(
1783 "unsupported inner expr: {other}"
1784 )))
1785 }
1786 };
1787
1788 let with_unit = if let Some(field) = &iv.leading_field {
1790 let unit_name = match field {
1791 sp::DateTimeField::Year => "years",
1792 sp::DateTimeField::Month => "months",
1793 sp::DateTimeField::Week(_) => "weeks",
1794 sp::DateTimeField::Day => "days",
1795 sp::DateTimeField::Hour => "hours",
1796 sp::DateTimeField::Minute => "minutes",
1797 sp::DateTimeField::Second => "seconds",
1798 _ => {
1799 return Err(SqlError::InvalidIntervalLiteral(format!(
1800 "unsupported leading field: {field:?}"
1801 )))
1802 }
1803 };
1804 format!("{raw} {unit_name}")
1805 } else {
1806 raw
1807 };
1808
1809 let (months, days, micros) = crate::datetime::parse_interval(&with_unit)?;
1810 Ok(Expr::Literal(Value::Interval {
1811 months,
1812 days,
1813 micros,
1814 }))
1815}
1816
1817fn convert_escape_value(val: &sp::Value) -> Result<Expr> {
1818 match val {
1819 sp::Value::SingleQuotedString(s) => Ok(Expr::Literal(Value::Text(s.as_str().into()))),
1820 _ => Err(SqlError::Unsupported(format!("ESCAPE value: {val}"))),
1821 }
1822}
1823
1824fn convert_bin_op(op: &sp::BinaryOperator) -> Result<BinOp> {
1825 match op {
1826 sp::BinaryOperator::Plus => Ok(BinOp::Add),
1827 sp::BinaryOperator::Minus => Ok(BinOp::Sub),
1828 sp::BinaryOperator::Multiply => Ok(BinOp::Mul),
1829 sp::BinaryOperator::Divide => Ok(BinOp::Div),
1830 sp::BinaryOperator::Modulo => Ok(BinOp::Mod),
1831 sp::BinaryOperator::Eq => Ok(BinOp::Eq),
1832 sp::BinaryOperator::NotEq => Ok(BinOp::NotEq),
1833 sp::BinaryOperator::Lt => Ok(BinOp::Lt),
1834 sp::BinaryOperator::Gt => Ok(BinOp::Gt),
1835 sp::BinaryOperator::LtEq => Ok(BinOp::LtEq),
1836 sp::BinaryOperator::GtEq => Ok(BinOp::GtEq),
1837 sp::BinaryOperator::And => Ok(BinOp::And),
1838 sp::BinaryOperator::Or => Ok(BinOp::Or),
1839 sp::BinaryOperator::StringConcat => Ok(BinOp::Concat),
1840 _ => Err(SqlError::Unsupported(format!("binary op: {op}"))),
1841 }
1842}
1843
1844fn convert_function(func: &sp::Function) -> Result<Expr> {
1845 let name = object_name_to_string(&func.name).to_ascii_uppercase();
1846
1847 let (args, is_count_star, distinct) = match &func.args {
1848 sp::FunctionArguments::List(list) => {
1849 let distinct = matches!(
1850 list.duplicate_treatment,
1851 Some(sp::DuplicateTreatment::Distinct)
1852 );
1853 if list.args.is_empty() && name == "COUNT" {
1854 (vec![], true, distinct)
1855 } else {
1856 let mut count_star = false;
1857 let args = list
1858 .args
1859 .iter()
1860 .map(|arg| match arg {
1861 sp::FunctionArg::Unnamed(sp::FunctionArgExpr::Expr(e)) => convert_expr(e),
1862 sp::FunctionArg::Unnamed(sp::FunctionArgExpr::Wildcard) => {
1863 if name == "COUNT" {
1864 count_star = true;
1865 Ok(Expr::CountStar)
1866 } else {
1867 Err(SqlError::Unsupported(format!("{name}(*)")))
1868 }
1869 }
1870 _ => Err(SqlError::Unsupported(format!(
1871 "function arg type in {name}"
1872 ))),
1873 })
1874 .collect::<Result<Vec<_>>>()?;
1875 if name == "COUNT" && args.len() == 1 && count_star {
1876 (vec![], true, distinct)
1877 } else {
1878 (args, false, distinct)
1879 }
1880 }
1881 }
1882 sp::FunctionArguments::None => {
1883 if name == "COUNT" {
1884 (vec![], true, false)
1885 } else {
1886 (vec![], false, false)
1887 }
1888 }
1889 sp::FunctionArguments::Subquery(_) => {
1890 return Err(SqlError::Unsupported("subquery in function".into()));
1891 }
1892 };
1893
1894 if let Some(over) = &func.over {
1896 let spec = match over {
1897 sp::WindowType::WindowSpec(ws) => convert_window_spec(ws)?,
1898 sp::WindowType::NamedWindow(_) => {
1899 return Err(SqlError::Unsupported("named windows".into()));
1900 }
1901 };
1902 return Ok(Expr::WindowFunction { name, args, spec });
1903 }
1904
1905 if is_count_star {
1907 return Ok(Expr::CountStar);
1908 }
1909
1910 if name == "COALESCE" {
1911 if args.is_empty() {
1912 return Err(SqlError::Parse(
1913 "COALESCE requires at least one argument".into(),
1914 ));
1915 }
1916 return Ok(Expr::Coalesce(args));
1917 }
1918
1919 if name == "NULLIF" {
1920 if args.len() != 2 {
1921 return Err(SqlError::Parse(
1922 "NULLIF requires exactly two arguments".into(),
1923 ));
1924 }
1925 return Ok(Expr::Case {
1926 operand: None,
1927 conditions: vec![(
1928 Expr::BinaryOp {
1929 left: Box::new(args[0].clone()),
1930 op: BinOp::Eq,
1931 right: Box::new(args[1].clone()),
1932 },
1933 Expr::Literal(Value::Null),
1934 )],
1935 else_result: Some(Box::new(args[0].clone())),
1936 });
1937 }
1938
1939 if name == "IIF" {
1940 if args.len() != 3 {
1941 return Err(SqlError::Parse(
1942 "IIF requires exactly three arguments".into(),
1943 ));
1944 }
1945 return Ok(Expr::Case {
1946 operand: None,
1947 conditions: vec![(args[0].clone(), args[1].clone())],
1948 else_result: Some(Box::new(args[2].clone())),
1949 });
1950 }
1951
1952 Ok(Expr::Function {
1953 name,
1954 args,
1955 distinct,
1956 })
1957}
1958
1959fn convert_window_spec(ws: &sp::WindowSpec) -> Result<WindowSpec> {
1960 let partition_by = ws
1961 .partition_by
1962 .iter()
1963 .map(convert_expr)
1964 .collect::<Result<Vec<_>>>()?;
1965 let order_by = ws
1966 .order_by
1967 .iter()
1968 .map(convert_order_by_expr)
1969 .collect::<Result<Vec<_>>>()?;
1970 let frame = ws
1971 .window_frame
1972 .as_ref()
1973 .map(convert_window_frame)
1974 .transpose()?;
1975 Ok(WindowSpec {
1976 partition_by,
1977 order_by,
1978 frame,
1979 })
1980}
1981
1982fn convert_window_frame(wf: &sp::WindowFrame) -> Result<WindowFrame> {
1983 let units = match wf.units {
1984 sp::WindowFrameUnits::Rows => WindowFrameUnits::Rows,
1985 sp::WindowFrameUnits::Range => WindowFrameUnits::Range,
1986 sp::WindowFrameUnits::Groups => {
1987 return Err(SqlError::Unsupported("GROUPS window frame".into()));
1988 }
1989 };
1990 let start = convert_window_frame_bound(&wf.start_bound)?;
1991 let end = match &wf.end_bound {
1992 Some(b) => convert_window_frame_bound(b)?,
1993 None => WindowFrameBound::CurrentRow,
1994 };
1995 Ok(WindowFrame { units, start, end })
1996}
1997
1998fn convert_window_frame_bound(b: &sp::WindowFrameBound) -> Result<WindowFrameBound> {
1999 match b {
2000 sp::WindowFrameBound::CurrentRow => Ok(WindowFrameBound::CurrentRow),
2001 sp::WindowFrameBound::Preceding(None) => Ok(WindowFrameBound::UnboundedPreceding),
2002 sp::WindowFrameBound::Preceding(Some(e)) => {
2003 Ok(WindowFrameBound::Preceding(Box::new(convert_expr(e)?)))
2004 }
2005 sp::WindowFrameBound::Following(None) => Ok(WindowFrameBound::UnboundedFollowing),
2006 sp::WindowFrameBound::Following(Some(e)) => {
2007 Ok(WindowFrameBound::Following(Box::new(convert_expr(e)?)))
2008 }
2009 }
2010}
2011
2012fn convert_select_item(item: &sp::SelectItem) -> Result<SelectColumn> {
2013 match item {
2014 sp::SelectItem::Wildcard(_) => Ok(SelectColumn::AllColumns),
2015 sp::SelectItem::UnnamedExpr(e) => {
2016 let expr = convert_expr(e)?;
2017 Ok(SelectColumn::Expr { expr, alias: None })
2018 }
2019 sp::SelectItem::ExprWithAlias { expr, alias } => {
2020 let expr = convert_expr(expr)?;
2021 Ok(SelectColumn::Expr {
2022 expr,
2023 alias: Some(alias.value.clone()),
2024 })
2025 }
2026 sp::SelectItem::QualifiedWildcard(_, _) => {
2027 Err(SqlError::Unsupported("qualified wildcard (table.*)".into()))
2028 }
2029 }
2030}
2031
2032fn convert_order_by_expr(expr: &sp::OrderByExpr) -> Result<OrderByItem> {
2033 let e = convert_expr(&expr.expr)?;
2034 let descending = expr.options.asc.map(|asc| !asc).unwrap_or(false);
2035 let nulls_first = expr.options.nulls_first;
2036
2037 Ok(OrderByItem {
2038 expr: e,
2039 descending,
2040 nulls_first,
2041 })
2042}
2043
2044fn convert_data_type(dt: &sp::DataType) -> Result<DataType> {
2045 match dt {
2046 sp::DataType::Int(_)
2047 | sp::DataType::Integer(_)
2048 | sp::DataType::BigInt(_)
2049 | sp::DataType::SmallInt(_)
2050 | sp::DataType::TinyInt(_)
2051 | sp::DataType::Int2(_)
2052 | sp::DataType::Int4(_)
2053 | sp::DataType::Int8(_) => Ok(DataType::Integer),
2054
2055 sp::DataType::Real
2056 | sp::DataType::Double(..)
2057 | sp::DataType::DoublePrecision
2058 | sp::DataType::Float(_)
2059 | sp::DataType::Float4
2060 | sp::DataType::Float64 => Ok(DataType::Real),
2061
2062 sp::DataType::Varchar(_)
2063 | sp::DataType::Text
2064 | sp::DataType::Char(_)
2065 | sp::DataType::Character(_)
2066 | sp::DataType::String(_) => Ok(DataType::Text),
2067
2068 sp::DataType::Blob(_) | sp::DataType::Bytea => Ok(DataType::Blob),
2069
2070 sp::DataType::Boolean | sp::DataType::Bool => Ok(DataType::Boolean),
2071
2072 sp::DataType::Date => Ok(DataType::Date),
2073 sp::DataType::Time(_, _) => Ok(DataType::Time),
2074 sp::DataType::Timestamp(_, _) => Ok(DataType::Timestamp),
2075 sp::DataType::Interval { .. } => Ok(DataType::Interval),
2076
2077 _ => Err(SqlError::Unsupported(format!("data type: {dt}"))),
2078 }
2079}
2080
2081fn object_name_to_string(name: &sp::ObjectName) -> String {
2082 name.0
2083 .iter()
2084 .filter_map(|p| match p {
2085 sp::ObjectNamePart::Identifier(ident) => Some(ident.value.clone()),
2086 _ => None,
2087 })
2088 .collect::<Vec<_>>()
2089 .join(".")
2090}
2091
2092#[cfg(test)]
2093mod tests {
2094 use super::*;
2095
2096 #[test]
2097 fn parse_create_table() {
2098 let stmt = parse_sql(
2099 "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL, age INTEGER)",
2100 )
2101 .unwrap();
2102
2103 match stmt {
2104 Statement::CreateTable(ct) => {
2105 assert_eq!(ct.name, "users");
2106 assert_eq!(ct.columns.len(), 3);
2107 assert_eq!(ct.columns[0].name, "id");
2108 assert_eq!(ct.columns[0].data_type, DataType::Integer);
2109 assert!(ct.columns[0].is_primary_key);
2110 assert!(!ct.columns[0].nullable);
2111 assert_eq!(ct.columns[1].name, "name");
2112 assert_eq!(ct.columns[1].data_type, DataType::Text);
2113 assert!(!ct.columns[1].nullable);
2114 assert_eq!(ct.columns[2].name, "age");
2115 assert!(ct.columns[2].nullable);
2116 assert_eq!(ct.primary_key, vec!["id"]);
2117 }
2118 _ => panic!("expected CreateTable"),
2119 }
2120 }
2121
2122 #[test]
2123 fn parse_create_table_if_not_exists() {
2124 let stmt = parse_sql("CREATE TABLE IF NOT EXISTS t (id INT PRIMARY KEY)").unwrap();
2125 match stmt {
2126 Statement::CreateTable(ct) => assert!(ct.if_not_exists),
2127 _ => panic!("expected CreateTable"),
2128 }
2129 }
2130
2131 #[test]
2132 fn parse_drop_table() {
2133 let stmt = parse_sql("DROP TABLE users").unwrap();
2134 match stmt {
2135 Statement::DropTable(dt) => {
2136 assert_eq!(dt.name, "users");
2137 assert!(!dt.if_exists);
2138 }
2139 _ => panic!("expected DropTable"),
2140 }
2141 }
2142
2143 #[test]
2144 fn parse_drop_table_if_exists() {
2145 let stmt = parse_sql("DROP TABLE IF EXISTS users").unwrap();
2146 match stmt {
2147 Statement::DropTable(dt) => assert!(dt.if_exists),
2148 _ => panic!("expected DropTable"),
2149 }
2150 }
2151
2152 #[test]
2153 fn parse_insert() {
2154 let stmt =
2155 parse_sql("INSERT INTO users (id, name) VALUES (1, 'Alice'), (2, 'Bob')").unwrap();
2156
2157 match stmt {
2158 Statement::Insert(ins) => {
2159 assert_eq!(ins.table, "users");
2160 assert_eq!(ins.columns, vec!["id", "name"]);
2161 let values = match &ins.source {
2162 InsertSource::Values(v) => v,
2163 _ => panic!("expected Values"),
2164 };
2165 assert_eq!(values.len(), 2);
2166 assert!(matches!(values[0][0], Expr::Literal(Value::Integer(1))));
2167 assert!(matches!(&values[0][1], Expr::Literal(Value::Text(s)) if s == "Alice"));
2168 }
2169 _ => panic!("expected Insert"),
2170 }
2171 }
2172
2173 #[test]
2174 fn parse_select_all() {
2175 let stmt = parse_sql("SELECT * FROM users").unwrap();
2176 match stmt {
2177 Statement::Select(sq) => match sq.body {
2178 QueryBody::Select(sel) => {
2179 assert_eq!(sel.from, "users");
2180 assert!(matches!(sel.columns[0], SelectColumn::AllColumns));
2181 assert!(sel.where_clause.is_none());
2182 }
2183 _ => panic!("expected QueryBody::Select"),
2184 },
2185 _ => panic!("expected Select"),
2186 }
2187 }
2188
2189 #[test]
2190 fn parse_select_where() {
2191 let stmt = parse_sql("SELECT id, name FROM users WHERE age > 18").unwrap();
2192 match stmt {
2193 Statement::Select(sq) => match sq.body {
2194 QueryBody::Select(sel) => {
2195 assert_eq!(sel.columns.len(), 2);
2196 assert!(sel.where_clause.is_some());
2197 }
2198 _ => panic!("expected QueryBody::Select"),
2199 },
2200 _ => panic!("expected Select"),
2201 }
2202 }
2203
2204 #[test]
2205 fn parse_select_order_limit() {
2206 let stmt = parse_sql("SELECT * FROM users ORDER BY name ASC LIMIT 10 OFFSET 5").unwrap();
2207 match stmt {
2208 Statement::Select(sq) => match sq.body {
2209 QueryBody::Select(sel) => {
2210 assert_eq!(sel.order_by.len(), 1);
2211 assert!(!sel.order_by[0].descending);
2212 assert!(sel.limit.is_some());
2213 assert!(sel.offset.is_some());
2214 }
2215 _ => panic!("expected QueryBody::Select"),
2216 },
2217 _ => panic!("expected Select"),
2218 }
2219 }
2220
2221 #[test]
2222 fn parse_update() {
2223 let stmt = parse_sql("UPDATE users SET name = 'Bob' WHERE id = 1").unwrap();
2224 match stmt {
2225 Statement::Update(upd) => {
2226 assert_eq!(upd.table, "users");
2227 assert_eq!(upd.assignments.len(), 1);
2228 assert_eq!(upd.assignments[0].0, "name");
2229 assert!(upd.where_clause.is_some());
2230 }
2231 _ => panic!("expected Update"),
2232 }
2233 }
2234
2235 #[test]
2236 fn parse_delete() {
2237 let stmt = parse_sql("DELETE FROM users WHERE id = 1").unwrap();
2238 match stmt {
2239 Statement::Delete(del) => {
2240 assert_eq!(del.table, "users");
2241 assert!(del.where_clause.is_some());
2242 }
2243 _ => panic!("expected Delete"),
2244 }
2245 }
2246
2247 #[test]
2248 fn parse_aggregate() {
2249 let stmt = parse_sql("SELECT COUNT(*), SUM(age) FROM users").unwrap();
2250 match stmt {
2251 Statement::Select(sq) => match sq.body {
2252 QueryBody::Select(sel) => {
2253 assert_eq!(sel.columns.len(), 2);
2254 match &sel.columns[0] {
2255 SelectColumn::Expr {
2256 expr: Expr::CountStar,
2257 ..
2258 } => {}
2259 other => panic!("expected CountStar, got {other:?}"),
2260 }
2261 }
2262 _ => panic!("expected QueryBody::Select"),
2263 },
2264 _ => panic!("expected Select"),
2265 }
2266 }
2267
2268 #[test]
2269 fn parse_group_by_having() {
2270 let stmt = parse_sql(
2271 "SELECT department, COUNT(*) FROM employees GROUP BY department HAVING COUNT(*) > 5",
2272 )
2273 .unwrap();
2274 match stmt {
2275 Statement::Select(sq) => match sq.body {
2276 QueryBody::Select(sel) => {
2277 assert_eq!(sel.group_by.len(), 1);
2278 assert!(sel.having.is_some());
2279 }
2280 _ => panic!("expected QueryBody::Select"),
2281 },
2282 _ => panic!("expected Select"),
2283 }
2284 }
2285
2286 #[test]
2287 fn parse_expressions() {
2288 let stmt = parse_sql("SELECT id + 1, -price, NOT active FROM items").unwrap();
2289 match stmt {
2290 Statement::Select(sq) => match sq.body {
2291 QueryBody::Select(sel) => {
2292 assert_eq!(sel.columns.len(), 3);
2293 match &sel.columns[0] {
2295 SelectColumn::Expr {
2296 expr: Expr::BinaryOp { op: BinOp::Add, .. },
2297 ..
2298 } => {}
2299 other => panic!("expected BinaryOp Add, got {other:?}"),
2300 }
2301 match &sel.columns[1] {
2303 SelectColumn::Expr {
2304 expr:
2305 Expr::UnaryOp {
2306 op: UnaryOp::Neg, ..
2307 },
2308 ..
2309 } => {}
2310 other => panic!("expected UnaryOp Neg, got {other:?}"),
2311 }
2312 match &sel.columns[2] {
2314 SelectColumn::Expr {
2315 expr:
2316 Expr::UnaryOp {
2317 op: UnaryOp::Not, ..
2318 },
2319 ..
2320 } => {}
2321 other => panic!("expected UnaryOp Not, got {other:?}"),
2322 }
2323 }
2324 _ => panic!("expected QueryBody::Select"),
2325 },
2326 _ => panic!("expected Select"),
2327 }
2328 }
2329
2330 #[test]
2331 fn parse_is_null() {
2332 let stmt = parse_sql("SELECT * FROM t WHERE x IS NULL").unwrap();
2333 match stmt {
2334 Statement::Select(sq) => match sq.body {
2335 QueryBody::Select(sel) => {
2336 assert!(matches!(sel.where_clause, Some(Expr::IsNull(_))));
2337 }
2338 _ => panic!("expected QueryBody::Select"),
2339 },
2340 _ => panic!("expected Select"),
2341 }
2342 }
2343
2344 #[test]
2345 fn parse_inner_join() {
2346 let stmt = parse_sql("SELECT * FROM a JOIN b ON a.id = b.id").unwrap();
2347 match stmt {
2348 Statement::Select(sq) => match sq.body {
2349 QueryBody::Select(sel) => {
2350 assert_eq!(sel.from, "a");
2351 assert_eq!(sel.joins.len(), 1);
2352 assert_eq!(sel.joins[0].join_type, JoinType::Inner);
2353 assert_eq!(sel.joins[0].table.name, "b");
2354 assert!(sel.joins[0].on_clause.is_some());
2355 }
2356 _ => panic!("expected QueryBody::Select"),
2357 },
2358 _ => panic!("expected Select"),
2359 }
2360 }
2361
2362 #[test]
2363 fn parse_inner_join_explicit() {
2364 let stmt = parse_sql("SELECT * FROM a INNER JOIN b ON a.id = b.a_id").unwrap();
2365 match stmt {
2366 Statement::Select(sq) => match sq.body {
2367 QueryBody::Select(sel) => {
2368 assert_eq!(sel.joins.len(), 1);
2369 assert_eq!(sel.joins[0].join_type, JoinType::Inner);
2370 }
2371 _ => panic!("expected QueryBody::Select"),
2372 },
2373 _ => panic!("expected Select"),
2374 }
2375 }
2376
2377 #[test]
2378 fn parse_cross_join() {
2379 let stmt = parse_sql("SELECT * FROM a CROSS JOIN b").unwrap();
2380 match stmt {
2381 Statement::Select(sq) => match sq.body {
2382 QueryBody::Select(sel) => {
2383 assert_eq!(sel.joins.len(), 1);
2384 assert_eq!(sel.joins[0].join_type, JoinType::Cross);
2385 assert!(sel.joins[0].on_clause.is_none());
2386 }
2387 _ => panic!("expected QueryBody::Select"),
2388 },
2389 _ => panic!("expected Select"),
2390 }
2391 }
2392
2393 #[test]
2394 fn parse_left_join() {
2395 let stmt = parse_sql("SELECT * FROM a LEFT JOIN b ON a.id = b.a_id").unwrap();
2396 match stmt {
2397 Statement::Select(sq) => match sq.body {
2398 QueryBody::Select(sel) => {
2399 assert_eq!(sel.joins.len(), 1);
2400 assert_eq!(sel.joins[0].join_type, JoinType::Left);
2401 }
2402 _ => panic!("expected QueryBody::Select"),
2403 },
2404 _ => panic!("expected Select"),
2405 }
2406 }
2407
2408 #[test]
2409 fn parse_table_alias() {
2410 let stmt = parse_sql("SELECT u.id FROM users u JOIN orders o ON u.id = o.user_id").unwrap();
2411 match stmt {
2412 Statement::Select(sq) => match sq.body {
2413 QueryBody::Select(sel) => {
2414 assert_eq!(sel.from, "users");
2415 assert_eq!(sel.from_alias.as_deref(), Some("u"));
2416 assert_eq!(sel.joins[0].table.name, "orders");
2417 assert_eq!(sel.joins[0].table.alias.as_deref(), Some("o"));
2418 }
2419 _ => panic!("expected QueryBody::Select"),
2420 },
2421 _ => panic!("expected Select"),
2422 }
2423 }
2424
2425 #[test]
2426 fn parse_multi_join() {
2427 let stmt =
2428 parse_sql("SELECT * FROM a JOIN b ON a.id = b.a_id JOIN c ON b.id = c.b_id").unwrap();
2429 match stmt {
2430 Statement::Select(sq) => match sq.body {
2431 QueryBody::Select(sel) => {
2432 assert_eq!(sel.joins.len(), 2);
2433 }
2434 _ => panic!("expected QueryBody::Select"),
2435 },
2436 _ => panic!("expected Select"),
2437 }
2438 }
2439
2440 #[test]
2441 fn parse_qualified_column() {
2442 let stmt = parse_sql("SELECT u.id, u.name FROM users u").unwrap();
2443 match stmt {
2444 Statement::Select(sq) => match sq.body {
2445 QueryBody::Select(sel) => match &sel.columns[0] {
2446 SelectColumn::Expr {
2447 expr: Expr::QualifiedColumn { table, column },
2448 ..
2449 } => {
2450 assert_eq!(table, "u");
2451 assert_eq!(column, "id");
2452 }
2453 other => panic!("expected QualifiedColumn, got {other:?}"),
2454 },
2455 _ => panic!("expected QueryBody::Select"),
2456 },
2457 _ => panic!("expected Select"),
2458 }
2459 }
2460
2461 #[test]
2462 fn reject_subquery() {
2463 assert!(parse_sql("SELECT * FROM (SELECT 1)").is_err());
2464 }
2465
2466 #[test]
2467 fn parse_type_mapping() {
2468 let stmt = parse_sql(
2469 "CREATE TABLE t (a INT PRIMARY KEY, b BIGINT, c SMALLINT, d REAL, e DOUBLE PRECISION, f VARCHAR(255), g BOOLEAN, h BLOB, i BYTEA)"
2470 ).unwrap();
2471 match stmt {
2472 Statement::CreateTable(ct) => {
2473 assert_eq!(ct.columns[0].data_type, DataType::Integer); assert_eq!(ct.columns[1].data_type, DataType::Integer); assert_eq!(ct.columns[2].data_type, DataType::Integer); assert_eq!(ct.columns[3].data_type, DataType::Real); assert_eq!(ct.columns[4].data_type, DataType::Real); assert_eq!(ct.columns[5].data_type, DataType::Text); assert_eq!(ct.columns[6].data_type, DataType::Boolean); assert_eq!(ct.columns[7].data_type, DataType::Blob); assert_eq!(ct.columns[8].data_type, DataType::Blob); }
2483 _ => panic!("expected CreateTable"),
2484 }
2485 }
2486
2487 #[test]
2488 fn parse_boolean_literals() {
2489 let stmt = parse_sql("INSERT INTO t (a, b) VALUES (true, false)").unwrap();
2490 match stmt {
2491 Statement::Insert(ins) => {
2492 let values = match &ins.source {
2493 InsertSource::Values(v) => v,
2494 _ => panic!("expected Values"),
2495 };
2496 assert!(matches!(values[0][0], Expr::Literal(Value::Boolean(true))));
2497 assert!(matches!(values[0][1], Expr::Literal(Value::Boolean(false))));
2498 }
2499 _ => panic!("expected Insert"),
2500 }
2501 }
2502
2503 #[test]
2504 fn parse_null_literal() {
2505 let stmt = parse_sql("INSERT INTO t (a) VALUES (NULL)").unwrap();
2506 match stmt {
2507 Statement::Insert(ins) => {
2508 let values = match &ins.source {
2509 InsertSource::Values(v) => v,
2510 _ => panic!("expected Values"),
2511 };
2512 assert!(matches!(values[0][0], Expr::Literal(Value::Null)));
2513 }
2514 _ => panic!("expected Insert"),
2515 }
2516 }
2517
2518 #[test]
2519 fn parse_alias() {
2520 let stmt = parse_sql("SELECT id AS user_id FROM users").unwrap();
2521 match stmt {
2522 Statement::Select(sq) => match sq.body {
2523 QueryBody::Select(sel) => match &sel.columns[0] {
2524 SelectColumn::Expr { alias: Some(a), .. } => assert_eq!(a, "user_id"),
2525 other => panic!("expected alias, got {other:?}"),
2526 },
2527 _ => panic!("expected QueryBody::Select"),
2528 },
2529 _ => panic!("expected Select"),
2530 }
2531 }
2532
2533 #[test]
2534 fn parse_begin() {
2535 let stmt = parse_sql("BEGIN").unwrap();
2536 assert!(matches!(stmt, Statement::Begin));
2537 }
2538
2539 #[test]
2540 fn parse_begin_transaction() {
2541 let stmt = parse_sql("BEGIN TRANSACTION").unwrap();
2542 assert!(matches!(stmt, Statement::Begin));
2543 }
2544
2545 #[test]
2546 fn parse_commit() {
2547 let stmt = parse_sql("COMMIT").unwrap();
2548 assert!(matches!(stmt, Statement::Commit));
2549 }
2550
2551 #[test]
2552 fn parse_rollback() {
2553 let stmt = parse_sql("ROLLBACK").unwrap();
2554 assert!(matches!(stmt, Statement::Rollback));
2555 }
2556
2557 #[test]
2558 fn parse_savepoint() {
2559 let stmt = parse_sql("SAVEPOINT sp1").unwrap();
2560 match stmt {
2561 Statement::Savepoint(name) => assert_eq!(name, "sp1"),
2562 other => panic!("expected Savepoint, got {other:?}"),
2563 }
2564 }
2565
2566 #[test]
2567 fn parse_savepoint_case_insensitive() {
2568 let stmt = parse_sql("SAVEPOINT My_SP").unwrap();
2569 match stmt {
2570 Statement::Savepoint(name) => assert_eq!(name, "my_sp"),
2571 other => panic!("expected Savepoint, got {other:?}"),
2572 }
2573 }
2574
2575 #[test]
2576 fn parse_release_savepoint() {
2577 let stmt = parse_sql("RELEASE SAVEPOINT sp1").unwrap();
2578 match stmt {
2579 Statement::ReleaseSavepoint(name) => assert_eq!(name, "sp1"),
2580 other => panic!("expected ReleaseSavepoint, got {other:?}"),
2581 }
2582 }
2583
2584 #[test]
2585 fn parse_release_without_savepoint_keyword() {
2586 let stmt = parse_sql("RELEASE sp1").unwrap();
2587 match stmt {
2588 Statement::ReleaseSavepoint(name) => assert_eq!(name, "sp1"),
2589 other => panic!("expected ReleaseSavepoint, got {other:?}"),
2590 }
2591 }
2592
2593 #[test]
2594 fn parse_rollback_to_savepoint() {
2595 let stmt = parse_sql("ROLLBACK TO SAVEPOINT sp1").unwrap();
2596 match stmt {
2597 Statement::RollbackTo(name) => assert_eq!(name, "sp1"),
2598 other => panic!("expected RollbackTo, got {other:?}"),
2599 }
2600 }
2601
2602 #[test]
2603 fn parse_rollback_to_without_savepoint_keyword() {
2604 let stmt = parse_sql("ROLLBACK TO sp1").unwrap();
2605 match stmt {
2606 Statement::RollbackTo(name) => assert_eq!(name, "sp1"),
2607 other => panic!("expected RollbackTo, got {other:?}"),
2608 }
2609 }
2610
2611 #[test]
2612 fn parse_rollback_to_case_insensitive() {
2613 let stmt = parse_sql("ROLLBACK TO My_SP").unwrap();
2614 match stmt {
2615 Statement::RollbackTo(name) => assert_eq!(name, "my_sp"),
2616 other => panic!("expected RollbackTo, got {other:?}"),
2617 }
2618 }
2619
2620 #[test]
2621 fn parse_commit_and_chain_rejected() {
2622 let err = parse_sql("COMMIT AND CHAIN").unwrap_err();
2623 assert!(matches!(err, SqlError::Unsupported(_)));
2624 }
2625
2626 #[test]
2627 fn parse_rollback_and_chain_rejected() {
2628 let err = parse_sql("ROLLBACK AND CHAIN").unwrap_err();
2629 assert!(matches!(err, SqlError::Unsupported(_)));
2630 }
2631
2632 #[test]
2633 fn parse_select_distinct() {
2634 let stmt = parse_sql("SELECT DISTINCT name FROM users").unwrap();
2635 match stmt {
2636 Statement::Select(sq) => match sq.body {
2637 QueryBody::Select(sel) => {
2638 assert!(sel.distinct);
2639 assert_eq!(sel.columns.len(), 1);
2640 }
2641 _ => panic!("expected QueryBody::Select"),
2642 },
2643 _ => panic!("expected Select"),
2644 }
2645 }
2646
2647 #[test]
2648 fn parse_select_without_distinct() {
2649 let stmt = parse_sql("SELECT name FROM users").unwrap();
2650 match stmt {
2651 Statement::Select(sq) => match sq.body {
2652 QueryBody::Select(sel) => {
2653 assert!(!sel.distinct);
2654 }
2655 _ => panic!("expected QueryBody::Select"),
2656 },
2657 _ => panic!("expected Select"),
2658 }
2659 }
2660
2661 #[test]
2662 fn parse_select_distinct_all_columns() {
2663 let stmt = parse_sql("SELECT DISTINCT * FROM users").unwrap();
2664 match stmt {
2665 Statement::Select(sq) => match sq.body {
2666 QueryBody::Select(sel) => {
2667 assert!(sel.distinct);
2668 assert!(matches!(sel.columns[0], SelectColumn::AllColumns));
2669 }
2670 _ => panic!("expected QueryBody::Select"),
2671 },
2672 _ => panic!("expected Select"),
2673 }
2674 }
2675
2676 #[test]
2677 fn reject_distinct_on() {
2678 assert!(parse_sql("SELECT DISTINCT ON (id) * FROM users").is_err());
2679 }
2680
2681 #[test]
2682 fn parse_create_index() {
2683 let stmt = parse_sql("CREATE INDEX idx_name ON users (name)").unwrap();
2684 match stmt {
2685 Statement::CreateIndex(ci) => {
2686 assert_eq!(ci.index_name, "idx_name");
2687 assert_eq!(ci.table_name, "users");
2688 assert_eq!(ci.columns, vec!["name"]);
2689 assert!(!ci.unique);
2690 assert!(!ci.if_not_exists);
2691 }
2692 _ => panic!("expected CreateIndex"),
2693 }
2694 }
2695
2696 #[test]
2697 fn parse_create_unique_index() {
2698 let stmt = parse_sql("CREATE UNIQUE INDEX idx_email ON users (email)").unwrap();
2699 match stmt {
2700 Statement::CreateIndex(ci) => {
2701 assert!(ci.unique);
2702 assert_eq!(ci.columns, vec!["email"]);
2703 }
2704 _ => panic!("expected CreateIndex"),
2705 }
2706 }
2707
2708 #[test]
2709 fn parse_create_index_if_not_exists() {
2710 let stmt = parse_sql("CREATE INDEX IF NOT EXISTS idx_x ON t (a)").unwrap();
2711 match stmt {
2712 Statement::CreateIndex(ci) => assert!(ci.if_not_exists),
2713 _ => panic!("expected CreateIndex"),
2714 }
2715 }
2716
2717 #[test]
2718 fn parse_create_index_multi_column() {
2719 let stmt = parse_sql("CREATE INDEX idx_multi ON t (a, b, c)").unwrap();
2720 match stmt {
2721 Statement::CreateIndex(ci) => {
2722 assert_eq!(ci.columns, vec!["a", "b", "c"]);
2723 }
2724 _ => panic!("expected CreateIndex"),
2725 }
2726 }
2727
2728 #[test]
2729 fn parse_drop_index() {
2730 let stmt = parse_sql("DROP INDEX idx_name").unwrap();
2731 match stmt {
2732 Statement::DropIndex(di) => {
2733 assert_eq!(di.index_name, "idx_name");
2734 assert!(!di.if_exists);
2735 }
2736 _ => panic!("expected DropIndex"),
2737 }
2738 }
2739
2740 #[test]
2741 fn parse_drop_index_if_exists() {
2742 let stmt = parse_sql("DROP INDEX IF EXISTS idx_name").unwrap();
2743 match stmt {
2744 Statement::DropIndex(di) => {
2745 assert!(di.if_exists);
2746 }
2747 _ => panic!("expected DropIndex"),
2748 }
2749 }
2750
2751 #[test]
2752 fn parse_explain_select() {
2753 let stmt = parse_sql("EXPLAIN SELECT * FROM users WHERE id = 1").unwrap();
2754 match stmt {
2755 Statement::Explain(inner) => {
2756 assert!(matches!(*inner, Statement::Select(_)));
2757 }
2758 _ => panic!("expected Explain"),
2759 }
2760 }
2761
2762 #[test]
2763 fn parse_explain_insert() {
2764 let stmt = parse_sql("EXPLAIN INSERT INTO t (a) VALUES (1)").unwrap();
2765 assert!(matches!(stmt, Statement::Explain(_)));
2766 }
2767
2768 #[test]
2769 fn reject_explain_analyze() {
2770 assert!(parse_sql("EXPLAIN ANALYZE SELECT * FROM t").is_err());
2771 }
2772
2773 #[test]
2774 fn parse_parameter_placeholder() {
2775 let stmt = parse_sql("SELECT * FROM t WHERE id = $1").unwrap();
2776 match stmt {
2777 Statement::Select(sq) => match sq.body {
2778 QueryBody::Select(sel) => match &sel.where_clause {
2779 Some(Expr::BinaryOp { right, .. }) => {
2780 assert!(matches!(right.as_ref(), Expr::Parameter(1)));
2781 }
2782 other => panic!("expected BinaryOp with Parameter, got {other:?}"),
2783 },
2784 _ => panic!("expected QueryBody::Select"),
2785 },
2786 _ => panic!("expected Select"),
2787 }
2788 }
2789
2790 #[test]
2791 fn parse_multiple_parameters() {
2792 let stmt = parse_sql("INSERT INTO t (a, b) VALUES ($1, $2)").unwrap();
2793 match stmt {
2794 Statement::Insert(ins) => {
2795 let values = match &ins.source {
2796 InsertSource::Values(v) => v,
2797 _ => panic!("expected Values"),
2798 };
2799 assert!(matches!(values[0][0], Expr::Parameter(1)));
2800 assert!(matches!(values[0][1], Expr::Parameter(2)));
2801 }
2802 _ => panic!("expected Insert"),
2803 }
2804 }
2805
2806 #[test]
2807 fn parse_insert_select() {
2808 let stmt =
2809 parse_sql("INSERT INTO t2 (id, name) SELECT id, name FROM t1 WHERE id > 5").unwrap();
2810 match stmt {
2811 Statement::Insert(ins) => {
2812 assert_eq!(ins.table, "t2");
2813 assert_eq!(ins.columns, vec!["id", "name"]);
2814 match &ins.source {
2815 InsertSource::Select(sq) => match &sq.body {
2816 QueryBody::Select(sel) => {
2817 assert_eq!(sel.from, "t1");
2818 assert_eq!(sel.columns.len(), 2);
2819 assert!(sel.where_clause.is_some());
2820 }
2821 _ => panic!("expected QueryBody::Select"),
2822 },
2823 _ => panic!("expected InsertSource::Select"),
2824 }
2825 }
2826 _ => panic!("expected Insert"),
2827 }
2828 }
2829
2830 #[test]
2831 fn parse_insert_select_no_columns() {
2832 let stmt = parse_sql("INSERT INTO t2 SELECT * FROM t1").unwrap();
2833 match stmt {
2834 Statement::Insert(ins) => {
2835 assert_eq!(ins.table, "t2");
2836 assert!(ins.columns.is_empty());
2837 assert!(matches!(&ins.source, InsertSource::Select(_)));
2838 }
2839 _ => panic!("expected Insert"),
2840 }
2841 }
2842
2843 #[test]
2844 fn reject_zero_parameter() {
2845 assert!(parse_sql("SELECT $0 FROM t").is_err());
2846 }
2847
2848 #[test]
2849 fn count_params_basic() {
2850 let stmt = parse_sql("SELECT * FROM t WHERE a = $1 AND b = $3").unwrap();
2851 assert_eq!(count_params(&stmt), 3);
2852 }
2853
2854 #[test]
2855 fn count_params_none() {
2856 let stmt = parse_sql("SELECT * FROM t WHERE a = 1").unwrap();
2857 assert_eq!(count_params(&stmt), 0);
2858 }
2859
2860 #[test]
2861 fn parse_table_constraint_pk() {
2862 let stmt = parse_sql("CREATE TABLE t (a INTEGER, b TEXT, PRIMARY KEY (a))").unwrap();
2863 match stmt {
2864 Statement::CreateTable(ct) => {
2865 assert_eq!(ct.primary_key, vec!["a"]);
2866 assert!(ct.columns[0].is_primary_key);
2867 assert!(!ct.columns[0].nullable);
2868 }
2869 _ => panic!("expected CreateTable"),
2870 }
2871 }
2872}