1use regex::Regex;
4use std::sync::LazyLock;
5
6use crate::errors::MdqlError;
7pub use crate::query_ast::*;
8
9static KEYWORDS: &[&str] = &[
12 "SELECT", "FROM", "WHERE", "AND", "OR", "ORDER", "BY",
13 "ASC", "DESC", "LIMIT", "LIKE", "IN", "IS", "NOT", "NULL",
14 "JOIN", "LEFT", "ON", "AS", "GROUP", "HAVING",
15 "INSERT", "INTO", "VALUES", "UPDATE", "SET", "DELETE",
16 "ALTER", "TABLE", "RENAME", "FIELD", "TO", "DROP", "MERGE", "FIELDS",
17 "CASE", "WHEN", "THEN", "ELSE", "END",
18 "INTERVAL", "DAY", "DAYS", "CURRENT_DATE", "CURRENT_TIMESTAMP", "DATEDIFF",
19 "CREATE", "VIEW", "CASCADE", "RESTRICT",
20 "WITH",
21 "OVER", "PARTITION", "ROW_NUMBER", "RANK", "DENSE_RANK", "LAG", "LEAD",
22];
23
24static AGG_FUNCS: &[&str] = &["COUNT", "SUM", "AVG", "MIN", "MAX"];
25static WINDOW_FUNCS: &[&str] = &["ROW_NUMBER", "RANK", "DENSE_RANK", "LAG", "LEAD"];
26
27static TOKEN_RE: LazyLock<Regex> = LazyLock::new(|| {
28 Regex::new(
29 r#"(?x)
30 \s*(?:
31 (?P<backtick>`[^`]+`)
32 | (?P<string>'(?:[^'\\]|\\.)*')
33 | (?P<number>\d+(?:\.\d+)?)
34 | (?P<op><=|>=|!=|[=<>,*()+\-/%])
35 | (?P<word>[A-Za-z_][A-Za-z0-9_./-]*)
36 )"#,
37 )
38 .unwrap()
39});
40
41#[derive(Debug, Clone)]
42struct Token {
43 token_type: String,
44 value: String,
45 raw: String,
46}
47
48fn tokenize(sql: &str) -> Vec<Token> {
49 let mut tokens = Vec::new();
50 for caps in TOKEN_RE.captures_iter(sql) {
51 if let Some(m) = caps.name("backtick") {
52 let raw = m.as_str();
53 tokens.push(Token {
54 token_type: "ident".into(),
55 value: raw[1..raw.len() - 1].into(),
56 raw: raw.into(),
57 });
58 } else if let Some(m) = caps.name("string") {
59 let raw = m.as_str();
60 tokens.push(Token {
61 token_type: "string".into(),
62 value: raw[1..raw.len() - 1].into(),
63 raw: raw.into(),
64 });
65 } else if let Some(m) = caps.name("number") {
66 let raw = m.as_str();
67 tokens.push(Token {
68 token_type: "number".into(),
69 value: raw.into(),
70 raw: raw.into(),
71 });
72 } else if let Some(m) = caps.name("op") {
73 let raw = m.as_str();
74 tokens.push(Token {
75 token_type: "op".into(),
76 value: raw.into(),
77 raw: raw.into(),
78 });
79 } else if let Some(m) = caps.name("word") {
80 let raw = m.as_str();
81 if KEYWORDS.contains(&raw.to_uppercase().as_str()) {
82 tokens.push(Token {
83 token_type: "keyword".into(),
84 value: raw.to_uppercase(),
85 raw: raw.into(),
86 });
87 } else {
88 tokens.push(Token {
89 token_type: "ident".into(),
90 value: raw.into(),
91 raw: raw.into(),
92 });
93 }
94 }
95 }
96 tokens
97}
98
99struct Parser {
102 tokens: Vec<Token>,
103 pos: usize,
104}
105
106impl Parser {
107 fn new(tokens: Vec<Token>) -> Self {
108 Parser { tokens, pos: 0 }
109 }
110
111 fn peek(&self) -> Option<&Token> {
112 self.tokens.get(self.pos)
113 }
114
115 fn advance(&mut self) -> Token {
116 let t = self.tokens[self.pos].clone();
117 self.pos += 1;
118 t
119 }
120
121 fn expect(&mut self, type_: &str, value: Option<&str>) -> Result<Token, MdqlError> {
122 let t = self.peek().ok_or_else(|| {
123 MdqlError::QueryParse(format!(
124 "Unexpected end of query, expected {}",
125 value.unwrap_or(type_)
126 ))
127 })?;
128 let matches_type = t.token_type == type_;
129 let matches_value = value.map_or(true, |v| t.value == v);
130 if !matches_type || !matches_value {
131 return Err(MdqlError::QueryParse(format!(
132 "Expected {}, got '{}' at position {}",
133 value.unwrap_or(type_),
134 t.raw,
135 self.pos
136 )));
137 }
138 Ok(self.advance())
139 }
140
141 fn match_keyword(&mut self, kw: &str) -> bool {
142 if let Some(t) = self.peek() {
143 if t.token_type == "keyword" && t.value == kw {
144 self.advance();
145 return true;
146 }
147 }
148 false
149 }
150
151 fn parse_statement(&mut self) -> Result<Statement, MdqlError> {
152 let t = self.peek().ok_or_else(|| MdqlError::QueryParse("Empty query".into()))?;
153 match (t.token_type.as_str(), t.value.as_str()) {
154 ("keyword", "WITH") => {
155 let ctes = self.parse_ctes()?;
156 let mut q = self.parse_select()?;
157 q.ctes = ctes;
158 self.expect_end()?;
159 Ok(Statement::Select(q))
160 }
161 ("keyword", "SELECT") => {
162 let q = self.parse_select()?;
163 self.expect_end()?;
164 Ok(Statement::Select(q))
165 }
166 ("keyword", "INSERT") => Ok(Statement::Insert(self.parse_insert()?)),
167 ("keyword", "UPDATE") => Ok(Statement::Update(self.parse_update()?)),
168 ("keyword", "DELETE") => Ok(Statement::Delete(self.parse_delete()?)),
169 ("keyword", "ALTER") => self.parse_alter(),
170 ("keyword", "CREATE") => self.parse_create_view(),
171 ("keyword", "DROP") => self.parse_drop_view(),
172 _ => Err(MdqlError::QueryParse(format!(
173 "Expected SELECT, INSERT, UPDATE, DELETE, ALTER, CREATE, or DROP, got '{}'",
174 t.raw
175 ))),
176 }
177 }
178
179 fn parse_ctes(&mut self) -> Result<Vec<CteClause>, MdqlError> {
180 self.expect("keyword", Some("WITH"))?;
181 let mut ctes = Vec::new();
182 loop {
183 let name = self.parse_ident()?;
184 self.expect("keyword", Some("AS"))?;
185 self.expect("op", Some("("))?;
186 let query = self.parse_select()?;
187 self.expect("op", Some(")"))?;
188 ctes.push(CteClause { name, query: Box::new(query) });
189 if !self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
190 break;
191 }
192 self.advance();
193 }
194 Ok(ctes)
195 }
196
197 fn parse_select(&mut self) -> Result<SelectQuery, MdqlError> {
198 self.expect("keyword", Some("SELECT"))?;
199 let columns = self.parse_columns()?;
200 self.expect("keyword", Some("FROM"))?;
201
202 let mut subquery = None;
204 let (table, mut table_alias) = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
205 self.advance();
206 let inner = self.parse_select()?;
207 self.expect("op", Some(")"))?;
208 subquery = Some(Box::new(inner));
209 let alias = if let Some(t) = self.peek() {
210 if t.token_type == "ident" && !self.is_clause_keyword(t) {
211 Some(self.advance().value)
212 } else {
213 None
214 }
215 } else {
216 None
217 };
218 ("_subquery".to_string(), alias)
219 } else {
220 let t = self.parse_ident()?;
221 (t, None)
222 };
223
224 if subquery.is_none() {
226 if let Some(t) = self.peek() {
227 if t.token_type == "ident" && !self.is_clause_keyword(t) {
228 table_alias = Some(self.advance().value);
229 }
230 }
231 }
232
233 let mut joins = Vec::new();
235 loop {
236 let jt = if self.match_keyword("LEFT") {
237 self.expect("keyword", Some("JOIN"))?;
238 JoinType::Left
239 } else if self.match_keyword("JOIN") {
240 JoinType::Inner
241 } else {
242 break;
243 };
244 let join_table = self.parse_ident()?;
245 let mut join_alias = None;
246 if let Some(t) = self.peek() {
247 if t.token_type == "ident" && !self.is_clause_keyword(t) {
248 join_alias = Some(self.advance().value);
249 }
250 }
251 self.expect("keyword", Some("ON"))?;
252 let condition = self.parse_or_expr()?;
253 joins.push(JoinClause {
254 join_type: jt,
255 table: join_table,
256 alias: join_alias,
257 condition,
258 });
259 }
260
261 let mut where_clause = None;
262 if self.match_keyword("WHERE") {
263 where_clause = Some(self.parse_or_expr()?);
264 }
265
266 let mut group_by = None;
267 if self.match_keyword("GROUP") {
268 self.expect("keyword", Some("BY"))?;
269 let mut cols = vec![self.parse_ident()?];
270 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
271 self.advance();
272 cols.push(self.parse_ident()?);
273 }
274 group_by = Some(cols);
275 }
276
277 let mut having = None;
278 if self.match_keyword("HAVING") {
279 having = Some(self.parse_or_expr()?);
280 }
281
282 let mut order_by = None;
283 if self.match_keyword("ORDER") {
284 self.expect("keyword", Some("BY"))?;
285 order_by = Some(self.parse_order_by()?);
286 }
287
288 let mut limit = None;
289 if self.match_keyword("LIMIT") {
290 let t = self.expect("number", None)?;
291 limit = Some(t.value.parse::<i64>().map_err(|_| {
292 MdqlError::QueryParse(format!("Invalid LIMIT value: {}", t.value))
293 })?);
294 }
295
296 Ok(SelectQuery {
297 columns,
298 table,
299 table_alias,
300 subquery,
301 joins,
302 where_clause,
303 group_by,
304 having,
305 order_by,
306 limit,
307 ctes: vec![],
308 })
309 }
310
311 fn parse_insert(&mut self) -> Result<InsertQuery, MdqlError> {
312 self.expect("keyword", Some("INSERT"))?;
313 self.expect("keyword", Some("INTO"))?;
314 let table = self.parse_ident()?;
315
316 self.expect("op", Some("("))?;
317 let mut columns = vec![self.parse_ident()?];
318 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
319 self.advance();
320 columns.push(self.parse_ident()?);
321 }
322 self.expect("op", Some(")"))?;
323
324 self.expect("keyword", Some("VALUES"))?;
325
326 self.expect("op", Some("("))?;
327 let mut values = vec![self.parse_value()?];
328 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
329 self.advance();
330 values.push(self.parse_value()?);
331 }
332 self.expect("op", Some(")"))?;
333
334 if columns.len() != values.len() {
335 return Err(MdqlError::QueryParse(format!(
336 "Column count ({}) does not match value count ({})",
337 columns.len(),
338 values.len()
339 )));
340 }
341
342 self.expect_end()?;
343 Ok(InsertQuery {
344 table,
345 columns,
346 values,
347 })
348 }
349
350 fn parse_update(&mut self) -> Result<UpdateQuery, MdqlError> {
351 self.expect("keyword", Some("UPDATE"))?;
352 let table = self.parse_ident()?;
353 self.expect("keyword", Some("SET"))?;
354
355 let mut assignments = Vec::new();
356 let col = self.parse_ident()?;
357 self.expect("op", Some("="))?;
358 let val = self.parse_value()?;
359 assignments.push((col, val));
360
361 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
362 self.advance();
363 let col = self.parse_ident()?;
364 self.expect("op", Some("="))?;
365 let val = self.parse_value()?;
366 assignments.push((col, val));
367 }
368
369 let mut where_clause = None;
370 if self.match_keyword("WHERE") {
371 where_clause = Some(self.parse_or_expr()?);
372 }
373
374 self.expect_end()?;
375 Ok(UpdateQuery {
376 table,
377 assignments,
378 where_clause,
379 })
380 }
381
382 fn parse_delete(&mut self) -> Result<DeleteQuery, MdqlError> {
383 self.expect("keyword", Some("DELETE"))?;
384 self.expect("keyword", Some("FROM"))?;
385 let table = self.parse_ident()?;
386
387 let mut where_clause = None;
388 if self.match_keyword("WHERE") {
389 where_clause = Some(self.parse_or_expr()?);
390 }
391
392 let mode = if self.match_keyword("CASCADE") {
393 DeleteMode::Cascade
394 } else if self.match_keyword("RESTRICT") {
395 DeleteMode::Restrict
396 } else {
397 DeleteMode::Default
398 };
399
400 self.expect_end()?;
401 Ok(DeleteQuery {
402 table,
403 where_clause,
404 mode,
405 })
406 }
407
408 fn parse_alter(&mut self) -> Result<Statement, MdqlError> {
409 self.expect("keyword", Some("ALTER"))?;
410 self.expect("keyword", Some("TABLE"))?;
411 let table = self.parse_ident()?;
412
413 let t = self.peek().ok_or_else(|| {
414 MdqlError::QueryParse("Expected RENAME, DROP, or MERGE after table name".into())
415 })?;
416
417 match (t.token_type.as_str(), t.value.as_str()) {
418 ("keyword", "RENAME") => {
419 self.advance();
420 self.expect("keyword", Some("FIELD"))?;
421 let old_name = self.parse_string_or_ident()?;
422 self.expect("keyword", Some("TO"))?;
423 let new_name = self.parse_string_or_ident()?;
424 self.expect_end()?;
425 Ok(Statement::AlterRename(AlterRenameFieldQuery {
426 table,
427 old_name,
428 new_name,
429 }))
430 }
431 ("keyword", "DROP") => {
432 self.advance();
433 self.expect("keyword", Some("FIELD"))?;
434 let field_name = self.parse_string_or_ident()?;
435 self.expect_end()?;
436 Ok(Statement::AlterDrop(AlterDropFieldQuery {
437 table,
438 field_name,
439 }))
440 }
441 ("keyword", "MERGE") => {
442 self.advance();
443 self.expect("keyword", Some("FIELDS"))?;
444 let mut sources = vec![self.parse_string_or_ident()?];
445 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
446 self.advance();
447 sources.push(self.parse_string_or_ident()?);
448 }
449 self.expect("keyword", Some("INTO"))?;
450 let target = self.parse_string_or_ident()?;
451 self.expect_end()?;
452 Ok(Statement::AlterMerge(AlterMergeFieldsQuery {
453 table,
454 sources,
455 into: target,
456 }))
457 }
458 _ => Err(MdqlError::QueryParse(format!(
459 "Expected RENAME, DROP, or MERGE, got '{}'",
460 t.raw
461 ))),
462 }
463 }
464
465 fn parse_create_view(&mut self) -> Result<Statement, MdqlError> {
466 self.expect("keyword", Some("CREATE"))?;
467 self.expect("keyword", Some("VIEW"))?;
468 let view_name = self.parse_ident()?;
469
470 let columns = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
471 self.advance();
472 let mut cols = vec![self.parse_ident()?];
473 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
474 self.advance();
475 cols.push(self.parse_ident()?);
476 }
477 self.expect("op", Some(")"))?;
478 Some(cols)
479 } else {
480 None
481 };
482
483 self.expect("keyword", Some("AS"))?;
484 let query = Box::new(self.parse_select()?);
485 self.expect_end()?;
486
487 Ok(Statement::CreateView(CreateViewQuery {
488 view_name,
489 columns,
490 query,
491 }))
492 }
493
494 fn parse_drop_view(&mut self) -> Result<Statement, MdqlError> {
495 self.expect("keyword", Some("DROP"))?;
496 self.expect("keyword", Some("VIEW"))?;
497 let view_name = self.parse_ident()?;
498 self.expect_end()?;
499 Ok(Statement::DropView(DropViewQuery { view_name }))
500 }
501
502 fn parse_string_or_ident(&mut self) -> Result<String, MdqlError> {
503 let t = self.peek().ok_or_else(|| {
504 MdqlError::QueryParse("Expected field name, got end of query".into())
505 })?;
506 match t.token_type.as_str() {
507 "string" => {
508 let v = self.advance().value;
509 Ok(v)
510 }
511 "ident" | "keyword" => {
512 let v = self.advance().value;
513 Ok(v)
514 }
515 _ => Err(MdqlError::QueryParse(format!(
516 "Expected field name, got '{}'",
517 t.raw
518 ))),
519 }
520 }
521
522 fn parse_columns(&mut self) -> Result<ColumnList, MdqlError> {
523 if let Some(t) = self.peek() {
524 if t.token_type == "op" && t.value == "*" {
525 self.advance();
526 return Ok(ColumnList::All);
527 }
528 }
529
530 let mut exprs = vec![self.parse_select_expr()?];
531 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
532 self.advance();
533 exprs.push(self.parse_select_expr()?);
534 }
535 Ok(ColumnList::Named(exprs))
536 }
537
538 fn peek_is_window_func(&self) -> bool {
539 let t = match self.peek() {
540 Some(t) => t,
541 None => return false,
542 };
543 let name_upper = t.value.to_uppercase();
544 if !WINDOW_FUNCS.contains(&name_upper.as_str()) {
545 return false;
546 }
547 self.tokens
548 .get(self.pos + 1)
549 .map_or(false, |next| next.token_type == "op" && next.value == "(")
550 }
551
552 fn peek_is_agg_func(&self) -> bool {
553 let t = match self.peek() {
554 Some(t) => t,
555 None => return false,
556 };
557 let name_upper = t.value.to_uppercase();
558 if !AGG_FUNCS.contains(&name_upper.as_str()) {
559 return false;
560 }
561 self.tokens
563 .get(self.pos + 1)
564 .map_or(false, |next| next.token_type == "op" && next.value == "(")
565 }
566
567 fn parse_select_expr(&mut self) -> Result<SelectExpr, MdqlError> {
568 let _t = self.peek().ok_or_else(|| {
569 MdqlError::QueryParse("Expected column or aggregate, got end of query".into())
570 })?;
571
572 let expr = self.parse_additive()?;
573
574 let alias = if self.match_keyword("AS") {
575 Some(self.parse_ident()?)
576 } else if self.peek().map_or(false, |t| {
577 t.token_type == "ident" && !self.is_clause_keyword(t)
578 }) {
579 Some(self.advance().value)
580 } else {
581 None
582 };
583
584 if let Expr::Aggregate { func, arg, arg_expr } = expr {
586 return Ok(SelectExpr::Aggregate {
587 func,
588 arg,
589 arg_expr: arg_expr.map(|e| *e),
590 alias,
591 });
592 }
593
594 if alias.is_none() {
595 if let Expr::Column(name) = &expr {
596 return Ok(SelectExpr::Column(name.clone()));
597 }
598 }
599
600 Ok(SelectExpr::Expr { expr, alias })
601 }
602
603 fn peek_is_additive_op(&self) -> bool {
606 self.peek().map_or(false, |t| {
607 t.token_type == "op" && (t.value == "+" || t.value == "-")
608 })
609 }
610
611 fn peek_is_multiplicative_op(&self) -> bool {
612 self.peek().map_or(false, |t| {
613 t.token_type == "op" && (t.value == "*" || t.value == "/" || t.value == "%")
614 })
615 }
616
617 fn parse_additive(&mut self) -> Result<Expr, MdqlError> {
618 let mut left = self.parse_multiplicative()?;
619 while self.peek_is_additive_op() {
620 let op_tok = self.advance();
621 let is_sub = op_tok.value == "-";
622
623 if self.peek().map_or(false, |t| t.token_type == "keyword" && t.value == "INTERVAL") {
625 self.advance(); let days_expr = self.parse_multiplicative()?;
627 if !self.match_keyword("DAY") && !self.match_keyword("DAYS") {
629 return Err(MdqlError::QueryParse("Expected DAY after INTERVAL value".into()));
630 }
631 let days = if is_sub {
632 Expr::UnaryMinus(Box::new(days_expr))
633 } else {
634 days_expr
635 };
636 left = Expr::DateAdd {
637 date: Box::new(left),
638 days: Box::new(days),
639 };
640 continue;
641 }
642
643 let op = match op_tok.value.as_str() {
644 "+" => ArithOp::Add,
645 "-" => ArithOp::Sub,
646 _ => unreachable!(),
647 };
648 let right = self.parse_multiplicative()?;
649 left = Expr::BinaryOp {
650 left: Box::new(left),
651 op,
652 right: Box::new(right),
653 };
654 }
655 Ok(left)
656 }
657
658 fn parse_multiplicative(&mut self) -> Result<Expr, MdqlError> {
659 let mut left = self.parse_unary()?;
660 while self.peek_is_multiplicative_op() {
661 let op_tok = self.advance();
662 let op = match op_tok.value.as_str() {
663 "*" => ArithOp::Mul,
664 "/" => ArithOp::Div,
665 "%" => ArithOp::Mod,
666 _ => unreachable!(),
667 };
668 let right = self.parse_unary()?;
669 left = Expr::BinaryOp {
670 left: Box::new(left),
671 op,
672 right: Box::new(right),
673 };
674 }
675 Ok(left)
676 }
677
678 fn parse_unary(&mut self) -> Result<Expr, MdqlError> {
679 if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "-") {
680 self.advance();
681 let inner = self.parse_atom()?;
682 match inner {
684 Expr::Literal(SqlValue::Int(n)) => Ok(Expr::Literal(SqlValue::Int(-n))),
685 Expr::Literal(SqlValue::Float(f)) => Ok(Expr::Literal(SqlValue::Float(-f))),
686 _ => Ok(Expr::UnaryMinus(Box::new(inner))),
687 }
688 } else {
689 self.parse_atom()
690 }
691 }
692
693 fn parse_atom(&mut self) -> Result<Expr, MdqlError> {
694 if self.peek_is_window_func() {
695 return self.parse_standalone_window();
696 }
697
698 if self.peek_is_agg_func() {
699 let agg = self.parse_agg_expr()?;
700 if self.peek().map_or(false, |t| t.token_type == "keyword" && t.value == "OVER") {
701 self.advance();
702 let over = self.parse_window_spec()?;
703 if let Expr::Aggregate { func, arg, arg_expr } = agg {
704 return Ok(Expr::Window {
705 func: WindowFunc::Agg(func),
706 args: if arg == "*" {
707 vec![]
708 } else {
709 vec![arg_expr.map(|e| *e).unwrap_or(Expr::Column(arg))]
710 },
711 over,
712 });
713 }
714 }
715 return Ok(agg);
716 }
717
718 let t = self.peek().ok_or_else(|| {
719 MdqlError::QueryParse("Expected expression, got end of query".into())
720 })?;
721
722 match t.token_type.as_str() {
723 "number" => {
724 let v = self.advance().value;
725 if v.contains('.') {
726 let f: f64 = v.parse().map_err(|_| {
727 MdqlError::QueryParse(format!("Invalid float: {}", v))
728 })?;
729 Ok(Expr::Literal(SqlValue::Float(f)))
730 } else {
731 let n: i64 = v.parse().map_err(|_| {
732 MdqlError::QueryParse(format!("Invalid int: {}", v))
733 })?;
734 Ok(Expr::Literal(SqlValue::Int(n)))
735 }
736 }
737 "string" => {
738 let v = self.advance().value;
739 Ok(Expr::Literal(SqlValue::String(v)))
740 }
741 "keyword" if t.value == "NULL" => {
742 self.advance();
743 Ok(Expr::Literal(SqlValue::Null))
744 }
745 "keyword" if t.value == "CASE" => {
746 self.parse_case_expr()
747 }
748 "keyword" if t.value == "CURRENT_DATE" => {
749 self.advance();
750 Ok(Expr::CurrentDate)
751 }
752 "keyword" if t.value == "CURRENT_TIMESTAMP" => {
753 self.advance();
754 Ok(Expr::CurrentTimestamp)
755 }
756 "keyword" if t.value == "DATEDIFF" => {
757 self.advance();
758 self.expect("op", Some("("))?;
759 let left = self.parse_additive()?;
760 self.expect("op", Some(","))?;
761 let right = self.parse_additive()?;
762 self.expect("op", Some(")"))?;
763 Ok(Expr::DateDiff { left: Box::new(left), right: Box::new(right) })
764 }
765 "op" if t.value == "(" => {
766 let next_is_select = self.tokens.get(self.pos + 1)
767 .map_or(false, |t| t.token_type == "keyword" && t.value == "SELECT");
768 if next_is_select {
769 self.advance();
770 let sq = self.parse_select()?;
771 self.expect("op", Some(")"))?;
772 Ok(Expr::Subquery(Box::new(sq)))
773 } else {
774 self.advance();
775 let expr = self.parse_additive()?;
776 self.expect("op", Some(")"))?;
777 Ok(expr)
778 }
779 }
780 "ident" if t.value.eq_ignore_ascii_case("true") => {
783 self.advance();
784 Ok(Expr::Literal(SqlValue::Bool(true)))
785 }
786 "ident" if t.value.eq_ignore_ascii_case("false") => {
787 self.advance();
788 Ok(Expr::Literal(SqlValue::Bool(false)))
789 }
790 "ident" => {
791 let name = self.advance().value;
792 Ok(Expr::Column(name))
793 }
794 "keyword" if !Self::is_reserved_keyword(&t.value) => {
795 let name = self.advance().value;
796 Ok(Expr::Column(name))
797 }
798 _ => Err(MdqlError::QueryParse(format!(
799 "Expected expression, got '{}'",
800 t.raw
801 ))),
802 }
803 }
804
805 fn parse_case_expr(&mut self) -> Result<Expr, MdqlError> {
806 self.expect("keyword", Some("CASE"))?;
807 let mut whens = Vec::new();
808 while self.match_keyword("WHEN") {
809 let condition = self.parse_or_expr()?;
810 self.expect("keyword", Some("THEN"))?;
811 let result = self.parse_additive()?;
812 whens.push((condition, Box::new(result)));
813 }
814 if whens.is_empty() {
815 return Err(MdqlError::QueryParse("CASE requires at least one WHEN clause".into()));
816 }
817 let else_expr = if self.match_keyword("ELSE") {
818 Some(Box::new(self.parse_additive()?))
819 } else {
820 None
821 };
822 self.expect("keyword", Some("END"))?;
823 Ok(Expr::Case { whens, else_expr })
824 }
825
826 fn parse_agg_expr(&mut self) -> Result<Expr, MdqlError> {
827 let func_name = self.advance().value.to_uppercase();
828 let func = match func_name.as_str() {
829 "COUNT" => AggFunc::Count,
830 "SUM" => AggFunc::Sum,
831 "AVG" => AggFunc::Avg,
832 "MIN" => AggFunc::Min,
833 "MAX" => AggFunc::Max,
834 _ => unreachable!(),
835 };
836 self.expect("op", Some("("))?;
837 let (arg, arg_expr) = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "*") {
838 self.advance();
839 ("*".to_string(), None)
840 } else {
841 let expr = self.parse_additive()?;
842 if let Expr::Column(name) = &expr {
843 (name.clone(), None)
844 } else {
845 (expr.display_name(), Some(Box::new(expr)))
846 }
847 };
848 self.expect("op", Some(")"))?;
849 Ok(Expr::Aggregate { func, arg, arg_expr })
850 }
851
852 fn parse_standalone_window(&mut self) -> Result<Expr, MdqlError> {
853 let func_name = self.advance().value.to_uppercase();
854 let func = match func_name.as_str() {
855 "ROW_NUMBER" => WindowFunc::RowNumber,
856 "RANK" => WindowFunc::Rank,
857 "DENSE_RANK" => WindowFunc::DenseRank,
858 "LAG" => WindowFunc::Lag,
859 "LEAD" => WindowFunc::Lead,
860 _ => unreachable!(),
861 };
862 self.expect("op", Some("("))?;
863 let mut args = Vec::new();
864 if !self.peek().map_or(false, |t| t.token_type == "op" && t.value == ")") {
865 args.push(self.parse_additive()?);
866 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
867 self.advance();
868 args.push(self.parse_additive()?);
869 }
870 }
871 self.expect("op", Some(")"))?;
872 self.expect("keyword", Some("OVER"))?;
873 let over = self.parse_window_spec()?;
874 Ok(Expr::Window { func, args, over })
875 }
876
877 fn parse_window_spec(&mut self) -> Result<WindowSpec, MdqlError> {
878 self.expect("op", Some("("))?;
879 let mut partition_by = Vec::new();
880 if self.match_keyword("PARTITION") {
881 self.expect("keyword", Some("BY"))?;
882 partition_by.push(self.parse_ident()?);
883 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
884 self.advance();
885 partition_by.push(self.parse_ident()?);
886 }
887 }
888 let mut order_by = Vec::new();
889 if self.match_keyword("ORDER") {
890 self.expect("keyword", Some("BY"))?;
891 order_by = self.parse_order_by()?;
892 }
893 self.expect("op", Some(")"))?;
894 Ok(WindowSpec { partition_by, order_by })
895 }
896
897 fn parse_ident(&mut self) -> Result<String, MdqlError> {
898 let t = self.peek().ok_or_else(|| {
899 MdqlError::QueryParse("Expected identifier, got end of query".into())
900 })?;
901 match t.token_type.as_str() {
902 "ident" | "keyword" => {
903 let v = self.advance().value;
904 Ok(v)
905 }
906 _ => Err(MdqlError::QueryParse(format!(
907 "Expected identifier, got '{}'",
908 t.raw
909 ))),
910 }
911 }
912
913 fn parse_or_expr(&mut self) -> Result<WhereClause, MdqlError> {
914 let mut left = self.parse_and_expr()?;
915 while self.match_keyword("OR") {
916 let right = self.parse_and_expr()?;
917 left = WhereClause::BoolOp(BoolOp {
918 op: BoolOpKind::Or,
919 left: Box::new(left),
920 right: Box::new(right),
921 });
922 }
923 Ok(left)
924 }
925
926 fn parse_and_expr(&mut self) -> Result<WhereClause, MdqlError> {
927 let mut left = self.parse_comparison()?;
928 while self.match_keyword("AND") {
929 let right = self.parse_comparison()?;
930 left = WhereClause::BoolOp(BoolOp {
931 op: BoolOpKind::And,
932 left: Box::new(left),
933 right: Box::new(right),
934 });
935 }
936 Ok(left)
937 }
938
939 fn parse_comparison(&mut self) -> Result<WhereClause, MdqlError> {
940 if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
942 let saved_pos = self.pos;
944 self.advance();
945 let result = self.parse_or_expr();
947 if result.is_ok() && self.peek().map_or(false, |t| t.token_type == "op" && t.value == ")") {
948 self.advance();
949 return result;
950 }
951 self.pos = saved_pos;
953 }
954
955 let left_expr = self.parse_additive()?;
957
958 let col = left_expr.as_column().unwrap_or("").to_string();
960
961 if self.match_keyword("IS") {
963 if self.match_keyword("NOT") {
964 self.expect("keyword", Some("NULL"))?;
965 return Ok(WhereClause::Comparison(Comparison {
966 column: col,
967 op: CmpOp::IsNotNull,
968 value: None,
969 left_expr: Some(left_expr),
970 right_expr: None,
971 }));
972 }
973 self.expect("keyword", Some("NULL"))?;
974 return Ok(WhereClause::Comparison(Comparison {
975 column: col,
976 op: CmpOp::IsNull,
977 value: None,
978 left_expr: Some(left_expr),
979 right_expr: None,
980 }));
981 }
982
983 if self.match_keyword("IN") {
985 self.expect("op", Some("("))?;
986 let is_subquery = self.peek().map_or(false, |t| t.token_type == "keyword" && t.value == "SELECT");
987 if is_subquery {
988 let sq = self.parse_select()?;
989 self.expect("op", Some(")"))?;
990 return Ok(WhereClause::Comparison(Comparison {
991 column: col,
992 op: CmpOp::In,
993 value: None,
994 left_expr: Some(left_expr),
995 right_expr: Some(Expr::Subquery(Box::new(sq))),
996 }));
997 }
998 let mut values = vec![self.parse_value()?];
999 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
1000 self.advance();
1001 values.push(self.parse_value()?);
1002 }
1003 self.expect("op", Some(")"))?;
1004 return Ok(WhereClause::Comparison(Comparison {
1005 column: col,
1006 op: CmpOp::In,
1007 value: Some(SqlValue::List(values)),
1008 left_expr: Some(left_expr),
1009 right_expr: None,
1010 }));
1011 }
1012
1013 if self.match_keyword("LIKE") {
1015 let val = self.parse_value()?;
1016 return Ok(WhereClause::Comparison(Comparison {
1017 column: col,
1018 op: CmpOp::Like,
1019 value: Some(val),
1020 left_expr: Some(left_expr),
1021 right_expr: None,
1022 }));
1023 }
1024
1025 if self.match_keyword("NOT") {
1027 if self.match_keyword("LIKE") {
1028 let val = self.parse_value()?;
1029 return Ok(WhereClause::Comparison(Comparison {
1030 column: col,
1031 op: CmpOp::NotLike,
1032 value: Some(val),
1033 left_expr: Some(left_expr),
1034 right_expr: None,
1035 }));
1036 }
1037 return Err(MdqlError::QueryParse("Expected LIKE after NOT".into()));
1038 }
1039
1040 if let Some(t) = self.peek() {
1042 if t.token_type == "op" && ["=", "!=", "<", ">", "<=", ">="].contains(&t.value.as_str())
1043 {
1044 let op_str = self.advance().value;
1045 let op = match op_str.as_str() {
1046 "=" => CmpOp::Eq,
1047 "!=" => CmpOp::Ne,
1048 "<" => CmpOp::Lt,
1049 ">" => CmpOp::Gt,
1050 "<=" => CmpOp::Le,
1051 ">=" => CmpOp::Ge,
1052 _ => unreachable!(),
1053 };
1054 let right_expr = self.parse_additive()?;
1056 let value = match &right_expr {
1058 Expr::Literal(v) => Some(v.clone()),
1059 _ => None,
1060 };
1061 return Ok(WhereClause::Comparison(Comparison {
1062 column: col,
1063 op,
1064 value,
1065 left_expr: Some(left_expr),
1066 right_expr: Some(right_expr),
1067 }));
1068 }
1069 }
1070
1071 let got = self.peek().map_or("end".to_string(), |t| t.raw.clone());
1072 Err(MdqlError::QueryParse(format!(
1073 "Expected operator after '{}', got '{}'",
1074 left_expr.display_name(), got
1075 )))
1076 }
1077
1078 fn parse_value(&mut self) -> Result<SqlValue, MdqlError> {
1079 let t = self.peek().ok_or_else(|| {
1080 MdqlError::QueryParse("Expected value, got end of query".into())
1081 })?;
1082 match t.token_type.as_str() {
1083 "string" => {
1084 let v = self.advance().value;
1085 Ok(SqlValue::String(v))
1086 }
1087 "number" => {
1088 let v = self.advance().value;
1089 if v.contains('.') {
1090 Ok(SqlValue::Float(v.parse().map_err(|_| {
1091 MdqlError::QueryParse(format!("Invalid float: {}", v))
1092 })?))
1093 } else {
1094 Ok(SqlValue::Int(v.parse().map_err(|_| {
1095 MdqlError::QueryParse(format!("Invalid int: {}", v))
1096 })?))
1097 }
1098 }
1099 "keyword" if t.value == "NULL" => {
1100 self.advance();
1101 Ok(SqlValue::Null)
1102 }
1103 "ident" if t.value.eq_ignore_ascii_case("true") => {
1107 self.advance();
1108 Ok(SqlValue::Bool(true))
1109 }
1110 "ident" if t.value.eq_ignore_ascii_case("false") => {
1111 self.advance();
1112 Ok(SqlValue::Bool(false))
1113 }
1114 _ => Err(MdqlError::QueryParse(format!(
1115 "Expected value, got '{}'",
1116 t.raw
1117 ))),
1118 }
1119 }
1120
1121 fn parse_order_by(&mut self) -> Result<Vec<OrderSpec>, MdqlError> {
1122 let mut specs = vec![self.parse_order_spec()?];
1123 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
1124 self.advance();
1125 specs.push(self.parse_order_spec()?);
1126 }
1127 Ok(specs)
1128 }
1129
1130 fn parse_order_spec(&mut self) -> Result<OrderSpec, MdqlError> {
1131 let expr = self.parse_additive()?;
1132 let col = expr.as_column().unwrap_or("").to_string();
1133 let descending = if self.match_keyword("DESC") {
1134 true
1135 } else {
1136 self.match_keyword("ASC");
1137 false
1138 };
1139 Ok(OrderSpec {
1140 column: col,
1141 expr: Some(expr),
1142 descending,
1143 })
1144 }
1145
1146 fn is_clause_keyword(&self, t: &Token) -> bool {
1147 t.token_type == "keyword"
1148 && ["WHERE", "ORDER", "LIMIT", "JOIN", "LEFT", "ON", "GROUP"].contains(&t.value.as_str())
1149 }
1150
1151 fn is_reserved_keyword(kw: &str) -> bool {
1153 matches!(kw,
1154 "AS" | "FROM" | "WHERE" | "AND" | "OR" | "ORDER" | "BY"
1155 | "ASC" | "DESC" | "LIMIT" | "JOIN" | "ON" | "GROUP"
1156 | "SELECT" | "INSERT" | "INTO" | "VALUES" | "UPDATE" | "SET"
1157 | "DELETE" | "ALTER" | "TABLE" | "IS" | "NOT" | "IN" | "LIKE"
1158 | "RENAME" | "FIELD" | "TO" | "DROP" | "MERGE" | "FIELDS"
1159 | "CASE" | "WHEN" | "THEN" | "ELSE" | "END"
1160 | "HAVING" | "INTERVAL" | "DAY" | "DAYS"
1161 | "CURRENT_DATE" | "CURRENT_TIMESTAMP" | "DATEDIFF"
1162 | "CREATE" | "VIEW" | "CASCADE" | "RESTRICT"
1163 | "WITH"
1164 | "OVER" | "PARTITION" | "ROW_NUMBER" | "RANK" | "DENSE_RANK" | "LAG" | "LEAD"
1165 )
1166 }
1167
1168 fn expect_end(&self) -> Result<(), MdqlError> {
1169 if let Some(t) = self.peek() {
1170 return Err(MdqlError::QueryParse(format!(
1171 "Unexpected token '{}' at position {}",
1172 t.raw, self.pos
1173 )));
1174 }
1175 Ok(())
1176 }
1177}
1178
1179pub fn parse_query(sql: &str) -> crate::errors::Result<Statement> {
1180 let tokens = tokenize(sql);
1181 if tokens.is_empty() {
1182 return Err(MdqlError::QueryParse("Empty query".into()));
1183 }
1184 let mut parser = Parser::new(tokens);
1185 parser.parse_statement()
1186}
1187
1188#[cfg(test)]
1189mod tests {
1190 use super::*;
1191
1192 #[test]
1193 fn test_simple_select() {
1194 let stmt = parse_query("SELECT title, status FROM strategies").unwrap();
1195 if let Statement::Select(q) = stmt {
1196 assert_eq!(q.columns, ColumnList::Named(vec![SelectExpr::Column("title".into()), SelectExpr::Column("status".into())]));
1197 assert_eq!(q.table, "strategies");
1198 } else {
1199 panic!("Expected Select");
1200 }
1201 }
1202
1203 #[test]
1204 fn test_select_star() {
1205 let stmt = parse_query("SELECT * FROM test").unwrap();
1206 if let Statement::Select(q) = stmt {
1207 assert_eq!(q.columns, ColumnList::All);
1208 } else {
1209 panic!("Expected Select");
1210 }
1211 }
1212
1213 #[test]
1214 fn test_boolean_literal_parses_as_bool_not_column() {
1215 for (sql, expected) in [
1218 ("SELECT title FROM t WHERE enabled = true", SqlValue::Bool(true)),
1219 ("SELECT title FROM t WHERE enabled = FALSE", SqlValue::Bool(false)),
1220 ] {
1221 let stmt = parse_query(sql).unwrap();
1222 let Statement::Select(q) = stmt else { panic!("Expected Select") };
1223 let WhereClause::Comparison(cmp) = q.where_clause.unwrap() else {
1224 panic!("Expected Comparison")
1225 };
1226 assert_eq!(cmp.value, Some(expected));
1227 }
1228 }
1229
1230 #[test]
1231 fn test_where_clause() {
1232 let stmt = parse_query("SELECT title FROM test WHERE count > 5").unwrap();
1233 if let Statement::Select(q) = stmt {
1234 assert!(q.where_clause.is_some());
1235 } else {
1236 panic!("Expected Select");
1237 }
1238 }
1239
1240 #[test]
1241 fn test_order_by() {
1242 let stmt =
1243 parse_query("SELECT title FROM test ORDER BY composite DESC, title ASC").unwrap();
1244 if let Statement::Select(q) = stmt {
1245 let ob = q.order_by.unwrap();
1246 assert_eq!(ob.len(), 2);
1247 assert!(ob[0].descending);
1248 assert!(!ob[1].descending);
1249 } else {
1250 panic!("Expected Select");
1251 }
1252 }
1253
1254 #[test]
1255 fn test_limit() {
1256 let stmt = parse_query("SELECT * FROM test LIMIT 10").unwrap();
1257 if let Statement::Select(q) = stmt {
1258 assert_eq!(q.limit, Some(10));
1259 } else {
1260 panic!("Expected Select");
1261 }
1262 }
1263
1264 #[test]
1265 fn test_insert() {
1266 let stmt = parse_query(
1267 "INSERT INTO test (title, count) VALUES ('Hello', 42)",
1268 )
1269 .unwrap();
1270 if let Statement::Insert(q) = stmt {
1271 assert_eq!(q.table, "test");
1272 assert_eq!(q.columns, vec!["title", "count"]);
1273 assert_eq!(q.values[0], SqlValue::String("Hello".into()));
1274 assert_eq!(q.values[1], SqlValue::Int(42));
1275 } else {
1276 panic!("Expected Insert");
1277 }
1278 }
1279
1280 #[test]
1281 fn test_update() {
1282 let stmt = parse_query("UPDATE test SET status = 'KILLED' WHERE path = 'a.md'").unwrap();
1283 if let Statement::Update(q) = stmt {
1284 assert_eq!(q.table, "test");
1285 assert_eq!(q.assignments.len(), 1);
1286 assert!(q.where_clause.is_some());
1287 } else {
1288 panic!("Expected Update");
1289 }
1290 }
1291
1292 #[test]
1293 fn test_delete() {
1294 let stmt = parse_query("DELETE FROM test WHERE status = 'draft'").unwrap();
1295 if let Statement::Delete(q) = stmt {
1296 assert_eq!(q.table, "test");
1297 assert!(q.where_clause.is_some());
1298 } else {
1299 panic!("Expected Delete");
1300 }
1301 }
1302
1303 #[test]
1304 fn test_alter_rename() {
1305 let stmt =
1306 parse_query("ALTER TABLE test RENAME FIELD 'Summary' TO 'Overview'").unwrap();
1307 if let Statement::AlterRename(q) = stmt {
1308 assert_eq!(q.old_name, "Summary");
1309 assert_eq!(q.new_name, "Overview");
1310 } else {
1311 panic!("Expected AlterRename");
1312 }
1313 }
1314
1315 #[test]
1316 fn test_alter_drop() {
1317 let stmt = parse_query("ALTER TABLE test DROP FIELD 'Details'").unwrap();
1318 if let Statement::AlterDrop(q) = stmt {
1319 assert_eq!(q.field_name, "Details");
1320 } else {
1321 panic!("Expected AlterDrop");
1322 }
1323 }
1324
1325 #[test]
1326 fn test_alter_merge() {
1327 let stmt = parse_query(
1328 "ALTER TABLE test MERGE FIELDS 'Entry Rules', 'Exit Rules' INTO 'Trading Rules'",
1329 )
1330 .unwrap();
1331 if let Statement::AlterMerge(q) = stmt {
1332 assert_eq!(q.sources, vec!["Entry Rules", "Exit Rules"]);
1333 assert_eq!(q.into, "Trading Rules");
1334 } else {
1335 panic!("Expected AlterMerge");
1336 }
1337 }
1338
1339 #[test]
1340 fn test_backtick_ident() {
1341 let stmt = parse_query("SELECT `Structural Mechanism` FROM test").unwrap();
1342 if let Statement::Select(q) = stmt {
1343 assert_eq!(
1344 q.columns,
1345 ColumnList::Named(vec![SelectExpr::Column("Structural Mechanism".into())])
1346 );
1347 } else {
1348 panic!("Expected Select");
1349 }
1350 }
1351
1352 #[test]
1353 fn test_like_operator() {
1354 let stmt = parse_query("SELECT title FROM test WHERE categories LIKE '%defi%'").unwrap();
1355 if let Statement::Select(q) = stmt {
1356 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1357 assert_eq!(c.op, CmpOp::Like);
1358 assert_eq!(c.value, Some(SqlValue::String("%defi%".into())));
1359 } else {
1360 panic!("Expected LIKE comparison");
1361 }
1362 } else {
1363 panic!("Expected Select");
1364 }
1365 }
1366
1367 #[test]
1368 fn test_in_operator() {
1369 let stmt =
1370 parse_query("SELECT * FROM test WHERE status IN ('ACTIVE', 'LIVE')").unwrap();
1371 if let Statement::Select(q) = stmt {
1372 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1373 assert_eq!(c.op, CmpOp::In);
1374 } else {
1375 panic!("Expected IN comparison");
1376 }
1377 } else {
1378 panic!("Expected Select");
1379 }
1380 }
1381
1382 #[test]
1383 fn test_is_null() {
1384 let stmt = parse_query("SELECT * FROM test WHERE title IS NULL").unwrap();
1385 if let Statement::Select(q) = stmt {
1386 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1387 assert_eq!(c.op, CmpOp::IsNull);
1388 } else {
1389 panic!("Expected IS NULL comparison");
1390 }
1391 } else {
1392 panic!("Expected Select");
1393 }
1394 }
1395
1396 #[test]
1397 fn test_and_or() {
1398 let stmt = parse_query(
1399 "SELECT * FROM test WHERE status = 'ACTIVE' AND count > 5 OR title LIKE '%test%'",
1400 )
1401 .unwrap();
1402 if let Statement::Select(q) = stmt {
1403 assert!(q.where_clause.is_some());
1404 } else {
1405 panic!("Expected Select");
1406 }
1407 }
1408
1409 #[test]
1410 fn test_join() {
1411 let stmt = parse_query(
1412 "SELECT s.title, b.sharpe FROM strategies s JOIN backtests b ON b.strategy = s.path",
1413 )
1414 .unwrap();
1415 if let Statement::Select(q) = stmt {
1416 assert_eq!(q.table, "strategies");
1417 assert_eq!(q.table_alias, Some("s".into()));
1418 assert_eq!(q.joins.len(), 1);
1419 let join = &q.joins[0];
1420 assert_eq!(join.table, "backtests");
1421 assert_eq!(join.alias, Some("b".into()));
1422 } else {
1423 panic!("Expected Select");
1424 }
1425 }
1426
1427 #[test]
1428 fn test_multi_join() {
1429 let stmt = parse_query(
1430 "SELECT s.title, b.sharpe, c.verdict FROM strategies s JOIN backtests b ON b.strategy = s.path JOIN critiques c ON c.strategy = s.path",
1431 )
1432 .unwrap();
1433 if let Statement::Select(q) = stmt {
1434 assert_eq!(q.table, "strategies");
1435 assert_eq!(q.table_alias, Some("s".into()));
1436 assert_eq!(q.joins.len(), 2);
1437 assert_eq!(q.joins[0].table, "backtests");
1438 assert_eq!(q.joins[0].alias, Some("b".into()));
1439 assert_eq!(where_clause_to_sql(&q.joins[0].condition), "b.strategy = s.path");
1440 assert_eq!(q.joins[1].table, "critiques");
1441 assert_eq!(q.joins[1].alias, Some("c".into()));
1442 assert_eq!(where_clause_to_sql(&q.joins[1].condition), "c.strategy = s.path");
1443 } else {
1444 panic!("Expected Select");
1445 }
1446 }
1447
1448 #[test]
1449 fn test_left_join() {
1450 let stmt = parse_query(
1451 "SELECT s.title, b.sharpe FROM strategies s LEFT JOIN backtests b ON b.strategy = s.path",
1452 )
1453 .unwrap();
1454 if let Statement::Select(q) = stmt {
1455 assert_eq!(q.joins.len(), 1);
1456 assert_eq!(q.joins[0].join_type, JoinType::Left);
1457 assert_eq!(q.joins[0].table, "backtests");
1458 } else {
1459 panic!("Expected Select");
1460 }
1461 }
1462
1463 #[test]
1464 fn test_mixed_join_types() {
1465 let stmt = parse_query(
1466 "SELECT s.title FROM strategies s JOIN backtests b ON b.strategy = s.path LEFT JOIN allocations a ON a.strategy = s.path",
1467 )
1468 .unwrap();
1469 if let Statement::Select(q) = stmt {
1470 assert_eq!(q.joins.len(), 2);
1471 assert_eq!(q.joins[0].join_type, JoinType::Inner);
1472 assert_eq!(q.joins[1].join_type, JoinType::Left);
1473 } else {
1474 panic!("Expected Select");
1475 }
1476 }
1477
1478 #[test]
1479 fn test_join_compound_and() {
1480 let stmt = parse_query(
1481 "SELECT s.title FROM strategies s LEFT JOIN backtests b ON b.strategy = s.path AND b.mode = 'PAPER'",
1482 )
1483 .unwrap();
1484 if let Statement::Select(q) = stmt {
1485 assert_eq!(q.joins.len(), 1);
1486 assert_eq!(q.joins[0].join_type, JoinType::Left);
1487 let sql = where_clause_to_sql(&q.joins[0].condition);
1488 assert!(sql.contains("b.strategy = s.path"));
1489 assert!(sql.contains("AND"));
1490 assert!(sql.contains("b.mode = 'PAPER'"));
1491 } else {
1492 panic!("Expected Select");
1493 }
1494 }
1495
1496 #[test]
1497 fn test_join_compound_or() {
1498 let stmt = parse_query(
1499 "SELECT * FROM a JOIN b ON a.id = b.id OR a.alt = b.id",
1500 )
1501 .unwrap();
1502 if let Statement::Select(q) = stmt {
1503 let sql = where_clause_to_sql(&q.joins[0].condition);
1504 assert!(sql.contains("OR"));
1505 } else {
1506 panic!("Expected Select");
1507 }
1508 }
1509
1510 #[test]
1511 fn test_join_compound_with_where() {
1512 let stmt = parse_query(
1513 "SELECT s.title FROM strategies s JOIN backtests b ON b.strategy = s.path AND b.mode = 'PAPER' WHERE s.title = 'Alpha'",
1514 )
1515 .unwrap();
1516 if let Statement::Select(q) = stmt {
1517 assert_eq!(q.joins.len(), 1);
1518 assert!(q.where_clause.is_some());
1519 let join_sql = where_clause_to_sql(&q.joins[0].condition);
1520 assert!(join_sql.contains("AND"));
1521 } else {
1522 panic!("Expected Select");
1523 }
1524 }
1525
1526 #[test]
1527 fn test_empty_query() {
1528 assert!(parse_query("").is_err());
1529 }
1530
1531 #[test]
1532 fn test_count_star() {
1533 let stmt = parse_query("SELECT status, COUNT(*) AS cnt FROM strategies GROUP BY status").unwrap();
1534 if let Statement::Select(q) = stmt {
1535 if let ColumnList::Named(exprs) = &q.columns {
1536 assert_eq!(exprs.len(), 2);
1537 assert_eq!(exprs[0], SelectExpr::Column("status".into()));
1538 assert!(matches!(&exprs[1], SelectExpr::Aggregate {
1539 func: AggFunc::Count,
1540 arg,
1541 alias: Some(a),
1542 ..
1543 } if arg == "*" && a == "cnt"));
1544 } else {
1545 panic!("Expected Named columns");
1546 }
1547 assert_eq!(q.group_by, Some(vec!["status".into()]));
1548 } else {
1549 panic!("Expected Select");
1550 }
1551 }
1552
1553 #[test]
1554 fn test_count_column_as_ident() {
1555 let stmt = parse_query("INSERT INTO test (title, count) VALUES ('Hello', 42)").unwrap();
1557 if let Statement::Insert(q) = stmt {
1558 assert_eq!(q.columns, vec!["title", "count"]);
1559 } else {
1560 panic!("Expected Insert");
1561 }
1562 }
1563
1564 #[test]
1565 fn test_multiple_aggregates() {
1566 let stmt = parse_query("SELECT MIN(composite), MAX(composite), AVG(composite) FROM strategies").unwrap();
1567 if let Statement::Select(q) = stmt {
1568 if let ColumnList::Named(exprs) = &q.columns {
1569 assert_eq!(exprs.len(), 3);
1570 assert!(matches!(&exprs[0], SelectExpr::Aggregate { func: AggFunc::Min, .. }));
1571 assert!(matches!(&exprs[1], SelectExpr::Aggregate { func: AggFunc::Max, .. }));
1572 assert!(matches!(&exprs[2], SelectExpr::Aggregate { func: AggFunc::Avg, .. }));
1573 } else {
1574 panic!("Expected Named columns");
1575 }
1576 assert_eq!(q.group_by, None);
1577 } else {
1578 panic!("Expected Select");
1579 }
1580 }
1581
1582 #[test]
1585 fn test_select_arithmetic_expr() {
1586 let stmt = parse_query("SELECT a + b FROM test").unwrap();
1587 if let Statement::Select(q) = stmt {
1588 if let ColumnList::Named(exprs) = &q.columns {
1589 assert_eq!(exprs.len(), 1);
1590 assert!(matches!(&exprs[0], SelectExpr::Expr {
1591 expr: Expr::BinaryOp { op: ArithOp::Add, .. },
1592 alias: None,
1593 }));
1594 } else {
1595 panic!("Expected Named columns");
1596 }
1597 } else {
1598 panic!("Expected Select");
1599 }
1600 }
1601
1602 #[test]
1603 fn test_select_arithmetic_with_alias() {
1604 let stmt = parse_query("SELECT a + b AS total FROM test").unwrap();
1605 if let Statement::Select(q) = stmt {
1606 if let ColumnList::Named(exprs) = &q.columns {
1607 assert_eq!(exprs.len(), 1);
1608 assert!(matches!(&exprs[0], SelectExpr::Expr {
1609 alias: Some(a),
1610 ..
1611 } if a == "total"));
1612 assert_eq!(exprs[0].output_name(), "total");
1613 } else {
1614 panic!("Expected Named columns");
1615 }
1616 } else {
1617 panic!("Expected Select");
1618 }
1619 }
1620
1621 #[test]
1622 fn test_select_precedence() {
1623 let stmt = parse_query("SELECT a + b * c FROM test").unwrap();
1625 if let Statement::Select(q) = stmt {
1626 if let ColumnList::Named(exprs) = &q.columns {
1627 if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1628 if let Expr::BinaryOp { left, op, right } = expr {
1629 assert_eq!(*op, ArithOp::Add);
1630 assert!(matches!(left.as_ref(), Expr::Column(n) if n == "a"));
1631 assert!(matches!(right.as_ref(), Expr::BinaryOp { op: ArithOp::Mul, .. }));
1632 } else {
1633 panic!("Expected BinaryOp");
1634 }
1635 } else {
1636 panic!("Expected Expr variant");
1637 }
1638 } else {
1639 panic!("Expected Named columns");
1640 }
1641 } else {
1642 panic!("Expected Select");
1643 }
1644 }
1645
1646 #[test]
1647 fn test_select_parenthesized_expr() {
1648 let stmt = parse_query("SELECT (a + b) * c FROM test").unwrap();
1650 if let Statement::Select(q) = stmt {
1651 if let ColumnList::Named(exprs) = &q.columns {
1652 if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1653 if let Expr::BinaryOp { left, op, .. } = expr {
1654 assert_eq!(*op, ArithOp::Mul);
1655 assert!(matches!(left.as_ref(), Expr::BinaryOp { op: ArithOp::Add, .. }));
1656 } else {
1657 panic!("Expected BinaryOp");
1658 }
1659 } else {
1660 panic!("Expected Expr variant");
1661 }
1662 } else {
1663 panic!("Expected Named columns");
1664 }
1665 } else {
1666 panic!("Expected Select");
1667 }
1668 }
1669
1670 #[test]
1671 fn test_select_unary_minus() {
1672 let stmt = parse_query("SELECT -count FROM test").unwrap();
1673 if let Statement::Select(q) = stmt {
1674 if let ColumnList::Named(exprs) = &q.columns {
1675 assert!(matches!(&exprs[0], SelectExpr::Expr {
1676 expr: Expr::UnaryMinus(_),
1677 ..
1678 }));
1679 } else {
1680 panic!("Expected Named columns");
1681 }
1682 } else {
1683 panic!("Expected Select");
1684 }
1685 }
1686
1687 #[test]
1688 fn test_select_negative_literal() {
1689 let stmt = parse_query("SELECT -42 FROM test").unwrap();
1690 if let Statement::Select(q) = stmt {
1691 if let ColumnList::Named(exprs) = &q.columns {
1692 assert!(matches!(&exprs[0], SelectExpr::Expr {
1694 expr: Expr::Literal(SqlValue::Int(-42)),
1695 ..
1696 }));
1697 } else {
1698 panic!("Expected Named columns");
1699 }
1700 } else {
1701 panic!("Expected Select");
1702 }
1703 }
1704
1705 #[test]
1706 fn test_where_arithmetic_expr() {
1707 let stmt = parse_query("SELECT * FROM test WHERE a + b > 10").unwrap();
1708 if let Statement::Select(q) = stmt {
1709 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1710 assert_eq!(c.op, CmpOp::Gt);
1711 assert!(matches!(&c.left_expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1712 assert!(matches!(&c.right_expr, Some(Expr::Literal(SqlValue::Int(10)))));
1713 } else {
1714 panic!("Expected comparison");
1715 }
1716 } else {
1717 panic!("Expected Select");
1718 }
1719 }
1720
1721 #[test]
1722 fn test_where_both_sides_expr() {
1723 let stmt = parse_query("SELECT * FROM test WHERE a * 2 > b + 1").unwrap();
1724 if let Statement::Select(q) = stmt {
1725 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1726 assert_eq!(c.op, CmpOp::Gt);
1727 assert!(matches!(&c.left_expr, Some(Expr::BinaryOp { op: ArithOp::Mul, .. })));
1728 assert!(matches!(&c.right_expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1729 } else {
1730 panic!("Expected comparison");
1731 }
1732 } else {
1733 panic!("Expected Select");
1734 }
1735 }
1736
1737 #[test]
1738 fn test_order_by_expr() {
1739 let stmt = parse_query("SELECT * FROM test ORDER BY a + b DESC").unwrap();
1740 if let Statement::Select(q) = stmt {
1741 let ob = q.order_by.unwrap();
1742 assert_eq!(ob.len(), 1);
1743 assert!(ob[0].descending);
1744 assert!(matches!(&ob[0].expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1745 } else {
1746 panic!("Expected Select");
1747 }
1748 }
1749
1750 #[test]
1751 fn test_all_arithmetic_ops() {
1752 let stmt = parse_query("SELECT a + b, a - b, a * b, a / b, a % b FROM test").unwrap();
1753 if let Statement::Select(q) = stmt {
1754 if let ColumnList::Named(exprs) = &q.columns {
1755 assert_eq!(exprs.len(), 5);
1756 assert!(matches!(&exprs[0], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Add, .. }, .. }));
1757 assert!(matches!(&exprs[1], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Sub, .. }, .. }));
1758 assert!(matches!(&exprs[2], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Mul, .. }, .. }));
1759 assert!(matches!(&exprs[3], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Div, .. }, .. }));
1760 assert!(matches!(&exprs[4], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Mod, .. }, .. }));
1761 } else {
1762 panic!("Expected Named columns");
1763 }
1764 } else {
1765 panic!("Expected Select");
1766 }
1767 }
1768
1769 #[test]
1770 fn test_column_with_literal_arithmetic() {
1771 let stmt = parse_query("SELECT count * 2 + 1 FROM test").unwrap();
1772 if let Statement::Select(q) = stmt {
1773 if let ColumnList::Named(exprs) = &q.columns {
1774 if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1776 if let Expr::BinaryOp { left, op, right } = expr {
1777 assert_eq!(*op, ArithOp::Add);
1778 assert!(matches!(right.as_ref(), Expr::Literal(SqlValue::Int(1))));
1779 assert!(matches!(left.as_ref(), Expr::BinaryOp { op: ArithOp::Mul, .. }));
1780 } else {
1781 panic!("Expected BinaryOp");
1782 }
1783 } else {
1784 panic!("Expected Expr");
1785 }
1786 } else {
1787 panic!("Expected Named columns");
1788 }
1789 } else {
1790 panic!("Expected Select");
1791 }
1792 }
1793
1794 #[test]
1795 fn test_mixed_columns_and_exprs() {
1796 let stmt = parse_query("SELECT title, a + b AS sum, count FROM test").unwrap();
1797 if let Statement::Select(q) = stmt {
1798 if let ColumnList::Named(exprs) = &q.columns {
1799 assert_eq!(exprs.len(), 3);
1800 assert_eq!(exprs[0], SelectExpr::Column("title".into()));
1801 assert!(matches!(&exprs[1], SelectExpr::Expr { alias: Some(a), .. } if a == "sum"));
1802 assert_eq!(exprs[2], SelectExpr::Column("count".into()));
1803 } else {
1804 panic!("Expected Named columns");
1805 }
1806 } else {
1807 panic!("Expected Select");
1808 }
1809 }
1810
1811 #[test]
1814 fn test_case_when_basic() {
1815 let stmt = parse_query(
1816 "SELECT CASE WHEN status = 'ACTIVE' THEN 1 ELSE 0 END FROM test"
1817 ).unwrap();
1818 if let Statement::Select(q) = stmt {
1819 if let ColumnList::Named(exprs) = &q.columns {
1820 assert_eq!(exprs.len(), 1);
1821 assert!(matches!(&exprs[0], SelectExpr::Expr {
1822 expr: Expr::Case { .. },
1823 ..
1824 }));
1825 } else {
1826 panic!("Expected Named columns");
1827 }
1828 } else {
1829 panic!("Expected Select");
1830 }
1831 }
1832
1833 #[test]
1834 fn test_case_when_multiple_branches() {
1835 let stmt = parse_query(
1836 "SELECT CASE WHEN x > 10 THEN 'high' WHEN x > 5 THEN 'mid' ELSE 'low' END FROM test"
1837 ).unwrap();
1838 if let Statement::Select(q) = stmt {
1839 if let ColumnList::Named(exprs) = &q.columns {
1840 if let SelectExpr::Expr { expr: Expr::Case { whens, else_expr }, .. } = &exprs[0] {
1841 assert_eq!(whens.len(), 2);
1842 assert!(else_expr.is_some());
1843 } else {
1844 panic!("Expected Case expression");
1845 }
1846 } else {
1847 panic!("Expected Named columns");
1848 }
1849 } else {
1850 panic!("Expected Select");
1851 }
1852 }
1853
1854 #[test]
1855 fn test_case_when_no_else() {
1856 let stmt = parse_query(
1857 "SELECT CASE WHEN x = 1 THEN 'one' END FROM test"
1858 ).unwrap();
1859 if let Statement::Select(q) = stmt {
1860 if let ColumnList::Named(exprs) = &q.columns {
1861 if let SelectExpr::Expr { expr: Expr::Case { whens, else_expr }, .. } = &exprs[0] {
1862 assert_eq!(whens.len(), 1);
1863 assert!(else_expr.is_none());
1864 } else {
1865 panic!("Expected Case expression");
1866 }
1867 } else {
1868 panic!("Expected Named columns");
1869 }
1870 } else {
1871 panic!("Expected Select");
1872 }
1873 }
1874
1875 #[test]
1876 fn test_case_when_in_aggregate() {
1877 let stmt = parse_query(
1878 "SELECT SUM(CASE WHEN side = 'BUY' THEN size ELSE -size END) AS net FROM orders GROUP BY token"
1879 ).unwrap();
1880 if let Statement::Select(q) = stmt {
1881 if let ColumnList::Named(exprs) = &q.columns {
1882 assert_eq!(exprs.len(), 1);
1883 assert!(matches!(&exprs[0], SelectExpr::Aggregate {
1884 func: AggFunc::Sum,
1885 arg_expr: Some(Expr::Case { .. }),
1886 alias: Some(a),
1887 ..
1888 } if a == "net"));
1889 } else {
1890 panic!("Expected Named columns");
1891 }
1892 } else {
1893 panic!("Expected Select");
1894 }
1895 }
1896
1897 #[test]
1898 fn test_case_when_with_alias() {
1899 let stmt = parse_query(
1900 "SELECT CASE WHEN x > 0 THEN 'pos' ELSE 'neg' END AS sign FROM test"
1901 ).unwrap();
1902 if let Statement::Select(q) = stmt {
1903 if let ColumnList::Named(exprs) = &q.columns {
1904 assert!(matches!(&exprs[0], SelectExpr::Expr {
1905 expr: Expr::Case { .. },
1906 alias: Some(a),
1907 } if a == "sign"));
1908 } else {
1909 panic!("Expected Named columns");
1910 }
1911 } else {
1912 panic!("Expected Select");
1913 }
1914 }
1915
1916 #[test]
1917 fn test_create_view() {
1918 let stmt = parse_query("CREATE VIEW live AS SELECT * FROM strategies WHERE status = 'LIVE'").unwrap();
1919 if let Statement::CreateView(cv) = stmt {
1920 assert_eq!(cv.view_name, "live");
1921 assert!(cv.columns.is_none());
1922 assert_eq!(cv.query.table, "strategies");
1923 assert!(cv.query.where_clause.is_some());
1924 } else {
1925 panic!("Expected CreateView, got {:?}", stmt);
1926 }
1927 }
1928
1929 #[test]
1930 fn test_create_view_with_columns() {
1931 let stmt = parse_query("CREATE VIEW v1 (a, b) AS SELECT title, status FROM t").unwrap();
1932 if let Statement::CreateView(cv) = stmt {
1933 assert_eq!(cv.view_name, "v1");
1934 assert_eq!(cv.columns, Some(vec!["a".into(), "b".into()]));
1935 } else {
1936 panic!("Expected CreateView");
1937 }
1938 }
1939
1940 #[test]
1941 fn test_drop_view() {
1942 let stmt = parse_query("DROP VIEW live").unwrap();
1943 if let Statement::DropView(dv) = stmt {
1944 assert_eq!(dv.view_name, "live");
1945 } else {
1946 panic!("Expected DropView, got {:?}", stmt);
1947 }
1948 }
1949
1950 #[test]
1951 fn test_create_view_case_insensitive() {
1952 let stmt = parse_query("create view My_View as select * from t").unwrap();
1953 if let Statement::CreateView(cv) = stmt {
1954 assert_eq!(cv.view_name, "My_View");
1955 } else {
1956 panic!("Expected CreateView");
1957 }
1958 }
1959
1960 #[test]
1963 fn test_aggregate_division() {
1964 let stmt = parse_query(
1965 "SELECT token, SUM(sell) / SUM(buy) as ratio FROM orders GROUP BY token"
1966 ).unwrap();
1967 if let Statement::Select(q) = stmt {
1968 assert_eq!(q.group_by, Some(vec!["token".into()]));
1969 if let ColumnList::Named(exprs) = &q.columns {
1970 assert_eq!(exprs.len(), 2);
1971 assert!(exprs[1].is_aggregate());
1972 } else {
1973 panic!("Expected Named columns");
1974 }
1975 } else {
1976 panic!("Expected Select");
1977 }
1978 }
1979
1980 #[test]
1981 fn test_aggregate_subtraction() {
1982 let stmt = parse_query(
1983 "SELECT token, SUM(sell) - SUM(buy) as net FROM orders GROUP BY token"
1984 ).unwrap();
1985 if let Statement::Select(q) = stmt {
1986 if let ColumnList::Named(exprs) = &q.columns {
1987 assert_eq!(exprs[1].output_name(), "net");
1988 }
1989 } else {
1990 panic!("Expected Select");
1991 }
1992 }
1993
1994 #[test]
1995 fn test_create_view_with_arithmetic() {
1996 let stmt = parse_query(
1997 "CREATE VIEW positions AS SELECT token, SUM(sell) / SUM(buy) as ratio FROM orders GROUP BY token"
1998 ).unwrap();
1999 if let Statement::CreateView(cv) = stmt {
2000 assert_eq!(cv.view_name, "positions");
2001 } else {
2002 panic!("Expected CreateView, got {:?}", stmt);
2003 }
2004 }
2005
2006 #[test]
2009 fn test_subquery_in_from() {
2010 let stmt = parse_query(
2011 "SELECT token, sell_size FROM (SELECT token, SUM(size) as sell_size FROM orders GROUP BY token) LIMIT 5"
2012 ).unwrap();
2013 if let Statement::Select(q) = stmt {
2014 assert!(q.subquery.is_some());
2015 assert_eq!(q.limit, Some(5));
2016 let sub = q.subquery.unwrap();
2017 assert_eq!(sub.table, "orders");
2018 assert!(sub.group_by.is_some());
2019 } else {
2020 panic!("Expected Select");
2021 }
2022 }
2023
2024 #[test]
2027 fn test_create_view_with_having() {
2028 let stmt = parse_query(
2029 "CREATE VIEW positions AS SELECT token, SUM(sell) as sell_size, SUM(buy) as buy_size FROM orders GROUP BY token HAVING sell_size > buy_size"
2030 ).unwrap();
2031 if let Statement::CreateView(cv) = stmt {
2032 assert_eq!(cv.view_name, "positions");
2033 assert!(cv.query.having.is_some());
2034 } else {
2035 panic!("Expected CreateView, got {:?}", stmt);
2036 }
2037 }
2038
2039 #[test]
2042 fn test_aggregate_multiplication() {
2043 let stmt = parse_query(
2044 "SELECT SUM(a) * 2 as doubled FROM test"
2045 ).unwrap();
2046 if let Statement::Select(q) = stmt {
2047 if let ColumnList::Named(exprs) = &q.columns {
2048 assert_eq!(exprs.len(), 1);
2049 assert!(exprs[0].is_aggregate());
2050 assert_eq!(exprs[0].output_name(), "doubled");
2051 } else {
2052 panic!("Expected Named columns");
2053 }
2054 } else {
2055 panic!("Expected Select");
2056 }
2057 }
2058
2059 #[test]
2060 fn test_complex_aggregate_arithmetic() {
2061 let stmt = parse_query(
2062 "SELECT SUM(CASE WHEN side = 'SELL' THEN size ELSE 0 END) / SUM(CASE WHEN side = 'BUY' THEN size ELSE 0 END) as ratio FROM orders GROUP BY token"
2063 ).unwrap();
2064 if let Statement::Select(q) = stmt {
2065 if let ColumnList::Named(exprs) = &q.columns {
2066 assert_eq!(exprs.len(), 1);
2067 assert!(exprs[0].is_aggregate());
2068 assert_eq!(exprs[0].output_name(), "ratio");
2069 } else {
2070 panic!("Expected Named columns");
2071 }
2072 assert_eq!(q.group_by, Some(vec!["token".into()]));
2073 } else {
2074 panic!("Expected Select");
2075 }
2076 }
2077
2078 #[test]
2081 fn test_subquery_with_alias() {
2082 let stmt = parse_query(
2083 "SELECT x FROM (SELECT x FROM t) sub"
2084 ).unwrap();
2085 if let Statement::Select(q) = stmt {
2086 assert!(q.subquery.is_some());
2087 let sub = q.subquery.unwrap();
2088 assert_eq!(sub.table, "t");
2089 if let ColumnList::Named(exprs) = &q.columns {
2090 assert_eq!(exprs.len(), 1);
2091 assert_eq!(exprs[0].output_name(), "x");
2092 } else {
2093 panic!("Expected Named columns");
2094 }
2095 } else {
2096 panic!("Expected Select");
2097 }
2098 }
2099
2100 #[test]
2101 fn test_subquery_with_where() {
2102 let stmt = parse_query(
2103 "SELECT x FROM (SELECT x FROM t WHERE y > 0) LIMIT 5"
2104 ).unwrap();
2105 if let Statement::Select(q) = stmt {
2106 assert!(q.subquery.is_some());
2107 assert_eq!(q.limit, Some(5));
2108 let sub = q.subquery.unwrap();
2109 assert_eq!(sub.table, "t");
2110 assert!(sub.where_clause.is_some());
2111 } else {
2112 panic!("Expected Select");
2113 }
2114 }
2115
2116 #[test]
2119 fn test_create_view_aggregate_subtraction() {
2120 let stmt = parse_query(
2121 "CREATE VIEW v AS SELECT token, SUM(sell) - SUM(buy) as net FROM orders GROUP BY token"
2122 ).unwrap();
2123 if let Statement::CreateView(cv) = stmt {
2124 assert_eq!(cv.view_name, "v");
2125 assert_eq!(cv.query.group_by, Some(vec!["token".into()]));
2126 if let ColumnList::Named(exprs) = &cv.query.columns {
2127 assert_eq!(exprs.len(), 2);
2128 assert_eq!(exprs[1].output_name(), "net");
2129 assert!(exprs[1].is_aggregate());
2130 } else {
2131 panic!("Expected Named columns");
2132 }
2133 } else {
2134 panic!("Expected CreateView, got {:?}", stmt);
2135 }
2136 }
2137
2138 #[test]
2139 fn test_delete_cascade() {
2140 let stmt = parse_query("DELETE FROM strategies WHERE status = 'KILLED' CASCADE").unwrap();
2141 if let Statement::Delete(q) = stmt {
2142 assert_eq!(q.table, "strategies");
2143 assert!(q.where_clause.is_some());
2144 assert_eq!(q.mode, DeleteMode::Cascade);
2145 } else {
2146 panic!("Expected Delete");
2147 }
2148 }
2149
2150 #[test]
2151 fn test_delete_restrict() {
2152 let stmt = parse_query("DELETE FROM strategies WHERE path = 'alpha.md' RESTRICT").unwrap();
2153 if let Statement::Delete(q) = stmt {
2154 assert_eq!(q.table, "strategies");
2155 assert_eq!(q.mode, DeleteMode::Restrict);
2156 } else {
2157 panic!("Expected Delete");
2158 }
2159 }
2160
2161 #[test]
2162 fn test_delete_default_unchanged() {
2163 let stmt = parse_query("DELETE FROM strategies WHERE status = 'KILLED'").unwrap();
2164 if let Statement::Delete(q) = stmt {
2165 assert_eq!(q.mode, DeleteMode::Default);
2166 } else {
2167 panic!("Expected Delete");
2168 }
2169 }
2170
2171 #[test]
2172 fn test_delete_cascade_no_where() {
2173 let stmt = parse_query("DELETE FROM strategies CASCADE").unwrap();
2174 if let Statement::Delete(q) = stmt {
2175 assert_eq!(q.table, "strategies");
2176 assert!(q.where_clause.is_none());
2177 assert_eq!(q.mode, DeleteMode::Cascade);
2178 } else {
2179 panic!("Expected Delete");
2180 }
2181 }
2182
2183 #[test]
2186 fn test_cte_basic() {
2187 let stmt = parse_query(
2188 "WITH live AS (SELECT * FROM strategies WHERE status = 'LIVE') SELECT * FROM live"
2189 ).unwrap();
2190 if let Statement::Select(q) = stmt {
2191 assert_eq!(q.ctes.len(), 1);
2192 assert_eq!(q.ctes[0].name, "live");
2193 assert_eq!(q.ctes[0].query.table, "strategies");
2194 assert!(q.ctes[0].query.where_clause.is_some());
2195 assert_eq!(q.table, "live");
2196 } else {
2197 panic!("Expected Select");
2198 }
2199 }
2200
2201 #[test]
2202 fn test_cte_multi() {
2203 let stmt = parse_query(
2204 "WITH a AS (SELECT * FROM t1), b AS (SELECT * FROM t2) SELECT * FROM a JOIN b ON a.id = b.id"
2205 ).unwrap();
2206 if let Statement::Select(q) = stmt {
2207 assert_eq!(q.ctes.len(), 2);
2208 assert_eq!(q.ctes[0].name, "a");
2209 assert_eq!(q.ctes[0].query.table, "t1");
2210 assert_eq!(q.ctes[1].name, "b");
2211 assert_eq!(q.ctes[1].query.table, "t2");
2212 assert_eq!(q.table, "a");
2213 assert_eq!(q.joins.len(), 1);
2214 } else {
2215 panic!("Expected Select");
2216 }
2217 }
2218
2219 #[test]
2220 fn test_cte_with_aggregation() {
2221 let stmt = parse_query(
2222 "WITH totals AS (SELECT strategy, COUNT(*) AS cnt FROM backtests GROUP BY strategy) SELECT * FROM totals WHERE cnt > 1"
2223 ).unwrap();
2224 if let Statement::Select(q) = stmt {
2225 assert_eq!(q.ctes.len(), 1);
2226 assert_eq!(q.ctes[0].name, "totals");
2227 assert!(q.ctes[0].query.group_by.is_some());
2228 assert_eq!(q.table, "totals");
2229 assert!(q.where_clause.is_some());
2230 } else {
2231 panic!("Expected Select");
2232 }
2233 }
2234
2235 #[test]
2236 fn test_cte_no_ctes_on_plain_select() {
2237 let stmt = parse_query("SELECT * FROM t").unwrap();
2238 if let Statement::Select(q) = stmt {
2239 assert!(q.ctes.is_empty());
2240 } else {
2241 panic!("Expected Select");
2242 }
2243 }
2244
2245 #[test]
2248 fn test_where_in_subquery() {
2249 let stmt = parse_query(
2250 "SELECT * FROM strategies WHERE path IN (SELECT strategy FROM backtests)"
2251 ).unwrap();
2252 if let Statement::Select(q) = stmt {
2253 if let Some(WhereClause::Comparison(c)) = &q.where_clause {
2254 assert_eq!(c.op, CmpOp::In);
2255 assert!(matches!(&c.right_expr, Some(Expr::Subquery(_))));
2256 } else {
2257 panic!("Expected IN comparison");
2258 }
2259 } else {
2260 panic!("Expected Select");
2261 }
2262 }
2263
2264 #[test]
2265 fn test_scalar_subquery_in_where() {
2266 let stmt = parse_query(
2267 "SELECT * FROM backtests WHERE sharpe > (SELECT AVG(sharpe) FROM backtests)"
2268 ).unwrap();
2269 if let Statement::Select(q) = stmt {
2270 if let Some(WhereClause::Comparison(c)) = &q.where_clause {
2271 assert_eq!(c.op, CmpOp::Gt);
2272 assert!(matches!(&c.right_expr, Some(Expr::Subquery(_))));
2273 } else {
2274 panic!("Expected comparison");
2275 }
2276 } else {
2277 panic!("Expected Select");
2278 }
2279 }
2280
2281 #[test]
2282 fn test_scalar_subquery_in_select() {
2283 let stmt = parse_query(
2284 "SELECT title, (SELECT COUNT(*) FROM backtests) AS cnt FROM strategies"
2285 ).unwrap();
2286 if let Statement::Select(q) = stmt {
2287 if let ColumnList::Named(exprs) = &q.columns {
2288 assert_eq!(exprs.len(), 2);
2289 assert!(matches!(&exprs[1], SelectExpr::Expr {
2290 expr: Expr::Subquery(_),
2291 alias: Some(a),
2292 } if a == "cnt"));
2293 } else {
2294 panic!("Expected Named columns");
2295 }
2296 } else {
2297 panic!("Expected Select");
2298 }
2299 }
2300
2301 #[test]
2304 fn test_row_number_over_order_by() {
2305 let stmt = parse_query(
2306 "SELECT title, ROW_NUMBER() OVER (ORDER BY count DESC) AS rn FROM test"
2307 ).unwrap();
2308 if let Statement::Select(q) = stmt {
2309 if let ColumnList::Named(exprs) = &q.columns {
2310 assert_eq!(exprs.len(), 2);
2311 if let SelectExpr::Expr { expr: Expr::Window { func, args, over }, alias } = &exprs[1] {
2312 assert_eq!(*func, WindowFunc::RowNumber);
2313 assert!(args.is_empty());
2314 assert!(over.partition_by.is_empty());
2315 assert_eq!(over.order_by.len(), 1);
2316 assert!(over.order_by[0].descending);
2317 assert_eq!(alias.as_deref(), Some("rn"));
2318 } else {
2319 panic!("Expected Window expression, got {:?}", exprs[1]);
2320 }
2321 } else {
2322 panic!("Expected Named columns");
2323 }
2324 } else {
2325 panic!("Expected Select");
2326 }
2327 }
2328
2329 #[test]
2330 fn test_rank_with_partition_by() {
2331 let stmt = parse_query(
2332 "SELECT RANK() OVER (PARTITION BY category ORDER BY price DESC) AS rnk FROM test"
2333 ).unwrap();
2334 if let Statement::Select(q) = stmt {
2335 if let ColumnList::Named(exprs) = &q.columns {
2336 if let SelectExpr::Expr { expr: Expr::Window { func, over, .. }, .. } = &exprs[0] {
2337 assert_eq!(*func, WindowFunc::Rank);
2338 assert_eq!(over.partition_by, vec!["category"]);
2339 assert_eq!(over.order_by.len(), 1);
2340 } else {
2341 panic!("Expected Window expression");
2342 }
2343 } else {
2344 panic!("Expected Named columns");
2345 }
2346 } else {
2347 panic!("Expected Select");
2348 }
2349 }
2350
2351 #[test]
2352 fn test_agg_over_window() {
2353 let stmt = parse_query(
2354 "SELECT SUM(price) OVER (PARTITION BY category) AS cat_total FROM test"
2355 ).unwrap();
2356 if let Statement::Select(q) = stmt {
2357 if let ColumnList::Named(exprs) = &q.columns {
2358 if let SelectExpr::Expr { expr: Expr::Window { func, args, over }, alias } = &exprs[0] {
2359 assert!(matches!(func, WindowFunc::Agg(AggFunc::Sum)));
2360 assert_eq!(args.len(), 1);
2361 assert_eq!(over.partition_by, vec!["category"]);
2362 assert!(over.order_by.is_empty());
2363 assert_eq!(alias.as_deref(), Some("cat_total"));
2364 } else {
2365 panic!("Expected Window expression");
2366 }
2367 } else {
2368 panic!("Expected Named columns");
2369 }
2370 } else {
2371 panic!("Expected Select");
2372 }
2373 }
2374
2375 #[test]
2376 fn test_lag_with_args() {
2377 let stmt = parse_query(
2378 "SELECT LAG(price, 1) OVER (ORDER BY price) AS prev_price FROM test"
2379 ).unwrap();
2380 if let Statement::Select(q) = stmt {
2381 if let ColumnList::Named(exprs) = &q.columns {
2382 if let SelectExpr::Expr { expr: Expr::Window { func, args, .. }, .. } = &exprs[0] {
2383 assert_eq!(*func, WindowFunc::Lag);
2384 assert_eq!(args.len(), 2);
2385 } else {
2386 panic!("Expected Window expression");
2387 }
2388 } else {
2389 panic!("Expected Named columns");
2390 }
2391 } else {
2392 panic!("Expected Select");
2393 }
2394 }
2395
2396 #[test]
2397 fn test_dense_rank() {
2398 let stmt = parse_query(
2399 "SELECT DENSE_RANK() OVER (ORDER BY count DESC) AS dr FROM test"
2400 ).unwrap();
2401 if let Statement::Select(q) = stmt {
2402 if let ColumnList::Named(exprs) = &q.columns {
2403 if let SelectExpr::Expr { expr: Expr::Window { func, .. }, .. } = &exprs[0] {
2404 assert_eq!(*func, WindowFunc::DenseRank);
2405 } else {
2406 panic!("Expected Window expression");
2407 }
2408 } else {
2409 panic!("Expected Named columns");
2410 }
2411 } else {
2412 panic!("Expected Select");
2413 }
2414 }
2415
2416 #[test]
2417 fn test_sum_without_over_is_aggregate() {
2418 let stmt = parse_query("SELECT SUM(count) FROM test").unwrap();
2419 if let Statement::Select(q) = stmt {
2420 if let ColumnList::Named(exprs) = &q.columns {
2421 assert!(matches!(&exprs[0], SelectExpr::Aggregate { func: AggFunc::Sum, .. }));
2422 } else {
2423 panic!("Expected Named columns");
2424 }
2425 } else {
2426 panic!("Expected Select");
2427 }
2428 }
2429}