1use regex::Regex;
4use std::sync::LazyLock;
5
6use crate::errors::MdqlError;
7
8#[derive(Debug, Clone, PartialEq)]
11pub enum ArithOp {
12 Add,
13 Sub,
14 Mul,
15 Div,
16 Mod,
17}
18
19#[derive(Debug, Clone, PartialEq)]
20pub enum Expr {
21 Literal(SqlValue),
22 Column(String),
23 BinaryOp { left: Box<Expr>, op: ArithOp, right: Box<Expr> },
24 UnaryMinus(Box<Expr>),
25 Case { whens: Vec<(WhereClause, Box<Expr>)>, else_expr: Option<Box<Expr>> },
26}
27
28impl Expr {
29 pub fn as_column(&self) -> Option<&str> {
31 if let Expr::Column(name) = self { Some(name) } else { None }
32 }
33
34 pub fn display_name(&self) -> String {
36 match self {
37 Expr::Literal(SqlValue::Int(n)) => n.to_string(),
38 Expr::Literal(SqlValue::Float(f)) => f.to_string(),
39 Expr::Literal(SqlValue::String(s)) => format!("'{}'", s),
40 Expr::Literal(SqlValue::Null) => "NULL".to_string(),
41 Expr::Literal(SqlValue::List(_)) => "list".to_string(),
42 Expr::Column(name) => name.clone(),
43 Expr::BinaryOp { left, op, right } => {
44 let op_str = match op {
45 ArithOp::Add => "+",
46 ArithOp::Sub => "-",
47 ArithOp::Mul => "*",
48 ArithOp::Div => "/",
49 ArithOp::Mod => "%",
50 };
51 format!("{} {} {}", left.display_name(), op_str, right.display_name())
52 }
53 Expr::UnaryMinus(inner) => format!("-{}", inner.display_name()),
54 Expr::Case { .. } => "CASE".to_string(),
55 }
56 }
57}
58
59#[derive(Debug, Clone, PartialEq)]
60pub struct OrderSpec {
61 pub column: String,
62 pub expr: Option<Expr>,
63 pub descending: bool,
64}
65
66#[derive(Debug, Clone, PartialEq)]
67pub struct Comparison {
68 pub column: String,
69 pub op: String,
70 pub value: Option<SqlValue>,
71 pub left_expr: Option<Expr>,
72 pub right_expr: Option<Expr>,
73}
74
75#[derive(Debug, Clone, PartialEq)]
76pub struct BoolOp {
77 pub op: String, pub left: Box<WhereClause>,
79 pub right: Box<WhereClause>,
80}
81
82#[derive(Debug, Clone, PartialEq)]
83pub enum WhereClause {
84 Comparison(Comparison),
85 BoolOp(BoolOp),
86}
87
88#[derive(Debug, Clone, PartialEq)]
89pub enum SqlValue {
90 String(String),
91 Int(i64),
92 Float(f64),
93 Null,
94 List(Vec<SqlValue>),
95}
96
97#[derive(Debug, Clone, PartialEq)]
98pub struct JoinClause {
99 pub table: String,
100 pub alias: Option<String>,
101 pub left_col: String,
102 pub right_col: String,
103}
104
105#[derive(Debug, Clone, PartialEq)]
106pub enum AggFunc {
107 Count,
108 Sum,
109 Avg,
110 Min,
111 Max,
112}
113
114#[derive(Debug, Clone, PartialEq)]
115pub enum SelectExpr {
116 Column(String),
117 Aggregate { func: AggFunc, arg: String, arg_expr: Option<Expr>, alias: Option<String> },
118 Expr { expr: Expr, alias: Option<String> },
119}
120
121impl SelectExpr {
122 pub fn output_name(&self) -> String {
123 match self {
124 SelectExpr::Column(name) => name.clone(),
125 SelectExpr::Aggregate { func, arg, alias, .. } => {
126 if let Some(a) = alias {
127 a.clone()
128 } else {
129 let func_name = match func {
130 AggFunc::Count => "COUNT",
131 AggFunc::Sum => "SUM",
132 AggFunc::Avg => "AVG",
133 AggFunc::Min => "MIN",
134 AggFunc::Max => "MAX",
135 };
136 format!("{}({})", func_name, arg)
137 }
138 }
139 SelectExpr::Expr { expr, alias } => {
140 alias.clone().unwrap_or_else(|| expr.display_name())
141 }
142 }
143 }
144
145 pub fn is_aggregate(&self) -> bool {
146 matches!(self, SelectExpr::Aggregate { .. })
147 }
148}
149
150#[derive(Debug, Clone, PartialEq)]
151pub struct SelectQuery {
152 pub columns: ColumnList,
153 pub table: String,
154 pub table_alias: Option<String>,
155 pub joins: Vec<JoinClause>,
156 pub where_clause: Option<WhereClause>,
157 pub group_by: Option<Vec<String>>,
158 pub order_by: Option<Vec<OrderSpec>>,
159 pub limit: Option<i64>,
160}
161
162#[derive(Debug, Clone, PartialEq)]
163pub enum ColumnList {
164 All,
165 Named(Vec<SelectExpr>),
166}
167
168#[derive(Debug, Clone, PartialEq)]
169pub struct InsertQuery {
170 pub table: String,
171 pub columns: Vec<String>,
172 pub values: Vec<SqlValue>,
173}
174
175#[derive(Debug, Clone, PartialEq)]
176pub struct UpdateQuery {
177 pub table: String,
178 pub assignments: Vec<(String, SqlValue)>,
179 pub where_clause: Option<WhereClause>,
180}
181
182#[derive(Debug, Clone, PartialEq)]
183pub struct DeleteQuery {
184 pub table: String,
185 pub where_clause: Option<WhereClause>,
186}
187
188#[derive(Debug, Clone, PartialEq)]
189pub struct AlterRenameFieldQuery {
190 pub table: String,
191 pub old_name: String,
192 pub new_name: String,
193}
194
195#[derive(Debug, Clone, PartialEq)]
196pub struct AlterDropFieldQuery {
197 pub table: String,
198 pub field_name: String,
199}
200
201#[derive(Debug, Clone, PartialEq)]
202pub struct AlterMergeFieldsQuery {
203 pub table: String,
204 pub sources: Vec<String>,
205 pub into: String,
206}
207
208#[derive(Debug, Clone, PartialEq)]
209pub enum Statement {
210 Select(SelectQuery),
211 Insert(InsertQuery),
212 Update(UpdateQuery),
213 Delete(DeleteQuery),
214 AlterRename(AlterRenameFieldQuery),
215 AlterDrop(AlterDropFieldQuery),
216 AlterMerge(AlterMergeFieldsQuery),
217}
218
219static KEYWORDS: &[&str] = &[
222 "SELECT", "FROM", "WHERE", "AND", "OR", "ORDER", "BY",
223 "ASC", "DESC", "LIMIT", "LIKE", "IN", "IS", "NOT", "NULL",
224 "JOIN", "ON", "AS", "GROUP",
225 "INSERT", "INTO", "VALUES", "UPDATE", "SET", "DELETE",
226 "ALTER", "TABLE", "RENAME", "FIELD", "TO", "DROP", "MERGE", "FIELDS",
227 "CASE", "WHEN", "THEN", "ELSE", "END",
228];
229
230static AGG_FUNCS: &[&str] = &["COUNT", "SUM", "AVG", "MIN", "MAX"];
231
232static TOKEN_RE: LazyLock<Regex> = LazyLock::new(|| {
233 Regex::new(
234 r#"(?x)
235 \s*(?:
236 (?P<backtick>`[^`]+`)
237 | (?P<string>'(?:[^'\\]|\\.)*')
238 | (?P<number>\d+(?:\.\d+)?)
239 | (?P<op><=|>=|!=|[=<>,*()+\-/%])
240 | (?P<word>[A-Za-z_][A-Za-z0-9_./-]*)
241 )"#,
242 )
243 .unwrap()
244});
245
246#[derive(Debug, Clone)]
247struct Token {
248 token_type: String,
249 value: String,
250 raw: String,
251}
252
253fn tokenize(sql: &str) -> Vec<Token> {
254 let mut tokens = Vec::new();
255 for caps in TOKEN_RE.captures_iter(sql) {
256 if let Some(m) = caps.name("backtick") {
257 let raw = m.as_str();
258 tokens.push(Token {
259 token_type: "ident".into(),
260 value: raw[1..raw.len() - 1].into(),
261 raw: raw.into(),
262 });
263 } else if let Some(m) = caps.name("string") {
264 let raw = m.as_str();
265 tokens.push(Token {
266 token_type: "string".into(),
267 value: raw[1..raw.len() - 1].into(),
268 raw: raw.into(),
269 });
270 } else if let Some(m) = caps.name("number") {
271 let raw = m.as_str();
272 tokens.push(Token {
273 token_type: "number".into(),
274 value: raw.into(),
275 raw: raw.into(),
276 });
277 } else if let Some(m) = caps.name("op") {
278 let raw = m.as_str();
279 tokens.push(Token {
280 token_type: "op".into(),
281 value: raw.into(),
282 raw: raw.into(),
283 });
284 } else if let Some(m) = caps.name("word") {
285 let raw = m.as_str();
286 if KEYWORDS.contains(&raw.to_uppercase().as_str()) {
287 tokens.push(Token {
288 token_type: "keyword".into(),
289 value: raw.to_uppercase(),
290 raw: raw.into(),
291 });
292 } else {
293 tokens.push(Token {
294 token_type: "ident".into(),
295 value: raw.into(),
296 raw: raw.into(),
297 });
298 }
299 }
300 }
301 tokens
302}
303
304struct Parser {
307 tokens: Vec<Token>,
308 pos: usize,
309}
310
311impl Parser {
312 fn new(tokens: Vec<Token>) -> Self {
313 Parser { tokens, pos: 0 }
314 }
315
316 fn peek(&self) -> Option<&Token> {
317 self.tokens.get(self.pos)
318 }
319
320 fn advance(&mut self) -> Token {
321 let t = self.tokens[self.pos].clone();
322 self.pos += 1;
323 t
324 }
325
326 fn expect(&mut self, type_: &str, value: Option<&str>) -> Result<Token, MdqlError> {
327 let t = self.peek().ok_or_else(|| {
328 MdqlError::QueryParse(format!(
329 "Unexpected end of query, expected {}",
330 value.unwrap_or(type_)
331 ))
332 })?;
333 let matches_type = t.token_type == type_;
334 let matches_value = value.map_or(true, |v| t.value == v);
335 if !matches_type || !matches_value {
336 return Err(MdqlError::QueryParse(format!(
337 "Expected {}, got '{}' at position {}",
338 value.unwrap_or(type_),
339 t.raw,
340 self.pos
341 )));
342 }
343 Ok(self.advance())
344 }
345
346 fn match_keyword(&mut self, kw: &str) -> bool {
347 if let Some(t) = self.peek() {
348 if t.token_type == "keyword" && t.value == kw {
349 self.advance();
350 return true;
351 }
352 }
353 false
354 }
355
356 fn parse_statement(&mut self) -> Result<Statement, MdqlError> {
357 let t = self.peek().ok_or_else(|| MdqlError::QueryParse("Empty query".into()))?;
358 match (t.token_type.as_str(), t.value.as_str()) {
359 ("keyword", "SELECT") => Ok(Statement::Select(self.parse_select()?)),
360 ("keyword", "INSERT") => Ok(Statement::Insert(self.parse_insert()?)),
361 ("keyword", "UPDATE") => Ok(Statement::Update(self.parse_update()?)),
362 ("keyword", "DELETE") => Ok(Statement::Delete(self.parse_delete()?)),
363 ("keyword", "ALTER") => self.parse_alter(),
364 _ => Err(MdqlError::QueryParse(format!(
365 "Expected SELECT, INSERT, UPDATE, DELETE, or ALTER, got '{}'",
366 t.raw
367 ))),
368 }
369 }
370
371 fn parse_select(&mut self) -> Result<SelectQuery, MdqlError> {
372 self.expect("keyword", Some("SELECT"))?;
373 let columns = self.parse_columns()?;
374 self.expect("keyword", Some("FROM"))?;
375 let table = self.parse_ident()?;
376
377 let mut table_alias = None;
379 if let Some(t) = self.peek() {
380 if t.token_type == "ident" && !self.is_clause_keyword(t) {
381 table_alias = Some(self.advance().value);
382 }
383 }
384
385 let mut joins = Vec::new();
387 while self.match_keyword("JOIN") {
388 let join_table = self.parse_ident()?;
389 let mut join_alias = None;
390 if let Some(t) = self.peek() {
391 if t.token_type == "ident" && !self.is_clause_keyword(t) {
392 join_alias = Some(self.advance().value);
393 }
394 }
395 self.expect("keyword", Some("ON"))?;
396 let left_col = self.parse_ident()?;
397 self.expect("op", Some("="))?;
398 let right_col = self.parse_ident()?;
399 joins.push(JoinClause {
400 table: join_table,
401 alias: join_alias,
402 left_col,
403 right_col,
404 });
405 }
406
407 let mut where_clause = None;
408 if self.match_keyword("WHERE") {
409 where_clause = Some(self.parse_or_expr()?);
410 }
411
412 let mut group_by = None;
413 if self.match_keyword("GROUP") {
414 self.expect("keyword", Some("BY"))?;
415 let mut cols = vec![self.parse_ident()?];
416 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
417 self.advance();
418 cols.push(self.parse_ident()?);
419 }
420 group_by = Some(cols);
421 }
422
423 let mut order_by = None;
424 if self.match_keyword("ORDER") {
425 self.expect("keyword", Some("BY"))?;
426 order_by = Some(self.parse_order_by()?);
427 }
428
429 let mut limit = None;
430 if self.match_keyword("LIMIT") {
431 let t = self.expect("number", None)?;
432 limit = Some(t.value.parse::<i64>().map_err(|_| {
433 MdqlError::QueryParse(format!("Invalid LIMIT value: {}", t.value))
434 })?);
435 }
436
437 self.expect_end()?;
438
439 Ok(SelectQuery {
440 columns,
441 table,
442 table_alias,
443 joins,
444 where_clause,
445 group_by,
446 order_by,
447 limit,
448 })
449 }
450
451 fn parse_insert(&mut self) -> Result<InsertQuery, MdqlError> {
452 self.expect("keyword", Some("INSERT"))?;
453 self.expect("keyword", Some("INTO"))?;
454 let table = self.parse_ident()?;
455
456 self.expect("op", Some("("))?;
457 let mut columns = vec![self.parse_ident()?];
458 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
459 self.advance();
460 columns.push(self.parse_ident()?);
461 }
462 self.expect("op", Some(")"))?;
463
464 self.expect("keyword", Some("VALUES"))?;
465
466 self.expect("op", Some("("))?;
467 let mut values = vec![self.parse_value()?];
468 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
469 self.advance();
470 values.push(self.parse_value()?);
471 }
472 self.expect("op", Some(")"))?;
473
474 if columns.len() != values.len() {
475 return Err(MdqlError::QueryParse(format!(
476 "Column count ({}) does not match value count ({})",
477 columns.len(),
478 values.len()
479 )));
480 }
481
482 self.expect_end()?;
483 Ok(InsertQuery {
484 table,
485 columns,
486 values,
487 })
488 }
489
490 fn parse_update(&mut self) -> Result<UpdateQuery, MdqlError> {
491 self.expect("keyword", Some("UPDATE"))?;
492 let table = self.parse_ident()?;
493 self.expect("keyword", Some("SET"))?;
494
495 let mut assignments = Vec::new();
496 let col = self.parse_ident()?;
497 self.expect("op", Some("="))?;
498 let val = self.parse_value()?;
499 assignments.push((col, val));
500
501 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
502 self.advance();
503 let col = self.parse_ident()?;
504 self.expect("op", Some("="))?;
505 let val = self.parse_value()?;
506 assignments.push((col, val));
507 }
508
509 let mut where_clause = None;
510 if self.match_keyword("WHERE") {
511 where_clause = Some(self.parse_or_expr()?);
512 }
513
514 self.expect_end()?;
515 Ok(UpdateQuery {
516 table,
517 assignments,
518 where_clause,
519 })
520 }
521
522 fn parse_delete(&mut self) -> Result<DeleteQuery, MdqlError> {
523 self.expect("keyword", Some("DELETE"))?;
524 self.expect("keyword", Some("FROM"))?;
525 let table = self.parse_ident()?;
526
527 let mut where_clause = None;
528 if self.match_keyword("WHERE") {
529 where_clause = Some(self.parse_or_expr()?);
530 }
531
532 self.expect_end()?;
533 Ok(DeleteQuery {
534 table,
535 where_clause,
536 })
537 }
538
539 fn parse_alter(&mut self) -> Result<Statement, MdqlError> {
540 self.expect("keyword", Some("ALTER"))?;
541 self.expect("keyword", Some("TABLE"))?;
542 let table = self.parse_ident()?;
543
544 let t = self.peek().ok_or_else(|| {
545 MdqlError::QueryParse("Expected RENAME, DROP, or MERGE after table name".into())
546 })?;
547
548 match (t.token_type.as_str(), t.value.as_str()) {
549 ("keyword", "RENAME") => {
550 self.advance();
551 self.expect("keyword", Some("FIELD"))?;
552 let old_name = self.parse_string_or_ident()?;
553 self.expect("keyword", Some("TO"))?;
554 let new_name = self.parse_string_or_ident()?;
555 self.expect_end()?;
556 Ok(Statement::AlterRename(AlterRenameFieldQuery {
557 table,
558 old_name,
559 new_name,
560 }))
561 }
562 ("keyword", "DROP") => {
563 self.advance();
564 self.expect("keyword", Some("FIELD"))?;
565 let field_name = self.parse_string_or_ident()?;
566 self.expect_end()?;
567 Ok(Statement::AlterDrop(AlterDropFieldQuery {
568 table,
569 field_name,
570 }))
571 }
572 ("keyword", "MERGE") => {
573 self.advance();
574 self.expect("keyword", Some("FIELDS"))?;
575 let mut sources = vec![self.parse_string_or_ident()?];
576 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
577 self.advance();
578 sources.push(self.parse_string_or_ident()?);
579 }
580 self.expect("keyword", Some("INTO"))?;
581 let target = self.parse_string_or_ident()?;
582 self.expect_end()?;
583 Ok(Statement::AlterMerge(AlterMergeFieldsQuery {
584 table,
585 sources,
586 into: target,
587 }))
588 }
589 _ => Err(MdqlError::QueryParse(format!(
590 "Expected RENAME, DROP, or MERGE, got '{}'",
591 t.raw
592 ))),
593 }
594 }
595
596 fn parse_string_or_ident(&mut self) -> Result<String, MdqlError> {
597 let t = self.peek().ok_or_else(|| {
598 MdqlError::QueryParse("Expected field name, got end of query".into())
599 })?;
600 match t.token_type.as_str() {
601 "string" => {
602 let v = self.advance().value;
603 Ok(v)
604 }
605 "ident" | "keyword" => {
606 let v = self.advance().value;
607 Ok(v)
608 }
609 _ => Err(MdqlError::QueryParse(format!(
610 "Expected field name, got '{}'",
611 t.raw
612 ))),
613 }
614 }
615
616 fn parse_columns(&mut self) -> Result<ColumnList, MdqlError> {
617 if let Some(t) = self.peek() {
618 if t.token_type == "op" && t.value == "*" {
619 self.advance();
620 return Ok(ColumnList::All);
621 }
622 }
623
624 let mut exprs = vec![self.parse_select_expr()?];
625 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
626 self.advance();
627 exprs.push(self.parse_select_expr()?);
628 }
629 Ok(ColumnList::Named(exprs))
630 }
631
632 fn peek_is_agg_func(&self) -> bool {
633 let t = match self.peek() {
634 Some(t) => t,
635 None => return false,
636 };
637 let name_upper = t.value.to_uppercase();
638 if !AGG_FUNCS.contains(&name_upper.as_str()) {
639 return false;
640 }
641 self.tokens
643 .get(self.pos + 1)
644 .map_or(false, |next| next.token_type == "op" && next.value == "(")
645 }
646
647 fn parse_select_expr(&mut self) -> Result<SelectExpr, MdqlError> {
648 let _t = self.peek().ok_or_else(|| {
649 MdqlError::QueryParse("Expected column or aggregate, got end of query".into())
650 })?;
651
652 if self.peek_is_agg_func() {
653 let func_name = self.advance().value.to_uppercase();
654 let func = match func_name.as_str() {
655 "COUNT" => AggFunc::Count,
656 "SUM" => AggFunc::Sum,
657 "AVG" => AggFunc::Avg,
658 "MIN" => AggFunc::Min,
659 "MAX" => AggFunc::Max,
660 _ => unreachable!(),
661 };
662 self.expect("op", Some("("))?;
663 let (arg, arg_expr) = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "*") {
664 self.advance();
665 ("*".to_string(), None)
666 } else {
667 let expr = self.parse_additive()?;
669 if let Expr::Column(name) = &expr {
670 (name.clone(), None)
671 } else {
672 (expr.display_name(), Some(expr))
673 }
674 };
675 self.expect("op", Some(")"))?;
676
677 let alias = if self.match_keyword("AS") {
678 Some(self.parse_ident()?)
679 } else if self.peek().map_or(false, |t| {
680 t.token_type == "ident" && !self.is_clause_keyword(t)
681 }) {
682 Some(self.advance().value)
683 } else {
684 None
685 };
686
687 Ok(SelectExpr::Aggregate { func, arg, arg_expr, alias })
688 } else {
689 let expr = self.parse_additive()?;
691
692 let alias = if self.match_keyword("AS") {
694 Some(self.parse_ident()?)
695 } else if self.peek().map_or(false, |t| {
696 t.token_type == "ident" && !self.is_clause_keyword(t)
697 }) {
698 Some(self.advance().value)
699 } else {
700 None
701 };
702
703 if alias.is_none() {
706 if let Expr::Column(name) = &expr {
707 return Ok(SelectExpr::Column(name.clone()));
708 }
709 }
710
711 Ok(SelectExpr::Expr { expr, alias })
712 }
713 }
714
715 fn peek_is_additive_op(&self) -> bool {
718 self.peek().map_or(false, |t| {
719 t.token_type == "op" && (t.value == "+" || t.value == "-")
720 })
721 }
722
723 fn peek_is_multiplicative_op(&self) -> bool {
724 self.peek().map_or(false, |t| {
725 t.token_type == "op" && (t.value == "*" || t.value == "/" || t.value == "%")
726 })
727 }
728
729 fn parse_additive(&mut self) -> Result<Expr, MdqlError> {
730 let mut left = self.parse_multiplicative()?;
731 while self.peek_is_additive_op() {
732 let op_tok = self.advance();
733 let op = match op_tok.value.as_str() {
734 "+" => ArithOp::Add,
735 "-" => ArithOp::Sub,
736 _ => unreachable!(),
737 };
738 let right = self.parse_multiplicative()?;
739 left = Expr::BinaryOp {
740 left: Box::new(left),
741 op,
742 right: Box::new(right),
743 };
744 }
745 Ok(left)
746 }
747
748 fn parse_multiplicative(&mut self) -> Result<Expr, MdqlError> {
749 let mut left = self.parse_unary()?;
750 while self.peek_is_multiplicative_op() {
751 let op_tok = self.advance();
752 let op = match op_tok.value.as_str() {
753 "*" => ArithOp::Mul,
754 "/" => ArithOp::Div,
755 "%" => ArithOp::Mod,
756 _ => unreachable!(),
757 };
758 let right = self.parse_unary()?;
759 left = Expr::BinaryOp {
760 left: Box::new(left),
761 op,
762 right: Box::new(right),
763 };
764 }
765 Ok(left)
766 }
767
768 fn parse_unary(&mut self) -> Result<Expr, MdqlError> {
769 if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "-") {
770 self.advance();
771 let inner = self.parse_atom()?;
772 match inner {
774 Expr::Literal(SqlValue::Int(n)) => Ok(Expr::Literal(SqlValue::Int(-n))),
775 Expr::Literal(SqlValue::Float(f)) => Ok(Expr::Literal(SqlValue::Float(-f))),
776 _ => Ok(Expr::UnaryMinus(Box::new(inner))),
777 }
778 } else {
779 self.parse_atom()
780 }
781 }
782
783 fn parse_atom(&mut self) -> Result<Expr, MdqlError> {
784 let t = self.peek().ok_or_else(|| {
785 MdqlError::QueryParse("Expected expression, got end of query".into())
786 })?;
787
788 match t.token_type.as_str() {
789 "number" => {
790 let v = self.advance().value;
791 if v.contains('.') {
792 let f: f64 = v.parse().map_err(|_| {
793 MdqlError::QueryParse(format!("Invalid float: {}", v))
794 })?;
795 Ok(Expr::Literal(SqlValue::Float(f)))
796 } else {
797 let n: i64 = v.parse().map_err(|_| {
798 MdqlError::QueryParse(format!("Invalid int: {}", v))
799 })?;
800 Ok(Expr::Literal(SqlValue::Int(n)))
801 }
802 }
803 "string" => {
804 let v = self.advance().value;
805 Ok(Expr::Literal(SqlValue::String(v)))
806 }
807 "keyword" if t.value == "NULL" => {
808 self.advance();
809 Ok(Expr::Literal(SqlValue::Null))
810 }
811 "keyword" if t.value == "CASE" => {
812 self.parse_case_expr()
813 }
814 "op" if t.value == "(" => {
815 self.advance();
816 let expr = self.parse_additive()?;
817 self.expect("op", Some(")"))?;
818 Ok(expr)
819 }
820 "ident" => {
821 let name = self.advance().value;
822 Ok(Expr::Column(name))
823 }
824 "keyword" if !Self::is_reserved_keyword(&t.value) => {
825 let name = self.advance().value;
826 Ok(Expr::Column(name))
827 }
828 _ => Err(MdqlError::QueryParse(format!(
829 "Expected expression, got '{}'",
830 t.raw
831 ))),
832 }
833 }
834
835 fn parse_case_expr(&mut self) -> Result<Expr, MdqlError> {
836 self.expect("keyword", Some("CASE"))?;
837 let mut whens = Vec::new();
838 while self.match_keyword("WHEN") {
839 let condition = self.parse_or_expr()?;
840 self.expect("keyword", Some("THEN"))?;
841 let result = self.parse_additive()?;
842 whens.push((condition, Box::new(result)));
843 }
844 if whens.is_empty() {
845 return Err(MdqlError::QueryParse("CASE requires at least one WHEN clause".into()));
846 }
847 let else_expr = if self.match_keyword("ELSE") {
848 Some(Box::new(self.parse_additive()?))
849 } else {
850 None
851 };
852 self.expect("keyword", Some("END"))?;
853 Ok(Expr::Case { whens, else_expr })
854 }
855
856 fn parse_ident(&mut self) -> Result<String, MdqlError> {
857 let t = self.peek().ok_or_else(|| {
858 MdqlError::QueryParse("Expected identifier, got end of query".into())
859 })?;
860 match t.token_type.as_str() {
861 "ident" | "keyword" => {
862 let v = self.advance().value;
863 Ok(v)
864 }
865 _ => Err(MdqlError::QueryParse(format!(
866 "Expected identifier, got '{}'",
867 t.raw
868 ))),
869 }
870 }
871
872 fn parse_or_expr(&mut self) -> Result<WhereClause, MdqlError> {
873 let mut left = self.parse_and_expr()?;
874 while self.match_keyword("OR") {
875 let right = self.parse_and_expr()?;
876 left = WhereClause::BoolOp(BoolOp {
877 op: "OR".into(),
878 left: Box::new(left),
879 right: Box::new(right),
880 });
881 }
882 Ok(left)
883 }
884
885 fn parse_and_expr(&mut self) -> Result<WhereClause, MdqlError> {
886 let mut left = self.parse_comparison()?;
887 while self.match_keyword("AND") {
888 let right = self.parse_comparison()?;
889 left = WhereClause::BoolOp(BoolOp {
890 op: "AND".into(),
891 left: Box::new(left),
892 right: Box::new(right),
893 });
894 }
895 Ok(left)
896 }
897
898 fn parse_comparison(&mut self) -> Result<WhereClause, MdqlError> {
899 if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
901 let saved_pos = self.pos;
903 self.advance();
904 let result = self.parse_or_expr();
906 if result.is_ok() && self.peek().map_or(false, |t| t.token_type == "op" && t.value == ")") {
907 self.advance();
908 return result;
909 }
910 self.pos = saved_pos;
912 }
913
914 let left_expr = self.parse_additive()?;
916
917 let col = left_expr.as_column().unwrap_or("").to_string();
919
920 if self.match_keyword("IS") {
922 if self.match_keyword("NOT") {
923 self.expect("keyword", Some("NULL"))?;
924 return Ok(WhereClause::Comparison(Comparison {
925 column: col,
926 op: "IS NOT NULL".into(),
927 value: None,
928 left_expr: Some(left_expr),
929 right_expr: None,
930 }));
931 }
932 self.expect("keyword", Some("NULL"))?;
933 return Ok(WhereClause::Comparison(Comparison {
934 column: col,
935 op: "IS NULL".into(),
936 value: None,
937 left_expr: Some(left_expr),
938 right_expr: None,
939 }));
940 }
941
942 if self.match_keyword("IN") {
944 self.expect("op", Some("("))?;
945 let mut values = vec![self.parse_value()?];
946 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
947 self.advance();
948 values.push(self.parse_value()?);
949 }
950 self.expect("op", Some(")"))?;
951 return Ok(WhereClause::Comparison(Comparison {
952 column: col,
953 op: "IN".into(),
954 value: Some(SqlValue::List(values)),
955 left_expr: Some(left_expr),
956 right_expr: None,
957 }));
958 }
959
960 if self.match_keyword("LIKE") {
962 let val = self.parse_value()?;
963 return Ok(WhereClause::Comparison(Comparison {
964 column: col,
965 op: "LIKE".into(),
966 value: Some(val),
967 left_expr: Some(left_expr),
968 right_expr: None,
969 }));
970 }
971
972 if self.match_keyword("NOT") {
974 if self.match_keyword("LIKE") {
975 let val = self.parse_value()?;
976 return Ok(WhereClause::Comparison(Comparison {
977 column: col,
978 op: "NOT LIKE".into(),
979 value: Some(val),
980 left_expr: Some(left_expr),
981 right_expr: None,
982 }));
983 }
984 return Err(MdqlError::QueryParse("Expected LIKE after NOT".into()));
985 }
986
987 if let Some(t) = self.peek() {
989 if t.token_type == "op" && ["=", "!=", "<", ">", "<=", ">="].contains(&t.value.as_str())
990 {
991 let op = self.advance().value;
992 let right_expr = self.parse_additive()?;
994 let value = match &right_expr {
996 Expr::Literal(v) => Some(v.clone()),
997 _ => None,
998 };
999 return Ok(WhereClause::Comparison(Comparison {
1000 column: col,
1001 op,
1002 value,
1003 left_expr: Some(left_expr),
1004 right_expr: Some(right_expr),
1005 }));
1006 }
1007 }
1008
1009 let got = self.peek().map_or("end".to_string(), |t| t.raw.clone());
1010 Err(MdqlError::QueryParse(format!(
1011 "Expected operator after '{}', got '{}'",
1012 left_expr.display_name(), got
1013 )))
1014 }
1015
1016 fn parse_value(&mut self) -> Result<SqlValue, MdqlError> {
1017 let t = self.peek().ok_or_else(|| {
1018 MdqlError::QueryParse("Expected value, got end of query".into())
1019 })?;
1020 match t.token_type.as_str() {
1021 "string" => {
1022 let v = self.advance().value;
1023 Ok(SqlValue::String(v))
1024 }
1025 "number" => {
1026 let v = self.advance().value;
1027 if v.contains('.') {
1028 Ok(SqlValue::Float(v.parse().map_err(|_| {
1029 MdqlError::QueryParse(format!("Invalid float: {}", v))
1030 })?))
1031 } else {
1032 Ok(SqlValue::Int(v.parse().map_err(|_| {
1033 MdqlError::QueryParse(format!("Invalid int: {}", v))
1034 })?))
1035 }
1036 }
1037 "keyword" if t.value == "NULL" => {
1038 self.advance();
1039 Ok(SqlValue::Null)
1040 }
1041 _ => Err(MdqlError::QueryParse(format!(
1042 "Expected value, got '{}'",
1043 t.raw
1044 ))),
1045 }
1046 }
1047
1048 fn parse_order_by(&mut self) -> Result<Vec<OrderSpec>, MdqlError> {
1049 let mut specs = vec![self.parse_order_spec()?];
1050 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
1051 self.advance();
1052 specs.push(self.parse_order_spec()?);
1053 }
1054 Ok(specs)
1055 }
1056
1057 fn parse_order_spec(&mut self) -> Result<OrderSpec, MdqlError> {
1058 let expr = self.parse_additive()?;
1059 let col = expr.as_column().unwrap_or("").to_string();
1060 let descending = if self.match_keyword("DESC") {
1061 true
1062 } else {
1063 self.match_keyword("ASC");
1064 false
1065 };
1066 Ok(OrderSpec {
1067 column: col,
1068 expr: Some(expr),
1069 descending,
1070 })
1071 }
1072
1073 fn is_clause_keyword(&self, t: &Token) -> bool {
1074 t.token_type == "keyword"
1075 && ["WHERE", "ORDER", "LIMIT", "JOIN", "ON", "GROUP"].contains(&t.value.as_str())
1076 }
1077
1078 fn is_reserved_keyword(kw: &str) -> bool {
1080 matches!(kw,
1081 "AS" | "FROM" | "WHERE" | "AND" | "OR" | "ORDER" | "BY"
1082 | "ASC" | "DESC" | "LIMIT" | "JOIN" | "ON" | "GROUP"
1083 | "SELECT" | "INSERT" | "INTO" | "VALUES" | "UPDATE" | "SET"
1084 | "DELETE" | "ALTER" | "TABLE" | "IS" | "NOT" | "IN" | "LIKE"
1085 | "RENAME" | "FIELD" | "TO" | "DROP" | "MERGE" | "FIELDS"
1086 | "CASE" | "WHEN" | "THEN" | "ELSE" | "END"
1087 )
1088 }
1089
1090 fn expect_end(&self) -> Result<(), MdqlError> {
1091 if let Some(t) = self.peek() {
1092 return Err(MdqlError::QueryParse(format!(
1093 "Unexpected token '{}' at position {}",
1094 t.raw, self.pos
1095 )));
1096 }
1097 Ok(())
1098 }
1099}
1100
1101pub fn parse_query(sql: &str) -> crate::errors::Result<Statement> {
1102 let tokens = tokenize(sql);
1103 if tokens.is_empty() {
1104 return Err(MdqlError::QueryParse("Empty query".into()));
1105 }
1106 let mut parser = Parser::new(tokens);
1107 parser.parse_statement()
1108}
1109
1110#[cfg(test)]
1111mod tests {
1112 use super::*;
1113
1114 #[test]
1115 fn test_simple_select() {
1116 let stmt = parse_query("SELECT title, status FROM strategies").unwrap();
1117 if let Statement::Select(q) = stmt {
1118 assert_eq!(q.columns, ColumnList::Named(vec![SelectExpr::Column("title".into()), SelectExpr::Column("status".into())]));
1119 assert_eq!(q.table, "strategies");
1120 } else {
1121 panic!("Expected Select");
1122 }
1123 }
1124
1125 #[test]
1126 fn test_select_star() {
1127 let stmt = parse_query("SELECT * FROM test").unwrap();
1128 if let Statement::Select(q) = stmt {
1129 assert_eq!(q.columns, ColumnList::All);
1130 } else {
1131 panic!("Expected Select");
1132 }
1133 }
1134
1135 #[test]
1136 fn test_where_clause() {
1137 let stmt = parse_query("SELECT title FROM test WHERE count > 5").unwrap();
1138 if let Statement::Select(q) = stmt {
1139 assert!(q.where_clause.is_some());
1140 } else {
1141 panic!("Expected Select");
1142 }
1143 }
1144
1145 #[test]
1146 fn test_order_by() {
1147 let stmt =
1148 parse_query("SELECT title FROM test ORDER BY composite DESC, title ASC").unwrap();
1149 if let Statement::Select(q) = stmt {
1150 let ob = q.order_by.unwrap();
1151 assert_eq!(ob.len(), 2);
1152 assert!(ob[0].descending);
1153 assert!(!ob[1].descending);
1154 } else {
1155 panic!("Expected Select");
1156 }
1157 }
1158
1159 #[test]
1160 fn test_limit() {
1161 let stmt = parse_query("SELECT * FROM test LIMIT 10").unwrap();
1162 if let Statement::Select(q) = stmt {
1163 assert_eq!(q.limit, Some(10));
1164 } else {
1165 panic!("Expected Select");
1166 }
1167 }
1168
1169 #[test]
1170 fn test_insert() {
1171 let stmt = parse_query(
1172 "INSERT INTO test (title, count) VALUES ('Hello', 42)",
1173 )
1174 .unwrap();
1175 if let Statement::Insert(q) = stmt {
1176 assert_eq!(q.table, "test");
1177 assert_eq!(q.columns, vec!["title", "count"]);
1178 assert_eq!(q.values[0], SqlValue::String("Hello".into()));
1179 assert_eq!(q.values[1], SqlValue::Int(42));
1180 } else {
1181 panic!("Expected Insert");
1182 }
1183 }
1184
1185 #[test]
1186 fn test_update() {
1187 let stmt = parse_query("UPDATE test SET status = 'KILLED' WHERE path = 'a.md'").unwrap();
1188 if let Statement::Update(q) = stmt {
1189 assert_eq!(q.table, "test");
1190 assert_eq!(q.assignments.len(), 1);
1191 assert!(q.where_clause.is_some());
1192 } else {
1193 panic!("Expected Update");
1194 }
1195 }
1196
1197 #[test]
1198 fn test_delete() {
1199 let stmt = parse_query("DELETE FROM test WHERE status = 'draft'").unwrap();
1200 if let Statement::Delete(q) = stmt {
1201 assert_eq!(q.table, "test");
1202 assert!(q.where_clause.is_some());
1203 } else {
1204 panic!("Expected Delete");
1205 }
1206 }
1207
1208 #[test]
1209 fn test_alter_rename() {
1210 let stmt =
1211 parse_query("ALTER TABLE test RENAME FIELD 'Summary' TO 'Overview'").unwrap();
1212 if let Statement::AlterRename(q) = stmt {
1213 assert_eq!(q.old_name, "Summary");
1214 assert_eq!(q.new_name, "Overview");
1215 } else {
1216 panic!("Expected AlterRename");
1217 }
1218 }
1219
1220 #[test]
1221 fn test_alter_drop() {
1222 let stmt = parse_query("ALTER TABLE test DROP FIELD 'Details'").unwrap();
1223 if let Statement::AlterDrop(q) = stmt {
1224 assert_eq!(q.field_name, "Details");
1225 } else {
1226 panic!("Expected AlterDrop");
1227 }
1228 }
1229
1230 #[test]
1231 fn test_alter_merge() {
1232 let stmt = parse_query(
1233 "ALTER TABLE test MERGE FIELDS 'Entry Rules', 'Exit Rules' INTO 'Trading Rules'",
1234 )
1235 .unwrap();
1236 if let Statement::AlterMerge(q) = stmt {
1237 assert_eq!(q.sources, vec!["Entry Rules", "Exit Rules"]);
1238 assert_eq!(q.into, "Trading Rules");
1239 } else {
1240 panic!("Expected AlterMerge");
1241 }
1242 }
1243
1244 #[test]
1245 fn test_backtick_ident() {
1246 let stmt = parse_query("SELECT `Structural Mechanism` FROM test").unwrap();
1247 if let Statement::Select(q) = stmt {
1248 assert_eq!(
1249 q.columns,
1250 ColumnList::Named(vec![SelectExpr::Column("Structural Mechanism".into())])
1251 );
1252 } else {
1253 panic!("Expected Select");
1254 }
1255 }
1256
1257 #[test]
1258 fn test_like_operator() {
1259 let stmt = parse_query("SELECT title FROM test WHERE categories LIKE '%defi%'").unwrap();
1260 if let Statement::Select(q) = stmt {
1261 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1262 assert_eq!(c.op, "LIKE");
1263 assert_eq!(c.value, Some(SqlValue::String("%defi%".into())));
1264 } else {
1265 panic!("Expected LIKE comparison");
1266 }
1267 } else {
1268 panic!("Expected Select");
1269 }
1270 }
1271
1272 #[test]
1273 fn test_in_operator() {
1274 let stmt =
1275 parse_query("SELECT * FROM test WHERE status IN ('ACTIVE', 'LIVE')").unwrap();
1276 if let Statement::Select(q) = stmt {
1277 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1278 assert_eq!(c.op, "IN");
1279 } else {
1280 panic!("Expected IN comparison");
1281 }
1282 } else {
1283 panic!("Expected Select");
1284 }
1285 }
1286
1287 #[test]
1288 fn test_is_null() {
1289 let stmt = parse_query("SELECT * FROM test WHERE title IS NULL").unwrap();
1290 if let Statement::Select(q) = stmt {
1291 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1292 assert_eq!(c.op, "IS NULL");
1293 } else {
1294 panic!("Expected IS NULL comparison");
1295 }
1296 } else {
1297 panic!("Expected Select");
1298 }
1299 }
1300
1301 #[test]
1302 fn test_and_or() {
1303 let stmt = parse_query(
1304 "SELECT * FROM test WHERE status = 'ACTIVE' AND count > 5 OR title LIKE '%test%'",
1305 )
1306 .unwrap();
1307 if let Statement::Select(q) = stmt {
1308 assert!(q.where_clause.is_some());
1309 } else {
1310 panic!("Expected Select");
1311 }
1312 }
1313
1314 #[test]
1315 fn test_join() {
1316 let stmt = parse_query(
1317 "SELECT s.title, b.sharpe FROM strategies s JOIN backtests b ON b.strategy = s.path",
1318 )
1319 .unwrap();
1320 if let Statement::Select(q) = stmt {
1321 assert_eq!(q.table, "strategies");
1322 assert_eq!(q.table_alias, Some("s".into()));
1323 assert_eq!(q.joins.len(), 1);
1324 let join = &q.joins[0];
1325 assert_eq!(join.table, "backtests");
1326 assert_eq!(join.alias, Some("b".into()));
1327 } else {
1328 panic!("Expected Select");
1329 }
1330 }
1331
1332 #[test]
1333 fn test_multi_join() {
1334 let stmt = parse_query(
1335 "SELECT s.title, b.sharpe, c.verdict FROM strategies s JOIN backtests b ON b.strategy = s.path JOIN critiques c ON c.strategy = s.path",
1336 )
1337 .unwrap();
1338 if let Statement::Select(q) = stmt {
1339 assert_eq!(q.table, "strategies");
1340 assert_eq!(q.table_alias, Some("s".into()));
1341 assert_eq!(q.joins.len(), 2);
1342 assert_eq!(q.joins[0].table, "backtests");
1343 assert_eq!(q.joins[0].alias, Some("b".into()));
1344 assert_eq!(q.joins[0].left_col, "b.strategy");
1345 assert_eq!(q.joins[0].right_col, "s.path");
1346 assert_eq!(q.joins[1].table, "critiques");
1347 assert_eq!(q.joins[1].alias, Some("c".into()));
1348 assert_eq!(q.joins[1].left_col, "c.strategy");
1349 assert_eq!(q.joins[1].right_col, "s.path");
1350 } else {
1351 panic!("Expected Select");
1352 }
1353 }
1354
1355 #[test]
1356 fn test_empty_query() {
1357 assert!(parse_query("").is_err());
1358 }
1359
1360 #[test]
1361 fn test_count_star() {
1362 let stmt = parse_query("SELECT status, COUNT(*) AS cnt FROM strategies GROUP BY status").unwrap();
1363 if let Statement::Select(q) = stmt {
1364 if let ColumnList::Named(exprs) = &q.columns {
1365 assert_eq!(exprs.len(), 2);
1366 assert_eq!(exprs[0], SelectExpr::Column("status".into()));
1367 assert!(matches!(&exprs[1], SelectExpr::Aggregate {
1368 func: AggFunc::Count,
1369 arg,
1370 alias: Some(a),
1371 ..
1372 } if arg == "*" && a == "cnt"));
1373 } else {
1374 panic!("Expected Named columns");
1375 }
1376 assert_eq!(q.group_by, Some(vec!["status".into()]));
1377 } else {
1378 panic!("Expected Select");
1379 }
1380 }
1381
1382 #[test]
1383 fn test_count_column_as_ident() {
1384 let stmt = parse_query("INSERT INTO test (title, count) VALUES ('Hello', 42)").unwrap();
1386 if let Statement::Insert(q) = stmt {
1387 assert_eq!(q.columns, vec!["title", "count"]);
1388 } else {
1389 panic!("Expected Insert");
1390 }
1391 }
1392
1393 #[test]
1394 fn test_multiple_aggregates() {
1395 let stmt = parse_query("SELECT MIN(composite), MAX(composite), AVG(composite) FROM strategies").unwrap();
1396 if let Statement::Select(q) = stmt {
1397 if let ColumnList::Named(exprs) = &q.columns {
1398 assert_eq!(exprs.len(), 3);
1399 assert!(matches!(&exprs[0], SelectExpr::Aggregate { func: AggFunc::Min, .. }));
1400 assert!(matches!(&exprs[1], SelectExpr::Aggregate { func: AggFunc::Max, .. }));
1401 assert!(matches!(&exprs[2], SelectExpr::Aggregate { func: AggFunc::Avg, .. }));
1402 } else {
1403 panic!("Expected Named columns");
1404 }
1405 assert_eq!(q.group_by, None);
1406 } else {
1407 panic!("Expected Select");
1408 }
1409 }
1410
1411 #[test]
1414 fn test_select_arithmetic_expr() {
1415 let stmt = parse_query("SELECT a + b FROM test").unwrap();
1416 if let Statement::Select(q) = stmt {
1417 if let ColumnList::Named(exprs) = &q.columns {
1418 assert_eq!(exprs.len(), 1);
1419 assert!(matches!(&exprs[0], SelectExpr::Expr {
1420 expr: Expr::BinaryOp { op: ArithOp::Add, .. },
1421 alias: None,
1422 }));
1423 } else {
1424 panic!("Expected Named columns");
1425 }
1426 } else {
1427 panic!("Expected Select");
1428 }
1429 }
1430
1431 #[test]
1432 fn test_select_arithmetic_with_alias() {
1433 let stmt = parse_query("SELECT a + b AS total FROM test").unwrap();
1434 if let Statement::Select(q) = stmt {
1435 if let ColumnList::Named(exprs) = &q.columns {
1436 assert_eq!(exprs.len(), 1);
1437 assert!(matches!(&exprs[0], SelectExpr::Expr {
1438 alias: Some(a),
1439 ..
1440 } if a == "total"));
1441 assert_eq!(exprs[0].output_name(), "total");
1442 } else {
1443 panic!("Expected Named columns");
1444 }
1445 } else {
1446 panic!("Expected Select");
1447 }
1448 }
1449
1450 #[test]
1451 fn test_select_precedence() {
1452 let stmt = parse_query("SELECT a + b * c FROM test").unwrap();
1454 if let Statement::Select(q) = stmt {
1455 if let ColumnList::Named(exprs) = &q.columns {
1456 if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1457 if let Expr::BinaryOp { left, op, right } = expr {
1458 assert_eq!(*op, ArithOp::Add);
1459 assert!(matches!(left.as_ref(), Expr::Column(n) if n == "a"));
1460 assert!(matches!(right.as_ref(), Expr::BinaryOp { op: ArithOp::Mul, .. }));
1461 } else {
1462 panic!("Expected BinaryOp");
1463 }
1464 } else {
1465 panic!("Expected Expr variant");
1466 }
1467 } else {
1468 panic!("Expected Named columns");
1469 }
1470 } else {
1471 panic!("Expected Select");
1472 }
1473 }
1474
1475 #[test]
1476 fn test_select_parenthesized_expr() {
1477 let stmt = parse_query("SELECT (a + b) * c FROM test").unwrap();
1479 if let Statement::Select(q) = stmt {
1480 if let ColumnList::Named(exprs) = &q.columns {
1481 if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1482 if let Expr::BinaryOp { left, op, .. } = expr {
1483 assert_eq!(*op, ArithOp::Mul);
1484 assert!(matches!(left.as_ref(), Expr::BinaryOp { op: ArithOp::Add, .. }));
1485 } else {
1486 panic!("Expected BinaryOp");
1487 }
1488 } else {
1489 panic!("Expected Expr variant");
1490 }
1491 } else {
1492 panic!("Expected Named columns");
1493 }
1494 } else {
1495 panic!("Expected Select");
1496 }
1497 }
1498
1499 #[test]
1500 fn test_select_unary_minus() {
1501 let stmt = parse_query("SELECT -count FROM test").unwrap();
1502 if let Statement::Select(q) = stmt {
1503 if let ColumnList::Named(exprs) = &q.columns {
1504 assert!(matches!(&exprs[0], SelectExpr::Expr {
1505 expr: Expr::UnaryMinus(_),
1506 ..
1507 }));
1508 } else {
1509 panic!("Expected Named columns");
1510 }
1511 } else {
1512 panic!("Expected Select");
1513 }
1514 }
1515
1516 #[test]
1517 fn test_select_negative_literal() {
1518 let stmt = parse_query("SELECT -42 FROM test").unwrap();
1519 if let Statement::Select(q) = stmt {
1520 if let ColumnList::Named(exprs) = &q.columns {
1521 assert!(matches!(&exprs[0], SelectExpr::Expr {
1523 expr: Expr::Literal(SqlValue::Int(-42)),
1524 ..
1525 }));
1526 } else {
1527 panic!("Expected Named columns");
1528 }
1529 } else {
1530 panic!("Expected Select");
1531 }
1532 }
1533
1534 #[test]
1535 fn test_where_arithmetic_expr() {
1536 let stmt = parse_query("SELECT * FROM test WHERE a + b > 10").unwrap();
1537 if let Statement::Select(q) = stmt {
1538 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1539 assert_eq!(c.op, ">");
1540 assert!(matches!(&c.left_expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1541 assert!(matches!(&c.right_expr, Some(Expr::Literal(SqlValue::Int(10)))));
1542 } else {
1543 panic!("Expected comparison");
1544 }
1545 } else {
1546 panic!("Expected Select");
1547 }
1548 }
1549
1550 #[test]
1551 fn test_where_both_sides_expr() {
1552 let stmt = parse_query("SELECT * FROM test WHERE a * 2 > b + 1").unwrap();
1553 if let Statement::Select(q) = stmt {
1554 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1555 assert_eq!(c.op, ">");
1556 assert!(matches!(&c.left_expr, Some(Expr::BinaryOp { op: ArithOp::Mul, .. })));
1557 assert!(matches!(&c.right_expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1558 } else {
1559 panic!("Expected comparison");
1560 }
1561 } else {
1562 panic!("Expected Select");
1563 }
1564 }
1565
1566 #[test]
1567 fn test_order_by_expr() {
1568 let stmt = parse_query("SELECT * FROM test ORDER BY a + b DESC").unwrap();
1569 if let Statement::Select(q) = stmt {
1570 let ob = q.order_by.unwrap();
1571 assert_eq!(ob.len(), 1);
1572 assert!(ob[0].descending);
1573 assert!(matches!(&ob[0].expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1574 } else {
1575 panic!("Expected Select");
1576 }
1577 }
1578
1579 #[test]
1580 fn test_all_arithmetic_ops() {
1581 let stmt = parse_query("SELECT a + b, a - b, a * b, a / b, a % b FROM test").unwrap();
1582 if let Statement::Select(q) = stmt {
1583 if let ColumnList::Named(exprs) = &q.columns {
1584 assert_eq!(exprs.len(), 5);
1585 assert!(matches!(&exprs[0], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Add, .. }, .. }));
1586 assert!(matches!(&exprs[1], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Sub, .. }, .. }));
1587 assert!(matches!(&exprs[2], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Mul, .. }, .. }));
1588 assert!(matches!(&exprs[3], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Div, .. }, .. }));
1589 assert!(matches!(&exprs[4], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Mod, .. }, .. }));
1590 } else {
1591 panic!("Expected Named columns");
1592 }
1593 } else {
1594 panic!("Expected Select");
1595 }
1596 }
1597
1598 #[test]
1599 fn test_column_with_literal_arithmetic() {
1600 let stmt = parse_query("SELECT count * 2 + 1 FROM test").unwrap();
1601 if let Statement::Select(q) = stmt {
1602 if let ColumnList::Named(exprs) = &q.columns {
1603 if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1605 if let Expr::BinaryOp { left, op, right } = expr {
1606 assert_eq!(*op, ArithOp::Add);
1607 assert!(matches!(right.as_ref(), Expr::Literal(SqlValue::Int(1))));
1608 assert!(matches!(left.as_ref(), Expr::BinaryOp { op: ArithOp::Mul, .. }));
1609 } else {
1610 panic!("Expected BinaryOp");
1611 }
1612 } else {
1613 panic!("Expected Expr");
1614 }
1615 } else {
1616 panic!("Expected Named columns");
1617 }
1618 } else {
1619 panic!("Expected Select");
1620 }
1621 }
1622
1623 #[test]
1624 fn test_mixed_columns_and_exprs() {
1625 let stmt = parse_query("SELECT title, a + b AS sum, count FROM test").unwrap();
1626 if let Statement::Select(q) = stmt {
1627 if let ColumnList::Named(exprs) = &q.columns {
1628 assert_eq!(exprs.len(), 3);
1629 assert_eq!(exprs[0], SelectExpr::Column("title".into()));
1630 assert!(matches!(&exprs[1], SelectExpr::Expr { alias: Some(a), .. } if a == "sum"));
1631 assert_eq!(exprs[2], SelectExpr::Column("count".into()));
1632 } else {
1633 panic!("Expected Named columns");
1634 }
1635 } else {
1636 panic!("Expected Select");
1637 }
1638 }
1639
1640 #[test]
1643 fn test_case_when_basic() {
1644 let stmt = parse_query(
1645 "SELECT CASE WHEN status = 'ACTIVE' THEN 1 ELSE 0 END FROM test"
1646 ).unwrap();
1647 if let Statement::Select(q) = stmt {
1648 if let ColumnList::Named(exprs) = &q.columns {
1649 assert_eq!(exprs.len(), 1);
1650 assert!(matches!(&exprs[0], SelectExpr::Expr {
1651 expr: Expr::Case { .. },
1652 ..
1653 }));
1654 } else {
1655 panic!("Expected Named columns");
1656 }
1657 } else {
1658 panic!("Expected Select");
1659 }
1660 }
1661
1662 #[test]
1663 fn test_case_when_multiple_branches() {
1664 let stmt = parse_query(
1665 "SELECT CASE WHEN x > 10 THEN 'high' WHEN x > 5 THEN 'mid' ELSE 'low' END FROM test"
1666 ).unwrap();
1667 if let Statement::Select(q) = stmt {
1668 if let ColumnList::Named(exprs) = &q.columns {
1669 if let SelectExpr::Expr { expr: Expr::Case { whens, else_expr }, .. } = &exprs[0] {
1670 assert_eq!(whens.len(), 2);
1671 assert!(else_expr.is_some());
1672 } else {
1673 panic!("Expected Case expression");
1674 }
1675 } else {
1676 panic!("Expected Named columns");
1677 }
1678 } else {
1679 panic!("Expected Select");
1680 }
1681 }
1682
1683 #[test]
1684 fn test_case_when_no_else() {
1685 let stmt = parse_query(
1686 "SELECT CASE WHEN x = 1 THEN 'one' END FROM test"
1687 ).unwrap();
1688 if let Statement::Select(q) = stmt {
1689 if let ColumnList::Named(exprs) = &q.columns {
1690 if let SelectExpr::Expr { expr: Expr::Case { whens, else_expr }, .. } = &exprs[0] {
1691 assert_eq!(whens.len(), 1);
1692 assert!(else_expr.is_none());
1693 } else {
1694 panic!("Expected Case expression");
1695 }
1696 } else {
1697 panic!("Expected Named columns");
1698 }
1699 } else {
1700 panic!("Expected Select");
1701 }
1702 }
1703
1704 #[test]
1705 fn test_case_when_in_aggregate() {
1706 let stmt = parse_query(
1707 "SELECT SUM(CASE WHEN side = 'BUY' THEN size ELSE -size END) AS net FROM orders GROUP BY token"
1708 ).unwrap();
1709 if let Statement::Select(q) = stmt {
1710 if let ColumnList::Named(exprs) = &q.columns {
1711 assert_eq!(exprs.len(), 1);
1712 assert!(matches!(&exprs[0], SelectExpr::Aggregate {
1713 func: AggFunc::Sum,
1714 arg_expr: Some(Expr::Case { .. }),
1715 alias: Some(a),
1716 ..
1717 } if a == "net"));
1718 } else {
1719 panic!("Expected Named columns");
1720 }
1721 } else {
1722 panic!("Expected Select");
1723 }
1724 }
1725
1726 #[test]
1727 fn test_case_when_with_alias() {
1728 let stmt = parse_query(
1729 "SELECT CASE WHEN x > 0 THEN 'pos' ELSE 'neg' END AS sign FROM test"
1730 ).unwrap();
1731 if let Statement::Select(q) = stmt {
1732 if let ColumnList::Named(exprs) = &q.columns {
1733 assert!(matches!(&exprs[0], SelectExpr::Expr {
1734 expr: Expr::Case { .. },
1735 alias: Some(a),
1736 } if a == "sign"));
1737 } else {
1738 panic!("Expected Named columns");
1739 }
1740 } else {
1741 panic!("Expected Select");
1742 }
1743 }
1744}