1use regex::Regex;
4use std::sync::LazyLock;
5
6use crate::errors::MdqlError;
7pub use crate::query_ast::*;
8
9static KEYWORDS: &[&str] = &[
12 "SELECT", "FROM", "WHERE", "AND", "OR", "ORDER", "BY",
13 "ASC", "DESC", "LIMIT", "LIKE", "IN", "IS", "NOT", "NULL",
14 "JOIN", "LEFT", "ON", "AS", "GROUP", "HAVING",
15 "INSERT", "INTO", "VALUES", "UPDATE", "SET", "DELETE",
16 "ALTER", "TABLE", "RENAME", "FIELD", "TO", "DROP", "MERGE", "FIELDS",
17 "CASE", "WHEN", "THEN", "ELSE", "END",
18 "INTERVAL", "DAY", "DAYS", "CURRENT_DATE", "CURRENT_TIMESTAMP", "DATEDIFF",
19 "CREATE", "VIEW", "CASCADE", "RESTRICT",
20];
21
22static AGG_FUNCS: &[&str] = &["COUNT", "SUM", "AVG", "MIN", "MAX"];
23
24static TOKEN_RE: LazyLock<Regex> = LazyLock::new(|| {
25 Regex::new(
26 r#"(?x)
27 \s*(?:
28 (?P<backtick>`[^`]+`)
29 | (?P<string>'(?:[^'\\]|\\.)*')
30 | (?P<number>\d+(?:\.\d+)?)
31 | (?P<op><=|>=|!=|[=<>,*()+\-/%])
32 | (?P<word>[A-Za-z_][A-Za-z0-9_./-]*)
33 )"#,
34 )
35 .unwrap()
36});
37
38#[derive(Debug, Clone)]
39struct Token {
40 token_type: String,
41 value: String,
42 raw: String,
43}
44
45fn tokenize(sql: &str) -> Vec<Token> {
46 let mut tokens = Vec::new();
47 for caps in TOKEN_RE.captures_iter(sql) {
48 if let Some(m) = caps.name("backtick") {
49 let raw = m.as_str();
50 tokens.push(Token {
51 token_type: "ident".into(),
52 value: raw[1..raw.len() - 1].into(),
53 raw: raw.into(),
54 });
55 } else if let Some(m) = caps.name("string") {
56 let raw = m.as_str();
57 tokens.push(Token {
58 token_type: "string".into(),
59 value: raw[1..raw.len() - 1].into(),
60 raw: raw.into(),
61 });
62 } else if let Some(m) = caps.name("number") {
63 let raw = m.as_str();
64 tokens.push(Token {
65 token_type: "number".into(),
66 value: raw.into(),
67 raw: raw.into(),
68 });
69 } else if let Some(m) = caps.name("op") {
70 let raw = m.as_str();
71 tokens.push(Token {
72 token_type: "op".into(),
73 value: raw.into(),
74 raw: raw.into(),
75 });
76 } else if let Some(m) = caps.name("word") {
77 let raw = m.as_str();
78 if KEYWORDS.contains(&raw.to_uppercase().as_str()) {
79 tokens.push(Token {
80 token_type: "keyword".into(),
81 value: raw.to_uppercase(),
82 raw: raw.into(),
83 });
84 } else {
85 tokens.push(Token {
86 token_type: "ident".into(),
87 value: raw.into(),
88 raw: raw.into(),
89 });
90 }
91 }
92 }
93 tokens
94}
95
96struct Parser {
99 tokens: Vec<Token>,
100 pos: usize,
101}
102
103impl Parser {
104 fn new(tokens: Vec<Token>) -> Self {
105 Parser { tokens, pos: 0 }
106 }
107
108 fn peek(&self) -> Option<&Token> {
109 self.tokens.get(self.pos)
110 }
111
112 fn advance(&mut self) -> Token {
113 let t = self.tokens[self.pos].clone();
114 self.pos += 1;
115 t
116 }
117
118 fn expect(&mut self, type_: &str, value: Option<&str>) -> Result<Token, MdqlError> {
119 let t = self.peek().ok_or_else(|| {
120 MdqlError::QueryParse(format!(
121 "Unexpected end of query, expected {}",
122 value.unwrap_or(type_)
123 ))
124 })?;
125 let matches_type = t.token_type == type_;
126 let matches_value = value.map_or(true, |v| t.value == v);
127 if !matches_type || !matches_value {
128 return Err(MdqlError::QueryParse(format!(
129 "Expected {}, got '{}' at position {}",
130 value.unwrap_or(type_),
131 t.raw,
132 self.pos
133 )));
134 }
135 Ok(self.advance())
136 }
137
138 fn match_keyword(&mut self, kw: &str) -> bool {
139 if let Some(t) = self.peek() {
140 if t.token_type == "keyword" && t.value == kw {
141 self.advance();
142 return true;
143 }
144 }
145 false
146 }
147
148 fn parse_statement(&mut self) -> Result<Statement, MdqlError> {
149 let t = self.peek().ok_or_else(|| MdqlError::QueryParse("Empty query".into()))?;
150 match (t.token_type.as_str(), t.value.as_str()) {
151 ("keyword", "SELECT") => {
152 let q = self.parse_select()?;
153 self.expect_end()?;
154 Ok(Statement::Select(q))
155 }
156 ("keyword", "INSERT") => Ok(Statement::Insert(self.parse_insert()?)),
157 ("keyword", "UPDATE") => Ok(Statement::Update(self.parse_update()?)),
158 ("keyword", "DELETE") => Ok(Statement::Delete(self.parse_delete()?)),
159 ("keyword", "ALTER") => self.parse_alter(),
160 ("keyword", "CREATE") => self.parse_create_view(),
161 ("keyword", "DROP") => self.parse_drop_view(),
162 _ => Err(MdqlError::QueryParse(format!(
163 "Expected SELECT, INSERT, UPDATE, DELETE, ALTER, CREATE, or DROP, got '{}'",
164 t.raw
165 ))),
166 }
167 }
168
169 fn parse_select(&mut self) -> Result<SelectQuery, MdqlError> {
170 self.expect("keyword", Some("SELECT"))?;
171 let columns = self.parse_columns()?;
172 self.expect("keyword", Some("FROM"))?;
173
174 let mut subquery = None;
176 let (table, mut table_alias) = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
177 self.advance();
178 let inner = self.parse_select()?;
179 self.expect("op", Some(")"))?;
180 subquery = Some(Box::new(inner));
181 let alias = if let Some(t) = self.peek() {
182 if t.token_type == "ident" && !self.is_clause_keyword(t) {
183 Some(self.advance().value)
184 } else {
185 None
186 }
187 } else {
188 None
189 };
190 ("_subquery".to_string(), alias)
191 } else {
192 let t = self.parse_ident()?;
193 (t, None)
194 };
195
196 if subquery.is_none() {
198 if let Some(t) = self.peek() {
199 if t.token_type == "ident" && !self.is_clause_keyword(t) {
200 table_alias = Some(self.advance().value);
201 }
202 }
203 }
204
205 let mut joins = Vec::new();
207 loop {
208 let jt = if self.match_keyword("LEFT") {
209 self.expect("keyword", Some("JOIN"))?;
210 JoinType::Left
211 } else if self.match_keyword("JOIN") {
212 JoinType::Inner
213 } else {
214 break;
215 };
216 let join_table = self.parse_ident()?;
217 let mut join_alias = None;
218 if let Some(t) = self.peek() {
219 if t.token_type == "ident" && !self.is_clause_keyword(t) {
220 join_alias = Some(self.advance().value);
221 }
222 }
223 self.expect("keyword", Some("ON"))?;
224 let condition = self.parse_or_expr()?;
225 joins.push(JoinClause {
226 join_type: jt,
227 table: join_table,
228 alias: join_alias,
229 condition,
230 });
231 }
232
233 let mut where_clause = None;
234 if self.match_keyword("WHERE") {
235 where_clause = Some(self.parse_or_expr()?);
236 }
237
238 let mut group_by = None;
239 if self.match_keyword("GROUP") {
240 self.expect("keyword", Some("BY"))?;
241 let mut cols = vec![self.parse_ident()?];
242 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
243 self.advance();
244 cols.push(self.parse_ident()?);
245 }
246 group_by = Some(cols);
247 }
248
249 let mut having = None;
250 if self.match_keyword("HAVING") {
251 having = Some(self.parse_or_expr()?);
252 }
253
254 let mut order_by = None;
255 if self.match_keyword("ORDER") {
256 self.expect("keyword", Some("BY"))?;
257 order_by = Some(self.parse_order_by()?);
258 }
259
260 let mut limit = None;
261 if self.match_keyword("LIMIT") {
262 let t = self.expect("number", None)?;
263 limit = Some(t.value.parse::<i64>().map_err(|_| {
264 MdqlError::QueryParse(format!("Invalid LIMIT value: {}", t.value))
265 })?);
266 }
267
268 Ok(SelectQuery {
269 columns,
270 table,
271 table_alias,
272 subquery,
273 joins,
274 where_clause,
275 group_by,
276 having,
277 order_by,
278 limit,
279 })
280 }
281
282 fn parse_insert(&mut self) -> Result<InsertQuery, MdqlError> {
283 self.expect("keyword", Some("INSERT"))?;
284 self.expect("keyword", Some("INTO"))?;
285 let table = self.parse_ident()?;
286
287 self.expect("op", Some("("))?;
288 let mut columns = vec![self.parse_ident()?];
289 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
290 self.advance();
291 columns.push(self.parse_ident()?);
292 }
293 self.expect("op", Some(")"))?;
294
295 self.expect("keyword", Some("VALUES"))?;
296
297 self.expect("op", Some("("))?;
298 let mut values = vec![self.parse_value()?];
299 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
300 self.advance();
301 values.push(self.parse_value()?);
302 }
303 self.expect("op", Some(")"))?;
304
305 if columns.len() != values.len() {
306 return Err(MdqlError::QueryParse(format!(
307 "Column count ({}) does not match value count ({})",
308 columns.len(),
309 values.len()
310 )));
311 }
312
313 self.expect_end()?;
314 Ok(InsertQuery {
315 table,
316 columns,
317 values,
318 })
319 }
320
321 fn parse_update(&mut self) -> Result<UpdateQuery, MdqlError> {
322 self.expect("keyword", Some("UPDATE"))?;
323 let table = self.parse_ident()?;
324 self.expect("keyword", Some("SET"))?;
325
326 let mut assignments = Vec::new();
327 let col = self.parse_ident()?;
328 self.expect("op", Some("="))?;
329 let val = self.parse_value()?;
330 assignments.push((col, val));
331
332 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
333 self.advance();
334 let col = self.parse_ident()?;
335 self.expect("op", Some("="))?;
336 let val = self.parse_value()?;
337 assignments.push((col, val));
338 }
339
340 let mut where_clause = None;
341 if self.match_keyword("WHERE") {
342 where_clause = Some(self.parse_or_expr()?);
343 }
344
345 self.expect_end()?;
346 Ok(UpdateQuery {
347 table,
348 assignments,
349 where_clause,
350 })
351 }
352
353 fn parse_delete(&mut self) -> Result<DeleteQuery, MdqlError> {
354 self.expect("keyword", Some("DELETE"))?;
355 self.expect("keyword", Some("FROM"))?;
356 let table = self.parse_ident()?;
357
358 let mut where_clause = None;
359 if self.match_keyword("WHERE") {
360 where_clause = Some(self.parse_or_expr()?);
361 }
362
363 self.expect_end()?;
364 Ok(DeleteQuery {
365 table,
366 where_clause,
367 })
368 }
369
370 fn parse_alter(&mut self) -> Result<Statement, MdqlError> {
371 self.expect("keyword", Some("ALTER"))?;
372 self.expect("keyword", Some("TABLE"))?;
373 let table = self.parse_ident()?;
374
375 let t = self.peek().ok_or_else(|| {
376 MdqlError::QueryParse("Expected RENAME, DROP, or MERGE after table name".into())
377 })?;
378
379 match (t.token_type.as_str(), t.value.as_str()) {
380 ("keyword", "RENAME") => {
381 self.advance();
382 self.expect("keyword", Some("FIELD"))?;
383 let old_name = self.parse_string_or_ident()?;
384 self.expect("keyword", Some("TO"))?;
385 let new_name = self.parse_string_or_ident()?;
386 self.expect_end()?;
387 Ok(Statement::AlterRename(AlterRenameFieldQuery {
388 table,
389 old_name,
390 new_name,
391 }))
392 }
393 ("keyword", "DROP") => {
394 self.advance();
395 self.expect("keyword", Some("FIELD"))?;
396 let field_name = self.parse_string_or_ident()?;
397 self.expect_end()?;
398 Ok(Statement::AlterDrop(AlterDropFieldQuery {
399 table,
400 field_name,
401 }))
402 }
403 ("keyword", "MERGE") => {
404 self.advance();
405 self.expect("keyword", Some("FIELDS"))?;
406 let mut sources = vec![self.parse_string_or_ident()?];
407 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
408 self.advance();
409 sources.push(self.parse_string_or_ident()?);
410 }
411 self.expect("keyword", Some("INTO"))?;
412 let target = self.parse_string_or_ident()?;
413 self.expect_end()?;
414 Ok(Statement::AlterMerge(AlterMergeFieldsQuery {
415 table,
416 sources,
417 into: target,
418 }))
419 }
420 _ => Err(MdqlError::QueryParse(format!(
421 "Expected RENAME, DROP, or MERGE, got '{}'",
422 t.raw
423 ))),
424 }
425 }
426
427 fn parse_create_view(&mut self) -> Result<Statement, MdqlError> {
428 self.expect("keyword", Some("CREATE"))?;
429 self.expect("keyword", Some("VIEW"))?;
430 let view_name = self.parse_ident()?;
431
432 let columns = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
433 self.advance();
434 let mut cols = vec![self.parse_ident()?];
435 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
436 self.advance();
437 cols.push(self.parse_ident()?);
438 }
439 self.expect("op", Some(")"))?;
440 Some(cols)
441 } else {
442 None
443 };
444
445 self.expect("keyword", Some("AS"))?;
446 let query = Box::new(self.parse_select()?);
447 self.expect_end()?;
448
449 Ok(Statement::CreateView(CreateViewQuery {
450 view_name,
451 columns,
452 query,
453 }))
454 }
455
456 fn parse_drop_view(&mut self) -> Result<Statement, MdqlError> {
457 self.expect("keyword", Some("DROP"))?;
458 self.expect("keyword", Some("VIEW"))?;
459 let view_name = self.parse_ident()?;
460 self.expect_end()?;
461 Ok(Statement::DropView(DropViewQuery { view_name }))
462 }
463
464 fn parse_string_or_ident(&mut self) -> Result<String, MdqlError> {
465 let t = self.peek().ok_or_else(|| {
466 MdqlError::QueryParse("Expected field name, got end of query".into())
467 })?;
468 match t.token_type.as_str() {
469 "string" => {
470 let v = self.advance().value;
471 Ok(v)
472 }
473 "ident" | "keyword" => {
474 let v = self.advance().value;
475 Ok(v)
476 }
477 _ => Err(MdqlError::QueryParse(format!(
478 "Expected field name, got '{}'",
479 t.raw
480 ))),
481 }
482 }
483
484 fn parse_columns(&mut self) -> Result<ColumnList, MdqlError> {
485 if let Some(t) = self.peek() {
486 if t.token_type == "op" && t.value == "*" {
487 self.advance();
488 return Ok(ColumnList::All);
489 }
490 }
491
492 let mut exprs = vec![self.parse_select_expr()?];
493 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
494 self.advance();
495 exprs.push(self.parse_select_expr()?);
496 }
497 Ok(ColumnList::Named(exprs))
498 }
499
500 fn peek_is_agg_func(&self) -> bool {
501 let t = match self.peek() {
502 Some(t) => t,
503 None => return false,
504 };
505 let name_upper = t.value.to_uppercase();
506 if !AGG_FUNCS.contains(&name_upper.as_str()) {
507 return false;
508 }
509 self.tokens
511 .get(self.pos + 1)
512 .map_or(false, |next| next.token_type == "op" && next.value == "(")
513 }
514
515 fn parse_select_expr(&mut self) -> Result<SelectExpr, MdqlError> {
516 let _t = self.peek().ok_or_else(|| {
517 MdqlError::QueryParse("Expected column or aggregate, got end of query".into())
518 })?;
519
520 let expr = self.parse_additive()?;
521
522 let alias = if self.match_keyword("AS") {
523 Some(self.parse_ident()?)
524 } else if self.peek().map_or(false, |t| {
525 t.token_type == "ident" && !self.is_clause_keyword(t)
526 }) {
527 Some(self.advance().value)
528 } else {
529 None
530 };
531
532 if let Expr::Aggregate { func, arg, arg_expr } = expr {
534 return Ok(SelectExpr::Aggregate {
535 func,
536 arg,
537 arg_expr: arg_expr.map(|e| *e),
538 alias,
539 });
540 }
541
542 if alias.is_none() {
543 if let Expr::Column(name) = &expr {
544 return Ok(SelectExpr::Column(name.clone()));
545 }
546 }
547
548 Ok(SelectExpr::Expr { expr, alias })
549 }
550
551 fn peek_is_additive_op(&self) -> bool {
554 self.peek().map_or(false, |t| {
555 t.token_type == "op" && (t.value == "+" || t.value == "-")
556 })
557 }
558
559 fn peek_is_multiplicative_op(&self) -> bool {
560 self.peek().map_or(false, |t| {
561 t.token_type == "op" && (t.value == "*" || t.value == "/" || t.value == "%")
562 })
563 }
564
565 fn parse_additive(&mut self) -> Result<Expr, MdqlError> {
566 let mut left = self.parse_multiplicative()?;
567 while self.peek_is_additive_op() {
568 let op_tok = self.advance();
569 let is_sub = op_tok.value == "-";
570
571 if self.peek().map_or(false, |t| t.token_type == "keyword" && t.value == "INTERVAL") {
573 self.advance(); let days_expr = self.parse_multiplicative()?;
575 if !self.match_keyword("DAY") && !self.match_keyword("DAYS") {
577 return Err(MdqlError::QueryParse("Expected DAY after INTERVAL value".into()));
578 }
579 let days = if is_sub {
580 Expr::UnaryMinus(Box::new(days_expr))
581 } else {
582 days_expr
583 };
584 left = Expr::DateAdd {
585 date: Box::new(left),
586 days: Box::new(days),
587 };
588 continue;
589 }
590
591 let op = match op_tok.value.as_str() {
592 "+" => ArithOp::Add,
593 "-" => ArithOp::Sub,
594 _ => unreachable!(),
595 };
596 let right = self.parse_multiplicative()?;
597 left = Expr::BinaryOp {
598 left: Box::new(left),
599 op,
600 right: Box::new(right),
601 };
602 }
603 Ok(left)
604 }
605
606 fn parse_multiplicative(&mut self) -> Result<Expr, MdqlError> {
607 let mut left = self.parse_unary()?;
608 while self.peek_is_multiplicative_op() {
609 let op_tok = self.advance();
610 let op = match op_tok.value.as_str() {
611 "*" => ArithOp::Mul,
612 "/" => ArithOp::Div,
613 "%" => ArithOp::Mod,
614 _ => unreachable!(),
615 };
616 let right = self.parse_unary()?;
617 left = Expr::BinaryOp {
618 left: Box::new(left),
619 op,
620 right: Box::new(right),
621 };
622 }
623 Ok(left)
624 }
625
626 fn parse_unary(&mut self) -> Result<Expr, MdqlError> {
627 if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "-") {
628 self.advance();
629 let inner = self.parse_atom()?;
630 match inner {
632 Expr::Literal(SqlValue::Int(n)) => Ok(Expr::Literal(SqlValue::Int(-n))),
633 Expr::Literal(SqlValue::Float(f)) => Ok(Expr::Literal(SqlValue::Float(-f))),
634 _ => Ok(Expr::UnaryMinus(Box::new(inner))),
635 }
636 } else {
637 self.parse_atom()
638 }
639 }
640
641 fn parse_atom(&mut self) -> Result<Expr, MdqlError> {
642 if self.peek_is_agg_func() {
643 return self.parse_agg_expr();
644 }
645
646 let t = self.peek().ok_or_else(|| {
647 MdqlError::QueryParse("Expected expression, got end of query".into())
648 })?;
649
650 match t.token_type.as_str() {
651 "number" => {
652 let v = self.advance().value;
653 if v.contains('.') {
654 let f: f64 = v.parse().map_err(|_| {
655 MdqlError::QueryParse(format!("Invalid float: {}", v))
656 })?;
657 Ok(Expr::Literal(SqlValue::Float(f)))
658 } else {
659 let n: i64 = v.parse().map_err(|_| {
660 MdqlError::QueryParse(format!("Invalid int: {}", v))
661 })?;
662 Ok(Expr::Literal(SqlValue::Int(n)))
663 }
664 }
665 "string" => {
666 let v = self.advance().value;
667 Ok(Expr::Literal(SqlValue::String(v)))
668 }
669 "keyword" if t.value == "NULL" => {
670 self.advance();
671 Ok(Expr::Literal(SqlValue::Null))
672 }
673 "keyword" if t.value == "CASE" => {
674 self.parse_case_expr()
675 }
676 "keyword" if t.value == "CURRENT_DATE" => {
677 self.advance();
678 Ok(Expr::CurrentDate)
679 }
680 "keyword" if t.value == "CURRENT_TIMESTAMP" => {
681 self.advance();
682 Ok(Expr::CurrentTimestamp)
683 }
684 "keyword" if t.value == "DATEDIFF" => {
685 self.advance();
686 self.expect("op", Some("("))?;
687 let left = self.parse_additive()?;
688 self.expect("op", Some(","))?;
689 let right = self.parse_additive()?;
690 self.expect("op", Some(")"))?;
691 Ok(Expr::DateDiff { left: Box::new(left), right: Box::new(right) })
692 }
693 "op" if t.value == "(" => {
694 self.advance();
695 let expr = self.parse_additive()?;
696 self.expect("op", Some(")"))?;
697 Ok(expr)
698 }
699 "ident" => {
700 let name = self.advance().value;
701 Ok(Expr::Column(name))
702 }
703 "keyword" if !Self::is_reserved_keyword(&t.value) => {
704 let name = self.advance().value;
705 Ok(Expr::Column(name))
706 }
707 _ => Err(MdqlError::QueryParse(format!(
708 "Expected expression, got '{}'",
709 t.raw
710 ))),
711 }
712 }
713
714 fn parse_case_expr(&mut self) -> Result<Expr, MdqlError> {
715 self.expect("keyword", Some("CASE"))?;
716 let mut whens = Vec::new();
717 while self.match_keyword("WHEN") {
718 let condition = self.parse_or_expr()?;
719 self.expect("keyword", Some("THEN"))?;
720 let result = self.parse_additive()?;
721 whens.push((condition, Box::new(result)));
722 }
723 if whens.is_empty() {
724 return Err(MdqlError::QueryParse("CASE requires at least one WHEN clause".into()));
725 }
726 let else_expr = if self.match_keyword("ELSE") {
727 Some(Box::new(self.parse_additive()?))
728 } else {
729 None
730 };
731 self.expect("keyword", Some("END"))?;
732 Ok(Expr::Case { whens, else_expr })
733 }
734
735 fn parse_agg_expr(&mut self) -> Result<Expr, MdqlError> {
736 let func_name = self.advance().value.to_uppercase();
737 let func = match func_name.as_str() {
738 "COUNT" => AggFunc::Count,
739 "SUM" => AggFunc::Sum,
740 "AVG" => AggFunc::Avg,
741 "MIN" => AggFunc::Min,
742 "MAX" => AggFunc::Max,
743 _ => unreachable!(),
744 };
745 self.expect("op", Some("("))?;
746 let (arg, arg_expr) = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "*") {
747 self.advance();
748 ("*".to_string(), None)
749 } else {
750 let expr = self.parse_additive()?;
751 if let Expr::Column(name) = &expr {
752 (name.clone(), None)
753 } else {
754 (expr.display_name(), Some(Box::new(expr)))
755 }
756 };
757 self.expect("op", Some(")"))?;
758 Ok(Expr::Aggregate { func, arg, arg_expr })
759 }
760
761 fn parse_ident(&mut self) -> Result<String, MdqlError> {
762 let t = self.peek().ok_or_else(|| {
763 MdqlError::QueryParse("Expected identifier, got end of query".into())
764 })?;
765 match t.token_type.as_str() {
766 "ident" | "keyword" => {
767 let v = self.advance().value;
768 Ok(v)
769 }
770 _ => Err(MdqlError::QueryParse(format!(
771 "Expected identifier, got '{}'",
772 t.raw
773 ))),
774 }
775 }
776
777 fn parse_or_expr(&mut self) -> Result<WhereClause, MdqlError> {
778 let mut left = self.parse_and_expr()?;
779 while self.match_keyword("OR") {
780 let right = self.parse_and_expr()?;
781 left = WhereClause::BoolOp(BoolOp {
782 op: BoolOpKind::Or,
783 left: Box::new(left),
784 right: Box::new(right),
785 });
786 }
787 Ok(left)
788 }
789
790 fn parse_and_expr(&mut self) -> Result<WhereClause, MdqlError> {
791 let mut left = self.parse_comparison()?;
792 while self.match_keyword("AND") {
793 let right = self.parse_comparison()?;
794 left = WhereClause::BoolOp(BoolOp {
795 op: BoolOpKind::And,
796 left: Box::new(left),
797 right: Box::new(right),
798 });
799 }
800 Ok(left)
801 }
802
803 fn parse_comparison(&mut self) -> Result<WhereClause, MdqlError> {
804 if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
806 let saved_pos = self.pos;
808 self.advance();
809 let result = self.parse_or_expr();
811 if result.is_ok() && self.peek().map_or(false, |t| t.token_type == "op" && t.value == ")") {
812 self.advance();
813 return result;
814 }
815 self.pos = saved_pos;
817 }
818
819 let left_expr = self.parse_additive()?;
821
822 let col = left_expr.as_column().unwrap_or("").to_string();
824
825 if self.match_keyword("IS") {
827 if self.match_keyword("NOT") {
828 self.expect("keyword", Some("NULL"))?;
829 return Ok(WhereClause::Comparison(Comparison {
830 column: col,
831 op: CmpOp::IsNotNull,
832 value: None,
833 left_expr: Some(left_expr),
834 right_expr: None,
835 }));
836 }
837 self.expect("keyword", Some("NULL"))?;
838 return Ok(WhereClause::Comparison(Comparison {
839 column: col,
840 op: CmpOp::IsNull,
841 value: None,
842 left_expr: Some(left_expr),
843 right_expr: None,
844 }));
845 }
846
847 if self.match_keyword("IN") {
849 self.expect("op", Some("("))?;
850 let mut values = vec![self.parse_value()?];
851 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
852 self.advance();
853 values.push(self.parse_value()?);
854 }
855 self.expect("op", Some(")"))?;
856 return Ok(WhereClause::Comparison(Comparison {
857 column: col,
858 op: CmpOp::In,
859 value: Some(SqlValue::List(values)),
860 left_expr: Some(left_expr),
861 right_expr: None,
862 }));
863 }
864
865 if self.match_keyword("LIKE") {
867 let val = self.parse_value()?;
868 return Ok(WhereClause::Comparison(Comparison {
869 column: col,
870 op: CmpOp::Like,
871 value: Some(val),
872 left_expr: Some(left_expr),
873 right_expr: None,
874 }));
875 }
876
877 if self.match_keyword("NOT") {
879 if self.match_keyword("LIKE") {
880 let val = self.parse_value()?;
881 return Ok(WhereClause::Comparison(Comparison {
882 column: col,
883 op: CmpOp::NotLike,
884 value: Some(val),
885 left_expr: Some(left_expr),
886 right_expr: None,
887 }));
888 }
889 return Err(MdqlError::QueryParse("Expected LIKE after NOT".into()));
890 }
891
892 if let Some(t) = self.peek() {
894 if t.token_type == "op" && ["=", "!=", "<", ">", "<=", ">="].contains(&t.value.as_str())
895 {
896 let op_str = self.advance().value;
897 let op = match op_str.as_str() {
898 "=" => CmpOp::Eq,
899 "!=" => CmpOp::Ne,
900 "<" => CmpOp::Lt,
901 ">" => CmpOp::Gt,
902 "<=" => CmpOp::Le,
903 ">=" => CmpOp::Ge,
904 _ => unreachable!(),
905 };
906 let right_expr = self.parse_additive()?;
908 let value = match &right_expr {
910 Expr::Literal(v) => Some(v.clone()),
911 _ => None,
912 };
913 return Ok(WhereClause::Comparison(Comparison {
914 column: col,
915 op,
916 value,
917 left_expr: Some(left_expr),
918 right_expr: Some(right_expr),
919 }));
920 }
921 }
922
923 let got = self.peek().map_or("end".to_string(), |t| t.raw.clone());
924 Err(MdqlError::QueryParse(format!(
925 "Expected operator after '{}', got '{}'",
926 left_expr.display_name(), got
927 )))
928 }
929
930 fn parse_value(&mut self) -> Result<SqlValue, MdqlError> {
931 let t = self.peek().ok_or_else(|| {
932 MdqlError::QueryParse("Expected value, got end of query".into())
933 })?;
934 match t.token_type.as_str() {
935 "string" => {
936 let v = self.advance().value;
937 Ok(SqlValue::String(v))
938 }
939 "number" => {
940 let v = self.advance().value;
941 if v.contains('.') {
942 Ok(SqlValue::Float(v.parse().map_err(|_| {
943 MdqlError::QueryParse(format!("Invalid float: {}", v))
944 })?))
945 } else {
946 Ok(SqlValue::Int(v.parse().map_err(|_| {
947 MdqlError::QueryParse(format!("Invalid int: {}", v))
948 })?))
949 }
950 }
951 "keyword" if t.value == "NULL" => {
952 self.advance();
953 Ok(SqlValue::Null)
954 }
955 _ => Err(MdqlError::QueryParse(format!(
956 "Expected value, got '{}'",
957 t.raw
958 ))),
959 }
960 }
961
962 fn parse_order_by(&mut self) -> Result<Vec<OrderSpec>, MdqlError> {
963 let mut specs = vec![self.parse_order_spec()?];
964 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
965 self.advance();
966 specs.push(self.parse_order_spec()?);
967 }
968 Ok(specs)
969 }
970
971 fn parse_order_spec(&mut self) -> Result<OrderSpec, MdqlError> {
972 let expr = self.parse_additive()?;
973 let col = expr.as_column().unwrap_or("").to_string();
974 let descending = if self.match_keyword("DESC") {
975 true
976 } else {
977 self.match_keyword("ASC");
978 false
979 };
980 Ok(OrderSpec {
981 column: col,
982 expr: Some(expr),
983 descending,
984 })
985 }
986
987 fn is_clause_keyword(&self, t: &Token) -> bool {
988 t.token_type == "keyword"
989 && ["WHERE", "ORDER", "LIMIT", "JOIN", "LEFT", "ON", "GROUP"].contains(&t.value.as_str())
990 }
991
992 fn is_reserved_keyword(kw: &str) -> bool {
994 matches!(kw,
995 "AS" | "FROM" | "WHERE" | "AND" | "OR" | "ORDER" | "BY"
996 | "ASC" | "DESC" | "LIMIT" | "JOIN" | "ON" | "GROUP"
997 | "SELECT" | "INSERT" | "INTO" | "VALUES" | "UPDATE" | "SET"
998 | "DELETE" | "ALTER" | "TABLE" | "IS" | "NOT" | "IN" | "LIKE"
999 | "RENAME" | "FIELD" | "TO" | "DROP" | "MERGE" | "FIELDS"
1000 | "CASE" | "WHEN" | "THEN" | "ELSE" | "END"
1001 | "HAVING" | "INTERVAL" | "DAY" | "DAYS"
1002 | "CURRENT_DATE" | "CURRENT_TIMESTAMP" | "DATEDIFF"
1003 | "CREATE" | "VIEW" | "CASCADE" | "RESTRICT"
1004 )
1005 }
1006
1007 fn expect_end(&self) -> Result<(), MdqlError> {
1008 if let Some(t) = self.peek() {
1009 return Err(MdqlError::QueryParse(format!(
1010 "Unexpected token '{}' at position {}",
1011 t.raw, self.pos
1012 )));
1013 }
1014 Ok(())
1015 }
1016}
1017
1018pub fn parse_query(sql: &str) -> crate::errors::Result<Statement> {
1019 let tokens = tokenize(sql);
1020 if tokens.is_empty() {
1021 return Err(MdqlError::QueryParse("Empty query".into()));
1022 }
1023 let mut parser = Parser::new(tokens);
1024 parser.parse_statement()
1025}
1026
1027#[cfg(test)]
1028mod tests {
1029 use super::*;
1030
1031 #[test]
1032 fn test_simple_select() {
1033 let stmt = parse_query("SELECT title, status FROM strategies").unwrap();
1034 if let Statement::Select(q) = stmt {
1035 assert_eq!(q.columns, ColumnList::Named(vec![SelectExpr::Column("title".into()), SelectExpr::Column("status".into())]));
1036 assert_eq!(q.table, "strategies");
1037 } else {
1038 panic!("Expected Select");
1039 }
1040 }
1041
1042 #[test]
1043 fn test_select_star() {
1044 let stmt = parse_query("SELECT * FROM test").unwrap();
1045 if let Statement::Select(q) = stmt {
1046 assert_eq!(q.columns, ColumnList::All);
1047 } else {
1048 panic!("Expected Select");
1049 }
1050 }
1051
1052 #[test]
1053 fn test_where_clause() {
1054 let stmt = parse_query("SELECT title FROM test WHERE count > 5").unwrap();
1055 if let Statement::Select(q) = stmt {
1056 assert!(q.where_clause.is_some());
1057 } else {
1058 panic!("Expected Select");
1059 }
1060 }
1061
1062 #[test]
1063 fn test_order_by() {
1064 let stmt =
1065 parse_query("SELECT title FROM test ORDER BY composite DESC, title ASC").unwrap();
1066 if let Statement::Select(q) = stmt {
1067 let ob = q.order_by.unwrap();
1068 assert_eq!(ob.len(), 2);
1069 assert!(ob[0].descending);
1070 assert!(!ob[1].descending);
1071 } else {
1072 panic!("Expected Select");
1073 }
1074 }
1075
1076 #[test]
1077 fn test_limit() {
1078 let stmt = parse_query("SELECT * FROM test LIMIT 10").unwrap();
1079 if let Statement::Select(q) = stmt {
1080 assert_eq!(q.limit, Some(10));
1081 } else {
1082 panic!("Expected Select");
1083 }
1084 }
1085
1086 #[test]
1087 fn test_insert() {
1088 let stmt = parse_query(
1089 "INSERT INTO test (title, count) VALUES ('Hello', 42)",
1090 )
1091 .unwrap();
1092 if let Statement::Insert(q) = stmt {
1093 assert_eq!(q.table, "test");
1094 assert_eq!(q.columns, vec!["title", "count"]);
1095 assert_eq!(q.values[0], SqlValue::String("Hello".into()));
1096 assert_eq!(q.values[1], SqlValue::Int(42));
1097 } else {
1098 panic!("Expected Insert");
1099 }
1100 }
1101
1102 #[test]
1103 fn test_update() {
1104 let stmt = parse_query("UPDATE test SET status = 'KILLED' WHERE path = 'a.md'").unwrap();
1105 if let Statement::Update(q) = stmt {
1106 assert_eq!(q.table, "test");
1107 assert_eq!(q.assignments.len(), 1);
1108 assert!(q.where_clause.is_some());
1109 } else {
1110 panic!("Expected Update");
1111 }
1112 }
1113
1114 #[test]
1115 fn test_delete() {
1116 let stmt = parse_query("DELETE FROM test WHERE status = 'draft'").unwrap();
1117 if let Statement::Delete(q) = stmt {
1118 assert_eq!(q.table, "test");
1119 assert!(q.where_clause.is_some());
1120 } else {
1121 panic!("Expected Delete");
1122 }
1123 }
1124
1125 #[test]
1126 fn test_alter_rename() {
1127 let stmt =
1128 parse_query("ALTER TABLE test RENAME FIELD 'Summary' TO 'Overview'").unwrap();
1129 if let Statement::AlterRename(q) = stmt {
1130 assert_eq!(q.old_name, "Summary");
1131 assert_eq!(q.new_name, "Overview");
1132 } else {
1133 panic!("Expected AlterRename");
1134 }
1135 }
1136
1137 #[test]
1138 fn test_alter_drop() {
1139 let stmt = parse_query("ALTER TABLE test DROP FIELD 'Details'").unwrap();
1140 if let Statement::AlterDrop(q) = stmt {
1141 assert_eq!(q.field_name, "Details");
1142 } else {
1143 panic!("Expected AlterDrop");
1144 }
1145 }
1146
1147 #[test]
1148 fn test_alter_merge() {
1149 let stmt = parse_query(
1150 "ALTER TABLE test MERGE FIELDS 'Entry Rules', 'Exit Rules' INTO 'Trading Rules'",
1151 )
1152 .unwrap();
1153 if let Statement::AlterMerge(q) = stmt {
1154 assert_eq!(q.sources, vec!["Entry Rules", "Exit Rules"]);
1155 assert_eq!(q.into, "Trading Rules");
1156 } else {
1157 panic!("Expected AlterMerge");
1158 }
1159 }
1160
1161 #[test]
1162 fn test_backtick_ident() {
1163 let stmt = parse_query("SELECT `Structural Mechanism` FROM test").unwrap();
1164 if let Statement::Select(q) = stmt {
1165 assert_eq!(
1166 q.columns,
1167 ColumnList::Named(vec![SelectExpr::Column("Structural Mechanism".into())])
1168 );
1169 } else {
1170 panic!("Expected Select");
1171 }
1172 }
1173
1174 #[test]
1175 fn test_like_operator() {
1176 let stmt = parse_query("SELECT title FROM test WHERE categories LIKE '%defi%'").unwrap();
1177 if let Statement::Select(q) = stmt {
1178 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1179 assert_eq!(c.op, CmpOp::Like);
1180 assert_eq!(c.value, Some(SqlValue::String("%defi%".into())));
1181 } else {
1182 panic!("Expected LIKE comparison");
1183 }
1184 } else {
1185 panic!("Expected Select");
1186 }
1187 }
1188
1189 #[test]
1190 fn test_in_operator() {
1191 let stmt =
1192 parse_query("SELECT * FROM test WHERE status IN ('ACTIVE', 'LIVE')").unwrap();
1193 if let Statement::Select(q) = stmt {
1194 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1195 assert_eq!(c.op, CmpOp::In);
1196 } else {
1197 panic!("Expected IN comparison");
1198 }
1199 } else {
1200 panic!("Expected Select");
1201 }
1202 }
1203
1204 #[test]
1205 fn test_is_null() {
1206 let stmt = parse_query("SELECT * FROM test WHERE title IS NULL").unwrap();
1207 if let Statement::Select(q) = stmt {
1208 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1209 assert_eq!(c.op, CmpOp::IsNull);
1210 } else {
1211 panic!("Expected IS NULL comparison");
1212 }
1213 } else {
1214 panic!("Expected Select");
1215 }
1216 }
1217
1218 #[test]
1219 fn test_and_or() {
1220 let stmt = parse_query(
1221 "SELECT * FROM test WHERE status = 'ACTIVE' AND count > 5 OR title LIKE '%test%'",
1222 )
1223 .unwrap();
1224 if let Statement::Select(q) = stmt {
1225 assert!(q.where_clause.is_some());
1226 } else {
1227 panic!("Expected Select");
1228 }
1229 }
1230
1231 #[test]
1232 fn test_join() {
1233 let stmt = parse_query(
1234 "SELECT s.title, b.sharpe FROM strategies s JOIN backtests b ON b.strategy = s.path",
1235 )
1236 .unwrap();
1237 if let Statement::Select(q) = stmt {
1238 assert_eq!(q.table, "strategies");
1239 assert_eq!(q.table_alias, Some("s".into()));
1240 assert_eq!(q.joins.len(), 1);
1241 let join = &q.joins[0];
1242 assert_eq!(join.table, "backtests");
1243 assert_eq!(join.alias, Some("b".into()));
1244 } else {
1245 panic!("Expected Select");
1246 }
1247 }
1248
1249 #[test]
1250 fn test_multi_join() {
1251 let stmt = parse_query(
1252 "SELECT s.title, b.sharpe, c.verdict FROM strategies s JOIN backtests b ON b.strategy = s.path JOIN critiques c ON c.strategy = s.path",
1253 )
1254 .unwrap();
1255 if let Statement::Select(q) = stmt {
1256 assert_eq!(q.table, "strategies");
1257 assert_eq!(q.table_alias, Some("s".into()));
1258 assert_eq!(q.joins.len(), 2);
1259 assert_eq!(q.joins[0].table, "backtests");
1260 assert_eq!(q.joins[0].alias, Some("b".into()));
1261 assert_eq!(where_clause_to_sql(&q.joins[0].condition), "b.strategy = s.path");
1262 assert_eq!(q.joins[1].table, "critiques");
1263 assert_eq!(q.joins[1].alias, Some("c".into()));
1264 assert_eq!(where_clause_to_sql(&q.joins[1].condition), "c.strategy = s.path");
1265 } else {
1266 panic!("Expected Select");
1267 }
1268 }
1269
1270 #[test]
1271 fn test_left_join() {
1272 let stmt = parse_query(
1273 "SELECT s.title, b.sharpe FROM strategies s LEFT JOIN backtests b ON b.strategy = s.path",
1274 )
1275 .unwrap();
1276 if let Statement::Select(q) = stmt {
1277 assert_eq!(q.joins.len(), 1);
1278 assert_eq!(q.joins[0].join_type, JoinType::Left);
1279 assert_eq!(q.joins[0].table, "backtests");
1280 } else {
1281 panic!("Expected Select");
1282 }
1283 }
1284
1285 #[test]
1286 fn test_mixed_join_types() {
1287 let stmt = parse_query(
1288 "SELECT s.title FROM strategies s JOIN backtests b ON b.strategy = s.path LEFT JOIN allocations a ON a.strategy = s.path",
1289 )
1290 .unwrap();
1291 if let Statement::Select(q) = stmt {
1292 assert_eq!(q.joins.len(), 2);
1293 assert_eq!(q.joins[0].join_type, JoinType::Inner);
1294 assert_eq!(q.joins[1].join_type, JoinType::Left);
1295 } else {
1296 panic!("Expected Select");
1297 }
1298 }
1299
1300 #[test]
1301 fn test_join_compound_and() {
1302 let stmt = parse_query(
1303 "SELECT s.title FROM strategies s LEFT JOIN backtests b ON b.strategy = s.path AND b.mode = 'PAPER'",
1304 )
1305 .unwrap();
1306 if let Statement::Select(q) = stmt {
1307 assert_eq!(q.joins.len(), 1);
1308 assert_eq!(q.joins[0].join_type, JoinType::Left);
1309 let sql = where_clause_to_sql(&q.joins[0].condition);
1310 assert!(sql.contains("b.strategy = s.path"));
1311 assert!(sql.contains("AND"));
1312 assert!(sql.contains("b.mode = 'PAPER'"));
1313 } else {
1314 panic!("Expected Select");
1315 }
1316 }
1317
1318 #[test]
1319 fn test_join_compound_or() {
1320 let stmt = parse_query(
1321 "SELECT * FROM a JOIN b ON a.id = b.id OR a.alt = b.id",
1322 )
1323 .unwrap();
1324 if let Statement::Select(q) = stmt {
1325 let sql = where_clause_to_sql(&q.joins[0].condition);
1326 assert!(sql.contains("OR"));
1327 } else {
1328 panic!("Expected Select");
1329 }
1330 }
1331
1332 #[test]
1333 fn test_join_compound_with_where() {
1334 let stmt = parse_query(
1335 "SELECT s.title FROM strategies s JOIN backtests b ON b.strategy = s.path AND b.mode = 'PAPER' WHERE s.title = 'Alpha'",
1336 )
1337 .unwrap();
1338 if let Statement::Select(q) = stmt {
1339 assert_eq!(q.joins.len(), 1);
1340 assert!(q.where_clause.is_some());
1341 let join_sql = where_clause_to_sql(&q.joins[0].condition);
1342 assert!(join_sql.contains("AND"));
1343 } else {
1344 panic!("Expected Select");
1345 }
1346 }
1347
1348 #[test]
1349 fn test_empty_query() {
1350 assert!(parse_query("").is_err());
1351 }
1352
1353 #[test]
1354 fn test_count_star() {
1355 let stmt = parse_query("SELECT status, COUNT(*) AS cnt FROM strategies GROUP BY status").unwrap();
1356 if let Statement::Select(q) = stmt {
1357 if let ColumnList::Named(exprs) = &q.columns {
1358 assert_eq!(exprs.len(), 2);
1359 assert_eq!(exprs[0], SelectExpr::Column("status".into()));
1360 assert!(matches!(&exprs[1], SelectExpr::Aggregate {
1361 func: AggFunc::Count,
1362 arg,
1363 alias: Some(a),
1364 ..
1365 } if arg == "*" && a == "cnt"));
1366 } else {
1367 panic!("Expected Named columns");
1368 }
1369 assert_eq!(q.group_by, Some(vec!["status".into()]));
1370 } else {
1371 panic!("Expected Select");
1372 }
1373 }
1374
1375 #[test]
1376 fn test_count_column_as_ident() {
1377 let stmt = parse_query("INSERT INTO test (title, count) VALUES ('Hello', 42)").unwrap();
1379 if let Statement::Insert(q) = stmt {
1380 assert_eq!(q.columns, vec!["title", "count"]);
1381 } else {
1382 panic!("Expected Insert");
1383 }
1384 }
1385
1386 #[test]
1387 fn test_multiple_aggregates() {
1388 let stmt = parse_query("SELECT MIN(composite), MAX(composite), AVG(composite) FROM strategies").unwrap();
1389 if let Statement::Select(q) = stmt {
1390 if let ColumnList::Named(exprs) = &q.columns {
1391 assert_eq!(exprs.len(), 3);
1392 assert!(matches!(&exprs[0], SelectExpr::Aggregate { func: AggFunc::Min, .. }));
1393 assert!(matches!(&exprs[1], SelectExpr::Aggregate { func: AggFunc::Max, .. }));
1394 assert!(matches!(&exprs[2], SelectExpr::Aggregate { func: AggFunc::Avg, .. }));
1395 } else {
1396 panic!("Expected Named columns");
1397 }
1398 assert_eq!(q.group_by, None);
1399 } else {
1400 panic!("Expected Select");
1401 }
1402 }
1403
1404 #[test]
1407 fn test_select_arithmetic_expr() {
1408 let stmt = parse_query("SELECT a + b FROM test").unwrap();
1409 if let Statement::Select(q) = stmt {
1410 if let ColumnList::Named(exprs) = &q.columns {
1411 assert_eq!(exprs.len(), 1);
1412 assert!(matches!(&exprs[0], SelectExpr::Expr {
1413 expr: Expr::BinaryOp { op: ArithOp::Add, .. },
1414 alias: None,
1415 }));
1416 } else {
1417 panic!("Expected Named columns");
1418 }
1419 } else {
1420 panic!("Expected Select");
1421 }
1422 }
1423
1424 #[test]
1425 fn test_select_arithmetic_with_alias() {
1426 let stmt = parse_query("SELECT a + b AS total FROM test").unwrap();
1427 if let Statement::Select(q) = stmt {
1428 if let ColumnList::Named(exprs) = &q.columns {
1429 assert_eq!(exprs.len(), 1);
1430 assert!(matches!(&exprs[0], SelectExpr::Expr {
1431 alias: Some(a),
1432 ..
1433 } if a == "total"));
1434 assert_eq!(exprs[0].output_name(), "total");
1435 } else {
1436 panic!("Expected Named columns");
1437 }
1438 } else {
1439 panic!("Expected Select");
1440 }
1441 }
1442
1443 #[test]
1444 fn test_select_precedence() {
1445 let stmt = parse_query("SELECT a + b * c FROM test").unwrap();
1447 if let Statement::Select(q) = stmt {
1448 if let ColumnList::Named(exprs) = &q.columns {
1449 if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1450 if let Expr::BinaryOp { left, op, right } = expr {
1451 assert_eq!(*op, ArithOp::Add);
1452 assert!(matches!(left.as_ref(), Expr::Column(n) if n == "a"));
1453 assert!(matches!(right.as_ref(), Expr::BinaryOp { op: ArithOp::Mul, .. }));
1454 } else {
1455 panic!("Expected BinaryOp");
1456 }
1457 } else {
1458 panic!("Expected Expr variant");
1459 }
1460 } else {
1461 panic!("Expected Named columns");
1462 }
1463 } else {
1464 panic!("Expected Select");
1465 }
1466 }
1467
1468 #[test]
1469 fn test_select_parenthesized_expr() {
1470 let stmt = parse_query("SELECT (a + b) * c FROM test").unwrap();
1472 if let Statement::Select(q) = stmt {
1473 if let ColumnList::Named(exprs) = &q.columns {
1474 if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1475 if let Expr::BinaryOp { left, op, .. } = expr {
1476 assert_eq!(*op, ArithOp::Mul);
1477 assert!(matches!(left.as_ref(), Expr::BinaryOp { op: ArithOp::Add, .. }));
1478 } else {
1479 panic!("Expected BinaryOp");
1480 }
1481 } else {
1482 panic!("Expected Expr variant");
1483 }
1484 } else {
1485 panic!("Expected Named columns");
1486 }
1487 } else {
1488 panic!("Expected Select");
1489 }
1490 }
1491
1492 #[test]
1493 fn test_select_unary_minus() {
1494 let stmt = parse_query("SELECT -count FROM test").unwrap();
1495 if let Statement::Select(q) = stmt {
1496 if let ColumnList::Named(exprs) = &q.columns {
1497 assert!(matches!(&exprs[0], SelectExpr::Expr {
1498 expr: Expr::UnaryMinus(_),
1499 ..
1500 }));
1501 } else {
1502 panic!("Expected Named columns");
1503 }
1504 } else {
1505 panic!("Expected Select");
1506 }
1507 }
1508
1509 #[test]
1510 fn test_select_negative_literal() {
1511 let stmt = parse_query("SELECT -42 FROM test").unwrap();
1512 if let Statement::Select(q) = stmt {
1513 if let ColumnList::Named(exprs) = &q.columns {
1514 assert!(matches!(&exprs[0], SelectExpr::Expr {
1516 expr: Expr::Literal(SqlValue::Int(-42)),
1517 ..
1518 }));
1519 } else {
1520 panic!("Expected Named columns");
1521 }
1522 } else {
1523 panic!("Expected Select");
1524 }
1525 }
1526
1527 #[test]
1528 fn test_where_arithmetic_expr() {
1529 let stmt = parse_query("SELECT * FROM test WHERE a + b > 10").unwrap();
1530 if let Statement::Select(q) = stmt {
1531 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1532 assert_eq!(c.op, CmpOp::Gt);
1533 assert!(matches!(&c.left_expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1534 assert!(matches!(&c.right_expr, Some(Expr::Literal(SqlValue::Int(10)))));
1535 } else {
1536 panic!("Expected comparison");
1537 }
1538 } else {
1539 panic!("Expected Select");
1540 }
1541 }
1542
1543 #[test]
1544 fn test_where_both_sides_expr() {
1545 let stmt = parse_query("SELECT * FROM test WHERE a * 2 > b + 1").unwrap();
1546 if let Statement::Select(q) = stmt {
1547 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1548 assert_eq!(c.op, CmpOp::Gt);
1549 assert!(matches!(&c.left_expr, Some(Expr::BinaryOp { op: ArithOp::Mul, .. })));
1550 assert!(matches!(&c.right_expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1551 } else {
1552 panic!("Expected comparison");
1553 }
1554 } else {
1555 panic!("Expected Select");
1556 }
1557 }
1558
1559 #[test]
1560 fn test_order_by_expr() {
1561 let stmt = parse_query("SELECT * FROM test ORDER BY a + b DESC").unwrap();
1562 if let Statement::Select(q) = stmt {
1563 let ob = q.order_by.unwrap();
1564 assert_eq!(ob.len(), 1);
1565 assert!(ob[0].descending);
1566 assert!(matches!(&ob[0].expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1567 } else {
1568 panic!("Expected Select");
1569 }
1570 }
1571
1572 #[test]
1573 fn test_all_arithmetic_ops() {
1574 let stmt = parse_query("SELECT a + b, a - b, a * b, a / b, a % b FROM test").unwrap();
1575 if let Statement::Select(q) = stmt {
1576 if let ColumnList::Named(exprs) = &q.columns {
1577 assert_eq!(exprs.len(), 5);
1578 assert!(matches!(&exprs[0], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Add, .. }, .. }));
1579 assert!(matches!(&exprs[1], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Sub, .. }, .. }));
1580 assert!(matches!(&exprs[2], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Mul, .. }, .. }));
1581 assert!(matches!(&exprs[3], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Div, .. }, .. }));
1582 assert!(matches!(&exprs[4], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Mod, .. }, .. }));
1583 } else {
1584 panic!("Expected Named columns");
1585 }
1586 } else {
1587 panic!("Expected Select");
1588 }
1589 }
1590
1591 #[test]
1592 fn test_column_with_literal_arithmetic() {
1593 let stmt = parse_query("SELECT count * 2 + 1 FROM test").unwrap();
1594 if let Statement::Select(q) = stmt {
1595 if let ColumnList::Named(exprs) = &q.columns {
1596 if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1598 if let Expr::BinaryOp { left, op, right } = expr {
1599 assert_eq!(*op, ArithOp::Add);
1600 assert!(matches!(right.as_ref(), Expr::Literal(SqlValue::Int(1))));
1601 assert!(matches!(left.as_ref(), Expr::BinaryOp { op: ArithOp::Mul, .. }));
1602 } else {
1603 panic!("Expected BinaryOp");
1604 }
1605 } else {
1606 panic!("Expected Expr");
1607 }
1608 } else {
1609 panic!("Expected Named columns");
1610 }
1611 } else {
1612 panic!("Expected Select");
1613 }
1614 }
1615
1616 #[test]
1617 fn test_mixed_columns_and_exprs() {
1618 let stmt = parse_query("SELECT title, a + b AS sum, count FROM test").unwrap();
1619 if let Statement::Select(q) = stmt {
1620 if let ColumnList::Named(exprs) = &q.columns {
1621 assert_eq!(exprs.len(), 3);
1622 assert_eq!(exprs[0], SelectExpr::Column("title".into()));
1623 assert!(matches!(&exprs[1], SelectExpr::Expr { alias: Some(a), .. } if a == "sum"));
1624 assert_eq!(exprs[2], SelectExpr::Column("count".into()));
1625 } else {
1626 panic!("Expected Named columns");
1627 }
1628 } else {
1629 panic!("Expected Select");
1630 }
1631 }
1632
1633 #[test]
1636 fn test_case_when_basic() {
1637 let stmt = parse_query(
1638 "SELECT CASE WHEN status = 'ACTIVE' THEN 1 ELSE 0 END FROM test"
1639 ).unwrap();
1640 if let Statement::Select(q) = stmt {
1641 if let ColumnList::Named(exprs) = &q.columns {
1642 assert_eq!(exprs.len(), 1);
1643 assert!(matches!(&exprs[0], SelectExpr::Expr {
1644 expr: Expr::Case { .. },
1645 ..
1646 }));
1647 } else {
1648 panic!("Expected Named columns");
1649 }
1650 } else {
1651 panic!("Expected Select");
1652 }
1653 }
1654
1655 #[test]
1656 fn test_case_when_multiple_branches() {
1657 let stmt = parse_query(
1658 "SELECT CASE WHEN x > 10 THEN 'high' WHEN x > 5 THEN 'mid' ELSE 'low' END FROM test"
1659 ).unwrap();
1660 if let Statement::Select(q) = stmt {
1661 if let ColumnList::Named(exprs) = &q.columns {
1662 if let SelectExpr::Expr { expr: Expr::Case { whens, else_expr }, .. } = &exprs[0] {
1663 assert_eq!(whens.len(), 2);
1664 assert!(else_expr.is_some());
1665 } else {
1666 panic!("Expected Case expression");
1667 }
1668 } else {
1669 panic!("Expected Named columns");
1670 }
1671 } else {
1672 panic!("Expected Select");
1673 }
1674 }
1675
1676 #[test]
1677 fn test_case_when_no_else() {
1678 let stmt = parse_query(
1679 "SELECT CASE WHEN x = 1 THEN 'one' END FROM test"
1680 ).unwrap();
1681 if let Statement::Select(q) = stmt {
1682 if let ColumnList::Named(exprs) = &q.columns {
1683 if let SelectExpr::Expr { expr: Expr::Case { whens, else_expr }, .. } = &exprs[0] {
1684 assert_eq!(whens.len(), 1);
1685 assert!(else_expr.is_none());
1686 } else {
1687 panic!("Expected Case expression");
1688 }
1689 } else {
1690 panic!("Expected Named columns");
1691 }
1692 } else {
1693 panic!("Expected Select");
1694 }
1695 }
1696
1697 #[test]
1698 fn test_case_when_in_aggregate() {
1699 let stmt = parse_query(
1700 "SELECT SUM(CASE WHEN side = 'BUY' THEN size ELSE -size END) AS net FROM orders GROUP BY token"
1701 ).unwrap();
1702 if let Statement::Select(q) = stmt {
1703 if let ColumnList::Named(exprs) = &q.columns {
1704 assert_eq!(exprs.len(), 1);
1705 assert!(matches!(&exprs[0], SelectExpr::Aggregate {
1706 func: AggFunc::Sum,
1707 arg_expr: Some(Expr::Case { .. }),
1708 alias: Some(a),
1709 ..
1710 } if a == "net"));
1711 } else {
1712 panic!("Expected Named columns");
1713 }
1714 } else {
1715 panic!("Expected Select");
1716 }
1717 }
1718
1719 #[test]
1720 fn test_case_when_with_alias() {
1721 let stmt = parse_query(
1722 "SELECT CASE WHEN x > 0 THEN 'pos' ELSE 'neg' END AS sign FROM test"
1723 ).unwrap();
1724 if let Statement::Select(q) = stmt {
1725 if let ColumnList::Named(exprs) = &q.columns {
1726 assert!(matches!(&exprs[0], SelectExpr::Expr {
1727 expr: Expr::Case { .. },
1728 alias: Some(a),
1729 } if a == "sign"));
1730 } else {
1731 panic!("Expected Named columns");
1732 }
1733 } else {
1734 panic!("Expected Select");
1735 }
1736 }
1737
1738 #[test]
1739 fn test_create_view() {
1740 let stmt = parse_query("CREATE VIEW live AS SELECT * FROM strategies WHERE status = 'LIVE'").unwrap();
1741 if let Statement::CreateView(cv) = stmt {
1742 assert_eq!(cv.view_name, "live");
1743 assert!(cv.columns.is_none());
1744 assert_eq!(cv.query.table, "strategies");
1745 assert!(cv.query.where_clause.is_some());
1746 } else {
1747 panic!("Expected CreateView, got {:?}", stmt);
1748 }
1749 }
1750
1751 #[test]
1752 fn test_create_view_with_columns() {
1753 let stmt = parse_query("CREATE VIEW v1 (a, b) AS SELECT title, status FROM t").unwrap();
1754 if let Statement::CreateView(cv) = stmt {
1755 assert_eq!(cv.view_name, "v1");
1756 assert_eq!(cv.columns, Some(vec!["a".into(), "b".into()]));
1757 } else {
1758 panic!("Expected CreateView");
1759 }
1760 }
1761
1762 #[test]
1763 fn test_drop_view() {
1764 let stmt = parse_query("DROP VIEW live").unwrap();
1765 if let Statement::DropView(dv) = stmt {
1766 assert_eq!(dv.view_name, "live");
1767 } else {
1768 panic!("Expected DropView, got {:?}", stmt);
1769 }
1770 }
1771
1772 #[test]
1773 fn test_create_view_case_insensitive() {
1774 let stmt = parse_query("create view My_View as select * from t").unwrap();
1775 if let Statement::CreateView(cv) = stmt {
1776 assert_eq!(cv.view_name, "My_View");
1777 } else {
1778 panic!("Expected CreateView");
1779 }
1780 }
1781
1782 #[test]
1785 fn test_aggregate_division() {
1786 let stmt = parse_query(
1787 "SELECT token, SUM(sell) / SUM(buy) as ratio FROM orders GROUP BY token"
1788 ).unwrap();
1789 if let Statement::Select(q) = stmt {
1790 assert_eq!(q.group_by, Some(vec!["token".into()]));
1791 if let ColumnList::Named(exprs) = &q.columns {
1792 assert_eq!(exprs.len(), 2);
1793 assert!(exprs[1].is_aggregate());
1794 } else {
1795 panic!("Expected Named columns");
1796 }
1797 } else {
1798 panic!("Expected Select");
1799 }
1800 }
1801
1802 #[test]
1803 fn test_aggregate_subtraction() {
1804 let stmt = parse_query(
1805 "SELECT token, SUM(sell) - SUM(buy) as net FROM orders GROUP BY token"
1806 ).unwrap();
1807 if let Statement::Select(q) = stmt {
1808 if let ColumnList::Named(exprs) = &q.columns {
1809 assert_eq!(exprs[1].output_name(), "net");
1810 }
1811 } else {
1812 panic!("Expected Select");
1813 }
1814 }
1815
1816 #[test]
1817 fn test_create_view_with_arithmetic() {
1818 let stmt = parse_query(
1819 "CREATE VIEW positions AS SELECT token, SUM(sell) / SUM(buy) as ratio FROM orders GROUP BY token"
1820 ).unwrap();
1821 if let Statement::CreateView(cv) = stmt {
1822 assert_eq!(cv.view_name, "positions");
1823 } else {
1824 panic!("Expected CreateView, got {:?}", stmt);
1825 }
1826 }
1827
1828 #[test]
1831 fn test_subquery_in_from() {
1832 let stmt = parse_query(
1833 "SELECT token, sell_size FROM (SELECT token, SUM(size) as sell_size FROM orders GROUP BY token) LIMIT 5"
1834 ).unwrap();
1835 if let Statement::Select(q) = stmt {
1836 assert!(q.subquery.is_some());
1837 assert_eq!(q.limit, Some(5));
1838 let sub = q.subquery.unwrap();
1839 assert_eq!(sub.table, "orders");
1840 assert!(sub.group_by.is_some());
1841 } else {
1842 panic!("Expected Select");
1843 }
1844 }
1845
1846 #[test]
1849 fn test_create_view_with_having() {
1850 let stmt = parse_query(
1851 "CREATE VIEW positions AS SELECT token, SUM(sell) as sell_size, SUM(buy) as buy_size FROM orders GROUP BY token HAVING sell_size > buy_size"
1852 ).unwrap();
1853 if let Statement::CreateView(cv) = stmt {
1854 assert_eq!(cv.view_name, "positions");
1855 assert!(cv.query.having.is_some());
1856 } else {
1857 panic!("Expected CreateView, got {:?}", stmt);
1858 }
1859 }
1860
1861 #[test]
1864 fn test_aggregate_multiplication() {
1865 let stmt = parse_query(
1866 "SELECT SUM(a) * 2 as doubled FROM test"
1867 ).unwrap();
1868 if let Statement::Select(q) = stmt {
1869 if let ColumnList::Named(exprs) = &q.columns {
1870 assert_eq!(exprs.len(), 1);
1871 assert!(exprs[0].is_aggregate());
1872 assert_eq!(exprs[0].output_name(), "doubled");
1873 } else {
1874 panic!("Expected Named columns");
1875 }
1876 } else {
1877 panic!("Expected Select");
1878 }
1879 }
1880
1881 #[test]
1882 fn test_complex_aggregate_arithmetic() {
1883 let stmt = parse_query(
1884 "SELECT SUM(CASE WHEN side = 'SELL' THEN size ELSE 0 END) / SUM(CASE WHEN side = 'BUY' THEN size ELSE 0 END) as ratio FROM orders GROUP BY token"
1885 ).unwrap();
1886 if let Statement::Select(q) = stmt {
1887 if let ColumnList::Named(exprs) = &q.columns {
1888 assert_eq!(exprs.len(), 1);
1889 assert!(exprs[0].is_aggregate());
1890 assert_eq!(exprs[0].output_name(), "ratio");
1891 } else {
1892 panic!("Expected Named columns");
1893 }
1894 assert_eq!(q.group_by, Some(vec!["token".into()]));
1895 } else {
1896 panic!("Expected Select");
1897 }
1898 }
1899
1900 #[test]
1903 fn test_subquery_with_alias() {
1904 let stmt = parse_query(
1905 "SELECT x FROM (SELECT x FROM t) sub"
1906 ).unwrap();
1907 if let Statement::Select(q) = stmt {
1908 assert!(q.subquery.is_some());
1909 let sub = q.subquery.unwrap();
1910 assert_eq!(sub.table, "t");
1911 if let ColumnList::Named(exprs) = &q.columns {
1912 assert_eq!(exprs.len(), 1);
1913 assert_eq!(exprs[0].output_name(), "x");
1914 } else {
1915 panic!("Expected Named columns");
1916 }
1917 } else {
1918 panic!("Expected Select");
1919 }
1920 }
1921
1922 #[test]
1923 fn test_subquery_with_where() {
1924 let stmt = parse_query(
1925 "SELECT x FROM (SELECT x FROM t WHERE y > 0) LIMIT 5"
1926 ).unwrap();
1927 if let Statement::Select(q) = stmt {
1928 assert!(q.subquery.is_some());
1929 assert_eq!(q.limit, Some(5));
1930 let sub = q.subquery.unwrap();
1931 assert_eq!(sub.table, "t");
1932 assert!(sub.where_clause.is_some());
1933 } else {
1934 panic!("Expected Select");
1935 }
1936 }
1937
1938 #[test]
1941 fn test_create_view_aggregate_subtraction() {
1942 let stmt = parse_query(
1943 "CREATE VIEW v AS SELECT token, SUM(sell) - SUM(buy) as net FROM orders GROUP BY token"
1944 ).unwrap();
1945 if let Statement::CreateView(cv) = stmt {
1946 assert_eq!(cv.view_name, "v");
1947 assert_eq!(cv.query.group_by, Some(vec!["token".into()]));
1948 if let ColumnList::Named(exprs) = &cv.query.columns {
1949 assert_eq!(exprs.len(), 2);
1950 assert_eq!(exprs[1].output_name(), "net");
1951 assert!(exprs[1].is_aggregate());
1952 } else {
1953 panic!("Expected Named columns");
1954 }
1955 } else {
1956 panic!("Expected CreateView, got {:?}", stmt);
1957 }
1958 }
1959}