Skip to main content

mdql_core/
query_parser.rs

1//! Hand-written recursive descent parser for the MDQL SQL subset.
2
3use regex::Regex;
4use std::sync::LazyLock;
5
6use crate::errors::MdqlError;
7pub use crate::query_ast::*;
8
9// ── Tokenizer ──────────────────────────────────────────────────────────────
10
11static KEYWORDS: &[&str] = &[
12    "SELECT", "FROM", "WHERE", "AND", "OR", "ORDER", "BY",
13    "ASC", "DESC", "LIMIT", "LIKE", "IN", "IS", "NOT", "NULL",
14    "JOIN", "ON", "AS", "GROUP", "HAVING",
15    "INSERT", "INTO", "VALUES", "UPDATE", "SET", "DELETE",
16    "ALTER", "TABLE", "RENAME", "FIELD", "TO", "DROP", "MERGE", "FIELDS",
17    "CASE", "WHEN", "THEN", "ELSE", "END",
18    "INTERVAL", "DAY", "DAYS", "CURRENT_DATE", "CURRENT_TIMESTAMP", "DATEDIFF",
19    "CREATE", "VIEW", "CASCADE", "RESTRICT",
20];
21
22static AGG_FUNCS: &[&str] = &["COUNT", "SUM", "AVG", "MIN", "MAX"];
23
24static TOKEN_RE: LazyLock<Regex> = LazyLock::new(|| {
25    Regex::new(
26        r#"(?x)
27        \s*(?:
28            (?P<backtick>`[^`]+`)
29            | (?P<string>'(?:[^'\\]|\\.)*')
30            | (?P<number>\d+(?:\.\d+)?)
31            | (?P<op><=|>=|!=|[=<>,*()+\-/%])
32            | (?P<word>[A-Za-z_][A-Za-z0-9_./-]*)
33        )"#,
34    )
35    .unwrap()
36});
37
38#[derive(Debug, Clone)]
39struct Token {
40    token_type: String,
41    value: String,
42    raw: String,
43}
44
45fn tokenize(sql: &str) -> Vec<Token> {
46    let mut tokens = Vec::new();
47    for caps in TOKEN_RE.captures_iter(sql) {
48        if let Some(m) = caps.name("backtick") {
49            let raw = m.as_str();
50            tokens.push(Token {
51                token_type: "ident".into(),
52                value: raw[1..raw.len() - 1].into(),
53                raw: raw.into(),
54            });
55        } else if let Some(m) = caps.name("string") {
56            let raw = m.as_str();
57            tokens.push(Token {
58                token_type: "string".into(),
59                value: raw[1..raw.len() - 1].into(),
60                raw: raw.into(),
61            });
62        } else if let Some(m) = caps.name("number") {
63            let raw = m.as_str();
64            tokens.push(Token {
65                token_type: "number".into(),
66                value: raw.into(),
67                raw: raw.into(),
68            });
69        } else if let Some(m) = caps.name("op") {
70            let raw = m.as_str();
71            tokens.push(Token {
72                token_type: "op".into(),
73                value: raw.into(),
74                raw: raw.into(),
75            });
76        } else if let Some(m) = caps.name("word") {
77            let raw = m.as_str();
78            if KEYWORDS.contains(&raw.to_uppercase().as_str()) {
79                tokens.push(Token {
80                    token_type: "keyword".into(),
81                    value: raw.to_uppercase(),
82                    raw: raw.into(),
83                });
84            } else {
85                tokens.push(Token {
86                    token_type: "ident".into(),
87                    value: raw.into(),
88                    raw: raw.into(),
89                });
90            }
91        }
92    }
93    tokens
94}
95
96// ── Parser ─────────────────────────────────────────────────────────────────
97
98struct Parser {
99    tokens: Vec<Token>,
100    pos: usize,
101}
102
103impl Parser {
104    fn new(tokens: Vec<Token>) -> Self {
105        Parser { tokens, pos: 0 }
106    }
107
108    fn peek(&self) -> Option<&Token> {
109        self.tokens.get(self.pos)
110    }
111
112    fn advance(&mut self) -> Token {
113        let t = self.tokens[self.pos].clone();
114        self.pos += 1;
115        t
116    }
117
118    fn expect(&mut self, type_: &str, value: Option<&str>) -> Result<Token, MdqlError> {
119        let t = self.peek().ok_or_else(|| {
120            MdqlError::QueryParse(format!(
121                "Unexpected end of query, expected {}",
122                value.unwrap_or(type_)
123            ))
124        })?;
125        let matches_type = t.token_type == type_;
126        let matches_value = value.map_or(true, |v| t.value == v);
127        if !matches_type || !matches_value {
128            return Err(MdqlError::QueryParse(format!(
129                "Expected {}, got '{}' at position {}",
130                value.unwrap_or(type_),
131                t.raw,
132                self.pos
133            )));
134        }
135        Ok(self.advance())
136    }
137
138    fn match_keyword(&mut self, kw: &str) -> bool {
139        if let Some(t) = self.peek() {
140            if t.token_type == "keyword" && t.value == kw {
141                self.advance();
142                return true;
143            }
144        }
145        false
146    }
147
148    fn parse_statement(&mut self) -> Result<Statement, MdqlError> {
149        let t = self.peek().ok_or_else(|| MdqlError::QueryParse("Empty query".into()))?;
150        match (t.token_type.as_str(), t.value.as_str()) {
151            ("keyword", "SELECT") => {
152                let q = self.parse_select()?;
153                self.expect_end()?;
154                Ok(Statement::Select(q))
155            }
156            ("keyword", "INSERT") => Ok(Statement::Insert(self.parse_insert()?)),
157            ("keyword", "UPDATE") => Ok(Statement::Update(self.parse_update()?)),
158            ("keyword", "DELETE") => Ok(Statement::Delete(self.parse_delete()?)),
159            ("keyword", "ALTER") => self.parse_alter(),
160            ("keyword", "CREATE") => self.parse_create_view(),
161            ("keyword", "DROP") => self.parse_drop_view(),
162            _ => Err(MdqlError::QueryParse(format!(
163                "Expected SELECT, INSERT, UPDATE, DELETE, ALTER, CREATE, or DROP, got '{}'",
164                t.raw
165            ))),
166        }
167    }
168
169    fn parse_select(&mut self) -> Result<SelectQuery, MdqlError> {
170        self.expect("keyword", Some("SELECT"))?;
171        let columns = self.parse_columns()?;
172        self.expect("keyword", Some("FROM"))?;
173
174        // Subquery: FROM (SELECT ...)
175        let mut subquery = None;
176        let (table, mut table_alias) = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
177            self.advance();
178            let inner = self.parse_select()?;
179            self.expect("op", Some(")"))?;
180            subquery = Some(Box::new(inner));
181            let alias = if let Some(t) = self.peek() {
182                if t.token_type == "ident" && !self.is_clause_keyword(t) {
183                    Some(self.advance().value)
184                } else {
185                    None
186                }
187            } else {
188                None
189            };
190            ("_subquery".to_string(), alias)
191        } else {
192            let t = self.parse_ident()?;
193            (t, None)
194        };
195
196        // Optional table alias (for non-subquery)
197        if subquery.is_none() {
198            if let Some(t) = self.peek() {
199                if t.token_type == "ident" && !self.is_clause_keyword(t) {
200                    table_alias = Some(self.advance().value);
201                }
202            }
203        }
204
205        // Optional JOIN(s)
206        let mut joins = Vec::new();
207        while self.match_keyword("JOIN") {
208            let join_table = self.parse_ident()?;
209            let mut join_alias = None;
210            if let Some(t) = self.peek() {
211                if t.token_type == "ident" && !self.is_clause_keyword(t) {
212                    join_alias = Some(self.advance().value);
213                }
214            }
215            self.expect("keyword", Some("ON"))?;
216            let left_col = self.parse_ident()?;
217            self.expect("op", Some("="))?;
218            let right_col = self.parse_ident()?;
219            joins.push(JoinClause {
220                table: join_table,
221                alias: join_alias,
222                left_col,
223                right_col,
224            });
225        }
226
227        let mut where_clause = None;
228        if self.match_keyword("WHERE") {
229            where_clause = Some(self.parse_or_expr()?);
230        }
231
232        let mut group_by = None;
233        if self.match_keyword("GROUP") {
234            self.expect("keyword", Some("BY"))?;
235            let mut cols = vec![self.parse_ident()?];
236            while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
237                self.advance();
238                cols.push(self.parse_ident()?);
239            }
240            group_by = Some(cols);
241        }
242
243        let mut having = None;
244        if self.match_keyword("HAVING") {
245            having = Some(self.parse_or_expr()?);
246        }
247
248        let mut order_by = None;
249        if self.match_keyword("ORDER") {
250            self.expect("keyword", Some("BY"))?;
251            order_by = Some(self.parse_order_by()?);
252        }
253
254        let mut limit = None;
255        if self.match_keyword("LIMIT") {
256            let t = self.expect("number", None)?;
257            limit = Some(t.value.parse::<i64>().map_err(|_| {
258                MdqlError::QueryParse(format!("Invalid LIMIT value: {}", t.value))
259            })?);
260        }
261
262        Ok(SelectQuery {
263            columns,
264            table,
265            table_alias,
266            subquery,
267            joins,
268            where_clause,
269            group_by,
270            having,
271            order_by,
272            limit,
273        })
274    }
275
276    fn parse_insert(&mut self) -> Result<InsertQuery, MdqlError> {
277        self.expect("keyword", Some("INSERT"))?;
278        self.expect("keyword", Some("INTO"))?;
279        let table = self.parse_ident()?;
280
281        self.expect("op", Some("("))?;
282        let mut columns = vec![self.parse_ident()?];
283        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
284            self.advance();
285            columns.push(self.parse_ident()?);
286        }
287        self.expect("op", Some(")"))?;
288
289        self.expect("keyword", Some("VALUES"))?;
290
291        self.expect("op", Some("("))?;
292        let mut values = vec![self.parse_value()?];
293        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
294            self.advance();
295            values.push(self.parse_value()?);
296        }
297        self.expect("op", Some(")"))?;
298
299        if columns.len() != values.len() {
300            return Err(MdqlError::QueryParse(format!(
301                "Column count ({}) does not match value count ({})",
302                columns.len(),
303                values.len()
304            )));
305        }
306
307        self.expect_end()?;
308        Ok(InsertQuery {
309            table,
310            columns,
311            values,
312        })
313    }
314
315    fn parse_update(&mut self) -> Result<UpdateQuery, MdqlError> {
316        self.expect("keyword", Some("UPDATE"))?;
317        let table = self.parse_ident()?;
318        self.expect("keyword", Some("SET"))?;
319
320        let mut assignments = Vec::new();
321        let col = self.parse_ident()?;
322        self.expect("op", Some("="))?;
323        let val = self.parse_value()?;
324        assignments.push((col, val));
325
326        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
327            self.advance();
328            let col = self.parse_ident()?;
329            self.expect("op", Some("="))?;
330            let val = self.parse_value()?;
331            assignments.push((col, val));
332        }
333
334        let mut where_clause = None;
335        if self.match_keyword("WHERE") {
336            where_clause = Some(self.parse_or_expr()?);
337        }
338
339        self.expect_end()?;
340        Ok(UpdateQuery {
341            table,
342            assignments,
343            where_clause,
344        })
345    }
346
347    fn parse_delete(&mut self) -> Result<DeleteQuery, MdqlError> {
348        self.expect("keyword", Some("DELETE"))?;
349        self.expect("keyword", Some("FROM"))?;
350        let table = self.parse_ident()?;
351
352        let mut where_clause = None;
353        if self.match_keyword("WHERE") {
354            where_clause = Some(self.parse_or_expr()?);
355        }
356
357        self.expect_end()?;
358        Ok(DeleteQuery {
359            table,
360            where_clause,
361        })
362    }
363
364    fn parse_alter(&mut self) -> Result<Statement, MdqlError> {
365        self.expect("keyword", Some("ALTER"))?;
366        self.expect("keyword", Some("TABLE"))?;
367        let table = self.parse_ident()?;
368
369        let t = self.peek().ok_or_else(|| {
370            MdqlError::QueryParse("Expected RENAME, DROP, or MERGE after table name".into())
371        })?;
372
373        match (t.token_type.as_str(), t.value.as_str()) {
374            ("keyword", "RENAME") => {
375                self.advance();
376                self.expect("keyword", Some("FIELD"))?;
377                let old_name = self.parse_string_or_ident()?;
378                self.expect("keyword", Some("TO"))?;
379                let new_name = self.parse_string_or_ident()?;
380                self.expect_end()?;
381                Ok(Statement::AlterRename(AlterRenameFieldQuery {
382                    table,
383                    old_name,
384                    new_name,
385                }))
386            }
387            ("keyword", "DROP") => {
388                self.advance();
389                self.expect("keyword", Some("FIELD"))?;
390                let field_name = self.parse_string_or_ident()?;
391                self.expect_end()?;
392                Ok(Statement::AlterDrop(AlterDropFieldQuery {
393                    table,
394                    field_name,
395                }))
396            }
397            ("keyword", "MERGE") => {
398                self.advance();
399                self.expect("keyword", Some("FIELDS"))?;
400                let mut sources = vec![self.parse_string_or_ident()?];
401                while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
402                    self.advance();
403                    sources.push(self.parse_string_or_ident()?);
404                }
405                self.expect("keyword", Some("INTO"))?;
406                let target = self.parse_string_or_ident()?;
407                self.expect_end()?;
408                Ok(Statement::AlterMerge(AlterMergeFieldsQuery {
409                    table,
410                    sources,
411                    into: target,
412                }))
413            }
414            _ => Err(MdqlError::QueryParse(format!(
415                "Expected RENAME, DROP, or MERGE, got '{}'",
416                t.raw
417            ))),
418        }
419    }
420
421    fn parse_create_view(&mut self) -> Result<Statement, MdqlError> {
422        self.expect("keyword", Some("CREATE"))?;
423        self.expect("keyword", Some("VIEW"))?;
424        let view_name = self.parse_ident()?;
425
426        let columns = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
427            self.advance();
428            let mut cols = vec![self.parse_ident()?];
429            while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
430                self.advance();
431                cols.push(self.parse_ident()?);
432            }
433            self.expect("op", Some(")"))?;
434            Some(cols)
435        } else {
436            None
437        };
438
439        self.expect("keyword", Some("AS"))?;
440        let query = Box::new(self.parse_select()?);
441        self.expect_end()?;
442
443        Ok(Statement::CreateView(CreateViewQuery {
444            view_name,
445            columns,
446            query,
447        }))
448    }
449
450    fn parse_drop_view(&mut self) -> Result<Statement, MdqlError> {
451        self.expect("keyword", Some("DROP"))?;
452        self.expect("keyword", Some("VIEW"))?;
453        let view_name = self.parse_ident()?;
454        self.expect_end()?;
455        Ok(Statement::DropView(DropViewQuery { view_name }))
456    }
457
458    fn parse_string_or_ident(&mut self) -> Result<String, MdqlError> {
459        let t = self.peek().ok_or_else(|| {
460            MdqlError::QueryParse("Expected field name, got end of query".into())
461        })?;
462        match t.token_type.as_str() {
463            "string" => {
464                let v = self.advance().value;
465                Ok(v)
466            }
467            "ident" | "keyword" => {
468                let v = self.advance().value;
469                Ok(v)
470            }
471            _ => Err(MdqlError::QueryParse(format!(
472                "Expected field name, got '{}'",
473                t.raw
474            ))),
475        }
476    }
477
478    fn parse_columns(&mut self) -> Result<ColumnList, MdqlError> {
479        if let Some(t) = self.peek() {
480            if t.token_type == "op" && t.value == "*" {
481                self.advance();
482                return Ok(ColumnList::All);
483            }
484        }
485
486        let mut exprs = vec![self.parse_select_expr()?];
487        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
488            self.advance();
489            exprs.push(self.parse_select_expr()?);
490        }
491        Ok(ColumnList::Named(exprs))
492    }
493
494    fn peek_is_agg_func(&self) -> bool {
495        let t = match self.peek() {
496            Some(t) => t,
497            None => return false,
498        };
499        let name_upper = t.value.to_uppercase();
500        if !AGG_FUNCS.contains(&name_upper.as_str()) {
501            return false;
502        }
503        // Only treat as aggregate if followed by (
504        self.tokens
505            .get(self.pos + 1)
506            .map_or(false, |next| next.token_type == "op" && next.value == "(")
507    }
508
509    fn parse_select_expr(&mut self) -> Result<SelectExpr, MdqlError> {
510        let _t = self.peek().ok_or_else(|| {
511            MdqlError::QueryParse("Expected column or aggregate, got end of query".into())
512        })?;
513
514        let expr = self.parse_additive()?;
515
516        let alias = if self.match_keyword("AS") {
517            Some(self.parse_ident()?)
518        } else if self.peek().map_or(false, |t| {
519            t.token_type == "ident" && !self.is_clause_keyword(t)
520        }) {
521            Some(self.advance().value)
522        } else {
523            None
524        };
525
526        // Bare aggregate → SelectExpr::Aggregate for backward compat
527        if let Expr::Aggregate { func, arg, arg_expr } = expr {
528            return Ok(SelectExpr::Aggregate {
529                func,
530                arg,
531                arg_expr: arg_expr.map(|e| *e),
532                alias,
533            });
534        }
535
536        if alias.is_none() {
537            if let Expr::Column(name) = &expr {
538                return Ok(SelectExpr::Column(name.clone()));
539            }
540        }
541
542        Ok(SelectExpr::Expr { expr, alias })
543    }
544
545    // ── Expression parser (precedence climbing) ───────────────────────
546
547    fn peek_is_additive_op(&self) -> bool {
548        self.peek().map_or(false, |t| {
549            t.token_type == "op" && (t.value == "+" || t.value == "-")
550        })
551    }
552
553    fn peek_is_multiplicative_op(&self) -> bool {
554        self.peek().map_or(false, |t| {
555            t.token_type == "op" && (t.value == "*" || t.value == "/" || t.value == "%")
556        })
557    }
558
559    fn parse_additive(&mut self) -> Result<Expr, MdqlError> {
560        let mut left = self.parse_multiplicative()?;
561        while self.peek_is_additive_op() {
562            let op_tok = self.advance();
563            let is_sub = op_tok.value == "-";
564
565            // Check for INTERVAL keyword: expr +/- INTERVAL n DAY
566            if self.peek().map_or(false, |t| t.token_type == "keyword" && t.value == "INTERVAL") {
567                self.advance(); // consume INTERVAL
568                let days_expr = self.parse_multiplicative()?;
569                // Expect DAY or DAYS
570                if !self.match_keyword("DAY") && !self.match_keyword("DAYS") {
571                    return Err(MdqlError::QueryParse("Expected DAY after INTERVAL value".into()));
572                }
573                let days = if is_sub {
574                    Expr::UnaryMinus(Box::new(days_expr))
575                } else {
576                    days_expr
577                };
578                left = Expr::DateAdd {
579                    date: Box::new(left),
580                    days: Box::new(days),
581                };
582                continue;
583            }
584
585            let op = match op_tok.value.as_str() {
586                "+" => ArithOp::Add,
587                "-" => ArithOp::Sub,
588                _ => unreachable!(),
589            };
590            let right = self.parse_multiplicative()?;
591            left = Expr::BinaryOp {
592                left: Box::new(left),
593                op,
594                right: Box::new(right),
595            };
596        }
597        Ok(left)
598    }
599
600    fn parse_multiplicative(&mut self) -> Result<Expr, MdqlError> {
601        let mut left = self.parse_unary()?;
602        while self.peek_is_multiplicative_op() {
603            let op_tok = self.advance();
604            let op = match op_tok.value.as_str() {
605                "*" => ArithOp::Mul,
606                "/" => ArithOp::Div,
607                "%" => ArithOp::Mod,
608                _ => unreachable!(),
609            };
610            let right = self.parse_unary()?;
611            left = Expr::BinaryOp {
612                left: Box::new(left),
613                op,
614                right: Box::new(right),
615            };
616        }
617        Ok(left)
618    }
619
620    fn parse_unary(&mut self) -> Result<Expr, MdqlError> {
621        if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "-") {
622            self.advance();
623            let inner = self.parse_atom()?;
624            // Fold unary minus on literals
625            match inner {
626                Expr::Literal(SqlValue::Int(n)) => Ok(Expr::Literal(SqlValue::Int(-n))),
627                Expr::Literal(SqlValue::Float(f)) => Ok(Expr::Literal(SqlValue::Float(-f))),
628                _ => Ok(Expr::UnaryMinus(Box::new(inner))),
629            }
630        } else {
631            self.parse_atom()
632        }
633    }
634
635    fn parse_atom(&mut self) -> Result<Expr, MdqlError> {
636        if self.peek_is_agg_func() {
637            return self.parse_agg_expr();
638        }
639
640        let t = self.peek().ok_or_else(|| {
641            MdqlError::QueryParse("Expected expression, got end of query".into())
642        })?;
643
644        match t.token_type.as_str() {
645            "number" => {
646                let v = self.advance().value;
647                if v.contains('.') {
648                    let f: f64 = v.parse().map_err(|_| {
649                        MdqlError::QueryParse(format!("Invalid float: {}", v))
650                    })?;
651                    Ok(Expr::Literal(SqlValue::Float(f)))
652                } else {
653                    let n: i64 = v.parse().map_err(|_| {
654                        MdqlError::QueryParse(format!("Invalid int: {}", v))
655                    })?;
656                    Ok(Expr::Literal(SqlValue::Int(n)))
657                }
658            }
659            "string" => {
660                let v = self.advance().value;
661                Ok(Expr::Literal(SqlValue::String(v)))
662            }
663            "keyword" if t.value == "NULL" => {
664                self.advance();
665                Ok(Expr::Literal(SqlValue::Null))
666            }
667            "keyword" if t.value == "CASE" => {
668                self.parse_case_expr()
669            }
670            "keyword" if t.value == "CURRENT_DATE" => {
671                self.advance();
672                Ok(Expr::CurrentDate)
673            }
674            "keyword" if t.value == "CURRENT_TIMESTAMP" => {
675                self.advance();
676                Ok(Expr::CurrentTimestamp)
677            }
678            "keyword" if t.value == "DATEDIFF" => {
679                self.advance();
680                self.expect("op", Some("("))?;
681                let left = self.parse_additive()?;
682                self.expect("op", Some(","))?;
683                let right = self.parse_additive()?;
684                self.expect("op", Some(")"))?;
685                Ok(Expr::DateDiff { left: Box::new(left), right: Box::new(right) })
686            }
687            "op" if t.value == "(" => {
688                self.advance();
689                let expr = self.parse_additive()?;
690                self.expect("op", Some(")"))?;
691                Ok(expr)
692            }
693            "ident" => {
694                let name = self.advance().value;
695                Ok(Expr::Column(name))
696            }
697            "keyword" if !Self::is_reserved_keyword(&t.value) => {
698                let name = self.advance().value;
699                Ok(Expr::Column(name))
700            }
701            _ => Err(MdqlError::QueryParse(format!(
702                "Expected expression, got '{}'",
703                t.raw
704            ))),
705        }
706    }
707
708    fn parse_case_expr(&mut self) -> Result<Expr, MdqlError> {
709        self.expect("keyword", Some("CASE"))?;
710        let mut whens = Vec::new();
711        while self.match_keyword("WHEN") {
712            let condition = self.parse_or_expr()?;
713            self.expect("keyword", Some("THEN"))?;
714            let result = self.parse_additive()?;
715            whens.push((condition, Box::new(result)));
716        }
717        if whens.is_empty() {
718            return Err(MdqlError::QueryParse("CASE requires at least one WHEN clause".into()));
719        }
720        let else_expr = if self.match_keyword("ELSE") {
721            Some(Box::new(self.parse_additive()?))
722        } else {
723            None
724        };
725        self.expect("keyword", Some("END"))?;
726        Ok(Expr::Case { whens, else_expr })
727    }
728
729    fn parse_agg_expr(&mut self) -> Result<Expr, MdqlError> {
730        let func_name = self.advance().value.to_uppercase();
731        let func = match func_name.as_str() {
732            "COUNT" => AggFunc::Count,
733            "SUM" => AggFunc::Sum,
734            "AVG" => AggFunc::Avg,
735            "MIN" => AggFunc::Min,
736            "MAX" => AggFunc::Max,
737            _ => unreachable!(),
738        };
739        self.expect("op", Some("("))?;
740        let (arg, arg_expr) = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "*") {
741            self.advance();
742            ("*".to_string(), None)
743        } else {
744            let expr = self.parse_additive()?;
745            if let Expr::Column(name) = &expr {
746                (name.clone(), None)
747            } else {
748                (expr.display_name(), Some(Box::new(expr)))
749            }
750        };
751        self.expect("op", Some(")"))?;
752        Ok(Expr::Aggregate { func, arg, arg_expr })
753    }
754
755    fn parse_ident(&mut self) -> Result<String, MdqlError> {
756        let t = self.peek().ok_or_else(|| {
757            MdqlError::QueryParse("Expected identifier, got end of query".into())
758        })?;
759        match t.token_type.as_str() {
760            "ident" | "keyword" => {
761                let v = self.advance().value;
762                Ok(v)
763            }
764            _ => Err(MdqlError::QueryParse(format!(
765                "Expected identifier, got '{}'",
766                t.raw
767            ))),
768        }
769    }
770
771    fn parse_or_expr(&mut self) -> Result<WhereClause, MdqlError> {
772        let mut left = self.parse_and_expr()?;
773        while self.match_keyword("OR") {
774            let right = self.parse_and_expr()?;
775            left = WhereClause::BoolOp(BoolOp {
776                op: "OR".into(),
777                left: Box::new(left),
778                right: Box::new(right),
779            });
780        }
781        Ok(left)
782    }
783
784    fn parse_and_expr(&mut self) -> Result<WhereClause, MdqlError> {
785        let mut left = self.parse_comparison()?;
786        while self.match_keyword("AND") {
787            let right = self.parse_comparison()?;
788            left = WhereClause::BoolOp(BoolOp {
789                op: "AND".into(),
790                left: Box::new(left),
791                right: Box::new(right),
792            });
793        }
794        Ok(left)
795    }
796
797    fn parse_comparison(&mut self) -> Result<WhereClause, MdqlError> {
798        // Handle parenthesized boolean expressions
799        if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
800            // Save position — might be arithmetic parens, not boolean
801            let saved_pos = self.pos;
802            self.advance();
803            // Try parsing as boolean (OR/AND) expression
804            let result = self.parse_or_expr();
805            if result.is_ok() && self.peek().map_or(false, |t| t.token_type == "op" && t.value == ")") {
806                self.advance();
807                return result;
808            }
809            // Not a boolean paren — rewind and parse as arithmetic expression
810            self.pos = saved_pos;
811        }
812
813        // Parse the left side as a full expression (column, literal, or arithmetic)
814        let left_expr = self.parse_additive()?;
815
816        // Extract column name for backward compat (simple column on left side)
817        let col = left_expr.as_column().unwrap_or("").to_string();
818
819        // IS NULL / IS NOT NULL (only valid with simple column)
820        if self.match_keyword("IS") {
821            if self.match_keyword("NOT") {
822                self.expect("keyword", Some("NULL"))?;
823                return Ok(WhereClause::Comparison(Comparison {
824                    column: col,
825                    op: "IS NOT NULL".into(),
826                    value: None,
827                    left_expr: Some(left_expr),
828                    right_expr: None,
829                }));
830            }
831            self.expect("keyword", Some("NULL"))?;
832            return Ok(WhereClause::Comparison(Comparison {
833                column: col,
834                op: "IS NULL".into(),
835                value: None,
836                left_expr: Some(left_expr),
837                right_expr: None,
838            }));
839        }
840
841        // IN (val, val, ...)
842        if self.match_keyword("IN") {
843            self.expect("op", Some("("))?;
844            let mut values = vec![self.parse_value()?];
845            while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
846                self.advance();
847                values.push(self.parse_value()?);
848            }
849            self.expect("op", Some(")"))?;
850            return Ok(WhereClause::Comparison(Comparison {
851                column: col,
852                op: "IN".into(),
853                value: Some(SqlValue::List(values)),
854                left_expr: Some(left_expr),
855                right_expr: None,
856            }));
857        }
858
859        // LIKE
860        if self.match_keyword("LIKE") {
861            let val = self.parse_value()?;
862            return Ok(WhereClause::Comparison(Comparison {
863                column: col,
864                op: "LIKE".into(),
865                value: Some(val),
866                left_expr: Some(left_expr),
867                right_expr: None,
868            }));
869        }
870
871        // NOT LIKE
872        if self.match_keyword("NOT") {
873            if self.match_keyword("LIKE") {
874                let val = self.parse_value()?;
875                return Ok(WhereClause::Comparison(Comparison {
876                    column: col,
877                    op: "NOT LIKE".into(),
878                    value: Some(val),
879                    left_expr: Some(left_expr),
880                    right_expr: None,
881                }));
882            }
883            return Err(MdqlError::QueryParse("Expected LIKE after NOT".into()));
884        }
885
886        // Standard comparison operators
887        if let Some(t) = self.peek() {
888            if t.token_type == "op" && ["=", "!=", "<", ">", "<=", ">="].contains(&t.value.as_str())
889            {
890                let op = self.advance().value;
891                // Parse right side as expression
892                let right_expr = self.parse_additive()?;
893                // Extract SqlValue for backward compat (simple literal on right side)
894                let value = match &right_expr {
895                    Expr::Literal(v) => Some(v.clone()),
896                    _ => None,
897                };
898                return Ok(WhereClause::Comparison(Comparison {
899                    column: col,
900                    op,
901                    value,
902                    left_expr: Some(left_expr),
903                    right_expr: Some(right_expr),
904                }));
905            }
906        }
907
908        let got = self.peek().map_or("end".to_string(), |t| t.raw.clone());
909        Err(MdqlError::QueryParse(format!(
910            "Expected operator after '{}', got '{}'",
911            left_expr.display_name(), got
912        )))
913    }
914
915    fn parse_value(&mut self) -> Result<SqlValue, MdqlError> {
916        let t = self.peek().ok_or_else(|| {
917            MdqlError::QueryParse("Expected value, got end of query".into())
918        })?;
919        match t.token_type.as_str() {
920            "string" => {
921                let v = self.advance().value;
922                Ok(SqlValue::String(v))
923            }
924            "number" => {
925                let v = self.advance().value;
926                if v.contains('.') {
927                    Ok(SqlValue::Float(v.parse().map_err(|_| {
928                        MdqlError::QueryParse(format!("Invalid float: {}", v))
929                    })?))
930                } else {
931                    Ok(SqlValue::Int(v.parse().map_err(|_| {
932                        MdqlError::QueryParse(format!("Invalid int: {}", v))
933                    })?))
934                }
935            }
936            "keyword" if t.value == "NULL" => {
937                self.advance();
938                Ok(SqlValue::Null)
939            }
940            _ => Err(MdqlError::QueryParse(format!(
941                "Expected value, got '{}'",
942                t.raw
943            ))),
944        }
945    }
946
947    fn parse_order_by(&mut self) -> Result<Vec<OrderSpec>, MdqlError> {
948        let mut specs = vec![self.parse_order_spec()?];
949        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
950            self.advance();
951            specs.push(self.parse_order_spec()?);
952        }
953        Ok(specs)
954    }
955
956    fn parse_order_spec(&mut self) -> Result<OrderSpec, MdqlError> {
957        let expr = self.parse_additive()?;
958        let col = expr.as_column().unwrap_or("").to_string();
959        let descending = if self.match_keyword("DESC") {
960            true
961        } else {
962            self.match_keyword("ASC");
963            false
964        };
965        Ok(OrderSpec {
966            column: col,
967            expr: Some(expr),
968            descending,
969        })
970    }
971
972    fn is_clause_keyword(&self, t: &Token) -> bool {
973        t.token_type == "keyword"
974            && ["WHERE", "ORDER", "LIMIT", "JOIN", "ON", "GROUP"].contains(&t.value.as_str())
975    }
976
977    /// Keywords that should never be consumed as column names inside expressions.
978    fn is_reserved_keyword(kw: &str) -> bool {
979        matches!(kw,
980            "AS" | "FROM" | "WHERE" | "AND" | "OR" | "ORDER" | "BY"
981            | "ASC" | "DESC" | "LIMIT" | "JOIN" | "ON" | "GROUP"
982            | "SELECT" | "INSERT" | "INTO" | "VALUES" | "UPDATE" | "SET"
983            | "DELETE" | "ALTER" | "TABLE" | "IS" | "NOT" | "IN" | "LIKE"
984            | "RENAME" | "FIELD" | "TO" | "DROP" | "MERGE" | "FIELDS"
985            | "CASE" | "WHEN" | "THEN" | "ELSE" | "END"
986            | "HAVING" | "INTERVAL" | "DAY" | "DAYS"
987            | "CURRENT_DATE" | "CURRENT_TIMESTAMP" | "DATEDIFF"
988            | "CREATE" | "VIEW" | "CASCADE" | "RESTRICT"
989        )
990    }
991
992    fn expect_end(&self) -> Result<(), MdqlError> {
993        if let Some(t) = self.peek() {
994            return Err(MdqlError::QueryParse(format!(
995                "Unexpected token '{}' at position {}",
996                t.raw, self.pos
997            )));
998        }
999        Ok(())
1000    }
1001}
1002
1003pub fn parse_query(sql: &str) -> crate::errors::Result<Statement> {
1004    let tokens = tokenize(sql);
1005    if tokens.is_empty() {
1006        return Err(MdqlError::QueryParse("Empty query".into()));
1007    }
1008    let mut parser = Parser::new(tokens);
1009    parser.parse_statement()
1010}
1011
1012#[cfg(test)]
1013mod tests {
1014    use super::*;
1015
1016    #[test]
1017    fn test_simple_select() {
1018        let stmt = parse_query("SELECT title, status FROM strategies").unwrap();
1019        if let Statement::Select(q) = stmt {
1020            assert_eq!(q.columns, ColumnList::Named(vec![SelectExpr::Column("title".into()), SelectExpr::Column("status".into())]));
1021            assert_eq!(q.table, "strategies");
1022        } else {
1023            panic!("Expected Select");
1024        }
1025    }
1026
1027    #[test]
1028    fn test_select_star() {
1029        let stmt = parse_query("SELECT * FROM test").unwrap();
1030        if let Statement::Select(q) = stmt {
1031            assert_eq!(q.columns, ColumnList::All);
1032        } else {
1033            panic!("Expected Select");
1034        }
1035    }
1036
1037    #[test]
1038    fn test_where_clause() {
1039        let stmt = parse_query("SELECT title FROM test WHERE count > 5").unwrap();
1040        if let Statement::Select(q) = stmt {
1041            assert!(q.where_clause.is_some());
1042        } else {
1043            panic!("Expected Select");
1044        }
1045    }
1046
1047    #[test]
1048    fn test_order_by() {
1049        let stmt =
1050            parse_query("SELECT title FROM test ORDER BY composite DESC, title ASC").unwrap();
1051        if let Statement::Select(q) = stmt {
1052            let ob = q.order_by.unwrap();
1053            assert_eq!(ob.len(), 2);
1054            assert!(ob[0].descending);
1055            assert!(!ob[1].descending);
1056        } else {
1057            panic!("Expected Select");
1058        }
1059    }
1060
1061    #[test]
1062    fn test_limit() {
1063        let stmt = parse_query("SELECT * FROM test LIMIT 10").unwrap();
1064        if let Statement::Select(q) = stmt {
1065            assert_eq!(q.limit, Some(10));
1066        } else {
1067            panic!("Expected Select");
1068        }
1069    }
1070
1071    #[test]
1072    fn test_insert() {
1073        let stmt = parse_query(
1074            "INSERT INTO test (title, count) VALUES ('Hello', 42)",
1075        )
1076        .unwrap();
1077        if let Statement::Insert(q) = stmt {
1078            assert_eq!(q.table, "test");
1079            assert_eq!(q.columns, vec!["title", "count"]);
1080            assert_eq!(q.values[0], SqlValue::String("Hello".into()));
1081            assert_eq!(q.values[1], SqlValue::Int(42));
1082        } else {
1083            panic!("Expected Insert");
1084        }
1085    }
1086
1087    #[test]
1088    fn test_update() {
1089        let stmt = parse_query("UPDATE test SET status = 'KILLED' WHERE path = 'a.md'").unwrap();
1090        if let Statement::Update(q) = stmt {
1091            assert_eq!(q.table, "test");
1092            assert_eq!(q.assignments.len(), 1);
1093            assert!(q.where_clause.is_some());
1094        } else {
1095            panic!("Expected Update");
1096        }
1097    }
1098
1099    #[test]
1100    fn test_delete() {
1101        let stmt = parse_query("DELETE FROM test WHERE status = 'draft'").unwrap();
1102        if let Statement::Delete(q) = stmt {
1103            assert_eq!(q.table, "test");
1104            assert!(q.where_clause.is_some());
1105        } else {
1106            panic!("Expected Delete");
1107        }
1108    }
1109
1110    #[test]
1111    fn test_alter_rename() {
1112        let stmt =
1113            parse_query("ALTER TABLE test RENAME FIELD 'Summary' TO 'Overview'").unwrap();
1114        if let Statement::AlterRename(q) = stmt {
1115            assert_eq!(q.old_name, "Summary");
1116            assert_eq!(q.new_name, "Overview");
1117        } else {
1118            panic!("Expected AlterRename");
1119        }
1120    }
1121
1122    #[test]
1123    fn test_alter_drop() {
1124        let stmt = parse_query("ALTER TABLE test DROP FIELD 'Details'").unwrap();
1125        if let Statement::AlterDrop(q) = stmt {
1126            assert_eq!(q.field_name, "Details");
1127        } else {
1128            panic!("Expected AlterDrop");
1129        }
1130    }
1131
1132    #[test]
1133    fn test_alter_merge() {
1134        let stmt = parse_query(
1135            "ALTER TABLE test MERGE FIELDS 'Entry Rules', 'Exit Rules' INTO 'Trading Rules'",
1136        )
1137        .unwrap();
1138        if let Statement::AlterMerge(q) = stmt {
1139            assert_eq!(q.sources, vec!["Entry Rules", "Exit Rules"]);
1140            assert_eq!(q.into, "Trading Rules");
1141        } else {
1142            panic!("Expected AlterMerge");
1143        }
1144    }
1145
1146    #[test]
1147    fn test_backtick_ident() {
1148        let stmt = parse_query("SELECT `Structural Mechanism` FROM test").unwrap();
1149        if let Statement::Select(q) = stmt {
1150            assert_eq!(
1151                q.columns,
1152                ColumnList::Named(vec![SelectExpr::Column("Structural Mechanism".into())])
1153            );
1154        } else {
1155            panic!("Expected Select");
1156        }
1157    }
1158
1159    #[test]
1160    fn test_like_operator() {
1161        let stmt = parse_query("SELECT title FROM test WHERE categories LIKE '%defi%'").unwrap();
1162        if let Statement::Select(q) = stmt {
1163            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1164                assert_eq!(c.op, "LIKE");
1165                assert_eq!(c.value, Some(SqlValue::String("%defi%".into())));
1166            } else {
1167                panic!("Expected LIKE comparison");
1168            }
1169        } else {
1170            panic!("Expected Select");
1171        }
1172    }
1173
1174    #[test]
1175    fn test_in_operator() {
1176        let stmt =
1177            parse_query("SELECT * FROM test WHERE status IN ('ACTIVE', 'LIVE')").unwrap();
1178        if let Statement::Select(q) = stmt {
1179            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1180                assert_eq!(c.op, "IN");
1181            } else {
1182                panic!("Expected IN comparison");
1183            }
1184        } else {
1185            panic!("Expected Select");
1186        }
1187    }
1188
1189    #[test]
1190    fn test_is_null() {
1191        let stmt = parse_query("SELECT * FROM test WHERE title IS NULL").unwrap();
1192        if let Statement::Select(q) = stmt {
1193            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1194                assert_eq!(c.op, "IS NULL");
1195            } else {
1196                panic!("Expected IS NULL comparison");
1197            }
1198        } else {
1199            panic!("Expected Select");
1200        }
1201    }
1202
1203    #[test]
1204    fn test_and_or() {
1205        let stmt = parse_query(
1206            "SELECT * FROM test WHERE status = 'ACTIVE' AND count > 5 OR title LIKE '%test%'",
1207        )
1208        .unwrap();
1209        if let Statement::Select(q) = stmt {
1210            assert!(q.where_clause.is_some());
1211        } else {
1212            panic!("Expected Select");
1213        }
1214    }
1215
1216    #[test]
1217    fn test_join() {
1218        let stmt = parse_query(
1219            "SELECT s.title, b.sharpe FROM strategies s JOIN backtests b ON b.strategy = s.path",
1220        )
1221        .unwrap();
1222        if let Statement::Select(q) = stmt {
1223            assert_eq!(q.table, "strategies");
1224            assert_eq!(q.table_alias, Some("s".into()));
1225            assert_eq!(q.joins.len(), 1);
1226            let join = &q.joins[0];
1227            assert_eq!(join.table, "backtests");
1228            assert_eq!(join.alias, Some("b".into()));
1229        } else {
1230            panic!("Expected Select");
1231        }
1232    }
1233
1234    #[test]
1235    fn test_multi_join() {
1236        let stmt = parse_query(
1237            "SELECT s.title, b.sharpe, c.verdict FROM strategies s JOIN backtests b ON b.strategy = s.path JOIN critiques c ON c.strategy = s.path",
1238        )
1239        .unwrap();
1240        if let Statement::Select(q) = stmt {
1241            assert_eq!(q.table, "strategies");
1242            assert_eq!(q.table_alias, Some("s".into()));
1243            assert_eq!(q.joins.len(), 2);
1244            assert_eq!(q.joins[0].table, "backtests");
1245            assert_eq!(q.joins[0].alias, Some("b".into()));
1246            assert_eq!(q.joins[0].left_col, "b.strategy");
1247            assert_eq!(q.joins[0].right_col, "s.path");
1248            assert_eq!(q.joins[1].table, "critiques");
1249            assert_eq!(q.joins[1].alias, Some("c".into()));
1250            assert_eq!(q.joins[1].left_col, "c.strategy");
1251            assert_eq!(q.joins[1].right_col, "s.path");
1252        } else {
1253            panic!("Expected Select");
1254        }
1255    }
1256
1257    #[test]
1258    fn test_empty_query() {
1259        assert!(parse_query("").is_err());
1260    }
1261
1262    #[test]
1263    fn test_count_star() {
1264        let stmt = parse_query("SELECT status, COUNT(*) AS cnt FROM strategies GROUP BY status").unwrap();
1265        if let Statement::Select(q) = stmt {
1266            if let ColumnList::Named(exprs) = &q.columns {
1267                assert_eq!(exprs.len(), 2);
1268                assert_eq!(exprs[0], SelectExpr::Column("status".into()));
1269                assert!(matches!(&exprs[1], SelectExpr::Aggregate {
1270                    func: AggFunc::Count,
1271                    arg,
1272                    alias: Some(a),
1273                    ..
1274                } if arg == "*" && a == "cnt"));
1275            } else {
1276                panic!("Expected Named columns");
1277            }
1278            assert_eq!(q.group_by, Some(vec!["status".into()]));
1279        } else {
1280            panic!("Expected Select");
1281        }
1282    }
1283
1284    #[test]
1285    fn test_count_column_as_ident() {
1286        // "count" as a column name should NOT be parsed as the COUNT aggregate
1287        let stmt = parse_query("INSERT INTO test (title, count) VALUES ('Hello', 42)").unwrap();
1288        if let Statement::Insert(q) = stmt {
1289            assert_eq!(q.columns, vec!["title", "count"]);
1290        } else {
1291            panic!("Expected Insert");
1292        }
1293    }
1294
1295    #[test]
1296    fn test_multiple_aggregates() {
1297        let stmt = parse_query("SELECT MIN(composite), MAX(composite), AVG(composite) FROM strategies").unwrap();
1298        if let Statement::Select(q) = stmt {
1299            if let ColumnList::Named(exprs) = &q.columns {
1300                assert_eq!(exprs.len(), 3);
1301                assert!(matches!(&exprs[0], SelectExpr::Aggregate { func: AggFunc::Min, .. }));
1302                assert!(matches!(&exprs[1], SelectExpr::Aggregate { func: AggFunc::Max, .. }));
1303                assert!(matches!(&exprs[2], SelectExpr::Aggregate { func: AggFunc::Avg, .. }));
1304            } else {
1305                panic!("Expected Named columns");
1306            }
1307            assert_eq!(q.group_by, None);
1308        } else {
1309            panic!("Expected Select");
1310        }
1311    }
1312
1313    // ── Expression tests ──────────────────────────────────────────
1314
1315    #[test]
1316    fn test_select_arithmetic_expr() {
1317        let stmt = parse_query("SELECT a + b FROM test").unwrap();
1318        if let Statement::Select(q) = stmt {
1319            if let ColumnList::Named(exprs) = &q.columns {
1320                assert_eq!(exprs.len(), 1);
1321                assert!(matches!(&exprs[0], SelectExpr::Expr {
1322                    expr: Expr::BinaryOp { op: ArithOp::Add, .. },
1323                    alias: None,
1324                }));
1325            } else {
1326                panic!("Expected Named columns");
1327            }
1328        } else {
1329            panic!("Expected Select");
1330        }
1331    }
1332
1333    #[test]
1334    fn test_select_arithmetic_with_alias() {
1335        let stmt = parse_query("SELECT a + b AS total FROM test").unwrap();
1336        if let Statement::Select(q) = stmt {
1337            if let ColumnList::Named(exprs) = &q.columns {
1338                assert_eq!(exprs.len(), 1);
1339                assert!(matches!(&exprs[0], SelectExpr::Expr {
1340                    alias: Some(a),
1341                    ..
1342                } if a == "total"));
1343                assert_eq!(exprs[0].output_name(), "total");
1344            } else {
1345                panic!("Expected Named columns");
1346            }
1347        } else {
1348            panic!("Expected Select");
1349        }
1350    }
1351
1352    #[test]
1353    fn test_select_precedence() {
1354        // a + b * c should parse as a + (b * c)
1355        let stmt = parse_query("SELECT a + b * c FROM test").unwrap();
1356        if let Statement::Select(q) = stmt {
1357            if let ColumnList::Named(exprs) = &q.columns {
1358                if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1359                    if let Expr::BinaryOp { left, op, right } = expr {
1360                        assert_eq!(*op, ArithOp::Add);
1361                        assert!(matches!(left.as_ref(), Expr::Column(n) if n == "a"));
1362                        assert!(matches!(right.as_ref(), Expr::BinaryOp { op: ArithOp::Mul, .. }));
1363                    } else {
1364                        panic!("Expected BinaryOp");
1365                    }
1366                } else {
1367                    panic!("Expected Expr variant");
1368                }
1369            } else {
1370                panic!("Expected Named columns");
1371            }
1372        } else {
1373            panic!("Expected Select");
1374        }
1375    }
1376
1377    #[test]
1378    fn test_select_parenthesized_expr() {
1379        // (a + b) * c should override default precedence
1380        let stmt = parse_query("SELECT (a + b) * c FROM test").unwrap();
1381        if let Statement::Select(q) = stmt {
1382            if let ColumnList::Named(exprs) = &q.columns {
1383                if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1384                    if let Expr::BinaryOp { left, op, .. } = expr {
1385                        assert_eq!(*op, ArithOp::Mul);
1386                        assert!(matches!(left.as_ref(), Expr::BinaryOp { op: ArithOp::Add, .. }));
1387                    } else {
1388                        panic!("Expected BinaryOp");
1389                    }
1390                } else {
1391                    panic!("Expected Expr variant");
1392                }
1393            } else {
1394                panic!("Expected Named columns");
1395            }
1396        } else {
1397            panic!("Expected Select");
1398        }
1399    }
1400
1401    #[test]
1402    fn test_select_unary_minus() {
1403        let stmt = parse_query("SELECT -count FROM test").unwrap();
1404        if let Statement::Select(q) = stmt {
1405            if let ColumnList::Named(exprs) = &q.columns {
1406                assert!(matches!(&exprs[0], SelectExpr::Expr {
1407                    expr: Expr::UnaryMinus(_),
1408                    ..
1409                }));
1410            } else {
1411                panic!("Expected Named columns");
1412            }
1413        } else {
1414            panic!("Expected Select");
1415        }
1416    }
1417
1418    #[test]
1419    fn test_select_negative_literal() {
1420        let stmt = parse_query("SELECT -42 FROM test").unwrap();
1421        if let Statement::Select(q) = stmt {
1422            if let ColumnList::Named(exprs) = &q.columns {
1423                // Unary minus folds into the literal
1424                assert!(matches!(&exprs[0], SelectExpr::Expr {
1425                    expr: Expr::Literal(SqlValue::Int(-42)),
1426                    ..
1427                }));
1428            } else {
1429                panic!("Expected Named columns");
1430            }
1431        } else {
1432            panic!("Expected Select");
1433        }
1434    }
1435
1436    #[test]
1437    fn test_where_arithmetic_expr() {
1438        let stmt = parse_query("SELECT * FROM test WHERE a + b > 10").unwrap();
1439        if let Statement::Select(q) = stmt {
1440            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1441                assert_eq!(c.op, ">");
1442                assert!(matches!(&c.left_expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1443                assert!(matches!(&c.right_expr, Some(Expr::Literal(SqlValue::Int(10)))));
1444            } else {
1445                panic!("Expected comparison");
1446            }
1447        } else {
1448            panic!("Expected Select");
1449        }
1450    }
1451
1452    #[test]
1453    fn test_where_both_sides_expr() {
1454        let stmt = parse_query("SELECT * FROM test WHERE a * 2 > b + 1").unwrap();
1455        if let Statement::Select(q) = stmt {
1456            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1457                assert_eq!(c.op, ">");
1458                assert!(matches!(&c.left_expr, Some(Expr::BinaryOp { op: ArithOp::Mul, .. })));
1459                assert!(matches!(&c.right_expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1460            } else {
1461                panic!("Expected comparison");
1462            }
1463        } else {
1464            panic!("Expected Select");
1465        }
1466    }
1467
1468    #[test]
1469    fn test_order_by_expr() {
1470        let stmt = parse_query("SELECT * FROM test ORDER BY a + b DESC").unwrap();
1471        if let Statement::Select(q) = stmt {
1472            let ob = q.order_by.unwrap();
1473            assert_eq!(ob.len(), 1);
1474            assert!(ob[0].descending);
1475            assert!(matches!(&ob[0].expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1476        } else {
1477            panic!("Expected Select");
1478        }
1479    }
1480
1481    #[test]
1482    fn test_all_arithmetic_ops() {
1483        let stmt = parse_query("SELECT a + b, a - b, a * b, a / b, a % b FROM test").unwrap();
1484        if let Statement::Select(q) = stmt {
1485            if let ColumnList::Named(exprs) = &q.columns {
1486                assert_eq!(exprs.len(), 5);
1487                assert!(matches!(&exprs[0], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Add, .. }, .. }));
1488                assert!(matches!(&exprs[1], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Sub, .. }, .. }));
1489                assert!(matches!(&exprs[2], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Mul, .. }, .. }));
1490                assert!(matches!(&exprs[3], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Div, .. }, .. }));
1491                assert!(matches!(&exprs[4], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Mod, .. }, .. }));
1492            } else {
1493                panic!("Expected Named columns");
1494            }
1495        } else {
1496            panic!("Expected Select");
1497        }
1498    }
1499
1500    #[test]
1501    fn test_column_with_literal_arithmetic() {
1502        let stmt = parse_query("SELECT count * 2 + 1 FROM test").unwrap();
1503        if let Statement::Select(q) = stmt {
1504            if let ColumnList::Named(exprs) = &q.columns {
1505                // Should be (count * 2) + 1
1506                if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1507                    if let Expr::BinaryOp { left, op, right } = expr {
1508                        assert_eq!(*op, ArithOp::Add);
1509                        assert!(matches!(right.as_ref(), Expr::Literal(SqlValue::Int(1))));
1510                        assert!(matches!(left.as_ref(), Expr::BinaryOp { op: ArithOp::Mul, .. }));
1511                    } else {
1512                        panic!("Expected BinaryOp");
1513                    }
1514                } else {
1515                    panic!("Expected Expr");
1516                }
1517            } else {
1518                panic!("Expected Named columns");
1519            }
1520        } else {
1521            panic!("Expected Select");
1522        }
1523    }
1524
1525    #[test]
1526    fn test_mixed_columns_and_exprs() {
1527        let stmt = parse_query("SELECT title, a + b AS sum, count FROM test").unwrap();
1528        if let Statement::Select(q) = stmt {
1529            if let ColumnList::Named(exprs) = &q.columns {
1530                assert_eq!(exprs.len(), 3);
1531                assert_eq!(exprs[0], SelectExpr::Column("title".into()));
1532                assert!(matches!(&exprs[1], SelectExpr::Expr { alias: Some(a), .. } if a == "sum"));
1533                assert_eq!(exprs[2], SelectExpr::Column("count".into()));
1534            } else {
1535                panic!("Expected Named columns");
1536            }
1537        } else {
1538            panic!("Expected Select");
1539        }
1540    }
1541
1542    // ── CASE WHEN tests ──────────────────────────────────────────
1543
1544    #[test]
1545    fn test_case_when_basic() {
1546        let stmt = parse_query(
1547            "SELECT CASE WHEN status = 'ACTIVE' THEN 1 ELSE 0 END FROM test"
1548        ).unwrap();
1549        if let Statement::Select(q) = stmt {
1550            if let ColumnList::Named(exprs) = &q.columns {
1551                assert_eq!(exprs.len(), 1);
1552                assert!(matches!(&exprs[0], SelectExpr::Expr {
1553                    expr: Expr::Case { .. },
1554                    ..
1555                }));
1556            } else {
1557                panic!("Expected Named columns");
1558            }
1559        } else {
1560            panic!("Expected Select");
1561        }
1562    }
1563
1564    #[test]
1565    fn test_case_when_multiple_branches() {
1566        let stmt = parse_query(
1567            "SELECT CASE WHEN x > 10 THEN 'high' WHEN x > 5 THEN 'mid' ELSE 'low' END FROM test"
1568        ).unwrap();
1569        if let Statement::Select(q) = stmt {
1570            if let ColumnList::Named(exprs) = &q.columns {
1571                if let SelectExpr::Expr { expr: Expr::Case { whens, else_expr }, .. } = &exprs[0] {
1572                    assert_eq!(whens.len(), 2);
1573                    assert!(else_expr.is_some());
1574                } else {
1575                    panic!("Expected Case expression");
1576                }
1577            } else {
1578                panic!("Expected Named columns");
1579            }
1580        } else {
1581            panic!("Expected Select");
1582        }
1583    }
1584
1585    #[test]
1586    fn test_case_when_no_else() {
1587        let stmt = parse_query(
1588            "SELECT CASE WHEN x = 1 THEN 'one' END FROM test"
1589        ).unwrap();
1590        if let Statement::Select(q) = stmt {
1591            if let ColumnList::Named(exprs) = &q.columns {
1592                if let SelectExpr::Expr { expr: Expr::Case { whens, else_expr }, .. } = &exprs[0] {
1593                    assert_eq!(whens.len(), 1);
1594                    assert!(else_expr.is_none());
1595                } else {
1596                    panic!("Expected Case expression");
1597                }
1598            } else {
1599                panic!("Expected Named columns");
1600            }
1601        } else {
1602            panic!("Expected Select");
1603        }
1604    }
1605
1606    #[test]
1607    fn test_case_when_in_aggregate() {
1608        let stmt = parse_query(
1609            "SELECT SUM(CASE WHEN side = 'BUY' THEN size ELSE -size END) AS net FROM orders GROUP BY token"
1610        ).unwrap();
1611        if let Statement::Select(q) = stmt {
1612            if let ColumnList::Named(exprs) = &q.columns {
1613                assert_eq!(exprs.len(), 1);
1614                assert!(matches!(&exprs[0], SelectExpr::Aggregate {
1615                    func: AggFunc::Sum,
1616                    arg_expr: Some(Expr::Case { .. }),
1617                    alias: Some(a),
1618                    ..
1619                } if a == "net"));
1620            } else {
1621                panic!("Expected Named columns");
1622            }
1623        } else {
1624            panic!("Expected Select");
1625        }
1626    }
1627
1628    #[test]
1629    fn test_case_when_with_alias() {
1630        let stmt = parse_query(
1631            "SELECT CASE WHEN x > 0 THEN 'pos' ELSE 'neg' END AS sign FROM test"
1632        ).unwrap();
1633        if let Statement::Select(q) = stmt {
1634            if let ColumnList::Named(exprs) = &q.columns {
1635                assert!(matches!(&exprs[0], SelectExpr::Expr {
1636                    expr: Expr::Case { .. },
1637                    alias: Some(a),
1638                } if a == "sign"));
1639            } else {
1640                panic!("Expected Named columns");
1641            }
1642        } else {
1643            panic!("Expected Select");
1644        }
1645    }
1646
1647    #[test]
1648    fn test_create_view() {
1649        let stmt = parse_query("CREATE VIEW live AS SELECT * FROM strategies WHERE status = 'LIVE'").unwrap();
1650        if let Statement::CreateView(cv) = stmt {
1651            assert_eq!(cv.view_name, "live");
1652            assert!(cv.columns.is_none());
1653            assert_eq!(cv.query.table, "strategies");
1654            assert!(cv.query.where_clause.is_some());
1655        } else {
1656            panic!("Expected CreateView, got {:?}", stmt);
1657        }
1658    }
1659
1660    #[test]
1661    fn test_create_view_with_columns() {
1662        let stmt = parse_query("CREATE VIEW v1 (a, b) AS SELECT title, status FROM t").unwrap();
1663        if let Statement::CreateView(cv) = stmt {
1664            assert_eq!(cv.view_name, "v1");
1665            assert_eq!(cv.columns, Some(vec!["a".into(), "b".into()]));
1666        } else {
1667            panic!("Expected CreateView");
1668        }
1669    }
1670
1671    #[test]
1672    fn test_drop_view() {
1673        let stmt = parse_query("DROP VIEW live").unwrap();
1674        if let Statement::DropView(dv) = stmt {
1675            assert_eq!(dv.view_name, "live");
1676        } else {
1677            panic!("Expected DropView, got {:?}", stmt);
1678        }
1679    }
1680
1681    #[test]
1682    fn test_create_view_case_insensitive() {
1683        let stmt = parse_query("create view My_View as select * from t").unwrap();
1684        if let Statement::CreateView(cv) = stmt {
1685            assert_eq!(cv.view_name, "My_View");
1686        } else {
1687            panic!("Expected CreateView");
1688        }
1689    }
1690
1691    // ── Issue #42: Arithmetic between aggregates in column expressions ──
1692
1693    #[test]
1694    fn test_aggregate_division() {
1695        let stmt = parse_query(
1696            "SELECT token, SUM(sell) / SUM(buy) as ratio FROM orders GROUP BY token"
1697        ).unwrap();
1698        if let Statement::Select(q) = stmt {
1699            assert_eq!(q.group_by, Some(vec!["token".into()]));
1700            if let ColumnList::Named(exprs) = &q.columns {
1701                assert_eq!(exprs.len(), 2);
1702                assert!(exprs[1].is_aggregate());
1703            } else {
1704                panic!("Expected Named columns");
1705            }
1706        } else {
1707            panic!("Expected Select");
1708        }
1709    }
1710
1711    #[test]
1712    fn test_aggregate_subtraction() {
1713        let stmt = parse_query(
1714            "SELECT token, SUM(sell) - SUM(buy) as net FROM orders GROUP BY token"
1715        ).unwrap();
1716        if let Statement::Select(q) = stmt {
1717            if let ColumnList::Named(exprs) = &q.columns {
1718                assert_eq!(exprs[1].output_name(), "net");
1719            }
1720        } else {
1721            panic!("Expected Select");
1722        }
1723    }
1724
1725    #[test]
1726    fn test_create_view_with_arithmetic() {
1727        let stmt = parse_query(
1728            "CREATE VIEW positions AS SELECT token, SUM(sell) / SUM(buy) as ratio FROM orders GROUP BY token"
1729        ).unwrap();
1730        if let Statement::CreateView(cv) = stmt {
1731            assert_eq!(cv.view_name, "positions");
1732        } else {
1733            panic!("Expected CreateView, got {:?}", stmt);
1734        }
1735    }
1736
1737    // ── Issue #43: Subqueries in FROM ──
1738
1739    #[test]
1740    fn test_subquery_in_from() {
1741        let stmt = parse_query(
1742            "SELECT token, sell_size FROM (SELECT token, SUM(size) as sell_size FROM orders GROUP BY token) LIMIT 5"
1743        ).unwrap();
1744        if let Statement::Select(q) = stmt {
1745            assert!(q.subquery.is_some());
1746            assert_eq!(q.limit, Some(5));
1747            let sub = q.subquery.unwrap();
1748            assert_eq!(sub.table, "orders");
1749            assert!(sub.group_by.is_some());
1750        } else {
1751            panic!("Expected Select");
1752        }
1753    }
1754
1755    // ── Issue #44: HAVING in CREATE VIEW ──
1756
1757    #[test]
1758    fn test_create_view_with_having() {
1759        let stmt = parse_query(
1760            "CREATE VIEW positions AS SELECT token, SUM(sell) as sell_size, SUM(buy) as buy_size FROM orders GROUP BY token HAVING sell_size > buy_size"
1761        ).unwrap();
1762        if let Statement::CreateView(cv) = stmt {
1763            assert_eq!(cv.view_name, "positions");
1764            assert!(cv.query.having.is_some());
1765        } else {
1766            panic!("Expected CreateView, got {:?}", stmt);
1767        }
1768    }
1769
1770    // ── Issue #42: Aggregate multiplication ──
1771
1772    #[test]
1773    fn test_aggregate_multiplication() {
1774        let stmt = parse_query(
1775            "SELECT SUM(a) * 2 as doubled FROM test"
1776        ).unwrap();
1777        if let Statement::Select(q) = stmt {
1778            if let ColumnList::Named(exprs) = &q.columns {
1779                assert_eq!(exprs.len(), 1);
1780                assert!(exprs[0].is_aggregate());
1781                assert_eq!(exprs[0].output_name(), "doubled");
1782            } else {
1783                panic!("Expected Named columns");
1784            }
1785        } else {
1786            panic!("Expected Select");
1787        }
1788    }
1789
1790    #[test]
1791    fn test_complex_aggregate_arithmetic() {
1792        let stmt = parse_query(
1793            "SELECT SUM(CASE WHEN side = 'SELL' THEN size ELSE 0 END) / SUM(CASE WHEN side = 'BUY' THEN size ELSE 0 END) as ratio FROM orders GROUP BY token"
1794        ).unwrap();
1795        if let Statement::Select(q) = stmt {
1796            if let ColumnList::Named(exprs) = &q.columns {
1797                assert_eq!(exprs.len(), 1);
1798                assert!(exprs[0].is_aggregate());
1799                assert_eq!(exprs[0].output_name(), "ratio");
1800            } else {
1801                panic!("Expected Named columns");
1802            }
1803            assert_eq!(q.group_by, Some(vec!["token".into()]));
1804        } else {
1805            panic!("Expected Select");
1806        }
1807    }
1808
1809    // ── Issue #43: Subquery with alias and WHERE ──
1810
1811    #[test]
1812    fn test_subquery_with_alias() {
1813        let stmt = parse_query(
1814            "SELECT x FROM (SELECT x FROM t) sub"
1815        ).unwrap();
1816        if let Statement::Select(q) = stmt {
1817            assert!(q.subquery.is_some());
1818            let sub = q.subquery.unwrap();
1819            assert_eq!(sub.table, "t");
1820            if let ColumnList::Named(exprs) = &q.columns {
1821                assert_eq!(exprs.len(), 1);
1822                assert_eq!(exprs[0].output_name(), "x");
1823            } else {
1824                panic!("Expected Named columns");
1825            }
1826        } else {
1827            panic!("Expected Select");
1828        }
1829    }
1830
1831    #[test]
1832    fn test_subquery_with_where() {
1833        let stmt = parse_query(
1834            "SELECT x FROM (SELECT x FROM t WHERE y > 0) LIMIT 5"
1835        ).unwrap();
1836        if let Statement::Select(q) = stmt {
1837            assert!(q.subquery.is_some());
1838            assert_eq!(q.limit, Some(5));
1839            let sub = q.subquery.unwrap();
1840            assert_eq!(sub.table, "t");
1841            assert!(sub.where_clause.is_some());
1842        } else {
1843            panic!("Expected Select");
1844        }
1845    }
1846
1847    // ── Issue #42 + CREATE VIEW: aggregate subtraction in view ──
1848
1849    #[test]
1850    fn test_create_view_aggregate_subtraction() {
1851        let stmt = parse_query(
1852            "CREATE VIEW v AS SELECT token, SUM(sell) - SUM(buy) as net FROM orders GROUP BY token"
1853        ).unwrap();
1854        if let Statement::CreateView(cv) = stmt {
1855            assert_eq!(cv.view_name, "v");
1856            assert_eq!(cv.query.group_by, Some(vec!["token".into()]));
1857            if let ColumnList::Named(exprs) = &cv.query.columns {
1858                assert_eq!(exprs.len(), 2);
1859                assert_eq!(exprs[1].output_name(), "net");
1860                assert!(exprs[1].is_aggregate());
1861            } else {
1862                panic!("Expected Named columns");
1863            }
1864        } else {
1865            panic!("Expected CreateView, got {:?}", stmt);
1866        }
1867    }
1868}