Skip to main content

mdql_core/
query_parser.rs

1//! Hand-written recursive descent parser for the MDQL SQL subset.
2
3use regex::Regex;
4use std::sync::LazyLock;
5
6use crate::errors::MdqlError;
7
8// ── AST nodes ──────────────────────────────────────────────────────────────
9
10#[derive(Debug, Clone, PartialEq)]
11pub enum ArithOp {
12    Add,
13    Sub,
14    Mul,
15    Div,
16    Mod,
17}
18
19#[derive(Debug, Clone, PartialEq)]
20pub enum Expr {
21    Literal(SqlValue),
22    Column(String),
23    BinaryOp { left: Box<Expr>, op: ArithOp, right: Box<Expr> },
24    UnaryMinus(Box<Expr>),
25    Case { whens: Vec<(WhereClause, Box<Expr>)>, else_expr: Option<Box<Expr>> },
26    /// date_expr + INTERVAL n DAY  (negative for subtraction)
27    DateAdd { date: Box<Expr>, days: Box<Expr> },
28    /// DATEDIFF(date1, date2) → integer days
29    DateDiff { left: Box<Expr>, right: Box<Expr> },
30    /// CURRENT_DATE or CURRENT_TIMESTAMP
31    CurrentDate,
32    CurrentTimestamp,
33}
34
35impl Expr {
36    /// If the expression is a simple column reference, return the name.
37    pub fn as_column(&self) -> Option<&str> {
38        if let Expr::Column(name) = self { Some(name) } else { None }
39    }
40
41    /// A display name for this expression (used as output column name).
42    pub fn display_name(&self) -> String {
43        match self {
44            Expr::Literal(SqlValue::Int(n)) => n.to_string(),
45            Expr::Literal(SqlValue::Float(f)) => f.to_string(),
46            Expr::Literal(SqlValue::String(s)) => format!("'{}'", s),
47            Expr::Literal(SqlValue::Null) => "NULL".to_string(),
48            Expr::Literal(SqlValue::List(_)) => "list".to_string(),
49            Expr::Column(name) => name.clone(),
50            Expr::BinaryOp { left, op, right } => {
51                let op_str = match op {
52                    ArithOp::Add => "+",
53                    ArithOp::Sub => "-",
54                    ArithOp::Mul => "*",
55                    ArithOp::Div => "/",
56                    ArithOp::Mod => "%",
57                };
58                format!("{} {} {}", left.display_name(), op_str, right.display_name())
59            }
60            Expr::UnaryMinus(inner) => format!("-{}", inner.display_name()),
61            Expr::Case { .. } => "CASE".to_string(),
62            Expr::DateAdd { date, days } => format!("DATE_ADD({}, {})", date.display_name(), days.display_name()),
63            Expr::DateDiff { left, right } => format!("DATEDIFF({}, {})", left.display_name(), right.display_name()),
64            Expr::CurrentDate => "CURRENT_DATE".to_string(),
65            Expr::CurrentTimestamp => "CURRENT_TIMESTAMP".to_string(),
66        }
67    }
68}
69
70#[derive(Debug, Clone, PartialEq)]
71pub struct OrderSpec {
72    pub column: String,
73    pub expr: Option<Expr>,
74    pub descending: bool,
75}
76
77#[derive(Debug, Clone, PartialEq)]
78pub struct Comparison {
79    pub column: String,
80    pub op: String,
81    pub value: Option<SqlValue>,
82    pub left_expr: Option<Expr>,
83    pub right_expr: Option<Expr>,
84}
85
86#[derive(Debug, Clone, PartialEq)]
87pub struct BoolOp {
88    pub op: String, // "AND" or "OR"
89    pub left: Box<WhereClause>,
90    pub right: Box<WhereClause>,
91}
92
93#[derive(Debug, Clone, PartialEq)]
94pub enum WhereClause {
95    Comparison(Comparison),
96    BoolOp(BoolOp),
97}
98
99#[derive(Debug, Clone, PartialEq)]
100pub enum SqlValue {
101    String(String),
102    Int(i64),
103    Float(f64),
104    Null,
105    List(Vec<SqlValue>),
106}
107
108#[derive(Debug, Clone, PartialEq)]
109pub struct JoinClause {
110    pub table: String,
111    pub alias: Option<String>,
112    pub left_col: String,
113    pub right_col: String,
114}
115
116#[derive(Debug, Clone, PartialEq)]
117pub enum AggFunc {
118    Count,
119    Sum,
120    Avg,
121    Min,
122    Max,
123}
124
125#[derive(Debug, Clone, PartialEq)]
126pub enum SelectExpr {
127    Column(String),
128    Aggregate { func: AggFunc, arg: String, arg_expr: Option<Expr>, alias: Option<String> },
129    Expr { expr: Expr, alias: Option<String> },
130}
131
132impl SelectExpr {
133    pub fn output_name(&self) -> String {
134        match self {
135            SelectExpr::Column(name) => name.clone(),
136            SelectExpr::Aggregate { func, arg, alias, .. } => {
137                if let Some(a) = alias {
138                    a.clone()
139                } else {
140                    let func_name = match func {
141                        AggFunc::Count => "COUNT",
142                        AggFunc::Sum => "SUM",
143                        AggFunc::Avg => "AVG",
144                        AggFunc::Min => "MIN",
145                        AggFunc::Max => "MAX",
146                    };
147                    format!("{}({})", func_name, arg)
148                }
149            }
150            SelectExpr::Expr { expr, alias } => {
151                alias.clone().unwrap_or_else(|| expr.display_name())
152            }
153        }
154    }
155
156    pub fn is_aggregate(&self) -> bool {
157        matches!(self, SelectExpr::Aggregate { .. })
158    }
159}
160
161#[derive(Debug, Clone, PartialEq)]
162pub struct SelectQuery {
163    pub columns: ColumnList,
164    pub table: String,
165    pub table_alias: Option<String>,
166    pub joins: Vec<JoinClause>,
167    pub where_clause: Option<WhereClause>,
168    pub group_by: Option<Vec<String>>,
169    pub having: Option<WhereClause>,
170    pub order_by: Option<Vec<OrderSpec>>,
171    pub limit: Option<i64>,
172}
173
174#[derive(Debug, Clone, PartialEq)]
175pub enum ColumnList {
176    All,
177    Named(Vec<SelectExpr>),
178}
179
180#[derive(Debug, Clone, PartialEq)]
181pub struct InsertQuery {
182    pub table: String,
183    pub columns: Vec<String>,
184    pub values: Vec<SqlValue>,
185}
186
187#[derive(Debug, Clone, PartialEq)]
188pub struct UpdateQuery {
189    pub table: String,
190    pub assignments: Vec<(String, SqlValue)>,
191    pub where_clause: Option<WhereClause>,
192}
193
194#[derive(Debug, Clone, PartialEq)]
195pub struct DeleteQuery {
196    pub table: String,
197    pub where_clause: Option<WhereClause>,
198}
199
200#[derive(Debug, Clone, PartialEq)]
201pub struct AlterRenameFieldQuery {
202    pub table: String,
203    pub old_name: String,
204    pub new_name: String,
205}
206
207#[derive(Debug, Clone, PartialEq)]
208pub struct AlterDropFieldQuery {
209    pub table: String,
210    pub field_name: String,
211}
212
213#[derive(Debug, Clone, PartialEq)]
214pub struct AlterMergeFieldsQuery {
215    pub table: String,
216    pub sources: Vec<String>,
217    pub into: String,
218}
219
220#[derive(Debug, Clone, PartialEq)]
221pub enum Statement {
222    Select(SelectQuery),
223    Insert(InsertQuery),
224    Update(UpdateQuery),
225    Delete(DeleteQuery),
226    AlterRename(AlterRenameFieldQuery),
227    AlterDrop(AlterDropFieldQuery),
228    AlterMerge(AlterMergeFieldsQuery),
229}
230
231// ── Tokenizer ──────────────────────────────────────────────────────────────
232
233static KEYWORDS: &[&str] = &[
234    "SELECT", "FROM", "WHERE", "AND", "OR", "ORDER", "BY",
235    "ASC", "DESC", "LIMIT", "LIKE", "IN", "IS", "NOT", "NULL",
236    "JOIN", "ON", "AS", "GROUP", "HAVING",
237    "INSERT", "INTO", "VALUES", "UPDATE", "SET", "DELETE",
238    "ALTER", "TABLE", "RENAME", "FIELD", "TO", "DROP", "MERGE", "FIELDS",
239    "CASE", "WHEN", "THEN", "ELSE", "END",
240    "INTERVAL", "DAY", "DAYS", "CURRENT_DATE", "CURRENT_TIMESTAMP", "DATEDIFF",
241];
242
243static AGG_FUNCS: &[&str] = &["COUNT", "SUM", "AVG", "MIN", "MAX"];
244
245static TOKEN_RE: LazyLock<Regex> = LazyLock::new(|| {
246    Regex::new(
247        r#"(?x)
248        \s*(?:
249            (?P<backtick>`[^`]+`)
250            | (?P<string>'(?:[^'\\]|\\.)*')
251            | (?P<number>\d+(?:\.\d+)?)
252            | (?P<op><=|>=|!=|[=<>,*()+\-/%])
253            | (?P<word>[A-Za-z_][A-Za-z0-9_./-]*)
254        )"#,
255    )
256    .unwrap()
257});
258
259#[derive(Debug, Clone)]
260struct Token {
261    token_type: String,
262    value: String,
263    raw: String,
264}
265
266fn tokenize(sql: &str) -> Vec<Token> {
267    let mut tokens = Vec::new();
268    for caps in TOKEN_RE.captures_iter(sql) {
269        if let Some(m) = caps.name("backtick") {
270            let raw = m.as_str();
271            tokens.push(Token {
272                token_type: "ident".into(),
273                value: raw[1..raw.len() - 1].into(),
274                raw: raw.into(),
275            });
276        } else if let Some(m) = caps.name("string") {
277            let raw = m.as_str();
278            tokens.push(Token {
279                token_type: "string".into(),
280                value: raw[1..raw.len() - 1].into(),
281                raw: raw.into(),
282            });
283        } else if let Some(m) = caps.name("number") {
284            let raw = m.as_str();
285            tokens.push(Token {
286                token_type: "number".into(),
287                value: raw.into(),
288                raw: raw.into(),
289            });
290        } else if let Some(m) = caps.name("op") {
291            let raw = m.as_str();
292            tokens.push(Token {
293                token_type: "op".into(),
294                value: raw.into(),
295                raw: raw.into(),
296            });
297        } else if let Some(m) = caps.name("word") {
298            let raw = m.as_str();
299            if KEYWORDS.contains(&raw.to_uppercase().as_str()) {
300                tokens.push(Token {
301                    token_type: "keyword".into(),
302                    value: raw.to_uppercase(),
303                    raw: raw.into(),
304                });
305            } else {
306                tokens.push(Token {
307                    token_type: "ident".into(),
308                    value: raw.into(),
309                    raw: raw.into(),
310                });
311            }
312        }
313    }
314    tokens
315}
316
317// ── Parser ─────────────────────────────────────────────────────────────────
318
319struct Parser {
320    tokens: Vec<Token>,
321    pos: usize,
322}
323
324impl Parser {
325    fn new(tokens: Vec<Token>) -> Self {
326        Parser { tokens, pos: 0 }
327    }
328
329    fn peek(&self) -> Option<&Token> {
330        self.tokens.get(self.pos)
331    }
332
333    fn advance(&mut self) -> Token {
334        let t = self.tokens[self.pos].clone();
335        self.pos += 1;
336        t
337    }
338
339    fn expect(&mut self, type_: &str, value: Option<&str>) -> Result<Token, MdqlError> {
340        let t = self.peek().ok_or_else(|| {
341            MdqlError::QueryParse(format!(
342                "Unexpected end of query, expected {}",
343                value.unwrap_or(type_)
344            ))
345        })?;
346        let matches_type = t.token_type == type_;
347        let matches_value = value.map_or(true, |v| t.value == v);
348        if !matches_type || !matches_value {
349            return Err(MdqlError::QueryParse(format!(
350                "Expected {}, got '{}' at position {}",
351                value.unwrap_or(type_),
352                t.raw,
353                self.pos
354            )));
355        }
356        Ok(self.advance())
357    }
358
359    fn match_keyword(&mut self, kw: &str) -> bool {
360        if let Some(t) = self.peek() {
361            if t.token_type == "keyword" && t.value == kw {
362                self.advance();
363                return true;
364            }
365        }
366        false
367    }
368
369    fn parse_statement(&mut self) -> Result<Statement, MdqlError> {
370        let t = self.peek().ok_or_else(|| MdqlError::QueryParse("Empty query".into()))?;
371        match (t.token_type.as_str(), t.value.as_str()) {
372            ("keyword", "SELECT") => Ok(Statement::Select(self.parse_select()?)),
373            ("keyword", "INSERT") => Ok(Statement::Insert(self.parse_insert()?)),
374            ("keyword", "UPDATE") => Ok(Statement::Update(self.parse_update()?)),
375            ("keyword", "DELETE") => Ok(Statement::Delete(self.parse_delete()?)),
376            ("keyword", "ALTER") => self.parse_alter(),
377            _ => Err(MdqlError::QueryParse(format!(
378                "Expected SELECT, INSERT, UPDATE, DELETE, or ALTER, got '{}'",
379                t.raw
380            ))),
381        }
382    }
383
384    fn parse_select(&mut self) -> Result<SelectQuery, MdqlError> {
385        self.expect("keyword", Some("SELECT"))?;
386        let columns = self.parse_columns()?;
387        self.expect("keyword", Some("FROM"))?;
388        let table = self.parse_ident()?;
389
390        // Optional table alias
391        let mut table_alias = None;
392        if let Some(t) = self.peek() {
393            if t.token_type == "ident" && !self.is_clause_keyword(t) {
394                table_alias = Some(self.advance().value);
395            }
396        }
397
398        // Optional JOIN(s)
399        let mut joins = Vec::new();
400        while self.match_keyword("JOIN") {
401            let join_table = self.parse_ident()?;
402            let mut join_alias = None;
403            if let Some(t) = self.peek() {
404                if t.token_type == "ident" && !self.is_clause_keyword(t) {
405                    join_alias = Some(self.advance().value);
406                }
407            }
408            self.expect("keyword", Some("ON"))?;
409            let left_col = self.parse_ident()?;
410            self.expect("op", Some("="))?;
411            let right_col = self.parse_ident()?;
412            joins.push(JoinClause {
413                table: join_table,
414                alias: join_alias,
415                left_col,
416                right_col,
417            });
418        }
419
420        let mut where_clause = None;
421        if self.match_keyword("WHERE") {
422            where_clause = Some(self.parse_or_expr()?);
423        }
424
425        let mut group_by = None;
426        if self.match_keyword("GROUP") {
427            self.expect("keyword", Some("BY"))?;
428            let mut cols = vec![self.parse_ident()?];
429            while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
430                self.advance();
431                cols.push(self.parse_ident()?);
432            }
433            group_by = Some(cols);
434        }
435
436        let mut having = None;
437        if self.match_keyword("HAVING") {
438            having = Some(self.parse_or_expr()?);
439        }
440
441        let mut order_by = None;
442        if self.match_keyword("ORDER") {
443            self.expect("keyword", Some("BY"))?;
444            order_by = Some(self.parse_order_by()?);
445        }
446
447        let mut limit = None;
448        if self.match_keyword("LIMIT") {
449            let t = self.expect("number", None)?;
450            limit = Some(t.value.parse::<i64>().map_err(|_| {
451                MdqlError::QueryParse(format!("Invalid LIMIT value: {}", t.value))
452            })?);
453        }
454
455        self.expect_end()?;
456
457        Ok(SelectQuery {
458            columns,
459            table,
460            table_alias,
461            joins,
462            where_clause,
463            group_by,
464            having,
465            order_by,
466            limit,
467        })
468    }
469
470    fn parse_insert(&mut self) -> Result<InsertQuery, MdqlError> {
471        self.expect("keyword", Some("INSERT"))?;
472        self.expect("keyword", Some("INTO"))?;
473        let table = self.parse_ident()?;
474
475        self.expect("op", Some("("))?;
476        let mut columns = vec![self.parse_ident()?];
477        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
478            self.advance();
479            columns.push(self.parse_ident()?);
480        }
481        self.expect("op", Some(")"))?;
482
483        self.expect("keyword", Some("VALUES"))?;
484
485        self.expect("op", Some("("))?;
486        let mut values = vec![self.parse_value()?];
487        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
488            self.advance();
489            values.push(self.parse_value()?);
490        }
491        self.expect("op", Some(")"))?;
492
493        if columns.len() != values.len() {
494            return Err(MdqlError::QueryParse(format!(
495                "Column count ({}) does not match value count ({})",
496                columns.len(),
497                values.len()
498            )));
499        }
500
501        self.expect_end()?;
502        Ok(InsertQuery {
503            table,
504            columns,
505            values,
506        })
507    }
508
509    fn parse_update(&mut self) -> Result<UpdateQuery, MdqlError> {
510        self.expect("keyword", Some("UPDATE"))?;
511        let table = self.parse_ident()?;
512        self.expect("keyword", Some("SET"))?;
513
514        let mut assignments = Vec::new();
515        let col = self.parse_ident()?;
516        self.expect("op", Some("="))?;
517        let val = self.parse_value()?;
518        assignments.push((col, val));
519
520        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
521            self.advance();
522            let col = self.parse_ident()?;
523            self.expect("op", Some("="))?;
524            let val = self.parse_value()?;
525            assignments.push((col, val));
526        }
527
528        let mut where_clause = None;
529        if self.match_keyword("WHERE") {
530            where_clause = Some(self.parse_or_expr()?);
531        }
532
533        self.expect_end()?;
534        Ok(UpdateQuery {
535            table,
536            assignments,
537            where_clause,
538        })
539    }
540
541    fn parse_delete(&mut self) -> Result<DeleteQuery, MdqlError> {
542        self.expect("keyword", Some("DELETE"))?;
543        self.expect("keyword", Some("FROM"))?;
544        let table = self.parse_ident()?;
545
546        let mut where_clause = None;
547        if self.match_keyword("WHERE") {
548            where_clause = Some(self.parse_or_expr()?);
549        }
550
551        self.expect_end()?;
552        Ok(DeleteQuery {
553            table,
554            where_clause,
555        })
556    }
557
558    fn parse_alter(&mut self) -> Result<Statement, MdqlError> {
559        self.expect("keyword", Some("ALTER"))?;
560        self.expect("keyword", Some("TABLE"))?;
561        let table = self.parse_ident()?;
562
563        let t = self.peek().ok_or_else(|| {
564            MdqlError::QueryParse("Expected RENAME, DROP, or MERGE after table name".into())
565        })?;
566
567        match (t.token_type.as_str(), t.value.as_str()) {
568            ("keyword", "RENAME") => {
569                self.advance();
570                self.expect("keyword", Some("FIELD"))?;
571                let old_name = self.parse_string_or_ident()?;
572                self.expect("keyword", Some("TO"))?;
573                let new_name = self.parse_string_or_ident()?;
574                self.expect_end()?;
575                Ok(Statement::AlterRename(AlterRenameFieldQuery {
576                    table,
577                    old_name,
578                    new_name,
579                }))
580            }
581            ("keyword", "DROP") => {
582                self.advance();
583                self.expect("keyword", Some("FIELD"))?;
584                let field_name = self.parse_string_or_ident()?;
585                self.expect_end()?;
586                Ok(Statement::AlterDrop(AlterDropFieldQuery {
587                    table,
588                    field_name,
589                }))
590            }
591            ("keyword", "MERGE") => {
592                self.advance();
593                self.expect("keyword", Some("FIELDS"))?;
594                let mut sources = vec![self.parse_string_or_ident()?];
595                while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
596                    self.advance();
597                    sources.push(self.parse_string_or_ident()?);
598                }
599                self.expect("keyword", Some("INTO"))?;
600                let target = self.parse_string_or_ident()?;
601                self.expect_end()?;
602                Ok(Statement::AlterMerge(AlterMergeFieldsQuery {
603                    table,
604                    sources,
605                    into: target,
606                }))
607            }
608            _ => Err(MdqlError::QueryParse(format!(
609                "Expected RENAME, DROP, or MERGE, got '{}'",
610                t.raw
611            ))),
612        }
613    }
614
615    fn parse_string_or_ident(&mut self) -> Result<String, MdqlError> {
616        let t = self.peek().ok_or_else(|| {
617            MdqlError::QueryParse("Expected field name, got end of query".into())
618        })?;
619        match t.token_type.as_str() {
620            "string" => {
621                let v = self.advance().value;
622                Ok(v)
623            }
624            "ident" | "keyword" => {
625                let v = self.advance().value;
626                Ok(v)
627            }
628            _ => Err(MdqlError::QueryParse(format!(
629                "Expected field name, got '{}'",
630                t.raw
631            ))),
632        }
633    }
634
635    fn parse_columns(&mut self) -> Result<ColumnList, MdqlError> {
636        if let Some(t) = self.peek() {
637            if t.token_type == "op" && t.value == "*" {
638                self.advance();
639                return Ok(ColumnList::All);
640            }
641        }
642
643        let mut exprs = vec![self.parse_select_expr()?];
644        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
645            self.advance();
646            exprs.push(self.parse_select_expr()?);
647        }
648        Ok(ColumnList::Named(exprs))
649    }
650
651    fn peek_is_agg_func(&self) -> bool {
652        let t = match self.peek() {
653            Some(t) => t,
654            None => return false,
655        };
656        let name_upper = t.value.to_uppercase();
657        if !AGG_FUNCS.contains(&name_upper.as_str()) {
658            return false;
659        }
660        // Only treat as aggregate if followed by (
661        self.tokens
662            .get(self.pos + 1)
663            .map_or(false, |next| next.token_type == "op" && next.value == "(")
664    }
665
666    fn parse_select_expr(&mut self) -> Result<SelectExpr, MdqlError> {
667        let _t = self.peek().ok_or_else(|| {
668            MdqlError::QueryParse("Expected column or aggregate, got end of query".into())
669        })?;
670
671        if self.peek_is_agg_func() {
672            let func_name = self.advance().value.to_uppercase();
673            let func = match func_name.as_str() {
674                "COUNT" => AggFunc::Count,
675                "SUM" => AggFunc::Sum,
676                "AVG" => AggFunc::Avg,
677                "MIN" => AggFunc::Min,
678                "MAX" => AggFunc::Max,
679                _ => unreachable!(),
680            };
681            self.expect("op", Some("("))?;
682            let (arg, arg_expr) = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "*") {
683                self.advance();
684                ("*".to_string(), None)
685            } else {
686                // Parse a full expression inside the aggregate (supports CASE WHEN, arithmetic, etc.)
687                let expr = self.parse_additive()?;
688                if let Expr::Column(name) = &expr {
689                    (name.clone(), None)
690                } else {
691                    (expr.display_name(), Some(expr))
692                }
693            };
694            self.expect("op", Some(")"))?;
695
696            let alias = if self.match_keyword("AS") {
697                Some(self.parse_ident()?)
698            } else if self.peek().map_or(false, |t| {
699                t.token_type == "ident" && !self.is_clause_keyword(t)
700            }) {
701                Some(self.advance().value)
702            } else {
703                None
704            };
705
706            Ok(SelectExpr::Aggregate { func, arg, arg_expr, alias })
707        } else {
708            // Parse a general expression (could be a column, literal, or arithmetic)
709            let expr = self.parse_additive()?;
710
711            // Optional alias: explicit (AS alias) or implicit (just an ident)
712            let alias = if self.match_keyword("AS") {
713                Some(self.parse_ident()?)
714            } else if self.peek().map_or(false, |t| {
715                t.token_type == "ident" && !self.is_clause_keyword(t)
716            }) {
717                Some(self.advance().value)
718            } else {
719                None
720            };
721
722            // If it's a simple column reference with no alias, return Column variant
723            // for backward compatibility
724            if alias.is_none() {
725                if let Expr::Column(name) = &expr {
726                    return Ok(SelectExpr::Column(name.clone()));
727                }
728            }
729
730            Ok(SelectExpr::Expr { expr, alias })
731        }
732    }
733
734    // ── Expression parser (precedence climbing) ───────────────────────
735
736    fn peek_is_additive_op(&self) -> bool {
737        self.peek().map_or(false, |t| {
738            t.token_type == "op" && (t.value == "+" || t.value == "-")
739        })
740    }
741
742    fn peek_is_multiplicative_op(&self) -> bool {
743        self.peek().map_or(false, |t| {
744            t.token_type == "op" && (t.value == "*" || t.value == "/" || t.value == "%")
745        })
746    }
747
748    fn parse_additive(&mut self) -> Result<Expr, MdqlError> {
749        let mut left = self.parse_multiplicative()?;
750        while self.peek_is_additive_op() {
751            let op_tok = self.advance();
752            let is_sub = op_tok.value == "-";
753
754            // Check for INTERVAL keyword: expr +/- INTERVAL n DAY
755            if self.peek().map_or(false, |t| t.token_type == "keyword" && t.value == "INTERVAL") {
756                self.advance(); // consume INTERVAL
757                let days_expr = self.parse_multiplicative()?;
758                // Expect DAY or DAYS
759                if !self.match_keyword("DAY") && !self.match_keyword("DAYS") {
760                    return Err(MdqlError::QueryParse("Expected DAY after INTERVAL value".into()));
761                }
762                let days = if is_sub {
763                    Expr::UnaryMinus(Box::new(days_expr))
764                } else {
765                    days_expr
766                };
767                left = Expr::DateAdd {
768                    date: Box::new(left),
769                    days: Box::new(days),
770                };
771                continue;
772            }
773
774            let op = match op_tok.value.as_str() {
775                "+" => ArithOp::Add,
776                "-" => ArithOp::Sub,
777                _ => unreachable!(),
778            };
779            let right = self.parse_multiplicative()?;
780            left = Expr::BinaryOp {
781                left: Box::new(left),
782                op,
783                right: Box::new(right),
784            };
785        }
786        Ok(left)
787    }
788
789    fn parse_multiplicative(&mut self) -> Result<Expr, MdqlError> {
790        let mut left = self.parse_unary()?;
791        while self.peek_is_multiplicative_op() {
792            let op_tok = self.advance();
793            let op = match op_tok.value.as_str() {
794                "*" => ArithOp::Mul,
795                "/" => ArithOp::Div,
796                "%" => ArithOp::Mod,
797                _ => unreachable!(),
798            };
799            let right = self.parse_unary()?;
800            left = Expr::BinaryOp {
801                left: Box::new(left),
802                op,
803                right: Box::new(right),
804            };
805        }
806        Ok(left)
807    }
808
809    fn parse_unary(&mut self) -> Result<Expr, MdqlError> {
810        if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "-") {
811            self.advance();
812            let inner = self.parse_atom()?;
813            // Fold unary minus on literals
814            match inner {
815                Expr::Literal(SqlValue::Int(n)) => Ok(Expr::Literal(SqlValue::Int(-n))),
816                Expr::Literal(SqlValue::Float(f)) => Ok(Expr::Literal(SqlValue::Float(-f))),
817                _ => Ok(Expr::UnaryMinus(Box::new(inner))),
818            }
819        } else {
820            self.parse_atom()
821        }
822    }
823
824    fn parse_atom(&mut self) -> Result<Expr, MdqlError> {
825        let t = self.peek().ok_or_else(|| {
826            MdqlError::QueryParse("Expected expression, got end of query".into())
827        })?;
828
829        match t.token_type.as_str() {
830            "number" => {
831                let v = self.advance().value;
832                if v.contains('.') {
833                    let f: f64 = v.parse().map_err(|_| {
834                        MdqlError::QueryParse(format!("Invalid float: {}", v))
835                    })?;
836                    Ok(Expr::Literal(SqlValue::Float(f)))
837                } else {
838                    let n: i64 = v.parse().map_err(|_| {
839                        MdqlError::QueryParse(format!("Invalid int: {}", v))
840                    })?;
841                    Ok(Expr::Literal(SqlValue::Int(n)))
842                }
843            }
844            "string" => {
845                let v = self.advance().value;
846                Ok(Expr::Literal(SqlValue::String(v)))
847            }
848            "keyword" if t.value == "NULL" => {
849                self.advance();
850                Ok(Expr::Literal(SqlValue::Null))
851            }
852            "keyword" if t.value == "CASE" => {
853                self.parse_case_expr()
854            }
855            "keyword" if t.value == "CURRENT_DATE" => {
856                self.advance();
857                Ok(Expr::CurrentDate)
858            }
859            "keyword" if t.value == "CURRENT_TIMESTAMP" => {
860                self.advance();
861                Ok(Expr::CurrentTimestamp)
862            }
863            "keyword" if t.value == "DATEDIFF" => {
864                self.advance();
865                self.expect("op", Some("("))?;
866                let left = self.parse_additive()?;
867                self.expect("op", Some(","))?;
868                let right = self.parse_additive()?;
869                self.expect("op", Some(")"))?;
870                Ok(Expr::DateDiff { left: Box::new(left), right: Box::new(right) })
871            }
872            "op" if t.value == "(" => {
873                self.advance();
874                let expr = self.parse_additive()?;
875                self.expect("op", Some(")"))?;
876                Ok(expr)
877            }
878            "ident" => {
879                let name = self.advance().value;
880                Ok(Expr::Column(name))
881            }
882            "keyword" if !Self::is_reserved_keyword(&t.value) => {
883                let name = self.advance().value;
884                Ok(Expr::Column(name))
885            }
886            _ => Err(MdqlError::QueryParse(format!(
887                "Expected expression, got '{}'",
888                t.raw
889            ))),
890        }
891    }
892
893    fn parse_case_expr(&mut self) -> Result<Expr, MdqlError> {
894        self.expect("keyword", Some("CASE"))?;
895        let mut whens = Vec::new();
896        while self.match_keyword("WHEN") {
897            let condition = self.parse_or_expr()?;
898            self.expect("keyword", Some("THEN"))?;
899            let result = self.parse_additive()?;
900            whens.push((condition, Box::new(result)));
901        }
902        if whens.is_empty() {
903            return Err(MdqlError::QueryParse("CASE requires at least one WHEN clause".into()));
904        }
905        let else_expr = if self.match_keyword("ELSE") {
906            Some(Box::new(self.parse_additive()?))
907        } else {
908            None
909        };
910        self.expect("keyword", Some("END"))?;
911        Ok(Expr::Case { whens, else_expr })
912    }
913
914    fn parse_ident(&mut self) -> Result<String, MdqlError> {
915        let t = self.peek().ok_or_else(|| {
916            MdqlError::QueryParse("Expected identifier, got end of query".into())
917        })?;
918        match t.token_type.as_str() {
919            "ident" | "keyword" => {
920                let v = self.advance().value;
921                Ok(v)
922            }
923            _ => Err(MdqlError::QueryParse(format!(
924                "Expected identifier, got '{}'",
925                t.raw
926            ))),
927        }
928    }
929
930    fn parse_or_expr(&mut self) -> Result<WhereClause, MdqlError> {
931        let mut left = self.parse_and_expr()?;
932        while self.match_keyword("OR") {
933            let right = self.parse_and_expr()?;
934            left = WhereClause::BoolOp(BoolOp {
935                op: "OR".into(),
936                left: Box::new(left),
937                right: Box::new(right),
938            });
939        }
940        Ok(left)
941    }
942
943    fn parse_and_expr(&mut self) -> Result<WhereClause, MdqlError> {
944        let mut left = self.parse_comparison()?;
945        while self.match_keyword("AND") {
946            let right = self.parse_comparison()?;
947            left = WhereClause::BoolOp(BoolOp {
948                op: "AND".into(),
949                left: Box::new(left),
950                right: Box::new(right),
951            });
952        }
953        Ok(left)
954    }
955
956    fn parse_comparison(&mut self) -> Result<WhereClause, MdqlError> {
957        // Handle parenthesized boolean expressions
958        if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
959            // Save position — might be arithmetic parens, not boolean
960            let saved_pos = self.pos;
961            self.advance();
962            // Try parsing as boolean (OR/AND) expression
963            let result = self.parse_or_expr();
964            if result.is_ok() && self.peek().map_or(false, |t| t.token_type == "op" && t.value == ")") {
965                self.advance();
966                return result;
967            }
968            // Not a boolean paren — rewind and parse as arithmetic expression
969            self.pos = saved_pos;
970        }
971
972        // Parse the left side as a full expression (column, literal, or arithmetic)
973        let left_expr = self.parse_additive()?;
974
975        // Extract column name for backward compat (simple column on left side)
976        let col = left_expr.as_column().unwrap_or("").to_string();
977
978        // IS NULL / IS NOT NULL (only valid with simple column)
979        if self.match_keyword("IS") {
980            if self.match_keyword("NOT") {
981                self.expect("keyword", Some("NULL"))?;
982                return Ok(WhereClause::Comparison(Comparison {
983                    column: col,
984                    op: "IS NOT NULL".into(),
985                    value: None,
986                    left_expr: Some(left_expr),
987                    right_expr: None,
988                }));
989            }
990            self.expect("keyword", Some("NULL"))?;
991            return Ok(WhereClause::Comparison(Comparison {
992                column: col,
993                op: "IS NULL".into(),
994                value: None,
995                left_expr: Some(left_expr),
996                right_expr: None,
997            }));
998        }
999
1000        // IN (val, val, ...)
1001        if self.match_keyword("IN") {
1002            self.expect("op", Some("("))?;
1003            let mut values = vec![self.parse_value()?];
1004            while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
1005                self.advance();
1006                values.push(self.parse_value()?);
1007            }
1008            self.expect("op", Some(")"))?;
1009            return Ok(WhereClause::Comparison(Comparison {
1010                column: col,
1011                op: "IN".into(),
1012                value: Some(SqlValue::List(values)),
1013                left_expr: Some(left_expr),
1014                right_expr: None,
1015            }));
1016        }
1017
1018        // LIKE
1019        if self.match_keyword("LIKE") {
1020            let val = self.parse_value()?;
1021            return Ok(WhereClause::Comparison(Comparison {
1022                column: col,
1023                op: "LIKE".into(),
1024                value: Some(val),
1025                left_expr: Some(left_expr),
1026                right_expr: None,
1027            }));
1028        }
1029
1030        // NOT LIKE
1031        if self.match_keyword("NOT") {
1032            if self.match_keyword("LIKE") {
1033                let val = self.parse_value()?;
1034                return Ok(WhereClause::Comparison(Comparison {
1035                    column: col,
1036                    op: "NOT LIKE".into(),
1037                    value: Some(val),
1038                    left_expr: Some(left_expr),
1039                    right_expr: None,
1040                }));
1041            }
1042            return Err(MdqlError::QueryParse("Expected LIKE after NOT".into()));
1043        }
1044
1045        // Standard comparison operators
1046        if let Some(t) = self.peek() {
1047            if t.token_type == "op" && ["=", "!=", "<", ">", "<=", ">="].contains(&t.value.as_str())
1048            {
1049                let op = self.advance().value;
1050                // Parse right side as expression
1051                let right_expr = self.parse_additive()?;
1052                // Extract SqlValue for backward compat (simple literal on right side)
1053                let value = match &right_expr {
1054                    Expr::Literal(v) => Some(v.clone()),
1055                    _ => None,
1056                };
1057                return Ok(WhereClause::Comparison(Comparison {
1058                    column: col,
1059                    op,
1060                    value,
1061                    left_expr: Some(left_expr),
1062                    right_expr: Some(right_expr),
1063                }));
1064            }
1065        }
1066
1067        let got = self.peek().map_or("end".to_string(), |t| t.raw.clone());
1068        Err(MdqlError::QueryParse(format!(
1069            "Expected operator after '{}', got '{}'",
1070            left_expr.display_name(), got
1071        )))
1072    }
1073
1074    fn parse_value(&mut self) -> Result<SqlValue, MdqlError> {
1075        let t = self.peek().ok_or_else(|| {
1076            MdqlError::QueryParse("Expected value, got end of query".into())
1077        })?;
1078        match t.token_type.as_str() {
1079            "string" => {
1080                let v = self.advance().value;
1081                Ok(SqlValue::String(v))
1082            }
1083            "number" => {
1084                let v = self.advance().value;
1085                if v.contains('.') {
1086                    Ok(SqlValue::Float(v.parse().map_err(|_| {
1087                        MdqlError::QueryParse(format!("Invalid float: {}", v))
1088                    })?))
1089                } else {
1090                    Ok(SqlValue::Int(v.parse().map_err(|_| {
1091                        MdqlError::QueryParse(format!("Invalid int: {}", v))
1092                    })?))
1093                }
1094            }
1095            "keyword" if t.value == "NULL" => {
1096                self.advance();
1097                Ok(SqlValue::Null)
1098            }
1099            _ => Err(MdqlError::QueryParse(format!(
1100                "Expected value, got '{}'",
1101                t.raw
1102            ))),
1103        }
1104    }
1105
1106    fn parse_order_by(&mut self) -> Result<Vec<OrderSpec>, MdqlError> {
1107        let mut specs = vec![self.parse_order_spec()?];
1108        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
1109            self.advance();
1110            specs.push(self.parse_order_spec()?);
1111        }
1112        Ok(specs)
1113    }
1114
1115    fn parse_order_spec(&mut self) -> Result<OrderSpec, MdqlError> {
1116        let expr = self.parse_additive()?;
1117        let col = expr.as_column().unwrap_or("").to_string();
1118        let descending = if self.match_keyword("DESC") {
1119            true
1120        } else {
1121            self.match_keyword("ASC");
1122            false
1123        };
1124        Ok(OrderSpec {
1125            column: col,
1126            expr: Some(expr),
1127            descending,
1128        })
1129    }
1130
1131    fn is_clause_keyword(&self, t: &Token) -> bool {
1132        t.token_type == "keyword"
1133            && ["WHERE", "ORDER", "LIMIT", "JOIN", "ON", "GROUP"].contains(&t.value.as_str())
1134    }
1135
1136    /// Keywords that should never be consumed as column names inside expressions.
1137    fn is_reserved_keyword(kw: &str) -> bool {
1138        matches!(kw,
1139            "AS" | "FROM" | "WHERE" | "AND" | "OR" | "ORDER" | "BY"
1140            | "ASC" | "DESC" | "LIMIT" | "JOIN" | "ON" | "GROUP"
1141            | "SELECT" | "INSERT" | "INTO" | "VALUES" | "UPDATE" | "SET"
1142            | "DELETE" | "ALTER" | "TABLE" | "IS" | "NOT" | "IN" | "LIKE"
1143            | "RENAME" | "FIELD" | "TO" | "DROP" | "MERGE" | "FIELDS"
1144            | "CASE" | "WHEN" | "THEN" | "ELSE" | "END"
1145            | "HAVING" | "INTERVAL" | "DAY" | "DAYS"
1146            | "CURRENT_DATE" | "CURRENT_TIMESTAMP" | "DATEDIFF"
1147        )
1148    }
1149
1150    fn expect_end(&self) -> Result<(), MdqlError> {
1151        if let Some(t) = self.peek() {
1152            return Err(MdqlError::QueryParse(format!(
1153                "Unexpected token '{}' at position {}",
1154                t.raw, self.pos
1155            )));
1156        }
1157        Ok(())
1158    }
1159}
1160
1161pub fn parse_query(sql: &str) -> crate::errors::Result<Statement> {
1162    let tokens = tokenize(sql);
1163    if tokens.is_empty() {
1164        return Err(MdqlError::QueryParse("Empty query".into()));
1165    }
1166    let mut parser = Parser::new(tokens);
1167    parser.parse_statement()
1168}
1169
1170#[cfg(test)]
1171mod tests {
1172    use super::*;
1173
1174    #[test]
1175    fn test_simple_select() {
1176        let stmt = parse_query("SELECT title, status FROM strategies").unwrap();
1177        if let Statement::Select(q) = stmt {
1178            assert_eq!(q.columns, ColumnList::Named(vec![SelectExpr::Column("title".into()), SelectExpr::Column("status".into())]));
1179            assert_eq!(q.table, "strategies");
1180        } else {
1181            panic!("Expected Select");
1182        }
1183    }
1184
1185    #[test]
1186    fn test_select_star() {
1187        let stmt = parse_query("SELECT * FROM test").unwrap();
1188        if let Statement::Select(q) = stmt {
1189            assert_eq!(q.columns, ColumnList::All);
1190        } else {
1191            panic!("Expected Select");
1192        }
1193    }
1194
1195    #[test]
1196    fn test_where_clause() {
1197        let stmt = parse_query("SELECT title FROM test WHERE count > 5").unwrap();
1198        if let Statement::Select(q) = stmt {
1199            assert!(q.where_clause.is_some());
1200        } else {
1201            panic!("Expected Select");
1202        }
1203    }
1204
1205    #[test]
1206    fn test_order_by() {
1207        let stmt =
1208            parse_query("SELECT title FROM test ORDER BY composite DESC, title ASC").unwrap();
1209        if let Statement::Select(q) = stmt {
1210            let ob = q.order_by.unwrap();
1211            assert_eq!(ob.len(), 2);
1212            assert!(ob[0].descending);
1213            assert!(!ob[1].descending);
1214        } else {
1215            panic!("Expected Select");
1216        }
1217    }
1218
1219    #[test]
1220    fn test_limit() {
1221        let stmt = parse_query("SELECT * FROM test LIMIT 10").unwrap();
1222        if let Statement::Select(q) = stmt {
1223            assert_eq!(q.limit, Some(10));
1224        } else {
1225            panic!("Expected Select");
1226        }
1227    }
1228
1229    #[test]
1230    fn test_insert() {
1231        let stmt = parse_query(
1232            "INSERT INTO test (title, count) VALUES ('Hello', 42)",
1233        )
1234        .unwrap();
1235        if let Statement::Insert(q) = stmt {
1236            assert_eq!(q.table, "test");
1237            assert_eq!(q.columns, vec!["title", "count"]);
1238            assert_eq!(q.values[0], SqlValue::String("Hello".into()));
1239            assert_eq!(q.values[1], SqlValue::Int(42));
1240        } else {
1241            panic!("Expected Insert");
1242        }
1243    }
1244
1245    #[test]
1246    fn test_update() {
1247        let stmt = parse_query("UPDATE test SET status = 'KILLED' WHERE path = 'a.md'").unwrap();
1248        if let Statement::Update(q) = stmt {
1249            assert_eq!(q.table, "test");
1250            assert_eq!(q.assignments.len(), 1);
1251            assert!(q.where_clause.is_some());
1252        } else {
1253            panic!("Expected Update");
1254        }
1255    }
1256
1257    #[test]
1258    fn test_delete() {
1259        let stmt = parse_query("DELETE FROM test WHERE status = 'draft'").unwrap();
1260        if let Statement::Delete(q) = stmt {
1261            assert_eq!(q.table, "test");
1262            assert!(q.where_clause.is_some());
1263        } else {
1264            panic!("Expected Delete");
1265        }
1266    }
1267
1268    #[test]
1269    fn test_alter_rename() {
1270        let stmt =
1271            parse_query("ALTER TABLE test RENAME FIELD 'Summary' TO 'Overview'").unwrap();
1272        if let Statement::AlterRename(q) = stmt {
1273            assert_eq!(q.old_name, "Summary");
1274            assert_eq!(q.new_name, "Overview");
1275        } else {
1276            panic!("Expected AlterRename");
1277        }
1278    }
1279
1280    #[test]
1281    fn test_alter_drop() {
1282        let stmt = parse_query("ALTER TABLE test DROP FIELD 'Details'").unwrap();
1283        if let Statement::AlterDrop(q) = stmt {
1284            assert_eq!(q.field_name, "Details");
1285        } else {
1286            panic!("Expected AlterDrop");
1287        }
1288    }
1289
1290    #[test]
1291    fn test_alter_merge() {
1292        let stmt = parse_query(
1293            "ALTER TABLE test MERGE FIELDS 'Entry Rules', 'Exit Rules' INTO 'Trading Rules'",
1294        )
1295        .unwrap();
1296        if let Statement::AlterMerge(q) = stmt {
1297            assert_eq!(q.sources, vec!["Entry Rules", "Exit Rules"]);
1298            assert_eq!(q.into, "Trading Rules");
1299        } else {
1300            panic!("Expected AlterMerge");
1301        }
1302    }
1303
1304    #[test]
1305    fn test_backtick_ident() {
1306        let stmt = parse_query("SELECT `Structural Mechanism` FROM test").unwrap();
1307        if let Statement::Select(q) = stmt {
1308            assert_eq!(
1309                q.columns,
1310                ColumnList::Named(vec![SelectExpr::Column("Structural Mechanism".into())])
1311            );
1312        } else {
1313            panic!("Expected Select");
1314        }
1315    }
1316
1317    #[test]
1318    fn test_like_operator() {
1319        let stmt = parse_query("SELECT title FROM test WHERE categories LIKE '%defi%'").unwrap();
1320        if let Statement::Select(q) = stmt {
1321            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1322                assert_eq!(c.op, "LIKE");
1323                assert_eq!(c.value, Some(SqlValue::String("%defi%".into())));
1324            } else {
1325                panic!("Expected LIKE comparison");
1326            }
1327        } else {
1328            panic!("Expected Select");
1329        }
1330    }
1331
1332    #[test]
1333    fn test_in_operator() {
1334        let stmt =
1335            parse_query("SELECT * FROM test WHERE status IN ('ACTIVE', 'LIVE')").unwrap();
1336        if let Statement::Select(q) = stmt {
1337            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1338                assert_eq!(c.op, "IN");
1339            } else {
1340                panic!("Expected IN comparison");
1341            }
1342        } else {
1343            panic!("Expected Select");
1344        }
1345    }
1346
1347    #[test]
1348    fn test_is_null() {
1349        let stmt = parse_query("SELECT * FROM test WHERE title IS NULL").unwrap();
1350        if let Statement::Select(q) = stmt {
1351            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1352                assert_eq!(c.op, "IS NULL");
1353            } else {
1354                panic!("Expected IS NULL comparison");
1355            }
1356        } else {
1357            panic!("Expected Select");
1358        }
1359    }
1360
1361    #[test]
1362    fn test_and_or() {
1363        let stmt = parse_query(
1364            "SELECT * FROM test WHERE status = 'ACTIVE' AND count > 5 OR title LIKE '%test%'",
1365        )
1366        .unwrap();
1367        if let Statement::Select(q) = stmt {
1368            assert!(q.where_clause.is_some());
1369        } else {
1370            panic!("Expected Select");
1371        }
1372    }
1373
1374    #[test]
1375    fn test_join() {
1376        let stmt = parse_query(
1377            "SELECT s.title, b.sharpe FROM strategies s JOIN backtests b ON b.strategy = s.path",
1378        )
1379        .unwrap();
1380        if let Statement::Select(q) = stmt {
1381            assert_eq!(q.table, "strategies");
1382            assert_eq!(q.table_alias, Some("s".into()));
1383            assert_eq!(q.joins.len(), 1);
1384            let join = &q.joins[0];
1385            assert_eq!(join.table, "backtests");
1386            assert_eq!(join.alias, Some("b".into()));
1387        } else {
1388            panic!("Expected Select");
1389        }
1390    }
1391
1392    #[test]
1393    fn test_multi_join() {
1394        let stmt = parse_query(
1395            "SELECT s.title, b.sharpe, c.verdict FROM strategies s JOIN backtests b ON b.strategy = s.path JOIN critiques c ON c.strategy = s.path",
1396        )
1397        .unwrap();
1398        if let Statement::Select(q) = stmt {
1399            assert_eq!(q.table, "strategies");
1400            assert_eq!(q.table_alias, Some("s".into()));
1401            assert_eq!(q.joins.len(), 2);
1402            assert_eq!(q.joins[0].table, "backtests");
1403            assert_eq!(q.joins[0].alias, Some("b".into()));
1404            assert_eq!(q.joins[0].left_col, "b.strategy");
1405            assert_eq!(q.joins[0].right_col, "s.path");
1406            assert_eq!(q.joins[1].table, "critiques");
1407            assert_eq!(q.joins[1].alias, Some("c".into()));
1408            assert_eq!(q.joins[1].left_col, "c.strategy");
1409            assert_eq!(q.joins[1].right_col, "s.path");
1410        } else {
1411            panic!("Expected Select");
1412        }
1413    }
1414
1415    #[test]
1416    fn test_empty_query() {
1417        assert!(parse_query("").is_err());
1418    }
1419
1420    #[test]
1421    fn test_count_star() {
1422        let stmt = parse_query("SELECT status, COUNT(*) AS cnt FROM strategies GROUP BY status").unwrap();
1423        if let Statement::Select(q) = stmt {
1424            if let ColumnList::Named(exprs) = &q.columns {
1425                assert_eq!(exprs.len(), 2);
1426                assert_eq!(exprs[0], SelectExpr::Column("status".into()));
1427                assert!(matches!(&exprs[1], SelectExpr::Aggregate {
1428                    func: AggFunc::Count,
1429                    arg,
1430                    alias: Some(a),
1431                    ..
1432                } if arg == "*" && a == "cnt"));
1433            } else {
1434                panic!("Expected Named columns");
1435            }
1436            assert_eq!(q.group_by, Some(vec!["status".into()]));
1437        } else {
1438            panic!("Expected Select");
1439        }
1440    }
1441
1442    #[test]
1443    fn test_count_column_as_ident() {
1444        // "count" as a column name should NOT be parsed as the COUNT aggregate
1445        let stmt = parse_query("INSERT INTO test (title, count) VALUES ('Hello', 42)").unwrap();
1446        if let Statement::Insert(q) = stmt {
1447            assert_eq!(q.columns, vec!["title", "count"]);
1448        } else {
1449            panic!("Expected Insert");
1450        }
1451    }
1452
1453    #[test]
1454    fn test_multiple_aggregates() {
1455        let stmt = parse_query("SELECT MIN(composite), MAX(composite), AVG(composite) FROM strategies").unwrap();
1456        if let Statement::Select(q) = stmt {
1457            if let ColumnList::Named(exprs) = &q.columns {
1458                assert_eq!(exprs.len(), 3);
1459                assert!(matches!(&exprs[0], SelectExpr::Aggregate { func: AggFunc::Min, .. }));
1460                assert!(matches!(&exprs[1], SelectExpr::Aggregate { func: AggFunc::Max, .. }));
1461                assert!(matches!(&exprs[2], SelectExpr::Aggregate { func: AggFunc::Avg, .. }));
1462            } else {
1463                panic!("Expected Named columns");
1464            }
1465            assert_eq!(q.group_by, None);
1466        } else {
1467            panic!("Expected Select");
1468        }
1469    }
1470
1471    // ── Expression tests ──────────────────────────────────────────
1472
1473    #[test]
1474    fn test_select_arithmetic_expr() {
1475        let stmt = parse_query("SELECT a + b FROM test").unwrap();
1476        if let Statement::Select(q) = stmt {
1477            if let ColumnList::Named(exprs) = &q.columns {
1478                assert_eq!(exprs.len(), 1);
1479                assert!(matches!(&exprs[0], SelectExpr::Expr {
1480                    expr: Expr::BinaryOp { op: ArithOp::Add, .. },
1481                    alias: None,
1482                }));
1483            } else {
1484                panic!("Expected Named columns");
1485            }
1486        } else {
1487            panic!("Expected Select");
1488        }
1489    }
1490
1491    #[test]
1492    fn test_select_arithmetic_with_alias() {
1493        let stmt = parse_query("SELECT a + b AS total FROM test").unwrap();
1494        if let Statement::Select(q) = stmt {
1495            if let ColumnList::Named(exprs) = &q.columns {
1496                assert_eq!(exprs.len(), 1);
1497                assert!(matches!(&exprs[0], SelectExpr::Expr {
1498                    alias: Some(a),
1499                    ..
1500                } if a == "total"));
1501                assert_eq!(exprs[0].output_name(), "total");
1502            } else {
1503                panic!("Expected Named columns");
1504            }
1505        } else {
1506            panic!("Expected Select");
1507        }
1508    }
1509
1510    #[test]
1511    fn test_select_precedence() {
1512        // a + b * c should parse as a + (b * c)
1513        let stmt = parse_query("SELECT a + b * c FROM test").unwrap();
1514        if let Statement::Select(q) = stmt {
1515            if let ColumnList::Named(exprs) = &q.columns {
1516                if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1517                    if let Expr::BinaryOp { left, op, right } = expr {
1518                        assert_eq!(*op, ArithOp::Add);
1519                        assert!(matches!(left.as_ref(), Expr::Column(n) if n == "a"));
1520                        assert!(matches!(right.as_ref(), Expr::BinaryOp { op: ArithOp::Mul, .. }));
1521                    } else {
1522                        panic!("Expected BinaryOp");
1523                    }
1524                } else {
1525                    panic!("Expected Expr variant");
1526                }
1527            } else {
1528                panic!("Expected Named columns");
1529            }
1530        } else {
1531            panic!("Expected Select");
1532        }
1533    }
1534
1535    #[test]
1536    fn test_select_parenthesized_expr() {
1537        // (a + b) * c should override default precedence
1538        let stmt = parse_query("SELECT (a + b) * c FROM test").unwrap();
1539        if let Statement::Select(q) = stmt {
1540            if let ColumnList::Named(exprs) = &q.columns {
1541                if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1542                    if let Expr::BinaryOp { left, op, .. } = expr {
1543                        assert_eq!(*op, ArithOp::Mul);
1544                        assert!(matches!(left.as_ref(), Expr::BinaryOp { op: ArithOp::Add, .. }));
1545                    } else {
1546                        panic!("Expected BinaryOp");
1547                    }
1548                } else {
1549                    panic!("Expected Expr variant");
1550                }
1551            } else {
1552                panic!("Expected Named columns");
1553            }
1554        } else {
1555            panic!("Expected Select");
1556        }
1557    }
1558
1559    #[test]
1560    fn test_select_unary_minus() {
1561        let stmt = parse_query("SELECT -count FROM test").unwrap();
1562        if let Statement::Select(q) = stmt {
1563            if let ColumnList::Named(exprs) = &q.columns {
1564                assert!(matches!(&exprs[0], SelectExpr::Expr {
1565                    expr: Expr::UnaryMinus(_),
1566                    ..
1567                }));
1568            } else {
1569                panic!("Expected Named columns");
1570            }
1571        } else {
1572            panic!("Expected Select");
1573        }
1574    }
1575
1576    #[test]
1577    fn test_select_negative_literal() {
1578        let stmt = parse_query("SELECT -42 FROM test").unwrap();
1579        if let Statement::Select(q) = stmt {
1580            if let ColumnList::Named(exprs) = &q.columns {
1581                // Unary minus folds into the literal
1582                assert!(matches!(&exprs[0], SelectExpr::Expr {
1583                    expr: Expr::Literal(SqlValue::Int(-42)),
1584                    ..
1585                }));
1586            } else {
1587                panic!("Expected Named columns");
1588            }
1589        } else {
1590            panic!("Expected Select");
1591        }
1592    }
1593
1594    #[test]
1595    fn test_where_arithmetic_expr() {
1596        let stmt = parse_query("SELECT * FROM test WHERE a + b > 10").unwrap();
1597        if let Statement::Select(q) = stmt {
1598            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1599                assert_eq!(c.op, ">");
1600                assert!(matches!(&c.left_expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1601                assert!(matches!(&c.right_expr, Some(Expr::Literal(SqlValue::Int(10)))));
1602            } else {
1603                panic!("Expected comparison");
1604            }
1605        } else {
1606            panic!("Expected Select");
1607        }
1608    }
1609
1610    #[test]
1611    fn test_where_both_sides_expr() {
1612        let stmt = parse_query("SELECT * FROM test WHERE a * 2 > b + 1").unwrap();
1613        if let Statement::Select(q) = stmt {
1614            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1615                assert_eq!(c.op, ">");
1616                assert!(matches!(&c.left_expr, Some(Expr::BinaryOp { op: ArithOp::Mul, .. })));
1617                assert!(matches!(&c.right_expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1618            } else {
1619                panic!("Expected comparison");
1620            }
1621        } else {
1622            panic!("Expected Select");
1623        }
1624    }
1625
1626    #[test]
1627    fn test_order_by_expr() {
1628        let stmt = parse_query("SELECT * FROM test ORDER BY a + b DESC").unwrap();
1629        if let Statement::Select(q) = stmt {
1630            let ob = q.order_by.unwrap();
1631            assert_eq!(ob.len(), 1);
1632            assert!(ob[0].descending);
1633            assert!(matches!(&ob[0].expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1634        } else {
1635            panic!("Expected Select");
1636        }
1637    }
1638
1639    #[test]
1640    fn test_all_arithmetic_ops() {
1641        let stmt = parse_query("SELECT a + b, a - b, a * b, a / b, a % b FROM test").unwrap();
1642        if let Statement::Select(q) = stmt {
1643            if let ColumnList::Named(exprs) = &q.columns {
1644                assert_eq!(exprs.len(), 5);
1645                assert!(matches!(&exprs[0], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Add, .. }, .. }));
1646                assert!(matches!(&exprs[1], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Sub, .. }, .. }));
1647                assert!(matches!(&exprs[2], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Mul, .. }, .. }));
1648                assert!(matches!(&exprs[3], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Div, .. }, .. }));
1649                assert!(matches!(&exprs[4], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Mod, .. }, .. }));
1650            } else {
1651                panic!("Expected Named columns");
1652            }
1653        } else {
1654            panic!("Expected Select");
1655        }
1656    }
1657
1658    #[test]
1659    fn test_column_with_literal_arithmetic() {
1660        let stmt = parse_query("SELECT count * 2 + 1 FROM test").unwrap();
1661        if let Statement::Select(q) = stmt {
1662            if let ColumnList::Named(exprs) = &q.columns {
1663                // Should be (count * 2) + 1
1664                if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1665                    if let Expr::BinaryOp { left, op, right } = expr {
1666                        assert_eq!(*op, ArithOp::Add);
1667                        assert!(matches!(right.as_ref(), Expr::Literal(SqlValue::Int(1))));
1668                        assert!(matches!(left.as_ref(), Expr::BinaryOp { op: ArithOp::Mul, .. }));
1669                    } else {
1670                        panic!("Expected BinaryOp");
1671                    }
1672                } else {
1673                    panic!("Expected Expr");
1674                }
1675            } else {
1676                panic!("Expected Named columns");
1677            }
1678        } else {
1679            panic!("Expected Select");
1680        }
1681    }
1682
1683    #[test]
1684    fn test_mixed_columns_and_exprs() {
1685        let stmt = parse_query("SELECT title, a + b AS sum, count FROM test").unwrap();
1686        if let Statement::Select(q) = stmt {
1687            if let ColumnList::Named(exprs) = &q.columns {
1688                assert_eq!(exprs.len(), 3);
1689                assert_eq!(exprs[0], SelectExpr::Column("title".into()));
1690                assert!(matches!(&exprs[1], SelectExpr::Expr { alias: Some(a), .. } if a == "sum"));
1691                assert_eq!(exprs[2], SelectExpr::Column("count".into()));
1692            } else {
1693                panic!("Expected Named columns");
1694            }
1695        } else {
1696            panic!("Expected Select");
1697        }
1698    }
1699
1700    // ── CASE WHEN tests ──────────────────────────────────────────
1701
1702    #[test]
1703    fn test_case_when_basic() {
1704        let stmt = parse_query(
1705            "SELECT CASE WHEN status = 'ACTIVE' THEN 1 ELSE 0 END FROM test"
1706        ).unwrap();
1707        if let Statement::Select(q) = stmt {
1708            if let ColumnList::Named(exprs) = &q.columns {
1709                assert_eq!(exprs.len(), 1);
1710                assert!(matches!(&exprs[0], SelectExpr::Expr {
1711                    expr: Expr::Case { .. },
1712                    ..
1713                }));
1714            } else {
1715                panic!("Expected Named columns");
1716            }
1717        } else {
1718            panic!("Expected Select");
1719        }
1720    }
1721
1722    #[test]
1723    fn test_case_when_multiple_branches() {
1724        let stmt = parse_query(
1725            "SELECT CASE WHEN x > 10 THEN 'high' WHEN x > 5 THEN 'mid' ELSE 'low' END FROM test"
1726        ).unwrap();
1727        if let Statement::Select(q) = stmt {
1728            if let ColumnList::Named(exprs) = &q.columns {
1729                if let SelectExpr::Expr { expr: Expr::Case { whens, else_expr }, .. } = &exprs[0] {
1730                    assert_eq!(whens.len(), 2);
1731                    assert!(else_expr.is_some());
1732                } else {
1733                    panic!("Expected Case expression");
1734                }
1735            } else {
1736                panic!("Expected Named columns");
1737            }
1738        } else {
1739            panic!("Expected Select");
1740        }
1741    }
1742
1743    #[test]
1744    fn test_case_when_no_else() {
1745        let stmt = parse_query(
1746            "SELECT CASE WHEN x = 1 THEN 'one' END FROM test"
1747        ).unwrap();
1748        if let Statement::Select(q) = stmt {
1749            if let ColumnList::Named(exprs) = &q.columns {
1750                if let SelectExpr::Expr { expr: Expr::Case { whens, else_expr }, .. } = &exprs[0] {
1751                    assert_eq!(whens.len(), 1);
1752                    assert!(else_expr.is_none());
1753                } else {
1754                    panic!("Expected Case expression");
1755                }
1756            } else {
1757                panic!("Expected Named columns");
1758            }
1759        } else {
1760            panic!("Expected Select");
1761        }
1762    }
1763
1764    #[test]
1765    fn test_case_when_in_aggregate() {
1766        let stmt = parse_query(
1767            "SELECT SUM(CASE WHEN side = 'BUY' THEN size ELSE -size END) AS net FROM orders GROUP BY token"
1768        ).unwrap();
1769        if let Statement::Select(q) = stmt {
1770            if let ColumnList::Named(exprs) = &q.columns {
1771                assert_eq!(exprs.len(), 1);
1772                assert!(matches!(&exprs[0], SelectExpr::Aggregate {
1773                    func: AggFunc::Sum,
1774                    arg_expr: Some(Expr::Case { .. }),
1775                    alias: Some(a),
1776                    ..
1777                } if a == "net"));
1778            } else {
1779                panic!("Expected Named columns");
1780            }
1781        } else {
1782            panic!("Expected Select");
1783        }
1784    }
1785
1786    #[test]
1787    fn test_case_when_with_alias() {
1788        let stmt = parse_query(
1789            "SELECT CASE WHEN x > 0 THEN 'pos' ELSE 'neg' END AS sign FROM test"
1790        ).unwrap();
1791        if let Statement::Select(q) = stmt {
1792            if let ColumnList::Named(exprs) = &q.columns {
1793                assert!(matches!(&exprs[0], SelectExpr::Expr {
1794                    expr: Expr::Case { .. },
1795                    alias: Some(a),
1796                } if a == "sign"));
1797            } else {
1798                panic!("Expected Named columns");
1799            }
1800        } else {
1801            panic!("Expected Select");
1802        }
1803    }
1804}