Skip to main content

mdql_core/
query_parser.rs

1//! Hand-written recursive descent parser for the MDQL SQL subset.
2
3use regex::Regex;
4use std::sync::LazyLock;
5
6use crate::errors::MdqlError;
7pub use crate::query_ast::*;
8
9// ── Tokenizer ──────────────────────────────────────────────────────────────
10
11static KEYWORDS: &[&str] = &[
12    "SELECT", "FROM", "WHERE", "AND", "OR", "ORDER", "BY",
13    "ASC", "DESC", "LIMIT", "LIKE", "IN", "IS", "NOT", "NULL",
14    "JOIN", "LEFT", "ON", "AS", "GROUP", "HAVING",
15    "INSERT", "INTO", "VALUES", "UPDATE", "SET", "DELETE",
16    "ALTER", "TABLE", "RENAME", "FIELD", "TO", "DROP", "MERGE", "FIELDS",
17    "CASE", "WHEN", "THEN", "ELSE", "END",
18    "INTERVAL", "DAY", "DAYS", "CURRENT_DATE", "CURRENT_TIMESTAMP", "DATEDIFF",
19    "CREATE", "VIEW", "CASCADE", "RESTRICT",
20    "WITH",
21];
22
23static AGG_FUNCS: &[&str] = &["COUNT", "SUM", "AVG", "MIN", "MAX"];
24
25static TOKEN_RE: LazyLock<Regex> = LazyLock::new(|| {
26    Regex::new(
27        r#"(?x)
28        \s*(?:
29            (?P<backtick>`[^`]+`)
30            | (?P<string>'(?:[^'\\]|\\.)*')
31            | (?P<number>\d+(?:\.\d+)?)
32            | (?P<op><=|>=|!=|[=<>,*()+\-/%])
33            | (?P<word>[A-Za-z_][A-Za-z0-9_./-]*)
34        )"#,
35    )
36    .unwrap()
37});
38
39#[derive(Debug, Clone)]
40struct Token {
41    token_type: String,
42    value: String,
43    raw: String,
44}
45
46fn tokenize(sql: &str) -> Vec<Token> {
47    let mut tokens = Vec::new();
48    for caps in TOKEN_RE.captures_iter(sql) {
49        if let Some(m) = caps.name("backtick") {
50            let raw = m.as_str();
51            tokens.push(Token {
52                token_type: "ident".into(),
53                value: raw[1..raw.len() - 1].into(),
54                raw: raw.into(),
55            });
56        } else if let Some(m) = caps.name("string") {
57            let raw = m.as_str();
58            tokens.push(Token {
59                token_type: "string".into(),
60                value: raw[1..raw.len() - 1].into(),
61                raw: raw.into(),
62            });
63        } else if let Some(m) = caps.name("number") {
64            let raw = m.as_str();
65            tokens.push(Token {
66                token_type: "number".into(),
67                value: raw.into(),
68                raw: raw.into(),
69            });
70        } else if let Some(m) = caps.name("op") {
71            let raw = m.as_str();
72            tokens.push(Token {
73                token_type: "op".into(),
74                value: raw.into(),
75                raw: raw.into(),
76            });
77        } else if let Some(m) = caps.name("word") {
78            let raw = m.as_str();
79            if KEYWORDS.contains(&raw.to_uppercase().as_str()) {
80                tokens.push(Token {
81                    token_type: "keyword".into(),
82                    value: raw.to_uppercase(),
83                    raw: raw.into(),
84                });
85            } else {
86                tokens.push(Token {
87                    token_type: "ident".into(),
88                    value: raw.into(),
89                    raw: raw.into(),
90                });
91            }
92        }
93    }
94    tokens
95}
96
97// ── Parser ─────────────────────────────────────────────────────────────────
98
99struct Parser {
100    tokens: Vec<Token>,
101    pos: usize,
102}
103
104impl Parser {
105    fn new(tokens: Vec<Token>) -> Self {
106        Parser { tokens, pos: 0 }
107    }
108
109    fn peek(&self) -> Option<&Token> {
110        self.tokens.get(self.pos)
111    }
112
113    fn advance(&mut self) -> Token {
114        let t = self.tokens[self.pos].clone();
115        self.pos += 1;
116        t
117    }
118
119    fn expect(&mut self, type_: &str, value: Option<&str>) -> Result<Token, MdqlError> {
120        let t = self.peek().ok_or_else(|| {
121            MdqlError::QueryParse(format!(
122                "Unexpected end of query, expected {}",
123                value.unwrap_or(type_)
124            ))
125        })?;
126        let matches_type = t.token_type == type_;
127        let matches_value = value.map_or(true, |v| t.value == v);
128        if !matches_type || !matches_value {
129            return Err(MdqlError::QueryParse(format!(
130                "Expected {}, got '{}' at position {}",
131                value.unwrap_or(type_),
132                t.raw,
133                self.pos
134            )));
135        }
136        Ok(self.advance())
137    }
138
139    fn match_keyword(&mut self, kw: &str) -> bool {
140        if let Some(t) = self.peek() {
141            if t.token_type == "keyword" && t.value == kw {
142                self.advance();
143                return true;
144            }
145        }
146        false
147    }
148
149    fn parse_statement(&mut self) -> Result<Statement, MdqlError> {
150        let t = self.peek().ok_or_else(|| MdqlError::QueryParse("Empty query".into()))?;
151        match (t.token_type.as_str(), t.value.as_str()) {
152            ("keyword", "WITH") => {
153                let ctes = self.parse_ctes()?;
154                let mut q = self.parse_select()?;
155                q.ctes = ctes;
156                self.expect_end()?;
157                Ok(Statement::Select(q))
158            }
159            ("keyword", "SELECT") => {
160                let q = self.parse_select()?;
161                self.expect_end()?;
162                Ok(Statement::Select(q))
163            }
164            ("keyword", "INSERT") => Ok(Statement::Insert(self.parse_insert()?)),
165            ("keyword", "UPDATE") => Ok(Statement::Update(self.parse_update()?)),
166            ("keyword", "DELETE") => Ok(Statement::Delete(self.parse_delete()?)),
167            ("keyword", "ALTER") => self.parse_alter(),
168            ("keyword", "CREATE") => self.parse_create_view(),
169            ("keyword", "DROP") => self.parse_drop_view(),
170            _ => Err(MdqlError::QueryParse(format!(
171                "Expected SELECT, INSERT, UPDATE, DELETE, ALTER, CREATE, or DROP, got '{}'",
172                t.raw
173            ))),
174        }
175    }
176
177    fn parse_ctes(&mut self) -> Result<Vec<CteClause>, MdqlError> {
178        self.expect("keyword", Some("WITH"))?;
179        let mut ctes = Vec::new();
180        loop {
181            let name = self.parse_ident()?;
182            self.expect("keyword", Some("AS"))?;
183            self.expect("op", Some("("))?;
184            let query = self.parse_select()?;
185            self.expect("op", Some(")"))?;
186            ctes.push(CteClause { name, query: Box::new(query) });
187            if !self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
188                break;
189            }
190            self.advance();
191        }
192        Ok(ctes)
193    }
194
195    fn parse_select(&mut self) -> Result<SelectQuery, MdqlError> {
196        self.expect("keyword", Some("SELECT"))?;
197        let columns = self.parse_columns()?;
198        self.expect("keyword", Some("FROM"))?;
199
200        // Subquery: FROM (SELECT ...)
201        let mut subquery = None;
202        let (table, mut table_alias) = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
203            self.advance();
204            let inner = self.parse_select()?;
205            self.expect("op", Some(")"))?;
206            subquery = Some(Box::new(inner));
207            let alias = if let Some(t) = self.peek() {
208                if t.token_type == "ident" && !self.is_clause_keyword(t) {
209                    Some(self.advance().value)
210                } else {
211                    None
212                }
213            } else {
214                None
215            };
216            ("_subquery".to_string(), alias)
217        } else {
218            let t = self.parse_ident()?;
219            (t, None)
220        };
221
222        // Optional table alias (for non-subquery)
223        if subquery.is_none() {
224            if let Some(t) = self.peek() {
225                if t.token_type == "ident" && !self.is_clause_keyword(t) {
226                    table_alias = Some(self.advance().value);
227                }
228            }
229        }
230
231        // Optional JOIN(s)
232        let mut joins = Vec::new();
233        loop {
234            let jt = if self.match_keyword("LEFT") {
235                self.expect("keyword", Some("JOIN"))?;
236                JoinType::Left
237            } else if self.match_keyword("JOIN") {
238                JoinType::Inner
239            } else {
240                break;
241            };
242            let join_table = self.parse_ident()?;
243            let mut join_alias = None;
244            if let Some(t) = self.peek() {
245                if t.token_type == "ident" && !self.is_clause_keyword(t) {
246                    join_alias = Some(self.advance().value);
247                }
248            }
249            self.expect("keyword", Some("ON"))?;
250            let condition = self.parse_or_expr()?;
251            joins.push(JoinClause {
252                join_type: jt,
253                table: join_table,
254                alias: join_alias,
255                condition,
256            });
257        }
258
259        let mut where_clause = None;
260        if self.match_keyword("WHERE") {
261            where_clause = Some(self.parse_or_expr()?);
262        }
263
264        let mut group_by = None;
265        if self.match_keyword("GROUP") {
266            self.expect("keyword", Some("BY"))?;
267            let mut cols = vec![self.parse_ident()?];
268            while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
269                self.advance();
270                cols.push(self.parse_ident()?);
271            }
272            group_by = Some(cols);
273        }
274
275        let mut having = None;
276        if self.match_keyword("HAVING") {
277            having = Some(self.parse_or_expr()?);
278        }
279
280        let mut order_by = None;
281        if self.match_keyword("ORDER") {
282            self.expect("keyword", Some("BY"))?;
283            order_by = Some(self.parse_order_by()?);
284        }
285
286        let mut limit = None;
287        if self.match_keyword("LIMIT") {
288            let t = self.expect("number", None)?;
289            limit = Some(t.value.parse::<i64>().map_err(|_| {
290                MdqlError::QueryParse(format!("Invalid LIMIT value: {}", t.value))
291            })?);
292        }
293
294        Ok(SelectQuery {
295            columns,
296            table,
297            table_alias,
298            subquery,
299            joins,
300            where_clause,
301            group_by,
302            having,
303            order_by,
304            limit,
305            ctes: vec![],
306        })
307    }
308
309    fn parse_insert(&mut self) -> Result<InsertQuery, MdqlError> {
310        self.expect("keyword", Some("INSERT"))?;
311        self.expect("keyword", Some("INTO"))?;
312        let table = self.parse_ident()?;
313
314        self.expect("op", Some("("))?;
315        let mut columns = vec![self.parse_ident()?];
316        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
317            self.advance();
318            columns.push(self.parse_ident()?);
319        }
320        self.expect("op", Some(")"))?;
321
322        self.expect("keyword", Some("VALUES"))?;
323
324        self.expect("op", Some("("))?;
325        let mut values = vec![self.parse_value()?];
326        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
327            self.advance();
328            values.push(self.parse_value()?);
329        }
330        self.expect("op", Some(")"))?;
331
332        if columns.len() != values.len() {
333            return Err(MdqlError::QueryParse(format!(
334                "Column count ({}) does not match value count ({})",
335                columns.len(),
336                values.len()
337            )));
338        }
339
340        self.expect_end()?;
341        Ok(InsertQuery {
342            table,
343            columns,
344            values,
345        })
346    }
347
348    fn parse_update(&mut self) -> Result<UpdateQuery, MdqlError> {
349        self.expect("keyword", Some("UPDATE"))?;
350        let table = self.parse_ident()?;
351        self.expect("keyword", Some("SET"))?;
352
353        let mut assignments = Vec::new();
354        let col = self.parse_ident()?;
355        self.expect("op", Some("="))?;
356        let val = self.parse_value()?;
357        assignments.push((col, val));
358
359        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
360            self.advance();
361            let col = self.parse_ident()?;
362            self.expect("op", Some("="))?;
363            let val = self.parse_value()?;
364            assignments.push((col, val));
365        }
366
367        let mut where_clause = None;
368        if self.match_keyword("WHERE") {
369            where_clause = Some(self.parse_or_expr()?);
370        }
371
372        self.expect_end()?;
373        Ok(UpdateQuery {
374            table,
375            assignments,
376            where_clause,
377        })
378    }
379
380    fn parse_delete(&mut self) -> Result<DeleteQuery, MdqlError> {
381        self.expect("keyword", Some("DELETE"))?;
382        self.expect("keyword", Some("FROM"))?;
383        let table = self.parse_ident()?;
384
385        let mut where_clause = None;
386        if self.match_keyword("WHERE") {
387            where_clause = Some(self.parse_or_expr()?);
388        }
389
390        let mode = if self.match_keyword("CASCADE") {
391            DeleteMode::Cascade
392        } else if self.match_keyword("RESTRICT") {
393            DeleteMode::Restrict
394        } else {
395            DeleteMode::Default
396        };
397
398        self.expect_end()?;
399        Ok(DeleteQuery {
400            table,
401            where_clause,
402            mode,
403        })
404    }
405
406    fn parse_alter(&mut self) -> Result<Statement, MdqlError> {
407        self.expect("keyword", Some("ALTER"))?;
408        self.expect("keyword", Some("TABLE"))?;
409        let table = self.parse_ident()?;
410
411        let t = self.peek().ok_or_else(|| {
412            MdqlError::QueryParse("Expected RENAME, DROP, or MERGE after table name".into())
413        })?;
414
415        match (t.token_type.as_str(), t.value.as_str()) {
416            ("keyword", "RENAME") => {
417                self.advance();
418                self.expect("keyword", Some("FIELD"))?;
419                let old_name = self.parse_string_or_ident()?;
420                self.expect("keyword", Some("TO"))?;
421                let new_name = self.parse_string_or_ident()?;
422                self.expect_end()?;
423                Ok(Statement::AlterRename(AlterRenameFieldQuery {
424                    table,
425                    old_name,
426                    new_name,
427                }))
428            }
429            ("keyword", "DROP") => {
430                self.advance();
431                self.expect("keyword", Some("FIELD"))?;
432                let field_name = self.parse_string_or_ident()?;
433                self.expect_end()?;
434                Ok(Statement::AlterDrop(AlterDropFieldQuery {
435                    table,
436                    field_name,
437                }))
438            }
439            ("keyword", "MERGE") => {
440                self.advance();
441                self.expect("keyword", Some("FIELDS"))?;
442                let mut sources = vec![self.parse_string_or_ident()?];
443                while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
444                    self.advance();
445                    sources.push(self.parse_string_or_ident()?);
446                }
447                self.expect("keyword", Some("INTO"))?;
448                let target = self.parse_string_or_ident()?;
449                self.expect_end()?;
450                Ok(Statement::AlterMerge(AlterMergeFieldsQuery {
451                    table,
452                    sources,
453                    into: target,
454                }))
455            }
456            _ => Err(MdqlError::QueryParse(format!(
457                "Expected RENAME, DROP, or MERGE, got '{}'",
458                t.raw
459            ))),
460        }
461    }
462
463    fn parse_create_view(&mut self) -> Result<Statement, MdqlError> {
464        self.expect("keyword", Some("CREATE"))?;
465        self.expect("keyword", Some("VIEW"))?;
466        let view_name = self.parse_ident()?;
467
468        let columns = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
469            self.advance();
470            let mut cols = vec![self.parse_ident()?];
471            while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
472                self.advance();
473                cols.push(self.parse_ident()?);
474            }
475            self.expect("op", Some(")"))?;
476            Some(cols)
477        } else {
478            None
479        };
480
481        self.expect("keyword", Some("AS"))?;
482        let query = Box::new(self.parse_select()?);
483        self.expect_end()?;
484
485        Ok(Statement::CreateView(CreateViewQuery {
486            view_name,
487            columns,
488            query,
489        }))
490    }
491
492    fn parse_drop_view(&mut self) -> Result<Statement, MdqlError> {
493        self.expect("keyword", Some("DROP"))?;
494        self.expect("keyword", Some("VIEW"))?;
495        let view_name = self.parse_ident()?;
496        self.expect_end()?;
497        Ok(Statement::DropView(DropViewQuery { view_name }))
498    }
499
500    fn parse_string_or_ident(&mut self) -> Result<String, MdqlError> {
501        let t = self.peek().ok_or_else(|| {
502            MdqlError::QueryParse("Expected field name, got end of query".into())
503        })?;
504        match t.token_type.as_str() {
505            "string" => {
506                let v = self.advance().value;
507                Ok(v)
508            }
509            "ident" | "keyword" => {
510                let v = self.advance().value;
511                Ok(v)
512            }
513            _ => Err(MdqlError::QueryParse(format!(
514                "Expected field name, got '{}'",
515                t.raw
516            ))),
517        }
518    }
519
520    fn parse_columns(&mut self) -> Result<ColumnList, MdqlError> {
521        if let Some(t) = self.peek() {
522            if t.token_type == "op" && t.value == "*" {
523                self.advance();
524                return Ok(ColumnList::All);
525            }
526        }
527
528        let mut exprs = vec![self.parse_select_expr()?];
529        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
530            self.advance();
531            exprs.push(self.parse_select_expr()?);
532        }
533        Ok(ColumnList::Named(exprs))
534    }
535
536    fn peek_is_agg_func(&self) -> bool {
537        let t = match self.peek() {
538            Some(t) => t,
539            None => return false,
540        };
541        let name_upper = t.value.to_uppercase();
542        if !AGG_FUNCS.contains(&name_upper.as_str()) {
543            return false;
544        }
545        // Only treat as aggregate if followed by (
546        self.tokens
547            .get(self.pos + 1)
548            .map_or(false, |next| next.token_type == "op" && next.value == "(")
549    }
550
551    fn parse_select_expr(&mut self) -> Result<SelectExpr, MdqlError> {
552        let _t = self.peek().ok_or_else(|| {
553            MdqlError::QueryParse("Expected column or aggregate, got end of query".into())
554        })?;
555
556        let expr = self.parse_additive()?;
557
558        let alias = if self.match_keyword("AS") {
559            Some(self.parse_ident()?)
560        } else if self.peek().map_or(false, |t| {
561            t.token_type == "ident" && !self.is_clause_keyword(t)
562        }) {
563            Some(self.advance().value)
564        } else {
565            None
566        };
567
568        // Bare aggregate → SelectExpr::Aggregate for backward compat
569        if let Expr::Aggregate { func, arg, arg_expr } = expr {
570            return Ok(SelectExpr::Aggregate {
571                func,
572                arg,
573                arg_expr: arg_expr.map(|e| *e),
574                alias,
575            });
576        }
577
578        if alias.is_none() {
579            if let Expr::Column(name) = &expr {
580                return Ok(SelectExpr::Column(name.clone()));
581            }
582        }
583
584        Ok(SelectExpr::Expr { expr, alias })
585    }
586
587    // ── Expression parser (precedence climbing) ───────────────────────
588
589    fn peek_is_additive_op(&self) -> bool {
590        self.peek().map_or(false, |t| {
591            t.token_type == "op" && (t.value == "+" || t.value == "-")
592        })
593    }
594
595    fn peek_is_multiplicative_op(&self) -> bool {
596        self.peek().map_or(false, |t| {
597            t.token_type == "op" && (t.value == "*" || t.value == "/" || t.value == "%")
598        })
599    }
600
601    fn parse_additive(&mut self) -> Result<Expr, MdqlError> {
602        let mut left = self.parse_multiplicative()?;
603        while self.peek_is_additive_op() {
604            let op_tok = self.advance();
605            let is_sub = op_tok.value == "-";
606
607            // Check for INTERVAL keyword: expr +/- INTERVAL n DAY
608            if self.peek().map_or(false, |t| t.token_type == "keyword" && t.value == "INTERVAL") {
609                self.advance(); // consume INTERVAL
610                let days_expr = self.parse_multiplicative()?;
611                // Expect DAY or DAYS
612                if !self.match_keyword("DAY") && !self.match_keyword("DAYS") {
613                    return Err(MdqlError::QueryParse("Expected DAY after INTERVAL value".into()));
614                }
615                let days = if is_sub {
616                    Expr::UnaryMinus(Box::new(days_expr))
617                } else {
618                    days_expr
619                };
620                left = Expr::DateAdd {
621                    date: Box::new(left),
622                    days: Box::new(days),
623                };
624                continue;
625            }
626
627            let op = match op_tok.value.as_str() {
628                "+" => ArithOp::Add,
629                "-" => ArithOp::Sub,
630                _ => unreachable!(),
631            };
632            let right = self.parse_multiplicative()?;
633            left = Expr::BinaryOp {
634                left: Box::new(left),
635                op,
636                right: Box::new(right),
637            };
638        }
639        Ok(left)
640    }
641
642    fn parse_multiplicative(&mut self) -> Result<Expr, MdqlError> {
643        let mut left = self.parse_unary()?;
644        while self.peek_is_multiplicative_op() {
645            let op_tok = self.advance();
646            let op = match op_tok.value.as_str() {
647                "*" => ArithOp::Mul,
648                "/" => ArithOp::Div,
649                "%" => ArithOp::Mod,
650                _ => unreachable!(),
651            };
652            let right = self.parse_unary()?;
653            left = Expr::BinaryOp {
654                left: Box::new(left),
655                op,
656                right: Box::new(right),
657            };
658        }
659        Ok(left)
660    }
661
662    fn parse_unary(&mut self) -> Result<Expr, MdqlError> {
663        if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "-") {
664            self.advance();
665            let inner = self.parse_atom()?;
666            // Fold unary minus on literals
667            match inner {
668                Expr::Literal(SqlValue::Int(n)) => Ok(Expr::Literal(SqlValue::Int(-n))),
669                Expr::Literal(SqlValue::Float(f)) => Ok(Expr::Literal(SqlValue::Float(-f))),
670                _ => Ok(Expr::UnaryMinus(Box::new(inner))),
671            }
672        } else {
673            self.parse_atom()
674        }
675    }
676
677    fn parse_atom(&mut self) -> Result<Expr, MdqlError> {
678        if self.peek_is_agg_func() {
679            return self.parse_agg_expr();
680        }
681
682        let t = self.peek().ok_or_else(|| {
683            MdqlError::QueryParse("Expected expression, got end of query".into())
684        })?;
685
686        match t.token_type.as_str() {
687            "number" => {
688                let v = self.advance().value;
689                if v.contains('.') {
690                    let f: f64 = v.parse().map_err(|_| {
691                        MdqlError::QueryParse(format!("Invalid float: {}", v))
692                    })?;
693                    Ok(Expr::Literal(SqlValue::Float(f)))
694                } else {
695                    let n: i64 = v.parse().map_err(|_| {
696                        MdqlError::QueryParse(format!("Invalid int: {}", v))
697                    })?;
698                    Ok(Expr::Literal(SqlValue::Int(n)))
699                }
700            }
701            "string" => {
702                let v = self.advance().value;
703                Ok(Expr::Literal(SqlValue::String(v)))
704            }
705            "keyword" if t.value == "NULL" => {
706                self.advance();
707                Ok(Expr::Literal(SqlValue::Null))
708            }
709            "keyword" if t.value == "CASE" => {
710                self.parse_case_expr()
711            }
712            "keyword" if t.value == "CURRENT_DATE" => {
713                self.advance();
714                Ok(Expr::CurrentDate)
715            }
716            "keyword" if t.value == "CURRENT_TIMESTAMP" => {
717                self.advance();
718                Ok(Expr::CurrentTimestamp)
719            }
720            "keyword" if t.value == "DATEDIFF" => {
721                self.advance();
722                self.expect("op", Some("("))?;
723                let left = self.parse_additive()?;
724                self.expect("op", Some(","))?;
725                let right = self.parse_additive()?;
726                self.expect("op", Some(")"))?;
727                Ok(Expr::DateDiff { left: Box::new(left), right: Box::new(right) })
728            }
729            "op" if t.value == "(" => {
730                self.advance();
731                let expr = self.parse_additive()?;
732                self.expect("op", Some(")"))?;
733                Ok(expr)
734            }
735            "ident" => {
736                let name = self.advance().value;
737                Ok(Expr::Column(name))
738            }
739            "keyword" if !Self::is_reserved_keyword(&t.value) => {
740                let name = self.advance().value;
741                Ok(Expr::Column(name))
742            }
743            _ => Err(MdqlError::QueryParse(format!(
744                "Expected expression, got '{}'",
745                t.raw
746            ))),
747        }
748    }
749
750    fn parse_case_expr(&mut self) -> Result<Expr, MdqlError> {
751        self.expect("keyword", Some("CASE"))?;
752        let mut whens = Vec::new();
753        while self.match_keyword("WHEN") {
754            let condition = self.parse_or_expr()?;
755            self.expect("keyword", Some("THEN"))?;
756            let result = self.parse_additive()?;
757            whens.push((condition, Box::new(result)));
758        }
759        if whens.is_empty() {
760            return Err(MdqlError::QueryParse("CASE requires at least one WHEN clause".into()));
761        }
762        let else_expr = if self.match_keyword("ELSE") {
763            Some(Box::new(self.parse_additive()?))
764        } else {
765            None
766        };
767        self.expect("keyword", Some("END"))?;
768        Ok(Expr::Case { whens, else_expr })
769    }
770
771    fn parse_agg_expr(&mut self) -> Result<Expr, MdqlError> {
772        let func_name = self.advance().value.to_uppercase();
773        let func = match func_name.as_str() {
774            "COUNT" => AggFunc::Count,
775            "SUM" => AggFunc::Sum,
776            "AVG" => AggFunc::Avg,
777            "MIN" => AggFunc::Min,
778            "MAX" => AggFunc::Max,
779            _ => unreachable!(),
780        };
781        self.expect("op", Some("("))?;
782        let (arg, arg_expr) = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "*") {
783            self.advance();
784            ("*".to_string(), None)
785        } else {
786            let expr = self.parse_additive()?;
787            if let Expr::Column(name) = &expr {
788                (name.clone(), None)
789            } else {
790                (expr.display_name(), Some(Box::new(expr)))
791            }
792        };
793        self.expect("op", Some(")"))?;
794        Ok(Expr::Aggregate { func, arg, arg_expr })
795    }
796
797    fn parse_ident(&mut self) -> Result<String, MdqlError> {
798        let t = self.peek().ok_or_else(|| {
799            MdqlError::QueryParse("Expected identifier, got end of query".into())
800        })?;
801        match t.token_type.as_str() {
802            "ident" | "keyword" => {
803                let v = self.advance().value;
804                Ok(v)
805            }
806            _ => Err(MdqlError::QueryParse(format!(
807                "Expected identifier, got '{}'",
808                t.raw
809            ))),
810        }
811    }
812
813    fn parse_or_expr(&mut self) -> Result<WhereClause, MdqlError> {
814        let mut left = self.parse_and_expr()?;
815        while self.match_keyword("OR") {
816            let right = self.parse_and_expr()?;
817            left = WhereClause::BoolOp(BoolOp {
818                op: BoolOpKind::Or,
819                left: Box::new(left),
820                right: Box::new(right),
821            });
822        }
823        Ok(left)
824    }
825
826    fn parse_and_expr(&mut self) -> Result<WhereClause, MdqlError> {
827        let mut left = self.parse_comparison()?;
828        while self.match_keyword("AND") {
829            let right = self.parse_comparison()?;
830            left = WhereClause::BoolOp(BoolOp {
831                op: BoolOpKind::And,
832                left: Box::new(left),
833                right: Box::new(right),
834            });
835        }
836        Ok(left)
837    }
838
839    fn parse_comparison(&mut self) -> Result<WhereClause, MdqlError> {
840        // Handle parenthesized boolean expressions
841        if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
842            // Save position — might be arithmetic parens, not boolean
843            let saved_pos = self.pos;
844            self.advance();
845            // Try parsing as boolean (OR/AND) expression
846            let result = self.parse_or_expr();
847            if result.is_ok() && self.peek().map_or(false, |t| t.token_type == "op" && t.value == ")") {
848                self.advance();
849                return result;
850            }
851            // Not a boolean paren — rewind and parse as arithmetic expression
852            self.pos = saved_pos;
853        }
854
855        // Parse the left side as a full expression (column, literal, or arithmetic)
856        let left_expr = self.parse_additive()?;
857
858        // Extract column name for backward compat (simple column on left side)
859        let col = left_expr.as_column().unwrap_or("").to_string();
860
861        // IS NULL / IS NOT NULL (only valid with simple column)
862        if self.match_keyword("IS") {
863            if self.match_keyword("NOT") {
864                self.expect("keyword", Some("NULL"))?;
865                return Ok(WhereClause::Comparison(Comparison {
866                    column: col,
867                    op: CmpOp::IsNotNull,
868                    value: None,
869                    left_expr: Some(left_expr),
870                    right_expr: None,
871                }));
872            }
873            self.expect("keyword", Some("NULL"))?;
874            return Ok(WhereClause::Comparison(Comparison {
875                column: col,
876                op: CmpOp::IsNull,
877                value: None,
878                left_expr: Some(left_expr),
879                right_expr: None,
880            }));
881        }
882
883        // IN (val, val, ...)
884        if self.match_keyword("IN") {
885            self.expect("op", Some("("))?;
886            let mut values = vec![self.parse_value()?];
887            while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
888                self.advance();
889                values.push(self.parse_value()?);
890            }
891            self.expect("op", Some(")"))?;
892            return Ok(WhereClause::Comparison(Comparison {
893                column: col,
894                op: CmpOp::In,
895                value: Some(SqlValue::List(values)),
896                left_expr: Some(left_expr),
897                right_expr: None,
898            }));
899        }
900
901        // LIKE
902        if self.match_keyword("LIKE") {
903            let val = self.parse_value()?;
904            return Ok(WhereClause::Comparison(Comparison {
905                column: col,
906                op: CmpOp::Like,
907                value: Some(val),
908                left_expr: Some(left_expr),
909                right_expr: None,
910            }));
911        }
912
913        // NOT LIKE
914        if self.match_keyword("NOT") {
915            if self.match_keyword("LIKE") {
916                let val = self.parse_value()?;
917                return Ok(WhereClause::Comparison(Comparison {
918                    column: col,
919                    op: CmpOp::NotLike,
920                    value: Some(val),
921                    left_expr: Some(left_expr),
922                    right_expr: None,
923                }));
924            }
925            return Err(MdqlError::QueryParse("Expected LIKE after NOT".into()));
926        }
927
928        // Standard comparison operators
929        if let Some(t) = self.peek() {
930            if t.token_type == "op" && ["=", "!=", "<", ">", "<=", ">="].contains(&t.value.as_str())
931            {
932                let op_str = self.advance().value;
933                let op = match op_str.as_str() {
934                    "=" => CmpOp::Eq,
935                    "!=" => CmpOp::Ne,
936                    "<" => CmpOp::Lt,
937                    ">" => CmpOp::Gt,
938                    "<=" => CmpOp::Le,
939                    ">=" => CmpOp::Ge,
940                    _ => unreachable!(),
941                };
942                // Parse right side as expression
943                let right_expr = self.parse_additive()?;
944                // Extract SqlValue for backward compat (simple literal on right side)
945                let value = match &right_expr {
946                    Expr::Literal(v) => Some(v.clone()),
947                    _ => None,
948                };
949                return Ok(WhereClause::Comparison(Comparison {
950                    column: col,
951                    op,
952                    value,
953                    left_expr: Some(left_expr),
954                    right_expr: Some(right_expr),
955                }));
956            }
957        }
958
959        let got = self.peek().map_or("end".to_string(), |t| t.raw.clone());
960        Err(MdqlError::QueryParse(format!(
961            "Expected operator after '{}', got '{}'",
962            left_expr.display_name(), got
963        )))
964    }
965
966    fn parse_value(&mut self) -> Result<SqlValue, MdqlError> {
967        let t = self.peek().ok_or_else(|| {
968            MdqlError::QueryParse("Expected value, got end of query".into())
969        })?;
970        match t.token_type.as_str() {
971            "string" => {
972                let v = self.advance().value;
973                Ok(SqlValue::String(v))
974            }
975            "number" => {
976                let v = self.advance().value;
977                if v.contains('.') {
978                    Ok(SqlValue::Float(v.parse().map_err(|_| {
979                        MdqlError::QueryParse(format!("Invalid float: {}", v))
980                    })?))
981                } else {
982                    Ok(SqlValue::Int(v.parse().map_err(|_| {
983                        MdqlError::QueryParse(format!("Invalid int: {}", v))
984                    })?))
985                }
986            }
987            "keyword" if t.value == "NULL" => {
988                self.advance();
989                Ok(SqlValue::Null)
990            }
991            _ => Err(MdqlError::QueryParse(format!(
992                "Expected value, got '{}'",
993                t.raw
994            ))),
995        }
996    }
997
998    fn parse_order_by(&mut self) -> Result<Vec<OrderSpec>, MdqlError> {
999        let mut specs = vec![self.parse_order_spec()?];
1000        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
1001            self.advance();
1002            specs.push(self.parse_order_spec()?);
1003        }
1004        Ok(specs)
1005    }
1006
1007    fn parse_order_spec(&mut self) -> Result<OrderSpec, MdqlError> {
1008        let expr = self.parse_additive()?;
1009        let col = expr.as_column().unwrap_or("").to_string();
1010        let descending = if self.match_keyword("DESC") {
1011            true
1012        } else {
1013            self.match_keyword("ASC");
1014            false
1015        };
1016        Ok(OrderSpec {
1017            column: col,
1018            expr: Some(expr),
1019            descending,
1020        })
1021    }
1022
1023    fn is_clause_keyword(&self, t: &Token) -> bool {
1024        t.token_type == "keyword"
1025            && ["WHERE", "ORDER", "LIMIT", "JOIN", "LEFT", "ON", "GROUP"].contains(&t.value.as_str())
1026    }
1027
1028    /// Keywords that should never be consumed as column names inside expressions.
1029    fn is_reserved_keyword(kw: &str) -> bool {
1030        matches!(kw,
1031            "AS" | "FROM" | "WHERE" | "AND" | "OR" | "ORDER" | "BY"
1032            | "ASC" | "DESC" | "LIMIT" | "JOIN" | "ON" | "GROUP"
1033            | "SELECT" | "INSERT" | "INTO" | "VALUES" | "UPDATE" | "SET"
1034            | "DELETE" | "ALTER" | "TABLE" | "IS" | "NOT" | "IN" | "LIKE"
1035            | "RENAME" | "FIELD" | "TO" | "DROP" | "MERGE" | "FIELDS"
1036            | "CASE" | "WHEN" | "THEN" | "ELSE" | "END"
1037            | "HAVING" | "INTERVAL" | "DAY" | "DAYS"
1038            | "CURRENT_DATE" | "CURRENT_TIMESTAMP" | "DATEDIFF"
1039            | "CREATE" | "VIEW" | "CASCADE" | "RESTRICT"
1040            | "WITH"
1041        )
1042    }
1043
1044    fn expect_end(&self) -> Result<(), MdqlError> {
1045        if let Some(t) = self.peek() {
1046            return Err(MdqlError::QueryParse(format!(
1047                "Unexpected token '{}' at position {}",
1048                t.raw, self.pos
1049            )));
1050        }
1051        Ok(())
1052    }
1053}
1054
1055pub fn parse_query(sql: &str) -> crate::errors::Result<Statement> {
1056    let tokens = tokenize(sql);
1057    if tokens.is_empty() {
1058        return Err(MdqlError::QueryParse("Empty query".into()));
1059    }
1060    let mut parser = Parser::new(tokens);
1061    parser.parse_statement()
1062}
1063
1064#[cfg(test)]
1065mod tests {
1066    use super::*;
1067
1068    #[test]
1069    fn test_simple_select() {
1070        let stmt = parse_query("SELECT title, status FROM strategies").unwrap();
1071        if let Statement::Select(q) = stmt {
1072            assert_eq!(q.columns, ColumnList::Named(vec![SelectExpr::Column("title".into()), SelectExpr::Column("status".into())]));
1073            assert_eq!(q.table, "strategies");
1074        } else {
1075            panic!("Expected Select");
1076        }
1077    }
1078
1079    #[test]
1080    fn test_select_star() {
1081        let stmt = parse_query("SELECT * FROM test").unwrap();
1082        if let Statement::Select(q) = stmt {
1083            assert_eq!(q.columns, ColumnList::All);
1084        } else {
1085            panic!("Expected Select");
1086        }
1087    }
1088
1089    #[test]
1090    fn test_where_clause() {
1091        let stmt = parse_query("SELECT title FROM test WHERE count > 5").unwrap();
1092        if let Statement::Select(q) = stmt {
1093            assert!(q.where_clause.is_some());
1094        } else {
1095            panic!("Expected Select");
1096        }
1097    }
1098
1099    #[test]
1100    fn test_order_by() {
1101        let stmt =
1102            parse_query("SELECT title FROM test ORDER BY composite DESC, title ASC").unwrap();
1103        if let Statement::Select(q) = stmt {
1104            let ob = q.order_by.unwrap();
1105            assert_eq!(ob.len(), 2);
1106            assert!(ob[0].descending);
1107            assert!(!ob[1].descending);
1108        } else {
1109            panic!("Expected Select");
1110        }
1111    }
1112
1113    #[test]
1114    fn test_limit() {
1115        let stmt = parse_query("SELECT * FROM test LIMIT 10").unwrap();
1116        if let Statement::Select(q) = stmt {
1117            assert_eq!(q.limit, Some(10));
1118        } else {
1119            panic!("Expected Select");
1120        }
1121    }
1122
1123    #[test]
1124    fn test_insert() {
1125        let stmt = parse_query(
1126            "INSERT INTO test (title, count) VALUES ('Hello', 42)",
1127        )
1128        .unwrap();
1129        if let Statement::Insert(q) = stmt {
1130            assert_eq!(q.table, "test");
1131            assert_eq!(q.columns, vec!["title", "count"]);
1132            assert_eq!(q.values[0], SqlValue::String("Hello".into()));
1133            assert_eq!(q.values[1], SqlValue::Int(42));
1134        } else {
1135            panic!("Expected Insert");
1136        }
1137    }
1138
1139    #[test]
1140    fn test_update() {
1141        let stmt = parse_query("UPDATE test SET status = 'KILLED' WHERE path = 'a.md'").unwrap();
1142        if let Statement::Update(q) = stmt {
1143            assert_eq!(q.table, "test");
1144            assert_eq!(q.assignments.len(), 1);
1145            assert!(q.where_clause.is_some());
1146        } else {
1147            panic!("Expected Update");
1148        }
1149    }
1150
1151    #[test]
1152    fn test_delete() {
1153        let stmt = parse_query("DELETE FROM test WHERE status = 'draft'").unwrap();
1154        if let Statement::Delete(q) = stmt {
1155            assert_eq!(q.table, "test");
1156            assert!(q.where_clause.is_some());
1157        } else {
1158            panic!("Expected Delete");
1159        }
1160    }
1161
1162    #[test]
1163    fn test_alter_rename() {
1164        let stmt =
1165            parse_query("ALTER TABLE test RENAME FIELD 'Summary' TO 'Overview'").unwrap();
1166        if let Statement::AlterRename(q) = stmt {
1167            assert_eq!(q.old_name, "Summary");
1168            assert_eq!(q.new_name, "Overview");
1169        } else {
1170            panic!("Expected AlterRename");
1171        }
1172    }
1173
1174    #[test]
1175    fn test_alter_drop() {
1176        let stmt = parse_query("ALTER TABLE test DROP FIELD 'Details'").unwrap();
1177        if let Statement::AlterDrop(q) = stmt {
1178            assert_eq!(q.field_name, "Details");
1179        } else {
1180            panic!("Expected AlterDrop");
1181        }
1182    }
1183
1184    #[test]
1185    fn test_alter_merge() {
1186        let stmt = parse_query(
1187            "ALTER TABLE test MERGE FIELDS 'Entry Rules', 'Exit Rules' INTO 'Trading Rules'",
1188        )
1189        .unwrap();
1190        if let Statement::AlterMerge(q) = stmt {
1191            assert_eq!(q.sources, vec!["Entry Rules", "Exit Rules"]);
1192            assert_eq!(q.into, "Trading Rules");
1193        } else {
1194            panic!("Expected AlterMerge");
1195        }
1196    }
1197
1198    #[test]
1199    fn test_backtick_ident() {
1200        let stmt = parse_query("SELECT `Structural Mechanism` FROM test").unwrap();
1201        if let Statement::Select(q) = stmt {
1202            assert_eq!(
1203                q.columns,
1204                ColumnList::Named(vec![SelectExpr::Column("Structural Mechanism".into())])
1205            );
1206        } else {
1207            panic!("Expected Select");
1208        }
1209    }
1210
1211    #[test]
1212    fn test_like_operator() {
1213        let stmt = parse_query("SELECT title FROM test WHERE categories LIKE '%defi%'").unwrap();
1214        if let Statement::Select(q) = stmt {
1215            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1216                assert_eq!(c.op, CmpOp::Like);
1217                assert_eq!(c.value, Some(SqlValue::String("%defi%".into())));
1218            } else {
1219                panic!("Expected LIKE comparison");
1220            }
1221        } else {
1222            panic!("Expected Select");
1223        }
1224    }
1225
1226    #[test]
1227    fn test_in_operator() {
1228        let stmt =
1229            parse_query("SELECT * FROM test WHERE status IN ('ACTIVE', 'LIVE')").unwrap();
1230        if let Statement::Select(q) = stmt {
1231            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1232                assert_eq!(c.op, CmpOp::In);
1233            } else {
1234                panic!("Expected IN comparison");
1235            }
1236        } else {
1237            panic!("Expected Select");
1238        }
1239    }
1240
1241    #[test]
1242    fn test_is_null() {
1243        let stmt = parse_query("SELECT * FROM test WHERE title IS NULL").unwrap();
1244        if let Statement::Select(q) = stmt {
1245            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1246                assert_eq!(c.op, CmpOp::IsNull);
1247            } else {
1248                panic!("Expected IS NULL comparison");
1249            }
1250        } else {
1251            panic!("Expected Select");
1252        }
1253    }
1254
1255    #[test]
1256    fn test_and_or() {
1257        let stmt = parse_query(
1258            "SELECT * FROM test WHERE status = 'ACTIVE' AND count > 5 OR title LIKE '%test%'",
1259        )
1260        .unwrap();
1261        if let Statement::Select(q) = stmt {
1262            assert!(q.where_clause.is_some());
1263        } else {
1264            panic!("Expected Select");
1265        }
1266    }
1267
1268    #[test]
1269    fn test_join() {
1270        let stmt = parse_query(
1271            "SELECT s.title, b.sharpe FROM strategies s JOIN backtests b ON b.strategy = s.path",
1272        )
1273        .unwrap();
1274        if let Statement::Select(q) = stmt {
1275            assert_eq!(q.table, "strategies");
1276            assert_eq!(q.table_alias, Some("s".into()));
1277            assert_eq!(q.joins.len(), 1);
1278            let join = &q.joins[0];
1279            assert_eq!(join.table, "backtests");
1280            assert_eq!(join.alias, Some("b".into()));
1281        } else {
1282            panic!("Expected Select");
1283        }
1284    }
1285
1286    #[test]
1287    fn test_multi_join() {
1288        let stmt = parse_query(
1289            "SELECT s.title, b.sharpe, c.verdict FROM strategies s JOIN backtests b ON b.strategy = s.path JOIN critiques c ON c.strategy = s.path",
1290        )
1291        .unwrap();
1292        if let Statement::Select(q) = stmt {
1293            assert_eq!(q.table, "strategies");
1294            assert_eq!(q.table_alias, Some("s".into()));
1295            assert_eq!(q.joins.len(), 2);
1296            assert_eq!(q.joins[0].table, "backtests");
1297            assert_eq!(q.joins[0].alias, Some("b".into()));
1298            assert_eq!(where_clause_to_sql(&q.joins[0].condition), "b.strategy = s.path");
1299            assert_eq!(q.joins[1].table, "critiques");
1300            assert_eq!(q.joins[1].alias, Some("c".into()));
1301            assert_eq!(where_clause_to_sql(&q.joins[1].condition), "c.strategy = s.path");
1302        } else {
1303            panic!("Expected Select");
1304        }
1305    }
1306
1307    #[test]
1308    fn test_left_join() {
1309        let stmt = parse_query(
1310            "SELECT s.title, b.sharpe FROM strategies s LEFT JOIN backtests b ON b.strategy = s.path",
1311        )
1312        .unwrap();
1313        if let Statement::Select(q) = stmt {
1314            assert_eq!(q.joins.len(), 1);
1315            assert_eq!(q.joins[0].join_type, JoinType::Left);
1316            assert_eq!(q.joins[0].table, "backtests");
1317        } else {
1318            panic!("Expected Select");
1319        }
1320    }
1321
1322    #[test]
1323    fn test_mixed_join_types() {
1324        let stmt = parse_query(
1325            "SELECT s.title FROM strategies s JOIN backtests b ON b.strategy = s.path LEFT JOIN allocations a ON a.strategy = s.path",
1326        )
1327        .unwrap();
1328        if let Statement::Select(q) = stmt {
1329            assert_eq!(q.joins.len(), 2);
1330            assert_eq!(q.joins[0].join_type, JoinType::Inner);
1331            assert_eq!(q.joins[1].join_type, JoinType::Left);
1332        } else {
1333            panic!("Expected Select");
1334        }
1335    }
1336
1337    #[test]
1338    fn test_join_compound_and() {
1339        let stmt = parse_query(
1340            "SELECT s.title FROM strategies s LEFT JOIN backtests b ON b.strategy = s.path AND b.mode = 'PAPER'",
1341        )
1342        .unwrap();
1343        if let Statement::Select(q) = stmt {
1344            assert_eq!(q.joins.len(), 1);
1345            assert_eq!(q.joins[0].join_type, JoinType::Left);
1346            let sql = where_clause_to_sql(&q.joins[0].condition);
1347            assert!(sql.contains("b.strategy = s.path"));
1348            assert!(sql.contains("AND"));
1349            assert!(sql.contains("b.mode = 'PAPER'"));
1350        } else {
1351            panic!("Expected Select");
1352        }
1353    }
1354
1355    #[test]
1356    fn test_join_compound_or() {
1357        let stmt = parse_query(
1358            "SELECT * FROM a JOIN b ON a.id = b.id OR a.alt = b.id",
1359        )
1360        .unwrap();
1361        if let Statement::Select(q) = stmt {
1362            let sql = where_clause_to_sql(&q.joins[0].condition);
1363            assert!(sql.contains("OR"));
1364        } else {
1365            panic!("Expected Select");
1366        }
1367    }
1368
1369    #[test]
1370    fn test_join_compound_with_where() {
1371        let stmt = parse_query(
1372            "SELECT s.title FROM strategies s JOIN backtests b ON b.strategy = s.path AND b.mode = 'PAPER' WHERE s.title = 'Alpha'",
1373        )
1374        .unwrap();
1375        if let Statement::Select(q) = stmt {
1376            assert_eq!(q.joins.len(), 1);
1377            assert!(q.where_clause.is_some());
1378            let join_sql = where_clause_to_sql(&q.joins[0].condition);
1379            assert!(join_sql.contains("AND"));
1380        } else {
1381            panic!("Expected Select");
1382        }
1383    }
1384
1385    #[test]
1386    fn test_empty_query() {
1387        assert!(parse_query("").is_err());
1388    }
1389
1390    #[test]
1391    fn test_count_star() {
1392        let stmt = parse_query("SELECT status, COUNT(*) AS cnt FROM strategies GROUP BY status").unwrap();
1393        if let Statement::Select(q) = stmt {
1394            if let ColumnList::Named(exprs) = &q.columns {
1395                assert_eq!(exprs.len(), 2);
1396                assert_eq!(exprs[0], SelectExpr::Column("status".into()));
1397                assert!(matches!(&exprs[1], SelectExpr::Aggregate {
1398                    func: AggFunc::Count,
1399                    arg,
1400                    alias: Some(a),
1401                    ..
1402                } if arg == "*" && a == "cnt"));
1403            } else {
1404                panic!("Expected Named columns");
1405            }
1406            assert_eq!(q.group_by, Some(vec!["status".into()]));
1407        } else {
1408            panic!("Expected Select");
1409        }
1410    }
1411
1412    #[test]
1413    fn test_count_column_as_ident() {
1414        // "count" as a column name should NOT be parsed as the COUNT aggregate
1415        let stmt = parse_query("INSERT INTO test (title, count) VALUES ('Hello', 42)").unwrap();
1416        if let Statement::Insert(q) = stmt {
1417            assert_eq!(q.columns, vec!["title", "count"]);
1418        } else {
1419            panic!("Expected Insert");
1420        }
1421    }
1422
1423    #[test]
1424    fn test_multiple_aggregates() {
1425        let stmt = parse_query("SELECT MIN(composite), MAX(composite), AVG(composite) FROM strategies").unwrap();
1426        if let Statement::Select(q) = stmt {
1427            if let ColumnList::Named(exprs) = &q.columns {
1428                assert_eq!(exprs.len(), 3);
1429                assert!(matches!(&exprs[0], SelectExpr::Aggregate { func: AggFunc::Min, .. }));
1430                assert!(matches!(&exprs[1], SelectExpr::Aggregate { func: AggFunc::Max, .. }));
1431                assert!(matches!(&exprs[2], SelectExpr::Aggregate { func: AggFunc::Avg, .. }));
1432            } else {
1433                panic!("Expected Named columns");
1434            }
1435            assert_eq!(q.group_by, None);
1436        } else {
1437            panic!("Expected Select");
1438        }
1439    }
1440
1441    // ── Expression tests ──────────────────────────────────────────
1442
1443    #[test]
1444    fn test_select_arithmetic_expr() {
1445        let stmt = parse_query("SELECT a + b FROM test").unwrap();
1446        if let Statement::Select(q) = stmt {
1447            if let ColumnList::Named(exprs) = &q.columns {
1448                assert_eq!(exprs.len(), 1);
1449                assert!(matches!(&exprs[0], SelectExpr::Expr {
1450                    expr: Expr::BinaryOp { op: ArithOp::Add, .. },
1451                    alias: None,
1452                }));
1453            } else {
1454                panic!("Expected Named columns");
1455            }
1456        } else {
1457            panic!("Expected Select");
1458        }
1459    }
1460
1461    #[test]
1462    fn test_select_arithmetic_with_alias() {
1463        let stmt = parse_query("SELECT a + b AS total FROM test").unwrap();
1464        if let Statement::Select(q) = stmt {
1465            if let ColumnList::Named(exprs) = &q.columns {
1466                assert_eq!(exprs.len(), 1);
1467                assert!(matches!(&exprs[0], SelectExpr::Expr {
1468                    alias: Some(a),
1469                    ..
1470                } if a == "total"));
1471                assert_eq!(exprs[0].output_name(), "total");
1472            } else {
1473                panic!("Expected Named columns");
1474            }
1475        } else {
1476            panic!("Expected Select");
1477        }
1478    }
1479
1480    #[test]
1481    fn test_select_precedence() {
1482        // a + b * c should parse as a + (b * c)
1483        let stmt = parse_query("SELECT a + b * c FROM test").unwrap();
1484        if let Statement::Select(q) = stmt {
1485            if let ColumnList::Named(exprs) = &q.columns {
1486                if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1487                    if let Expr::BinaryOp { left, op, right } = expr {
1488                        assert_eq!(*op, ArithOp::Add);
1489                        assert!(matches!(left.as_ref(), Expr::Column(n) if n == "a"));
1490                        assert!(matches!(right.as_ref(), Expr::BinaryOp { op: ArithOp::Mul, .. }));
1491                    } else {
1492                        panic!("Expected BinaryOp");
1493                    }
1494                } else {
1495                    panic!("Expected Expr variant");
1496                }
1497            } else {
1498                panic!("Expected Named columns");
1499            }
1500        } else {
1501            panic!("Expected Select");
1502        }
1503    }
1504
1505    #[test]
1506    fn test_select_parenthesized_expr() {
1507        // (a + b) * c should override default precedence
1508        let stmt = parse_query("SELECT (a + b) * c FROM test").unwrap();
1509        if let Statement::Select(q) = stmt {
1510            if let ColumnList::Named(exprs) = &q.columns {
1511                if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1512                    if let Expr::BinaryOp { left, op, .. } = expr {
1513                        assert_eq!(*op, ArithOp::Mul);
1514                        assert!(matches!(left.as_ref(), Expr::BinaryOp { op: ArithOp::Add, .. }));
1515                    } else {
1516                        panic!("Expected BinaryOp");
1517                    }
1518                } else {
1519                    panic!("Expected Expr variant");
1520                }
1521            } else {
1522                panic!("Expected Named columns");
1523            }
1524        } else {
1525            panic!("Expected Select");
1526        }
1527    }
1528
1529    #[test]
1530    fn test_select_unary_minus() {
1531        let stmt = parse_query("SELECT -count FROM test").unwrap();
1532        if let Statement::Select(q) = stmt {
1533            if let ColumnList::Named(exprs) = &q.columns {
1534                assert!(matches!(&exprs[0], SelectExpr::Expr {
1535                    expr: Expr::UnaryMinus(_),
1536                    ..
1537                }));
1538            } else {
1539                panic!("Expected Named columns");
1540            }
1541        } else {
1542            panic!("Expected Select");
1543        }
1544    }
1545
1546    #[test]
1547    fn test_select_negative_literal() {
1548        let stmt = parse_query("SELECT -42 FROM test").unwrap();
1549        if let Statement::Select(q) = stmt {
1550            if let ColumnList::Named(exprs) = &q.columns {
1551                // Unary minus folds into the literal
1552                assert!(matches!(&exprs[0], SelectExpr::Expr {
1553                    expr: Expr::Literal(SqlValue::Int(-42)),
1554                    ..
1555                }));
1556            } else {
1557                panic!("Expected Named columns");
1558            }
1559        } else {
1560            panic!("Expected Select");
1561        }
1562    }
1563
1564    #[test]
1565    fn test_where_arithmetic_expr() {
1566        let stmt = parse_query("SELECT * FROM test WHERE a + b > 10").unwrap();
1567        if let Statement::Select(q) = stmt {
1568            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1569                assert_eq!(c.op, CmpOp::Gt);
1570                assert!(matches!(&c.left_expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1571                assert!(matches!(&c.right_expr, Some(Expr::Literal(SqlValue::Int(10)))));
1572            } else {
1573                panic!("Expected comparison");
1574            }
1575        } else {
1576            panic!("Expected Select");
1577        }
1578    }
1579
1580    #[test]
1581    fn test_where_both_sides_expr() {
1582        let stmt = parse_query("SELECT * FROM test WHERE a * 2 > b + 1").unwrap();
1583        if let Statement::Select(q) = stmt {
1584            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1585                assert_eq!(c.op, CmpOp::Gt);
1586                assert!(matches!(&c.left_expr, Some(Expr::BinaryOp { op: ArithOp::Mul, .. })));
1587                assert!(matches!(&c.right_expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1588            } else {
1589                panic!("Expected comparison");
1590            }
1591        } else {
1592            panic!("Expected Select");
1593        }
1594    }
1595
1596    #[test]
1597    fn test_order_by_expr() {
1598        let stmt = parse_query("SELECT * FROM test ORDER BY a + b DESC").unwrap();
1599        if let Statement::Select(q) = stmt {
1600            let ob = q.order_by.unwrap();
1601            assert_eq!(ob.len(), 1);
1602            assert!(ob[0].descending);
1603            assert!(matches!(&ob[0].expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1604        } else {
1605            panic!("Expected Select");
1606        }
1607    }
1608
1609    #[test]
1610    fn test_all_arithmetic_ops() {
1611        let stmt = parse_query("SELECT a + b, a - b, a * b, a / b, a % b FROM test").unwrap();
1612        if let Statement::Select(q) = stmt {
1613            if let ColumnList::Named(exprs) = &q.columns {
1614                assert_eq!(exprs.len(), 5);
1615                assert!(matches!(&exprs[0], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Add, .. }, .. }));
1616                assert!(matches!(&exprs[1], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Sub, .. }, .. }));
1617                assert!(matches!(&exprs[2], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Mul, .. }, .. }));
1618                assert!(matches!(&exprs[3], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Div, .. }, .. }));
1619                assert!(matches!(&exprs[4], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Mod, .. }, .. }));
1620            } else {
1621                panic!("Expected Named columns");
1622            }
1623        } else {
1624            panic!("Expected Select");
1625        }
1626    }
1627
1628    #[test]
1629    fn test_column_with_literal_arithmetic() {
1630        let stmt = parse_query("SELECT count * 2 + 1 FROM test").unwrap();
1631        if let Statement::Select(q) = stmt {
1632            if let ColumnList::Named(exprs) = &q.columns {
1633                // Should be (count * 2) + 1
1634                if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1635                    if let Expr::BinaryOp { left, op, right } = expr {
1636                        assert_eq!(*op, ArithOp::Add);
1637                        assert!(matches!(right.as_ref(), Expr::Literal(SqlValue::Int(1))));
1638                        assert!(matches!(left.as_ref(), Expr::BinaryOp { op: ArithOp::Mul, .. }));
1639                    } else {
1640                        panic!("Expected BinaryOp");
1641                    }
1642                } else {
1643                    panic!("Expected Expr");
1644                }
1645            } else {
1646                panic!("Expected Named columns");
1647            }
1648        } else {
1649            panic!("Expected Select");
1650        }
1651    }
1652
1653    #[test]
1654    fn test_mixed_columns_and_exprs() {
1655        let stmt = parse_query("SELECT title, a + b AS sum, count FROM test").unwrap();
1656        if let Statement::Select(q) = stmt {
1657            if let ColumnList::Named(exprs) = &q.columns {
1658                assert_eq!(exprs.len(), 3);
1659                assert_eq!(exprs[0], SelectExpr::Column("title".into()));
1660                assert!(matches!(&exprs[1], SelectExpr::Expr { alias: Some(a), .. } if a == "sum"));
1661                assert_eq!(exprs[2], SelectExpr::Column("count".into()));
1662            } else {
1663                panic!("Expected Named columns");
1664            }
1665        } else {
1666            panic!("Expected Select");
1667        }
1668    }
1669
1670    // ── CASE WHEN tests ──────────────────────────────────────────
1671
1672    #[test]
1673    fn test_case_when_basic() {
1674        let stmt = parse_query(
1675            "SELECT CASE WHEN status = 'ACTIVE' THEN 1 ELSE 0 END FROM test"
1676        ).unwrap();
1677        if let Statement::Select(q) = stmt {
1678            if let ColumnList::Named(exprs) = &q.columns {
1679                assert_eq!(exprs.len(), 1);
1680                assert!(matches!(&exprs[0], SelectExpr::Expr {
1681                    expr: Expr::Case { .. },
1682                    ..
1683                }));
1684            } else {
1685                panic!("Expected Named columns");
1686            }
1687        } else {
1688            panic!("Expected Select");
1689        }
1690    }
1691
1692    #[test]
1693    fn test_case_when_multiple_branches() {
1694        let stmt = parse_query(
1695            "SELECT CASE WHEN x > 10 THEN 'high' WHEN x > 5 THEN 'mid' ELSE 'low' END FROM test"
1696        ).unwrap();
1697        if let Statement::Select(q) = stmt {
1698            if let ColumnList::Named(exprs) = &q.columns {
1699                if let SelectExpr::Expr { expr: Expr::Case { whens, else_expr }, .. } = &exprs[0] {
1700                    assert_eq!(whens.len(), 2);
1701                    assert!(else_expr.is_some());
1702                } else {
1703                    panic!("Expected Case expression");
1704                }
1705            } else {
1706                panic!("Expected Named columns");
1707            }
1708        } else {
1709            panic!("Expected Select");
1710        }
1711    }
1712
1713    #[test]
1714    fn test_case_when_no_else() {
1715        let stmt = parse_query(
1716            "SELECT CASE WHEN x = 1 THEN 'one' END FROM test"
1717        ).unwrap();
1718        if let Statement::Select(q) = stmt {
1719            if let ColumnList::Named(exprs) = &q.columns {
1720                if let SelectExpr::Expr { expr: Expr::Case { whens, else_expr }, .. } = &exprs[0] {
1721                    assert_eq!(whens.len(), 1);
1722                    assert!(else_expr.is_none());
1723                } else {
1724                    panic!("Expected Case expression");
1725                }
1726            } else {
1727                panic!("Expected Named columns");
1728            }
1729        } else {
1730            panic!("Expected Select");
1731        }
1732    }
1733
1734    #[test]
1735    fn test_case_when_in_aggregate() {
1736        let stmt = parse_query(
1737            "SELECT SUM(CASE WHEN side = 'BUY' THEN size ELSE -size END) AS net FROM orders GROUP BY token"
1738        ).unwrap();
1739        if let Statement::Select(q) = stmt {
1740            if let ColumnList::Named(exprs) = &q.columns {
1741                assert_eq!(exprs.len(), 1);
1742                assert!(matches!(&exprs[0], SelectExpr::Aggregate {
1743                    func: AggFunc::Sum,
1744                    arg_expr: Some(Expr::Case { .. }),
1745                    alias: Some(a),
1746                    ..
1747                } if a == "net"));
1748            } else {
1749                panic!("Expected Named columns");
1750            }
1751        } else {
1752            panic!("Expected Select");
1753        }
1754    }
1755
1756    #[test]
1757    fn test_case_when_with_alias() {
1758        let stmt = parse_query(
1759            "SELECT CASE WHEN x > 0 THEN 'pos' ELSE 'neg' END AS sign FROM test"
1760        ).unwrap();
1761        if let Statement::Select(q) = stmt {
1762            if let ColumnList::Named(exprs) = &q.columns {
1763                assert!(matches!(&exprs[0], SelectExpr::Expr {
1764                    expr: Expr::Case { .. },
1765                    alias: Some(a),
1766                } if a == "sign"));
1767            } else {
1768                panic!("Expected Named columns");
1769            }
1770        } else {
1771            panic!("Expected Select");
1772        }
1773    }
1774
1775    #[test]
1776    fn test_create_view() {
1777        let stmt = parse_query("CREATE VIEW live AS SELECT * FROM strategies WHERE status = 'LIVE'").unwrap();
1778        if let Statement::CreateView(cv) = stmt {
1779            assert_eq!(cv.view_name, "live");
1780            assert!(cv.columns.is_none());
1781            assert_eq!(cv.query.table, "strategies");
1782            assert!(cv.query.where_clause.is_some());
1783        } else {
1784            panic!("Expected CreateView, got {:?}", stmt);
1785        }
1786    }
1787
1788    #[test]
1789    fn test_create_view_with_columns() {
1790        let stmt = parse_query("CREATE VIEW v1 (a, b) AS SELECT title, status FROM t").unwrap();
1791        if let Statement::CreateView(cv) = stmt {
1792            assert_eq!(cv.view_name, "v1");
1793            assert_eq!(cv.columns, Some(vec!["a".into(), "b".into()]));
1794        } else {
1795            panic!("Expected CreateView");
1796        }
1797    }
1798
1799    #[test]
1800    fn test_drop_view() {
1801        let stmt = parse_query("DROP VIEW live").unwrap();
1802        if let Statement::DropView(dv) = stmt {
1803            assert_eq!(dv.view_name, "live");
1804        } else {
1805            panic!("Expected DropView, got {:?}", stmt);
1806        }
1807    }
1808
1809    #[test]
1810    fn test_create_view_case_insensitive() {
1811        let stmt = parse_query("create view My_View as select * from t").unwrap();
1812        if let Statement::CreateView(cv) = stmt {
1813            assert_eq!(cv.view_name, "My_View");
1814        } else {
1815            panic!("Expected CreateView");
1816        }
1817    }
1818
1819    // ── Issue #42: Arithmetic between aggregates in column expressions ──
1820
1821    #[test]
1822    fn test_aggregate_division() {
1823        let stmt = parse_query(
1824            "SELECT token, SUM(sell) / SUM(buy) as ratio FROM orders GROUP BY token"
1825        ).unwrap();
1826        if let Statement::Select(q) = stmt {
1827            assert_eq!(q.group_by, Some(vec!["token".into()]));
1828            if let ColumnList::Named(exprs) = &q.columns {
1829                assert_eq!(exprs.len(), 2);
1830                assert!(exprs[1].is_aggregate());
1831            } else {
1832                panic!("Expected Named columns");
1833            }
1834        } else {
1835            panic!("Expected Select");
1836        }
1837    }
1838
1839    #[test]
1840    fn test_aggregate_subtraction() {
1841        let stmt = parse_query(
1842            "SELECT token, SUM(sell) - SUM(buy) as net FROM orders GROUP BY token"
1843        ).unwrap();
1844        if let Statement::Select(q) = stmt {
1845            if let ColumnList::Named(exprs) = &q.columns {
1846                assert_eq!(exprs[1].output_name(), "net");
1847            }
1848        } else {
1849            panic!("Expected Select");
1850        }
1851    }
1852
1853    #[test]
1854    fn test_create_view_with_arithmetic() {
1855        let stmt = parse_query(
1856            "CREATE VIEW positions AS SELECT token, SUM(sell) / SUM(buy) as ratio FROM orders GROUP BY token"
1857        ).unwrap();
1858        if let Statement::CreateView(cv) = stmt {
1859            assert_eq!(cv.view_name, "positions");
1860        } else {
1861            panic!("Expected CreateView, got {:?}", stmt);
1862        }
1863    }
1864
1865    // ── Issue #43: Subqueries in FROM ──
1866
1867    #[test]
1868    fn test_subquery_in_from() {
1869        let stmt = parse_query(
1870            "SELECT token, sell_size FROM (SELECT token, SUM(size) as sell_size FROM orders GROUP BY token) LIMIT 5"
1871        ).unwrap();
1872        if let Statement::Select(q) = stmt {
1873            assert!(q.subquery.is_some());
1874            assert_eq!(q.limit, Some(5));
1875            let sub = q.subquery.unwrap();
1876            assert_eq!(sub.table, "orders");
1877            assert!(sub.group_by.is_some());
1878        } else {
1879            panic!("Expected Select");
1880        }
1881    }
1882
1883    // ── Issue #44: HAVING in CREATE VIEW ──
1884
1885    #[test]
1886    fn test_create_view_with_having() {
1887        let stmt = parse_query(
1888            "CREATE VIEW positions AS SELECT token, SUM(sell) as sell_size, SUM(buy) as buy_size FROM orders GROUP BY token HAVING sell_size > buy_size"
1889        ).unwrap();
1890        if let Statement::CreateView(cv) = stmt {
1891            assert_eq!(cv.view_name, "positions");
1892            assert!(cv.query.having.is_some());
1893        } else {
1894            panic!("Expected CreateView, got {:?}", stmt);
1895        }
1896    }
1897
1898    // ── Issue #42: Aggregate multiplication ──
1899
1900    #[test]
1901    fn test_aggregate_multiplication() {
1902        let stmt = parse_query(
1903            "SELECT SUM(a) * 2 as doubled FROM test"
1904        ).unwrap();
1905        if let Statement::Select(q) = stmt {
1906            if let ColumnList::Named(exprs) = &q.columns {
1907                assert_eq!(exprs.len(), 1);
1908                assert!(exprs[0].is_aggregate());
1909                assert_eq!(exprs[0].output_name(), "doubled");
1910            } else {
1911                panic!("Expected Named columns");
1912            }
1913        } else {
1914            panic!("Expected Select");
1915        }
1916    }
1917
1918    #[test]
1919    fn test_complex_aggregate_arithmetic() {
1920        let stmt = parse_query(
1921            "SELECT SUM(CASE WHEN side = 'SELL' THEN size ELSE 0 END) / SUM(CASE WHEN side = 'BUY' THEN size ELSE 0 END) as ratio FROM orders GROUP BY token"
1922        ).unwrap();
1923        if let Statement::Select(q) = stmt {
1924            if let ColumnList::Named(exprs) = &q.columns {
1925                assert_eq!(exprs.len(), 1);
1926                assert!(exprs[0].is_aggregate());
1927                assert_eq!(exprs[0].output_name(), "ratio");
1928            } else {
1929                panic!("Expected Named columns");
1930            }
1931            assert_eq!(q.group_by, Some(vec!["token".into()]));
1932        } else {
1933            panic!("Expected Select");
1934        }
1935    }
1936
1937    // ── Issue #43: Subquery with alias and WHERE ──
1938
1939    #[test]
1940    fn test_subquery_with_alias() {
1941        let stmt = parse_query(
1942            "SELECT x FROM (SELECT x FROM t) sub"
1943        ).unwrap();
1944        if let Statement::Select(q) = stmt {
1945            assert!(q.subquery.is_some());
1946            let sub = q.subquery.unwrap();
1947            assert_eq!(sub.table, "t");
1948            if let ColumnList::Named(exprs) = &q.columns {
1949                assert_eq!(exprs.len(), 1);
1950                assert_eq!(exprs[0].output_name(), "x");
1951            } else {
1952                panic!("Expected Named columns");
1953            }
1954        } else {
1955            panic!("Expected Select");
1956        }
1957    }
1958
1959    #[test]
1960    fn test_subquery_with_where() {
1961        let stmt = parse_query(
1962            "SELECT x FROM (SELECT x FROM t WHERE y > 0) LIMIT 5"
1963        ).unwrap();
1964        if let Statement::Select(q) = stmt {
1965            assert!(q.subquery.is_some());
1966            assert_eq!(q.limit, Some(5));
1967            let sub = q.subquery.unwrap();
1968            assert_eq!(sub.table, "t");
1969            assert!(sub.where_clause.is_some());
1970        } else {
1971            panic!("Expected Select");
1972        }
1973    }
1974
1975    // ── Issue #42 + CREATE VIEW: aggregate subtraction in view ──
1976
1977    #[test]
1978    fn test_create_view_aggregate_subtraction() {
1979        let stmt = parse_query(
1980            "CREATE VIEW v AS SELECT token, SUM(sell) - SUM(buy) as net FROM orders GROUP BY token"
1981        ).unwrap();
1982        if let Statement::CreateView(cv) = stmt {
1983            assert_eq!(cv.view_name, "v");
1984            assert_eq!(cv.query.group_by, Some(vec!["token".into()]));
1985            if let ColumnList::Named(exprs) = &cv.query.columns {
1986                assert_eq!(exprs.len(), 2);
1987                assert_eq!(exprs[1].output_name(), "net");
1988                assert!(exprs[1].is_aggregate());
1989            } else {
1990                panic!("Expected Named columns");
1991            }
1992        } else {
1993            panic!("Expected CreateView, got {:?}", stmt);
1994        }
1995    }
1996
1997    #[test]
1998    fn test_delete_cascade() {
1999        let stmt = parse_query("DELETE FROM strategies WHERE status = 'KILLED' CASCADE").unwrap();
2000        if let Statement::Delete(q) = stmt {
2001            assert_eq!(q.table, "strategies");
2002            assert!(q.where_clause.is_some());
2003            assert_eq!(q.mode, DeleteMode::Cascade);
2004        } else {
2005            panic!("Expected Delete");
2006        }
2007    }
2008
2009    #[test]
2010    fn test_delete_restrict() {
2011        let stmt = parse_query("DELETE FROM strategies WHERE path = 'alpha.md' RESTRICT").unwrap();
2012        if let Statement::Delete(q) = stmt {
2013            assert_eq!(q.table, "strategies");
2014            assert_eq!(q.mode, DeleteMode::Restrict);
2015        } else {
2016            panic!("Expected Delete");
2017        }
2018    }
2019
2020    #[test]
2021    fn test_delete_default_unchanged() {
2022        let stmt = parse_query("DELETE FROM strategies WHERE status = 'KILLED'").unwrap();
2023        if let Statement::Delete(q) = stmt {
2024            assert_eq!(q.mode, DeleteMode::Default);
2025        } else {
2026            panic!("Expected Delete");
2027        }
2028    }
2029
2030    #[test]
2031    fn test_delete_cascade_no_where() {
2032        let stmt = parse_query("DELETE FROM strategies CASCADE").unwrap();
2033        if let Statement::Delete(q) = stmt {
2034            assert_eq!(q.table, "strategies");
2035            assert!(q.where_clause.is_none());
2036            assert_eq!(q.mode, DeleteMode::Cascade);
2037        } else {
2038            panic!("Expected Delete");
2039        }
2040    }
2041
2042    // ── CTE (WITH) tests ──────────────────────────────────────────
2043
2044    #[test]
2045    fn test_cte_basic() {
2046        let stmt = parse_query(
2047            "WITH live AS (SELECT * FROM strategies WHERE status = 'LIVE') SELECT * FROM live"
2048        ).unwrap();
2049        if let Statement::Select(q) = stmt {
2050            assert_eq!(q.ctes.len(), 1);
2051            assert_eq!(q.ctes[0].name, "live");
2052            assert_eq!(q.ctes[0].query.table, "strategies");
2053            assert!(q.ctes[0].query.where_clause.is_some());
2054            assert_eq!(q.table, "live");
2055        } else {
2056            panic!("Expected Select");
2057        }
2058    }
2059
2060    #[test]
2061    fn test_cte_multi() {
2062        let stmt = parse_query(
2063            "WITH a AS (SELECT * FROM t1), b AS (SELECT * FROM t2) SELECT * FROM a JOIN b ON a.id = b.id"
2064        ).unwrap();
2065        if let Statement::Select(q) = stmt {
2066            assert_eq!(q.ctes.len(), 2);
2067            assert_eq!(q.ctes[0].name, "a");
2068            assert_eq!(q.ctes[0].query.table, "t1");
2069            assert_eq!(q.ctes[1].name, "b");
2070            assert_eq!(q.ctes[1].query.table, "t2");
2071            assert_eq!(q.table, "a");
2072            assert_eq!(q.joins.len(), 1);
2073        } else {
2074            panic!("Expected Select");
2075        }
2076    }
2077
2078    #[test]
2079    fn test_cte_with_aggregation() {
2080        let stmt = parse_query(
2081            "WITH totals AS (SELECT strategy, COUNT(*) AS cnt FROM backtests GROUP BY strategy) SELECT * FROM totals WHERE cnt > 1"
2082        ).unwrap();
2083        if let Statement::Select(q) = stmt {
2084            assert_eq!(q.ctes.len(), 1);
2085            assert_eq!(q.ctes[0].name, "totals");
2086            assert!(q.ctes[0].query.group_by.is_some());
2087            assert_eq!(q.table, "totals");
2088            assert!(q.where_clause.is_some());
2089        } else {
2090            panic!("Expected Select");
2091        }
2092    }
2093
2094    #[test]
2095    fn test_cte_no_ctes_on_plain_select() {
2096        let stmt = parse_query("SELECT * FROM t").unwrap();
2097        if let Statement::Select(q) = stmt {
2098            assert!(q.ctes.is_empty());
2099        } else {
2100            panic!("Expected Select");
2101        }
2102    }
2103}