1use regex::Regex;
4use std::sync::LazyLock;
5
6use crate::errors::MdqlError;
7pub use crate::query_ast::*;
8
9static KEYWORDS: &[&str] = &[
12 "SELECT", "FROM", "WHERE", "AND", "OR", "ORDER", "BY",
13 "ASC", "DESC", "LIMIT", "LIKE", "IN", "IS", "NOT", "NULL",
14 "JOIN", "ON", "AS", "GROUP", "HAVING",
15 "INSERT", "INTO", "VALUES", "UPDATE", "SET", "DELETE",
16 "ALTER", "TABLE", "RENAME", "FIELD", "TO", "DROP", "MERGE", "FIELDS",
17 "CASE", "WHEN", "THEN", "ELSE", "END",
18 "INTERVAL", "DAY", "DAYS", "CURRENT_DATE", "CURRENT_TIMESTAMP", "DATEDIFF",
19 "CREATE", "VIEW", "CASCADE", "RESTRICT",
20];
21
22static AGG_FUNCS: &[&str] = &["COUNT", "SUM", "AVG", "MIN", "MAX"];
23
24static TOKEN_RE: LazyLock<Regex> = LazyLock::new(|| {
25 Regex::new(
26 r#"(?x)
27 \s*(?:
28 (?P<backtick>`[^`]+`)
29 | (?P<string>'(?:[^'\\]|\\.)*')
30 | (?P<number>\d+(?:\.\d+)?)
31 | (?P<op><=|>=|!=|[=<>,*()+\-/%])
32 | (?P<word>[A-Za-z_][A-Za-z0-9_./-]*)
33 )"#,
34 )
35 .unwrap()
36});
37
38#[derive(Debug, Clone)]
39struct Token {
40 token_type: String,
41 value: String,
42 raw: String,
43}
44
45fn tokenize(sql: &str) -> Vec<Token> {
46 let mut tokens = Vec::new();
47 for caps in TOKEN_RE.captures_iter(sql) {
48 if let Some(m) = caps.name("backtick") {
49 let raw = m.as_str();
50 tokens.push(Token {
51 token_type: "ident".into(),
52 value: raw[1..raw.len() - 1].into(),
53 raw: raw.into(),
54 });
55 } else if let Some(m) = caps.name("string") {
56 let raw = m.as_str();
57 tokens.push(Token {
58 token_type: "string".into(),
59 value: raw[1..raw.len() - 1].into(),
60 raw: raw.into(),
61 });
62 } else if let Some(m) = caps.name("number") {
63 let raw = m.as_str();
64 tokens.push(Token {
65 token_type: "number".into(),
66 value: raw.into(),
67 raw: raw.into(),
68 });
69 } else if let Some(m) = caps.name("op") {
70 let raw = m.as_str();
71 tokens.push(Token {
72 token_type: "op".into(),
73 value: raw.into(),
74 raw: raw.into(),
75 });
76 } else if let Some(m) = caps.name("word") {
77 let raw = m.as_str();
78 if KEYWORDS.contains(&raw.to_uppercase().as_str()) {
79 tokens.push(Token {
80 token_type: "keyword".into(),
81 value: raw.to_uppercase(),
82 raw: raw.into(),
83 });
84 } else {
85 tokens.push(Token {
86 token_type: "ident".into(),
87 value: raw.into(),
88 raw: raw.into(),
89 });
90 }
91 }
92 }
93 tokens
94}
95
96struct Parser {
99 tokens: Vec<Token>,
100 pos: usize,
101}
102
103impl Parser {
104 fn new(tokens: Vec<Token>) -> Self {
105 Parser { tokens, pos: 0 }
106 }
107
108 fn peek(&self) -> Option<&Token> {
109 self.tokens.get(self.pos)
110 }
111
112 fn advance(&mut self) -> Token {
113 let t = self.tokens[self.pos].clone();
114 self.pos += 1;
115 t
116 }
117
118 fn expect(&mut self, type_: &str, value: Option<&str>) -> Result<Token, MdqlError> {
119 let t = self.peek().ok_or_else(|| {
120 MdqlError::QueryParse(format!(
121 "Unexpected end of query, expected {}",
122 value.unwrap_or(type_)
123 ))
124 })?;
125 let matches_type = t.token_type == type_;
126 let matches_value = value.map_or(true, |v| t.value == v);
127 if !matches_type || !matches_value {
128 return Err(MdqlError::QueryParse(format!(
129 "Expected {}, got '{}' at position {}",
130 value.unwrap_or(type_),
131 t.raw,
132 self.pos
133 )));
134 }
135 Ok(self.advance())
136 }
137
138 fn match_keyword(&mut self, kw: &str) -> bool {
139 if let Some(t) = self.peek() {
140 if t.token_type == "keyword" && t.value == kw {
141 self.advance();
142 return true;
143 }
144 }
145 false
146 }
147
148 fn parse_statement(&mut self) -> Result<Statement, MdqlError> {
149 let t = self.peek().ok_or_else(|| MdqlError::QueryParse("Empty query".into()))?;
150 match (t.token_type.as_str(), t.value.as_str()) {
151 ("keyword", "SELECT") => {
152 let q = self.parse_select()?;
153 self.expect_end()?;
154 Ok(Statement::Select(q))
155 }
156 ("keyword", "INSERT") => Ok(Statement::Insert(self.parse_insert()?)),
157 ("keyword", "UPDATE") => Ok(Statement::Update(self.parse_update()?)),
158 ("keyword", "DELETE") => Ok(Statement::Delete(self.parse_delete()?)),
159 ("keyword", "ALTER") => self.parse_alter(),
160 ("keyword", "CREATE") => self.parse_create_view(),
161 ("keyword", "DROP") => self.parse_drop_view(),
162 _ => Err(MdqlError::QueryParse(format!(
163 "Expected SELECT, INSERT, UPDATE, DELETE, ALTER, CREATE, or DROP, got '{}'",
164 t.raw
165 ))),
166 }
167 }
168
169 fn parse_select(&mut self) -> Result<SelectQuery, MdqlError> {
170 self.expect("keyword", Some("SELECT"))?;
171 let columns = self.parse_columns()?;
172 self.expect("keyword", Some("FROM"))?;
173
174 let mut subquery = None;
176 let (table, mut table_alias) = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
177 self.advance();
178 let inner = self.parse_select()?;
179 self.expect("op", Some(")"))?;
180 subquery = Some(Box::new(inner));
181 let alias = if let Some(t) = self.peek() {
182 if t.token_type == "ident" && !self.is_clause_keyword(t) {
183 Some(self.advance().value)
184 } else {
185 None
186 }
187 } else {
188 None
189 };
190 ("_subquery".to_string(), alias)
191 } else {
192 let t = self.parse_ident()?;
193 (t, None)
194 };
195
196 if subquery.is_none() {
198 if let Some(t) = self.peek() {
199 if t.token_type == "ident" && !self.is_clause_keyword(t) {
200 table_alias = Some(self.advance().value);
201 }
202 }
203 }
204
205 let mut joins = Vec::new();
207 while self.match_keyword("JOIN") {
208 let join_table = self.parse_ident()?;
209 let mut join_alias = None;
210 if let Some(t) = self.peek() {
211 if t.token_type == "ident" && !self.is_clause_keyword(t) {
212 join_alias = Some(self.advance().value);
213 }
214 }
215 self.expect("keyword", Some("ON"))?;
216 let left_col = self.parse_ident()?;
217 self.expect("op", Some("="))?;
218 let right_col = self.parse_ident()?;
219 joins.push(JoinClause {
220 table: join_table,
221 alias: join_alias,
222 left_col,
223 right_col,
224 });
225 }
226
227 let mut where_clause = None;
228 if self.match_keyword("WHERE") {
229 where_clause = Some(self.parse_or_expr()?);
230 }
231
232 let mut group_by = None;
233 if self.match_keyword("GROUP") {
234 self.expect("keyword", Some("BY"))?;
235 let mut cols = vec![self.parse_ident()?];
236 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
237 self.advance();
238 cols.push(self.parse_ident()?);
239 }
240 group_by = Some(cols);
241 }
242
243 let mut having = None;
244 if self.match_keyword("HAVING") {
245 having = Some(self.parse_or_expr()?);
246 }
247
248 let mut order_by = None;
249 if self.match_keyword("ORDER") {
250 self.expect("keyword", Some("BY"))?;
251 order_by = Some(self.parse_order_by()?);
252 }
253
254 let mut limit = None;
255 if self.match_keyword("LIMIT") {
256 let t = self.expect("number", None)?;
257 limit = Some(t.value.parse::<i64>().map_err(|_| {
258 MdqlError::QueryParse(format!("Invalid LIMIT value: {}", t.value))
259 })?);
260 }
261
262 Ok(SelectQuery {
263 columns,
264 table,
265 table_alias,
266 subquery,
267 joins,
268 where_clause,
269 group_by,
270 having,
271 order_by,
272 limit,
273 })
274 }
275
276 fn parse_insert(&mut self) -> Result<InsertQuery, MdqlError> {
277 self.expect("keyword", Some("INSERT"))?;
278 self.expect("keyword", Some("INTO"))?;
279 let table = self.parse_ident()?;
280
281 self.expect("op", Some("("))?;
282 let mut columns = vec![self.parse_ident()?];
283 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
284 self.advance();
285 columns.push(self.parse_ident()?);
286 }
287 self.expect("op", Some(")"))?;
288
289 self.expect("keyword", Some("VALUES"))?;
290
291 self.expect("op", Some("("))?;
292 let mut values = vec![self.parse_value()?];
293 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
294 self.advance();
295 values.push(self.parse_value()?);
296 }
297 self.expect("op", Some(")"))?;
298
299 if columns.len() != values.len() {
300 return Err(MdqlError::QueryParse(format!(
301 "Column count ({}) does not match value count ({})",
302 columns.len(),
303 values.len()
304 )));
305 }
306
307 self.expect_end()?;
308 Ok(InsertQuery {
309 table,
310 columns,
311 values,
312 })
313 }
314
315 fn parse_update(&mut self) -> Result<UpdateQuery, MdqlError> {
316 self.expect("keyword", Some("UPDATE"))?;
317 let table = self.parse_ident()?;
318 self.expect("keyword", Some("SET"))?;
319
320 let mut assignments = Vec::new();
321 let col = self.parse_ident()?;
322 self.expect("op", Some("="))?;
323 let val = self.parse_value()?;
324 assignments.push((col, val));
325
326 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
327 self.advance();
328 let col = self.parse_ident()?;
329 self.expect("op", Some("="))?;
330 let val = self.parse_value()?;
331 assignments.push((col, val));
332 }
333
334 let mut where_clause = None;
335 if self.match_keyword("WHERE") {
336 where_clause = Some(self.parse_or_expr()?);
337 }
338
339 self.expect_end()?;
340 Ok(UpdateQuery {
341 table,
342 assignments,
343 where_clause,
344 })
345 }
346
347 fn parse_delete(&mut self) -> Result<DeleteQuery, MdqlError> {
348 self.expect("keyword", Some("DELETE"))?;
349 self.expect("keyword", Some("FROM"))?;
350 let table = self.parse_ident()?;
351
352 let mut where_clause = None;
353 if self.match_keyword("WHERE") {
354 where_clause = Some(self.parse_or_expr()?);
355 }
356
357 self.expect_end()?;
358 Ok(DeleteQuery {
359 table,
360 where_clause,
361 })
362 }
363
364 fn parse_alter(&mut self) -> Result<Statement, MdqlError> {
365 self.expect("keyword", Some("ALTER"))?;
366 self.expect("keyword", Some("TABLE"))?;
367 let table = self.parse_ident()?;
368
369 let t = self.peek().ok_or_else(|| {
370 MdqlError::QueryParse("Expected RENAME, DROP, or MERGE after table name".into())
371 })?;
372
373 match (t.token_type.as_str(), t.value.as_str()) {
374 ("keyword", "RENAME") => {
375 self.advance();
376 self.expect("keyword", Some("FIELD"))?;
377 let old_name = self.parse_string_or_ident()?;
378 self.expect("keyword", Some("TO"))?;
379 let new_name = self.parse_string_or_ident()?;
380 self.expect_end()?;
381 Ok(Statement::AlterRename(AlterRenameFieldQuery {
382 table,
383 old_name,
384 new_name,
385 }))
386 }
387 ("keyword", "DROP") => {
388 self.advance();
389 self.expect("keyword", Some("FIELD"))?;
390 let field_name = self.parse_string_or_ident()?;
391 self.expect_end()?;
392 Ok(Statement::AlterDrop(AlterDropFieldQuery {
393 table,
394 field_name,
395 }))
396 }
397 ("keyword", "MERGE") => {
398 self.advance();
399 self.expect("keyword", Some("FIELDS"))?;
400 let mut sources = vec![self.parse_string_or_ident()?];
401 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
402 self.advance();
403 sources.push(self.parse_string_or_ident()?);
404 }
405 self.expect("keyword", Some("INTO"))?;
406 let target = self.parse_string_or_ident()?;
407 self.expect_end()?;
408 Ok(Statement::AlterMerge(AlterMergeFieldsQuery {
409 table,
410 sources,
411 into: target,
412 }))
413 }
414 _ => Err(MdqlError::QueryParse(format!(
415 "Expected RENAME, DROP, or MERGE, got '{}'",
416 t.raw
417 ))),
418 }
419 }
420
421 fn parse_create_view(&mut self) -> Result<Statement, MdqlError> {
422 self.expect("keyword", Some("CREATE"))?;
423 self.expect("keyword", Some("VIEW"))?;
424 let view_name = self.parse_ident()?;
425
426 let columns = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
427 self.advance();
428 let mut cols = vec![self.parse_ident()?];
429 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
430 self.advance();
431 cols.push(self.parse_ident()?);
432 }
433 self.expect("op", Some(")"))?;
434 Some(cols)
435 } else {
436 None
437 };
438
439 self.expect("keyword", Some("AS"))?;
440 let query = Box::new(self.parse_select()?);
441 self.expect_end()?;
442
443 Ok(Statement::CreateView(CreateViewQuery {
444 view_name,
445 columns,
446 query,
447 }))
448 }
449
450 fn parse_drop_view(&mut self) -> Result<Statement, MdqlError> {
451 self.expect("keyword", Some("DROP"))?;
452 self.expect("keyword", Some("VIEW"))?;
453 let view_name = self.parse_ident()?;
454 self.expect_end()?;
455 Ok(Statement::DropView(DropViewQuery { view_name }))
456 }
457
458 fn parse_string_or_ident(&mut self) -> Result<String, MdqlError> {
459 let t = self.peek().ok_or_else(|| {
460 MdqlError::QueryParse("Expected field name, got end of query".into())
461 })?;
462 match t.token_type.as_str() {
463 "string" => {
464 let v = self.advance().value;
465 Ok(v)
466 }
467 "ident" | "keyword" => {
468 let v = self.advance().value;
469 Ok(v)
470 }
471 _ => Err(MdqlError::QueryParse(format!(
472 "Expected field name, got '{}'",
473 t.raw
474 ))),
475 }
476 }
477
478 fn parse_columns(&mut self) -> Result<ColumnList, MdqlError> {
479 if let Some(t) = self.peek() {
480 if t.token_type == "op" && t.value == "*" {
481 self.advance();
482 return Ok(ColumnList::All);
483 }
484 }
485
486 let mut exprs = vec![self.parse_select_expr()?];
487 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
488 self.advance();
489 exprs.push(self.parse_select_expr()?);
490 }
491 Ok(ColumnList::Named(exprs))
492 }
493
494 fn peek_is_agg_func(&self) -> bool {
495 let t = match self.peek() {
496 Some(t) => t,
497 None => return false,
498 };
499 let name_upper = t.value.to_uppercase();
500 if !AGG_FUNCS.contains(&name_upper.as_str()) {
501 return false;
502 }
503 self.tokens
505 .get(self.pos + 1)
506 .map_or(false, |next| next.token_type == "op" && next.value == "(")
507 }
508
509 fn parse_select_expr(&mut self) -> Result<SelectExpr, MdqlError> {
510 let _t = self.peek().ok_or_else(|| {
511 MdqlError::QueryParse("Expected column or aggregate, got end of query".into())
512 })?;
513
514 let expr = self.parse_additive()?;
515
516 let alias = if self.match_keyword("AS") {
517 Some(self.parse_ident()?)
518 } else if self.peek().map_or(false, |t| {
519 t.token_type == "ident" && !self.is_clause_keyword(t)
520 }) {
521 Some(self.advance().value)
522 } else {
523 None
524 };
525
526 if let Expr::Aggregate { func, arg, arg_expr } = expr {
528 return Ok(SelectExpr::Aggregate {
529 func,
530 arg,
531 arg_expr: arg_expr.map(|e| *e),
532 alias,
533 });
534 }
535
536 if alias.is_none() {
537 if let Expr::Column(name) = &expr {
538 return Ok(SelectExpr::Column(name.clone()));
539 }
540 }
541
542 Ok(SelectExpr::Expr { expr, alias })
543 }
544
545 fn peek_is_additive_op(&self) -> bool {
548 self.peek().map_or(false, |t| {
549 t.token_type == "op" && (t.value == "+" || t.value == "-")
550 })
551 }
552
553 fn peek_is_multiplicative_op(&self) -> bool {
554 self.peek().map_or(false, |t| {
555 t.token_type == "op" && (t.value == "*" || t.value == "/" || t.value == "%")
556 })
557 }
558
559 fn parse_additive(&mut self) -> Result<Expr, MdqlError> {
560 let mut left = self.parse_multiplicative()?;
561 while self.peek_is_additive_op() {
562 let op_tok = self.advance();
563 let is_sub = op_tok.value == "-";
564
565 if self.peek().map_or(false, |t| t.token_type == "keyword" && t.value == "INTERVAL") {
567 self.advance(); let days_expr = self.parse_multiplicative()?;
569 if !self.match_keyword("DAY") && !self.match_keyword("DAYS") {
571 return Err(MdqlError::QueryParse("Expected DAY after INTERVAL value".into()));
572 }
573 let days = if is_sub {
574 Expr::UnaryMinus(Box::new(days_expr))
575 } else {
576 days_expr
577 };
578 left = Expr::DateAdd {
579 date: Box::new(left),
580 days: Box::new(days),
581 };
582 continue;
583 }
584
585 let op = match op_tok.value.as_str() {
586 "+" => ArithOp::Add,
587 "-" => ArithOp::Sub,
588 _ => unreachable!(),
589 };
590 let right = self.parse_multiplicative()?;
591 left = Expr::BinaryOp {
592 left: Box::new(left),
593 op,
594 right: Box::new(right),
595 };
596 }
597 Ok(left)
598 }
599
600 fn parse_multiplicative(&mut self) -> Result<Expr, MdqlError> {
601 let mut left = self.parse_unary()?;
602 while self.peek_is_multiplicative_op() {
603 let op_tok = self.advance();
604 let op = match op_tok.value.as_str() {
605 "*" => ArithOp::Mul,
606 "/" => ArithOp::Div,
607 "%" => ArithOp::Mod,
608 _ => unreachable!(),
609 };
610 let right = self.parse_unary()?;
611 left = Expr::BinaryOp {
612 left: Box::new(left),
613 op,
614 right: Box::new(right),
615 };
616 }
617 Ok(left)
618 }
619
620 fn parse_unary(&mut self) -> Result<Expr, MdqlError> {
621 if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "-") {
622 self.advance();
623 let inner = self.parse_atom()?;
624 match inner {
626 Expr::Literal(SqlValue::Int(n)) => Ok(Expr::Literal(SqlValue::Int(-n))),
627 Expr::Literal(SqlValue::Float(f)) => Ok(Expr::Literal(SqlValue::Float(-f))),
628 _ => Ok(Expr::UnaryMinus(Box::new(inner))),
629 }
630 } else {
631 self.parse_atom()
632 }
633 }
634
635 fn parse_atom(&mut self) -> Result<Expr, MdqlError> {
636 if self.peek_is_agg_func() {
637 return self.parse_agg_expr();
638 }
639
640 let t = self.peek().ok_or_else(|| {
641 MdqlError::QueryParse("Expected expression, got end of query".into())
642 })?;
643
644 match t.token_type.as_str() {
645 "number" => {
646 let v = self.advance().value;
647 if v.contains('.') {
648 let f: f64 = v.parse().map_err(|_| {
649 MdqlError::QueryParse(format!("Invalid float: {}", v))
650 })?;
651 Ok(Expr::Literal(SqlValue::Float(f)))
652 } else {
653 let n: i64 = v.parse().map_err(|_| {
654 MdqlError::QueryParse(format!("Invalid int: {}", v))
655 })?;
656 Ok(Expr::Literal(SqlValue::Int(n)))
657 }
658 }
659 "string" => {
660 let v = self.advance().value;
661 Ok(Expr::Literal(SqlValue::String(v)))
662 }
663 "keyword" if t.value == "NULL" => {
664 self.advance();
665 Ok(Expr::Literal(SqlValue::Null))
666 }
667 "keyword" if t.value == "CASE" => {
668 self.parse_case_expr()
669 }
670 "keyword" if t.value == "CURRENT_DATE" => {
671 self.advance();
672 Ok(Expr::CurrentDate)
673 }
674 "keyword" if t.value == "CURRENT_TIMESTAMP" => {
675 self.advance();
676 Ok(Expr::CurrentTimestamp)
677 }
678 "keyword" if t.value == "DATEDIFF" => {
679 self.advance();
680 self.expect("op", Some("("))?;
681 let left = self.parse_additive()?;
682 self.expect("op", Some(","))?;
683 let right = self.parse_additive()?;
684 self.expect("op", Some(")"))?;
685 Ok(Expr::DateDiff { left: Box::new(left), right: Box::new(right) })
686 }
687 "op" if t.value == "(" => {
688 self.advance();
689 let expr = self.parse_additive()?;
690 self.expect("op", Some(")"))?;
691 Ok(expr)
692 }
693 "ident" => {
694 let name = self.advance().value;
695 Ok(Expr::Column(name))
696 }
697 "keyword" if !Self::is_reserved_keyword(&t.value) => {
698 let name = self.advance().value;
699 Ok(Expr::Column(name))
700 }
701 _ => Err(MdqlError::QueryParse(format!(
702 "Expected expression, got '{}'",
703 t.raw
704 ))),
705 }
706 }
707
708 fn parse_case_expr(&mut self) -> Result<Expr, MdqlError> {
709 self.expect("keyword", Some("CASE"))?;
710 let mut whens = Vec::new();
711 while self.match_keyword("WHEN") {
712 let condition = self.parse_or_expr()?;
713 self.expect("keyword", Some("THEN"))?;
714 let result = self.parse_additive()?;
715 whens.push((condition, Box::new(result)));
716 }
717 if whens.is_empty() {
718 return Err(MdqlError::QueryParse("CASE requires at least one WHEN clause".into()));
719 }
720 let else_expr = if self.match_keyword("ELSE") {
721 Some(Box::new(self.parse_additive()?))
722 } else {
723 None
724 };
725 self.expect("keyword", Some("END"))?;
726 Ok(Expr::Case { whens, else_expr })
727 }
728
729 fn parse_agg_expr(&mut self) -> Result<Expr, MdqlError> {
730 let func_name = self.advance().value.to_uppercase();
731 let func = match func_name.as_str() {
732 "COUNT" => AggFunc::Count,
733 "SUM" => AggFunc::Sum,
734 "AVG" => AggFunc::Avg,
735 "MIN" => AggFunc::Min,
736 "MAX" => AggFunc::Max,
737 _ => unreachable!(),
738 };
739 self.expect("op", Some("("))?;
740 let (arg, arg_expr) = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "*") {
741 self.advance();
742 ("*".to_string(), None)
743 } else {
744 let expr = self.parse_additive()?;
745 if let Expr::Column(name) = &expr {
746 (name.clone(), None)
747 } else {
748 (expr.display_name(), Some(Box::new(expr)))
749 }
750 };
751 self.expect("op", Some(")"))?;
752 Ok(Expr::Aggregate { func, arg, arg_expr })
753 }
754
755 fn parse_ident(&mut self) -> Result<String, MdqlError> {
756 let t = self.peek().ok_or_else(|| {
757 MdqlError::QueryParse("Expected identifier, got end of query".into())
758 })?;
759 match t.token_type.as_str() {
760 "ident" | "keyword" => {
761 let v = self.advance().value;
762 Ok(v)
763 }
764 _ => Err(MdqlError::QueryParse(format!(
765 "Expected identifier, got '{}'",
766 t.raw
767 ))),
768 }
769 }
770
771 fn parse_or_expr(&mut self) -> Result<WhereClause, MdqlError> {
772 let mut left = self.parse_and_expr()?;
773 while self.match_keyword("OR") {
774 let right = self.parse_and_expr()?;
775 left = WhereClause::BoolOp(BoolOp {
776 op: BoolOpKind::Or,
777 left: Box::new(left),
778 right: Box::new(right),
779 });
780 }
781 Ok(left)
782 }
783
784 fn parse_and_expr(&mut self) -> Result<WhereClause, MdqlError> {
785 let mut left = self.parse_comparison()?;
786 while self.match_keyword("AND") {
787 let right = self.parse_comparison()?;
788 left = WhereClause::BoolOp(BoolOp {
789 op: BoolOpKind::And,
790 left: Box::new(left),
791 right: Box::new(right),
792 });
793 }
794 Ok(left)
795 }
796
797 fn parse_comparison(&mut self) -> Result<WhereClause, MdqlError> {
798 if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
800 let saved_pos = self.pos;
802 self.advance();
803 let result = self.parse_or_expr();
805 if result.is_ok() && self.peek().map_or(false, |t| t.token_type == "op" && t.value == ")") {
806 self.advance();
807 return result;
808 }
809 self.pos = saved_pos;
811 }
812
813 let left_expr = self.parse_additive()?;
815
816 let col = left_expr.as_column().unwrap_or("").to_string();
818
819 if self.match_keyword("IS") {
821 if self.match_keyword("NOT") {
822 self.expect("keyword", Some("NULL"))?;
823 return Ok(WhereClause::Comparison(Comparison {
824 column: col,
825 op: CmpOp::IsNotNull,
826 value: None,
827 left_expr: Some(left_expr),
828 right_expr: None,
829 }));
830 }
831 self.expect("keyword", Some("NULL"))?;
832 return Ok(WhereClause::Comparison(Comparison {
833 column: col,
834 op: CmpOp::IsNull,
835 value: None,
836 left_expr: Some(left_expr),
837 right_expr: None,
838 }));
839 }
840
841 if self.match_keyword("IN") {
843 self.expect("op", Some("("))?;
844 let mut values = vec![self.parse_value()?];
845 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
846 self.advance();
847 values.push(self.parse_value()?);
848 }
849 self.expect("op", Some(")"))?;
850 return Ok(WhereClause::Comparison(Comparison {
851 column: col,
852 op: CmpOp::In,
853 value: Some(SqlValue::List(values)),
854 left_expr: Some(left_expr),
855 right_expr: None,
856 }));
857 }
858
859 if self.match_keyword("LIKE") {
861 let val = self.parse_value()?;
862 return Ok(WhereClause::Comparison(Comparison {
863 column: col,
864 op: CmpOp::Like,
865 value: Some(val),
866 left_expr: Some(left_expr),
867 right_expr: None,
868 }));
869 }
870
871 if self.match_keyword("NOT") {
873 if self.match_keyword("LIKE") {
874 let val = self.parse_value()?;
875 return Ok(WhereClause::Comparison(Comparison {
876 column: col,
877 op: CmpOp::NotLike,
878 value: Some(val),
879 left_expr: Some(left_expr),
880 right_expr: None,
881 }));
882 }
883 return Err(MdqlError::QueryParse("Expected LIKE after NOT".into()));
884 }
885
886 if let Some(t) = self.peek() {
888 if t.token_type == "op" && ["=", "!=", "<", ">", "<=", ">="].contains(&t.value.as_str())
889 {
890 let op_str = self.advance().value;
891 let op = match op_str.as_str() {
892 "=" => CmpOp::Eq,
893 "!=" => CmpOp::Ne,
894 "<" => CmpOp::Lt,
895 ">" => CmpOp::Gt,
896 "<=" => CmpOp::Le,
897 ">=" => CmpOp::Ge,
898 _ => unreachable!(),
899 };
900 let right_expr = self.parse_additive()?;
902 let value = match &right_expr {
904 Expr::Literal(v) => Some(v.clone()),
905 _ => None,
906 };
907 return Ok(WhereClause::Comparison(Comparison {
908 column: col,
909 op,
910 value,
911 left_expr: Some(left_expr),
912 right_expr: Some(right_expr),
913 }));
914 }
915 }
916
917 let got = self.peek().map_or("end".to_string(), |t| t.raw.clone());
918 Err(MdqlError::QueryParse(format!(
919 "Expected operator after '{}', got '{}'",
920 left_expr.display_name(), got
921 )))
922 }
923
924 fn parse_value(&mut self) -> Result<SqlValue, MdqlError> {
925 let t = self.peek().ok_or_else(|| {
926 MdqlError::QueryParse("Expected value, got end of query".into())
927 })?;
928 match t.token_type.as_str() {
929 "string" => {
930 let v = self.advance().value;
931 Ok(SqlValue::String(v))
932 }
933 "number" => {
934 let v = self.advance().value;
935 if v.contains('.') {
936 Ok(SqlValue::Float(v.parse().map_err(|_| {
937 MdqlError::QueryParse(format!("Invalid float: {}", v))
938 })?))
939 } else {
940 Ok(SqlValue::Int(v.parse().map_err(|_| {
941 MdqlError::QueryParse(format!("Invalid int: {}", v))
942 })?))
943 }
944 }
945 "keyword" if t.value == "NULL" => {
946 self.advance();
947 Ok(SqlValue::Null)
948 }
949 _ => Err(MdqlError::QueryParse(format!(
950 "Expected value, got '{}'",
951 t.raw
952 ))),
953 }
954 }
955
956 fn parse_order_by(&mut self) -> Result<Vec<OrderSpec>, MdqlError> {
957 let mut specs = vec![self.parse_order_spec()?];
958 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
959 self.advance();
960 specs.push(self.parse_order_spec()?);
961 }
962 Ok(specs)
963 }
964
965 fn parse_order_spec(&mut self) -> Result<OrderSpec, MdqlError> {
966 let expr = self.parse_additive()?;
967 let col = expr.as_column().unwrap_or("").to_string();
968 let descending = if self.match_keyword("DESC") {
969 true
970 } else {
971 self.match_keyword("ASC");
972 false
973 };
974 Ok(OrderSpec {
975 column: col,
976 expr: Some(expr),
977 descending,
978 })
979 }
980
981 fn is_clause_keyword(&self, t: &Token) -> bool {
982 t.token_type == "keyword"
983 && ["WHERE", "ORDER", "LIMIT", "JOIN", "ON", "GROUP"].contains(&t.value.as_str())
984 }
985
986 fn is_reserved_keyword(kw: &str) -> bool {
988 matches!(kw,
989 "AS" | "FROM" | "WHERE" | "AND" | "OR" | "ORDER" | "BY"
990 | "ASC" | "DESC" | "LIMIT" | "JOIN" | "ON" | "GROUP"
991 | "SELECT" | "INSERT" | "INTO" | "VALUES" | "UPDATE" | "SET"
992 | "DELETE" | "ALTER" | "TABLE" | "IS" | "NOT" | "IN" | "LIKE"
993 | "RENAME" | "FIELD" | "TO" | "DROP" | "MERGE" | "FIELDS"
994 | "CASE" | "WHEN" | "THEN" | "ELSE" | "END"
995 | "HAVING" | "INTERVAL" | "DAY" | "DAYS"
996 | "CURRENT_DATE" | "CURRENT_TIMESTAMP" | "DATEDIFF"
997 | "CREATE" | "VIEW" | "CASCADE" | "RESTRICT"
998 )
999 }
1000
1001 fn expect_end(&self) -> Result<(), MdqlError> {
1002 if let Some(t) = self.peek() {
1003 return Err(MdqlError::QueryParse(format!(
1004 "Unexpected token '{}' at position {}",
1005 t.raw, self.pos
1006 )));
1007 }
1008 Ok(())
1009 }
1010}
1011
1012pub fn parse_query(sql: &str) -> crate::errors::Result<Statement> {
1013 let tokens = tokenize(sql);
1014 if tokens.is_empty() {
1015 return Err(MdqlError::QueryParse("Empty query".into()));
1016 }
1017 let mut parser = Parser::new(tokens);
1018 parser.parse_statement()
1019}
1020
1021#[cfg(test)]
1022mod tests {
1023 use super::*;
1024
1025 #[test]
1026 fn test_simple_select() {
1027 let stmt = parse_query("SELECT title, status FROM strategies").unwrap();
1028 if let Statement::Select(q) = stmt {
1029 assert_eq!(q.columns, ColumnList::Named(vec![SelectExpr::Column("title".into()), SelectExpr::Column("status".into())]));
1030 assert_eq!(q.table, "strategies");
1031 } else {
1032 panic!("Expected Select");
1033 }
1034 }
1035
1036 #[test]
1037 fn test_select_star() {
1038 let stmt = parse_query("SELECT * FROM test").unwrap();
1039 if let Statement::Select(q) = stmt {
1040 assert_eq!(q.columns, ColumnList::All);
1041 } else {
1042 panic!("Expected Select");
1043 }
1044 }
1045
1046 #[test]
1047 fn test_where_clause() {
1048 let stmt = parse_query("SELECT title FROM test WHERE count > 5").unwrap();
1049 if let Statement::Select(q) = stmt {
1050 assert!(q.where_clause.is_some());
1051 } else {
1052 panic!("Expected Select");
1053 }
1054 }
1055
1056 #[test]
1057 fn test_order_by() {
1058 let stmt =
1059 parse_query("SELECT title FROM test ORDER BY composite DESC, title ASC").unwrap();
1060 if let Statement::Select(q) = stmt {
1061 let ob = q.order_by.unwrap();
1062 assert_eq!(ob.len(), 2);
1063 assert!(ob[0].descending);
1064 assert!(!ob[1].descending);
1065 } else {
1066 panic!("Expected Select");
1067 }
1068 }
1069
1070 #[test]
1071 fn test_limit() {
1072 let stmt = parse_query("SELECT * FROM test LIMIT 10").unwrap();
1073 if let Statement::Select(q) = stmt {
1074 assert_eq!(q.limit, Some(10));
1075 } else {
1076 panic!("Expected Select");
1077 }
1078 }
1079
1080 #[test]
1081 fn test_insert() {
1082 let stmt = parse_query(
1083 "INSERT INTO test (title, count) VALUES ('Hello', 42)",
1084 )
1085 .unwrap();
1086 if let Statement::Insert(q) = stmt {
1087 assert_eq!(q.table, "test");
1088 assert_eq!(q.columns, vec!["title", "count"]);
1089 assert_eq!(q.values[0], SqlValue::String("Hello".into()));
1090 assert_eq!(q.values[1], SqlValue::Int(42));
1091 } else {
1092 panic!("Expected Insert");
1093 }
1094 }
1095
1096 #[test]
1097 fn test_update() {
1098 let stmt = parse_query("UPDATE test SET status = 'KILLED' WHERE path = 'a.md'").unwrap();
1099 if let Statement::Update(q) = stmt {
1100 assert_eq!(q.table, "test");
1101 assert_eq!(q.assignments.len(), 1);
1102 assert!(q.where_clause.is_some());
1103 } else {
1104 panic!("Expected Update");
1105 }
1106 }
1107
1108 #[test]
1109 fn test_delete() {
1110 let stmt = parse_query("DELETE FROM test WHERE status = 'draft'").unwrap();
1111 if let Statement::Delete(q) = stmt {
1112 assert_eq!(q.table, "test");
1113 assert!(q.where_clause.is_some());
1114 } else {
1115 panic!("Expected Delete");
1116 }
1117 }
1118
1119 #[test]
1120 fn test_alter_rename() {
1121 let stmt =
1122 parse_query("ALTER TABLE test RENAME FIELD 'Summary' TO 'Overview'").unwrap();
1123 if let Statement::AlterRename(q) = stmt {
1124 assert_eq!(q.old_name, "Summary");
1125 assert_eq!(q.new_name, "Overview");
1126 } else {
1127 panic!("Expected AlterRename");
1128 }
1129 }
1130
1131 #[test]
1132 fn test_alter_drop() {
1133 let stmt = parse_query("ALTER TABLE test DROP FIELD 'Details'").unwrap();
1134 if let Statement::AlterDrop(q) = stmt {
1135 assert_eq!(q.field_name, "Details");
1136 } else {
1137 panic!("Expected AlterDrop");
1138 }
1139 }
1140
1141 #[test]
1142 fn test_alter_merge() {
1143 let stmt = parse_query(
1144 "ALTER TABLE test MERGE FIELDS 'Entry Rules', 'Exit Rules' INTO 'Trading Rules'",
1145 )
1146 .unwrap();
1147 if let Statement::AlterMerge(q) = stmt {
1148 assert_eq!(q.sources, vec!["Entry Rules", "Exit Rules"]);
1149 assert_eq!(q.into, "Trading Rules");
1150 } else {
1151 panic!("Expected AlterMerge");
1152 }
1153 }
1154
1155 #[test]
1156 fn test_backtick_ident() {
1157 let stmt = parse_query("SELECT `Structural Mechanism` FROM test").unwrap();
1158 if let Statement::Select(q) = stmt {
1159 assert_eq!(
1160 q.columns,
1161 ColumnList::Named(vec![SelectExpr::Column("Structural Mechanism".into())])
1162 );
1163 } else {
1164 panic!("Expected Select");
1165 }
1166 }
1167
1168 #[test]
1169 fn test_like_operator() {
1170 let stmt = parse_query("SELECT title FROM test WHERE categories LIKE '%defi%'").unwrap();
1171 if let Statement::Select(q) = stmt {
1172 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1173 assert_eq!(c.op, CmpOp::Like);
1174 assert_eq!(c.value, Some(SqlValue::String("%defi%".into())));
1175 } else {
1176 panic!("Expected LIKE comparison");
1177 }
1178 } else {
1179 panic!("Expected Select");
1180 }
1181 }
1182
1183 #[test]
1184 fn test_in_operator() {
1185 let stmt =
1186 parse_query("SELECT * FROM test WHERE status IN ('ACTIVE', 'LIVE')").unwrap();
1187 if let Statement::Select(q) = stmt {
1188 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1189 assert_eq!(c.op, CmpOp::In);
1190 } else {
1191 panic!("Expected IN comparison");
1192 }
1193 } else {
1194 panic!("Expected Select");
1195 }
1196 }
1197
1198 #[test]
1199 fn test_is_null() {
1200 let stmt = parse_query("SELECT * FROM test WHERE title IS NULL").unwrap();
1201 if let Statement::Select(q) = stmt {
1202 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1203 assert_eq!(c.op, CmpOp::IsNull);
1204 } else {
1205 panic!("Expected IS NULL comparison");
1206 }
1207 } else {
1208 panic!("Expected Select");
1209 }
1210 }
1211
1212 #[test]
1213 fn test_and_or() {
1214 let stmt = parse_query(
1215 "SELECT * FROM test WHERE status = 'ACTIVE' AND count > 5 OR title LIKE '%test%'",
1216 )
1217 .unwrap();
1218 if let Statement::Select(q) = stmt {
1219 assert!(q.where_clause.is_some());
1220 } else {
1221 panic!("Expected Select");
1222 }
1223 }
1224
1225 #[test]
1226 fn test_join() {
1227 let stmt = parse_query(
1228 "SELECT s.title, b.sharpe FROM strategies s JOIN backtests b ON b.strategy = s.path",
1229 )
1230 .unwrap();
1231 if let Statement::Select(q) = stmt {
1232 assert_eq!(q.table, "strategies");
1233 assert_eq!(q.table_alias, Some("s".into()));
1234 assert_eq!(q.joins.len(), 1);
1235 let join = &q.joins[0];
1236 assert_eq!(join.table, "backtests");
1237 assert_eq!(join.alias, Some("b".into()));
1238 } else {
1239 panic!("Expected Select");
1240 }
1241 }
1242
1243 #[test]
1244 fn test_multi_join() {
1245 let stmt = parse_query(
1246 "SELECT s.title, b.sharpe, c.verdict FROM strategies s JOIN backtests b ON b.strategy = s.path JOIN critiques c ON c.strategy = s.path",
1247 )
1248 .unwrap();
1249 if let Statement::Select(q) = stmt {
1250 assert_eq!(q.table, "strategies");
1251 assert_eq!(q.table_alias, Some("s".into()));
1252 assert_eq!(q.joins.len(), 2);
1253 assert_eq!(q.joins[0].table, "backtests");
1254 assert_eq!(q.joins[0].alias, Some("b".into()));
1255 assert_eq!(q.joins[0].left_col, "b.strategy");
1256 assert_eq!(q.joins[0].right_col, "s.path");
1257 assert_eq!(q.joins[1].table, "critiques");
1258 assert_eq!(q.joins[1].alias, Some("c".into()));
1259 assert_eq!(q.joins[1].left_col, "c.strategy");
1260 assert_eq!(q.joins[1].right_col, "s.path");
1261 } else {
1262 panic!("Expected Select");
1263 }
1264 }
1265
1266 #[test]
1267 fn test_empty_query() {
1268 assert!(parse_query("").is_err());
1269 }
1270
1271 #[test]
1272 fn test_count_star() {
1273 let stmt = parse_query("SELECT status, COUNT(*) AS cnt FROM strategies GROUP BY status").unwrap();
1274 if let Statement::Select(q) = stmt {
1275 if let ColumnList::Named(exprs) = &q.columns {
1276 assert_eq!(exprs.len(), 2);
1277 assert_eq!(exprs[0], SelectExpr::Column("status".into()));
1278 assert!(matches!(&exprs[1], SelectExpr::Aggregate {
1279 func: AggFunc::Count,
1280 arg,
1281 alias: Some(a),
1282 ..
1283 } if arg == "*" && a == "cnt"));
1284 } else {
1285 panic!("Expected Named columns");
1286 }
1287 assert_eq!(q.group_by, Some(vec!["status".into()]));
1288 } else {
1289 panic!("Expected Select");
1290 }
1291 }
1292
1293 #[test]
1294 fn test_count_column_as_ident() {
1295 let stmt = parse_query("INSERT INTO test (title, count) VALUES ('Hello', 42)").unwrap();
1297 if let Statement::Insert(q) = stmt {
1298 assert_eq!(q.columns, vec!["title", "count"]);
1299 } else {
1300 panic!("Expected Insert");
1301 }
1302 }
1303
1304 #[test]
1305 fn test_multiple_aggregates() {
1306 let stmt = parse_query("SELECT MIN(composite), MAX(composite), AVG(composite) FROM strategies").unwrap();
1307 if let Statement::Select(q) = stmt {
1308 if let ColumnList::Named(exprs) = &q.columns {
1309 assert_eq!(exprs.len(), 3);
1310 assert!(matches!(&exprs[0], SelectExpr::Aggregate { func: AggFunc::Min, .. }));
1311 assert!(matches!(&exprs[1], SelectExpr::Aggregate { func: AggFunc::Max, .. }));
1312 assert!(matches!(&exprs[2], SelectExpr::Aggregate { func: AggFunc::Avg, .. }));
1313 } else {
1314 panic!("Expected Named columns");
1315 }
1316 assert_eq!(q.group_by, None);
1317 } else {
1318 panic!("Expected Select");
1319 }
1320 }
1321
1322 #[test]
1325 fn test_select_arithmetic_expr() {
1326 let stmt = parse_query("SELECT a + b FROM test").unwrap();
1327 if let Statement::Select(q) = stmt {
1328 if let ColumnList::Named(exprs) = &q.columns {
1329 assert_eq!(exprs.len(), 1);
1330 assert!(matches!(&exprs[0], SelectExpr::Expr {
1331 expr: Expr::BinaryOp { op: ArithOp::Add, .. },
1332 alias: None,
1333 }));
1334 } else {
1335 panic!("Expected Named columns");
1336 }
1337 } else {
1338 panic!("Expected Select");
1339 }
1340 }
1341
1342 #[test]
1343 fn test_select_arithmetic_with_alias() {
1344 let stmt = parse_query("SELECT a + b AS total FROM test").unwrap();
1345 if let Statement::Select(q) = stmt {
1346 if let ColumnList::Named(exprs) = &q.columns {
1347 assert_eq!(exprs.len(), 1);
1348 assert!(matches!(&exprs[0], SelectExpr::Expr {
1349 alias: Some(a),
1350 ..
1351 } if a == "total"));
1352 assert_eq!(exprs[0].output_name(), "total");
1353 } else {
1354 panic!("Expected Named columns");
1355 }
1356 } else {
1357 panic!("Expected Select");
1358 }
1359 }
1360
1361 #[test]
1362 fn test_select_precedence() {
1363 let stmt = parse_query("SELECT a + b * c FROM test").unwrap();
1365 if let Statement::Select(q) = stmt {
1366 if let ColumnList::Named(exprs) = &q.columns {
1367 if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1368 if let Expr::BinaryOp { left, op, right } = expr {
1369 assert_eq!(*op, ArithOp::Add);
1370 assert!(matches!(left.as_ref(), Expr::Column(n) if n == "a"));
1371 assert!(matches!(right.as_ref(), Expr::BinaryOp { op: ArithOp::Mul, .. }));
1372 } else {
1373 panic!("Expected BinaryOp");
1374 }
1375 } else {
1376 panic!("Expected Expr variant");
1377 }
1378 } else {
1379 panic!("Expected Named columns");
1380 }
1381 } else {
1382 panic!("Expected Select");
1383 }
1384 }
1385
1386 #[test]
1387 fn test_select_parenthesized_expr() {
1388 let stmt = parse_query("SELECT (a + b) * c FROM test").unwrap();
1390 if let Statement::Select(q) = stmt {
1391 if let ColumnList::Named(exprs) = &q.columns {
1392 if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1393 if let Expr::BinaryOp { left, op, .. } = expr {
1394 assert_eq!(*op, ArithOp::Mul);
1395 assert!(matches!(left.as_ref(), Expr::BinaryOp { op: ArithOp::Add, .. }));
1396 } else {
1397 panic!("Expected BinaryOp");
1398 }
1399 } else {
1400 panic!("Expected Expr variant");
1401 }
1402 } else {
1403 panic!("Expected Named columns");
1404 }
1405 } else {
1406 panic!("Expected Select");
1407 }
1408 }
1409
1410 #[test]
1411 fn test_select_unary_minus() {
1412 let stmt = parse_query("SELECT -count FROM test").unwrap();
1413 if let Statement::Select(q) = stmt {
1414 if let ColumnList::Named(exprs) = &q.columns {
1415 assert!(matches!(&exprs[0], SelectExpr::Expr {
1416 expr: Expr::UnaryMinus(_),
1417 ..
1418 }));
1419 } else {
1420 panic!("Expected Named columns");
1421 }
1422 } else {
1423 panic!("Expected Select");
1424 }
1425 }
1426
1427 #[test]
1428 fn test_select_negative_literal() {
1429 let stmt = parse_query("SELECT -42 FROM test").unwrap();
1430 if let Statement::Select(q) = stmt {
1431 if let ColumnList::Named(exprs) = &q.columns {
1432 assert!(matches!(&exprs[0], SelectExpr::Expr {
1434 expr: Expr::Literal(SqlValue::Int(-42)),
1435 ..
1436 }));
1437 } else {
1438 panic!("Expected Named columns");
1439 }
1440 } else {
1441 panic!("Expected Select");
1442 }
1443 }
1444
1445 #[test]
1446 fn test_where_arithmetic_expr() {
1447 let stmt = parse_query("SELECT * FROM test WHERE a + b > 10").unwrap();
1448 if let Statement::Select(q) = stmt {
1449 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1450 assert_eq!(c.op, CmpOp::Gt);
1451 assert!(matches!(&c.left_expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1452 assert!(matches!(&c.right_expr, Some(Expr::Literal(SqlValue::Int(10)))));
1453 } else {
1454 panic!("Expected comparison");
1455 }
1456 } else {
1457 panic!("Expected Select");
1458 }
1459 }
1460
1461 #[test]
1462 fn test_where_both_sides_expr() {
1463 let stmt = parse_query("SELECT * FROM test WHERE a * 2 > b + 1").unwrap();
1464 if let Statement::Select(q) = stmt {
1465 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1466 assert_eq!(c.op, CmpOp::Gt);
1467 assert!(matches!(&c.left_expr, Some(Expr::BinaryOp { op: ArithOp::Mul, .. })));
1468 assert!(matches!(&c.right_expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1469 } else {
1470 panic!("Expected comparison");
1471 }
1472 } else {
1473 panic!("Expected Select");
1474 }
1475 }
1476
1477 #[test]
1478 fn test_order_by_expr() {
1479 let stmt = parse_query("SELECT * FROM test ORDER BY a + b DESC").unwrap();
1480 if let Statement::Select(q) = stmt {
1481 let ob = q.order_by.unwrap();
1482 assert_eq!(ob.len(), 1);
1483 assert!(ob[0].descending);
1484 assert!(matches!(&ob[0].expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1485 } else {
1486 panic!("Expected Select");
1487 }
1488 }
1489
1490 #[test]
1491 fn test_all_arithmetic_ops() {
1492 let stmt = parse_query("SELECT a + b, a - b, a * b, a / b, a % b FROM test").unwrap();
1493 if let Statement::Select(q) = stmt {
1494 if let ColumnList::Named(exprs) = &q.columns {
1495 assert_eq!(exprs.len(), 5);
1496 assert!(matches!(&exprs[0], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Add, .. }, .. }));
1497 assert!(matches!(&exprs[1], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Sub, .. }, .. }));
1498 assert!(matches!(&exprs[2], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Mul, .. }, .. }));
1499 assert!(matches!(&exprs[3], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Div, .. }, .. }));
1500 assert!(matches!(&exprs[4], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Mod, .. }, .. }));
1501 } else {
1502 panic!("Expected Named columns");
1503 }
1504 } else {
1505 panic!("Expected Select");
1506 }
1507 }
1508
1509 #[test]
1510 fn test_column_with_literal_arithmetic() {
1511 let stmt = parse_query("SELECT count * 2 + 1 FROM test").unwrap();
1512 if let Statement::Select(q) = stmt {
1513 if let ColumnList::Named(exprs) = &q.columns {
1514 if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1516 if let Expr::BinaryOp { left, op, right } = expr {
1517 assert_eq!(*op, ArithOp::Add);
1518 assert!(matches!(right.as_ref(), Expr::Literal(SqlValue::Int(1))));
1519 assert!(matches!(left.as_ref(), Expr::BinaryOp { op: ArithOp::Mul, .. }));
1520 } else {
1521 panic!("Expected BinaryOp");
1522 }
1523 } else {
1524 panic!("Expected Expr");
1525 }
1526 } else {
1527 panic!("Expected Named columns");
1528 }
1529 } else {
1530 panic!("Expected Select");
1531 }
1532 }
1533
1534 #[test]
1535 fn test_mixed_columns_and_exprs() {
1536 let stmt = parse_query("SELECT title, a + b AS sum, count FROM test").unwrap();
1537 if let Statement::Select(q) = stmt {
1538 if let ColumnList::Named(exprs) = &q.columns {
1539 assert_eq!(exprs.len(), 3);
1540 assert_eq!(exprs[0], SelectExpr::Column("title".into()));
1541 assert!(matches!(&exprs[1], SelectExpr::Expr { alias: Some(a), .. } if a == "sum"));
1542 assert_eq!(exprs[2], SelectExpr::Column("count".into()));
1543 } else {
1544 panic!("Expected Named columns");
1545 }
1546 } else {
1547 panic!("Expected Select");
1548 }
1549 }
1550
1551 #[test]
1554 fn test_case_when_basic() {
1555 let stmt = parse_query(
1556 "SELECT CASE WHEN status = 'ACTIVE' THEN 1 ELSE 0 END FROM test"
1557 ).unwrap();
1558 if let Statement::Select(q) = stmt {
1559 if let ColumnList::Named(exprs) = &q.columns {
1560 assert_eq!(exprs.len(), 1);
1561 assert!(matches!(&exprs[0], SelectExpr::Expr {
1562 expr: Expr::Case { .. },
1563 ..
1564 }));
1565 } else {
1566 panic!("Expected Named columns");
1567 }
1568 } else {
1569 panic!("Expected Select");
1570 }
1571 }
1572
1573 #[test]
1574 fn test_case_when_multiple_branches() {
1575 let stmt = parse_query(
1576 "SELECT CASE WHEN x > 10 THEN 'high' WHEN x > 5 THEN 'mid' ELSE 'low' END FROM test"
1577 ).unwrap();
1578 if let Statement::Select(q) = stmt {
1579 if let ColumnList::Named(exprs) = &q.columns {
1580 if let SelectExpr::Expr { expr: Expr::Case { whens, else_expr }, .. } = &exprs[0] {
1581 assert_eq!(whens.len(), 2);
1582 assert!(else_expr.is_some());
1583 } else {
1584 panic!("Expected Case expression");
1585 }
1586 } else {
1587 panic!("Expected Named columns");
1588 }
1589 } else {
1590 panic!("Expected Select");
1591 }
1592 }
1593
1594 #[test]
1595 fn test_case_when_no_else() {
1596 let stmt = parse_query(
1597 "SELECT CASE WHEN x = 1 THEN 'one' END FROM test"
1598 ).unwrap();
1599 if let Statement::Select(q) = stmt {
1600 if let ColumnList::Named(exprs) = &q.columns {
1601 if let SelectExpr::Expr { expr: Expr::Case { whens, else_expr }, .. } = &exprs[0] {
1602 assert_eq!(whens.len(), 1);
1603 assert!(else_expr.is_none());
1604 } else {
1605 panic!("Expected Case expression");
1606 }
1607 } else {
1608 panic!("Expected Named columns");
1609 }
1610 } else {
1611 panic!("Expected Select");
1612 }
1613 }
1614
1615 #[test]
1616 fn test_case_when_in_aggregate() {
1617 let stmt = parse_query(
1618 "SELECT SUM(CASE WHEN side = 'BUY' THEN size ELSE -size END) AS net FROM orders GROUP BY token"
1619 ).unwrap();
1620 if let Statement::Select(q) = stmt {
1621 if let ColumnList::Named(exprs) = &q.columns {
1622 assert_eq!(exprs.len(), 1);
1623 assert!(matches!(&exprs[0], SelectExpr::Aggregate {
1624 func: AggFunc::Sum,
1625 arg_expr: Some(Expr::Case { .. }),
1626 alias: Some(a),
1627 ..
1628 } if a == "net"));
1629 } else {
1630 panic!("Expected Named columns");
1631 }
1632 } else {
1633 panic!("Expected Select");
1634 }
1635 }
1636
1637 #[test]
1638 fn test_case_when_with_alias() {
1639 let stmt = parse_query(
1640 "SELECT CASE WHEN x > 0 THEN 'pos' ELSE 'neg' END AS sign FROM test"
1641 ).unwrap();
1642 if let Statement::Select(q) = stmt {
1643 if let ColumnList::Named(exprs) = &q.columns {
1644 assert!(matches!(&exprs[0], SelectExpr::Expr {
1645 expr: Expr::Case { .. },
1646 alias: Some(a),
1647 } if a == "sign"));
1648 } else {
1649 panic!("Expected Named columns");
1650 }
1651 } else {
1652 panic!("Expected Select");
1653 }
1654 }
1655
1656 #[test]
1657 fn test_create_view() {
1658 let stmt = parse_query("CREATE VIEW live AS SELECT * FROM strategies WHERE status = 'LIVE'").unwrap();
1659 if let Statement::CreateView(cv) = stmt {
1660 assert_eq!(cv.view_name, "live");
1661 assert!(cv.columns.is_none());
1662 assert_eq!(cv.query.table, "strategies");
1663 assert!(cv.query.where_clause.is_some());
1664 } else {
1665 panic!("Expected CreateView, got {:?}", stmt);
1666 }
1667 }
1668
1669 #[test]
1670 fn test_create_view_with_columns() {
1671 let stmt = parse_query("CREATE VIEW v1 (a, b) AS SELECT title, status FROM t").unwrap();
1672 if let Statement::CreateView(cv) = stmt {
1673 assert_eq!(cv.view_name, "v1");
1674 assert_eq!(cv.columns, Some(vec!["a".into(), "b".into()]));
1675 } else {
1676 panic!("Expected CreateView");
1677 }
1678 }
1679
1680 #[test]
1681 fn test_drop_view() {
1682 let stmt = parse_query("DROP VIEW live").unwrap();
1683 if let Statement::DropView(dv) = stmt {
1684 assert_eq!(dv.view_name, "live");
1685 } else {
1686 panic!("Expected DropView, got {:?}", stmt);
1687 }
1688 }
1689
1690 #[test]
1691 fn test_create_view_case_insensitive() {
1692 let stmt = parse_query("create view My_View as select * from t").unwrap();
1693 if let Statement::CreateView(cv) = stmt {
1694 assert_eq!(cv.view_name, "My_View");
1695 } else {
1696 panic!("Expected CreateView");
1697 }
1698 }
1699
1700 #[test]
1703 fn test_aggregate_division() {
1704 let stmt = parse_query(
1705 "SELECT token, SUM(sell) / SUM(buy) as ratio FROM orders GROUP BY token"
1706 ).unwrap();
1707 if let Statement::Select(q) = stmt {
1708 assert_eq!(q.group_by, Some(vec!["token".into()]));
1709 if let ColumnList::Named(exprs) = &q.columns {
1710 assert_eq!(exprs.len(), 2);
1711 assert!(exprs[1].is_aggregate());
1712 } else {
1713 panic!("Expected Named columns");
1714 }
1715 } else {
1716 panic!("Expected Select");
1717 }
1718 }
1719
1720 #[test]
1721 fn test_aggregate_subtraction() {
1722 let stmt = parse_query(
1723 "SELECT token, SUM(sell) - SUM(buy) as net FROM orders GROUP BY token"
1724 ).unwrap();
1725 if let Statement::Select(q) = stmt {
1726 if let ColumnList::Named(exprs) = &q.columns {
1727 assert_eq!(exprs[1].output_name(), "net");
1728 }
1729 } else {
1730 panic!("Expected Select");
1731 }
1732 }
1733
1734 #[test]
1735 fn test_create_view_with_arithmetic() {
1736 let stmt = parse_query(
1737 "CREATE VIEW positions AS SELECT token, SUM(sell) / SUM(buy) as ratio FROM orders GROUP BY token"
1738 ).unwrap();
1739 if let Statement::CreateView(cv) = stmt {
1740 assert_eq!(cv.view_name, "positions");
1741 } else {
1742 panic!("Expected CreateView, got {:?}", stmt);
1743 }
1744 }
1745
1746 #[test]
1749 fn test_subquery_in_from() {
1750 let stmt = parse_query(
1751 "SELECT token, sell_size FROM (SELECT token, SUM(size) as sell_size FROM orders GROUP BY token) LIMIT 5"
1752 ).unwrap();
1753 if let Statement::Select(q) = stmt {
1754 assert!(q.subquery.is_some());
1755 assert_eq!(q.limit, Some(5));
1756 let sub = q.subquery.unwrap();
1757 assert_eq!(sub.table, "orders");
1758 assert!(sub.group_by.is_some());
1759 } else {
1760 panic!("Expected Select");
1761 }
1762 }
1763
1764 #[test]
1767 fn test_create_view_with_having() {
1768 let stmt = parse_query(
1769 "CREATE VIEW positions AS SELECT token, SUM(sell) as sell_size, SUM(buy) as buy_size FROM orders GROUP BY token HAVING sell_size > buy_size"
1770 ).unwrap();
1771 if let Statement::CreateView(cv) = stmt {
1772 assert_eq!(cv.view_name, "positions");
1773 assert!(cv.query.having.is_some());
1774 } else {
1775 panic!("Expected CreateView, got {:?}", stmt);
1776 }
1777 }
1778
1779 #[test]
1782 fn test_aggregate_multiplication() {
1783 let stmt = parse_query(
1784 "SELECT SUM(a) * 2 as doubled FROM test"
1785 ).unwrap();
1786 if let Statement::Select(q) = stmt {
1787 if let ColumnList::Named(exprs) = &q.columns {
1788 assert_eq!(exprs.len(), 1);
1789 assert!(exprs[0].is_aggregate());
1790 assert_eq!(exprs[0].output_name(), "doubled");
1791 } else {
1792 panic!("Expected Named columns");
1793 }
1794 } else {
1795 panic!("Expected Select");
1796 }
1797 }
1798
1799 #[test]
1800 fn test_complex_aggregate_arithmetic() {
1801 let stmt = parse_query(
1802 "SELECT SUM(CASE WHEN side = 'SELL' THEN size ELSE 0 END) / SUM(CASE WHEN side = 'BUY' THEN size ELSE 0 END) as ratio FROM orders GROUP BY token"
1803 ).unwrap();
1804 if let Statement::Select(q) = stmt {
1805 if let ColumnList::Named(exprs) = &q.columns {
1806 assert_eq!(exprs.len(), 1);
1807 assert!(exprs[0].is_aggregate());
1808 assert_eq!(exprs[0].output_name(), "ratio");
1809 } else {
1810 panic!("Expected Named columns");
1811 }
1812 assert_eq!(q.group_by, Some(vec!["token".into()]));
1813 } else {
1814 panic!("Expected Select");
1815 }
1816 }
1817
1818 #[test]
1821 fn test_subquery_with_alias() {
1822 let stmt = parse_query(
1823 "SELECT x FROM (SELECT x FROM t) sub"
1824 ).unwrap();
1825 if let Statement::Select(q) = stmt {
1826 assert!(q.subquery.is_some());
1827 let sub = q.subquery.unwrap();
1828 assert_eq!(sub.table, "t");
1829 if let ColumnList::Named(exprs) = &q.columns {
1830 assert_eq!(exprs.len(), 1);
1831 assert_eq!(exprs[0].output_name(), "x");
1832 } else {
1833 panic!("Expected Named columns");
1834 }
1835 } else {
1836 panic!("Expected Select");
1837 }
1838 }
1839
1840 #[test]
1841 fn test_subquery_with_where() {
1842 let stmt = parse_query(
1843 "SELECT x FROM (SELECT x FROM t WHERE y > 0) LIMIT 5"
1844 ).unwrap();
1845 if let Statement::Select(q) = stmt {
1846 assert!(q.subquery.is_some());
1847 assert_eq!(q.limit, Some(5));
1848 let sub = q.subquery.unwrap();
1849 assert_eq!(sub.table, "t");
1850 assert!(sub.where_clause.is_some());
1851 } else {
1852 panic!("Expected Select");
1853 }
1854 }
1855
1856 #[test]
1859 fn test_create_view_aggregate_subtraction() {
1860 let stmt = parse_query(
1861 "CREATE VIEW v AS SELECT token, SUM(sell) - SUM(buy) as net FROM orders GROUP BY token"
1862 ).unwrap();
1863 if let Statement::CreateView(cv) = stmt {
1864 assert_eq!(cv.view_name, "v");
1865 assert_eq!(cv.query.group_by, Some(vec!["token".into()]));
1866 if let ColumnList::Named(exprs) = &cv.query.columns {
1867 assert_eq!(exprs.len(), 2);
1868 assert_eq!(exprs[1].output_name(), "net");
1869 assert!(exprs[1].is_aggregate());
1870 } else {
1871 panic!("Expected Named columns");
1872 }
1873 } else {
1874 panic!("Expected CreateView, got {:?}", stmt);
1875 }
1876 }
1877}