Skip to main content

mdql_core/
query_parser.rs

1//! Hand-written recursive descent parser for the MDQL SQL subset.
2
3use regex::Regex;
4use std::sync::LazyLock;
5
6use crate::errors::MdqlError;
7pub use crate::query_ast::*;
8
9// ── Tokenizer ──────────────────────────────────────────────────────────────
10
11static KEYWORDS: &[&str] = &[
12    "SELECT", "FROM", "WHERE", "AND", "OR", "ORDER", "BY",
13    "ASC", "DESC", "LIMIT", "LIKE", "IN", "IS", "NOT", "NULL",
14    "JOIN", "LEFT", "ON", "AS", "GROUP", "HAVING",
15    "INSERT", "INTO", "VALUES", "UPDATE", "SET", "DELETE",
16    "ALTER", "TABLE", "RENAME", "FIELD", "TO", "DROP", "MERGE", "FIELDS",
17    "CASE", "WHEN", "THEN", "ELSE", "END",
18    "INTERVAL", "DAY", "DAYS", "CURRENT_DATE", "CURRENT_TIMESTAMP", "DATEDIFF",
19    "CREATE", "VIEW", "CASCADE", "RESTRICT",
20];
21
22static AGG_FUNCS: &[&str] = &["COUNT", "SUM", "AVG", "MIN", "MAX"];
23
24static TOKEN_RE: LazyLock<Regex> = LazyLock::new(|| {
25    Regex::new(
26        r#"(?x)
27        \s*(?:
28            (?P<backtick>`[^`]+`)
29            | (?P<string>'(?:[^'\\]|\\.)*')
30            | (?P<number>\d+(?:\.\d+)?)
31            | (?P<op><=|>=|!=|[=<>,*()+\-/%])
32            | (?P<word>[A-Za-z_][A-Za-z0-9_./-]*)
33        )"#,
34    )
35    .unwrap()
36});
37
38#[derive(Debug, Clone)]
39struct Token {
40    token_type: String,
41    value: String,
42    raw: String,
43}
44
45fn tokenize(sql: &str) -> Vec<Token> {
46    let mut tokens = Vec::new();
47    for caps in TOKEN_RE.captures_iter(sql) {
48        if let Some(m) = caps.name("backtick") {
49            let raw = m.as_str();
50            tokens.push(Token {
51                token_type: "ident".into(),
52                value: raw[1..raw.len() - 1].into(),
53                raw: raw.into(),
54            });
55        } else if let Some(m) = caps.name("string") {
56            let raw = m.as_str();
57            tokens.push(Token {
58                token_type: "string".into(),
59                value: raw[1..raw.len() - 1].into(),
60                raw: raw.into(),
61            });
62        } else if let Some(m) = caps.name("number") {
63            let raw = m.as_str();
64            tokens.push(Token {
65                token_type: "number".into(),
66                value: raw.into(),
67                raw: raw.into(),
68            });
69        } else if let Some(m) = caps.name("op") {
70            let raw = m.as_str();
71            tokens.push(Token {
72                token_type: "op".into(),
73                value: raw.into(),
74                raw: raw.into(),
75            });
76        } else if let Some(m) = caps.name("word") {
77            let raw = m.as_str();
78            if KEYWORDS.contains(&raw.to_uppercase().as_str()) {
79                tokens.push(Token {
80                    token_type: "keyword".into(),
81                    value: raw.to_uppercase(),
82                    raw: raw.into(),
83                });
84            } else {
85                tokens.push(Token {
86                    token_type: "ident".into(),
87                    value: raw.into(),
88                    raw: raw.into(),
89                });
90            }
91        }
92    }
93    tokens
94}
95
96// ── Parser ─────────────────────────────────────────────────────────────────
97
98struct Parser {
99    tokens: Vec<Token>,
100    pos: usize,
101}
102
103impl Parser {
104    fn new(tokens: Vec<Token>) -> Self {
105        Parser { tokens, pos: 0 }
106    }
107
108    fn peek(&self) -> Option<&Token> {
109        self.tokens.get(self.pos)
110    }
111
112    fn advance(&mut self) -> Token {
113        let t = self.tokens[self.pos].clone();
114        self.pos += 1;
115        t
116    }
117
118    fn expect(&mut self, type_: &str, value: Option<&str>) -> Result<Token, MdqlError> {
119        let t = self.peek().ok_or_else(|| {
120            MdqlError::QueryParse(format!(
121                "Unexpected end of query, expected {}",
122                value.unwrap_or(type_)
123            ))
124        })?;
125        let matches_type = t.token_type == type_;
126        let matches_value = value.map_or(true, |v| t.value == v);
127        if !matches_type || !matches_value {
128            return Err(MdqlError::QueryParse(format!(
129                "Expected {}, got '{}' at position {}",
130                value.unwrap_or(type_),
131                t.raw,
132                self.pos
133            )));
134        }
135        Ok(self.advance())
136    }
137
138    fn match_keyword(&mut self, kw: &str) -> bool {
139        if let Some(t) = self.peek() {
140            if t.token_type == "keyword" && t.value == kw {
141                self.advance();
142                return true;
143            }
144        }
145        false
146    }
147
148    fn parse_statement(&mut self) -> Result<Statement, MdqlError> {
149        let t = self.peek().ok_or_else(|| MdqlError::QueryParse("Empty query".into()))?;
150        match (t.token_type.as_str(), t.value.as_str()) {
151            ("keyword", "SELECT") => {
152                let q = self.parse_select()?;
153                self.expect_end()?;
154                Ok(Statement::Select(q))
155            }
156            ("keyword", "INSERT") => Ok(Statement::Insert(self.parse_insert()?)),
157            ("keyword", "UPDATE") => Ok(Statement::Update(self.parse_update()?)),
158            ("keyword", "DELETE") => Ok(Statement::Delete(self.parse_delete()?)),
159            ("keyword", "ALTER") => self.parse_alter(),
160            ("keyword", "CREATE") => self.parse_create_view(),
161            ("keyword", "DROP") => self.parse_drop_view(),
162            _ => Err(MdqlError::QueryParse(format!(
163                "Expected SELECT, INSERT, UPDATE, DELETE, ALTER, CREATE, or DROP, got '{}'",
164                t.raw
165            ))),
166        }
167    }
168
169    fn parse_select(&mut self) -> Result<SelectQuery, MdqlError> {
170        self.expect("keyword", Some("SELECT"))?;
171        let columns = self.parse_columns()?;
172        self.expect("keyword", Some("FROM"))?;
173
174        // Subquery: FROM (SELECT ...)
175        let mut subquery = None;
176        let (table, mut table_alias) = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
177            self.advance();
178            let inner = self.parse_select()?;
179            self.expect("op", Some(")"))?;
180            subquery = Some(Box::new(inner));
181            let alias = if let Some(t) = self.peek() {
182                if t.token_type == "ident" && !self.is_clause_keyword(t) {
183                    Some(self.advance().value)
184                } else {
185                    None
186                }
187            } else {
188                None
189            };
190            ("_subquery".to_string(), alias)
191        } else {
192            let t = self.parse_ident()?;
193            (t, None)
194        };
195
196        // Optional table alias (for non-subquery)
197        if subquery.is_none() {
198            if let Some(t) = self.peek() {
199                if t.token_type == "ident" && !self.is_clause_keyword(t) {
200                    table_alias = Some(self.advance().value);
201                }
202            }
203        }
204
205        // Optional JOIN(s)
206        let mut joins = Vec::new();
207        loop {
208            let jt = if self.match_keyword("LEFT") {
209                self.expect("keyword", Some("JOIN"))?;
210                JoinType::Left
211            } else if self.match_keyword("JOIN") {
212                JoinType::Inner
213            } else {
214                break;
215            };
216            let join_table = self.parse_ident()?;
217            let mut join_alias = None;
218            if let Some(t) = self.peek() {
219                if t.token_type == "ident" && !self.is_clause_keyword(t) {
220                    join_alias = Some(self.advance().value);
221                }
222            }
223            self.expect("keyword", Some("ON"))?;
224            let left_col = self.parse_ident()?;
225            self.expect("op", Some("="))?;
226            let right_col = self.parse_ident()?;
227            joins.push(JoinClause {
228                join_type: jt,
229                table: join_table,
230                alias: join_alias,
231                left_col,
232                right_col,
233            });
234        }
235
236        let mut where_clause = None;
237        if self.match_keyword("WHERE") {
238            where_clause = Some(self.parse_or_expr()?);
239        }
240
241        let mut group_by = None;
242        if self.match_keyword("GROUP") {
243            self.expect("keyword", Some("BY"))?;
244            let mut cols = vec![self.parse_ident()?];
245            while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
246                self.advance();
247                cols.push(self.parse_ident()?);
248            }
249            group_by = Some(cols);
250        }
251
252        let mut having = None;
253        if self.match_keyword("HAVING") {
254            having = Some(self.parse_or_expr()?);
255        }
256
257        let mut order_by = None;
258        if self.match_keyword("ORDER") {
259            self.expect("keyword", Some("BY"))?;
260            order_by = Some(self.parse_order_by()?);
261        }
262
263        let mut limit = None;
264        if self.match_keyword("LIMIT") {
265            let t = self.expect("number", None)?;
266            limit = Some(t.value.parse::<i64>().map_err(|_| {
267                MdqlError::QueryParse(format!("Invalid LIMIT value: {}", t.value))
268            })?);
269        }
270
271        Ok(SelectQuery {
272            columns,
273            table,
274            table_alias,
275            subquery,
276            joins,
277            where_clause,
278            group_by,
279            having,
280            order_by,
281            limit,
282        })
283    }
284
285    fn parse_insert(&mut self) -> Result<InsertQuery, MdqlError> {
286        self.expect("keyword", Some("INSERT"))?;
287        self.expect("keyword", Some("INTO"))?;
288        let table = self.parse_ident()?;
289
290        self.expect("op", Some("("))?;
291        let mut columns = vec![self.parse_ident()?];
292        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
293            self.advance();
294            columns.push(self.parse_ident()?);
295        }
296        self.expect("op", Some(")"))?;
297
298        self.expect("keyword", Some("VALUES"))?;
299
300        self.expect("op", Some("("))?;
301        let mut values = vec![self.parse_value()?];
302        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
303            self.advance();
304            values.push(self.parse_value()?);
305        }
306        self.expect("op", Some(")"))?;
307
308        if columns.len() != values.len() {
309            return Err(MdqlError::QueryParse(format!(
310                "Column count ({}) does not match value count ({})",
311                columns.len(),
312                values.len()
313            )));
314        }
315
316        self.expect_end()?;
317        Ok(InsertQuery {
318            table,
319            columns,
320            values,
321        })
322    }
323
324    fn parse_update(&mut self) -> Result<UpdateQuery, MdqlError> {
325        self.expect("keyword", Some("UPDATE"))?;
326        let table = self.parse_ident()?;
327        self.expect("keyword", Some("SET"))?;
328
329        let mut assignments = Vec::new();
330        let col = self.parse_ident()?;
331        self.expect("op", Some("="))?;
332        let val = self.parse_value()?;
333        assignments.push((col, val));
334
335        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
336            self.advance();
337            let col = self.parse_ident()?;
338            self.expect("op", Some("="))?;
339            let val = self.parse_value()?;
340            assignments.push((col, val));
341        }
342
343        let mut where_clause = None;
344        if self.match_keyword("WHERE") {
345            where_clause = Some(self.parse_or_expr()?);
346        }
347
348        self.expect_end()?;
349        Ok(UpdateQuery {
350            table,
351            assignments,
352            where_clause,
353        })
354    }
355
356    fn parse_delete(&mut self) -> Result<DeleteQuery, MdqlError> {
357        self.expect("keyword", Some("DELETE"))?;
358        self.expect("keyword", Some("FROM"))?;
359        let table = self.parse_ident()?;
360
361        let mut where_clause = None;
362        if self.match_keyword("WHERE") {
363            where_clause = Some(self.parse_or_expr()?);
364        }
365
366        self.expect_end()?;
367        Ok(DeleteQuery {
368            table,
369            where_clause,
370        })
371    }
372
373    fn parse_alter(&mut self) -> Result<Statement, MdqlError> {
374        self.expect("keyword", Some("ALTER"))?;
375        self.expect("keyword", Some("TABLE"))?;
376        let table = self.parse_ident()?;
377
378        let t = self.peek().ok_or_else(|| {
379            MdqlError::QueryParse("Expected RENAME, DROP, or MERGE after table name".into())
380        })?;
381
382        match (t.token_type.as_str(), t.value.as_str()) {
383            ("keyword", "RENAME") => {
384                self.advance();
385                self.expect("keyword", Some("FIELD"))?;
386                let old_name = self.parse_string_or_ident()?;
387                self.expect("keyword", Some("TO"))?;
388                let new_name = self.parse_string_or_ident()?;
389                self.expect_end()?;
390                Ok(Statement::AlterRename(AlterRenameFieldQuery {
391                    table,
392                    old_name,
393                    new_name,
394                }))
395            }
396            ("keyword", "DROP") => {
397                self.advance();
398                self.expect("keyword", Some("FIELD"))?;
399                let field_name = self.parse_string_or_ident()?;
400                self.expect_end()?;
401                Ok(Statement::AlterDrop(AlterDropFieldQuery {
402                    table,
403                    field_name,
404                }))
405            }
406            ("keyword", "MERGE") => {
407                self.advance();
408                self.expect("keyword", Some("FIELDS"))?;
409                let mut sources = vec![self.parse_string_or_ident()?];
410                while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
411                    self.advance();
412                    sources.push(self.parse_string_or_ident()?);
413                }
414                self.expect("keyword", Some("INTO"))?;
415                let target = self.parse_string_or_ident()?;
416                self.expect_end()?;
417                Ok(Statement::AlterMerge(AlterMergeFieldsQuery {
418                    table,
419                    sources,
420                    into: target,
421                }))
422            }
423            _ => Err(MdqlError::QueryParse(format!(
424                "Expected RENAME, DROP, or MERGE, got '{}'",
425                t.raw
426            ))),
427        }
428    }
429
430    fn parse_create_view(&mut self) -> Result<Statement, MdqlError> {
431        self.expect("keyword", Some("CREATE"))?;
432        self.expect("keyword", Some("VIEW"))?;
433        let view_name = self.parse_ident()?;
434
435        let columns = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
436            self.advance();
437            let mut cols = vec![self.parse_ident()?];
438            while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
439                self.advance();
440                cols.push(self.parse_ident()?);
441            }
442            self.expect("op", Some(")"))?;
443            Some(cols)
444        } else {
445            None
446        };
447
448        self.expect("keyword", Some("AS"))?;
449        let query = Box::new(self.parse_select()?);
450        self.expect_end()?;
451
452        Ok(Statement::CreateView(CreateViewQuery {
453            view_name,
454            columns,
455            query,
456        }))
457    }
458
459    fn parse_drop_view(&mut self) -> Result<Statement, MdqlError> {
460        self.expect("keyword", Some("DROP"))?;
461        self.expect("keyword", Some("VIEW"))?;
462        let view_name = self.parse_ident()?;
463        self.expect_end()?;
464        Ok(Statement::DropView(DropViewQuery { view_name }))
465    }
466
467    fn parse_string_or_ident(&mut self) -> Result<String, MdqlError> {
468        let t = self.peek().ok_or_else(|| {
469            MdqlError::QueryParse("Expected field name, got end of query".into())
470        })?;
471        match t.token_type.as_str() {
472            "string" => {
473                let v = self.advance().value;
474                Ok(v)
475            }
476            "ident" | "keyword" => {
477                let v = self.advance().value;
478                Ok(v)
479            }
480            _ => Err(MdqlError::QueryParse(format!(
481                "Expected field name, got '{}'",
482                t.raw
483            ))),
484        }
485    }
486
487    fn parse_columns(&mut self) -> Result<ColumnList, MdqlError> {
488        if let Some(t) = self.peek() {
489            if t.token_type == "op" && t.value == "*" {
490                self.advance();
491                return Ok(ColumnList::All);
492            }
493        }
494
495        let mut exprs = vec![self.parse_select_expr()?];
496        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
497            self.advance();
498            exprs.push(self.parse_select_expr()?);
499        }
500        Ok(ColumnList::Named(exprs))
501    }
502
503    fn peek_is_agg_func(&self) -> bool {
504        let t = match self.peek() {
505            Some(t) => t,
506            None => return false,
507        };
508        let name_upper = t.value.to_uppercase();
509        if !AGG_FUNCS.contains(&name_upper.as_str()) {
510            return false;
511        }
512        // Only treat as aggregate if followed by (
513        self.tokens
514            .get(self.pos + 1)
515            .map_or(false, |next| next.token_type == "op" && next.value == "(")
516    }
517
518    fn parse_select_expr(&mut self) -> Result<SelectExpr, MdqlError> {
519        let _t = self.peek().ok_or_else(|| {
520            MdqlError::QueryParse("Expected column or aggregate, got end of query".into())
521        })?;
522
523        let expr = self.parse_additive()?;
524
525        let alias = if self.match_keyword("AS") {
526            Some(self.parse_ident()?)
527        } else if self.peek().map_or(false, |t| {
528            t.token_type == "ident" && !self.is_clause_keyword(t)
529        }) {
530            Some(self.advance().value)
531        } else {
532            None
533        };
534
535        // Bare aggregate → SelectExpr::Aggregate for backward compat
536        if let Expr::Aggregate { func, arg, arg_expr } = expr {
537            return Ok(SelectExpr::Aggregate {
538                func,
539                arg,
540                arg_expr: arg_expr.map(|e| *e),
541                alias,
542            });
543        }
544
545        if alias.is_none() {
546            if let Expr::Column(name) = &expr {
547                return Ok(SelectExpr::Column(name.clone()));
548            }
549        }
550
551        Ok(SelectExpr::Expr { expr, alias })
552    }
553
554    // ── Expression parser (precedence climbing) ───────────────────────
555
556    fn peek_is_additive_op(&self) -> bool {
557        self.peek().map_or(false, |t| {
558            t.token_type == "op" && (t.value == "+" || t.value == "-")
559        })
560    }
561
562    fn peek_is_multiplicative_op(&self) -> bool {
563        self.peek().map_or(false, |t| {
564            t.token_type == "op" && (t.value == "*" || t.value == "/" || t.value == "%")
565        })
566    }
567
568    fn parse_additive(&mut self) -> Result<Expr, MdqlError> {
569        let mut left = self.parse_multiplicative()?;
570        while self.peek_is_additive_op() {
571            let op_tok = self.advance();
572            let is_sub = op_tok.value == "-";
573
574            // Check for INTERVAL keyword: expr +/- INTERVAL n DAY
575            if self.peek().map_or(false, |t| t.token_type == "keyword" && t.value == "INTERVAL") {
576                self.advance(); // consume INTERVAL
577                let days_expr = self.parse_multiplicative()?;
578                // Expect DAY or DAYS
579                if !self.match_keyword("DAY") && !self.match_keyword("DAYS") {
580                    return Err(MdqlError::QueryParse("Expected DAY after INTERVAL value".into()));
581                }
582                let days = if is_sub {
583                    Expr::UnaryMinus(Box::new(days_expr))
584                } else {
585                    days_expr
586                };
587                left = Expr::DateAdd {
588                    date: Box::new(left),
589                    days: Box::new(days),
590                };
591                continue;
592            }
593
594            let op = match op_tok.value.as_str() {
595                "+" => ArithOp::Add,
596                "-" => ArithOp::Sub,
597                _ => unreachable!(),
598            };
599            let right = self.parse_multiplicative()?;
600            left = Expr::BinaryOp {
601                left: Box::new(left),
602                op,
603                right: Box::new(right),
604            };
605        }
606        Ok(left)
607    }
608
609    fn parse_multiplicative(&mut self) -> Result<Expr, MdqlError> {
610        let mut left = self.parse_unary()?;
611        while self.peek_is_multiplicative_op() {
612            let op_tok = self.advance();
613            let op = match op_tok.value.as_str() {
614                "*" => ArithOp::Mul,
615                "/" => ArithOp::Div,
616                "%" => ArithOp::Mod,
617                _ => unreachable!(),
618            };
619            let right = self.parse_unary()?;
620            left = Expr::BinaryOp {
621                left: Box::new(left),
622                op,
623                right: Box::new(right),
624            };
625        }
626        Ok(left)
627    }
628
629    fn parse_unary(&mut self) -> Result<Expr, MdqlError> {
630        if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "-") {
631            self.advance();
632            let inner = self.parse_atom()?;
633            // Fold unary minus on literals
634            match inner {
635                Expr::Literal(SqlValue::Int(n)) => Ok(Expr::Literal(SqlValue::Int(-n))),
636                Expr::Literal(SqlValue::Float(f)) => Ok(Expr::Literal(SqlValue::Float(-f))),
637                _ => Ok(Expr::UnaryMinus(Box::new(inner))),
638            }
639        } else {
640            self.parse_atom()
641        }
642    }
643
644    fn parse_atom(&mut self) -> Result<Expr, MdqlError> {
645        if self.peek_is_agg_func() {
646            return self.parse_agg_expr();
647        }
648
649        let t = self.peek().ok_or_else(|| {
650            MdqlError::QueryParse("Expected expression, got end of query".into())
651        })?;
652
653        match t.token_type.as_str() {
654            "number" => {
655                let v = self.advance().value;
656                if v.contains('.') {
657                    let f: f64 = v.parse().map_err(|_| {
658                        MdqlError::QueryParse(format!("Invalid float: {}", v))
659                    })?;
660                    Ok(Expr::Literal(SqlValue::Float(f)))
661                } else {
662                    let n: i64 = v.parse().map_err(|_| {
663                        MdqlError::QueryParse(format!("Invalid int: {}", v))
664                    })?;
665                    Ok(Expr::Literal(SqlValue::Int(n)))
666                }
667            }
668            "string" => {
669                let v = self.advance().value;
670                Ok(Expr::Literal(SqlValue::String(v)))
671            }
672            "keyword" if t.value == "NULL" => {
673                self.advance();
674                Ok(Expr::Literal(SqlValue::Null))
675            }
676            "keyword" if t.value == "CASE" => {
677                self.parse_case_expr()
678            }
679            "keyword" if t.value == "CURRENT_DATE" => {
680                self.advance();
681                Ok(Expr::CurrentDate)
682            }
683            "keyword" if t.value == "CURRENT_TIMESTAMP" => {
684                self.advance();
685                Ok(Expr::CurrentTimestamp)
686            }
687            "keyword" if t.value == "DATEDIFF" => {
688                self.advance();
689                self.expect("op", Some("("))?;
690                let left = self.parse_additive()?;
691                self.expect("op", Some(","))?;
692                let right = self.parse_additive()?;
693                self.expect("op", Some(")"))?;
694                Ok(Expr::DateDiff { left: Box::new(left), right: Box::new(right) })
695            }
696            "op" if t.value == "(" => {
697                self.advance();
698                let expr = self.parse_additive()?;
699                self.expect("op", Some(")"))?;
700                Ok(expr)
701            }
702            "ident" => {
703                let name = self.advance().value;
704                Ok(Expr::Column(name))
705            }
706            "keyword" if !Self::is_reserved_keyword(&t.value) => {
707                let name = self.advance().value;
708                Ok(Expr::Column(name))
709            }
710            _ => Err(MdqlError::QueryParse(format!(
711                "Expected expression, got '{}'",
712                t.raw
713            ))),
714        }
715    }
716
717    fn parse_case_expr(&mut self) -> Result<Expr, MdqlError> {
718        self.expect("keyword", Some("CASE"))?;
719        let mut whens = Vec::new();
720        while self.match_keyword("WHEN") {
721            let condition = self.parse_or_expr()?;
722            self.expect("keyword", Some("THEN"))?;
723            let result = self.parse_additive()?;
724            whens.push((condition, Box::new(result)));
725        }
726        if whens.is_empty() {
727            return Err(MdqlError::QueryParse("CASE requires at least one WHEN clause".into()));
728        }
729        let else_expr = if self.match_keyword("ELSE") {
730            Some(Box::new(self.parse_additive()?))
731        } else {
732            None
733        };
734        self.expect("keyword", Some("END"))?;
735        Ok(Expr::Case { whens, else_expr })
736    }
737
738    fn parse_agg_expr(&mut self) -> Result<Expr, MdqlError> {
739        let func_name = self.advance().value.to_uppercase();
740        let func = match func_name.as_str() {
741            "COUNT" => AggFunc::Count,
742            "SUM" => AggFunc::Sum,
743            "AVG" => AggFunc::Avg,
744            "MIN" => AggFunc::Min,
745            "MAX" => AggFunc::Max,
746            _ => unreachable!(),
747        };
748        self.expect("op", Some("("))?;
749        let (arg, arg_expr) = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "*") {
750            self.advance();
751            ("*".to_string(), None)
752        } else {
753            let expr = self.parse_additive()?;
754            if let Expr::Column(name) = &expr {
755                (name.clone(), None)
756            } else {
757                (expr.display_name(), Some(Box::new(expr)))
758            }
759        };
760        self.expect("op", Some(")"))?;
761        Ok(Expr::Aggregate { func, arg, arg_expr })
762    }
763
764    fn parse_ident(&mut self) -> Result<String, MdqlError> {
765        let t = self.peek().ok_or_else(|| {
766            MdqlError::QueryParse("Expected identifier, got end of query".into())
767        })?;
768        match t.token_type.as_str() {
769            "ident" | "keyword" => {
770                let v = self.advance().value;
771                Ok(v)
772            }
773            _ => Err(MdqlError::QueryParse(format!(
774                "Expected identifier, got '{}'",
775                t.raw
776            ))),
777        }
778    }
779
780    fn parse_or_expr(&mut self) -> Result<WhereClause, MdqlError> {
781        let mut left = self.parse_and_expr()?;
782        while self.match_keyword("OR") {
783            let right = self.parse_and_expr()?;
784            left = WhereClause::BoolOp(BoolOp {
785                op: BoolOpKind::Or,
786                left: Box::new(left),
787                right: Box::new(right),
788            });
789        }
790        Ok(left)
791    }
792
793    fn parse_and_expr(&mut self) -> Result<WhereClause, MdqlError> {
794        let mut left = self.parse_comparison()?;
795        while self.match_keyword("AND") {
796            let right = self.parse_comparison()?;
797            left = WhereClause::BoolOp(BoolOp {
798                op: BoolOpKind::And,
799                left: Box::new(left),
800                right: Box::new(right),
801            });
802        }
803        Ok(left)
804    }
805
806    fn parse_comparison(&mut self) -> Result<WhereClause, MdqlError> {
807        // Handle parenthesized boolean expressions
808        if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
809            // Save position — might be arithmetic parens, not boolean
810            let saved_pos = self.pos;
811            self.advance();
812            // Try parsing as boolean (OR/AND) expression
813            let result = self.parse_or_expr();
814            if result.is_ok() && self.peek().map_or(false, |t| t.token_type == "op" && t.value == ")") {
815                self.advance();
816                return result;
817            }
818            // Not a boolean paren — rewind and parse as arithmetic expression
819            self.pos = saved_pos;
820        }
821
822        // Parse the left side as a full expression (column, literal, or arithmetic)
823        let left_expr = self.parse_additive()?;
824
825        // Extract column name for backward compat (simple column on left side)
826        let col = left_expr.as_column().unwrap_or("").to_string();
827
828        // IS NULL / IS NOT NULL (only valid with simple column)
829        if self.match_keyword("IS") {
830            if self.match_keyword("NOT") {
831                self.expect("keyword", Some("NULL"))?;
832                return Ok(WhereClause::Comparison(Comparison {
833                    column: col,
834                    op: CmpOp::IsNotNull,
835                    value: None,
836                    left_expr: Some(left_expr),
837                    right_expr: None,
838                }));
839            }
840            self.expect("keyword", Some("NULL"))?;
841            return Ok(WhereClause::Comparison(Comparison {
842                column: col,
843                op: CmpOp::IsNull,
844                value: None,
845                left_expr: Some(left_expr),
846                right_expr: None,
847            }));
848        }
849
850        // IN (val, val, ...)
851        if self.match_keyword("IN") {
852            self.expect("op", Some("("))?;
853            let mut values = vec![self.parse_value()?];
854            while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
855                self.advance();
856                values.push(self.parse_value()?);
857            }
858            self.expect("op", Some(")"))?;
859            return Ok(WhereClause::Comparison(Comparison {
860                column: col,
861                op: CmpOp::In,
862                value: Some(SqlValue::List(values)),
863                left_expr: Some(left_expr),
864                right_expr: None,
865            }));
866        }
867
868        // LIKE
869        if self.match_keyword("LIKE") {
870            let val = self.parse_value()?;
871            return Ok(WhereClause::Comparison(Comparison {
872                column: col,
873                op: CmpOp::Like,
874                value: Some(val),
875                left_expr: Some(left_expr),
876                right_expr: None,
877            }));
878        }
879
880        // NOT LIKE
881        if self.match_keyword("NOT") {
882            if self.match_keyword("LIKE") {
883                let val = self.parse_value()?;
884                return Ok(WhereClause::Comparison(Comparison {
885                    column: col,
886                    op: CmpOp::NotLike,
887                    value: Some(val),
888                    left_expr: Some(left_expr),
889                    right_expr: None,
890                }));
891            }
892            return Err(MdqlError::QueryParse("Expected LIKE after NOT".into()));
893        }
894
895        // Standard comparison operators
896        if let Some(t) = self.peek() {
897            if t.token_type == "op" && ["=", "!=", "<", ">", "<=", ">="].contains(&t.value.as_str())
898            {
899                let op_str = self.advance().value;
900                let op = match op_str.as_str() {
901                    "=" => CmpOp::Eq,
902                    "!=" => CmpOp::Ne,
903                    "<" => CmpOp::Lt,
904                    ">" => CmpOp::Gt,
905                    "<=" => CmpOp::Le,
906                    ">=" => CmpOp::Ge,
907                    _ => unreachable!(),
908                };
909                // Parse right side as expression
910                let right_expr = self.parse_additive()?;
911                // Extract SqlValue for backward compat (simple literal on right side)
912                let value = match &right_expr {
913                    Expr::Literal(v) => Some(v.clone()),
914                    _ => None,
915                };
916                return Ok(WhereClause::Comparison(Comparison {
917                    column: col,
918                    op,
919                    value,
920                    left_expr: Some(left_expr),
921                    right_expr: Some(right_expr),
922                }));
923            }
924        }
925
926        let got = self.peek().map_or("end".to_string(), |t| t.raw.clone());
927        Err(MdqlError::QueryParse(format!(
928            "Expected operator after '{}', got '{}'",
929            left_expr.display_name(), got
930        )))
931    }
932
933    fn parse_value(&mut self) -> Result<SqlValue, MdqlError> {
934        let t = self.peek().ok_or_else(|| {
935            MdqlError::QueryParse("Expected value, got end of query".into())
936        })?;
937        match t.token_type.as_str() {
938            "string" => {
939                let v = self.advance().value;
940                Ok(SqlValue::String(v))
941            }
942            "number" => {
943                let v = self.advance().value;
944                if v.contains('.') {
945                    Ok(SqlValue::Float(v.parse().map_err(|_| {
946                        MdqlError::QueryParse(format!("Invalid float: {}", v))
947                    })?))
948                } else {
949                    Ok(SqlValue::Int(v.parse().map_err(|_| {
950                        MdqlError::QueryParse(format!("Invalid int: {}", v))
951                    })?))
952                }
953            }
954            "keyword" if t.value == "NULL" => {
955                self.advance();
956                Ok(SqlValue::Null)
957            }
958            _ => Err(MdqlError::QueryParse(format!(
959                "Expected value, got '{}'",
960                t.raw
961            ))),
962        }
963    }
964
965    fn parse_order_by(&mut self) -> Result<Vec<OrderSpec>, MdqlError> {
966        let mut specs = vec![self.parse_order_spec()?];
967        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
968            self.advance();
969            specs.push(self.parse_order_spec()?);
970        }
971        Ok(specs)
972    }
973
974    fn parse_order_spec(&mut self) -> Result<OrderSpec, MdqlError> {
975        let expr = self.parse_additive()?;
976        let col = expr.as_column().unwrap_or("").to_string();
977        let descending = if self.match_keyword("DESC") {
978            true
979        } else {
980            self.match_keyword("ASC");
981            false
982        };
983        Ok(OrderSpec {
984            column: col,
985            expr: Some(expr),
986            descending,
987        })
988    }
989
990    fn is_clause_keyword(&self, t: &Token) -> bool {
991        t.token_type == "keyword"
992            && ["WHERE", "ORDER", "LIMIT", "JOIN", "LEFT", "ON", "GROUP"].contains(&t.value.as_str())
993    }
994
995    /// Keywords that should never be consumed as column names inside expressions.
996    fn is_reserved_keyword(kw: &str) -> bool {
997        matches!(kw,
998            "AS" | "FROM" | "WHERE" | "AND" | "OR" | "ORDER" | "BY"
999            | "ASC" | "DESC" | "LIMIT" | "JOIN" | "ON" | "GROUP"
1000            | "SELECT" | "INSERT" | "INTO" | "VALUES" | "UPDATE" | "SET"
1001            | "DELETE" | "ALTER" | "TABLE" | "IS" | "NOT" | "IN" | "LIKE"
1002            | "RENAME" | "FIELD" | "TO" | "DROP" | "MERGE" | "FIELDS"
1003            | "CASE" | "WHEN" | "THEN" | "ELSE" | "END"
1004            | "HAVING" | "INTERVAL" | "DAY" | "DAYS"
1005            | "CURRENT_DATE" | "CURRENT_TIMESTAMP" | "DATEDIFF"
1006            | "CREATE" | "VIEW" | "CASCADE" | "RESTRICT"
1007        )
1008    }
1009
1010    fn expect_end(&self) -> Result<(), MdqlError> {
1011        if let Some(t) = self.peek() {
1012            return Err(MdqlError::QueryParse(format!(
1013                "Unexpected token '{}' at position {}",
1014                t.raw, self.pos
1015            )));
1016        }
1017        Ok(())
1018    }
1019}
1020
1021pub fn parse_query(sql: &str) -> crate::errors::Result<Statement> {
1022    let tokens = tokenize(sql);
1023    if tokens.is_empty() {
1024        return Err(MdqlError::QueryParse("Empty query".into()));
1025    }
1026    let mut parser = Parser::new(tokens);
1027    parser.parse_statement()
1028}
1029
1030#[cfg(test)]
1031mod tests {
1032    use super::*;
1033
1034    #[test]
1035    fn test_simple_select() {
1036        let stmt = parse_query("SELECT title, status FROM strategies").unwrap();
1037        if let Statement::Select(q) = stmt {
1038            assert_eq!(q.columns, ColumnList::Named(vec![SelectExpr::Column("title".into()), SelectExpr::Column("status".into())]));
1039            assert_eq!(q.table, "strategies");
1040        } else {
1041            panic!("Expected Select");
1042        }
1043    }
1044
1045    #[test]
1046    fn test_select_star() {
1047        let stmt = parse_query("SELECT * FROM test").unwrap();
1048        if let Statement::Select(q) = stmt {
1049            assert_eq!(q.columns, ColumnList::All);
1050        } else {
1051            panic!("Expected Select");
1052        }
1053    }
1054
1055    #[test]
1056    fn test_where_clause() {
1057        let stmt = parse_query("SELECT title FROM test WHERE count > 5").unwrap();
1058        if let Statement::Select(q) = stmt {
1059            assert!(q.where_clause.is_some());
1060        } else {
1061            panic!("Expected Select");
1062        }
1063    }
1064
1065    #[test]
1066    fn test_order_by() {
1067        let stmt =
1068            parse_query("SELECT title FROM test ORDER BY composite DESC, title ASC").unwrap();
1069        if let Statement::Select(q) = stmt {
1070            let ob = q.order_by.unwrap();
1071            assert_eq!(ob.len(), 2);
1072            assert!(ob[0].descending);
1073            assert!(!ob[1].descending);
1074        } else {
1075            panic!("Expected Select");
1076        }
1077    }
1078
1079    #[test]
1080    fn test_limit() {
1081        let stmt = parse_query("SELECT * FROM test LIMIT 10").unwrap();
1082        if let Statement::Select(q) = stmt {
1083            assert_eq!(q.limit, Some(10));
1084        } else {
1085            panic!("Expected Select");
1086        }
1087    }
1088
1089    #[test]
1090    fn test_insert() {
1091        let stmt = parse_query(
1092            "INSERT INTO test (title, count) VALUES ('Hello', 42)",
1093        )
1094        .unwrap();
1095        if let Statement::Insert(q) = stmt {
1096            assert_eq!(q.table, "test");
1097            assert_eq!(q.columns, vec!["title", "count"]);
1098            assert_eq!(q.values[0], SqlValue::String("Hello".into()));
1099            assert_eq!(q.values[1], SqlValue::Int(42));
1100        } else {
1101            panic!("Expected Insert");
1102        }
1103    }
1104
1105    #[test]
1106    fn test_update() {
1107        let stmt = parse_query("UPDATE test SET status = 'KILLED' WHERE path = 'a.md'").unwrap();
1108        if let Statement::Update(q) = stmt {
1109            assert_eq!(q.table, "test");
1110            assert_eq!(q.assignments.len(), 1);
1111            assert!(q.where_clause.is_some());
1112        } else {
1113            panic!("Expected Update");
1114        }
1115    }
1116
1117    #[test]
1118    fn test_delete() {
1119        let stmt = parse_query("DELETE FROM test WHERE status = 'draft'").unwrap();
1120        if let Statement::Delete(q) = stmt {
1121            assert_eq!(q.table, "test");
1122            assert!(q.where_clause.is_some());
1123        } else {
1124            panic!("Expected Delete");
1125        }
1126    }
1127
1128    #[test]
1129    fn test_alter_rename() {
1130        let stmt =
1131            parse_query("ALTER TABLE test RENAME FIELD 'Summary' TO 'Overview'").unwrap();
1132        if let Statement::AlterRename(q) = stmt {
1133            assert_eq!(q.old_name, "Summary");
1134            assert_eq!(q.new_name, "Overview");
1135        } else {
1136            panic!("Expected AlterRename");
1137        }
1138    }
1139
1140    #[test]
1141    fn test_alter_drop() {
1142        let stmt = parse_query("ALTER TABLE test DROP FIELD 'Details'").unwrap();
1143        if let Statement::AlterDrop(q) = stmt {
1144            assert_eq!(q.field_name, "Details");
1145        } else {
1146            panic!("Expected AlterDrop");
1147        }
1148    }
1149
1150    #[test]
1151    fn test_alter_merge() {
1152        let stmt = parse_query(
1153            "ALTER TABLE test MERGE FIELDS 'Entry Rules', 'Exit Rules' INTO 'Trading Rules'",
1154        )
1155        .unwrap();
1156        if let Statement::AlterMerge(q) = stmt {
1157            assert_eq!(q.sources, vec!["Entry Rules", "Exit Rules"]);
1158            assert_eq!(q.into, "Trading Rules");
1159        } else {
1160            panic!("Expected AlterMerge");
1161        }
1162    }
1163
1164    #[test]
1165    fn test_backtick_ident() {
1166        let stmt = parse_query("SELECT `Structural Mechanism` FROM test").unwrap();
1167        if let Statement::Select(q) = stmt {
1168            assert_eq!(
1169                q.columns,
1170                ColumnList::Named(vec![SelectExpr::Column("Structural Mechanism".into())])
1171            );
1172        } else {
1173            panic!("Expected Select");
1174        }
1175    }
1176
1177    #[test]
1178    fn test_like_operator() {
1179        let stmt = parse_query("SELECT title FROM test WHERE categories LIKE '%defi%'").unwrap();
1180        if let Statement::Select(q) = stmt {
1181            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1182                assert_eq!(c.op, CmpOp::Like);
1183                assert_eq!(c.value, Some(SqlValue::String("%defi%".into())));
1184            } else {
1185                panic!("Expected LIKE comparison");
1186            }
1187        } else {
1188            panic!("Expected Select");
1189        }
1190    }
1191
1192    #[test]
1193    fn test_in_operator() {
1194        let stmt =
1195            parse_query("SELECT * FROM test WHERE status IN ('ACTIVE', 'LIVE')").unwrap();
1196        if let Statement::Select(q) = stmt {
1197            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1198                assert_eq!(c.op, CmpOp::In);
1199            } else {
1200                panic!("Expected IN comparison");
1201            }
1202        } else {
1203            panic!("Expected Select");
1204        }
1205    }
1206
1207    #[test]
1208    fn test_is_null() {
1209        let stmt = parse_query("SELECT * FROM test WHERE title IS NULL").unwrap();
1210        if let Statement::Select(q) = stmt {
1211            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1212                assert_eq!(c.op, CmpOp::IsNull);
1213            } else {
1214                panic!("Expected IS NULL comparison");
1215            }
1216        } else {
1217            panic!("Expected Select");
1218        }
1219    }
1220
1221    #[test]
1222    fn test_and_or() {
1223        let stmt = parse_query(
1224            "SELECT * FROM test WHERE status = 'ACTIVE' AND count > 5 OR title LIKE '%test%'",
1225        )
1226        .unwrap();
1227        if let Statement::Select(q) = stmt {
1228            assert!(q.where_clause.is_some());
1229        } else {
1230            panic!("Expected Select");
1231        }
1232    }
1233
1234    #[test]
1235    fn test_join() {
1236        let stmt = parse_query(
1237            "SELECT s.title, b.sharpe FROM strategies s JOIN backtests b ON b.strategy = s.path",
1238        )
1239        .unwrap();
1240        if let Statement::Select(q) = stmt {
1241            assert_eq!(q.table, "strategies");
1242            assert_eq!(q.table_alias, Some("s".into()));
1243            assert_eq!(q.joins.len(), 1);
1244            let join = &q.joins[0];
1245            assert_eq!(join.table, "backtests");
1246            assert_eq!(join.alias, Some("b".into()));
1247        } else {
1248            panic!("Expected Select");
1249        }
1250    }
1251
1252    #[test]
1253    fn test_multi_join() {
1254        let stmt = parse_query(
1255            "SELECT s.title, b.sharpe, c.verdict FROM strategies s JOIN backtests b ON b.strategy = s.path JOIN critiques c ON c.strategy = s.path",
1256        )
1257        .unwrap();
1258        if let Statement::Select(q) = stmt {
1259            assert_eq!(q.table, "strategies");
1260            assert_eq!(q.table_alias, Some("s".into()));
1261            assert_eq!(q.joins.len(), 2);
1262            assert_eq!(q.joins[0].table, "backtests");
1263            assert_eq!(q.joins[0].alias, Some("b".into()));
1264            assert_eq!(q.joins[0].left_col, "b.strategy");
1265            assert_eq!(q.joins[0].right_col, "s.path");
1266            assert_eq!(q.joins[1].table, "critiques");
1267            assert_eq!(q.joins[1].alias, Some("c".into()));
1268            assert_eq!(q.joins[1].left_col, "c.strategy");
1269            assert_eq!(q.joins[1].right_col, "s.path");
1270        } else {
1271            panic!("Expected Select");
1272        }
1273    }
1274
1275    #[test]
1276    fn test_left_join() {
1277        let stmt = parse_query(
1278            "SELECT s.title, b.sharpe FROM strategies s LEFT JOIN backtests b ON b.strategy = s.path",
1279        )
1280        .unwrap();
1281        if let Statement::Select(q) = stmt {
1282            assert_eq!(q.joins.len(), 1);
1283            assert_eq!(q.joins[0].join_type, JoinType::Left);
1284            assert_eq!(q.joins[0].table, "backtests");
1285        } else {
1286            panic!("Expected Select");
1287        }
1288    }
1289
1290    #[test]
1291    fn test_mixed_join_types() {
1292        let stmt = parse_query(
1293            "SELECT s.title FROM strategies s JOIN backtests b ON b.strategy = s.path LEFT JOIN allocations a ON a.strategy = s.path",
1294        )
1295        .unwrap();
1296        if let Statement::Select(q) = stmt {
1297            assert_eq!(q.joins.len(), 2);
1298            assert_eq!(q.joins[0].join_type, JoinType::Inner);
1299            assert_eq!(q.joins[1].join_type, JoinType::Left);
1300        } else {
1301            panic!("Expected Select");
1302        }
1303    }
1304
1305    #[test]
1306    fn test_empty_query() {
1307        assert!(parse_query("").is_err());
1308    }
1309
1310    #[test]
1311    fn test_count_star() {
1312        let stmt = parse_query("SELECT status, COUNT(*) AS cnt FROM strategies GROUP BY status").unwrap();
1313        if let Statement::Select(q) = stmt {
1314            if let ColumnList::Named(exprs) = &q.columns {
1315                assert_eq!(exprs.len(), 2);
1316                assert_eq!(exprs[0], SelectExpr::Column("status".into()));
1317                assert!(matches!(&exprs[1], SelectExpr::Aggregate {
1318                    func: AggFunc::Count,
1319                    arg,
1320                    alias: Some(a),
1321                    ..
1322                } if arg == "*" && a == "cnt"));
1323            } else {
1324                panic!("Expected Named columns");
1325            }
1326            assert_eq!(q.group_by, Some(vec!["status".into()]));
1327        } else {
1328            panic!("Expected Select");
1329        }
1330    }
1331
1332    #[test]
1333    fn test_count_column_as_ident() {
1334        // "count" as a column name should NOT be parsed as the COUNT aggregate
1335        let stmt = parse_query("INSERT INTO test (title, count) VALUES ('Hello', 42)").unwrap();
1336        if let Statement::Insert(q) = stmt {
1337            assert_eq!(q.columns, vec!["title", "count"]);
1338        } else {
1339            panic!("Expected Insert");
1340        }
1341    }
1342
1343    #[test]
1344    fn test_multiple_aggregates() {
1345        let stmt = parse_query("SELECT MIN(composite), MAX(composite), AVG(composite) FROM strategies").unwrap();
1346        if let Statement::Select(q) = stmt {
1347            if let ColumnList::Named(exprs) = &q.columns {
1348                assert_eq!(exprs.len(), 3);
1349                assert!(matches!(&exprs[0], SelectExpr::Aggregate { func: AggFunc::Min, .. }));
1350                assert!(matches!(&exprs[1], SelectExpr::Aggregate { func: AggFunc::Max, .. }));
1351                assert!(matches!(&exprs[2], SelectExpr::Aggregate { func: AggFunc::Avg, .. }));
1352            } else {
1353                panic!("Expected Named columns");
1354            }
1355            assert_eq!(q.group_by, None);
1356        } else {
1357            panic!("Expected Select");
1358        }
1359    }
1360
1361    // ── Expression tests ──────────────────────────────────────────
1362
1363    #[test]
1364    fn test_select_arithmetic_expr() {
1365        let stmt = parse_query("SELECT a + b FROM test").unwrap();
1366        if let Statement::Select(q) = stmt {
1367            if let ColumnList::Named(exprs) = &q.columns {
1368                assert_eq!(exprs.len(), 1);
1369                assert!(matches!(&exprs[0], SelectExpr::Expr {
1370                    expr: Expr::BinaryOp { op: ArithOp::Add, .. },
1371                    alias: None,
1372                }));
1373            } else {
1374                panic!("Expected Named columns");
1375            }
1376        } else {
1377            panic!("Expected Select");
1378        }
1379    }
1380
1381    #[test]
1382    fn test_select_arithmetic_with_alias() {
1383        let stmt = parse_query("SELECT a + b AS total FROM test").unwrap();
1384        if let Statement::Select(q) = stmt {
1385            if let ColumnList::Named(exprs) = &q.columns {
1386                assert_eq!(exprs.len(), 1);
1387                assert!(matches!(&exprs[0], SelectExpr::Expr {
1388                    alias: Some(a),
1389                    ..
1390                } if a == "total"));
1391                assert_eq!(exprs[0].output_name(), "total");
1392            } else {
1393                panic!("Expected Named columns");
1394            }
1395        } else {
1396            panic!("Expected Select");
1397        }
1398    }
1399
1400    #[test]
1401    fn test_select_precedence() {
1402        // a + b * c should parse as a + (b * c)
1403        let stmt = parse_query("SELECT a + b * c FROM test").unwrap();
1404        if let Statement::Select(q) = stmt {
1405            if let ColumnList::Named(exprs) = &q.columns {
1406                if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1407                    if let Expr::BinaryOp { left, op, right } = expr {
1408                        assert_eq!(*op, ArithOp::Add);
1409                        assert!(matches!(left.as_ref(), Expr::Column(n) if n == "a"));
1410                        assert!(matches!(right.as_ref(), Expr::BinaryOp { op: ArithOp::Mul, .. }));
1411                    } else {
1412                        panic!("Expected BinaryOp");
1413                    }
1414                } else {
1415                    panic!("Expected Expr variant");
1416                }
1417            } else {
1418                panic!("Expected Named columns");
1419            }
1420        } else {
1421            panic!("Expected Select");
1422        }
1423    }
1424
1425    #[test]
1426    fn test_select_parenthesized_expr() {
1427        // (a + b) * c should override default precedence
1428        let stmt = parse_query("SELECT (a + b) * c FROM test").unwrap();
1429        if let Statement::Select(q) = stmt {
1430            if let ColumnList::Named(exprs) = &q.columns {
1431                if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1432                    if let Expr::BinaryOp { left, op, .. } = expr {
1433                        assert_eq!(*op, ArithOp::Mul);
1434                        assert!(matches!(left.as_ref(), Expr::BinaryOp { op: ArithOp::Add, .. }));
1435                    } else {
1436                        panic!("Expected BinaryOp");
1437                    }
1438                } else {
1439                    panic!("Expected Expr variant");
1440                }
1441            } else {
1442                panic!("Expected Named columns");
1443            }
1444        } else {
1445            panic!("Expected Select");
1446        }
1447    }
1448
1449    #[test]
1450    fn test_select_unary_minus() {
1451        let stmt = parse_query("SELECT -count FROM test").unwrap();
1452        if let Statement::Select(q) = stmt {
1453            if let ColumnList::Named(exprs) = &q.columns {
1454                assert!(matches!(&exprs[0], SelectExpr::Expr {
1455                    expr: Expr::UnaryMinus(_),
1456                    ..
1457                }));
1458            } else {
1459                panic!("Expected Named columns");
1460            }
1461        } else {
1462            panic!("Expected Select");
1463        }
1464    }
1465
1466    #[test]
1467    fn test_select_negative_literal() {
1468        let stmt = parse_query("SELECT -42 FROM test").unwrap();
1469        if let Statement::Select(q) = stmt {
1470            if let ColumnList::Named(exprs) = &q.columns {
1471                // Unary minus folds into the literal
1472                assert!(matches!(&exprs[0], SelectExpr::Expr {
1473                    expr: Expr::Literal(SqlValue::Int(-42)),
1474                    ..
1475                }));
1476            } else {
1477                panic!("Expected Named columns");
1478            }
1479        } else {
1480            panic!("Expected Select");
1481        }
1482    }
1483
1484    #[test]
1485    fn test_where_arithmetic_expr() {
1486        let stmt = parse_query("SELECT * FROM test WHERE a + b > 10").unwrap();
1487        if let Statement::Select(q) = stmt {
1488            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1489                assert_eq!(c.op, CmpOp::Gt);
1490                assert!(matches!(&c.left_expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1491                assert!(matches!(&c.right_expr, Some(Expr::Literal(SqlValue::Int(10)))));
1492            } else {
1493                panic!("Expected comparison");
1494            }
1495        } else {
1496            panic!("Expected Select");
1497        }
1498    }
1499
1500    #[test]
1501    fn test_where_both_sides_expr() {
1502        let stmt = parse_query("SELECT * FROM test WHERE a * 2 > b + 1").unwrap();
1503        if let Statement::Select(q) = stmt {
1504            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1505                assert_eq!(c.op, CmpOp::Gt);
1506                assert!(matches!(&c.left_expr, Some(Expr::BinaryOp { op: ArithOp::Mul, .. })));
1507                assert!(matches!(&c.right_expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1508            } else {
1509                panic!("Expected comparison");
1510            }
1511        } else {
1512            panic!("Expected Select");
1513        }
1514    }
1515
1516    #[test]
1517    fn test_order_by_expr() {
1518        let stmt = parse_query("SELECT * FROM test ORDER BY a + b DESC").unwrap();
1519        if let Statement::Select(q) = stmt {
1520            let ob = q.order_by.unwrap();
1521            assert_eq!(ob.len(), 1);
1522            assert!(ob[0].descending);
1523            assert!(matches!(&ob[0].expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1524        } else {
1525            panic!("Expected Select");
1526        }
1527    }
1528
1529    #[test]
1530    fn test_all_arithmetic_ops() {
1531        let stmt = parse_query("SELECT a + b, a - b, a * b, a / b, a % b FROM test").unwrap();
1532        if let Statement::Select(q) = stmt {
1533            if let ColumnList::Named(exprs) = &q.columns {
1534                assert_eq!(exprs.len(), 5);
1535                assert!(matches!(&exprs[0], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Add, .. }, .. }));
1536                assert!(matches!(&exprs[1], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Sub, .. }, .. }));
1537                assert!(matches!(&exprs[2], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Mul, .. }, .. }));
1538                assert!(matches!(&exprs[3], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Div, .. }, .. }));
1539                assert!(matches!(&exprs[4], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Mod, .. }, .. }));
1540            } else {
1541                panic!("Expected Named columns");
1542            }
1543        } else {
1544            panic!("Expected Select");
1545        }
1546    }
1547
1548    #[test]
1549    fn test_column_with_literal_arithmetic() {
1550        let stmt = parse_query("SELECT count * 2 + 1 FROM test").unwrap();
1551        if let Statement::Select(q) = stmt {
1552            if let ColumnList::Named(exprs) = &q.columns {
1553                // Should be (count * 2) + 1
1554                if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1555                    if let Expr::BinaryOp { left, op, right } = expr {
1556                        assert_eq!(*op, ArithOp::Add);
1557                        assert!(matches!(right.as_ref(), Expr::Literal(SqlValue::Int(1))));
1558                        assert!(matches!(left.as_ref(), Expr::BinaryOp { op: ArithOp::Mul, .. }));
1559                    } else {
1560                        panic!("Expected BinaryOp");
1561                    }
1562                } else {
1563                    panic!("Expected Expr");
1564                }
1565            } else {
1566                panic!("Expected Named columns");
1567            }
1568        } else {
1569            panic!("Expected Select");
1570        }
1571    }
1572
1573    #[test]
1574    fn test_mixed_columns_and_exprs() {
1575        let stmt = parse_query("SELECT title, a + b AS sum, count FROM test").unwrap();
1576        if let Statement::Select(q) = stmt {
1577            if let ColumnList::Named(exprs) = &q.columns {
1578                assert_eq!(exprs.len(), 3);
1579                assert_eq!(exprs[0], SelectExpr::Column("title".into()));
1580                assert!(matches!(&exprs[1], SelectExpr::Expr { alias: Some(a), .. } if a == "sum"));
1581                assert_eq!(exprs[2], SelectExpr::Column("count".into()));
1582            } else {
1583                panic!("Expected Named columns");
1584            }
1585        } else {
1586            panic!("Expected Select");
1587        }
1588    }
1589
1590    // ── CASE WHEN tests ──────────────────────────────────────────
1591
1592    #[test]
1593    fn test_case_when_basic() {
1594        let stmt = parse_query(
1595            "SELECT CASE WHEN status = 'ACTIVE' THEN 1 ELSE 0 END FROM test"
1596        ).unwrap();
1597        if let Statement::Select(q) = stmt {
1598            if let ColumnList::Named(exprs) = &q.columns {
1599                assert_eq!(exprs.len(), 1);
1600                assert!(matches!(&exprs[0], SelectExpr::Expr {
1601                    expr: Expr::Case { .. },
1602                    ..
1603                }));
1604            } else {
1605                panic!("Expected Named columns");
1606            }
1607        } else {
1608            panic!("Expected Select");
1609        }
1610    }
1611
1612    #[test]
1613    fn test_case_when_multiple_branches() {
1614        let stmt = parse_query(
1615            "SELECT CASE WHEN x > 10 THEN 'high' WHEN x > 5 THEN 'mid' ELSE 'low' END FROM test"
1616        ).unwrap();
1617        if let Statement::Select(q) = stmt {
1618            if let ColumnList::Named(exprs) = &q.columns {
1619                if let SelectExpr::Expr { expr: Expr::Case { whens, else_expr }, .. } = &exprs[0] {
1620                    assert_eq!(whens.len(), 2);
1621                    assert!(else_expr.is_some());
1622                } else {
1623                    panic!("Expected Case expression");
1624                }
1625            } else {
1626                panic!("Expected Named columns");
1627            }
1628        } else {
1629            panic!("Expected Select");
1630        }
1631    }
1632
1633    #[test]
1634    fn test_case_when_no_else() {
1635        let stmt = parse_query(
1636            "SELECT CASE WHEN x = 1 THEN 'one' END FROM test"
1637        ).unwrap();
1638        if let Statement::Select(q) = stmt {
1639            if let ColumnList::Named(exprs) = &q.columns {
1640                if let SelectExpr::Expr { expr: Expr::Case { whens, else_expr }, .. } = &exprs[0] {
1641                    assert_eq!(whens.len(), 1);
1642                    assert!(else_expr.is_none());
1643                } else {
1644                    panic!("Expected Case expression");
1645                }
1646            } else {
1647                panic!("Expected Named columns");
1648            }
1649        } else {
1650            panic!("Expected Select");
1651        }
1652    }
1653
1654    #[test]
1655    fn test_case_when_in_aggregate() {
1656        let stmt = parse_query(
1657            "SELECT SUM(CASE WHEN side = 'BUY' THEN size ELSE -size END) AS net FROM orders GROUP BY token"
1658        ).unwrap();
1659        if let Statement::Select(q) = stmt {
1660            if let ColumnList::Named(exprs) = &q.columns {
1661                assert_eq!(exprs.len(), 1);
1662                assert!(matches!(&exprs[0], SelectExpr::Aggregate {
1663                    func: AggFunc::Sum,
1664                    arg_expr: Some(Expr::Case { .. }),
1665                    alias: Some(a),
1666                    ..
1667                } if a == "net"));
1668            } else {
1669                panic!("Expected Named columns");
1670            }
1671        } else {
1672            panic!("Expected Select");
1673        }
1674    }
1675
1676    #[test]
1677    fn test_case_when_with_alias() {
1678        let stmt = parse_query(
1679            "SELECT CASE WHEN x > 0 THEN 'pos' ELSE 'neg' END AS sign FROM test"
1680        ).unwrap();
1681        if let Statement::Select(q) = stmt {
1682            if let ColumnList::Named(exprs) = &q.columns {
1683                assert!(matches!(&exprs[0], SelectExpr::Expr {
1684                    expr: Expr::Case { .. },
1685                    alias: Some(a),
1686                } if a == "sign"));
1687            } else {
1688                panic!("Expected Named columns");
1689            }
1690        } else {
1691            panic!("Expected Select");
1692        }
1693    }
1694
1695    #[test]
1696    fn test_create_view() {
1697        let stmt = parse_query("CREATE VIEW live AS SELECT * FROM strategies WHERE status = 'LIVE'").unwrap();
1698        if let Statement::CreateView(cv) = stmt {
1699            assert_eq!(cv.view_name, "live");
1700            assert!(cv.columns.is_none());
1701            assert_eq!(cv.query.table, "strategies");
1702            assert!(cv.query.where_clause.is_some());
1703        } else {
1704            panic!("Expected CreateView, got {:?}", stmt);
1705        }
1706    }
1707
1708    #[test]
1709    fn test_create_view_with_columns() {
1710        let stmt = parse_query("CREATE VIEW v1 (a, b) AS SELECT title, status FROM t").unwrap();
1711        if let Statement::CreateView(cv) = stmt {
1712            assert_eq!(cv.view_name, "v1");
1713            assert_eq!(cv.columns, Some(vec!["a".into(), "b".into()]));
1714        } else {
1715            panic!("Expected CreateView");
1716        }
1717    }
1718
1719    #[test]
1720    fn test_drop_view() {
1721        let stmt = parse_query("DROP VIEW live").unwrap();
1722        if let Statement::DropView(dv) = stmt {
1723            assert_eq!(dv.view_name, "live");
1724        } else {
1725            panic!("Expected DropView, got {:?}", stmt);
1726        }
1727    }
1728
1729    #[test]
1730    fn test_create_view_case_insensitive() {
1731        let stmt = parse_query("create view My_View as select * from t").unwrap();
1732        if let Statement::CreateView(cv) = stmt {
1733            assert_eq!(cv.view_name, "My_View");
1734        } else {
1735            panic!("Expected CreateView");
1736        }
1737    }
1738
1739    // ── Issue #42: Arithmetic between aggregates in column expressions ──
1740
1741    #[test]
1742    fn test_aggregate_division() {
1743        let stmt = parse_query(
1744            "SELECT token, SUM(sell) / SUM(buy) as ratio FROM orders GROUP BY token"
1745        ).unwrap();
1746        if let Statement::Select(q) = stmt {
1747            assert_eq!(q.group_by, Some(vec!["token".into()]));
1748            if let ColumnList::Named(exprs) = &q.columns {
1749                assert_eq!(exprs.len(), 2);
1750                assert!(exprs[1].is_aggregate());
1751            } else {
1752                panic!("Expected Named columns");
1753            }
1754        } else {
1755            panic!("Expected Select");
1756        }
1757    }
1758
1759    #[test]
1760    fn test_aggregate_subtraction() {
1761        let stmt = parse_query(
1762            "SELECT token, SUM(sell) - SUM(buy) as net FROM orders GROUP BY token"
1763        ).unwrap();
1764        if let Statement::Select(q) = stmt {
1765            if let ColumnList::Named(exprs) = &q.columns {
1766                assert_eq!(exprs[1].output_name(), "net");
1767            }
1768        } else {
1769            panic!("Expected Select");
1770        }
1771    }
1772
1773    #[test]
1774    fn test_create_view_with_arithmetic() {
1775        let stmt = parse_query(
1776            "CREATE VIEW positions AS SELECT token, SUM(sell) / SUM(buy) as ratio FROM orders GROUP BY token"
1777        ).unwrap();
1778        if let Statement::CreateView(cv) = stmt {
1779            assert_eq!(cv.view_name, "positions");
1780        } else {
1781            panic!("Expected CreateView, got {:?}", stmt);
1782        }
1783    }
1784
1785    // ── Issue #43: Subqueries in FROM ──
1786
1787    #[test]
1788    fn test_subquery_in_from() {
1789        let stmt = parse_query(
1790            "SELECT token, sell_size FROM (SELECT token, SUM(size) as sell_size FROM orders GROUP BY token) LIMIT 5"
1791        ).unwrap();
1792        if let Statement::Select(q) = stmt {
1793            assert!(q.subquery.is_some());
1794            assert_eq!(q.limit, Some(5));
1795            let sub = q.subquery.unwrap();
1796            assert_eq!(sub.table, "orders");
1797            assert!(sub.group_by.is_some());
1798        } else {
1799            panic!("Expected Select");
1800        }
1801    }
1802
1803    // ── Issue #44: HAVING in CREATE VIEW ──
1804
1805    #[test]
1806    fn test_create_view_with_having() {
1807        let stmt = parse_query(
1808            "CREATE VIEW positions AS SELECT token, SUM(sell) as sell_size, SUM(buy) as buy_size FROM orders GROUP BY token HAVING sell_size > buy_size"
1809        ).unwrap();
1810        if let Statement::CreateView(cv) = stmt {
1811            assert_eq!(cv.view_name, "positions");
1812            assert!(cv.query.having.is_some());
1813        } else {
1814            panic!("Expected CreateView, got {:?}", stmt);
1815        }
1816    }
1817
1818    // ── Issue #42: Aggregate multiplication ──
1819
1820    #[test]
1821    fn test_aggregate_multiplication() {
1822        let stmt = parse_query(
1823            "SELECT SUM(a) * 2 as doubled FROM test"
1824        ).unwrap();
1825        if let Statement::Select(q) = stmt {
1826            if let ColumnList::Named(exprs) = &q.columns {
1827                assert_eq!(exprs.len(), 1);
1828                assert!(exprs[0].is_aggregate());
1829                assert_eq!(exprs[0].output_name(), "doubled");
1830            } else {
1831                panic!("Expected Named columns");
1832            }
1833        } else {
1834            panic!("Expected Select");
1835        }
1836    }
1837
1838    #[test]
1839    fn test_complex_aggregate_arithmetic() {
1840        let stmt = parse_query(
1841            "SELECT SUM(CASE WHEN side = 'SELL' THEN size ELSE 0 END) / SUM(CASE WHEN side = 'BUY' THEN size ELSE 0 END) as ratio FROM orders GROUP BY token"
1842        ).unwrap();
1843        if let Statement::Select(q) = stmt {
1844            if let ColumnList::Named(exprs) = &q.columns {
1845                assert_eq!(exprs.len(), 1);
1846                assert!(exprs[0].is_aggregate());
1847                assert_eq!(exprs[0].output_name(), "ratio");
1848            } else {
1849                panic!("Expected Named columns");
1850            }
1851            assert_eq!(q.group_by, Some(vec!["token".into()]));
1852        } else {
1853            panic!("Expected Select");
1854        }
1855    }
1856
1857    // ── Issue #43: Subquery with alias and WHERE ──
1858
1859    #[test]
1860    fn test_subquery_with_alias() {
1861        let stmt = parse_query(
1862            "SELECT x FROM (SELECT x FROM t) sub"
1863        ).unwrap();
1864        if let Statement::Select(q) = stmt {
1865            assert!(q.subquery.is_some());
1866            let sub = q.subquery.unwrap();
1867            assert_eq!(sub.table, "t");
1868            if let ColumnList::Named(exprs) = &q.columns {
1869                assert_eq!(exprs.len(), 1);
1870                assert_eq!(exprs[0].output_name(), "x");
1871            } else {
1872                panic!("Expected Named columns");
1873            }
1874        } else {
1875            panic!("Expected Select");
1876        }
1877    }
1878
1879    #[test]
1880    fn test_subquery_with_where() {
1881        let stmt = parse_query(
1882            "SELECT x FROM (SELECT x FROM t WHERE y > 0) LIMIT 5"
1883        ).unwrap();
1884        if let Statement::Select(q) = stmt {
1885            assert!(q.subquery.is_some());
1886            assert_eq!(q.limit, Some(5));
1887            let sub = q.subquery.unwrap();
1888            assert_eq!(sub.table, "t");
1889            assert!(sub.where_clause.is_some());
1890        } else {
1891            panic!("Expected Select");
1892        }
1893    }
1894
1895    // ── Issue #42 + CREATE VIEW: aggregate subtraction in view ──
1896
1897    #[test]
1898    fn test_create_view_aggregate_subtraction() {
1899        let stmt = parse_query(
1900            "CREATE VIEW v AS SELECT token, SUM(sell) - SUM(buy) as net FROM orders GROUP BY token"
1901        ).unwrap();
1902        if let Statement::CreateView(cv) = stmt {
1903            assert_eq!(cv.view_name, "v");
1904            assert_eq!(cv.query.group_by, Some(vec!["token".into()]));
1905            if let ColumnList::Named(exprs) = &cv.query.columns {
1906                assert_eq!(exprs.len(), 2);
1907                assert_eq!(exprs[1].output_name(), "net");
1908                assert!(exprs[1].is_aggregate());
1909            } else {
1910                panic!("Expected Named columns");
1911            }
1912        } else {
1913            panic!("Expected CreateView, got {:?}", stmt);
1914        }
1915    }
1916}