Skip to main content

mdql_core/
query_parser.rs

1//! Hand-written recursive descent parser for the MDQL SQL subset.
2
3use regex::Regex;
4use std::sync::LazyLock;
5
6use crate::errors::MdqlError;
7pub use crate::query_ast::*;
8
9// ── Tokenizer ──────────────────────────────────────────────────────────────
10
11static KEYWORDS: &[&str] = &[
12    "SELECT", "FROM", "WHERE", "AND", "OR", "ORDER", "BY",
13    "ASC", "DESC", "LIMIT", "LIKE", "IN", "IS", "NOT", "NULL",
14    "JOIN", "LEFT", "ON", "AS", "GROUP", "HAVING",
15    "INSERT", "INTO", "VALUES", "UPDATE", "SET", "DELETE",
16    "ALTER", "TABLE", "RENAME", "FIELD", "TO", "DROP", "MERGE", "FIELDS",
17    "CASE", "WHEN", "THEN", "ELSE", "END",
18    "INTERVAL", "DAY", "DAYS", "CURRENT_DATE", "CURRENT_TIMESTAMP", "DATEDIFF",
19    "CREATE", "VIEW", "CASCADE", "RESTRICT",
20];
21
22static AGG_FUNCS: &[&str] = &["COUNT", "SUM", "AVG", "MIN", "MAX"];
23
24static TOKEN_RE: LazyLock<Regex> = LazyLock::new(|| {
25    Regex::new(
26        r#"(?x)
27        \s*(?:
28            (?P<backtick>`[^`]+`)
29            | (?P<string>'(?:[^'\\]|\\.)*')
30            | (?P<number>\d+(?:\.\d+)?)
31            | (?P<op><=|>=|!=|[=<>,*()+\-/%])
32            | (?P<word>[A-Za-z_][A-Za-z0-9_./-]*)
33        )"#,
34    )
35    .unwrap()
36});
37
38#[derive(Debug, Clone)]
39struct Token {
40    token_type: String,
41    value: String,
42    raw: String,
43}
44
45fn tokenize(sql: &str) -> Vec<Token> {
46    let mut tokens = Vec::new();
47    for caps in TOKEN_RE.captures_iter(sql) {
48        if let Some(m) = caps.name("backtick") {
49            let raw = m.as_str();
50            tokens.push(Token {
51                token_type: "ident".into(),
52                value: raw[1..raw.len() - 1].into(),
53                raw: raw.into(),
54            });
55        } else if let Some(m) = caps.name("string") {
56            let raw = m.as_str();
57            tokens.push(Token {
58                token_type: "string".into(),
59                value: raw[1..raw.len() - 1].into(),
60                raw: raw.into(),
61            });
62        } else if let Some(m) = caps.name("number") {
63            let raw = m.as_str();
64            tokens.push(Token {
65                token_type: "number".into(),
66                value: raw.into(),
67                raw: raw.into(),
68            });
69        } else if let Some(m) = caps.name("op") {
70            let raw = m.as_str();
71            tokens.push(Token {
72                token_type: "op".into(),
73                value: raw.into(),
74                raw: raw.into(),
75            });
76        } else if let Some(m) = caps.name("word") {
77            let raw = m.as_str();
78            if KEYWORDS.contains(&raw.to_uppercase().as_str()) {
79                tokens.push(Token {
80                    token_type: "keyword".into(),
81                    value: raw.to_uppercase(),
82                    raw: raw.into(),
83                });
84            } else {
85                tokens.push(Token {
86                    token_type: "ident".into(),
87                    value: raw.into(),
88                    raw: raw.into(),
89                });
90            }
91        }
92    }
93    tokens
94}
95
96// ── Parser ─────────────────────────────────────────────────────────────────
97
98struct Parser {
99    tokens: Vec<Token>,
100    pos: usize,
101}
102
103impl Parser {
104    fn new(tokens: Vec<Token>) -> Self {
105        Parser { tokens, pos: 0 }
106    }
107
108    fn peek(&self) -> Option<&Token> {
109        self.tokens.get(self.pos)
110    }
111
112    fn advance(&mut self) -> Token {
113        let t = self.tokens[self.pos].clone();
114        self.pos += 1;
115        t
116    }
117
118    fn expect(&mut self, type_: &str, value: Option<&str>) -> Result<Token, MdqlError> {
119        let t = self.peek().ok_or_else(|| {
120            MdqlError::QueryParse(format!(
121                "Unexpected end of query, expected {}",
122                value.unwrap_or(type_)
123            ))
124        })?;
125        let matches_type = t.token_type == type_;
126        let matches_value = value.map_or(true, |v| t.value == v);
127        if !matches_type || !matches_value {
128            return Err(MdqlError::QueryParse(format!(
129                "Expected {}, got '{}' at position {}",
130                value.unwrap_or(type_),
131                t.raw,
132                self.pos
133            )));
134        }
135        Ok(self.advance())
136    }
137
138    fn match_keyword(&mut self, kw: &str) -> bool {
139        if let Some(t) = self.peek() {
140            if t.token_type == "keyword" && t.value == kw {
141                self.advance();
142                return true;
143            }
144        }
145        false
146    }
147
148    fn parse_statement(&mut self) -> Result<Statement, MdqlError> {
149        let t = self.peek().ok_or_else(|| MdqlError::QueryParse("Empty query".into()))?;
150        match (t.token_type.as_str(), t.value.as_str()) {
151            ("keyword", "SELECT") => {
152                let q = self.parse_select()?;
153                self.expect_end()?;
154                Ok(Statement::Select(q))
155            }
156            ("keyword", "INSERT") => Ok(Statement::Insert(self.parse_insert()?)),
157            ("keyword", "UPDATE") => Ok(Statement::Update(self.parse_update()?)),
158            ("keyword", "DELETE") => Ok(Statement::Delete(self.parse_delete()?)),
159            ("keyword", "ALTER") => self.parse_alter(),
160            ("keyword", "CREATE") => self.parse_create_view(),
161            ("keyword", "DROP") => self.parse_drop_view(),
162            _ => Err(MdqlError::QueryParse(format!(
163                "Expected SELECT, INSERT, UPDATE, DELETE, ALTER, CREATE, or DROP, got '{}'",
164                t.raw
165            ))),
166        }
167    }
168
169    fn parse_select(&mut self) -> Result<SelectQuery, MdqlError> {
170        self.expect("keyword", Some("SELECT"))?;
171        let columns = self.parse_columns()?;
172        self.expect("keyword", Some("FROM"))?;
173
174        // Subquery: FROM (SELECT ...)
175        let mut subquery = None;
176        let (table, mut table_alias) = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
177            self.advance();
178            let inner = self.parse_select()?;
179            self.expect("op", Some(")"))?;
180            subquery = Some(Box::new(inner));
181            let alias = if let Some(t) = self.peek() {
182                if t.token_type == "ident" && !self.is_clause_keyword(t) {
183                    Some(self.advance().value)
184                } else {
185                    None
186                }
187            } else {
188                None
189            };
190            ("_subquery".to_string(), alias)
191        } else {
192            let t = self.parse_ident()?;
193            (t, None)
194        };
195
196        // Optional table alias (for non-subquery)
197        if subquery.is_none() {
198            if let Some(t) = self.peek() {
199                if t.token_type == "ident" && !self.is_clause_keyword(t) {
200                    table_alias = Some(self.advance().value);
201                }
202            }
203        }
204
205        // Optional JOIN(s)
206        let mut joins = Vec::new();
207        loop {
208            let jt = if self.match_keyword("LEFT") {
209                self.expect("keyword", Some("JOIN"))?;
210                JoinType::Left
211            } else if self.match_keyword("JOIN") {
212                JoinType::Inner
213            } else {
214                break;
215            };
216            let join_table = self.parse_ident()?;
217            let mut join_alias = None;
218            if let Some(t) = self.peek() {
219                if t.token_type == "ident" && !self.is_clause_keyword(t) {
220                    join_alias = Some(self.advance().value);
221                }
222            }
223            self.expect("keyword", Some("ON"))?;
224            let condition = self.parse_or_expr()?;
225            joins.push(JoinClause {
226                join_type: jt,
227                table: join_table,
228                alias: join_alias,
229                condition,
230            });
231        }
232
233        let mut where_clause = None;
234        if self.match_keyword("WHERE") {
235            where_clause = Some(self.parse_or_expr()?);
236        }
237
238        let mut group_by = None;
239        if self.match_keyword("GROUP") {
240            self.expect("keyword", Some("BY"))?;
241            let mut cols = vec![self.parse_ident()?];
242            while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
243                self.advance();
244                cols.push(self.parse_ident()?);
245            }
246            group_by = Some(cols);
247        }
248
249        let mut having = None;
250        if self.match_keyword("HAVING") {
251            having = Some(self.parse_or_expr()?);
252        }
253
254        let mut order_by = None;
255        if self.match_keyword("ORDER") {
256            self.expect("keyword", Some("BY"))?;
257            order_by = Some(self.parse_order_by()?);
258        }
259
260        let mut limit = None;
261        if self.match_keyword("LIMIT") {
262            let t = self.expect("number", None)?;
263            limit = Some(t.value.parse::<i64>().map_err(|_| {
264                MdqlError::QueryParse(format!("Invalid LIMIT value: {}", t.value))
265            })?);
266        }
267
268        Ok(SelectQuery {
269            columns,
270            table,
271            table_alias,
272            subquery,
273            joins,
274            where_clause,
275            group_by,
276            having,
277            order_by,
278            limit,
279        })
280    }
281
282    fn parse_insert(&mut self) -> Result<InsertQuery, MdqlError> {
283        self.expect("keyword", Some("INSERT"))?;
284        self.expect("keyword", Some("INTO"))?;
285        let table = self.parse_ident()?;
286
287        self.expect("op", Some("("))?;
288        let mut columns = vec![self.parse_ident()?];
289        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
290            self.advance();
291            columns.push(self.parse_ident()?);
292        }
293        self.expect("op", Some(")"))?;
294
295        self.expect("keyword", Some("VALUES"))?;
296
297        self.expect("op", Some("("))?;
298        let mut values = vec![self.parse_value()?];
299        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
300            self.advance();
301            values.push(self.parse_value()?);
302        }
303        self.expect("op", Some(")"))?;
304
305        if columns.len() != values.len() {
306            return Err(MdqlError::QueryParse(format!(
307                "Column count ({}) does not match value count ({})",
308                columns.len(),
309                values.len()
310            )));
311        }
312
313        self.expect_end()?;
314        Ok(InsertQuery {
315            table,
316            columns,
317            values,
318        })
319    }
320
321    fn parse_update(&mut self) -> Result<UpdateQuery, MdqlError> {
322        self.expect("keyword", Some("UPDATE"))?;
323        let table = self.parse_ident()?;
324        self.expect("keyword", Some("SET"))?;
325
326        let mut assignments = Vec::new();
327        let col = self.parse_ident()?;
328        self.expect("op", Some("="))?;
329        let val = self.parse_value()?;
330        assignments.push((col, val));
331
332        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
333            self.advance();
334            let col = self.parse_ident()?;
335            self.expect("op", Some("="))?;
336            let val = self.parse_value()?;
337            assignments.push((col, val));
338        }
339
340        let mut where_clause = None;
341        if self.match_keyword("WHERE") {
342            where_clause = Some(self.parse_or_expr()?);
343        }
344
345        self.expect_end()?;
346        Ok(UpdateQuery {
347            table,
348            assignments,
349            where_clause,
350        })
351    }
352
353    fn parse_delete(&mut self) -> Result<DeleteQuery, MdqlError> {
354        self.expect("keyword", Some("DELETE"))?;
355        self.expect("keyword", Some("FROM"))?;
356        let table = self.parse_ident()?;
357
358        let mut where_clause = None;
359        if self.match_keyword("WHERE") {
360            where_clause = Some(self.parse_or_expr()?);
361        }
362
363        self.expect_end()?;
364        Ok(DeleteQuery {
365            table,
366            where_clause,
367        })
368    }
369
370    fn parse_alter(&mut self) -> Result<Statement, MdqlError> {
371        self.expect("keyword", Some("ALTER"))?;
372        self.expect("keyword", Some("TABLE"))?;
373        let table = self.parse_ident()?;
374
375        let t = self.peek().ok_or_else(|| {
376            MdqlError::QueryParse("Expected RENAME, DROP, or MERGE after table name".into())
377        })?;
378
379        match (t.token_type.as_str(), t.value.as_str()) {
380            ("keyword", "RENAME") => {
381                self.advance();
382                self.expect("keyword", Some("FIELD"))?;
383                let old_name = self.parse_string_or_ident()?;
384                self.expect("keyword", Some("TO"))?;
385                let new_name = self.parse_string_or_ident()?;
386                self.expect_end()?;
387                Ok(Statement::AlterRename(AlterRenameFieldQuery {
388                    table,
389                    old_name,
390                    new_name,
391                }))
392            }
393            ("keyword", "DROP") => {
394                self.advance();
395                self.expect("keyword", Some("FIELD"))?;
396                let field_name = self.parse_string_or_ident()?;
397                self.expect_end()?;
398                Ok(Statement::AlterDrop(AlterDropFieldQuery {
399                    table,
400                    field_name,
401                }))
402            }
403            ("keyword", "MERGE") => {
404                self.advance();
405                self.expect("keyword", Some("FIELDS"))?;
406                let mut sources = vec![self.parse_string_or_ident()?];
407                while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
408                    self.advance();
409                    sources.push(self.parse_string_or_ident()?);
410                }
411                self.expect("keyword", Some("INTO"))?;
412                let target = self.parse_string_or_ident()?;
413                self.expect_end()?;
414                Ok(Statement::AlterMerge(AlterMergeFieldsQuery {
415                    table,
416                    sources,
417                    into: target,
418                }))
419            }
420            _ => Err(MdqlError::QueryParse(format!(
421                "Expected RENAME, DROP, or MERGE, got '{}'",
422                t.raw
423            ))),
424        }
425    }
426
427    fn parse_create_view(&mut self) -> Result<Statement, MdqlError> {
428        self.expect("keyword", Some("CREATE"))?;
429        self.expect("keyword", Some("VIEW"))?;
430        let view_name = self.parse_ident()?;
431
432        let columns = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
433            self.advance();
434            let mut cols = vec![self.parse_ident()?];
435            while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
436                self.advance();
437                cols.push(self.parse_ident()?);
438            }
439            self.expect("op", Some(")"))?;
440            Some(cols)
441        } else {
442            None
443        };
444
445        self.expect("keyword", Some("AS"))?;
446        let query = Box::new(self.parse_select()?);
447        self.expect_end()?;
448
449        Ok(Statement::CreateView(CreateViewQuery {
450            view_name,
451            columns,
452            query,
453        }))
454    }
455
456    fn parse_drop_view(&mut self) -> Result<Statement, MdqlError> {
457        self.expect("keyword", Some("DROP"))?;
458        self.expect("keyword", Some("VIEW"))?;
459        let view_name = self.parse_ident()?;
460        self.expect_end()?;
461        Ok(Statement::DropView(DropViewQuery { view_name }))
462    }
463
464    fn parse_string_or_ident(&mut self) -> Result<String, MdqlError> {
465        let t = self.peek().ok_or_else(|| {
466            MdqlError::QueryParse("Expected field name, got end of query".into())
467        })?;
468        match t.token_type.as_str() {
469            "string" => {
470                let v = self.advance().value;
471                Ok(v)
472            }
473            "ident" | "keyword" => {
474                let v = self.advance().value;
475                Ok(v)
476            }
477            _ => Err(MdqlError::QueryParse(format!(
478                "Expected field name, got '{}'",
479                t.raw
480            ))),
481        }
482    }
483
484    fn parse_columns(&mut self) -> Result<ColumnList, MdqlError> {
485        if let Some(t) = self.peek() {
486            if t.token_type == "op" && t.value == "*" {
487                self.advance();
488                return Ok(ColumnList::All);
489            }
490        }
491
492        let mut exprs = vec![self.parse_select_expr()?];
493        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
494            self.advance();
495            exprs.push(self.parse_select_expr()?);
496        }
497        Ok(ColumnList::Named(exprs))
498    }
499
500    fn peek_is_agg_func(&self) -> bool {
501        let t = match self.peek() {
502            Some(t) => t,
503            None => return false,
504        };
505        let name_upper = t.value.to_uppercase();
506        if !AGG_FUNCS.contains(&name_upper.as_str()) {
507            return false;
508        }
509        // Only treat as aggregate if followed by (
510        self.tokens
511            .get(self.pos + 1)
512            .map_or(false, |next| next.token_type == "op" && next.value == "(")
513    }
514
515    fn parse_select_expr(&mut self) -> Result<SelectExpr, MdqlError> {
516        let _t = self.peek().ok_or_else(|| {
517            MdqlError::QueryParse("Expected column or aggregate, got end of query".into())
518        })?;
519
520        let expr = self.parse_additive()?;
521
522        let alias = if self.match_keyword("AS") {
523            Some(self.parse_ident()?)
524        } else if self.peek().map_or(false, |t| {
525            t.token_type == "ident" && !self.is_clause_keyword(t)
526        }) {
527            Some(self.advance().value)
528        } else {
529            None
530        };
531
532        // Bare aggregate → SelectExpr::Aggregate for backward compat
533        if let Expr::Aggregate { func, arg, arg_expr } = expr {
534            return Ok(SelectExpr::Aggregate {
535                func,
536                arg,
537                arg_expr: arg_expr.map(|e| *e),
538                alias,
539            });
540        }
541
542        if alias.is_none() {
543            if let Expr::Column(name) = &expr {
544                return Ok(SelectExpr::Column(name.clone()));
545            }
546        }
547
548        Ok(SelectExpr::Expr { expr, alias })
549    }
550
551    // ── Expression parser (precedence climbing) ───────────────────────
552
553    fn peek_is_additive_op(&self) -> bool {
554        self.peek().map_or(false, |t| {
555            t.token_type == "op" && (t.value == "+" || t.value == "-")
556        })
557    }
558
559    fn peek_is_multiplicative_op(&self) -> bool {
560        self.peek().map_or(false, |t| {
561            t.token_type == "op" && (t.value == "*" || t.value == "/" || t.value == "%")
562        })
563    }
564
565    fn parse_additive(&mut self) -> Result<Expr, MdqlError> {
566        let mut left = self.parse_multiplicative()?;
567        while self.peek_is_additive_op() {
568            let op_tok = self.advance();
569            let is_sub = op_tok.value == "-";
570
571            // Check for INTERVAL keyword: expr +/- INTERVAL n DAY
572            if self.peek().map_or(false, |t| t.token_type == "keyword" && t.value == "INTERVAL") {
573                self.advance(); // consume INTERVAL
574                let days_expr = self.parse_multiplicative()?;
575                // Expect DAY or DAYS
576                if !self.match_keyword("DAY") && !self.match_keyword("DAYS") {
577                    return Err(MdqlError::QueryParse("Expected DAY after INTERVAL value".into()));
578                }
579                let days = if is_sub {
580                    Expr::UnaryMinus(Box::new(days_expr))
581                } else {
582                    days_expr
583                };
584                left = Expr::DateAdd {
585                    date: Box::new(left),
586                    days: Box::new(days),
587                };
588                continue;
589            }
590
591            let op = match op_tok.value.as_str() {
592                "+" => ArithOp::Add,
593                "-" => ArithOp::Sub,
594                _ => unreachable!(),
595            };
596            let right = self.parse_multiplicative()?;
597            left = Expr::BinaryOp {
598                left: Box::new(left),
599                op,
600                right: Box::new(right),
601            };
602        }
603        Ok(left)
604    }
605
606    fn parse_multiplicative(&mut self) -> Result<Expr, MdqlError> {
607        let mut left = self.parse_unary()?;
608        while self.peek_is_multiplicative_op() {
609            let op_tok = self.advance();
610            let op = match op_tok.value.as_str() {
611                "*" => ArithOp::Mul,
612                "/" => ArithOp::Div,
613                "%" => ArithOp::Mod,
614                _ => unreachable!(),
615            };
616            let right = self.parse_unary()?;
617            left = Expr::BinaryOp {
618                left: Box::new(left),
619                op,
620                right: Box::new(right),
621            };
622        }
623        Ok(left)
624    }
625
626    fn parse_unary(&mut self) -> Result<Expr, MdqlError> {
627        if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "-") {
628            self.advance();
629            let inner = self.parse_atom()?;
630            // Fold unary minus on literals
631            match inner {
632                Expr::Literal(SqlValue::Int(n)) => Ok(Expr::Literal(SqlValue::Int(-n))),
633                Expr::Literal(SqlValue::Float(f)) => Ok(Expr::Literal(SqlValue::Float(-f))),
634                _ => Ok(Expr::UnaryMinus(Box::new(inner))),
635            }
636        } else {
637            self.parse_atom()
638        }
639    }
640
641    fn parse_atom(&mut self) -> Result<Expr, MdqlError> {
642        if self.peek_is_agg_func() {
643            return self.parse_agg_expr();
644        }
645
646        let t = self.peek().ok_or_else(|| {
647            MdqlError::QueryParse("Expected expression, got end of query".into())
648        })?;
649
650        match t.token_type.as_str() {
651            "number" => {
652                let v = self.advance().value;
653                if v.contains('.') {
654                    let f: f64 = v.parse().map_err(|_| {
655                        MdqlError::QueryParse(format!("Invalid float: {}", v))
656                    })?;
657                    Ok(Expr::Literal(SqlValue::Float(f)))
658                } else {
659                    let n: i64 = v.parse().map_err(|_| {
660                        MdqlError::QueryParse(format!("Invalid int: {}", v))
661                    })?;
662                    Ok(Expr::Literal(SqlValue::Int(n)))
663                }
664            }
665            "string" => {
666                let v = self.advance().value;
667                Ok(Expr::Literal(SqlValue::String(v)))
668            }
669            "keyword" if t.value == "NULL" => {
670                self.advance();
671                Ok(Expr::Literal(SqlValue::Null))
672            }
673            "keyword" if t.value == "CASE" => {
674                self.parse_case_expr()
675            }
676            "keyword" if t.value == "CURRENT_DATE" => {
677                self.advance();
678                Ok(Expr::CurrentDate)
679            }
680            "keyword" if t.value == "CURRENT_TIMESTAMP" => {
681                self.advance();
682                Ok(Expr::CurrentTimestamp)
683            }
684            "keyword" if t.value == "DATEDIFF" => {
685                self.advance();
686                self.expect("op", Some("("))?;
687                let left = self.parse_additive()?;
688                self.expect("op", Some(","))?;
689                let right = self.parse_additive()?;
690                self.expect("op", Some(")"))?;
691                Ok(Expr::DateDiff { left: Box::new(left), right: Box::new(right) })
692            }
693            "op" if t.value == "(" => {
694                self.advance();
695                let expr = self.parse_additive()?;
696                self.expect("op", Some(")"))?;
697                Ok(expr)
698            }
699            "ident" => {
700                let name = self.advance().value;
701                Ok(Expr::Column(name))
702            }
703            "keyword" if !Self::is_reserved_keyword(&t.value) => {
704                let name = self.advance().value;
705                Ok(Expr::Column(name))
706            }
707            _ => Err(MdqlError::QueryParse(format!(
708                "Expected expression, got '{}'",
709                t.raw
710            ))),
711        }
712    }
713
714    fn parse_case_expr(&mut self) -> Result<Expr, MdqlError> {
715        self.expect("keyword", Some("CASE"))?;
716        let mut whens = Vec::new();
717        while self.match_keyword("WHEN") {
718            let condition = self.parse_or_expr()?;
719            self.expect("keyword", Some("THEN"))?;
720            let result = self.parse_additive()?;
721            whens.push((condition, Box::new(result)));
722        }
723        if whens.is_empty() {
724            return Err(MdqlError::QueryParse("CASE requires at least one WHEN clause".into()));
725        }
726        let else_expr = if self.match_keyword("ELSE") {
727            Some(Box::new(self.parse_additive()?))
728        } else {
729            None
730        };
731        self.expect("keyword", Some("END"))?;
732        Ok(Expr::Case { whens, else_expr })
733    }
734
735    fn parse_agg_expr(&mut self) -> Result<Expr, MdqlError> {
736        let func_name = self.advance().value.to_uppercase();
737        let func = match func_name.as_str() {
738            "COUNT" => AggFunc::Count,
739            "SUM" => AggFunc::Sum,
740            "AVG" => AggFunc::Avg,
741            "MIN" => AggFunc::Min,
742            "MAX" => AggFunc::Max,
743            _ => unreachable!(),
744        };
745        self.expect("op", Some("("))?;
746        let (arg, arg_expr) = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "*") {
747            self.advance();
748            ("*".to_string(), None)
749        } else {
750            let expr = self.parse_additive()?;
751            if let Expr::Column(name) = &expr {
752                (name.clone(), None)
753            } else {
754                (expr.display_name(), Some(Box::new(expr)))
755            }
756        };
757        self.expect("op", Some(")"))?;
758        Ok(Expr::Aggregate { func, arg, arg_expr })
759    }
760
761    fn parse_ident(&mut self) -> Result<String, MdqlError> {
762        let t = self.peek().ok_or_else(|| {
763            MdqlError::QueryParse("Expected identifier, got end of query".into())
764        })?;
765        match t.token_type.as_str() {
766            "ident" | "keyword" => {
767                let v = self.advance().value;
768                Ok(v)
769            }
770            _ => Err(MdqlError::QueryParse(format!(
771                "Expected identifier, got '{}'",
772                t.raw
773            ))),
774        }
775    }
776
777    fn parse_or_expr(&mut self) -> Result<WhereClause, MdqlError> {
778        let mut left = self.parse_and_expr()?;
779        while self.match_keyword("OR") {
780            let right = self.parse_and_expr()?;
781            left = WhereClause::BoolOp(BoolOp {
782                op: BoolOpKind::Or,
783                left: Box::new(left),
784                right: Box::new(right),
785            });
786        }
787        Ok(left)
788    }
789
790    fn parse_and_expr(&mut self) -> Result<WhereClause, MdqlError> {
791        let mut left = self.parse_comparison()?;
792        while self.match_keyword("AND") {
793            let right = self.parse_comparison()?;
794            left = WhereClause::BoolOp(BoolOp {
795                op: BoolOpKind::And,
796                left: Box::new(left),
797                right: Box::new(right),
798            });
799        }
800        Ok(left)
801    }
802
803    fn parse_comparison(&mut self) -> Result<WhereClause, MdqlError> {
804        // Handle parenthesized boolean expressions
805        if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
806            // Save position — might be arithmetic parens, not boolean
807            let saved_pos = self.pos;
808            self.advance();
809            // Try parsing as boolean (OR/AND) expression
810            let result = self.parse_or_expr();
811            if result.is_ok() && self.peek().map_or(false, |t| t.token_type == "op" && t.value == ")") {
812                self.advance();
813                return result;
814            }
815            // Not a boolean paren — rewind and parse as arithmetic expression
816            self.pos = saved_pos;
817        }
818
819        // Parse the left side as a full expression (column, literal, or arithmetic)
820        let left_expr = self.parse_additive()?;
821
822        // Extract column name for backward compat (simple column on left side)
823        let col = left_expr.as_column().unwrap_or("").to_string();
824
825        // IS NULL / IS NOT NULL (only valid with simple column)
826        if self.match_keyword("IS") {
827            if self.match_keyword("NOT") {
828                self.expect("keyword", Some("NULL"))?;
829                return Ok(WhereClause::Comparison(Comparison {
830                    column: col,
831                    op: CmpOp::IsNotNull,
832                    value: None,
833                    left_expr: Some(left_expr),
834                    right_expr: None,
835                }));
836            }
837            self.expect("keyword", Some("NULL"))?;
838            return Ok(WhereClause::Comparison(Comparison {
839                column: col,
840                op: CmpOp::IsNull,
841                value: None,
842                left_expr: Some(left_expr),
843                right_expr: None,
844            }));
845        }
846
847        // IN (val, val, ...)
848        if self.match_keyword("IN") {
849            self.expect("op", Some("("))?;
850            let mut values = vec![self.parse_value()?];
851            while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
852                self.advance();
853                values.push(self.parse_value()?);
854            }
855            self.expect("op", Some(")"))?;
856            return Ok(WhereClause::Comparison(Comparison {
857                column: col,
858                op: CmpOp::In,
859                value: Some(SqlValue::List(values)),
860                left_expr: Some(left_expr),
861                right_expr: None,
862            }));
863        }
864
865        // LIKE
866        if self.match_keyword("LIKE") {
867            let val = self.parse_value()?;
868            return Ok(WhereClause::Comparison(Comparison {
869                column: col,
870                op: CmpOp::Like,
871                value: Some(val),
872                left_expr: Some(left_expr),
873                right_expr: None,
874            }));
875        }
876
877        // NOT LIKE
878        if self.match_keyword("NOT") {
879            if self.match_keyword("LIKE") {
880                let val = self.parse_value()?;
881                return Ok(WhereClause::Comparison(Comparison {
882                    column: col,
883                    op: CmpOp::NotLike,
884                    value: Some(val),
885                    left_expr: Some(left_expr),
886                    right_expr: None,
887                }));
888            }
889            return Err(MdqlError::QueryParse("Expected LIKE after NOT".into()));
890        }
891
892        // Standard comparison operators
893        if let Some(t) = self.peek() {
894            if t.token_type == "op" && ["=", "!=", "<", ">", "<=", ">="].contains(&t.value.as_str())
895            {
896                let op_str = self.advance().value;
897                let op = match op_str.as_str() {
898                    "=" => CmpOp::Eq,
899                    "!=" => CmpOp::Ne,
900                    "<" => CmpOp::Lt,
901                    ">" => CmpOp::Gt,
902                    "<=" => CmpOp::Le,
903                    ">=" => CmpOp::Ge,
904                    _ => unreachable!(),
905                };
906                // Parse right side as expression
907                let right_expr = self.parse_additive()?;
908                // Extract SqlValue for backward compat (simple literal on right side)
909                let value = match &right_expr {
910                    Expr::Literal(v) => Some(v.clone()),
911                    _ => None,
912                };
913                return Ok(WhereClause::Comparison(Comparison {
914                    column: col,
915                    op,
916                    value,
917                    left_expr: Some(left_expr),
918                    right_expr: Some(right_expr),
919                }));
920            }
921        }
922
923        let got = self.peek().map_or("end".to_string(), |t| t.raw.clone());
924        Err(MdqlError::QueryParse(format!(
925            "Expected operator after '{}', got '{}'",
926            left_expr.display_name(), got
927        )))
928    }
929
930    fn parse_value(&mut self) -> Result<SqlValue, MdqlError> {
931        let t = self.peek().ok_or_else(|| {
932            MdqlError::QueryParse("Expected value, got end of query".into())
933        })?;
934        match t.token_type.as_str() {
935            "string" => {
936                let v = self.advance().value;
937                Ok(SqlValue::String(v))
938            }
939            "number" => {
940                let v = self.advance().value;
941                if v.contains('.') {
942                    Ok(SqlValue::Float(v.parse().map_err(|_| {
943                        MdqlError::QueryParse(format!("Invalid float: {}", v))
944                    })?))
945                } else {
946                    Ok(SqlValue::Int(v.parse().map_err(|_| {
947                        MdqlError::QueryParse(format!("Invalid int: {}", v))
948                    })?))
949                }
950            }
951            "keyword" if t.value == "NULL" => {
952                self.advance();
953                Ok(SqlValue::Null)
954            }
955            _ => Err(MdqlError::QueryParse(format!(
956                "Expected value, got '{}'",
957                t.raw
958            ))),
959        }
960    }
961
962    fn parse_order_by(&mut self) -> Result<Vec<OrderSpec>, MdqlError> {
963        let mut specs = vec![self.parse_order_spec()?];
964        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
965            self.advance();
966            specs.push(self.parse_order_spec()?);
967        }
968        Ok(specs)
969    }
970
971    fn parse_order_spec(&mut self) -> Result<OrderSpec, MdqlError> {
972        let expr = self.parse_additive()?;
973        let col = expr.as_column().unwrap_or("").to_string();
974        let descending = if self.match_keyword("DESC") {
975            true
976        } else {
977            self.match_keyword("ASC");
978            false
979        };
980        Ok(OrderSpec {
981            column: col,
982            expr: Some(expr),
983            descending,
984        })
985    }
986
987    fn is_clause_keyword(&self, t: &Token) -> bool {
988        t.token_type == "keyword"
989            && ["WHERE", "ORDER", "LIMIT", "JOIN", "LEFT", "ON", "GROUP"].contains(&t.value.as_str())
990    }
991
992    /// Keywords that should never be consumed as column names inside expressions.
993    fn is_reserved_keyword(kw: &str) -> bool {
994        matches!(kw,
995            "AS" | "FROM" | "WHERE" | "AND" | "OR" | "ORDER" | "BY"
996            | "ASC" | "DESC" | "LIMIT" | "JOIN" | "ON" | "GROUP"
997            | "SELECT" | "INSERT" | "INTO" | "VALUES" | "UPDATE" | "SET"
998            | "DELETE" | "ALTER" | "TABLE" | "IS" | "NOT" | "IN" | "LIKE"
999            | "RENAME" | "FIELD" | "TO" | "DROP" | "MERGE" | "FIELDS"
1000            | "CASE" | "WHEN" | "THEN" | "ELSE" | "END"
1001            | "HAVING" | "INTERVAL" | "DAY" | "DAYS"
1002            | "CURRENT_DATE" | "CURRENT_TIMESTAMP" | "DATEDIFF"
1003            | "CREATE" | "VIEW" | "CASCADE" | "RESTRICT"
1004        )
1005    }
1006
1007    fn expect_end(&self) -> Result<(), MdqlError> {
1008        if let Some(t) = self.peek() {
1009            return Err(MdqlError::QueryParse(format!(
1010                "Unexpected token '{}' at position {}",
1011                t.raw, self.pos
1012            )));
1013        }
1014        Ok(())
1015    }
1016}
1017
1018pub fn parse_query(sql: &str) -> crate::errors::Result<Statement> {
1019    let tokens = tokenize(sql);
1020    if tokens.is_empty() {
1021        return Err(MdqlError::QueryParse("Empty query".into()));
1022    }
1023    let mut parser = Parser::new(tokens);
1024    parser.parse_statement()
1025}
1026
1027#[cfg(test)]
1028mod tests {
1029    use super::*;
1030
1031    #[test]
1032    fn test_simple_select() {
1033        let stmt = parse_query("SELECT title, status FROM strategies").unwrap();
1034        if let Statement::Select(q) = stmt {
1035            assert_eq!(q.columns, ColumnList::Named(vec![SelectExpr::Column("title".into()), SelectExpr::Column("status".into())]));
1036            assert_eq!(q.table, "strategies");
1037        } else {
1038            panic!("Expected Select");
1039        }
1040    }
1041
1042    #[test]
1043    fn test_select_star() {
1044        let stmt = parse_query("SELECT * FROM test").unwrap();
1045        if let Statement::Select(q) = stmt {
1046            assert_eq!(q.columns, ColumnList::All);
1047        } else {
1048            panic!("Expected Select");
1049        }
1050    }
1051
1052    #[test]
1053    fn test_where_clause() {
1054        let stmt = parse_query("SELECT title FROM test WHERE count > 5").unwrap();
1055        if let Statement::Select(q) = stmt {
1056            assert!(q.where_clause.is_some());
1057        } else {
1058            panic!("Expected Select");
1059        }
1060    }
1061
1062    #[test]
1063    fn test_order_by() {
1064        let stmt =
1065            parse_query("SELECT title FROM test ORDER BY composite DESC, title ASC").unwrap();
1066        if let Statement::Select(q) = stmt {
1067            let ob = q.order_by.unwrap();
1068            assert_eq!(ob.len(), 2);
1069            assert!(ob[0].descending);
1070            assert!(!ob[1].descending);
1071        } else {
1072            panic!("Expected Select");
1073        }
1074    }
1075
1076    #[test]
1077    fn test_limit() {
1078        let stmt = parse_query("SELECT * FROM test LIMIT 10").unwrap();
1079        if let Statement::Select(q) = stmt {
1080            assert_eq!(q.limit, Some(10));
1081        } else {
1082            panic!("Expected Select");
1083        }
1084    }
1085
1086    #[test]
1087    fn test_insert() {
1088        let stmt = parse_query(
1089            "INSERT INTO test (title, count) VALUES ('Hello', 42)",
1090        )
1091        .unwrap();
1092        if let Statement::Insert(q) = stmt {
1093            assert_eq!(q.table, "test");
1094            assert_eq!(q.columns, vec!["title", "count"]);
1095            assert_eq!(q.values[0], SqlValue::String("Hello".into()));
1096            assert_eq!(q.values[1], SqlValue::Int(42));
1097        } else {
1098            panic!("Expected Insert");
1099        }
1100    }
1101
1102    #[test]
1103    fn test_update() {
1104        let stmt = parse_query("UPDATE test SET status = 'KILLED' WHERE path = 'a.md'").unwrap();
1105        if let Statement::Update(q) = stmt {
1106            assert_eq!(q.table, "test");
1107            assert_eq!(q.assignments.len(), 1);
1108            assert!(q.where_clause.is_some());
1109        } else {
1110            panic!("Expected Update");
1111        }
1112    }
1113
1114    #[test]
1115    fn test_delete() {
1116        let stmt = parse_query("DELETE FROM test WHERE status = 'draft'").unwrap();
1117        if let Statement::Delete(q) = stmt {
1118            assert_eq!(q.table, "test");
1119            assert!(q.where_clause.is_some());
1120        } else {
1121            panic!("Expected Delete");
1122        }
1123    }
1124
1125    #[test]
1126    fn test_alter_rename() {
1127        let stmt =
1128            parse_query("ALTER TABLE test RENAME FIELD 'Summary' TO 'Overview'").unwrap();
1129        if let Statement::AlterRename(q) = stmt {
1130            assert_eq!(q.old_name, "Summary");
1131            assert_eq!(q.new_name, "Overview");
1132        } else {
1133            panic!("Expected AlterRename");
1134        }
1135    }
1136
1137    #[test]
1138    fn test_alter_drop() {
1139        let stmt = parse_query("ALTER TABLE test DROP FIELD 'Details'").unwrap();
1140        if let Statement::AlterDrop(q) = stmt {
1141            assert_eq!(q.field_name, "Details");
1142        } else {
1143            panic!("Expected AlterDrop");
1144        }
1145    }
1146
1147    #[test]
1148    fn test_alter_merge() {
1149        let stmt = parse_query(
1150            "ALTER TABLE test MERGE FIELDS 'Entry Rules', 'Exit Rules' INTO 'Trading Rules'",
1151        )
1152        .unwrap();
1153        if let Statement::AlterMerge(q) = stmt {
1154            assert_eq!(q.sources, vec!["Entry Rules", "Exit Rules"]);
1155            assert_eq!(q.into, "Trading Rules");
1156        } else {
1157            panic!("Expected AlterMerge");
1158        }
1159    }
1160
1161    #[test]
1162    fn test_backtick_ident() {
1163        let stmt = parse_query("SELECT `Structural Mechanism` FROM test").unwrap();
1164        if let Statement::Select(q) = stmt {
1165            assert_eq!(
1166                q.columns,
1167                ColumnList::Named(vec![SelectExpr::Column("Structural Mechanism".into())])
1168            );
1169        } else {
1170            panic!("Expected Select");
1171        }
1172    }
1173
1174    #[test]
1175    fn test_like_operator() {
1176        let stmt = parse_query("SELECT title FROM test WHERE categories LIKE '%defi%'").unwrap();
1177        if let Statement::Select(q) = stmt {
1178            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1179                assert_eq!(c.op, CmpOp::Like);
1180                assert_eq!(c.value, Some(SqlValue::String("%defi%".into())));
1181            } else {
1182                panic!("Expected LIKE comparison");
1183            }
1184        } else {
1185            panic!("Expected Select");
1186        }
1187    }
1188
1189    #[test]
1190    fn test_in_operator() {
1191        let stmt =
1192            parse_query("SELECT * FROM test WHERE status IN ('ACTIVE', 'LIVE')").unwrap();
1193        if let Statement::Select(q) = stmt {
1194            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1195                assert_eq!(c.op, CmpOp::In);
1196            } else {
1197                panic!("Expected IN comparison");
1198            }
1199        } else {
1200            panic!("Expected Select");
1201        }
1202    }
1203
1204    #[test]
1205    fn test_is_null() {
1206        let stmt = parse_query("SELECT * FROM test WHERE title IS NULL").unwrap();
1207        if let Statement::Select(q) = stmt {
1208            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1209                assert_eq!(c.op, CmpOp::IsNull);
1210            } else {
1211                panic!("Expected IS NULL comparison");
1212            }
1213        } else {
1214            panic!("Expected Select");
1215        }
1216    }
1217
1218    #[test]
1219    fn test_and_or() {
1220        let stmt = parse_query(
1221            "SELECT * FROM test WHERE status = 'ACTIVE' AND count > 5 OR title LIKE '%test%'",
1222        )
1223        .unwrap();
1224        if let Statement::Select(q) = stmt {
1225            assert!(q.where_clause.is_some());
1226        } else {
1227            panic!("Expected Select");
1228        }
1229    }
1230
1231    #[test]
1232    fn test_join() {
1233        let stmt = parse_query(
1234            "SELECT s.title, b.sharpe FROM strategies s JOIN backtests b ON b.strategy = s.path",
1235        )
1236        .unwrap();
1237        if let Statement::Select(q) = stmt {
1238            assert_eq!(q.table, "strategies");
1239            assert_eq!(q.table_alias, Some("s".into()));
1240            assert_eq!(q.joins.len(), 1);
1241            let join = &q.joins[0];
1242            assert_eq!(join.table, "backtests");
1243            assert_eq!(join.alias, Some("b".into()));
1244        } else {
1245            panic!("Expected Select");
1246        }
1247    }
1248
1249    #[test]
1250    fn test_multi_join() {
1251        let stmt = parse_query(
1252            "SELECT s.title, b.sharpe, c.verdict FROM strategies s JOIN backtests b ON b.strategy = s.path JOIN critiques c ON c.strategy = s.path",
1253        )
1254        .unwrap();
1255        if let Statement::Select(q) = stmt {
1256            assert_eq!(q.table, "strategies");
1257            assert_eq!(q.table_alias, Some("s".into()));
1258            assert_eq!(q.joins.len(), 2);
1259            assert_eq!(q.joins[0].table, "backtests");
1260            assert_eq!(q.joins[0].alias, Some("b".into()));
1261            assert_eq!(where_clause_to_sql(&q.joins[0].condition), "b.strategy = s.path");
1262            assert_eq!(q.joins[1].table, "critiques");
1263            assert_eq!(q.joins[1].alias, Some("c".into()));
1264            assert_eq!(where_clause_to_sql(&q.joins[1].condition), "c.strategy = s.path");
1265        } else {
1266            panic!("Expected Select");
1267        }
1268    }
1269
1270    #[test]
1271    fn test_left_join() {
1272        let stmt = parse_query(
1273            "SELECT s.title, b.sharpe FROM strategies s LEFT JOIN backtests b ON b.strategy = s.path",
1274        )
1275        .unwrap();
1276        if let Statement::Select(q) = stmt {
1277            assert_eq!(q.joins.len(), 1);
1278            assert_eq!(q.joins[0].join_type, JoinType::Left);
1279            assert_eq!(q.joins[0].table, "backtests");
1280        } else {
1281            panic!("Expected Select");
1282        }
1283    }
1284
1285    #[test]
1286    fn test_mixed_join_types() {
1287        let stmt = parse_query(
1288            "SELECT s.title FROM strategies s JOIN backtests b ON b.strategy = s.path LEFT JOIN allocations a ON a.strategy = s.path",
1289        )
1290        .unwrap();
1291        if let Statement::Select(q) = stmt {
1292            assert_eq!(q.joins.len(), 2);
1293            assert_eq!(q.joins[0].join_type, JoinType::Inner);
1294            assert_eq!(q.joins[1].join_type, JoinType::Left);
1295        } else {
1296            panic!("Expected Select");
1297        }
1298    }
1299
1300    #[test]
1301    fn test_join_compound_and() {
1302        let stmt = parse_query(
1303            "SELECT s.title FROM strategies s LEFT JOIN backtests b ON b.strategy = s.path AND b.mode = 'PAPER'",
1304        )
1305        .unwrap();
1306        if let Statement::Select(q) = stmt {
1307            assert_eq!(q.joins.len(), 1);
1308            assert_eq!(q.joins[0].join_type, JoinType::Left);
1309            let sql = where_clause_to_sql(&q.joins[0].condition);
1310            assert!(sql.contains("b.strategy = s.path"));
1311            assert!(sql.contains("AND"));
1312            assert!(sql.contains("b.mode = 'PAPER'"));
1313        } else {
1314            panic!("Expected Select");
1315        }
1316    }
1317
1318    #[test]
1319    fn test_join_compound_or() {
1320        let stmt = parse_query(
1321            "SELECT * FROM a JOIN b ON a.id = b.id OR a.alt = b.id",
1322        )
1323        .unwrap();
1324        if let Statement::Select(q) = stmt {
1325            let sql = where_clause_to_sql(&q.joins[0].condition);
1326            assert!(sql.contains("OR"));
1327        } else {
1328            panic!("Expected Select");
1329        }
1330    }
1331
1332    #[test]
1333    fn test_join_compound_with_where() {
1334        let stmt = parse_query(
1335            "SELECT s.title FROM strategies s JOIN backtests b ON b.strategy = s.path AND b.mode = 'PAPER' WHERE s.title = 'Alpha'",
1336        )
1337        .unwrap();
1338        if let Statement::Select(q) = stmt {
1339            assert_eq!(q.joins.len(), 1);
1340            assert!(q.where_clause.is_some());
1341            let join_sql = where_clause_to_sql(&q.joins[0].condition);
1342            assert!(join_sql.contains("AND"));
1343        } else {
1344            panic!("Expected Select");
1345        }
1346    }
1347
1348    #[test]
1349    fn test_empty_query() {
1350        assert!(parse_query("").is_err());
1351    }
1352
1353    #[test]
1354    fn test_count_star() {
1355        let stmt = parse_query("SELECT status, COUNT(*) AS cnt FROM strategies GROUP BY status").unwrap();
1356        if let Statement::Select(q) = stmt {
1357            if let ColumnList::Named(exprs) = &q.columns {
1358                assert_eq!(exprs.len(), 2);
1359                assert_eq!(exprs[0], SelectExpr::Column("status".into()));
1360                assert!(matches!(&exprs[1], SelectExpr::Aggregate {
1361                    func: AggFunc::Count,
1362                    arg,
1363                    alias: Some(a),
1364                    ..
1365                } if arg == "*" && a == "cnt"));
1366            } else {
1367                panic!("Expected Named columns");
1368            }
1369            assert_eq!(q.group_by, Some(vec!["status".into()]));
1370        } else {
1371            panic!("Expected Select");
1372        }
1373    }
1374
1375    #[test]
1376    fn test_count_column_as_ident() {
1377        // "count" as a column name should NOT be parsed as the COUNT aggregate
1378        let stmt = parse_query("INSERT INTO test (title, count) VALUES ('Hello', 42)").unwrap();
1379        if let Statement::Insert(q) = stmt {
1380            assert_eq!(q.columns, vec!["title", "count"]);
1381        } else {
1382            panic!("Expected Insert");
1383        }
1384    }
1385
1386    #[test]
1387    fn test_multiple_aggregates() {
1388        let stmt = parse_query("SELECT MIN(composite), MAX(composite), AVG(composite) FROM strategies").unwrap();
1389        if let Statement::Select(q) = stmt {
1390            if let ColumnList::Named(exprs) = &q.columns {
1391                assert_eq!(exprs.len(), 3);
1392                assert!(matches!(&exprs[0], SelectExpr::Aggregate { func: AggFunc::Min, .. }));
1393                assert!(matches!(&exprs[1], SelectExpr::Aggregate { func: AggFunc::Max, .. }));
1394                assert!(matches!(&exprs[2], SelectExpr::Aggregate { func: AggFunc::Avg, .. }));
1395            } else {
1396                panic!("Expected Named columns");
1397            }
1398            assert_eq!(q.group_by, None);
1399        } else {
1400            panic!("Expected Select");
1401        }
1402    }
1403
1404    // ── Expression tests ──────────────────────────────────────────
1405
1406    #[test]
1407    fn test_select_arithmetic_expr() {
1408        let stmt = parse_query("SELECT a + b FROM test").unwrap();
1409        if let Statement::Select(q) = stmt {
1410            if let ColumnList::Named(exprs) = &q.columns {
1411                assert_eq!(exprs.len(), 1);
1412                assert!(matches!(&exprs[0], SelectExpr::Expr {
1413                    expr: Expr::BinaryOp { op: ArithOp::Add, .. },
1414                    alias: None,
1415                }));
1416            } else {
1417                panic!("Expected Named columns");
1418            }
1419        } else {
1420            panic!("Expected Select");
1421        }
1422    }
1423
1424    #[test]
1425    fn test_select_arithmetic_with_alias() {
1426        let stmt = parse_query("SELECT a + b AS total FROM test").unwrap();
1427        if let Statement::Select(q) = stmt {
1428            if let ColumnList::Named(exprs) = &q.columns {
1429                assert_eq!(exprs.len(), 1);
1430                assert!(matches!(&exprs[0], SelectExpr::Expr {
1431                    alias: Some(a),
1432                    ..
1433                } if a == "total"));
1434                assert_eq!(exprs[0].output_name(), "total");
1435            } else {
1436                panic!("Expected Named columns");
1437            }
1438        } else {
1439            panic!("Expected Select");
1440        }
1441    }
1442
1443    #[test]
1444    fn test_select_precedence() {
1445        // a + b * c should parse as a + (b * c)
1446        let stmt = parse_query("SELECT a + b * c FROM test").unwrap();
1447        if let Statement::Select(q) = stmt {
1448            if let ColumnList::Named(exprs) = &q.columns {
1449                if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1450                    if let Expr::BinaryOp { left, op, right } = expr {
1451                        assert_eq!(*op, ArithOp::Add);
1452                        assert!(matches!(left.as_ref(), Expr::Column(n) if n == "a"));
1453                        assert!(matches!(right.as_ref(), Expr::BinaryOp { op: ArithOp::Mul, .. }));
1454                    } else {
1455                        panic!("Expected BinaryOp");
1456                    }
1457                } else {
1458                    panic!("Expected Expr variant");
1459                }
1460            } else {
1461                panic!("Expected Named columns");
1462            }
1463        } else {
1464            panic!("Expected Select");
1465        }
1466    }
1467
1468    #[test]
1469    fn test_select_parenthesized_expr() {
1470        // (a + b) * c should override default precedence
1471        let stmt = parse_query("SELECT (a + b) * c FROM test").unwrap();
1472        if let Statement::Select(q) = stmt {
1473            if let ColumnList::Named(exprs) = &q.columns {
1474                if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1475                    if let Expr::BinaryOp { left, op, .. } = expr {
1476                        assert_eq!(*op, ArithOp::Mul);
1477                        assert!(matches!(left.as_ref(), Expr::BinaryOp { op: ArithOp::Add, .. }));
1478                    } else {
1479                        panic!("Expected BinaryOp");
1480                    }
1481                } else {
1482                    panic!("Expected Expr variant");
1483                }
1484            } else {
1485                panic!("Expected Named columns");
1486            }
1487        } else {
1488            panic!("Expected Select");
1489        }
1490    }
1491
1492    #[test]
1493    fn test_select_unary_minus() {
1494        let stmt = parse_query("SELECT -count FROM test").unwrap();
1495        if let Statement::Select(q) = stmt {
1496            if let ColumnList::Named(exprs) = &q.columns {
1497                assert!(matches!(&exprs[0], SelectExpr::Expr {
1498                    expr: Expr::UnaryMinus(_),
1499                    ..
1500                }));
1501            } else {
1502                panic!("Expected Named columns");
1503            }
1504        } else {
1505            panic!("Expected Select");
1506        }
1507    }
1508
1509    #[test]
1510    fn test_select_negative_literal() {
1511        let stmt = parse_query("SELECT -42 FROM test").unwrap();
1512        if let Statement::Select(q) = stmt {
1513            if let ColumnList::Named(exprs) = &q.columns {
1514                // Unary minus folds into the literal
1515                assert!(matches!(&exprs[0], SelectExpr::Expr {
1516                    expr: Expr::Literal(SqlValue::Int(-42)),
1517                    ..
1518                }));
1519            } else {
1520                panic!("Expected Named columns");
1521            }
1522        } else {
1523            panic!("Expected Select");
1524        }
1525    }
1526
1527    #[test]
1528    fn test_where_arithmetic_expr() {
1529        let stmt = parse_query("SELECT * FROM test WHERE a + b > 10").unwrap();
1530        if let Statement::Select(q) = stmt {
1531            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1532                assert_eq!(c.op, CmpOp::Gt);
1533                assert!(matches!(&c.left_expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1534                assert!(matches!(&c.right_expr, Some(Expr::Literal(SqlValue::Int(10)))));
1535            } else {
1536                panic!("Expected comparison");
1537            }
1538        } else {
1539            panic!("Expected Select");
1540        }
1541    }
1542
1543    #[test]
1544    fn test_where_both_sides_expr() {
1545        let stmt = parse_query("SELECT * FROM test WHERE a * 2 > b + 1").unwrap();
1546        if let Statement::Select(q) = stmt {
1547            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1548                assert_eq!(c.op, CmpOp::Gt);
1549                assert!(matches!(&c.left_expr, Some(Expr::BinaryOp { op: ArithOp::Mul, .. })));
1550                assert!(matches!(&c.right_expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1551            } else {
1552                panic!("Expected comparison");
1553            }
1554        } else {
1555            panic!("Expected Select");
1556        }
1557    }
1558
1559    #[test]
1560    fn test_order_by_expr() {
1561        let stmt = parse_query("SELECT * FROM test ORDER BY a + b DESC").unwrap();
1562        if let Statement::Select(q) = stmt {
1563            let ob = q.order_by.unwrap();
1564            assert_eq!(ob.len(), 1);
1565            assert!(ob[0].descending);
1566            assert!(matches!(&ob[0].expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1567        } else {
1568            panic!("Expected Select");
1569        }
1570    }
1571
1572    #[test]
1573    fn test_all_arithmetic_ops() {
1574        let stmt = parse_query("SELECT a + b, a - b, a * b, a / b, a % b FROM test").unwrap();
1575        if let Statement::Select(q) = stmt {
1576            if let ColumnList::Named(exprs) = &q.columns {
1577                assert_eq!(exprs.len(), 5);
1578                assert!(matches!(&exprs[0], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Add, .. }, .. }));
1579                assert!(matches!(&exprs[1], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Sub, .. }, .. }));
1580                assert!(matches!(&exprs[2], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Mul, .. }, .. }));
1581                assert!(matches!(&exprs[3], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Div, .. }, .. }));
1582                assert!(matches!(&exprs[4], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Mod, .. }, .. }));
1583            } else {
1584                panic!("Expected Named columns");
1585            }
1586        } else {
1587            panic!("Expected Select");
1588        }
1589    }
1590
1591    #[test]
1592    fn test_column_with_literal_arithmetic() {
1593        let stmt = parse_query("SELECT count * 2 + 1 FROM test").unwrap();
1594        if let Statement::Select(q) = stmt {
1595            if let ColumnList::Named(exprs) = &q.columns {
1596                // Should be (count * 2) + 1
1597                if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1598                    if let Expr::BinaryOp { left, op, right } = expr {
1599                        assert_eq!(*op, ArithOp::Add);
1600                        assert!(matches!(right.as_ref(), Expr::Literal(SqlValue::Int(1))));
1601                        assert!(matches!(left.as_ref(), Expr::BinaryOp { op: ArithOp::Mul, .. }));
1602                    } else {
1603                        panic!("Expected BinaryOp");
1604                    }
1605                } else {
1606                    panic!("Expected Expr");
1607                }
1608            } else {
1609                panic!("Expected Named columns");
1610            }
1611        } else {
1612            panic!("Expected Select");
1613        }
1614    }
1615
1616    #[test]
1617    fn test_mixed_columns_and_exprs() {
1618        let stmt = parse_query("SELECT title, a + b AS sum, count FROM test").unwrap();
1619        if let Statement::Select(q) = stmt {
1620            if let ColumnList::Named(exprs) = &q.columns {
1621                assert_eq!(exprs.len(), 3);
1622                assert_eq!(exprs[0], SelectExpr::Column("title".into()));
1623                assert!(matches!(&exprs[1], SelectExpr::Expr { alias: Some(a), .. } if a == "sum"));
1624                assert_eq!(exprs[2], SelectExpr::Column("count".into()));
1625            } else {
1626                panic!("Expected Named columns");
1627            }
1628        } else {
1629            panic!("Expected Select");
1630        }
1631    }
1632
1633    // ── CASE WHEN tests ──────────────────────────────────────────
1634
1635    #[test]
1636    fn test_case_when_basic() {
1637        let stmt = parse_query(
1638            "SELECT CASE WHEN status = 'ACTIVE' THEN 1 ELSE 0 END FROM test"
1639        ).unwrap();
1640        if let Statement::Select(q) = stmt {
1641            if let ColumnList::Named(exprs) = &q.columns {
1642                assert_eq!(exprs.len(), 1);
1643                assert!(matches!(&exprs[0], SelectExpr::Expr {
1644                    expr: Expr::Case { .. },
1645                    ..
1646                }));
1647            } else {
1648                panic!("Expected Named columns");
1649            }
1650        } else {
1651            panic!("Expected Select");
1652        }
1653    }
1654
1655    #[test]
1656    fn test_case_when_multiple_branches() {
1657        let stmt = parse_query(
1658            "SELECT CASE WHEN x > 10 THEN 'high' WHEN x > 5 THEN 'mid' ELSE 'low' END FROM test"
1659        ).unwrap();
1660        if let Statement::Select(q) = stmt {
1661            if let ColumnList::Named(exprs) = &q.columns {
1662                if let SelectExpr::Expr { expr: Expr::Case { whens, else_expr }, .. } = &exprs[0] {
1663                    assert_eq!(whens.len(), 2);
1664                    assert!(else_expr.is_some());
1665                } else {
1666                    panic!("Expected Case expression");
1667                }
1668            } else {
1669                panic!("Expected Named columns");
1670            }
1671        } else {
1672            panic!("Expected Select");
1673        }
1674    }
1675
1676    #[test]
1677    fn test_case_when_no_else() {
1678        let stmt = parse_query(
1679            "SELECT CASE WHEN x = 1 THEN 'one' END FROM test"
1680        ).unwrap();
1681        if let Statement::Select(q) = stmt {
1682            if let ColumnList::Named(exprs) = &q.columns {
1683                if let SelectExpr::Expr { expr: Expr::Case { whens, else_expr }, .. } = &exprs[0] {
1684                    assert_eq!(whens.len(), 1);
1685                    assert!(else_expr.is_none());
1686                } else {
1687                    panic!("Expected Case expression");
1688                }
1689            } else {
1690                panic!("Expected Named columns");
1691            }
1692        } else {
1693            panic!("Expected Select");
1694        }
1695    }
1696
1697    #[test]
1698    fn test_case_when_in_aggregate() {
1699        let stmt = parse_query(
1700            "SELECT SUM(CASE WHEN side = 'BUY' THEN size ELSE -size END) AS net FROM orders GROUP BY token"
1701        ).unwrap();
1702        if let Statement::Select(q) = stmt {
1703            if let ColumnList::Named(exprs) = &q.columns {
1704                assert_eq!(exprs.len(), 1);
1705                assert!(matches!(&exprs[0], SelectExpr::Aggregate {
1706                    func: AggFunc::Sum,
1707                    arg_expr: Some(Expr::Case { .. }),
1708                    alias: Some(a),
1709                    ..
1710                } if a == "net"));
1711            } else {
1712                panic!("Expected Named columns");
1713            }
1714        } else {
1715            panic!("Expected Select");
1716        }
1717    }
1718
1719    #[test]
1720    fn test_case_when_with_alias() {
1721        let stmt = parse_query(
1722            "SELECT CASE WHEN x > 0 THEN 'pos' ELSE 'neg' END AS sign FROM test"
1723        ).unwrap();
1724        if let Statement::Select(q) = stmt {
1725            if let ColumnList::Named(exprs) = &q.columns {
1726                assert!(matches!(&exprs[0], SelectExpr::Expr {
1727                    expr: Expr::Case { .. },
1728                    alias: Some(a),
1729                } if a == "sign"));
1730            } else {
1731                panic!("Expected Named columns");
1732            }
1733        } else {
1734            panic!("Expected Select");
1735        }
1736    }
1737
1738    #[test]
1739    fn test_create_view() {
1740        let stmt = parse_query("CREATE VIEW live AS SELECT * FROM strategies WHERE status = 'LIVE'").unwrap();
1741        if let Statement::CreateView(cv) = stmt {
1742            assert_eq!(cv.view_name, "live");
1743            assert!(cv.columns.is_none());
1744            assert_eq!(cv.query.table, "strategies");
1745            assert!(cv.query.where_clause.is_some());
1746        } else {
1747            panic!("Expected CreateView, got {:?}", stmt);
1748        }
1749    }
1750
1751    #[test]
1752    fn test_create_view_with_columns() {
1753        let stmt = parse_query("CREATE VIEW v1 (a, b) AS SELECT title, status FROM t").unwrap();
1754        if let Statement::CreateView(cv) = stmt {
1755            assert_eq!(cv.view_name, "v1");
1756            assert_eq!(cv.columns, Some(vec!["a".into(), "b".into()]));
1757        } else {
1758            panic!("Expected CreateView");
1759        }
1760    }
1761
1762    #[test]
1763    fn test_drop_view() {
1764        let stmt = parse_query("DROP VIEW live").unwrap();
1765        if let Statement::DropView(dv) = stmt {
1766            assert_eq!(dv.view_name, "live");
1767        } else {
1768            panic!("Expected DropView, got {:?}", stmt);
1769        }
1770    }
1771
1772    #[test]
1773    fn test_create_view_case_insensitive() {
1774        let stmt = parse_query("create view My_View as select * from t").unwrap();
1775        if let Statement::CreateView(cv) = stmt {
1776            assert_eq!(cv.view_name, "My_View");
1777        } else {
1778            panic!("Expected CreateView");
1779        }
1780    }
1781
1782    // ── Issue #42: Arithmetic between aggregates in column expressions ──
1783
1784    #[test]
1785    fn test_aggregate_division() {
1786        let stmt = parse_query(
1787            "SELECT token, SUM(sell) / SUM(buy) as ratio FROM orders GROUP BY token"
1788        ).unwrap();
1789        if let Statement::Select(q) = stmt {
1790            assert_eq!(q.group_by, Some(vec!["token".into()]));
1791            if let ColumnList::Named(exprs) = &q.columns {
1792                assert_eq!(exprs.len(), 2);
1793                assert!(exprs[1].is_aggregate());
1794            } else {
1795                panic!("Expected Named columns");
1796            }
1797        } else {
1798            panic!("Expected Select");
1799        }
1800    }
1801
1802    #[test]
1803    fn test_aggregate_subtraction() {
1804        let stmt = parse_query(
1805            "SELECT token, SUM(sell) - SUM(buy) as net FROM orders GROUP BY token"
1806        ).unwrap();
1807        if let Statement::Select(q) = stmt {
1808            if let ColumnList::Named(exprs) = &q.columns {
1809                assert_eq!(exprs[1].output_name(), "net");
1810            }
1811        } else {
1812            panic!("Expected Select");
1813        }
1814    }
1815
1816    #[test]
1817    fn test_create_view_with_arithmetic() {
1818        let stmt = parse_query(
1819            "CREATE VIEW positions AS SELECT token, SUM(sell) / SUM(buy) as ratio FROM orders GROUP BY token"
1820        ).unwrap();
1821        if let Statement::CreateView(cv) = stmt {
1822            assert_eq!(cv.view_name, "positions");
1823        } else {
1824            panic!("Expected CreateView, got {:?}", stmt);
1825        }
1826    }
1827
1828    // ── Issue #43: Subqueries in FROM ──
1829
1830    #[test]
1831    fn test_subquery_in_from() {
1832        let stmt = parse_query(
1833            "SELECT token, sell_size FROM (SELECT token, SUM(size) as sell_size FROM orders GROUP BY token) LIMIT 5"
1834        ).unwrap();
1835        if let Statement::Select(q) = stmt {
1836            assert!(q.subquery.is_some());
1837            assert_eq!(q.limit, Some(5));
1838            let sub = q.subquery.unwrap();
1839            assert_eq!(sub.table, "orders");
1840            assert!(sub.group_by.is_some());
1841        } else {
1842            panic!("Expected Select");
1843        }
1844    }
1845
1846    // ── Issue #44: HAVING in CREATE VIEW ──
1847
1848    #[test]
1849    fn test_create_view_with_having() {
1850        let stmt = parse_query(
1851            "CREATE VIEW positions AS SELECT token, SUM(sell) as sell_size, SUM(buy) as buy_size FROM orders GROUP BY token HAVING sell_size > buy_size"
1852        ).unwrap();
1853        if let Statement::CreateView(cv) = stmt {
1854            assert_eq!(cv.view_name, "positions");
1855            assert!(cv.query.having.is_some());
1856        } else {
1857            panic!("Expected CreateView, got {:?}", stmt);
1858        }
1859    }
1860
1861    // ── Issue #42: Aggregate multiplication ──
1862
1863    #[test]
1864    fn test_aggregate_multiplication() {
1865        let stmt = parse_query(
1866            "SELECT SUM(a) * 2 as doubled FROM test"
1867        ).unwrap();
1868        if let Statement::Select(q) = stmt {
1869            if let ColumnList::Named(exprs) = &q.columns {
1870                assert_eq!(exprs.len(), 1);
1871                assert!(exprs[0].is_aggregate());
1872                assert_eq!(exprs[0].output_name(), "doubled");
1873            } else {
1874                panic!("Expected Named columns");
1875            }
1876        } else {
1877            panic!("Expected Select");
1878        }
1879    }
1880
1881    #[test]
1882    fn test_complex_aggregate_arithmetic() {
1883        let stmt = parse_query(
1884            "SELECT SUM(CASE WHEN side = 'SELL' THEN size ELSE 0 END) / SUM(CASE WHEN side = 'BUY' THEN size ELSE 0 END) as ratio FROM orders GROUP BY token"
1885        ).unwrap();
1886        if let Statement::Select(q) = stmt {
1887            if let ColumnList::Named(exprs) = &q.columns {
1888                assert_eq!(exprs.len(), 1);
1889                assert!(exprs[0].is_aggregate());
1890                assert_eq!(exprs[0].output_name(), "ratio");
1891            } else {
1892                panic!("Expected Named columns");
1893            }
1894            assert_eq!(q.group_by, Some(vec!["token".into()]));
1895        } else {
1896            panic!("Expected Select");
1897        }
1898    }
1899
1900    // ── Issue #43: Subquery with alias and WHERE ──
1901
1902    #[test]
1903    fn test_subquery_with_alias() {
1904        let stmt = parse_query(
1905            "SELECT x FROM (SELECT x FROM t) sub"
1906        ).unwrap();
1907        if let Statement::Select(q) = stmt {
1908            assert!(q.subquery.is_some());
1909            let sub = q.subquery.unwrap();
1910            assert_eq!(sub.table, "t");
1911            if let ColumnList::Named(exprs) = &q.columns {
1912                assert_eq!(exprs.len(), 1);
1913                assert_eq!(exprs[0].output_name(), "x");
1914            } else {
1915                panic!("Expected Named columns");
1916            }
1917        } else {
1918            panic!("Expected Select");
1919        }
1920    }
1921
1922    #[test]
1923    fn test_subquery_with_where() {
1924        let stmt = parse_query(
1925            "SELECT x FROM (SELECT x FROM t WHERE y > 0) LIMIT 5"
1926        ).unwrap();
1927        if let Statement::Select(q) = stmt {
1928            assert!(q.subquery.is_some());
1929            assert_eq!(q.limit, Some(5));
1930            let sub = q.subquery.unwrap();
1931            assert_eq!(sub.table, "t");
1932            assert!(sub.where_clause.is_some());
1933        } else {
1934            panic!("Expected Select");
1935        }
1936    }
1937
1938    // ── Issue #42 + CREATE VIEW: aggregate subtraction in view ──
1939
1940    #[test]
1941    fn test_create_view_aggregate_subtraction() {
1942        let stmt = parse_query(
1943            "CREATE VIEW v AS SELECT token, SUM(sell) - SUM(buy) as net FROM orders GROUP BY token"
1944        ).unwrap();
1945        if let Statement::CreateView(cv) = stmt {
1946            assert_eq!(cv.view_name, "v");
1947            assert_eq!(cv.query.group_by, Some(vec!["token".into()]));
1948            if let ColumnList::Named(exprs) = &cv.query.columns {
1949                assert_eq!(exprs.len(), 2);
1950                assert_eq!(exprs[1].output_name(), "net");
1951                assert!(exprs[1].is_aggregate());
1952            } else {
1953                panic!("Expected Named columns");
1954            }
1955        } else {
1956            panic!("Expected CreateView, got {:?}", stmt);
1957        }
1958    }
1959}