Skip to main content

kimberlite_query/
parser.rs

1//! SQL parsing for the query engine.
2//!
3//! Wraps `sqlparser` to parse a minimal SQL subset:
4//! - SELECT with column list or *
5//! - FROM single table
6//! - WHERE with comparison predicates
7//! - ORDER BY
8//! - LIMIT
9//! - CREATE TABLE, DROP TABLE, CREATE INDEX (DDL)
10//! - INSERT, UPDATE, DELETE (DML)
11
12use sqlparser::ast::{
13    BinaryOperator, ColumnDef as SqlColumnDef, DataType as SqlDataType, Expr, Ident, ObjectName,
14    OrderByExpr, Query, Select, SelectItem, SetExpr, Statement, Value as SqlValue,
15};
16use sqlparser::dialect::GenericDialect;
17use sqlparser::parser::Parser;
18
19use crate::error::{QueryError, Result};
20use crate::schema::ColumnName;
21use crate::value::Value;
22
23// ============================================================================
24// Parsed Statement Types
25// ============================================================================
26
27/// Top-level parsed SQL statement.
28#[derive(Debug, Clone)]
29pub enum ParsedStatement {
30    /// SELECT query
31    Select(ParsedSelect),
32    /// CREATE TABLE DDL
33    CreateTable(ParsedCreateTable),
34    /// DROP TABLE DDL
35    DropTable(String),
36    /// CREATE INDEX DDL
37    CreateIndex(ParsedCreateIndex),
38    /// INSERT DML
39    Insert(ParsedInsert),
40    /// UPDATE DML
41    Update(ParsedUpdate),
42    /// DELETE DML
43    Delete(ParsedDelete),
44}
45
46/// Parsed SELECT statement.
47#[derive(Debug, Clone)]
48pub struct ParsedSelect {
49    /// Table name from FROM clause.
50    pub table: String,
51    /// Selected columns (None = SELECT *).
52    pub columns: Option<Vec<ColumnName>>,
53    /// WHERE predicates.
54    pub predicates: Vec<Predicate>,
55    /// ORDER BY clauses.
56    pub order_by: Vec<OrderByClause>,
57    /// LIMIT value.
58    pub limit: Option<usize>,
59    /// Aggregate functions in SELECT clause.
60    pub aggregates: Vec<AggregateFunction>,
61    /// GROUP BY columns.
62    pub group_by: Vec<ColumnName>,
63    /// Whether DISTINCT is specified.
64    pub distinct: bool,
65}
66
67/// Parsed CREATE TABLE statement.
68#[derive(Debug, Clone)]
69pub struct ParsedCreateTable {
70    pub table_name: String,
71    pub columns: Vec<ParsedColumn>,
72    pub primary_key: Vec<String>,
73}
74
75/// Parsed column definition.
76#[derive(Debug, Clone)]
77pub struct ParsedColumn {
78    pub name: String,
79    pub data_type: String, // "BIGINT", "TEXT", "BOOLEAN", "TIMESTAMP", "BYTES"
80    pub nullable: bool,
81}
82
83/// Parsed CREATE INDEX statement.
84#[derive(Debug, Clone)]
85pub struct ParsedCreateIndex {
86    pub index_name: String,
87    pub table_name: String,
88    pub columns: Vec<String>,
89}
90
91/// Parsed INSERT statement.
92#[derive(Debug, Clone)]
93pub struct ParsedInsert {
94    pub table: String,
95    pub columns: Vec<String>,
96    pub values: Vec<Vec<Value>>,        // Each Vec<Value> is one row
97    pub returning: Option<Vec<String>>, // Columns to return after insert
98}
99
100/// Parsed UPDATE statement.
101#[derive(Debug, Clone)]
102pub struct ParsedUpdate {
103    pub table: String,
104    pub assignments: Vec<(String, Value)>, // column = value pairs
105    pub predicates: Vec<Predicate>,
106    pub returning: Option<Vec<String>>, // Columns to return after update
107}
108
109/// Parsed DELETE statement.
110#[derive(Debug, Clone)]
111pub struct ParsedDelete {
112    pub table: String,
113    pub predicates: Vec<Predicate>,
114    pub returning: Option<Vec<String>>, // Columns to return after delete
115}
116
117/// Aggregate function in SELECT clause.
118#[derive(Debug, Clone, PartialEq, Eq)]
119pub enum AggregateFunction {
120    /// COUNT(*) - count all rows
121    CountStar,
122    /// COUNT(column) - count non-NULL values in column
123    Count(ColumnName),
124    /// SUM(column) - sum values in column
125    Sum(ColumnName),
126    /// AVG(column) - average values in column
127    Avg(ColumnName),
128    /// MIN(column) - minimum value in column
129    Min(ColumnName),
130    /// MAX(column) - maximum value in column
131    Max(ColumnName),
132}
133
134/// A comparison predicate from the WHERE clause.
135#[derive(Debug, Clone)]
136pub enum Predicate {
137    /// column = value or column = $N
138    Eq(ColumnName, PredicateValue),
139    /// column < value
140    Lt(ColumnName, PredicateValue),
141    /// column <= value
142    Le(ColumnName, PredicateValue),
143    /// column > value
144    Gt(ColumnName, PredicateValue),
145    /// column >= value
146    Ge(ColumnName, PredicateValue),
147    /// column IN (value, value, ...)
148    In(ColumnName, Vec<PredicateValue>),
149    /// column LIKE 'pattern'
150    Like(ColumnName, String),
151    /// column IS NULL
152    IsNull(ColumnName),
153    /// column IS NOT NULL
154    IsNotNull(ColumnName),
155    /// OR of multiple predicates
156    Or(Vec<Predicate>, Vec<Predicate>),
157}
158
159impl Predicate {
160    /// Returns the column name this predicate operates on.
161    ///
162    /// Returns None for OR predicates which may reference multiple columns.
163    #[allow(dead_code)]
164    pub fn column(&self) -> Option<&ColumnName> {
165        match self {
166            Predicate::Eq(col, _)
167            | Predicate::Lt(col, _)
168            | Predicate::Le(col, _)
169            | Predicate::Gt(col, _)
170            | Predicate::Ge(col, _)
171            | Predicate::In(col, _)
172            | Predicate::Like(col, _)
173            | Predicate::IsNull(col)
174            | Predicate::IsNotNull(col) => Some(col),
175            Predicate::Or(_, _) => None,
176        }
177    }
178}
179
180/// A value in a predicate (literal or parameter reference).
181#[derive(Debug, Clone)]
182pub enum PredicateValue {
183    /// Literal integer.
184    Int(i64),
185    /// Literal string.
186    String(String),
187    /// Literal boolean.
188    Bool(bool),
189    /// NULL literal.
190    Null,
191    /// Parameter placeholder ($1, $2, etc.) - 1-indexed.
192    Param(usize),
193    /// Literal value (for any type).
194    Literal(Value),
195}
196
197/// ORDER BY clause.
198#[derive(Debug, Clone)]
199pub struct OrderByClause {
200    /// Column to order by.
201    pub column: ColumnName,
202    /// Ascending (true) or descending (false).
203    pub ascending: bool,
204}
205
206// ============================================================================
207// Parser
208// ============================================================================
209
210/// Parses a SQL statement string into a `ParsedStatement`.
211pub fn parse_statement(sql: &str) -> Result<ParsedStatement> {
212    let dialect = GenericDialect {};
213    let statements =
214        Parser::parse_sql(&dialect, sql).map_err(|e| QueryError::ParseError(e.to_string()))?;
215
216    if statements.len() != 1 {
217        return Err(QueryError::ParseError(format!(
218            "expected exactly 1 statement, got {}",
219            statements.len()
220        )));
221    }
222
223    match &statements[0] {
224        Statement::Query(query) => {
225            let select = parse_select_query(query)?;
226            Ok(ParsedStatement::Select(select))
227        }
228        Statement::CreateTable(create_table) => {
229            let parsed = parse_create_table(create_table)?;
230            Ok(ParsedStatement::CreateTable(parsed))
231        }
232        Statement::Drop {
233            object_type,
234            names,
235            if_exists: _,
236            ..
237        } => {
238            if !matches!(object_type, sqlparser::ast::ObjectType::Table) {
239                return Err(QueryError::UnsupportedFeature(
240                    "only DROP TABLE is supported".to_string(),
241                ));
242            }
243            if names.len() != 1 {
244                return Err(QueryError::ParseError(
245                    "expected exactly 1 table in DROP TABLE".to_string(),
246                ));
247            }
248            let table_name = object_name_to_string(&names[0]);
249            Ok(ParsedStatement::DropTable(table_name))
250        }
251        Statement::CreateIndex(create_index) => {
252            let parsed = parse_create_index(create_index)?;
253            Ok(ParsedStatement::CreateIndex(parsed))
254        }
255        Statement::Insert(insert) => {
256            let parsed = parse_insert(insert)?;
257            Ok(ParsedStatement::Insert(parsed))
258        }
259        Statement::Update {
260            table,
261            assignments,
262            selection,
263            returning,
264            ..
265        } => {
266            let parsed = parse_update(table, assignments, selection.as_ref(), returning.as_ref())?;
267            Ok(ParsedStatement::Update(parsed))
268        }
269        Statement::Delete(delete) => {
270            let parsed = parse_delete_stmt(delete)?;
271            Ok(ParsedStatement::Delete(parsed))
272        }
273        other => Err(QueryError::UnsupportedFeature(format!(
274            "statement type not supported: {other:?}"
275        ))),
276    }
277}
278
279/// Legacy function for backward compatibility (queries only).
280pub fn parse_query(sql: &str) -> Result<ParsedSelect> {
281    match parse_statement(sql)? {
282        ParsedStatement::Select(select) => Ok(select),
283        _ => Err(QueryError::UnsupportedFeature(
284            "only SELECT queries are supported in parse_query()".to_string(),
285        )),
286    }
287}
288
289fn parse_select_query(query: &Query) -> Result<ParsedSelect> {
290    // Reject CTEs
291    if query.with.is_some() {
292        return Err(QueryError::UnsupportedFeature(
293            "WITH clauses (CTEs) are not supported".to_string(),
294        ));
295    }
296
297    let SetExpr::Select(select) = query.body.as_ref() else {
298        return Err(QueryError::UnsupportedFeature(
299            "only simple SELECT queries are supported".to_string(),
300        ));
301    };
302
303    let parsed_select = parse_select(select)?;
304
305    // Parse ORDER BY from query (not select)
306    let order_by = match &query.order_by {
307        Some(ob) => parse_order_by(ob)?,
308        None => vec![],
309    };
310
311    // Parse LIMIT from query
312    let limit = parse_limit(query.limit.as_ref())?;
313
314    Ok(ParsedSelect {
315        table: parsed_select.table,
316        columns: parsed_select.columns,
317        predicates: parsed_select.predicates,
318        order_by,
319        limit,
320        aggregates: parsed_select.aggregates,
321        group_by: parsed_select.group_by,
322        distinct: parsed_select.distinct,
323    })
324}
325
326fn parse_select(select: &Select) -> Result<ParsedSelect> {
327    // Parse DISTINCT flag
328    let distinct = select.distinct.is_some();
329
330    // Parse FROM - must be exactly one table
331    if select.from.len() != 1 {
332        return Err(QueryError::ParseError(format!(
333            "expected exactly 1 table in FROM clause, got {}",
334            select.from.len()
335        )));
336    }
337
338    let from = &select.from[0];
339
340    // Reject JOINs
341    if !from.joins.is_empty() {
342        return Err(QueryError::UnsupportedFeature(
343            "JOINs are not supported".to_string(),
344        ));
345    }
346
347    let table = match &from.relation {
348        sqlparser::ast::TableFactor::Table { name, .. } => object_name_to_string(name),
349        other => {
350            return Err(QueryError::UnsupportedFeature(format!(
351                "unsupported FROM clause: {other:?}"
352            )));
353        }
354    };
355
356    // Parse SELECT columns
357    let columns = parse_select_items(&select.projection)?;
358
359    // Parse WHERE predicates
360    let predicates = match &select.selection {
361        Some(expr) => parse_where_expr(expr)?,
362        None => vec![],
363    };
364
365    // Parse GROUP BY clause
366    let group_by = match &select.group_by {
367        sqlparser::ast::GroupByExpr::Expressions(exprs, _) if !exprs.is_empty() => {
368            parse_group_by_expr(exprs)?
369        }
370        sqlparser::ast::GroupByExpr::All(_) => {
371            return Err(QueryError::UnsupportedFeature(
372                "GROUP BY ALL is not supported".to_string(),
373            ));
374        }
375        sqlparser::ast::GroupByExpr::Expressions(_, _) => vec![],
376    };
377
378    // Parse aggregates from SELECT clause
379    let aggregates = parse_aggregates_from_select_items(&select.projection)?;
380
381    // Reject HAVING for now
382    if select.having.is_some() {
383        return Err(QueryError::UnsupportedFeature(
384            "HAVING is not supported yet".to_string(),
385        ));
386    }
387
388    Ok(ParsedSelect {
389        table,
390        columns,
391        predicates,
392        order_by: vec![],
393        limit: None,
394        aggregates,
395        group_by,
396        distinct,
397    })
398}
399
400fn parse_select_items(items: &[SelectItem]) -> Result<Option<Vec<ColumnName>>> {
401    let mut columns = Vec::new();
402
403    for item in items {
404        match item {
405            SelectItem::Wildcard(_) => {
406                // SELECT * - return None to indicate all columns
407                return Ok(None);
408            }
409            SelectItem::UnnamedExpr(Expr::Identifier(ident)) => {
410                columns.push(ColumnName::new(ident.value.clone()));
411            }
412            SelectItem::ExprWithAlias {
413                expr: Expr::Identifier(ident),
414                alias,
415            } => {
416                // For now, we ignore aliases and just use the column name
417                let _ = alias;
418                columns.push(ColumnName::new(ident.value.clone()));
419            }
420            SelectItem::UnnamedExpr(Expr::Function(_))
421            | SelectItem::ExprWithAlias {
422                expr: Expr::Function(_),
423                ..
424            } => {
425                // Aggregate functions are handled separately by parse_aggregates_from_select_items
426                // Skip them here
427            }
428            other => {
429                return Err(QueryError::UnsupportedFeature(format!(
430                    "unsupported SELECT item: {other:?}"
431                )));
432            }
433        }
434    }
435
436    Ok(Some(columns))
437}
438
439/// Parses aggregate functions from SELECT items.
440fn parse_aggregates_from_select_items(items: &[SelectItem]) -> Result<Vec<AggregateFunction>> {
441    let mut aggregates = Vec::new();
442
443    for item in items {
444        match item {
445            SelectItem::UnnamedExpr(expr) | SelectItem::ExprWithAlias { expr, .. } => {
446                if let Some(agg) = try_parse_aggregate(expr)? {
447                    aggregates.push(agg);
448                }
449            }
450            _ => {
451                // SELECT * has no aggregates; ignore other select items (Wildcard, QualifiedWildcard, etc.)
452            }
453        }
454    }
455
456    Ok(aggregates)
457}
458
459/// Tries to parse an expression as an aggregate function.
460/// Returns None if the expression is not an aggregate function.
461fn try_parse_aggregate(expr: &Expr) -> Result<Option<AggregateFunction>> {
462    match expr {
463        Expr::Function(func) => {
464            let func_name = func.name.to_string().to_uppercase();
465
466            // Extract function arguments from the FunctionArguments enum
467            let args = match &func.args {
468                sqlparser::ast::FunctionArguments::List(list) => &list.args,
469                _ => {
470                    return Err(QueryError::UnsupportedFeature(
471                        "non-list function arguments not supported".to_string(),
472                    ));
473                }
474            };
475
476            match func_name.as_str() {
477                "COUNT" => {
478                    // COUNT(*) or COUNT(column)
479                    if args.len() == 1 {
480                        match &args[0] {
481                            sqlparser::ast::FunctionArg::Unnamed(arg_expr) => match arg_expr {
482                                sqlparser::ast::FunctionArgExpr::Wildcard => {
483                                    Ok(Some(AggregateFunction::CountStar))
484                                }
485                                sqlparser::ast::FunctionArgExpr::Expr(Expr::Identifier(ident)) => {
486                                    Ok(Some(AggregateFunction::Count(ColumnName::new(
487                                        ident.value.clone(),
488                                    ))))
489                                }
490                                _ => Err(QueryError::UnsupportedFeature(
491                                    "COUNT with complex expression not supported".to_string(),
492                                )),
493                            },
494                            _ => Err(QueryError::UnsupportedFeature(
495                                "named function arguments not supported".to_string(),
496                            )),
497                        }
498                    } else {
499                        Err(QueryError::ParseError(format!(
500                            "COUNT expects 1 argument, got {}",
501                            args.len()
502                        )))
503                    }
504                }
505                "SUM" | "AVG" | "MIN" | "MAX" => {
506                    // SUM/AVG/MIN/MAX(column)
507                    if args.len() != 1 {
508                        return Err(QueryError::ParseError(format!(
509                            "{} expects 1 argument, got {}",
510                            func_name,
511                            args.len()
512                        )));
513                    }
514
515                    match &args[0] {
516                        sqlparser::ast::FunctionArg::Unnamed(arg_expr) => match arg_expr {
517                            sqlparser::ast::FunctionArgExpr::Expr(Expr::Identifier(ident)) => {
518                                let column = ColumnName::new(ident.value.clone());
519                                match func_name.as_str() {
520                                    "SUM" => Ok(Some(AggregateFunction::Sum(column))),
521                                    "AVG" => Ok(Some(AggregateFunction::Avg(column))),
522                                    "MIN" => Ok(Some(AggregateFunction::Min(column))),
523                                    "MAX" => Ok(Some(AggregateFunction::Max(column))),
524                                    _ => unreachable!(),
525                                }
526                            }
527                            _ => Err(QueryError::UnsupportedFeature(format!(
528                                "{func_name} with complex expression not supported"
529                            ))),
530                        },
531                        _ => Err(QueryError::UnsupportedFeature(
532                            "named function arguments not supported".to_string(),
533                        )),
534                    }
535                }
536                _ => {
537                    // Not an aggregate function
538                    Ok(None)
539                }
540            }
541        }
542        _ => {
543            // Not a function call
544            Ok(None)
545        }
546    }
547}
548
549/// Parses GROUP BY expressions into column names.
550fn parse_group_by_expr(exprs: &[Expr]) -> Result<Vec<ColumnName>> {
551    let mut columns = Vec::new();
552
553    for expr in exprs {
554        match expr {
555            Expr::Identifier(ident) => {
556                columns.push(ColumnName::new(ident.value.clone()));
557            }
558            _ => {
559                return Err(QueryError::UnsupportedFeature(
560                    "complex GROUP BY expressions not supported".to_string(),
561                ));
562            }
563        }
564    }
565
566    Ok(columns)
567}
568
569/// Maximum nesting depth for WHERE clause expressions.
570///
571/// Prevents stack overflow from deeply nested queries like:
572/// `WHERE ((((...(a = 1)...))))`
573///
574/// 100 levels is sufficient for all practical queries while preventing
575/// malicious or pathological input from exhausting the stack.
576const MAX_WHERE_DEPTH: usize = 100;
577
578fn parse_where_expr(expr: &Expr) -> Result<Vec<Predicate>> {
579    parse_where_expr_inner(expr, 0)
580}
581
582fn parse_where_expr_inner(expr: &Expr, depth: usize) -> Result<Vec<Predicate>> {
583    if depth >= MAX_WHERE_DEPTH {
584        return Err(QueryError::ParseError(format!(
585            "WHERE clause nesting exceeds maximum depth of {MAX_WHERE_DEPTH}"
586        )));
587    }
588
589    match expr {
590        // AND combines multiple predicates
591        Expr::BinaryOp {
592            left,
593            op: BinaryOperator::And,
594            right,
595        } => {
596            let mut predicates = parse_where_expr_inner(left, depth + 1)?;
597            predicates.extend(parse_where_expr_inner(right, depth + 1)?);
598            Ok(predicates)
599        }
600
601        // OR creates a disjunction
602        Expr::BinaryOp {
603            left,
604            op: BinaryOperator::Or,
605            right,
606        } => {
607            let left_preds = parse_where_expr_inner(left, depth + 1)?;
608            let right_preds = parse_where_expr_inner(right, depth + 1)?;
609            Ok(vec![Predicate::Or(left_preds, right_preds)])
610        }
611
612        // LIKE pattern matching
613        Expr::Like {
614            expr,
615            pattern,
616            negated,
617            ..
618        } => {
619            if *negated {
620                return Err(QueryError::UnsupportedFeature(
621                    "NOT LIKE is not supported".to_string(),
622                ));
623            }
624
625            let column = expr_to_column(expr)?;
626            let pattern_value = expr_to_predicate_value(pattern)?;
627
628            match pattern_value {
629                PredicateValue::String(pattern_str)
630                | PredicateValue::Literal(Value::Text(pattern_str)) => {
631                    Ok(vec![Predicate::Like(column, pattern_str)])
632                }
633                _ => Err(QueryError::UnsupportedFeature(
634                    "LIKE pattern must be a string literal".to_string(),
635                )),
636            }
637        }
638
639        // IS NULL / IS NOT NULL
640        Expr::IsNull(expr) => {
641            let column = expr_to_column(expr)?;
642            Ok(vec![Predicate::IsNull(column)])
643        }
644
645        Expr::IsNotNull(expr) => {
646            let column = expr_to_column(expr)?;
647            Ok(vec![Predicate::IsNotNull(column)])
648        }
649
650        // Comparison operators
651        Expr::BinaryOp { left, op, right } => {
652            let predicate = parse_comparison(left, op, right)?;
653            Ok(vec![predicate])
654        }
655
656        // IN list
657        Expr::InList {
658            expr,
659            list,
660            negated,
661        } => {
662            if *negated {
663                return Err(QueryError::UnsupportedFeature(
664                    "NOT IN is not supported".to_string(),
665                ));
666            }
667
668            let column = expr_to_column(expr)?;
669            let values: Result<Vec<_>> = list.iter().map(expr_to_predicate_value).collect();
670            Ok(vec![Predicate::In(column, values?)])
671        }
672
673        // Parenthesized expression
674        Expr::Nested(inner) => parse_where_expr_inner(inner, depth + 1),
675
676        other => Err(QueryError::UnsupportedFeature(format!(
677            "unsupported WHERE expression: {other:?}"
678        ))),
679    }
680}
681
682fn parse_comparison(left: &Expr, op: &BinaryOperator, right: &Expr) -> Result<Predicate> {
683    let column = expr_to_column(left)?;
684    let value = expr_to_predicate_value(right)?;
685
686    match op {
687        BinaryOperator::Eq => Ok(Predicate::Eq(column, value)),
688        BinaryOperator::Lt => Ok(Predicate::Lt(column, value)),
689        BinaryOperator::LtEq => Ok(Predicate::Le(column, value)),
690        BinaryOperator::Gt => Ok(Predicate::Gt(column, value)),
691        BinaryOperator::GtEq => Ok(Predicate::Ge(column, value)),
692        other => Err(QueryError::UnsupportedFeature(format!(
693            "unsupported operator: {other:?}"
694        ))),
695    }
696}
697
698fn expr_to_column(expr: &Expr) -> Result<ColumnName> {
699    match expr {
700        Expr::Identifier(ident) => Ok(ColumnName::new(ident.value.clone())),
701        Expr::CompoundIdentifier(idents) if idents.len() == 2 => {
702            // table.column - ignore table for now
703            Ok(ColumnName::new(idents[1].value.clone()))
704        }
705        other => Err(QueryError::UnsupportedFeature(format!(
706            "expected column name, got {other:?}"
707        ))),
708    }
709}
710
711fn expr_to_predicate_value(expr: &Expr) -> Result<PredicateValue> {
712    match expr {
713        Expr::Value(SqlValue::Number(n, _)) => {
714            let value = parse_number_literal(n)?;
715            match value {
716                Value::BigInt(v) => Ok(PredicateValue::Int(v)),
717                Value::Decimal(_, _) => Ok(PredicateValue::Literal(value)),
718                _ => unreachable!("parse_number_literal only returns BigInt or Decimal"),
719            }
720        }
721        Expr::Value(SqlValue::SingleQuotedString(s) | SqlValue::DoubleQuotedString(s)) => {
722            Ok(PredicateValue::String(s.clone()))
723        }
724        Expr::Value(SqlValue::Boolean(b)) => Ok(PredicateValue::Bool(*b)),
725        Expr::Value(SqlValue::Null) => Ok(PredicateValue::Null),
726        Expr::Value(SqlValue::Placeholder(p)) => {
727            // Parse $1, $2, etc.
728            if let Some(num_str) = p.strip_prefix('$') {
729                let idx: usize = num_str.parse().map_err(|_| {
730                    QueryError::ParseError(format!("invalid parameter placeholder: {p}"))
731                })?;
732                // SQL parameters are 1-indexed, reject $0
733                if idx == 0 {
734                    return Err(QueryError::ParseError(
735                        "parameter indices start at $1, not $0".to_string(),
736                    ));
737                }
738                Ok(PredicateValue::Param(idx))
739            } else {
740                Err(QueryError::ParseError(format!(
741                    "unsupported placeholder format: {p}"
742                )))
743            }
744        }
745        Expr::UnaryOp {
746            op: sqlparser::ast::UnaryOperator::Minus,
747            expr,
748        } => {
749            // Handle negative numbers
750            if let Expr::Value(SqlValue::Number(n, _)) = expr.as_ref() {
751                let value = parse_number_literal(n)?;
752                match value {
753                    Value::BigInt(v) => Ok(PredicateValue::Int(-v)),
754                    Value::Decimal(v, scale) => {
755                        Ok(PredicateValue::Literal(Value::Decimal(-v, scale)))
756                    }
757                    _ => unreachable!("parse_number_literal only returns BigInt or Decimal"),
758                }
759            } else {
760                Err(QueryError::UnsupportedFeature(format!(
761                    "unsupported unary minus operand: {expr:?}"
762                )))
763            }
764        }
765        other => Err(QueryError::UnsupportedFeature(format!(
766            "unsupported value expression: {other:?}"
767        ))),
768    }
769}
770
771fn parse_order_by(order_by: &sqlparser::ast::OrderBy) -> Result<Vec<OrderByClause>> {
772    let mut clauses = Vec::new();
773
774    for expr in &order_by.exprs {
775        clauses.push(parse_order_by_expr(expr)?);
776    }
777
778    Ok(clauses)
779}
780
781fn parse_order_by_expr(expr: &OrderByExpr) -> Result<OrderByClause> {
782    let column = match &expr.expr {
783        Expr::Identifier(ident) => ColumnName::new(ident.value.clone()),
784        other => {
785            return Err(QueryError::UnsupportedFeature(format!(
786                "unsupported ORDER BY expression: {other:?}"
787            )));
788        }
789    };
790
791    let ascending = expr.asc.unwrap_or(true);
792
793    Ok(OrderByClause { column, ascending })
794}
795
796fn parse_limit(limit: Option<&Expr>) -> Result<Option<usize>> {
797    match limit {
798        None => Ok(None),
799        Some(Expr::Value(SqlValue::Number(n, _))) => {
800            let v: usize = n
801                .parse()
802                .map_err(|_| QueryError::ParseError(format!("invalid LIMIT value: {n}")))?;
803            Ok(Some(v))
804        }
805        Some(other) => Err(QueryError::UnsupportedFeature(format!(
806            "unsupported LIMIT expression: {other:?}"
807        ))),
808    }
809}
810
811fn object_name_to_string(name: &ObjectName) -> String {
812    name.0
813        .iter()
814        .map(|i: &Ident| i.value.clone())
815        .collect::<Vec<_>>()
816        .join(".")
817}
818
819// ============================================================================
820// DDL Parsers
821// ============================================================================
822
823fn parse_create_table(create_table: &sqlparser::ast::CreateTable) -> Result<ParsedCreateTable> {
824    let table_name = object_name_to_string(&create_table.name);
825
826    // Extract column definitions
827    let mut columns = Vec::new();
828    for col_def in &create_table.columns {
829        let parsed_col = parse_column_def(col_def)?;
830        columns.push(parsed_col);
831    }
832
833    // Extract primary key from constraints
834    let mut primary_key = Vec::new();
835    for constraint in &create_table.constraints {
836        if let sqlparser::ast::TableConstraint::PrimaryKey {
837            columns: pk_cols, ..
838        } = constraint
839        {
840            for col in pk_cols {
841                primary_key.push(col.value.clone());
842            }
843        }
844    }
845
846    // If no explicit PRIMARY KEY constraint, check for PRIMARY KEY in column definitions
847    if primary_key.is_empty() {
848        for col_def in &create_table.columns {
849            for option in &col_def.options {
850                if matches!(
851                    &option.option,
852                    sqlparser::ast::ColumnOption::Unique { is_primary, .. } if *is_primary
853                ) {
854                    primary_key.push(col_def.name.value.clone());
855                }
856            }
857        }
858    }
859
860    Ok(ParsedCreateTable {
861        table_name,
862        columns,
863        primary_key,
864    })
865}
866
867fn parse_column_def(col_def: &SqlColumnDef) -> Result<ParsedColumn> {
868    let name = col_def.name.value.clone();
869
870    // Map SQL data type to string
871    // For DECIMAL, we need to handle precision/scale specially
872    let data_type = match &col_def.data_type {
873        // Integer types
874        SqlDataType::TinyInt(_) => "TINYINT".to_string(),
875        SqlDataType::SmallInt(_) => "SMALLINT".to_string(),
876        SqlDataType::Int(_) | SqlDataType::Integer(_) => "INTEGER".to_string(),
877        SqlDataType::BigInt(_) => "BIGINT".to_string(),
878
879        // Numeric types
880        SqlDataType::Real | SqlDataType::Float(_) | SqlDataType::Double(_) => "REAL".to_string(),
881        SqlDataType::Decimal(precision_opt) => match precision_opt {
882            sqlparser::ast::ExactNumberInfo::PrecisionAndScale(p, s) => {
883                format!("DECIMAL({p},{s})")
884            }
885            sqlparser::ast::ExactNumberInfo::Precision(p) => {
886                format!("DECIMAL({p},0)")
887            }
888            sqlparser::ast::ExactNumberInfo::None => "DECIMAL(18,2)".to_string(),
889        },
890
891        // String types
892        SqlDataType::Text | SqlDataType::Varchar(_) | SqlDataType::String(_) => "TEXT".to_string(),
893
894        // Binary types
895        SqlDataType::Binary(_) | SqlDataType::Varbinary(_) | SqlDataType::Blob(_) => {
896            "BYTES".to_string()
897        }
898
899        // Boolean type
900        SqlDataType::Boolean | SqlDataType::Bool => "BOOLEAN".to_string(),
901
902        // Date/Time types
903        SqlDataType::Date => "DATE".to_string(),
904        SqlDataType::Time(_, _) => "TIME".to_string(),
905        SqlDataType::Timestamp(_, _) => "TIMESTAMP".to_string(),
906
907        // Structured types
908        SqlDataType::Uuid => "UUID".to_string(),
909        SqlDataType::JSON => "JSON".to_string(),
910
911        other => {
912            return Err(QueryError::UnsupportedFeature(format!(
913                "unsupported data type: {other:?}"
914            )));
915        }
916    };
917
918    // Check for NOT NULL constraint
919    let mut nullable = true;
920    for option in &col_def.options {
921        if matches!(option.option, sqlparser::ast::ColumnOption::NotNull) {
922            nullable = false;
923        }
924    }
925
926    Ok(ParsedColumn {
927        name,
928        data_type,
929        nullable,
930    })
931}
932
933fn parse_create_index(create_index: &sqlparser::ast::CreateIndex) -> Result<ParsedCreateIndex> {
934    let index_name = match &create_index.name {
935        Some(name) => object_name_to_string(name),
936        None => {
937            return Err(QueryError::ParseError(
938                "CREATE INDEX requires an index name".to_string(),
939            ));
940        }
941    };
942
943    let table_name = object_name_to_string(&create_index.table_name);
944
945    let mut columns = Vec::new();
946    for col in &create_index.columns {
947        columns.push(col.expr.to_string());
948    }
949
950    Ok(ParsedCreateIndex {
951        index_name,
952        table_name,
953        columns,
954    })
955}
956
957// ============================================================================
958// DML Parsers
959// ============================================================================
960
961fn parse_insert(insert: &sqlparser::ast::Insert) -> Result<ParsedInsert> {
962    // TableObject might be ObjectName directly - convert to string
963    let table = insert.table.to_string();
964
965    // Extract column names
966    let columns: Vec<String> = insert.columns.iter().map(|c| c.value.clone()).collect();
967
968    // Extract values from all rows
969    let values = match insert.source.as_ref().map(|s| s.body.as_ref()) {
970        Some(SetExpr::Values(values)) => {
971            let mut all_rows = Vec::new();
972            for row in &values.rows {
973                let mut parsed_row = Vec::new();
974                for expr in row {
975                    let val = expr_to_value(expr)?;
976                    parsed_row.push(val);
977                }
978                all_rows.push(parsed_row);
979            }
980            all_rows
981        }
982        _ => {
983            return Err(QueryError::UnsupportedFeature(
984                "only VALUES clause is supported in INSERT".to_string(),
985            ));
986        }
987    };
988
989    // Parse RETURNING clause
990    let returning = parse_returning(insert.returning.as_ref())?;
991
992    Ok(ParsedInsert {
993        table,
994        columns,
995        values,
996        returning,
997    })
998}
999
1000fn parse_update(
1001    table: &sqlparser::ast::TableWithJoins,
1002    assignments: &[sqlparser::ast::Assignment],
1003    selection: Option<&Expr>,
1004    returning: Option<&Vec<SelectItem>>,
1005) -> Result<ParsedUpdate> {
1006    let table_name = match &table.relation {
1007        sqlparser::ast::TableFactor::Table { name, .. } => object_name_to_string(name),
1008        other => {
1009            return Err(QueryError::UnsupportedFeature(format!(
1010                "unsupported table in UPDATE: {other:?}"
1011            )));
1012        }
1013    };
1014
1015    // Parse assignments (SET clauses)
1016    let mut parsed_assignments = Vec::new();
1017    for assignment in assignments {
1018        let col_name = assignment.target.to_string();
1019        let value = expr_to_value(&assignment.value)?;
1020        parsed_assignments.push((col_name, value));
1021    }
1022
1023    // Parse WHERE clause
1024    let predicates = match selection {
1025        Some(expr) => parse_where_expr(expr)?,
1026        None => vec![],
1027    };
1028
1029    // Parse RETURNING clause
1030    let returning_cols = parse_returning(returning)?;
1031
1032    Ok(ParsedUpdate {
1033        table: table_name,
1034        assignments: parsed_assignments,
1035        predicates,
1036        returning: returning_cols,
1037    })
1038}
1039
1040fn parse_delete_stmt(delete: &sqlparser::ast::Delete) -> Result<ParsedDelete> {
1041    // In sqlparser 0.54, DELETE FROM uses a single `from` table
1042    use sqlparser::ast::FromTable;
1043
1044    let table_name = match &delete.from {
1045        FromTable::WithFromKeyword(tables) => {
1046            if tables.len() != 1 {
1047                return Err(QueryError::ParseError(
1048                    "expected exactly 1 table in DELETE FROM".to_string(),
1049                ));
1050            }
1051
1052            match &tables[0].relation {
1053                sqlparser::ast::TableFactor::Table { name, .. } => object_name_to_string(name),
1054                _ => {
1055                    return Err(QueryError::ParseError(
1056                        "DELETE only supports simple table names".to_string(),
1057                    ));
1058                }
1059            }
1060        }
1061        FromTable::WithoutKeyword(tables) => {
1062            if tables.len() != 1 {
1063                return Err(QueryError::ParseError(
1064                    "expected exactly 1 table in DELETE".to_string(),
1065                ));
1066            }
1067
1068            match &tables[0].relation {
1069                sqlparser::ast::TableFactor::Table { name, .. } => object_name_to_string(name),
1070                _ => {
1071                    return Err(QueryError::ParseError(
1072                        "DELETE only supports simple table names".to_string(),
1073                    ));
1074                }
1075            }
1076        }
1077    };
1078
1079    // Parse WHERE clause
1080    let predicates = match &delete.selection {
1081        Some(expr) => parse_where_expr(expr)?,
1082        None => vec![],
1083    };
1084
1085    // Parse RETURNING clause
1086    let returning_cols = parse_returning(delete.returning.as_ref())?;
1087
1088    Ok(ParsedDelete {
1089        table: table_name,
1090        predicates,
1091        returning: returning_cols,
1092    })
1093}
1094
1095/// Parses a RETURNING clause into a list of column names.
1096fn parse_returning(returning: Option<&Vec<SelectItem>>) -> Result<Option<Vec<String>>> {
1097    match returning {
1098        None => Ok(None),
1099        Some(items) => {
1100            let mut columns = Vec::new();
1101            for item in items {
1102                match item {
1103                    SelectItem::UnnamedExpr(Expr::Identifier(ident)) => {
1104                        columns.push(ident.value.clone());
1105                    }
1106                    SelectItem::UnnamedExpr(Expr::CompoundIdentifier(parts)) => {
1107                        // Handle table.column format - just take the column name
1108                        if let Some(last) = parts.last() {
1109                            columns.push(last.value.clone());
1110                        } else {
1111                            return Err(QueryError::ParseError(
1112                                "invalid column in RETURNING clause".to_string(),
1113                            ));
1114                        }
1115                    }
1116                    _ => {
1117                        return Err(QueryError::UnsupportedFeature(
1118                            "only simple column names supported in RETURNING clause".to_string(),
1119                        ));
1120                    }
1121                }
1122            }
1123            Ok(Some(columns))
1124        }
1125    }
1126}
1127
1128/// Parses a number literal as either an integer or decimal.
1129///
1130/// Uses `rust_decimal` for robust decimal parsing (handles all edge cases correctly).
1131fn parse_number_literal(n: &str) -> Result<Value> {
1132    use rust_decimal::Decimal;
1133    use std::str::FromStr;
1134
1135    if n.contains('.') {
1136        // Parse as DECIMAL using rust_decimal for correct handling
1137        let decimal = Decimal::from_str(n)
1138            .map_err(|e| QueryError::ParseError(format!("invalid decimal '{n}': {e}")))?;
1139
1140        // Get scale (number of decimal places)
1141        let scale = decimal.scale() as u8;
1142
1143        if scale > 38 {
1144            return Err(QueryError::ParseError(format!(
1145                "decimal scale too large (max 38): {n}"
1146            )));
1147        }
1148
1149        // Convert to i128 representation: mantissa * 10^scale
1150        // rust_decimal stores internally as i128 mantissa with scale
1151        let mantissa = decimal.mantissa();
1152
1153        Ok(Value::Decimal(mantissa, scale))
1154    } else {
1155        // Parse as integer (BigInt)
1156        let v: i64 = n
1157            .parse()
1158            .map_err(|_| QueryError::ParseError(format!("invalid integer: {n}")))?;
1159        Ok(Value::BigInt(v))
1160    }
1161}
1162
1163/// Converts a SQL expression to a Value.
1164fn expr_to_value(expr: &Expr) -> Result<Value> {
1165    match expr {
1166        Expr::Value(SqlValue::Number(n, _)) => parse_number_literal(n),
1167        Expr::Value(SqlValue::SingleQuotedString(s) | SqlValue::DoubleQuotedString(s)) => {
1168            Ok(Value::Text(s.clone()))
1169        }
1170        Expr::Value(SqlValue::Boolean(b)) => Ok(Value::Boolean(*b)),
1171        Expr::Value(SqlValue::Null) => Ok(Value::Null),
1172        Expr::Value(SqlValue::Placeholder(p)) => {
1173            // Parse $1, $2, etc.
1174            if let Some(num_str) = p.strip_prefix('$') {
1175                let idx: usize = num_str.parse().map_err(|_| {
1176                    QueryError::ParseError(format!("invalid parameter placeholder: {p}"))
1177                })?;
1178                // SQL parameters are 1-indexed, reject $0
1179                if idx == 0 {
1180                    return Err(QueryError::ParseError(
1181                        "parameter indices start at $1, not $0".to_string(),
1182                    ));
1183                }
1184                Ok(Value::Placeholder(idx))
1185            } else {
1186                Err(QueryError::ParseError(format!(
1187                    "unsupported placeholder format: {p}"
1188                )))
1189            }
1190        }
1191        Expr::UnaryOp {
1192            op: sqlparser::ast::UnaryOperator::Minus,
1193            expr,
1194        } => {
1195            // Handle negative numbers
1196            if let Expr::Value(SqlValue::Number(n, _)) = expr.as_ref() {
1197                let value = parse_number_literal(n)?;
1198                match value {
1199                    Value::BigInt(v) => Ok(Value::BigInt(-v)),
1200                    Value::Decimal(v, scale) => Ok(Value::Decimal(-v, scale)),
1201                    _ => unreachable!("parse_number_literal only returns BigInt or Decimal"),
1202                }
1203            } else {
1204                Err(QueryError::UnsupportedFeature(format!(
1205                    "unsupported unary minus operand: {expr:?}"
1206                )))
1207            }
1208        }
1209        other => Err(QueryError::UnsupportedFeature(format!(
1210            "unsupported value expression: {other:?}"
1211        ))),
1212    }
1213}
1214
1215#[cfg(test)]
1216mod tests {
1217    use super::*;
1218
1219    #[test]
1220    fn test_parse_simple_select() {
1221        let result = parse_query("SELECT id, name FROM users").unwrap();
1222        assert_eq!(result.table, "users");
1223        assert_eq!(
1224            result.columns,
1225            Some(vec![ColumnName::new("id"), ColumnName::new("name")])
1226        );
1227        assert!(result.predicates.is_empty());
1228    }
1229
1230    #[test]
1231    fn test_parse_select_star() {
1232        let result = parse_query("SELECT * FROM users").unwrap();
1233        assert_eq!(result.table, "users");
1234        assert!(result.columns.is_none());
1235    }
1236
1237    #[test]
1238    fn test_parse_where_eq() {
1239        let result = parse_query("SELECT * FROM users WHERE id = 42").unwrap();
1240        assert_eq!(result.predicates.len(), 1);
1241        match &result.predicates[0] {
1242            Predicate::Eq(col, PredicateValue::Int(42)) => {
1243                assert_eq!(col.as_str(), "id");
1244            }
1245            other => panic!("unexpected predicate: {other:?}"),
1246        }
1247    }
1248
1249    #[test]
1250    fn test_parse_where_string() {
1251        let result = parse_query("SELECT * FROM users WHERE name = 'alice'").unwrap();
1252        match &result.predicates[0] {
1253            Predicate::Eq(col, PredicateValue::String(s)) => {
1254                assert_eq!(col.as_str(), "name");
1255                assert_eq!(s, "alice");
1256            }
1257            other => panic!("unexpected predicate: {other:?}"),
1258        }
1259    }
1260
1261    #[test]
1262    fn test_parse_where_and() {
1263        let result = parse_query("SELECT * FROM users WHERE id = 1 AND name = 'bob'").unwrap();
1264        assert_eq!(result.predicates.len(), 2);
1265    }
1266
1267    #[test]
1268    fn test_parse_where_in() {
1269        let result = parse_query("SELECT * FROM users WHERE id IN (1, 2, 3)").unwrap();
1270        match &result.predicates[0] {
1271            Predicate::In(col, values) => {
1272                assert_eq!(col.as_str(), "id");
1273                assert_eq!(values.len(), 3);
1274            }
1275            other => panic!("unexpected predicate: {other:?}"),
1276        }
1277    }
1278
1279    #[test]
1280    fn test_parse_order_by() {
1281        let result = parse_query("SELECT * FROM users ORDER BY name ASC, id DESC").unwrap();
1282        assert_eq!(result.order_by.len(), 2);
1283        assert_eq!(result.order_by[0].column.as_str(), "name");
1284        assert!(result.order_by[0].ascending);
1285        assert_eq!(result.order_by[1].column.as_str(), "id");
1286        assert!(!result.order_by[1].ascending);
1287    }
1288
1289    #[test]
1290    fn test_parse_limit() {
1291        let result = parse_query("SELECT * FROM users LIMIT 10").unwrap();
1292        assert_eq!(result.limit, Some(10));
1293    }
1294
1295    #[test]
1296    fn test_parse_param() {
1297        let result = parse_query("SELECT * FROM users WHERE id = $1").unwrap();
1298        match &result.predicates[0] {
1299            Predicate::Eq(_, PredicateValue::Param(1)) => {}
1300            other => panic!("unexpected predicate: {other:?}"),
1301        }
1302    }
1303
1304    #[test]
1305    fn test_reject_join() {
1306        let result = parse_query("SELECT * FROM users JOIN orders ON users.id = orders.user_id");
1307        assert!(result.is_err());
1308    }
1309
1310    #[test]
1311    fn test_reject_subquery() {
1312        let result = parse_query("SELECT * FROM (SELECT * FROM users)");
1313        assert!(result.is_err());
1314    }
1315
1316    #[test]
1317    fn test_where_depth_within_limit() {
1318        // Test reasonable nesting depth (stays within sqlparser limits)
1319        // Build a query with nested AND/OR to test our depth tracking
1320        let mut sql = String::from("SELECT * FROM users WHERE ");
1321        for i in 0..10 {
1322            if i > 0 {
1323                sql.push_str(" AND ");
1324            }
1325            sql.push('(');
1326            sql.push_str("id = ");
1327            sql.push_str(&i.to_string());
1328            sql.push(')');
1329        }
1330
1331        let result = parse_query(&sql);
1332        assert!(
1333            result.is_ok(),
1334            "Moderate nesting should succeed, but got: {result:?}"
1335        );
1336    }
1337
1338    #[test]
1339    fn test_where_depth_nested_parens() {
1340        // Test nested parentheses (this will hit sqlparser limit before ours)
1341        // Just verify that excessive nesting is rejected by some limit
1342        let mut sql = String::from("SELECT * FROM users WHERE ");
1343        for _ in 0..200 {
1344            sql.push('(');
1345        }
1346        sql.push_str("id = 1");
1347        for _ in 0..200 {
1348            sql.push(')');
1349        }
1350
1351        let result = parse_query(&sql);
1352        assert!(
1353            result.is_err(),
1354            "Excessive parenthesis nesting should be rejected"
1355        );
1356    }
1357
1358    #[test]
1359    fn test_where_depth_complex_and_or() {
1360        // Test complex AND/OR nesting patterns
1361        let sql = "SELECT * FROM users WHERE \
1362                   ((id = 1 AND name = 'a') OR (id = 2 AND name = 'b')) AND \
1363                   ((age > 10 AND age < 20) OR (age > 30 AND age < 40))";
1364
1365        let result = parse_query(sql);
1366        assert!(result.is_ok(), "Complex AND/OR should succeed");
1367    }
1368}