Skip to main content

mdql_core/
query_parser.rs

1//! Hand-written recursive descent parser for the MDQL SQL subset.
2
3use regex::Regex;
4use std::sync::LazyLock;
5
6use crate::errors::MdqlError;
7
8// ── AST nodes ──────────────────────────────────────────────────────────────
9
10#[derive(Debug, Clone, PartialEq)]
11pub enum ArithOp {
12    Add,
13    Sub,
14    Mul,
15    Div,
16    Mod,
17}
18
19#[derive(Debug, Clone, PartialEq)]
20pub enum Expr {
21    Literal(SqlValue),
22    Column(String),
23    BinaryOp { left: Box<Expr>, op: ArithOp, right: Box<Expr> },
24    UnaryMinus(Box<Expr>),
25    Case { whens: Vec<(WhereClause, Box<Expr>)>, else_expr: Option<Box<Expr>> },
26}
27
28impl Expr {
29    /// If the expression is a simple column reference, return the name.
30    pub fn as_column(&self) -> Option<&str> {
31        if let Expr::Column(name) = self { Some(name) } else { None }
32    }
33
34    /// A display name for this expression (used as output column name).
35    pub fn display_name(&self) -> String {
36        match self {
37            Expr::Literal(SqlValue::Int(n)) => n.to_string(),
38            Expr::Literal(SqlValue::Float(f)) => f.to_string(),
39            Expr::Literal(SqlValue::String(s)) => format!("'{}'", s),
40            Expr::Literal(SqlValue::Null) => "NULL".to_string(),
41            Expr::Literal(SqlValue::List(_)) => "list".to_string(),
42            Expr::Column(name) => name.clone(),
43            Expr::BinaryOp { left, op, right } => {
44                let op_str = match op {
45                    ArithOp::Add => "+",
46                    ArithOp::Sub => "-",
47                    ArithOp::Mul => "*",
48                    ArithOp::Div => "/",
49                    ArithOp::Mod => "%",
50                };
51                format!("{} {} {}", left.display_name(), op_str, right.display_name())
52            }
53            Expr::UnaryMinus(inner) => format!("-{}", inner.display_name()),
54            Expr::Case { .. } => "CASE".to_string(),
55        }
56    }
57}
58
59#[derive(Debug, Clone, PartialEq)]
60pub struct OrderSpec {
61    pub column: String,
62    pub expr: Option<Expr>,
63    pub descending: bool,
64}
65
66#[derive(Debug, Clone, PartialEq)]
67pub struct Comparison {
68    pub column: String,
69    pub op: String,
70    pub value: Option<SqlValue>,
71    pub left_expr: Option<Expr>,
72    pub right_expr: Option<Expr>,
73}
74
75#[derive(Debug, Clone, PartialEq)]
76pub struct BoolOp {
77    pub op: String, // "AND" or "OR"
78    pub left: Box<WhereClause>,
79    pub right: Box<WhereClause>,
80}
81
82#[derive(Debug, Clone, PartialEq)]
83pub enum WhereClause {
84    Comparison(Comparison),
85    BoolOp(BoolOp),
86}
87
88#[derive(Debug, Clone, PartialEq)]
89pub enum SqlValue {
90    String(String),
91    Int(i64),
92    Float(f64),
93    Null,
94    List(Vec<SqlValue>),
95}
96
97#[derive(Debug, Clone, PartialEq)]
98pub struct JoinClause {
99    pub table: String,
100    pub alias: Option<String>,
101    pub left_col: String,
102    pub right_col: String,
103}
104
105#[derive(Debug, Clone, PartialEq)]
106pub enum AggFunc {
107    Count,
108    Sum,
109    Avg,
110    Min,
111    Max,
112}
113
114#[derive(Debug, Clone, PartialEq)]
115pub enum SelectExpr {
116    Column(String),
117    Aggregate { func: AggFunc, arg: String, arg_expr: Option<Expr>, alias: Option<String> },
118    Expr { expr: Expr, alias: Option<String> },
119}
120
121impl SelectExpr {
122    pub fn output_name(&self) -> String {
123        match self {
124            SelectExpr::Column(name) => name.clone(),
125            SelectExpr::Aggregate { func, arg, alias, .. } => {
126                if let Some(a) = alias {
127                    a.clone()
128                } else {
129                    let func_name = match func {
130                        AggFunc::Count => "COUNT",
131                        AggFunc::Sum => "SUM",
132                        AggFunc::Avg => "AVG",
133                        AggFunc::Min => "MIN",
134                        AggFunc::Max => "MAX",
135                    };
136                    format!("{}({})", func_name, arg)
137                }
138            }
139            SelectExpr::Expr { expr, alias } => {
140                alias.clone().unwrap_or_else(|| expr.display_name())
141            }
142        }
143    }
144
145    pub fn is_aggregate(&self) -> bool {
146        matches!(self, SelectExpr::Aggregate { .. })
147    }
148}
149
150#[derive(Debug, Clone, PartialEq)]
151pub struct SelectQuery {
152    pub columns: ColumnList,
153    pub table: String,
154    pub table_alias: Option<String>,
155    pub joins: Vec<JoinClause>,
156    pub where_clause: Option<WhereClause>,
157    pub group_by: Option<Vec<String>>,
158    pub order_by: Option<Vec<OrderSpec>>,
159    pub limit: Option<i64>,
160}
161
162#[derive(Debug, Clone, PartialEq)]
163pub enum ColumnList {
164    All,
165    Named(Vec<SelectExpr>),
166}
167
168#[derive(Debug, Clone, PartialEq)]
169pub struct InsertQuery {
170    pub table: String,
171    pub columns: Vec<String>,
172    pub values: Vec<SqlValue>,
173}
174
175#[derive(Debug, Clone, PartialEq)]
176pub struct UpdateQuery {
177    pub table: String,
178    pub assignments: Vec<(String, SqlValue)>,
179    pub where_clause: Option<WhereClause>,
180}
181
182#[derive(Debug, Clone, PartialEq)]
183pub struct DeleteQuery {
184    pub table: String,
185    pub where_clause: Option<WhereClause>,
186}
187
188#[derive(Debug, Clone, PartialEq)]
189pub struct AlterRenameFieldQuery {
190    pub table: String,
191    pub old_name: String,
192    pub new_name: String,
193}
194
195#[derive(Debug, Clone, PartialEq)]
196pub struct AlterDropFieldQuery {
197    pub table: String,
198    pub field_name: String,
199}
200
201#[derive(Debug, Clone, PartialEq)]
202pub struct AlterMergeFieldsQuery {
203    pub table: String,
204    pub sources: Vec<String>,
205    pub into: String,
206}
207
208#[derive(Debug, Clone, PartialEq)]
209pub enum Statement {
210    Select(SelectQuery),
211    Insert(InsertQuery),
212    Update(UpdateQuery),
213    Delete(DeleteQuery),
214    AlterRename(AlterRenameFieldQuery),
215    AlterDrop(AlterDropFieldQuery),
216    AlterMerge(AlterMergeFieldsQuery),
217}
218
219// ── Tokenizer ──────────────────────────────────────────────────────────────
220
221static KEYWORDS: &[&str] = &[
222    "SELECT", "FROM", "WHERE", "AND", "OR", "ORDER", "BY",
223    "ASC", "DESC", "LIMIT", "LIKE", "IN", "IS", "NOT", "NULL",
224    "JOIN", "ON", "AS", "GROUP",
225    "INSERT", "INTO", "VALUES", "UPDATE", "SET", "DELETE",
226    "ALTER", "TABLE", "RENAME", "FIELD", "TO", "DROP", "MERGE", "FIELDS",
227    "CASE", "WHEN", "THEN", "ELSE", "END",
228];
229
230static AGG_FUNCS: &[&str] = &["COUNT", "SUM", "AVG", "MIN", "MAX"];
231
232static TOKEN_RE: LazyLock<Regex> = LazyLock::new(|| {
233    Regex::new(
234        r#"(?x)
235        \s*(?:
236            (?P<backtick>`[^`]+`)
237            | (?P<string>'(?:[^'\\]|\\.)*')
238            | (?P<number>\d+(?:\.\d+)?)
239            | (?P<op><=|>=|!=|[=<>,*()+\-/%])
240            | (?P<word>[A-Za-z_][A-Za-z0-9_./-]*)
241        )"#,
242    )
243    .unwrap()
244});
245
246#[derive(Debug, Clone)]
247struct Token {
248    token_type: String,
249    value: String,
250    raw: String,
251}
252
253fn tokenize(sql: &str) -> Vec<Token> {
254    let mut tokens = Vec::new();
255    for caps in TOKEN_RE.captures_iter(sql) {
256        if let Some(m) = caps.name("backtick") {
257            let raw = m.as_str();
258            tokens.push(Token {
259                token_type: "ident".into(),
260                value: raw[1..raw.len() - 1].into(),
261                raw: raw.into(),
262            });
263        } else if let Some(m) = caps.name("string") {
264            let raw = m.as_str();
265            tokens.push(Token {
266                token_type: "string".into(),
267                value: raw[1..raw.len() - 1].into(),
268                raw: raw.into(),
269            });
270        } else if let Some(m) = caps.name("number") {
271            let raw = m.as_str();
272            tokens.push(Token {
273                token_type: "number".into(),
274                value: raw.into(),
275                raw: raw.into(),
276            });
277        } else if let Some(m) = caps.name("op") {
278            let raw = m.as_str();
279            tokens.push(Token {
280                token_type: "op".into(),
281                value: raw.into(),
282                raw: raw.into(),
283            });
284        } else if let Some(m) = caps.name("word") {
285            let raw = m.as_str();
286            if KEYWORDS.contains(&raw.to_uppercase().as_str()) {
287                tokens.push(Token {
288                    token_type: "keyword".into(),
289                    value: raw.to_uppercase(),
290                    raw: raw.into(),
291                });
292            } else {
293                tokens.push(Token {
294                    token_type: "ident".into(),
295                    value: raw.into(),
296                    raw: raw.into(),
297                });
298            }
299        }
300    }
301    tokens
302}
303
304// ── Parser ─────────────────────────────────────────────────────────────────
305
306struct Parser {
307    tokens: Vec<Token>,
308    pos: usize,
309}
310
311impl Parser {
312    fn new(tokens: Vec<Token>) -> Self {
313        Parser { tokens, pos: 0 }
314    }
315
316    fn peek(&self) -> Option<&Token> {
317        self.tokens.get(self.pos)
318    }
319
320    fn advance(&mut self) -> Token {
321        let t = self.tokens[self.pos].clone();
322        self.pos += 1;
323        t
324    }
325
326    fn expect(&mut self, type_: &str, value: Option<&str>) -> Result<Token, MdqlError> {
327        let t = self.peek().ok_or_else(|| {
328            MdqlError::QueryParse(format!(
329                "Unexpected end of query, expected {}",
330                value.unwrap_or(type_)
331            ))
332        })?;
333        let matches_type = t.token_type == type_;
334        let matches_value = value.map_or(true, |v| t.value == v);
335        if !matches_type || !matches_value {
336            return Err(MdqlError::QueryParse(format!(
337                "Expected {}, got '{}' at position {}",
338                value.unwrap_or(type_),
339                t.raw,
340                self.pos
341            )));
342        }
343        Ok(self.advance())
344    }
345
346    fn match_keyword(&mut self, kw: &str) -> bool {
347        if let Some(t) = self.peek() {
348            if t.token_type == "keyword" && t.value == kw {
349                self.advance();
350                return true;
351            }
352        }
353        false
354    }
355
356    fn parse_statement(&mut self) -> Result<Statement, MdqlError> {
357        let t = self.peek().ok_or_else(|| MdqlError::QueryParse("Empty query".into()))?;
358        match (t.token_type.as_str(), t.value.as_str()) {
359            ("keyword", "SELECT") => Ok(Statement::Select(self.parse_select()?)),
360            ("keyword", "INSERT") => Ok(Statement::Insert(self.parse_insert()?)),
361            ("keyword", "UPDATE") => Ok(Statement::Update(self.parse_update()?)),
362            ("keyword", "DELETE") => Ok(Statement::Delete(self.parse_delete()?)),
363            ("keyword", "ALTER") => self.parse_alter(),
364            _ => Err(MdqlError::QueryParse(format!(
365                "Expected SELECT, INSERT, UPDATE, DELETE, or ALTER, got '{}'",
366                t.raw
367            ))),
368        }
369    }
370
371    fn parse_select(&mut self) -> Result<SelectQuery, MdqlError> {
372        self.expect("keyword", Some("SELECT"))?;
373        let columns = self.parse_columns()?;
374        self.expect("keyword", Some("FROM"))?;
375        let table = self.parse_ident()?;
376
377        // Optional table alias
378        let mut table_alias = None;
379        if let Some(t) = self.peek() {
380            if t.token_type == "ident" && !self.is_clause_keyword(t) {
381                table_alias = Some(self.advance().value);
382            }
383        }
384
385        // Optional JOIN(s)
386        let mut joins = Vec::new();
387        while self.match_keyword("JOIN") {
388            let join_table = self.parse_ident()?;
389            let mut join_alias = None;
390            if let Some(t) = self.peek() {
391                if t.token_type == "ident" && !self.is_clause_keyword(t) {
392                    join_alias = Some(self.advance().value);
393                }
394            }
395            self.expect("keyword", Some("ON"))?;
396            let left_col = self.parse_ident()?;
397            self.expect("op", Some("="))?;
398            let right_col = self.parse_ident()?;
399            joins.push(JoinClause {
400                table: join_table,
401                alias: join_alias,
402                left_col,
403                right_col,
404            });
405        }
406
407        let mut where_clause = None;
408        if self.match_keyword("WHERE") {
409            where_clause = Some(self.parse_or_expr()?);
410        }
411
412        let mut group_by = None;
413        if self.match_keyword("GROUP") {
414            self.expect("keyword", Some("BY"))?;
415            let mut cols = vec![self.parse_ident()?];
416            while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
417                self.advance();
418                cols.push(self.parse_ident()?);
419            }
420            group_by = Some(cols);
421        }
422
423        let mut order_by = None;
424        if self.match_keyword("ORDER") {
425            self.expect("keyword", Some("BY"))?;
426            order_by = Some(self.parse_order_by()?);
427        }
428
429        let mut limit = None;
430        if self.match_keyword("LIMIT") {
431            let t = self.expect("number", None)?;
432            limit = Some(t.value.parse::<i64>().map_err(|_| {
433                MdqlError::QueryParse(format!("Invalid LIMIT value: {}", t.value))
434            })?);
435        }
436
437        self.expect_end()?;
438
439        Ok(SelectQuery {
440            columns,
441            table,
442            table_alias,
443            joins,
444            where_clause,
445            group_by,
446            order_by,
447            limit,
448        })
449    }
450
451    fn parse_insert(&mut self) -> Result<InsertQuery, MdqlError> {
452        self.expect("keyword", Some("INSERT"))?;
453        self.expect("keyword", Some("INTO"))?;
454        let table = self.parse_ident()?;
455
456        self.expect("op", Some("("))?;
457        let mut columns = vec![self.parse_ident()?];
458        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
459            self.advance();
460            columns.push(self.parse_ident()?);
461        }
462        self.expect("op", Some(")"))?;
463
464        self.expect("keyword", Some("VALUES"))?;
465
466        self.expect("op", Some("("))?;
467        let mut values = vec![self.parse_value()?];
468        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
469            self.advance();
470            values.push(self.parse_value()?);
471        }
472        self.expect("op", Some(")"))?;
473
474        if columns.len() != values.len() {
475            return Err(MdqlError::QueryParse(format!(
476                "Column count ({}) does not match value count ({})",
477                columns.len(),
478                values.len()
479            )));
480        }
481
482        self.expect_end()?;
483        Ok(InsertQuery {
484            table,
485            columns,
486            values,
487        })
488    }
489
490    fn parse_update(&mut self) -> Result<UpdateQuery, MdqlError> {
491        self.expect("keyword", Some("UPDATE"))?;
492        let table = self.parse_ident()?;
493        self.expect("keyword", Some("SET"))?;
494
495        let mut assignments = Vec::new();
496        let col = self.parse_ident()?;
497        self.expect("op", Some("="))?;
498        let val = self.parse_value()?;
499        assignments.push((col, val));
500
501        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
502            self.advance();
503            let col = self.parse_ident()?;
504            self.expect("op", Some("="))?;
505            let val = self.parse_value()?;
506            assignments.push((col, val));
507        }
508
509        let mut where_clause = None;
510        if self.match_keyword("WHERE") {
511            where_clause = Some(self.parse_or_expr()?);
512        }
513
514        self.expect_end()?;
515        Ok(UpdateQuery {
516            table,
517            assignments,
518            where_clause,
519        })
520    }
521
522    fn parse_delete(&mut self) -> Result<DeleteQuery, MdqlError> {
523        self.expect("keyword", Some("DELETE"))?;
524        self.expect("keyword", Some("FROM"))?;
525        let table = self.parse_ident()?;
526
527        let mut where_clause = None;
528        if self.match_keyword("WHERE") {
529            where_clause = Some(self.parse_or_expr()?);
530        }
531
532        self.expect_end()?;
533        Ok(DeleteQuery {
534            table,
535            where_clause,
536        })
537    }
538
539    fn parse_alter(&mut self) -> Result<Statement, MdqlError> {
540        self.expect("keyword", Some("ALTER"))?;
541        self.expect("keyword", Some("TABLE"))?;
542        let table = self.parse_ident()?;
543
544        let t = self.peek().ok_or_else(|| {
545            MdqlError::QueryParse("Expected RENAME, DROP, or MERGE after table name".into())
546        })?;
547
548        match (t.token_type.as_str(), t.value.as_str()) {
549            ("keyword", "RENAME") => {
550                self.advance();
551                self.expect("keyword", Some("FIELD"))?;
552                let old_name = self.parse_string_or_ident()?;
553                self.expect("keyword", Some("TO"))?;
554                let new_name = self.parse_string_or_ident()?;
555                self.expect_end()?;
556                Ok(Statement::AlterRename(AlterRenameFieldQuery {
557                    table,
558                    old_name,
559                    new_name,
560                }))
561            }
562            ("keyword", "DROP") => {
563                self.advance();
564                self.expect("keyword", Some("FIELD"))?;
565                let field_name = self.parse_string_or_ident()?;
566                self.expect_end()?;
567                Ok(Statement::AlterDrop(AlterDropFieldQuery {
568                    table,
569                    field_name,
570                }))
571            }
572            ("keyword", "MERGE") => {
573                self.advance();
574                self.expect("keyword", Some("FIELDS"))?;
575                let mut sources = vec![self.parse_string_or_ident()?];
576                while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
577                    self.advance();
578                    sources.push(self.parse_string_or_ident()?);
579                }
580                self.expect("keyword", Some("INTO"))?;
581                let target = self.parse_string_or_ident()?;
582                self.expect_end()?;
583                Ok(Statement::AlterMerge(AlterMergeFieldsQuery {
584                    table,
585                    sources,
586                    into: target,
587                }))
588            }
589            _ => Err(MdqlError::QueryParse(format!(
590                "Expected RENAME, DROP, or MERGE, got '{}'",
591                t.raw
592            ))),
593        }
594    }
595
596    fn parse_string_or_ident(&mut self) -> Result<String, MdqlError> {
597        let t = self.peek().ok_or_else(|| {
598            MdqlError::QueryParse("Expected field name, got end of query".into())
599        })?;
600        match t.token_type.as_str() {
601            "string" => {
602                let v = self.advance().value;
603                Ok(v)
604            }
605            "ident" | "keyword" => {
606                let v = self.advance().value;
607                Ok(v)
608            }
609            _ => Err(MdqlError::QueryParse(format!(
610                "Expected field name, got '{}'",
611                t.raw
612            ))),
613        }
614    }
615
616    fn parse_columns(&mut self) -> Result<ColumnList, MdqlError> {
617        if let Some(t) = self.peek() {
618            if t.token_type == "op" && t.value == "*" {
619                self.advance();
620                return Ok(ColumnList::All);
621            }
622        }
623
624        let mut exprs = vec![self.parse_select_expr()?];
625        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
626            self.advance();
627            exprs.push(self.parse_select_expr()?);
628        }
629        Ok(ColumnList::Named(exprs))
630    }
631
632    fn peek_is_agg_func(&self) -> bool {
633        let t = match self.peek() {
634            Some(t) => t,
635            None => return false,
636        };
637        let name_upper = t.value.to_uppercase();
638        if !AGG_FUNCS.contains(&name_upper.as_str()) {
639            return false;
640        }
641        // Only treat as aggregate if followed by (
642        self.tokens
643            .get(self.pos + 1)
644            .map_or(false, |next| next.token_type == "op" && next.value == "(")
645    }
646
647    fn parse_select_expr(&mut self) -> Result<SelectExpr, MdqlError> {
648        let _t = self.peek().ok_or_else(|| {
649            MdqlError::QueryParse("Expected column or aggregate, got end of query".into())
650        })?;
651
652        if self.peek_is_agg_func() {
653            let func_name = self.advance().value.to_uppercase();
654            let func = match func_name.as_str() {
655                "COUNT" => AggFunc::Count,
656                "SUM" => AggFunc::Sum,
657                "AVG" => AggFunc::Avg,
658                "MIN" => AggFunc::Min,
659                "MAX" => AggFunc::Max,
660                _ => unreachable!(),
661            };
662            self.expect("op", Some("("))?;
663            let (arg, arg_expr) = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "*") {
664                self.advance();
665                ("*".to_string(), None)
666            } else {
667                // Parse a full expression inside the aggregate (supports CASE WHEN, arithmetic, etc.)
668                let expr = self.parse_additive()?;
669                if let Expr::Column(name) = &expr {
670                    (name.clone(), None)
671                } else {
672                    (expr.display_name(), Some(expr))
673                }
674            };
675            self.expect("op", Some(")"))?;
676
677            let alias = if self.match_keyword("AS") {
678                Some(self.parse_ident()?)
679            } else if self.peek().map_or(false, |t| {
680                t.token_type == "ident" && !self.is_clause_keyword(t)
681            }) {
682                Some(self.advance().value)
683            } else {
684                None
685            };
686
687            Ok(SelectExpr::Aggregate { func, arg, arg_expr, alias })
688        } else {
689            // Parse a general expression (could be a column, literal, or arithmetic)
690            let expr = self.parse_additive()?;
691
692            // Optional alias: explicit (AS alias) or implicit (just an ident)
693            let alias = if self.match_keyword("AS") {
694                Some(self.parse_ident()?)
695            } else if self.peek().map_or(false, |t| {
696                t.token_type == "ident" && !self.is_clause_keyword(t)
697            }) {
698                Some(self.advance().value)
699            } else {
700                None
701            };
702
703            // If it's a simple column reference with no alias, return Column variant
704            // for backward compatibility
705            if alias.is_none() {
706                if let Expr::Column(name) = &expr {
707                    return Ok(SelectExpr::Column(name.clone()));
708                }
709            }
710
711            Ok(SelectExpr::Expr { expr, alias })
712        }
713    }
714
715    // ── Expression parser (precedence climbing) ───────────────────────
716
717    fn peek_is_additive_op(&self) -> bool {
718        self.peek().map_or(false, |t| {
719            t.token_type == "op" && (t.value == "+" || t.value == "-")
720        })
721    }
722
723    fn peek_is_multiplicative_op(&self) -> bool {
724        self.peek().map_or(false, |t| {
725            t.token_type == "op" && (t.value == "*" || t.value == "/" || t.value == "%")
726        })
727    }
728
729    fn parse_additive(&mut self) -> Result<Expr, MdqlError> {
730        let mut left = self.parse_multiplicative()?;
731        while self.peek_is_additive_op() {
732            let op_tok = self.advance();
733            let op = match op_tok.value.as_str() {
734                "+" => ArithOp::Add,
735                "-" => ArithOp::Sub,
736                _ => unreachable!(),
737            };
738            let right = self.parse_multiplicative()?;
739            left = Expr::BinaryOp {
740                left: Box::new(left),
741                op,
742                right: Box::new(right),
743            };
744        }
745        Ok(left)
746    }
747
748    fn parse_multiplicative(&mut self) -> Result<Expr, MdqlError> {
749        let mut left = self.parse_unary()?;
750        while self.peek_is_multiplicative_op() {
751            let op_tok = self.advance();
752            let op = match op_tok.value.as_str() {
753                "*" => ArithOp::Mul,
754                "/" => ArithOp::Div,
755                "%" => ArithOp::Mod,
756                _ => unreachable!(),
757            };
758            let right = self.parse_unary()?;
759            left = Expr::BinaryOp {
760                left: Box::new(left),
761                op,
762                right: Box::new(right),
763            };
764        }
765        Ok(left)
766    }
767
768    fn parse_unary(&mut self) -> Result<Expr, MdqlError> {
769        if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "-") {
770            self.advance();
771            let inner = self.parse_atom()?;
772            // Fold unary minus on literals
773            match inner {
774                Expr::Literal(SqlValue::Int(n)) => Ok(Expr::Literal(SqlValue::Int(-n))),
775                Expr::Literal(SqlValue::Float(f)) => Ok(Expr::Literal(SqlValue::Float(-f))),
776                _ => Ok(Expr::UnaryMinus(Box::new(inner))),
777            }
778        } else {
779            self.parse_atom()
780        }
781    }
782
783    fn parse_atom(&mut self) -> Result<Expr, MdqlError> {
784        let t = self.peek().ok_or_else(|| {
785            MdqlError::QueryParse("Expected expression, got end of query".into())
786        })?;
787
788        match t.token_type.as_str() {
789            "number" => {
790                let v = self.advance().value;
791                if v.contains('.') {
792                    let f: f64 = v.parse().map_err(|_| {
793                        MdqlError::QueryParse(format!("Invalid float: {}", v))
794                    })?;
795                    Ok(Expr::Literal(SqlValue::Float(f)))
796                } else {
797                    let n: i64 = v.parse().map_err(|_| {
798                        MdqlError::QueryParse(format!("Invalid int: {}", v))
799                    })?;
800                    Ok(Expr::Literal(SqlValue::Int(n)))
801                }
802            }
803            "string" => {
804                let v = self.advance().value;
805                Ok(Expr::Literal(SqlValue::String(v)))
806            }
807            "keyword" if t.value == "NULL" => {
808                self.advance();
809                Ok(Expr::Literal(SqlValue::Null))
810            }
811            "keyword" if t.value == "CASE" => {
812                self.parse_case_expr()
813            }
814            "op" if t.value == "(" => {
815                self.advance();
816                let expr = self.parse_additive()?;
817                self.expect("op", Some(")"))?;
818                Ok(expr)
819            }
820            "ident" => {
821                let name = self.advance().value;
822                Ok(Expr::Column(name))
823            }
824            "keyword" if !Self::is_reserved_keyword(&t.value) => {
825                let name = self.advance().value;
826                Ok(Expr::Column(name))
827            }
828            _ => Err(MdqlError::QueryParse(format!(
829                "Expected expression, got '{}'",
830                t.raw
831            ))),
832        }
833    }
834
835    fn parse_case_expr(&mut self) -> Result<Expr, MdqlError> {
836        self.expect("keyword", Some("CASE"))?;
837        let mut whens = Vec::new();
838        while self.match_keyword("WHEN") {
839            let condition = self.parse_or_expr()?;
840            self.expect("keyword", Some("THEN"))?;
841            let result = self.parse_additive()?;
842            whens.push((condition, Box::new(result)));
843        }
844        if whens.is_empty() {
845            return Err(MdqlError::QueryParse("CASE requires at least one WHEN clause".into()));
846        }
847        let else_expr = if self.match_keyword("ELSE") {
848            Some(Box::new(self.parse_additive()?))
849        } else {
850            None
851        };
852        self.expect("keyword", Some("END"))?;
853        Ok(Expr::Case { whens, else_expr })
854    }
855
856    fn parse_ident(&mut self) -> Result<String, MdqlError> {
857        let t = self.peek().ok_or_else(|| {
858            MdqlError::QueryParse("Expected identifier, got end of query".into())
859        })?;
860        match t.token_type.as_str() {
861            "ident" | "keyword" => {
862                let v = self.advance().value;
863                Ok(v)
864            }
865            _ => Err(MdqlError::QueryParse(format!(
866                "Expected identifier, got '{}'",
867                t.raw
868            ))),
869        }
870    }
871
872    fn parse_or_expr(&mut self) -> Result<WhereClause, MdqlError> {
873        let mut left = self.parse_and_expr()?;
874        while self.match_keyword("OR") {
875            let right = self.parse_and_expr()?;
876            left = WhereClause::BoolOp(BoolOp {
877                op: "OR".into(),
878                left: Box::new(left),
879                right: Box::new(right),
880            });
881        }
882        Ok(left)
883    }
884
885    fn parse_and_expr(&mut self) -> Result<WhereClause, MdqlError> {
886        let mut left = self.parse_comparison()?;
887        while self.match_keyword("AND") {
888            let right = self.parse_comparison()?;
889            left = WhereClause::BoolOp(BoolOp {
890                op: "AND".into(),
891                left: Box::new(left),
892                right: Box::new(right),
893            });
894        }
895        Ok(left)
896    }
897
898    fn parse_comparison(&mut self) -> Result<WhereClause, MdqlError> {
899        // Handle parenthesized boolean expressions
900        if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
901            // Save position — might be arithmetic parens, not boolean
902            let saved_pos = self.pos;
903            self.advance();
904            // Try parsing as boolean (OR/AND) expression
905            let result = self.parse_or_expr();
906            if result.is_ok() && self.peek().map_or(false, |t| t.token_type == "op" && t.value == ")") {
907                self.advance();
908                return result;
909            }
910            // Not a boolean paren — rewind and parse as arithmetic expression
911            self.pos = saved_pos;
912        }
913
914        // Parse the left side as a full expression (column, literal, or arithmetic)
915        let left_expr = self.parse_additive()?;
916
917        // Extract column name for backward compat (simple column on left side)
918        let col = left_expr.as_column().unwrap_or("").to_string();
919
920        // IS NULL / IS NOT NULL (only valid with simple column)
921        if self.match_keyword("IS") {
922            if self.match_keyword("NOT") {
923                self.expect("keyword", Some("NULL"))?;
924                return Ok(WhereClause::Comparison(Comparison {
925                    column: col,
926                    op: "IS NOT NULL".into(),
927                    value: None,
928                    left_expr: Some(left_expr),
929                    right_expr: None,
930                }));
931            }
932            self.expect("keyword", Some("NULL"))?;
933            return Ok(WhereClause::Comparison(Comparison {
934                column: col,
935                op: "IS NULL".into(),
936                value: None,
937                left_expr: Some(left_expr),
938                right_expr: None,
939            }));
940        }
941
942        // IN (val, val, ...)
943        if self.match_keyword("IN") {
944            self.expect("op", Some("("))?;
945            let mut values = vec![self.parse_value()?];
946            while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
947                self.advance();
948                values.push(self.parse_value()?);
949            }
950            self.expect("op", Some(")"))?;
951            return Ok(WhereClause::Comparison(Comparison {
952                column: col,
953                op: "IN".into(),
954                value: Some(SqlValue::List(values)),
955                left_expr: Some(left_expr),
956                right_expr: None,
957            }));
958        }
959
960        // LIKE
961        if self.match_keyword("LIKE") {
962            let val = self.parse_value()?;
963            return Ok(WhereClause::Comparison(Comparison {
964                column: col,
965                op: "LIKE".into(),
966                value: Some(val),
967                left_expr: Some(left_expr),
968                right_expr: None,
969            }));
970        }
971
972        // NOT LIKE
973        if self.match_keyword("NOT") {
974            if self.match_keyword("LIKE") {
975                let val = self.parse_value()?;
976                return Ok(WhereClause::Comparison(Comparison {
977                    column: col,
978                    op: "NOT LIKE".into(),
979                    value: Some(val),
980                    left_expr: Some(left_expr),
981                    right_expr: None,
982                }));
983            }
984            return Err(MdqlError::QueryParse("Expected LIKE after NOT".into()));
985        }
986
987        // Standard comparison operators
988        if let Some(t) = self.peek() {
989            if t.token_type == "op" && ["=", "!=", "<", ">", "<=", ">="].contains(&t.value.as_str())
990            {
991                let op = self.advance().value;
992                // Parse right side as expression
993                let right_expr = self.parse_additive()?;
994                // Extract SqlValue for backward compat (simple literal on right side)
995                let value = match &right_expr {
996                    Expr::Literal(v) => Some(v.clone()),
997                    _ => None,
998                };
999                return Ok(WhereClause::Comparison(Comparison {
1000                    column: col,
1001                    op,
1002                    value,
1003                    left_expr: Some(left_expr),
1004                    right_expr: Some(right_expr),
1005                }));
1006            }
1007        }
1008
1009        let got = self.peek().map_or("end".to_string(), |t| t.raw.clone());
1010        Err(MdqlError::QueryParse(format!(
1011            "Expected operator after '{}', got '{}'",
1012            left_expr.display_name(), got
1013        )))
1014    }
1015
1016    fn parse_value(&mut self) -> Result<SqlValue, MdqlError> {
1017        let t = self.peek().ok_or_else(|| {
1018            MdqlError::QueryParse("Expected value, got end of query".into())
1019        })?;
1020        match t.token_type.as_str() {
1021            "string" => {
1022                let v = self.advance().value;
1023                Ok(SqlValue::String(v))
1024            }
1025            "number" => {
1026                let v = self.advance().value;
1027                if v.contains('.') {
1028                    Ok(SqlValue::Float(v.parse().map_err(|_| {
1029                        MdqlError::QueryParse(format!("Invalid float: {}", v))
1030                    })?))
1031                } else {
1032                    Ok(SqlValue::Int(v.parse().map_err(|_| {
1033                        MdqlError::QueryParse(format!("Invalid int: {}", v))
1034                    })?))
1035                }
1036            }
1037            "keyword" if t.value == "NULL" => {
1038                self.advance();
1039                Ok(SqlValue::Null)
1040            }
1041            _ => Err(MdqlError::QueryParse(format!(
1042                "Expected value, got '{}'",
1043                t.raw
1044            ))),
1045        }
1046    }
1047
1048    fn parse_order_by(&mut self) -> Result<Vec<OrderSpec>, MdqlError> {
1049        let mut specs = vec![self.parse_order_spec()?];
1050        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
1051            self.advance();
1052            specs.push(self.parse_order_spec()?);
1053        }
1054        Ok(specs)
1055    }
1056
1057    fn parse_order_spec(&mut self) -> Result<OrderSpec, MdqlError> {
1058        let expr = self.parse_additive()?;
1059        let col = expr.as_column().unwrap_or("").to_string();
1060        let descending = if self.match_keyword("DESC") {
1061            true
1062        } else {
1063            self.match_keyword("ASC");
1064            false
1065        };
1066        Ok(OrderSpec {
1067            column: col,
1068            expr: Some(expr),
1069            descending,
1070        })
1071    }
1072
1073    fn is_clause_keyword(&self, t: &Token) -> bool {
1074        t.token_type == "keyword"
1075            && ["WHERE", "ORDER", "LIMIT", "JOIN", "ON", "GROUP"].contains(&t.value.as_str())
1076    }
1077
1078    /// Keywords that should never be consumed as column names inside expressions.
1079    fn is_reserved_keyword(kw: &str) -> bool {
1080        matches!(kw,
1081            "AS" | "FROM" | "WHERE" | "AND" | "OR" | "ORDER" | "BY"
1082            | "ASC" | "DESC" | "LIMIT" | "JOIN" | "ON" | "GROUP"
1083            | "SELECT" | "INSERT" | "INTO" | "VALUES" | "UPDATE" | "SET"
1084            | "DELETE" | "ALTER" | "TABLE" | "IS" | "NOT" | "IN" | "LIKE"
1085            | "RENAME" | "FIELD" | "TO" | "DROP" | "MERGE" | "FIELDS"
1086            | "CASE" | "WHEN" | "THEN" | "ELSE" | "END"
1087        )
1088    }
1089
1090    fn expect_end(&self) -> Result<(), MdqlError> {
1091        if let Some(t) = self.peek() {
1092            return Err(MdqlError::QueryParse(format!(
1093                "Unexpected token '{}' at position {}",
1094                t.raw, self.pos
1095            )));
1096        }
1097        Ok(())
1098    }
1099}
1100
1101pub fn parse_query(sql: &str) -> crate::errors::Result<Statement> {
1102    let tokens = tokenize(sql);
1103    if tokens.is_empty() {
1104        return Err(MdqlError::QueryParse("Empty query".into()));
1105    }
1106    let mut parser = Parser::new(tokens);
1107    parser.parse_statement()
1108}
1109
1110#[cfg(test)]
1111mod tests {
1112    use super::*;
1113
1114    #[test]
1115    fn test_simple_select() {
1116        let stmt = parse_query("SELECT title, status FROM strategies").unwrap();
1117        if let Statement::Select(q) = stmt {
1118            assert_eq!(q.columns, ColumnList::Named(vec![SelectExpr::Column("title".into()), SelectExpr::Column("status".into())]));
1119            assert_eq!(q.table, "strategies");
1120        } else {
1121            panic!("Expected Select");
1122        }
1123    }
1124
1125    #[test]
1126    fn test_select_star() {
1127        let stmt = parse_query("SELECT * FROM test").unwrap();
1128        if let Statement::Select(q) = stmt {
1129            assert_eq!(q.columns, ColumnList::All);
1130        } else {
1131            panic!("Expected Select");
1132        }
1133    }
1134
1135    #[test]
1136    fn test_where_clause() {
1137        let stmt = parse_query("SELECT title FROM test WHERE count > 5").unwrap();
1138        if let Statement::Select(q) = stmt {
1139            assert!(q.where_clause.is_some());
1140        } else {
1141            panic!("Expected Select");
1142        }
1143    }
1144
1145    #[test]
1146    fn test_order_by() {
1147        let stmt =
1148            parse_query("SELECT title FROM test ORDER BY composite DESC, title ASC").unwrap();
1149        if let Statement::Select(q) = stmt {
1150            let ob = q.order_by.unwrap();
1151            assert_eq!(ob.len(), 2);
1152            assert!(ob[0].descending);
1153            assert!(!ob[1].descending);
1154        } else {
1155            panic!("Expected Select");
1156        }
1157    }
1158
1159    #[test]
1160    fn test_limit() {
1161        let stmt = parse_query("SELECT * FROM test LIMIT 10").unwrap();
1162        if let Statement::Select(q) = stmt {
1163            assert_eq!(q.limit, Some(10));
1164        } else {
1165            panic!("Expected Select");
1166        }
1167    }
1168
1169    #[test]
1170    fn test_insert() {
1171        let stmt = parse_query(
1172            "INSERT INTO test (title, count) VALUES ('Hello', 42)",
1173        )
1174        .unwrap();
1175        if let Statement::Insert(q) = stmt {
1176            assert_eq!(q.table, "test");
1177            assert_eq!(q.columns, vec!["title", "count"]);
1178            assert_eq!(q.values[0], SqlValue::String("Hello".into()));
1179            assert_eq!(q.values[1], SqlValue::Int(42));
1180        } else {
1181            panic!("Expected Insert");
1182        }
1183    }
1184
1185    #[test]
1186    fn test_update() {
1187        let stmt = parse_query("UPDATE test SET status = 'KILLED' WHERE path = 'a.md'").unwrap();
1188        if let Statement::Update(q) = stmt {
1189            assert_eq!(q.table, "test");
1190            assert_eq!(q.assignments.len(), 1);
1191            assert!(q.where_clause.is_some());
1192        } else {
1193            panic!("Expected Update");
1194        }
1195    }
1196
1197    #[test]
1198    fn test_delete() {
1199        let stmt = parse_query("DELETE FROM test WHERE status = 'draft'").unwrap();
1200        if let Statement::Delete(q) = stmt {
1201            assert_eq!(q.table, "test");
1202            assert!(q.where_clause.is_some());
1203        } else {
1204            panic!("Expected Delete");
1205        }
1206    }
1207
1208    #[test]
1209    fn test_alter_rename() {
1210        let stmt =
1211            parse_query("ALTER TABLE test RENAME FIELD 'Summary' TO 'Overview'").unwrap();
1212        if let Statement::AlterRename(q) = stmt {
1213            assert_eq!(q.old_name, "Summary");
1214            assert_eq!(q.new_name, "Overview");
1215        } else {
1216            panic!("Expected AlterRename");
1217        }
1218    }
1219
1220    #[test]
1221    fn test_alter_drop() {
1222        let stmt = parse_query("ALTER TABLE test DROP FIELD 'Details'").unwrap();
1223        if let Statement::AlterDrop(q) = stmt {
1224            assert_eq!(q.field_name, "Details");
1225        } else {
1226            panic!("Expected AlterDrop");
1227        }
1228    }
1229
1230    #[test]
1231    fn test_alter_merge() {
1232        let stmt = parse_query(
1233            "ALTER TABLE test MERGE FIELDS 'Entry Rules', 'Exit Rules' INTO 'Trading Rules'",
1234        )
1235        .unwrap();
1236        if let Statement::AlterMerge(q) = stmt {
1237            assert_eq!(q.sources, vec!["Entry Rules", "Exit Rules"]);
1238            assert_eq!(q.into, "Trading Rules");
1239        } else {
1240            panic!("Expected AlterMerge");
1241        }
1242    }
1243
1244    #[test]
1245    fn test_backtick_ident() {
1246        let stmt = parse_query("SELECT `Structural Mechanism` FROM test").unwrap();
1247        if let Statement::Select(q) = stmt {
1248            assert_eq!(
1249                q.columns,
1250                ColumnList::Named(vec![SelectExpr::Column("Structural Mechanism".into())])
1251            );
1252        } else {
1253            panic!("Expected Select");
1254        }
1255    }
1256
1257    #[test]
1258    fn test_like_operator() {
1259        let stmt = parse_query("SELECT title FROM test WHERE categories LIKE '%defi%'").unwrap();
1260        if let Statement::Select(q) = stmt {
1261            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1262                assert_eq!(c.op, "LIKE");
1263                assert_eq!(c.value, Some(SqlValue::String("%defi%".into())));
1264            } else {
1265                panic!("Expected LIKE comparison");
1266            }
1267        } else {
1268            panic!("Expected Select");
1269        }
1270    }
1271
1272    #[test]
1273    fn test_in_operator() {
1274        let stmt =
1275            parse_query("SELECT * FROM test WHERE status IN ('ACTIVE', 'LIVE')").unwrap();
1276        if let Statement::Select(q) = stmt {
1277            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1278                assert_eq!(c.op, "IN");
1279            } else {
1280                panic!("Expected IN comparison");
1281            }
1282        } else {
1283            panic!("Expected Select");
1284        }
1285    }
1286
1287    #[test]
1288    fn test_is_null() {
1289        let stmt = parse_query("SELECT * FROM test WHERE title IS NULL").unwrap();
1290        if let Statement::Select(q) = stmt {
1291            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1292                assert_eq!(c.op, "IS NULL");
1293            } else {
1294                panic!("Expected IS NULL comparison");
1295            }
1296        } else {
1297            panic!("Expected Select");
1298        }
1299    }
1300
1301    #[test]
1302    fn test_and_or() {
1303        let stmt = parse_query(
1304            "SELECT * FROM test WHERE status = 'ACTIVE' AND count > 5 OR title LIKE '%test%'",
1305        )
1306        .unwrap();
1307        if let Statement::Select(q) = stmt {
1308            assert!(q.where_clause.is_some());
1309        } else {
1310            panic!("Expected Select");
1311        }
1312    }
1313
1314    #[test]
1315    fn test_join() {
1316        let stmt = parse_query(
1317            "SELECT s.title, b.sharpe FROM strategies s JOIN backtests b ON b.strategy = s.path",
1318        )
1319        .unwrap();
1320        if let Statement::Select(q) = stmt {
1321            assert_eq!(q.table, "strategies");
1322            assert_eq!(q.table_alias, Some("s".into()));
1323            assert_eq!(q.joins.len(), 1);
1324            let join = &q.joins[0];
1325            assert_eq!(join.table, "backtests");
1326            assert_eq!(join.alias, Some("b".into()));
1327        } else {
1328            panic!("Expected Select");
1329        }
1330    }
1331
1332    #[test]
1333    fn test_multi_join() {
1334        let stmt = parse_query(
1335            "SELECT s.title, b.sharpe, c.verdict FROM strategies s JOIN backtests b ON b.strategy = s.path JOIN critiques c ON c.strategy = s.path",
1336        )
1337        .unwrap();
1338        if let Statement::Select(q) = stmt {
1339            assert_eq!(q.table, "strategies");
1340            assert_eq!(q.table_alias, Some("s".into()));
1341            assert_eq!(q.joins.len(), 2);
1342            assert_eq!(q.joins[0].table, "backtests");
1343            assert_eq!(q.joins[0].alias, Some("b".into()));
1344            assert_eq!(q.joins[0].left_col, "b.strategy");
1345            assert_eq!(q.joins[0].right_col, "s.path");
1346            assert_eq!(q.joins[1].table, "critiques");
1347            assert_eq!(q.joins[1].alias, Some("c".into()));
1348            assert_eq!(q.joins[1].left_col, "c.strategy");
1349            assert_eq!(q.joins[1].right_col, "s.path");
1350        } else {
1351            panic!("Expected Select");
1352        }
1353    }
1354
1355    #[test]
1356    fn test_empty_query() {
1357        assert!(parse_query("").is_err());
1358    }
1359
1360    #[test]
1361    fn test_count_star() {
1362        let stmt = parse_query("SELECT status, COUNT(*) AS cnt FROM strategies GROUP BY status").unwrap();
1363        if let Statement::Select(q) = stmt {
1364            if let ColumnList::Named(exprs) = &q.columns {
1365                assert_eq!(exprs.len(), 2);
1366                assert_eq!(exprs[0], SelectExpr::Column("status".into()));
1367                assert!(matches!(&exprs[1], SelectExpr::Aggregate {
1368                    func: AggFunc::Count,
1369                    arg,
1370                    alias: Some(a),
1371                    ..
1372                } if arg == "*" && a == "cnt"));
1373            } else {
1374                panic!("Expected Named columns");
1375            }
1376            assert_eq!(q.group_by, Some(vec!["status".into()]));
1377        } else {
1378            panic!("Expected Select");
1379        }
1380    }
1381
1382    #[test]
1383    fn test_count_column_as_ident() {
1384        // "count" as a column name should NOT be parsed as the COUNT aggregate
1385        let stmt = parse_query("INSERT INTO test (title, count) VALUES ('Hello', 42)").unwrap();
1386        if let Statement::Insert(q) = stmt {
1387            assert_eq!(q.columns, vec!["title", "count"]);
1388        } else {
1389            panic!("Expected Insert");
1390        }
1391    }
1392
1393    #[test]
1394    fn test_multiple_aggregates() {
1395        let stmt = parse_query("SELECT MIN(composite), MAX(composite), AVG(composite) FROM strategies").unwrap();
1396        if let Statement::Select(q) = stmt {
1397            if let ColumnList::Named(exprs) = &q.columns {
1398                assert_eq!(exprs.len(), 3);
1399                assert!(matches!(&exprs[0], SelectExpr::Aggregate { func: AggFunc::Min, .. }));
1400                assert!(matches!(&exprs[1], SelectExpr::Aggregate { func: AggFunc::Max, .. }));
1401                assert!(matches!(&exprs[2], SelectExpr::Aggregate { func: AggFunc::Avg, .. }));
1402            } else {
1403                panic!("Expected Named columns");
1404            }
1405            assert_eq!(q.group_by, None);
1406        } else {
1407            panic!("Expected Select");
1408        }
1409    }
1410
1411    // ── Expression tests ──────────────────────────────────────────
1412
1413    #[test]
1414    fn test_select_arithmetic_expr() {
1415        let stmt = parse_query("SELECT a + b FROM test").unwrap();
1416        if let Statement::Select(q) = stmt {
1417            if let ColumnList::Named(exprs) = &q.columns {
1418                assert_eq!(exprs.len(), 1);
1419                assert!(matches!(&exprs[0], SelectExpr::Expr {
1420                    expr: Expr::BinaryOp { op: ArithOp::Add, .. },
1421                    alias: None,
1422                }));
1423            } else {
1424                panic!("Expected Named columns");
1425            }
1426        } else {
1427            panic!("Expected Select");
1428        }
1429    }
1430
1431    #[test]
1432    fn test_select_arithmetic_with_alias() {
1433        let stmt = parse_query("SELECT a + b AS total FROM test").unwrap();
1434        if let Statement::Select(q) = stmt {
1435            if let ColumnList::Named(exprs) = &q.columns {
1436                assert_eq!(exprs.len(), 1);
1437                assert!(matches!(&exprs[0], SelectExpr::Expr {
1438                    alias: Some(a),
1439                    ..
1440                } if a == "total"));
1441                assert_eq!(exprs[0].output_name(), "total");
1442            } else {
1443                panic!("Expected Named columns");
1444            }
1445        } else {
1446            panic!("Expected Select");
1447        }
1448    }
1449
1450    #[test]
1451    fn test_select_precedence() {
1452        // a + b * c should parse as a + (b * c)
1453        let stmt = parse_query("SELECT a + b * c FROM test").unwrap();
1454        if let Statement::Select(q) = stmt {
1455            if let ColumnList::Named(exprs) = &q.columns {
1456                if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1457                    if let Expr::BinaryOp { left, op, right } = expr {
1458                        assert_eq!(*op, ArithOp::Add);
1459                        assert!(matches!(left.as_ref(), Expr::Column(n) if n == "a"));
1460                        assert!(matches!(right.as_ref(), Expr::BinaryOp { op: ArithOp::Mul, .. }));
1461                    } else {
1462                        panic!("Expected BinaryOp");
1463                    }
1464                } else {
1465                    panic!("Expected Expr variant");
1466                }
1467            } else {
1468                panic!("Expected Named columns");
1469            }
1470        } else {
1471            panic!("Expected Select");
1472        }
1473    }
1474
1475    #[test]
1476    fn test_select_parenthesized_expr() {
1477        // (a + b) * c should override default precedence
1478        let stmt = parse_query("SELECT (a + b) * c FROM test").unwrap();
1479        if let Statement::Select(q) = stmt {
1480            if let ColumnList::Named(exprs) = &q.columns {
1481                if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1482                    if let Expr::BinaryOp { left, op, .. } = expr {
1483                        assert_eq!(*op, ArithOp::Mul);
1484                        assert!(matches!(left.as_ref(), Expr::BinaryOp { op: ArithOp::Add, .. }));
1485                    } else {
1486                        panic!("Expected BinaryOp");
1487                    }
1488                } else {
1489                    panic!("Expected Expr variant");
1490                }
1491            } else {
1492                panic!("Expected Named columns");
1493            }
1494        } else {
1495            panic!("Expected Select");
1496        }
1497    }
1498
1499    #[test]
1500    fn test_select_unary_minus() {
1501        let stmt = parse_query("SELECT -count FROM test").unwrap();
1502        if let Statement::Select(q) = stmt {
1503            if let ColumnList::Named(exprs) = &q.columns {
1504                assert!(matches!(&exprs[0], SelectExpr::Expr {
1505                    expr: Expr::UnaryMinus(_),
1506                    ..
1507                }));
1508            } else {
1509                panic!("Expected Named columns");
1510            }
1511        } else {
1512            panic!("Expected Select");
1513        }
1514    }
1515
1516    #[test]
1517    fn test_select_negative_literal() {
1518        let stmt = parse_query("SELECT -42 FROM test").unwrap();
1519        if let Statement::Select(q) = stmt {
1520            if let ColumnList::Named(exprs) = &q.columns {
1521                // Unary minus folds into the literal
1522                assert!(matches!(&exprs[0], SelectExpr::Expr {
1523                    expr: Expr::Literal(SqlValue::Int(-42)),
1524                    ..
1525                }));
1526            } else {
1527                panic!("Expected Named columns");
1528            }
1529        } else {
1530            panic!("Expected Select");
1531        }
1532    }
1533
1534    #[test]
1535    fn test_where_arithmetic_expr() {
1536        let stmt = parse_query("SELECT * FROM test WHERE a + b > 10").unwrap();
1537        if let Statement::Select(q) = stmt {
1538            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1539                assert_eq!(c.op, ">");
1540                assert!(matches!(&c.left_expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1541                assert!(matches!(&c.right_expr, Some(Expr::Literal(SqlValue::Int(10)))));
1542            } else {
1543                panic!("Expected comparison");
1544            }
1545        } else {
1546            panic!("Expected Select");
1547        }
1548    }
1549
1550    #[test]
1551    fn test_where_both_sides_expr() {
1552        let stmt = parse_query("SELECT * FROM test WHERE a * 2 > b + 1").unwrap();
1553        if let Statement::Select(q) = stmt {
1554            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1555                assert_eq!(c.op, ">");
1556                assert!(matches!(&c.left_expr, Some(Expr::BinaryOp { op: ArithOp::Mul, .. })));
1557                assert!(matches!(&c.right_expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1558            } else {
1559                panic!("Expected comparison");
1560            }
1561        } else {
1562            panic!("Expected Select");
1563        }
1564    }
1565
1566    #[test]
1567    fn test_order_by_expr() {
1568        let stmt = parse_query("SELECT * FROM test ORDER BY a + b DESC").unwrap();
1569        if let Statement::Select(q) = stmt {
1570            let ob = q.order_by.unwrap();
1571            assert_eq!(ob.len(), 1);
1572            assert!(ob[0].descending);
1573            assert!(matches!(&ob[0].expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1574        } else {
1575            panic!("Expected Select");
1576        }
1577    }
1578
1579    #[test]
1580    fn test_all_arithmetic_ops() {
1581        let stmt = parse_query("SELECT a + b, a - b, a * b, a / b, a % b FROM test").unwrap();
1582        if let Statement::Select(q) = stmt {
1583            if let ColumnList::Named(exprs) = &q.columns {
1584                assert_eq!(exprs.len(), 5);
1585                assert!(matches!(&exprs[0], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Add, .. }, .. }));
1586                assert!(matches!(&exprs[1], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Sub, .. }, .. }));
1587                assert!(matches!(&exprs[2], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Mul, .. }, .. }));
1588                assert!(matches!(&exprs[3], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Div, .. }, .. }));
1589                assert!(matches!(&exprs[4], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Mod, .. }, .. }));
1590            } else {
1591                panic!("Expected Named columns");
1592            }
1593        } else {
1594            panic!("Expected Select");
1595        }
1596    }
1597
1598    #[test]
1599    fn test_column_with_literal_arithmetic() {
1600        let stmt = parse_query("SELECT count * 2 + 1 FROM test").unwrap();
1601        if let Statement::Select(q) = stmt {
1602            if let ColumnList::Named(exprs) = &q.columns {
1603                // Should be (count * 2) + 1
1604                if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1605                    if let Expr::BinaryOp { left, op, right } = expr {
1606                        assert_eq!(*op, ArithOp::Add);
1607                        assert!(matches!(right.as_ref(), Expr::Literal(SqlValue::Int(1))));
1608                        assert!(matches!(left.as_ref(), Expr::BinaryOp { op: ArithOp::Mul, .. }));
1609                    } else {
1610                        panic!("Expected BinaryOp");
1611                    }
1612                } else {
1613                    panic!("Expected Expr");
1614                }
1615            } else {
1616                panic!("Expected Named columns");
1617            }
1618        } else {
1619            panic!("Expected Select");
1620        }
1621    }
1622
1623    #[test]
1624    fn test_mixed_columns_and_exprs() {
1625        let stmt = parse_query("SELECT title, a + b AS sum, count FROM test").unwrap();
1626        if let Statement::Select(q) = stmt {
1627            if let ColumnList::Named(exprs) = &q.columns {
1628                assert_eq!(exprs.len(), 3);
1629                assert_eq!(exprs[0], SelectExpr::Column("title".into()));
1630                assert!(matches!(&exprs[1], SelectExpr::Expr { alias: Some(a), .. } if a == "sum"));
1631                assert_eq!(exprs[2], SelectExpr::Column("count".into()));
1632            } else {
1633                panic!("Expected Named columns");
1634            }
1635        } else {
1636            panic!("Expected Select");
1637        }
1638    }
1639
1640    // ── CASE WHEN tests ──────────────────────────────────────────
1641
1642    #[test]
1643    fn test_case_when_basic() {
1644        let stmt = parse_query(
1645            "SELECT CASE WHEN status = 'ACTIVE' THEN 1 ELSE 0 END FROM test"
1646        ).unwrap();
1647        if let Statement::Select(q) = stmt {
1648            if let ColumnList::Named(exprs) = &q.columns {
1649                assert_eq!(exprs.len(), 1);
1650                assert!(matches!(&exprs[0], SelectExpr::Expr {
1651                    expr: Expr::Case { .. },
1652                    ..
1653                }));
1654            } else {
1655                panic!("Expected Named columns");
1656            }
1657        } else {
1658            panic!("Expected Select");
1659        }
1660    }
1661
1662    #[test]
1663    fn test_case_when_multiple_branches() {
1664        let stmt = parse_query(
1665            "SELECT CASE WHEN x > 10 THEN 'high' WHEN x > 5 THEN 'mid' ELSE 'low' END FROM test"
1666        ).unwrap();
1667        if let Statement::Select(q) = stmt {
1668            if let ColumnList::Named(exprs) = &q.columns {
1669                if let SelectExpr::Expr { expr: Expr::Case { whens, else_expr }, .. } = &exprs[0] {
1670                    assert_eq!(whens.len(), 2);
1671                    assert!(else_expr.is_some());
1672                } else {
1673                    panic!("Expected Case expression");
1674                }
1675            } else {
1676                panic!("Expected Named columns");
1677            }
1678        } else {
1679            panic!("Expected Select");
1680        }
1681    }
1682
1683    #[test]
1684    fn test_case_when_no_else() {
1685        let stmt = parse_query(
1686            "SELECT CASE WHEN x = 1 THEN 'one' END FROM test"
1687        ).unwrap();
1688        if let Statement::Select(q) = stmt {
1689            if let ColumnList::Named(exprs) = &q.columns {
1690                if let SelectExpr::Expr { expr: Expr::Case { whens, else_expr }, .. } = &exprs[0] {
1691                    assert_eq!(whens.len(), 1);
1692                    assert!(else_expr.is_none());
1693                } else {
1694                    panic!("Expected Case expression");
1695                }
1696            } else {
1697                panic!("Expected Named columns");
1698            }
1699        } else {
1700            panic!("Expected Select");
1701        }
1702    }
1703
1704    #[test]
1705    fn test_case_when_in_aggregate() {
1706        let stmt = parse_query(
1707            "SELECT SUM(CASE WHEN side = 'BUY' THEN size ELSE -size END) AS net FROM orders GROUP BY token"
1708        ).unwrap();
1709        if let Statement::Select(q) = stmt {
1710            if let ColumnList::Named(exprs) = &q.columns {
1711                assert_eq!(exprs.len(), 1);
1712                assert!(matches!(&exprs[0], SelectExpr::Aggregate {
1713                    func: AggFunc::Sum,
1714                    arg_expr: Some(Expr::Case { .. }),
1715                    alias: Some(a),
1716                    ..
1717                } if a == "net"));
1718            } else {
1719                panic!("Expected Named columns");
1720            }
1721        } else {
1722            panic!("Expected Select");
1723        }
1724    }
1725
1726    #[test]
1727    fn test_case_when_with_alias() {
1728        let stmt = parse_query(
1729            "SELECT CASE WHEN x > 0 THEN 'pos' ELSE 'neg' END AS sign FROM test"
1730        ).unwrap();
1731        if let Statement::Select(q) = stmt {
1732            if let ColumnList::Named(exprs) = &q.columns {
1733                assert!(matches!(&exprs[0], SelectExpr::Expr {
1734                    expr: Expr::Case { .. },
1735                    alias: Some(a),
1736                } if a == "sign"));
1737            } else {
1738                panic!("Expected Named columns");
1739            }
1740        } else {
1741            panic!("Expected Select");
1742        }
1743    }
1744}