Skip to main content

mdql_core/
query_parser.rs

1//! Hand-written recursive descent parser for the MDQL SQL subset.
2
3use regex::Regex;
4use std::sync::LazyLock;
5
6use crate::errors::MdqlError;
7pub use crate::query_ast::*;
8
9// ── Tokenizer ──────────────────────────────────────────────────────────────
10
11static KEYWORDS: &[&str] = &[
12    "SELECT", "FROM", "WHERE", "AND", "OR", "ORDER", "BY",
13    "ASC", "DESC", "LIMIT", "LIKE", "IN", "IS", "NOT", "NULL",
14    "JOIN", "ON", "AS", "GROUP", "HAVING",
15    "INSERT", "INTO", "VALUES", "UPDATE", "SET", "DELETE",
16    "ALTER", "TABLE", "RENAME", "FIELD", "TO", "DROP", "MERGE", "FIELDS",
17    "CASE", "WHEN", "THEN", "ELSE", "END",
18    "INTERVAL", "DAY", "DAYS", "CURRENT_DATE", "CURRENT_TIMESTAMP", "DATEDIFF",
19    "CREATE", "VIEW", "CASCADE", "RESTRICT",
20];
21
22static AGG_FUNCS: &[&str] = &["COUNT", "SUM", "AVG", "MIN", "MAX"];
23
24static TOKEN_RE: LazyLock<Regex> = LazyLock::new(|| {
25    Regex::new(
26        r#"(?x)
27        \s*(?:
28            (?P<backtick>`[^`]+`)
29            | (?P<string>'(?:[^'\\]|\\.)*')
30            | (?P<number>\d+(?:\.\d+)?)
31            | (?P<op><=|>=|!=|[=<>,*()+\-/%])
32            | (?P<word>[A-Za-z_][A-Za-z0-9_./-]*)
33        )"#,
34    )
35    .unwrap()
36});
37
38#[derive(Debug, Clone)]
39struct Token {
40    token_type: String,
41    value: String,
42    raw: String,
43}
44
45fn tokenize(sql: &str) -> Vec<Token> {
46    let mut tokens = Vec::new();
47    for caps in TOKEN_RE.captures_iter(sql) {
48        if let Some(m) = caps.name("backtick") {
49            let raw = m.as_str();
50            tokens.push(Token {
51                token_type: "ident".into(),
52                value: raw[1..raw.len() - 1].into(),
53                raw: raw.into(),
54            });
55        } else if let Some(m) = caps.name("string") {
56            let raw = m.as_str();
57            tokens.push(Token {
58                token_type: "string".into(),
59                value: raw[1..raw.len() - 1].into(),
60                raw: raw.into(),
61            });
62        } else if let Some(m) = caps.name("number") {
63            let raw = m.as_str();
64            tokens.push(Token {
65                token_type: "number".into(),
66                value: raw.into(),
67                raw: raw.into(),
68            });
69        } else if let Some(m) = caps.name("op") {
70            let raw = m.as_str();
71            tokens.push(Token {
72                token_type: "op".into(),
73                value: raw.into(),
74                raw: raw.into(),
75            });
76        } else if let Some(m) = caps.name("word") {
77            let raw = m.as_str();
78            if KEYWORDS.contains(&raw.to_uppercase().as_str()) {
79                tokens.push(Token {
80                    token_type: "keyword".into(),
81                    value: raw.to_uppercase(),
82                    raw: raw.into(),
83                });
84            } else {
85                tokens.push(Token {
86                    token_type: "ident".into(),
87                    value: raw.into(),
88                    raw: raw.into(),
89                });
90            }
91        }
92    }
93    tokens
94}
95
96// ── Parser ─────────────────────────────────────────────────────────────────
97
98struct Parser {
99    tokens: Vec<Token>,
100    pos: usize,
101}
102
103impl Parser {
104    fn new(tokens: Vec<Token>) -> Self {
105        Parser { tokens, pos: 0 }
106    }
107
108    fn peek(&self) -> Option<&Token> {
109        self.tokens.get(self.pos)
110    }
111
112    fn advance(&mut self) -> Token {
113        let t = self.tokens[self.pos].clone();
114        self.pos += 1;
115        t
116    }
117
118    fn expect(&mut self, type_: &str, value: Option<&str>) -> Result<Token, MdqlError> {
119        let t = self.peek().ok_or_else(|| {
120            MdqlError::QueryParse(format!(
121                "Unexpected end of query, expected {}",
122                value.unwrap_or(type_)
123            ))
124        })?;
125        let matches_type = t.token_type == type_;
126        let matches_value = value.map_or(true, |v| t.value == v);
127        if !matches_type || !matches_value {
128            return Err(MdqlError::QueryParse(format!(
129                "Expected {}, got '{}' at position {}",
130                value.unwrap_or(type_),
131                t.raw,
132                self.pos
133            )));
134        }
135        Ok(self.advance())
136    }
137
138    fn match_keyword(&mut self, kw: &str) -> bool {
139        if let Some(t) = self.peek() {
140            if t.token_type == "keyword" && t.value == kw {
141                self.advance();
142                return true;
143            }
144        }
145        false
146    }
147
148    fn parse_statement(&mut self) -> Result<Statement, MdqlError> {
149        let t = self.peek().ok_or_else(|| MdqlError::QueryParse("Empty query".into()))?;
150        match (t.token_type.as_str(), t.value.as_str()) {
151            ("keyword", "SELECT") => {
152                let q = self.parse_select()?;
153                self.expect_end()?;
154                Ok(Statement::Select(q))
155            }
156            ("keyword", "INSERT") => Ok(Statement::Insert(self.parse_insert()?)),
157            ("keyword", "UPDATE") => Ok(Statement::Update(self.parse_update()?)),
158            ("keyword", "DELETE") => Ok(Statement::Delete(self.parse_delete()?)),
159            ("keyword", "ALTER") => self.parse_alter(),
160            ("keyword", "CREATE") => self.parse_create_view(),
161            ("keyword", "DROP") => self.parse_drop_view(),
162            _ => Err(MdqlError::QueryParse(format!(
163                "Expected SELECT, INSERT, UPDATE, DELETE, ALTER, CREATE, or DROP, got '{}'",
164                t.raw
165            ))),
166        }
167    }
168
169    fn parse_select(&mut self) -> Result<SelectQuery, MdqlError> {
170        self.expect("keyword", Some("SELECT"))?;
171        let columns = self.parse_columns()?;
172        self.expect("keyword", Some("FROM"))?;
173
174        // Subquery: FROM (SELECT ...)
175        let mut subquery = None;
176        let (table, mut table_alias) = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
177            self.advance();
178            let inner = self.parse_select()?;
179            self.expect("op", Some(")"))?;
180            subquery = Some(Box::new(inner));
181            let alias = if let Some(t) = self.peek() {
182                if t.token_type == "ident" && !self.is_clause_keyword(t) {
183                    Some(self.advance().value)
184                } else {
185                    None
186                }
187            } else {
188                None
189            };
190            ("_subquery".to_string(), alias)
191        } else {
192            let t = self.parse_ident()?;
193            (t, None)
194        };
195
196        // Optional table alias (for non-subquery)
197        if subquery.is_none() {
198            if let Some(t) = self.peek() {
199                if t.token_type == "ident" && !self.is_clause_keyword(t) {
200                    table_alias = Some(self.advance().value);
201                }
202            }
203        }
204
205        // Optional JOIN(s)
206        let mut joins = Vec::new();
207        while self.match_keyword("JOIN") {
208            let join_table = self.parse_ident()?;
209            let mut join_alias = None;
210            if let Some(t) = self.peek() {
211                if t.token_type == "ident" && !self.is_clause_keyword(t) {
212                    join_alias = Some(self.advance().value);
213                }
214            }
215            self.expect("keyword", Some("ON"))?;
216            let left_col = self.parse_ident()?;
217            self.expect("op", Some("="))?;
218            let right_col = self.parse_ident()?;
219            joins.push(JoinClause {
220                table: join_table,
221                alias: join_alias,
222                left_col,
223                right_col,
224            });
225        }
226
227        let mut where_clause = None;
228        if self.match_keyword("WHERE") {
229            where_clause = Some(self.parse_or_expr()?);
230        }
231
232        let mut group_by = None;
233        if self.match_keyword("GROUP") {
234            self.expect("keyword", Some("BY"))?;
235            let mut cols = vec![self.parse_ident()?];
236            while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
237                self.advance();
238                cols.push(self.parse_ident()?);
239            }
240            group_by = Some(cols);
241        }
242
243        let mut having = None;
244        if self.match_keyword("HAVING") {
245            having = Some(self.parse_or_expr()?);
246        }
247
248        let mut order_by = None;
249        if self.match_keyword("ORDER") {
250            self.expect("keyword", Some("BY"))?;
251            order_by = Some(self.parse_order_by()?);
252        }
253
254        let mut limit = None;
255        if self.match_keyword("LIMIT") {
256            let t = self.expect("number", None)?;
257            limit = Some(t.value.parse::<i64>().map_err(|_| {
258                MdqlError::QueryParse(format!("Invalid LIMIT value: {}", t.value))
259            })?);
260        }
261
262        Ok(SelectQuery {
263            columns,
264            table,
265            table_alias,
266            subquery,
267            joins,
268            where_clause,
269            group_by,
270            having,
271            order_by,
272            limit,
273        })
274    }
275
276    fn parse_insert(&mut self) -> Result<InsertQuery, MdqlError> {
277        self.expect("keyword", Some("INSERT"))?;
278        self.expect("keyword", Some("INTO"))?;
279        let table = self.parse_ident()?;
280
281        self.expect("op", Some("("))?;
282        let mut columns = vec![self.parse_ident()?];
283        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
284            self.advance();
285            columns.push(self.parse_ident()?);
286        }
287        self.expect("op", Some(")"))?;
288
289        self.expect("keyword", Some("VALUES"))?;
290
291        self.expect("op", Some("("))?;
292        let mut values = vec![self.parse_value()?];
293        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
294            self.advance();
295            values.push(self.parse_value()?);
296        }
297        self.expect("op", Some(")"))?;
298
299        if columns.len() != values.len() {
300            return Err(MdqlError::QueryParse(format!(
301                "Column count ({}) does not match value count ({})",
302                columns.len(),
303                values.len()
304            )));
305        }
306
307        self.expect_end()?;
308        Ok(InsertQuery {
309            table,
310            columns,
311            values,
312        })
313    }
314
315    fn parse_update(&mut self) -> Result<UpdateQuery, MdqlError> {
316        self.expect("keyword", Some("UPDATE"))?;
317        let table = self.parse_ident()?;
318        self.expect("keyword", Some("SET"))?;
319
320        let mut assignments = Vec::new();
321        let col = self.parse_ident()?;
322        self.expect("op", Some("="))?;
323        let val = self.parse_value()?;
324        assignments.push((col, val));
325
326        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
327            self.advance();
328            let col = self.parse_ident()?;
329            self.expect("op", Some("="))?;
330            let val = self.parse_value()?;
331            assignments.push((col, val));
332        }
333
334        let mut where_clause = None;
335        if self.match_keyword("WHERE") {
336            where_clause = Some(self.parse_or_expr()?);
337        }
338
339        self.expect_end()?;
340        Ok(UpdateQuery {
341            table,
342            assignments,
343            where_clause,
344        })
345    }
346
347    fn parse_delete(&mut self) -> Result<DeleteQuery, MdqlError> {
348        self.expect("keyword", Some("DELETE"))?;
349        self.expect("keyword", Some("FROM"))?;
350        let table = self.parse_ident()?;
351
352        let mut where_clause = None;
353        if self.match_keyword("WHERE") {
354            where_clause = Some(self.parse_or_expr()?);
355        }
356
357        self.expect_end()?;
358        Ok(DeleteQuery {
359            table,
360            where_clause,
361        })
362    }
363
364    fn parse_alter(&mut self) -> Result<Statement, MdqlError> {
365        self.expect("keyword", Some("ALTER"))?;
366        self.expect("keyword", Some("TABLE"))?;
367        let table = self.parse_ident()?;
368
369        let t = self.peek().ok_or_else(|| {
370            MdqlError::QueryParse("Expected RENAME, DROP, or MERGE after table name".into())
371        })?;
372
373        match (t.token_type.as_str(), t.value.as_str()) {
374            ("keyword", "RENAME") => {
375                self.advance();
376                self.expect("keyword", Some("FIELD"))?;
377                let old_name = self.parse_string_or_ident()?;
378                self.expect("keyword", Some("TO"))?;
379                let new_name = self.parse_string_or_ident()?;
380                self.expect_end()?;
381                Ok(Statement::AlterRename(AlterRenameFieldQuery {
382                    table,
383                    old_name,
384                    new_name,
385                }))
386            }
387            ("keyword", "DROP") => {
388                self.advance();
389                self.expect("keyword", Some("FIELD"))?;
390                let field_name = self.parse_string_or_ident()?;
391                self.expect_end()?;
392                Ok(Statement::AlterDrop(AlterDropFieldQuery {
393                    table,
394                    field_name,
395                }))
396            }
397            ("keyword", "MERGE") => {
398                self.advance();
399                self.expect("keyword", Some("FIELDS"))?;
400                let mut sources = vec![self.parse_string_or_ident()?];
401                while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
402                    self.advance();
403                    sources.push(self.parse_string_or_ident()?);
404                }
405                self.expect("keyword", Some("INTO"))?;
406                let target = self.parse_string_or_ident()?;
407                self.expect_end()?;
408                Ok(Statement::AlterMerge(AlterMergeFieldsQuery {
409                    table,
410                    sources,
411                    into: target,
412                }))
413            }
414            _ => Err(MdqlError::QueryParse(format!(
415                "Expected RENAME, DROP, or MERGE, got '{}'",
416                t.raw
417            ))),
418        }
419    }
420
421    fn parse_create_view(&mut self) -> Result<Statement, MdqlError> {
422        self.expect("keyword", Some("CREATE"))?;
423        self.expect("keyword", Some("VIEW"))?;
424        let view_name = self.parse_ident()?;
425
426        let columns = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
427            self.advance();
428            let mut cols = vec![self.parse_ident()?];
429            while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
430                self.advance();
431                cols.push(self.parse_ident()?);
432            }
433            self.expect("op", Some(")"))?;
434            Some(cols)
435        } else {
436            None
437        };
438
439        self.expect("keyword", Some("AS"))?;
440        let query = Box::new(self.parse_select()?);
441        self.expect_end()?;
442
443        Ok(Statement::CreateView(CreateViewQuery {
444            view_name,
445            columns,
446            query,
447        }))
448    }
449
450    fn parse_drop_view(&mut self) -> Result<Statement, MdqlError> {
451        self.expect("keyword", Some("DROP"))?;
452        self.expect("keyword", Some("VIEW"))?;
453        let view_name = self.parse_ident()?;
454        self.expect_end()?;
455        Ok(Statement::DropView(DropViewQuery { view_name }))
456    }
457
458    fn parse_string_or_ident(&mut self) -> Result<String, MdqlError> {
459        let t = self.peek().ok_or_else(|| {
460            MdqlError::QueryParse("Expected field name, got end of query".into())
461        })?;
462        match t.token_type.as_str() {
463            "string" => {
464                let v = self.advance().value;
465                Ok(v)
466            }
467            "ident" | "keyword" => {
468                let v = self.advance().value;
469                Ok(v)
470            }
471            _ => Err(MdqlError::QueryParse(format!(
472                "Expected field name, got '{}'",
473                t.raw
474            ))),
475        }
476    }
477
478    fn parse_columns(&mut self) -> Result<ColumnList, MdqlError> {
479        if let Some(t) = self.peek() {
480            if t.token_type == "op" && t.value == "*" {
481                self.advance();
482                return Ok(ColumnList::All);
483            }
484        }
485
486        let mut exprs = vec![self.parse_select_expr()?];
487        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
488            self.advance();
489            exprs.push(self.parse_select_expr()?);
490        }
491        Ok(ColumnList::Named(exprs))
492    }
493
494    fn peek_is_agg_func(&self) -> bool {
495        let t = match self.peek() {
496            Some(t) => t,
497            None => return false,
498        };
499        let name_upper = t.value.to_uppercase();
500        if !AGG_FUNCS.contains(&name_upper.as_str()) {
501            return false;
502        }
503        // Only treat as aggregate if followed by (
504        self.tokens
505            .get(self.pos + 1)
506            .map_or(false, |next| next.token_type == "op" && next.value == "(")
507    }
508
509    fn parse_select_expr(&mut self) -> Result<SelectExpr, MdqlError> {
510        let _t = self.peek().ok_or_else(|| {
511            MdqlError::QueryParse("Expected column or aggregate, got end of query".into())
512        })?;
513
514        let expr = self.parse_additive()?;
515
516        let alias = if self.match_keyword("AS") {
517            Some(self.parse_ident()?)
518        } else if self.peek().map_or(false, |t| {
519            t.token_type == "ident" && !self.is_clause_keyword(t)
520        }) {
521            Some(self.advance().value)
522        } else {
523            None
524        };
525
526        // Bare aggregate → SelectExpr::Aggregate for backward compat
527        if let Expr::Aggregate { func, arg, arg_expr } = expr {
528            return Ok(SelectExpr::Aggregate {
529                func,
530                arg,
531                arg_expr: arg_expr.map(|e| *e),
532                alias,
533            });
534        }
535
536        if alias.is_none() {
537            if let Expr::Column(name) = &expr {
538                return Ok(SelectExpr::Column(name.clone()));
539            }
540        }
541
542        Ok(SelectExpr::Expr { expr, alias })
543    }
544
545    // ── Expression parser (precedence climbing) ───────────────────────
546
547    fn peek_is_additive_op(&self) -> bool {
548        self.peek().map_or(false, |t| {
549            t.token_type == "op" && (t.value == "+" || t.value == "-")
550        })
551    }
552
553    fn peek_is_multiplicative_op(&self) -> bool {
554        self.peek().map_or(false, |t| {
555            t.token_type == "op" && (t.value == "*" || t.value == "/" || t.value == "%")
556        })
557    }
558
559    fn parse_additive(&mut self) -> Result<Expr, MdqlError> {
560        let mut left = self.parse_multiplicative()?;
561        while self.peek_is_additive_op() {
562            let op_tok = self.advance();
563            let is_sub = op_tok.value == "-";
564
565            // Check for INTERVAL keyword: expr +/- INTERVAL n DAY
566            if self.peek().map_or(false, |t| t.token_type == "keyword" && t.value == "INTERVAL") {
567                self.advance(); // consume INTERVAL
568                let days_expr = self.parse_multiplicative()?;
569                // Expect DAY or DAYS
570                if !self.match_keyword("DAY") && !self.match_keyword("DAYS") {
571                    return Err(MdqlError::QueryParse("Expected DAY after INTERVAL value".into()));
572                }
573                let days = if is_sub {
574                    Expr::UnaryMinus(Box::new(days_expr))
575                } else {
576                    days_expr
577                };
578                left = Expr::DateAdd {
579                    date: Box::new(left),
580                    days: Box::new(days),
581                };
582                continue;
583            }
584
585            let op = match op_tok.value.as_str() {
586                "+" => ArithOp::Add,
587                "-" => ArithOp::Sub,
588                _ => unreachable!(),
589            };
590            let right = self.parse_multiplicative()?;
591            left = Expr::BinaryOp {
592                left: Box::new(left),
593                op,
594                right: Box::new(right),
595            };
596        }
597        Ok(left)
598    }
599
600    fn parse_multiplicative(&mut self) -> Result<Expr, MdqlError> {
601        let mut left = self.parse_unary()?;
602        while self.peek_is_multiplicative_op() {
603            let op_tok = self.advance();
604            let op = match op_tok.value.as_str() {
605                "*" => ArithOp::Mul,
606                "/" => ArithOp::Div,
607                "%" => ArithOp::Mod,
608                _ => unreachable!(),
609            };
610            let right = self.parse_unary()?;
611            left = Expr::BinaryOp {
612                left: Box::new(left),
613                op,
614                right: Box::new(right),
615            };
616        }
617        Ok(left)
618    }
619
620    fn parse_unary(&mut self) -> Result<Expr, MdqlError> {
621        if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "-") {
622            self.advance();
623            let inner = self.parse_atom()?;
624            // Fold unary minus on literals
625            match inner {
626                Expr::Literal(SqlValue::Int(n)) => Ok(Expr::Literal(SqlValue::Int(-n))),
627                Expr::Literal(SqlValue::Float(f)) => Ok(Expr::Literal(SqlValue::Float(-f))),
628                _ => Ok(Expr::UnaryMinus(Box::new(inner))),
629            }
630        } else {
631            self.parse_atom()
632        }
633    }
634
635    fn parse_atom(&mut self) -> Result<Expr, MdqlError> {
636        if self.peek_is_agg_func() {
637            return self.parse_agg_expr();
638        }
639
640        let t = self.peek().ok_or_else(|| {
641            MdqlError::QueryParse("Expected expression, got end of query".into())
642        })?;
643
644        match t.token_type.as_str() {
645            "number" => {
646                let v = self.advance().value;
647                if v.contains('.') {
648                    let f: f64 = v.parse().map_err(|_| {
649                        MdqlError::QueryParse(format!("Invalid float: {}", v))
650                    })?;
651                    Ok(Expr::Literal(SqlValue::Float(f)))
652                } else {
653                    let n: i64 = v.parse().map_err(|_| {
654                        MdqlError::QueryParse(format!("Invalid int: {}", v))
655                    })?;
656                    Ok(Expr::Literal(SqlValue::Int(n)))
657                }
658            }
659            "string" => {
660                let v = self.advance().value;
661                Ok(Expr::Literal(SqlValue::String(v)))
662            }
663            "keyword" if t.value == "NULL" => {
664                self.advance();
665                Ok(Expr::Literal(SqlValue::Null))
666            }
667            "keyword" if t.value == "CASE" => {
668                self.parse_case_expr()
669            }
670            "keyword" if t.value == "CURRENT_DATE" => {
671                self.advance();
672                Ok(Expr::CurrentDate)
673            }
674            "keyword" if t.value == "CURRENT_TIMESTAMP" => {
675                self.advance();
676                Ok(Expr::CurrentTimestamp)
677            }
678            "keyword" if t.value == "DATEDIFF" => {
679                self.advance();
680                self.expect("op", Some("("))?;
681                let left = self.parse_additive()?;
682                self.expect("op", Some(","))?;
683                let right = self.parse_additive()?;
684                self.expect("op", Some(")"))?;
685                Ok(Expr::DateDiff { left: Box::new(left), right: Box::new(right) })
686            }
687            "op" if t.value == "(" => {
688                self.advance();
689                let expr = self.parse_additive()?;
690                self.expect("op", Some(")"))?;
691                Ok(expr)
692            }
693            "ident" => {
694                let name = self.advance().value;
695                Ok(Expr::Column(name))
696            }
697            "keyword" if !Self::is_reserved_keyword(&t.value) => {
698                let name = self.advance().value;
699                Ok(Expr::Column(name))
700            }
701            _ => Err(MdqlError::QueryParse(format!(
702                "Expected expression, got '{}'",
703                t.raw
704            ))),
705        }
706    }
707
708    fn parse_case_expr(&mut self) -> Result<Expr, MdqlError> {
709        self.expect("keyword", Some("CASE"))?;
710        let mut whens = Vec::new();
711        while self.match_keyword("WHEN") {
712            let condition = self.parse_or_expr()?;
713            self.expect("keyword", Some("THEN"))?;
714            let result = self.parse_additive()?;
715            whens.push((condition, Box::new(result)));
716        }
717        if whens.is_empty() {
718            return Err(MdqlError::QueryParse("CASE requires at least one WHEN clause".into()));
719        }
720        let else_expr = if self.match_keyword("ELSE") {
721            Some(Box::new(self.parse_additive()?))
722        } else {
723            None
724        };
725        self.expect("keyword", Some("END"))?;
726        Ok(Expr::Case { whens, else_expr })
727    }
728
729    fn parse_agg_expr(&mut self) -> Result<Expr, MdqlError> {
730        let func_name = self.advance().value.to_uppercase();
731        let func = match func_name.as_str() {
732            "COUNT" => AggFunc::Count,
733            "SUM" => AggFunc::Sum,
734            "AVG" => AggFunc::Avg,
735            "MIN" => AggFunc::Min,
736            "MAX" => AggFunc::Max,
737            _ => unreachable!(),
738        };
739        self.expect("op", Some("("))?;
740        let (arg, arg_expr) = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "*") {
741            self.advance();
742            ("*".to_string(), None)
743        } else {
744            let expr = self.parse_additive()?;
745            if let Expr::Column(name) = &expr {
746                (name.clone(), None)
747            } else {
748                (expr.display_name(), Some(Box::new(expr)))
749            }
750        };
751        self.expect("op", Some(")"))?;
752        Ok(Expr::Aggregate { func, arg, arg_expr })
753    }
754
755    fn parse_ident(&mut self) -> Result<String, MdqlError> {
756        let t = self.peek().ok_or_else(|| {
757            MdqlError::QueryParse("Expected identifier, got end of query".into())
758        })?;
759        match t.token_type.as_str() {
760            "ident" | "keyword" => {
761                let v = self.advance().value;
762                Ok(v)
763            }
764            _ => Err(MdqlError::QueryParse(format!(
765                "Expected identifier, got '{}'",
766                t.raw
767            ))),
768        }
769    }
770
771    fn parse_or_expr(&mut self) -> Result<WhereClause, MdqlError> {
772        let mut left = self.parse_and_expr()?;
773        while self.match_keyword("OR") {
774            let right = self.parse_and_expr()?;
775            left = WhereClause::BoolOp(BoolOp {
776                op: BoolOpKind::Or,
777                left: Box::new(left),
778                right: Box::new(right),
779            });
780        }
781        Ok(left)
782    }
783
784    fn parse_and_expr(&mut self) -> Result<WhereClause, MdqlError> {
785        let mut left = self.parse_comparison()?;
786        while self.match_keyword("AND") {
787            let right = self.parse_comparison()?;
788            left = WhereClause::BoolOp(BoolOp {
789                op: BoolOpKind::And,
790                left: Box::new(left),
791                right: Box::new(right),
792            });
793        }
794        Ok(left)
795    }
796
797    fn parse_comparison(&mut self) -> Result<WhereClause, MdqlError> {
798        // Handle parenthesized boolean expressions
799        if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
800            // Save position — might be arithmetic parens, not boolean
801            let saved_pos = self.pos;
802            self.advance();
803            // Try parsing as boolean (OR/AND) expression
804            let result = self.parse_or_expr();
805            if result.is_ok() && self.peek().map_or(false, |t| t.token_type == "op" && t.value == ")") {
806                self.advance();
807                return result;
808            }
809            // Not a boolean paren — rewind and parse as arithmetic expression
810            self.pos = saved_pos;
811        }
812
813        // Parse the left side as a full expression (column, literal, or arithmetic)
814        let left_expr = self.parse_additive()?;
815
816        // Extract column name for backward compat (simple column on left side)
817        let col = left_expr.as_column().unwrap_or("").to_string();
818
819        // IS NULL / IS NOT NULL (only valid with simple column)
820        if self.match_keyword("IS") {
821            if self.match_keyword("NOT") {
822                self.expect("keyword", Some("NULL"))?;
823                return Ok(WhereClause::Comparison(Comparison {
824                    column: col,
825                    op: CmpOp::IsNotNull,
826                    value: None,
827                    left_expr: Some(left_expr),
828                    right_expr: None,
829                }));
830            }
831            self.expect("keyword", Some("NULL"))?;
832            return Ok(WhereClause::Comparison(Comparison {
833                column: col,
834                op: CmpOp::IsNull,
835                value: None,
836                left_expr: Some(left_expr),
837                right_expr: None,
838            }));
839        }
840
841        // IN (val, val, ...)
842        if self.match_keyword("IN") {
843            self.expect("op", Some("("))?;
844            let mut values = vec![self.parse_value()?];
845            while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
846                self.advance();
847                values.push(self.parse_value()?);
848            }
849            self.expect("op", Some(")"))?;
850            return Ok(WhereClause::Comparison(Comparison {
851                column: col,
852                op: CmpOp::In,
853                value: Some(SqlValue::List(values)),
854                left_expr: Some(left_expr),
855                right_expr: None,
856            }));
857        }
858
859        // LIKE
860        if self.match_keyword("LIKE") {
861            let val = self.parse_value()?;
862            return Ok(WhereClause::Comparison(Comparison {
863                column: col,
864                op: CmpOp::Like,
865                value: Some(val),
866                left_expr: Some(left_expr),
867                right_expr: None,
868            }));
869        }
870
871        // NOT LIKE
872        if self.match_keyword("NOT") {
873            if self.match_keyword("LIKE") {
874                let val = self.parse_value()?;
875                return Ok(WhereClause::Comparison(Comparison {
876                    column: col,
877                    op: CmpOp::NotLike,
878                    value: Some(val),
879                    left_expr: Some(left_expr),
880                    right_expr: None,
881                }));
882            }
883            return Err(MdqlError::QueryParse("Expected LIKE after NOT".into()));
884        }
885
886        // Standard comparison operators
887        if let Some(t) = self.peek() {
888            if t.token_type == "op" && ["=", "!=", "<", ">", "<=", ">="].contains(&t.value.as_str())
889            {
890                let op_str = self.advance().value;
891                let op = match op_str.as_str() {
892                    "=" => CmpOp::Eq,
893                    "!=" => CmpOp::Ne,
894                    "<" => CmpOp::Lt,
895                    ">" => CmpOp::Gt,
896                    "<=" => CmpOp::Le,
897                    ">=" => CmpOp::Ge,
898                    _ => unreachable!(),
899                };
900                // Parse right side as expression
901                let right_expr = self.parse_additive()?;
902                // Extract SqlValue for backward compat (simple literal on right side)
903                let value = match &right_expr {
904                    Expr::Literal(v) => Some(v.clone()),
905                    _ => None,
906                };
907                return Ok(WhereClause::Comparison(Comparison {
908                    column: col,
909                    op,
910                    value,
911                    left_expr: Some(left_expr),
912                    right_expr: Some(right_expr),
913                }));
914            }
915        }
916
917        let got = self.peek().map_or("end".to_string(), |t| t.raw.clone());
918        Err(MdqlError::QueryParse(format!(
919            "Expected operator after '{}', got '{}'",
920            left_expr.display_name(), got
921        )))
922    }
923
924    fn parse_value(&mut self) -> Result<SqlValue, MdqlError> {
925        let t = self.peek().ok_or_else(|| {
926            MdqlError::QueryParse("Expected value, got end of query".into())
927        })?;
928        match t.token_type.as_str() {
929            "string" => {
930                let v = self.advance().value;
931                Ok(SqlValue::String(v))
932            }
933            "number" => {
934                let v = self.advance().value;
935                if v.contains('.') {
936                    Ok(SqlValue::Float(v.parse().map_err(|_| {
937                        MdqlError::QueryParse(format!("Invalid float: {}", v))
938                    })?))
939                } else {
940                    Ok(SqlValue::Int(v.parse().map_err(|_| {
941                        MdqlError::QueryParse(format!("Invalid int: {}", v))
942                    })?))
943                }
944            }
945            "keyword" if t.value == "NULL" => {
946                self.advance();
947                Ok(SqlValue::Null)
948            }
949            _ => Err(MdqlError::QueryParse(format!(
950                "Expected value, got '{}'",
951                t.raw
952            ))),
953        }
954    }
955
956    fn parse_order_by(&mut self) -> Result<Vec<OrderSpec>, MdqlError> {
957        let mut specs = vec![self.parse_order_spec()?];
958        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
959            self.advance();
960            specs.push(self.parse_order_spec()?);
961        }
962        Ok(specs)
963    }
964
965    fn parse_order_spec(&mut self) -> Result<OrderSpec, MdqlError> {
966        let expr = self.parse_additive()?;
967        let col = expr.as_column().unwrap_or("").to_string();
968        let descending = if self.match_keyword("DESC") {
969            true
970        } else {
971            self.match_keyword("ASC");
972            false
973        };
974        Ok(OrderSpec {
975            column: col,
976            expr: Some(expr),
977            descending,
978        })
979    }
980
981    fn is_clause_keyword(&self, t: &Token) -> bool {
982        t.token_type == "keyword"
983            && ["WHERE", "ORDER", "LIMIT", "JOIN", "ON", "GROUP"].contains(&t.value.as_str())
984    }
985
986    /// Keywords that should never be consumed as column names inside expressions.
987    fn is_reserved_keyword(kw: &str) -> bool {
988        matches!(kw,
989            "AS" | "FROM" | "WHERE" | "AND" | "OR" | "ORDER" | "BY"
990            | "ASC" | "DESC" | "LIMIT" | "JOIN" | "ON" | "GROUP"
991            | "SELECT" | "INSERT" | "INTO" | "VALUES" | "UPDATE" | "SET"
992            | "DELETE" | "ALTER" | "TABLE" | "IS" | "NOT" | "IN" | "LIKE"
993            | "RENAME" | "FIELD" | "TO" | "DROP" | "MERGE" | "FIELDS"
994            | "CASE" | "WHEN" | "THEN" | "ELSE" | "END"
995            | "HAVING" | "INTERVAL" | "DAY" | "DAYS"
996            | "CURRENT_DATE" | "CURRENT_TIMESTAMP" | "DATEDIFF"
997            | "CREATE" | "VIEW" | "CASCADE" | "RESTRICT"
998        )
999    }
1000
1001    fn expect_end(&self) -> Result<(), MdqlError> {
1002        if let Some(t) = self.peek() {
1003            return Err(MdqlError::QueryParse(format!(
1004                "Unexpected token '{}' at position {}",
1005                t.raw, self.pos
1006            )));
1007        }
1008        Ok(())
1009    }
1010}
1011
1012pub fn parse_query(sql: &str) -> crate::errors::Result<Statement> {
1013    let tokens = tokenize(sql);
1014    if tokens.is_empty() {
1015        return Err(MdqlError::QueryParse("Empty query".into()));
1016    }
1017    let mut parser = Parser::new(tokens);
1018    parser.parse_statement()
1019}
1020
1021#[cfg(test)]
1022mod tests {
1023    use super::*;
1024
1025    #[test]
1026    fn test_simple_select() {
1027        let stmt = parse_query("SELECT title, status FROM strategies").unwrap();
1028        if let Statement::Select(q) = stmt {
1029            assert_eq!(q.columns, ColumnList::Named(vec![SelectExpr::Column("title".into()), SelectExpr::Column("status".into())]));
1030            assert_eq!(q.table, "strategies");
1031        } else {
1032            panic!("Expected Select");
1033        }
1034    }
1035
1036    #[test]
1037    fn test_select_star() {
1038        let stmt = parse_query("SELECT * FROM test").unwrap();
1039        if let Statement::Select(q) = stmt {
1040            assert_eq!(q.columns, ColumnList::All);
1041        } else {
1042            panic!("Expected Select");
1043        }
1044    }
1045
1046    #[test]
1047    fn test_where_clause() {
1048        let stmt = parse_query("SELECT title FROM test WHERE count > 5").unwrap();
1049        if let Statement::Select(q) = stmt {
1050            assert!(q.where_clause.is_some());
1051        } else {
1052            panic!("Expected Select");
1053        }
1054    }
1055
1056    #[test]
1057    fn test_order_by() {
1058        let stmt =
1059            parse_query("SELECT title FROM test ORDER BY composite DESC, title ASC").unwrap();
1060        if let Statement::Select(q) = stmt {
1061            let ob = q.order_by.unwrap();
1062            assert_eq!(ob.len(), 2);
1063            assert!(ob[0].descending);
1064            assert!(!ob[1].descending);
1065        } else {
1066            panic!("Expected Select");
1067        }
1068    }
1069
1070    #[test]
1071    fn test_limit() {
1072        let stmt = parse_query("SELECT * FROM test LIMIT 10").unwrap();
1073        if let Statement::Select(q) = stmt {
1074            assert_eq!(q.limit, Some(10));
1075        } else {
1076            panic!("Expected Select");
1077        }
1078    }
1079
1080    #[test]
1081    fn test_insert() {
1082        let stmt = parse_query(
1083            "INSERT INTO test (title, count) VALUES ('Hello', 42)",
1084        )
1085        .unwrap();
1086        if let Statement::Insert(q) = stmt {
1087            assert_eq!(q.table, "test");
1088            assert_eq!(q.columns, vec!["title", "count"]);
1089            assert_eq!(q.values[0], SqlValue::String("Hello".into()));
1090            assert_eq!(q.values[1], SqlValue::Int(42));
1091        } else {
1092            panic!("Expected Insert");
1093        }
1094    }
1095
1096    #[test]
1097    fn test_update() {
1098        let stmt = parse_query("UPDATE test SET status = 'KILLED' WHERE path = 'a.md'").unwrap();
1099        if let Statement::Update(q) = stmt {
1100            assert_eq!(q.table, "test");
1101            assert_eq!(q.assignments.len(), 1);
1102            assert!(q.where_clause.is_some());
1103        } else {
1104            panic!("Expected Update");
1105        }
1106    }
1107
1108    #[test]
1109    fn test_delete() {
1110        let stmt = parse_query("DELETE FROM test WHERE status = 'draft'").unwrap();
1111        if let Statement::Delete(q) = stmt {
1112            assert_eq!(q.table, "test");
1113            assert!(q.where_clause.is_some());
1114        } else {
1115            panic!("Expected Delete");
1116        }
1117    }
1118
1119    #[test]
1120    fn test_alter_rename() {
1121        let stmt =
1122            parse_query("ALTER TABLE test RENAME FIELD 'Summary' TO 'Overview'").unwrap();
1123        if let Statement::AlterRename(q) = stmt {
1124            assert_eq!(q.old_name, "Summary");
1125            assert_eq!(q.new_name, "Overview");
1126        } else {
1127            panic!("Expected AlterRename");
1128        }
1129    }
1130
1131    #[test]
1132    fn test_alter_drop() {
1133        let stmt = parse_query("ALTER TABLE test DROP FIELD 'Details'").unwrap();
1134        if let Statement::AlterDrop(q) = stmt {
1135            assert_eq!(q.field_name, "Details");
1136        } else {
1137            panic!("Expected AlterDrop");
1138        }
1139    }
1140
1141    #[test]
1142    fn test_alter_merge() {
1143        let stmt = parse_query(
1144            "ALTER TABLE test MERGE FIELDS 'Entry Rules', 'Exit Rules' INTO 'Trading Rules'",
1145        )
1146        .unwrap();
1147        if let Statement::AlterMerge(q) = stmt {
1148            assert_eq!(q.sources, vec!["Entry Rules", "Exit Rules"]);
1149            assert_eq!(q.into, "Trading Rules");
1150        } else {
1151            panic!("Expected AlterMerge");
1152        }
1153    }
1154
1155    #[test]
1156    fn test_backtick_ident() {
1157        let stmt = parse_query("SELECT `Structural Mechanism` FROM test").unwrap();
1158        if let Statement::Select(q) = stmt {
1159            assert_eq!(
1160                q.columns,
1161                ColumnList::Named(vec![SelectExpr::Column("Structural Mechanism".into())])
1162            );
1163        } else {
1164            panic!("Expected Select");
1165        }
1166    }
1167
1168    #[test]
1169    fn test_like_operator() {
1170        let stmt = parse_query("SELECT title FROM test WHERE categories LIKE '%defi%'").unwrap();
1171        if let Statement::Select(q) = stmt {
1172            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1173                assert_eq!(c.op, CmpOp::Like);
1174                assert_eq!(c.value, Some(SqlValue::String("%defi%".into())));
1175            } else {
1176                panic!("Expected LIKE comparison");
1177            }
1178        } else {
1179            panic!("Expected Select");
1180        }
1181    }
1182
1183    #[test]
1184    fn test_in_operator() {
1185        let stmt =
1186            parse_query("SELECT * FROM test WHERE status IN ('ACTIVE', 'LIVE')").unwrap();
1187        if let Statement::Select(q) = stmt {
1188            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1189                assert_eq!(c.op, CmpOp::In);
1190            } else {
1191                panic!("Expected IN comparison");
1192            }
1193        } else {
1194            panic!("Expected Select");
1195        }
1196    }
1197
1198    #[test]
1199    fn test_is_null() {
1200        let stmt = parse_query("SELECT * FROM test WHERE title IS NULL").unwrap();
1201        if let Statement::Select(q) = stmt {
1202            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1203                assert_eq!(c.op, CmpOp::IsNull);
1204            } else {
1205                panic!("Expected IS NULL comparison");
1206            }
1207        } else {
1208            panic!("Expected Select");
1209        }
1210    }
1211
1212    #[test]
1213    fn test_and_or() {
1214        let stmt = parse_query(
1215            "SELECT * FROM test WHERE status = 'ACTIVE' AND count > 5 OR title LIKE '%test%'",
1216        )
1217        .unwrap();
1218        if let Statement::Select(q) = stmt {
1219            assert!(q.where_clause.is_some());
1220        } else {
1221            panic!("Expected Select");
1222        }
1223    }
1224
1225    #[test]
1226    fn test_join() {
1227        let stmt = parse_query(
1228            "SELECT s.title, b.sharpe FROM strategies s JOIN backtests b ON b.strategy = s.path",
1229        )
1230        .unwrap();
1231        if let Statement::Select(q) = stmt {
1232            assert_eq!(q.table, "strategies");
1233            assert_eq!(q.table_alias, Some("s".into()));
1234            assert_eq!(q.joins.len(), 1);
1235            let join = &q.joins[0];
1236            assert_eq!(join.table, "backtests");
1237            assert_eq!(join.alias, Some("b".into()));
1238        } else {
1239            panic!("Expected Select");
1240        }
1241    }
1242
1243    #[test]
1244    fn test_multi_join() {
1245        let stmt = parse_query(
1246            "SELECT s.title, b.sharpe, c.verdict FROM strategies s JOIN backtests b ON b.strategy = s.path JOIN critiques c ON c.strategy = s.path",
1247        )
1248        .unwrap();
1249        if let Statement::Select(q) = stmt {
1250            assert_eq!(q.table, "strategies");
1251            assert_eq!(q.table_alias, Some("s".into()));
1252            assert_eq!(q.joins.len(), 2);
1253            assert_eq!(q.joins[0].table, "backtests");
1254            assert_eq!(q.joins[0].alias, Some("b".into()));
1255            assert_eq!(q.joins[0].left_col, "b.strategy");
1256            assert_eq!(q.joins[0].right_col, "s.path");
1257            assert_eq!(q.joins[1].table, "critiques");
1258            assert_eq!(q.joins[1].alias, Some("c".into()));
1259            assert_eq!(q.joins[1].left_col, "c.strategy");
1260            assert_eq!(q.joins[1].right_col, "s.path");
1261        } else {
1262            panic!("Expected Select");
1263        }
1264    }
1265
1266    #[test]
1267    fn test_empty_query() {
1268        assert!(parse_query("").is_err());
1269    }
1270
1271    #[test]
1272    fn test_count_star() {
1273        let stmt = parse_query("SELECT status, COUNT(*) AS cnt FROM strategies GROUP BY status").unwrap();
1274        if let Statement::Select(q) = stmt {
1275            if let ColumnList::Named(exprs) = &q.columns {
1276                assert_eq!(exprs.len(), 2);
1277                assert_eq!(exprs[0], SelectExpr::Column("status".into()));
1278                assert!(matches!(&exprs[1], SelectExpr::Aggregate {
1279                    func: AggFunc::Count,
1280                    arg,
1281                    alias: Some(a),
1282                    ..
1283                } if arg == "*" && a == "cnt"));
1284            } else {
1285                panic!("Expected Named columns");
1286            }
1287            assert_eq!(q.group_by, Some(vec!["status".into()]));
1288        } else {
1289            panic!("Expected Select");
1290        }
1291    }
1292
1293    #[test]
1294    fn test_count_column_as_ident() {
1295        // "count" as a column name should NOT be parsed as the COUNT aggregate
1296        let stmt = parse_query("INSERT INTO test (title, count) VALUES ('Hello', 42)").unwrap();
1297        if let Statement::Insert(q) = stmt {
1298            assert_eq!(q.columns, vec!["title", "count"]);
1299        } else {
1300            panic!("Expected Insert");
1301        }
1302    }
1303
1304    #[test]
1305    fn test_multiple_aggregates() {
1306        let stmt = parse_query("SELECT MIN(composite), MAX(composite), AVG(composite) FROM strategies").unwrap();
1307        if let Statement::Select(q) = stmt {
1308            if let ColumnList::Named(exprs) = &q.columns {
1309                assert_eq!(exprs.len(), 3);
1310                assert!(matches!(&exprs[0], SelectExpr::Aggregate { func: AggFunc::Min, .. }));
1311                assert!(matches!(&exprs[1], SelectExpr::Aggregate { func: AggFunc::Max, .. }));
1312                assert!(matches!(&exprs[2], SelectExpr::Aggregate { func: AggFunc::Avg, .. }));
1313            } else {
1314                panic!("Expected Named columns");
1315            }
1316            assert_eq!(q.group_by, None);
1317        } else {
1318            panic!("Expected Select");
1319        }
1320    }
1321
1322    // ── Expression tests ──────────────────────────────────────────
1323
1324    #[test]
1325    fn test_select_arithmetic_expr() {
1326        let stmt = parse_query("SELECT a + b FROM test").unwrap();
1327        if let Statement::Select(q) = stmt {
1328            if let ColumnList::Named(exprs) = &q.columns {
1329                assert_eq!(exprs.len(), 1);
1330                assert!(matches!(&exprs[0], SelectExpr::Expr {
1331                    expr: Expr::BinaryOp { op: ArithOp::Add, .. },
1332                    alias: None,
1333                }));
1334            } else {
1335                panic!("Expected Named columns");
1336            }
1337        } else {
1338            panic!("Expected Select");
1339        }
1340    }
1341
1342    #[test]
1343    fn test_select_arithmetic_with_alias() {
1344        let stmt = parse_query("SELECT a + b AS total FROM test").unwrap();
1345        if let Statement::Select(q) = stmt {
1346            if let ColumnList::Named(exprs) = &q.columns {
1347                assert_eq!(exprs.len(), 1);
1348                assert!(matches!(&exprs[0], SelectExpr::Expr {
1349                    alias: Some(a),
1350                    ..
1351                } if a == "total"));
1352                assert_eq!(exprs[0].output_name(), "total");
1353            } else {
1354                panic!("Expected Named columns");
1355            }
1356        } else {
1357            panic!("Expected Select");
1358        }
1359    }
1360
1361    #[test]
1362    fn test_select_precedence() {
1363        // a + b * c should parse as a + (b * c)
1364        let stmt = parse_query("SELECT a + b * c FROM test").unwrap();
1365        if let Statement::Select(q) = stmt {
1366            if let ColumnList::Named(exprs) = &q.columns {
1367                if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1368                    if let Expr::BinaryOp { left, op, right } = expr {
1369                        assert_eq!(*op, ArithOp::Add);
1370                        assert!(matches!(left.as_ref(), Expr::Column(n) if n == "a"));
1371                        assert!(matches!(right.as_ref(), Expr::BinaryOp { op: ArithOp::Mul, .. }));
1372                    } else {
1373                        panic!("Expected BinaryOp");
1374                    }
1375                } else {
1376                    panic!("Expected Expr variant");
1377                }
1378            } else {
1379                panic!("Expected Named columns");
1380            }
1381        } else {
1382            panic!("Expected Select");
1383        }
1384    }
1385
1386    #[test]
1387    fn test_select_parenthesized_expr() {
1388        // (a + b) * c should override default precedence
1389        let stmt = parse_query("SELECT (a + b) * c FROM test").unwrap();
1390        if let Statement::Select(q) = stmt {
1391            if let ColumnList::Named(exprs) = &q.columns {
1392                if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1393                    if let Expr::BinaryOp { left, op, .. } = expr {
1394                        assert_eq!(*op, ArithOp::Mul);
1395                        assert!(matches!(left.as_ref(), Expr::BinaryOp { op: ArithOp::Add, .. }));
1396                    } else {
1397                        panic!("Expected BinaryOp");
1398                    }
1399                } else {
1400                    panic!("Expected Expr variant");
1401                }
1402            } else {
1403                panic!("Expected Named columns");
1404            }
1405        } else {
1406            panic!("Expected Select");
1407        }
1408    }
1409
1410    #[test]
1411    fn test_select_unary_minus() {
1412        let stmt = parse_query("SELECT -count FROM test").unwrap();
1413        if let Statement::Select(q) = stmt {
1414            if let ColumnList::Named(exprs) = &q.columns {
1415                assert!(matches!(&exprs[0], SelectExpr::Expr {
1416                    expr: Expr::UnaryMinus(_),
1417                    ..
1418                }));
1419            } else {
1420                panic!("Expected Named columns");
1421            }
1422        } else {
1423            panic!("Expected Select");
1424        }
1425    }
1426
1427    #[test]
1428    fn test_select_negative_literal() {
1429        let stmt = parse_query("SELECT -42 FROM test").unwrap();
1430        if let Statement::Select(q) = stmt {
1431            if let ColumnList::Named(exprs) = &q.columns {
1432                // Unary minus folds into the literal
1433                assert!(matches!(&exprs[0], SelectExpr::Expr {
1434                    expr: Expr::Literal(SqlValue::Int(-42)),
1435                    ..
1436                }));
1437            } else {
1438                panic!("Expected Named columns");
1439            }
1440        } else {
1441            panic!("Expected Select");
1442        }
1443    }
1444
1445    #[test]
1446    fn test_where_arithmetic_expr() {
1447        let stmt = parse_query("SELECT * FROM test WHERE a + b > 10").unwrap();
1448        if let Statement::Select(q) = stmt {
1449            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1450                assert_eq!(c.op, CmpOp::Gt);
1451                assert!(matches!(&c.left_expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1452                assert!(matches!(&c.right_expr, Some(Expr::Literal(SqlValue::Int(10)))));
1453            } else {
1454                panic!("Expected comparison");
1455            }
1456        } else {
1457            panic!("Expected Select");
1458        }
1459    }
1460
1461    #[test]
1462    fn test_where_both_sides_expr() {
1463        let stmt = parse_query("SELECT * FROM test WHERE a * 2 > b + 1").unwrap();
1464        if let Statement::Select(q) = stmt {
1465            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1466                assert_eq!(c.op, CmpOp::Gt);
1467                assert!(matches!(&c.left_expr, Some(Expr::BinaryOp { op: ArithOp::Mul, .. })));
1468                assert!(matches!(&c.right_expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1469            } else {
1470                panic!("Expected comparison");
1471            }
1472        } else {
1473            panic!("Expected Select");
1474        }
1475    }
1476
1477    #[test]
1478    fn test_order_by_expr() {
1479        let stmt = parse_query("SELECT * FROM test ORDER BY a + b DESC").unwrap();
1480        if let Statement::Select(q) = stmt {
1481            let ob = q.order_by.unwrap();
1482            assert_eq!(ob.len(), 1);
1483            assert!(ob[0].descending);
1484            assert!(matches!(&ob[0].expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1485        } else {
1486            panic!("Expected Select");
1487        }
1488    }
1489
1490    #[test]
1491    fn test_all_arithmetic_ops() {
1492        let stmt = parse_query("SELECT a + b, a - b, a * b, a / b, a % b FROM test").unwrap();
1493        if let Statement::Select(q) = stmt {
1494            if let ColumnList::Named(exprs) = &q.columns {
1495                assert_eq!(exprs.len(), 5);
1496                assert!(matches!(&exprs[0], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Add, .. }, .. }));
1497                assert!(matches!(&exprs[1], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Sub, .. }, .. }));
1498                assert!(matches!(&exprs[2], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Mul, .. }, .. }));
1499                assert!(matches!(&exprs[3], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Div, .. }, .. }));
1500                assert!(matches!(&exprs[4], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Mod, .. }, .. }));
1501            } else {
1502                panic!("Expected Named columns");
1503            }
1504        } else {
1505            panic!("Expected Select");
1506        }
1507    }
1508
1509    #[test]
1510    fn test_column_with_literal_arithmetic() {
1511        let stmt = parse_query("SELECT count * 2 + 1 FROM test").unwrap();
1512        if let Statement::Select(q) = stmt {
1513            if let ColumnList::Named(exprs) = &q.columns {
1514                // Should be (count * 2) + 1
1515                if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1516                    if let Expr::BinaryOp { left, op, right } = expr {
1517                        assert_eq!(*op, ArithOp::Add);
1518                        assert!(matches!(right.as_ref(), Expr::Literal(SqlValue::Int(1))));
1519                        assert!(matches!(left.as_ref(), Expr::BinaryOp { op: ArithOp::Mul, .. }));
1520                    } else {
1521                        panic!("Expected BinaryOp");
1522                    }
1523                } else {
1524                    panic!("Expected Expr");
1525                }
1526            } else {
1527                panic!("Expected Named columns");
1528            }
1529        } else {
1530            panic!("Expected Select");
1531        }
1532    }
1533
1534    #[test]
1535    fn test_mixed_columns_and_exprs() {
1536        let stmt = parse_query("SELECT title, a + b AS sum, count FROM test").unwrap();
1537        if let Statement::Select(q) = stmt {
1538            if let ColumnList::Named(exprs) = &q.columns {
1539                assert_eq!(exprs.len(), 3);
1540                assert_eq!(exprs[0], SelectExpr::Column("title".into()));
1541                assert!(matches!(&exprs[1], SelectExpr::Expr { alias: Some(a), .. } if a == "sum"));
1542                assert_eq!(exprs[2], SelectExpr::Column("count".into()));
1543            } else {
1544                panic!("Expected Named columns");
1545            }
1546        } else {
1547            panic!("Expected Select");
1548        }
1549    }
1550
1551    // ── CASE WHEN tests ──────────────────────────────────────────
1552
1553    #[test]
1554    fn test_case_when_basic() {
1555        let stmt = parse_query(
1556            "SELECT CASE WHEN status = 'ACTIVE' THEN 1 ELSE 0 END FROM test"
1557        ).unwrap();
1558        if let Statement::Select(q) = stmt {
1559            if let ColumnList::Named(exprs) = &q.columns {
1560                assert_eq!(exprs.len(), 1);
1561                assert!(matches!(&exprs[0], SelectExpr::Expr {
1562                    expr: Expr::Case { .. },
1563                    ..
1564                }));
1565            } else {
1566                panic!("Expected Named columns");
1567            }
1568        } else {
1569            panic!("Expected Select");
1570        }
1571    }
1572
1573    #[test]
1574    fn test_case_when_multiple_branches() {
1575        let stmt = parse_query(
1576            "SELECT CASE WHEN x > 10 THEN 'high' WHEN x > 5 THEN 'mid' ELSE 'low' END FROM test"
1577        ).unwrap();
1578        if let Statement::Select(q) = stmt {
1579            if let ColumnList::Named(exprs) = &q.columns {
1580                if let SelectExpr::Expr { expr: Expr::Case { whens, else_expr }, .. } = &exprs[0] {
1581                    assert_eq!(whens.len(), 2);
1582                    assert!(else_expr.is_some());
1583                } else {
1584                    panic!("Expected Case expression");
1585                }
1586            } else {
1587                panic!("Expected Named columns");
1588            }
1589        } else {
1590            panic!("Expected Select");
1591        }
1592    }
1593
1594    #[test]
1595    fn test_case_when_no_else() {
1596        let stmt = parse_query(
1597            "SELECT CASE WHEN x = 1 THEN 'one' END FROM test"
1598        ).unwrap();
1599        if let Statement::Select(q) = stmt {
1600            if let ColumnList::Named(exprs) = &q.columns {
1601                if let SelectExpr::Expr { expr: Expr::Case { whens, else_expr }, .. } = &exprs[0] {
1602                    assert_eq!(whens.len(), 1);
1603                    assert!(else_expr.is_none());
1604                } else {
1605                    panic!("Expected Case expression");
1606                }
1607            } else {
1608                panic!("Expected Named columns");
1609            }
1610        } else {
1611            panic!("Expected Select");
1612        }
1613    }
1614
1615    #[test]
1616    fn test_case_when_in_aggregate() {
1617        let stmt = parse_query(
1618            "SELECT SUM(CASE WHEN side = 'BUY' THEN size ELSE -size END) AS net FROM orders GROUP BY token"
1619        ).unwrap();
1620        if let Statement::Select(q) = stmt {
1621            if let ColumnList::Named(exprs) = &q.columns {
1622                assert_eq!(exprs.len(), 1);
1623                assert!(matches!(&exprs[0], SelectExpr::Aggregate {
1624                    func: AggFunc::Sum,
1625                    arg_expr: Some(Expr::Case { .. }),
1626                    alias: Some(a),
1627                    ..
1628                } if a == "net"));
1629            } else {
1630                panic!("Expected Named columns");
1631            }
1632        } else {
1633            panic!("Expected Select");
1634        }
1635    }
1636
1637    #[test]
1638    fn test_case_when_with_alias() {
1639        let stmt = parse_query(
1640            "SELECT CASE WHEN x > 0 THEN 'pos' ELSE 'neg' END AS sign FROM test"
1641        ).unwrap();
1642        if let Statement::Select(q) = stmt {
1643            if let ColumnList::Named(exprs) = &q.columns {
1644                assert!(matches!(&exprs[0], SelectExpr::Expr {
1645                    expr: Expr::Case { .. },
1646                    alias: Some(a),
1647                } if a == "sign"));
1648            } else {
1649                panic!("Expected Named columns");
1650            }
1651        } else {
1652            panic!("Expected Select");
1653        }
1654    }
1655
1656    #[test]
1657    fn test_create_view() {
1658        let stmt = parse_query("CREATE VIEW live AS SELECT * FROM strategies WHERE status = 'LIVE'").unwrap();
1659        if let Statement::CreateView(cv) = stmt {
1660            assert_eq!(cv.view_name, "live");
1661            assert!(cv.columns.is_none());
1662            assert_eq!(cv.query.table, "strategies");
1663            assert!(cv.query.where_clause.is_some());
1664        } else {
1665            panic!("Expected CreateView, got {:?}", stmt);
1666        }
1667    }
1668
1669    #[test]
1670    fn test_create_view_with_columns() {
1671        let stmt = parse_query("CREATE VIEW v1 (a, b) AS SELECT title, status FROM t").unwrap();
1672        if let Statement::CreateView(cv) = stmt {
1673            assert_eq!(cv.view_name, "v1");
1674            assert_eq!(cv.columns, Some(vec!["a".into(), "b".into()]));
1675        } else {
1676            panic!("Expected CreateView");
1677        }
1678    }
1679
1680    #[test]
1681    fn test_drop_view() {
1682        let stmt = parse_query("DROP VIEW live").unwrap();
1683        if let Statement::DropView(dv) = stmt {
1684            assert_eq!(dv.view_name, "live");
1685        } else {
1686            panic!("Expected DropView, got {:?}", stmt);
1687        }
1688    }
1689
1690    #[test]
1691    fn test_create_view_case_insensitive() {
1692        let stmt = parse_query("create view My_View as select * from t").unwrap();
1693        if let Statement::CreateView(cv) = stmt {
1694            assert_eq!(cv.view_name, "My_View");
1695        } else {
1696            panic!("Expected CreateView");
1697        }
1698    }
1699
1700    // ── Issue #42: Arithmetic between aggregates in column expressions ──
1701
1702    #[test]
1703    fn test_aggregate_division() {
1704        let stmt = parse_query(
1705            "SELECT token, SUM(sell) / SUM(buy) as ratio FROM orders GROUP BY token"
1706        ).unwrap();
1707        if let Statement::Select(q) = stmt {
1708            assert_eq!(q.group_by, Some(vec!["token".into()]));
1709            if let ColumnList::Named(exprs) = &q.columns {
1710                assert_eq!(exprs.len(), 2);
1711                assert!(exprs[1].is_aggregate());
1712            } else {
1713                panic!("Expected Named columns");
1714            }
1715        } else {
1716            panic!("Expected Select");
1717        }
1718    }
1719
1720    #[test]
1721    fn test_aggregate_subtraction() {
1722        let stmt = parse_query(
1723            "SELECT token, SUM(sell) - SUM(buy) as net FROM orders GROUP BY token"
1724        ).unwrap();
1725        if let Statement::Select(q) = stmt {
1726            if let ColumnList::Named(exprs) = &q.columns {
1727                assert_eq!(exprs[1].output_name(), "net");
1728            }
1729        } else {
1730            panic!("Expected Select");
1731        }
1732    }
1733
1734    #[test]
1735    fn test_create_view_with_arithmetic() {
1736        let stmt = parse_query(
1737            "CREATE VIEW positions AS SELECT token, SUM(sell) / SUM(buy) as ratio FROM orders GROUP BY token"
1738        ).unwrap();
1739        if let Statement::CreateView(cv) = stmt {
1740            assert_eq!(cv.view_name, "positions");
1741        } else {
1742            panic!("Expected CreateView, got {:?}", stmt);
1743        }
1744    }
1745
1746    // ── Issue #43: Subqueries in FROM ──
1747
1748    #[test]
1749    fn test_subquery_in_from() {
1750        let stmt = parse_query(
1751            "SELECT token, sell_size FROM (SELECT token, SUM(size) as sell_size FROM orders GROUP BY token) LIMIT 5"
1752        ).unwrap();
1753        if let Statement::Select(q) = stmt {
1754            assert!(q.subquery.is_some());
1755            assert_eq!(q.limit, Some(5));
1756            let sub = q.subquery.unwrap();
1757            assert_eq!(sub.table, "orders");
1758            assert!(sub.group_by.is_some());
1759        } else {
1760            panic!("Expected Select");
1761        }
1762    }
1763
1764    // ── Issue #44: HAVING in CREATE VIEW ──
1765
1766    #[test]
1767    fn test_create_view_with_having() {
1768        let stmt = parse_query(
1769            "CREATE VIEW positions AS SELECT token, SUM(sell) as sell_size, SUM(buy) as buy_size FROM orders GROUP BY token HAVING sell_size > buy_size"
1770        ).unwrap();
1771        if let Statement::CreateView(cv) = stmt {
1772            assert_eq!(cv.view_name, "positions");
1773            assert!(cv.query.having.is_some());
1774        } else {
1775            panic!("Expected CreateView, got {:?}", stmt);
1776        }
1777    }
1778
1779    // ── Issue #42: Aggregate multiplication ──
1780
1781    #[test]
1782    fn test_aggregate_multiplication() {
1783        let stmt = parse_query(
1784            "SELECT SUM(a) * 2 as doubled FROM test"
1785        ).unwrap();
1786        if let Statement::Select(q) = stmt {
1787            if let ColumnList::Named(exprs) = &q.columns {
1788                assert_eq!(exprs.len(), 1);
1789                assert!(exprs[0].is_aggregate());
1790                assert_eq!(exprs[0].output_name(), "doubled");
1791            } else {
1792                panic!("Expected Named columns");
1793            }
1794        } else {
1795            panic!("Expected Select");
1796        }
1797    }
1798
1799    #[test]
1800    fn test_complex_aggregate_arithmetic() {
1801        let stmt = parse_query(
1802            "SELECT SUM(CASE WHEN side = 'SELL' THEN size ELSE 0 END) / SUM(CASE WHEN side = 'BUY' THEN size ELSE 0 END) as ratio FROM orders GROUP BY token"
1803        ).unwrap();
1804        if let Statement::Select(q) = stmt {
1805            if let ColumnList::Named(exprs) = &q.columns {
1806                assert_eq!(exprs.len(), 1);
1807                assert!(exprs[0].is_aggregate());
1808                assert_eq!(exprs[0].output_name(), "ratio");
1809            } else {
1810                panic!("Expected Named columns");
1811            }
1812            assert_eq!(q.group_by, Some(vec!["token".into()]));
1813        } else {
1814            panic!("Expected Select");
1815        }
1816    }
1817
1818    // ── Issue #43: Subquery with alias and WHERE ──
1819
1820    #[test]
1821    fn test_subquery_with_alias() {
1822        let stmt = parse_query(
1823            "SELECT x FROM (SELECT x FROM t) sub"
1824        ).unwrap();
1825        if let Statement::Select(q) = stmt {
1826            assert!(q.subquery.is_some());
1827            let sub = q.subquery.unwrap();
1828            assert_eq!(sub.table, "t");
1829            if let ColumnList::Named(exprs) = &q.columns {
1830                assert_eq!(exprs.len(), 1);
1831                assert_eq!(exprs[0].output_name(), "x");
1832            } else {
1833                panic!("Expected Named columns");
1834            }
1835        } else {
1836            panic!("Expected Select");
1837        }
1838    }
1839
1840    #[test]
1841    fn test_subquery_with_where() {
1842        let stmt = parse_query(
1843            "SELECT x FROM (SELECT x FROM t WHERE y > 0) LIMIT 5"
1844        ).unwrap();
1845        if let Statement::Select(q) = stmt {
1846            assert!(q.subquery.is_some());
1847            assert_eq!(q.limit, Some(5));
1848            let sub = q.subquery.unwrap();
1849            assert_eq!(sub.table, "t");
1850            assert!(sub.where_clause.is_some());
1851        } else {
1852            panic!("Expected Select");
1853        }
1854    }
1855
1856    // ── Issue #42 + CREATE VIEW: aggregate subtraction in view ──
1857
1858    #[test]
1859    fn test_create_view_aggregate_subtraction() {
1860        let stmt = parse_query(
1861            "CREATE VIEW v AS SELECT token, SUM(sell) - SUM(buy) as net FROM orders GROUP BY token"
1862        ).unwrap();
1863        if let Statement::CreateView(cv) = stmt {
1864            assert_eq!(cv.view_name, "v");
1865            assert_eq!(cv.query.group_by, Some(vec!["token".into()]));
1866            if let ColumnList::Named(exprs) = &cv.query.columns {
1867                assert_eq!(exprs.len(), 2);
1868                assert_eq!(exprs[1].output_name(), "net");
1869                assert!(exprs[1].is_aggregate());
1870            } else {
1871                panic!("Expected Named columns");
1872            }
1873        } else {
1874            panic!("Expected CreateView, got {:?}", stmt);
1875        }
1876    }
1877}