1use sqlparser::ast::{
13 BinaryOperator, ColumnDef as SqlColumnDef, DataType as SqlDataType, Expr, Ident, ObjectName,
14 OrderByExpr, Query, Select, SelectItem, SetExpr, Statement, Value as SqlValue,
15};
16use sqlparser::dialect::GenericDialect;
17use sqlparser::parser::Parser;
18
19use crate::error::{QueryError, Result};
20use crate::schema::ColumnName;
21use crate::value::Value;
22
23#[derive(Debug, Clone)]
29pub enum ParsedStatement {
30 Select(ParsedSelect),
32 CreateTable(ParsedCreateTable),
34 DropTable(String),
36 CreateIndex(ParsedCreateIndex),
38 Insert(ParsedInsert),
40 Update(ParsedUpdate),
42 Delete(ParsedDelete),
44}
45
46#[derive(Debug, Clone)]
48pub struct ParsedSelect {
49 pub table: String,
51 pub columns: Option<Vec<ColumnName>>,
53 pub predicates: Vec<Predicate>,
55 pub order_by: Vec<OrderByClause>,
57 pub limit: Option<usize>,
59 pub aggregates: Vec<AggregateFunction>,
61 pub group_by: Vec<ColumnName>,
63 pub distinct: bool,
65}
66
67#[derive(Debug, Clone)]
69pub struct ParsedCreateTable {
70 pub table_name: String,
71 pub columns: Vec<ParsedColumn>,
72 pub primary_key: Vec<String>,
73}
74
75#[derive(Debug, Clone)]
77pub struct ParsedColumn {
78 pub name: String,
79 pub data_type: String, pub nullable: bool,
81}
82
83#[derive(Debug, Clone)]
85pub struct ParsedCreateIndex {
86 pub index_name: String,
87 pub table_name: String,
88 pub columns: Vec<String>,
89}
90
91#[derive(Debug, Clone)]
93pub struct ParsedInsert {
94 pub table: String,
95 pub columns: Vec<String>,
96 pub values: Vec<Vec<Value>>, pub returning: Option<Vec<String>>, }
99
100#[derive(Debug, Clone)]
102pub struct ParsedUpdate {
103 pub table: String,
104 pub assignments: Vec<(String, Value)>, pub predicates: Vec<Predicate>,
106 pub returning: Option<Vec<String>>, }
108
109#[derive(Debug, Clone)]
111pub struct ParsedDelete {
112 pub table: String,
113 pub predicates: Vec<Predicate>,
114 pub returning: Option<Vec<String>>, }
116
117#[derive(Debug, Clone, PartialEq, Eq)]
119pub enum AggregateFunction {
120 CountStar,
122 Count(ColumnName),
124 Sum(ColumnName),
126 Avg(ColumnName),
128 Min(ColumnName),
130 Max(ColumnName),
132}
133
134#[derive(Debug, Clone)]
136pub enum Predicate {
137 Eq(ColumnName, PredicateValue),
139 Lt(ColumnName, PredicateValue),
141 Le(ColumnName, PredicateValue),
143 Gt(ColumnName, PredicateValue),
145 Ge(ColumnName, PredicateValue),
147 In(ColumnName, Vec<PredicateValue>),
149 Like(ColumnName, String),
151 IsNull(ColumnName),
153 IsNotNull(ColumnName),
155 Or(Vec<Predicate>, Vec<Predicate>),
157}
158
159impl Predicate {
160 #[allow(dead_code)]
164 pub fn column(&self) -> Option<&ColumnName> {
165 match self {
166 Predicate::Eq(col, _)
167 | Predicate::Lt(col, _)
168 | Predicate::Le(col, _)
169 | Predicate::Gt(col, _)
170 | Predicate::Ge(col, _)
171 | Predicate::In(col, _)
172 | Predicate::Like(col, _)
173 | Predicate::IsNull(col)
174 | Predicate::IsNotNull(col) => Some(col),
175 Predicate::Or(_, _) => None,
176 }
177 }
178}
179
180#[derive(Debug, Clone)]
182pub enum PredicateValue {
183 Int(i64),
185 String(String),
187 Bool(bool),
189 Null,
191 Param(usize),
193 Literal(Value),
195}
196
197#[derive(Debug, Clone)]
199pub struct OrderByClause {
200 pub column: ColumnName,
202 pub ascending: bool,
204}
205
206pub fn parse_statement(sql: &str) -> Result<ParsedStatement> {
212 let dialect = GenericDialect {};
213 let statements =
214 Parser::parse_sql(&dialect, sql).map_err(|e| QueryError::ParseError(e.to_string()))?;
215
216 if statements.len() != 1 {
217 return Err(QueryError::ParseError(format!(
218 "expected exactly 1 statement, got {}",
219 statements.len()
220 )));
221 }
222
223 match &statements[0] {
224 Statement::Query(query) => {
225 let select = parse_select_query(query)?;
226 Ok(ParsedStatement::Select(select))
227 }
228 Statement::CreateTable(create_table) => {
229 let parsed = parse_create_table(create_table)?;
230 Ok(ParsedStatement::CreateTable(parsed))
231 }
232 Statement::Drop {
233 object_type,
234 names,
235 if_exists: _,
236 ..
237 } => {
238 if !matches!(object_type, sqlparser::ast::ObjectType::Table) {
239 return Err(QueryError::UnsupportedFeature(
240 "only DROP TABLE is supported".to_string(),
241 ));
242 }
243 if names.len() != 1 {
244 return Err(QueryError::ParseError(
245 "expected exactly 1 table in DROP TABLE".to_string(),
246 ));
247 }
248 let table_name = object_name_to_string(&names[0]);
249 Ok(ParsedStatement::DropTable(table_name))
250 }
251 Statement::CreateIndex(create_index) => {
252 let parsed = parse_create_index(create_index)?;
253 Ok(ParsedStatement::CreateIndex(parsed))
254 }
255 Statement::Insert(insert) => {
256 let parsed = parse_insert(insert)?;
257 Ok(ParsedStatement::Insert(parsed))
258 }
259 Statement::Update {
260 table,
261 assignments,
262 selection,
263 returning,
264 ..
265 } => {
266 let parsed = parse_update(table, assignments, selection.as_ref(), returning.as_ref())?;
267 Ok(ParsedStatement::Update(parsed))
268 }
269 Statement::Delete(delete) => {
270 let parsed = parse_delete_stmt(delete)?;
271 Ok(ParsedStatement::Delete(parsed))
272 }
273 other => Err(QueryError::UnsupportedFeature(format!(
274 "statement type not supported: {other:?}"
275 ))),
276 }
277}
278
279pub fn parse_query(sql: &str) -> Result<ParsedSelect> {
281 match parse_statement(sql)? {
282 ParsedStatement::Select(select) => Ok(select),
283 _ => Err(QueryError::UnsupportedFeature(
284 "only SELECT queries are supported in parse_query()".to_string(),
285 )),
286 }
287}
288
289fn parse_select_query(query: &Query) -> Result<ParsedSelect> {
290 if query.with.is_some() {
292 return Err(QueryError::UnsupportedFeature(
293 "WITH clauses (CTEs) are not supported".to_string(),
294 ));
295 }
296
297 let SetExpr::Select(select) = query.body.as_ref() else {
298 return Err(QueryError::UnsupportedFeature(
299 "only simple SELECT queries are supported".to_string(),
300 ));
301 };
302
303 let parsed_select = parse_select(select)?;
304
305 let order_by = match &query.order_by {
307 Some(ob) => parse_order_by(ob)?,
308 None => vec![],
309 };
310
311 let limit = parse_limit(query.limit.as_ref())?;
313
314 Ok(ParsedSelect {
315 table: parsed_select.table,
316 columns: parsed_select.columns,
317 predicates: parsed_select.predicates,
318 order_by,
319 limit,
320 aggregates: parsed_select.aggregates,
321 group_by: parsed_select.group_by,
322 distinct: parsed_select.distinct,
323 })
324}
325
326fn parse_select(select: &Select) -> Result<ParsedSelect> {
327 let distinct = select.distinct.is_some();
329
330 if select.from.len() != 1 {
332 return Err(QueryError::ParseError(format!(
333 "expected exactly 1 table in FROM clause, got {}",
334 select.from.len()
335 )));
336 }
337
338 let from = &select.from[0];
339
340 if !from.joins.is_empty() {
342 return Err(QueryError::UnsupportedFeature(
343 "JOINs are not supported".to_string(),
344 ));
345 }
346
347 let table = match &from.relation {
348 sqlparser::ast::TableFactor::Table { name, .. } => object_name_to_string(name),
349 other => {
350 return Err(QueryError::UnsupportedFeature(format!(
351 "unsupported FROM clause: {other:?}"
352 )));
353 }
354 };
355
356 let columns = parse_select_items(&select.projection)?;
358
359 let predicates = match &select.selection {
361 Some(expr) => parse_where_expr(expr)?,
362 None => vec![],
363 };
364
365 let group_by = match &select.group_by {
367 sqlparser::ast::GroupByExpr::Expressions(exprs, _) if !exprs.is_empty() => {
368 parse_group_by_expr(exprs)?
369 }
370 sqlparser::ast::GroupByExpr::All(_) => {
371 return Err(QueryError::UnsupportedFeature(
372 "GROUP BY ALL is not supported".to_string(),
373 ));
374 }
375 sqlparser::ast::GroupByExpr::Expressions(_, _) => vec![],
376 };
377
378 let aggregates = parse_aggregates_from_select_items(&select.projection)?;
380
381 if select.having.is_some() {
383 return Err(QueryError::UnsupportedFeature(
384 "HAVING is not supported yet".to_string(),
385 ));
386 }
387
388 Ok(ParsedSelect {
389 table,
390 columns,
391 predicates,
392 order_by: vec![],
393 limit: None,
394 aggregates,
395 group_by,
396 distinct,
397 })
398}
399
400fn parse_select_items(items: &[SelectItem]) -> Result<Option<Vec<ColumnName>>> {
401 let mut columns = Vec::new();
402
403 for item in items {
404 match item {
405 SelectItem::Wildcard(_) => {
406 return Ok(None);
408 }
409 SelectItem::UnnamedExpr(Expr::Identifier(ident)) => {
410 columns.push(ColumnName::new(ident.value.clone()));
411 }
412 SelectItem::ExprWithAlias {
413 expr: Expr::Identifier(ident),
414 alias,
415 } => {
416 let _ = alias;
418 columns.push(ColumnName::new(ident.value.clone()));
419 }
420 SelectItem::UnnamedExpr(Expr::Function(_))
421 | SelectItem::ExprWithAlias {
422 expr: Expr::Function(_),
423 ..
424 } => {
425 }
428 other => {
429 return Err(QueryError::UnsupportedFeature(format!(
430 "unsupported SELECT item: {other:?}"
431 )));
432 }
433 }
434 }
435
436 Ok(Some(columns))
437}
438
439fn parse_aggregates_from_select_items(items: &[SelectItem]) -> Result<Vec<AggregateFunction>> {
441 let mut aggregates = Vec::new();
442
443 for item in items {
444 match item {
445 SelectItem::UnnamedExpr(expr) | SelectItem::ExprWithAlias { expr, .. } => {
446 if let Some(agg) = try_parse_aggregate(expr)? {
447 aggregates.push(agg);
448 }
449 }
450 _ => {
451 }
453 }
454 }
455
456 Ok(aggregates)
457}
458
459fn try_parse_aggregate(expr: &Expr) -> Result<Option<AggregateFunction>> {
462 match expr {
463 Expr::Function(func) => {
464 let func_name = func.name.to_string().to_uppercase();
465
466 let args = match &func.args {
468 sqlparser::ast::FunctionArguments::List(list) => &list.args,
469 _ => {
470 return Err(QueryError::UnsupportedFeature(
471 "non-list function arguments not supported".to_string(),
472 ));
473 }
474 };
475
476 match func_name.as_str() {
477 "COUNT" => {
478 if args.len() == 1 {
480 match &args[0] {
481 sqlparser::ast::FunctionArg::Unnamed(arg_expr) => match arg_expr {
482 sqlparser::ast::FunctionArgExpr::Wildcard => {
483 Ok(Some(AggregateFunction::CountStar))
484 }
485 sqlparser::ast::FunctionArgExpr::Expr(Expr::Identifier(ident)) => {
486 Ok(Some(AggregateFunction::Count(ColumnName::new(
487 ident.value.clone(),
488 ))))
489 }
490 _ => Err(QueryError::UnsupportedFeature(
491 "COUNT with complex expression not supported".to_string(),
492 )),
493 },
494 _ => Err(QueryError::UnsupportedFeature(
495 "named function arguments not supported".to_string(),
496 )),
497 }
498 } else {
499 Err(QueryError::ParseError(format!(
500 "COUNT expects 1 argument, got {}",
501 args.len()
502 )))
503 }
504 }
505 "SUM" | "AVG" | "MIN" | "MAX" => {
506 if args.len() != 1 {
508 return Err(QueryError::ParseError(format!(
509 "{} expects 1 argument, got {}",
510 func_name,
511 args.len()
512 )));
513 }
514
515 match &args[0] {
516 sqlparser::ast::FunctionArg::Unnamed(arg_expr) => match arg_expr {
517 sqlparser::ast::FunctionArgExpr::Expr(Expr::Identifier(ident)) => {
518 let column = ColumnName::new(ident.value.clone());
519 match func_name.as_str() {
520 "SUM" => Ok(Some(AggregateFunction::Sum(column))),
521 "AVG" => Ok(Some(AggregateFunction::Avg(column))),
522 "MIN" => Ok(Some(AggregateFunction::Min(column))),
523 "MAX" => Ok(Some(AggregateFunction::Max(column))),
524 _ => unreachable!(),
525 }
526 }
527 _ => Err(QueryError::UnsupportedFeature(format!(
528 "{func_name} with complex expression not supported"
529 ))),
530 },
531 _ => Err(QueryError::UnsupportedFeature(
532 "named function arguments not supported".to_string(),
533 )),
534 }
535 }
536 _ => {
537 Ok(None)
539 }
540 }
541 }
542 _ => {
543 Ok(None)
545 }
546 }
547}
548
549fn parse_group_by_expr(exprs: &[Expr]) -> Result<Vec<ColumnName>> {
551 let mut columns = Vec::new();
552
553 for expr in exprs {
554 match expr {
555 Expr::Identifier(ident) => {
556 columns.push(ColumnName::new(ident.value.clone()));
557 }
558 _ => {
559 return Err(QueryError::UnsupportedFeature(
560 "complex GROUP BY expressions not supported".to_string(),
561 ));
562 }
563 }
564 }
565
566 Ok(columns)
567}
568
569const MAX_WHERE_DEPTH: usize = 100;
577
578fn parse_where_expr(expr: &Expr) -> Result<Vec<Predicate>> {
579 parse_where_expr_inner(expr, 0)
580}
581
582fn parse_where_expr_inner(expr: &Expr, depth: usize) -> Result<Vec<Predicate>> {
583 if depth >= MAX_WHERE_DEPTH {
584 return Err(QueryError::ParseError(format!(
585 "WHERE clause nesting exceeds maximum depth of {MAX_WHERE_DEPTH}"
586 )));
587 }
588
589 match expr {
590 Expr::BinaryOp {
592 left,
593 op: BinaryOperator::And,
594 right,
595 } => {
596 let mut predicates = parse_where_expr_inner(left, depth + 1)?;
597 predicates.extend(parse_where_expr_inner(right, depth + 1)?);
598 Ok(predicates)
599 }
600
601 Expr::BinaryOp {
603 left,
604 op: BinaryOperator::Or,
605 right,
606 } => {
607 let left_preds = parse_where_expr_inner(left, depth + 1)?;
608 let right_preds = parse_where_expr_inner(right, depth + 1)?;
609 Ok(vec![Predicate::Or(left_preds, right_preds)])
610 }
611
612 Expr::Like {
614 expr,
615 pattern,
616 negated,
617 ..
618 } => {
619 if *negated {
620 return Err(QueryError::UnsupportedFeature(
621 "NOT LIKE is not supported".to_string(),
622 ));
623 }
624
625 let column = expr_to_column(expr)?;
626 let pattern_value = expr_to_predicate_value(pattern)?;
627
628 match pattern_value {
629 PredicateValue::String(pattern_str)
630 | PredicateValue::Literal(Value::Text(pattern_str)) => {
631 Ok(vec![Predicate::Like(column, pattern_str)])
632 }
633 _ => Err(QueryError::UnsupportedFeature(
634 "LIKE pattern must be a string literal".to_string(),
635 )),
636 }
637 }
638
639 Expr::IsNull(expr) => {
641 let column = expr_to_column(expr)?;
642 Ok(vec![Predicate::IsNull(column)])
643 }
644
645 Expr::IsNotNull(expr) => {
646 let column = expr_to_column(expr)?;
647 Ok(vec![Predicate::IsNotNull(column)])
648 }
649
650 Expr::BinaryOp { left, op, right } => {
652 let predicate = parse_comparison(left, op, right)?;
653 Ok(vec![predicate])
654 }
655
656 Expr::InList {
658 expr,
659 list,
660 negated,
661 } => {
662 if *negated {
663 return Err(QueryError::UnsupportedFeature(
664 "NOT IN is not supported".to_string(),
665 ));
666 }
667
668 let column = expr_to_column(expr)?;
669 let values: Result<Vec<_>> = list.iter().map(expr_to_predicate_value).collect();
670 Ok(vec![Predicate::In(column, values?)])
671 }
672
673 Expr::Nested(inner) => parse_where_expr_inner(inner, depth + 1),
675
676 other => Err(QueryError::UnsupportedFeature(format!(
677 "unsupported WHERE expression: {other:?}"
678 ))),
679 }
680}
681
682fn parse_comparison(left: &Expr, op: &BinaryOperator, right: &Expr) -> Result<Predicate> {
683 let column = expr_to_column(left)?;
684 let value = expr_to_predicate_value(right)?;
685
686 match op {
687 BinaryOperator::Eq => Ok(Predicate::Eq(column, value)),
688 BinaryOperator::Lt => Ok(Predicate::Lt(column, value)),
689 BinaryOperator::LtEq => Ok(Predicate::Le(column, value)),
690 BinaryOperator::Gt => Ok(Predicate::Gt(column, value)),
691 BinaryOperator::GtEq => Ok(Predicate::Ge(column, value)),
692 other => Err(QueryError::UnsupportedFeature(format!(
693 "unsupported operator: {other:?}"
694 ))),
695 }
696}
697
698fn expr_to_column(expr: &Expr) -> Result<ColumnName> {
699 match expr {
700 Expr::Identifier(ident) => Ok(ColumnName::new(ident.value.clone())),
701 Expr::CompoundIdentifier(idents) if idents.len() == 2 => {
702 Ok(ColumnName::new(idents[1].value.clone()))
704 }
705 other => Err(QueryError::UnsupportedFeature(format!(
706 "expected column name, got {other:?}"
707 ))),
708 }
709}
710
711fn expr_to_predicate_value(expr: &Expr) -> Result<PredicateValue> {
712 match expr {
713 Expr::Value(SqlValue::Number(n, _)) => {
714 let value = parse_number_literal(n)?;
715 match value {
716 Value::BigInt(v) => Ok(PredicateValue::Int(v)),
717 Value::Decimal(_, _) => Ok(PredicateValue::Literal(value)),
718 _ => unreachable!("parse_number_literal only returns BigInt or Decimal"),
719 }
720 }
721 Expr::Value(SqlValue::SingleQuotedString(s) | SqlValue::DoubleQuotedString(s)) => {
722 Ok(PredicateValue::String(s.clone()))
723 }
724 Expr::Value(SqlValue::Boolean(b)) => Ok(PredicateValue::Bool(*b)),
725 Expr::Value(SqlValue::Null) => Ok(PredicateValue::Null),
726 Expr::Value(SqlValue::Placeholder(p)) => {
727 if let Some(num_str) = p.strip_prefix('$') {
729 let idx: usize = num_str.parse().map_err(|_| {
730 QueryError::ParseError(format!("invalid parameter placeholder: {p}"))
731 })?;
732 if idx == 0 {
734 return Err(QueryError::ParseError(
735 "parameter indices start at $1, not $0".to_string(),
736 ));
737 }
738 Ok(PredicateValue::Param(idx))
739 } else {
740 Err(QueryError::ParseError(format!(
741 "unsupported placeholder format: {p}"
742 )))
743 }
744 }
745 Expr::UnaryOp {
746 op: sqlparser::ast::UnaryOperator::Minus,
747 expr,
748 } => {
749 if let Expr::Value(SqlValue::Number(n, _)) = expr.as_ref() {
751 let value = parse_number_literal(n)?;
752 match value {
753 Value::BigInt(v) => Ok(PredicateValue::Int(-v)),
754 Value::Decimal(v, scale) => {
755 Ok(PredicateValue::Literal(Value::Decimal(-v, scale)))
756 }
757 _ => unreachable!("parse_number_literal only returns BigInt or Decimal"),
758 }
759 } else {
760 Err(QueryError::UnsupportedFeature(format!(
761 "unsupported unary minus operand: {expr:?}"
762 )))
763 }
764 }
765 other => Err(QueryError::UnsupportedFeature(format!(
766 "unsupported value expression: {other:?}"
767 ))),
768 }
769}
770
771fn parse_order_by(order_by: &sqlparser::ast::OrderBy) -> Result<Vec<OrderByClause>> {
772 let mut clauses = Vec::new();
773
774 for expr in &order_by.exprs {
775 clauses.push(parse_order_by_expr(expr)?);
776 }
777
778 Ok(clauses)
779}
780
781fn parse_order_by_expr(expr: &OrderByExpr) -> Result<OrderByClause> {
782 let column = match &expr.expr {
783 Expr::Identifier(ident) => ColumnName::new(ident.value.clone()),
784 other => {
785 return Err(QueryError::UnsupportedFeature(format!(
786 "unsupported ORDER BY expression: {other:?}"
787 )));
788 }
789 };
790
791 let ascending = expr.asc.unwrap_or(true);
792
793 Ok(OrderByClause { column, ascending })
794}
795
796fn parse_limit(limit: Option<&Expr>) -> Result<Option<usize>> {
797 match limit {
798 None => Ok(None),
799 Some(Expr::Value(SqlValue::Number(n, _))) => {
800 let v: usize = n
801 .parse()
802 .map_err(|_| QueryError::ParseError(format!("invalid LIMIT value: {n}")))?;
803 Ok(Some(v))
804 }
805 Some(other) => Err(QueryError::UnsupportedFeature(format!(
806 "unsupported LIMIT expression: {other:?}"
807 ))),
808 }
809}
810
811fn object_name_to_string(name: &ObjectName) -> String {
812 name.0
813 .iter()
814 .map(|i: &Ident| i.value.clone())
815 .collect::<Vec<_>>()
816 .join(".")
817}
818
819fn parse_create_table(create_table: &sqlparser::ast::CreateTable) -> Result<ParsedCreateTable> {
824 let table_name = object_name_to_string(&create_table.name);
825
826 let mut columns = Vec::new();
828 for col_def in &create_table.columns {
829 let parsed_col = parse_column_def(col_def)?;
830 columns.push(parsed_col);
831 }
832
833 let mut primary_key = Vec::new();
835 for constraint in &create_table.constraints {
836 if let sqlparser::ast::TableConstraint::PrimaryKey {
837 columns: pk_cols, ..
838 } = constraint
839 {
840 for col in pk_cols {
841 primary_key.push(col.value.clone());
842 }
843 }
844 }
845
846 if primary_key.is_empty() {
848 for col_def in &create_table.columns {
849 for option in &col_def.options {
850 if matches!(
851 &option.option,
852 sqlparser::ast::ColumnOption::Unique { is_primary, .. } if *is_primary
853 ) {
854 primary_key.push(col_def.name.value.clone());
855 }
856 }
857 }
858 }
859
860 Ok(ParsedCreateTable {
861 table_name,
862 columns,
863 primary_key,
864 })
865}
866
867fn parse_column_def(col_def: &SqlColumnDef) -> Result<ParsedColumn> {
868 let name = col_def.name.value.clone();
869
870 let data_type = match &col_def.data_type {
873 SqlDataType::TinyInt(_) => "TINYINT".to_string(),
875 SqlDataType::SmallInt(_) => "SMALLINT".to_string(),
876 SqlDataType::Int(_) | SqlDataType::Integer(_) => "INTEGER".to_string(),
877 SqlDataType::BigInt(_) => "BIGINT".to_string(),
878
879 SqlDataType::Real | SqlDataType::Float(_) | SqlDataType::Double(_) => "REAL".to_string(),
881 SqlDataType::Decimal(precision_opt) => match precision_opt {
882 sqlparser::ast::ExactNumberInfo::PrecisionAndScale(p, s) => {
883 format!("DECIMAL({p},{s})")
884 }
885 sqlparser::ast::ExactNumberInfo::Precision(p) => {
886 format!("DECIMAL({p},0)")
887 }
888 sqlparser::ast::ExactNumberInfo::None => "DECIMAL(18,2)".to_string(),
889 },
890
891 SqlDataType::Text | SqlDataType::Varchar(_) | SqlDataType::String(_) => "TEXT".to_string(),
893
894 SqlDataType::Binary(_) | SqlDataType::Varbinary(_) | SqlDataType::Blob(_) => {
896 "BYTES".to_string()
897 }
898
899 SqlDataType::Boolean | SqlDataType::Bool => "BOOLEAN".to_string(),
901
902 SqlDataType::Date => "DATE".to_string(),
904 SqlDataType::Time(_, _) => "TIME".to_string(),
905 SqlDataType::Timestamp(_, _) => "TIMESTAMP".to_string(),
906
907 SqlDataType::Uuid => "UUID".to_string(),
909 SqlDataType::JSON => "JSON".to_string(),
910
911 other => {
912 return Err(QueryError::UnsupportedFeature(format!(
913 "unsupported data type: {other:?}"
914 )));
915 }
916 };
917
918 let mut nullable = true;
920 for option in &col_def.options {
921 if matches!(option.option, sqlparser::ast::ColumnOption::NotNull) {
922 nullable = false;
923 }
924 }
925
926 Ok(ParsedColumn {
927 name,
928 data_type,
929 nullable,
930 })
931}
932
933fn parse_create_index(create_index: &sqlparser::ast::CreateIndex) -> Result<ParsedCreateIndex> {
934 let index_name = match &create_index.name {
935 Some(name) => object_name_to_string(name),
936 None => {
937 return Err(QueryError::ParseError(
938 "CREATE INDEX requires an index name".to_string(),
939 ));
940 }
941 };
942
943 let table_name = object_name_to_string(&create_index.table_name);
944
945 let mut columns = Vec::new();
946 for col in &create_index.columns {
947 columns.push(col.expr.to_string());
948 }
949
950 Ok(ParsedCreateIndex {
951 index_name,
952 table_name,
953 columns,
954 })
955}
956
957fn parse_insert(insert: &sqlparser::ast::Insert) -> Result<ParsedInsert> {
962 let table = insert.table.to_string();
964
965 let columns: Vec<String> = insert.columns.iter().map(|c| c.value.clone()).collect();
967
968 let values = match insert.source.as_ref().map(|s| s.body.as_ref()) {
970 Some(SetExpr::Values(values)) => {
971 let mut all_rows = Vec::new();
972 for row in &values.rows {
973 let mut parsed_row = Vec::new();
974 for expr in row {
975 let val = expr_to_value(expr)?;
976 parsed_row.push(val);
977 }
978 all_rows.push(parsed_row);
979 }
980 all_rows
981 }
982 _ => {
983 return Err(QueryError::UnsupportedFeature(
984 "only VALUES clause is supported in INSERT".to_string(),
985 ));
986 }
987 };
988
989 let returning = parse_returning(insert.returning.as_ref())?;
991
992 Ok(ParsedInsert {
993 table,
994 columns,
995 values,
996 returning,
997 })
998}
999
1000fn parse_update(
1001 table: &sqlparser::ast::TableWithJoins,
1002 assignments: &[sqlparser::ast::Assignment],
1003 selection: Option<&Expr>,
1004 returning: Option<&Vec<SelectItem>>,
1005) -> Result<ParsedUpdate> {
1006 let table_name = match &table.relation {
1007 sqlparser::ast::TableFactor::Table { name, .. } => object_name_to_string(name),
1008 other => {
1009 return Err(QueryError::UnsupportedFeature(format!(
1010 "unsupported table in UPDATE: {other:?}"
1011 )));
1012 }
1013 };
1014
1015 let mut parsed_assignments = Vec::new();
1017 for assignment in assignments {
1018 let col_name = assignment.target.to_string();
1019 let value = expr_to_value(&assignment.value)?;
1020 parsed_assignments.push((col_name, value));
1021 }
1022
1023 let predicates = match selection {
1025 Some(expr) => parse_where_expr(expr)?,
1026 None => vec![],
1027 };
1028
1029 let returning_cols = parse_returning(returning)?;
1031
1032 Ok(ParsedUpdate {
1033 table: table_name,
1034 assignments: parsed_assignments,
1035 predicates,
1036 returning: returning_cols,
1037 })
1038}
1039
1040fn parse_delete_stmt(delete: &sqlparser::ast::Delete) -> Result<ParsedDelete> {
1041 use sqlparser::ast::FromTable;
1043
1044 let table_name = match &delete.from {
1045 FromTable::WithFromKeyword(tables) => {
1046 if tables.len() != 1 {
1047 return Err(QueryError::ParseError(
1048 "expected exactly 1 table in DELETE FROM".to_string(),
1049 ));
1050 }
1051
1052 match &tables[0].relation {
1053 sqlparser::ast::TableFactor::Table { name, .. } => object_name_to_string(name),
1054 _ => {
1055 return Err(QueryError::ParseError(
1056 "DELETE only supports simple table names".to_string(),
1057 ));
1058 }
1059 }
1060 }
1061 FromTable::WithoutKeyword(tables) => {
1062 if tables.len() != 1 {
1063 return Err(QueryError::ParseError(
1064 "expected exactly 1 table in DELETE".to_string(),
1065 ));
1066 }
1067
1068 match &tables[0].relation {
1069 sqlparser::ast::TableFactor::Table { name, .. } => object_name_to_string(name),
1070 _ => {
1071 return Err(QueryError::ParseError(
1072 "DELETE only supports simple table names".to_string(),
1073 ));
1074 }
1075 }
1076 }
1077 };
1078
1079 let predicates = match &delete.selection {
1081 Some(expr) => parse_where_expr(expr)?,
1082 None => vec![],
1083 };
1084
1085 let returning_cols = parse_returning(delete.returning.as_ref())?;
1087
1088 Ok(ParsedDelete {
1089 table: table_name,
1090 predicates,
1091 returning: returning_cols,
1092 })
1093}
1094
1095fn parse_returning(returning: Option<&Vec<SelectItem>>) -> Result<Option<Vec<String>>> {
1097 match returning {
1098 None => Ok(None),
1099 Some(items) => {
1100 let mut columns = Vec::new();
1101 for item in items {
1102 match item {
1103 SelectItem::UnnamedExpr(Expr::Identifier(ident)) => {
1104 columns.push(ident.value.clone());
1105 }
1106 SelectItem::UnnamedExpr(Expr::CompoundIdentifier(parts)) => {
1107 if let Some(last) = parts.last() {
1109 columns.push(last.value.clone());
1110 } else {
1111 return Err(QueryError::ParseError(
1112 "invalid column in RETURNING clause".to_string(),
1113 ));
1114 }
1115 }
1116 _ => {
1117 return Err(QueryError::UnsupportedFeature(
1118 "only simple column names supported in RETURNING clause".to_string(),
1119 ));
1120 }
1121 }
1122 }
1123 Ok(Some(columns))
1124 }
1125 }
1126}
1127
1128fn parse_number_literal(n: &str) -> Result<Value> {
1132 use rust_decimal::Decimal;
1133 use std::str::FromStr;
1134
1135 if n.contains('.') {
1136 let decimal = Decimal::from_str(n)
1138 .map_err(|e| QueryError::ParseError(format!("invalid decimal '{n}': {e}")))?;
1139
1140 let scale = decimal.scale() as u8;
1142
1143 if scale > 38 {
1144 return Err(QueryError::ParseError(format!(
1145 "decimal scale too large (max 38): {n}"
1146 )));
1147 }
1148
1149 let mantissa = decimal.mantissa();
1152
1153 Ok(Value::Decimal(mantissa, scale))
1154 } else {
1155 let v: i64 = n
1157 .parse()
1158 .map_err(|_| QueryError::ParseError(format!("invalid integer: {n}")))?;
1159 Ok(Value::BigInt(v))
1160 }
1161}
1162
1163fn expr_to_value(expr: &Expr) -> Result<Value> {
1165 match expr {
1166 Expr::Value(SqlValue::Number(n, _)) => parse_number_literal(n),
1167 Expr::Value(SqlValue::SingleQuotedString(s) | SqlValue::DoubleQuotedString(s)) => {
1168 Ok(Value::Text(s.clone()))
1169 }
1170 Expr::Value(SqlValue::Boolean(b)) => Ok(Value::Boolean(*b)),
1171 Expr::Value(SqlValue::Null) => Ok(Value::Null),
1172 Expr::Value(SqlValue::Placeholder(p)) => {
1173 if let Some(num_str) = p.strip_prefix('$') {
1175 let idx: usize = num_str.parse().map_err(|_| {
1176 QueryError::ParseError(format!("invalid parameter placeholder: {p}"))
1177 })?;
1178 if idx == 0 {
1180 return Err(QueryError::ParseError(
1181 "parameter indices start at $1, not $0".to_string(),
1182 ));
1183 }
1184 Ok(Value::Placeholder(idx))
1185 } else {
1186 Err(QueryError::ParseError(format!(
1187 "unsupported placeholder format: {p}"
1188 )))
1189 }
1190 }
1191 Expr::UnaryOp {
1192 op: sqlparser::ast::UnaryOperator::Minus,
1193 expr,
1194 } => {
1195 if let Expr::Value(SqlValue::Number(n, _)) = expr.as_ref() {
1197 let value = parse_number_literal(n)?;
1198 match value {
1199 Value::BigInt(v) => Ok(Value::BigInt(-v)),
1200 Value::Decimal(v, scale) => Ok(Value::Decimal(-v, scale)),
1201 _ => unreachable!("parse_number_literal only returns BigInt or Decimal"),
1202 }
1203 } else {
1204 Err(QueryError::UnsupportedFeature(format!(
1205 "unsupported unary minus operand: {expr:?}"
1206 )))
1207 }
1208 }
1209 other => Err(QueryError::UnsupportedFeature(format!(
1210 "unsupported value expression: {other:?}"
1211 ))),
1212 }
1213}
1214
1215#[cfg(test)]
1216mod tests {
1217 use super::*;
1218
1219 #[test]
1220 fn test_parse_simple_select() {
1221 let result = parse_query("SELECT id, name FROM users").unwrap();
1222 assert_eq!(result.table, "users");
1223 assert_eq!(
1224 result.columns,
1225 Some(vec![ColumnName::new("id"), ColumnName::new("name")])
1226 );
1227 assert!(result.predicates.is_empty());
1228 }
1229
1230 #[test]
1231 fn test_parse_select_star() {
1232 let result = parse_query("SELECT * FROM users").unwrap();
1233 assert_eq!(result.table, "users");
1234 assert!(result.columns.is_none());
1235 }
1236
1237 #[test]
1238 fn test_parse_where_eq() {
1239 let result = parse_query("SELECT * FROM users WHERE id = 42").unwrap();
1240 assert_eq!(result.predicates.len(), 1);
1241 match &result.predicates[0] {
1242 Predicate::Eq(col, PredicateValue::Int(42)) => {
1243 assert_eq!(col.as_str(), "id");
1244 }
1245 other => panic!("unexpected predicate: {other:?}"),
1246 }
1247 }
1248
1249 #[test]
1250 fn test_parse_where_string() {
1251 let result = parse_query("SELECT * FROM users WHERE name = 'alice'").unwrap();
1252 match &result.predicates[0] {
1253 Predicate::Eq(col, PredicateValue::String(s)) => {
1254 assert_eq!(col.as_str(), "name");
1255 assert_eq!(s, "alice");
1256 }
1257 other => panic!("unexpected predicate: {other:?}"),
1258 }
1259 }
1260
1261 #[test]
1262 fn test_parse_where_and() {
1263 let result = parse_query("SELECT * FROM users WHERE id = 1 AND name = 'bob'").unwrap();
1264 assert_eq!(result.predicates.len(), 2);
1265 }
1266
1267 #[test]
1268 fn test_parse_where_in() {
1269 let result = parse_query("SELECT * FROM users WHERE id IN (1, 2, 3)").unwrap();
1270 match &result.predicates[0] {
1271 Predicate::In(col, values) => {
1272 assert_eq!(col.as_str(), "id");
1273 assert_eq!(values.len(), 3);
1274 }
1275 other => panic!("unexpected predicate: {other:?}"),
1276 }
1277 }
1278
1279 #[test]
1280 fn test_parse_order_by() {
1281 let result = parse_query("SELECT * FROM users ORDER BY name ASC, id DESC").unwrap();
1282 assert_eq!(result.order_by.len(), 2);
1283 assert_eq!(result.order_by[0].column.as_str(), "name");
1284 assert!(result.order_by[0].ascending);
1285 assert_eq!(result.order_by[1].column.as_str(), "id");
1286 assert!(!result.order_by[1].ascending);
1287 }
1288
1289 #[test]
1290 fn test_parse_limit() {
1291 let result = parse_query("SELECT * FROM users LIMIT 10").unwrap();
1292 assert_eq!(result.limit, Some(10));
1293 }
1294
1295 #[test]
1296 fn test_parse_param() {
1297 let result = parse_query("SELECT * FROM users WHERE id = $1").unwrap();
1298 match &result.predicates[0] {
1299 Predicate::Eq(_, PredicateValue::Param(1)) => {}
1300 other => panic!("unexpected predicate: {other:?}"),
1301 }
1302 }
1303
1304 #[test]
1305 fn test_reject_join() {
1306 let result = parse_query("SELECT * FROM users JOIN orders ON users.id = orders.user_id");
1307 assert!(result.is_err());
1308 }
1309
1310 #[test]
1311 fn test_reject_subquery() {
1312 let result = parse_query("SELECT * FROM (SELECT * FROM users)");
1313 assert!(result.is_err());
1314 }
1315
1316 #[test]
1317 fn test_where_depth_within_limit() {
1318 let mut sql = String::from("SELECT * FROM users WHERE ");
1321 for i in 0..10 {
1322 if i > 0 {
1323 sql.push_str(" AND ");
1324 }
1325 sql.push('(');
1326 sql.push_str("id = ");
1327 sql.push_str(&i.to_string());
1328 sql.push(')');
1329 }
1330
1331 let result = parse_query(&sql);
1332 assert!(
1333 result.is_ok(),
1334 "Moderate nesting should succeed, but got: {result:?}"
1335 );
1336 }
1337
1338 #[test]
1339 fn test_where_depth_nested_parens() {
1340 let mut sql = String::from("SELECT * FROM users WHERE ");
1343 for _ in 0..200 {
1344 sql.push('(');
1345 }
1346 sql.push_str("id = 1");
1347 for _ in 0..200 {
1348 sql.push(')');
1349 }
1350
1351 let result = parse_query(&sql);
1352 assert!(
1353 result.is_err(),
1354 "Excessive parenthesis nesting should be rejected"
1355 );
1356 }
1357
1358 #[test]
1359 fn test_where_depth_complex_and_or() {
1360 let sql = "SELECT * FROM users WHERE \
1362 ((id = 1 AND name = 'a') OR (id = 2 AND name = 'b')) AND \
1363 ((age > 10 AND age < 20) OR (age > 30 AND age < 40))";
1364
1365 let result = parse_query(sql);
1366 assert!(result.is_ok(), "Complex AND/OR should succeed");
1367 }
1368}