1use regex::Regex;
4use std::sync::LazyLock;
5
6use crate::errors::MdqlError;
7
8#[derive(Debug, Clone, PartialEq)]
11pub enum ArithOp {
12 Add,
13 Sub,
14 Mul,
15 Div,
16 Mod,
17}
18
19#[derive(Debug, Clone, PartialEq)]
20pub enum Expr {
21 Literal(SqlValue),
22 Column(String),
23 BinaryOp { left: Box<Expr>, op: ArithOp, right: Box<Expr> },
24 UnaryMinus(Box<Expr>),
25 Case { whens: Vec<(WhereClause, Box<Expr>)>, else_expr: Option<Box<Expr>> },
26 DateAdd { date: Box<Expr>, days: Box<Expr> },
28 DateDiff { left: Box<Expr>, right: Box<Expr> },
30 CurrentDate,
32 CurrentTimestamp,
33}
34
35impl Expr {
36 pub fn as_column(&self) -> Option<&str> {
38 if let Expr::Column(name) = self { Some(name) } else { None }
39 }
40
41 pub fn display_name(&self) -> String {
43 match self {
44 Expr::Literal(SqlValue::Int(n)) => n.to_string(),
45 Expr::Literal(SqlValue::Float(f)) => f.to_string(),
46 Expr::Literal(SqlValue::String(s)) => format!("'{}'", s),
47 Expr::Literal(SqlValue::Null) => "NULL".to_string(),
48 Expr::Literal(SqlValue::List(_)) => "list".to_string(),
49 Expr::Column(name) => name.clone(),
50 Expr::BinaryOp { left, op, right } => {
51 let op_str = match op {
52 ArithOp::Add => "+",
53 ArithOp::Sub => "-",
54 ArithOp::Mul => "*",
55 ArithOp::Div => "/",
56 ArithOp::Mod => "%",
57 };
58 format!("{} {} {}", left.display_name(), op_str, right.display_name())
59 }
60 Expr::UnaryMinus(inner) => format!("-{}", inner.display_name()),
61 Expr::Case { .. } => "CASE".to_string(),
62 Expr::DateAdd { date, days } => format!("DATE_ADD({}, {})", date.display_name(), days.display_name()),
63 Expr::DateDiff { left, right } => format!("DATEDIFF({}, {})", left.display_name(), right.display_name()),
64 Expr::CurrentDate => "CURRENT_DATE".to_string(),
65 Expr::CurrentTimestamp => "CURRENT_TIMESTAMP".to_string(),
66 }
67 }
68}
69
70#[derive(Debug, Clone, PartialEq)]
71pub struct OrderSpec {
72 pub column: String,
73 pub expr: Option<Expr>,
74 pub descending: bool,
75}
76
77#[derive(Debug, Clone, PartialEq)]
78pub struct Comparison {
79 pub column: String,
80 pub op: String,
81 pub value: Option<SqlValue>,
82 pub left_expr: Option<Expr>,
83 pub right_expr: Option<Expr>,
84}
85
86#[derive(Debug, Clone, PartialEq)]
87pub struct BoolOp {
88 pub op: String, pub left: Box<WhereClause>,
90 pub right: Box<WhereClause>,
91}
92
93#[derive(Debug, Clone, PartialEq)]
94pub enum WhereClause {
95 Comparison(Comparison),
96 BoolOp(BoolOp),
97}
98
99#[derive(Debug, Clone, PartialEq)]
100pub enum SqlValue {
101 String(String),
102 Int(i64),
103 Float(f64),
104 Null,
105 List(Vec<SqlValue>),
106}
107
108#[derive(Debug, Clone, PartialEq)]
109pub struct JoinClause {
110 pub table: String,
111 pub alias: Option<String>,
112 pub left_col: String,
113 pub right_col: String,
114}
115
116#[derive(Debug, Clone, PartialEq)]
117pub enum AggFunc {
118 Count,
119 Sum,
120 Avg,
121 Min,
122 Max,
123}
124
125#[derive(Debug, Clone, PartialEq)]
126pub enum SelectExpr {
127 Column(String),
128 Aggregate { func: AggFunc, arg: String, arg_expr: Option<Expr>, alias: Option<String> },
129 Expr { expr: Expr, alias: Option<String> },
130}
131
132impl SelectExpr {
133 pub fn output_name(&self) -> String {
134 match self {
135 SelectExpr::Column(name) => name.clone(),
136 SelectExpr::Aggregate { func, arg, alias, .. } => {
137 if let Some(a) = alias {
138 a.clone()
139 } else {
140 let func_name = match func {
141 AggFunc::Count => "COUNT",
142 AggFunc::Sum => "SUM",
143 AggFunc::Avg => "AVG",
144 AggFunc::Min => "MIN",
145 AggFunc::Max => "MAX",
146 };
147 format!("{}({})", func_name, arg)
148 }
149 }
150 SelectExpr::Expr { expr, alias } => {
151 alias.clone().unwrap_or_else(|| expr.display_name())
152 }
153 }
154 }
155
156 pub fn is_aggregate(&self) -> bool {
157 matches!(self, SelectExpr::Aggregate { .. })
158 }
159}
160
161#[derive(Debug, Clone, PartialEq)]
162pub struct SelectQuery {
163 pub columns: ColumnList,
164 pub table: String,
165 pub table_alias: Option<String>,
166 pub joins: Vec<JoinClause>,
167 pub where_clause: Option<WhereClause>,
168 pub group_by: Option<Vec<String>>,
169 pub having: Option<WhereClause>,
170 pub order_by: Option<Vec<OrderSpec>>,
171 pub limit: Option<i64>,
172}
173
174#[derive(Debug, Clone, PartialEq)]
175pub enum ColumnList {
176 All,
177 Named(Vec<SelectExpr>),
178}
179
180#[derive(Debug, Clone, PartialEq)]
181pub struct InsertQuery {
182 pub table: String,
183 pub columns: Vec<String>,
184 pub values: Vec<SqlValue>,
185}
186
187#[derive(Debug, Clone, PartialEq)]
188pub struct UpdateQuery {
189 pub table: String,
190 pub assignments: Vec<(String, SqlValue)>,
191 pub where_clause: Option<WhereClause>,
192}
193
194#[derive(Debug, Clone, PartialEq)]
195pub struct DeleteQuery {
196 pub table: String,
197 pub where_clause: Option<WhereClause>,
198}
199
200#[derive(Debug, Clone, PartialEq)]
201pub struct AlterRenameFieldQuery {
202 pub table: String,
203 pub old_name: String,
204 pub new_name: String,
205}
206
207#[derive(Debug, Clone, PartialEq)]
208pub struct AlterDropFieldQuery {
209 pub table: String,
210 pub field_name: String,
211}
212
213#[derive(Debug, Clone, PartialEq)]
214pub struct AlterMergeFieldsQuery {
215 pub table: String,
216 pub sources: Vec<String>,
217 pub into: String,
218}
219
220#[derive(Debug, Clone, PartialEq)]
221pub enum Statement {
222 Select(SelectQuery),
223 Insert(InsertQuery),
224 Update(UpdateQuery),
225 Delete(DeleteQuery),
226 AlterRename(AlterRenameFieldQuery),
227 AlterDrop(AlterDropFieldQuery),
228 AlterMerge(AlterMergeFieldsQuery),
229}
230
231static KEYWORDS: &[&str] = &[
234 "SELECT", "FROM", "WHERE", "AND", "OR", "ORDER", "BY",
235 "ASC", "DESC", "LIMIT", "LIKE", "IN", "IS", "NOT", "NULL",
236 "JOIN", "ON", "AS", "GROUP", "HAVING",
237 "INSERT", "INTO", "VALUES", "UPDATE", "SET", "DELETE",
238 "ALTER", "TABLE", "RENAME", "FIELD", "TO", "DROP", "MERGE", "FIELDS",
239 "CASE", "WHEN", "THEN", "ELSE", "END",
240 "INTERVAL", "DAY", "DAYS", "CURRENT_DATE", "CURRENT_TIMESTAMP", "DATEDIFF",
241];
242
243static AGG_FUNCS: &[&str] = &["COUNT", "SUM", "AVG", "MIN", "MAX"];
244
245static TOKEN_RE: LazyLock<Regex> = LazyLock::new(|| {
246 Regex::new(
247 r#"(?x)
248 \s*(?:
249 (?P<backtick>`[^`]+`)
250 | (?P<string>'(?:[^'\\]|\\.)*')
251 | (?P<number>\d+(?:\.\d+)?)
252 | (?P<op><=|>=|!=|[=<>,*()+\-/%])
253 | (?P<word>[A-Za-z_][A-Za-z0-9_./-]*)
254 )"#,
255 )
256 .unwrap()
257});
258
259#[derive(Debug, Clone)]
260struct Token {
261 token_type: String,
262 value: String,
263 raw: String,
264}
265
266fn tokenize(sql: &str) -> Vec<Token> {
267 let mut tokens = Vec::new();
268 for caps in TOKEN_RE.captures_iter(sql) {
269 if let Some(m) = caps.name("backtick") {
270 let raw = m.as_str();
271 tokens.push(Token {
272 token_type: "ident".into(),
273 value: raw[1..raw.len() - 1].into(),
274 raw: raw.into(),
275 });
276 } else if let Some(m) = caps.name("string") {
277 let raw = m.as_str();
278 tokens.push(Token {
279 token_type: "string".into(),
280 value: raw[1..raw.len() - 1].into(),
281 raw: raw.into(),
282 });
283 } else if let Some(m) = caps.name("number") {
284 let raw = m.as_str();
285 tokens.push(Token {
286 token_type: "number".into(),
287 value: raw.into(),
288 raw: raw.into(),
289 });
290 } else if let Some(m) = caps.name("op") {
291 let raw = m.as_str();
292 tokens.push(Token {
293 token_type: "op".into(),
294 value: raw.into(),
295 raw: raw.into(),
296 });
297 } else if let Some(m) = caps.name("word") {
298 let raw = m.as_str();
299 if KEYWORDS.contains(&raw.to_uppercase().as_str()) {
300 tokens.push(Token {
301 token_type: "keyword".into(),
302 value: raw.to_uppercase(),
303 raw: raw.into(),
304 });
305 } else {
306 tokens.push(Token {
307 token_type: "ident".into(),
308 value: raw.into(),
309 raw: raw.into(),
310 });
311 }
312 }
313 }
314 tokens
315}
316
317struct Parser {
320 tokens: Vec<Token>,
321 pos: usize,
322}
323
324impl Parser {
325 fn new(tokens: Vec<Token>) -> Self {
326 Parser { tokens, pos: 0 }
327 }
328
329 fn peek(&self) -> Option<&Token> {
330 self.tokens.get(self.pos)
331 }
332
333 fn advance(&mut self) -> Token {
334 let t = self.tokens[self.pos].clone();
335 self.pos += 1;
336 t
337 }
338
339 fn expect(&mut self, type_: &str, value: Option<&str>) -> Result<Token, MdqlError> {
340 let t = self.peek().ok_or_else(|| {
341 MdqlError::QueryParse(format!(
342 "Unexpected end of query, expected {}",
343 value.unwrap_or(type_)
344 ))
345 })?;
346 let matches_type = t.token_type == type_;
347 let matches_value = value.map_or(true, |v| t.value == v);
348 if !matches_type || !matches_value {
349 return Err(MdqlError::QueryParse(format!(
350 "Expected {}, got '{}' at position {}",
351 value.unwrap_or(type_),
352 t.raw,
353 self.pos
354 )));
355 }
356 Ok(self.advance())
357 }
358
359 fn match_keyword(&mut self, kw: &str) -> bool {
360 if let Some(t) = self.peek() {
361 if t.token_type == "keyword" && t.value == kw {
362 self.advance();
363 return true;
364 }
365 }
366 false
367 }
368
369 fn parse_statement(&mut self) -> Result<Statement, MdqlError> {
370 let t = self.peek().ok_or_else(|| MdqlError::QueryParse("Empty query".into()))?;
371 match (t.token_type.as_str(), t.value.as_str()) {
372 ("keyword", "SELECT") => Ok(Statement::Select(self.parse_select()?)),
373 ("keyword", "INSERT") => Ok(Statement::Insert(self.parse_insert()?)),
374 ("keyword", "UPDATE") => Ok(Statement::Update(self.parse_update()?)),
375 ("keyword", "DELETE") => Ok(Statement::Delete(self.parse_delete()?)),
376 ("keyword", "ALTER") => self.parse_alter(),
377 _ => Err(MdqlError::QueryParse(format!(
378 "Expected SELECT, INSERT, UPDATE, DELETE, or ALTER, got '{}'",
379 t.raw
380 ))),
381 }
382 }
383
384 fn parse_select(&mut self) -> Result<SelectQuery, MdqlError> {
385 self.expect("keyword", Some("SELECT"))?;
386 let columns = self.parse_columns()?;
387 self.expect("keyword", Some("FROM"))?;
388 let table = self.parse_ident()?;
389
390 let mut table_alias = None;
392 if let Some(t) = self.peek() {
393 if t.token_type == "ident" && !self.is_clause_keyword(t) {
394 table_alias = Some(self.advance().value);
395 }
396 }
397
398 let mut joins = Vec::new();
400 while self.match_keyword("JOIN") {
401 let join_table = self.parse_ident()?;
402 let mut join_alias = None;
403 if let Some(t) = self.peek() {
404 if t.token_type == "ident" && !self.is_clause_keyword(t) {
405 join_alias = Some(self.advance().value);
406 }
407 }
408 self.expect("keyword", Some("ON"))?;
409 let left_col = self.parse_ident()?;
410 self.expect("op", Some("="))?;
411 let right_col = self.parse_ident()?;
412 joins.push(JoinClause {
413 table: join_table,
414 alias: join_alias,
415 left_col,
416 right_col,
417 });
418 }
419
420 let mut where_clause = None;
421 if self.match_keyword("WHERE") {
422 where_clause = Some(self.parse_or_expr()?);
423 }
424
425 let mut group_by = None;
426 if self.match_keyword("GROUP") {
427 self.expect("keyword", Some("BY"))?;
428 let mut cols = vec![self.parse_ident()?];
429 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
430 self.advance();
431 cols.push(self.parse_ident()?);
432 }
433 group_by = Some(cols);
434 }
435
436 let mut having = None;
437 if self.match_keyword("HAVING") {
438 having = Some(self.parse_or_expr()?);
439 }
440
441 let mut order_by = None;
442 if self.match_keyword("ORDER") {
443 self.expect("keyword", Some("BY"))?;
444 order_by = Some(self.parse_order_by()?);
445 }
446
447 let mut limit = None;
448 if self.match_keyword("LIMIT") {
449 let t = self.expect("number", None)?;
450 limit = Some(t.value.parse::<i64>().map_err(|_| {
451 MdqlError::QueryParse(format!("Invalid LIMIT value: {}", t.value))
452 })?);
453 }
454
455 self.expect_end()?;
456
457 Ok(SelectQuery {
458 columns,
459 table,
460 table_alias,
461 joins,
462 where_clause,
463 group_by,
464 having,
465 order_by,
466 limit,
467 })
468 }
469
470 fn parse_insert(&mut self) -> Result<InsertQuery, MdqlError> {
471 self.expect("keyword", Some("INSERT"))?;
472 self.expect("keyword", Some("INTO"))?;
473 let table = self.parse_ident()?;
474
475 self.expect("op", Some("("))?;
476 let mut columns = vec![self.parse_ident()?];
477 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
478 self.advance();
479 columns.push(self.parse_ident()?);
480 }
481 self.expect("op", Some(")"))?;
482
483 self.expect("keyword", Some("VALUES"))?;
484
485 self.expect("op", Some("("))?;
486 let mut values = vec![self.parse_value()?];
487 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
488 self.advance();
489 values.push(self.parse_value()?);
490 }
491 self.expect("op", Some(")"))?;
492
493 if columns.len() != values.len() {
494 return Err(MdqlError::QueryParse(format!(
495 "Column count ({}) does not match value count ({})",
496 columns.len(),
497 values.len()
498 )));
499 }
500
501 self.expect_end()?;
502 Ok(InsertQuery {
503 table,
504 columns,
505 values,
506 })
507 }
508
509 fn parse_update(&mut self) -> Result<UpdateQuery, MdqlError> {
510 self.expect("keyword", Some("UPDATE"))?;
511 let table = self.parse_ident()?;
512 self.expect("keyword", Some("SET"))?;
513
514 let mut assignments = Vec::new();
515 let col = self.parse_ident()?;
516 self.expect("op", Some("="))?;
517 let val = self.parse_value()?;
518 assignments.push((col, val));
519
520 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
521 self.advance();
522 let col = self.parse_ident()?;
523 self.expect("op", Some("="))?;
524 let val = self.parse_value()?;
525 assignments.push((col, val));
526 }
527
528 let mut where_clause = None;
529 if self.match_keyword("WHERE") {
530 where_clause = Some(self.parse_or_expr()?);
531 }
532
533 self.expect_end()?;
534 Ok(UpdateQuery {
535 table,
536 assignments,
537 where_clause,
538 })
539 }
540
541 fn parse_delete(&mut self) -> Result<DeleteQuery, MdqlError> {
542 self.expect("keyword", Some("DELETE"))?;
543 self.expect("keyword", Some("FROM"))?;
544 let table = self.parse_ident()?;
545
546 let mut where_clause = None;
547 if self.match_keyword("WHERE") {
548 where_clause = Some(self.parse_or_expr()?);
549 }
550
551 self.expect_end()?;
552 Ok(DeleteQuery {
553 table,
554 where_clause,
555 })
556 }
557
558 fn parse_alter(&mut self) -> Result<Statement, MdqlError> {
559 self.expect("keyword", Some("ALTER"))?;
560 self.expect("keyword", Some("TABLE"))?;
561 let table = self.parse_ident()?;
562
563 let t = self.peek().ok_or_else(|| {
564 MdqlError::QueryParse("Expected RENAME, DROP, or MERGE after table name".into())
565 })?;
566
567 match (t.token_type.as_str(), t.value.as_str()) {
568 ("keyword", "RENAME") => {
569 self.advance();
570 self.expect("keyword", Some("FIELD"))?;
571 let old_name = self.parse_string_or_ident()?;
572 self.expect("keyword", Some("TO"))?;
573 let new_name = self.parse_string_or_ident()?;
574 self.expect_end()?;
575 Ok(Statement::AlterRename(AlterRenameFieldQuery {
576 table,
577 old_name,
578 new_name,
579 }))
580 }
581 ("keyword", "DROP") => {
582 self.advance();
583 self.expect("keyword", Some("FIELD"))?;
584 let field_name = self.parse_string_or_ident()?;
585 self.expect_end()?;
586 Ok(Statement::AlterDrop(AlterDropFieldQuery {
587 table,
588 field_name,
589 }))
590 }
591 ("keyword", "MERGE") => {
592 self.advance();
593 self.expect("keyword", Some("FIELDS"))?;
594 let mut sources = vec![self.parse_string_or_ident()?];
595 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
596 self.advance();
597 sources.push(self.parse_string_or_ident()?);
598 }
599 self.expect("keyword", Some("INTO"))?;
600 let target = self.parse_string_or_ident()?;
601 self.expect_end()?;
602 Ok(Statement::AlterMerge(AlterMergeFieldsQuery {
603 table,
604 sources,
605 into: target,
606 }))
607 }
608 _ => Err(MdqlError::QueryParse(format!(
609 "Expected RENAME, DROP, or MERGE, got '{}'",
610 t.raw
611 ))),
612 }
613 }
614
615 fn parse_string_or_ident(&mut self) -> Result<String, MdqlError> {
616 let t = self.peek().ok_or_else(|| {
617 MdqlError::QueryParse("Expected field name, got end of query".into())
618 })?;
619 match t.token_type.as_str() {
620 "string" => {
621 let v = self.advance().value;
622 Ok(v)
623 }
624 "ident" | "keyword" => {
625 let v = self.advance().value;
626 Ok(v)
627 }
628 _ => Err(MdqlError::QueryParse(format!(
629 "Expected field name, got '{}'",
630 t.raw
631 ))),
632 }
633 }
634
635 fn parse_columns(&mut self) -> Result<ColumnList, MdqlError> {
636 if let Some(t) = self.peek() {
637 if t.token_type == "op" && t.value == "*" {
638 self.advance();
639 return Ok(ColumnList::All);
640 }
641 }
642
643 let mut exprs = vec![self.parse_select_expr()?];
644 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
645 self.advance();
646 exprs.push(self.parse_select_expr()?);
647 }
648 Ok(ColumnList::Named(exprs))
649 }
650
651 fn peek_is_agg_func(&self) -> bool {
652 let t = match self.peek() {
653 Some(t) => t,
654 None => return false,
655 };
656 let name_upper = t.value.to_uppercase();
657 if !AGG_FUNCS.contains(&name_upper.as_str()) {
658 return false;
659 }
660 self.tokens
662 .get(self.pos + 1)
663 .map_or(false, |next| next.token_type == "op" && next.value == "(")
664 }
665
666 fn parse_select_expr(&mut self) -> Result<SelectExpr, MdqlError> {
667 let _t = self.peek().ok_or_else(|| {
668 MdqlError::QueryParse("Expected column or aggregate, got end of query".into())
669 })?;
670
671 if self.peek_is_agg_func() {
672 let func_name = self.advance().value.to_uppercase();
673 let func = match func_name.as_str() {
674 "COUNT" => AggFunc::Count,
675 "SUM" => AggFunc::Sum,
676 "AVG" => AggFunc::Avg,
677 "MIN" => AggFunc::Min,
678 "MAX" => AggFunc::Max,
679 _ => unreachable!(),
680 };
681 self.expect("op", Some("("))?;
682 let (arg, arg_expr) = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "*") {
683 self.advance();
684 ("*".to_string(), None)
685 } else {
686 let expr = self.parse_additive()?;
688 if let Expr::Column(name) = &expr {
689 (name.clone(), None)
690 } else {
691 (expr.display_name(), Some(expr))
692 }
693 };
694 self.expect("op", Some(")"))?;
695
696 let alias = if self.match_keyword("AS") {
697 Some(self.parse_ident()?)
698 } else if self.peek().map_or(false, |t| {
699 t.token_type == "ident" && !self.is_clause_keyword(t)
700 }) {
701 Some(self.advance().value)
702 } else {
703 None
704 };
705
706 Ok(SelectExpr::Aggregate { func, arg, arg_expr, alias })
707 } else {
708 let expr = self.parse_additive()?;
710
711 let alias = if self.match_keyword("AS") {
713 Some(self.parse_ident()?)
714 } else if self.peek().map_or(false, |t| {
715 t.token_type == "ident" && !self.is_clause_keyword(t)
716 }) {
717 Some(self.advance().value)
718 } else {
719 None
720 };
721
722 if alias.is_none() {
725 if let Expr::Column(name) = &expr {
726 return Ok(SelectExpr::Column(name.clone()));
727 }
728 }
729
730 Ok(SelectExpr::Expr { expr, alias })
731 }
732 }
733
734 fn peek_is_additive_op(&self) -> bool {
737 self.peek().map_or(false, |t| {
738 t.token_type == "op" && (t.value == "+" || t.value == "-")
739 })
740 }
741
742 fn peek_is_multiplicative_op(&self) -> bool {
743 self.peek().map_or(false, |t| {
744 t.token_type == "op" && (t.value == "*" || t.value == "/" || t.value == "%")
745 })
746 }
747
748 fn parse_additive(&mut self) -> Result<Expr, MdqlError> {
749 let mut left = self.parse_multiplicative()?;
750 while self.peek_is_additive_op() {
751 let op_tok = self.advance();
752 let is_sub = op_tok.value == "-";
753
754 if self.peek().map_or(false, |t| t.token_type == "keyword" && t.value == "INTERVAL") {
756 self.advance(); let days_expr = self.parse_multiplicative()?;
758 if !self.match_keyword("DAY") && !self.match_keyword("DAYS") {
760 return Err(MdqlError::QueryParse("Expected DAY after INTERVAL value".into()));
761 }
762 let days = if is_sub {
763 Expr::UnaryMinus(Box::new(days_expr))
764 } else {
765 days_expr
766 };
767 left = Expr::DateAdd {
768 date: Box::new(left),
769 days: Box::new(days),
770 };
771 continue;
772 }
773
774 let op = match op_tok.value.as_str() {
775 "+" => ArithOp::Add,
776 "-" => ArithOp::Sub,
777 _ => unreachable!(),
778 };
779 let right = self.parse_multiplicative()?;
780 left = Expr::BinaryOp {
781 left: Box::new(left),
782 op,
783 right: Box::new(right),
784 };
785 }
786 Ok(left)
787 }
788
789 fn parse_multiplicative(&mut self) -> Result<Expr, MdqlError> {
790 let mut left = self.parse_unary()?;
791 while self.peek_is_multiplicative_op() {
792 let op_tok = self.advance();
793 let op = match op_tok.value.as_str() {
794 "*" => ArithOp::Mul,
795 "/" => ArithOp::Div,
796 "%" => ArithOp::Mod,
797 _ => unreachable!(),
798 };
799 let right = self.parse_unary()?;
800 left = Expr::BinaryOp {
801 left: Box::new(left),
802 op,
803 right: Box::new(right),
804 };
805 }
806 Ok(left)
807 }
808
809 fn parse_unary(&mut self) -> Result<Expr, MdqlError> {
810 if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "-") {
811 self.advance();
812 let inner = self.parse_atom()?;
813 match inner {
815 Expr::Literal(SqlValue::Int(n)) => Ok(Expr::Literal(SqlValue::Int(-n))),
816 Expr::Literal(SqlValue::Float(f)) => Ok(Expr::Literal(SqlValue::Float(-f))),
817 _ => Ok(Expr::UnaryMinus(Box::new(inner))),
818 }
819 } else {
820 self.parse_atom()
821 }
822 }
823
824 fn parse_atom(&mut self) -> Result<Expr, MdqlError> {
825 let t = self.peek().ok_or_else(|| {
826 MdqlError::QueryParse("Expected expression, got end of query".into())
827 })?;
828
829 match t.token_type.as_str() {
830 "number" => {
831 let v = self.advance().value;
832 if v.contains('.') {
833 let f: f64 = v.parse().map_err(|_| {
834 MdqlError::QueryParse(format!("Invalid float: {}", v))
835 })?;
836 Ok(Expr::Literal(SqlValue::Float(f)))
837 } else {
838 let n: i64 = v.parse().map_err(|_| {
839 MdqlError::QueryParse(format!("Invalid int: {}", v))
840 })?;
841 Ok(Expr::Literal(SqlValue::Int(n)))
842 }
843 }
844 "string" => {
845 let v = self.advance().value;
846 Ok(Expr::Literal(SqlValue::String(v)))
847 }
848 "keyword" if t.value == "NULL" => {
849 self.advance();
850 Ok(Expr::Literal(SqlValue::Null))
851 }
852 "keyword" if t.value == "CASE" => {
853 self.parse_case_expr()
854 }
855 "keyword" if t.value == "CURRENT_DATE" => {
856 self.advance();
857 Ok(Expr::CurrentDate)
858 }
859 "keyword" if t.value == "CURRENT_TIMESTAMP" => {
860 self.advance();
861 Ok(Expr::CurrentTimestamp)
862 }
863 "keyword" if t.value == "DATEDIFF" => {
864 self.advance();
865 self.expect("op", Some("("))?;
866 let left = self.parse_additive()?;
867 self.expect("op", Some(","))?;
868 let right = self.parse_additive()?;
869 self.expect("op", Some(")"))?;
870 Ok(Expr::DateDiff { left: Box::new(left), right: Box::new(right) })
871 }
872 "op" if t.value == "(" => {
873 self.advance();
874 let expr = self.parse_additive()?;
875 self.expect("op", Some(")"))?;
876 Ok(expr)
877 }
878 "ident" => {
879 let name = self.advance().value;
880 Ok(Expr::Column(name))
881 }
882 "keyword" if !Self::is_reserved_keyword(&t.value) => {
883 let name = self.advance().value;
884 Ok(Expr::Column(name))
885 }
886 _ => Err(MdqlError::QueryParse(format!(
887 "Expected expression, got '{}'",
888 t.raw
889 ))),
890 }
891 }
892
893 fn parse_case_expr(&mut self) -> Result<Expr, MdqlError> {
894 self.expect("keyword", Some("CASE"))?;
895 let mut whens = Vec::new();
896 while self.match_keyword("WHEN") {
897 let condition = self.parse_or_expr()?;
898 self.expect("keyword", Some("THEN"))?;
899 let result = self.parse_additive()?;
900 whens.push((condition, Box::new(result)));
901 }
902 if whens.is_empty() {
903 return Err(MdqlError::QueryParse("CASE requires at least one WHEN clause".into()));
904 }
905 let else_expr = if self.match_keyword("ELSE") {
906 Some(Box::new(self.parse_additive()?))
907 } else {
908 None
909 };
910 self.expect("keyword", Some("END"))?;
911 Ok(Expr::Case { whens, else_expr })
912 }
913
914 fn parse_ident(&mut self) -> Result<String, MdqlError> {
915 let t = self.peek().ok_or_else(|| {
916 MdqlError::QueryParse("Expected identifier, got end of query".into())
917 })?;
918 match t.token_type.as_str() {
919 "ident" | "keyword" => {
920 let v = self.advance().value;
921 Ok(v)
922 }
923 _ => Err(MdqlError::QueryParse(format!(
924 "Expected identifier, got '{}'",
925 t.raw
926 ))),
927 }
928 }
929
930 fn parse_or_expr(&mut self) -> Result<WhereClause, MdqlError> {
931 let mut left = self.parse_and_expr()?;
932 while self.match_keyword("OR") {
933 let right = self.parse_and_expr()?;
934 left = WhereClause::BoolOp(BoolOp {
935 op: "OR".into(),
936 left: Box::new(left),
937 right: Box::new(right),
938 });
939 }
940 Ok(left)
941 }
942
943 fn parse_and_expr(&mut self) -> Result<WhereClause, MdqlError> {
944 let mut left = self.parse_comparison()?;
945 while self.match_keyword("AND") {
946 let right = self.parse_comparison()?;
947 left = WhereClause::BoolOp(BoolOp {
948 op: "AND".into(),
949 left: Box::new(left),
950 right: Box::new(right),
951 });
952 }
953 Ok(left)
954 }
955
956 fn parse_comparison(&mut self) -> Result<WhereClause, MdqlError> {
957 if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
959 let saved_pos = self.pos;
961 self.advance();
962 let result = self.parse_or_expr();
964 if result.is_ok() && self.peek().map_or(false, |t| t.token_type == "op" && t.value == ")") {
965 self.advance();
966 return result;
967 }
968 self.pos = saved_pos;
970 }
971
972 let left_expr = self.parse_additive()?;
974
975 let col = left_expr.as_column().unwrap_or("").to_string();
977
978 if self.match_keyword("IS") {
980 if self.match_keyword("NOT") {
981 self.expect("keyword", Some("NULL"))?;
982 return Ok(WhereClause::Comparison(Comparison {
983 column: col,
984 op: "IS NOT NULL".into(),
985 value: None,
986 left_expr: Some(left_expr),
987 right_expr: None,
988 }));
989 }
990 self.expect("keyword", Some("NULL"))?;
991 return Ok(WhereClause::Comparison(Comparison {
992 column: col,
993 op: "IS NULL".into(),
994 value: None,
995 left_expr: Some(left_expr),
996 right_expr: None,
997 }));
998 }
999
1000 if self.match_keyword("IN") {
1002 self.expect("op", Some("("))?;
1003 let mut values = vec![self.parse_value()?];
1004 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
1005 self.advance();
1006 values.push(self.parse_value()?);
1007 }
1008 self.expect("op", Some(")"))?;
1009 return Ok(WhereClause::Comparison(Comparison {
1010 column: col,
1011 op: "IN".into(),
1012 value: Some(SqlValue::List(values)),
1013 left_expr: Some(left_expr),
1014 right_expr: None,
1015 }));
1016 }
1017
1018 if self.match_keyword("LIKE") {
1020 let val = self.parse_value()?;
1021 return Ok(WhereClause::Comparison(Comparison {
1022 column: col,
1023 op: "LIKE".into(),
1024 value: Some(val),
1025 left_expr: Some(left_expr),
1026 right_expr: None,
1027 }));
1028 }
1029
1030 if self.match_keyword("NOT") {
1032 if self.match_keyword("LIKE") {
1033 let val = self.parse_value()?;
1034 return Ok(WhereClause::Comparison(Comparison {
1035 column: col,
1036 op: "NOT LIKE".into(),
1037 value: Some(val),
1038 left_expr: Some(left_expr),
1039 right_expr: None,
1040 }));
1041 }
1042 return Err(MdqlError::QueryParse("Expected LIKE after NOT".into()));
1043 }
1044
1045 if let Some(t) = self.peek() {
1047 if t.token_type == "op" && ["=", "!=", "<", ">", "<=", ">="].contains(&t.value.as_str())
1048 {
1049 let op = self.advance().value;
1050 let right_expr = self.parse_additive()?;
1052 let value = match &right_expr {
1054 Expr::Literal(v) => Some(v.clone()),
1055 _ => None,
1056 };
1057 return Ok(WhereClause::Comparison(Comparison {
1058 column: col,
1059 op,
1060 value,
1061 left_expr: Some(left_expr),
1062 right_expr: Some(right_expr),
1063 }));
1064 }
1065 }
1066
1067 let got = self.peek().map_or("end".to_string(), |t| t.raw.clone());
1068 Err(MdqlError::QueryParse(format!(
1069 "Expected operator after '{}', got '{}'",
1070 left_expr.display_name(), got
1071 )))
1072 }
1073
1074 fn parse_value(&mut self) -> Result<SqlValue, MdqlError> {
1075 let t = self.peek().ok_or_else(|| {
1076 MdqlError::QueryParse("Expected value, got end of query".into())
1077 })?;
1078 match t.token_type.as_str() {
1079 "string" => {
1080 let v = self.advance().value;
1081 Ok(SqlValue::String(v))
1082 }
1083 "number" => {
1084 let v = self.advance().value;
1085 if v.contains('.') {
1086 Ok(SqlValue::Float(v.parse().map_err(|_| {
1087 MdqlError::QueryParse(format!("Invalid float: {}", v))
1088 })?))
1089 } else {
1090 Ok(SqlValue::Int(v.parse().map_err(|_| {
1091 MdqlError::QueryParse(format!("Invalid int: {}", v))
1092 })?))
1093 }
1094 }
1095 "keyword" if t.value == "NULL" => {
1096 self.advance();
1097 Ok(SqlValue::Null)
1098 }
1099 _ => Err(MdqlError::QueryParse(format!(
1100 "Expected value, got '{}'",
1101 t.raw
1102 ))),
1103 }
1104 }
1105
1106 fn parse_order_by(&mut self) -> Result<Vec<OrderSpec>, MdqlError> {
1107 let mut specs = vec![self.parse_order_spec()?];
1108 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
1109 self.advance();
1110 specs.push(self.parse_order_spec()?);
1111 }
1112 Ok(specs)
1113 }
1114
1115 fn parse_order_spec(&mut self) -> Result<OrderSpec, MdqlError> {
1116 let expr = self.parse_additive()?;
1117 let col = expr.as_column().unwrap_or("").to_string();
1118 let descending = if self.match_keyword("DESC") {
1119 true
1120 } else {
1121 self.match_keyword("ASC");
1122 false
1123 };
1124 Ok(OrderSpec {
1125 column: col,
1126 expr: Some(expr),
1127 descending,
1128 })
1129 }
1130
1131 fn is_clause_keyword(&self, t: &Token) -> bool {
1132 t.token_type == "keyword"
1133 && ["WHERE", "ORDER", "LIMIT", "JOIN", "ON", "GROUP"].contains(&t.value.as_str())
1134 }
1135
1136 fn is_reserved_keyword(kw: &str) -> bool {
1138 matches!(kw,
1139 "AS" | "FROM" | "WHERE" | "AND" | "OR" | "ORDER" | "BY"
1140 | "ASC" | "DESC" | "LIMIT" | "JOIN" | "ON" | "GROUP"
1141 | "SELECT" | "INSERT" | "INTO" | "VALUES" | "UPDATE" | "SET"
1142 | "DELETE" | "ALTER" | "TABLE" | "IS" | "NOT" | "IN" | "LIKE"
1143 | "RENAME" | "FIELD" | "TO" | "DROP" | "MERGE" | "FIELDS"
1144 | "CASE" | "WHEN" | "THEN" | "ELSE" | "END"
1145 | "HAVING" | "INTERVAL" | "DAY" | "DAYS"
1146 | "CURRENT_DATE" | "CURRENT_TIMESTAMP" | "DATEDIFF"
1147 )
1148 }
1149
1150 fn expect_end(&self) -> Result<(), MdqlError> {
1151 if let Some(t) = self.peek() {
1152 return Err(MdqlError::QueryParse(format!(
1153 "Unexpected token '{}' at position {}",
1154 t.raw, self.pos
1155 )));
1156 }
1157 Ok(())
1158 }
1159}
1160
1161pub fn parse_query(sql: &str) -> crate::errors::Result<Statement> {
1162 let tokens = tokenize(sql);
1163 if tokens.is_empty() {
1164 return Err(MdqlError::QueryParse("Empty query".into()));
1165 }
1166 let mut parser = Parser::new(tokens);
1167 parser.parse_statement()
1168}
1169
1170#[cfg(test)]
1171mod tests {
1172 use super::*;
1173
1174 #[test]
1175 fn test_simple_select() {
1176 let stmt = parse_query("SELECT title, status FROM strategies").unwrap();
1177 if let Statement::Select(q) = stmt {
1178 assert_eq!(q.columns, ColumnList::Named(vec![SelectExpr::Column("title".into()), SelectExpr::Column("status".into())]));
1179 assert_eq!(q.table, "strategies");
1180 } else {
1181 panic!("Expected Select");
1182 }
1183 }
1184
1185 #[test]
1186 fn test_select_star() {
1187 let stmt = parse_query("SELECT * FROM test").unwrap();
1188 if let Statement::Select(q) = stmt {
1189 assert_eq!(q.columns, ColumnList::All);
1190 } else {
1191 panic!("Expected Select");
1192 }
1193 }
1194
1195 #[test]
1196 fn test_where_clause() {
1197 let stmt = parse_query("SELECT title FROM test WHERE count > 5").unwrap();
1198 if let Statement::Select(q) = stmt {
1199 assert!(q.where_clause.is_some());
1200 } else {
1201 panic!("Expected Select");
1202 }
1203 }
1204
1205 #[test]
1206 fn test_order_by() {
1207 let stmt =
1208 parse_query("SELECT title FROM test ORDER BY composite DESC, title ASC").unwrap();
1209 if let Statement::Select(q) = stmt {
1210 let ob = q.order_by.unwrap();
1211 assert_eq!(ob.len(), 2);
1212 assert!(ob[0].descending);
1213 assert!(!ob[1].descending);
1214 } else {
1215 panic!("Expected Select");
1216 }
1217 }
1218
1219 #[test]
1220 fn test_limit() {
1221 let stmt = parse_query("SELECT * FROM test LIMIT 10").unwrap();
1222 if let Statement::Select(q) = stmt {
1223 assert_eq!(q.limit, Some(10));
1224 } else {
1225 panic!("Expected Select");
1226 }
1227 }
1228
1229 #[test]
1230 fn test_insert() {
1231 let stmt = parse_query(
1232 "INSERT INTO test (title, count) VALUES ('Hello', 42)",
1233 )
1234 .unwrap();
1235 if let Statement::Insert(q) = stmt {
1236 assert_eq!(q.table, "test");
1237 assert_eq!(q.columns, vec!["title", "count"]);
1238 assert_eq!(q.values[0], SqlValue::String("Hello".into()));
1239 assert_eq!(q.values[1], SqlValue::Int(42));
1240 } else {
1241 panic!("Expected Insert");
1242 }
1243 }
1244
1245 #[test]
1246 fn test_update() {
1247 let stmt = parse_query("UPDATE test SET status = 'KILLED' WHERE path = 'a.md'").unwrap();
1248 if let Statement::Update(q) = stmt {
1249 assert_eq!(q.table, "test");
1250 assert_eq!(q.assignments.len(), 1);
1251 assert!(q.where_clause.is_some());
1252 } else {
1253 panic!("Expected Update");
1254 }
1255 }
1256
1257 #[test]
1258 fn test_delete() {
1259 let stmt = parse_query("DELETE FROM test WHERE status = 'draft'").unwrap();
1260 if let Statement::Delete(q) = stmt {
1261 assert_eq!(q.table, "test");
1262 assert!(q.where_clause.is_some());
1263 } else {
1264 panic!("Expected Delete");
1265 }
1266 }
1267
1268 #[test]
1269 fn test_alter_rename() {
1270 let stmt =
1271 parse_query("ALTER TABLE test RENAME FIELD 'Summary' TO 'Overview'").unwrap();
1272 if let Statement::AlterRename(q) = stmt {
1273 assert_eq!(q.old_name, "Summary");
1274 assert_eq!(q.new_name, "Overview");
1275 } else {
1276 panic!("Expected AlterRename");
1277 }
1278 }
1279
1280 #[test]
1281 fn test_alter_drop() {
1282 let stmt = parse_query("ALTER TABLE test DROP FIELD 'Details'").unwrap();
1283 if let Statement::AlterDrop(q) = stmt {
1284 assert_eq!(q.field_name, "Details");
1285 } else {
1286 panic!("Expected AlterDrop");
1287 }
1288 }
1289
1290 #[test]
1291 fn test_alter_merge() {
1292 let stmt = parse_query(
1293 "ALTER TABLE test MERGE FIELDS 'Entry Rules', 'Exit Rules' INTO 'Trading Rules'",
1294 )
1295 .unwrap();
1296 if let Statement::AlterMerge(q) = stmt {
1297 assert_eq!(q.sources, vec!["Entry Rules", "Exit Rules"]);
1298 assert_eq!(q.into, "Trading Rules");
1299 } else {
1300 panic!("Expected AlterMerge");
1301 }
1302 }
1303
1304 #[test]
1305 fn test_backtick_ident() {
1306 let stmt = parse_query("SELECT `Structural Mechanism` FROM test").unwrap();
1307 if let Statement::Select(q) = stmt {
1308 assert_eq!(
1309 q.columns,
1310 ColumnList::Named(vec![SelectExpr::Column("Structural Mechanism".into())])
1311 );
1312 } else {
1313 panic!("Expected Select");
1314 }
1315 }
1316
1317 #[test]
1318 fn test_like_operator() {
1319 let stmt = parse_query("SELECT title FROM test WHERE categories LIKE '%defi%'").unwrap();
1320 if let Statement::Select(q) = stmt {
1321 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1322 assert_eq!(c.op, "LIKE");
1323 assert_eq!(c.value, Some(SqlValue::String("%defi%".into())));
1324 } else {
1325 panic!("Expected LIKE comparison");
1326 }
1327 } else {
1328 panic!("Expected Select");
1329 }
1330 }
1331
1332 #[test]
1333 fn test_in_operator() {
1334 let stmt =
1335 parse_query("SELECT * FROM test WHERE status IN ('ACTIVE', 'LIVE')").unwrap();
1336 if let Statement::Select(q) = stmt {
1337 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1338 assert_eq!(c.op, "IN");
1339 } else {
1340 panic!("Expected IN comparison");
1341 }
1342 } else {
1343 panic!("Expected Select");
1344 }
1345 }
1346
1347 #[test]
1348 fn test_is_null() {
1349 let stmt = parse_query("SELECT * FROM test WHERE title IS NULL").unwrap();
1350 if let Statement::Select(q) = stmt {
1351 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1352 assert_eq!(c.op, "IS NULL");
1353 } else {
1354 panic!("Expected IS NULL comparison");
1355 }
1356 } else {
1357 panic!("Expected Select");
1358 }
1359 }
1360
1361 #[test]
1362 fn test_and_or() {
1363 let stmt = parse_query(
1364 "SELECT * FROM test WHERE status = 'ACTIVE' AND count > 5 OR title LIKE '%test%'",
1365 )
1366 .unwrap();
1367 if let Statement::Select(q) = stmt {
1368 assert!(q.where_clause.is_some());
1369 } else {
1370 panic!("Expected Select");
1371 }
1372 }
1373
1374 #[test]
1375 fn test_join() {
1376 let stmt = parse_query(
1377 "SELECT s.title, b.sharpe FROM strategies s JOIN backtests b ON b.strategy = s.path",
1378 )
1379 .unwrap();
1380 if let Statement::Select(q) = stmt {
1381 assert_eq!(q.table, "strategies");
1382 assert_eq!(q.table_alias, Some("s".into()));
1383 assert_eq!(q.joins.len(), 1);
1384 let join = &q.joins[0];
1385 assert_eq!(join.table, "backtests");
1386 assert_eq!(join.alias, Some("b".into()));
1387 } else {
1388 panic!("Expected Select");
1389 }
1390 }
1391
1392 #[test]
1393 fn test_multi_join() {
1394 let stmt = parse_query(
1395 "SELECT s.title, b.sharpe, c.verdict FROM strategies s JOIN backtests b ON b.strategy = s.path JOIN critiques c ON c.strategy = s.path",
1396 )
1397 .unwrap();
1398 if let Statement::Select(q) = stmt {
1399 assert_eq!(q.table, "strategies");
1400 assert_eq!(q.table_alias, Some("s".into()));
1401 assert_eq!(q.joins.len(), 2);
1402 assert_eq!(q.joins[0].table, "backtests");
1403 assert_eq!(q.joins[0].alias, Some("b".into()));
1404 assert_eq!(q.joins[0].left_col, "b.strategy");
1405 assert_eq!(q.joins[0].right_col, "s.path");
1406 assert_eq!(q.joins[1].table, "critiques");
1407 assert_eq!(q.joins[1].alias, Some("c".into()));
1408 assert_eq!(q.joins[1].left_col, "c.strategy");
1409 assert_eq!(q.joins[1].right_col, "s.path");
1410 } else {
1411 panic!("Expected Select");
1412 }
1413 }
1414
1415 #[test]
1416 fn test_empty_query() {
1417 assert!(parse_query("").is_err());
1418 }
1419
1420 #[test]
1421 fn test_count_star() {
1422 let stmt = parse_query("SELECT status, COUNT(*) AS cnt FROM strategies GROUP BY status").unwrap();
1423 if let Statement::Select(q) = stmt {
1424 if let ColumnList::Named(exprs) = &q.columns {
1425 assert_eq!(exprs.len(), 2);
1426 assert_eq!(exprs[0], SelectExpr::Column("status".into()));
1427 assert!(matches!(&exprs[1], SelectExpr::Aggregate {
1428 func: AggFunc::Count,
1429 arg,
1430 alias: Some(a),
1431 ..
1432 } if arg == "*" && a == "cnt"));
1433 } else {
1434 panic!("Expected Named columns");
1435 }
1436 assert_eq!(q.group_by, Some(vec!["status".into()]));
1437 } else {
1438 panic!("Expected Select");
1439 }
1440 }
1441
1442 #[test]
1443 fn test_count_column_as_ident() {
1444 let stmt = parse_query("INSERT INTO test (title, count) VALUES ('Hello', 42)").unwrap();
1446 if let Statement::Insert(q) = stmt {
1447 assert_eq!(q.columns, vec!["title", "count"]);
1448 } else {
1449 panic!("Expected Insert");
1450 }
1451 }
1452
1453 #[test]
1454 fn test_multiple_aggregates() {
1455 let stmt = parse_query("SELECT MIN(composite), MAX(composite), AVG(composite) FROM strategies").unwrap();
1456 if let Statement::Select(q) = stmt {
1457 if let ColumnList::Named(exprs) = &q.columns {
1458 assert_eq!(exprs.len(), 3);
1459 assert!(matches!(&exprs[0], SelectExpr::Aggregate { func: AggFunc::Min, .. }));
1460 assert!(matches!(&exprs[1], SelectExpr::Aggregate { func: AggFunc::Max, .. }));
1461 assert!(matches!(&exprs[2], SelectExpr::Aggregate { func: AggFunc::Avg, .. }));
1462 } else {
1463 panic!("Expected Named columns");
1464 }
1465 assert_eq!(q.group_by, None);
1466 } else {
1467 panic!("Expected Select");
1468 }
1469 }
1470
1471 #[test]
1474 fn test_select_arithmetic_expr() {
1475 let stmt = parse_query("SELECT a + b FROM test").unwrap();
1476 if let Statement::Select(q) = stmt {
1477 if let ColumnList::Named(exprs) = &q.columns {
1478 assert_eq!(exprs.len(), 1);
1479 assert!(matches!(&exprs[0], SelectExpr::Expr {
1480 expr: Expr::BinaryOp { op: ArithOp::Add, .. },
1481 alias: None,
1482 }));
1483 } else {
1484 panic!("Expected Named columns");
1485 }
1486 } else {
1487 panic!("Expected Select");
1488 }
1489 }
1490
1491 #[test]
1492 fn test_select_arithmetic_with_alias() {
1493 let stmt = parse_query("SELECT a + b AS total FROM test").unwrap();
1494 if let Statement::Select(q) = stmt {
1495 if let ColumnList::Named(exprs) = &q.columns {
1496 assert_eq!(exprs.len(), 1);
1497 assert!(matches!(&exprs[0], SelectExpr::Expr {
1498 alias: Some(a),
1499 ..
1500 } if a == "total"));
1501 assert_eq!(exprs[0].output_name(), "total");
1502 } else {
1503 panic!("Expected Named columns");
1504 }
1505 } else {
1506 panic!("Expected Select");
1507 }
1508 }
1509
1510 #[test]
1511 fn test_select_precedence() {
1512 let stmt = parse_query("SELECT a + b * c FROM test").unwrap();
1514 if let Statement::Select(q) = stmt {
1515 if let ColumnList::Named(exprs) = &q.columns {
1516 if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1517 if let Expr::BinaryOp { left, op, right } = expr {
1518 assert_eq!(*op, ArithOp::Add);
1519 assert!(matches!(left.as_ref(), Expr::Column(n) if n == "a"));
1520 assert!(matches!(right.as_ref(), Expr::BinaryOp { op: ArithOp::Mul, .. }));
1521 } else {
1522 panic!("Expected BinaryOp");
1523 }
1524 } else {
1525 panic!("Expected Expr variant");
1526 }
1527 } else {
1528 panic!("Expected Named columns");
1529 }
1530 } else {
1531 panic!("Expected Select");
1532 }
1533 }
1534
1535 #[test]
1536 fn test_select_parenthesized_expr() {
1537 let stmt = parse_query("SELECT (a + b) * c FROM test").unwrap();
1539 if let Statement::Select(q) = stmt {
1540 if let ColumnList::Named(exprs) = &q.columns {
1541 if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1542 if let Expr::BinaryOp { left, op, .. } = expr {
1543 assert_eq!(*op, ArithOp::Mul);
1544 assert!(matches!(left.as_ref(), Expr::BinaryOp { op: ArithOp::Add, .. }));
1545 } else {
1546 panic!("Expected BinaryOp");
1547 }
1548 } else {
1549 panic!("Expected Expr variant");
1550 }
1551 } else {
1552 panic!("Expected Named columns");
1553 }
1554 } else {
1555 panic!("Expected Select");
1556 }
1557 }
1558
1559 #[test]
1560 fn test_select_unary_minus() {
1561 let stmt = parse_query("SELECT -count FROM test").unwrap();
1562 if let Statement::Select(q) = stmt {
1563 if let ColumnList::Named(exprs) = &q.columns {
1564 assert!(matches!(&exprs[0], SelectExpr::Expr {
1565 expr: Expr::UnaryMinus(_),
1566 ..
1567 }));
1568 } else {
1569 panic!("Expected Named columns");
1570 }
1571 } else {
1572 panic!("Expected Select");
1573 }
1574 }
1575
1576 #[test]
1577 fn test_select_negative_literal() {
1578 let stmt = parse_query("SELECT -42 FROM test").unwrap();
1579 if let Statement::Select(q) = stmt {
1580 if let ColumnList::Named(exprs) = &q.columns {
1581 assert!(matches!(&exprs[0], SelectExpr::Expr {
1583 expr: Expr::Literal(SqlValue::Int(-42)),
1584 ..
1585 }));
1586 } else {
1587 panic!("Expected Named columns");
1588 }
1589 } else {
1590 panic!("Expected Select");
1591 }
1592 }
1593
1594 #[test]
1595 fn test_where_arithmetic_expr() {
1596 let stmt = parse_query("SELECT * FROM test WHERE a + b > 10").unwrap();
1597 if let Statement::Select(q) = stmt {
1598 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1599 assert_eq!(c.op, ">");
1600 assert!(matches!(&c.left_expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1601 assert!(matches!(&c.right_expr, Some(Expr::Literal(SqlValue::Int(10)))));
1602 } else {
1603 panic!("Expected comparison");
1604 }
1605 } else {
1606 panic!("Expected Select");
1607 }
1608 }
1609
1610 #[test]
1611 fn test_where_both_sides_expr() {
1612 let stmt = parse_query("SELECT * FROM test WHERE a * 2 > b + 1").unwrap();
1613 if let Statement::Select(q) = stmt {
1614 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1615 assert_eq!(c.op, ">");
1616 assert!(matches!(&c.left_expr, Some(Expr::BinaryOp { op: ArithOp::Mul, .. })));
1617 assert!(matches!(&c.right_expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1618 } else {
1619 panic!("Expected comparison");
1620 }
1621 } else {
1622 panic!("Expected Select");
1623 }
1624 }
1625
1626 #[test]
1627 fn test_order_by_expr() {
1628 let stmt = parse_query("SELECT * FROM test ORDER BY a + b DESC").unwrap();
1629 if let Statement::Select(q) = stmt {
1630 let ob = q.order_by.unwrap();
1631 assert_eq!(ob.len(), 1);
1632 assert!(ob[0].descending);
1633 assert!(matches!(&ob[0].expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1634 } else {
1635 panic!("Expected Select");
1636 }
1637 }
1638
1639 #[test]
1640 fn test_all_arithmetic_ops() {
1641 let stmt = parse_query("SELECT a + b, a - b, a * b, a / b, a % b FROM test").unwrap();
1642 if let Statement::Select(q) = stmt {
1643 if let ColumnList::Named(exprs) = &q.columns {
1644 assert_eq!(exprs.len(), 5);
1645 assert!(matches!(&exprs[0], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Add, .. }, .. }));
1646 assert!(matches!(&exprs[1], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Sub, .. }, .. }));
1647 assert!(matches!(&exprs[2], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Mul, .. }, .. }));
1648 assert!(matches!(&exprs[3], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Div, .. }, .. }));
1649 assert!(matches!(&exprs[4], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Mod, .. }, .. }));
1650 } else {
1651 panic!("Expected Named columns");
1652 }
1653 } else {
1654 panic!("Expected Select");
1655 }
1656 }
1657
1658 #[test]
1659 fn test_column_with_literal_arithmetic() {
1660 let stmt = parse_query("SELECT count * 2 + 1 FROM test").unwrap();
1661 if let Statement::Select(q) = stmt {
1662 if let ColumnList::Named(exprs) = &q.columns {
1663 if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1665 if let Expr::BinaryOp { left, op, right } = expr {
1666 assert_eq!(*op, ArithOp::Add);
1667 assert!(matches!(right.as_ref(), Expr::Literal(SqlValue::Int(1))));
1668 assert!(matches!(left.as_ref(), Expr::BinaryOp { op: ArithOp::Mul, .. }));
1669 } else {
1670 panic!("Expected BinaryOp");
1671 }
1672 } else {
1673 panic!("Expected Expr");
1674 }
1675 } else {
1676 panic!("Expected Named columns");
1677 }
1678 } else {
1679 panic!("Expected Select");
1680 }
1681 }
1682
1683 #[test]
1684 fn test_mixed_columns_and_exprs() {
1685 let stmt = parse_query("SELECT title, a + b AS sum, count FROM test").unwrap();
1686 if let Statement::Select(q) = stmt {
1687 if let ColumnList::Named(exprs) = &q.columns {
1688 assert_eq!(exprs.len(), 3);
1689 assert_eq!(exprs[0], SelectExpr::Column("title".into()));
1690 assert!(matches!(&exprs[1], SelectExpr::Expr { alias: Some(a), .. } if a == "sum"));
1691 assert_eq!(exprs[2], SelectExpr::Column("count".into()));
1692 } else {
1693 panic!("Expected Named columns");
1694 }
1695 } else {
1696 panic!("Expected Select");
1697 }
1698 }
1699
1700 #[test]
1703 fn test_case_when_basic() {
1704 let stmt = parse_query(
1705 "SELECT CASE WHEN status = 'ACTIVE' THEN 1 ELSE 0 END FROM test"
1706 ).unwrap();
1707 if let Statement::Select(q) = stmt {
1708 if let ColumnList::Named(exprs) = &q.columns {
1709 assert_eq!(exprs.len(), 1);
1710 assert!(matches!(&exprs[0], SelectExpr::Expr {
1711 expr: Expr::Case { .. },
1712 ..
1713 }));
1714 } else {
1715 panic!("Expected Named columns");
1716 }
1717 } else {
1718 panic!("Expected Select");
1719 }
1720 }
1721
1722 #[test]
1723 fn test_case_when_multiple_branches() {
1724 let stmt = parse_query(
1725 "SELECT CASE WHEN x > 10 THEN 'high' WHEN x > 5 THEN 'mid' ELSE 'low' END FROM test"
1726 ).unwrap();
1727 if let Statement::Select(q) = stmt {
1728 if let ColumnList::Named(exprs) = &q.columns {
1729 if let SelectExpr::Expr { expr: Expr::Case { whens, else_expr }, .. } = &exprs[0] {
1730 assert_eq!(whens.len(), 2);
1731 assert!(else_expr.is_some());
1732 } else {
1733 panic!("Expected Case expression");
1734 }
1735 } else {
1736 panic!("Expected Named columns");
1737 }
1738 } else {
1739 panic!("Expected Select");
1740 }
1741 }
1742
1743 #[test]
1744 fn test_case_when_no_else() {
1745 let stmt = parse_query(
1746 "SELECT CASE WHEN x = 1 THEN 'one' END FROM test"
1747 ).unwrap();
1748 if let Statement::Select(q) = stmt {
1749 if let ColumnList::Named(exprs) = &q.columns {
1750 if let SelectExpr::Expr { expr: Expr::Case { whens, else_expr }, .. } = &exprs[0] {
1751 assert_eq!(whens.len(), 1);
1752 assert!(else_expr.is_none());
1753 } else {
1754 panic!("Expected Case expression");
1755 }
1756 } else {
1757 panic!("Expected Named columns");
1758 }
1759 } else {
1760 panic!("Expected Select");
1761 }
1762 }
1763
1764 #[test]
1765 fn test_case_when_in_aggregate() {
1766 let stmt = parse_query(
1767 "SELECT SUM(CASE WHEN side = 'BUY' THEN size ELSE -size END) AS net FROM orders GROUP BY token"
1768 ).unwrap();
1769 if let Statement::Select(q) = stmt {
1770 if let ColumnList::Named(exprs) = &q.columns {
1771 assert_eq!(exprs.len(), 1);
1772 assert!(matches!(&exprs[0], SelectExpr::Aggregate {
1773 func: AggFunc::Sum,
1774 arg_expr: Some(Expr::Case { .. }),
1775 alias: Some(a),
1776 ..
1777 } if a == "net"));
1778 } else {
1779 panic!("Expected Named columns");
1780 }
1781 } else {
1782 panic!("Expected Select");
1783 }
1784 }
1785
1786 #[test]
1787 fn test_case_when_with_alias() {
1788 let stmt = parse_query(
1789 "SELECT CASE WHEN x > 0 THEN 'pos' ELSE 'neg' END AS sign FROM test"
1790 ).unwrap();
1791 if let Statement::Select(q) = stmt {
1792 if let ColumnList::Named(exprs) = &q.columns {
1793 assert!(matches!(&exprs[0], SelectExpr::Expr {
1794 expr: Expr::Case { .. },
1795 alias: Some(a),
1796 } if a == "sign"));
1797 } else {
1798 panic!("Expected Named columns");
1799 }
1800 } else {
1801 panic!("Expected Select");
1802 }
1803 }
1804}