1use regex::Regex;
4use std::sync::LazyLock;
5
6use crate::errors::MdqlError;
7pub use crate::query_ast::*;
8
9static KEYWORDS: &[&str] = &[
12 "SELECT", "FROM", "WHERE", "AND", "OR", "ORDER", "BY",
13 "ASC", "DESC", "LIMIT", "LIKE", "IN", "IS", "NOT", "NULL",
14 "JOIN", "ON", "AS", "GROUP", "HAVING",
15 "INSERT", "INTO", "VALUES", "UPDATE", "SET", "DELETE",
16 "ALTER", "TABLE", "RENAME", "FIELD", "TO", "DROP", "MERGE", "FIELDS",
17 "CASE", "WHEN", "THEN", "ELSE", "END",
18 "INTERVAL", "DAY", "DAYS", "CURRENT_DATE", "CURRENT_TIMESTAMP", "DATEDIFF",
19 "CREATE", "VIEW", "CASCADE", "RESTRICT",
20];
21
22static AGG_FUNCS: &[&str] = &["COUNT", "SUM", "AVG", "MIN", "MAX"];
23
24static TOKEN_RE: LazyLock<Regex> = LazyLock::new(|| {
25 Regex::new(
26 r#"(?x)
27 \s*(?:
28 (?P<backtick>`[^`]+`)
29 | (?P<string>'(?:[^'\\]|\\.)*')
30 | (?P<number>\d+(?:\.\d+)?)
31 | (?P<op><=|>=|!=|[=<>,*()+\-/%])
32 | (?P<word>[A-Za-z_][A-Za-z0-9_./-]*)
33 )"#,
34 )
35 .unwrap()
36});
37
38#[derive(Debug, Clone)]
39struct Token {
40 token_type: String,
41 value: String,
42 raw: String,
43}
44
45fn tokenize(sql: &str) -> Vec<Token> {
46 let mut tokens = Vec::new();
47 for caps in TOKEN_RE.captures_iter(sql) {
48 if let Some(m) = caps.name("backtick") {
49 let raw = m.as_str();
50 tokens.push(Token {
51 token_type: "ident".into(),
52 value: raw[1..raw.len() - 1].into(),
53 raw: raw.into(),
54 });
55 } else if let Some(m) = caps.name("string") {
56 let raw = m.as_str();
57 tokens.push(Token {
58 token_type: "string".into(),
59 value: raw[1..raw.len() - 1].into(),
60 raw: raw.into(),
61 });
62 } else if let Some(m) = caps.name("number") {
63 let raw = m.as_str();
64 tokens.push(Token {
65 token_type: "number".into(),
66 value: raw.into(),
67 raw: raw.into(),
68 });
69 } else if let Some(m) = caps.name("op") {
70 let raw = m.as_str();
71 tokens.push(Token {
72 token_type: "op".into(),
73 value: raw.into(),
74 raw: raw.into(),
75 });
76 } else if let Some(m) = caps.name("word") {
77 let raw = m.as_str();
78 if KEYWORDS.contains(&raw.to_uppercase().as_str()) {
79 tokens.push(Token {
80 token_type: "keyword".into(),
81 value: raw.to_uppercase(),
82 raw: raw.into(),
83 });
84 } else {
85 tokens.push(Token {
86 token_type: "ident".into(),
87 value: raw.into(),
88 raw: raw.into(),
89 });
90 }
91 }
92 }
93 tokens
94}
95
96struct Parser {
99 tokens: Vec<Token>,
100 pos: usize,
101}
102
103impl Parser {
104 fn new(tokens: Vec<Token>) -> Self {
105 Parser { tokens, pos: 0 }
106 }
107
108 fn peek(&self) -> Option<&Token> {
109 self.tokens.get(self.pos)
110 }
111
112 fn advance(&mut self) -> Token {
113 let t = self.tokens[self.pos].clone();
114 self.pos += 1;
115 t
116 }
117
118 fn expect(&mut self, type_: &str, value: Option<&str>) -> Result<Token, MdqlError> {
119 let t = self.peek().ok_or_else(|| {
120 MdqlError::QueryParse(format!(
121 "Unexpected end of query, expected {}",
122 value.unwrap_or(type_)
123 ))
124 })?;
125 let matches_type = t.token_type == type_;
126 let matches_value = value.map_or(true, |v| t.value == v);
127 if !matches_type || !matches_value {
128 return Err(MdqlError::QueryParse(format!(
129 "Expected {}, got '{}' at position {}",
130 value.unwrap_or(type_),
131 t.raw,
132 self.pos
133 )));
134 }
135 Ok(self.advance())
136 }
137
138 fn match_keyword(&mut self, kw: &str) -> bool {
139 if let Some(t) = self.peek() {
140 if t.token_type == "keyword" && t.value == kw {
141 self.advance();
142 return true;
143 }
144 }
145 false
146 }
147
148 fn parse_statement(&mut self) -> Result<Statement, MdqlError> {
149 let t = self.peek().ok_or_else(|| MdqlError::QueryParse("Empty query".into()))?;
150 match (t.token_type.as_str(), t.value.as_str()) {
151 ("keyword", "SELECT") => {
152 let q = self.parse_select()?;
153 self.expect_end()?;
154 Ok(Statement::Select(q))
155 }
156 ("keyword", "INSERT") => Ok(Statement::Insert(self.parse_insert()?)),
157 ("keyword", "UPDATE") => Ok(Statement::Update(self.parse_update()?)),
158 ("keyword", "DELETE") => Ok(Statement::Delete(self.parse_delete()?)),
159 ("keyword", "ALTER") => self.parse_alter(),
160 ("keyword", "CREATE") => self.parse_create_view(),
161 ("keyword", "DROP") => self.parse_drop_view(),
162 _ => Err(MdqlError::QueryParse(format!(
163 "Expected SELECT, INSERT, UPDATE, DELETE, ALTER, CREATE, or DROP, got '{}'",
164 t.raw
165 ))),
166 }
167 }
168
169 fn parse_select(&mut self) -> Result<SelectQuery, MdqlError> {
170 self.expect("keyword", Some("SELECT"))?;
171 let columns = self.parse_columns()?;
172 self.expect("keyword", Some("FROM"))?;
173
174 let mut subquery = None;
176 let (table, mut table_alias) = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
177 self.advance();
178 let inner = self.parse_select()?;
179 self.expect("op", Some(")"))?;
180 subquery = Some(Box::new(inner));
181 let alias = if let Some(t) = self.peek() {
182 if t.token_type == "ident" && !self.is_clause_keyword(t) {
183 Some(self.advance().value)
184 } else {
185 None
186 }
187 } else {
188 None
189 };
190 ("_subquery".to_string(), alias)
191 } else {
192 let t = self.parse_ident()?;
193 (t, None)
194 };
195
196 if subquery.is_none() {
198 if let Some(t) = self.peek() {
199 if t.token_type == "ident" && !self.is_clause_keyword(t) {
200 table_alias = Some(self.advance().value);
201 }
202 }
203 }
204
205 let mut joins = Vec::new();
207 while self.match_keyword("JOIN") {
208 let join_table = self.parse_ident()?;
209 let mut join_alias = None;
210 if let Some(t) = self.peek() {
211 if t.token_type == "ident" && !self.is_clause_keyword(t) {
212 join_alias = Some(self.advance().value);
213 }
214 }
215 self.expect("keyword", Some("ON"))?;
216 let left_col = self.parse_ident()?;
217 self.expect("op", Some("="))?;
218 let right_col = self.parse_ident()?;
219 joins.push(JoinClause {
220 table: join_table,
221 alias: join_alias,
222 left_col,
223 right_col,
224 });
225 }
226
227 let mut where_clause = None;
228 if self.match_keyword("WHERE") {
229 where_clause = Some(self.parse_or_expr()?);
230 }
231
232 let mut group_by = None;
233 if self.match_keyword("GROUP") {
234 self.expect("keyword", Some("BY"))?;
235 let mut cols = vec![self.parse_ident()?];
236 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
237 self.advance();
238 cols.push(self.parse_ident()?);
239 }
240 group_by = Some(cols);
241 }
242
243 let mut having = None;
244 if self.match_keyword("HAVING") {
245 having = Some(self.parse_or_expr()?);
246 }
247
248 let mut order_by = None;
249 if self.match_keyword("ORDER") {
250 self.expect("keyword", Some("BY"))?;
251 order_by = Some(self.parse_order_by()?);
252 }
253
254 let mut limit = None;
255 if self.match_keyword("LIMIT") {
256 let t = self.expect("number", None)?;
257 limit = Some(t.value.parse::<i64>().map_err(|_| {
258 MdqlError::QueryParse(format!("Invalid LIMIT value: {}", t.value))
259 })?);
260 }
261
262 Ok(SelectQuery {
263 columns,
264 table,
265 table_alias,
266 subquery,
267 joins,
268 where_clause,
269 group_by,
270 having,
271 order_by,
272 limit,
273 })
274 }
275
276 fn parse_insert(&mut self) -> Result<InsertQuery, MdqlError> {
277 self.expect("keyword", Some("INSERT"))?;
278 self.expect("keyword", Some("INTO"))?;
279 let table = self.parse_ident()?;
280
281 self.expect("op", Some("("))?;
282 let mut columns = vec![self.parse_ident()?];
283 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
284 self.advance();
285 columns.push(self.parse_ident()?);
286 }
287 self.expect("op", Some(")"))?;
288
289 self.expect("keyword", Some("VALUES"))?;
290
291 self.expect("op", Some("("))?;
292 let mut values = vec![self.parse_value()?];
293 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
294 self.advance();
295 values.push(self.parse_value()?);
296 }
297 self.expect("op", Some(")"))?;
298
299 if columns.len() != values.len() {
300 return Err(MdqlError::QueryParse(format!(
301 "Column count ({}) does not match value count ({})",
302 columns.len(),
303 values.len()
304 )));
305 }
306
307 self.expect_end()?;
308 Ok(InsertQuery {
309 table,
310 columns,
311 values,
312 })
313 }
314
315 fn parse_update(&mut self) -> Result<UpdateQuery, MdqlError> {
316 self.expect("keyword", Some("UPDATE"))?;
317 let table = self.parse_ident()?;
318 self.expect("keyword", Some("SET"))?;
319
320 let mut assignments = Vec::new();
321 let col = self.parse_ident()?;
322 self.expect("op", Some("="))?;
323 let val = self.parse_value()?;
324 assignments.push((col, val));
325
326 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
327 self.advance();
328 let col = self.parse_ident()?;
329 self.expect("op", Some("="))?;
330 let val = self.parse_value()?;
331 assignments.push((col, val));
332 }
333
334 let mut where_clause = None;
335 if self.match_keyword("WHERE") {
336 where_clause = Some(self.parse_or_expr()?);
337 }
338
339 self.expect_end()?;
340 Ok(UpdateQuery {
341 table,
342 assignments,
343 where_clause,
344 })
345 }
346
347 fn parse_delete(&mut self) -> Result<DeleteQuery, MdqlError> {
348 self.expect("keyword", Some("DELETE"))?;
349 self.expect("keyword", Some("FROM"))?;
350 let table = self.parse_ident()?;
351
352 let mut where_clause = None;
353 if self.match_keyword("WHERE") {
354 where_clause = Some(self.parse_or_expr()?);
355 }
356
357 self.expect_end()?;
358 Ok(DeleteQuery {
359 table,
360 where_clause,
361 })
362 }
363
364 fn parse_alter(&mut self) -> Result<Statement, MdqlError> {
365 self.expect("keyword", Some("ALTER"))?;
366 self.expect("keyword", Some("TABLE"))?;
367 let table = self.parse_ident()?;
368
369 let t = self.peek().ok_or_else(|| {
370 MdqlError::QueryParse("Expected RENAME, DROP, or MERGE after table name".into())
371 })?;
372
373 match (t.token_type.as_str(), t.value.as_str()) {
374 ("keyword", "RENAME") => {
375 self.advance();
376 self.expect("keyword", Some("FIELD"))?;
377 let old_name = self.parse_string_or_ident()?;
378 self.expect("keyword", Some("TO"))?;
379 let new_name = self.parse_string_or_ident()?;
380 self.expect_end()?;
381 Ok(Statement::AlterRename(AlterRenameFieldQuery {
382 table,
383 old_name,
384 new_name,
385 }))
386 }
387 ("keyword", "DROP") => {
388 self.advance();
389 self.expect("keyword", Some("FIELD"))?;
390 let field_name = self.parse_string_or_ident()?;
391 self.expect_end()?;
392 Ok(Statement::AlterDrop(AlterDropFieldQuery {
393 table,
394 field_name,
395 }))
396 }
397 ("keyword", "MERGE") => {
398 self.advance();
399 self.expect("keyword", Some("FIELDS"))?;
400 let mut sources = vec![self.parse_string_or_ident()?];
401 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
402 self.advance();
403 sources.push(self.parse_string_or_ident()?);
404 }
405 self.expect("keyword", Some("INTO"))?;
406 let target = self.parse_string_or_ident()?;
407 self.expect_end()?;
408 Ok(Statement::AlterMerge(AlterMergeFieldsQuery {
409 table,
410 sources,
411 into: target,
412 }))
413 }
414 _ => Err(MdqlError::QueryParse(format!(
415 "Expected RENAME, DROP, or MERGE, got '{}'",
416 t.raw
417 ))),
418 }
419 }
420
421 fn parse_create_view(&mut self) -> Result<Statement, MdqlError> {
422 self.expect("keyword", Some("CREATE"))?;
423 self.expect("keyword", Some("VIEW"))?;
424 let view_name = self.parse_ident()?;
425
426 let columns = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
427 self.advance();
428 let mut cols = vec![self.parse_ident()?];
429 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
430 self.advance();
431 cols.push(self.parse_ident()?);
432 }
433 self.expect("op", Some(")"))?;
434 Some(cols)
435 } else {
436 None
437 };
438
439 self.expect("keyword", Some("AS"))?;
440 let query = Box::new(self.parse_select()?);
441 self.expect_end()?;
442
443 Ok(Statement::CreateView(CreateViewQuery {
444 view_name,
445 columns,
446 query,
447 }))
448 }
449
450 fn parse_drop_view(&mut self) -> Result<Statement, MdqlError> {
451 self.expect("keyword", Some("DROP"))?;
452 self.expect("keyword", Some("VIEW"))?;
453 let view_name = self.parse_ident()?;
454 self.expect_end()?;
455 Ok(Statement::DropView(DropViewQuery { view_name }))
456 }
457
458 fn parse_string_or_ident(&mut self) -> Result<String, MdqlError> {
459 let t = self.peek().ok_or_else(|| {
460 MdqlError::QueryParse("Expected field name, got end of query".into())
461 })?;
462 match t.token_type.as_str() {
463 "string" => {
464 let v = self.advance().value;
465 Ok(v)
466 }
467 "ident" | "keyword" => {
468 let v = self.advance().value;
469 Ok(v)
470 }
471 _ => Err(MdqlError::QueryParse(format!(
472 "Expected field name, got '{}'",
473 t.raw
474 ))),
475 }
476 }
477
478 fn parse_columns(&mut self) -> Result<ColumnList, MdqlError> {
479 if let Some(t) = self.peek() {
480 if t.token_type == "op" && t.value == "*" {
481 self.advance();
482 return Ok(ColumnList::All);
483 }
484 }
485
486 let mut exprs = vec![self.parse_select_expr()?];
487 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
488 self.advance();
489 exprs.push(self.parse_select_expr()?);
490 }
491 Ok(ColumnList::Named(exprs))
492 }
493
494 fn peek_is_agg_func(&self) -> bool {
495 let t = match self.peek() {
496 Some(t) => t,
497 None => return false,
498 };
499 let name_upper = t.value.to_uppercase();
500 if !AGG_FUNCS.contains(&name_upper.as_str()) {
501 return false;
502 }
503 self.tokens
505 .get(self.pos + 1)
506 .map_or(false, |next| next.token_type == "op" && next.value == "(")
507 }
508
509 fn parse_select_expr(&mut self) -> Result<SelectExpr, MdqlError> {
510 let _t = self.peek().ok_or_else(|| {
511 MdqlError::QueryParse("Expected column or aggregate, got end of query".into())
512 })?;
513
514 let expr = self.parse_additive()?;
515
516 let alias = if self.match_keyword("AS") {
517 Some(self.parse_ident()?)
518 } else if self.peek().map_or(false, |t| {
519 t.token_type == "ident" && !self.is_clause_keyword(t)
520 }) {
521 Some(self.advance().value)
522 } else {
523 None
524 };
525
526 if let Expr::Aggregate { func, arg, arg_expr } = expr {
528 return Ok(SelectExpr::Aggregate {
529 func,
530 arg,
531 arg_expr: arg_expr.map(|e| *e),
532 alias,
533 });
534 }
535
536 if alias.is_none() {
537 if let Expr::Column(name) = &expr {
538 return Ok(SelectExpr::Column(name.clone()));
539 }
540 }
541
542 Ok(SelectExpr::Expr { expr, alias })
543 }
544
545 fn peek_is_additive_op(&self) -> bool {
548 self.peek().map_or(false, |t| {
549 t.token_type == "op" && (t.value == "+" || t.value == "-")
550 })
551 }
552
553 fn peek_is_multiplicative_op(&self) -> bool {
554 self.peek().map_or(false, |t| {
555 t.token_type == "op" && (t.value == "*" || t.value == "/" || t.value == "%")
556 })
557 }
558
559 fn parse_additive(&mut self) -> Result<Expr, MdqlError> {
560 let mut left = self.parse_multiplicative()?;
561 while self.peek_is_additive_op() {
562 let op_tok = self.advance();
563 let is_sub = op_tok.value == "-";
564
565 if self.peek().map_or(false, |t| t.token_type == "keyword" && t.value == "INTERVAL") {
567 self.advance(); let days_expr = self.parse_multiplicative()?;
569 if !self.match_keyword("DAY") && !self.match_keyword("DAYS") {
571 return Err(MdqlError::QueryParse("Expected DAY after INTERVAL value".into()));
572 }
573 let days = if is_sub {
574 Expr::UnaryMinus(Box::new(days_expr))
575 } else {
576 days_expr
577 };
578 left = Expr::DateAdd {
579 date: Box::new(left),
580 days: Box::new(days),
581 };
582 continue;
583 }
584
585 let op = match op_tok.value.as_str() {
586 "+" => ArithOp::Add,
587 "-" => ArithOp::Sub,
588 _ => unreachable!(),
589 };
590 let right = self.parse_multiplicative()?;
591 left = Expr::BinaryOp {
592 left: Box::new(left),
593 op,
594 right: Box::new(right),
595 };
596 }
597 Ok(left)
598 }
599
600 fn parse_multiplicative(&mut self) -> Result<Expr, MdqlError> {
601 let mut left = self.parse_unary()?;
602 while self.peek_is_multiplicative_op() {
603 let op_tok = self.advance();
604 let op = match op_tok.value.as_str() {
605 "*" => ArithOp::Mul,
606 "/" => ArithOp::Div,
607 "%" => ArithOp::Mod,
608 _ => unreachable!(),
609 };
610 let right = self.parse_unary()?;
611 left = Expr::BinaryOp {
612 left: Box::new(left),
613 op,
614 right: Box::new(right),
615 };
616 }
617 Ok(left)
618 }
619
620 fn parse_unary(&mut self) -> Result<Expr, MdqlError> {
621 if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "-") {
622 self.advance();
623 let inner = self.parse_atom()?;
624 match inner {
626 Expr::Literal(SqlValue::Int(n)) => Ok(Expr::Literal(SqlValue::Int(-n))),
627 Expr::Literal(SqlValue::Float(f)) => Ok(Expr::Literal(SqlValue::Float(-f))),
628 _ => Ok(Expr::UnaryMinus(Box::new(inner))),
629 }
630 } else {
631 self.parse_atom()
632 }
633 }
634
635 fn parse_atom(&mut self) -> Result<Expr, MdqlError> {
636 if self.peek_is_agg_func() {
637 return self.parse_agg_expr();
638 }
639
640 let t = self.peek().ok_or_else(|| {
641 MdqlError::QueryParse("Expected expression, got end of query".into())
642 })?;
643
644 match t.token_type.as_str() {
645 "number" => {
646 let v = self.advance().value;
647 if v.contains('.') {
648 let f: f64 = v.parse().map_err(|_| {
649 MdqlError::QueryParse(format!("Invalid float: {}", v))
650 })?;
651 Ok(Expr::Literal(SqlValue::Float(f)))
652 } else {
653 let n: i64 = v.parse().map_err(|_| {
654 MdqlError::QueryParse(format!("Invalid int: {}", v))
655 })?;
656 Ok(Expr::Literal(SqlValue::Int(n)))
657 }
658 }
659 "string" => {
660 let v = self.advance().value;
661 Ok(Expr::Literal(SqlValue::String(v)))
662 }
663 "keyword" if t.value == "NULL" => {
664 self.advance();
665 Ok(Expr::Literal(SqlValue::Null))
666 }
667 "keyword" if t.value == "CASE" => {
668 self.parse_case_expr()
669 }
670 "keyword" if t.value == "CURRENT_DATE" => {
671 self.advance();
672 Ok(Expr::CurrentDate)
673 }
674 "keyword" if t.value == "CURRENT_TIMESTAMP" => {
675 self.advance();
676 Ok(Expr::CurrentTimestamp)
677 }
678 "keyword" if t.value == "DATEDIFF" => {
679 self.advance();
680 self.expect("op", Some("("))?;
681 let left = self.parse_additive()?;
682 self.expect("op", Some(","))?;
683 let right = self.parse_additive()?;
684 self.expect("op", Some(")"))?;
685 Ok(Expr::DateDiff { left: Box::new(left), right: Box::new(right) })
686 }
687 "op" if t.value == "(" => {
688 self.advance();
689 let expr = self.parse_additive()?;
690 self.expect("op", Some(")"))?;
691 Ok(expr)
692 }
693 "ident" => {
694 let name = self.advance().value;
695 Ok(Expr::Column(name))
696 }
697 "keyword" if !Self::is_reserved_keyword(&t.value) => {
698 let name = self.advance().value;
699 Ok(Expr::Column(name))
700 }
701 _ => Err(MdqlError::QueryParse(format!(
702 "Expected expression, got '{}'",
703 t.raw
704 ))),
705 }
706 }
707
708 fn parse_case_expr(&mut self) -> Result<Expr, MdqlError> {
709 self.expect("keyword", Some("CASE"))?;
710 let mut whens = Vec::new();
711 while self.match_keyword("WHEN") {
712 let condition = self.parse_or_expr()?;
713 self.expect("keyword", Some("THEN"))?;
714 let result = self.parse_additive()?;
715 whens.push((condition, Box::new(result)));
716 }
717 if whens.is_empty() {
718 return Err(MdqlError::QueryParse("CASE requires at least one WHEN clause".into()));
719 }
720 let else_expr = if self.match_keyword("ELSE") {
721 Some(Box::new(self.parse_additive()?))
722 } else {
723 None
724 };
725 self.expect("keyword", Some("END"))?;
726 Ok(Expr::Case { whens, else_expr })
727 }
728
729 fn parse_agg_expr(&mut self) -> Result<Expr, MdqlError> {
730 let func_name = self.advance().value.to_uppercase();
731 let func = match func_name.as_str() {
732 "COUNT" => AggFunc::Count,
733 "SUM" => AggFunc::Sum,
734 "AVG" => AggFunc::Avg,
735 "MIN" => AggFunc::Min,
736 "MAX" => AggFunc::Max,
737 _ => unreachable!(),
738 };
739 self.expect("op", Some("("))?;
740 let (arg, arg_expr) = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "*") {
741 self.advance();
742 ("*".to_string(), None)
743 } else {
744 let expr = self.parse_additive()?;
745 if let Expr::Column(name) = &expr {
746 (name.clone(), None)
747 } else {
748 (expr.display_name(), Some(Box::new(expr)))
749 }
750 };
751 self.expect("op", Some(")"))?;
752 Ok(Expr::Aggregate { func, arg, arg_expr })
753 }
754
755 fn parse_ident(&mut self) -> Result<String, MdqlError> {
756 let t = self.peek().ok_or_else(|| {
757 MdqlError::QueryParse("Expected identifier, got end of query".into())
758 })?;
759 match t.token_type.as_str() {
760 "ident" | "keyword" => {
761 let v = self.advance().value;
762 Ok(v)
763 }
764 _ => Err(MdqlError::QueryParse(format!(
765 "Expected identifier, got '{}'",
766 t.raw
767 ))),
768 }
769 }
770
771 fn parse_or_expr(&mut self) -> Result<WhereClause, MdqlError> {
772 let mut left = self.parse_and_expr()?;
773 while self.match_keyword("OR") {
774 let right = self.parse_and_expr()?;
775 left = WhereClause::BoolOp(BoolOp {
776 op: "OR".into(),
777 left: Box::new(left),
778 right: Box::new(right),
779 });
780 }
781 Ok(left)
782 }
783
784 fn parse_and_expr(&mut self) -> Result<WhereClause, MdqlError> {
785 let mut left = self.parse_comparison()?;
786 while self.match_keyword("AND") {
787 let right = self.parse_comparison()?;
788 left = WhereClause::BoolOp(BoolOp {
789 op: "AND".into(),
790 left: Box::new(left),
791 right: Box::new(right),
792 });
793 }
794 Ok(left)
795 }
796
797 fn parse_comparison(&mut self) -> Result<WhereClause, MdqlError> {
798 if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
800 let saved_pos = self.pos;
802 self.advance();
803 let result = self.parse_or_expr();
805 if result.is_ok() && self.peek().map_or(false, |t| t.token_type == "op" && t.value == ")") {
806 self.advance();
807 return result;
808 }
809 self.pos = saved_pos;
811 }
812
813 let left_expr = self.parse_additive()?;
815
816 let col = left_expr.as_column().unwrap_or("").to_string();
818
819 if self.match_keyword("IS") {
821 if self.match_keyword("NOT") {
822 self.expect("keyword", Some("NULL"))?;
823 return Ok(WhereClause::Comparison(Comparison {
824 column: col,
825 op: "IS NOT NULL".into(),
826 value: None,
827 left_expr: Some(left_expr),
828 right_expr: None,
829 }));
830 }
831 self.expect("keyword", Some("NULL"))?;
832 return Ok(WhereClause::Comparison(Comparison {
833 column: col,
834 op: "IS NULL".into(),
835 value: None,
836 left_expr: Some(left_expr),
837 right_expr: None,
838 }));
839 }
840
841 if self.match_keyword("IN") {
843 self.expect("op", Some("("))?;
844 let mut values = vec![self.parse_value()?];
845 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
846 self.advance();
847 values.push(self.parse_value()?);
848 }
849 self.expect("op", Some(")"))?;
850 return Ok(WhereClause::Comparison(Comparison {
851 column: col,
852 op: "IN".into(),
853 value: Some(SqlValue::List(values)),
854 left_expr: Some(left_expr),
855 right_expr: None,
856 }));
857 }
858
859 if self.match_keyword("LIKE") {
861 let val = self.parse_value()?;
862 return Ok(WhereClause::Comparison(Comparison {
863 column: col,
864 op: "LIKE".into(),
865 value: Some(val),
866 left_expr: Some(left_expr),
867 right_expr: None,
868 }));
869 }
870
871 if self.match_keyword("NOT") {
873 if self.match_keyword("LIKE") {
874 let val = self.parse_value()?;
875 return Ok(WhereClause::Comparison(Comparison {
876 column: col,
877 op: "NOT LIKE".into(),
878 value: Some(val),
879 left_expr: Some(left_expr),
880 right_expr: None,
881 }));
882 }
883 return Err(MdqlError::QueryParse("Expected LIKE after NOT".into()));
884 }
885
886 if let Some(t) = self.peek() {
888 if t.token_type == "op" && ["=", "!=", "<", ">", "<=", ">="].contains(&t.value.as_str())
889 {
890 let op = self.advance().value;
891 let right_expr = self.parse_additive()?;
893 let value = match &right_expr {
895 Expr::Literal(v) => Some(v.clone()),
896 _ => None,
897 };
898 return Ok(WhereClause::Comparison(Comparison {
899 column: col,
900 op,
901 value,
902 left_expr: Some(left_expr),
903 right_expr: Some(right_expr),
904 }));
905 }
906 }
907
908 let got = self.peek().map_or("end".to_string(), |t| t.raw.clone());
909 Err(MdqlError::QueryParse(format!(
910 "Expected operator after '{}', got '{}'",
911 left_expr.display_name(), got
912 )))
913 }
914
915 fn parse_value(&mut self) -> Result<SqlValue, MdqlError> {
916 let t = self.peek().ok_or_else(|| {
917 MdqlError::QueryParse("Expected value, got end of query".into())
918 })?;
919 match t.token_type.as_str() {
920 "string" => {
921 let v = self.advance().value;
922 Ok(SqlValue::String(v))
923 }
924 "number" => {
925 let v = self.advance().value;
926 if v.contains('.') {
927 Ok(SqlValue::Float(v.parse().map_err(|_| {
928 MdqlError::QueryParse(format!("Invalid float: {}", v))
929 })?))
930 } else {
931 Ok(SqlValue::Int(v.parse().map_err(|_| {
932 MdqlError::QueryParse(format!("Invalid int: {}", v))
933 })?))
934 }
935 }
936 "keyword" if t.value == "NULL" => {
937 self.advance();
938 Ok(SqlValue::Null)
939 }
940 _ => Err(MdqlError::QueryParse(format!(
941 "Expected value, got '{}'",
942 t.raw
943 ))),
944 }
945 }
946
947 fn parse_order_by(&mut self) -> Result<Vec<OrderSpec>, MdqlError> {
948 let mut specs = vec![self.parse_order_spec()?];
949 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
950 self.advance();
951 specs.push(self.parse_order_spec()?);
952 }
953 Ok(specs)
954 }
955
956 fn parse_order_spec(&mut self) -> Result<OrderSpec, MdqlError> {
957 let expr = self.parse_additive()?;
958 let col = expr.as_column().unwrap_or("").to_string();
959 let descending = if self.match_keyword("DESC") {
960 true
961 } else {
962 self.match_keyword("ASC");
963 false
964 };
965 Ok(OrderSpec {
966 column: col,
967 expr: Some(expr),
968 descending,
969 })
970 }
971
972 fn is_clause_keyword(&self, t: &Token) -> bool {
973 t.token_type == "keyword"
974 && ["WHERE", "ORDER", "LIMIT", "JOIN", "ON", "GROUP"].contains(&t.value.as_str())
975 }
976
977 fn is_reserved_keyword(kw: &str) -> bool {
979 matches!(kw,
980 "AS" | "FROM" | "WHERE" | "AND" | "OR" | "ORDER" | "BY"
981 | "ASC" | "DESC" | "LIMIT" | "JOIN" | "ON" | "GROUP"
982 | "SELECT" | "INSERT" | "INTO" | "VALUES" | "UPDATE" | "SET"
983 | "DELETE" | "ALTER" | "TABLE" | "IS" | "NOT" | "IN" | "LIKE"
984 | "RENAME" | "FIELD" | "TO" | "DROP" | "MERGE" | "FIELDS"
985 | "CASE" | "WHEN" | "THEN" | "ELSE" | "END"
986 | "HAVING" | "INTERVAL" | "DAY" | "DAYS"
987 | "CURRENT_DATE" | "CURRENT_TIMESTAMP" | "DATEDIFF"
988 | "CREATE" | "VIEW" | "CASCADE" | "RESTRICT"
989 )
990 }
991
992 fn expect_end(&self) -> Result<(), MdqlError> {
993 if let Some(t) = self.peek() {
994 return Err(MdqlError::QueryParse(format!(
995 "Unexpected token '{}' at position {}",
996 t.raw, self.pos
997 )));
998 }
999 Ok(())
1000 }
1001}
1002
1003pub fn parse_query(sql: &str) -> crate::errors::Result<Statement> {
1004 let tokens = tokenize(sql);
1005 if tokens.is_empty() {
1006 return Err(MdqlError::QueryParse("Empty query".into()));
1007 }
1008 let mut parser = Parser::new(tokens);
1009 parser.parse_statement()
1010}
1011
1012#[cfg(test)]
1013mod tests {
1014 use super::*;
1015
1016 #[test]
1017 fn test_simple_select() {
1018 let stmt = parse_query("SELECT title, status FROM strategies").unwrap();
1019 if let Statement::Select(q) = stmt {
1020 assert_eq!(q.columns, ColumnList::Named(vec![SelectExpr::Column("title".into()), SelectExpr::Column("status".into())]));
1021 assert_eq!(q.table, "strategies");
1022 } else {
1023 panic!("Expected Select");
1024 }
1025 }
1026
1027 #[test]
1028 fn test_select_star() {
1029 let stmt = parse_query("SELECT * FROM test").unwrap();
1030 if let Statement::Select(q) = stmt {
1031 assert_eq!(q.columns, ColumnList::All);
1032 } else {
1033 panic!("Expected Select");
1034 }
1035 }
1036
1037 #[test]
1038 fn test_where_clause() {
1039 let stmt = parse_query("SELECT title FROM test WHERE count > 5").unwrap();
1040 if let Statement::Select(q) = stmt {
1041 assert!(q.where_clause.is_some());
1042 } else {
1043 panic!("Expected Select");
1044 }
1045 }
1046
1047 #[test]
1048 fn test_order_by() {
1049 let stmt =
1050 parse_query("SELECT title FROM test ORDER BY composite DESC, title ASC").unwrap();
1051 if let Statement::Select(q) = stmt {
1052 let ob = q.order_by.unwrap();
1053 assert_eq!(ob.len(), 2);
1054 assert!(ob[0].descending);
1055 assert!(!ob[1].descending);
1056 } else {
1057 panic!("Expected Select");
1058 }
1059 }
1060
1061 #[test]
1062 fn test_limit() {
1063 let stmt = parse_query("SELECT * FROM test LIMIT 10").unwrap();
1064 if let Statement::Select(q) = stmt {
1065 assert_eq!(q.limit, Some(10));
1066 } else {
1067 panic!("Expected Select");
1068 }
1069 }
1070
1071 #[test]
1072 fn test_insert() {
1073 let stmt = parse_query(
1074 "INSERT INTO test (title, count) VALUES ('Hello', 42)",
1075 )
1076 .unwrap();
1077 if let Statement::Insert(q) = stmt {
1078 assert_eq!(q.table, "test");
1079 assert_eq!(q.columns, vec!["title", "count"]);
1080 assert_eq!(q.values[0], SqlValue::String("Hello".into()));
1081 assert_eq!(q.values[1], SqlValue::Int(42));
1082 } else {
1083 panic!("Expected Insert");
1084 }
1085 }
1086
1087 #[test]
1088 fn test_update() {
1089 let stmt = parse_query("UPDATE test SET status = 'KILLED' WHERE path = 'a.md'").unwrap();
1090 if let Statement::Update(q) = stmt {
1091 assert_eq!(q.table, "test");
1092 assert_eq!(q.assignments.len(), 1);
1093 assert!(q.where_clause.is_some());
1094 } else {
1095 panic!("Expected Update");
1096 }
1097 }
1098
1099 #[test]
1100 fn test_delete() {
1101 let stmt = parse_query("DELETE FROM test WHERE status = 'draft'").unwrap();
1102 if let Statement::Delete(q) = stmt {
1103 assert_eq!(q.table, "test");
1104 assert!(q.where_clause.is_some());
1105 } else {
1106 panic!("Expected Delete");
1107 }
1108 }
1109
1110 #[test]
1111 fn test_alter_rename() {
1112 let stmt =
1113 parse_query("ALTER TABLE test RENAME FIELD 'Summary' TO 'Overview'").unwrap();
1114 if let Statement::AlterRename(q) = stmt {
1115 assert_eq!(q.old_name, "Summary");
1116 assert_eq!(q.new_name, "Overview");
1117 } else {
1118 panic!("Expected AlterRename");
1119 }
1120 }
1121
1122 #[test]
1123 fn test_alter_drop() {
1124 let stmt = parse_query("ALTER TABLE test DROP FIELD 'Details'").unwrap();
1125 if let Statement::AlterDrop(q) = stmt {
1126 assert_eq!(q.field_name, "Details");
1127 } else {
1128 panic!("Expected AlterDrop");
1129 }
1130 }
1131
1132 #[test]
1133 fn test_alter_merge() {
1134 let stmt = parse_query(
1135 "ALTER TABLE test MERGE FIELDS 'Entry Rules', 'Exit Rules' INTO 'Trading Rules'",
1136 )
1137 .unwrap();
1138 if let Statement::AlterMerge(q) = stmt {
1139 assert_eq!(q.sources, vec!["Entry Rules", "Exit Rules"]);
1140 assert_eq!(q.into, "Trading Rules");
1141 } else {
1142 panic!("Expected AlterMerge");
1143 }
1144 }
1145
1146 #[test]
1147 fn test_backtick_ident() {
1148 let stmt = parse_query("SELECT `Structural Mechanism` FROM test").unwrap();
1149 if let Statement::Select(q) = stmt {
1150 assert_eq!(
1151 q.columns,
1152 ColumnList::Named(vec![SelectExpr::Column("Structural Mechanism".into())])
1153 );
1154 } else {
1155 panic!("Expected Select");
1156 }
1157 }
1158
1159 #[test]
1160 fn test_like_operator() {
1161 let stmt = parse_query("SELECT title FROM test WHERE categories LIKE '%defi%'").unwrap();
1162 if let Statement::Select(q) = stmt {
1163 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1164 assert_eq!(c.op, "LIKE");
1165 assert_eq!(c.value, Some(SqlValue::String("%defi%".into())));
1166 } else {
1167 panic!("Expected LIKE comparison");
1168 }
1169 } else {
1170 panic!("Expected Select");
1171 }
1172 }
1173
1174 #[test]
1175 fn test_in_operator() {
1176 let stmt =
1177 parse_query("SELECT * FROM test WHERE status IN ('ACTIVE', 'LIVE')").unwrap();
1178 if let Statement::Select(q) = stmt {
1179 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1180 assert_eq!(c.op, "IN");
1181 } else {
1182 panic!("Expected IN comparison");
1183 }
1184 } else {
1185 panic!("Expected Select");
1186 }
1187 }
1188
1189 #[test]
1190 fn test_is_null() {
1191 let stmt = parse_query("SELECT * FROM test WHERE title IS NULL").unwrap();
1192 if let Statement::Select(q) = stmt {
1193 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1194 assert_eq!(c.op, "IS NULL");
1195 } else {
1196 panic!("Expected IS NULL comparison");
1197 }
1198 } else {
1199 panic!("Expected Select");
1200 }
1201 }
1202
1203 #[test]
1204 fn test_and_or() {
1205 let stmt = parse_query(
1206 "SELECT * FROM test WHERE status = 'ACTIVE' AND count > 5 OR title LIKE '%test%'",
1207 )
1208 .unwrap();
1209 if let Statement::Select(q) = stmt {
1210 assert!(q.where_clause.is_some());
1211 } else {
1212 panic!("Expected Select");
1213 }
1214 }
1215
1216 #[test]
1217 fn test_join() {
1218 let stmt = parse_query(
1219 "SELECT s.title, b.sharpe FROM strategies s JOIN backtests b ON b.strategy = s.path",
1220 )
1221 .unwrap();
1222 if let Statement::Select(q) = stmt {
1223 assert_eq!(q.table, "strategies");
1224 assert_eq!(q.table_alias, Some("s".into()));
1225 assert_eq!(q.joins.len(), 1);
1226 let join = &q.joins[0];
1227 assert_eq!(join.table, "backtests");
1228 assert_eq!(join.alias, Some("b".into()));
1229 } else {
1230 panic!("Expected Select");
1231 }
1232 }
1233
1234 #[test]
1235 fn test_multi_join() {
1236 let stmt = parse_query(
1237 "SELECT s.title, b.sharpe, c.verdict FROM strategies s JOIN backtests b ON b.strategy = s.path JOIN critiques c ON c.strategy = s.path",
1238 )
1239 .unwrap();
1240 if let Statement::Select(q) = stmt {
1241 assert_eq!(q.table, "strategies");
1242 assert_eq!(q.table_alias, Some("s".into()));
1243 assert_eq!(q.joins.len(), 2);
1244 assert_eq!(q.joins[0].table, "backtests");
1245 assert_eq!(q.joins[0].alias, Some("b".into()));
1246 assert_eq!(q.joins[0].left_col, "b.strategy");
1247 assert_eq!(q.joins[0].right_col, "s.path");
1248 assert_eq!(q.joins[1].table, "critiques");
1249 assert_eq!(q.joins[1].alias, Some("c".into()));
1250 assert_eq!(q.joins[1].left_col, "c.strategy");
1251 assert_eq!(q.joins[1].right_col, "s.path");
1252 } else {
1253 panic!("Expected Select");
1254 }
1255 }
1256
1257 #[test]
1258 fn test_empty_query() {
1259 assert!(parse_query("").is_err());
1260 }
1261
1262 #[test]
1263 fn test_count_star() {
1264 let stmt = parse_query("SELECT status, COUNT(*) AS cnt FROM strategies GROUP BY status").unwrap();
1265 if let Statement::Select(q) = stmt {
1266 if let ColumnList::Named(exprs) = &q.columns {
1267 assert_eq!(exprs.len(), 2);
1268 assert_eq!(exprs[0], SelectExpr::Column("status".into()));
1269 assert!(matches!(&exprs[1], SelectExpr::Aggregate {
1270 func: AggFunc::Count,
1271 arg,
1272 alias: Some(a),
1273 ..
1274 } if arg == "*" && a == "cnt"));
1275 } else {
1276 panic!("Expected Named columns");
1277 }
1278 assert_eq!(q.group_by, Some(vec!["status".into()]));
1279 } else {
1280 panic!("Expected Select");
1281 }
1282 }
1283
1284 #[test]
1285 fn test_count_column_as_ident() {
1286 let stmt = parse_query("INSERT INTO test (title, count) VALUES ('Hello', 42)").unwrap();
1288 if let Statement::Insert(q) = stmt {
1289 assert_eq!(q.columns, vec!["title", "count"]);
1290 } else {
1291 panic!("Expected Insert");
1292 }
1293 }
1294
1295 #[test]
1296 fn test_multiple_aggregates() {
1297 let stmt = parse_query("SELECT MIN(composite), MAX(composite), AVG(composite) FROM strategies").unwrap();
1298 if let Statement::Select(q) = stmt {
1299 if let ColumnList::Named(exprs) = &q.columns {
1300 assert_eq!(exprs.len(), 3);
1301 assert!(matches!(&exprs[0], SelectExpr::Aggregate { func: AggFunc::Min, .. }));
1302 assert!(matches!(&exprs[1], SelectExpr::Aggregate { func: AggFunc::Max, .. }));
1303 assert!(matches!(&exprs[2], SelectExpr::Aggregate { func: AggFunc::Avg, .. }));
1304 } else {
1305 panic!("Expected Named columns");
1306 }
1307 assert_eq!(q.group_by, None);
1308 } else {
1309 panic!("Expected Select");
1310 }
1311 }
1312
1313 #[test]
1316 fn test_select_arithmetic_expr() {
1317 let stmt = parse_query("SELECT a + b FROM test").unwrap();
1318 if let Statement::Select(q) = stmt {
1319 if let ColumnList::Named(exprs) = &q.columns {
1320 assert_eq!(exprs.len(), 1);
1321 assert!(matches!(&exprs[0], SelectExpr::Expr {
1322 expr: Expr::BinaryOp { op: ArithOp::Add, .. },
1323 alias: None,
1324 }));
1325 } else {
1326 panic!("Expected Named columns");
1327 }
1328 } else {
1329 panic!("Expected Select");
1330 }
1331 }
1332
1333 #[test]
1334 fn test_select_arithmetic_with_alias() {
1335 let stmt = parse_query("SELECT a + b AS total FROM test").unwrap();
1336 if let Statement::Select(q) = stmt {
1337 if let ColumnList::Named(exprs) = &q.columns {
1338 assert_eq!(exprs.len(), 1);
1339 assert!(matches!(&exprs[0], SelectExpr::Expr {
1340 alias: Some(a),
1341 ..
1342 } if a == "total"));
1343 assert_eq!(exprs[0].output_name(), "total");
1344 } else {
1345 panic!("Expected Named columns");
1346 }
1347 } else {
1348 panic!("Expected Select");
1349 }
1350 }
1351
1352 #[test]
1353 fn test_select_precedence() {
1354 let stmt = parse_query("SELECT a + b * c FROM test").unwrap();
1356 if let Statement::Select(q) = stmt {
1357 if let ColumnList::Named(exprs) = &q.columns {
1358 if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1359 if let Expr::BinaryOp { left, op, right } = expr {
1360 assert_eq!(*op, ArithOp::Add);
1361 assert!(matches!(left.as_ref(), Expr::Column(n) if n == "a"));
1362 assert!(matches!(right.as_ref(), Expr::BinaryOp { op: ArithOp::Mul, .. }));
1363 } else {
1364 panic!("Expected BinaryOp");
1365 }
1366 } else {
1367 panic!("Expected Expr variant");
1368 }
1369 } else {
1370 panic!("Expected Named columns");
1371 }
1372 } else {
1373 panic!("Expected Select");
1374 }
1375 }
1376
1377 #[test]
1378 fn test_select_parenthesized_expr() {
1379 let stmt = parse_query("SELECT (a + b) * c FROM test").unwrap();
1381 if let Statement::Select(q) = stmt {
1382 if let ColumnList::Named(exprs) = &q.columns {
1383 if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1384 if let Expr::BinaryOp { left, op, .. } = expr {
1385 assert_eq!(*op, ArithOp::Mul);
1386 assert!(matches!(left.as_ref(), Expr::BinaryOp { op: ArithOp::Add, .. }));
1387 } else {
1388 panic!("Expected BinaryOp");
1389 }
1390 } else {
1391 panic!("Expected Expr variant");
1392 }
1393 } else {
1394 panic!("Expected Named columns");
1395 }
1396 } else {
1397 panic!("Expected Select");
1398 }
1399 }
1400
1401 #[test]
1402 fn test_select_unary_minus() {
1403 let stmt = parse_query("SELECT -count FROM test").unwrap();
1404 if let Statement::Select(q) = stmt {
1405 if let ColumnList::Named(exprs) = &q.columns {
1406 assert!(matches!(&exprs[0], SelectExpr::Expr {
1407 expr: Expr::UnaryMinus(_),
1408 ..
1409 }));
1410 } else {
1411 panic!("Expected Named columns");
1412 }
1413 } else {
1414 panic!("Expected Select");
1415 }
1416 }
1417
1418 #[test]
1419 fn test_select_negative_literal() {
1420 let stmt = parse_query("SELECT -42 FROM test").unwrap();
1421 if let Statement::Select(q) = stmt {
1422 if let ColumnList::Named(exprs) = &q.columns {
1423 assert!(matches!(&exprs[0], SelectExpr::Expr {
1425 expr: Expr::Literal(SqlValue::Int(-42)),
1426 ..
1427 }));
1428 } else {
1429 panic!("Expected Named columns");
1430 }
1431 } else {
1432 panic!("Expected Select");
1433 }
1434 }
1435
1436 #[test]
1437 fn test_where_arithmetic_expr() {
1438 let stmt = parse_query("SELECT * FROM test WHERE a + b > 10").unwrap();
1439 if let Statement::Select(q) = stmt {
1440 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1441 assert_eq!(c.op, ">");
1442 assert!(matches!(&c.left_expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1443 assert!(matches!(&c.right_expr, Some(Expr::Literal(SqlValue::Int(10)))));
1444 } else {
1445 panic!("Expected comparison");
1446 }
1447 } else {
1448 panic!("Expected Select");
1449 }
1450 }
1451
1452 #[test]
1453 fn test_where_both_sides_expr() {
1454 let stmt = parse_query("SELECT * FROM test WHERE a * 2 > b + 1").unwrap();
1455 if let Statement::Select(q) = stmt {
1456 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1457 assert_eq!(c.op, ">");
1458 assert!(matches!(&c.left_expr, Some(Expr::BinaryOp { op: ArithOp::Mul, .. })));
1459 assert!(matches!(&c.right_expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1460 } else {
1461 panic!("Expected comparison");
1462 }
1463 } else {
1464 panic!("Expected Select");
1465 }
1466 }
1467
1468 #[test]
1469 fn test_order_by_expr() {
1470 let stmt = parse_query("SELECT * FROM test ORDER BY a + b DESC").unwrap();
1471 if let Statement::Select(q) = stmt {
1472 let ob = q.order_by.unwrap();
1473 assert_eq!(ob.len(), 1);
1474 assert!(ob[0].descending);
1475 assert!(matches!(&ob[0].expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1476 } else {
1477 panic!("Expected Select");
1478 }
1479 }
1480
1481 #[test]
1482 fn test_all_arithmetic_ops() {
1483 let stmt = parse_query("SELECT a + b, a - b, a * b, a / b, a % b FROM test").unwrap();
1484 if let Statement::Select(q) = stmt {
1485 if let ColumnList::Named(exprs) = &q.columns {
1486 assert_eq!(exprs.len(), 5);
1487 assert!(matches!(&exprs[0], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Add, .. }, .. }));
1488 assert!(matches!(&exprs[1], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Sub, .. }, .. }));
1489 assert!(matches!(&exprs[2], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Mul, .. }, .. }));
1490 assert!(matches!(&exprs[3], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Div, .. }, .. }));
1491 assert!(matches!(&exprs[4], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Mod, .. }, .. }));
1492 } else {
1493 panic!("Expected Named columns");
1494 }
1495 } else {
1496 panic!("Expected Select");
1497 }
1498 }
1499
1500 #[test]
1501 fn test_column_with_literal_arithmetic() {
1502 let stmt = parse_query("SELECT count * 2 + 1 FROM test").unwrap();
1503 if let Statement::Select(q) = stmt {
1504 if let ColumnList::Named(exprs) = &q.columns {
1505 if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1507 if let Expr::BinaryOp { left, op, right } = expr {
1508 assert_eq!(*op, ArithOp::Add);
1509 assert!(matches!(right.as_ref(), Expr::Literal(SqlValue::Int(1))));
1510 assert!(matches!(left.as_ref(), Expr::BinaryOp { op: ArithOp::Mul, .. }));
1511 } else {
1512 panic!("Expected BinaryOp");
1513 }
1514 } else {
1515 panic!("Expected Expr");
1516 }
1517 } else {
1518 panic!("Expected Named columns");
1519 }
1520 } else {
1521 panic!("Expected Select");
1522 }
1523 }
1524
1525 #[test]
1526 fn test_mixed_columns_and_exprs() {
1527 let stmt = parse_query("SELECT title, a + b AS sum, count FROM test").unwrap();
1528 if let Statement::Select(q) = stmt {
1529 if let ColumnList::Named(exprs) = &q.columns {
1530 assert_eq!(exprs.len(), 3);
1531 assert_eq!(exprs[0], SelectExpr::Column("title".into()));
1532 assert!(matches!(&exprs[1], SelectExpr::Expr { alias: Some(a), .. } if a == "sum"));
1533 assert_eq!(exprs[2], SelectExpr::Column("count".into()));
1534 } else {
1535 panic!("Expected Named columns");
1536 }
1537 } else {
1538 panic!("Expected Select");
1539 }
1540 }
1541
1542 #[test]
1545 fn test_case_when_basic() {
1546 let stmt = parse_query(
1547 "SELECT CASE WHEN status = 'ACTIVE' THEN 1 ELSE 0 END FROM test"
1548 ).unwrap();
1549 if let Statement::Select(q) = stmt {
1550 if let ColumnList::Named(exprs) = &q.columns {
1551 assert_eq!(exprs.len(), 1);
1552 assert!(matches!(&exprs[0], SelectExpr::Expr {
1553 expr: Expr::Case { .. },
1554 ..
1555 }));
1556 } else {
1557 panic!("Expected Named columns");
1558 }
1559 } else {
1560 panic!("Expected Select");
1561 }
1562 }
1563
1564 #[test]
1565 fn test_case_when_multiple_branches() {
1566 let stmt = parse_query(
1567 "SELECT CASE WHEN x > 10 THEN 'high' WHEN x > 5 THEN 'mid' ELSE 'low' END FROM test"
1568 ).unwrap();
1569 if let Statement::Select(q) = stmt {
1570 if let ColumnList::Named(exprs) = &q.columns {
1571 if let SelectExpr::Expr { expr: Expr::Case { whens, else_expr }, .. } = &exprs[0] {
1572 assert_eq!(whens.len(), 2);
1573 assert!(else_expr.is_some());
1574 } else {
1575 panic!("Expected Case expression");
1576 }
1577 } else {
1578 panic!("Expected Named columns");
1579 }
1580 } else {
1581 panic!("Expected Select");
1582 }
1583 }
1584
1585 #[test]
1586 fn test_case_when_no_else() {
1587 let stmt = parse_query(
1588 "SELECT CASE WHEN x = 1 THEN 'one' END FROM test"
1589 ).unwrap();
1590 if let Statement::Select(q) = stmt {
1591 if let ColumnList::Named(exprs) = &q.columns {
1592 if let SelectExpr::Expr { expr: Expr::Case { whens, else_expr }, .. } = &exprs[0] {
1593 assert_eq!(whens.len(), 1);
1594 assert!(else_expr.is_none());
1595 } else {
1596 panic!("Expected Case expression");
1597 }
1598 } else {
1599 panic!("Expected Named columns");
1600 }
1601 } else {
1602 panic!("Expected Select");
1603 }
1604 }
1605
1606 #[test]
1607 fn test_case_when_in_aggregate() {
1608 let stmt = parse_query(
1609 "SELECT SUM(CASE WHEN side = 'BUY' THEN size ELSE -size END) AS net FROM orders GROUP BY token"
1610 ).unwrap();
1611 if let Statement::Select(q) = stmt {
1612 if let ColumnList::Named(exprs) = &q.columns {
1613 assert_eq!(exprs.len(), 1);
1614 assert!(matches!(&exprs[0], SelectExpr::Aggregate {
1615 func: AggFunc::Sum,
1616 arg_expr: Some(Expr::Case { .. }),
1617 alias: Some(a),
1618 ..
1619 } if a == "net"));
1620 } else {
1621 panic!("Expected Named columns");
1622 }
1623 } else {
1624 panic!("Expected Select");
1625 }
1626 }
1627
1628 #[test]
1629 fn test_case_when_with_alias() {
1630 let stmt = parse_query(
1631 "SELECT CASE WHEN x > 0 THEN 'pos' ELSE 'neg' END AS sign FROM test"
1632 ).unwrap();
1633 if let Statement::Select(q) = stmt {
1634 if let ColumnList::Named(exprs) = &q.columns {
1635 assert!(matches!(&exprs[0], SelectExpr::Expr {
1636 expr: Expr::Case { .. },
1637 alias: Some(a),
1638 } if a == "sign"));
1639 } else {
1640 panic!("Expected Named columns");
1641 }
1642 } else {
1643 panic!("Expected Select");
1644 }
1645 }
1646
1647 #[test]
1648 fn test_create_view() {
1649 let stmt = parse_query("CREATE VIEW live AS SELECT * FROM strategies WHERE status = 'LIVE'").unwrap();
1650 if let Statement::CreateView(cv) = stmt {
1651 assert_eq!(cv.view_name, "live");
1652 assert!(cv.columns.is_none());
1653 assert_eq!(cv.query.table, "strategies");
1654 assert!(cv.query.where_clause.is_some());
1655 } else {
1656 panic!("Expected CreateView, got {:?}", stmt);
1657 }
1658 }
1659
1660 #[test]
1661 fn test_create_view_with_columns() {
1662 let stmt = parse_query("CREATE VIEW v1 (a, b) AS SELECT title, status FROM t").unwrap();
1663 if let Statement::CreateView(cv) = stmt {
1664 assert_eq!(cv.view_name, "v1");
1665 assert_eq!(cv.columns, Some(vec!["a".into(), "b".into()]));
1666 } else {
1667 panic!("Expected CreateView");
1668 }
1669 }
1670
1671 #[test]
1672 fn test_drop_view() {
1673 let stmt = parse_query("DROP VIEW live").unwrap();
1674 if let Statement::DropView(dv) = stmt {
1675 assert_eq!(dv.view_name, "live");
1676 } else {
1677 panic!("Expected DropView, got {:?}", stmt);
1678 }
1679 }
1680
1681 #[test]
1682 fn test_create_view_case_insensitive() {
1683 let stmt = parse_query("create view My_View as select * from t").unwrap();
1684 if let Statement::CreateView(cv) = stmt {
1685 assert_eq!(cv.view_name, "My_View");
1686 } else {
1687 panic!("Expected CreateView");
1688 }
1689 }
1690
1691 #[test]
1694 fn test_aggregate_division() {
1695 let stmt = parse_query(
1696 "SELECT token, SUM(sell) / SUM(buy) as ratio FROM orders GROUP BY token"
1697 ).unwrap();
1698 if let Statement::Select(q) = stmt {
1699 assert_eq!(q.group_by, Some(vec!["token".into()]));
1700 if let ColumnList::Named(exprs) = &q.columns {
1701 assert_eq!(exprs.len(), 2);
1702 assert!(exprs[1].is_aggregate());
1703 } else {
1704 panic!("Expected Named columns");
1705 }
1706 } else {
1707 panic!("Expected Select");
1708 }
1709 }
1710
1711 #[test]
1712 fn test_aggregate_subtraction() {
1713 let stmt = parse_query(
1714 "SELECT token, SUM(sell) - SUM(buy) as net FROM orders GROUP BY token"
1715 ).unwrap();
1716 if let Statement::Select(q) = stmt {
1717 if let ColumnList::Named(exprs) = &q.columns {
1718 assert_eq!(exprs[1].output_name(), "net");
1719 }
1720 } else {
1721 panic!("Expected Select");
1722 }
1723 }
1724
1725 #[test]
1726 fn test_create_view_with_arithmetic() {
1727 let stmt = parse_query(
1728 "CREATE VIEW positions AS SELECT token, SUM(sell) / SUM(buy) as ratio FROM orders GROUP BY token"
1729 ).unwrap();
1730 if let Statement::CreateView(cv) = stmt {
1731 assert_eq!(cv.view_name, "positions");
1732 } else {
1733 panic!("Expected CreateView, got {:?}", stmt);
1734 }
1735 }
1736
1737 #[test]
1740 fn test_subquery_in_from() {
1741 let stmt = parse_query(
1742 "SELECT token, sell_size FROM (SELECT token, SUM(size) as sell_size FROM orders GROUP BY token) LIMIT 5"
1743 ).unwrap();
1744 if let Statement::Select(q) = stmt {
1745 assert!(q.subquery.is_some());
1746 assert_eq!(q.limit, Some(5));
1747 let sub = q.subquery.unwrap();
1748 assert_eq!(sub.table, "orders");
1749 assert!(sub.group_by.is_some());
1750 } else {
1751 panic!("Expected Select");
1752 }
1753 }
1754
1755 #[test]
1758 fn test_create_view_with_having() {
1759 let stmt = parse_query(
1760 "CREATE VIEW positions AS SELECT token, SUM(sell) as sell_size, SUM(buy) as buy_size FROM orders GROUP BY token HAVING sell_size > buy_size"
1761 ).unwrap();
1762 if let Statement::CreateView(cv) = stmt {
1763 assert_eq!(cv.view_name, "positions");
1764 assert!(cv.query.having.is_some());
1765 } else {
1766 panic!("Expected CreateView, got {:?}", stmt);
1767 }
1768 }
1769
1770 #[test]
1773 fn test_aggregate_multiplication() {
1774 let stmt = parse_query(
1775 "SELECT SUM(a) * 2 as doubled FROM test"
1776 ).unwrap();
1777 if let Statement::Select(q) = stmt {
1778 if let ColumnList::Named(exprs) = &q.columns {
1779 assert_eq!(exprs.len(), 1);
1780 assert!(exprs[0].is_aggregate());
1781 assert_eq!(exprs[0].output_name(), "doubled");
1782 } else {
1783 panic!("Expected Named columns");
1784 }
1785 } else {
1786 panic!("Expected Select");
1787 }
1788 }
1789
1790 #[test]
1791 fn test_complex_aggregate_arithmetic() {
1792 let stmt = parse_query(
1793 "SELECT SUM(CASE WHEN side = 'SELL' THEN size ELSE 0 END) / SUM(CASE WHEN side = 'BUY' THEN size ELSE 0 END) as ratio FROM orders GROUP BY token"
1794 ).unwrap();
1795 if let Statement::Select(q) = stmt {
1796 if let ColumnList::Named(exprs) = &q.columns {
1797 assert_eq!(exprs.len(), 1);
1798 assert!(exprs[0].is_aggregate());
1799 assert_eq!(exprs[0].output_name(), "ratio");
1800 } else {
1801 panic!("Expected Named columns");
1802 }
1803 assert_eq!(q.group_by, Some(vec!["token".into()]));
1804 } else {
1805 panic!("Expected Select");
1806 }
1807 }
1808
1809 #[test]
1812 fn test_subquery_with_alias() {
1813 let stmt = parse_query(
1814 "SELECT x FROM (SELECT x FROM t) sub"
1815 ).unwrap();
1816 if let Statement::Select(q) = stmt {
1817 assert!(q.subquery.is_some());
1818 let sub = q.subquery.unwrap();
1819 assert_eq!(sub.table, "t");
1820 if let ColumnList::Named(exprs) = &q.columns {
1821 assert_eq!(exprs.len(), 1);
1822 assert_eq!(exprs[0].output_name(), "x");
1823 } else {
1824 panic!("Expected Named columns");
1825 }
1826 } else {
1827 panic!("Expected Select");
1828 }
1829 }
1830
1831 #[test]
1832 fn test_subquery_with_where() {
1833 let stmt = parse_query(
1834 "SELECT x FROM (SELECT x FROM t WHERE y > 0) LIMIT 5"
1835 ).unwrap();
1836 if let Statement::Select(q) = stmt {
1837 assert!(q.subquery.is_some());
1838 assert_eq!(q.limit, Some(5));
1839 let sub = q.subquery.unwrap();
1840 assert_eq!(sub.table, "t");
1841 assert!(sub.where_clause.is_some());
1842 } else {
1843 panic!("Expected Select");
1844 }
1845 }
1846
1847 #[test]
1850 fn test_create_view_aggregate_subtraction() {
1851 let stmt = parse_query(
1852 "CREATE VIEW v AS SELECT token, SUM(sell) - SUM(buy) as net FROM orders GROUP BY token"
1853 ).unwrap();
1854 if let Statement::CreateView(cv) = stmt {
1855 assert_eq!(cv.view_name, "v");
1856 assert_eq!(cv.query.group_by, Some(vec!["token".into()]));
1857 if let ColumnList::Named(exprs) = &cv.query.columns {
1858 assert_eq!(exprs.len(), 2);
1859 assert_eq!(exprs[1].output_name(), "net");
1860 assert!(exprs[1].is_aggregate());
1861 } else {
1862 panic!("Expected Named columns");
1863 }
1864 } else {
1865 panic!("Expected CreateView, got {:?}", stmt);
1866 }
1867 }
1868}