1use sqlparser::ast as sp;
4use sqlparser::dialect::GenericDialect;
5use sqlparser::parser::Parser;
6
7use crate::error::{Result, SqlError};
8use crate::types::{DataType, Value};
9
10#[derive(Debug, Clone)]
13pub enum Statement {
14 CreateTable(CreateTableStmt),
15 DropTable(DropTableStmt),
16 CreateIndex(CreateIndexStmt),
17 DropIndex(DropIndexStmt),
18 Insert(InsertStmt),
19 Select(SelectStmt),
20 Update(UpdateStmt),
21 Delete(DeleteStmt),
22 Begin,
23 Commit,
24 Rollback,
25 Explain(Box<Statement>),
26}
27
28#[derive(Debug, Clone)]
29pub struct CreateTableStmt {
30 pub name: String,
31 pub columns: Vec<ColumnSpec>,
32 pub primary_key: Vec<String>,
33 pub if_not_exists: bool,
34}
35
36#[derive(Debug, Clone)]
37pub struct ColumnSpec {
38 pub name: String,
39 pub data_type: DataType,
40 pub nullable: bool,
41 pub is_primary_key: bool,
42}
43
44#[derive(Debug, Clone)]
45pub struct DropTableStmt {
46 pub name: String,
47 pub if_exists: bool,
48}
49
50#[derive(Debug, Clone)]
51pub struct CreateIndexStmt {
52 pub index_name: String,
53 pub table_name: String,
54 pub columns: Vec<String>,
55 pub unique: bool,
56 pub if_not_exists: bool,
57}
58
59#[derive(Debug, Clone)]
60pub struct DropIndexStmt {
61 pub index_name: String,
62 pub if_exists: bool,
63}
64
65#[derive(Debug, Clone)]
66pub struct InsertStmt {
67 pub table: String,
68 pub columns: Vec<String>,
69 pub values: Vec<Vec<Expr>>,
70}
71
72#[derive(Debug, Clone)]
73pub struct TableRef {
74 pub name: String,
75 pub alias: Option<String>,
76}
77
78#[derive(Debug, Clone, Copy, PartialEq)]
79pub enum JoinType {
80 Inner,
81 Cross,
82 Left,
83 Right,
84}
85
86#[derive(Debug, Clone)]
87pub struct JoinClause {
88 pub join_type: JoinType,
89 pub table: TableRef,
90 pub on_clause: Option<Expr>,
91}
92
93#[derive(Debug, Clone)]
94pub struct SelectStmt {
95 pub columns: Vec<SelectColumn>,
96 pub from: String,
97 pub from_alias: Option<String>,
98 pub joins: Vec<JoinClause>,
99 pub distinct: bool,
100 pub where_clause: Option<Expr>,
101 pub order_by: Vec<OrderByItem>,
102 pub limit: Option<Expr>,
103 pub offset: Option<Expr>,
104 pub group_by: Vec<Expr>,
105 pub having: Option<Expr>,
106}
107
108#[derive(Debug, Clone)]
109pub struct UpdateStmt {
110 pub table: String,
111 pub assignments: Vec<(String, Expr)>,
112 pub where_clause: Option<Expr>,
113}
114
115#[derive(Debug, Clone)]
116pub struct DeleteStmt {
117 pub table: String,
118 pub where_clause: Option<Expr>,
119}
120
121#[derive(Debug, Clone)]
122pub enum SelectColumn {
123 AllColumns,
124 Expr { expr: Expr, alias: Option<String> },
125}
126
127#[derive(Debug, Clone)]
128pub struct OrderByItem {
129 pub expr: Expr,
130 pub descending: bool,
131 pub nulls_first: Option<bool>,
132}
133
134#[derive(Debug, Clone)]
135pub enum Expr {
136 Literal(Value),
137 Column(String),
138 QualifiedColumn { table: String, column: String },
139 BinaryOp { left: Box<Expr>, op: BinOp, right: Box<Expr> },
140 UnaryOp { op: UnaryOp, expr: Box<Expr> },
141 IsNull(Box<Expr>),
142 IsNotNull(Box<Expr>),
143 Function { name: String, args: Vec<Expr> },
144 CountStar,
145 InSubquery { expr: Box<Expr>, subquery: Box<SelectStmt>, negated: bool },
146 InList { expr: Box<Expr>, list: Vec<Expr>, negated: bool },
147 Exists { subquery: Box<SelectStmt>, negated: bool },
148 ScalarSubquery(Box<SelectStmt>),
149 InSet { expr: Box<Expr>, values: std::collections::HashSet<Value>, has_null: bool, negated: bool },
150 Between { expr: Box<Expr>, low: Box<Expr>, high: Box<Expr>, negated: bool },
151 Like { expr: Box<Expr>, pattern: Box<Expr>, escape: Option<Box<Expr>>, negated: bool },
152 Case { operand: Option<Box<Expr>>, conditions: Vec<(Expr, Expr)>, else_result: Option<Box<Expr>> },
153 Coalesce(Vec<Expr>),
154 Cast { expr: Box<Expr>, data_type: DataType },
155 Parameter(usize),
156}
157
158#[derive(Debug, Clone, Copy, PartialEq, Eq)]
159pub enum BinOp {
160 Add, Sub, Mul, Div, Mod,
161 Eq, NotEq, Lt, Gt, LtEq, GtEq,
162 And, Or, Concat,
163}
164
165#[derive(Debug, Clone, Copy, PartialEq, Eq)]
166pub enum UnaryOp {
167 Neg,
168 Not,
169}
170
171pub fn parse_sql(sql: &str) -> Result<Statement> {
174 let dialect = GenericDialect {};
175 let stmts = Parser::parse_sql(&dialect, sql)
176 .map_err(|e| SqlError::Parse(e.to_string()))?;
177
178 if stmts.is_empty() {
179 return Err(SqlError::Parse("empty SQL".into()));
180 }
181 if stmts.len() > 1 {
182 return Err(SqlError::Unsupported("multiple statements".into()));
183 }
184
185 convert_statement(stmts.into_iter().next().unwrap())
186}
187
188pub fn count_params(stmt: &Statement) -> usize {
192 let mut max_idx = 0usize;
193 visit_exprs_stmt(stmt, &mut |e| {
194 if let Expr::Parameter(n) = e {
195 max_idx = max_idx.max(*n);
196 }
197 });
198 max_idx
199}
200
201pub fn bind_params(stmt: &Statement, params: &[crate::types::Value]) -> crate::error::Result<Statement> {
203 bind_stmt(stmt, params)
204}
205
206fn bind_stmt(stmt: &Statement, params: &[crate::types::Value]) -> crate::error::Result<Statement> {
207 match stmt {
208 Statement::Select(sel) => Ok(Statement::Select(bind_select(sel, params)?)),
209 Statement::Insert(ins) => {
210 let values = ins.values.iter()
211 .map(|row| row.iter().map(|e| bind_expr(e, params)).collect::<crate::error::Result<Vec<_>>>())
212 .collect::<crate::error::Result<Vec<_>>>()?;
213 Ok(Statement::Insert(InsertStmt {
214 table: ins.table.clone(),
215 columns: ins.columns.clone(),
216 values,
217 }))
218 }
219 Statement::Update(upd) => {
220 let assignments = upd.assignments.iter()
221 .map(|(col, e)| Ok((col.clone(), bind_expr(e, params)?)))
222 .collect::<crate::error::Result<Vec<_>>>()?;
223 let where_clause = upd.where_clause.as_ref()
224 .map(|e| bind_expr(e, params)).transpose()?;
225 Ok(Statement::Update(UpdateStmt {
226 table: upd.table.clone(),
227 assignments,
228 where_clause,
229 }))
230 }
231 Statement::Delete(del) => {
232 let where_clause = del.where_clause.as_ref()
233 .map(|e| bind_expr(e, params)).transpose()?;
234 Ok(Statement::Delete(DeleteStmt {
235 table: del.table.clone(),
236 where_clause,
237 }))
238 }
239 Statement::Explain(inner) => Ok(Statement::Explain(Box::new(bind_stmt(inner, params)?))),
240 other => Ok(other.clone()),
241 }
242}
243
244fn bind_select(sel: &SelectStmt, params: &[crate::types::Value]) -> crate::error::Result<SelectStmt> {
245 let columns = sel.columns.iter().map(|c| match c {
246 SelectColumn::AllColumns => Ok(SelectColumn::AllColumns),
247 SelectColumn::Expr { expr, alias } => Ok(SelectColumn::Expr {
248 expr: bind_expr(expr, params)?,
249 alias: alias.clone(),
250 }),
251 }).collect::<crate::error::Result<Vec<_>>>()?;
252 let joins = sel.joins.iter().map(|j| {
253 let on_clause = j.on_clause.as_ref()
254 .map(|e| bind_expr(e, params)).transpose()?;
255 Ok(JoinClause {
256 join_type: j.join_type,
257 table: j.table.clone(),
258 on_clause,
259 })
260 }).collect::<crate::error::Result<Vec<_>>>()?;
261 let where_clause = sel.where_clause.as_ref()
262 .map(|e| bind_expr(e, params)).transpose()?;
263 let order_by = sel.order_by.iter().map(|o| {
264 Ok(OrderByItem {
265 expr: bind_expr(&o.expr, params)?,
266 descending: o.descending,
267 nulls_first: o.nulls_first,
268 })
269 }).collect::<crate::error::Result<Vec<_>>>()?;
270 let limit = sel.limit.as_ref().map(|e| bind_expr(e, params)).transpose()?;
271 let offset = sel.offset.as_ref().map(|e| bind_expr(e, params)).transpose()?;
272 let group_by = sel.group_by.iter()
273 .map(|e| bind_expr(e, params))
274 .collect::<crate::error::Result<Vec<_>>>()?;
275 let having = sel.having.as_ref().map(|e| bind_expr(e, params)).transpose()?;
276
277 Ok(SelectStmt {
278 columns,
279 from: sel.from.clone(),
280 from_alias: sel.from_alias.clone(),
281 joins,
282 distinct: sel.distinct,
283 where_clause,
284 order_by,
285 limit,
286 offset,
287 group_by,
288 having,
289 })
290}
291
292fn bind_expr(expr: &Expr, params: &[crate::types::Value]) -> crate::error::Result<Expr> {
293 match expr {
294 Expr::Parameter(n) => {
295 if *n == 0 || *n > params.len() {
296 return Err(SqlError::ParameterCountMismatch {
297 expected: *n,
298 got: params.len(),
299 });
300 }
301 Ok(Expr::Literal(params[*n - 1].clone()))
302 }
303 Expr::Literal(_) | Expr::Column(_) | Expr::QualifiedColumn { .. }
304 | Expr::CountStar => Ok(expr.clone()),
305 Expr::BinaryOp { left, op, right } => Ok(Expr::BinaryOp {
306 left: Box::new(bind_expr(left, params)?),
307 op: *op,
308 right: Box::new(bind_expr(right, params)?),
309 }),
310 Expr::UnaryOp { op, expr: e } => Ok(Expr::UnaryOp {
311 op: *op,
312 expr: Box::new(bind_expr(e, params)?),
313 }),
314 Expr::IsNull(e) => Ok(Expr::IsNull(Box::new(bind_expr(e, params)?))),
315 Expr::IsNotNull(e) => Ok(Expr::IsNotNull(Box::new(bind_expr(e, params)?))),
316 Expr::Function { name, args } => {
317 let args = args.iter().map(|a| bind_expr(a, params))
318 .collect::<crate::error::Result<Vec<_>>>()?;
319 Ok(Expr::Function { name: name.clone(), args })
320 }
321 Expr::InSubquery { expr: e, subquery, negated } => Ok(Expr::InSubquery {
322 expr: Box::new(bind_expr(e, params)?),
323 subquery: Box::new(bind_select(subquery, params)?),
324 negated: *negated,
325 }),
326 Expr::InList { expr: e, list, negated } => {
327 let list = list.iter().map(|l| bind_expr(l, params))
328 .collect::<crate::error::Result<Vec<_>>>()?;
329 Ok(Expr::InList {
330 expr: Box::new(bind_expr(e, params)?),
331 list,
332 negated: *negated,
333 })
334 }
335 Expr::Exists { subquery, negated } => Ok(Expr::Exists {
336 subquery: Box::new(bind_select(subquery, params)?),
337 negated: *negated,
338 }),
339 Expr::ScalarSubquery(sq) => Ok(Expr::ScalarSubquery(
340 Box::new(bind_select(sq, params)?),
341 )),
342 Expr::InSet { expr: e, values, has_null, negated } => Ok(Expr::InSet {
343 expr: Box::new(bind_expr(e, params)?),
344 values: values.clone(),
345 has_null: *has_null,
346 negated: *negated,
347 }),
348 Expr::Between { expr: e, low, high, negated } => Ok(Expr::Between {
349 expr: Box::new(bind_expr(e, params)?),
350 low: Box::new(bind_expr(low, params)?),
351 high: Box::new(bind_expr(high, params)?),
352 negated: *negated,
353 }),
354 Expr::Like { expr: e, pattern, escape, negated } => Ok(Expr::Like {
355 expr: Box::new(bind_expr(e, params)?),
356 pattern: Box::new(bind_expr(pattern, params)?),
357 escape: escape.as_ref().map(|esc| bind_expr(esc, params).map(Box::new)).transpose()?,
358 negated: *negated,
359 }),
360 Expr::Case { operand, conditions, else_result } => {
361 let operand = operand.as_ref()
362 .map(|e| bind_expr(e, params).map(Box::new)).transpose()?;
363 let conditions = conditions.iter()
364 .map(|(cond, then)| Ok((bind_expr(cond, params)?, bind_expr(then, params)?)))
365 .collect::<crate::error::Result<Vec<_>>>()?;
366 let else_result = else_result.as_ref()
367 .map(|e| bind_expr(e, params).map(Box::new)).transpose()?;
368 Ok(Expr::Case { operand, conditions, else_result })
369 }
370 Expr::Coalesce(args) => {
371 let args = args.iter().map(|a| bind_expr(a, params))
372 .collect::<crate::error::Result<Vec<_>>>()?;
373 Ok(Expr::Coalesce(args))
374 }
375 Expr::Cast { expr: e, data_type } => Ok(Expr::Cast {
376 expr: Box::new(bind_expr(e, params)?),
377 data_type: data_type.clone(),
378 }),
379 }
380}
381
382fn visit_exprs_stmt(stmt: &Statement, visitor: &mut impl FnMut(&Expr)) {
383 match stmt {
384 Statement::Select(sel) => visit_exprs_select(sel, visitor),
385 Statement::Insert(ins) => {
386 for row in &ins.values {
387 for e in row { visit_expr(e, visitor); }
388 }
389 }
390 Statement::Update(upd) => {
391 for (_, e) in &upd.assignments { visit_expr(e, visitor); }
392 if let Some(w) = &upd.where_clause { visit_expr(w, visitor); }
393 }
394 Statement::Delete(del) => {
395 if let Some(w) = &del.where_clause { visit_expr(w, visitor); }
396 }
397 Statement::Explain(inner) => visit_exprs_stmt(inner, visitor),
398 _ => {}
399 }
400}
401
402fn visit_exprs_select(sel: &SelectStmt, visitor: &mut impl FnMut(&Expr)) {
403 for col in &sel.columns {
404 if let SelectColumn::Expr { expr, .. } = col { visit_expr(expr, visitor); }
405 }
406 for j in &sel.joins {
407 if let Some(on) = &j.on_clause { visit_expr(on, visitor); }
408 }
409 if let Some(w) = &sel.where_clause { visit_expr(w, visitor); }
410 for o in &sel.order_by { visit_expr(&o.expr, visitor); }
411 if let Some(l) = &sel.limit { visit_expr(l, visitor); }
412 if let Some(o) = &sel.offset { visit_expr(o, visitor); }
413 for g in &sel.group_by { visit_expr(g, visitor); }
414 if let Some(h) = &sel.having { visit_expr(h, visitor); }
415}
416
417fn visit_expr(expr: &Expr, visitor: &mut impl FnMut(&Expr)) {
418 visitor(expr);
419 match expr {
420 Expr::BinaryOp { left, right, .. } => {
421 visit_expr(left, visitor);
422 visit_expr(right, visitor);
423 }
424 Expr::UnaryOp { expr: e, .. } | Expr::IsNull(e) | Expr::IsNotNull(e) => {
425 visit_expr(e, visitor);
426 }
427 Expr::Function { args, .. } | Expr::Coalesce(args) => {
428 for a in args { visit_expr(a, visitor); }
429 }
430 Expr::InSubquery { expr: e, subquery, .. } => {
431 visit_expr(e, visitor);
432 visit_exprs_select(subquery, visitor);
433 }
434 Expr::InList { expr: e, list, .. } => {
435 visit_expr(e, visitor);
436 for l in list { visit_expr(l, visitor); }
437 }
438 Expr::Exists { subquery, .. } => visit_exprs_select(subquery, visitor),
439 Expr::ScalarSubquery(sq) => visit_exprs_select(sq, visitor),
440 Expr::InSet { expr: e, .. } => visit_expr(e, visitor),
441 Expr::Between { expr: e, low, high, .. } => {
442 visit_expr(e, visitor);
443 visit_expr(low, visitor);
444 visit_expr(high, visitor);
445 }
446 Expr::Like { expr: e, pattern, escape, .. } => {
447 visit_expr(e, visitor);
448 visit_expr(pattern, visitor);
449 if let Some(esc) = escape { visit_expr(esc, visitor); }
450 }
451 Expr::Case { operand, conditions, else_result } => {
452 if let Some(op) = operand { visit_expr(op, visitor); }
453 for (cond, then) in conditions {
454 visit_expr(cond, visitor);
455 visit_expr(then, visitor);
456 }
457 if let Some(el) = else_result { visit_expr(el, visitor); }
458 }
459 Expr::Cast { expr: e, .. } => visit_expr(e, visitor),
460 Expr::Literal(_) | Expr::Column(_) | Expr::QualifiedColumn { .. }
461 | Expr::CountStar | Expr::Parameter(_) => {}
462 }
463}
464
465fn convert_statement(stmt: sp::Statement) -> Result<Statement> {
468 match stmt {
469 sp::Statement::CreateTable(ct) => convert_create_table(ct),
470 sp::Statement::CreateIndex(ci) => convert_create_index(ci),
471 sp::Statement::Drop {
472 object_type: sp::ObjectType::Table,
473 if_exists,
474 names,
475 ..
476 } => {
477 if names.len() != 1 {
478 return Err(SqlError::Unsupported("multi-table DROP".into()));
479 }
480 Ok(Statement::DropTable(DropTableStmt {
481 name: object_name_to_string(&names[0]),
482 if_exists,
483 }))
484 }
485 sp::Statement::Drop {
486 object_type: sp::ObjectType::Index,
487 if_exists,
488 names,
489 ..
490 } => {
491 if names.len() != 1 {
492 return Err(SqlError::Unsupported("multi-index DROP".into()));
493 }
494 Ok(Statement::DropIndex(DropIndexStmt {
495 index_name: object_name_to_string(&names[0]),
496 if_exists,
497 }))
498 }
499 sp::Statement::Insert(insert) => convert_insert(insert),
500 sp::Statement::Query(query) => convert_query(*query),
501 sp::Statement::Update(update) => convert_update(update),
502 sp::Statement::Delete(delete) => convert_delete(delete),
503 sp::Statement::StartTransaction { .. } => Ok(Statement::Begin),
504 sp::Statement::Commit { .. } => Ok(Statement::Commit),
505 sp::Statement::Rollback { .. } => Ok(Statement::Rollback),
506 sp::Statement::Explain { statement, analyze, .. } => {
507 if analyze {
508 return Err(SqlError::Unsupported("EXPLAIN ANALYZE".into()));
509 }
510 let inner = convert_statement(*statement)?;
511 Ok(Statement::Explain(Box::new(inner)))
512 }
513 _ => Err(SqlError::Unsupported(format!(
514 "statement type: {}",
515 stmt
516 ))),
517 }
518}
519
520fn convert_create_table(ct: sp::CreateTable) -> Result<Statement> {
521 let name = object_name_to_string(&ct.name);
522 let if_not_exists = ct.if_not_exists;
523
524 let mut columns = Vec::new();
525 let mut inline_pk: Vec<String> = Vec::new();
526
527 for col_def in &ct.columns {
528 let col_name = col_def.name.value.clone();
529 let data_type = convert_data_type(&col_def.data_type)?;
530 let mut nullable = true;
531 let mut is_primary_key = false;
532
533 for opt in &col_def.options {
534 match &opt.option {
535 sp::ColumnOption::NotNull => nullable = false,
536 sp::ColumnOption::Null => nullable = true,
537 sp::ColumnOption::PrimaryKey(_) => {
538 is_primary_key = true;
539 nullable = false;
540 inline_pk.push(col_name.clone());
541 }
542 _ => {}
543 }
544 }
545
546 columns.push(ColumnSpec {
547 name: col_name,
548 data_type,
549 nullable,
550 is_primary_key,
551 });
552 }
553
554 for constraint in &ct.constraints {
556 if let sp::TableConstraint::PrimaryKey(pk_constraint) = constraint {
557 for idx_col in &pk_constraint.columns {
558 let col_name = match &idx_col.column.expr {
560 sp::Expr::Identifier(ident) => ident.value.clone(),
561 _ => continue,
562 };
563 if !inline_pk.contains(&col_name) {
564 inline_pk.push(col_name.clone());
565 }
566 if let Some(col) = columns.iter_mut().find(|c| c.name == col_name) {
567 col.nullable = false;
568 col.is_primary_key = true;
569 }
570 }
571 }
572 }
573
574 Ok(Statement::CreateTable(CreateTableStmt {
575 name,
576 columns,
577 primary_key: inline_pk,
578 if_not_exists,
579 }))
580}
581
582fn convert_create_index(ci: sp::CreateIndex) -> Result<Statement> {
583 let index_name = ci.name
584 .as_ref()
585 .map(object_name_to_string)
586 .ok_or_else(|| SqlError::Parse("index name required".into()))?;
587
588 let table_name = object_name_to_string(&ci.table_name);
589
590 let columns: Vec<String> = ci.columns.iter().map(|idx_col| {
591 match &idx_col.column.expr {
592 sp::Expr::Identifier(ident) => Ok(ident.value.clone()),
593 other => Err(SqlError::Unsupported(format!("expression index: {other}"))),
594 }
595 }).collect::<Result<_>>()?;
596
597 if columns.is_empty() {
598 return Err(SqlError::Parse("index must have at least one column".into()));
599 }
600
601 Ok(Statement::CreateIndex(CreateIndexStmt {
602 index_name,
603 table_name,
604 columns,
605 unique: ci.unique,
606 if_not_exists: ci.if_not_exists,
607 }))
608}
609
610fn convert_insert(insert: sp::Insert) -> Result<Statement> {
611 let table = match &insert.table {
612 sp::TableObject::TableName(name) => object_name_to_string(name),
613 _ => return Err(SqlError::Unsupported("INSERT into non-table object".into())),
614 };
615
616 let columns: Vec<String> = insert.columns.iter().map(|c| c.value.clone()).collect();
617
618 let source = insert.source.ok_or_else(|| {
619 SqlError::Parse("INSERT requires VALUES".into())
620 })?;
621
622 let values = match *source.body {
623 sp::SetExpr::Values(sp::Values { rows, .. }) => {
624 let mut result = Vec::new();
625 for row in rows {
626 let mut exprs = Vec::new();
627 for expr in row {
628 exprs.push(convert_expr(&expr)?);
629 }
630 result.push(exprs);
631 }
632 result
633 }
634 _ => return Err(SqlError::Unsupported("INSERT ... SELECT".into())),
635 };
636
637 Ok(Statement::Insert(InsertStmt {
638 table,
639 columns,
640 values,
641 }))
642}
643
644fn convert_subquery(query: &sp::Query) -> Result<SelectStmt> {
645 match convert_query(query.clone())? {
646 Statement::Select(s) => Ok(s),
647 _ => Err(SqlError::Unsupported("non-SELECT subquery".into())),
648 }
649}
650
651fn convert_query(query: sp::Query) -> Result<Statement> {
652 let select = match *query.body {
653 sp::SetExpr::Select(sel) => *sel,
654 _ => return Err(SqlError::Unsupported("UNION/INTERSECT/EXCEPT".into())),
655 };
656
657 let distinct = match &select.distinct {
658 Some(sp::Distinct::Distinct) => true,
659 Some(sp::Distinct::On(_)) => {
660 return Err(SqlError::Unsupported("DISTINCT ON".into()));
661 }
662 _ => false,
663 };
664
665 let (from, from_alias, joins) = if select.from.is_empty() {
667 (String::new(), None, vec![])
668 } else if select.from.len() == 1 {
669 let table_with_joins = &select.from[0];
670 let (name, alias) = match &table_with_joins.relation {
671 sp::TableFactor::Table { name, alias, .. } => {
672 let table_name = object_name_to_string(name);
673 let alias_str = alias.as_ref().map(|a| a.name.value.clone());
674 (table_name, alias_str)
675 }
676 _ => return Err(SqlError::Unsupported("non-table FROM source".into())),
677 };
678 let j = table_with_joins.joins.iter()
679 .map(|j| convert_join(j))
680 .collect::<Result<Vec<_>>>()?;
681 (name, alias, j)
682 } else {
683 return Err(SqlError::Unsupported("comma-separated FROM tables".into()));
684 };
685
686 let columns: Vec<SelectColumn> = select.projection.iter()
688 .map(convert_select_item)
689 .collect::<Result<_>>()?;
690
691 let where_clause = select.selection.as_ref()
693 .map(convert_expr)
694 .transpose()?;
695
696 let order_by = if let Some(ref ob) = query.order_by {
698 match &ob.kind {
699 sp::OrderByKind::Expressions(exprs) => {
700 exprs.iter().map(convert_order_by_expr).collect::<Result<_>>()?
701 }
702 sp::OrderByKind::All { .. } => {
703 return Err(SqlError::Unsupported("ORDER BY ALL".into()));
704 }
705 }
706 } else {
707 vec![]
708 };
709
710 let (limit, offset) = match &query.limit_clause {
712 Some(sp::LimitClause::LimitOffset { limit, offset, .. }) => {
713 let l = limit.as_ref().map(convert_expr).transpose()?;
714 let o = offset.as_ref().map(|o| convert_expr(&o.value)).transpose()?;
715 (l, o)
716 }
717 Some(sp::LimitClause::OffsetCommaLimit { limit, offset }) => {
718 let l = Some(convert_expr(limit)?);
719 let o = Some(convert_expr(offset)?);
720 (l, o)
721 }
722 None => (None, None),
723 };
724
725 let group_by = match &select.group_by {
727 sp::GroupByExpr::Expressions(exprs, _) => {
728 exprs.iter().map(convert_expr).collect::<Result<_>>()?
729 }
730 sp::GroupByExpr::All(_) => {
731 return Err(SqlError::Unsupported("GROUP BY ALL".into()));
732 }
733 };
734
735 let having = select.having.as_ref().map(convert_expr).transpose()?;
737
738 Ok(Statement::Select(SelectStmt {
739 columns,
740 from,
741 from_alias,
742 joins,
743 distinct,
744 where_clause,
745 order_by,
746 limit,
747 offset,
748 group_by,
749 having,
750 }))
751}
752
753fn convert_join(join: &sp::Join) -> Result<JoinClause> {
754 let (join_type, constraint) = match &join.join_operator {
755 sp::JoinOperator::Inner(c) => (JoinType::Inner, Some(c)),
756 sp::JoinOperator::Join(c) => (JoinType::Inner, Some(c)),
757 sp::JoinOperator::CrossJoin(c) => (JoinType::Cross, Some(c)),
758 sp::JoinOperator::Left(c) => (JoinType::Left, Some(c)),
759 sp::JoinOperator::LeftSemi(c) => (JoinType::Left, Some(c)),
760 sp::JoinOperator::LeftAnti(c) => (JoinType::Left, Some(c)),
761 sp::JoinOperator::Right(c) => (JoinType::Right, Some(c)),
762 sp::JoinOperator::RightSemi(c) => (JoinType::Right, Some(c)),
763 sp::JoinOperator::RightAnti(c) => (JoinType::Right, Some(c)),
764 other => return Err(SqlError::Unsupported(format!("join type: {other:?}"))),
765 };
766
767 let (name, alias) = match &join.relation {
768 sp::TableFactor::Table { name, alias, .. } => {
769 let table_name = object_name_to_string(name);
770 let alias_str = alias.as_ref().map(|a| a.name.value.clone());
771 (table_name, alias_str)
772 }
773 _ => return Err(SqlError::Unsupported("non-table JOIN source".into())),
774 };
775
776 let on_clause = match constraint {
777 Some(sp::JoinConstraint::On(expr)) => Some(convert_expr(expr)?),
778 Some(sp::JoinConstraint::None) | None => None,
779 Some(other) => return Err(SqlError::Unsupported(format!("join constraint: {other:?}"))),
780 };
781
782 Ok(JoinClause {
783 join_type,
784 table: TableRef { name, alias },
785 on_clause,
786 })
787}
788
789fn convert_update(update: sp::Update) -> Result<Statement> {
790 let table = match &update.table.relation {
791 sp::TableFactor::Table { name, .. } => object_name_to_string(name),
792 _ => return Err(SqlError::Unsupported("non-table UPDATE target".into())),
793 };
794
795 let assignments = update.assignments.iter()
796 .map(|a| {
797 let col = match &a.target {
798 sp::AssignmentTarget::ColumnName(name) => object_name_to_string(name),
799 _ => return Err(SqlError::Unsupported("tuple assignment".into())),
800 };
801 let expr = convert_expr(&a.value)?;
802 Ok((col, expr))
803 })
804 .collect::<Result<_>>()?;
805
806 let where_clause = update.selection.as_ref()
807 .map(convert_expr)
808 .transpose()?;
809
810 Ok(Statement::Update(UpdateStmt {
811 table,
812 assignments,
813 where_clause,
814 }))
815}
816
817fn convert_delete(delete: sp::Delete) -> Result<Statement> {
818 let table_name = match &delete.from {
819 sp::FromTable::WithFromKeyword(tables) => {
820 if tables.len() != 1 {
821 return Err(SqlError::Unsupported("multi-table DELETE".into()));
822 }
823 match &tables[0].relation {
824 sp::TableFactor::Table { name, .. } => object_name_to_string(name),
825 _ => return Err(SqlError::Unsupported("non-table DELETE target".into())),
826 }
827 }
828 sp::FromTable::WithoutKeyword(tables) => {
829 if tables.len() != 1 {
830 return Err(SqlError::Unsupported("multi-table DELETE".into()));
831 }
832 match &tables[0].relation {
833 sp::TableFactor::Table { name, .. } => object_name_to_string(name),
834 _ => return Err(SqlError::Unsupported("non-table DELETE target".into())),
835 }
836 }
837 };
838
839 let where_clause = delete.selection.as_ref()
840 .map(convert_expr)
841 .transpose()?;
842
843 Ok(Statement::Delete(DeleteStmt {
844 table: table_name,
845 where_clause,
846 }))
847}
848
849fn convert_expr(expr: &sp::Expr) -> Result<Expr> {
852 match expr {
853 sp::Expr::Value(v) => convert_value(&v.value),
854 sp::Expr::Identifier(ident) => Ok(Expr::Column(ident.value.clone())),
855 sp::Expr::CompoundIdentifier(parts) => {
856 if parts.len() == 2 {
857 Ok(Expr::QualifiedColumn {
858 table: parts[0].value.clone(),
859 column: parts[1].value.clone(),
860 })
861 } else {
862 Ok(Expr::Column(parts.last().unwrap().value.clone()))
863 }
864 }
865 sp::Expr::BinaryOp { left, op, right } => {
866 let bin_op = convert_bin_op(op)?;
867 Ok(Expr::BinaryOp {
868 left: Box::new(convert_expr(left)?),
869 op: bin_op,
870 right: Box::new(convert_expr(right)?),
871 })
872 }
873 sp::Expr::UnaryOp { op, expr } => {
874 let unary_op = match op {
875 sp::UnaryOperator::Minus => UnaryOp::Neg,
876 sp::UnaryOperator::Not => UnaryOp::Not,
877 _ => return Err(SqlError::Unsupported(format!("unary op: {op}"))),
878 };
879 Ok(Expr::UnaryOp {
880 op: unary_op,
881 expr: Box::new(convert_expr(expr)?),
882 })
883 }
884 sp::Expr::IsNull(e) => Ok(Expr::IsNull(Box::new(convert_expr(e)?))),
885 sp::Expr::IsNotNull(e) => Ok(Expr::IsNotNull(Box::new(convert_expr(e)?))),
886 sp::Expr::Nested(e) => convert_expr(e),
887 sp::Expr::Function(func) => convert_function(func),
888 sp::Expr::InSubquery { expr: e, subquery, negated } => {
889 let inner_expr = convert_expr(e)?;
890 let stmt = convert_subquery(subquery)?;
891 Ok(Expr::InSubquery {
892 expr: Box::new(inner_expr),
893 subquery: Box::new(stmt),
894 negated: *negated,
895 })
896 }
897 sp::Expr::InList { expr: e, list, negated } => {
898 let inner_expr = convert_expr(e)?;
899 let items = list.iter().map(convert_expr).collect::<Result<Vec<_>>>()?;
900 Ok(Expr::InList {
901 expr: Box::new(inner_expr),
902 list: items,
903 negated: *negated,
904 })
905 }
906 sp::Expr::Exists { subquery, negated } => {
907 let stmt = convert_subquery(subquery)?;
908 Ok(Expr::Exists {
909 subquery: Box::new(stmt),
910 negated: *negated,
911 })
912 }
913 sp::Expr::Subquery(query) => {
914 let stmt = convert_subquery(query)?;
915 Ok(Expr::ScalarSubquery(Box::new(stmt)))
916 }
917 sp::Expr::Between { expr: e, negated, low, high } => {
918 Ok(Expr::Between {
919 expr: Box::new(convert_expr(e)?),
920 low: Box::new(convert_expr(low)?),
921 high: Box::new(convert_expr(high)?),
922 negated: *negated,
923 })
924 }
925 sp::Expr::Like { expr: e, negated, pattern, escape_char, .. } => {
926 let esc = escape_char.as_ref()
927 .map(|v| convert_escape_value(v))
928 .transpose()?
929 .map(Box::new);
930 Ok(Expr::Like {
931 expr: Box::new(convert_expr(e)?),
932 pattern: Box::new(convert_expr(pattern)?),
933 escape: esc,
934 negated: *negated,
935 })
936 }
937 sp::Expr::ILike { expr: e, negated, pattern, escape_char, .. } => {
938 let esc = escape_char.as_ref()
939 .map(|v| convert_escape_value(v))
940 .transpose()?
941 .map(Box::new);
942 Ok(Expr::Like {
943 expr: Box::new(convert_expr(e)?),
944 pattern: Box::new(convert_expr(pattern)?),
945 escape: esc,
946 negated: *negated,
947 })
948 }
949 sp::Expr::Case { operand, conditions, else_result, .. } => {
950 let op = operand.as_ref()
951 .map(|e| convert_expr(e))
952 .transpose()?
953 .map(Box::new);
954 let conds: Vec<(Expr, Expr)> = conditions.iter()
955 .map(|cw| Ok((convert_expr(&cw.condition)?, convert_expr(&cw.result)?)))
956 .collect::<Result<_>>()?;
957 let else_r = else_result.as_ref()
958 .map(|e| convert_expr(e))
959 .transpose()?
960 .map(Box::new);
961 Ok(Expr::Case { operand: op, conditions: conds, else_result: else_r })
962 }
963 sp::Expr::Cast { expr: e, data_type: dt, .. } => {
964 let target = convert_data_type(dt)?;
965 Ok(Expr::Cast {
966 expr: Box::new(convert_expr(e)?),
967 data_type: target,
968 })
969 }
970 sp::Expr::Substring { expr: e, substring_from, substring_for, .. } => {
971 let mut args = vec![convert_expr(e)?];
972 if let Some(from) = substring_from {
973 args.push(convert_expr(from)?);
974 }
975 if let Some(f) = substring_for {
976 args.push(convert_expr(f)?);
977 }
978 Ok(Expr::Function { name: "SUBSTR".into(), args })
979 }
980 sp::Expr::Trim { expr: e, trim_where, trim_what, trim_characters } => {
981 let fn_name = match trim_where {
982 Some(sp::TrimWhereField::Leading) => "LTRIM",
983 Some(sp::TrimWhereField::Trailing) => "RTRIM",
984 _ => "TRIM",
985 };
986 let mut args = vec![convert_expr(e)?];
987 if let Some(what) = trim_what {
988 args.push(convert_expr(what)?);
989 } else if let Some(chars) = trim_characters {
990 if let Some(first) = chars.first() {
991 args.push(convert_expr(first)?);
992 }
993 }
994 Ok(Expr::Function { name: fn_name.into(), args })
995 }
996 sp::Expr::Ceil { expr: e, .. } => {
997 Ok(Expr::Function { name: "CEIL".into(), args: vec![convert_expr(e)?] })
998 }
999 sp::Expr::Floor { expr: e, .. } => {
1000 Ok(Expr::Function { name: "FLOOR".into(), args: vec![convert_expr(e)?] })
1001 }
1002 sp::Expr::Position { expr: e, r#in } => {
1003 Ok(Expr::Function { name: "INSTR".into(), args: vec![convert_expr(r#in)?, convert_expr(e)?] })
1004 }
1005 _ => Err(SqlError::Unsupported(format!("expression: {expr}"))),
1006 }
1007}
1008
1009fn convert_value(val: &sp::Value) -> Result<Expr> {
1010 match val {
1011 sp::Value::Number(n, _) => {
1012 if let Ok(i) = n.parse::<i64>() {
1013 Ok(Expr::Literal(Value::Integer(i)))
1014 } else if let Ok(f) = n.parse::<f64>() {
1015 Ok(Expr::Literal(Value::Real(f)))
1016 } else {
1017 Err(SqlError::InvalidValue(format!("cannot parse number: {n}")))
1018 }
1019 }
1020 sp::Value::SingleQuotedString(s) => Ok(Expr::Literal(Value::Text(s.clone()))),
1021 sp::Value::Boolean(b) => Ok(Expr::Literal(Value::Boolean(*b))),
1022 sp::Value::Null => Ok(Expr::Literal(Value::Null)),
1023 sp::Value::Placeholder(s) => {
1024 let idx_str = s.strip_prefix('$')
1025 .ok_or_else(|| SqlError::Parse(format!("invalid placeholder: {s}")))?;
1026 let idx: usize = idx_str.parse()
1027 .map_err(|_| SqlError::Parse(format!("invalid placeholder index: {s}")))?;
1028 if idx == 0 {
1029 return Err(SqlError::Parse("placeholder index must be >= 1".into()));
1030 }
1031 Ok(Expr::Parameter(idx))
1032 }
1033 _ => Err(SqlError::Unsupported(format!("value type: {val}"))),
1034 }
1035}
1036
1037fn convert_escape_value(val: &sp::Value) -> Result<Expr> {
1038 match val {
1039 sp::Value::SingleQuotedString(s) => Ok(Expr::Literal(Value::Text(s.clone()))),
1040 _ => Err(SqlError::Unsupported(format!("ESCAPE value: {val}"))),
1041 }
1042}
1043
1044fn convert_bin_op(op: &sp::BinaryOperator) -> Result<BinOp> {
1045 match op {
1046 sp::BinaryOperator::Plus => Ok(BinOp::Add),
1047 sp::BinaryOperator::Minus => Ok(BinOp::Sub),
1048 sp::BinaryOperator::Multiply => Ok(BinOp::Mul),
1049 sp::BinaryOperator::Divide => Ok(BinOp::Div),
1050 sp::BinaryOperator::Modulo => Ok(BinOp::Mod),
1051 sp::BinaryOperator::Eq => Ok(BinOp::Eq),
1052 sp::BinaryOperator::NotEq => Ok(BinOp::NotEq),
1053 sp::BinaryOperator::Lt => Ok(BinOp::Lt),
1054 sp::BinaryOperator::Gt => Ok(BinOp::Gt),
1055 sp::BinaryOperator::LtEq => Ok(BinOp::LtEq),
1056 sp::BinaryOperator::GtEq => Ok(BinOp::GtEq),
1057 sp::BinaryOperator::And => Ok(BinOp::And),
1058 sp::BinaryOperator::Or => Ok(BinOp::Or),
1059 sp::BinaryOperator::StringConcat => Ok(BinOp::Concat),
1060 _ => Err(SqlError::Unsupported(format!("binary op: {op}"))),
1061 }
1062}
1063
1064fn convert_function(func: &sp::Function) -> Result<Expr> {
1065 let name = object_name_to_string(&func.name).to_ascii_uppercase();
1066
1067 match &func.args {
1069 sp::FunctionArguments::List(list) => {
1070 if list.args.is_empty() && name == "COUNT" {
1071 return Ok(Expr::CountStar);
1072 }
1073 let args = list.args.iter()
1074 .map(|arg| match arg {
1075 sp::FunctionArg::Unnamed(sp::FunctionArgExpr::Expr(e)) => convert_expr(e),
1076 sp::FunctionArg::Unnamed(sp::FunctionArgExpr::Wildcard) => {
1077 if name == "COUNT" {
1078 Ok(Expr::CountStar)
1079 } else {
1080 Err(SqlError::Unsupported(format!("{name}(*)")))
1081 }
1082 }
1083 _ => Err(SqlError::Unsupported(format!("function arg type in {name}"))),
1084 })
1085 .collect::<Result<Vec<_>>>()?;
1086
1087 if name == "COUNT" && args.len() == 1 && matches!(args[0], Expr::CountStar) {
1088 return Ok(Expr::CountStar);
1089 }
1090
1091 if name == "COALESCE" {
1092 if args.is_empty() {
1093 return Err(SqlError::Parse("COALESCE requires at least one argument".into()));
1094 }
1095 return Ok(Expr::Coalesce(args));
1096 }
1097
1098 if name == "NULLIF" {
1099 if args.len() != 2 {
1100 return Err(SqlError::Parse("NULLIF requires exactly two arguments".into()));
1101 }
1102 return Ok(Expr::Case {
1103 operand: None,
1104 conditions: vec![
1105 (Expr::BinaryOp {
1106 left: Box::new(args[0].clone()),
1107 op: BinOp::Eq,
1108 right: Box::new(args[1].clone()),
1109 }, Expr::Literal(Value::Null)),
1110 ],
1111 else_result: Some(Box::new(args[0].clone())),
1112 });
1113 }
1114
1115 if name == "IIF" {
1116 if args.len() != 3 {
1117 return Err(SqlError::Parse("IIF requires exactly three arguments".into()));
1118 }
1119 return Ok(Expr::Case {
1120 operand: None,
1121 conditions: vec![(args[0].clone(), args[1].clone())],
1122 else_result: Some(Box::new(args[2].clone())),
1123 });
1124 }
1125
1126 Ok(Expr::Function { name, args })
1127 }
1128 sp::FunctionArguments::None => {
1129 if name == "COUNT" {
1130 Ok(Expr::CountStar)
1131 } else {
1132 Ok(Expr::Function { name, args: vec![] })
1133 }
1134 }
1135 sp::FunctionArguments::Subquery(_) => {
1136 Err(SqlError::Unsupported("subquery in function".into()))
1137 }
1138 }
1139}
1140
1141fn convert_select_item(item: &sp::SelectItem) -> Result<SelectColumn> {
1142 match item {
1143 sp::SelectItem::Wildcard(_) => Ok(SelectColumn::AllColumns),
1144 sp::SelectItem::UnnamedExpr(e) => {
1145 let expr = convert_expr(e)?;
1146 Ok(SelectColumn::Expr { expr, alias: None })
1147 }
1148 sp::SelectItem::ExprWithAlias { expr, alias } => {
1149 let expr = convert_expr(expr)?;
1150 Ok(SelectColumn::Expr {
1151 expr,
1152 alias: Some(alias.value.clone()),
1153 })
1154 }
1155 sp::SelectItem::QualifiedWildcard(_, _) => {
1156 Err(SqlError::Unsupported("qualified wildcard (table.*)".into()))
1157 }
1158 }
1159}
1160
1161fn convert_order_by_expr(expr: &sp::OrderByExpr) -> Result<OrderByItem> {
1162 let e = convert_expr(&expr.expr)?;
1163 let descending = expr.options.asc.map(|asc| !asc).unwrap_or(false);
1164 let nulls_first = expr.options.nulls_first;
1165
1166 Ok(OrderByItem {
1167 expr: e,
1168 descending,
1169 nulls_first,
1170 })
1171}
1172
1173fn convert_data_type(dt: &sp::DataType) -> Result<DataType> {
1176 match dt {
1177 sp::DataType::Int(_)
1178 | sp::DataType::Integer(_)
1179 | sp::DataType::BigInt(_)
1180 | sp::DataType::SmallInt(_)
1181 | sp::DataType::TinyInt(_)
1182 | sp::DataType::Int2(_)
1183 | sp::DataType::Int4(_)
1184 | sp::DataType::Int8(_) => Ok(DataType::Integer),
1185
1186 sp::DataType::Real
1187 | sp::DataType::Double(..)
1188 | sp::DataType::DoublePrecision
1189 | sp::DataType::Float(_)
1190 | sp::DataType::Float4
1191 | sp::DataType::Float64 => Ok(DataType::Real),
1192
1193 sp::DataType::Varchar(_)
1194 | sp::DataType::Text
1195 | sp::DataType::Char(_)
1196 | sp::DataType::Character(_)
1197 | sp::DataType::String(_) => Ok(DataType::Text),
1198
1199 sp::DataType::Blob(_)
1200 | sp::DataType::Bytea => Ok(DataType::Blob),
1201
1202 sp::DataType::Boolean | sp::DataType::Bool => Ok(DataType::Boolean),
1203
1204 _ => Err(SqlError::Unsupported(format!("data type: {dt}"))),
1205 }
1206}
1207
1208fn object_name_to_string(name: &sp::ObjectName) -> String {
1211 name.0
1212 .iter()
1213 .filter_map(|p| match p {
1214 sp::ObjectNamePart::Identifier(ident) => Some(ident.value.clone()),
1215 _ => None,
1216 })
1217 .collect::<Vec<_>>()
1218 .join(".")
1219}
1220
1221#[cfg(test)]
1222mod tests {
1223 use super::*;
1224
1225 #[test]
1226 fn parse_create_table() {
1227 let stmt = parse_sql(
1228 "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL, age INTEGER)"
1229 ).unwrap();
1230
1231 match stmt {
1232 Statement::CreateTable(ct) => {
1233 assert_eq!(ct.name, "users");
1234 assert_eq!(ct.columns.len(), 3);
1235 assert_eq!(ct.columns[0].name, "id");
1236 assert_eq!(ct.columns[0].data_type, DataType::Integer);
1237 assert!(ct.columns[0].is_primary_key);
1238 assert!(!ct.columns[0].nullable);
1239 assert_eq!(ct.columns[1].name, "name");
1240 assert_eq!(ct.columns[1].data_type, DataType::Text);
1241 assert!(!ct.columns[1].nullable);
1242 assert_eq!(ct.columns[2].name, "age");
1243 assert!(ct.columns[2].nullable);
1244 assert_eq!(ct.primary_key, vec!["id"]);
1245 }
1246 _ => panic!("expected CreateTable"),
1247 }
1248 }
1249
1250 #[test]
1251 fn parse_create_table_if_not_exists() {
1252 let stmt = parse_sql("CREATE TABLE IF NOT EXISTS t (id INT PRIMARY KEY)").unwrap();
1253 match stmt {
1254 Statement::CreateTable(ct) => assert!(ct.if_not_exists),
1255 _ => panic!("expected CreateTable"),
1256 }
1257 }
1258
1259 #[test]
1260 fn parse_drop_table() {
1261 let stmt = parse_sql("DROP TABLE users").unwrap();
1262 match stmt {
1263 Statement::DropTable(dt) => {
1264 assert_eq!(dt.name, "users");
1265 assert!(!dt.if_exists);
1266 }
1267 _ => panic!("expected DropTable"),
1268 }
1269 }
1270
1271 #[test]
1272 fn parse_drop_table_if_exists() {
1273 let stmt = parse_sql("DROP TABLE IF EXISTS users").unwrap();
1274 match stmt {
1275 Statement::DropTable(dt) => assert!(dt.if_exists),
1276 _ => panic!("expected DropTable"),
1277 }
1278 }
1279
1280 #[test]
1281 fn parse_insert() {
1282 let stmt = parse_sql(
1283 "INSERT INTO users (id, name) VALUES (1, 'Alice'), (2, 'Bob')"
1284 ).unwrap();
1285
1286 match stmt {
1287 Statement::Insert(ins) => {
1288 assert_eq!(ins.table, "users");
1289 assert_eq!(ins.columns, vec!["id", "name"]);
1290 assert_eq!(ins.values.len(), 2);
1291 assert!(matches!(ins.values[0][0], Expr::Literal(Value::Integer(1))));
1292 assert!(matches!(&ins.values[0][1], Expr::Literal(Value::Text(s)) if s == "Alice"));
1293 }
1294 _ => panic!("expected Insert"),
1295 }
1296 }
1297
1298 #[test]
1299 fn parse_select_all() {
1300 let stmt = parse_sql("SELECT * FROM users").unwrap();
1301 match stmt {
1302 Statement::Select(sel) => {
1303 assert_eq!(sel.from, "users");
1304 assert!(matches!(sel.columns[0], SelectColumn::AllColumns));
1305 assert!(sel.where_clause.is_none());
1306 }
1307 _ => panic!("expected Select"),
1308 }
1309 }
1310
1311 #[test]
1312 fn parse_select_where() {
1313 let stmt = parse_sql("SELECT id, name FROM users WHERE age > 18").unwrap();
1314 match stmt {
1315 Statement::Select(sel) => {
1316 assert_eq!(sel.columns.len(), 2);
1317 assert!(sel.where_clause.is_some());
1318 }
1319 _ => panic!("expected Select"),
1320 }
1321 }
1322
1323 #[test]
1324 fn parse_select_order_limit() {
1325 let stmt = parse_sql(
1326 "SELECT * FROM users ORDER BY name ASC LIMIT 10 OFFSET 5"
1327 ).unwrap();
1328 match stmt {
1329 Statement::Select(sel) => {
1330 assert_eq!(sel.order_by.len(), 1);
1331 assert!(!sel.order_by[0].descending);
1332 assert!(sel.limit.is_some());
1333 assert!(sel.offset.is_some());
1334 }
1335 _ => panic!("expected Select"),
1336 }
1337 }
1338
1339 #[test]
1340 fn parse_update() {
1341 let stmt = parse_sql("UPDATE users SET name = 'Bob' WHERE id = 1").unwrap();
1342 match stmt {
1343 Statement::Update(upd) => {
1344 assert_eq!(upd.table, "users");
1345 assert_eq!(upd.assignments.len(), 1);
1346 assert_eq!(upd.assignments[0].0, "name");
1347 assert!(upd.where_clause.is_some());
1348 }
1349 _ => panic!("expected Update"),
1350 }
1351 }
1352
1353 #[test]
1354 fn parse_delete() {
1355 let stmt = parse_sql("DELETE FROM users WHERE id = 1").unwrap();
1356 match stmt {
1357 Statement::Delete(del) => {
1358 assert_eq!(del.table, "users");
1359 assert!(del.where_clause.is_some());
1360 }
1361 _ => panic!("expected Delete"),
1362 }
1363 }
1364
1365 #[test]
1366 fn parse_aggregate() {
1367 let stmt = parse_sql("SELECT COUNT(*), SUM(age) FROM users").unwrap();
1368 match stmt {
1369 Statement::Select(sel) => {
1370 assert_eq!(sel.columns.len(), 2);
1371 match &sel.columns[0] {
1372 SelectColumn::Expr { expr: Expr::CountStar, .. } => {}
1373 other => panic!("expected CountStar, got {other:?}"),
1374 }
1375 }
1376 _ => panic!("expected Select"),
1377 }
1378 }
1379
1380 #[test]
1381 fn parse_group_by_having() {
1382 let stmt = parse_sql(
1383 "SELECT department, COUNT(*) FROM employees GROUP BY department HAVING COUNT(*) > 5"
1384 ).unwrap();
1385 match stmt {
1386 Statement::Select(sel) => {
1387 assert_eq!(sel.group_by.len(), 1);
1388 assert!(sel.having.is_some());
1389 }
1390 _ => panic!("expected Select"),
1391 }
1392 }
1393
1394 #[test]
1395 fn parse_expressions() {
1396 let stmt = parse_sql("SELECT id + 1, -price, NOT active FROM items").unwrap();
1397 match stmt {
1398 Statement::Select(sel) => {
1399 assert_eq!(sel.columns.len(), 3);
1400 match &sel.columns[0] {
1402 SelectColumn::Expr { expr: Expr::BinaryOp { op: BinOp::Add, .. }, .. } => {}
1403 other => panic!("expected BinaryOp Add, got {other:?}"),
1404 }
1405 match &sel.columns[1] {
1407 SelectColumn::Expr { expr: Expr::UnaryOp { op: UnaryOp::Neg, .. }, .. } => {}
1408 other => panic!("expected UnaryOp Neg, got {other:?}"),
1409 }
1410 match &sel.columns[2] {
1412 SelectColumn::Expr { expr: Expr::UnaryOp { op: UnaryOp::Not, .. }, .. } => {}
1413 other => panic!("expected UnaryOp Not, got {other:?}"),
1414 }
1415 }
1416 _ => panic!("expected Select"),
1417 }
1418 }
1419
1420 #[test]
1421 fn parse_is_null() {
1422 let stmt = parse_sql("SELECT * FROM t WHERE x IS NULL").unwrap();
1423 match stmt {
1424 Statement::Select(sel) => {
1425 assert!(matches!(sel.where_clause, Some(Expr::IsNull(_))));
1426 }
1427 _ => panic!("expected Select"),
1428 }
1429 }
1430
1431 #[test]
1432 fn parse_inner_join() {
1433 let stmt = parse_sql("SELECT * FROM a JOIN b ON a.id = b.id").unwrap();
1434 match stmt {
1435 Statement::Select(sel) => {
1436 assert_eq!(sel.from, "a");
1437 assert_eq!(sel.joins.len(), 1);
1438 assert_eq!(sel.joins[0].join_type, JoinType::Inner);
1439 assert_eq!(sel.joins[0].table.name, "b");
1440 assert!(sel.joins[0].on_clause.is_some());
1441 }
1442 _ => panic!("expected Select"),
1443 }
1444 }
1445
1446 #[test]
1447 fn parse_inner_join_explicit() {
1448 let stmt = parse_sql("SELECT * FROM a INNER JOIN b ON a.id = b.a_id").unwrap();
1449 match stmt {
1450 Statement::Select(sel) => {
1451 assert_eq!(sel.joins.len(), 1);
1452 assert_eq!(sel.joins[0].join_type, JoinType::Inner);
1453 }
1454 _ => panic!("expected Select"),
1455 }
1456 }
1457
1458 #[test]
1459 fn parse_cross_join() {
1460 let stmt = parse_sql("SELECT * FROM a CROSS JOIN b").unwrap();
1461 match stmt {
1462 Statement::Select(sel) => {
1463 assert_eq!(sel.joins.len(), 1);
1464 assert_eq!(sel.joins[0].join_type, JoinType::Cross);
1465 assert!(sel.joins[0].on_clause.is_none());
1466 }
1467 _ => panic!("expected Select"),
1468 }
1469 }
1470
1471 #[test]
1472 fn parse_left_join() {
1473 let stmt = parse_sql("SELECT * FROM a LEFT JOIN b ON a.id = b.a_id").unwrap();
1474 match stmt {
1475 Statement::Select(sel) => {
1476 assert_eq!(sel.joins.len(), 1);
1477 assert_eq!(sel.joins[0].join_type, JoinType::Left);
1478 }
1479 _ => panic!("expected Select"),
1480 }
1481 }
1482
1483 #[test]
1484 fn parse_table_alias() {
1485 let stmt = parse_sql("SELECT u.id FROM users u JOIN orders o ON u.id = o.user_id").unwrap();
1486 match stmt {
1487 Statement::Select(sel) => {
1488 assert_eq!(sel.from, "users");
1489 assert_eq!(sel.from_alias.as_deref(), Some("u"));
1490 assert_eq!(sel.joins[0].table.name, "orders");
1491 assert_eq!(sel.joins[0].table.alias.as_deref(), Some("o"));
1492 }
1493 _ => panic!("expected Select"),
1494 }
1495 }
1496
1497 #[test]
1498 fn parse_multi_join() {
1499 let stmt = parse_sql(
1500 "SELECT * FROM a JOIN b ON a.id = b.a_id JOIN c ON b.id = c.b_id"
1501 ).unwrap();
1502 match stmt {
1503 Statement::Select(sel) => {
1504 assert_eq!(sel.joins.len(), 2);
1505 }
1506 _ => panic!("expected Select"),
1507 }
1508 }
1509
1510 #[test]
1511 fn parse_qualified_column() {
1512 let stmt = parse_sql("SELECT u.id, u.name FROM users u").unwrap();
1513 match stmt {
1514 Statement::Select(sel) => {
1515 match &sel.columns[0] {
1516 SelectColumn::Expr { expr: Expr::QualifiedColumn { table, column }, .. } => {
1517 assert_eq!(table, "u");
1518 assert_eq!(column, "id");
1519 }
1520 other => panic!("expected QualifiedColumn, got {other:?}"),
1521 }
1522 }
1523 _ => panic!("expected Select"),
1524 }
1525 }
1526
1527 #[test]
1528 fn reject_subquery() {
1529 assert!(parse_sql("SELECT * FROM (SELECT 1)").is_err());
1530 }
1531
1532 #[test]
1533 fn parse_type_mapping() {
1534 let stmt = parse_sql(
1535 "CREATE TABLE t (a INT PRIMARY KEY, b BIGINT, c SMALLINT, d REAL, e DOUBLE PRECISION, f VARCHAR(255), g BOOLEAN, h BLOB, i BYTEA)"
1536 ).unwrap();
1537 match stmt {
1538 Statement::CreateTable(ct) => {
1539 assert_eq!(ct.columns[0].data_type, DataType::Integer); assert_eq!(ct.columns[1].data_type, DataType::Integer); assert_eq!(ct.columns[2].data_type, DataType::Integer); assert_eq!(ct.columns[3].data_type, DataType::Real); assert_eq!(ct.columns[4].data_type, DataType::Real); assert_eq!(ct.columns[5].data_type, DataType::Text); assert_eq!(ct.columns[6].data_type, DataType::Boolean); assert_eq!(ct.columns[7].data_type, DataType::Blob); assert_eq!(ct.columns[8].data_type, DataType::Blob); }
1549 _ => panic!("expected CreateTable"),
1550 }
1551 }
1552
1553 #[test]
1554 fn parse_boolean_literals() {
1555 let stmt = parse_sql("INSERT INTO t (a, b) VALUES (true, false)").unwrap();
1556 match stmt {
1557 Statement::Insert(ins) => {
1558 assert!(matches!(ins.values[0][0], Expr::Literal(Value::Boolean(true))));
1559 assert!(matches!(ins.values[0][1], Expr::Literal(Value::Boolean(false))));
1560 }
1561 _ => panic!("expected Insert"),
1562 }
1563 }
1564
1565 #[test]
1566 fn parse_null_literal() {
1567 let stmt = parse_sql("INSERT INTO t (a) VALUES (NULL)").unwrap();
1568 match stmt {
1569 Statement::Insert(ins) => {
1570 assert!(matches!(ins.values[0][0], Expr::Literal(Value::Null)));
1571 }
1572 _ => panic!("expected Insert"),
1573 }
1574 }
1575
1576 #[test]
1577 fn parse_alias() {
1578 let stmt = parse_sql("SELECT id AS user_id FROM users").unwrap();
1579 match stmt {
1580 Statement::Select(sel) => {
1581 match &sel.columns[0] {
1582 SelectColumn::Expr { alias: Some(a), .. } => assert_eq!(a, "user_id"),
1583 other => panic!("expected alias, got {other:?}"),
1584 }
1585 }
1586 _ => panic!("expected Select"),
1587 }
1588 }
1589
1590 #[test]
1591 fn parse_begin() {
1592 let stmt = parse_sql("BEGIN").unwrap();
1593 assert!(matches!(stmt, Statement::Begin));
1594 }
1595
1596 #[test]
1597 fn parse_begin_transaction() {
1598 let stmt = parse_sql("BEGIN TRANSACTION").unwrap();
1599 assert!(matches!(stmt, Statement::Begin));
1600 }
1601
1602 #[test]
1603 fn parse_commit() {
1604 let stmt = parse_sql("COMMIT").unwrap();
1605 assert!(matches!(stmt, Statement::Commit));
1606 }
1607
1608 #[test]
1609 fn parse_rollback() {
1610 let stmt = parse_sql("ROLLBACK").unwrap();
1611 assert!(matches!(stmt, Statement::Rollback));
1612 }
1613
1614 #[test]
1615 fn parse_select_distinct() {
1616 let stmt = parse_sql("SELECT DISTINCT name FROM users").unwrap();
1617 match stmt {
1618 Statement::Select(sel) => {
1619 assert!(sel.distinct);
1620 assert_eq!(sel.columns.len(), 1);
1621 }
1622 _ => panic!("expected Select"),
1623 }
1624 }
1625
1626 #[test]
1627 fn parse_select_without_distinct() {
1628 let stmt = parse_sql("SELECT name FROM users").unwrap();
1629 match stmt {
1630 Statement::Select(sel) => {
1631 assert!(!sel.distinct);
1632 }
1633 _ => panic!("expected Select"),
1634 }
1635 }
1636
1637 #[test]
1638 fn parse_select_distinct_all_columns() {
1639 let stmt = parse_sql("SELECT DISTINCT * FROM users").unwrap();
1640 match stmt {
1641 Statement::Select(sel) => {
1642 assert!(sel.distinct);
1643 assert!(matches!(sel.columns[0], SelectColumn::AllColumns));
1644 }
1645 _ => panic!("expected Select"),
1646 }
1647 }
1648
1649 #[test]
1650 fn reject_distinct_on() {
1651 assert!(parse_sql("SELECT DISTINCT ON (id) * FROM users").is_err());
1652 }
1653
1654 #[test]
1655 fn parse_create_index() {
1656 let stmt = parse_sql("CREATE INDEX idx_name ON users (name)").unwrap();
1657 match stmt {
1658 Statement::CreateIndex(ci) => {
1659 assert_eq!(ci.index_name, "idx_name");
1660 assert_eq!(ci.table_name, "users");
1661 assert_eq!(ci.columns, vec!["name"]);
1662 assert!(!ci.unique);
1663 assert!(!ci.if_not_exists);
1664 }
1665 _ => panic!("expected CreateIndex"),
1666 }
1667 }
1668
1669 #[test]
1670 fn parse_create_unique_index() {
1671 let stmt = parse_sql("CREATE UNIQUE INDEX idx_email ON users (email)").unwrap();
1672 match stmt {
1673 Statement::CreateIndex(ci) => {
1674 assert!(ci.unique);
1675 assert_eq!(ci.columns, vec!["email"]);
1676 }
1677 _ => panic!("expected CreateIndex"),
1678 }
1679 }
1680
1681 #[test]
1682 fn parse_create_index_if_not_exists() {
1683 let stmt = parse_sql("CREATE INDEX IF NOT EXISTS idx_x ON t (a)").unwrap();
1684 match stmt {
1685 Statement::CreateIndex(ci) => assert!(ci.if_not_exists),
1686 _ => panic!("expected CreateIndex"),
1687 }
1688 }
1689
1690 #[test]
1691 fn parse_create_index_multi_column() {
1692 let stmt = parse_sql("CREATE INDEX idx_multi ON t (a, b, c)").unwrap();
1693 match stmt {
1694 Statement::CreateIndex(ci) => {
1695 assert_eq!(ci.columns, vec!["a", "b", "c"]);
1696 }
1697 _ => panic!("expected CreateIndex"),
1698 }
1699 }
1700
1701 #[test]
1702 fn parse_drop_index() {
1703 let stmt = parse_sql("DROP INDEX idx_name").unwrap();
1704 match stmt {
1705 Statement::DropIndex(di) => {
1706 assert_eq!(di.index_name, "idx_name");
1707 assert!(!di.if_exists);
1708 }
1709 _ => panic!("expected DropIndex"),
1710 }
1711 }
1712
1713 #[test]
1714 fn parse_drop_index_if_exists() {
1715 let stmt = parse_sql("DROP INDEX IF EXISTS idx_name").unwrap();
1716 match stmt {
1717 Statement::DropIndex(di) => {
1718 assert!(di.if_exists);
1719 }
1720 _ => panic!("expected DropIndex"),
1721 }
1722 }
1723
1724 #[test]
1725 fn parse_explain_select() {
1726 let stmt = parse_sql("EXPLAIN SELECT * FROM users WHERE id = 1").unwrap();
1727 match stmt {
1728 Statement::Explain(inner) => {
1729 assert!(matches!(*inner, Statement::Select(_)));
1730 }
1731 _ => panic!("expected Explain"),
1732 }
1733 }
1734
1735 #[test]
1736 fn parse_explain_insert() {
1737 let stmt = parse_sql("EXPLAIN INSERT INTO t (a) VALUES (1)").unwrap();
1738 assert!(matches!(stmt, Statement::Explain(_)));
1739 }
1740
1741 #[test]
1742 fn reject_explain_analyze() {
1743 assert!(parse_sql("EXPLAIN ANALYZE SELECT * FROM t").is_err());
1744 }
1745
1746 #[test]
1747 fn parse_parameter_placeholder() {
1748 let stmt = parse_sql("SELECT * FROM t WHERE id = $1").unwrap();
1749 match stmt {
1750 Statement::Select(sel) => {
1751 match &sel.where_clause {
1752 Some(Expr::BinaryOp { right, .. }) => {
1753 assert!(matches!(right.as_ref(), Expr::Parameter(1)));
1754 }
1755 other => panic!("expected BinaryOp with Parameter, got {other:?}"),
1756 }
1757 }
1758 _ => panic!("expected Select"),
1759 }
1760 }
1761
1762 #[test]
1763 fn parse_multiple_parameters() {
1764 let stmt = parse_sql("INSERT INTO t (a, b) VALUES ($1, $2)").unwrap();
1765 match stmt {
1766 Statement::Insert(ins) => {
1767 assert!(matches!(ins.values[0][0], Expr::Parameter(1)));
1768 assert!(matches!(ins.values[0][1], Expr::Parameter(2)));
1769 }
1770 _ => panic!("expected Insert"),
1771 }
1772 }
1773
1774 #[test]
1775 fn reject_zero_parameter() {
1776 assert!(parse_sql("SELECT $0 FROM t").is_err());
1777 }
1778
1779 #[test]
1780 fn count_params_basic() {
1781 let stmt = parse_sql("SELECT * FROM t WHERE a = $1 AND b = $3").unwrap();
1782 assert_eq!(count_params(&stmt), 3);
1783 }
1784
1785 #[test]
1786 fn count_params_none() {
1787 let stmt = parse_sql("SELECT * FROM t WHERE a = 1").unwrap();
1788 assert_eq!(count_params(&stmt), 0);
1789 }
1790
1791 #[test]
1792 fn bind_params_basic() {
1793 let stmt = parse_sql("SELECT * FROM t WHERE id = $1").unwrap();
1794 let bound = bind_params(&stmt, &[Value::Integer(42)]).unwrap();
1795 match bound {
1796 Statement::Select(sel) => {
1797 match &sel.where_clause {
1798 Some(Expr::BinaryOp { right, .. }) => {
1799 assert!(matches!(right.as_ref(), Expr::Literal(Value::Integer(42))));
1800 }
1801 other => panic!("expected BinaryOp with Literal, got {other:?}"),
1802 }
1803 }
1804 _ => panic!("expected Select"),
1805 }
1806 }
1807
1808 #[test]
1809 fn bind_params_out_of_range() {
1810 let stmt = parse_sql("SELECT * FROM t WHERE id = $2").unwrap();
1811 let result = bind_params(&stmt, &[Value::Integer(1)]);
1812 assert!(result.is_err());
1813 }
1814
1815 #[test]
1816 fn parse_table_constraint_pk() {
1817 let stmt = parse_sql(
1818 "CREATE TABLE t (a INTEGER, b TEXT, PRIMARY KEY (a))"
1819 ).unwrap();
1820 match stmt {
1821 Statement::CreateTable(ct) => {
1822 assert_eq!(ct.primary_key, vec!["a"]);
1823 assert!(ct.columns[0].is_primary_key);
1824 assert!(!ct.columns[0].nullable);
1825 }
1826 _ => panic!("expected CreateTable"),
1827 }
1828 }
1829}