1use regex::Regex;
4use std::sync::LazyLock;
5
6use crate::errors::MdqlError;
7pub use crate::query_ast::*;
8
9static KEYWORDS: &[&str] = &[
12 "SELECT", "FROM", "WHERE", "AND", "OR", "ORDER", "BY",
13 "ASC", "DESC", "LIMIT", "LIKE", "IN", "IS", "NOT", "NULL",
14 "JOIN", "LEFT", "ON", "AS", "GROUP", "HAVING",
15 "INSERT", "INTO", "VALUES", "UPDATE", "SET", "DELETE",
16 "ALTER", "TABLE", "RENAME", "FIELD", "TO", "DROP", "MERGE", "FIELDS",
17 "CASE", "WHEN", "THEN", "ELSE", "END",
18 "INTERVAL", "DAY", "DAYS", "CURRENT_DATE", "CURRENT_TIMESTAMP", "DATEDIFF",
19 "CREATE", "VIEW", "CASCADE", "RESTRICT",
20 "WITH",
21 "OVER", "PARTITION", "ROW_NUMBER", "RANK", "DENSE_RANK", "LAG", "LEAD",
22];
23
24static AGG_FUNCS: &[&str] = &["COUNT", "SUM", "AVG", "MIN", "MAX"];
25static WINDOW_FUNCS: &[&str] = &["ROW_NUMBER", "RANK", "DENSE_RANK", "LAG", "LEAD"];
26
27static TOKEN_RE: LazyLock<Regex> = LazyLock::new(|| {
28 Regex::new(
29 r#"(?x)
30 \s*(?:
31 (?P<backtick>`[^`]+`)
32 | (?P<string>'(?:[^'\\]|\\.)*')
33 | (?P<number>\d+(?:\.\d+)?)
34 | (?P<op><=|>=|!=|[=<>,*()+\-/%])
35 | (?P<word>[A-Za-z_][A-Za-z0-9_./-]*)
36 )"#,
37 )
38 .unwrap()
39});
40
41#[derive(Debug, Clone)]
42struct Token {
43 token_type: String,
44 value: String,
45 raw: String,
46}
47
48fn tokenize(sql: &str) -> Vec<Token> {
49 let mut tokens = Vec::new();
50 for caps in TOKEN_RE.captures_iter(sql) {
51 if let Some(m) = caps.name("backtick") {
52 let raw = m.as_str();
53 tokens.push(Token {
54 token_type: "ident".into(),
55 value: raw[1..raw.len() - 1].into(),
56 raw: raw.into(),
57 });
58 } else if let Some(m) = caps.name("string") {
59 let raw = m.as_str();
60 tokens.push(Token {
61 token_type: "string".into(),
62 value: raw[1..raw.len() - 1].into(),
63 raw: raw.into(),
64 });
65 } else if let Some(m) = caps.name("number") {
66 let raw = m.as_str();
67 tokens.push(Token {
68 token_type: "number".into(),
69 value: raw.into(),
70 raw: raw.into(),
71 });
72 } else if let Some(m) = caps.name("op") {
73 let raw = m.as_str();
74 tokens.push(Token {
75 token_type: "op".into(),
76 value: raw.into(),
77 raw: raw.into(),
78 });
79 } else if let Some(m) = caps.name("word") {
80 let raw = m.as_str();
81 if KEYWORDS.contains(&raw.to_uppercase().as_str()) {
82 tokens.push(Token {
83 token_type: "keyword".into(),
84 value: raw.to_uppercase(),
85 raw: raw.into(),
86 });
87 } else {
88 tokens.push(Token {
89 token_type: "ident".into(),
90 value: raw.into(),
91 raw: raw.into(),
92 });
93 }
94 }
95 }
96 tokens
97}
98
99struct Parser {
102 tokens: Vec<Token>,
103 pos: usize,
104}
105
106impl Parser {
107 fn new(tokens: Vec<Token>) -> Self {
108 Parser { tokens, pos: 0 }
109 }
110
111 fn peek(&self) -> Option<&Token> {
112 self.tokens.get(self.pos)
113 }
114
115 fn advance(&mut self) -> Token {
116 let t = self.tokens[self.pos].clone();
117 self.pos += 1;
118 t
119 }
120
121 fn expect(&mut self, type_: &str, value: Option<&str>) -> Result<Token, MdqlError> {
122 let t = self.peek().ok_or_else(|| {
123 MdqlError::QueryParse(format!(
124 "Unexpected end of query, expected {}",
125 value.unwrap_or(type_)
126 ))
127 })?;
128 let matches_type = t.token_type == type_;
129 let matches_value = value.map_or(true, |v| t.value == v);
130 if !matches_type || !matches_value {
131 return Err(MdqlError::QueryParse(format!(
132 "Expected {}, got '{}' at position {}",
133 value.unwrap_or(type_),
134 t.raw,
135 self.pos
136 )));
137 }
138 Ok(self.advance())
139 }
140
141 fn match_keyword(&mut self, kw: &str) -> bool {
142 if let Some(t) = self.peek() {
143 if t.token_type == "keyword" && t.value == kw {
144 self.advance();
145 return true;
146 }
147 }
148 false
149 }
150
151 fn parse_statement(&mut self) -> Result<Statement, MdqlError> {
152 let t = self.peek().ok_or_else(|| MdqlError::QueryParse("Empty query".into()))?;
153 match (t.token_type.as_str(), t.value.as_str()) {
154 ("keyword", "WITH") => {
155 let ctes = self.parse_ctes()?;
156 let mut q = self.parse_select()?;
157 q.ctes = ctes;
158 self.expect_end()?;
159 Ok(Statement::Select(q))
160 }
161 ("keyword", "SELECT") => {
162 let q = self.parse_select()?;
163 self.expect_end()?;
164 Ok(Statement::Select(q))
165 }
166 ("keyword", "INSERT") => Ok(Statement::Insert(self.parse_insert()?)),
167 ("keyword", "UPDATE") => Ok(Statement::Update(self.parse_update()?)),
168 ("keyword", "DELETE") => Ok(Statement::Delete(self.parse_delete()?)),
169 ("keyword", "ALTER") => self.parse_alter(),
170 ("keyword", "CREATE") => self.parse_create_view(),
171 ("keyword", "DROP") => self.parse_drop_view(),
172 _ => Err(MdqlError::QueryParse(format!(
173 "Expected SELECT, INSERT, UPDATE, DELETE, ALTER, CREATE, or DROP, got '{}'",
174 t.raw
175 ))),
176 }
177 }
178
179 fn parse_ctes(&mut self) -> Result<Vec<CteClause>, MdqlError> {
180 self.expect("keyword", Some("WITH"))?;
181 let mut ctes = Vec::new();
182 loop {
183 let name = self.parse_ident()?;
184 self.expect("keyword", Some("AS"))?;
185 self.expect("op", Some("("))?;
186 let query = self.parse_select()?;
187 self.expect("op", Some(")"))?;
188 ctes.push(CteClause { name, query: Box::new(query) });
189 if !self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
190 break;
191 }
192 self.advance();
193 }
194 Ok(ctes)
195 }
196
197 fn parse_select(&mut self) -> Result<SelectQuery, MdqlError> {
198 self.expect("keyword", Some("SELECT"))?;
199 let columns = self.parse_columns()?;
200 self.expect("keyword", Some("FROM"))?;
201
202 let mut subquery = None;
204 let (table, mut table_alias) = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
205 self.advance();
206 let inner = self.parse_select()?;
207 self.expect("op", Some(")"))?;
208 subquery = Some(Box::new(inner));
209 let alias = if let Some(t) = self.peek() {
210 if t.token_type == "ident" && !self.is_clause_keyword(t) {
211 Some(self.advance().value)
212 } else {
213 None
214 }
215 } else {
216 None
217 };
218 ("_subquery".to_string(), alias)
219 } else {
220 let t = self.parse_ident()?;
221 (t, None)
222 };
223
224 if subquery.is_none() {
226 if let Some(t) = self.peek() {
227 if t.token_type == "ident" && !self.is_clause_keyword(t) {
228 table_alias = Some(self.advance().value);
229 }
230 }
231 }
232
233 let mut joins = Vec::new();
235 loop {
236 let jt = if self.match_keyword("LEFT") {
237 self.expect("keyword", Some("JOIN"))?;
238 JoinType::Left
239 } else if self.match_keyword("JOIN") {
240 JoinType::Inner
241 } else {
242 break;
243 };
244 let join_table = self.parse_ident()?;
245 let mut join_alias = None;
246 if let Some(t) = self.peek() {
247 if t.token_type == "ident" && !self.is_clause_keyword(t) {
248 join_alias = Some(self.advance().value);
249 }
250 }
251 self.expect("keyword", Some("ON"))?;
252 let condition = self.parse_or_expr()?;
253 joins.push(JoinClause {
254 join_type: jt,
255 table: join_table,
256 alias: join_alias,
257 condition,
258 });
259 }
260
261 let mut where_clause = None;
262 if self.match_keyword("WHERE") {
263 where_clause = Some(self.parse_or_expr()?);
264 }
265
266 let mut group_by = None;
267 if self.match_keyword("GROUP") {
268 self.expect("keyword", Some("BY"))?;
269 let mut cols = vec![self.parse_ident()?];
270 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
271 self.advance();
272 cols.push(self.parse_ident()?);
273 }
274 group_by = Some(cols);
275 }
276
277 let mut having = None;
278 if self.match_keyword("HAVING") {
279 having = Some(self.parse_or_expr()?);
280 }
281
282 let mut order_by = None;
283 if self.match_keyword("ORDER") {
284 self.expect("keyword", Some("BY"))?;
285 order_by = Some(self.parse_order_by()?);
286 }
287
288 let mut limit = None;
289 if self.match_keyword("LIMIT") {
290 let t = self.expect("number", None)?;
291 limit = Some(t.value.parse::<i64>().map_err(|_| {
292 MdqlError::QueryParse(format!("Invalid LIMIT value: {}", t.value))
293 })?);
294 }
295
296 Ok(SelectQuery {
297 columns,
298 table,
299 table_alias,
300 subquery,
301 joins,
302 where_clause,
303 group_by,
304 having,
305 order_by,
306 limit,
307 ctes: vec![],
308 })
309 }
310
311 fn parse_insert(&mut self) -> Result<InsertQuery, MdqlError> {
312 self.expect("keyword", Some("INSERT"))?;
313 self.expect("keyword", Some("INTO"))?;
314 let table = self.parse_ident()?;
315
316 self.expect("op", Some("("))?;
317 let mut columns = vec![self.parse_ident()?];
318 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
319 self.advance();
320 columns.push(self.parse_ident()?);
321 }
322 self.expect("op", Some(")"))?;
323
324 self.expect("keyword", Some("VALUES"))?;
325
326 self.expect("op", Some("("))?;
327 let mut values = vec![self.parse_value()?];
328 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
329 self.advance();
330 values.push(self.parse_value()?);
331 }
332 self.expect("op", Some(")"))?;
333
334 if columns.len() != values.len() {
335 return Err(MdqlError::QueryParse(format!(
336 "Column count ({}) does not match value count ({})",
337 columns.len(),
338 values.len()
339 )));
340 }
341
342 self.expect_end()?;
343 Ok(InsertQuery {
344 table,
345 columns,
346 values,
347 })
348 }
349
350 fn parse_update(&mut self) -> Result<UpdateQuery, MdqlError> {
351 self.expect("keyword", Some("UPDATE"))?;
352 let table = self.parse_ident()?;
353 self.expect("keyword", Some("SET"))?;
354
355 let mut assignments = Vec::new();
356 let col = self.parse_ident()?;
357 self.expect("op", Some("="))?;
358 let val = self.parse_value()?;
359 assignments.push((col, val));
360
361 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
362 self.advance();
363 let col = self.parse_ident()?;
364 self.expect("op", Some("="))?;
365 let val = self.parse_value()?;
366 assignments.push((col, val));
367 }
368
369 let mut where_clause = None;
370 if self.match_keyword("WHERE") {
371 where_clause = Some(self.parse_or_expr()?);
372 }
373
374 self.expect_end()?;
375 Ok(UpdateQuery {
376 table,
377 assignments,
378 where_clause,
379 })
380 }
381
382 fn parse_delete(&mut self) -> Result<DeleteQuery, MdqlError> {
383 self.expect("keyword", Some("DELETE"))?;
384 self.expect("keyword", Some("FROM"))?;
385 let table = self.parse_ident()?;
386
387 let mut where_clause = None;
388 if self.match_keyword("WHERE") {
389 where_clause = Some(self.parse_or_expr()?);
390 }
391
392 let mode = if self.match_keyword("CASCADE") {
393 DeleteMode::Cascade
394 } else if self.match_keyword("RESTRICT") {
395 DeleteMode::Restrict
396 } else {
397 DeleteMode::Default
398 };
399
400 self.expect_end()?;
401 Ok(DeleteQuery {
402 table,
403 where_clause,
404 mode,
405 })
406 }
407
408 fn parse_alter(&mut self) -> Result<Statement, MdqlError> {
409 self.expect("keyword", Some("ALTER"))?;
410 self.expect("keyword", Some("TABLE"))?;
411 let table = self.parse_ident()?;
412
413 let t = self.peek().ok_or_else(|| {
414 MdqlError::QueryParse("Expected RENAME, DROP, or MERGE after table name".into())
415 })?;
416
417 match (t.token_type.as_str(), t.value.as_str()) {
418 ("keyword", "RENAME") => {
419 self.advance();
420 self.expect("keyword", Some("FIELD"))?;
421 let old_name = self.parse_string_or_ident()?;
422 self.expect("keyword", Some("TO"))?;
423 let new_name = self.parse_string_or_ident()?;
424 self.expect_end()?;
425 Ok(Statement::AlterRename(AlterRenameFieldQuery {
426 table,
427 old_name,
428 new_name,
429 }))
430 }
431 ("keyword", "DROP") => {
432 self.advance();
433 self.expect("keyword", Some("FIELD"))?;
434 let field_name = self.parse_string_or_ident()?;
435 self.expect_end()?;
436 Ok(Statement::AlterDrop(AlterDropFieldQuery {
437 table,
438 field_name,
439 }))
440 }
441 ("keyword", "MERGE") => {
442 self.advance();
443 self.expect("keyword", Some("FIELDS"))?;
444 let mut sources = vec![self.parse_string_or_ident()?];
445 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
446 self.advance();
447 sources.push(self.parse_string_or_ident()?);
448 }
449 self.expect("keyword", Some("INTO"))?;
450 let target = self.parse_string_or_ident()?;
451 self.expect_end()?;
452 Ok(Statement::AlterMerge(AlterMergeFieldsQuery {
453 table,
454 sources,
455 into: target,
456 }))
457 }
458 _ => Err(MdqlError::QueryParse(format!(
459 "Expected RENAME, DROP, or MERGE, got '{}'",
460 t.raw
461 ))),
462 }
463 }
464
465 fn parse_create_view(&mut self) -> Result<Statement, MdqlError> {
466 self.expect("keyword", Some("CREATE"))?;
467 self.expect("keyword", Some("VIEW"))?;
468 let view_name = self.parse_ident()?;
469
470 let columns = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
471 self.advance();
472 let mut cols = vec![self.parse_ident()?];
473 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
474 self.advance();
475 cols.push(self.parse_ident()?);
476 }
477 self.expect("op", Some(")"))?;
478 Some(cols)
479 } else {
480 None
481 };
482
483 self.expect("keyword", Some("AS"))?;
484 let query = Box::new(self.parse_select()?);
485 self.expect_end()?;
486
487 Ok(Statement::CreateView(CreateViewQuery {
488 view_name,
489 columns,
490 query,
491 }))
492 }
493
494 fn parse_drop_view(&mut self) -> Result<Statement, MdqlError> {
495 self.expect("keyword", Some("DROP"))?;
496 self.expect("keyword", Some("VIEW"))?;
497 let view_name = self.parse_ident()?;
498 self.expect_end()?;
499 Ok(Statement::DropView(DropViewQuery { view_name }))
500 }
501
502 fn parse_string_or_ident(&mut self) -> Result<String, MdqlError> {
503 let t = self.peek().ok_or_else(|| {
504 MdqlError::QueryParse("Expected field name, got end of query".into())
505 })?;
506 match t.token_type.as_str() {
507 "string" => {
508 let v = self.advance().value;
509 Ok(v)
510 }
511 "ident" | "keyword" => {
512 let v = self.advance().value;
513 Ok(v)
514 }
515 _ => Err(MdqlError::QueryParse(format!(
516 "Expected field name, got '{}'",
517 t.raw
518 ))),
519 }
520 }
521
522 fn parse_columns(&mut self) -> Result<ColumnList, MdqlError> {
523 if let Some(t) = self.peek() {
524 if t.token_type == "op" && t.value == "*" {
525 self.advance();
526 return Ok(ColumnList::All);
527 }
528 }
529
530 let mut exprs = vec![self.parse_select_expr()?];
531 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
532 self.advance();
533 exprs.push(self.parse_select_expr()?);
534 }
535 Ok(ColumnList::Named(exprs))
536 }
537
538 fn peek_is_window_func(&self) -> bool {
539 let t = match self.peek() {
540 Some(t) => t,
541 None => return false,
542 };
543 let name_upper = t.value.to_uppercase();
544 if !WINDOW_FUNCS.contains(&name_upper.as_str()) {
545 return false;
546 }
547 self.tokens
548 .get(self.pos + 1)
549 .map_or(false, |next| next.token_type == "op" && next.value == "(")
550 }
551
552 fn peek_is_agg_func(&self) -> bool {
553 let t = match self.peek() {
554 Some(t) => t,
555 None => return false,
556 };
557 let name_upper = t.value.to_uppercase();
558 if !AGG_FUNCS.contains(&name_upper.as_str()) {
559 return false;
560 }
561 self.tokens
563 .get(self.pos + 1)
564 .map_or(false, |next| next.token_type == "op" && next.value == "(")
565 }
566
567 fn parse_select_expr(&mut self) -> Result<SelectExpr, MdqlError> {
568 let _t = self.peek().ok_or_else(|| {
569 MdqlError::QueryParse("Expected column or aggregate, got end of query".into())
570 })?;
571
572 let expr = self.parse_additive()?;
573
574 let alias = if self.match_keyword("AS") {
575 Some(self.parse_ident()?)
576 } else if self.peek().map_or(false, |t| {
577 t.token_type == "ident" && !self.is_clause_keyword(t)
578 }) {
579 Some(self.advance().value)
580 } else {
581 None
582 };
583
584 if let Expr::Aggregate { func, arg, arg_expr } = expr {
586 return Ok(SelectExpr::Aggregate {
587 func,
588 arg,
589 arg_expr: arg_expr.map(|e| *e),
590 alias,
591 });
592 }
593
594 if alias.is_none() {
595 if let Expr::Column(name) = &expr {
596 return Ok(SelectExpr::Column(name.clone()));
597 }
598 }
599
600 Ok(SelectExpr::Expr { expr, alias })
601 }
602
603 fn peek_is_additive_op(&self) -> bool {
606 self.peek().map_or(false, |t| {
607 t.token_type == "op" && (t.value == "+" || t.value == "-")
608 })
609 }
610
611 fn peek_is_multiplicative_op(&self) -> bool {
612 self.peek().map_or(false, |t| {
613 t.token_type == "op" && (t.value == "*" || t.value == "/" || t.value == "%")
614 })
615 }
616
617 fn parse_additive(&mut self) -> Result<Expr, MdqlError> {
618 let mut left = self.parse_multiplicative()?;
619 while self.peek_is_additive_op() {
620 let op_tok = self.advance();
621 let is_sub = op_tok.value == "-";
622
623 if self.peek().map_or(false, |t| t.token_type == "keyword" && t.value == "INTERVAL") {
625 self.advance(); let days_expr = self.parse_multiplicative()?;
627 if !self.match_keyword("DAY") && !self.match_keyword("DAYS") {
629 return Err(MdqlError::QueryParse("Expected DAY after INTERVAL value".into()));
630 }
631 let days = if is_sub {
632 Expr::UnaryMinus(Box::new(days_expr))
633 } else {
634 days_expr
635 };
636 left = Expr::DateAdd {
637 date: Box::new(left),
638 days: Box::new(days),
639 };
640 continue;
641 }
642
643 let op = match op_tok.value.as_str() {
644 "+" => ArithOp::Add,
645 "-" => ArithOp::Sub,
646 _ => unreachable!(),
647 };
648 let right = self.parse_multiplicative()?;
649 left = Expr::BinaryOp {
650 left: Box::new(left),
651 op,
652 right: Box::new(right),
653 };
654 }
655 Ok(left)
656 }
657
658 fn parse_multiplicative(&mut self) -> Result<Expr, MdqlError> {
659 let mut left = self.parse_unary()?;
660 while self.peek_is_multiplicative_op() {
661 let op_tok = self.advance();
662 let op = match op_tok.value.as_str() {
663 "*" => ArithOp::Mul,
664 "/" => ArithOp::Div,
665 "%" => ArithOp::Mod,
666 _ => unreachable!(),
667 };
668 let right = self.parse_unary()?;
669 left = Expr::BinaryOp {
670 left: Box::new(left),
671 op,
672 right: Box::new(right),
673 };
674 }
675 Ok(left)
676 }
677
678 fn parse_unary(&mut self) -> Result<Expr, MdqlError> {
679 if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "-") {
680 self.advance();
681 let inner = self.parse_atom()?;
682 match inner {
684 Expr::Literal(SqlValue::Int(n)) => Ok(Expr::Literal(SqlValue::Int(-n))),
685 Expr::Literal(SqlValue::Float(f)) => Ok(Expr::Literal(SqlValue::Float(-f))),
686 _ => Ok(Expr::UnaryMinus(Box::new(inner))),
687 }
688 } else {
689 self.parse_atom()
690 }
691 }
692
693 fn parse_atom(&mut self) -> Result<Expr, MdqlError> {
694 if self.peek_is_window_func() {
695 return self.parse_standalone_window();
696 }
697
698 if self.peek_is_agg_func() {
699 let agg = self.parse_agg_expr()?;
700 if self.peek().map_or(false, |t| t.token_type == "keyword" && t.value == "OVER") {
701 self.advance();
702 let over = self.parse_window_spec()?;
703 if let Expr::Aggregate { func, arg, arg_expr } = agg {
704 return Ok(Expr::Window {
705 func: WindowFunc::Agg(func),
706 args: if arg == "*" {
707 vec![]
708 } else {
709 vec![arg_expr.map(|e| *e).unwrap_or(Expr::Column(arg))]
710 },
711 over,
712 });
713 }
714 }
715 return Ok(agg);
716 }
717
718 let t = self.peek().ok_or_else(|| {
719 MdqlError::QueryParse("Expected expression, got end of query".into())
720 })?;
721
722 match t.token_type.as_str() {
723 "number" => {
724 let v = self.advance().value;
725 if v.contains('.') {
726 let f: f64 = v.parse().map_err(|_| {
727 MdqlError::QueryParse(format!("Invalid float: {}", v))
728 })?;
729 Ok(Expr::Literal(SqlValue::Float(f)))
730 } else {
731 let n: i64 = v.parse().map_err(|_| {
732 MdqlError::QueryParse(format!("Invalid int: {}", v))
733 })?;
734 Ok(Expr::Literal(SqlValue::Int(n)))
735 }
736 }
737 "string" => {
738 let v = self.advance().value;
739 Ok(Expr::Literal(SqlValue::String(v)))
740 }
741 "keyword" if t.value == "NULL" => {
742 self.advance();
743 Ok(Expr::Literal(SqlValue::Null))
744 }
745 "keyword" if t.value == "CASE" => {
746 self.parse_case_expr()
747 }
748 "keyword" if t.value == "CURRENT_DATE" => {
749 self.advance();
750 Ok(Expr::CurrentDate)
751 }
752 "keyword" if t.value == "CURRENT_TIMESTAMP" => {
753 self.advance();
754 Ok(Expr::CurrentTimestamp)
755 }
756 "keyword" if t.value == "DATEDIFF" => {
757 self.advance();
758 self.expect("op", Some("("))?;
759 let left = self.parse_additive()?;
760 self.expect("op", Some(","))?;
761 let right = self.parse_additive()?;
762 self.expect("op", Some(")"))?;
763 Ok(Expr::DateDiff { left: Box::new(left), right: Box::new(right) })
764 }
765 "op" if t.value == "(" => {
766 let next_is_select = self.tokens.get(self.pos + 1)
767 .map_or(false, |t| t.token_type == "keyword" && t.value == "SELECT");
768 if next_is_select {
769 self.advance();
770 let sq = self.parse_select()?;
771 self.expect("op", Some(")"))?;
772 Ok(Expr::Subquery(Box::new(sq)))
773 } else {
774 self.advance();
775 let expr = self.parse_additive()?;
776 self.expect("op", Some(")"))?;
777 Ok(expr)
778 }
779 }
780 "ident" => {
781 let name = self.advance().value;
782 Ok(Expr::Column(name))
783 }
784 "keyword" if !Self::is_reserved_keyword(&t.value) => {
785 let name = self.advance().value;
786 Ok(Expr::Column(name))
787 }
788 _ => Err(MdqlError::QueryParse(format!(
789 "Expected expression, got '{}'",
790 t.raw
791 ))),
792 }
793 }
794
795 fn parse_case_expr(&mut self) -> Result<Expr, MdqlError> {
796 self.expect("keyword", Some("CASE"))?;
797 let mut whens = Vec::new();
798 while self.match_keyword("WHEN") {
799 let condition = self.parse_or_expr()?;
800 self.expect("keyword", Some("THEN"))?;
801 let result = self.parse_additive()?;
802 whens.push((condition, Box::new(result)));
803 }
804 if whens.is_empty() {
805 return Err(MdqlError::QueryParse("CASE requires at least one WHEN clause".into()));
806 }
807 let else_expr = if self.match_keyword("ELSE") {
808 Some(Box::new(self.parse_additive()?))
809 } else {
810 None
811 };
812 self.expect("keyword", Some("END"))?;
813 Ok(Expr::Case { whens, else_expr })
814 }
815
816 fn parse_agg_expr(&mut self) -> Result<Expr, MdqlError> {
817 let func_name = self.advance().value.to_uppercase();
818 let func = match func_name.as_str() {
819 "COUNT" => AggFunc::Count,
820 "SUM" => AggFunc::Sum,
821 "AVG" => AggFunc::Avg,
822 "MIN" => AggFunc::Min,
823 "MAX" => AggFunc::Max,
824 _ => unreachable!(),
825 };
826 self.expect("op", Some("("))?;
827 let (arg, arg_expr) = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "*") {
828 self.advance();
829 ("*".to_string(), None)
830 } else {
831 let expr = self.parse_additive()?;
832 if let Expr::Column(name) = &expr {
833 (name.clone(), None)
834 } else {
835 (expr.display_name(), Some(Box::new(expr)))
836 }
837 };
838 self.expect("op", Some(")"))?;
839 Ok(Expr::Aggregate { func, arg, arg_expr })
840 }
841
842 fn parse_standalone_window(&mut self) -> Result<Expr, MdqlError> {
843 let func_name = self.advance().value.to_uppercase();
844 let func = match func_name.as_str() {
845 "ROW_NUMBER" => WindowFunc::RowNumber,
846 "RANK" => WindowFunc::Rank,
847 "DENSE_RANK" => WindowFunc::DenseRank,
848 "LAG" => WindowFunc::Lag,
849 "LEAD" => WindowFunc::Lead,
850 _ => unreachable!(),
851 };
852 self.expect("op", Some("("))?;
853 let mut args = Vec::new();
854 if !self.peek().map_or(false, |t| t.token_type == "op" && t.value == ")") {
855 args.push(self.parse_additive()?);
856 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
857 self.advance();
858 args.push(self.parse_additive()?);
859 }
860 }
861 self.expect("op", Some(")"))?;
862 self.expect("keyword", Some("OVER"))?;
863 let over = self.parse_window_spec()?;
864 Ok(Expr::Window { func, args, over })
865 }
866
867 fn parse_window_spec(&mut self) -> Result<WindowSpec, MdqlError> {
868 self.expect("op", Some("("))?;
869 let mut partition_by = Vec::new();
870 if self.match_keyword("PARTITION") {
871 self.expect("keyword", Some("BY"))?;
872 partition_by.push(self.parse_ident()?);
873 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
874 self.advance();
875 partition_by.push(self.parse_ident()?);
876 }
877 }
878 let mut order_by = Vec::new();
879 if self.match_keyword("ORDER") {
880 self.expect("keyword", Some("BY"))?;
881 order_by = self.parse_order_by()?;
882 }
883 self.expect("op", Some(")"))?;
884 Ok(WindowSpec { partition_by, order_by })
885 }
886
887 fn parse_ident(&mut self) -> Result<String, MdqlError> {
888 let t = self.peek().ok_or_else(|| {
889 MdqlError::QueryParse("Expected identifier, got end of query".into())
890 })?;
891 match t.token_type.as_str() {
892 "ident" | "keyword" => {
893 let v = self.advance().value;
894 Ok(v)
895 }
896 _ => Err(MdqlError::QueryParse(format!(
897 "Expected identifier, got '{}'",
898 t.raw
899 ))),
900 }
901 }
902
903 fn parse_or_expr(&mut self) -> Result<WhereClause, MdqlError> {
904 let mut left = self.parse_and_expr()?;
905 while self.match_keyword("OR") {
906 let right = self.parse_and_expr()?;
907 left = WhereClause::BoolOp(BoolOp {
908 op: BoolOpKind::Or,
909 left: Box::new(left),
910 right: Box::new(right),
911 });
912 }
913 Ok(left)
914 }
915
916 fn parse_and_expr(&mut self) -> Result<WhereClause, MdqlError> {
917 let mut left = self.parse_comparison()?;
918 while self.match_keyword("AND") {
919 let right = self.parse_comparison()?;
920 left = WhereClause::BoolOp(BoolOp {
921 op: BoolOpKind::And,
922 left: Box::new(left),
923 right: Box::new(right),
924 });
925 }
926 Ok(left)
927 }
928
929 fn parse_comparison(&mut self) -> Result<WhereClause, MdqlError> {
930 if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
932 let saved_pos = self.pos;
934 self.advance();
935 let result = self.parse_or_expr();
937 if result.is_ok() && self.peek().map_or(false, |t| t.token_type == "op" && t.value == ")") {
938 self.advance();
939 return result;
940 }
941 self.pos = saved_pos;
943 }
944
945 let left_expr = self.parse_additive()?;
947
948 let col = left_expr.as_column().unwrap_or("").to_string();
950
951 if self.match_keyword("IS") {
953 if self.match_keyword("NOT") {
954 self.expect("keyword", Some("NULL"))?;
955 return Ok(WhereClause::Comparison(Comparison {
956 column: col,
957 op: CmpOp::IsNotNull,
958 value: None,
959 left_expr: Some(left_expr),
960 right_expr: None,
961 }));
962 }
963 self.expect("keyword", Some("NULL"))?;
964 return Ok(WhereClause::Comparison(Comparison {
965 column: col,
966 op: CmpOp::IsNull,
967 value: None,
968 left_expr: Some(left_expr),
969 right_expr: None,
970 }));
971 }
972
973 if self.match_keyword("IN") {
975 self.expect("op", Some("("))?;
976 let is_subquery = self.peek().map_or(false, |t| t.token_type == "keyword" && t.value == "SELECT");
977 if is_subquery {
978 let sq = self.parse_select()?;
979 self.expect("op", Some(")"))?;
980 return Ok(WhereClause::Comparison(Comparison {
981 column: col,
982 op: CmpOp::In,
983 value: None,
984 left_expr: Some(left_expr),
985 right_expr: Some(Expr::Subquery(Box::new(sq))),
986 }));
987 }
988 let mut values = vec![self.parse_value()?];
989 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
990 self.advance();
991 values.push(self.parse_value()?);
992 }
993 self.expect("op", Some(")"))?;
994 return Ok(WhereClause::Comparison(Comparison {
995 column: col,
996 op: CmpOp::In,
997 value: Some(SqlValue::List(values)),
998 left_expr: Some(left_expr),
999 right_expr: None,
1000 }));
1001 }
1002
1003 if self.match_keyword("LIKE") {
1005 let val = self.parse_value()?;
1006 return Ok(WhereClause::Comparison(Comparison {
1007 column: col,
1008 op: CmpOp::Like,
1009 value: Some(val),
1010 left_expr: Some(left_expr),
1011 right_expr: None,
1012 }));
1013 }
1014
1015 if self.match_keyword("NOT") {
1017 if self.match_keyword("LIKE") {
1018 let val = self.parse_value()?;
1019 return Ok(WhereClause::Comparison(Comparison {
1020 column: col,
1021 op: CmpOp::NotLike,
1022 value: Some(val),
1023 left_expr: Some(left_expr),
1024 right_expr: None,
1025 }));
1026 }
1027 return Err(MdqlError::QueryParse("Expected LIKE after NOT".into()));
1028 }
1029
1030 if let Some(t) = self.peek() {
1032 if t.token_type == "op" && ["=", "!=", "<", ">", "<=", ">="].contains(&t.value.as_str())
1033 {
1034 let op_str = self.advance().value;
1035 let op = match op_str.as_str() {
1036 "=" => CmpOp::Eq,
1037 "!=" => CmpOp::Ne,
1038 "<" => CmpOp::Lt,
1039 ">" => CmpOp::Gt,
1040 "<=" => CmpOp::Le,
1041 ">=" => CmpOp::Ge,
1042 _ => unreachable!(),
1043 };
1044 let right_expr = self.parse_additive()?;
1046 let value = match &right_expr {
1048 Expr::Literal(v) => Some(v.clone()),
1049 _ => None,
1050 };
1051 return Ok(WhereClause::Comparison(Comparison {
1052 column: col,
1053 op,
1054 value,
1055 left_expr: Some(left_expr),
1056 right_expr: Some(right_expr),
1057 }));
1058 }
1059 }
1060
1061 let got = self.peek().map_or("end".to_string(), |t| t.raw.clone());
1062 Err(MdqlError::QueryParse(format!(
1063 "Expected operator after '{}', got '{}'",
1064 left_expr.display_name(), got
1065 )))
1066 }
1067
1068 fn parse_value(&mut self) -> Result<SqlValue, MdqlError> {
1069 let t = self.peek().ok_or_else(|| {
1070 MdqlError::QueryParse("Expected value, got end of query".into())
1071 })?;
1072 match t.token_type.as_str() {
1073 "string" => {
1074 let v = self.advance().value;
1075 Ok(SqlValue::String(v))
1076 }
1077 "number" => {
1078 let v = self.advance().value;
1079 if v.contains('.') {
1080 Ok(SqlValue::Float(v.parse().map_err(|_| {
1081 MdqlError::QueryParse(format!("Invalid float: {}", v))
1082 })?))
1083 } else {
1084 Ok(SqlValue::Int(v.parse().map_err(|_| {
1085 MdqlError::QueryParse(format!("Invalid int: {}", v))
1086 })?))
1087 }
1088 }
1089 "keyword" if t.value == "NULL" => {
1090 self.advance();
1091 Ok(SqlValue::Null)
1092 }
1093 _ => Err(MdqlError::QueryParse(format!(
1094 "Expected value, got '{}'",
1095 t.raw
1096 ))),
1097 }
1098 }
1099
1100 fn parse_order_by(&mut self) -> Result<Vec<OrderSpec>, MdqlError> {
1101 let mut specs = vec![self.parse_order_spec()?];
1102 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
1103 self.advance();
1104 specs.push(self.parse_order_spec()?);
1105 }
1106 Ok(specs)
1107 }
1108
1109 fn parse_order_spec(&mut self) -> Result<OrderSpec, MdqlError> {
1110 let expr = self.parse_additive()?;
1111 let col = expr.as_column().unwrap_or("").to_string();
1112 let descending = if self.match_keyword("DESC") {
1113 true
1114 } else {
1115 self.match_keyword("ASC");
1116 false
1117 };
1118 Ok(OrderSpec {
1119 column: col,
1120 expr: Some(expr),
1121 descending,
1122 })
1123 }
1124
1125 fn is_clause_keyword(&self, t: &Token) -> bool {
1126 t.token_type == "keyword"
1127 && ["WHERE", "ORDER", "LIMIT", "JOIN", "LEFT", "ON", "GROUP"].contains(&t.value.as_str())
1128 }
1129
1130 fn is_reserved_keyword(kw: &str) -> bool {
1132 matches!(kw,
1133 "AS" | "FROM" | "WHERE" | "AND" | "OR" | "ORDER" | "BY"
1134 | "ASC" | "DESC" | "LIMIT" | "JOIN" | "ON" | "GROUP"
1135 | "SELECT" | "INSERT" | "INTO" | "VALUES" | "UPDATE" | "SET"
1136 | "DELETE" | "ALTER" | "TABLE" | "IS" | "NOT" | "IN" | "LIKE"
1137 | "RENAME" | "FIELD" | "TO" | "DROP" | "MERGE" | "FIELDS"
1138 | "CASE" | "WHEN" | "THEN" | "ELSE" | "END"
1139 | "HAVING" | "INTERVAL" | "DAY" | "DAYS"
1140 | "CURRENT_DATE" | "CURRENT_TIMESTAMP" | "DATEDIFF"
1141 | "CREATE" | "VIEW" | "CASCADE" | "RESTRICT"
1142 | "WITH"
1143 | "OVER" | "PARTITION" | "ROW_NUMBER" | "RANK" | "DENSE_RANK" | "LAG" | "LEAD"
1144 )
1145 }
1146
1147 fn expect_end(&self) -> Result<(), MdqlError> {
1148 if let Some(t) = self.peek() {
1149 return Err(MdqlError::QueryParse(format!(
1150 "Unexpected token '{}' at position {}",
1151 t.raw, self.pos
1152 )));
1153 }
1154 Ok(())
1155 }
1156}
1157
1158pub fn parse_query(sql: &str) -> crate::errors::Result<Statement> {
1159 let tokens = tokenize(sql);
1160 if tokens.is_empty() {
1161 return Err(MdqlError::QueryParse("Empty query".into()));
1162 }
1163 let mut parser = Parser::new(tokens);
1164 parser.parse_statement()
1165}
1166
1167#[cfg(test)]
1168mod tests {
1169 use super::*;
1170
1171 #[test]
1172 fn test_simple_select() {
1173 let stmt = parse_query("SELECT title, status FROM strategies").unwrap();
1174 if let Statement::Select(q) = stmt {
1175 assert_eq!(q.columns, ColumnList::Named(vec![SelectExpr::Column("title".into()), SelectExpr::Column("status".into())]));
1176 assert_eq!(q.table, "strategies");
1177 } else {
1178 panic!("Expected Select");
1179 }
1180 }
1181
1182 #[test]
1183 fn test_select_star() {
1184 let stmt = parse_query("SELECT * FROM test").unwrap();
1185 if let Statement::Select(q) = stmt {
1186 assert_eq!(q.columns, ColumnList::All);
1187 } else {
1188 panic!("Expected Select");
1189 }
1190 }
1191
1192 #[test]
1193 fn test_where_clause() {
1194 let stmt = parse_query("SELECT title FROM test WHERE count > 5").unwrap();
1195 if let Statement::Select(q) = stmt {
1196 assert!(q.where_clause.is_some());
1197 } else {
1198 panic!("Expected Select");
1199 }
1200 }
1201
1202 #[test]
1203 fn test_order_by() {
1204 let stmt =
1205 parse_query("SELECT title FROM test ORDER BY composite DESC, title ASC").unwrap();
1206 if let Statement::Select(q) = stmt {
1207 let ob = q.order_by.unwrap();
1208 assert_eq!(ob.len(), 2);
1209 assert!(ob[0].descending);
1210 assert!(!ob[1].descending);
1211 } else {
1212 panic!("Expected Select");
1213 }
1214 }
1215
1216 #[test]
1217 fn test_limit() {
1218 let stmt = parse_query("SELECT * FROM test LIMIT 10").unwrap();
1219 if let Statement::Select(q) = stmt {
1220 assert_eq!(q.limit, Some(10));
1221 } else {
1222 panic!("Expected Select");
1223 }
1224 }
1225
1226 #[test]
1227 fn test_insert() {
1228 let stmt = parse_query(
1229 "INSERT INTO test (title, count) VALUES ('Hello', 42)",
1230 )
1231 .unwrap();
1232 if let Statement::Insert(q) = stmt {
1233 assert_eq!(q.table, "test");
1234 assert_eq!(q.columns, vec!["title", "count"]);
1235 assert_eq!(q.values[0], SqlValue::String("Hello".into()));
1236 assert_eq!(q.values[1], SqlValue::Int(42));
1237 } else {
1238 panic!("Expected Insert");
1239 }
1240 }
1241
1242 #[test]
1243 fn test_update() {
1244 let stmt = parse_query("UPDATE test SET status = 'KILLED' WHERE path = 'a.md'").unwrap();
1245 if let Statement::Update(q) = stmt {
1246 assert_eq!(q.table, "test");
1247 assert_eq!(q.assignments.len(), 1);
1248 assert!(q.where_clause.is_some());
1249 } else {
1250 panic!("Expected Update");
1251 }
1252 }
1253
1254 #[test]
1255 fn test_delete() {
1256 let stmt = parse_query("DELETE FROM test WHERE status = 'draft'").unwrap();
1257 if let Statement::Delete(q) = stmt {
1258 assert_eq!(q.table, "test");
1259 assert!(q.where_clause.is_some());
1260 } else {
1261 panic!("Expected Delete");
1262 }
1263 }
1264
1265 #[test]
1266 fn test_alter_rename() {
1267 let stmt =
1268 parse_query("ALTER TABLE test RENAME FIELD 'Summary' TO 'Overview'").unwrap();
1269 if let Statement::AlterRename(q) = stmt {
1270 assert_eq!(q.old_name, "Summary");
1271 assert_eq!(q.new_name, "Overview");
1272 } else {
1273 panic!("Expected AlterRename");
1274 }
1275 }
1276
1277 #[test]
1278 fn test_alter_drop() {
1279 let stmt = parse_query("ALTER TABLE test DROP FIELD 'Details'").unwrap();
1280 if let Statement::AlterDrop(q) = stmt {
1281 assert_eq!(q.field_name, "Details");
1282 } else {
1283 panic!("Expected AlterDrop");
1284 }
1285 }
1286
1287 #[test]
1288 fn test_alter_merge() {
1289 let stmt = parse_query(
1290 "ALTER TABLE test MERGE FIELDS 'Entry Rules', 'Exit Rules' INTO 'Trading Rules'",
1291 )
1292 .unwrap();
1293 if let Statement::AlterMerge(q) = stmt {
1294 assert_eq!(q.sources, vec!["Entry Rules", "Exit Rules"]);
1295 assert_eq!(q.into, "Trading Rules");
1296 } else {
1297 panic!("Expected AlterMerge");
1298 }
1299 }
1300
1301 #[test]
1302 fn test_backtick_ident() {
1303 let stmt = parse_query("SELECT `Structural Mechanism` FROM test").unwrap();
1304 if let Statement::Select(q) = stmt {
1305 assert_eq!(
1306 q.columns,
1307 ColumnList::Named(vec![SelectExpr::Column("Structural Mechanism".into())])
1308 );
1309 } else {
1310 panic!("Expected Select");
1311 }
1312 }
1313
1314 #[test]
1315 fn test_like_operator() {
1316 let stmt = parse_query("SELECT title FROM test WHERE categories LIKE '%defi%'").unwrap();
1317 if let Statement::Select(q) = stmt {
1318 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1319 assert_eq!(c.op, CmpOp::Like);
1320 assert_eq!(c.value, Some(SqlValue::String("%defi%".into())));
1321 } else {
1322 panic!("Expected LIKE comparison");
1323 }
1324 } else {
1325 panic!("Expected Select");
1326 }
1327 }
1328
1329 #[test]
1330 fn test_in_operator() {
1331 let stmt =
1332 parse_query("SELECT * FROM test WHERE status IN ('ACTIVE', 'LIVE')").unwrap();
1333 if let Statement::Select(q) = stmt {
1334 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1335 assert_eq!(c.op, CmpOp::In);
1336 } else {
1337 panic!("Expected IN comparison");
1338 }
1339 } else {
1340 panic!("Expected Select");
1341 }
1342 }
1343
1344 #[test]
1345 fn test_is_null() {
1346 let stmt = parse_query("SELECT * FROM test WHERE title IS NULL").unwrap();
1347 if let Statement::Select(q) = stmt {
1348 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1349 assert_eq!(c.op, CmpOp::IsNull);
1350 } else {
1351 panic!("Expected IS NULL comparison");
1352 }
1353 } else {
1354 panic!("Expected Select");
1355 }
1356 }
1357
1358 #[test]
1359 fn test_and_or() {
1360 let stmt = parse_query(
1361 "SELECT * FROM test WHERE status = 'ACTIVE' AND count > 5 OR title LIKE '%test%'",
1362 )
1363 .unwrap();
1364 if let Statement::Select(q) = stmt {
1365 assert!(q.where_clause.is_some());
1366 } else {
1367 panic!("Expected Select");
1368 }
1369 }
1370
1371 #[test]
1372 fn test_join() {
1373 let stmt = parse_query(
1374 "SELECT s.title, b.sharpe FROM strategies s JOIN backtests b ON b.strategy = s.path",
1375 )
1376 .unwrap();
1377 if let Statement::Select(q) = stmt {
1378 assert_eq!(q.table, "strategies");
1379 assert_eq!(q.table_alias, Some("s".into()));
1380 assert_eq!(q.joins.len(), 1);
1381 let join = &q.joins[0];
1382 assert_eq!(join.table, "backtests");
1383 assert_eq!(join.alias, Some("b".into()));
1384 } else {
1385 panic!("Expected Select");
1386 }
1387 }
1388
1389 #[test]
1390 fn test_multi_join() {
1391 let stmt = parse_query(
1392 "SELECT s.title, b.sharpe, c.verdict FROM strategies s JOIN backtests b ON b.strategy = s.path JOIN critiques c ON c.strategy = s.path",
1393 )
1394 .unwrap();
1395 if let Statement::Select(q) = stmt {
1396 assert_eq!(q.table, "strategies");
1397 assert_eq!(q.table_alias, Some("s".into()));
1398 assert_eq!(q.joins.len(), 2);
1399 assert_eq!(q.joins[0].table, "backtests");
1400 assert_eq!(q.joins[0].alias, Some("b".into()));
1401 assert_eq!(where_clause_to_sql(&q.joins[0].condition), "b.strategy = s.path");
1402 assert_eq!(q.joins[1].table, "critiques");
1403 assert_eq!(q.joins[1].alias, Some("c".into()));
1404 assert_eq!(where_clause_to_sql(&q.joins[1].condition), "c.strategy = s.path");
1405 } else {
1406 panic!("Expected Select");
1407 }
1408 }
1409
1410 #[test]
1411 fn test_left_join() {
1412 let stmt = parse_query(
1413 "SELECT s.title, b.sharpe FROM strategies s LEFT JOIN backtests b ON b.strategy = s.path",
1414 )
1415 .unwrap();
1416 if let Statement::Select(q) = stmt {
1417 assert_eq!(q.joins.len(), 1);
1418 assert_eq!(q.joins[0].join_type, JoinType::Left);
1419 assert_eq!(q.joins[0].table, "backtests");
1420 } else {
1421 panic!("Expected Select");
1422 }
1423 }
1424
1425 #[test]
1426 fn test_mixed_join_types() {
1427 let stmt = parse_query(
1428 "SELECT s.title FROM strategies s JOIN backtests b ON b.strategy = s.path LEFT JOIN allocations a ON a.strategy = s.path",
1429 )
1430 .unwrap();
1431 if let Statement::Select(q) = stmt {
1432 assert_eq!(q.joins.len(), 2);
1433 assert_eq!(q.joins[0].join_type, JoinType::Inner);
1434 assert_eq!(q.joins[1].join_type, JoinType::Left);
1435 } else {
1436 panic!("Expected Select");
1437 }
1438 }
1439
1440 #[test]
1441 fn test_join_compound_and() {
1442 let stmt = parse_query(
1443 "SELECT s.title FROM strategies s LEFT JOIN backtests b ON b.strategy = s.path AND b.mode = 'PAPER'",
1444 )
1445 .unwrap();
1446 if let Statement::Select(q) = stmt {
1447 assert_eq!(q.joins.len(), 1);
1448 assert_eq!(q.joins[0].join_type, JoinType::Left);
1449 let sql = where_clause_to_sql(&q.joins[0].condition);
1450 assert!(sql.contains("b.strategy = s.path"));
1451 assert!(sql.contains("AND"));
1452 assert!(sql.contains("b.mode = 'PAPER'"));
1453 } else {
1454 panic!("Expected Select");
1455 }
1456 }
1457
1458 #[test]
1459 fn test_join_compound_or() {
1460 let stmt = parse_query(
1461 "SELECT * FROM a JOIN b ON a.id = b.id OR a.alt = b.id",
1462 )
1463 .unwrap();
1464 if let Statement::Select(q) = stmt {
1465 let sql = where_clause_to_sql(&q.joins[0].condition);
1466 assert!(sql.contains("OR"));
1467 } else {
1468 panic!("Expected Select");
1469 }
1470 }
1471
1472 #[test]
1473 fn test_join_compound_with_where() {
1474 let stmt = parse_query(
1475 "SELECT s.title FROM strategies s JOIN backtests b ON b.strategy = s.path AND b.mode = 'PAPER' WHERE s.title = 'Alpha'",
1476 )
1477 .unwrap();
1478 if let Statement::Select(q) = stmt {
1479 assert_eq!(q.joins.len(), 1);
1480 assert!(q.where_clause.is_some());
1481 let join_sql = where_clause_to_sql(&q.joins[0].condition);
1482 assert!(join_sql.contains("AND"));
1483 } else {
1484 panic!("Expected Select");
1485 }
1486 }
1487
1488 #[test]
1489 fn test_empty_query() {
1490 assert!(parse_query("").is_err());
1491 }
1492
1493 #[test]
1494 fn test_count_star() {
1495 let stmt = parse_query("SELECT status, COUNT(*) AS cnt FROM strategies GROUP BY status").unwrap();
1496 if let Statement::Select(q) = stmt {
1497 if let ColumnList::Named(exprs) = &q.columns {
1498 assert_eq!(exprs.len(), 2);
1499 assert_eq!(exprs[0], SelectExpr::Column("status".into()));
1500 assert!(matches!(&exprs[1], SelectExpr::Aggregate {
1501 func: AggFunc::Count,
1502 arg,
1503 alias: Some(a),
1504 ..
1505 } if arg == "*" && a == "cnt"));
1506 } else {
1507 panic!("Expected Named columns");
1508 }
1509 assert_eq!(q.group_by, Some(vec!["status".into()]));
1510 } else {
1511 panic!("Expected Select");
1512 }
1513 }
1514
1515 #[test]
1516 fn test_count_column_as_ident() {
1517 let stmt = parse_query("INSERT INTO test (title, count) VALUES ('Hello', 42)").unwrap();
1519 if let Statement::Insert(q) = stmt {
1520 assert_eq!(q.columns, vec!["title", "count"]);
1521 } else {
1522 panic!("Expected Insert");
1523 }
1524 }
1525
1526 #[test]
1527 fn test_multiple_aggregates() {
1528 let stmt = parse_query("SELECT MIN(composite), MAX(composite), AVG(composite) FROM strategies").unwrap();
1529 if let Statement::Select(q) = stmt {
1530 if let ColumnList::Named(exprs) = &q.columns {
1531 assert_eq!(exprs.len(), 3);
1532 assert!(matches!(&exprs[0], SelectExpr::Aggregate { func: AggFunc::Min, .. }));
1533 assert!(matches!(&exprs[1], SelectExpr::Aggregate { func: AggFunc::Max, .. }));
1534 assert!(matches!(&exprs[2], SelectExpr::Aggregate { func: AggFunc::Avg, .. }));
1535 } else {
1536 panic!("Expected Named columns");
1537 }
1538 assert_eq!(q.group_by, None);
1539 } else {
1540 panic!("Expected Select");
1541 }
1542 }
1543
1544 #[test]
1547 fn test_select_arithmetic_expr() {
1548 let stmt = parse_query("SELECT a + b FROM test").unwrap();
1549 if let Statement::Select(q) = stmt {
1550 if let ColumnList::Named(exprs) = &q.columns {
1551 assert_eq!(exprs.len(), 1);
1552 assert!(matches!(&exprs[0], SelectExpr::Expr {
1553 expr: Expr::BinaryOp { op: ArithOp::Add, .. },
1554 alias: None,
1555 }));
1556 } else {
1557 panic!("Expected Named columns");
1558 }
1559 } else {
1560 panic!("Expected Select");
1561 }
1562 }
1563
1564 #[test]
1565 fn test_select_arithmetic_with_alias() {
1566 let stmt = parse_query("SELECT a + b AS total FROM test").unwrap();
1567 if let Statement::Select(q) = stmt {
1568 if let ColumnList::Named(exprs) = &q.columns {
1569 assert_eq!(exprs.len(), 1);
1570 assert!(matches!(&exprs[0], SelectExpr::Expr {
1571 alias: Some(a),
1572 ..
1573 } if a == "total"));
1574 assert_eq!(exprs[0].output_name(), "total");
1575 } else {
1576 panic!("Expected Named columns");
1577 }
1578 } else {
1579 panic!("Expected Select");
1580 }
1581 }
1582
1583 #[test]
1584 fn test_select_precedence() {
1585 let stmt = parse_query("SELECT a + b * c FROM test").unwrap();
1587 if let Statement::Select(q) = stmt {
1588 if let ColumnList::Named(exprs) = &q.columns {
1589 if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1590 if let Expr::BinaryOp { left, op, right } = expr {
1591 assert_eq!(*op, ArithOp::Add);
1592 assert!(matches!(left.as_ref(), Expr::Column(n) if n == "a"));
1593 assert!(matches!(right.as_ref(), Expr::BinaryOp { op: ArithOp::Mul, .. }));
1594 } else {
1595 panic!("Expected BinaryOp");
1596 }
1597 } else {
1598 panic!("Expected Expr variant");
1599 }
1600 } else {
1601 panic!("Expected Named columns");
1602 }
1603 } else {
1604 panic!("Expected Select");
1605 }
1606 }
1607
1608 #[test]
1609 fn test_select_parenthesized_expr() {
1610 let stmt = parse_query("SELECT (a + b) * c FROM test").unwrap();
1612 if let Statement::Select(q) = stmt {
1613 if let ColumnList::Named(exprs) = &q.columns {
1614 if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1615 if let Expr::BinaryOp { left, op, .. } = expr {
1616 assert_eq!(*op, ArithOp::Mul);
1617 assert!(matches!(left.as_ref(), Expr::BinaryOp { op: ArithOp::Add, .. }));
1618 } else {
1619 panic!("Expected BinaryOp");
1620 }
1621 } else {
1622 panic!("Expected Expr variant");
1623 }
1624 } else {
1625 panic!("Expected Named columns");
1626 }
1627 } else {
1628 panic!("Expected Select");
1629 }
1630 }
1631
1632 #[test]
1633 fn test_select_unary_minus() {
1634 let stmt = parse_query("SELECT -count FROM test").unwrap();
1635 if let Statement::Select(q) = stmt {
1636 if let ColumnList::Named(exprs) = &q.columns {
1637 assert!(matches!(&exprs[0], SelectExpr::Expr {
1638 expr: Expr::UnaryMinus(_),
1639 ..
1640 }));
1641 } else {
1642 panic!("Expected Named columns");
1643 }
1644 } else {
1645 panic!("Expected Select");
1646 }
1647 }
1648
1649 #[test]
1650 fn test_select_negative_literal() {
1651 let stmt = parse_query("SELECT -42 FROM test").unwrap();
1652 if let Statement::Select(q) = stmt {
1653 if let ColumnList::Named(exprs) = &q.columns {
1654 assert!(matches!(&exprs[0], SelectExpr::Expr {
1656 expr: Expr::Literal(SqlValue::Int(-42)),
1657 ..
1658 }));
1659 } else {
1660 panic!("Expected Named columns");
1661 }
1662 } else {
1663 panic!("Expected Select");
1664 }
1665 }
1666
1667 #[test]
1668 fn test_where_arithmetic_expr() {
1669 let stmt = parse_query("SELECT * FROM test WHERE a + b > 10").unwrap();
1670 if let Statement::Select(q) = stmt {
1671 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1672 assert_eq!(c.op, CmpOp::Gt);
1673 assert!(matches!(&c.left_expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1674 assert!(matches!(&c.right_expr, Some(Expr::Literal(SqlValue::Int(10)))));
1675 } else {
1676 panic!("Expected comparison");
1677 }
1678 } else {
1679 panic!("Expected Select");
1680 }
1681 }
1682
1683 #[test]
1684 fn test_where_both_sides_expr() {
1685 let stmt = parse_query("SELECT * FROM test WHERE a * 2 > b + 1").unwrap();
1686 if let Statement::Select(q) = stmt {
1687 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1688 assert_eq!(c.op, CmpOp::Gt);
1689 assert!(matches!(&c.left_expr, Some(Expr::BinaryOp { op: ArithOp::Mul, .. })));
1690 assert!(matches!(&c.right_expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1691 } else {
1692 panic!("Expected comparison");
1693 }
1694 } else {
1695 panic!("Expected Select");
1696 }
1697 }
1698
1699 #[test]
1700 fn test_order_by_expr() {
1701 let stmt = parse_query("SELECT * FROM test ORDER BY a + b DESC").unwrap();
1702 if let Statement::Select(q) = stmt {
1703 let ob = q.order_by.unwrap();
1704 assert_eq!(ob.len(), 1);
1705 assert!(ob[0].descending);
1706 assert!(matches!(&ob[0].expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1707 } else {
1708 panic!("Expected Select");
1709 }
1710 }
1711
1712 #[test]
1713 fn test_all_arithmetic_ops() {
1714 let stmt = parse_query("SELECT a + b, a - b, a * b, a / b, a % b FROM test").unwrap();
1715 if let Statement::Select(q) = stmt {
1716 if let ColumnList::Named(exprs) = &q.columns {
1717 assert_eq!(exprs.len(), 5);
1718 assert!(matches!(&exprs[0], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Add, .. }, .. }));
1719 assert!(matches!(&exprs[1], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Sub, .. }, .. }));
1720 assert!(matches!(&exprs[2], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Mul, .. }, .. }));
1721 assert!(matches!(&exprs[3], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Div, .. }, .. }));
1722 assert!(matches!(&exprs[4], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Mod, .. }, .. }));
1723 } else {
1724 panic!("Expected Named columns");
1725 }
1726 } else {
1727 panic!("Expected Select");
1728 }
1729 }
1730
1731 #[test]
1732 fn test_column_with_literal_arithmetic() {
1733 let stmt = parse_query("SELECT count * 2 + 1 FROM test").unwrap();
1734 if let Statement::Select(q) = stmt {
1735 if let ColumnList::Named(exprs) = &q.columns {
1736 if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1738 if let Expr::BinaryOp { left, op, right } = expr {
1739 assert_eq!(*op, ArithOp::Add);
1740 assert!(matches!(right.as_ref(), Expr::Literal(SqlValue::Int(1))));
1741 assert!(matches!(left.as_ref(), Expr::BinaryOp { op: ArithOp::Mul, .. }));
1742 } else {
1743 panic!("Expected BinaryOp");
1744 }
1745 } else {
1746 panic!("Expected Expr");
1747 }
1748 } else {
1749 panic!("Expected Named columns");
1750 }
1751 } else {
1752 panic!("Expected Select");
1753 }
1754 }
1755
1756 #[test]
1757 fn test_mixed_columns_and_exprs() {
1758 let stmt = parse_query("SELECT title, a + b AS sum, count FROM test").unwrap();
1759 if let Statement::Select(q) = stmt {
1760 if let ColumnList::Named(exprs) = &q.columns {
1761 assert_eq!(exprs.len(), 3);
1762 assert_eq!(exprs[0], SelectExpr::Column("title".into()));
1763 assert!(matches!(&exprs[1], SelectExpr::Expr { alias: Some(a), .. } if a == "sum"));
1764 assert_eq!(exprs[2], SelectExpr::Column("count".into()));
1765 } else {
1766 panic!("Expected Named columns");
1767 }
1768 } else {
1769 panic!("Expected Select");
1770 }
1771 }
1772
1773 #[test]
1776 fn test_case_when_basic() {
1777 let stmt = parse_query(
1778 "SELECT CASE WHEN status = 'ACTIVE' THEN 1 ELSE 0 END FROM test"
1779 ).unwrap();
1780 if let Statement::Select(q) = stmt {
1781 if let ColumnList::Named(exprs) = &q.columns {
1782 assert_eq!(exprs.len(), 1);
1783 assert!(matches!(&exprs[0], SelectExpr::Expr {
1784 expr: Expr::Case { .. },
1785 ..
1786 }));
1787 } else {
1788 panic!("Expected Named columns");
1789 }
1790 } else {
1791 panic!("Expected Select");
1792 }
1793 }
1794
1795 #[test]
1796 fn test_case_when_multiple_branches() {
1797 let stmt = parse_query(
1798 "SELECT CASE WHEN x > 10 THEN 'high' WHEN x > 5 THEN 'mid' ELSE 'low' END FROM test"
1799 ).unwrap();
1800 if let Statement::Select(q) = stmt {
1801 if let ColumnList::Named(exprs) = &q.columns {
1802 if let SelectExpr::Expr { expr: Expr::Case { whens, else_expr }, .. } = &exprs[0] {
1803 assert_eq!(whens.len(), 2);
1804 assert!(else_expr.is_some());
1805 } else {
1806 panic!("Expected Case expression");
1807 }
1808 } else {
1809 panic!("Expected Named columns");
1810 }
1811 } else {
1812 panic!("Expected Select");
1813 }
1814 }
1815
1816 #[test]
1817 fn test_case_when_no_else() {
1818 let stmt = parse_query(
1819 "SELECT CASE WHEN x = 1 THEN 'one' END FROM test"
1820 ).unwrap();
1821 if let Statement::Select(q) = stmt {
1822 if let ColumnList::Named(exprs) = &q.columns {
1823 if let SelectExpr::Expr { expr: Expr::Case { whens, else_expr }, .. } = &exprs[0] {
1824 assert_eq!(whens.len(), 1);
1825 assert!(else_expr.is_none());
1826 } else {
1827 panic!("Expected Case expression");
1828 }
1829 } else {
1830 panic!("Expected Named columns");
1831 }
1832 } else {
1833 panic!("Expected Select");
1834 }
1835 }
1836
1837 #[test]
1838 fn test_case_when_in_aggregate() {
1839 let stmt = parse_query(
1840 "SELECT SUM(CASE WHEN side = 'BUY' THEN size ELSE -size END) AS net FROM orders GROUP BY token"
1841 ).unwrap();
1842 if let Statement::Select(q) = stmt {
1843 if let ColumnList::Named(exprs) = &q.columns {
1844 assert_eq!(exprs.len(), 1);
1845 assert!(matches!(&exprs[0], SelectExpr::Aggregate {
1846 func: AggFunc::Sum,
1847 arg_expr: Some(Expr::Case { .. }),
1848 alias: Some(a),
1849 ..
1850 } if a == "net"));
1851 } else {
1852 panic!("Expected Named columns");
1853 }
1854 } else {
1855 panic!("Expected Select");
1856 }
1857 }
1858
1859 #[test]
1860 fn test_case_when_with_alias() {
1861 let stmt = parse_query(
1862 "SELECT CASE WHEN x > 0 THEN 'pos' ELSE 'neg' END AS sign FROM test"
1863 ).unwrap();
1864 if let Statement::Select(q) = stmt {
1865 if let ColumnList::Named(exprs) = &q.columns {
1866 assert!(matches!(&exprs[0], SelectExpr::Expr {
1867 expr: Expr::Case { .. },
1868 alias: Some(a),
1869 } if a == "sign"));
1870 } else {
1871 panic!("Expected Named columns");
1872 }
1873 } else {
1874 panic!("Expected Select");
1875 }
1876 }
1877
1878 #[test]
1879 fn test_create_view() {
1880 let stmt = parse_query("CREATE VIEW live AS SELECT * FROM strategies WHERE status = 'LIVE'").unwrap();
1881 if let Statement::CreateView(cv) = stmt {
1882 assert_eq!(cv.view_name, "live");
1883 assert!(cv.columns.is_none());
1884 assert_eq!(cv.query.table, "strategies");
1885 assert!(cv.query.where_clause.is_some());
1886 } else {
1887 panic!("Expected CreateView, got {:?}", stmt);
1888 }
1889 }
1890
1891 #[test]
1892 fn test_create_view_with_columns() {
1893 let stmt = parse_query("CREATE VIEW v1 (a, b) AS SELECT title, status FROM t").unwrap();
1894 if let Statement::CreateView(cv) = stmt {
1895 assert_eq!(cv.view_name, "v1");
1896 assert_eq!(cv.columns, Some(vec!["a".into(), "b".into()]));
1897 } else {
1898 panic!("Expected CreateView");
1899 }
1900 }
1901
1902 #[test]
1903 fn test_drop_view() {
1904 let stmt = parse_query("DROP VIEW live").unwrap();
1905 if let Statement::DropView(dv) = stmt {
1906 assert_eq!(dv.view_name, "live");
1907 } else {
1908 panic!("Expected DropView, got {:?}", stmt);
1909 }
1910 }
1911
1912 #[test]
1913 fn test_create_view_case_insensitive() {
1914 let stmt = parse_query("create view My_View as select * from t").unwrap();
1915 if let Statement::CreateView(cv) = stmt {
1916 assert_eq!(cv.view_name, "My_View");
1917 } else {
1918 panic!("Expected CreateView");
1919 }
1920 }
1921
1922 #[test]
1925 fn test_aggregate_division() {
1926 let stmt = parse_query(
1927 "SELECT token, SUM(sell) / SUM(buy) as ratio FROM orders GROUP BY token"
1928 ).unwrap();
1929 if let Statement::Select(q) = stmt {
1930 assert_eq!(q.group_by, Some(vec!["token".into()]));
1931 if let ColumnList::Named(exprs) = &q.columns {
1932 assert_eq!(exprs.len(), 2);
1933 assert!(exprs[1].is_aggregate());
1934 } else {
1935 panic!("Expected Named columns");
1936 }
1937 } else {
1938 panic!("Expected Select");
1939 }
1940 }
1941
1942 #[test]
1943 fn test_aggregate_subtraction() {
1944 let stmt = parse_query(
1945 "SELECT token, SUM(sell) - SUM(buy) as net FROM orders GROUP BY token"
1946 ).unwrap();
1947 if let Statement::Select(q) = stmt {
1948 if let ColumnList::Named(exprs) = &q.columns {
1949 assert_eq!(exprs[1].output_name(), "net");
1950 }
1951 } else {
1952 panic!("Expected Select");
1953 }
1954 }
1955
1956 #[test]
1957 fn test_create_view_with_arithmetic() {
1958 let stmt = parse_query(
1959 "CREATE VIEW positions AS SELECT token, SUM(sell) / SUM(buy) as ratio FROM orders GROUP BY token"
1960 ).unwrap();
1961 if let Statement::CreateView(cv) = stmt {
1962 assert_eq!(cv.view_name, "positions");
1963 } else {
1964 panic!("Expected CreateView, got {:?}", stmt);
1965 }
1966 }
1967
1968 #[test]
1971 fn test_subquery_in_from() {
1972 let stmt = parse_query(
1973 "SELECT token, sell_size FROM (SELECT token, SUM(size) as sell_size FROM orders GROUP BY token) LIMIT 5"
1974 ).unwrap();
1975 if let Statement::Select(q) = stmt {
1976 assert!(q.subquery.is_some());
1977 assert_eq!(q.limit, Some(5));
1978 let sub = q.subquery.unwrap();
1979 assert_eq!(sub.table, "orders");
1980 assert!(sub.group_by.is_some());
1981 } else {
1982 panic!("Expected Select");
1983 }
1984 }
1985
1986 #[test]
1989 fn test_create_view_with_having() {
1990 let stmt = parse_query(
1991 "CREATE VIEW positions AS SELECT token, SUM(sell) as sell_size, SUM(buy) as buy_size FROM orders GROUP BY token HAVING sell_size > buy_size"
1992 ).unwrap();
1993 if let Statement::CreateView(cv) = stmt {
1994 assert_eq!(cv.view_name, "positions");
1995 assert!(cv.query.having.is_some());
1996 } else {
1997 panic!("Expected CreateView, got {:?}", stmt);
1998 }
1999 }
2000
2001 #[test]
2004 fn test_aggregate_multiplication() {
2005 let stmt = parse_query(
2006 "SELECT SUM(a) * 2 as doubled FROM test"
2007 ).unwrap();
2008 if let Statement::Select(q) = stmt {
2009 if let ColumnList::Named(exprs) = &q.columns {
2010 assert_eq!(exprs.len(), 1);
2011 assert!(exprs[0].is_aggregate());
2012 assert_eq!(exprs[0].output_name(), "doubled");
2013 } else {
2014 panic!("Expected Named columns");
2015 }
2016 } else {
2017 panic!("Expected Select");
2018 }
2019 }
2020
2021 #[test]
2022 fn test_complex_aggregate_arithmetic() {
2023 let stmt = parse_query(
2024 "SELECT SUM(CASE WHEN side = 'SELL' THEN size ELSE 0 END) / SUM(CASE WHEN side = 'BUY' THEN size ELSE 0 END) as ratio FROM orders GROUP BY token"
2025 ).unwrap();
2026 if let Statement::Select(q) = stmt {
2027 if let ColumnList::Named(exprs) = &q.columns {
2028 assert_eq!(exprs.len(), 1);
2029 assert!(exprs[0].is_aggregate());
2030 assert_eq!(exprs[0].output_name(), "ratio");
2031 } else {
2032 panic!("Expected Named columns");
2033 }
2034 assert_eq!(q.group_by, Some(vec!["token".into()]));
2035 } else {
2036 panic!("Expected Select");
2037 }
2038 }
2039
2040 #[test]
2043 fn test_subquery_with_alias() {
2044 let stmt = parse_query(
2045 "SELECT x FROM (SELECT x FROM t) sub"
2046 ).unwrap();
2047 if let Statement::Select(q) = stmt {
2048 assert!(q.subquery.is_some());
2049 let sub = q.subquery.unwrap();
2050 assert_eq!(sub.table, "t");
2051 if let ColumnList::Named(exprs) = &q.columns {
2052 assert_eq!(exprs.len(), 1);
2053 assert_eq!(exprs[0].output_name(), "x");
2054 } else {
2055 panic!("Expected Named columns");
2056 }
2057 } else {
2058 panic!("Expected Select");
2059 }
2060 }
2061
2062 #[test]
2063 fn test_subquery_with_where() {
2064 let stmt = parse_query(
2065 "SELECT x FROM (SELECT x FROM t WHERE y > 0) LIMIT 5"
2066 ).unwrap();
2067 if let Statement::Select(q) = stmt {
2068 assert!(q.subquery.is_some());
2069 assert_eq!(q.limit, Some(5));
2070 let sub = q.subquery.unwrap();
2071 assert_eq!(sub.table, "t");
2072 assert!(sub.where_clause.is_some());
2073 } else {
2074 panic!("Expected Select");
2075 }
2076 }
2077
2078 #[test]
2081 fn test_create_view_aggregate_subtraction() {
2082 let stmt = parse_query(
2083 "CREATE VIEW v AS SELECT token, SUM(sell) - SUM(buy) as net FROM orders GROUP BY token"
2084 ).unwrap();
2085 if let Statement::CreateView(cv) = stmt {
2086 assert_eq!(cv.view_name, "v");
2087 assert_eq!(cv.query.group_by, Some(vec!["token".into()]));
2088 if let ColumnList::Named(exprs) = &cv.query.columns {
2089 assert_eq!(exprs.len(), 2);
2090 assert_eq!(exprs[1].output_name(), "net");
2091 assert!(exprs[1].is_aggregate());
2092 } else {
2093 panic!("Expected Named columns");
2094 }
2095 } else {
2096 panic!("Expected CreateView, got {:?}", stmt);
2097 }
2098 }
2099
2100 #[test]
2101 fn test_delete_cascade() {
2102 let stmt = parse_query("DELETE FROM strategies WHERE status = 'KILLED' CASCADE").unwrap();
2103 if let Statement::Delete(q) = stmt {
2104 assert_eq!(q.table, "strategies");
2105 assert!(q.where_clause.is_some());
2106 assert_eq!(q.mode, DeleteMode::Cascade);
2107 } else {
2108 panic!("Expected Delete");
2109 }
2110 }
2111
2112 #[test]
2113 fn test_delete_restrict() {
2114 let stmt = parse_query("DELETE FROM strategies WHERE path = 'alpha.md' RESTRICT").unwrap();
2115 if let Statement::Delete(q) = stmt {
2116 assert_eq!(q.table, "strategies");
2117 assert_eq!(q.mode, DeleteMode::Restrict);
2118 } else {
2119 panic!("Expected Delete");
2120 }
2121 }
2122
2123 #[test]
2124 fn test_delete_default_unchanged() {
2125 let stmt = parse_query("DELETE FROM strategies WHERE status = 'KILLED'").unwrap();
2126 if let Statement::Delete(q) = stmt {
2127 assert_eq!(q.mode, DeleteMode::Default);
2128 } else {
2129 panic!("Expected Delete");
2130 }
2131 }
2132
2133 #[test]
2134 fn test_delete_cascade_no_where() {
2135 let stmt = parse_query("DELETE FROM strategies CASCADE").unwrap();
2136 if let Statement::Delete(q) = stmt {
2137 assert_eq!(q.table, "strategies");
2138 assert!(q.where_clause.is_none());
2139 assert_eq!(q.mode, DeleteMode::Cascade);
2140 } else {
2141 panic!("Expected Delete");
2142 }
2143 }
2144
2145 #[test]
2148 fn test_cte_basic() {
2149 let stmt = parse_query(
2150 "WITH live AS (SELECT * FROM strategies WHERE status = 'LIVE') SELECT * FROM live"
2151 ).unwrap();
2152 if let Statement::Select(q) = stmt {
2153 assert_eq!(q.ctes.len(), 1);
2154 assert_eq!(q.ctes[0].name, "live");
2155 assert_eq!(q.ctes[0].query.table, "strategies");
2156 assert!(q.ctes[0].query.where_clause.is_some());
2157 assert_eq!(q.table, "live");
2158 } else {
2159 panic!("Expected Select");
2160 }
2161 }
2162
2163 #[test]
2164 fn test_cte_multi() {
2165 let stmt = parse_query(
2166 "WITH a AS (SELECT * FROM t1), b AS (SELECT * FROM t2) SELECT * FROM a JOIN b ON a.id = b.id"
2167 ).unwrap();
2168 if let Statement::Select(q) = stmt {
2169 assert_eq!(q.ctes.len(), 2);
2170 assert_eq!(q.ctes[0].name, "a");
2171 assert_eq!(q.ctes[0].query.table, "t1");
2172 assert_eq!(q.ctes[1].name, "b");
2173 assert_eq!(q.ctes[1].query.table, "t2");
2174 assert_eq!(q.table, "a");
2175 assert_eq!(q.joins.len(), 1);
2176 } else {
2177 panic!("Expected Select");
2178 }
2179 }
2180
2181 #[test]
2182 fn test_cte_with_aggregation() {
2183 let stmt = parse_query(
2184 "WITH totals AS (SELECT strategy, COUNT(*) AS cnt FROM backtests GROUP BY strategy) SELECT * FROM totals WHERE cnt > 1"
2185 ).unwrap();
2186 if let Statement::Select(q) = stmt {
2187 assert_eq!(q.ctes.len(), 1);
2188 assert_eq!(q.ctes[0].name, "totals");
2189 assert!(q.ctes[0].query.group_by.is_some());
2190 assert_eq!(q.table, "totals");
2191 assert!(q.where_clause.is_some());
2192 } else {
2193 panic!("Expected Select");
2194 }
2195 }
2196
2197 #[test]
2198 fn test_cte_no_ctes_on_plain_select() {
2199 let stmt = parse_query("SELECT * FROM t").unwrap();
2200 if let Statement::Select(q) = stmt {
2201 assert!(q.ctes.is_empty());
2202 } else {
2203 panic!("Expected Select");
2204 }
2205 }
2206
2207 #[test]
2210 fn test_where_in_subquery() {
2211 let stmt = parse_query(
2212 "SELECT * FROM strategies WHERE path IN (SELECT strategy FROM backtests)"
2213 ).unwrap();
2214 if let Statement::Select(q) = stmt {
2215 if let Some(WhereClause::Comparison(c)) = &q.where_clause {
2216 assert_eq!(c.op, CmpOp::In);
2217 assert!(matches!(&c.right_expr, Some(Expr::Subquery(_))));
2218 } else {
2219 panic!("Expected IN comparison");
2220 }
2221 } else {
2222 panic!("Expected Select");
2223 }
2224 }
2225
2226 #[test]
2227 fn test_scalar_subquery_in_where() {
2228 let stmt = parse_query(
2229 "SELECT * FROM backtests WHERE sharpe > (SELECT AVG(sharpe) FROM backtests)"
2230 ).unwrap();
2231 if let Statement::Select(q) = stmt {
2232 if let Some(WhereClause::Comparison(c)) = &q.where_clause {
2233 assert_eq!(c.op, CmpOp::Gt);
2234 assert!(matches!(&c.right_expr, Some(Expr::Subquery(_))));
2235 } else {
2236 panic!("Expected comparison");
2237 }
2238 } else {
2239 panic!("Expected Select");
2240 }
2241 }
2242
2243 #[test]
2244 fn test_scalar_subquery_in_select() {
2245 let stmt = parse_query(
2246 "SELECT title, (SELECT COUNT(*) FROM backtests) AS cnt FROM strategies"
2247 ).unwrap();
2248 if let Statement::Select(q) = stmt {
2249 if let ColumnList::Named(exprs) = &q.columns {
2250 assert_eq!(exprs.len(), 2);
2251 assert!(matches!(&exprs[1], SelectExpr::Expr {
2252 expr: Expr::Subquery(_),
2253 alias: Some(a),
2254 } if a == "cnt"));
2255 } else {
2256 panic!("Expected Named columns");
2257 }
2258 } else {
2259 panic!("Expected Select");
2260 }
2261 }
2262
2263 #[test]
2266 fn test_row_number_over_order_by() {
2267 let stmt = parse_query(
2268 "SELECT title, ROW_NUMBER() OVER (ORDER BY count DESC) AS rn FROM test"
2269 ).unwrap();
2270 if let Statement::Select(q) = stmt {
2271 if let ColumnList::Named(exprs) = &q.columns {
2272 assert_eq!(exprs.len(), 2);
2273 if let SelectExpr::Expr { expr: Expr::Window { func, args, over }, alias } = &exprs[1] {
2274 assert_eq!(*func, WindowFunc::RowNumber);
2275 assert!(args.is_empty());
2276 assert!(over.partition_by.is_empty());
2277 assert_eq!(over.order_by.len(), 1);
2278 assert!(over.order_by[0].descending);
2279 assert_eq!(alias.as_deref(), Some("rn"));
2280 } else {
2281 panic!("Expected Window expression, got {:?}", exprs[1]);
2282 }
2283 } else {
2284 panic!("Expected Named columns");
2285 }
2286 } else {
2287 panic!("Expected Select");
2288 }
2289 }
2290
2291 #[test]
2292 fn test_rank_with_partition_by() {
2293 let stmt = parse_query(
2294 "SELECT RANK() OVER (PARTITION BY category ORDER BY price DESC) AS rnk FROM test"
2295 ).unwrap();
2296 if let Statement::Select(q) = stmt {
2297 if let ColumnList::Named(exprs) = &q.columns {
2298 if let SelectExpr::Expr { expr: Expr::Window { func, over, .. }, .. } = &exprs[0] {
2299 assert_eq!(*func, WindowFunc::Rank);
2300 assert_eq!(over.partition_by, vec!["category"]);
2301 assert_eq!(over.order_by.len(), 1);
2302 } else {
2303 panic!("Expected Window expression");
2304 }
2305 } else {
2306 panic!("Expected Named columns");
2307 }
2308 } else {
2309 panic!("Expected Select");
2310 }
2311 }
2312
2313 #[test]
2314 fn test_agg_over_window() {
2315 let stmt = parse_query(
2316 "SELECT SUM(price) OVER (PARTITION BY category) AS cat_total FROM test"
2317 ).unwrap();
2318 if let Statement::Select(q) = stmt {
2319 if let ColumnList::Named(exprs) = &q.columns {
2320 if let SelectExpr::Expr { expr: Expr::Window { func, args, over }, alias } = &exprs[0] {
2321 assert!(matches!(func, WindowFunc::Agg(AggFunc::Sum)));
2322 assert_eq!(args.len(), 1);
2323 assert_eq!(over.partition_by, vec!["category"]);
2324 assert!(over.order_by.is_empty());
2325 assert_eq!(alias.as_deref(), Some("cat_total"));
2326 } else {
2327 panic!("Expected Window expression");
2328 }
2329 } else {
2330 panic!("Expected Named columns");
2331 }
2332 } else {
2333 panic!("Expected Select");
2334 }
2335 }
2336
2337 #[test]
2338 fn test_lag_with_args() {
2339 let stmt = parse_query(
2340 "SELECT LAG(price, 1) OVER (ORDER BY price) AS prev_price FROM test"
2341 ).unwrap();
2342 if let Statement::Select(q) = stmt {
2343 if let ColumnList::Named(exprs) = &q.columns {
2344 if let SelectExpr::Expr { expr: Expr::Window { func, args, .. }, .. } = &exprs[0] {
2345 assert_eq!(*func, WindowFunc::Lag);
2346 assert_eq!(args.len(), 2);
2347 } else {
2348 panic!("Expected Window expression");
2349 }
2350 } else {
2351 panic!("Expected Named columns");
2352 }
2353 } else {
2354 panic!("Expected Select");
2355 }
2356 }
2357
2358 #[test]
2359 fn test_dense_rank() {
2360 let stmt = parse_query(
2361 "SELECT DENSE_RANK() OVER (ORDER BY count DESC) AS dr FROM test"
2362 ).unwrap();
2363 if let Statement::Select(q) = stmt {
2364 if let ColumnList::Named(exprs) = &q.columns {
2365 if let SelectExpr::Expr { expr: Expr::Window { func, .. }, .. } = &exprs[0] {
2366 assert_eq!(*func, WindowFunc::DenseRank);
2367 } else {
2368 panic!("Expected Window expression");
2369 }
2370 } else {
2371 panic!("Expected Named columns");
2372 }
2373 } else {
2374 panic!("Expected Select");
2375 }
2376 }
2377
2378 #[test]
2379 fn test_sum_without_over_is_aggregate() {
2380 let stmt = parse_query("SELECT SUM(count) FROM test").unwrap();
2381 if let Statement::Select(q) = stmt {
2382 if let ColumnList::Named(exprs) = &q.columns {
2383 assert!(matches!(&exprs[0], SelectExpr::Aggregate { func: AggFunc::Sum, .. }));
2384 } else {
2385 panic!("Expected Named columns");
2386 }
2387 } else {
2388 panic!("Expected Select");
2389 }
2390 }
2391}