Skip to main content

mdql_core/
query_parser.rs

1//! Hand-written recursive descent parser for the MDQL SQL subset.
2
3use regex::Regex;
4use std::sync::LazyLock;
5
6use crate::errors::MdqlError;
7
8// ── AST nodes ──────────────────────────────────────────────────────────────
9
10#[derive(Debug, Clone, PartialEq)]
11pub struct OrderSpec {
12    pub column: String,
13    pub descending: bool,
14}
15
16#[derive(Debug, Clone, PartialEq)]
17pub struct Comparison {
18    pub column: String,
19    pub op: String,
20    pub value: Option<SqlValue>,
21}
22
23#[derive(Debug, Clone, PartialEq)]
24pub struct BoolOp {
25    pub op: String, // "AND" or "OR"
26    pub left: Box<WhereClause>,
27    pub right: Box<WhereClause>,
28}
29
30#[derive(Debug, Clone, PartialEq)]
31pub enum WhereClause {
32    Comparison(Comparison),
33    BoolOp(BoolOp),
34}
35
36#[derive(Debug, Clone, PartialEq)]
37pub enum SqlValue {
38    String(String),
39    Int(i64),
40    Float(f64),
41    Null,
42    List(Vec<SqlValue>),
43}
44
45#[derive(Debug, Clone, PartialEq)]
46pub struct JoinClause {
47    pub table: String,
48    pub alias: Option<String>,
49    pub left_col: String,
50    pub right_col: String,
51}
52
53#[derive(Debug, Clone, PartialEq)]
54pub enum AggFunc {
55    Count,
56    Sum,
57    Avg,
58    Min,
59    Max,
60}
61
62#[derive(Debug, Clone, PartialEq)]
63pub enum SelectExpr {
64    Column(String),
65    Aggregate { func: AggFunc, arg: String, alias: Option<String> },
66}
67
68impl SelectExpr {
69    pub fn output_name(&self) -> String {
70        match self {
71            SelectExpr::Column(name) => name.clone(),
72            SelectExpr::Aggregate { func, arg, alias } => {
73                if let Some(a) = alias {
74                    a.clone()
75                } else {
76                    let func_name = match func {
77                        AggFunc::Count => "COUNT",
78                        AggFunc::Sum => "SUM",
79                        AggFunc::Avg => "AVG",
80                        AggFunc::Min => "MIN",
81                        AggFunc::Max => "MAX",
82                    };
83                    format!("{}({})", func_name, arg)
84                }
85            }
86        }
87    }
88
89    pub fn is_aggregate(&self) -> bool {
90        matches!(self, SelectExpr::Aggregate { .. })
91    }
92}
93
94#[derive(Debug, Clone, PartialEq)]
95pub struct SelectQuery {
96    pub columns: ColumnList,
97    pub table: String,
98    pub table_alias: Option<String>,
99    pub joins: Vec<JoinClause>,
100    pub where_clause: Option<WhereClause>,
101    pub group_by: Option<Vec<String>>,
102    pub order_by: Option<Vec<OrderSpec>>,
103    pub limit: Option<i64>,
104}
105
106#[derive(Debug, Clone, PartialEq)]
107pub enum ColumnList {
108    All,
109    Named(Vec<SelectExpr>),
110}
111
112#[derive(Debug, Clone, PartialEq)]
113pub struct InsertQuery {
114    pub table: String,
115    pub columns: Vec<String>,
116    pub values: Vec<SqlValue>,
117}
118
119#[derive(Debug, Clone, PartialEq)]
120pub struct UpdateQuery {
121    pub table: String,
122    pub assignments: Vec<(String, SqlValue)>,
123    pub where_clause: Option<WhereClause>,
124}
125
126#[derive(Debug, Clone, PartialEq)]
127pub struct DeleteQuery {
128    pub table: String,
129    pub where_clause: Option<WhereClause>,
130}
131
132#[derive(Debug, Clone, PartialEq)]
133pub struct AlterRenameFieldQuery {
134    pub table: String,
135    pub old_name: String,
136    pub new_name: String,
137}
138
139#[derive(Debug, Clone, PartialEq)]
140pub struct AlterDropFieldQuery {
141    pub table: String,
142    pub field_name: String,
143}
144
145#[derive(Debug, Clone, PartialEq)]
146pub struct AlterMergeFieldsQuery {
147    pub table: String,
148    pub sources: Vec<String>,
149    pub into: String,
150}
151
152#[derive(Debug, Clone, PartialEq)]
153pub enum Statement {
154    Select(SelectQuery),
155    Insert(InsertQuery),
156    Update(UpdateQuery),
157    Delete(DeleteQuery),
158    AlterRename(AlterRenameFieldQuery),
159    AlterDrop(AlterDropFieldQuery),
160    AlterMerge(AlterMergeFieldsQuery),
161}
162
163// ── Tokenizer ──────────────────────────────────────────────────────────────
164
165static KEYWORDS: &[&str] = &[
166    "SELECT", "FROM", "WHERE", "AND", "OR", "ORDER", "BY",
167    "ASC", "DESC", "LIMIT", "LIKE", "IN", "IS", "NOT", "NULL",
168    "JOIN", "ON", "AS", "GROUP",
169    "INSERT", "INTO", "VALUES", "UPDATE", "SET", "DELETE",
170    "ALTER", "TABLE", "RENAME", "FIELD", "TO", "DROP", "MERGE", "FIELDS",
171];
172
173static AGG_FUNCS: &[&str] = &["COUNT", "SUM", "AVG", "MIN", "MAX"];
174
175static TOKEN_RE: LazyLock<Regex> = LazyLock::new(|| {
176    Regex::new(
177        r#"(?x)
178        \s*(?:
179            (?P<backtick>`[^`]+`)
180            | (?P<string>'(?:[^'\\]|\\.)*')
181            | (?P<number>-?\d+(?:\.\d+)?)
182            | (?P<op><=|>=|!=|[=<>,*()])
183            | (?P<word>[A-Za-z_][A-Za-z0-9_./-]*)
184        )"#,
185    )
186    .unwrap()
187});
188
189#[derive(Debug, Clone)]
190struct Token {
191    token_type: String,
192    value: String,
193    raw: String,
194}
195
196fn tokenize(sql: &str) -> Vec<Token> {
197    let mut tokens = Vec::new();
198    for caps in TOKEN_RE.captures_iter(sql) {
199        if let Some(m) = caps.name("backtick") {
200            let raw = m.as_str();
201            tokens.push(Token {
202                token_type: "ident".into(),
203                value: raw[1..raw.len() - 1].into(),
204                raw: raw.into(),
205            });
206        } else if let Some(m) = caps.name("string") {
207            let raw = m.as_str();
208            tokens.push(Token {
209                token_type: "string".into(),
210                value: raw[1..raw.len() - 1].into(),
211                raw: raw.into(),
212            });
213        } else if let Some(m) = caps.name("number") {
214            let raw = m.as_str();
215            tokens.push(Token {
216                token_type: "number".into(),
217                value: raw.into(),
218                raw: raw.into(),
219            });
220        } else if let Some(m) = caps.name("op") {
221            let raw = m.as_str();
222            tokens.push(Token {
223                token_type: "op".into(),
224                value: raw.into(),
225                raw: raw.into(),
226            });
227        } else if let Some(m) = caps.name("word") {
228            let raw = m.as_str();
229            if KEYWORDS.contains(&raw.to_uppercase().as_str()) {
230                tokens.push(Token {
231                    token_type: "keyword".into(),
232                    value: raw.to_uppercase(),
233                    raw: raw.into(),
234                });
235            } else {
236                tokens.push(Token {
237                    token_type: "ident".into(),
238                    value: raw.into(),
239                    raw: raw.into(),
240                });
241            }
242        }
243    }
244    tokens
245}
246
247// ── Parser ─────────────────────────────────────────────────────────────────
248
249struct Parser {
250    tokens: Vec<Token>,
251    pos: usize,
252}
253
254impl Parser {
255    fn new(tokens: Vec<Token>) -> Self {
256        Parser { tokens, pos: 0 }
257    }
258
259    fn peek(&self) -> Option<&Token> {
260        self.tokens.get(self.pos)
261    }
262
263    fn advance(&mut self) -> Token {
264        let t = self.tokens[self.pos].clone();
265        self.pos += 1;
266        t
267    }
268
269    fn expect(&mut self, type_: &str, value: Option<&str>) -> Result<Token, MdqlError> {
270        let t = self.peek().ok_or_else(|| {
271            MdqlError::QueryParse(format!(
272                "Unexpected end of query, expected {}",
273                value.unwrap_or(type_)
274            ))
275        })?;
276        let matches_type = t.token_type == type_;
277        let matches_value = value.map_or(true, |v| t.value == v);
278        if !matches_type || !matches_value {
279            return Err(MdqlError::QueryParse(format!(
280                "Expected {}, got '{}' at position {}",
281                value.unwrap_or(type_),
282                t.raw,
283                self.pos
284            )));
285        }
286        Ok(self.advance())
287    }
288
289    fn match_keyword(&mut self, kw: &str) -> bool {
290        if let Some(t) = self.peek() {
291            if t.token_type == "keyword" && t.value == kw {
292                self.advance();
293                return true;
294            }
295        }
296        false
297    }
298
299    fn parse_statement(&mut self) -> Result<Statement, MdqlError> {
300        let t = self.peek().ok_or_else(|| MdqlError::QueryParse("Empty query".into()))?;
301        match (t.token_type.as_str(), t.value.as_str()) {
302            ("keyword", "SELECT") => Ok(Statement::Select(self.parse_select()?)),
303            ("keyword", "INSERT") => Ok(Statement::Insert(self.parse_insert()?)),
304            ("keyword", "UPDATE") => Ok(Statement::Update(self.parse_update()?)),
305            ("keyword", "DELETE") => Ok(Statement::Delete(self.parse_delete()?)),
306            ("keyword", "ALTER") => self.parse_alter(),
307            _ => Err(MdqlError::QueryParse(format!(
308                "Expected SELECT, INSERT, UPDATE, DELETE, or ALTER, got '{}'",
309                t.raw
310            ))),
311        }
312    }
313
314    fn parse_select(&mut self) -> Result<SelectQuery, MdqlError> {
315        self.expect("keyword", Some("SELECT"))?;
316        let columns = self.parse_columns()?;
317        self.expect("keyword", Some("FROM"))?;
318        let table = self.parse_ident()?;
319
320        // Optional table alias
321        let mut table_alias = None;
322        if let Some(t) = self.peek() {
323            if t.token_type == "ident" && !self.is_clause_keyword(t) {
324                table_alias = Some(self.advance().value);
325            }
326        }
327
328        // Optional JOIN(s)
329        let mut joins = Vec::new();
330        while self.match_keyword("JOIN") {
331            let join_table = self.parse_ident()?;
332            let mut join_alias = None;
333            if let Some(t) = self.peek() {
334                if t.token_type == "ident" && !self.is_clause_keyword(t) {
335                    join_alias = Some(self.advance().value);
336                }
337            }
338            self.expect("keyword", Some("ON"))?;
339            let left_col = self.parse_ident()?;
340            self.expect("op", Some("="))?;
341            let right_col = self.parse_ident()?;
342            joins.push(JoinClause {
343                table: join_table,
344                alias: join_alias,
345                left_col,
346                right_col,
347            });
348        }
349
350        let mut where_clause = None;
351        if self.match_keyword("WHERE") {
352            where_clause = Some(self.parse_or_expr()?);
353        }
354
355        let mut group_by = None;
356        if self.match_keyword("GROUP") {
357            self.expect("keyword", Some("BY"))?;
358            let mut cols = vec![self.parse_ident()?];
359            while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
360                self.advance();
361                cols.push(self.parse_ident()?);
362            }
363            group_by = Some(cols);
364        }
365
366        let mut order_by = None;
367        if self.match_keyword("ORDER") {
368            self.expect("keyword", Some("BY"))?;
369            order_by = Some(self.parse_order_by()?);
370        }
371
372        let mut limit = None;
373        if self.match_keyword("LIMIT") {
374            let t = self.expect("number", None)?;
375            limit = Some(t.value.parse::<i64>().map_err(|_| {
376                MdqlError::QueryParse(format!("Invalid LIMIT value: {}", t.value))
377            })?);
378        }
379
380        self.expect_end()?;
381
382        Ok(SelectQuery {
383            columns,
384            table,
385            table_alias,
386            joins,
387            where_clause,
388            group_by,
389            order_by,
390            limit,
391        })
392    }
393
394    fn parse_insert(&mut self) -> Result<InsertQuery, MdqlError> {
395        self.expect("keyword", Some("INSERT"))?;
396        self.expect("keyword", Some("INTO"))?;
397        let table = self.parse_ident()?;
398
399        self.expect("op", Some("("))?;
400        let mut columns = vec![self.parse_ident()?];
401        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
402            self.advance();
403            columns.push(self.parse_ident()?);
404        }
405        self.expect("op", Some(")"))?;
406
407        self.expect("keyword", Some("VALUES"))?;
408
409        self.expect("op", Some("("))?;
410        let mut values = vec![self.parse_value()?];
411        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
412            self.advance();
413            values.push(self.parse_value()?);
414        }
415        self.expect("op", Some(")"))?;
416
417        if columns.len() != values.len() {
418            return Err(MdqlError::QueryParse(format!(
419                "Column count ({}) does not match value count ({})",
420                columns.len(),
421                values.len()
422            )));
423        }
424
425        self.expect_end()?;
426        Ok(InsertQuery {
427            table,
428            columns,
429            values,
430        })
431    }
432
433    fn parse_update(&mut self) -> Result<UpdateQuery, MdqlError> {
434        self.expect("keyword", Some("UPDATE"))?;
435        let table = self.parse_ident()?;
436        self.expect("keyword", Some("SET"))?;
437
438        let mut assignments = Vec::new();
439        let col = self.parse_ident()?;
440        self.expect("op", Some("="))?;
441        let val = self.parse_value()?;
442        assignments.push((col, val));
443
444        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
445            self.advance();
446            let col = self.parse_ident()?;
447            self.expect("op", Some("="))?;
448            let val = self.parse_value()?;
449            assignments.push((col, val));
450        }
451
452        let mut where_clause = None;
453        if self.match_keyword("WHERE") {
454            where_clause = Some(self.parse_or_expr()?);
455        }
456
457        self.expect_end()?;
458        Ok(UpdateQuery {
459            table,
460            assignments,
461            where_clause,
462        })
463    }
464
465    fn parse_delete(&mut self) -> Result<DeleteQuery, MdqlError> {
466        self.expect("keyword", Some("DELETE"))?;
467        self.expect("keyword", Some("FROM"))?;
468        let table = self.parse_ident()?;
469
470        let mut where_clause = None;
471        if self.match_keyword("WHERE") {
472            where_clause = Some(self.parse_or_expr()?);
473        }
474
475        self.expect_end()?;
476        Ok(DeleteQuery {
477            table,
478            where_clause,
479        })
480    }
481
482    fn parse_alter(&mut self) -> Result<Statement, MdqlError> {
483        self.expect("keyword", Some("ALTER"))?;
484        self.expect("keyword", Some("TABLE"))?;
485        let table = self.parse_ident()?;
486
487        let t = self.peek().ok_or_else(|| {
488            MdqlError::QueryParse("Expected RENAME, DROP, or MERGE after table name".into())
489        })?;
490
491        match (t.token_type.as_str(), t.value.as_str()) {
492            ("keyword", "RENAME") => {
493                self.advance();
494                self.expect("keyword", Some("FIELD"))?;
495                let old_name = self.parse_string_or_ident()?;
496                self.expect("keyword", Some("TO"))?;
497                let new_name = self.parse_string_or_ident()?;
498                self.expect_end()?;
499                Ok(Statement::AlterRename(AlterRenameFieldQuery {
500                    table,
501                    old_name,
502                    new_name,
503                }))
504            }
505            ("keyword", "DROP") => {
506                self.advance();
507                self.expect("keyword", Some("FIELD"))?;
508                let field_name = self.parse_string_or_ident()?;
509                self.expect_end()?;
510                Ok(Statement::AlterDrop(AlterDropFieldQuery {
511                    table,
512                    field_name,
513                }))
514            }
515            ("keyword", "MERGE") => {
516                self.advance();
517                self.expect("keyword", Some("FIELDS"))?;
518                let mut sources = vec![self.parse_string_or_ident()?];
519                while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
520                    self.advance();
521                    sources.push(self.parse_string_or_ident()?);
522                }
523                self.expect("keyword", Some("INTO"))?;
524                let target = self.parse_string_or_ident()?;
525                self.expect_end()?;
526                Ok(Statement::AlterMerge(AlterMergeFieldsQuery {
527                    table,
528                    sources,
529                    into: target,
530                }))
531            }
532            _ => Err(MdqlError::QueryParse(format!(
533                "Expected RENAME, DROP, or MERGE, got '{}'",
534                t.raw
535            ))),
536        }
537    }
538
539    fn parse_string_or_ident(&mut self) -> Result<String, MdqlError> {
540        let t = self.peek().ok_or_else(|| {
541            MdqlError::QueryParse("Expected field name, got end of query".into())
542        })?;
543        match t.token_type.as_str() {
544            "string" => {
545                let v = self.advance().value;
546                Ok(v)
547            }
548            "ident" | "keyword" => {
549                let v = self.advance().value;
550                Ok(v)
551            }
552            _ => Err(MdqlError::QueryParse(format!(
553                "Expected field name, got '{}'",
554                t.raw
555            ))),
556        }
557    }
558
559    fn parse_columns(&mut self) -> Result<ColumnList, MdqlError> {
560        if let Some(t) = self.peek() {
561            if t.token_type == "op" && t.value == "*" {
562                self.advance();
563                return Ok(ColumnList::All);
564            }
565        }
566
567        let mut exprs = vec![self.parse_select_expr()?];
568        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
569            self.advance();
570            exprs.push(self.parse_select_expr()?);
571        }
572        Ok(ColumnList::Named(exprs))
573    }
574
575    fn peek_is_agg_func(&self) -> bool {
576        let t = match self.peek() {
577            Some(t) => t,
578            None => return false,
579        };
580        let name_upper = t.value.to_uppercase();
581        if !AGG_FUNCS.contains(&name_upper.as_str()) {
582            return false;
583        }
584        // Only treat as aggregate if followed by (
585        self.tokens
586            .get(self.pos + 1)
587            .map_or(false, |next| next.token_type == "op" && next.value == "(")
588    }
589
590    fn parse_select_expr(&mut self) -> Result<SelectExpr, MdqlError> {
591        let _t = self.peek().ok_or_else(|| {
592            MdqlError::QueryParse("Expected column or aggregate, got end of query".into())
593        })?;
594
595        if self.peek_is_agg_func() {
596            let func_name = self.advance().value.to_uppercase();
597            let func = match func_name.as_str() {
598                "COUNT" => AggFunc::Count,
599                "SUM" => AggFunc::Sum,
600                "AVG" => AggFunc::Avg,
601                "MIN" => AggFunc::Min,
602                "MAX" => AggFunc::Max,
603                _ => unreachable!(),
604            };
605            self.expect("op", Some("("))?;
606            let arg = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "*") {
607                self.advance();
608                "*".to_string()
609            } else {
610                self.parse_ident()?
611            };
612            self.expect("op", Some(")"))?;
613
614            let alias = if self.match_keyword("AS") {
615                Some(self.parse_ident()?)
616            } else {
617                None
618            };
619
620            Ok(SelectExpr::Aggregate { func, arg, alias })
621        } else {
622            let name = self.parse_ident()?;
623            // Optional AS alias for plain columns
624            if self.match_keyword("AS") {
625                let alias = self.parse_ident()?;
626                Ok(SelectExpr::Aggregate {
627                    func: AggFunc::Count, // Won't be used — reusing for alias
628                    arg: name.clone(),
629                    alias: Some(alias),
630                })
631            } else {
632                Ok(SelectExpr::Column(name))
633            }
634        }
635    }
636
637    fn parse_ident(&mut self) -> Result<String, MdqlError> {
638        let t = self.peek().ok_or_else(|| {
639            MdqlError::QueryParse("Expected identifier, got end of query".into())
640        })?;
641        match t.token_type.as_str() {
642            "ident" | "keyword" => {
643                let v = self.advance().value;
644                Ok(v)
645            }
646            _ => Err(MdqlError::QueryParse(format!(
647                "Expected identifier, got '{}'",
648                t.raw
649            ))),
650        }
651    }
652
653    fn parse_or_expr(&mut self) -> Result<WhereClause, MdqlError> {
654        let mut left = self.parse_and_expr()?;
655        while self.match_keyword("OR") {
656            let right = self.parse_and_expr()?;
657            left = WhereClause::BoolOp(BoolOp {
658                op: "OR".into(),
659                left: Box::new(left),
660                right: Box::new(right),
661            });
662        }
663        Ok(left)
664    }
665
666    fn parse_and_expr(&mut self) -> Result<WhereClause, MdqlError> {
667        let mut left = self.parse_comparison()?;
668        while self.match_keyword("AND") {
669            let right = self.parse_comparison()?;
670            left = WhereClause::BoolOp(BoolOp {
671                op: "AND".into(),
672                left: Box::new(left),
673                right: Box::new(right),
674            });
675        }
676        Ok(left)
677    }
678
679    fn parse_comparison(&mut self) -> Result<WhereClause, MdqlError> {
680        // Handle parenthesized expressions
681        if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
682            self.advance();
683            let expr = self.parse_or_expr()?;
684            self.expect("op", Some(")"))?;
685            return Ok(expr);
686        }
687
688        let col = self.parse_ident()?;
689
690        // IS NULL / IS NOT NULL
691        if self.match_keyword("IS") {
692            if self.match_keyword("NOT") {
693                self.expect("keyword", Some("NULL"))?;
694                return Ok(WhereClause::Comparison(Comparison {
695                    column: col,
696                    op: "IS NOT NULL".into(),
697                    value: None,
698                }));
699            }
700            self.expect("keyword", Some("NULL"))?;
701            return Ok(WhereClause::Comparison(Comparison {
702                column: col,
703                op: "IS NULL".into(),
704                value: None,
705            }));
706        }
707
708        // IN (val, val, ...)
709        if self.match_keyword("IN") {
710            self.expect("op", Some("("))?;
711            let mut values = vec![self.parse_value()?];
712            while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
713                self.advance();
714                values.push(self.parse_value()?);
715            }
716            self.expect("op", Some(")"))?;
717            return Ok(WhereClause::Comparison(Comparison {
718                column: col,
719                op: "IN".into(),
720                value: Some(SqlValue::List(values)),
721            }));
722        }
723
724        // LIKE
725        if self.match_keyword("LIKE") {
726            let val = self.parse_value()?;
727            return Ok(WhereClause::Comparison(Comparison {
728                column: col,
729                op: "LIKE".into(),
730                value: Some(val),
731            }));
732        }
733
734        // NOT LIKE
735        if self.match_keyword("NOT") {
736            if self.match_keyword("LIKE") {
737                let val = self.parse_value()?;
738                return Ok(WhereClause::Comparison(Comparison {
739                    column: col,
740                    op: "NOT LIKE".into(),
741                    value: Some(val),
742                }));
743            }
744            return Err(MdqlError::QueryParse("Expected LIKE after NOT".into()));
745        }
746
747        // Standard operators
748        if let Some(t) = self.peek() {
749            if t.token_type == "op" && ["=", "!=", "<", ">", "<=", ">="].contains(&t.value.as_str())
750            {
751                let op = self.advance().value;
752                let val = self.parse_value()?;
753                return Ok(WhereClause::Comparison(Comparison {
754                    column: col,
755                    op,
756                    value: Some(val),
757                }));
758            }
759        }
760
761        let got = self.peek().map_or("end".to_string(), |t| t.raw.clone());
762        Err(MdqlError::QueryParse(format!(
763            "Expected operator after '{}', got '{}'",
764            col, got
765        )))
766    }
767
768    fn parse_value(&mut self) -> Result<SqlValue, MdqlError> {
769        let t = self.peek().ok_or_else(|| {
770            MdqlError::QueryParse("Expected value, got end of query".into())
771        })?;
772        match t.token_type.as_str() {
773            "string" => {
774                let v = self.advance().value;
775                Ok(SqlValue::String(v))
776            }
777            "number" => {
778                let v = self.advance().value;
779                if v.contains('.') {
780                    Ok(SqlValue::Float(v.parse().map_err(|_| {
781                        MdqlError::QueryParse(format!("Invalid float: {}", v))
782                    })?))
783                } else {
784                    Ok(SqlValue::Int(v.parse().map_err(|_| {
785                        MdqlError::QueryParse(format!("Invalid int: {}", v))
786                    })?))
787                }
788            }
789            "keyword" if t.value == "NULL" => {
790                self.advance();
791                Ok(SqlValue::Null)
792            }
793            _ => Err(MdqlError::QueryParse(format!(
794                "Expected value, got '{}'",
795                t.raw
796            ))),
797        }
798    }
799
800    fn parse_order_by(&mut self) -> Result<Vec<OrderSpec>, MdqlError> {
801        let mut specs = vec![self.parse_order_spec()?];
802        while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
803            self.advance();
804            specs.push(self.parse_order_spec()?);
805        }
806        Ok(specs)
807    }
808
809    fn parse_order_spec(&mut self) -> Result<OrderSpec, MdqlError> {
810        let col = self.parse_ident()?;
811        let descending = if self.match_keyword("DESC") {
812            true
813        } else {
814            self.match_keyword("ASC");
815            false
816        };
817        Ok(OrderSpec {
818            column: col,
819            descending,
820        })
821    }
822
823    fn is_clause_keyword(&self, t: &Token) -> bool {
824        t.token_type == "keyword"
825            && ["WHERE", "ORDER", "LIMIT", "JOIN", "ON", "GROUP"].contains(&t.value.as_str())
826    }
827
828    fn expect_end(&self) -> Result<(), MdqlError> {
829        if let Some(t) = self.peek() {
830            return Err(MdqlError::QueryParse(format!(
831                "Unexpected token '{}' at position {}",
832                t.raw, self.pos
833            )));
834        }
835        Ok(())
836    }
837}
838
839pub fn parse_query(sql: &str) -> crate::errors::Result<Statement> {
840    let tokens = tokenize(sql);
841    if tokens.is_empty() {
842        return Err(MdqlError::QueryParse("Empty query".into()));
843    }
844    let mut parser = Parser::new(tokens);
845    parser.parse_statement()
846}
847
848#[cfg(test)]
849mod tests {
850    use super::*;
851
852    #[test]
853    fn test_simple_select() {
854        let stmt = parse_query("SELECT title, status FROM strategies").unwrap();
855        if let Statement::Select(q) = stmt {
856            assert_eq!(q.columns, ColumnList::Named(vec![SelectExpr::Column("title".into()), SelectExpr::Column("status".into())]));
857            assert_eq!(q.table, "strategies");
858        } else {
859            panic!("Expected Select");
860        }
861    }
862
863    #[test]
864    fn test_select_star() {
865        let stmt = parse_query("SELECT * FROM test").unwrap();
866        if let Statement::Select(q) = stmt {
867            assert_eq!(q.columns, ColumnList::All);
868        } else {
869            panic!("Expected Select");
870        }
871    }
872
873    #[test]
874    fn test_where_clause() {
875        let stmt = parse_query("SELECT title FROM test WHERE count > 5").unwrap();
876        if let Statement::Select(q) = stmt {
877            assert!(q.where_clause.is_some());
878        } else {
879            panic!("Expected Select");
880        }
881    }
882
883    #[test]
884    fn test_order_by() {
885        let stmt =
886            parse_query("SELECT title FROM test ORDER BY composite DESC, title ASC").unwrap();
887        if let Statement::Select(q) = stmt {
888            let ob = q.order_by.unwrap();
889            assert_eq!(ob.len(), 2);
890            assert!(ob[0].descending);
891            assert!(!ob[1].descending);
892        } else {
893            panic!("Expected Select");
894        }
895    }
896
897    #[test]
898    fn test_limit() {
899        let stmt = parse_query("SELECT * FROM test LIMIT 10").unwrap();
900        if let Statement::Select(q) = stmt {
901            assert_eq!(q.limit, Some(10));
902        } else {
903            panic!("Expected Select");
904        }
905    }
906
907    #[test]
908    fn test_insert() {
909        let stmt = parse_query(
910            "INSERT INTO test (title, count) VALUES ('Hello', 42)",
911        )
912        .unwrap();
913        if let Statement::Insert(q) = stmt {
914            assert_eq!(q.table, "test");
915            assert_eq!(q.columns, vec!["title", "count"]);
916            assert_eq!(q.values[0], SqlValue::String("Hello".into()));
917            assert_eq!(q.values[1], SqlValue::Int(42));
918        } else {
919            panic!("Expected Insert");
920        }
921    }
922
923    #[test]
924    fn test_update() {
925        let stmt = parse_query("UPDATE test SET status = 'KILLED' WHERE path = 'a.md'").unwrap();
926        if let Statement::Update(q) = stmt {
927            assert_eq!(q.table, "test");
928            assert_eq!(q.assignments.len(), 1);
929            assert!(q.where_clause.is_some());
930        } else {
931            panic!("Expected Update");
932        }
933    }
934
935    #[test]
936    fn test_delete() {
937        let stmt = parse_query("DELETE FROM test WHERE status = 'draft'").unwrap();
938        if let Statement::Delete(q) = stmt {
939            assert_eq!(q.table, "test");
940            assert!(q.where_clause.is_some());
941        } else {
942            panic!("Expected Delete");
943        }
944    }
945
946    #[test]
947    fn test_alter_rename() {
948        let stmt =
949            parse_query("ALTER TABLE test RENAME FIELD 'Summary' TO 'Overview'").unwrap();
950        if let Statement::AlterRename(q) = stmt {
951            assert_eq!(q.old_name, "Summary");
952            assert_eq!(q.new_name, "Overview");
953        } else {
954            panic!("Expected AlterRename");
955        }
956    }
957
958    #[test]
959    fn test_alter_drop() {
960        let stmt = parse_query("ALTER TABLE test DROP FIELD 'Details'").unwrap();
961        if let Statement::AlterDrop(q) = stmt {
962            assert_eq!(q.field_name, "Details");
963        } else {
964            panic!("Expected AlterDrop");
965        }
966    }
967
968    #[test]
969    fn test_alter_merge() {
970        let stmt = parse_query(
971            "ALTER TABLE test MERGE FIELDS 'Entry Rules', 'Exit Rules' INTO 'Trading Rules'",
972        )
973        .unwrap();
974        if let Statement::AlterMerge(q) = stmt {
975            assert_eq!(q.sources, vec!["Entry Rules", "Exit Rules"]);
976            assert_eq!(q.into, "Trading Rules");
977        } else {
978            panic!("Expected AlterMerge");
979        }
980    }
981
982    #[test]
983    fn test_backtick_ident() {
984        let stmt = parse_query("SELECT `Structural Mechanism` FROM test").unwrap();
985        if let Statement::Select(q) = stmt {
986            assert_eq!(
987                q.columns,
988                ColumnList::Named(vec![SelectExpr::Column("Structural Mechanism".into())])
989            );
990        } else {
991            panic!("Expected Select");
992        }
993    }
994
995    #[test]
996    fn test_like_operator() {
997        let stmt = parse_query("SELECT title FROM test WHERE categories LIKE '%defi%'").unwrap();
998        if let Statement::Select(q) = stmt {
999            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1000                assert_eq!(c.op, "LIKE");
1001                assert_eq!(c.value, Some(SqlValue::String("%defi%".into())));
1002            } else {
1003                panic!("Expected LIKE comparison");
1004            }
1005        } else {
1006            panic!("Expected Select");
1007        }
1008    }
1009
1010    #[test]
1011    fn test_in_operator() {
1012        let stmt =
1013            parse_query("SELECT * FROM test WHERE status IN ('ACTIVE', 'LIVE')").unwrap();
1014        if let Statement::Select(q) = stmt {
1015            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1016                assert_eq!(c.op, "IN");
1017            } else {
1018                panic!("Expected IN comparison");
1019            }
1020        } else {
1021            panic!("Expected Select");
1022        }
1023    }
1024
1025    #[test]
1026    fn test_is_null() {
1027        let stmt = parse_query("SELECT * FROM test WHERE title IS NULL").unwrap();
1028        if let Statement::Select(q) = stmt {
1029            if let Some(WhereClause::Comparison(c)) = q.where_clause {
1030                assert_eq!(c.op, "IS NULL");
1031            } else {
1032                panic!("Expected IS NULL comparison");
1033            }
1034        } else {
1035            panic!("Expected Select");
1036        }
1037    }
1038
1039    #[test]
1040    fn test_and_or() {
1041        let stmt = parse_query(
1042            "SELECT * FROM test WHERE status = 'ACTIVE' AND count > 5 OR title LIKE '%test%'",
1043        )
1044        .unwrap();
1045        if let Statement::Select(q) = stmt {
1046            assert!(q.where_clause.is_some());
1047        } else {
1048            panic!("Expected Select");
1049        }
1050    }
1051
1052    #[test]
1053    fn test_join() {
1054        let stmt = parse_query(
1055            "SELECT s.title, b.sharpe FROM strategies s JOIN backtests b ON b.strategy = s.path",
1056        )
1057        .unwrap();
1058        if let Statement::Select(q) = stmt {
1059            assert_eq!(q.table, "strategies");
1060            assert_eq!(q.table_alias, Some("s".into()));
1061            assert_eq!(q.joins.len(), 1);
1062            let join = &q.joins[0];
1063            assert_eq!(join.table, "backtests");
1064            assert_eq!(join.alias, Some("b".into()));
1065        } else {
1066            panic!("Expected Select");
1067        }
1068    }
1069
1070    #[test]
1071    fn test_multi_join() {
1072        let stmt = parse_query(
1073            "SELECT s.title, b.sharpe, c.verdict FROM strategies s JOIN backtests b ON b.strategy = s.path JOIN critiques c ON c.strategy = s.path",
1074        )
1075        .unwrap();
1076        if let Statement::Select(q) = stmt {
1077            assert_eq!(q.table, "strategies");
1078            assert_eq!(q.table_alias, Some("s".into()));
1079            assert_eq!(q.joins.len(), 2);
1080            assert_eq!(q.joins[0].table, "backtests");
1081            assert_eq!(q.joins[0].alias, Some("b".into()));
1082            assert_eq!(q.joins[0].left_col, "b.strategy");
1083            assert_eq!(q.joins[0].right_col, "s.path");
1084            assert_eq!(q.joins[1].table, "critiques");
1085            assert_eq!(q.joins[1].alias, Some("c".into()));
1086            assert_eq!(q.joins[1].left_col, "c.strategy");
1087            assert_eq!(q.joins[1].right_col, "s.path");
1088        } else {
1089            panic!("Expected Select");
1090        }
1091    }
1092
1093    #[test]
1094    fn test_empty_query() {
1095        assert!(parse_query("").is_err());
1096    }
1097
1098    #[test]
1099    fn test_count_star() {
1100        let stmt = parse_query("SELECT status, COUNT(*) AS cnt FROM strategies GROUP BY status").unwrap();
1101        if let Statement::Select(q) = stmt {
1102            if let ColumnList::Named(exprs) = &q.columns {
1103                assert_eq!(exprs.len(), 2);
1104                assert_eq!(exprs[0], SelectExpr::Column("status".into()));
1105                assert!(matches!(&exprs[1], SelectExpr::Aggregate {
1106                    func: AggFunc::Count,
1107                    arg,
1108                    alias: Some(a),
1109                } if arg == "*" && a == "cnt"));
1110            } else {
1111                panic!("Expected Named columns");
1112            }
1113            assert_eq!(q.group_by, Some(vec!["status".into()]));
1114        } else {
1115            panic!("Expected Select");
1116        }
1117    }
1118
1119    #[test]
1120    fn test_count_column_as_ident() {
1121        // "count" as a column name should NOT be parsed as the COUNT aggregate
1122        let stmt = parse_query("INSERT INTO test (title, count) VALUES ('Hello', 42)").unwrap();
1123        if let Statement::Insert(q) = stmt {
1124            assert_eq!(q.columns, vec!["title", "count"]);
1125        } else {
1126            panic!("Expected Insert");
1127        }
1128    }
1129
1130    #[test]
1131    fn test_multiple_aggregates() {
1132        let stmt = parse_query("SELECT MIN(composite), MAX(composite), AVG(composite) FROM strategies").unwrap();
1133        if let Statement::Select(q) = stmt {
1134            if let ColumnList::Named(exprs) = &q.columns {
1135                assert_eq!(exprs.len(), 3);
1136                assert!(matches!(&exprs[0], SelectExpr::Aggregate { func: AggFunc::Min, .. }));
1137                assert!(matches!(&exprs[1], SelectExpr::Aggregate { func: AggFunc::Max, .. }));
1138                assert!(matches!(&exprs[2], SelectExpr::Aggregate { func: AggFunc::Avg, .. }));
1139            } else {
1140                panic!("Expected Named columns");
1141            }
1142            assert_eq!(q.group_by, None);
1143        } else {
1144            panic!("Expected Select");
1145        }
1146    }
1147}