Skip to main content

mdql_core/
query_parser.rs

1//! Hand-written recursive descent parser for the MDQL SQL subset.
2
3use regex::Regex;
4use std::sync::LazyLock;
5
6use crate::errors::MdqlError;
7pub use crate::query_ast::*;
8
9// ── Tokenizer ──────────────────────────────────────────────────────────────
10
11static KEYWORDS: &[&str] = &[
12    "SELECT", "FROM", "WHERE", "AND", "OR", "ORDER", "BY",
13    "ASC", "DESC", "LIMIT", "LIKE", "IN", "IS", "NOT", "NULL",
14    "JOIN", "ON", "AS", "GROUP", "HAVING",
15    "INSERT", "INTO", "VALUES", "UPDATE", "SET", "DELETE",
16    "ALTER", "TABLE", "RENAME", "FIELD", "TO", "DROP", "MERGE", "FIELDS",
17    "CASE", "WHEN", "THEN", "ELSE", "END",
18    "INTERVAL", "DAY", "DAYS", "CURRENT_DATE", "CURRENT_TIMESTAMP", "DATEDIFF",
19    "CREATE", "VIEW", "CASCADE", "RESTRICT",
20];
21
22static AGG_FUNCS: &[&str] = &["COUNT", "SUM", "AVG", "MIN", "MAX"];
23
24static TOKEN_RE: LazyLock<Regex> = LazyLock::new(|| {
25    Regex::new(
26        r#"(?x)
27        \s*(?:
28            (?P<backtick>`[^`]+`)
29            | (?P<string>'(?:[^'\\]|\\.)*')
30            | (?P<number>\d+(?:\.\d+)?)
31            | (?P<op><=|>=|!=|[=<>,*()+\-/%])
32            | (?P<word>[A-Za-z_][A-Za-z0-9_./-]*)
33        )"#,
34    )
35    .unwrap()
36});
37
38#[derive(Debug, Clone)]
39struct Token {
40    token_type: String,
41    value: String,
42    raw: String,
43}
44
45fn tokenize(sql: &str) -> Vec<Token> {
46    let mut tokens = Vec::new();
47    for caps in TOKEN_RE.captures_iter(sql) {
48        if let Some(m) = caps.name("backtick") {
49            let raw = m.as_str();
50            tokens.push(Token {
51                token_type: "ident".into(),
52                value: raw[1..raw.len() - 1].into(),
53                raw: raw.into(),
54            });
55        } else if let Some(m) = caps.name("string") {
56            let raw = m.as_str();
57            tokens.push(Token {
58                token_type: "string".into(),
59                value: raw[1..raw.len() - 1].into(),
60                raw: raw.into(),
61            });
62        } else if let Some(m) = caps.name("number") {
63            let raw = m.as_str();
64            tokens.push(Token {
65                token_type: "number".into(),
66                value: raw.into(),
67                raw: raw.into(),
68            });
69        } else if let Some(m) = caps.name("op") {
70            let raw = m.as_str();
71            tokens.push(Token {
72                token_type: "op".into(),
73                value: raw.into(),
74                raw: raw.into(),
75            });
76        } else if let Some(m) = caps.name("word") {
77            let raw = m.as_str();
78            if KEYWORDS.contains(&raw.to_uppercase().as_str()) {
79                tokens.push(Token {
80                    token_type: "keyword".into(),
81                    value: raw.to_uppercase(),
82                    raw: raw.into(),
83                });
84            } else {
85                tokens.push(Token {
86                    token_type: "ident".into(),
87                    value: raw.into(),
88                    raw: raw.into(),
89                });
90            }
91        }
92    }
93    tokens
94}
95
96// ── Parser ─────────────────────────────────────────────────────────────────
97
98struct Parser {
99    tokens: Vec<Token>,
100    pos: usize,
101}
102
103impl Parser {
104    fn new(tokens: Vec<Token>) -> Self {
105        Parser { tokens, pos: 0 }
106    }
107
108    fn peek(&self) -> Option<&Token> {
109        self.tokens.get(self.pos)
110    }
111
112    fn advance(&mut self) -> Token {
113        let t = self.tokens[self.pos].clone();
114        self.pos += 1;
115        t
116    }
117
118    fn expect(&mut self, type_: &str, value: Option<&str>) -> Result<Token, MdqlError> {
119        let t = self.peek().ok_or_else(|| {
120            MdqlError::QueryParse(format!(
121                "Unexpected end of query, expected {}",
122                value.unwrap_or(type_)
123            ))
124        })?;
125        let matches_type = t.token_type == type_;
126        let matches_value = value.map_or(true, |v| t.value == v);
127        if !matches_type || !matches_value {
128            return Err(MdqlError::QueryParse(format!(
129                "Expected {}, got '{}' at position {}",
130                value.unwrap_or(type_),
131                t.raw,
132                self.pos
133            )));
134        }
135        Ok(self.advance())
136    }
137
138    fn match_keyword(&mut self, kw: &str) -> bool {
139        if let Some(t) = self.peek() {
140            if t.token_type == "keyword" && t.value == kw {
141                self.advance();
142                return true;
143            }
144        }
145        false
146    }
147
148    fn parse_statement(&mut self) -> Result<Statement, MdqlError> {
149        let t = self.peek().ok_or_else(|| MdqlError::QueryParse("Empty query".into()))?;
150        match (t.token_type.as_str(), t.value.as_str()) {
151            ("keyword", "SELECT") => Ok(Statement::Select(self.parse_select()?)),
152            ("keyword", "INSERT") => Ok(Statement::Insert(self.parse_insert()?)),
153            ("keyword", "UPDATE") => Ok(Statement::Update(self.parse_update()?)),
154            ("keyword", "DELETE") => Ok(Statement::Delete(self.parse_delete()?)),
155            ("keyword", "ALTER") => self.parse_alter(),
156            ("keyword", "CREATE") => self.parse_create_view(),
157            ("keyword", "DROP") => self.parse_drop_view(),
158            _ => Err(MdqlError::QueryParse(format!(
159                "Expected SELECT, INSERT, UPDATE, DELETE, ALTER, CREATE, or DROP, got '{}'",
160                t.raw
161            ))),
162        }
163    }
164
165    fn parse_select(&mut self) -> Result<SelectQuery, MdqlError> {
166        self.expect("keyword", Some("SELECT"))?;
167        let columns = self.parse_columns()?;
168        self.expect("keyword", Some("FROM"))?;
169        let table = self.parse_ident()?;
170
171        // Optional table alias
172        let mut table_alias = None;
173        if let Some(t) = self.peek() {
174            if t.token_type == "ident" && !self.is_clause_keyword(t) {
175                table_alias = Some(self.advance().value);
176            }
177        }
178
179        // Optional JOIN(s)
180        let mut joins = Vec::new();
181        while self.match_keyword("JOIN") {
182            let join_table = self.parse_ident()?;
183            let mut join_alias = None;
184            if let Some(t) = self.peek() {
185                if t.token_type == "ident" && !self.is_clause_keyword(t) {
186                    join_alias = Some(self.advance().value);
187                }
188            }
189            self.expect("keyword", Some("ON"))?;
190            let left_col = self.parse_ident()?;
191            self.expect("op", Some("="))?;
192            let right_col = self.parse_ident()?;
193            joins.push(JoinClause {
194                table: join_table,
195                alias: join_alias,
196                left_col,
197                right_col,
198            });
199        }
200
201        let mut where_clause = None;
202        if self.match_keyword("WHERE") {
203            where_clause = Some(self.parse_or_expr()?);
204        }
205
206        let mut group_by = None;
207        if self.match_keyword("GROUP") {
208            self.expect("keyword", Some("BY"))?;
209            let mut cols = vec![self.parse_ident()?];
210            while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
211                self.advance();
212                cols.push(self.parse_ident()?);
213            }
214            group_by = Some(cols);
215        }
216
217        let mut having = None;
218        if self.match_keyword("HAVING") {
219            having = Some(self.parse_or_expr()?);
220        }
221
222        let mut order_by = None;
223        if self.match_keyword("ORDER") {
224            self.expect("keyword", Some("BY"))?;
225            order_by = Some(self.parse_order_by()?);
226        }
227
228        let mut limit = None;
229        if self.match_keyword("LIMIT") {
230            let t = self.expect("number", None)?;
231            limit = Some(t.value.parse::<i64>().map_err(|_| {
232                MdqlError::QueryParse(format!("Invalid LIMIT value: {}", t.value))
233            })?);
234        }
235
236        self.expect_end()?;
237
238        Ok(SelectQuery {
239            columns,
240            table,
241            table_alias,
242            joins,
243            where_clause,
244            group_by,
245            having,
246            order_by,
247            limit,
248        })
249    }
250
251    fn parse_insert(&mut self) -> Result<InsertQuery, MdqlError> {
252        self.expect("keyword", Some("INSERT"))?;
253        self.expect("keyword", Some("INTO"))?;
254        let table = self.parse_ident()?;
255
256        self.expect("op", Some("("))?;
257        let mut columns = vec![self.parse_ident()?];
258        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
259            self.advance();
260            columns.push(self.parse_ident()?);
261        }
262        self.expect("op", Some(")"))?;
263
264        self.expect("keyword", Some("VALUES"))?;
265
266        self.expect("op", Some("("))?;
267        let mut values = vec![self.parse_value()?];
268        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
269            self.advance();
270            values.push(self.parse_value()?);
271        }
272        self.expect("op", Some(")"))?;
273
274        if columns.len() != values.len() {
275            return Err(MdqlError::QueryParse(format!(
276                "Column count ({}) does not match value count ({})",
277                columns.len(),
278                values.len()
279            )));
280        }
281
282        self.expect_end()?;
283        Ok(InsertQuery {
284            table,
285            columns,
286            values,
287        })
288    }
289
290    fn parse_update(&mut self) -> Result<UpdateQuery, MdqlError> {
291        self.expect("keyword", Some("UPDATE"))?;
292        let table = self.parse_ident()?;
293        self.expect("keyword", Some("SET"))?;
294
295        let mut assignments = Vec::new();
296        let col = self.parse_ident()?;
297        self.expect("op", Some("="))?;
298        let val = self.parse_value()?;
299        assignments.push((col, val));
300
301        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
302            self.advance();
303            let col = self.parse_ident()?;
304            self.expect("op", Some("="))?;
305            let val = self.parse_value()?;
306            assignments.push((col, val));
307        }
308
309        let mut where_clause = None;
310        if self.match_keyword("WHERE") {
311            where_clause = Some(self.parse_or_expr()?);
312        }
313
314        self.expect_end()?;
315        Ok(UpdateQuery {
316            table,
317            assignments,
318            where_clause,
319        })
320    }
321
322    fn parse_delete(&mut self) -> Result<DeleteQuery, MdqlError> {
323        self.expect("keyword", Some("DELETE"))?;
324        self.expect("keyword", Some("FROM"))?;
325        let table = self.parse_ident()?;
326
327        let mut where_clause = None;
328        if self.match_keyword("WHERE") {
329            where_clause = Some(self.parse_or_expr()?);
330        }
331
332        self.expect_end()?;
333        Ok(DeleteQuery {
334            table,
335            where_clause,
336        })
337    }
338
339    fn parse_alter(&mut self) -> Result<Statement, MdqlError> {
340        self.expect("keyword", Some("ALTER"))?;
341        self.expect("keyword", Some("TABLE"))?;
342        let table = self.parse_ident()?;
343
344        let t = self.peek().ok_or_else(|| {
345            MdqlError::QueryParse("Expected RENAME, DROP, or MERGE after table name".into())
346        })?;
347
348        match (t.token_type.as_str(), t.value.as_str()) {
349            ("keyword", "RENAME") => {
350                self.advance();
351                self.expect("keyword", Some("FIELD"))?;
352                let old_name = self.parse_string_or_ident()?;
353                self.expect("keyword", Some("TO"))?;
354                let new_name = self.parse_string_or_ident()?;
355                self.expect_end()?;
356                Ok(Statement::AlterRename(AlterRenameFieldQuery {
357                    table,
358                    old_name,
359                    new_name,
360                }))
361            }
362            ("keyword", "DROP") => {
363                self.advance();
364                self.expect("keyword", Some("FIELD"))?;
365                let field_name = self.parse_string_or_ident()?;
366                self.expect_end()?;
367                Ok(Statement::AlterDrop(AlterDropFieldQuery {
368                    table,
369                    field_name,
370                }))
371            }
372            ("keyword", "MERGE") => {
373                self.advance();
374                self.expect("keyword", Some("FIELDS"))?;
375                let mut sources = vec![self.parse_string_or_ident()?];
376                while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
377                    self.advance();
378                    sources.push(self.parse_string_or_ident()?);
379                }
380                self.expect("keyword", Some("INTO"))?;
381                let target = self.parse_string_or_ident()?;
382                self.expect_end()?;
383                Ok(Statement::AlterMerge(AlterMergeFieldsQuery {
384                    table,
385                    sources,
386                    into: target,
387                }))
388            }
389            _ => Err(MdqlError::QueryParse(format!(
390                "Expected RENAME, DROP, or MERGE, got '{}'",
391                t.raw
392            ))),
393        }
394    }
395
396    fn parse_create_view(&mut self) -> Result<Statement, MdqlError> {
397        self.expect("keyword", Some("CREATE"))?;
398        self.expect("keyword", Some("VIEW"))?;
399        let view_name = self.parse_ident()?;
400
401        let columns = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
402            self.advance();
403            let mut cols = vec![self.parse_ident()?];
404            while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
405                self.advance();
406                cols.push(self.parse_ident()?);
407            }
408            self.expect("op", Some(")"))?;
409            Some(cols)
410        } else {
411            None
412        };
413
414        self.expect("keyword", Some("AS"))?;
415        let query = Box::new(self.parse_select()?);
416
417        Ok(Statement::CreateView(CreateViewQuery {
418            view_name,
419            columns,
420            query,
421        }))
422    }
423
424    fn parse_drop_view(&mut self) -> Result<Statement, MdqlError> {
425        self.expect("keyword", Some("DROP"))?;
426        self.expect("keyword", Some("VIEW"))?;
427        let view_name = self.parse_ident()?;
428        self.expect_end()?;
429        Ok(Statement::DropView(DropViewQuery { view_name }))
430    }
431
432    fn parse_string_or_ident(&mut self) -> Result<String, MdqlError> {
433        let t = self.peek().ok_or_else(|| {
434            MdqlError::QueryParse("Expected field name, got end of query".into())
435        })?;
436        match t.token_type.as_str() {
437            "string" => {
438                let v = self.advance().value;
439                Ok(v)
440            }
441            "ident" | "keyword" => {
442                let v = self.advance().value;
443                Ok(v)
444            }
445            _ => Err(MdqlError::QueryParse(format!(
446                "Expected field name, got '{}'",
447                t.raw
448            ))),
449        }
450    }
451
452    fn parse_columns(&mut self) -> Result<ColumnList, MdqlError> {
453        if let Some(t) = self.peek() {
454            if t.token_type == "op" && t.value == "*" {
455                self.advance();
456                return Ok(ColumnList::All);
457            }
458        }
459
460        let mut exprs = vec![self.parse_select_expr()?];
461        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
462            self.advance();
463            exprs.push(self.parse_select_expr()?);
464        }
465        Ok(ColumnList::Named(exprs))
466    }
467
468    fn peek_is_agg_func(&self) -> bool {
469        let t = match self.peek() {
470            Some(t) => t,
471            None => return false,
472        };
473        let name_upper = t.value.to_uppercase();
474        if !AGG_FUNCS.contains(&name_upper.as_str()) {
475            return false;
476        }
477        // Only treat as aggregate if followed by (
478        self.tokens
479            .get(self.pos + 1)
480            .map_or(false, |next| next.token_type == "op" && next.value == "(")
481    }
482
483    fn parse_select_expr(&mut self) -> Result<SelectExpr, MdqlError> {
484        let _t = self.peek().ok_or_else(|| {
485            MdqlError::QueryParse("Expected column or aggregate, got end of query".into())
486        })?;
487
488        if self.peek_is_agg_func() {
489            let func_name = self.advance().value.to_uppercase();
490            let func = match func_name.as_str() {
491                "COUNT" => AggFunc::Count,
492                "SUM" => AggFunc::Sum,
493                "AVG" => AggFunc::Avg,
494                "MIN" => AggFunc::Min,
495                "MAX" => AggFunc::Max,
496                _ => unreachable!(),
497            };
498            self.expect("op", Some("("))?;
499            let (arg, arg_expr) = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "*") {
500                self.advance();
501                ("*".to_string(), None)
502            } else {
503                // Parse a full expression inside the aggregate (supports CASE WHEN, arithmetic, etc.)
504                let expr = self.parse_additive()?;
505                if let Expr::Column(name) = &expr {
506                    (name.clone(), None)
507                } else {
508                    (expr.display_name(), Some(expr))
509                }
510            };
511            self.expect("op", Some(")"))?;
512
513            let alias = if self.match_keyword("AS") {
514                Some(self.parse_ident()?)
515            } else if self.peek().map_or(false, |t| {
516                t.token_type == "ident" && !self.is_clause_keyword(t)
517            }) {
518                Some(self.advance().value)
519            } else {
520                None
521            };
522
523            Ok(SelectExpr::Aggregate { func, arg, arg_expr, alias })
524        } else {
525            // Parse a general expression (could be a column, literal, or arithmetic)
526            let expr = self.parse_additive()?;
527
528            // Optional alias: explicit (AS alias) or implicit (just an ident)
529            let alias = if self.match_keyword("AS") {
530                Some(self.parse_ident()?)
531            } else if self.peek().map_or(false, |t| {
532                t.token_type == "ident" && !self.is_clause_keyword(t)
533            }) {
534                Some(self.advance().value)
535            } else {
536                None
537            };
538
539            // If it's a simple column reference with no alias, return Column variant
540            // for backward compatibility
541            if alias.is_none() {
542                if let Expr::Column(name) = &expr {
543                    return Ok(SelectExpr::Column(name.clone()));
544                }
545            }
546
547            Ok(SelectExpr::Expr { expr, alias })
548        }
549    }
550
551    // ── Expression parser (precedence climbing) ───────────────────────
552
553    fn peek_is_additive_op(&self) -> bool {
554        self.peek().map_or(false, |t| {
555            t.token_type == "op" && (t.value == "+" || t.value == "-")
556        })
557    }
558
559    fn peek_is_multiplicative_op(&self) -> bool {
560        self.peek().map_or(false, |t| {
561            t.token_type == "op" && (t.value == "*" || t.value == "/" || t.value == "%")
562        })
563    }
564
565    fn parse_additive(&mut self) -> Result<Expr, MdqlError> {
566        let mut left = self.parse_multiplicative()?;
567        while self.peek_is_additive_op() {
568            let op_tok = self.advance();
569            let is_sub = op_tok.value == "-";
570
571            // Check for INTERVAL keyword: expr +/- INTERVAL n DAY
572            if self.peek().map_or(false, |t| t.token_type == "keyword" && t.value == "INTERVAL") {
573                self.advance(); // consume INTERVAL
574                let days_expr = self.parse_multiplicative()?;
575                // Expect DAY or DAYS
576                if !self.match_keyword("DAY") && !self.match_keyword("DAYS") {
577                    return Err(MdqlError::QueryParse("Expected DAY after INTERVAL value".into()));
578                }
579                let days = if is_sub {
580                    Expr::UnaryMinus(Box::new(days_expr))
581                } else {
582                    days_expr
583                };
584                left = Expr::DateAdd {
585                    date: Box::new(left),
586                    days: Box::new(days),
587                };
588                continue;
589            }
590
591            let op = match op_tok.value.as_str() {
592                "+" => ArithOp::Add,
593                "-" => ArithOp::Sub,
594                _ => unreachable!(),
595            };
596            let right = self.parse_multiplicative()?;
597            left = Expr::BinaryOp {
598                left: Box::new(left),
599                op,
600                right: Box::new(right),
601            };
602        }
603        Ok(left)
604    }
605
606    fn parse_multiplicative(&mut self) -> Result<Expr, MdqlError> {
607        let mut left = self.parse_unary()?;
608        while self.peek_is_multiplicative_op() {
609            let op_tok = self.advance();
610            let op = match op_tok.value.as_str() {
611                "*" => ArithOp::Mul,
612                "/" => ArithOp::Div,
613                "%" => ArithOp::Mod,
614                _ => unreachable!(),
615            };
616            let right = self.parse_unary()?;
617            left = Expr::BinaryOp {
618                left: Box::new(left),
619                op,
620                right: Box::new(right),
621            };
622        }
623        Ok(left)
624    }
625
626    fn parse_unary(&mut self) -> Result<Expr, MdqlError> {
627        if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "-") {
628            self.advance();
629            let inner = self.parse_atom()?;
630            // Fold unary minus on literals
631            match inner {
632                Expr::Literal(SqlValue::Int(n)) => Ok(Expr::Literal(SqlValue::Int(-n))),
633                Expr::Literal(SqlValue::Float(f)) => Ok(Expr::Literal(SqlValue::Float(-f))),
634                _ => Ok(Expr::UnaryMinus(Box::new(inner))),
635            }
636        } else {
637            self.parse_atom()
638        }
639    }
640
641    fn parse_atom(&mut self) -> Result<Expr, MdqlError> {
642        let t = self.peek().ok_or_else(|| {
643            MdqlError::QueryParse("Expected expression, got end of query".into())
644        })?;
645
646        match t.token_type.as_str() {
647            "number" => {
648                let v = self.advance().value;
649                if v.contains('.') {
650                    let f: f64 = v.parse().map_err(|_| {
651                        MdqlError::QueryParse(format!("Invalid float: {}", v))
652                    })?;
653                    Ok(Expr::Literal(SqlValue::Float(f)))
654                } else {
655                    let n: i64 = v.parse().map_err(|_| {
656                        MdqlError::QueryParse(format!("Invalid int: {}", v))
657                    })?;
658                    Ok(Expr::Literal(SqlValue::Int(n)))
659                }
660            }
661            "string" => {
662                let v = self.advance().value;
663                Ok(Expr::Literal(SqlValue::String(v)))
664            }
665            "keyword" if t.value == "NULL" => {
666                self.advance();
667                Ok(Expr::Literal(SqlValue::Null))
668            }
669            "keyword" if t.value == "CASE" => {
670                self.parse_case_expr()
671            }
672            "keyword" if t.value == "CURRENT_DATE" => {
673                self.advance();
674                Ok(Expr::CurrentDate)
675            }
676            "keyword" if t.value == "CURRENT_TIMESTAMP" => {
677                self.advance();
678                Ok(Expr::CurrentTimestamp)
679            }
680            "keyword" if t.value == "DATEDIFF" => {
681                self.advance();
682                self.expect("op", Some("("))?;
683                let left = self.parse_additive()?;
684                self.expect("op", Some(","))?;
685                let right = self.parse_additive()?;
686                self.expect("op", Some(")"))?;
687                Ok(Expr::DateDiff { left: Box::new(left), right: Box::new(right) })
688            }
689            "op" if t.value == "(" => {
690                self.advance();
691                let expr = self.parse_additive()?;
692                self.expect("op", Some(")"))?;
693                Ok(expr)
694            }
695            "ident" => {
696                let name = self.advance().value;
697                Ok(Expr::Column(name))
698            }
699            "keyword" if !Self::is_reserved_keyword(&t.value) => {
700                let name = self.advance().value;
701                Ok(Expr::Column(name))
702            }
703            _ => Err(MdqlError::QueryParse(format!(
704                "Expected expression, got '{}'",
705                t.raw
706            ))),
707        }
708    }
709
710    fn parse_case_expr(&mut self) -> Result<Expr, MdqlError> {
711        self.expect("keyword", Some("CASE"))?;
712        let mut whens = Vec::new();
713        while self.match_keyword("WHEN") {
714            let condition = self.parse_or_expr()?;
715            self.expect("keyword", Some("THEN"))?;
716            let result = self.parse_additive()?;
717            whens.push((condition, Box::new(result)));
718        }
719        if whens.is_empty() {
720            return Err(MdqlError::QueryParse("CASE requires at least one WHEN clause".into()));
721        }
722        let else_expr = if self.match_keyword("ELSE") {
723            Some(Box::new(self.parse_additive()?))
724        } else {
725            None
726        };
727        self.expect("keyword", Some("END"))?;
728        Ok(Expr::Case { whens, else_expr })
729    }
730
731    fn parse_ident(&mut self) -> Result<String, MdqlError> {
732        let t = self.peek().ok_or_else(|| {
733            MdqlError::QueryParse("Expected identifier, got end of query".into())
734        })?;
735        match t.token_type.as_str() {
736            "ident" | "keyword" => {
737                let v = self.advance().value;
738                Ok(v)
739            }
740            _ => Err(MdqlError::QueryParse(format!(
741                "Expected identifier, got '{}'",
742                t.raw
743            ))),
744        }
745    }
746
747    fn parse_or_expr(&mut self) -> Result<WhereClause, MdqlError> {
748        let mut left = self.parse_and_expr()?;
749        while self.match_keyword("OR") {
750            let right = self.parse_and_expr()?;
751            left = WhereClause::BoolOp(BoolOp {
752                op: "OR".into(),
753                left: Box::new(left),
754                right: Box::new(right),
755            });
756        }
757        Ok(left)
758    }
759
760    fn parse_and_expr(&mut self) -> Result<WhereClause, MdqlError> {
761        let mut left = self.parse_comparison()?;
762        while self.match_keyword("AND") {
763            let right = self.parse_comparison()?;
764            left = WhereClause::BoolOp(BoolOp {
765                op: "AND".into(),
766                left: Box::new(left),
767                right: Box::new(right),
768            });
769        }
770        Ok(left)
771    }
772
773    fn parse_comparison(&mut self) -> Result<WhereClause, MdqlError> {
774        // Handle parenthesized boolean expressions
775        if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
776            // Save position — might be arithmetic parens, not boolean
777            let saved_pos = self.pos;
778            self.advance();
779            // Try parsing as boolean (OR/AND) expression
780            let result = self.parse_or_expr();
781            if result.is_ok() && self.peek().map_or(false, |t| t.token_type == "op" && t.value == ")") {
782                self.advance();
783                return result;
784            }
785            // Not a boolean paren — rewind and parse as arithmetic expression
786            self.pos = saved_pos;
787        }
788
789        // Parse the left side as a full expression (column, literal, or arithmetic)
790        let left_expr = self.parse_additive()?;
791
792        // Extract column name for backward compat (simple column on left side)
793        let col = left_expr.as_column().unwrap_or("").to_string();
794
795        // IS NULL / IS NOT NULL (only valid with simple column)
796        if self.match_keyword("IS") {
797            if self.match_keyword("NOT") {
798                self.expect("keyword", Some("NULL"))?;
799                return Ok(WhereClause::Comparison(Comparison {
800                    column: col,
801                    op: "IS NOT NULL".into(),
802                    value: None,
803                    left_expr: Some(left_expr),
804                    right_expr: None,
805                }));
806            }
807            self.expect("keyword", Some("NULL"))?;
808            return Ok(WhereClause::Comparison(Comparison {
809                column: col,
810                op: "IS NULL".into(),
811                value: None,
812                left_expr: Some(left_expr),
813                right_expr: None,
814            }));
815        }
816
817        // IN (val, val, ...)
818        if self.match_keyword("IN") {
819            self.expect("op", Some("("))?;
820            let mut values = vec![self.parse_value()?];
821            while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
822                self.advance();
823                values.push(self.parse_value()?);
824            }
825            self.expect("op", Some(")"))?;
826            return Ok(WhereClause::Comparison(Comparison {
827                column: col,
828                op: "IN".into(),
829                value: Some(SqlValue::List(values)),
830                left_expr: Some(left_expr),
831                right_expr: None,
832            }));
833        }
834
835        // LIKE
836        if self.match_keyword("LIKE") {
837            let val = self.parse_value()?;
838            return Ok(WhereClause::Comparison(Comparison {
839                column: col,
840                op: "LIKE".into(),
841                value: Some(val),
842                left_expr: Some(left_expr),
843                right_expr: None,
844            }));
845        }
846
847        // NOT LIKE
848        if self.match_keyword("NOT") {
849            if self.match_keyword("LIKE") {
850                let val = self.parse_value()?;
851                return Ok(WhereClause::Comparison(Comparison {
852                    column: col,
853                    op: "NOT LIKE".into(),
854                    value: Some(val),
855                    left_expr: Some(left_expr),
856                    right_expr: None,
857                }));
858            }
859            return Err(MdqlError::QueryParse("Expected LIKE after NOT".into()));
860        }
861
862        // Standard comparison operators
863        if let Some(t) = self.peek() {
864            if t.token_type == "op" && ["=", "!=", "<", ">", "<=", ">="].contains(&t.value.as_str())
865            {
866                let op = self.advance().value;
867                // Parse right side as expression
868                let right_expr = self.parse_additive()?;
869                // Extract SqlValue for backward compat (simple literal on right side)
870                let value = match &right_expr {
871                    Expr::Literal(v) => Some(v.clone()),
872                    _ => None,
873                };
874                return Ok(WhereClause::Comparison(Comparison {
875                    column: col,
876                    op,
877                    value,
878                    left_expr: Some(left_expr),
879                    right_expr: Some(right_expr),
880                }));
881            }
882        }
883
884        let got = self.peek().map_or("end".to_string(), |t| t.raw.clone());
885        Err(MdqlError::QueryParse(format!(
886            "Expected operator after '{}', got '{}'",
887            left_expr.display_name(), got
888        )))
889    }
890
891    fn parse_value(&mut self) -> Result<SqlValue, MdqlError> {
892        let t = self.peek().ok_or_else(|| {
893            MdqlError::QueryParse("Expected value, got end of query".into())
894        })?;
895        match t.token_type.as_str() {
896            "string" => {
897                let v = self.advance().value;
898                Ok(SqlValue::String(v))
899            }
900            "number" => {
901                let v = self.advance().value;
902                if v.contains('.') {
903                    Ok(SqlValue::Float(v.parse().map_err(|_| {
904                        MdqlError::QueryParse(format!("Invalid float: {}", v))
905                    })?))
906                } else {
907                    Ok(SqlValue::Int(v.parse().map_err(|_| {
908                        MdqlError::QueryParse(format!("Invalid int: {}", v))
909                    })?))
910                }
911            }
912            "keyword" if t.value == "NULL" => {
913                self.advance();
914                Ok(SqlValue::Null)
915            }
916            _ => Err(MdqlError::QueryParse(format!(
917                "Expected value, got '{}'",
918                t.raw
919            ))),
920        }
921    }
922
923    fn parse_order_by(&mut self) -> Result<Vec<OrderSpec>, MdqlError> {
924        let mut specs = vec![self.parse_order_spec()?];
925        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
926            self.advance();
927            specs.push(self.parse_order_spec()?);
928        }
929        Ok(specs)
930    }
931
932    fn parse_order_spec(&mut self) -> Result<OrderSpec, MdqlError> {
933        let expr = self.parse_additive()?;
934        let col = expr.as_column().unwrap_or("").to_string();
935        let descending = if self.match_keyword("DESC") {
936            true
937        } else {
938            self.match_keyword("ASC");
939            false
940        };
941        Ok(OrderSpec {
942            column: col,
943            expr: Some(expr),
944            descending,
945        })
946    }
947
948    fn is_clause_keyword(&self, t: &Token) -> bool {
949        t.token_type == "keyword"
950            && ["WHERE", "ORDER", "LIMIT", "JOIN", "ON", "GROUP"].contains(&t.value.as_str())
951    }
952
953    /// Keywords that should never be consumed as column names inside expressions.
954    fn is_reserved_keyword(kw: &str) -> bool {
955        matches!(kw,
956            "AS" | "FROM" | "WHERE" | "AND" | "OR" | "ORDER" | "BY"
957            | "ASC" | "DESC" | "LIMIT" | "JOIN" | "ON" | "GROUP"
958            | "SELECT" | "INSERT" | "INTO" | "VALUES" | "UPDATE" | "SET"
959            | "DELETE" | "ALTER" | "TABLE" | "IS" | "NOT" | "IN" | "LIKE"
960            | "RENAME" | "FIELD" | "TO" | "DROP" | "MERGE" | "FIELDS"
961            | "CASE" | "WHEN" | "THEN" | "ELSE" | "END"
962            | "HAVING" | "INTERVAL" | "DAY" | "DAYS"
963            | "CURRENT_DATE" | "CURRENT_TIMESTAMP" | "DATEDIFF"
964            | "CREATE" | "VIEW" | "CASCADE" | "RESTRICT"
965        )
966    }
967
968    fn expect_end(&self) -> Result<(), MdqlError> {
969        if let Some(t) = self.peek() {
970            return Err(MdqlError::QueryParse(format!(
971                "Unexpected token '{}' at position {}",
972                t.raw, self.pos
973            )));
974        }
975        Ok(())
976    }
977}
978
979pub fn parse_query(sql: &str) -> crate::errors::Result<Statement> {
980    let tokens = tokenize(sql);
981    if tokens.is_empty() {
982        return Err(MdqlError::QueryParse("Empty query".into()));
983    }
984    let mut parser = Parser::new(tokens);
985    parser.parse_statement()
986}
987
988#[cfg(test)]
989mod tests {
990    use super::*;
991
992    #[test]
993    fn test_simple_select() {
994        let stmt = parse_query("SELECT title, status FROM strategies").unwrap();
995        if let Statement::Select(q) = stmt {
996            assert_eq!(q.columns, ColumnList::Named(vec![SelectExpr::Column("title".into()), SelectExpr::Column("status".into())]));
997            assert_eq!(q.table, "strategies");
998        } else {
999            panic!("Expected Select");
1000        }
1001    }
1002
1003    #[test]
1004    fn test_select_star() {
1005        let stmt = parse_query("SELECT * FROM test").unwrap();
1006        if let Statement::Select(q) = stmt {
1007            assert_eq!(q.columns, ColumnList::All);
1008        } else {
1009            panic!("Expected Select");
1010        }
1011    }
1012
1013    #[test]
1014    fn test_where_clause() {
1015        let stmt = parse_query("SELECT title FROM test WHERE count > 5").unwrap();
1016        if let Statement::Select(q) = stmt {
1017            assert!(q.where_clause.is_some());
1018        } else {
1019            panic!("Expected Select");
1020        }
1021    }
1022
1023    #[test]
1024    fn test_order_by() {
1025        let stmt =
1026            parse_query("SELECT title FROM test ORDER BY composite DESC, title ASC").unwrap();
1027        if let Statement::Select(q) = stmt {
1028            let ob = q.order_by.unwrap();
1029            assert_eq!(ob.len(), 2);
1030            assert!(ob[0].descending);
1031            assert!(!ob[1].descending);
1032        } else {
1033            panic!("Expected Select");
1034        }
1035    }
1036
1037    #[test]
1038    fn test_limit() {
1039        let stmt = parse_query("SELECT * FROM test LIMIT 10").unwrap();
1040        if let Statement::Select(q) = stmt {
1041            assert_eq!(q.limit, Some(10));
1042        } else {
1043            panic!("Expected Select");
1044        }
1045    }
1046
1047    #[test]
1048    fn test_insert() {
1049        let stmt = parse_query(
1050            "INSERT INTO test (title, count) VALUES ('Hello', 42)",
1051        )
1052        .unwrap();
1053        if let Statement::Insert(q) = stmt {
1054            assert_eq!(q.table, "test");
1055            assert_eq!(q.columns, vec!["title", "count"]);
1056            assert_eq!(q.values[0], SqlValue::String("Hello".into()));
1057            assert_eq!(q.values[1], SqlValue::Int(42));
1058        } else {
1059            panic!("Expected Insert");
1060        }
1061    }
1062
1063    #[test]
1064    fn test_update() {
1065        let stmt = parse_query("UPDATE test SET status = 'KILLED' WHERE path = 'a.md'").unwrap();
1066        if let Statement::Update(q) = stmt {
1067            assert_eq!(q.table, "test");
1068            assert_eq!(q.assignments.len(), 1);
1069            assert!(q.where_clause.is_some());
1070        } else {
1071            panic!("Expected Update");
1072        }
1073    }
1074
1075    #[test]
1076    fn test_delete() {
1077        let stmt = parse_query("DELETE FROM test WHERE status = 'draft'").unwrap();
1078        if let Statement::Delete(q) = stmt {
1079            assert_eq!(q.table, "test");
1080            assert!(q.where_clause.is_some());
1081        } else {
1082            panic!("Expected Delete");
1083        }
1084    }
1085
1086    #[test]
1087    fn test_alter_rename() {
1088        let stmt =
1089            parse_query("ALTER TABLE test RENAME FIELD 'Summary' TO 'Overview'").unwrap();
1090        if let Statement::AlterRename(q) = stmt {
1091            assert_eq!(q.old_name, "Summary");
1092            assert_eq!(q.new_name, "Overview");
1093        } else {
1094            panic!("Expected AlterRename");
1095        }
1096    }
1097
1098    #[test]
1099    fn test_alter_drop() {
1100        let stmt = parse_query("ALTER TABLE test DROP FIELD 'Details'").unwrap();
1101        if let Statement::AlterDrop(q) = stmt {
1102            assert_eq!(q.field_name, "Details");
1103        } else {
1104            panic!("Expected AlterDrop");
1105        }
1106    }
1107
1108    #[test]
1109    fn test_alter_merge() {
1110        let stmt = parse_query(
1111            "ALTER TABLE test MERGE FIELDS 'Entry Rules', 'Exit Rules' INTO 'Trading Rules'",
1112        )
1113        .unwrap();
1114        if let Statement::AlterMerge(q) = stmt {
1115            assert_eq!(q.sources, vec!["Entry Rules", "Exit Rules"]);
1116            assert_eq!(q.into, "Trading Rules");
1117        } else {
1118            panic!("Expected AlterMerge");
1119        }
1120    }
1121
1122    #[test]
1123    fn test_backtick_ident() {
1124        let stmt = parse_query("SELECT `Structural Mechanism` FROM test").unwrap();
1125        if let Statement::Select(q) = stmt {
1126            assert_eq!(
1127                q.columns,
1128                ColumnList::Named(vec![SelectExpr::Column("Structural Mechanism".into())])
1129            );
1130        } else {
1131            panic!("Expected Select");
1132        }
1133    }
1134
1135    #[test]
1136    fn test_like_operator() {
1137        let stmt = parse_query("SELECT title FROM test WHERE categories LIKE '%defi%'").unwrap();
1138        if let Statement::Select(q) = stmt {
1139            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1140                assert_eq!(c.op, "LIKE");
1141                assert_eq!(c.value, Some(SqlValue::String("%defi%".into())));
1142            } else {
1143                panic!("Expected LIKE comparison");
1144            }
1145        } else {
1146            panic!("Expected Select");
1147        }
1148    }
1149
1150    #[test]
1151    fn test_in_operator() {
1152        let stmt =
1153            parse_query("SELECT * FROM test WHERE status IN ('ACTIVE', 'LIVE')").unwrap();
1154        if let Statement::Select(q) = stmt {
1155            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1156                assert_eq!(c.op, "IN");
1157            } else {
1158                panic!("Expected IN comparison");
1159            }
1160        } else {
1161            panic!("Expected Select");
1162        }
1163    }
1164
1165    #[test]
1166    fn test_is_null() {
1167        let stmt = parse_query("SELECT * FROM test WHERE title IS NULL").unwrap();
1168        if let Statement::Select(q) = stmt {
1169            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1170                assert_eq!(c.op, "IS NULL");
1171            } else {
1172                panic!("Expected IS NULL comparison");
1173            }
1174        } else {
1175            panic!("Expected Select");
1176        }
1177    }
1178
1179    #[test]
1180    fn test_and_or() {
1181        let stmt = parse_query(
1182            "SELECT * FROM test WHERE status = 'ACTIVE' AND count > 5 OR title LIKE '%test%'",
1183        )
1184        .unwrap();
1185        if let Statement::Select(q) = stmt {
1186            assert!(q.where_clause.is_some());
1187        } else {
1188            panic!("Expected Select");
1189        }
1190    }
1191
1192    #[test]
1193    fn test_join() {
1194        let stmt = parse_query(
1195            "SELECT s.title, b.sharpe FROM strategies s JOIN backtests b ON b.strategy = s.path",
1196        )
1197        .unwrap();
1198        if let Statement::Select(q) = stmt {
1199            assert_eq!(q.table, "strategies");
1200            assert_eq!(q.table_alias, Some("s".into()));
1201            assert_eq!(q.joins.len(), 1);
1202            let join = &q.joins[0];
1203            assert_eq!(join.table, "backtests");
1204            assert_eq!(join.alias, Some("b".into()));
1205        } else {
1206            panic!("Expected Select");
1207        }
1208    }
1209
1210    #[test]
1211    fn test_multi_join() {
1212        let stmt = parse_query(
1213            "SELECT s.title, b.sharpe, c.verdict FROM strategies s JOIN backtests b ON b.strategy = s.path JOIN critiques c ON c.strategy = s.path",
1214        )
1215        .unwrap();
1216        if let Statement::Select(q) = stmt {
1217            assert_eq!(q.table, "strategies");
1218            assert_eq!(q.table_alias, Some("s".into()));
1219            assert_eq!(q.joins.len(), 2);
1220            assert_eq!(q.joins[0].table, "backtests");
1221            assert_eq!(q.joins[0].alias, Some("b".into()));
1222            assert_eq!(q.joins[0].left_col, "b.strategy");
1223            assert_eq!(q.joins[0].right_col, "s.path");
1224            assert_eq!(q.joins[1].table, "critiques");
1225            assert_eq!(q.joins[1].alias, Some("c".into()));
1226            assert_eq!(q.joins[1].left_col, "c.strategy");
1227            assert_eq!(q.joins[1].right_col, "s.path");
1228        } else {
1229            panic!("Expected Select");
1230        }
1231    }
1232
1233    #[test]
1234    fn test_empty_query() {
1235        assert!(parse_query("").is_err());
1236    }
1237
1238    #[test]
1239    fn test_count_star() {
1240        let stmt = parse_query("SELECT status, COUNT(*) AS cnt FROM strategies GROUP BY status").unwrap();
1241        if let Statement::Select(q) = stmt {
1242            if let ColumnList::Named(exprs) = &q.columns {
1243                assert_eq!(exprs.len(), 2);
1244                assert_eq!(exprs[0], SelectExpr::Column("status".into()));
1245                assert!(matches!(&exprs[1], SelectExpr::Aggregate {
1246                    func: AggFunc::Count,
1247                    arg,
1248                    alias: Some(a),
1249                    ..
1250                } if arg == "*" && a == "cnt"));
1251            } else {
1252                panic!("Expected Named columns");
1253            }
1254            assert_eq!(q.group_by, Some(vec!["status".into()]));
1255        } else {
1256            panic!("Expected Select");
1257        }
1258    }
1259
1260    #[test]
1261    fn test_count_column_as_ident() {
1262        // "count" as a column name should NOT be parsed as the COUNT aggregate
1263        let stmt = parse_query("INSERT INTO test (title, count) VALUES ('Hello', 42)").unwrap();
1264        if let Statement::Insert(q) = stmt {
1265            assert_eq!(q.columns, vec!["title", "count"]);
1266        } else {
1267            panic!("Expected Insert");
1268        }
1269    }
1270
1271    #[test]
1272    fn test_multiple_aggregates() {
1273        let stmt = parse_query("SELECT MIN(composite), MAX(composite), AVG(composite) FROM strategies").unwrap();
1274        if let Statement::Select(q) = stmt {
1275            if let ColumnList::Named(exprs) = &q.columns {
1276                assert_eq!(exprs.len(), 3);
1277                assert!(matches!(&exprs[0], SelectExpr::Aggregate { func: AggFunc::Min, .. }));
1278                assert!(matches!(&exprs[1], SelectExpr::Aggregate { func: AggFunc::Max, .. }));
1279                assert!(matches!(&exprs[2], SelectExpr::Aggregate { func: AggFunc::Avg, .. }));
1280            } else {
1281                panic!("Expected Named columns");
1282            }
1283            assert_eq!(q.group_by, None);
1284        } else {
1285            panic!("Expected Select");
1286        }
1287    }
1288
1289    // ── Expression tests ──────────────────────────────────────────
1290
1291    #[test]
1292    fn test_select_arithmetic_expr() {
1293        let stmt = parse_query("SELECT a + b FROM test").unwrap();
1294        if let Statement::Select(q) = stmt {
1295            if let ColumnList::Named(exprs) = &q.columns {
1296                assert_eq!(exprs.len(), 1);
1297                assert!(matches!(&exprs[0], SelectExpr::Expr {
1298                    expr: Expr::BinaryOp { op: ArithOp::Add, .. },
1299                    alias: None,
1300                }));
1301            } else {
1302                panic!("Expected Named columns");
1303            }
1304        } else {
1305            panic!("Expected Select");
1306        }
1307    }
1308
1309    #[test]
1310    fn test_select_arithmetic_with_alias() {
1311        let stmt = parse_query("SELECT a + b AS total FROM test").unwrap();
1312        if let Statement::Select(q) = stmt {
1313            if let ColumnList::Named(exprs) = &q.columns {
1314                assert_eq!(exprs.len(), 1);
1315                assert!(matches!(&exprs[0], SelectExpr::Expr {
1316                    alias: Some(a),
1317                    ..
1318                } if a == "total"));
1319                assert_eq!(exprs[0].output_name(), "total");
1320            } else {
1321                panic!("Expected Named columns");
1322            }
1323        } else {
1324            panic!("Expected Select");
1325        }
1326    }
1327
1328    #[test]
1329    fn test_select_precedence() {
1330        // a + b * c should parse as a + (b * c)
1331        let stmt = parse_query("SELECT a + b * c FROM test").unwrap();
1332        if let Statement::Select(q) = stmt {
1333            if let ColumnList::Named(exprs) = &q.columns {
1334                if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1335                    if let Expr::BinaryOp { left, op, right } = expr {
1336                        assert_eq!(*op, ArithOp::Add);
1337                        assert!(matches!(left.as_ref(), Expr::Column(n) if n == "a"));
1338                        assert!(matches!(right.as_ref(), Expr::BinaryOp { op: ArithOp::Mul, .. }));
1339                    } else {
1340                        panic!("Expected BinaryOp");
1341                    }
1342                } else {
1343                    panic!("Expected Expr variant");
1344                }
1345            } else {
1346                panic!("Expected Named columns");
1347            }
1348        } else {
1349            panic!("Expected Select");
1350        }
1351    }
1352
1353    #[test]
1354    fn test_select_parenthesized_expr() {
1355        // (a + b) * c should override default precedence
1356        let stmt = parse_query("SELECT (a + b) * c FROM test").unwrap();
1357        if let Statement::Select(q) = stmt {
1358            if let ColumnList::Named(exprs) = &q.columns {
1359                if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1360                    if let Expr::BinaryOp { left, op, .. } = expr {
1361                        assert_eq!(*op, ArithOp::Mul);
1362                        assert!(matches!(left.as_ref(), Expr::BinaryOp { op: ArithOp::Add, .. }));
1363                    } else {
1364                        panic!("Expected BinaryOp");
1365                    }
1366                } else {
1367                    panic!("Expected Expr variant");
1368                }
1369            } else {
1370                panic!("Expected Named columns");
1371            }
1372        } else {
1373            panic!("Expected Select");
1374        }
1375    }
1376
1377    #[test]
1378    fn test_select_unary_minus() {
1379        let stmt = parse_query("SELECT -count FROM test").unwrap();
1380        if let Statement::Select(q) = stmt {
1381            if let ColumnList::Named(exprs) = &q.columns {
1382                assert!(matches!(&exprs[0], SelectExpr::Expr {
1383                    expr: Expr::UnaryMinus(_),
1384                    ..
1385                }));
1386            } else {
1387                panic!("Expected Named columns");
1388            }
1389        } else {
1390            panic!("Expected Select");
1391        }
1392    }
1393
1394    #[test]
1395    fn test_select_negative_literal() {
1396        let stmt = parse_query("SELECT -42 FROM test").unwrap();
1397        if let Statement::Select(q) = stmt {
1398            if let ColumnList::Named(exprs) = &q.columns {
1399                // Unary minus folds into the literal
1400                assert!(matches!(&exprs[0], SelectExpr::Expr {
1401                    expr: Expr::Literal(SqlValue::Int(-42)),
1402                    ..
1403                }));
1404            } else {
1405                panic!("Expected Named columns");
1406            }
1407        } else {
1408            panic!("Expected Select");
1409        }
1410    }
1411
1412    #[test]
1413    fn test_where_arithmetic_expr() {
1414        let stmt = parse_query("SELECT * FROM test WHERE a + b > 10").unwrap();
1415        if let Statement::Select(q) = stmt {
1416            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1417                assert_eq!(c.op, ">");
1418                assert!(matches!(&c.left_expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1419                assert!(matches!(&c.right_expr, Some(Expr::Literal(SqlValue::Int(10)))));
1420            } else {
1421                panic!("Expected comparison");
1422            }
1423        } else {
1424            panic!("Expected Select");
1425        }
1426    }
1427
1428    #[test]
1429    fn test_where_both_sides_expr() {
1430        let stmt = parse_query("SELECT * FROM test WHERE a * 2 > b + 1").unwrap();
1431        if let Statement::Select(q) = stmt {
1432            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1433                assert_eq!(c.op, ">");
1434                assert!(matches!(&c.left_expr, Some(Expr::BinaryOp { op: ArithOp::Mul, .. })));
1435                assert!(matches!(&c.right_expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1436            } else {
1437                panic!("Expected comparison");
1438            }
1439        } else {
1440            panic!("Expected Select");
1441        }
1442    }
1443
1444    #[test]
1445    fn test_order_by_expr() {
1446        let stmt = parse_query("SELECT * FROM test ORDER BY a + b DESC").unwrap();
1447        if let Statement::Select(q) = stmt {
1448            let ob = q.order_by.unwrap();
1449            assert_eq!(ob.len(), 1);
1450            assert!(ob[0].descending);
1451            assert!(matches!(&ob[0].expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1452        } else {
1453            panic!("Expected Select");
1454        }
1455    }
1456
1457    #[test]
1458    fn test_all_arithmetic_ops() {
1459        let stmt = parse_query("SELECT a + b, a - b, a * b, a / b, a % b FROM test").unwrap();
1460        if let Statement::Select(q) = stmt {
1461            if let ColumnList::Named(exprs) = &q.columns {
1462                assert_eq!(exprs.len(), 5);
1463                assert!(matches!(&exprs[0], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Add, .. }, .. }));
1464                assert!(matches!(&exprs[1], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Sub, .. }, .. }));
1465                assert!(matches!(&exprs[2], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Mul, .. }, .. }));
1466                assert!(matches!(&exprs[3], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Div, .. }, .. }));
1467                assert!(matches!(&exprs[4], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Mod, .. }, .. }));
1468            } else {
1469                panic!("Expected Named columns");
1470            }
1471        } else {
1472            panic!("Expected Select");
1473        }
1474    }
1475
1476    #[test]
1477    fn test_column_with_literal_arithmetic() {
1478        let stmt = parse_query("SELECT count * 2 + 1 FROM test").unwrap();
1479        if let Statement::Select(q) = stmt {
1480            if let ColumnList::Named(exprs) = &q.columns {
1481                // Should be (count * 2) + 1
1482                if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1483                    if let Expr::BinaryOp { left, op, right } = expr {
1484                        assert_eq!(*op, ArithOp::Add);
1485                        assert!(matches!(right.as_ref(), Expr::Literal(SqlValue::Int(1))));
1486                        assert!(matches!(left.as_ref(), Expr::BinaryOp { op: ArithOp::Mul, .. }));
1487                    } else {
1488                        panic!("Expected BinaryOp");
1489                    }
1490                } else {
1491                    panic!("Expected Expr");
1492                }
1493            } else {
1494                panic!("Expected Named columns");
1495            }
1496        } else {
1497            panic!("Expected Select");
1498        }
1499    }
1500
1501    #[test]
1502    fn test_mixed_columns_and_exprs() {
1503        let stmt = parse_query("SELECT title, a + b AS sum, count FROM test").unwrap();
1504        if let Statement::Select(q) = stmt {
1505            if let ColumnList::Named(exprs) = &q.columns {
1506                assert_eq!(exprs.len(), 3);
1507                assert_eq!(exprs[0], SelectExpr::Column("title".into()));
1508                assert!(matches!(&exprs[1], SelectExpr::Expr { alias: Some(a), .. } if a == "sum"));
1509                assert_eq!(exprs[2], SelectExpr::Column("count".into()));
1510            } else {
1511                panic!("Expected Named columns");
1512            }
1513        } else {
1514            panic!("Expected Select");
1515        }
1516    }
1517
1518    // ── CASE WHEN tests ──────────────────────────────────────────
1519
1520    #[test]
1521    fn test_case_when_basic() {
1522        let stmt = parse_query(
1523            "SELECT CASE WHEN status = 'ACTIVE' THEN 1 ELSE 0 END FROM test"
1524        ).unwrap();
1525        if let Statement::Select(q) = stmt {
1526            if let ColumnList::Named(exprs) = &q.columns {
1527                assert_eq!(exprs.len(), 1);
1528                assert!(matches!(&exprs[0], SelectExpr::Expr {
1529                    expr: Expr::Case { .. },
1530                    ..
1531                }));
1532            } else {
1533                panic!("Expected Named columns");
1534            }
1535        } else {
1536            panic!("Expected Select");
1537        }
1538    }
1539
1540    #[test]
1541    fn test_case_when_multiple_branches() {
1542        let stmt = parse_query(
1543            "SELECT CASE WHEN x > 10 THEN 'high' WHEN x > 5 THEN 'mid' ELSE 'low' END FROM test"
1544        ).unwrap();
1545        if let Statement::Select(q) = stmt {
1546            if let ColumnList::Named(exprs) = &q.columns {
1547                if let SelectExpr::Expr { expr: Expr::Case { whens, else_expr }, .. } = &exprs[0] {
1548                    assert_eq!(whens.len(), 2);
1549                    assert!(else_expr.is_some());
1550                } else {
1551                    panic!("Expected Case expression");
1552                }
1553            } else {
1554                panic!("Expected Named columns");
1555            }
1556        } else {
1557            panic!("Expected Select");
1558        }
1559    }
1560
1561    #[test]
1562    fn test_case_when_no_else() {
1563        let stmt = parse_query(
1564            "SELECT CASE WHEN x = 1 THEN 'one' END FROM test"
1565        ).unwrap();
1566        if let Statement::Select(q) = stmt {
1567            if let ColumnList::Named(exprs) = &q.columns {
1568                if let SelectExpr::Expr { expr: Expr::Case { whens, else_expr }, .. } = &exprs[0] {
1569                    assert_eq!(whens.len(), 1);
1570                    assert!(else_expr.is_none());
1571                } else {
1572                    panic!("Expected Case expression");
1573                }
1574            } else {
1575                panic!("Expected Named columns");
1576            }
1577        } else {
1578            panic!("Expected Select");
1579        }
1580    }
1581
1582    #[test]
1583    fn test_case_when_in_aggregate() {
1584        let stmt = parse_query(
1585            "SELECT SUM(CASE WHEN side = 'BUY' THEN size ELSE -size END) AS net FROM orders GROUP BY token"
1586        ).unwrap();
1587        if let Statement::Select(q) = stmt {
1588            if let ColumnList::Named(exprs) = &q.columns {
1589                assert_eq!(exprs.len(), 1);
1590                assert!(matches!(&exprs[0], SelectExpr::Aggregate {
1591                    func: AggFunc::Sum,
1592                    arg_expr: Some(Expr::Case { .. }),
1593                    alias: Some(a),
1594                    ..
1595                } if a == "net"));
1596            } else {
1597                panic!("Expected Named columns");
1598            }
1599        } else {
1600            panic!("Expected Select");
1601        }
1602    }
1603
1604    #[test]
1605    fn test_case_when_with_alias() {
1606        let stmt = parse_query(
1607            "SELECT CASE WHEN x > 0 THEN 'pos' ELSE 'neg' END AS sign FROM test"
1608        ).unwrap();
1609        if let Statement::Select(q) = stmt {
1610            if let ColumnList::Named(exprs) = &q.columns {
1611                assert!(matches!(&exprs[0], SelectExpr::Expr {
1612                    expr: Expr::Case { .. },
1613                    alias: Some(a),
1614                } if a == "sign"));
1615            } else {
1616                panic!("Expected Named columns");
1617            }
1618        } else {
1619            panic!("Expected Select");
1620        }
1621    }
1622
1623    #[test]
1624    fn test_create_view() {
1625        let stmt = parse_query("CREATE VIEW live AS SELECT * FROM strategies WHERE status = 'LIVE'").unwrap();
1626        if let Statement::CreateView(cv) = stmt {
1627            assert_eq!(cv.view_name, "live");
1628            assert!(cv.columns.is_none());
1629            assert_eq!(cv.query.table, "strategies");
1630            assert!(cv.query.where_clause.is_some());
1631        } else {
1632            panic!("Expected CreateView, got {:?}", stmt);
1633        }
1634    }
1635
1636    #[test]
1637    fn test_create_view_with_columns() {
1638        let stmt = parse_query("CREATE VIEW v1 (a, b) AS SELECT title, status FROM t").unwrap();
1639        if let Statement::CreateView(cv) = stmt {
1640            assert_eq!(cv.view_name, "v1");
1641            assert_eq!(cv.columns, Some(vec!["a".into(), "b".into()]));
1642        } else {
1643            panic!("Expected CreateView");
1644        }
1645    }
1646
1647    #[test]
1648    fn test_drop_view() {
1649        let stmt = parse_query("DROP VIEW live").unwrap();
1650        if let Statement::DropView(dv) = stmt {
1651            assert_eq!(dv.view_name, "live");
1652        } else {
1653            panic!("Expected DropView, got {:?}", stmt);
1654        }
1655    }
1656
1657    #[test]
1658    fn test_create_view_case_insensitive() {
1659        let stmt = parse_query("create view My_View as select * from t").unwrap();
1660        if let Statement::CreateView(cv) = stmt {
1661            assert_eq!(cv.view_name, "My_View");
1662        } else {
1663            panic!("Expected CreateView");
1664        }
1665    }
1666}