1use regex::Regex;
4use std::sync::LazyLock;
5
6use crate::errors::MdqlError;
7pub use crate::query_ast::*;
8
9static KEYWORDS: &[&str] = &[
12 "SELECT", "FROM", "WHERE", "AND", "OR", "ORDER", "BY",
13 "ASC", "DESC", "LIMIT", "LIKE", "IN", "IS", "NOT", "NULL",
14 "JOIN", "LEFT", "ON", "AS", "GROUP", "HAVING",
15 "INSERT", "INTO", "VALUES", "UPDATE", "SET", "DELETE",
16 "ALTER", "TABLE", "RENAME", "FIELD", "TO", "DROP", "MERGE", "FIELDS",
17 "CASE", "WHEN", "THEN", "ELSE", "END",
18 "INTERVAL", "DAY", "DAYS", "CURRENT_DATE", "CURRENT_TIMESTAMP", "DATEDIFF",
19 "CREATE", "VIEW", "CASCADE", "RESTRICT",
20 "WITH",
21 "OVER", "PARTITION", "ROW_NUMBER", "RANK", "DENSE_RANK", "LAG", "LEAD",
22];
23
24static AGG_FUNCS: &[&str] = &["COUNT", "SUM", "AVG", "MIN", "MAX"];
25static WINDOW_FUNCS: &[&str] = &["ROW_NUMBER", "RANK", "DENSE_RANK", "LAG", "LEAD"];
26
27static TOKEN_RE: LazyLock<Regex> = LazyLock::new(|| {
28 Regex::new(
29 r#"(?x)
30 \s*(?:
31 (?P<backtick>`[^`]+`)
32 | (?P<string>'(?:[^'\\]|\\.)*')
33 | (?P<number>\d+(?:\.\d+)?)
34 | (?P<op><=|>=|!=|[=<>,*()+\-/%])
35 | (?P<word>[A-Za-z_][A-Za-z0-9_./-]*)
36 )"#,
37 )
38 .unwrap()
39});
40
41#[derive(Debug, Clone)]
42struct Token {
43 token_type: String,
44 value: String,
45 raw: String,
46}
47
48fn tokenize(sql: &str) -> Vec<Token> {
49 let mut tokens = Vec::new();
50 for caps in TOKEN_RE.captures_iter(sql) {
51 if let Some(m) = caps.name("backtick") {
52 let raw = m.as_str();
53 tokens.push(Token {
54 token_type: "ident".into(),
55 value: raw[1..raw.len() - 1].into(),
56 raw: raw.into(),
57 });
58 } else if let Some(m) = caps.name("string") {
59 let raw = m.as_str();
60 tokens.push(Token {
61 token_type: "string".into(),
62 value: raw[1..raw.len() - 1].into(),
63 raw: raw.into(),
64 });
65 } else if let Some(m) = caps.name("number") {
66 let raw = m.as_str();
67 tokens.push(Token {
68 token_type: "number".into(),
69 value: raw.into(),
70 raw: raw.into(),
71 });
72 } else if let Some(m) = caps.name("op") {
73 let raw = m.as_str();
74 tokens.push(Token {
75 token_type: "op".into(),
76 value: raw.into(),
77 raw: raw.into(),
78 });
79 } else if let Some(m) = caps.name("word") {
80 let raw = m.as_str();
81 if KEYWORDS.contains(&raw.to_uppercase().as_str()) {
82 tokens.push(Token {
83 token_type: "keyword".into(),
84 value: raw.to_uppercase(),
85 raw: raw.into(),
86 });
87 } else {
88 tokens.push(Token {
89 token_type: "ident".into(),
90 value: raw.into(),
91 raw: raw.into(),
92 });
93 }
94 }
95 }
96 tokens
97}
98
99struct Parser {
102 tokens: Vec<Token>,
103 pos: usize,
104}
105
106impl Parser {
107 fn new(tokens: Vec<Token>) -> Self {
108 Parser { tokens, pos: 0 }
109 }
110
111 fn peek(&self) -> Option<&Token> {
112 self.tokens.get(self.pos)
113 }
114
115 fn advance(&mut self) -> Token {
116 let t = self.tokens[self.pos].clone();
117 self.pos += 1;
118 t
119 }
120
121 fn expect(&mut self, type_: &str, value: Option<&str>) -> Result<Token, MdqlError> {
122 let t = self.peek().ok_or_else(|| {
123 MdqlError::QueryParse(format!(
124 "Unexpected end of query, expected {}",
125 value.unwrap_or(type_)
126 ))
127 })?;
128 let matches_type = t.token_type == type_;
129 let matches_value = value.map_or(true, |v| t.value == v);
130 if !matches_type || !matches_value {
131 return Err(MdqlError::QueryParse(format!(
132 "Expected {}, got '{}' at position {}",
133 value.unwrap_or(type_),
134 t.raw,
135 self.pos
136 )));
137 }
138 Ok(self.advance())
139 }
140
141 fn match_keyword(&mut self, kw: &str) -> bool {
142 if let Some(t) = self.peek() {
143 if t.token_type == "keyword" && t.value == kw {
144 self.advance();
145 return true;
146 }
147 }
148 false
149 }
150
151 fn parse_statement(&mut self) -> Result<Statement, MdqlError> {
152 let t = self.peek().ok_or_else(|| MdqlError::QueryParse("Empty query".into()))?;
153 match (t.token_type.as_str(), t.value.as_str()) {
154 ("keyword", "WITH") => {
155 let ctes = self.parse_ctes()?;
156 let mut q = self.parse_select()?;
157 q.ctes = ctes;
158 self.expect_end()?;
159 Ok(Statement::Select(q))
160 }
161 ("keyword", "SELECT") => {
162 let q = self.parse_select()?;
163 self.expect_end()?;
164 Ok(Statement::Select(q))
165 }
166 ("keyword", "INSERT") => Ok(Statement::Insert(self.parse_insert()?)),
167 ("keyword", "UPDATE") => Ok(Statement::Update(self.parse_update()?)),
168 ("keyword", "DELETE") => Ok(Statement::Delete(self.parse_delete()?)),
169 ("keyword", "ALTER") => self.parse_alter(),
170 ("keyword", "CREATE") => self.parse_create_view(),
171 ("keyword", "DROP") => self.parse_drop_view(),
172 _ => Err(MdqlError::QueryParse(format!(
173 "Expected SELECT, INSERT, UPDATE, DELETE, ALTER, CREATE, or DROP, got '{}'",
174 t.raw
175 ))),
176 }
177 }
178
179 fn parse_ctes(&mut self) -> Result<Vec<CteClause>, MdqlError> {
180 self.expect("keyword", Some("WITH"))?;
181 let mut ctes = Vec::new();
182 loop {
183 let name = self.parse_ident()?;
184 self.expect("keyword", Some("AS"))?;
185 self.expect("op", Some("("))?;
186 let query = self.parse_select()?;
187 self.expect("op", Some(")"))?;
188 ctes.push(CteClause { name, query: Box::new(query) });
189 if !self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
190 break;
191 }
192 self.advance();
193 }
194 Ok(ctes)
195 }
196
197 fn parse_select(&mut self) -> Result<SelectQuery, MdqlError> {
198 self.expect("keyword", Some("SELECT"))?;
199 let distinct = if self
203 .peek()
204 .map_or(false, |t| t.token_type == "ident" && t.value.eq_ignore_ascii_case("distinct"))
205 {
206 self.advance();
207 true
208 } else {
209 false
210 };
211 let columns = self.parse_columns()?;
212 self.expect("keyword", Some("FROM"))?;
213
214 let mut subquery = None;
216 let (table, mut table_alias) = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
217 self.advance();
218 let inner = self.parse_select()?;
219 self.expect("op", Some(")"))?;
220 subquery = Some(Box::new(inner));
221 let alias = if let Some(t) = self.peek() {
222 if t.token_type == "ident" && !self.is_clause_keyword(t) {
223 Some(self.advance().value)
224 } else {
225 None
226 }
227 } else {
228 None
229 };
230 ("_subquery".to_string(), alias)
231 } else {
232 let t = self.parse_ident()?;
233 (t, None)
234 };
235
236 if subquery.is_none() {
238 if let Some(t) = self.peek() {
239 if t.token_type == "ident" && !self.is_clause_keyword(t) {
240 table_alias = Some(self.advance().value);
241 }
242 }
243 }
244
245 let mut joins = Vec::new();
247 loop {
248 let jt = if self.match_keyword("LEFT") {
249 self.expect("keyword", Some("JOIN"))?;
250 JoinType::Left
251 } else if self.match_keyword("JOIN") {
252 JoinType::Inner
253 } else {
254 break;
255 };
256 let join_table = self.parse_ident()?;
257 let mut join_alias = None;
258 if let Some(t) = self.peek() {
259 if t.token_type == "ident" && !self.is_clause_keyword(t) {
260 join_alias = Some(self.advance().value);
261 }
262 }
263 self.expect("keyword", Some("ON"))?;
264 let condition = self.parse_or_expr()?;
265 joins.push(JoinClause {
266 join_type: jt,
267 table: join_table,
268 alias: join_alias,
269 condition,
270 });
271 }
272
273 let mut where_clause = None;
274 if self.match_keyword("WHERE") {
275 where_clause = Some(self.parse_or_expr()?);
276 }
277
278 let mut group_by = None;
279 if self.match_keyword("GROUP") {
280 self.expect("keyword", Some("BY"))?;
281 let mut cols = vec![self.parse_ident()?];
282 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
283 self.advance();
284 cols.push(self.parse_ident()?);
285 }
286 group_by = Some(cols);
287 }
288
289 let mut having = None;
290 if self.match_keyword("HAVING") {
291 having = Some(self.parse_or_expr()?);
292 }
293
294 let mut order_by = None;
295 if self.match_keyword("ORDER") {
296 self.expect("keyword", Some("BY"))?;
297 order_by = Some(self.parse_order_by()?);
298 }
299
300 let mut limit = None;
301 if self.match_keyword("LIMIT") {
302 let t = self.expect("number", None)?;
303 limit = Some(t.value.parse::<i64>().map_err(|_| {
304 MdqlError::QueryParse(format!("Invalid LIMIT value: {}", t.value))
305 })?);
306 }
307
308 Ok(SelectQuery {
309 distinct,
310 columns,
311 table,
312 table_alias,
313 subquery,
314 joins,
315 where_clause,
316 group_by,
317 having,
318 order_by,
319 limit,
320 ctes: vec![],
321 })
322 }
323
324 fn parse_insert(&mut self) -> Result<InsertQuery, MdqlError> {
325 self.expect("keyword", Some("INSERT"))?;
326 self.expect("keyword", Some("INTO"))?;
327 let table = self.parse_ident()?;
328
329 self.expect("op", Some("("))?;
330 let mut columns = vec![self.parse_ident()?];
331 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
332 self.advance();
333 columns.push(self.parse_ident()?);
334 }
335 self.expect("op", Some(")"))?;
336
337 self.expect("keyword", Some("VALUES"))?;
338
339 self.expect("op", Some("("))?;
340 let mut values = vec![self.parse_value()?];
341 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
342 self.advance();
343 values.push(self.parse_value()?);
344 }
345 self.expect("op", Some(")"))?;
346
347 if columns.len() != values.len() {
348 return Err(MdqlError::QueryParse(format!(
349 "Column count ({}) does not match value count ({})",
350 columns.len(),
351 values.len()
352 )));
353 }
354
355 self.expect_end()?;
356 Ok(InsertQuery {
357 table,
358 columns,
359 values,
360 })
361 }
362
363 fn parse_update(&mut self) -> Result<UpdateQuery, MdqlError> {
364 self.expect("keyword", Some("UPDATE"))?;
365 let table = self.parse_ident()?;
366 self.expect("keyword", Some("SET"))?;
367
368 let mut assignments = Vec::new();
369 let col = self.parse_ident()?;
370 self.expect("op", Some("="))?;
371 let val = self.parse_value()?;
372 assignments.push((col, val));
373
374 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
375 self.advance();
376 let col = self.parse_ident()?;
377 self.expect("op", Some("="))?;
378 let val = self.parse_value()?;
379 assignments.push((col, val));
380 }
381
382 let mut where_clause = None;
383 if self.match_keyword("WHERE") {
384 where_clause = Some(self.parse_or_expr()?);
385 }
386
387 self.expect_end()?;
388 Ok(UpdateQuery {
389 table,
390 assignments,
391 where_clause,
392 })
393 }
394
395 fn parse_delete(&mut self) -> Result<DeleteQuery, MdqlError> {
396 self.expect("keyword", Some("DELETE"))?;
397 self.expect("keyword", Some("FROM"))?;
398 let table = self.parse_ident()?;
399
400 let mut where_clause = None;
401 if self.match_keyword("WHERE") {
402 where_clause = Some(self.parse_or_expr()?);
403 }
404
405 let mode = if self.match_keyword("CASCADE") {
406 DeleteMode::Cascade
407 } else if self.match_keyword("RESTRICT") {
408 DeleteMode::Restrict
409 } else {
410 DeleteMode::Default
411 };
412
413 self.expect_end()?;
414 Ok(DeleteQuery {
415 table,
416 where_clause,
417 mode,
418 })
419 }
420
421 fn parse_alter(&mut self) -> Result<Statement, MdqlError> {
422 self.expect("keyword", Some("ALTER"))?;
423 self.expect("keyword", Some("TABLE"))?;
424 let table = self.parse_ident()?;
425
426 let t = self.peek().ok_or_else(|| {
427 MdqlError::QueryParse("Expected RENAME, DROP, or MERGE after table name".into())
428 })?;
429
430 match (t.token_type.as_str(), t.value.as_str()) {
431 ("keyword", "RENAME") => {
432 self.advance();
433 self.expect("keyword", Some("FIELD"))?;
434 let old_name = self.parse_string_or_ident()?;
435 self.expect("keyword", Some("TO"))?;
436 let new_name = self.parse_string_or_ident()?;
437 self.expect_end()?;
438 Ok(Statement::AlterRename(AlterRenameFieldQuery {
439 table,
440 old_name,
441 new_name,
442 }))
443 }
444 ("keyword", "DROP") => {
445 self.advance();
446 self.expect("keyword", Some("FIELD"))?;
447 let field_name = self.parse_string_or_ident()?;
448 self.expect_end()?;
449 Ok(Statement::AlterDrop(AlterDropFieldQuery {
450 table,
451 field_name,
452 }))
453 }
454 ("keyword", "MERGE") => {
455 self.advance();
456 self.expect("keyword", Some("FIELDS"))?;
457 let mut sources = vec![self.parse_string_or_ident()?];
458 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
459 self.advance();
460 sources.push(self.parse_string_or_ident()?);
461 }
462 self.expect("keyword", Some("INTO"))?;
463 let target = self.parse_string_or_ident()?;
464 self.expect_end()?;
465 Ok(Statement::AlterMerge(AlterMergeFieldsQuery {
466 table,
467 sources,
468 into: target,
469 }))
470 }
471 _ => Err(MdqlError::QueryParse(format!(
472 "Expected RENAME, DROP, or MERGE, got '{}'",
473 t.raw
474 ))),
475 }
476 }
477
478 fn parse_create_view(&mut self) -> Result<Statement, MdqlError> {
479 self.expect("keyword", Some("CREATE"))?;
480 self.expect("keyword", Some("VIEW"))?;
481 let view_name = self.parse_ident()?;
482
483 let columns = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
484 self.advance();
485 let mut cols = vec![self.parse_ident()?];
486 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
487 self.advance();
488 cols.push(self.parse_ident()?);
489 }
490 self.expect("op", Some(")"))?;
491 Some(cols)
492 } else {
493 None
494 };
495
496 self.expect("keyword", Some("AS"))?;
497 let query = Box::new(self.parse_select()?);
498 self.expect_end()?;
499
500 Ok(Statement::CreateView(CreateViewQuery {
501 view_name,
502 columns,
503 query,
504 }))
505 }
506
507 fn parse_drop_view(&mut self) -> Result<Statement, MdqlError> {
508 self.expect("keyword", Some("DROP"))?;
509 self.expect("keyword", Some("VIEW"))?;
510 let view_name = self.parse_ident()?;
511 self.expect_end()?;
512 Ok(Statement::DropView(DropViewQuery { view_name }))
513 }
514
515 fn parse_string_or_ident(&mut self) -> Result<String, MdqlError> {
516 let t = self.peek().ok_or_else(|| {
517 MdqlError::QueryParse("Expected field name, got end of query".into())
518 })?;
519 match t.token_type.as_str() {
520 "string" => {
521 let v = self.advance().value;
522 Ok(v)
523 }
524 "ident" | "keyword" => {
525 let v = self.advance().value;
526 Ok(v)
527 }
528 _ => Err(MdqlError::QueryParse(format!(
529 "Expected field name, got '{}'",
530 t.raw
531 ))),
532 }
533 }
534
535 fn parse_columns(&mut self) -> Result<ColumnList, MdqlError> {
536 if let Some(t) = self.peek() {
537 if t.token_type == "op" && t.value == "*" {
538 self.advance();
539 return Ok(ColumnList::All);
540 }
541 }
542
543 let mut exprs = vec![self.parse_select_expr()?];
544 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
545 self.advance();
546 exprs.push(self.parse_select_expr()?);
547 }
548 Ok(ColumnList::Named(exprs))
549 }
550
551 fn peek_is_window_func(&self) -> bool {
552 let t = match self.peek() {
553 Some(t) => t,
554 None => return false,
555 };
556 let name_upper = t.value.to_uppercase();
557 if !WINDOW_FUNCS.contains(&name_upper.as_str()) {
558 return false;
559 }
560 self.tokens
561 .get(self.pos + 1)
562 .map_or(false, |next| next.token_type == "op" && next.value == "(")
563 }
564
565 fn peek_is_agg_func(&self) -> bool {
566 let t = match self.peek() {
567 Some(t) => t,
568 None => return false,
569 };
570 let name_upper = t.value.to_uppercase();
571 if !AGG_FUNCS.contains(&name_upper.as_str()) {
572 return false;
573 }
574 self.tokens
576 .get(self.pos + 1)
577 .map_or(false, |next| next.token_type == "op" && next.value == "(")
578 }
579
580 fn parse_select_expr(&mut self) -> Result<SelectExpr, MdqlError> {
581 let _t = self.peek().ok_or_else(|| {
582 MdqlError::QueryParse("Expected column or aggregate, got end of query".into())
583 })?;
584
585 let expr = self.parse_additive()?;
586
587 let alias = if self.match_keyword("AS") {
588 Some(self.parse_ident()?)
589 } else if self.peek().map_or(false, |t| {
590 t.token_type == "ident" && !self.is_clause_keyword(t)
591 }) {
592 Some(self.advance().value)
593 } else {
594 None
595 };
596
597 if let Expr::Aggregate { func, arg, arg_expr } = expr {
599 return Ok(SelectExpr::Aggregate {
600 func,
601 arg,
602 arg_expr: arg_expr.map(|e| *e),
603 alias,
604 });
605 }
606
607 if alias.is_none() {
608 if let Expr::Column(name) = &expr {
609 return Ok(SelectExpr::Column(name.clone()));
610 }
611 }
612
613 Ok(SelectExpr::Expr { expr, alias })
614 }
615
616 fn peek_is_additive_op(&self) -> bool {
619 self.peek().map_or(false, |t| {
620 t.token_type == "op" && (t.value == "+" || t.value == "-")
621 })
622 }
623
624 fn peek_is_multiplicative_op(&self) -> bool {
625 self.peek().map_or(false, |t| {
626 t.token_type == "op" && (t.value == "*" || t.value == "/" || t.value == "%")
627 })
628 }
629
630 fn parse_additive(&mut self) -> Result<Expr, MdqlError> {
631 let mut left = self.parse_multiplicative()?;
632 while self.peek_is_additive_op() {
633 let op_tok = self.advance();
634 let is_sub = op_tok.value == "-";
635
636 if self.peek().map_or(false, |t| t.token_type == "keyword" && t.value == "INTERVAL") {
638 self.advance(); let days_expr = self.parse_multiplicative()?;
640 if !self.match_keyword("DAY") && !self.match_keyword("DAYS") {
642 return Err(MdqlError::QueryParse("Expected DAY after INTERVAL value".into()));
643 }
644 let days = if is_sub {
645 Expr::UnaryMinus(Box::new(days_expr))
646 } else {
647 days_expr
648 };
649 left = Expr::DateAdd {
650 date: Box::new(left),
651 days: Box::new(days),
652 };
653 continue;
654 }
655
656 let op = match op_tok.value.as_str() {
657 "+" => ArithOp::Add,
658 "-" => ArithOp::Sub,
659 _ => unreachable!(),
660 };
661 let right = self.parse_multiplicative()?;
662 left = Expr::BinaryOp {
663 left: Box::new(left),
664 op,
665 right: Box::new(right),
666 };
667 }
668 Ok(left)
669 }
670
671 fn parse_multiplicative(&mut self) -> Result<Expr, MdqlError> {
672 let mut left = self.parse_unary()?;
673 while self.peek_is_multiplicative_op() {
674 let op_tok = self.advance();
675 let op = match op_tok.value.as_str() {
676 "*" => ArithOp::Mul,
677 "/" => ArithOp::Div,
678 "%" => ArithOp::Mod,
679 _ => unreachable!(),
680 };
681 let right = self.parse_unary()?;
682 left = Expr::BinaryOp {
683 left: Box::new(left),
684 op,
685 right: Box::new(right),
686 };
687 }
688 Ok(left)
689 }
690
691 fn parse_unary(&mut self) -> Result<Expr, MdqlError> {
692 if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "-") {
693 self.advance();
694 let inner = self.parse_atom()?;
695 match inner {
697 Expr::Literal(SqlValue::Int(n)) => Ok(Expr::Literal(SqlValue::Int(-n))),
698 Expr::Literal(SqlValue::Float(f)) => Ok(Expr::Literal(SqlValue::Float(-f))),
699 _ => Ok(Expr::UnaryMinus(Box::new(inner))),
700 }
701 } else {
702 self.parse_atom()
703 }
704 }
705
706 fn parse_atom(&mut self) -> Result<Expr, MdqlError> {
707 if self.peek_is_window_func() {
708 return self.parse_standalone_window();
709 }
710
711 if self.peek_is_agg_func() {
712 let agg = self.parse_agg_expr()?;
713 if self.peek().map_or(false, |t| t.token_type == "keyword" && t.value == "OVER") {
714 self.advance();
715 let over = self.parse_window_spec()?;
716 if let Expr::Aggregate { func, arg, arg_expr } = agg {
717 return Ok(Expr::Window {
718 func: WindowFunc::Agg(func),
719 args: if arg == "*" {
720 vec![]
721 } else {
722 vec![arg_expr.map(|e| *e).unwrap_or(Expr::Column(arg))]
723 },
724 over,
725 });
726 }
727 }
728 return Ok(agg);
729 }
730
731 let t = self.peek().ok_or_else(|| {
732 MdqlError::QueryParse("Expected expression, got end of query".into())
733 })?;
734
735 match t.token_type.as_str() {
736 "number" => {
737 let v = self.advance().value;
738 if v.contains('.') {
739 let f: f64 = v.parse().map_err(|_| {
740 MdqlError::QueryParse(format!("Invalid float: {}", v))
741 })?;
742 Ok(Expr::Literal(SqlValue::Float(f)))
743 } else {
744 let n: i64 = v.parse().map_err(|_| {
745 MdqlError::QueryParse(format!("Invalid int: {}", v))
746 })?;
747 Ok(Expr::Literal(SqlValue::Int(n)))
748 }
749 }
750 "string" => {
751 let v = self.advance().value;
752 Ok(Expr::Literal(SqlValue::String(v)))
753 }
754 "keyword" if t.value == "NULL" => {
755 self.advance();
756 Ok(Expr::Literal(SqlValue::Null))
757 }
758 "keyword" if t.value == "CASE" => {
759 self.parse_case_expr()
760 }
761 "keyword" if t.value == "CURRENT_DATE" => {
762 self.advance();
763 Ok(Expr::CurrentDate)
764 }
765 "keyword" if t.value == "CURRENT_TIMESTAMP" => {
766 self.advance();
767 Ok(Expr::CurrentTimestamp)
768 }
769 "keyword" if t.value == "DATEDIFF" => {
770 self.advance();
771 self.expect("op", Some("("))?;
772 let left = self.parse_additive()?;
773 self.expect("op", Some(","))?;
774 let right = self.parse_additive()?;
775 self.expect("op", Some(")"))?;
776 Ok(Expr::DateDiff { left: Box::new(left), right: Box::new(right) })
777 }
778 "op" if t.value == "(" => {
779 let next_is_select = self.tokens.get(self.pos + 1)
780 .map_or(false, |t| t.token_type == "keyword" && t.value == "SELECT");
781 if next_is_select {
782 self.advance();
783 let sq = self.parse_select()?;
784 self.expect("op", Some(")"))?;
785 Ok(Expr::Subquery(Box::new(sq)))
786 } else {
787 self.advance();
788 let expr = self.parse_additive()?;
789 self.expect("op", Some(")"))?;
790 Ok(expr)
791 }
792 }
793 "ident" if t.value.eq_ignore_ascii_case("true") => {
796 self.advance();
797 Ok(Expr::Literal(SqlValue::Bool(true)))
798 }
799 "ident" if t.value.eq_ignore_ascii_case("false") => {
800 self.advance();
801 Ok(Expr::Literal(SqlValue::Bool(false)))
802 }
803 "ident" => {
804 let name = self.advance().value;
805 Ok(Expr::Column(name))
806 }
807 "keyword" if !Self::is_reserved_keyword(&t.value) => {
808 let name = self.advance().value;
809 Ok(Expr::Column(name))
810 }
811 _ => Err(MdqlError::QueryParse(format!(
812 "Expected expression, got '{}'",
813 t.raw
814 ))),
815 }
816 }
817
818 fn parse_case_expr(&mut self) -> Result<Expr, MdqlError> {
819 self.expect("keyword", Some("CASE"))?;
820 let mut whens = Vec::new();
821 while self.match_keyword("WHEN") {
822 let condition = self.parse_or_expr()?;
823 self.expect("keyword", Some("THEN"))?;
824 let result = self.parse_additive()?;
825 whens.push((condition, Box::new(result)));
826 }
827 if whens.is_empty() {
828 return Err(MdqlError::QueryParse("CASE requires at least one WHEN clause".into()));
829 }
830 let else_expr = if self.match_keyword("ELSE") {
831 Some(Box::new(self.parse_additive()?))
832 } else {
833 None
834 };
835 self.expect("keyword", Some("END"))?;
836 Ok(Expr::Case { whens, else_expr })
837 }
838
839 fn parse_agg_expr(&mut self) -> Result<Expr, MdqlError> {
840 let func_name = self.advance().value.to_uppercase();
841 let func = match func_name.as_str() {
842 "COUNT" => AggFunc::Count,
843 "SUM" => AggFunc::Sum,
844 "AVG" => AggFunc::Avg,
845 "MIN" => AggFunc::Min,
846 "MAX" => AggFunc::Max,
847 _ => unreachable!(),
848 };
849 self.expect("op", Some("("))?;
850 let (arg, arg_expr) = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "*") {
851 self.advance();
852 ("*".to_string(), None)
853 } else {
854 let expr = self.parse_additive()?;
855 if let Expr::Column(name) = &expr {
856 (name.clone(), None)
857 } else {
858 (expr.display_name(), Some(Box::new(expr)))
859 }
860 };
861 self.expect("op", Some(")"))?;
862 Ok(Expr::Aggregate { func, arg, arg_expr })
863 }
864
865 fn parse_standalone_window(&mut self) -> Result<Expr, MdqlError> {
866 let func_name = self.advance().value.to_uppercase();
867 let func = match func_name.as_str() {
868 "ROW_NUMBER" => WindowFunc::RowNumber,
869 "RANK" => WindowFunc::Rank,
870 "DENSE_RANK" => WindowFunc::DenseRank,
871 "LAG" => WindowFunc::Lag,
872 "LEAD" => WindowFunc::Lead,
873 _ => unreachable!(),
874 };
875 self.expect("op", Some("("))?;
876 let mut args = Vec::new();
877 if !self.peek().map_or(false, |t| t.token_type == "op" && t.value == ")") {
878 args.push(self.parse_additive()?);
879 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
880 self.advance();
881 args.push(self.parse_additive()?);
882 }
883 }
884 self.expect("op", Some(")"))?;
885 self.expect("keyword", Some("OVER"))?;
886 let over = self.parse_window_spec()?;
887 Ok(Expr::Window { func, args, over })
888 }
889
890 fn parse_window_spec(&mut self) -> Result<WindowSpec, MdqlError> {
891 self.expect("op", Some("("))?;
892 let mut partition_by = Vec::new();
893 if self.match_keyword("PARTITION") {
894 self.expect("keyword", Some("BY"))?;
895 partition_by.push(self.parse_ident()?);
896 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
897 self.advance();
898 partition_by.push(self.parse_ident()?);
899 }
900 }
901 let mut order_by = Vec::new();
902 if self.match_keyword("ORDER") {
903 self.expect("keyword", Some("BY"))?;
904 order_by = self.parse_order_by()?;
905 }
906 self.expect("op", Some(")"))?;
907 Ok(WindowSpec { partition_by, order_by })
908 }
909
910 fn parse_ident(&mut self) -> Result<String, MdqlError> {
911 let t = self.peek().ok_or_else(|| {
912 MdqlError::QueryParse("Expected identifier, got end of query".into())
913 })?;
914 match t.token_type.as_str() {
915 "ident" | "keyword" => {
916 let v = self.advance().value;
917 Ok(v)
918 }
919 _ => Err(MdqlError::QueryParse(format!(
920 "Expected identifier, got '{}'",
921 t.raw
922 ))),
923 }
924 }
925
926 fn parse_or_expr(&mut self) -> Result<WhereClause, MdqlError> {
927 let mut left = self.parse_and_expr()?;
928 while self.match_keyword("OR") {
929 let right = self.parse_and_expr()?;
930 left = WhereClause::BoolOp(BoolOp {
931 op: BoolOpKind::Or,
932 left: Box::new(left),
933 right: Box::new(right),
934 });
935 }
936 Ok(left)
937 }
938
939 fn parse_and_expr(&mut self) -> Result<WhereClause, MdqlError> {
940 let mut left = self.parse_comparison()?;
941 while self.match_keyword("AND") {
942 let right = self.parse_comparison()?;
943 left = WhereClause::BoolOp(BoolOp {
944 op: BoolOpKind::And,
945 left: Box::new(left),
946 right: Box::new(right),
947 });
948 }
949 Ok(left)
950 }
951
952 fn parse_comparison(&mut self) -> Result<WhereClause, MdqlError> {
953 if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
955 let saved_pos = self.pos;
957 self.advance();
958 let result = self.parse_or_expr();
960 if result.is_ok() && self.peek().map_or(false, |t| t.token_type == "op" && t.value == ")") {
961 self.advance();
962 return result;
963 }
964 self.pos = saved_pos;
966 }
967
968 let left_expr = self.parse_additive()?;
970
971 let col = left_expr.as_column().unwrap_or("").to_string();
973
974 if self.match_keyword("IS") {
976 if self.match_keyword("NOT") {
977 self.expect("keyword", Some("NULL"))?;
978 return Ok(WhereClause::Comparison(Comparison {
979 column: col,
980 op: CmpOp::IsNotNull,
981 value: None,
982 left_expr: Some(left_expr),
983 right_expr: None,
984 }));
985 }
986 self.expect("keyword", Some("NULL"))?;
987 return Ok(WhereClause::Comparison(Comparison {
988 column: col,
989 op: CmpOp::IsNull,
990 value: None,
991 left_expr: Some(left_expr),
992 right_expr: None,
993 }));
994 }
995
996 if self.match_keyword("IN") {
998 self.expect("op", Some("("))?;
999 let is_subquery = self.peek().map_or(false, |t| t.token_type == "keyword" && t.value == "SELECT");
1000 if is_subquery {
1001 let sq = self.parse_select()?;
1002 self.expect("op", Some(")"))?;
1003 return Ok(WhereClause::Comparison(Comparison {
1004 column: col,
1005 op: CmpOp::In,
1006 value: None,
1007 left_expr: Some(left_expr),
1008 right_expr: Some(Expr::Subquery(Box::new(sq))),
1009 }));
1010 }
1011 let mut values = vec![self.parse_value()?];
1012 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
1013 self.advance();
1014 values.push(self.parse_value()?);
1015 }
1016 self.expect("op", Some(")"))?;
1017 return Ok(WhereClause::Comparison(Comparison {
1018 column: col,
1019 op: CmpOp::In,
1020 value: Some(SqlValue::List(values)),
1021 left_expr: Some(left_expr),
1022 right_expr: None,
1023 }));
1024 }
1025
1026 if self.match_keyword("LIKE") {
1028 let val = self.parse_value()?;
1029 return Ok(WhereClause::Comparison(Comparison {
1030 column: col,
1031 op: CmpOp::Like,
1032 value: Some(val),
1033 left_expr: Some(left_expr),
1034 right_expr: None,
1035 }));
1036 }
1037
1038 if self.match_keyword("NOT") {
1040 if self.match_keyword("LIKE") {
1041 let val = self.parse_value()?;
1042 return Ok(WhereClause::Comparison(Comparison {
1043 column: col,
1044 op: CmpOp::NotLike,
1045 value: Some(val),
1046 left_expr: Some(left_expr),
1047 right_expr: None,
1048 }));
1049 }
1050 return Err(MdqlError::QueryParse("Expected LIKE after NOT".into()));
1051 }
1052
1053 if let Some(t) = self.peek() {
1055 if t.token_type == "op" && ["=", "!=", "<", ">", "<=", ">="].contains(&t.value.as_str())
1056 {
1057 let op_str = self.advance().value;
1058 let op = match op_str.as_str() {
1059 "=" => CmpOp::Eq,
1060 "!=" => CmpOp::Ne,
1061 "<" => CmpOp::Lt,
1062 ">" => CmpOp::Gt,
1063 "<=" => CmpOp::Le,
1064 ">=" => CmpOp::Ge,
1065 _ => unreachable!(),
1066 };
1067 let right_expr = self.parse_additive()?;
1069 let value = match &right_expr {
1071 Expr::Literal(v) => Some(v.clone()),
1072 _ => None,
1073 };
1074 return Ok(WhereClause::Comparison(Comparison {
1075 column: col,
1076 op,
1077 value,
1078 left_expr: Some(left_expr),
1079 right_expr: Some(right_expr),
1080 }));
1081 }
1082 }
1083
1084 let got = self.peek().map_or("end".to_string(), |t| t.raw.clone());
1085 Err(MdqlError::QueryParse(format!(
1086 "Expected operator after '{}', got '{}'",
1087 left_expr.display_name(), got
1088 )))
1089 }
1090
1091 fn parse_value(&mut self) -> Result<SqlValue, MdqlError> {
1092 let t = self.peek().ok_or_else(|| {
1093 MdqlError::QueryParse("Expected value, got end of query".into())
1094 })?;
1095 match t.token_type.as_str() {
1096 "string" => {
1097 let v = self.advance().value;
1098 Ok(SqlValue::String(v))
1099 }
1100 "number" => {
1101 let v = self.advance().value;
1102 if v.contains('.') {
1103 Ok(SqlValue::Float(v.parse().map_err(|_| {
1104 MdqlError::QueryParse(format!("Invalid float: {}", v))
1105 })?))
1106 } else {
1107 Ok(SqlValue::Int(v.parse().map_err(|_| {
1108 MdqlError::QueryParse(format!("Invalid int: {}", v))
1109 })?))
1110 }
1111 }
1112 "keyword" if t.value == "NULL" => {
1113 self.advance();
1114 Ok(SqlValue::Null)
1115 }
1116 "ident" if t.value.eq_ignore_ascii_case("true") => {
1120 self.advance();
1121 Ok(SqlValue::Bool(true))
1122 }
1123 "ident" if t.value.eq_ignore_ascii_case("false") => {
1124 self.advance();
1125 Ok(SqlValue::Bool(false))
1126 }
1127 _ => Err(MdqlError::QueryParse(format!(
1128 "Expected value, got '{}'",
1129 t.raw
1130 ))),
1131 }
1132 }
1133
1134 fn parse_order_by(&mut self) -> Result<Vec<OrderSpec>, MdqlError> {
1135 let mut specs = vec![self.parse_order_spec()?];
1136 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
1137 self.advance();
1138 specs.push(self.parse_order_spec()?);
1139 }
1140 Ok(specs)
1141 }
1142
1143 fn parse_order_spec(&mut self) -> Result<OrderSpec, MdqlError> {
1144 let expr = self.parse_additive()?;
1145 let col = expr.as_column().unwrap_or("").to_string();
1146 let descending = if self.match_keyword("DESC") {
1147 true
1148 } else {
1149 self.match_keyword("ASC");
1150 false
1151 };
1152 Ok(OrderSpec {
1153 column: col,
1154 expr: Some(expr),
1155 descending,
1156 })
1157 }
1158
1159 fn is_clause_keyword(&self, t: &Token) -> bool {
1160 t.token_type == "keyword"
1161 && ["WHERE", "ORDER", "LIMIT", "JOIN", "LEFT", "ON", "GROUP"].contains(&t.value.as_str())
1162 }
1163
1164 fn is_reserved_keyword(kw: &str) -> bool {
1166 matches!(kw,
1167 "AS" | "FROM" | "WHERE" | "AND" | "OR" | "ORDER" | "BY"
1168 | "ASC" | "DESC" | "LIMIT" | "JOIN" | "ON" | "GROUP"
1169 | "SELECT" | "INSERT" | "INTO" | "VALUES" | "UPDATE" | "SET"
1170 | "DELETE" | "ALTER" | "TABLE" | "IS" | "NOT" | "IN" | "LIKE"
1171 | "RENAME" | "FIELD" | "TO" | "DROP" | "MERGE" | "FIELDS"
1172 | "CASE" | "WHEN" | "THEN" | "ELSE" | "END"
1173 | "HAVING" | "INTERVAL" | "DAY" | "DAYS"
1174 | "CURRENT_DATE" | "CURRENT_TIMESTAMP" | "DATEDIFF"
1175 | "CREATE" | "VIEW" | "CASCADE" | "RESTRICT"
1176 | "WITH"
1177 | "OVER" | "PARTITION" | "ROW_NUMBER" | "RANK" | "DENSE_RANK" | "LAG" | "LEAD"
1178 )
1179 }
1180
1181 fn expect_end(&self) -> Result<(), MdqlError> {
1182 if let Some(t) = self.peek() {
1183 return Err(MdqlError::QueryParse(format!(
1184 "Unexpected token '{}' at position {}",
1185 t.raw, self.pos
1186 )));
1187 }
1188 Ok(())
1189 }
1190}
1191
1192pub fn parse_query(sql: &str) -> crate::errors::Result<Statement> {
1193 let tokens = tokenize(sql);
1194 if tokens.is_empty() {
1195 return Err(MdqlError::QueryParse("Empty query".into()));
1196 }
1197 let mut parser = Parser::new(tokens);
1198 parser.parse_statement()
1199}
1200
1201#[cfg(test)]
1202mod tests {
1203 use super::*;
1204
1205 #[test]
1206 fn test_simple_select() {
1207 let stmt = parse_query("SELECT title, status FROM strategies").unwrap();
1208 if let Statement::Select(q) = stmt {
1209 assert_eq!(q.columns, ColumnList::Named(vec![SelectExpr::Column("title".into()), SelectExpr::Column("status".into())]));
1210 assert_eq!(q.table, "strategies");
1211 } else {
1212 panic!("Expected Select");
1213 }
1214 }
1215
1216 #[test]
1217 fn test_select_star() {
1218 let stmt = parse_query("SELECT * FROM test").unwrap();
1219 if let Statement::Select(q) = stmt {
1220 assert_eq!(q.columns, ColumnList::All);
1221 } else {
1222 panic!("Expected Select");
1223 }
1224 }
1225
1226 #[test]
1227 fn test_select_distinct_parses_flag_and_column() {
1228 let stmt = parse_query("SELECT DISTINCT strategy FROM backtests").unwrap();
1231 let Statement::Select(q) = stmt else { panic!("Expected Select") };
1232 assert!(q.distinct);
1233 assert_eq!(
1234 q.columns,
1235 ColumnList::Named(vec![SelectExpr::Column("strategy".into())])
1236 );
1237
1238 let Statement::Select(q) = parse_query("select distinct a, b from t").unwrap() else {
1240 panic!("Expected Select")
1241 };
1242 assert!(q.distinct);
1243 let Statement::Select(q) = parse_query("SELECT strategy FROM backtests").unwrap() else {
1244 panic!("Expected Select")
1245 };
1246 assert!(!q.distinct);
1247 }
1248
1249 #[test]
1250 fn test_boolean_literal_parses_as_bool_not_column() {
1251 for (sql, expected) in [
1254 ("SELECT title FROM t WHERE enabled = true", SqlValue::Bool(true)),
1255 ("SELECT title FROM t WHERE enabled = FALSE", SqlValue::Bool(false)),
1256 ] {
1257 let stmt = parse_query(sql).unwrap();
1258 let Statement::Select(q) = stmt else { panic!("Expected Select") };
1259 let WhereClause::Comparison(cmp) = q.where_clause.unwrap() else {
1260 panic!("Expected Comparison")
1261 };
1262 assert_eq!(cmp.value, Some(expected));
1263 }
1264 }
1265
1266 #[test]
1267 fn test_where_clause() {
1268 let stmt = parse_query("SELECT title FROM test WHERE count > 5").unwrap();
1269 if let Statement::Select(q) = stmt {
1270 assert!(q.where_clause.is_some());
1271 } else {
1272 panic!("Expected Select");
1273 }
1274 }
1275
1276 #[test]
1277 fn test_order_by() {
1278 let stmt =
1279 parse_query("SELECT title FROM test ORDER BY composite DESC, title ASC").unwrap();
1280 if let Statement::Select(q) = stmt {
1281 let ob = q.order_by.unwrap();
1282 assert_eq!(ob.len(), 2);
1283 assert!(ob[0].descending);
1284 assert!(!ob[1].descending);
1285 } else {
1286 panic!("Expected Select");
1287 }
1288 }
1289
1290 #[test]
1291 fn test_limit() {
1292 let stmt = parse_query("SELECT * FROM test LIMIT 10").unwrap();
1293 if let Statement::Select(q) = stmt {
1294 assert_eq!(q.limit, Some(10));
1295 } else {
1296 panic!("Expected Select");
1297 }
1298 }
1299
1300 #[test]
1301 fn test_insert() {
1302 let stmt = parse_query(
1303 "INSERT INTO test (title, count) VALUES ('Hello', 42)",
1304 )
1305 .unwrap();
1306 if let Statement::Insert(q) = stmt {
1307 assert_eq!(q.table, "test");
1308 assert_eq!(q.columns, vec!["title", "count"]);
1309 assert_eq!(q.values[0], SqlValue::String("Hello".into()));
1310 assert_eq!(q.values[1], SqlValue::Int(42));
1311 } else {
1312 panic!("Expected Insert");
1313 }
1314 }
1315
1316 #[test]
1317 fn test_update() {
1318 let stmt = parse_query("UPDATE test SET status = 'KILLED' WHERE path = 'a.md'").unwrap();
1319 if let Statement::Update(q) = stmt {
1320 assert_eq!(q.table, "test");
1321 assert_eq!(q.assignments.len(), 1);
1322 assert!(q.where_clause.is_some());
1323 } else {
1324 panic!("Expected Update");
1325 }
1326 }
1327
1328 #[test]
1329 fn test_delete() {
1330 let stmt = parse_query("DELETE FROM test WHERE status = 'draft'").unwrap();
1331 if let Statement::Delete(q) = stmt {
1332 assert_eq!(q.table, "test");
1333 assert!(q.where_clause.is_some());
1334 } else {
1335 panic!("Expected Delete");
1336 }
1337 }
1338
1339 #[test]
1340 fn test_alter_rename() {
1341 let stmt =
1342 parse_query("ALTER TABLE test RENAME FIELD 'Summary' TO 'Overview'").unwrap();
1343 if let Statement::AlterRename(q) = stmt {
1344 assert_eq!(q.old_name, "Summary");
1345 assert_eq!(q.new_name, "Overview");
1346 } else {
1347 panic!("Expected AlterRename");
1348 }
1349 }
1350
1351 #[test]
1352 fn test_alter_drop() {
1353 let stmt = parse_query("ALTER TABLE test DROP FIELD 'Details'").unwrap();
1354 if let Statement::AlterDrop(q) = stmt {
1355 assert_eq!(q.field_name, "Details");
1356 } else {
1357 panic!("Expected AlterDrop");
1358 }
1359 }
1360
1361 #[test]
1362 fn test_alter_merge() {
1363 let stmt = parse_query(
1364 "ALTER TABLE test MERGE FIELDS 'Entry Rules', 'Exit Rules' INTO 'Trading Rules'",
1365 )
1366 .unwrap();
1367 if let Statement::AlterMerge(q) = stmt {
1368 assert_eq!(q.sources, vec!["Entry Rules", "Exit Rules"]);
1369 assert_eq!(q.into, "Trading Rules");
1370 } else {
1371 panic!("Expected AlterMerge");
1372 }
1373 }
1374
1375 #[test]
1376 fn test_backtick_ident() {
1377 let stmt = parse_query("SELECT `Structural Mechanism` FROM test").unwrap();
1378 if let Statement::Select(q) = stmt {
1379 assert_eq!(
1380 q.columns,
1381 ColumnList::Named(vec![SelectExpr::Column("Structural Mechanism".into())])
1382 );
1383 } else {
1384 panic!("Expected Select");
1385 }
1386 }
1387
1388 #[test]
1389 fn test_like_operator() {
1390 let stmt = parse_query("SELECT title FROM test WHERE categories LIKE '%defi%'").unwrap();
1391 if let Statement::Select(q) = stmt {
1392 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1393 assert_eq!(c.op, CmpOp::Like);
1394 assert_eq!(c.value, Some(SqlValue::String("%defi%".into())));
1395 } else {
1396 panic!("Expected LIKE comparison");
1397 }
1398 } else {
1399 panic!("Expected Select");
1400 }
1401 }
1402
1403 #[test]
1404 fn test_in_operator() {
1405 let stmt =
1406 parse_query("SELECT * FROM test WHERE status IN ('ACTIVE', 'LIVE')").unwrap();
1407 if let Statement::Select(q) = stmt {
1408 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1409 assert_eq!(c.op, CmpOp::In);
1410 } else {
1411 panic!("Expected IN comparison");
1412 }
1413 } else {
1414 panic!("Expected Select");
1415 }
1416 }
1417
1418 #[test]
1419 fn test_is_null() {
1420 let stmt = parse_query("SELECT * FROM test WHERE title IS NULL").unwrap();
1421 if let Statement::Select(q) = stmt {
1422 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1423 assert_eq!(c.op, CmpOp::IsNull);
1424 } else {
1425 panic!("Expected IS NULL comparison");
1426 }
1427 } else {
1428 panic!("Expected Select");
1429 }
1430 }
1431
1432 #[test]
1433 fn test_and_or() {
1434 let stmt = parse_query(
1435 "SELECT * FROM test WHERE status = 'ACTIVE' AND count > 5 OR title LIKE '%test%'",
1436 )
1437 .unwrap();
1438 if let Statement::Select(q) = stmt {
1439 assert!(q.where_clause.is_some());
1440 } else {
1441 panic!("Expected Select");
1442 }
1443 }
1444
1445 #[test]
1446 fn test_join() {
1447 let stmt = parse_query(
1448 "SELECT s.title, b.sharpe FROM strategies s JOIN backtests b ON b.strategy = s.path",
1449 )
1450 .unwrap();
1451 if let Statement::Select(q) = stmt {
1452 assert_eq!(q.table, "strategies");
1453 assert_eq!(q.table_alias, Some("s".into()));
1454 assert_eq!(q.joins.len(), 1);
1455 let join = &q.joins[0];
1456 assert_eq!(join.table, "backtests");
1457 assert_eq!(join.alias, Some("b".into()));
1458 } else {
1459 panic!("Expected Select");
1460 }
1461 }
1462
1463 #[test]
1464 fn test_multi_join() {
1465 let stmt = parse_query(
1466 "SELECT s.title, b.sharpe, c.verdict FROM strategies s JOIN backtests b ON b.strategy = s.path JOIN critiques c ON c.strategy = s.path",
1467 )
1468 .unwrap();
1469 if let Statement::Select(q) = stmt {
1470 assert_eq!(q.table, "strategies");
1471 assert_eq!(q.table_alias, Some("s".into()));
1472 assert_eq!(q.joins.len(), 2);
1473 assert_eq!(q.joins[0].table, "backtests");
1474 assert_eq!(q.joins[0].alias, Some("b".into()));
1475 assert_eq!(where_clause_to_sql(&q.joins[0].condition), "b.strategy = s.path");
1476 assert_eq!(q.joins[1].table, "critiques");
1477 assert_eq!(q.joins[1].alias, Some("c".into()));
1478 assert_eq!(where_clause_to_sql(&q.joins[1].condition), "c.strategy = s.path");
1479 } else {
1480 panic!("Expected Select");
1481 }
1482 }
1483
1484 #[test]
1485 fn test_left_join() {
1486 let stmt = parse_query(
1487 "SELECT s.title, b.sharpe FROM strategies s LEFT JOIN backtests b ON b.strategy = s.path",
1488 )
1489 .unwrap();
1490 if let Statement::Select(q) = stmt {
1491 assert_eq!(q.joins.len(), 1);
1492 assert_eq!(q.joins[0].join_type, JoinType::Left);
1493 assert_eq!(q.joins[0].table, "backtests");
1494 } else {
1495 panic!("Expected Select");
1496 }
1497 }
1498
1499 #[test]
1500 fn test_mixed_join_types() {
1501 let stmt = parse_query(
1502 "SELECT s.title FROM strategies s JOIN backtests b ON b.strategy = s.path LEFT JOIN allocations a ON a.strategy = s.path",
1503 )
1504 .unwrap();
1505 if let Statement::Select(q) = stmt {
1506 assert_eq!(q.joins.len(), 2);
1507 assert_eq!(q.joins[0].join_type, JoinType::Inner);
1508 assert_eq!(q.joins[1].join_type, JoinType::Left);
1509 } else {
1510 panic!("Expected Select");
1511 }
1512 }
1513
1514 #[test]
1515 fn test_join_compound_and() {
1516 let stmt = parse_query(
1517 "SELECT s.title FROM strategies s LEFT JOIN backtests b ON b.strategy = s.path AND b.mode = 'PAPER'",
1518 )
1519 .unwrap();
1520 if let Statement::Select(q) = stmt {
1521 assert_eq!(q.joins.len(), 1);
1522 assert_eq!(q.joins[0].join_type, JoinType::Left);
1523 let sql = where_clause_to_sql(&q.joins[0].condition);
1524 assert!(sql.contains("b.strategy = s.path"));
1525 assert!(sql.contains("AND"));
1526 assert!(sql.contains("b.mode = 'PAPER'"));
1527 } else {
1528 panic!("Expected Select");
1529 }
1530 }
1531
1532 #[test]
1533 fn test_join_compound_or() {
1534 let stmt = parse_query(
1535 "SELECT * FROM a JOIN b ON a.id = b.id OR a.alt = b.id",
1536 )
1537 .unwrap();
1538 if let Statement::Select(q) = stmt {
1539 let sql = where_clause_to_sql(&q.joins[0].condition);
1540 assert!(sql.contains("OR"));
1541 } else {
1542 panic!("Expected Select");
1543 }
1544 }
1545
1546 #[test]
1547 fn test_join_compound_with_where() {
1548 let stmt = parse_query(
1549 "SELECT s.title FROM strategies s JOIN backtests b ON b.strategy = s.path AND b.mode = 'PAPER' WHERE s.title = 'Alpha'",
1550 )
1551 .unwrap();
1552 if let Statement::Select(q) = stmt {
1553 assert_eq!(q.joins.len(), 1);
1554 assert!(q.where_clause.is_some());
1555 let join_sql = where_clause_to_sql(&q.joins[0].condition);
1556 assert!(join_sql.contains("AND"));
1557 } else {
1558 panic!("Expected Select");
1559 }
1560 }
1561
1562 #[test]
1563 fn test_empty_query() {
1564 assert!(parse_query("").is_err());
1565 }
1566
1567 #[test]
1568 fn test_count_star() {
1569 let stmt = parse_query("SELECT status, COUNT(*) AS cnt FROM strategies GROUP BY status").unwrap();
1570 if let Statement::Select(q) = stmt {
1571 if let ColumnList::Named(exprs) = &q.columns {
1572 assert_eq!(exprs.len(), 2);
1573 assert_eq!(exprs[0], SelectExpr::Column("status".into()));
1574 assert!(matches!(&exprs[1], SelectExpr::Aggregate {
1575 func: AggFunc::Count,
1576 arg,
1577 alias: Some(a),
1578 ..
1579 } if arg == "*" && a == "cnt"));
1580 } else {
1581 panic!("Expected Named columns");
1582 }
1583 assert_eq!(q.group_by, Some(vec!["status".into()]));
1584 } else {
1585 panic!("Expected Select");
1586 }
1587 }
1588
1589 #[test]
1590 fn test_count_column_as_ident() {
1591 let stmt = parse_query("INSERT INTO test (title, count) VALUES ('Hello', 42)").unwrap();
1593 if let Statement::Insert(q) = stmt {
1594 assert_eq!(q.columns, vec!["title", "count"]);
1595 } else {
1596 panic!("Expected Insert");
1597 }
1598 }
1599
1600 #[test]
1601 fn test_multiple_aggregates() {
1602 let stmt = parse_query("SELECT MIN(composite), MAX(composite), AVG(composite) FROM strategies").unwrap();
1603 if let Statement::Select(q) = stmt {
1604 if let ColumnList::Named(exprs) = &q.columns {
1605 assert_eq!(exprs.len(), 3);
1606 assert!(matches!(&exprs[0], SelectExpr::Aggregate { func: AggFunc::Min, .. }));
1607 assert!(matches!(&exprs[1], SelectExpr::Aggregate { func: AggFunc::Max, .. }));
1608 assert!(matches!(&exprs[2], SelectExpr::Aggregate { func: AggFunc::Avg, .. }));
1609 } else {
1610 panic!("Expected Named columns");
1611 }
1612 assert_eq!(q.group_by, None);
1613 } else {
1614 panic!("Expected Select");
1615 }
1616 }
1617
1618 #[test]
1621 fn test_select_arithmetic_expr() {
1622 let stmt = parse_query("SELECT a + b FROM test").unwrap();
1623 if let Statement::Select(q) = stmt {
1624 if let ColumnList::Named(exprs) = &q.columns {
1625 assert_eq!(exprs.len(), 1);
1626 assert!(matches!(&exprs[0], SelectExpr::Expr {
1627 expr: Expr::BinaryOp { op: ArithOp::Add, .. },
1628 alias: None,
1629 }));
1630 } else {
1631 panic!("Expected Named columns");
1632 }
1633 } else {
1634 panic!("Expected Select");
1635 }
1636 }
1637
1638 #[test]
1639 fn test_select_arithmetic_with_alias() {
1640 let stmt = parse_query("SELECT a + b AS total FROM test").unwrap();
1641 if let Statement::Select(q) = stmt {
1642 if let ColumnList::Named(exprs) = &q.columns {
1643 assert_eq!(exprs.len(), 1);
1644 assert!(matches!(&exprs[0], SelectExpr::Expr {
1645 alias: Some(a),
1646 ..
1647 } if a == "total"));
1648 assert_eq!(exprs[0].output_name(), "total");
1649 } else {
1650 panic!("Expected Named columns");
1651 }
1652 } else {
1653 panic!("Expected Select");
1654 }
1655 }
1656
1657 #[test]
1658 fn test_select_precedence() {
1659 let stmt = parse_query("SELECT a + b * c FROM test").unwrap();
1661 if let Statement::Select(q) = stmt {
1662 if let ColumnList::Named(exprs) = &q.columns {
1663 if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1664 if let Expr::BinaryOp { left, op, right } = expr {
1665 assert_eq!(*op, ArithOp::Add);
1666 assert!(matches!(left.as_ref(), Expr::Column(n) if n == "a"));
1667 assert!(matches!(right.as_ref(), Expr::BinaryOp { op: ArithOp::Mul, .. }));
1668 } else {
1669 panic!("Expected BinaryOp");
1670 }
1671 } else {
1672 panic!("Expected Expr variant");
1673 }
1674 } else {
1675 panic!("Expected Named columns");
1676 }
1677 } else {
1678 panic!("Expected Select");
1679 }
1680 }
1681
1682 #[test]
1683 fn test_select_parenthesized_expr() {
1684 let stmt = parse_query("SELECT (a + b) * c FROM test").unwrap();
1686 if let Statement::Select(q) = stmt {
1687 if let ColumnList::Named(exprs) = &q.columns {
1688 if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1689 if let Expr::BinaryOp { left, op, .. } = expr {
1690 assert_eq!(*op, ArithOp::Mul);
1691 assert!(matches!(left.as_ref(), Expr::BinaryOp { op: ArithOp::Add, .. }));
1692 } else {
1693 panic!("Expected BinaryOp");
1694 }
1695 } else {
1696 panic!("Expected Expr variant");
1697 }
1698 } else {
1699 panic!("Expected Named columns");
1700 }
1701 } else {
1702 panic!("Expected Select");
1703 }
1704 }
1705
1706 #[test]
1707 fn test_select_unary_minus() {
1708 let stmt = parse_query("SELECT -count FROM test").unwrap();
1709 if let Statement::Select(q) = stmt {
1710 if let ColumnList::Named(exprs) = &q.columns {
1711 assert!(matches!(&exprs[0], SelectExpr::Expr {
1712 expr: Expr::UnaryMinus(_),
1713 ..
1714 }));
1715 } else {
1716 panic!("Expected Named columns");
1717 }
1718 } else {
1719 panic!("Expected Select");
1720 }
1721 }
1722
1723 #[test]
1724 fn test_select_negative_literal() {
1725 let stmt = parse_query("SELECT -42 FROM test").unwrap();
1726 if let Statement::Select(q) = stmt {
1727 if let ColumnList::Named(exprs) = &q.columns {
1728 assert!(matches!(&exprs[0], SelectExpr::Expr {
1730 expr: Expr::Literal(SqlValue::Int(-42)),
1731 ..
1732 }));
1733 } else {
1734 panic!("Expected Named columns");
1735 }
1736 } else {
1737 panic!("Expected Select");
1738 }
1739 }
1740
1741 #[test]
1742 fn test_where_arithmetic_expr() {
1743 let stmt = parse_query("SELECT * FROM test WHERE a + b > 10").unwrap();
1744 if let Statement::Select(q) = stmt {
1745 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1746 assert_eq!(c.op, CmpOp::Gt);
1747 assert!(matches!(&c.left_expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1748 assert!(matches!(&c.right_expr, Some(Expr::Literal(SqlValue::Int(10)))));
1749 } else {
1750 panic!("Expected comparison");
1751 }
1752 } else {
1753 panic!("Expected Select");
1754 }
1755 }
1756
1757 #[test]
1758 fn test_where_both_sides_expr() {
1759 let stmt = parse_query("SELECT * FROM test WHERE a * 2 > b + 1").unwrap();
1760 if let Statement::Select(q) = stmt {
1761 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1762 assert_eq!(c.op, CmpOp::Gt);
1763 assert!(matches!(&c.left_expr, Some(Expr::BinaryOp { op: ArithOp::Mul, .. })));
1764 assert!(matches!(&c.right_expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1765 } else {
1766 panic!("Expected comparison");
1767 }
1768 } else {
1769 panic!("Expected Select");
1770 }
1771 }
1772
1773 #[test]
1774 fn test_order_by_expr() {
1775 let stmt = parse_query("SELECT * FROM test ORDER BY a + b DESC").unwrap();
1776 if let Statement::Select(q) = stmt {
1777 let ob = q.order_by.unwrap();
1778 assert_eq!(ob.len(), 1);
1779 assert!(ob[0].descending);
1780 assert!(matches!(&ob[0].expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1781 } else {
1782 panic!("Expected Select");
1783 }
1784 }
1785
1786 #[test]
1787 fn test_all_arithmetic_ops() {
1788 let stmt = parse_query("SELECT a + b, a - b, a * b, a / b, a % b FROM test").unwrap();
1789 if let Statement::Select(q) = stmt {
1790 if let ColumnList::Named(exprs) = &q.columns {
1791 assert_eq!(exprs.len(), 5);
1792 assert!(matches!(&exprs[0], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Add, .. }, .. }));
1793 assert!(matches!(&exprs[1], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Sub, .. }, .. }));
1794 assert!(matches!(&exprs[2], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Mul, .. }, .. }));
1795 assert!(matches!(&exprs[3], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Div, .. }, .. }));
1796 assert!(matches!(&exprs[4], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Mod, .. }, .. }));
1797 } else {
1798 panic!("Expected Named columns");
1799 }
1800 } else {
1801 panic!("Expected Select");
1802 }
1803 }
1804
1805 #[test]
1806 fn test_column_with_literal_arithmetic() {
1807 let stmt = parse_query("SELECT count * 2 + 1 FROM test").unwrap();
1808 if let Statement::Select(q) = stmt {
1809 if let ColumnList::Named(exprs) = &q.columns {
1810 if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1812 if let Expr::BinaryOp { left, op, right } = expr {
1813 assert_eq!(*op, ArithOp::Add);
1814 assert!(matches!(right.as_ref(), Expr::Literal(SqlValue::Int(1))));
1815 assert!(matches!(left.as_ref(), Expr::BinaryOp { op: ArithOp::Mul, .. }));
1816 } else {
1817 panic!("Expected BinaryOp");
1818 }
1819 } else {
1820 panic!("Expected Expr");
1821 }
1822 } else {
1823 panic!("Expected Named columns");
1824 }
1825 } else {
1826 panic!("Expected Select");
1827 }
1828 }
1829
1830 #[test]
1831 fn test_mixed_columns_and_exprs() {
1832 let stmt = parse_query("SELECT title, a + b AS sum, count FROM test").unwrap();
1833 if let Statement::Select(q) = stmt {
1834 if let ColumnList::Named(exprs) = &q.columns {
1835 assert_eq!(exprs.len(), 3);
1836 assert_eq!(exprs[0], SelectExpr::Column("title".into()));
1837 assert!(matches!(&exprs[1], SelectExpr::Expr { alias: Some(a), .. } if a == "sum"));
1838 assert_eq!(exprs[2], SelectExpr::Column("count".into()));
1839 } else {
1840 panic!("Expected Named columns");
1841 }
1842 } else {
1843 panic!("Expected Select");
1844 }
1845 }
1846
1847 #[test]
1850 fn test_case_when_basic() {
1851 let stmt = parse_query(
1852 "SELECT CASE WHEN status = 'ACTIVE' THEN 1 ELSE 0 END FROM test"
1853 ).unwrap();
1854 if let Statement::Select(q) = stmt {
1855 if let ColumnList::Named(exprs) = &q.columns {
1856 assert_eq!(exprs.len(), 1);
1857 assert!(matches!(&exprs[0], SelectExpr::Expr {
1858 expr: Expr::Case { .. },
1859 ..
1860 }));
1861 } else {
1862 panic!("Expected Named columns");
1863 }
1864 } else {
1865 panic!("Expected Select");
1866 }
1867 }
1868
1869 #[test]
1870 fn test_case_when_multiple_branches() {
1871 let stmt = parse_query(
1872 "SELECT CASE WHEN x > 10 THEN 'high' WHEN x > 5 THEN 'mid' ELSE 'low' END FROM test"
1873 ).unwrap();
1874 if let Statement::Select(q) = stmt {
1875 if let ColumnList::Named(exprs) = &q.columns {
1876 if let SelectExpr::Expr { expr: Expr::Case { whens, else_expr }, .. } = &exprs[0] {
1877 assert_eq!(whens.len(), 2);
1878 assert!(else_expr.is_some());
1879 } else {
1880 panic!("Expected Case expression");
1881 }
1882 } else {
1883 panic!("Expected Named columns");
1884 }
1885 } else {
1886 panic!("Expected Select");
1887 }
1888 }
1889
1890 #[test]
1891 fn test_case_when_no_else() {
1892 let stmt = parse_query(
1893 "SELECT CASE WHEN x = 1 THEN 'one' END FROM test"
1894 ).unwrap();
1895 if let Statement::Select(q) = stmt {
1896 if let ColumnList::Named(exprs) = &q.columns {
1897 if let SelectExpr::Expr { expr: Expr::Case { whens, else_expr }, .. } = &exprs[0] {
1898 assert_eq!(whens.len(), 1);
1899 assert!(else_expr.is_none());
1900 } else {
1901 panic!("Expected Case expression");
1902 }
1903 } else {
1904 panic!("Expected Named columns");
1905 }
1906 } else {
1907 panic!("Expected Select");
1908 }
1909 }
1910
1911 #[test]
1912 fn test_case_when_in_aggregate() {
1913 let stmt = parse_query(
1914 "SELECT SUM(CASE WHEN side = 'BUY' THEN size ELSE -size END) AS net FROM orders GROUP BY token"
1915 ).unwrap();
1916 if let Statement::Select(q) = stmt {
1917 if let ColumnList::Named(exprs) = &q.columns {
1918 assert_eq!(exprs.len(), 1);
1919 assert!(matches!(&exprs[0], SelectExpr::Aggregate {
1920 func: AggFunc::Sum,
1921 arg_expr: Some(Expr::Case { .. }),
1922 alias: Some(a),
1923 ..
1924 } if a == "net"));
1925 } else {
1926 panic!("Expected Named columns");
1927 }
1928 } else {
1929 panic!("Expected Select");
1930 }
1931 }
1932
1933 #[test]
1934 fn test_case_when_with_alias() {
1935 let stmt = parse_query(
1936 "SELECT CASE WHEN x > 0 THEN 'pos' ELSE 'neg' END AS sign FROM test"
1937 ).unwrap();
1938 if let Statement::Select(q) = stmt {
1939 if let ColumnList::Named(exprs) = &q.columns {
1940 assert!(matches!(&exprs[0], SelectExpr::Expr {
1941 expr: Expr::Case { .. },
1942 alias: Some(a),
1943 } if a == "sign"));
1944 } else {
1945 panic!("Expected Named columns");
1946 }
1947 } else {
1948 panic!("Expected Select");
1949 }
1950 }
1951
1952 #[test]
1953 fn test_create_view() {
1954 let stmt = parse_query("CREATE VIEW live AS SELECT * FROM strategies WHERE status = 'LIVE'").unwrap();
1955 if let Statement::CreateView(cv) = stmt {
1956 assert_eq!(cv.view_name, "live");
1957 assert!(cv.columns.is_none());
1958 assert_eq!(cv.query.table, "strategies");
1959 assert!(cv.query.where_clause.is_some());
1960 } else {
1961 panic!("Expected CreateView, got {:?}", stmt);
1962 }
1963 }
1964
1965 #[test]
1966 fn test_create_view_with_columns() {
1967 let stmt = parse_query("CREATE VIEW v1 (a, b) AS SELECT title, status FROM t").unwrap();
1968 if let Statement::CreateView(cv) = stmt {
1969 assert_eq!(cv.view_name, "v1");
1970 assert_eq!(cv.columns, Some(vec!["a".into(), "b".into()]));
1971 } else {
1972 panic!("Expected CreateView");
1973 }
1974 }
1975
1976 #[test]
1977 fn test_drop_view() {
1978 let stmt = parse_query("DROP VIEW live").unwrap();
1979 if let Statement::DropView(dv) = stmt {
1980 assert_eq!(dv.view_name, "live");
1981 } else {
1982 panic!("Expected DropView, got {:?}", stmt);
1983 }
1984 }
1985
1986 #[test]
1987 fn test_create_view_case_insensitive() {
1988 let stmt = parse_query("create view My_View as select * from t").unwrap();
1989 if let Statement::CreateView(cv) = stmt {
1990 assert_eq!(cv.view_name, "My_View");
1991 } else {
1992 panic!("Expected CreateView");
1993 }
1994 }
1995
1996 #[test]
1999 fn test_aggregate_division() {
2000 let stmt = parse_query(
2001 "SELECT token, SUM(sell) / SUM(buy) as ratio FROM orders GROUP BY token"
2002 ).unwrap();
2003 if let Statement::Select(q) = stmt {
2004 assert_eq!(q.group_by, Some(vec!["token".into()]));
2005 if let ColumnList::Named(exprs) = &q.columns {
2006 assert_eq!(exprs.len(), 2);
2007 assert!(exprs[1].is_aggregate());
2008 } else {
2009 panic!("Expected Named columns");
2010 }
2011 } else {
2012 panic!("Expected Select");
2013 }
2014 }
2015
2016 #[test]
2017 fn test_aggregate_subtraction() {
2018 let stmt = parse_query(
2019 "SELECT token, SUM(sell) - SUM(buy) as net FROM orders GROUP BY token"
2020 ).unwrap();
2021 if let Statement::Select(q) = stmt {
2022 if let ColumnList::Named(exprs) = &q.columns {
2023 assert_eq!(exprs[1].output_name(), "net");
2024 }
2025 } else {
2026 panic!("Expected Select");
2027 }
2028 }
2029
2030 #[test]
2031 fn test_create_view_with_arithmetic() {
2032 let stmt = parse_query(
2033 "CREATE VIEW positions AS SELECT token, SUM(sell) / SUM(buy) as ratio FROM orders GROUP BY token"
2034 ).unwrap();
2035 if let Statement::CreateView(cv) = stmt {
2036 assert_eq!(cv.view_name, "positions");
2037 } else {
2038 panic!("Expected CreateView, got {:?}", stmt);
2039 }
2040 }
2041
2042 #[test]
2045 fn test_subquery_in_from() {
2046 let stmt = parse_query(
2047 "SELECT token, sell_size FROM (SELECT token, SUM(size) as sell_size FROM orders GROUP BY token) LIMIT 5"
2048 ).unwrap();
2049 if let Statement::Select(q) = stmt {
2050 assert!(q.subquery.is_some());
2051 assert_eq!(q.limit, Some(5));
2052 let sub = q.subquery.unwrap();
2053 assert_eq!(sub.table, "orders");
2054 assert!(sub.group_by.is_some());
2055 } else {
2056 panic!("Expected Select");
2057 }
2058 }
2059
2060 #[test]
2063 fn test_create_view_with_having() {
2064 let stmt = parse_query(
2065 "CREATE VIEW positions AS SELECT token, SUM(sell) as sell_size, SUM(buy) as buy_size FROM orders GROUP BY token HAVING sell_size > buy_size"
2066 ).unwrap();
2067 if let Statement::CreateView(cv) = stmt {
2068 assert_eq!(cv.view_name, "positions");
2069 assert!(cv.query.having.is_some());
2070 } else {
2071 panic!("Expected CreateView, got {:?}", stmt);
2072 }
2073 }
2074
2075 #[test]
2078 fn test_aggregate_multiplication() {
2079 let stmt = parse_query(
2080 "SELECT SUM(a) * 2 as doubled FROM test"
2081 ).unwrap();
2082 if let Statement::Select(q) = stmt {
2083 if let ColumnList::Named(exprs) = &q.columns {
2084 assert_eq!(exprs.len(), 1);
2085 assert!(exprs[0].is_aggregate());
2086 assert_eq!(exprs[0].output_name(), "doubled");
2087 } else {
2088 panic!("Expected Named columns");
2089 }
2090 } else {
2091 panic!("Expected Select");
2092 }
2093 }
2094
2095 #[test]
2096 fn test_complex_aggregate_arithmetic() {
2097 let stmt = parse_query(
2098 "SELECT SUM(CASE WHEN side = 'SELL' THEN size ELSE 0 END) / SUM(CASE WHEN side = 'BUY' THEN size ELSE 0 END) as ratio FROM orders GROUP BY token"
2099 ).unwrap();
2100 if let Statement::Select(q) = stmt {
2101 if let ColumnList::Named(exprs) = &q.columns {
2102 assert_eq!(exprs.len(), 1);
2103 assert!(exprs[0].is_aggregate());
2104 assert_eq!(exprs[0].output_name(), "ratio");
2105 } else {
2106 panic!("Expected Named columns");
2107 }
2108 assert_eq!(q.group_by, Some(vec!["token".into()]));
2109 } else {
2110 panic!("Expected Select");
2111 }
2112 }
2113
2114 #[test]
2117 fn test_subquery_with_alias() {
2118 let stmt = parse_query(
2119 "SELECT x FROM (SELECT x FROM t) sub"
2120 ).unwrap();
2121 if let Statement::Select(q) = stmt {
2122 assert!(q.subquery.is_some());
2123 let sub = q.subquery.unwrap();
2124 assert_eq!(sub.table, "t");
2125 if let ColumnList::Named(exprs) = &q.columns {
2126 assert_eq!(exprs.len(), 1);
2127 assert_eq!(exprs[0].output_name(), "x");
2128 } else {
2129 panic!("Expected Named columns");
2130 }
2131 } else {
2132 panic!("Expected Select");
2133 }
2134 }
2135
2136 #[test]
2137 fn test_subquery_with_where() {
2138 let stmt = parse_query(
2139 "SELECT x FROM (SELECT x FROM t WHERE y > 0) LIMIT 5"
2140 ).unwrap();
2141 if let Statement::Select(q) = stmt {
2142 assert!(q.subquery.is_some());
2143 assert_eq!(q.limit, Some(5));
2144 let sub = q.subquery.unwrap();
2145 assert_eq!(sub.table, "t");
2146 assert!(sub.where_clause.is_some());
2147 } else {
2148 panic!("Expected Select");
2149 }
2150 }
2151
2152 #[test]
2155 fn test_create_view_aggregate_subtraction() {
2156 let stmt = parse_query(
2157 "CREATE VIEW v AS SELECT token, SUM(sell) - SUM(buy) as net FROM orders GROUP BY token"
2158 ).unwrap();
2159 if let Statement::CreateView(cv) = stmt {
2160 assert_eq!(cv.view_name, "v");
2161 assert_eq!(cv.query.group_by, Some(vec!["token".into()]));
2162 if let ColumnList::Named(exprs) = &cv.query.columns {
2163 assert_eq!(exprs.len(), 2);
2164 assert_eq!(exprs[1].output_name(), "net");
2165 assert!(exprs[1].is_aggregate());
2166 } else {
2167 panic!("Expected Named columns");
2168 }
2169 } else {
2170 panic!("Expected CreateView, got {:?}", stmt);
2171 }
2172 }
2173
2174 #[test]
2175 fn test_delete_cascade() {
2176 let stmt = parse_query("DELETE FROM strategies WHERE status = 'KILLED' CASCADE").unwrap();
2177 if let Statement::Delete(q) = stmt {
2178 assert_eq!(q.table, "strategies");
2179 assert!(q.where_clause.is_some());
2180 assert_eq!(q.mode, DeleteMode::Cascade);
2181 } else {
2182 panic!("Expected Delete");
2183 }
2184 }
2185
2186 #[test]
2187 fn test_delete_restrict() {
2188 let stmt = parse_query("DELETE FROM strategies WHERE path = 'alpha.md' RESTRICT").unwrap();
2189 if let Statement::Delete(q) = stmt {
2190 assert_eq!(q.table, "strategies");
2191 assert_eq!(q.mode, DeleteMode::Restrict);
2192 } else {
2193 panic!("Expected Delete");
2194 }
2195 }
2196
2197 #[test]
2198 fn test_delete_default_unchanged() {
2199 let stmt = parse_query("DELETE FROM strategies WHERE status = 'KILLED'").unwrap();
2200 if let Statement::Delete(q) = stmt {
2201 assert_eq!(q.mode, DeleteMode::Default);
2202 } else {
2203 panic!("Expected Delete");
2204 }
2205 }
2206
2207 #[test]
2208 fn test_delete_cascade_no_where() {
2209 let stmt = parse_query("DELETE FROM strategies CASCADE").unwrap();
2210 if let Statement::Delete(q) = stmt {
2211 assert_eq!(q.table, "strategies");
2212 assert!(q.where_clause.is_none());
2213 assert_eq!(q.mode, DeleteMode::Cascade);
2214 } else {
2215 panic!("Expected Delete");
2216 }
2217 }
2218
2219 #[test]
2222 fn test_cte_basic() {
2223 let stmt = parse_query(
2224 "WITH live AS (SELECT * FROM strategies WHERE status = 'LIVE') SELECT * FROM live"
2225 ).unwrap();
2226 if let Statement::Select(q) = stmt {
2227 assert_eq!(q.ctes.len(), 1);
2228 assert_eq!(q.ctes[0].name, "live");
2229 assert_eq!(q.ctes[0].query.table, "strategies");
2230 assert!(q.ctes[0].query.where_clause.is_some());
2231 assert_eq!(q.table, "live");
2232 } else {
2233 panic!("Expected Select");
2234 }
2235 }
2236
2237 #[test]
2238 fn test_cte_multi() {
2239 let stmt = parse_query(
2240 "WITH a AS (SELECT * FROM t1), b AS (SELECT * FROM t2) SELECT * FROM a JOIN b ON a.id = b.id"
2241 ).unwrap();
2242 if let Statement::Select(q) = stmt {
2243 assert_eq!(q.ctes.len(), 2);
2244 assert_eq!(q.ctes[0].name, "a");
2245 assert_eq!(q.ctes[0].query.table, "t1");
2246 assert_eq!(q.ctes[1].name, "b");
2247 assert_eq!(q.ctes[1].query.table, "t2");
2248 assert_eq!(q.table, "a");
2249 assert_eq!(q.joins.len(), 1);
2250 } else {
2251 panic!("Expected Select");
2252 }
2253 }
2254
2255 #[test]
2256 fn test_cte_with_aggregation() {
2257 let stmt = parse_query(
2258 "WITH totals AS (SELECT strategy, COUNT(*) AS cnt FROM backtests GROUP BY strategy) SELECT * FROM totals WHERE cnt > 1"
2259 ).unwrap();
2260 if let Statement::Select(q) = stmt {
2261 assert_eq!(q.ctes.len(), 1);
2262 assert_eq!(q.ctes[0].name, "totals");
2263 assert!(q.ctes[0].query.group_by.is_some());
2264 assert_eq!(q.table, "totals");
2265 assert!(q.where_clause.is_some());
2266 } else {
2267 panic!("Expected Select");
2268 }
2269 }
2270
2271 #[test]
2272 fn test_cte_no_ctes_on_plain_select() {
2273 let stmt = parse_query("SELECT * FROM t").unwrap();
2274 if let Statement::Select(q) = stmt {
2275 assert!(q.ctes.is_empty());
2276 } else {
2277 panic!("Expected Select");
2278 }
2279 }
2280
2281 #[test]
2284 fn test_where_in_subquery() {
2285 let stmt = parse_query(
2286 "SELECT * FROM strategies WHERE path IN (SELECT strategy FROM backtests)"
2287 ).unwrap();
2288 if let Statement::Select(q) = stmt {
2289 if let Some(WhereClause::Comparison(c)) = &q.where_clause {
2290 assert_eq!(c.op, CmpOp::In);
2291 assert!(matches!(&c.right_expr, Some(Expr::Subquery(_))));
2292 } else {
2293 panic!("Expected IN comparison");
2294 }
2295 } else {
2296 panic!("Expected Select");
2297 }
2298 }
2299
2300 #[test]
2301 fn test_scalar_subquery_in_where() {
2302 let stmt = parse_query(
2303 "SELECT * FROM backtests WHERE sharpe > (SELECT AVG(sharpe) FROM backtests)"
2304 ).unwrap();
2305 if let Statement::Select(q) = stmt {
2306 if let Some(WhereClause::Comparison(c)) = &q.where_clause {
2307 assert_eq!(c.op, CmpOp::Gt);
2308 assert!(matches!(&c.right_expr, Some(Expr::Subquery(_))));
2309 } else {
2310 panic!("Expected comparison");
2311 }
2312 } else {
2313 panic!("Expected Select");
2314 }
2315 }
2316
2317 #[test]
2318 fn test_scalar_subquery_in_select() {
2319 let stmt = parse_query(
2320 "SELECT title, (SELECT COUNT(*) FROM backtests) AS cnt FROM strategies"
2321 ).unwrap();
2322 if let Statement::Select(q) = stmt {
2323 if let ColumnList::Named(exprs) = &q.columns {
2324 assert_eq!(exprs.len(), 2);
2325 assert!(matches!(&exprs[1], SelectExpr::Expr {
2326 expr: Expr::Subquery(_),
2327 alias: Some(a),
2328 } if a == "cnt"));
2329 } else {
2330 panic!("Expected Named columns");
2331 }
2332 } else {
2333 panic!("Expected Select");
2334 }
2335 }
2336
2337 #[test]
2340 fn test_row_number_over_order_by() {
2341 let stmt = parse_query(
2342 "SELECT title, ROW_NUMBER() OVER (ORDER BY count DESC) AS rn FROM test"
2343 ).unwrap();
2344 if let Statement::Select(q) = stmt {
2345 if let ColumnList::Named(exprs) = &q.columns {
2346 assert_eq!(exprs.len(), 2);
2347 if let SelectExpr::Expr { expr: Expr::Window { func, args, over }, alias } = &exprs[1] {
2348 assert_eq!(*func, WindowFunc::RowNumber);
2349 assert!(args.is_empty());
2350 assert!(over.partition_by.is_empty());
2351 assert_eq!(over.order_by.len(), 1);
2352 assert!(over.order_by[0].descending);
2353 assert_eq!(alias.as_deref(), Some("rn"));
2354 } else {
2355 panic!("Expected Window expression, got {:?}", exprs[1]);
2356 }
2357 } else {
2358 panic!("Expected Named columns");
2359 }
2360 } else {
2361 panic!("Expected Select");
2362 }
2363 }
2364
2365 #[test]
2366 fn test_rank_with_partition_by() {
2367 let stmt = parse_query(
2368 "SELECT RANK() OVER (PARTITION BY category ORDER BY price DESC) AS rnk FROM test"
2369 ).unwrap();
2370 if let Statement::Select(q) = stmt {
2371 if let ColumnList::Named(exprs) = &q.columns {
2372 if let SelectExpr::Expr { expr: Expr::Window { func, over, .. }, .. } = &exprs[0] {
2373 assert_eq!(*func, WindowFunc::Rank);
2374 assert_eq!(over.partition_by, vec!["category"]);
2375 assert_eq!(over.order_by.len(), 1);
2376 } else {
2377 panic!("Expected Window expression");
2378 }
2379 } else {
2380 panic!("Expected Named columns");
2381 }
2382 } else {
2383 panic!("Expected Select");
2384 }
2385 }
2386
2387 #[test]
2388 fn test_agg_over_window() {
2389 let stmt = parse_query(
2390 "SELECT SUM(price) OVER (PARTITION BY category) AS cat_total FROM test"
2391 ).unwrap();
2392 if let Statement::Select(q) = stmt {
2393 if let ColumnList::Named(exprs) = &q.columns {
2394 if let SelectExpr::Expr { expr: Expr::Window { func, args, over }, alias } = &exprs[0] {
2395 assert!(matches!(func, WindowFunc::Agg(AggFunc::Sum)));
2396 assert_eq!(args.len(), 1);
2397 assert_eq!(over.partition_by, vec!["category"]);
2398 assert!(over.order_by.is_empty());
2399 assert_eq!(alias.as_deref(), Some("cat_total"));
2400 } else {
2401 panic!("Expected Window expression");
2402 }
2403 } else {
2404 panic!("Expected Named columns");
2405 }
2406 } else {
2407 panic!("Expected Select");
2408 }
2409 }
2410
2411 #[test]
2412 fn test_lag_with_args() {
2413 let stmt = parse_query(
2414 "SELECT LAG(price, 1) OVER (ORDER BY price) AS prev_price FROM test"
2415 ).unwrap();
2416 if let Statement::Select(q) = stmt {
2417 if let ColumnList::Named(exprs) = &q.columns {
2418 if let SelectExpr::Expr { expr: Expr::Window { func, args, .. }, .. } = &exprs[0] {
2419 assert_eq!(*func, WindowFunc::Lag);
2420 assert_eq!(args.len(), 2);
2421 } else {
2422 panic!("Expected Window expression");
2423 }
2424 } else {
2425 panic!("Expected Named columns");
2426 }
2427 } else {
2428 panic!("Expected Select");
2429 }
2430 }
2431
2432 #[test]
2433 fn test_dense_rank() {
2434 let stmt = parse_query(
2435 "SELECT DENSE_RANK() OVER (ORDER BY count DESC) AS dr FROM test"
2436 ).unwrap();
2437 if let Statement::Select(q) = stmt {
2438 if let ColumnList::Named(exprs) = &q.columns {
2439 if let SelectExpr::Expr { expr: Expr::Window { func, .. }, .. } = &exprs[0] {
2440 assert_eq!(*func, WindowFunc::DenseRank);
2441 } else {
2442 panic!("Expected Window expression");
2443 }
2444 } else {
2445 panic!("Expected Named columns");
2446 }
2447 } else {
2448 panic!("Expected Select");
2449 }
2450 }
2451
2452 #[test]
2453 fn test_sum_without_over_is_aggregate() {
2454 let stmt = parse_query("SELECT SUM(count) FROM test").unwrap();
2455 if let Statement::Select(q) = stmt {
2456 if let ColumnList::Named(exprs) = &q.columns {
2457 assert!(matches!(&exprs[0], SelectExpr::Aggregate { func: AggFunc::Sum, .. }));
2458 } else {
2459 panic!("Expected Named columns");
2460 }
2461 } else {
2462 panic!("Expected Select");
2463 }
2464 }
2465}