Skip to main content

mdql_core/
query_parser.rs

1//! Hand-written recursive descent parser for the MDQL SQL subset.
2
3use regex::Regex;
4use std::sync::LazyLock;
5
6use crate::errors::MdqlError;
7pub use crate::query_ast::*;
8
9// ── Tokenizer ──────────────────────────────────────────────────────────────
10
11static KEYWORDS: &[&str] = &[
12    "SELECT", "FROM", "WHERE", "AND", "OR", "ORDER", "BY",
13    "ASC", "DESC", "LIMIT", "LIKE", "IN", "IS", "NOT", "NULL",
14    "JOIN", "LEFT", "ON", "AS", "GROUP", "HAVING",
15    "INSERT", "INTO", "VALUES", "UPDATE", "SET", "DELETE",
16    "ALTER", "TABLE", "RENAME", "FIELD", "TO", "DROP", "MERGE", "FIELDS",
17    "CASE", "WHEN", "THEN", "ELSE", "END",
18    "INTERVAL", "DAY", "DAYS", "CURRENT_DATE", "CURRENT_TIMESTAMP", "DATEDIFF",
19    "CREATE", "VIEW", "CASCADE", "RESTRICT",
20    "WITH",
21];
22
23static AGG_FUNCS: &[&str] = &["COUNT", "SUM", "AVG", "MIN", "MAX"];
24
25static TOKEN_RE: LazyLock<Regex> = LazyLock::new(|| {
26    Regex::new(
27        r#"(?x)
28        \s*(?:
29            (?P<backtick>`[^`]+`)
30            | (?P<string>'(?:[^'\\]|\\.)*')
31            | (?P<number>\d+(?:\.\d+)?)
32            | (?P<op><=|>=|!=|[=<>,*()+\-/%])
33            | (?P<word>[A-Za-z_][A-Za-z0-9_./-]*)
34        )"#,
35    )
36    .unwrap()
37});
38
39#[derive(Debug, Clone)]
40struct Token {
41    token_type: String,
42    value: String,
43    raw: String,
44}
45
46fn tokenize(sql: &str) -> Vec<Token> {
47    let mut tokens = Vec::new();
48    for caps in TOKEN_RE.captures_iter(sql) {
49        if let Some(m) = caps.name("backtick") {
50            let raw = m.as_str();
51            tokens.push(Token {
52                token_type: "ident".into(),
53                value: raw[1..raw.len() - 1].into(),
54                raw: raw.into(),
55            });
56        } else if let Some(m) = caps.name("string") {
57            let raw = m.as_str();
58            tokens.push(Token {
59                token_type: "string".into(),
60                value: raw[1..raw.len() - 1].into(),
61                raw: raw.into(),
62            });
63        } else if let Some(m) = caps.name("number") {
64            let raw = m.as_str();
65            tokens.push(Token {
66                token_type: "number".into(),
67                value: raw.into(),
68                raw: raw.into(),
69            });
70        } else if let Some(m) = caps.name("op") {
71            let raw = m.as_str();
72            tokens.push(Token {
73                token_type: "op".into(),
74                value: raw.into(),
75                raw: raw.into(),
76            });
77        } else if let Some(m) = caps.name("word") {
78            let raw = m.as_str();
79            if KEYWORDS.contains(&raw.to_uppercase().as_str()) {
80                tokens.push(Token {
81                    token_type: "keyword".into(),
82                    value: raw.to_uppercase(),
83                    raw: raw.into(),
84                });
85            } else {
86                tokens.push(Token {
87                    token_type: "ident".into(),
88                    value: raw.into(),
89                    raw: raw.into(),
90                });
91            }
92        }
93    }
94    tokens
95}
96
97// ── Parser ─────────────────────────────────────────────────────────────────
98
99struct Parser {
100    tokens: Vec<Token>,
101    pos: usize,
102}
103
104impl Parser {
105    fn new(tokens: Vec<Token>) -> Self {
106        Parser { tokens, pos: 0 }
107    }
108
109    fn peek(&self) -> Option<&Token> {
110        self.tokens.get(self.pos)
111    }
112
113    fn advance(&mut self) -> Token {
114        let t = self.tokens[self.pos].clone();
115        self.pos += 1;
116        t
117    }
118
119    fn expect(&mut self, type_: &str, value: Option<&str>) -> Result<Token, MdqlError> {
120        let t = self.peek().ok_or_else(|| {
121            MdqlError::QueryParse(format!(
122                "Unexpected end of query, expected {}",
123                value.unwrap_or(type_)
124            ))
125        })?;
126        let matches_type = t.token_type == type_;
127        let matches_value = value.map_or(true, |v| t.value == v);
128        if !matches_type || !matches_value {
129            return Err(MdqlError::QueryParse(format!(
130                "Expected {}, got '{}' at position {}",
131                value.unwrap_or(type_),
132                t.raw,
133                self.pos
134            )));
135        }
136        Ok(self.advance())
137    }
138
139    fn match_keyword(&mut self, kw: &str) -> bool {
140        if let Some(t) = self.peek() {
141            if t.token_type == "keyword" && t.value == kw {
142                self.advance();
143                return true;
144            }
145        }
146        false
147    }
148
149    fn parse_statement(&mut self) -> Result<Statement, MdqlError> {
150        let t = self.peek().ok_or_else(|| MdqlError::QueryParse("Empty query".into()))?;
151        match (t.token_type.as_str(), t.value.as_str()) {
152            ("keyword", "WITH") => {
153                let ctes = self.parse_ctes()?;
154                let mut q = self.parse_select()?;
155                q.ctes = ctes;
156                self.expect_end()?;
157                Ok(Statement::Select(q))
158            }
159            ("keyword", "SELECT") => {
160                let q = self.parse_select()?;
161                self.expect_end()?;
162                Ok(Statement::Select(q))
163            }
164            ("keyword", "INSERT") => Ok(Statement::Insert(self.parse_insert()?)),
165            ("keyword", "UPDATE") => Ok(Statement::Update(self.parse_update()?)),
166            ("keyword", "DELETE") => Ok(Statement::Delete(self.parse_delete()?)),
167            ("keyword", "ALTER") => self.parse_alter(),
168            ("keyword", "CREATE") => self.parse_create_view(),
169            ("keyword", "DROP") => self.parse_drop_view(),
170            _ => Err(MdqlError::QueryParse(format!(
171                "Expected SELECT, INSERT, UPDATE, DELETE, ALTER, CREATE, or DROP, got '{}'",
172                t.raw
173            ))),
174        }
175    }
176
177    fn parse_ctes(&mut self) -> Result<Vec<CteClause>, MdqlError> {
178        self.expect("keyword", Some("WITH"))?;
179        let mut ctes = Vec::new();
180        loop {
181            let name = self.parse_ident()?;
182            self.expect("keyword", Some("AS"))?;
183            self.expect("op", Some("("))?;
184            let query = self.parse_select()?;
185            self.expect("op", Some(")"))?;
186            ctes.push(CteClause { name, query: Box::new(query) });
187            if !self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
188                break;
189            }
190            self.advance();
191        }
192        Ok(ctes)
193    }
194
195    fn parse_select(&mut self) -> Result<SelectQuery, MdqlError> {
196        self.expect("keyword", Some("SELECT"))?;
197        let columns = self.parse_columns()?;
198        self.expect("keyword", Some("FROM"))?;
199
200        // Subquery: FROM (SELECT ...)
201        let mut subquery = None;
202        let (table, mut table_alias) = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
203            self.advance();
204            let inner = self.parse_select()?;
205            self.expect("op", Some(")"))?;
206            subquery = Some(Box::new(inner));
207            let alias = if let Some(t) = self.peek() {
208                if t.token_type == "ident" && !self.is_clause_keyword(t) {
209                    Some(self.advance().value)
210                } else {
211                    None
212                }
213            } else {
214                None
215            };
216            ("_subquery".to_string(), alias)
217        } else {
218            let t = self.parse_ident()?;
219            (t, None)
220        };
221
222        // Optional table alias (for non-subquery)
223        if subquery.is_none() {
224            if let Some(t) = self.peek() {
225                if t.token_type == "ident" && !self.is_clause_keyword(t) {
226                    table_alias = Some(self.advance().value);
227                }
228            }
229        }
230
231        // Optional JOIN(s)
232        let mut joins = Vec::new();
233        loop {
234            let jt = if self.match_keyword("LEFT") {
235                self.expect("keyword", Some("JOIN"))?;
236                JoinType::Left
237            } else if self.match_keyword("JOIN") {
238                JoinType::Inner
239            } else {
240                break;
241            };
242            let join_table = self.parse_ident()?;
243            let mut join_alias = None;
244            if let Some(t) = self.peek() {
245                if t.token_type == "ident" && !self.is_clause_keyword(t) {
246                    join_alias = Some(self.advance().value);
247                }
248            }
249            self.expect("keyword", Some("ON"))?;
250            let condition = self.parse_or_expr()?;
251            joins.push(JoinClause {
252                join_type: jt,
253                table: join_table,
254                alias: join_alias,
255                condition,
256            });
257        }
258
259        let mut where_clause = None;
260        if self.match_keyword("WHERE") {
261            where_clause = Some(self.parse_or_expr()?);
262        }
263
264        let mut group_by = None;
265        if self.match_keyword("GROUP") {
266            self.expect("keyword", Some("BY"))?;
267            let mut cols = vec![self.parse_ident()?];
268            while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
269                self.advance();
270                cols.push(self.parse_ident()?);
271            }
272            group_by = Some(cols);
273        }
274
275        let mut having = None;
276        if self.match_keyword("HAVING") {
277            having = Some(self.parse_or_expr()?);
278        }
279
280        let mut order_by = None;
281        if self.match_keyword("ORDER") {
282            self.expect("keyword", Some("BY"))?;
283            order_by = Some(self.parse_order_by()?);
284        }
285
286        let mut limit = None;
287        if self.match_keyword("LIMIT") {
288            let t = self.expect("number", None)?;
289            limit = Some(t.value.parse::<i64>().map_err(|_| {
290                MdqlError::QueryParse(format!("Invalid LIMIT value: {}", t.value))
291            })?);
292        }
293
294        Ok(SelectQuery {
295            columns,
296            table,
297            table_alias,
298            subquery,
299            joins,
300            where_clause,
301            group_by,
302            having,
303            order_by,
304            limit,
305            ctes: vec![],
306        })
307    }
308
309    fn parse_insert(&mut self) -> Result<InsertQuery, MdqlError> {
310        self.expect("keyword", Some("INSERT"))?;
311        self.expect("keyword", Some("INTO"))?;
312        let table = self.parse_ident()?;
313
314        self.expect("op", Some("("))?;
315        let mut columns = vec![self.parse_ident()?];
316        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
317            self.advance();
318            columns.push(self.parse_ident()?);
319        }
320        self.expect("op", Some(")"))?;
321
322        self.expect("keyword", Some("VALUES"))?;
323
324        self.expect("op", Some("("))?;
325        let mut values = vec![self.parse_value()?];
326        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
327            self.advance();
328            values.push(self.parse_value()?);
329        }
330        self.expect("op", Some(")"))?;
331
332        if columns.len() != values.len() {
333            return Err(MdqlError::QueryParse(format!(
334                "Column count ({}) does not match value count ({})",
335                columns.len(),
336                values.len()
337            )));
338        }
339
340        self.expect_end()?;
341        Ok(InsertQuery {
342            table,
343            columns,
344            values,
345        })
346    }
347
348    fn parse_update(&mut self) -> Result<UpdateQuery, MdqlError> {
349        self.expect("keyword", Some("UPDATE"))?;
350        let table = self.parse_ident()?;
351        self.expect("keyword", Some("SET"))?;
352
353        let mut assignments = Vec::new();
354        let col = self.parse_ident()?;
355        self.expect("op", Some("="))?;
356        let val = self.parse_value()?;
357        assignments.push((col, val));
358
359        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
360            self.advance();
361            let col = self.parse_ident()?;
362            self.expect("op", Some("="))?;
363            let val = self.parse_value()?;
364            assignments.push((col, val));
365        }
366
367        let mut where_clause = None;
368        if self.match_keyword("WHERE") {
369            where_clause = Some(self.parse_or_expr()?);
370        }
371
372        self.expect_end()?;
373        Ok(UpdateQuery {
374            table,
375            assignments,
376            where_clause,
377        })
378    }
379
380    fn parse_delete(&mut self) -> Result<DeleteQuery, MdqlError> {
381        self.expect("keyword", Some("DELETE"))?;
382        self.expect("keyword", Some("FROM"))?;
383        let table = self.parse_ident()?;
384
385        let mut where_clause = None;
386        if self.match_keyword("WHERE") {
387            where_clause = Some(self.parse_or_expr()?);
388        }
389
390        let mode = if self.match_keyword("CASCADE") {
391            DeleteMode::Cascade
392        } else if self.match_keyword("RESTRICT") {
393            DeleteMode::Restrict
394        } else {
395            DeleteMode::Default
396        };
397
398        self.expect_end()?;
399        Ok(DeleteQuery {
400            table,
401            where_clause,
402            mode,
403        })
404    }
405
406    fn parse_alter(&mut self) -> Result<Statement, MdqlError> {
407        self.expect("keyword", Some("ALTER"))?;
408        self.expect("keyword", Some("TABLE"))?;
409        let table = self.parse_ident()?;
410
411        let t = self.peek().ok_or_else(|| {
412            MdqlError::QueryParse("Expected RENAME, DROP, or MERGE after table name".into())
413        })?;
414
415        match (t.token_type.as_str(), t.value.as_str()) {
416            ("keyword", "RENAME") => {
417                self.advance();
418                self.expect("keyword", Some("FIELD"))?;
419                let old_name = self.parse_string_or_ident()?;
420                self.expect("keyword", Some("TO"))?;
421                let new_name = self.parse_string_or_ident()?;
422                self.expect_end()?;
423                Ok(Statement::AlterRename(AlterRenameFieldQuery {
424                    table,
425                    old_name,
426                    new_name,
427                }))
428            }
429            ("keyword", "DROP") => {
430                self.advance();
431                self.expect("keyword", Some("FIELD"))?;
432                let field_name = self.parse_string_or_ident()?;
433                self.expect_end()?;
434                Ok(Statement::AlterDrop(AlterDropFieldQuery {
435                    table,
436                    field_name,
437                }))
438            }
439            ("keyword", "MERGE") => {
440                self.advance();
441                self.expect("keyword", Some("FIELDS"))?;
442                let mut sources = vec![self.parse_string_or_ident()?];
443                while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
444                    self.advance();
445                    sources.push(self.parse_string_or_ident()?);
446                }
447                self.expect("keyword", Some("INTO"))?;
448                let target = self.parse_string_or_ident()?;
449                self.expect_end()?;
450                Ok(Statement::AlterMerge(AlterMergeFieldsQuery {
451                    table,
452                    sources,
453                    into: target,
454                }))
455            }
456            _ => Err(MdqlError::QueryParse(format!(
457                "Expected RENAME, DROP, or MERGE, got '{}'",
458                t.raw
459            ))),
460        }
461    }
462
463    fn parse_create_view(&mut self) -> Result<Statement, MdqlError> {
464        self.expect("keyword", Some("CREATE"))?;
465        self.expect("keyword", Some("VIEW"))?;
466        let view_name = self.parse_ident()?;
467
468        let columns = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
469            self.advance();
470            let mut cols = vec![self.parse_ident()?];
471            while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
472                self.advance();
473                cols.push(self.parse_ident()?);
474            }
475            self.expect("op", Some(")"))?;
476            Some(cols)
477        } else {
478            None
479        };
480
481        self.expect("keyword", Some("AS"))?;
482        let query = Box::new(self.parse_select()?);
483        self.expect_end()?;
484
485        Ok(Statement::CreateView(CreateViewQuery {
486            view_name,
487            columns,
488            query,
489        }))
490    }
491
492    fn parse_drop_view(&mut self) -> Result<Statement, MdqlError> {
493        self.expect("keyword", Some("DROP"))?;
494        self.expect("keyword", Some("VIEW"))?;
495        let view_name = self.parse_ident()?;
496        self.expect_end()?;
497        Ok(Statement::DropView(DropViewQuery { view_name }))
498    }
499
500    fn parse_string_or_ident(&mut self) -> Result<String, MdqlError> {
501        let t = self.peek().ok_or_else(|| {
502            MdqlError::QueryParse("Expected field name, got end of query".into())
503        })?;
504        match t.token_type.as_str() {
505            "string" => {
506                let v = self.advance().value;
507                Ok(v)
508            }
509            "ident" | "keyword" => {
510                let v = self.advance().value;
511                Ok(v)
512            }
513            _ => Err(MdqlError::QueryParse(format!(
514                "Expected field name, got '{}'",
515                t.raw
516            ))),
517        }
518    }
519
520    fn parse_columns(&mut self) -> Result<ColumnList, MdqlError> {
521        if let Some(t) = self.peek() {
522            if t.token_type == "op" && t.value == "*" {
523                self.advance();
524                return Ok(ColumnList::All);
525            }
526        }
527
528        let mut exprs = vec![self.parse_select_expr()?];
529        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
530            self.advance();
531            exprs.push(self.parse_select_expr()?);
532        }
533        Ok(ColumnList::Named(exprs))
534    }
535
536    fn peek_is_agg_func(&self) -> bool {
537        let t = match self.peek() {
538            Some(t) => t,
539            None => return false,
540        };
541        let name_upper = t.value.to_uppercase();
542        if !AGG_FUNCS.contains(&name_upper.as_str()) {
543            return false;
544        }
545        // Only treat as aggregate if followed by (
546        self.tokens
547            .get(self.pos + 1)
548            .map_or(false, |next| next.token_type == "op" && next.value == "(")
549    }
550
551    fn parse_select_expr(&mut self) -> Result<SelectExpr, MdqlError> {
552        let _t = self.peek().ok_or_else(|| {
553            MdqlError::QueryParse("Expected column or aggregate, got end of query".into())
554        })?;
555
556        let expr = self.parse_additive()?;
557
558        let alias = if self.match_keyword("AS") {
559            Some(self.parse_ident()?)
560        } else if self.peek().map_or(false, |t| {
561            t.token_type == "ident" && !self.is_clause_keyword(t)
562        }) {
563            Some(self.advance().value)
564        } else {
565            None
566        };
567
568        // Bare aggregate → SelectExpr::Aggregate for backward compat
569        if let Expr::Aggregate { func, arg, arg_expr } = expr {
570            return Ok(SelectExpr::Aggregate {
571                func,
572                arg,
573                arg_expr: arg_expr.map(|e| *e),
574                alias,
575            });
576        }
577
578        if alias.is_none() {
579            if let Expr::Column(name) = &expr {
580                return Ok(SelectExpr::Column(name.clone()));
581            }
582        }
583
584        Ok(SelectExpr::Expr { expr, alias })
585    }
586
587    // ── Expression parser (precedence climbing) ───────────────────────
588
589    fn peek_is_additive_op(&self) -> bool {
590        self.peek().map_or(false, |t| {
591            t.token_type == "op" && (t.value == "+" || t.value == "-")
592        })
593    }
594
595    fn peek_is_multiplicative_op(&self) -> bool {
596        self.peek().map_or(false, |t| {
597            t.token_type == "op" && (t.value == "*" || t.value == "/" || t.value == "%")
598        })
599    }
600
601    fn parse_additive(&mut self) -> Result<Expr, MdqlError> {
602        let mut left = self.parse_multiplicative()?;
603        while self.peek_is_additive_op() {
604            let op_tok = self.advance();
605            let is_sub = op_tok.value == "-";
606
607            // Check for INTERVAL keyword: expr +/- INTERVAL n DAY
608            if self.peek().map_or(false, |t| t.token_type == "keyword" && t.value == "INTERVAL") {
609                self.advance(); // consume INTERVAL
610                let days_expr = self.parse_multiplicative()?;
611                // Expect DAY or DAYS
612                if !self.match_keyword("DAY") && !self.match_keyword("DAYS") {
613                    return Err(MdqlError::QueryParse("Expected DAY after INTERVAL value".into()));
614                }
615                let days = if is_sub {
616                    Expr::UnaryMinus(Box::new(days_expr))
617                } else {
618                    days_expr
619                };
620                left = Expr::DateAdd {
621                    date: Box::new(left),
622                    days: Box::new(days),
623                };
624                continue;
625            }
626
627            let op = match op_tok.value.as_str() {
628                "+" => ArithOp::Add,
629                "-" => ArithOp::Sub,
630                _ => unreachable!(),
631            };
632            let right = self.parse_multiplicative()?;
633            left = Expr::BinaryOp {
634                left: Box::new(left),
635                op,
636                right: Box::new(right),
637            };
638        }
639        Ok(left)
640    }
641
642    fn parse_multiplicative(&mut self) -> Result<Expr, MdqlError> {
643        let mut left = self.parse_unary()?;
644        while self.peek_is_multiplicative_op() {
645            let op_tok = self.advance();
646            let op = match op_tok.value.as_str() {
647                "*" => ArithOp::Mul,
648                "/" => ArithOp::Div,
649                "%" => ArithOp::Mod,
650                _ => unreachable!(),
651            };
652            let right = self.parse_unary()?;
653            left = Expr::BinaryOp {
654                left: Box::new(left),
655                op,
656                right: Box::new(right),
657            };
658        }
659        Ok(left)
660    }
661
662    fn parse_unary(&mut self) -> Result<Expr, MdqlError> {
663        if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "-") {
664            self.advance();
665            let inner = self.parse_atom()?;
666            // Fold unary minus on literals
667            match inner {
668                Expr::Literal(SqlValue::Int(n)) => Ok(Expr::Literal(SqlValue::Int(-n))),
669                Expr::Literal(SqlValue::Float(f)) => Ok(Expr::Literal(SqlValue::Float(-f))),
670                _ => Ok(Expr::UnaryMinus(Box::new(inner))),
671            }
672        } else {
673            self.parse_atom()
674        }
675    }
676
677    fn parse_atom(&mut self) -> Result<Expr, MdqlError> {
678        if self.peek_is_agg_func() {
679            return self.parse_agg_expr();
680        }
681
682        let t = self.peek().ok_or_else(|| {
683            MdqlError::QueryParse("Expected expression, got end of query".into())
684        })?;
685
686        match t.token_type.as_str() {
687            "number" => {
688                let v = self.advance().value;
689                if v.contains('.') {
690                    let f: f64 = v.parse().map_err(|_| {
691                        MdqlError::QueryParse(format!("Invalid float: {}", v))
692                    })?;
693                    Ok(Expr::Literal(SqlValue::Float(f)))
694                } else {
695                    let n: i64 = v.parse().map_err(|_| {
696                        MdqlError::QueryParse(format!("Invalid int: {}", v))
697                    })?;
698                    Ok(Expr::Literal(SqlValue::Int(n)))
699                }
700            }
701            "string" => {
702                let v = self.advance().value;
703                Ok(Expr::Literal(SqlValue::String(v)))
704            }
705            "keyword" if t.value == "NULL" => {
706                self.advance();
707                Ok(Expr::Literal(SqlValue::Null))
708            }
709            "keyword" if t.value == "CASE" => {
710                self.parse_case_expr()
711            }
712            "keyword" if t.value == "CURRENT_DATE" => {
713                self.advance();
714                Ok(Expr::CurrentDate)
715            }
716            "keyword" if t.value == "CURRENT_TIMESTAMP" => {
717                self.advance();
718                Ok(Expr::CurrentTimestamp)
719            }
720            "keyword" if t.value == "DATEDIFF" => {
721                self.advance();
722                self.expect("op", Some("("))?;
723                let left = self.parse_additive()?;
724                self.expect("op", Some(","))?;
725                let right = self.parse_additive()?;
726                self.expect("op", Some(")"))?;
727                Ok(Expr::DateDiff { left: Box::new(left), right: Box::new(right) })
728            }
729            "op" if t.value == "(" => {
730                let next_is_select = self.tokens.get(self.pos + 1)
731                    .map_or(false, |t| t.token_type == "keyword" && t.value == "SELECT");
732                if next_is_select {
733                    self.advance();
734                    let sq = self.parse_select()?;
735                    self.expect("op", Some(")"))?;
736                    Ok(Expr::Subquery(Box::new(sq)))
737                } else {
738                    self.advance();
739                    let expr = self.parse_additive()?;
740                    self.expect("op", Some(")"))?;
741                    Ok(expr)
742                }
743            }
744            "ident" => {
745                let name = self.advance().value;
746                Ok(Expr::Column(name))
747            }
748            "keyword" if !Self::is_reserved_keyword(&t.value) => {
749                let name = self.advance().value;
750                Ok(Expr::Column(name))
751            }
752            _ => Err(MdqlError::QueryParse(format!(
753                "Expected expression, got '{}'",
754                t.raw
755            ))),
756        }
757    }
758
759    fn parse_case_expr(&mut self) -> Result<Expr, MdqlError> {
760        self.expect("keyword", Some("CASE"))?;
761        let mut whens = Vec::new();
762        while self.match_keyword("WHEN") {
763            let condition = self.parse_or_expr()?;
764            self.expect("keyword", Some("THEN"))?;
765            let result = self.parse_additive()?;
766            whens.push((condition, Box::new(result)));
767        }
768        if whens.is_empty() {
769            return Err(MdqlError::QueryParse("CASE requires at least one WHEN clause".into()));
770        }
771        let else_expr = if self.match_keyword("ELSE") {
772            Some(Box::new(self.parse_additive()?))
773        } else {
774            None
775        };
776        self.expect("keyword", Some("END"))?;
777        Ok(Expr::Case { whens, else_expr })
778    }
779
780    fn parse_agg_expr(&mut self) -> Result<Expr, MdqlError> {
781        let func_name = self.advance().value.to_uppercase();
782        let func = match func_name.as_str() {
783            "COUNT" => AggFunc::Count,
784            "SUM" => AggFunc::Sum,
785            "AVG" => AggFunc::Avg,
786            "MIN" => AggFunc::Min,
787            "MAX" => AggFunc::Max,
788            _ => unreachable!(),
789        };
790        self.expect("op", Some("("))?;
791        let (arg, arg_expr) = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "*") {
792            self.advance();
793            ("*".to_string(), None)
794        } else {
795            let expr = self.parse_additive()?;
796            if let Expr::Column(name) = &expr {
797                (name.clone(), None)
798            } else {
799                (expr.display_name(), Some(Box::new(expr)))
800            }
801        };
802        self.expect("op", Some(")"))?;
803        Ok(Expr::Aggregate { func, arg, arg_expr })
804    }
805
806    fn parse_ident(&mut self) -> Result<String, MdqlError> {
807        let t = self.peek().ok_or_else(|| {
808            MdqlError::QueryParse("Expected identifier, got end of query".into())
809        })?;
810        match t.token_type.as_str() {
811            "ident" | "keyword" => {
812                let v = self.advance().value;
813                Ok(v)
814            }
815            _ => Err(MdqlError::QueryParse(format!(
816                "Expected identifier, got '{}'",
817                t.raw
818            ))),
819        }
820    }
821
822    fn parse_or_expr(&mut self) -> Result<WhereClause, MdqlError> {
823        let mut left = self.parse_and_expr()?;
824        while self.match_keyword("OR") {
825            let right = self.parse_and_expr()?;
826            left = WhereClause::BoolOp(BoolOp {
827                op: BoolOpKind::Or,
828                left: Box::new(left),
829                right: Box::new(right),
830            });
831        }
832        Ok(left)
833    }
834
835    fn parse_and_expr(&mut self) -> Result<WhereClause, MdqlError> {
836        let mut left = self.parse_comparison()?;
837        while self.match_keyword("AND") {
838            let right = self.parse_comparison()?;
839            left = WhereClause::BoolOp(BoolOp {
840                op: BoolOpKind::And,
841                left: Box::new(left),
842                right: Box::new(right),
843            });
844        }
845        Ok(left)
846    }
847
848    fn parse_comparison(&mut self) -> Result<WhereClause, MdqlError> {
849        // Handle parenthesized boolean expressions
850        if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
851            // Save position — might be arithmetic parens, not boolean
852            let saved_pos = self.pos;
853            self.advance();
854            // Try parsing as boolean (OR/AND) expression
855            let result = self.parse_or_expr();
856            if result.is_ok() && self.peek().map_or(false, |t| t.token_type == "op" && t.value == ")") {
857                self.advance();
858                return result;
859            }
860            // Not a boolean paren — rewind and parse as arithmetic expression
861            self.pos = saved_pos;
862        }
863
864        // Parse the left side as a full expression (column, literal, or arithmetic)
865        let left_expr = self.parse_additive()?;
866
867        // Extract column name for backward compat (simple column on left side)
868        let col = left_expr.as_column().unwrap_or("").to_string();
869
870        // IS NULL / IS NOT NULL (only valid with simple column)
871        if self.match_keyword("IS") {
872            if self.match_keyword("NOT") {
873                self.expect("keyword", Some("NULL"))?;
874                return Ok(WhereClause::Comparison(Comparison {
875                    column: col,
876                    op: CmpOp::IsNotNull,
877                    value: None,
878                    left_expr: Some(left_expr),
879                    right_expr: None,
880                }));
881            }
882            self.expect("keyword", Some("NULL"))?;
883            return Ok(WhereClause::Comparison(Comparison {
884                column: col,
885                op: CmpOp::IsNull,
886                value: None,
887                left_expr: Some(left_expr),
888                right_expr: None,
889            }));
890        }
891
892        // IN (val, val, ...) or IN (SELECT ...)
893        if self.match_keyword("IN") {
894            self.expect("op", Some("("))?;
895            let is_subquery = self.peek().map_or(false, |t| t.token_type == "keyword" && t.value == "SELECT");
896            if is_subquery {
897                let sq = self.parse_select()?;
898                self.expect("op", Some(")"))?;
899                return Ok(WhereClause::Comparison(Comparison {
900                    column: col,
901                    op: CmpOp::In,
902                    value: None,
903                    left_expr: Some(left_expr),
904                    right_expr: Some(Expr::Subquery(Box::new(sq))),
905                }));
906            }
907            let mut values = vec![self.parse_value()?];
908            while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
909                self.advance();
910                values.push(self.parse_value()?);
911            }
912            self.expect("op", Some(")"))?;
913            return Ok(WhereClause::Comparison(Comparison {
914                column: col,
915                op: CmpOp::In,
916                value: Some(SqlValue::List(values)),
917                left_expr: Some(left_expr),
918                right_expr: None,
919            }));
920        }
921
922        // LIKE
923        if self.match_keyword("LIKE") {
924            let val = self.parse_value()?;
925            return Ok(WhereClause::Comparison(Comparison {
926                column: col,
927                op: CmpOp::Like,
928                value: Some(val),
929                left_expr: Some(left_expr),
930                right_expr: None,
931            }));
932        }
933
934        // NOT LIKE
935        if self.match_keyword("NOT") {
936            if self.match_keyword("LIKE") {
937                let val = self.parse_value()?;
938                return Ok(WhereClause::Comparison(Comparison {
939                    column: col,
940                    op: CmpOp::NotLike,
941                    value: Some(val),
942                    left_expr: Some(left_expr),
943                    right_expr: None,
944                }));
945            }
946            return Err(MdqlError::QueryParse("Expected LIKE after NOT".into()));
947        }
948
949        // Standard comparison operators
950        if let Some(t) = self.peek() {
951            if t.token_type == "op" && ["=", "!=", "<", ">", "<=", ">="].contains(&t.value.as_str())
952            {
953                let op_str = self.advance().value;
954                let op = match op_str.as_str() {
955                    "=" => CmpOp::Eq,
956                    "!=" => CmpOp::Ne,
957                    "<" => CmpOp::Lt,
958                    ">" => CmpOp::Gt,
959                    "<=" => CmpOp::Le,
960                    ">=" => CmpOp::Ge,
961                    _ => unreachable!(),
962                };
963                // Parse right side as expression
964                let right_expr = self.parse_additive()?;
965                // Extract SqlValue for backward compat (simple literal on right side)
966                let value = match &right_expr {
967                    Expr::Literal(v) => Some(v.clone()),
968                    _ => None,
969                };
970                return Ok(WhereClause::Comparison(Comparison {
971                    column: col,
972                    op,
973                    value,
974                    left_expr: Some(left_expr),
975                    right_expr: Some(right_expr),
976                }));
977            }
978        }
979
980        let got = self.peek().map_or("end".to_string(), |t| t.raw.clone());
981        Err(MdqlError::QueryParse(format!(
982            "Expected operator after '{}', got '{}'",
983            left_expr.display_name(), got
984        )))
985    }
986
987    fn parse_value(&mut self) -> Result<SqlValue, MdqlError> {
988        let t = self.peek().ok_or_else(|| {
989            MdqlError::QueryParse("Expected value, got end of query".into())
990        })?;
991        match t.token_type.as_str() {
992            "string" => {
993                let v = self.advance().value;
994                Ok(SqlValue::String(v))
995            }
996            "number" => {
997                let v = self.advance().value;
998                if v.contains('.') {
999                    Ok(SqlValue::Float(v.parse().map_err(|_| {
1000                        MdqlError::QueryParse(format!("Invalid float: {}", v))
1001                    })?))
1002                } else {
1003                    Ok(SqlValue::Int(v.parse().map_err(|_| {
1004                        MdqlError::QueryParse(format!("Invalid int: {}", v))
1005                    })?))
1006                }
1007            }
1008            "keyword" if t.value == "NULL" => {
1009                self.advance();
1010                Ok(SqlValue::Null)
1011            }
1012            _ => Err(MdqlError::QueryParse(format!(
1013                "Expected value, got '{}'",
1014                t.raw
1015            ))),
1016        }
1017    }
1018
1019    fn parse_order_by(&mut self) -> Result<Vec<OrderSpec>, MdqlError> {
1020        let mut specs = vec![self.parse_order_spec()?];
1021        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
1022            self.advance();
1023            specs.push(self.parse_order_spec()?);
1024        }
1025        Ok(specs)
1026    }
1027
1028    fn parse_order_spec(&mut self) -> Result<OrderSpec, MdqlError> {
1029        let expr = self.parse_additive()?;
1030        let col = expr.as_column().unwrap_or("").to_string();
1031        let descending = if self.match_keyword("DESC") {
1032            true
1033        } else {
1034            self.match_keyword("ASC");
1035            false
1036        };
1037        Ok(OrderSpec {
1038            column: col,
1039            expr: Some(expr),
1040            descending,
1041        })
1042    }
1043
1044    fn is_clause_keyword(&self, t: &Token) -> bool {
1045        t.token_type == "keyword"
1046            && ["WHERE", "ORDER", "LIMIT", "JOIN", "LEFT", "ON", "GROUP"].contains(&t.value.as_str())
1047    }
1048
1049    /// Keywords that should never be consumed as column names inside expressions.
1050    fn is_reserved_keyword(kw: &str) -> bool {
1051        matches!(kw,
1052            "AS" | "FROM" | "WHERE" | "AND" | "OR" | "ORDER" | "BY"
1053            | "ASC" | "DESC" | "LIMIT" | "JOIN" | "ON" | "GROUP"
1054            | "SELECT" | "INSERT" | "INTO" | "VALUES" | "UPDATE" | "SET"
1055            | "DELETE" | "ALTER" | "TABLE" | "IS" | "NOT" | "IN" | "LIKE"
1056            | "RENAME" | "FIELD" | "TO" | "DROP" | "MERGE" | "FIELDS"
1057            | "CASE" | "WHEN" | "THEN" | "ELSE" | "END"
1058            | "HAVING" | "INTERVAL" | "DAY" | "DAYS"
1059            | "CURRENT_DATE" | "CURRENT_TIMESTAMP" | "DATEDIFF"
1060            | "CREATE" | "VIEW" | "CASCADE" | "RESTRICT"
1061            | "WITH"
1062        )
1063    }
1064
1065    fn expect_end(&self) -> Result<(), MdqlError> {
1066        if let Some(t) = self.peek() {
1067            return Err(MdqlError::QueryParse(format!(
1068                "Unexpected token '{}' at position {}",
1069                t.raw, self.pos
1070            )));
1071        }
1072        Ok(())
1073    }
1074}
1075
1076pub fn parse_query(sql: &str) -> crate::errors::Result<Statement> {
1077    let tokens = tokenize(sql);
1078    if tokens.is_empty() {
1079        return Err(MdqlError::QueryParse("Empty query".into()));
1080    }
1081    let mut parser = Parser::new(tokens);
1082    parser.parse_statement()
1083}
1084
1085#[cfg(test)]
1086mod tests {
1087    use super::*;
1088
1089    #[test]
1090    fn test_simple_select() {
1091        let stmt = parse_query("SELECT title, status FROM strategies").unwrap();
1092        if let Statement::Select(q) = stmt {
1093            assert_eq!(q.columns, ColumnList::Named(vec![SelectExpr::Column("title".into()), SelectExpr::Column("status".into())]));
1094            assert_eq!(q.table, "strategies");
1095        } else {
1096            panic!("Expected Select");
1097        }
1098    }
1099
1100    #[test]
1101    fn test_select_star() {
1102        let stmt = parse_query("SELECT * FROM test").unwrap();
1103        if let Statement::Select(q) = stmt {
1104            assert_eq!(q.columns, ColumnList::All);
1105        } else {
1106            panic!("Expected Select");
1107        }
1108    }
1109
1110    #[test]
1111    fn test_where_clause() {
1112        let stmt = parse_query("SELECT title FROM test WHERE count > 5").unwrap();
1113        if let Statement::Select(q) = stmt {
1114            assert!(q.where_clause.is_some());
1115        } else {
1116            panic!("Expected Select");
1117        }
1118    }
1119
1120    #[test]
1121    fn test_order_by() {
1122        let stmt =
1123            parse_query("SELECT title FROM test ORDER BY composite DESC, title ASC").unwrap();
1124        if let Statement::Select(q) = stmt {
1125            let ob = q.order_by.unwrap();
1126            assert_eq!(ob.len(), 2);
1127            assert!(ob[0].descending);
1128            assert!(!ob[1].descending);
1129        } else {
1130            panic!("Expected Select");
1131        }
1132    }
1133
1134    #[test]
1135    fn test_limit() {
1136        let stmt = parse_query("SELECT * FROM test LIMIT 10").unwrap();
1137        if let Statement::Select(q) = stmt {
1138            assert_eq!(q.limit, Some(10));
1139        } else {
1140            panic!("Expected Select");
1141        }
1142    }
1143
1144    #[test]
1145    fn test_insert() {
1146        let stmt = parse_query(
1147            "INSERT INTO test (title, count) VALUES ('Hello', 42)",
1148        )
1149        .unwrap();
1150        if let Statement::Insert(q) = stmt {
1151            assert_eq!(q.table, "test");
1152            assert_eq!(q.columns, vec!["title", "count"]);
1153            assert_eq!(q.values[0], SqlValue::String("Hello".into()));
1154            assert_eq!(q.values[1], SqlValue::Int(42));
1155        } else {
1156            panic!("Expected Insert");
1157        }
1158    }
1159
1160    #[test]
1161    fn test_update() {
1162        let stmt = parse_query("UPDATE test SET status = 'KILLED' WHERE path = 'a.md'").unwrap();
1163        if let Statement::Update(q) = stmt {
1164            assert_eq!(q.table, "test");
1165            assert_eq!(q.assignments.len(), 1);
1166            assert!(q.where_clause.is_some());
1167        } else {
1168            panic!("Expected Update");
1169        }
1170    }
1171
1172    #[test]
1173    fn test_delete() {
1174        let stmt = parse_query("DELETE FROM test WHERE status = 'draft'").unwrap();
1175        if let Statement::Delete(q) = stmt {
1176            assert_eq!(q.table, "test");
1177            assert!(q.where_clause.is_some());
1178        } else {
1179            panic!("Expected Delete");
1180        }
1181    }
1182
1183    #[test]
1184    fn test_alter_rename() {
1185        let stmt =
1186            parse_query("ALTER TABLE test RENAME FIELD 'Summary' TO 'Overview'").unwrap();
1187        if let Statement::AlterRename(q) = stmt {
1188            assert_eq!(q.old_name, "Summary");
1189            assert_eq!(q.new_name, "Overview");
1190        } else {
1191            panic!("Expected AlterRename");
1192        }
1193    }
1194
1195    #[test]
1196    fn test_alter_drop() {
1197        let stmt = parse_query("ALTER TABLE test DROP FIELD 'Details'").unwrap();
1198        if let Statement::AlterDrop(q) = stmt {
1199            assert_eq!(q.field_name, "Details");
1200        } else {
1201            panic!("Expected AlterDrop");
1202        }
1203    }
1204
1205    #[test]
1206    fn test_alter_merge() {
1207        let stmt = parse_query(
1208            "ALTER TABLE test MERGE FIELDS 'Entry Rules', 'Exit Rules' INTO 'Trading Rules'",
1209        )
1210        .unwrap();
1211        if let Statement::AlterMerge(q) = stmt {
1212            assert_eq!(q.sources, vec!["Entry Rules", "Exit Rules"]);
1213            assert_eq!(q.into, "Trading Rules");
1214        } else {
1215            panic!("Expected AlterMerge");
1216        }
1217    }
1218
1219    #[test]
1220    fn test_backtick_ident() {
1221        let stmt = parse_query("SELECT `Structural Mechanism` FROM test").unwrap();
1222        if let Statement::Select(q) = stmt {
1223            assert_eq!(
1224                q.columns,
1225                ColumnList::Named(vec![SelectExpr::Column("Structural Mechanism".into())])
1226            );
1227        } else {
1228            panic!("Expected Select");
1229        }
1230    }
1231
1232    #[test]
1233    fn test_like_operator() {
1234        let stmt = parse_query("SELECT title FROM test WHERE categories LIKE '%defi%'").unwrap();
1235        if let Statement::Select(q) = stmt {
1236            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1237                assert_eq!(c.op, CmpOp::Like);
1238                assert_eq!(c.value, Some(SqlValue::String("%defi%".into())));
1239            } else {
1240                panic!("Expected LIKE comparison");
1241            }
1242        } else {
1243            panic!("Expected Select");
1244        }
1245    }
1246
1247    #[test]
1248    fn test_in_operator() {
1249        let stmt =
1250            parse_query("SELECT * FROM test WHERE status IN ('ACTIVE', 'LIVE')").unwrap();
1251        if let Statement::Select(q) = stmt {
1252            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1253                assert_eq!(c.op, CmpOp::In);
1254            } else {
1255                panic!("Expected IN comparison");
1256            }
1257        } else {
1258            panic!("Expected Select");
1259        }
1260    }
1261
1262    #[test]
1263    fn test_is_null() {
1264        let stmt = parse_query("SELECT * FROM test WHERE title IS NULL").unwrap();
1265        if let Statement::Select(q) = stmt {
1266            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1267                assert_eq!(c.op, CmpOp::IsNull);
1268            } else {
1269                panic!("Expected IS NULL comparison");
1270            }
1271        } else {
1272            panic!("Expected Select");
1273        }
1274    }
1275
1276    #[test]
1277    fn test_and_or() {
1278        let stmt = parse_query(
1279            "SELECT * FROM test WHERE status = 'ACTIVE' AND count > 5 OR title LIKE '%test%'",
1280        )
1281        .unwrap();
1282        if let Statement::Select(q) = stmt {
1283            assert!(q.where_clause.is_some());
1284        } else {
1285            panic!("Expected Select");
1286        }
1287    }
1288
1289    #[test]
1290    fn test_join() {
1291        let stmt = parse_query(
1292            "SELECT s.title, b.sharpe FROM strategies s JOIN backtests b ON b.strategy = s.path",
1293        )
1294        .unwrap();
1295        if let Statement::Select(q) = stmt {
1296            assert_eq!(q.table, "strategies");
1297            assert_eq!(q.table_alias, Some("s".into()));
1298            assert_eq!(q.joins.len(), 1);
1299            let join = &q.joins[0];
1300            assert_eq!(join.table, "backtests");
1301            assert_eq!(join.alias, Some("b".into()));
1302        } else {
1303            panic!("Expected Select");
1304        }
1305    }
1306
1307    #[test]
1308    fn test_multi_join() {
1309        let stmt = parse_query(
1310            "SELECT s.title, b.sharpe, c.verdict FROM strategies s JOIN backtests b ON b.strategy = s.path JOIN critiques c ON c.strategy = s.path",
1311        )
1312        .unwrap();
1313        if let Statement::Select(q) = stmt {
1314            assert_eq!(q.table, "strategies");
1315            assert_eq!(q.table_alias, Some("s".into()));
1316            assert_eq!(q.joins.len(), 2);
1317            assert_eq!(q.joins[0].table, "backtests");
1318            assert_eq!(q.joins[0].alias, Some("b".into()));
1319            assert_eq!(where_clause_to_sql(&q.joins[0].condition), "b.strategy = s.path");
1320            assert_eq!(q.joins[1].table, "critiques");
1321            assert_eq!(q.joins[1].alias, Some("c".into()));
1322            assert_eq!(where_clause_to_sql(&q.joins[1].condition), "c.strategy = s.path");
1323        } else {
1324            panic!("Expected Select");
1325        }
1326    }
1327
1328    #[test]
1329    fn test_left_join() {
1330        let stmt = parse_query(
1331            "SELECT s.title, b.sharpe FROM strategies s LEFT JOIN backtests b ON b.strategy = s.path",
1332        )
1333        .unwrap();
1334        if let Statement::Select(q) = stmt {
1335            assert_eq!(q.joins.len(), 1);
1336            assert_eq!(q.joins[0].join_type, JoinType::Left);
1337            assert_eq!(q.joins[0].table, "backtests");
1338        } else {
1339            panic!("Expected Select");
1340        }
1341    }
1342
1343    #[test]
1344    fn test_mixed_join_types() {
1345        let stmt = parse_query(
1346            "SELECT s.title FROM strategies s JOIN backtests b ON b.strategy = s.path LEFT JOIN allocations a ON a.strategy = s.path",
1347        )
1348        .unwrap();
1349        if let Statement::Select(q) = stmt {
1350            assert_eq!(q.joins.len(), 2);
1351            assert_eq!(q.joins[0].join_type, JoinType::Inner);
1352            assert_eq!(q.joins[1].join_type, JoinType::Left);
1353        } else {
1354            panic!("Expected Select");
1355        }
1356    }
1357
1358    #[test]
1359    fn test_join_compound_and() {
1360        let stmt = parse_query(
1361            "SELECT s.title FROM strategies s LEFT JOIN backtests b ON b.strategy = s.path AND b.mode = 'PAPER'",
1362        )
1363        .unwrap();
1364        if let Statement::Select(q) = stmt {
1365            assert_eq!(q.joins.len(), 1);
1366            assert_eq!(q.joins[0].join_type, JoinType::Left);
1367            let sql = where_clause_to_sql(&q.joins[0].condition);
1368            assert!(sql.contains("b.strategy = s.path"));
1369            assert!(sql.contains("AND"));
1370            assert!(sql.contains("b.mode = 'PAPER'"));
1371        } else {
1372            panic!("Expected Select");
1373        }
1374    }
1375
1376    #[test]
1377    fn test_join_compound_or() {
1378        let stmt = parse_query(
1379            "SELECT * FROM a JOIN b ON a.id = b.id OR a.alt = b.id",
1380        )
1381        .unwrap();
1382        if let Statement::Select(q) = stmt {
1383            let sql = where_clause_to_sql(&q.joins[0].condition);
1384            assert!(sql.contains("OR"));
1385        } else {
1386            panic!("Expected Select");
1387        }
1388    }
1389
1390    #[test]
1391    fn test_join_compound_with_where() {
1392        let stmt = parse_query(
1393            "SELECT s.title FROM strategies s JOIN backtests b ON b.strategy = s.path AND b.mode = 'PAPER' WHERE s.title = 'Alpha'",
1394        )
1395        .unwrap();
1396        if let Statement::Select(q) = stmt {
1397            assert_eq!(q.joins.len(), 1);
1398            assert!(q.where_clause.is_some());
1399            let join_sql = where_clause_to_sql(&q.joins[0].condition);
1400            assert!(join_sql.contains("AND"));
1401        } else {
1402            panic!("Expected Select");
1403        }
1404    }
1405
1406    #[test]
1407    fn test_empty_query() {
1408        assert!(parse_query("").is_err());
1409    }
1410
1411    #[test]
1412    fn test_count_star() {
1413        let stmt = parse_query("SELECT status, COUNT(*) AS cnt FROM strategies GROUP BY status").unwrap();
1414        if let Statement::Select(q) = stmt {
1415            if let ColumnList::Named(exprs) = &q.columns {
1416                assert_eq!(exprs.len(), 2);
1417                assert_eq!(exprs[0], SelectExpr::Column("status".into()));
1418                assert!(matches!(&exprs[1], SelectExpr::Aggregate {
1419                    func: AggFunc::Count,
1420                    arg,
1421                    alias: Some(a),
1422                    ..
1423                } if arg == "*" && a == "cnt"));
1424            } else {
1425                panic!("Expected Named columns");
1426            }
1427            assert_eq!(q.group_by, Some(vec!["status".into()]));
1428        } else {
1429            panic!("Expected Select");
1430        }
1431    }
1432
1433    #[test]
1434    fn test_count_column_as_ident() {
1435        // "count" as a column name should NOT be parsed as the COUNT aggregate
1436        let stmt = parse_query("INSERT INTO test (title, count) VALUES ('Hello', 42)").unwrap();
1437        if let Statement::Insert(q) = stmt {
1438            assert_eq!(q.columns, vec!["title", "count"]);
1439        } else {
1440            panic!("Expected Insert");
1441        }
1442    }
1443
1444    #[test]
1445    fn test_multiple_aggregates() {
1446        let stmt = parse_query("SELECT MIN(composite), MAX(composite), AVG(composite) FROM strategies").unwrap();
1447        if let Statement::Select(q) = stmt {
1448            if let ColumnList::Named(exprs) = &q.columns {
1449                assert_eq!(exprs.len(), 3);
1450                assert!(matches!(&exprs[0], SelectExpr::Aggregate { func: AggFunc::Min, .. }));
1451                assert!(matches!(&exprs[1], SelectExpr::Aggregate { func: AggFunc::Max, .. }));
1452                assert!(matches!(&exprs[2], SelectExpr::Aggregate { func: AggFunc::Avg, .. }));
1453            } else {
1454                panic!("Expected Named columns");
1455            }
1456            assert_eq!(q.group_by, None);
1457        } else {
1458            panic!("Expected Select");
1459        }
1460    }
1461
1462    // ── Expression tests ──────────────────────────────────────────
1463
1464    #[test]
1465    fn test_select_arithmetic_expr() {
1466        let stmt = parse_query("SELECT a + b FROM test").unwrap();
1467        if let Statement::Select(q) = stmt {
1468            if let ColumnList::Named(exprs) = &q.columns {
1469                assert_eq!(exprs.len(), 1);
1470                assert!(matches!(&exprs[0], SelectExpr::Expr {
1471                    expr: Expr::BinaryOp { op: ArithOp::Add, .. },
1472                    alias: None,
1473                }));
1474            } else {
1475                panic!("Expected Named columns");
1476            }
1477        } else {
1478            panic!("Expected Select");
1479        }
1480    }
1481
1482    #[test]
1483    fn test_select_arithmetic_with_alias() {
1484        let stmt = parse_query("SELECT a + b AS total FROM test").unwrap();
1485        if let Statement::Select(q) = stmt {
1486            if let ColumnList::Named(exprs) = &q.columns {
1487                assert_eq!(exprs.len(), 1);
1488                assert!(matches!(&exprs[0], SelectExpr::Expr {
1489                    alias: Some(a),
1490                    ..
1491                } if a == "total"));
1492                assert_eq!(exprs[0].output_name(), "total");
1493            } else {
1494                panic!("Expected Named columns");
1495            }
1496        } else {
1497            panic!("Expected Select");
1498        }
1499    }
1500
1501    #[test]
1502    fn test_select_precedence() {
1503        // a + b * c should parse as a + (b * c)
1504        let stmt = parse_query("SELECT a + b * c FROM test").unwrap();
1505        if let Statement::Select(q) = stmt {
1506            if let ColumnList::Named(exprs) = &q.columns {
1507                if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1508                    if let Expr::BinaryOp { left, op, right } = expr {
1509                        assert_eq!(*op, ArithOp::Add);
1510                        assert!(matches!(left.as_ref(), Expr::Column(n) if n == "a"));
1511                        assert!(matches!(right.as_ref(), Expr::BinaryOp { op: ArithOp::Mul, .. }));
1512                    } else {
1513                        panic!("Expected BinaryOp");
1514                    }
1515                } else {
1516                    panic!("Expected Expr variant");
1517                }
1518            } else {
1519                panic!("Expected Named columns");
1520            }
1521        } else {
1522            panic!("Expected Select");
1523        }
1524    }
1525
1526    #[test]
1527    fn test_select_parenthesized_expr() {
1528        // (a + b) * c should override default precedence
1529        let stmt = parse_query("SELECT (a + b) * c FROM test").unwrap();
1530        if let Statement::Select(q) = stmt {
1531            if let ColumnList::Named(exprs) = &q.columns {
1532                if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1533                    if let Expr::BinaryOp { left, op, .. } = expr {
1534                        assert_eq!(*op, ArithOp::Mul);
1535                        assert!(matches!(left.as_ref(), Expr::BinaryOp { op: ArithOp::Add, .. }));
1536                    } else {
1537                        panic!("Expected BinaryOp");
1538                    }
1539                } else {
1540                    panic!("Expected Expr variant");
1541                }
1542            } else {
1543                panic!("Expected Named columns");
1544            }
1545        } else {
1546            panic!("Expected Select");
1547        }
1548    }
1549
1550    #[test]
1551    fn test_select_unary_minus() {
1552        let stmt = parse_query("SELECT -count FROM test").unwrap();
1553        if let Statement::Select(q) = stmt {
1554            if let ColumnList::Named(exprs) = &q.columns {
1555                assert!(matches!(&exprs[0], SelectExpr::Expr {
1556                    expr: Expr::UnaryMinus(_),
1557                    ..
1558                }));
1559            } else {
1560                panic!("Expected Named columns");
1561            }
1562        } else {
1563            panic!("Expected Select");
1564        }
1565    }
1566
1567    #[test]
1568    fn test_select_negative_literal() {
1569        let stmt = parse_query("SELECT -42 FROM test").unwrap();
1570        if let Statement::Select(q) = stmt {
1571            if let ColumnList::Named(exprs) = &q.columns {
1572                // Unary minus folds into the literal
1573                assert!(matches!(&exprs[0], SelectExpr::Expr {
1574                    expr: Expr::Literal(SqlValue::Int(-42)),
1575                    ..
1576                }));
1577            } else {
1578                panic!("Expected Named columns");
1579            }
1580        } else {
1581            panic!("Expected Select");
1582        }
1583    }
1584
1585    #[test]
1586    fn test_where_arithmetic_expr() {
1587        let stmt = parse_query("SELECT * FROM test WHERE a + b > 10").unwrap();
1588        if let Statement::Select(q) = stmt {
1589            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1590                assert_eq!(c.op, CmpOp::Gt);
1591                assert!(matches!(&c.left_expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1592                assert!(matches!(&c.right_expr, Some(Expr::Literal(SqlValue::Int(10)))));
1593            } else {
1594                panic!("Expected comparison");
1595            }
1596        } else {
1597            panic!("Expected Select");
1598        }
1599    }
1600
1601    #[test]
1602    fn test_where_both_sides_expr() {
1603        let stmt = parse_query("SELECT * FROM test WHERE a * 2 > b + 1").unwrap();
1604        if let Statement::Select(q) = stmt {
1605            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1606                assert_eq!(c.op, CmpOp::Gt);
1607                assert!(matches!(&c.left_expr, Some(Expr::BinaryOp { op: ArithOp::Mul, .. })));
1608                assert!(matches!(&c.right_expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1609            } else {
1610                panic!("Expected comparison");
1611            }
1612        } else {
1613            panic!("Expected Select");
1614        }
1615    }
1616
1617    #[test]
1618    fn test_order_by_expr() {
1619        let stmt = parse_query("SELECT * FROM test ORDER BY a + b DESC").unwrap();
1620        if let Statement::Select(q) = stmt {
1621            let ob = q.order_by.unwrap();
1622            assert_eq!(ob.len(), 1);
1623            assert!(ob[0].descending);
1624            assert!(matches!(&ob[0].expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1625        } else {
1626            panic!("Expected Select");
1627        }
1628    }
1629
1630    #[test]
1631    fn test_all_arithmetic_ops() {
1632        let stmt = parse_query("SELECT a + b, a - b, a * b, a / b, a % b FROM test").unwrap();
1633        if let Statement::Select(q) = stmt {
1634            if let ColumnList::Named(exprs) = &q.columns {
1635                assert_eq!(exprs.len(), 5);
1636                assert!(matches!(&exprs[0], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Add, .. }, .. }));
1637                assert!(matches!(&exprs[1], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Sub, .. }, .. }));
1638                assert!(matches!(&exprs[2], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Mul, .. }, .. }));
1639                assert!(matches!(&exprs[3], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Div, .. }, .. }));
1640                assert!(matches!(&exprs[4], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Mod, .. }, .. }));
1641            } else {
1642                panic!("Expected Named columns");
1643            }
1644        } else {
1645            panic!("Expected Select");
1646        }
1647    }
1648
1649    #[test]
1650    fn test_column_with_literal_arithmetic() {
1651        let stmt = parse_query("SELECT count * 2 + 1 FROM test").unwrap();
1652        if let Statement::Select(q) = stmt {
1653            if let ColumnList::Named(exprs) = &q.columns {
1654                // Should be (count * 2) + 1
1655                if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1656                    if let Expr::BinaryOp { left, op, right } = expr {
1657                        assert_eq!(*op, ArithOp::Add);
1658                        assert!(matches!(right.as_ref(), Expr::Literal(SqlValue::Int(1))));
1659                        assert!(matches!(left.as_ref(), Expr::BinaryOp { op: ArithOp::Mul, .. }));
1660                    } else {
1661                        panic!("Expected BinaryOp");
1662                    }
1663                } else {
1664                    panic!("Expected Expr");
1665                }
1666            } else {
1667                panic!("Expected Named columns");
1668            }
1669        } else {
1670            panic!("Expected Select");
1671        }
1672    }
1673
1674    #[test]
1675    fn test_mixed_columns_and_exprs() {
1676        let stmt = parse_query("SELECT title, a + b AS sum, count FROM test").unwrap();
1677        if let Statement::Select(q) = stmt {
1678            if let ColumnList::Named(exprs) = &q.columns {
1679                assert_eq!(exprs.len(), 3);
1680                assert_eq!(exprs[0], SelectExpr::Column("title".into()));
1681                assert!(matches!(&exprs[1], SelectExpr::Expr { alias: Some(a), .. } if a == "sum"));
1682                assert_eq!(exprs[2], SelectExpr::Column("count".into()));
1683            } else {
1684                panic!("Expected Named columns");
1685            }
1686        } else {
1687            panic!("Expected Select");
1688        }
1689    }
1690
1691    // ── CASE WHEN tests ──────────────────────────────────────────
1692
1693    #[test]
1694    fn test_case_when_basic() {
1695        let stmt = parse_query(
1696            "SELECT CASE WHEN status = 'ACTIVE' THEN 1 ELSE 0 END FROM test"
1697        ).unwrap();
1698        if let Statement::Select(q) = stmt {
1699            if let ColumnList::Named(exprs) = &q.columns {
1700                assert_eq!(exprs.len(), 1);
1701                assert!(matches!(&exprs[0], SelectExpr::Expr {
1702                    expr: Expr::Case { .. },
1703                    ..
1704                }));
1705            } else {
1706                panic!("Expected Named columns");
1707            }
1708        } else {
1709            panic!("Expected Select");
1710        }
1711    }
1712
1713    #[test]
1714    fn test_case_when_multiple_branches() {
1715        let stmt = parse_query(
1716            "SELECT CASE WHEN x > 10 THEN 'high' WHEN x > 5 THEN 'mid' ELSE 'low' END FROM test"
1717        ).unwrap();
1718        if let Statement::Select(q) = stmt {
1719            if let ColumnList::Named(exprs) = &q.columns {
1720                if let SelectExpr::Expr { expr: Expr::Case { whens, else_expr }, .. } = &exprs[0] {
1721                    assert_eq!(whens.len(), 2);
1722                    assert!(else_expr.is_some());
1723                } else {
1724                    panic!("Expected Case expression");
1725                }
1726            } else {
1727                panic!("Expected Named columns");
1728            }
1729        } else {
1730            panic!("Expected Select");
1731        }
1732    }
1733
1734    #[test]
1735    fn test_case_when_no_else() {
1736        let stmt = parse_query(
1737            "SELECT CASE WHEN x = 1 THEN 'one' END FROM test"
1738        ).unwrap();
1739        if let Statement::Select(q) = stmt {
1740            if let ColumnList::Named(exprs) = &q.columns {
1741                if let SelectExpr::Expr { expr: Expr::Case { whens, else_expr }, .. } = &exprs[0] {
1742                    assert_eq!(whens.len(), 1);
1743                    assert!(else_expr.is_none());
1744                } else {
1745                    panic!("Expected Case expression");
1746                }
1747            } else {
1748                panic!("Expected Named columns");
1749            }
1750        } else {
1751            panic!("Expected Select");
1752        }
1753    }
1754
1755    #[test]
1756    fn test_case_when_in_aggregate() {
1757        let stmt = parse_query(
1758            "SELECT SUM(CASE WHEN side = 'BUY' THEN size ELSE -size END) AS net FROM orders GROUP BY token"
1759        ).unwrap();
1760        if let Statement::Select(q) = stmt {
1761            if let ColumnList::Named(exprs) = &q.columns {
1762                assert_eq!(exprs.len(), 1);
1763                assert!(matches!(&exprs[0], SelectExpr::Aggregate {
1764                    func: AggFunc::Sum,
1765                    arg_expr: Some(Expr::Case { .. }),
1766                    alias: Some(a),
1767                    ..
1768                } if a == "net"));
1769            } else {
1770                panic!("Expected Named columns");
1771            }
1772        } else {
1773            panic!("Expected Select");
1774        }
1775    }
1776
1777    #[test]
1778    fn test_case_when_with_alias() {
1779        let stmt = parse_query(
1780            "SELECT CASE WHEN x > 0 THEN 'pos' ELSE 'neg' END AS sign FROM test"
1781        ).unwrap();
1782        if let Statement::Select(q) = stmt {
1783            if let ColumnList::Named(exprs) = &q.columns {
1784                assert!(matches!(&exprs[0], SelectExpr::Expr {
1785                    expr: Expr::Case { .. },
1786                    alias: Some(a),
1787                } if a == "sign"));
1788            } else {
1789                panic!("Expected Named columns");
1790            }
1791        } else {
1792            panic!("Expected Select");
1793        }
1794    }
1795
1796    #[test]
1797    fn test_create_view() {
1798        let stmt = parse_query("CREATE VIEW live AS SELECT * FROM strategies WHERE status = 'LIVE'").unwrap();
1799        if let Statement::CreateView(cv) = stmt {
1800            assert_eq!(cv.view_name, "live");
1801            assert!(cv.columns.is_none());
1802            assert_eq!(cv.query.table, "strategies");
1803            assert!(cv.query.where_clause.is_some());
1804        } else {
1805            panic!("Expected CreateView, got {:?}", stmt);
1806        }
1807    }
1808
1809    #[test]
1810    fn test_create_view_with_columns() {
1811        let stmt = parse_query("CREATE VIEW v1 (a, b) AS SELECT title, status FROM t").unwrap();
1812        if let Statement::CreateView(cv) = stmt {
1813            assert_eq!(cv.view_name, "v1");
1814            assert_eq!(cv.columns, Some(vec!["a".into(), "b".into()]));
1815        } else {
1816            panic!("Expected CreateView");
1817        }
1818    }
1819
1820    #[test]
1821    fn test_drop_view() {
1822        let stmt = parse_query("DROP VIEW live").unwrap();
1823        if let Statement::DropView(dv) = stmt {
1824            assert_eq!(dv.view_name, "live");
1825        } else {
1826            panic!("Expected DropView, got {:?}", stmt);
1827        }
1828    }
1829
1830    #[test]
1831    fn test_create_view_case_insensitive() {
1832        let stmt = parse_query("create view My_View as select * from t").unwrap();
1833        if let Statement::CreateView(cv) = stmt {
1834            assert_eq!(cv.view_name, "My_View");
1835        } else {
1836            panic!("Expected CreateView");
1837        }
1838    }
1839
1840    // ── Issue #42: Arithmetic between aggregates in column expressions ──
1841
1842    #[test]
1843    fn test_aggregate_division() {
1844        let stmt = parse_query(
1845            "SELECT token, SUM(sell) / SUM(buy) as ratio FROM orders GROUP BY token"
1846        ).unwrap();
1847        if let Statement::Select(q) = stmt {
1848            assert_eq!(q.group_by, Some(vec!["token".into()]));
1849            if let ColumnList::Named(exprs) = &q.columns {
1850                assert_eq!(exprs.len(), 2);
1851                assert!(exprs[1].is_aggregate());
1852            } else {
1853                panic!("Expected Named columns");
1854            }
1855        } else {
1856            panic!("Expected Select");
1857        }
1858    }
1859
1860    #[test]
1861    fn test_aggregate_subtraction() {
1862        let stmt = parse_query(
1863            "SELECT token, SUM(sell) - SUM(buy) as net FROM orders GROUP BY token"
1864        ).unwrap();
1865        if let Statement::Select(q) = stmt {
1866            if let ColumnList::Named(exprs) = &q.columns {
1867                assert_eq!(exprs[1].output_name(), "net");
1868            }
1869        } else {
1870            panic!("Expected Select");
1871        }
1872    }
1873
1874    #[test]
1875    fn test_create_view_with_arithmetic() {
1876        let stmt = parse_query(
1877            "CREATE VIEW positions AS SELECT token, SUM(sell) / SUM(buy) as ratio FROM orders GROUP BY token"
1878        ).unwrap();
1879        if let Statement::CreateView(cv) = stmt {
1880            assert_eq!(cv.view_name, "positions");
1881        } else {
1882            panic!("Expected CreateView, got {:?}", stmt);
1883        }
1884    }
1885
1886    // ── Issue #43: Subqueries in FROM ──
1887
1888    #[test]
1889    fn test_subquery_in_from() {
1890        let stmt = parse_query(
1891            "SELECT token, sell_size FROM (SELECT token, SUM(size) as sell_size FROM orders GROUP BY token) LIMIT 5"
1892        ).unwrap();
1893        if let Statement::Select(q) = stmt {
1894            assert!(q.subquery.is_some());
1895            assert_eq!(q.limit, Some(5));
1896            let sub = q.subquery.unwrap();
1897            assert_eq!(sub.table, "orders");
1898            assert!(sub.group_by.is_some());
1899        } else {
1900            panic!("Expected Select");
1901        }
1902    }
1903
1904    // ── Issue #44: HAVING in CREATE VIEW ──
1905
1906    #[test]
1907    fn test_create_view_with_having() {
1908        let stmt = parse_query(
1909            "CREATE VIEW positions AS SELECT token, SUM(sell) as sell_size, SUM(buy) as buy_size FROM orders GROUP BY token HAVING sell_size > buy_size"
1910        ).unwrap();
1911        if let Statement::CreateView(cv) = stmt {
1912            assert_eq!(cv.view_name, "positions");
1913            assert!(cv.query.having.is_some());
1914        } else {
1915            panic!("Expected CreateView, got {:?}", stmt);
1916        }
1917    }
1918
1919    // ── Issue #42: Aggregate multiplication ──
1920
1921    #[test]
1922    fn test_aggregate_multiplication() {
1923        let stmt = parse_query(
1924            "SELECT SUM(a) * 2 as doubled FROM test"
1925        ).unwrap();
1926        if let Statement::Select(q) = stmt {
1927            if let ColumnList::Named(exprs) = &q.columns {
1928                assert_eq!(exprs.len(), 1);
1929                assert!(exprs[0].is_aggregate());
1930                assert_eq!(exprs[0].output_name(), "doubled");
1931            } else {
1932                panic!("Expected Named columns");
1933            }
1934        } else {
1935            panic!("Expected Select");
1936        }
1937    }
1938
1939    #[test]
1940    fn test_complex_aggregate_arithmetic() {
1941        let stmt = parse_query(
1942            "SELECT SUM(CASE WHEN side = 'SELL' THEN size ELSE 0 END) / SUM(CASE WHEN side = 'BUY' THEN size ELSE 0 END) as ratio FROM orders GROUP BY token"
1943        ).unwrap();
1944        if let Statement::Select(q) = stmt {
1945            if let ColumnList::Named(exprs) = &q.columns {
1946                assert_eq!(exprs.len(), 1);
1947                assert!(exprs[0].is_aggregate());
1948                assert_eq!(exprs[0].output_name(), "ratio");
1949            } else {
1950                panic!("Expected Named columns");
1951            }
1952            assert_eq!(q.group_by, Some(vec!["token".into()]));
1953        } else {
1954            panic!("Expected Select");
1955        }
1956    }
1957
1958    // ── Issue #43: Subquery with alias and WHERE ──
1959
1960    #[test]
1961    fn test_subquery_with_alias() {
1962        let stmt = parse_query(
1963            "SELECT x FROM (SELECT x FROM t) sub"
1964        ).unwrap();
1965        if let Statement::Select(q) = stmt {
1966            assert!(q.subquery.is_some());
1967            let sub = q.subquery.unwrap();
1968            assert_eq!(sub.table, "t");
1969            if let ColumnList::Named(exprs) = &q.columns {
1970                assert_eq!(exprs.len(), 1);
1971                assert_eq!(exprs[0].output_name(), "x");
1972            } else {
1973                panic!("Expected Named columns");
1974            }
1975        } else {
1976            panic!("Expected Select");
1977        }
1978    }
1979
1980    #[test]
1981    fn test_subquery_with_where() {
1982        let stmt = parse_query(
1983            "SELECT x FROM (SELECT x FROM t WHERE y > 0) LIMIT 5"
1984        ).unwrap();
1985        if let Statement::Select(q) = stmt {
1986            assert!(q.subquery.is_some());
1987            assert_eq!(q.limit, Some(5));
1988            let sub = q.subquery.unwrap();
1989            assert_eq!(sub.table, "t");
1990            assert!(sub.where_clause.is_some());
1991        } else {
1992            panic!("Expected Select");
1993        }
1994    }
1995
1996    // ── Issue #42 + CREATE VIEW: aggregate subtraction in view ──
1997
1998    #[test]
1999    fn test_create_view_aggregate_subtraction() {
2000        let stmt = parse_query(
2001            "CREATE VIEW v AS SELECT token, SUM(sell) - SUM(buy) as net FROM orders GROUP BY token"
2002        ).unwrap();
2003        if let Statement::CreateView(cv) = stmt {
2004            assert_eq!(cv.view_name, "v");
2005            assert_eq!(cv.query.group_by, Some(vec!["token".into()]));
2006            if let ColumnList::Named(exprs) = &cv.query.columns {
2007                assert_eq!(exprs.len(), 2);
2008                assert_eq!(exprs[1].output_name(), "net");
2009                assert!(exprs[1].is_aggregate());
2010            } else {
2011                panic!("Expected Named columns");
2012            }
2013        } else {
2014            panic!("Expected CreateView, got {:?}", stmt);
2015        }
2016    }
2017
2018    #[test]
2019    fn test_delete_cascade() {
2020        let stmt = parse_query("DELETE FROM strategies WHERE status = 'KILLED' CASCADE").unwrap();
2021        if let Statement::Delete(q) = stmt {
2022            assert_eq!(q.table, "strategies");
2023            assert!(q.where_clause.is_some());
2024            assert_eq!(q.mode, DeleteMode::Cascade);
2025        } else {
2026            panic!("Expected Delete");
2027        }
2028    }
2029
2030    #[test]
2031    fn test_delete_restrict() {
2032        let stmt = parse_query("DELETE FROM strategies WHERE path = 'alpha.md' RESTRICT").unwrap();
2033        if let Statement::Delete(q) = stmt {
2034            assert_eq!(q.table, "strategies");
2035            assert_eq!(q.mode, DeleteMode::Restrict);
2036        } else {
2037            panic!("Expected Delete");
2038        }
2039    }
2040
2041    #[test]
2042    fn test_delete_default_unchanged() {
2043        let stmt = parse_query("DELETE FROM strategies WHERE status = 'KILLED'").unwrap();
2044        if let Statement::Delete(q) = stmt {
2045            assert_eq!(q.mode, DeleteMode::Default);
2046        } else {
2047            panic!("Expected Delete");
2048        }
2049    }
2050
2051    #[test]
2052    fn test_delete_cascade_no_where() {
2053        let stmt = parse_query("DELETE FROM strategies CASCADE").unwrap();
2054        if let Statement::Delete(q) = stmt {
2055            assert_eq!(q.table, "strategies");
2056            assert!(q.where_clause.is_none());
2057            assert_eq!(q.mode, DeleteMode::Cascade);
2058        } else {
2059            panic!("Expected Delete");
2060        }
2061    }
2062
2063    // ── CTE (WITH) tests ──────────────────────────────────────────
2064
2065    #[test]
2066    fn test_cte_basic() {
2067        let stmt = parse_query(
2068            "WITH live AS (SELECT * FROM strategies WHERE status = 'LIVE') SELECT * FROM live"
2069        ).unwrap();
2070        if let Statement::Select(q) = stmt {
2071            assert_eq!(q.ctes.len(), 1);
2072            assert_eq!(q.ctes[0].name, "live");
2073            assert_eq!(q.ctes[0].query.table, "strategies");
2074            assert!(q.ctes[0].query.where_clause.is_some());
2075            assert_eq!(q.table, "live");
2076        } else {
2077            panic!("Expected Select");
2078        }
2079    }
2080
2081    #[test]
2082    fn test_cte_multi() {
2083        let stmt = parse_query(
2084            "WITH a AS (SELECT * FROM t1), b AS (SELECT * FROM t2) SELECT * FROM a JOIN b ON a.id = b.id"
2085        ).unwrap();
2086        if let Statement::Select(q) = stmt {
2087            assert_eq!(q.ctes.len(), 2);
2088            assert_eq!(q.ctes[0].name, "a");
2089            assert_eq!(q.ctes[0].query.table, "t1");
2090            assert_eq!(q.ctes[1].name, "b");
2091            assert_eq!(q.ctes[1].query.table, "t2");
2092            assert_eq!(q.table, "a");
2093            assert_eq!(q.joins.len(), 1);
2094        } else {
2095            panic!("Expected Select");
2096        }
2097    }
2098
2099    #[test]
2100    fn test_cte_with_aggregation() {
2101        let stmt = parse_query(
2102            "WITH totals AS (SELECT strategy, COUNT(*) AS cnt FROM backtests GROUP BY strategy) SELECT * FROM totals WHERE cnt > 1"
2103        ).unwrap();
2104        if let Statement::Select(q) = stmt {
2105            assert_eq!(q.ctes.len(), 1);
2106            assert_eq!(q.ctes[0].name, "totals");
2107            assert!(q.ctes[0].query.group_by.is_some());
2108            assert_eq!(q.table, "totals");
2109            assert!(q.where_clause.is_some());
2110        } else {
2111            panic!("Expected Select");
2112        }
2113    }
2114
2115    #[test]
2116    fn test_cte_no_ctes_on_plain_select() {
2117        let stmt = parse_query("SELECT * FROM t").unwrap();
2118        if let Statement::Select(q) = stmt {
2119            assert!(q.ctes.is_empty());
2120        } else {
2121            panic!("Expected Select");
2122        }
2123    }
2124
2125    // ── Subquery tests ──────────────────────────────────────────
2126
2127    #[test]
2128    fn test_where_in_subquery() {
2129        let stmt = parse_query(
2130            "SELECT * FROM strategies WHERE path IN (SELECT strategy FROM backtests)"
2131        ).unwrap();
2132        if let Statement::Select(q) = stmt {
2133            if let Some(WhereClause::Comparison(c)) = &q.where_clause {
2134                assert_eq!(c.op, CmpOp::In);
2135                assert!(matches!(&c.right_expr, Some(Expr::Subquery(_))));
2136            } else {
2137                panic!("Expected IN comparison");
2138            }
2139        } else {
2140            panic!("Expected Select");
2141        }
2142    }
2143
2144    #[test]
2145    fn test_scalar_subquery_in_where() {
2146        let stmt = parse_query(
2147            "SELECT * FROM backtests WHERE sharpe > (SELECT AVG(sharpe) FROM backtests)"
2148        ).unwrap();
2149        if let Statement::Select(q) = stmt {
2150            if let Some(WhereClause::Comparison(c)) = &q.where_clause {
2151                assert_eq!(c.op, CmpOp::Gt);
2152                assert!(matches!(&c.right_expr, Some(Expr::Subquery(_))));
2153            } else {
2154                panic!("Expected comparison");
2155            }
2156        } else {
2157            panic!("Expected Select");
2158        }
2159    }
2160
2161    #[test]
2162    fn test_scalar_subquery_in_select() {
2163        let stmt = parse_query(
2164            "SELECT title, (SELECT COUNT(*) FROM backtests) AS cnt FROM strategies"
2165        ).unwrap();
2166        if let Statement::Select(q) = stmt {
2167            if let ColumnList::Named(exprs) = &q.columns {
2168                assert_eq!(exprs.len(), 2);
2169                assert!(matches!(&exprs[1], SelectExpr::Expr {
2170                    expr: Expr::Subquery(_),
2171                    alias: Some(a),
2172                } if a == "cnt"));
2173            } else {
2174                panic!("Expected Named columns");
2175            }
2176        } else {
2177            panic!("Expected Select");
2178        }
2179    }
2180}