Skip to main content

mdql_core/
query_parser.rs

1//! Hand-written recursive descent parser for the MDQL SQL subset.
2
3use regex::Regex;
4use std::sync::LazyLock;
5
6use crate::errors::MdqlError;
7pub use crate::query_ast::*;
8
9// ── Tokenizer ──────────────────────────────────────────────────────────────
10
11static KEYWORDS: &[&str] = &[
12    "SELECT", "FROM", "WHERE", "AND", "OR", "ORDER", "BY",
13    "ASC", "DESC", "LIMIT", "LIKE", "IN", "IS", "NOT", "NULL",
14    "JOIN", "LEFT", "ON", "AS", "GROUP", "HAVING",
15    "INSERT", "INTO", "VALUES", "UPDATE", "SET", "DELETE",
16    "ALTER", "TABLE", "RENAME", "FIELD", "TO", "DROP", "MERGE", "FIELDS",
17    "CASE", "WHEN", "THEN", "ELSE", "END",
18    "INTERVAL", "DAY", "DAYS", "CURRENT_DATE", "CURRENT_TIMESTAMP", "DATEDIFF",
19    "CREATE", "VIEW", "CASCADE", "RESTRICT",
20];
21
22static AGG_FUNCS: &[&str] = &["COUNT", "SUM", "AVG", "MIN", "MAX"];
23
24static TOKEN_RE: LazyLock<Regex> = LazyLock::new(|| {
25    Regex::new(
26        r#"(?x)
27        \s*(?:
28            (?P<backtick>`[^`]+`)
29            | (?P<string>'(?:[^'\\]|\\.)*')
30            | (?P<number>\d+(?:\.\d+)?)
31            | (?P<op><=|>=|!=|[=<>,*()+\-/%])
32            | (?P<word>[A-Za-z_][A-Za-z0-9_./-]*)
33        )"#,
34    )
35    .unwrap()
36});
37
38#[derive(Debug, Clone)]
39struct Token {
40    token_type: String,
41    value: String,
42    raw: String,
43}
44
45fn tokenize(sql: &str) -> Vec<Token> {
46    let mut tokens = Vec::new();
47    for caps in TOKEN_RE.captures_iter(sql) {
48        if let Some(m) = caps.name("backtick") {
49            let raw = m.as_str();
50            tokens.push(Token {
51                token_type: "ident".into(),
52                value: raw[1..raw.len() - 1].into(),
53                raw: raw.into(),
54            });
55        } else if let Some(m) = caps.name("string") {
56            let raw = m.as_str();
57            tokens.push(Token {
58                token_type: "string".into(),
59                value: raw[1..raw.len() - 1].into(),
60                raw: raw.into(),
61            });
62        } else if let Some(m) = caps.name("number") {
63            let raw = m.as_str();
64            tokens.push(Token {
65                token_type: "number".into(),
66                value: raw.into(),
67                raw: raw.into(),
68            });
69        } else if let Some(m) = caps.name("op") {
70            let raw = m.as_str();
71            tokens.push(Token {
72                token_type: "op".into(),
73                value: raw.into(),
74                raw: raw.into(),
75            });
76        } else if let Some(m) = caps.name("word") {
77            let raw = m.as_str();
78            if KEYWORDS.contains(&raw.to_uppercase().as_str()) {
79                tokens.push(Token {
80                    token_type: "keyword".into(),
81                    value: raw.to_uppercase(),
82                    raw: raw.into(),
83                });
84            } else {
85                tokens.push(Token {
86                    token_type: "ident".into(),
87                    value: raw.into(),
88                    raw: raw.into(),
89                });
90            }
91        }
92    }
93    tokens
94}
95
96// ── Parser ─────────────────────────────────────────────────────────────────
97
98struct Parser {
99    tokens: Vec<Token>,
100    pos: usize,
101}
102
103impl Parser {
104    fn new(tokens: Vec<Token>) -> Self {
105        Parser { tokens, pos: 0 }
106    }
107
108    fn peek(&self) -> Option<&Token> {
109        self.tokens.get(self.pos)
110    }
111
112    fn advance(&mut self) -> Token {
113        let t = self.tokens[self.pos].clone();
114        self.pos += 1;
115        t
116    }
117
118    fn expect(&mut self, type_: &str, value: Option<&str>) -> Result<Token, MdqlError> {
119        let t = self.peek().ok_or_else(|| {
120            MdqlError::QueryParse(format!(
121                "Unexpected end of query, expected {}",
122                value.unwrap_or(type_)
123            ))
124        })?;
125        let matches_type = t.token_type == type_;
126        let matches_value = value.map_or(true, |v| t.value == v);
127        if !matches_type || !matches_value {
128            return Err(MdqlError::QueryParse(format!(
129                "Expected {}, got '{}' at position {}",
130                value.unwrap_or(type_),
131                t.raw,
132                self.pos
133            )));
134        }
135        Ok(self.advance())
136    }
137
138    fn match_keyword(&mut self, kw: &str) -> bool {
139        if let Some(t) = self.peek() {
140            if t.token_type == "keyword" && t.value == kw {
141                self.advance();
142                return true;
143            }
144        }
145        false
146    }
147
148    fn parse_statement(&mut self) -> Result<Statement, MdqlError> {
149        let t = self.peek().ok_or_else(|| MdqlError::QueryParse("Empty query".into()))?;
150        match (t.token_type.as_str(), t.value.as_str()) {
151            ("keyword", "SELECT") => {
152                let q = self.parse_select()?;
153                self.expect_end()?;
154                Ok(Statement::Select(q))
155            }
156            ("keyword", "INSERT") => Ok(Statement::Insert(self.parse_insert()?)),
157            ("keyword", "UPDATE") => Ok(Statement::Update(self.parse_update()?)),
158            ("keyword", "DELETE") => Ok(Statement::Delete(self.parse_delete()?)),
159            ("keyword", "ALTER") => self.parse_alter(),
160            ("keyword", "CREATE") => self.parse_create_view(),
161            ("keyword", "DROP") => self.parse_drop_view(),
162            _ => Err(MdqlError::QueryParse(format!(
163                "Expected SELECT, INSERT, UPDATE, DELETE, ALTER, CREATE, or DROP, got '{}'",
164                t.raw
165            ))),
166        }
167    }
168
169    fn parse_select(&mut self) -> Result<SelectQuery, MdqlError> {
170        self.expect("keyword", Some("SELECT"))?;
171        let columns = self.parse_columns()?;
172        self.expect("keyword", Some("FROM"))?;
173
174        // Subquery: FROM (SELECT ...)
175        let mut subquery = None;
176        let (table, mut table_alias) = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
177            self.advance();
178            let inner = self.parse_select()?;
179            self.expect("op", Some(")"))?;
180            subquery = Some(Box::new(inner));
181            let alias = if let Some(t) = self.peek() {
182                if t.token_type == "ident" && !self.is_clause_keyword(t) {
183                    Some(self.advance().value)
184                } else {
185                    None
186                }
187            } else {
188                None
189            };
190            ("_subquery".to_string(), alias)
191        } else {
192            let t = self.parse_ident()?;
193            (t, None)
194        };
195
196        // Optional table alias (for non-subquery)
197        if subquery.is_none() {
198            if let Some(t) = self.peek() {
199                if t.token_type == "ident" && !self.is_clause_keyword(t) {
200                    table_alias = Some(self.advance().value);
201                }
202            }
203        }
204
205        // Optional JOIN(s)
206        let mut joins = Vec::new();
207        loop {
208            let jt = if self.match_keyword("LEFT") {
209                self.expect("keyword", Some("JOIN"))?;
210                JoinType::Left
211            } else if self.match_keyword("JOIN") {
212                JoinType::Inner
213            } else {
214                break;
215            };
216            let join_table = self.parse_ident()?;
217            let mut join_alias = None;
218            if let Some(t) = self.peek() {
219                if t.token_type == "ident" && !self.is_clause_keyword(t) {
220                    join_alias = Some(self.advance().value);
221                }
222            }
223            self.expect("keyword", Some("ON"))?;
224            let condition = self.parse_or_expr()?;
225            joins.push(JoinClause {
226                join_type: jt,
227                table: join_table,
228                alias: join_alias,
229                condition,
230            });
231        }
232
233        let mut where_clause = None;
234        if self.match_keyword("WHERE") {
235            where_clause = Some(self.parse_or_expr()?);
236        }
237
238        let mut group_by = None;
239        if self.match_keyword("GROUP") {
240            self.expect("keyword", Some("BY"))?;
241            let mut cols = vec![self.parse_ident()?];
242            while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
243                self.advance();
244                cols.push(self.parse_ident()?);
245            }
246            group_by = Some(cols);
247        }
248
249        let mut having = None;
250        if self.match_keyword("HAVING") {
251            having = Some(self.parse_or_expr()?);
252        }
253
254        let mut order_by = None;
255        if self.match_keyword("ORDER") {
256            self.expect("keyword", Some("BY"))?;
257            order_by = Some(self.parse_order_by()?);
258        }
259
260        let mut limit = None;
261        if self.match_keyword("LIMIT") {
262            let t = self.expect("number", None)?;
263            limit = Some(t.value.parse::<i64>().map_err(|_| {
264                MdqlError::QueryParse(format!("Invalid LIMIT value: {}", t.value))
265            })?);
266        }
267
268        Ok(SelectQuery {
269            columns,
270            table,
271            table_alias,
272            subquery,
273            joins,
274            where_clause,
275            group_by,
276            having,
277            order_by,
278            limit,
279        })
280    }
281
282    fn parse_insert(&mut self) -> Result<InsertQuery, MdqlError> {
283        self.expect("keyword", Some("INSERT"))?;
284        self.expect("keyword", Some("INTO"))?;
285        let table = self.parse_ident()?;
286
287        self.expect("op", Some("("))?;
288        let mut columns = vec![self.parse_ident()?];
289        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
290            self.advance();
291            columns.push(self.parse_ident()?);
292        }
293        self.expect("op", Some(")"))?;
294
295        self.expect("keyword", Some("VALUES"))?;
296
297        self.expect("op", Some("("))?;
298        let mut values = vec![self.parse_value()?];
299        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
300            self.advance();
301            values.push(self.parse_value()?);
302        }
303        self.expect("op", Some(")"))?;
304
305        if columns.len() != values.len() {
306            return Err(MdqlError::QueryParse(format!(
307                "Column count ({}) does not match value count ({})",
308                columns.len(),
309                values.len()
310            )));
311        }
312
313        self.expect_end()?;
314        Ok(InsertQuery {
315            table,
316            columns,
317            values,
318        })
319    }
320
321    fn parse_update(&mut self) -> Result<UpdateQuery, MdqlError> {
322        self.expect("keyword", Some("UPDATE"))?;
323        let table = self.parse_ident()?;
324        self.expect("keyword", Some("SET"))?;
325
326        let mut assignments = Vec::new();
327        let col = self.parse_ident()?;
328        self.expect("op", Some("="))?;
329        let val = self.parse_value()?;
330        assignments.push((col, val));
331
332        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
333            self.advance();
334            let col = self.parse_ident()?;
335            self.expect("op", Some("="))?;
336            let val = self.parse_value()?;
337            assignments.push((col, val));
338        }
339
340        let mut where_clause = None;
341        if self.match_keyword("WHERE") {
342            where_clause = Some(self.parse_or_expr()?);
343        }
344
345        self.expect_end()?;
346        Ok(UpdateQuery {
347            table,
348            assignments,
349            where_clause,
350        })
351    }
352
353    fn parse_delete(&mut self) -> Result<DeleteQuery, MdqlError> {
354        self.expect("keyword", Some("DELETE"))?;
355        self.expect("keyword", Some("FROM"))?;
356        let table = self.parse_ident()?;
357
358        let mut where_clause = None;
359        if self.match_keyword("WHERE") {
360            where_clause = Some(self.parse_or_expr()?);
361        }
362
363        let mode = if self.match_keyword("CASCADE") {
364            DeleteMode::Cascade
365        } else if self.match_keyword("RESTRICT") {
366            DeleteMode::Restrict
367        } else {
368            DeleteMode::Default
369        };
370
371        self.expect_end()?;
372        Ok(DeleteQuery {
373            table,
374            where_clause,
375            mode,
376        })
377    }
378
379    fn parse_alter(&mut self) -> Result<Statement, MdqlError> {
380        self.expect("keyword", Some("ALTER"))?;
381        self.expect("keyword", Some("TABLE"))?;
382        let table = self.parse_ident()?;
383
384        let t = self.peek().ok_or_else(|| {
385            MdqlError::QueryParse("Expected RENAME, DROP, or MERGE after table name".into())
386        })?;
387
388        match (t.token_type.as_str(), t.value.as_str()) {
389            ("keyword", "RENAME") => {
390                self.advance();
391                self.expect("keyword", Some("FIELD"))?;
392                let old_name = self.parse_string_or_ident()?;
393                self.expect("keyword", Some("TO"))?;
394                let new_name = self.parse_string_or_ident()?;
395                self.expect_end()?;
396                Ok(Statement::AlterRename(AlterRenameFieldQuery {
397                    table,
398                    old_name,
399                    new_name,
400                }))
401            }
402            ("keyword", "DROP") => {
403                self.advance();
404                self.expect("keyword", Some("FIELD"))?;
405                let field_name = self.parse_string_or_ident()?;
406                self.expect_end()?;
407                Ok(Statement::AlterDrop(AlterDropFieldQuery {
408                    table,
409                    field_name,
410                }))
411            }
412            ("keyword", "MERGE") => {
413                self.advance();
414                self.expect("keyword", Some("FIELDS"))?;
415                let mut sources = vec![self.parse_string_or_ident()?];
416                while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
417                    self.advance();
418                    sources.push(self.parse_string_or_ident()?);
419                }
420                self.expect("keyword", Some("INTO"))?;
421                let target = self.parse_string_or_ident()?;
422                self.expect_end()?;
423                Ok(Statement::AlterMerge(AlterMergeFieldsQuery {
424                    table,
425                    sources,
426                    into: target,
427                }))
428            }
429            _ => Err(MdqlError::QueryParse(format!(
430                "Expected RENAME, DROP, or MERGE, got '{}'",
431                t.raw
432            ))),
433        }
434    }
435
436    fn parse_create_view(&mut self) -> Result<Statement, MdqlError> {
437        self.expect("keyword", Some("CREATE"))?;
438        self.expect("keyword", Some("VIEW"))?;
439        let view_name = self.parse_ident()?;
440
441        let columns = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
442            self.advance();
443            let mut cols = vec![self.parse_ident()?];
444            while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
445                self.advance();
446                cols.push(self.parse_ident()?);
447            }
448            self.expect("op", Some(")"))?;
449            Some(cols)
450        } else {
451            None
452        };
453
454        self.expect("keyword", Some("AS"))?;
455        let query = Box::new(self.parse_select()?);
456        self.expect_end()?;
457
458        Ok(Statement::CreateView(CreateViewQuery {
459            view_name,
460            columns,
461            query,
462        }))
463    }
464
465    fn parse_drop_view(&mut self) -> Result<Statement, MdqlError> {
466        self.expect("keyword", Some("DROP"))?;
467        self.expect("keyword", Some("VIEW"))?;
468        let view_name = self.parse_ident()?;
469        self.expect_end()?;
470        Ok(Statement::DropView(DropViewQuery { view_name }))
471    }
472
473    fn parse_string_or_ident(&mut self) -> Result<String, MdqlError> {
474        let t = self.peek().ok_or_else(|| {
475            MdqlError::QueryParse("Expected field name, got end of query".into())
476        })?;
477        match t.token_type.as_str() {
478            "string" => {
479                let v = self.advance().value;
480                Ok(v)
481            }
482            "ident" | "keyword" => {
483                let v = self.advance().value;
484                Ok(v)
485            }
486            _ => Err(MdqlError::QueryParse(format!(
487                "Expected field name, got '{}'",
488                t.raw
489            ))),
490        }
491    }
492
493    fn parse_columns(&mut self) -> Result<ColumnList, MdqlError> {
494        if let Some(t) = self.peek() {
495            if t.token_type == "op" && t.value == "*" {
496                self.advance();
497                return Ok(ColumnList::All);
498            }
499        }
500
501        let mut exprs = vec![self.parse_select_expr()?];
502        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
503            self.advance();
504            exprs.push(self.parse_select_expr()?);
505        }
506        Ok(ColumnList::Named(exprs))
507    }
508
509    fn peek_is_agg_func(&self) -> bool {
510        let t = match self.peek() {
511            Some(t) => t,
512            None => return false,
513        };
514        let name_upper = t.value.to_uppercase();
515        if !AGG_FUNCS.contains(&name_upper.as_str()) {
516            return false;
517        }
518        // Only treat as aggregate if followed by (
519        self.tokens
520            .get(self.pos + 1)
521            .map_or(false, |next| next.token_type == "op" && next.value == "(")
522    }
523
524    fn parse_select_expr(&mut self) -> Result<SelectExpr, MdqlError> {
525        let _t = self.peek().ok_or_else(|| {
526            MdqlError::QueryParse("Expected column or aggregate, got end of query".into())
527        })?;
528
529        let expr = self.parse_additive()?;
530
531        let alias = if self.match_keyword("AS") {
532            Some(self.parse_ident()?)
533        } else if self.peek().map_or(false, |t| {
534            t.token_type == "ident" && !self.is_clause_keyword(t)
535        }) {
536            Some(self.advance().value)
537        } else {
538            None
539        };
540
541        // Bare aggregate → SelectExpr::Aggregate for backward compat
542        if let Expr::Aggregate { func, arg, arg_expr } = expr {
543            return Ok(SelectExpr::Aggregate {
544                func,
545                arg,
546                arg_expr: arg_expr.map(|e| *e),
547                alias,
548            });
549        }
550
551        if alias.is_none() {
552            if let Expr::Column(name) = &expr {
553                return Ok(SelectExpr::Column(name.clone()));
554            }
555        }
556
557        Ok(SelectExpr::Expr { expr, alias })
558    }
559
560    // ── Expression parser (precedence climbing) ───────────────────────
561
562    fn peek_is_additive_op(&self) -> bool {
563        self.peek().map_or(false, |t| {
564            t.token_type == "op" && (t.value == "+" || t.value == "-")
565        })
566    }
567
568    fn peek_is_multiplicative_op(&self) -> bool {
569        self.peek().map_or(false, |t| {
570            t.token_type == "op" && (t.value == "*" || t.value == "/" || t.value == "%")
571        })
572    }
573
574    fn parse_additive(&mut self) -> Result<Expr, MdqlError> {
575        let mut left = self.parse_multiplicative()?;
576        while self.peek_is_additive_op() {
577            let op_tok = self.advance();
578            let is_sub = op_tok.value == "-";
579
580            // Check for INTERVAL keyword: expr +/- INTERVAL n DAY
581            if self.peek().map_or(false, |t| t.token_type == "keyword" && t.value == "INTERVAL") {
582                self.advance(); // consume INTERVAL
583                let days_expr = self.parse_multiplicative()?;
584                // Expect DAY or DAYS
585                if !self.match_keyword("DAY") && !self.match_keyword("DAYS") {
586                    return Err(MdqlError::QueryParse("Expected DAY after INTERVAL value".into()));
587                }
588                let days = if is_sub {
589                    Expr::UnaryMinus(Box::new(days_expr))
590                } else {
591                    days_expr
592                };
593                left = Expr::DateAdd {
594                    date: Box::new(left),
595                    days: Box::new(days),
596                };
597                continue;
598            }
599
600            let op = match op_tok.value.as_str() {
601                "+" => ArithOp::Add,
602                "-" => ArithOp::Sub,
603                _ => unreachable!(),
604            };
605            let right = self.parse_multiplicative()?;
606            left = Expr::BinaryOp {
607                left: Box::new(left),
608                op,
609                right: Box::new(right),
610            };
611        }
612        Ok(left)
613    }
614
615    fn parse_multiplicative(&mut self) -> Result<Expr, MdqlError> {
616        let mut left = self.parse_unary()?;
617        while self.peek_is_multiplicative_op() {
618            let op_tok = self.advance();
619            let op = match op_tok.value.as_str() {
620                "*" => ArithOp::Mul,
621                "/" => ArithOp::Div,
622                "%" => ArithOp::Mod,
623                _ => unreachable!(),
624            };
625            let right = self.parse_unary()?;
626            left = Expr::BinaryOp {
627                left: Box::new(left),
628                op,
629                right: Box::new(right),
630            };
631        }
632        Ok(left)
633    }
634
635    fn parse_unary(&mut self) -> Result<Expr, MdqlError> {
636        if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "-") {
637            self.advance();
638            let inner = self.parse_atom()?;
639            // Fold unary minus on literals
640            match inner {
641                Expr::Literal(SqlValue::Int(n)) => Ok(Expr::Literal(SqlValue::Int(-n))),
642                Expr::Literal(SqlValue::Float(f)) => Ok(Expr::Literal(SqlValue::Float(-f))),
643                _ => Ok(Expr::UnaryMinus(Box::new(inner))),
644            }
645        } else {
646            self.parse_atom()
647        }
648    }
649
650    fn parse_atom(&mut self) -> Result<Expr, MdqlError> {
651        if self.peek_is_agg_func() {
652            return self.parse_agg_expr();
653        }
654
655        let t = self.peek().ok_or_else(|| {
656            MdqlError::QueryParse("Expected expression, got end of query".into())
657        })?;
658
659        match t.token_type.as_str() {
660            "number" => {
661                let v = self.advance().value;
662                if v.contains('.') {
663                    let f: f64 = v.parse().map_err(|_| {
664                        MdqlError::QueryParse(format!("Invalid float: {}", v))
665                    })?;
666                    Ok(Expr::Literal(SqlValue::Float(f)))
667                } else {
668                    let n: i64 = v.parse().map_err(|_| {
669                        MdqlError::QueryParse(format!("Invalid int: {}", v))
670                    })?;
671                    Ok(Expr::Literal(SqlValue::Int(n)))
672                }
673            }
674            "string" => {
675                let v = self.advance().value;
676                Ok(Expr::Literal(SqlValue::String(v)))
677            }
678            "keyword" if t.value == "NULL" => {
679                self.advance();
680                Ok(Expr::Literal(SqlValue::Null))
681            }
682            "keyword" if t.value == "CASE" => {
683                self.parse_case_expr()
684            }
685            "keyword" if t.value == "CURRENT_DATE" => {
686                self.advance();
687                Ok(Expr::CurrentDate)
688            }
689            "keyword" if t.value == "CURRENT_TIMESTAMP" => {
690                self.advance();
691                Ok(Expr::CurrentTimestamp)
692            }
693            "keyword" if t.value == "DATEDIFF" => {
694                self.advance();
695                self.expect("op", Some("("))?;
696                let left = self.parse_additive()?;
697                self.expect("op", Some(","))?;
698                let right = self.parse_additive()?;
699                self.expect("op", Some(")"))?;
700                Ok(Expr::DateDiff { left: Box::new(left), right: Box::new(right) })
701            }
702            "op" if t.value == "(" => {
703                self.advance();
704                let expr = self.parse_additive()?;
705                self.expect("op", Some(")"))?;
706                Ok(expr)
707            }
708            "ident" => {
709                let name = self.advance().value;
710                Ok(Expr::Column(name))
711            }
712            "keyword" if !Self::is_reserved_keyword(&t.value) => {
713                let name = self.advance().value;
714                Ok(Expr::Column(name))
715            }
716            _ => Err(MdqlError::QueryParse(format!(
717                "Expected expression, got '{}'",
718                t.raw
719            ))),
720        }
721    }
722
723    fn parse_case_expr(&mut self) -> Result<Expr, MdqlError> {
724        self.expect("keyword", Some("CASE"))?;
725        let mut whens = Vec::new();
726        while self.match_keyword("WHEN") {
727            let condition = self.parse_or_expr()?;
728            self.expect("keyword", Some("THEN"))?;
729            let result = self.parse_additive()?;
730            whens.push((condition, Box::new(result)));
731        }
732        if whens.is_empty() {
733            return Err(MdqlError::QueryParse("CASE requires at least one WHEN clause".into()));
734        }
735        let else_expr = if self.match_keyword("ELSE") {
736            Some(Box::new(self.parse_additive()?))
737        } else {
738            None
739        };
740        self.expect("keyword", Some("END"))?;
741        Ok(Expr::Case { whens, else_expr })
742    }
743
744    fn parse_agg_expr(&mut self) -> Result<Expr, MdqlError> {
745        let func_name = self.advance().value.to_uppercase();
746        let func = match func_name.as_str() {
747            "COUNT" => AggFunc::Count,
748            "SUM" => AggFunc::Sum,
749            "AVG" => AggFunc::Avg,
750            "MIN" => AggFunc::Min,
751            "MAX" => AggFunc::Max,
752            _ => unreachable!(),
753        };
754        self.expect("op", Some("("))?;
755        let (arg, arg_expr) = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "*") {
756            self.advance();
757            ("*".to_string(), None)
758        } else {
759            let expr = self.parse_additive()?;
760            if let Expr::Column(name) = &expr {
761                (name.clone(), None)
762            } else {
763                (expr.display_name(), Some(Box::new(expr)))
764            }
765        };
766        self.expect("op", Some(")"))?;
767        Ok(Expr::Aggregate { func, arg, arg_expr })
768    }
769
770    fn parse_ident(&mut self) -> Result<String, MdqlError> {
771        let t = self.peek().ok_or_else(|| {
772            MdqlError::QueryParse("Expected identifier, got end of query".into())
773        })?;
774        match t.token_type.as_str() {
775            "ident" | "keyword" => {
776                let v = self.advance().value;
777                Ok(v)
778            }
779            _ => Err(MdqlError::QueryParse(format!(
780                "Expected identifier, got '{}'",
781                t.raw
782            ))),
783        }
784    }
785
786    fn parse_or_expr(&mut self) -> Result<WhereClause, MdqlError> {
787        let mut left = self.parse_and_expr()?;
788        while self.match_keyword("OR") {
789            let right = self.parse_and_expr()?;
790            left = WhereClause::BoolOp(BoolOp {
791                op: BoolOpKind::Or,
792                left: Box::new(left),
793                right: Box::new(right),
794            });
795        }
796        Ok(left)
797    }
798
799    fn parse_and_expr(&mut self) -> Result<WhereClause, MdqlError> {
800        let mut left = self.parse_comparison()?;
801        while self.match_keyword("AND") {
802            let right = self.parse_comparison()?;
803            left = WhereClause::BoolOp(BoolOp {
804                op: BoolOpKind::And,
805                left: Box::new(left),
806                right: Box::new(right),
807            });
808        }
809        Ok(left)
810    }
811
812    fn parse_comparison(&mut self) -> Result<WhereClause, MdqlError> {
813        // Handle parenthesized boolean expressions
814        if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
815            // Save position — might be arithmetic parens, not boolean
816            let saved_pos = self.pos;
817            self.advance();
818            // Try parsing as boolean (OR/AND) expression
819            let result = self.parse_or_expr();
820            if result.is_ok() && self.peek().map_or(false, |t| t.token_type == "op" && t.value == ")") {
821                self.advance();
822                return result;
823            }
824            // Not a boolean paren — rewind and parse as arithmetic expression
825            self.pos = saved_pos;
826        }
827
828        // Parse the left side as a full expression (column, literal, or arithmetic)
829        let left_expr = self.parse_additive()?;
830
831        // Extract column name for backward compat (simple column on left side)
832        let col = left_expr.as_column().unwrap_or("").to_string();
833
834        // IS NULL / IS NOT NULL (only valid with simple column)
835        if self.match_keyword("IS") {
836            if self.match_keyword("NOT") {
837                self.expect("keyword", Some("NULL"))?;
838                return Ok(WhereClause::Comparison(Comparison {
839                    column: col,
840                    op: CmpOp::IsNotNull,
841                    value: None,
842                    left_expr: Some(left_expr),
843                    right_expr: None,
844                }));
845            }
846            self.expect("keyword", Some("NULL"))?;
847            return Ok(WhereClause::Comparison(Comparison {
848                column: col,
849                op: CmpOp::IsNull,
850                value: None,
851                left_expr: Some(left_expr),
852                right_expr: None,
853            }));
854        }
855
856        // IN (val, val, ...)
857        if self.match_keyword("IN") {
858            self.expect("op", Some("("))?;
859            let mut values = vec![self.parse_value()?];
860            while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
861                self.advance();
862                values.push(self.parse_value()?);
863            }
864            self.expect("op", Some(")"))?;
865            return Ok(WhereClause::Comparison(Comparison {
866                column: col,
867                op: CmpOp::In,
868                value: Some(SqlValue::List(values)),
869                left_expr: Some(left_expr),
870                right_expr: None,
871            }));
872        }
873
874        // LIKE
875        if self.match_keyword("LIKE") {
876            let val = self.parse_value()?;
877            return Ok(WhereClause::Comparison(Comparison {
878                column: col,
879                op: CmpOp::Like,
880                value: Some(val),
881                left_expr: Some(left_expr),
882                right_expr: None,
883            }));
884        }
885
886        // NOT LIKE
887        if self.match_keyword("NOT") {
888            if self.match_keyword("LIKE") {
889                let val = self.parse_value()?;
890                return Ok(WhereClause::Comparison(Comparison {
891                    column: col,
892                    op: CmpOp::NotLike,
893                    value: Some(val),
894                    left_expr: Some(left_expr),
895                    right_expr: None,
896                }));
897            }
898            return Err(MdqlError::QueryParse("Expected LIKE after NOT".into()));
899        }
900
901        // Standard comparison operators
902        if let Some(t) = self.peek() {
903            if t.token_type == "op" && ["=", "!=", "<", ">", "<=", ">="].contains(&t.value.as_str())
904            {
905                let op_str = self.advance().value;
906                let op = match op_str.as_str() {
907                    "=" => CmpOp::Eq,
908                    "!=" => CmpOp::Ne,
909                    "<" => CmpOp::Lt,
910                    ">" => CmpOp::Gt,
911                    "<=" => CmpOp::Le,
912                    ">=" => CmpOp::Ge,
913                    _ => unreachable!(),
914                };
915                // Parse right side as expression
916                let right_expr = self.parse_additive()?;
917                // Extract SqlValue for backward compat (simple literal on right side)
918                let value = match &right_expr {
919                    Expr::Literal(v) => Some(v.clone()),
920                    _ => None,
921                };
922                return Ok(WhereClause::Comparison(Comparison {
923                    column: col,
924                    op,
925                    value,
926                    left_expr: Some(left_expr),
927                    right_expr: Some(right_expr),
928                }));
929            }
930        }
931
932        let got = self.peek().map_or("end".to_string(), |t| t.raw.clone());
933        Err(MdqlError::QueryParse(format!(
934            "Expected operator after '{}', got '{}'",
935            left_expr.display_name(), got
936        )))
937    }
938
939    fn parse_value(&mut self) -> Result<SqlValue, MdqlError> {
940        let t = self.peek().ok_or_else(|| {
941            MdqlError::QueryParse("Expected value, got end of query".into())
942        })?;
943        match t.token_type.as_str() {
944            "string" => {
945                let v = self.advance().value;
946                Ok(SqlValue::String(v))
947            }
948            "number" => {
949                let v = self.advance().value;
950                if v.contains('.') {
951                    Ok(SqlValue::Float(v.parse().map_err(|_| {
952                        MdqlError::QueryParse(format!("Invalid float: {}", v))
953                    })?))
954                } else {
955                    Ok(SqlValue::Int(v.parse().map_err(|_| {
956                        MdqlError::QueryParse(format!("Invalid int: {}", v))
957                    })?))
958                }
959            }
960            "keyword" if t.value == "NULL" => {
961                self.advance();
962                Ok(SqlValue::Null)
963            }
964            _ => Err(MdqlError::QueryParse(format!(
965                "Expected value, got '{}'",
966                t.raw
967            ))),
968        }
969    }
970
971    fn parse_order_by(&mut self) -> Result<Vec<OrderSpec>, MdqlError> {
972        let mut specs = vec![self.parse_order_spec()?];
973        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
974            self.advance();
975            specs.push(self.parse_order_spec()?);
976        }
977        Ok(specs)
978    }
979
980    fn parse_order_spec(&mut self) -> Result<OrderSpec, MdqlError> {
981        let expr = self.parse_additive()?;
982        let col = expr.as_column().unwrap_or("").to_string();
983        let descending = if self.match_keyword("DESC") {
984            true
985        } else {
986            self.match_keyword("ASC");
987            false
988        };
989        Ok(OrderSpec {
990            column: col,
991            expr: Some(expr),
992            descending,
993        })
994    }
995
996    fn is_clause_keyword(&self, t: &Token) -> bool {
997        t.token_type == "keyword"
998            && ["WHERE", "ORDER", "LIMIT", "JOIN", "LEFT", "ON", "GROUP"].contains(&t.value.as_str())
999    }
1000
1001    /// Keywords that should never be consumed as column names inside expressions.
1002    fn is_reserved_keyword(kw: &str) -> bool {
1003        matches!(kw,
1004            "AS" | "FROM" | "WHERE" | "AND" | "OR" | "ORDER" | "BY"
1005            | "ASC" | "DESC" | "LIMIT" | "JOIN" | "ON" | "GROUP"
1006            | "SELECT" | "INSERT" | "INTO" | "VALUES" | "UPDATE" | "SET"
1007            | "DELETE" | "ALTER" | "TABLE" | "IS" | "NOT" | "IN" | "LIKE"
1008            | "RENAME" | "FIELD" | "TO" | "DROP" | "MERGE" | "FIELDS"
1009            | "CASE" | "WHEN" | "THEN" | "ELSE" | "END"
1010            | "HAVING" | "INTERVAL" | "DAY" | "DAYS"
1011            | "CURRENT_DATE" | "CURRENT_TIMESTAMP" | "DATEDIFF"
1012            | "CREATE" | "VIEW" | "CASCADE" | "RESTRICT"
1013        )
1014    }
1015
1016    fn expect_end(&self) -> Result<(), MdqlError> {
1017        if let Some(t) = self.peek() {
1018            return Err(MdqlError::QueryParse(format!(
1019                "Unexpected token '{}' at position {}",
1020                t.raw, self.pos
1021            )));
1022        }
1023        Ok(())
1024    }
1025}
1026
1027pub fn parse_query(sql: &str) -> crate::errors::Result<Statement> {
1028    let tokens = tokenize(sql);
1029    if tokens.is_empty() {
1030        return Err(MdqlError::QueryParse("Empty query".into()));
1031    }
1032    let mut parser = Parser::new(tokens);
1033    parser.parse_statement()
1034}
1035
1036#[cfg(test)]
1037mod tests {
1038    use super::*;
1039
1040    #[test]
1041    fn test_simple_select() {
1042        let stmt = parse_query("SELECT title, status FROM strategies").unwrap();
1043        if let Statement::Select(q) = stmt {
1044            assert_eq!(q.columns, ColumnList::Named(vec![SelectExpr::Column("title".into()), SelectExpr::Column("status".into())]));
1045            assert_eq!(q.table, "strategies");
1046        } else {
1047            panic!("Expected Select");
1048        }
1049    }
1050
1051    #[test]
1052    fn test_select_star() {
1053        let stmt = parse_query("SELECT * FROM test").unwrap();
1054        if let Statement::Select(q) = stmt {
1055            assert_eq!(q.columns, ColumnList::All);
1056        } else {
1057            panic!("Expected Select");
1058        }
1059    }
1060
1061    #[test]
1062    fn test_where_clause() {
1063        let stmt = parse_query("SELECT title FROM test WHERE count > 5").unwrap();
1064        if let Statement::Select(q) = stmt {
1065            assert!(q.where_clause.is_some());
1066        } else {
1067            panic!("Expected Select");
1068        }
1069    }
1070
1071    #[test]
1072    fn test_order_by() {
1073        let stmt =
1074            parse_query("SELECT title FROM test ORDER BY composite DESC, title ASC").unwrap();
1075        if let Statement::Select(q) = stmt {
1076            let ob = q.order_by.unwrap();
1077            assert_eq!(ob.len(), 2);
1078            assert!(ob[0].descending);
1079            assert!(!ob[1].descending);
1080        } else {
1081            panic!("Expected Select");
1082        }
1083    }
1084
1085    #[test]
1086    fn test_limit() {
1087        let stmt = parse_query("SELECT * FROM test LIMIT 10").unwrap();
1088        if let Statement::Select(q) = stmt {
1089            assert_eq!(q.limit, Some(10));
1090        } else {
1091            panic!("Expected Select");
1092        }
1093    }
1094
1095    #[test]
1096    fn test_insert() {
1097        let stmt = parse_query(
1098            "INSERT INTO test (title, count) VALUES ('Hello', 42)",
1099        )
1100        .unwrap();
1101        if let Statement::Insert(q) = stmt {
1102            assert_eq!(q.table, "test");
1103            assert_eq!(q.columns, vec!["title", "count"]);
1104            assert_eq!(q.values[0], SqlValue::String("Hello".into()));
1105            assert_eq!(q.values[1], SqlValue::Int(42));
1106        } else {
1107            panic!("Expected Insert");
1108        }
1109    }
1110
1111    #[test]
1112    fn test_update() {
1113        let stmt = parse_query("UPDATE test SET status = 'KILLED' WHERE path = 'a.md'").unwrap();
1114        if let Statement::Update(q) = stmt {
1115            assert_eq!(q.table, "test");
1116            assert_eq!(q.assignments.len(), 1);
1117            assert!(q.where_clause.is_some());
1118        } else {
1119            panic!("Expected Update");
1120        }
1121    }
1122
1123    #[test]
1124    fn test_delete() {
1125        let stmt = parse_query("DELETE FROM test WHERE status = 'draft'").unwrap();
1126        if let Statement::Delete(q) = stmt {
1127            assert_eq!(q.table, "test");
1128            assert!(q.where_clause.is_some());
1129        } else {
1130            panic!("Expected Delete");
1131        }
1132    }
1133
1134    #[test]
1135    fn test_alter_rename() {
1136        let stmt =
1137            parse_query("ALTER TABLE test RENAME FIELD 'Summary' TO 'Overview'").unwrap();
1138        if let Statement::AlterRename(q) = stmt {
1139            assert_eq!(q.old_name, "Summary");
1140            assert_eq!(q.new_name, "Overview");
1141        } else {
1142            panic!("Expected AlterRename");
1143        }
1144    }
1145
1146    #[test]
1147    fn test_alter_drop() {
1148        let stmt = parse_query("ALTER TABLE test DROP FIELD 'Details'").unwrap();
1149        if let Statement::AlterDrop(q) = stmt {
1150            assert_eq!(q.field_name, "Details");
1151        } else {
1152            panic!("Expected AlterDrop");
1153        }
1154    }
1155
1156    #[test]
1157    fn test_alter_merge() {
1158        let stmt = parse_query(
1159            "ALTER TABLE test MERGE FIELDS 'Entry Rules', 'Exit Rules' INTO 'Trading Rules'",
1160        )
1161        .unwrap();
1162        if let Statement::AlterMerge(q) = stmt {
1163            assert_eq!(q.sources, vec!["Entry Rules", "Exit Rules"]);
1164            assert_eq!(q.into, "Trading Rules");
1165        } else {
1166            panic!("Expected AlterMerge");
1167        }
1168    }
1169
1170    #[test]
1171    fn test_backtick_ident() {
1172        let stmt = parse_query("SELECT `Structural Mechanism` FROM test").unwrap();
1173        if let Statement::Select(q) = stmt {
1174            assert_eq!(
1175                q.columns,
1176                ColumnList::Named(vec![SelectExpr::Column("Structural Mechanism".into())])
1177            );
1178        } else {
1179            panic!("Expected Select");
1180        }
1181    }
1182
1183    #[test]
1184    fn test_like_operator() {
1185        let stmt = parse_query("SELECT title FROM test WHERE categories LIKE '%defi%'").unwrap();
1186        if let Statement::Select(q) = stmt {
1187            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1188                assert_eq!(c.op, CmpOp::Like);
1189                assert_eq!(c.value, Some(SqlValue::String("%defi%".into())));
1190            } else {
1191                panic!("Expected LIKE comparison");
1192            }
1193        } else {
1194            panic!("Expected Select");
1195        }
1196    }
1197
1198    #[test]
1199    fn test_in_operator() {
1200        let stmt =
1201            parse_query("SELECT * FROM test WHERE status IN ('ACTIVE', 'LIVE')").unwrap();
1202        if let Statement::Select(q) = stmt {
1203            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1204                assert_eq!(c.op, CmpOp::In);
1205            } else {
1206                panic!("Expected IN comparison");
1207            }
1208        } else {
1209            panic!("Expected Select");
1210        }
1211    }
1212
1213    #[test]
1214    fn test_is_null() {
1215        let stmt = parse_query("SELECT * FROM test WHERE title IS NULL").unwrap();
1216        if let Statement::Select(q) = stmt {
1217            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1218                assert_eq!(c.op, CmpOp::IsNull);
1219            } else {
1220                panic!("Expected IS NULL comparison");
1221            }
1222        } else {
1223            panic!("Expected Select");
1224        }
1225    }
1226
1227    #[test]
1228    fn test_and_or() {
1229        let stmt = parse_query(
1230            "SELECT * FROM test WHERE status = 'ACTIVE' AND count > 5 OR title LIKE '%test%'",
1231        )
1232        .unwrap();
1233        if let Statement::Select(q) = stmt {
1234            assert!(q.where_clause.is_some());
1235        } else {
1236            panic!("Expected Select");
1237        }
1238    }
1239
1240    #[test]
1241    fn test_join() {
1242        let stmt = parse_query(
1243            "SELECT s.title, b.sharpe FROM strategies s JOIN backtests b ON b.strategy = s.path",
1244        )
1245        .unwrap();
1246        if let Statement::Select(q) = stmt {
1247            assert_eq!(q.table, "strategies");
1248            assert_eq!(q.table_alias, Some("s".into()));
1249            assert_eq!(q.joins.len(), 1);
1250            let join = &q.joins[0];
1251            assert_eq!(join.table, "backtests");
1252            assert_eq!(join.alias, Some("b".into()));
1253        } else {
1254            panic!("Expected Select");
1255        }
1256    }
1257
1258    #[test]
1259    fn test_multi_join() {
1260        let stmt = parse_query(
1261            "SELECT s.title, b.sharpe, c.verdict FROM strategies s JOIN backtests b ON b.strategy = s.path JOIN critiques c ON c.strategy = s.path",
1262        )
1263        .unwrap();
1264        if let Statement::Select(q) = stmt {
1265            assert_eq!(q.table, "strategies");
1266            assert_eq!(q.table_alias, Some("s".into()));
1267            assert_eq!(q.joins.len(), 2);
1268            assert_eq!(q.joins[0].table, "backtests");
1269            assert_eq!(q.joins[0].alias, Some("b".into()));
1270            assert_eq!(where_clause_to_sql(&q.joins[0].condition), "b.strategy = s.path");
1271            assert_eq!(q.joins[1].table, "critiques");
1272            assert_eq!(q.joins[1].alias, Some("c".into()));
1273            assert_eq!(where_clause_to_sql(&q.joins[1].condition), "c.strategy = s.path");
1274        } else {
1275            panic!("Expected Select");
1276        }
1277    }
1278
1279    #[test]
1280    fn test_left_join() {
1281        let stmt = parse_query(
1282            "SELECT s.title, b.sharpe FROM strategies s LEFT JOIN backtests b ON b.strategy = s.path",
1283        )
1284        .unwrap();
1285        if let Statement::Select(q) = stmt {
1286            assert_eq!(q.joins.len(), 1);
1287            assert_eq!(q.joins[0].join_type, JoinType::Left);
1288            assert_eq!(q.joins[0].table, "backtests");
1289        } else {
1290            panic!("Expected Select");
1291        }
1292    }
1293
1294    #[test]
1295    fn test_mixed_join_types() {
1296        let stmt = parse_query(
1297            "SELECT s.title FROM strategies s JOIN backtests b ON b.strategy = s.path LEFT JOIN allocations a ON a.strategy = s.path",
1298        )
1299        .unwrap();
1300        if let Statement::Select(q) = stmt {
1301            assert_eq!(q.joins.len(), 2);
1302            assert_eq!(q.joins[0].join_type, JoinType::Inner);
1303            assert_eq!(q.joins[1].join_type, JoinType::Left);
1304        } else {
1305            panic!("Expected Select");
1306        }
1307    }
1308
1309    #[test]
1310    fn test_join_compound_and() {
1311        let stmt = parse_query(
1312            "SELECT s.title FROM strategies s LEFT JOIN backtests b ON b.strategy = s.path AND b.mode = 'PAPER'",
1313        )
1314        .unwrap();
1315        if let Statement::Select(q) = stmt {
1316            assert_eq!(q.joins.len(), 1);
1317            assert_eq!(q.joins[0].join_type, JoinType::Left);
1318            let sql = where_clause_to_sql(&q.joins[0].condition);
1319            assert!(sql.contains("b.strategy = s.path"));
1320            assert!(sql.contains("AND"));
1321            assert!(sql.contains("b.mode = 'PAPER'"));
1322        } else {
1323            panic!("Expected Select");
1324        }
1325    }
1326
1327    #[test]
1328    fn test_join_compound_or() {
1329        let stmt = parse_query(
1330            "SELECT * FROM a JOIN b ON a.id = b.id OR a.alt = b.id",
1331        )
1332        .unwrap();
1333        if let Statement::Select(q) = stmt {
1334            let sql = where_clause_to_sql(&q.joins[0].condition);
1335            assert!(sql.contains("OR"));
1336        } else {
1337            panic!("Expected Select");
1338        }
1339    }
1340
1341    #[test]
1342    fn test_join_compound_with_where() {
1343        let stmt = parse_query(
1344            "SELECT s.title FROM strategies s JOIN backtests b ON b.strategy = s.path AND b.mode = 'PAPER' WHERE s.title = 'Alpha'",
1345        )
1346        .unwrap();
1347        if let Statement::Select(q) = stmt {
1348            assert_eq!(q.joins.len(), 1);
1349            assert!(q.where_clause.is_some());
1350            let join_sql = where_clause_to_sql(&q.joins[0].condition);
1351            assert!(join_sql.contains("AND"));
1352        } else {
1353            panic!("Expected Select");
1354        }
1355    }
1356
1357    #[test]
1358    fn test_empty_query() {
1359        assert!(parse_query("").is_err());
1360    }
1361
1362    #[test]
1363    fn test_count_star() {
1364        let stmt = parse_query("SELECT status, COUNT(*) AS cnt FROM strategies GROUP BY status").unwrap();
1365        if let Statement::Select(q) = stmt {
1366            if let ColumnList::Named(exprs) = &q.columns {
1367                assert_eq!(exprs.len(), 2);
1368                assert_eq!(exprs[0], SelectExpr::Column("status".into()));
1369                assert!(matches!(&exprs[1], SelectExpr::Aggregate {
1370                    func: AggFunc::Count,
1371                    arg,
1372                    alias: Some(a),
1373                    ..
1374                } if arg == "*" && a == "cnt"));
1375            } else {
1376                panic!("Expected Named columns");
1377            }
1378            assert_eq!(q.group_by, Some(vec!["status".into()]));
1379        } else {
1380            panic!("Expected Select");
1381        }
1382    }
1383
1384    #[test]
1385    fn test_count_column_as_ident() {
1386        // "count" as a column name should NOT be parsed as the COUNT aggregate
1387        let stmt = parse_query("INSERT INTO test (title, count) VALUES ('Hello', 42)").unwrap();
1388        if let Statement::Insert(q) = stmt {
1389            assert_eq!(q.columns, vec!["title", "count"]);
1390        } else {
1391            panic!("Expected Insert");
1392        }
1393    }
1394
1395    #[test]
1396    fn test_multiple_aggregates() {
1397        let stmt = parse_query("SELECT MIN(composite), MAX(composite), AVG(composite) FROM strategies").unwrap();
1398        if let Statement::Select(q) = stmt {
1399            if let ColumnList::Named(exprs) = &q.columns {
1400                assert_eq!(exprs.len(), 3);
1401                assert!(matches!(&exprs[0], SelectExpr::Aggregate { func: AggFunc::Min, .. }));
1402                assert!(matches!(&exprs[1], SelectExpr::Aggregate { func: AggFunc::Max, .. }));
1403                assert!(matches!(&exprs[2], SelectExpr::Aggregate { func: AggFunc::Avg, .. }));
1404            } else {
1405                panic!("Expected Named columns");
1406            }
1407            assert_eq!(q.group_by, None);
1408        } else {
1409            panic!("Expected Select");
1410        }
1411    }
1412
1413    // ── Expression tests ──────────────────────────────────────────
1414
1415    #[test]
1416    fn test_select_arithmetic_expr() {
1417        let stmt = parse_query("SELECT a + b FROM test").unwrap();
1418        if let Statement::Select(q) = stmt {
1419            if let ColumnList::Named(exprs) = &q.columns {
1420                assert_eq!(exprs.len(), 1);
1421                assert!(matches!(&exprs[0], SelectExpr::Expr {
1422                    expr: Expr::BinaryOp { op: ArithOp::Add, .. },
1423                    alias: None,
1424                }));
1425            } else {
1426                panic!("Expected Named columns");
1427            }
1428        } else {
1429            panic!("Expected Select");
1430        }
1431    }
1432
1433    #[test]
1434    fn test_select_arithmetic_with_alias() {
1435        let stmt = parse_query("SELECT a + b AS total FROM test").unwrap();
1436        if let Statement::Select(q) = stmt {
1437            if let ColumnList::Named(exprs) = &q.columns {
1438                assert_eq!(exprs.len(), 1);
1439                assert!(matches!(&exprs[0], SelectExpr::Expr {
1440                    alias: Some(a),
1441                    ..
1442                } if a == "total"));
1443                assert_eq!(exprs[0].output_name(), "total");
1444            } else {
1445                panic!("Expected Named columns");
1446            }
1447        } else {
1448            panic!("Expected Select");
1449        }
1450    }
1451
1452    #[test]
1453    fn test_select_precedence() {
1454        // a + b * c should parse as a + (b * c)
1455        let stmt = parse_query("SELECT a + b * c FROM test").unwrap();
1456        if let Statement::Select(q) = stmt {
1457            if let ColumnList::Named(exprs) = &q.columns {
1458                if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1459                    if let Expr::BinaryOp { left, op, right } = expr {
1460                        assert_eq!(*op, ArithOp::Add);
1461                        assert!(matches!(left.as_ref(), Expr::Column(n) if n == "a"));
1462                        assert!(matches!(right.as_ref(), Expr::BinaryOp { op: ArithOp::Mul, .. }));
1463                    } else {
1464                        panic!("Expected BinaryOp");
1465                    }
1466                } else {
1467                    panic!("Expected Expr variant");
1468                }
1469            } else {
1470                panic!("Expected Named columns");
1471            }
1472        } else {
1473            panic!("Expected Select");
1474        }
1475    }
1476
1477    #[test]
1478    fn test_select_parenthesized_expr() {
1479        // (a + b) * c should override default precedence
1480        let stmt = parse_query("SELECT (a + b) * c FROM test").unwrap();
1481        if let Statement::Select(q) = stmt {
1482            if let ColumnList::Named(exprs) = &q.columns {
1483                if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1484                    if let Expr::BinaryOp { left, op, .. } = expr {
1485                        assert_eq!(*op, ArithOp::Mul);
1486                        assert!(matches!(left.as_ref(), Expr::BinaryOp { op: ArithOp::Add, .. }));
1487                    } else {
1488                        panic!("Expected BinaryOp");
1489                    }
1490                } else {
1491                    panic!("Expected Expr variant");
1492                }
1493            } else {
1494                panic!("Expected Named columns");
1495            }
1496        } else {
1497            panic!("Expected Select");
1498        }
1499    }
1500
1501    #[test]
1502    fn test_select_unary_minus() {
1503        let stmt = parse_query("SELECT -count FROM test").unwrap();
1504        if let Statement::Select(q) = stmt {
1505            if let ColumnList::Named(exprs) = &q.columns {
1506                assert!(matches!(&exprs[0], SelectExpr::Expr {
1507                    expr: Expr::UnaryMinus(_),
1508                    ..
1509                }));
1510            } else {
1511                panic!("Expected Named columns");
1512            }
1513        } else {
1514            panic!("Expected Select");
1515        }
1516    }
1517
1518    #[test]
1519    fn test_select_negative_literal() {
1520        let stmt = parse_query("SELECT -42 FROM test").unwrap();
1521        if let Statement::Select(q) = stmt {
1522            if let ColumnList::Named(exprs) = &q.columns {
1523                // Unary minus folds into the literal
1524                assert!(matches!(&exprs[0], SelectExpr::Expr {
1525                    expr: Expr::Literal(SqlValue::Int(-42)),
1526                    ..
1527                }));
1528            } else {
1529                panic!("Expected Named columns");
1530            }
1531        } else {
1532            panic!("Expected Select");
1533        }
1534    }
1535
1536    #[test]
1537    fn test_where_arithmetic_expr() {
1538        let stmt = parse_query("SELECT * FROM test WHERE a + b > 10").unwrap();
1539        if let Statement::Select(q) = stmt {
1540            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1541                assert_eq!(c.op, CmpOp::Gt);
1542                assert!(matches!(&c.left_expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1543                assert!(matches!(&c.right_expr, Some(Expr::Literal(SqlValue::Int(10)))));
1544            } else {
1545                panic!("Expected comparison");
1546            }
1547        } else {
1548            panic!("Expected Select");
1549        }
1550    }
1551
1552    #[test]
1553    fn test_where_both_sides_expr() {
1554        let stmt = parse_query("SELECT * FROM test WHERE a * 2 > b + 1").unwrap();
1555        if let Statement::Select(q) = stmt {
1556            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1557                assert_eq!(c.op, CmpOp::Gt);
1558                assert!(matches!(&c.left_expr, Some(Expr::BinaryOp { op: ArithOp::Mul, .. })));
1559                assert!(matches!(&c.right_expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1560            } else {
1561                panic!("Expected comparison");
1562            }
1563        } else {
1564            panic!("Expected Select");
1565        }
1566    }
1567
1568    #[test]
1569    fn test_order_by_expr() {
1570        let stmt = parse_query("SELECT * FROM test ORDER BY a + b DESC").unwrap();
1571        if let Statement::Select(q) = stmt {
1572            let ob = q.order_by.unwrap();
1573            assert_eq!(ob.len(), 1);
1574            assert!(ob[0].descending);
1575            assert!(matches!(&ob[0].expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1576        } else {
1577            panic!("Expected Select");
1578        }
1579    }
1580
1581    #[test]
1582    fn test_all_arithmetic_ops() {
1583        let stmt = parse_query("SELECT a + b, a - b, a * b, a / b, a % b FROM test").unwrap();
1584        if let Statement::Select(q) = stmt {
1585            if let ColumnList::Named(exprs) = &q.columns {
1586                assert_eq!(exprs.len(), 5);
1587                assert!(matches!(&exprs[0], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Add, .. }, .. }));
1588                assert!(matches!(&exprs[1], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Sub, .. }, .. }));
1589                assert!(matches!(&exprs[2], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Mul, .. }, .. }));
1590                assert!(matches!(&exprs[3], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Div, .. }, .. }));
1591                assert!(matches!(&exprs[4], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Mod, .. }, .. }));
1592            } else {
1593                panic!("Expected Named columns");
1594            }
1595        } else {
1596            panic!("Expected Select");
1597        }
1598    }
1599
1600    #[test]
1601    fn test_column_with_literal_arithmetic() {
1602        let stmt = parse_query("SELECT count * 2 + 1 FROM test").unwrap();
1603        if let Statement::Select(q) = stmt {
1604            if let ColumnList::Named(exprs) = &q.columns {
1605                // Should be (count * 2) + 1
1606                if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1607                    if let Expr::BinaryOp { left, op, right } = expr {
1608                        assert_eq!(*op, ArithOp::Add);
1609                        assert!(matches!(right.as_ref(), Expr::Literal(SqlValue::Int(1))));
1610                        assert!(matches!(left.as_ref(), Expr::BinaryOp { op: ArithOp::Mul, .. }));
1611                    } else {
1612                        panic!("Expected BinaryOp");
1613                    }
1614                } else {
1615                    panic!("Expected Expr");
1616                }
1617            } else {
1618                panic!("Expected Named columns");
1619            }
1620        } else {
1621            panic!("Expected Select");
1622        }
1623    }
1624
1625    #[test]
1626    fn test_mixed_columns_and_exprs() {
1627        let stmt = parse_query("SELECT title, a + b AS sum, count FROM test").unwrap();
1628        if let Statement::Select(q) = stmt {
1629            if let ColumnList::Named(exprs) = &q.columns {
1630                assert_eq!(exprs.len(), 3);
1631                assert_eq!(exprs[0], SelectExpr::Column("title".into()));
1632                assert!(matches!(&exprs[1], SelectExpr::Expr { alias: Some(a), .. } if a == "sum"));
1633                assert_eq!(exprs[2], SelectExpr::Column("count".into()));
1634            } else {
1635                panic!("Expected Named columns");
1636            }
1637        } else {
1638            panic!("Expected Select");
1639        }
1640    }
1641
1642    // ── CASE WHEN tests ──────────────────────────────────────────
1643
1644    #[test]
1645    fn test_case_when_basic() {
1646        let stmt = parse_query(
1647            "SELECT CASE WHEN status = 'ACTIVE' THEN 1 ELSE 0 END FROM test"
1648        ).unwrap();
1649        if let Statement::Select(q) = stmt {
1650            if let ColumnList::Named(exprs) = &q.columns {
1651                assert_eq!(exprs.len(), 1);
1652                assert!(matches!(&exprs[0], SelectExpr::Expr {
1653                    expr: Expr::Case { .. },
1654                    ..
1655                }));
1656            } else {
1657                panic!("Expected Named columns");
1658            }
1659        } else {
1660            panic!("Expected Select");
1661        }
1662    }
1663
1664    #[test]
1665    fn test_case_when_multiple_branches() {
1666        let stmt = parse_query(
1667            "SELECT CASE WHEN x > 10 THEN 'high' WHEN x > 5 THEN 'mid' ELSE 'low' END FROM test"
1668        ).unwrap();
1669        if let Statement::Select(q) = stmt {
1670            if let ColumnList::Named(exprs) = &q.columns {
1671                if let SelectExpr::Expr { expr: Expr::Case { whens, else_expr }, .. } = &exprs[0] {
1672                    assert_eq!(whens.len(), 2);
1673                    assert!(else_expr.is_some());
1674                } else {
1675                    panic!("Expected Case expression");
1676                }
1677            } else {
1678                panic!("Expected Named columns");
1679            }
1680        } else {
1681            panic!("Expected Select");
1682        }
1683    }
1684
1685    #[test]
1686    fn test_case_when_no_else() {
1687        let stmt = parse_query(
1688            "SELECT CASE WHEN x = 1 THEN 'one' END FROM test"
1689        ).unwrap();
1690        if let Statement::Select(q) = stmt {
1691            if let ColumnList::Named(exprs) = &q.columns {
1692                if let SelectExpr::Expr { expr: Expr::Case { whens, else_expr }, .. } = &exprs[0] {
1693                    assert_eq!(whens.len(), 1);
1694                    assert!(else_expr.is_none());
1695                } else {
1696                    panic!("Expected Case expression");
1697                }
1698            } else {
1699                panic!("Expected Named columns");
1700            }
1701        } else {
1702            panic!("Expected Select");
1703        }
1704    }
1705
1706    #[test]
1707    fn test_case_when_in_aggregate() {
1708        let stmt = parse_query(
1709            "SELECT SUM(CASE WHEN side = 'BUY' THEN size ELSE -size END) AS net FROM orders GROUP BY token"
1710        ).unwrap();
1711        if let Statement::Select(q) = stmt {
1712            if let ColumnList::Named(exprs) = &q.columns {
1713                assert_eq!(exprs.len(), 1);
1714                assert!(matches!(&exprs[0], SelectExpr::Aggregate {
1715                    func: AggFunc::Sum,
1716                    arg_expr: Some(Expr::Case { .. }),
1717                    alias: Some(a),
1718                    ..
1719                } if a == "net"));
1720            } else {
1721                panic!("Expected Named columns");
1722            }
1723        } else {
1724            panic!("Expected Select");
1725        }
1726    }
1727
1728    #[test]
1729    fn test_case_when_with_alias() {
1730        let stmt = parse_query(
1731            "SELECT CASE WHEN x > 0 THEN 'pos' ELSE 'neg' END AS sign FROM test"
1732        ).unwrap();
1733        if let Statement::Select(q) = stmt {
1734            if let ColumnList::Named(exprs) = &q.columns {
1735                assert!(matches!(&exprs[0], SelectExpr::Expr {
1736                    expr: Expr::Case { .. },
1737                    alias: Some(a),
1738                } if a == "sign"));
1739            } else {
1740                panic!("Expected Named columns");
1741            }
1742        } else {
1743            panic!("Expected Select");
1744        }
1745    }
1746
1747    #[test]
1748    fn test_create_view() {
1749        let stmt = parse_query("CREATE VIEW live AS SELECT * FROM strategies WHERE status = 'LIVE'").unwrap();
1750        if let Statement::CreateView(cv) = stmt {
1751            assert_eq!(cv.view_name, "live");
1752            assert!(cv.columns.is_none());
1753            assert_eq!(cv.query.table, "strategies");
1754            assert!(cv.query.where_clause.is_some());
1755        } else {
1756            panic!("Expected CreateView, got {:?}", stmt);
1757        }
1758    }
1759
1760    #[test]
1761    fn test_create_view_with_columns() {
1762        let stmt = parse_query("CREATE VIEW v1 (a, b) AS SELECT title, status FROM t").unwrap();
1763        if let Statement::CreateView(cv) = stmt {
1764            assert_eq!(cv.view_name, "v1");
1765            assert_eq!(cv.columns, Some(vec!["a".into(), "b".into()]));
1766        } else {
1767            panic!("Expected CreateView");
1768        }
1769    }
1770
1771    #[test]
1772    fn test_drop_view() {
1773        let stmt = parse_query("DROP VIEW live").unwrap();
1774        if let Statement::DropView(dv) = stmt {
1775            assert_eq!(dv.view_name, "live");
1776        } else {
1777            panic!("Expected DropView, got {:?}", stmt);
1778        }
1779    }
1780
1781    #[test]
1782    fn test_create_view_case_insensitive() {
1783        let stmt = parse_query("create view My_View as select * from t").unwrap();
1784        if let Statement::CreateView(cv) = stmt {
1785            assert_eq!(cv.view_name, "My_View");
1786        } else {
1787            panic!("Expected CreateView");
1788        }
1789    }
1790
1791    // ── Issue #42: Arithmetic between aggregates in column expressions ──
1792
1793    #[test]
1794    fn test_aggregate_division() {
1795        let stmt = parse_query(
1796            "SELECT token, SUM(sell) / SUM(buy) as ratio FROM orders GROUP BY token"
1797        ).unwrap();
1798        if let Statement::Select(q) = stmt {
1799            assert_eq!(q.group_by, Some(vec!["token".into()]));
1800            if let ColumnList::Named(exprs) = &q.columns {
1801                assert_eq!(exprs.len(), 2);
1802                assert!(exprs[1].is_aggregate());
1803            } else {
1804                panic!("Expected Named columns");
1805            }
1806        } else {
1807            panic!("Expected Select");
1808        }
1809    }
1810
1811    #[test]
1812    fn test_aggregate_subtraction() {
1813        let stmt = parse_query(
1814            "SELECT token, SUM(sell) - SUM(buy) as net FROM orders GROUP BY token"
1815        ).unwrap();
1816        if let Statement::Select(q) = stmt {
1817            if let ColumnList::Named(exprs) = &q.columns {
1818                assert_eq!(exprs[1].output_name(), "net");
1819            }
1820        } else {
1821            panic!("Expected Select");
1822        }
1823    }
1824
1825    #[test]
1826    fn test_create_view_with_arithmetic() {
1827        let stmt = parse_query(
1828            "CREATE VIEW positions AS SELECT token, SUM(sell) / SUM(buy) as ratio FROM orders GROUP BY token"
1829        ).unwrap();
1830        if let Statement::CreateView(cv) = stmt {
1831            assert_eq!(cv.view_name, "positions");
1832        } else {
1833            panic!("Expected CreateView, got {:?}", stmt);
1834        }
1835    }
1836
1837    // ── Issue #43: Subqueries in FROM ──
1838
1839    #[test]
1840    fn test_subquery_in_from() {
1841        let stmt = parse_query(
1842            "SELECT token, sell_size FROM (SELECT token, SUM(size) as sell_size FROM orders GROUP BY token) LIMIT 5"
1843        ).unwrap();
1844        if let Statement::Select(q) = stmt {
1845            assert!(q.subquery.is_some());
1846            assert_eq!(q.limit, Some(5));
1847            let sub = q.subquery.unwrap();
1848            assert_eq!(sub.table, "orders");
1849            assert!(sub.group_by.is_some());
1850        } else {
1851            panic!("Expected Select");
1852        }
1853    }
1854
1855    // ── Issue #44: HAVING in CREATE VIEW ──
1856
1857    #[test]
1858    fn test_create_view_with_having() {
1859        let stmt = parse_query(
1860            "CREATE VIEW positions AS SELECT token, SUM(sell) as sell_size, SUM(buy) as buy_size FROM orders GROUP BY token HAVING sell_size > buy_size"
1861        ).unwrap();
1862        if let Statement::CreateView(cv) = stmt {
1863            assert_eq!(cv.view_name, "positions");
1864            assert!(cv.query.having.is_some());
1865        } else {
1866            panic!("Expected CreateView, got {:?}", stmt);
1867        }
1868    }
1869
1870    // ── Issue #42: Aggregate multiplication ──
1871
1872    #[test]
1873    fn test_aggregate_multiplication() {
1874        let stmt = parse_query(
1875            "SELECT SUM(a) * 2 as doubled FROM test"
1876        ).unwrap();
1877        if let Statement::Select(q) = stmt {
1878            if let ColumnList::Named(exprs) = &q.columns {
1879                assert_eq!(exprs.len(), 1);
1880                assert!(exprs[0].is_aggregate());
1881                assert_eq!(exprs[0].output_name(), "doubled");
1882            } else {
1883                panic!("Expected Named columns");
1884            }
1885        } else {
1886            panic!("Expected Select");
1887        }
1888    }
1889
1890    #[test]
1891    fn test_complex_aggregate_arithmetic() {
1892        let stmt = parse_query(
1893            "SELECT SUM(CASE WHEN side = 'SELL' THEN size ELSE 0 END) / SUM(CASE WHEN side = 'BUY' THEN size ELSE 0 END) as ratio FROM orders GROUP BY token"
1894        ).unwrap();
1895        if let Statement::Select(q) = stmt {
1896            if let ColumnList::Named(exprs) = &q.columns {
1897                assert_eq!(exprs.len(), 1);
1898                assert!(exprs[0].is_aggregate());
1899                assert_eq!(exprs[0].output_name(), "ratio");
1900            } else {
1901                panic!("Expected Named columns");
1902            }
1903            assert_eq!(q.group_by, Some(vec!["token".into()]));
1904        } else {
1905            panic!("Expected Select");
1906        }
1907    }
1908
1909    // ── Issue #43: Subquery with alias and WHERE ──
1910
1911    #[test]
1912    fn test_subquery_with_alias() {
1913        let stmt = parse_query(
1914            "SELECT x FROM (SELECT x FROM t) sub"
1915        ).unwrap();
1916        if let Statement::Select(q) = stmt {
1917            assert!(q.subquery.is_some());
1918            let sub = q.subquery.unwrap();
1919            assert_eq!(sub.table, "t");
1920            if let ColumnList::Named(exprs) = &q.columns {
1921                assert_eq!(exprs.len(), 1);
1922                assert_eq!(exprs[0].output_name(), "x");
1923            } else {
1924                panic!("Expected Named columns");
1925            }
1926        } else {
1927            panic!("Expected Select");
1928        }
1929    }
1930
1931    #[test]
1932    fn test_subquery_with_where() {
1933        let stmt = parse_query(
1934            "SELECT x FROM (SELECT x FROM t WHERE y > 0) LIMIT 5"
1935        ).unwrap();
1936        if let Statement::Select(q) = stmt {
1937            assert!(q.subquery.is_some());
1938            assert_eq!(q.limit, Some(5));
1939            let sub = q.subquery.unwrap();
1940            assert_eq!(sub.table, "t");
1941            assert!(sub.where_clause.is_some());
1942        } else {
1943            panic!("Expected Select");
1944        }
1945    }
1946
1947    // ── Issue #42 + CREATE VIEW: aggregate subtraction in view ──
1948
1949    #[test]
1950    fn test_create_view_aggregate_subtraction() {
1951        let stmt = parse_query(
1952            "CREATE VIEW v AS SELECT token, SUM(sell) - SUM(buy) as net FROM orders GROUP BY token"
1953        ).unwrap();
1954        if let Statement::CreateView(cv) = stmt {
1955            assert_eq!(cv.view_name, "v");
1956            assert_eq!(cv.query.group_by, Some(vec!["token".into()]));
1957            if let ColumnList::Named(exprs) = &cv.query.columns {
1958                assert_eq!(exprs.len(), 2);
1959                assert_eq!(exprs[1].output_name(), "net");
1960                assert!(exprs[1].is_aggregate());
1961            } else {
1962                panic!("Expected Named columns");
1963            }
1964        } else {
1965            panic!("Expected CreateView, got {:?}", stmt);
1966        }
1967    }
1968
1969    #[test]
1970    fn test_delete_cascade() {
1971        let stmt = parse_query("DELETE FROM strategies WHERE status = 'KILLED' CASCADE").unwrap();
1972        if let Statement::Delete(q) = stmt {
1973            assert_eq!(q.table, "strategies");
1974            assert!(q.where_clause.is_some());
1975            assert_eq!(q.mode, DeleteMode::Cascade);
1976        } else {
1977            panic!("Expected Delete");
1978        }
1979    }
1980
1981    #[test]
1982    fn test_delete_restrict() {
1983        let stmt = parse_query("DELETE FROM strategies WHERE path = 'alpha.md' RESTRICT").unwrap();
1984        if let Statement::Delete(q) = stmt {
1985            assert_eq!(q.table, "strategies");
1986            assert_eq!(q.mode, DeleteMode::Restrict);
1987        } else {
1988            panic!("Expected Delete");
1989        }
1990    }
1991
1992    #[test]
1993    fn test_delete_default_unchanged() {
1994        let stmt = parse_query("DELETE FROM strategies WHERE status = 'KILLED'").unwrap();
1995        if let Statement::Delete(q) = stmt {
1996            assert_eq!(q.mode, DeleteMode::Default);
1997        } else {
1998            panic!("Expected Delete");
1999        }
2000    }
2001
2002    #[test]
2003    fn test_delete_cascade_no_where() {
2004        let stmt = parse_query("DELETE FROM strategies CASCADE").unwrap();
2005        if let Statement::Delete(q) = stmt {
2006            assert_eq!(q.table, "strategies");
2007            assert!(q.where_clause.is_none());
2008            assert_eq!(q.mode, DeleteMode::Cascade);
2009        } else {
2010            panic!("Expected Delete");
2011        }
2012    }
2013}