1use regex::Regex;
4use std::sync::LazyLock;
5
6use crate::errors::MdqlError;
7pub use crate::query_ast::*;
8
9static KEYWORDS: &[&str] = &[
12 "SELECT", "FROM", "WHERE", "AND", "OR", "ORDER", "BY",
13 "ASC", "DESC", "LIMIT", "LIKE", "IN", "IS", "NOT", "NULL",
14 "JOIN", "LEFT", "ON", "AS", "GROUP", "HAVING",
15 "INSERT", "INTO", "VALUES", "UPDATE", "SET", "DELETE",
16 "ALTER", "TABLE", "RENAME", "FIELD", "TO", "DROP", "MERGE", "FIELDS",
17 "CASE", "WHEN", "THEN", "ELSE", "END",
18 "INTERVAL", "DAY", "DAYS", "CURRENT_DATE", "CURRENT_TIMESTAMP", "DATEDIFF",
19 "CREATE", "VIEW", "CASCADE", "RESTRICT",
20];
21
22static AGG_FUNCS: &[&str] = &["COUNT", "SUM", "AVG", "MIN", "MAX"];
23
24static TOKEN_RE: LazyLock<Regex> = LazyLock::new(|| {
25 Regex::new(
26 r#"(?x)
27 \s*(?:
28 (?P<backtick>`[^`]+`)
29 | (?P<string>'(?:[^'\\]|\\.)*')
30 | (?P<number>\d+(?:\.\d+)?)
31 | (?P<op><=|>=|!=|[=<>,*()+\-/%])
32 | (?P<word>[A-Za-z_][A-Za-z0-9_./-]*)
33 )"#,
34 )
35 .unwrap()
36});
37
38#[derive(Debug, Clone)]
39struct Token {
40 token_type: String,
41 value: String,
42 raw: String,
43}
44
45fn tokenize(sql: &str) -> Vec<Token> {
46 let mut tokens = Vec::new();
47 for caps in TOKEN_RE.captures_iter(sql) {
48 if let Some(m) = caps.name("backtick") {
49 let raw = m.as_str();
50 tokens.push(Token {
51 token_type: "ident".into(),
52 value: raw[1..raw.len() - 1].into(),
53 raw: raw.into(),
54 });
55 } else if let Some(m) = caps.name("string") {
56 let raw = m.as_str();
57 tokens.push(Token {
58 token_type: "string".into(),
59 value: raw[1..raw.len() - 1].into(),
60 raw: raw.into(),
61 });
62 } else if let Some(m) = caps.name("number") {
63 let raw = m.as_str();
64 tokens.push(Token {
65 token_type: "number".into(),
66 value: raw.into(),
67 raw: raw.into(),
68 });
69 } else if let Some(m) = caps.name("op") {
70 let raw = m.as_str();
71 tokens.push(Token {
72 token_type: "op".into(),
73 value: raw.into(),
74 raw: raw.into(),
75 });
76 } else if let Some(m) = caps.name("word") {
77 let raw = m.as_str();
78 if KEYWORDS.contains(&raw.to_uppercase().as_str()) {
79 tokens.push(Token {
80 token_type: "keyword".into(),
81 value: raw.to_uppercase(),
82 raw: raw.into(),
83 });
84 } else {
85 tokens.push(Token {
86 token_type: "ident".into(),
87 value: raw.into(),
88 raw: raw.into(),
89 });
90 }
91 }
92 }
93 tokens
94}
95
96struct Parser {
99 tokens: Vec<Token>,
100 pos: usize,
101}
102
103impl Parser {
104 fn new(tokens: Vec<Token>) -> Self {
105 Parser { tokens, pos: 0 }
106 }
107
108 fn peek(&self) -> Option<&Token> {
109 self.tokens.get(self.pos)
110 }
111
112 fn advance(&mut self) -> Token {
113 let t = self.tokens[self.pos].clone();
114 self.pos += 1;
115 t
116 }
117
118 fn expect(&mut self, type_: &str, value: Option<&str>) -> Result<Token, MdqlError> {
119 let t = self.peek().ok_or_else(|| {
120 MdqlError::QueryParse(format!(
121 "Unexpected end of query, expected {}",
122 value.unwrap_or(type_)
123 ))
124 })?;
125 let matches_type = t.token_type == type_;
126 let matches_value = value.map_or(true, |v| t.value == v);
127 if !matches_type || !matches_value {
128 return Err(MdqlError::QueryParse(format!(
129 "Expected {}, got '{}' at position {}",
130 value.unwrap_or(type_),
131 t.raw,
132 self.pos
133 )));
134 }
135 Ok(self.advance())
136 }
137
138 fn match_keyword(&mut self, kw: &str) -> bool {
139 if let Some(t) = self.peek() {
140 if t.token_type == "keyword" && t.value == kw {
141 self.advance();
142 return true;
143 }
144 }
145 false
146 }
147
148 fn parse_statement(&mut self) -> Result<Statement, MdqlError> {
149 let t = self.peek().ok_or_else(|| MdqlError::QueryParse("Empty query".into()))?;
150 match (t.token_type.as_str(), t.value.as_str()) {
151 ("keyword", "SELECT") => {
152 let q = self.parse_select()?;
153 self.expect_end()?;
154 Ok(Statement::Select(q))
155 }
156 ("keyword", "INSERT") => Ok(Statement::Insert(self.parse_insert()?)),
157 ("keyword", "UPDATE") => Ok(Statement::Update(self.parse_update()?)),
158 ("keyword", "DELETE") => Ok(Statement::Delete(self.parse_delete()?)),
159 ("keyword", "ALTER") => self.parse_alter(),
160 ("keyword", "CREATE") => self.parse_create_view(),
161 ("keyword", "DROP") => self.parse_drop_view(),
162 _ => Err(MdqlError::QueryParse(format!(
163 "Expected SELECT, INSERT, UPDATE, DELETE, ALTER, CREATE, or DROP, got '{}'",
164 t.raw
165 ))),
166 }
167 }
168
169 fn parse_select(&mut self) -> Result<SelectQuery, MdqlError> {
170 self.expect("keyword", Some("SELECT"))?;
171 let columns = self.parse_columns()?;
172 self.expect("keyword", Some("FROM"))?;
173
174 let mut subquery = None;
176 let (table, mut table_alias) = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
177 self.advance();
178 let inner = self.parse_select()?;
179 self.expect("op", Some(")"))?;
180 subquery = Some(Box::new(inner));
181 let alias = if let Some(t) = self.peek() {
182 if t.token_type == "ident" && !self.is_clause_keyword(t) {
183 Some(self.advance().value)
184 } else {
185 None
186 }
187 } else {
188 None
189 };
190 ("_subquery".to_string(), alias)
191 } else {
192 let t = self.parse_ident()?;
193 (t, None)
194 };
195
196 if subquery.is_none() {
198 if let Some(t) = self.peek() {
199 if t.token_type == "ident" && !self.is_clause_keyword(t) {
200 table_alias = Some(self.advance().value);
201 }
202 }
203 }
204
205 let mut joins = Vec::new();
207 loop {
208 let jt = if self.match_keyword("LEFT") {
209 self.expect("keyword", Some("JOIN"))?;
210 JoinType::Left
211 } else if self.match_keyword("JOIN") {
212 JoinType::Inner
213 } else {
214 break;
215 };
216 let join_table = self.parse_ident()?;
217 let mut join_alias = None;
218 if let Some(t) = self.peek() {
219 if t.token_type == "ident" && !self.is_clause_keyword(t) {
220 join_alias = Some(self.advance().value);
221 }
222 }
223 self.expect("keyword", Some("ON"))?;
224 let left_col = self.parse_ident()?;
225 self.expect("op", Some("="))?;
226 let right_col = self.parse_ident()?;
227 joins.push(JoinClause {
228 join_type: jt,
229 table: join_table,
230 alias: join_alias,
231 left_col,
232 right_col,
233 });
234 }
235
236 let mut where_clause = None;
237 if self.match_keyword("WHERE") {
238 where_clause = Some(self.parse_or_expr()?);
239 }
240
241 let mut group_by = None;
242 if self.match_keyword("GROUP") {
243 self.expect("keyword", Some("BY"))?;
244 let mut cols = vec![self.parse_ident()?];
245 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
246 self.advance();
247 cols.push(self.parse_ident()?);
248 }
249 group_by = Some(cols);
250 }
251
252 let mut having = None;
253 if self.match_keyword("HAVING") {
254 having = Some(self.parse_or_expr()?);
255 }
256
257 let mut order_by = None;
258 if self.match_keyword("ORDER") {
259 self.expect("keyword", Some("BY"))?;
260 order_by = Some(self.parse_order_by()?);
261 }
262
263 let mut limit = None;
264 if self.match_keyword("LIMIT") {
265 let t = self.expect("number", None)?;
266 limit = Some(t.value.parse::<i64>().map_err(|_| {
267 MdqlError::QueryParse(format!("Invalid LIMIT value: {}", t.value))
268 })?);
269 }
270
271 Ok(SelectQuery {
272 columns,
273 table,
274 table_alias,
275 subquery,
276 joins,
277 where_clause,
278 group_by,
279 having,
280 order_by,
281 limit,
282 })
283 }
284
285 fn parse_insert(&mut self) -> Result<InsertQuery, MdqlError> {
286 self.expect("keyword", Some("INSERT"))?;
287 self.expect("keyword", Some("INTO"))?;
288 let table = self.parse_ident()?;
289
290 self.expect("op", Some("("))?;
291 let mut columns = vec![self.parse_ident()?];
292 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
293 self.advance();
294 columns.push(self.parse_ident()?);
295 }
296 self.expect("op", Some(")"))?;
297
298 self.expect("keyword", Some("VALUES"))?;
299
300 self.expect("op", Some("("))?;
301 let mut values = vec![self.parse_value()?];
302 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
303 self.advance();
304 values.push(self.parse_value()?);
305 }
306 self.expect("op", Some(")"))?;
307
308 if columns.len() != values.len() {
309 return Err(MdqlError::QueryParse(format!(
310 "Column count ({}) does not match value count ({})",
311 columns.len(),
312 values.len()
313 )));
314 }
315
316 self.expect_end()?;
317 Ok(InsertQuery {
318 table,
319 columns,
320 values,
321 })
322 }
323
324 fn parse_update(&mut self) -> Result<UpdateQuery, MdqlError> {
325 self.expect("keyword", Some("UPDATE"))?;
326 let table = self.parse_ident()?;
327 self.expect("keyword", Some("SET"))?;
328
329 let mut assignments = Vec::new();
330 let col = self.parse_ident()?;
331 self.expect("op", Some("="))?;
332 let val = self.parse_value()?;
333 assignments.push((col, val));
334
335 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
336 self.advance();
337 let col = self.parse_ident()?;
338 self.expect("op", Some("="))?;
339 let val = self.parse_value()?;
340 assignments.push((col, val));
341 }
342
343 let mut where_clause = None;
344 if self.match_keyword("WHERE") {
345 where_clause = Some(self.parse_or_expr()?);
346 }
347
348 self.expect_end()?;
349 Ok(UpdateQuery {
350 table,
351 assignments,
352 where_clause,
353 })
354 }
355
356 fn parse_delete(&mut self) -> Result<DeleteQuery, MdqlError> {
357 self.expect("keyword", Some("DELETE"))?;
358 self.expect("keyword", Some("FROM"))?;
359 let table = self.parse_ident()?;
360
361 let mut where_clause = None;
362 if self.match_keyword("WHERE") {
363 where_clause = Some(self.parse_or_expr()?);
364 }
365
366 self.expect_end()?;
367 Ok(DeleteQuery {
368 table,
369 where_clause,
370 })
371 }
372
373 fn parse_alter(&mut self) -> Result<Statement, MdqlError> {
374 self.expect("keyword", Some("ALTER"))?;
375 self.expect("keyword", Some("TABLE"))?;
376 let table = self.parse_ident()?;
377
378 let t = self.peek().ok_or_else(|| {
379 MdqlError::QueryParse("Expected RENAME, DROP, or MERGE after table name".into())
380 })?;
381
382 match (t.token_type.as_str(), t.value.as_str()) {
383 ("keyword", "RENAME") => {
384 self.advance();
385 self.expect("keyword", Some("FIELD"))?;
386 let old_name = self.parse_string_or_ident()?;
387 self.expect("keyword", Some("TO"))?;
388 let new_name = self.parse_string_or_ident()?;
389 self.expect_end()?;
390 Ok(Statement::AlterRename(AlterRenameFieldQuery {
391 table,
392 old_name,
393 new_name,
394 }))
395 }
396 ("keyword", "DROP") => {
397 self.advance();
398 self.expect("keyword", Some("FIELD"))?;
399 let field_name = self.parse_string_or_ident()?;
400 self.expect_end()?;
401 Ok(Statement::AlterDrop(AlterDropFieldQuery {
402 table,
403 field_name,
404 }))
405 }
406 ("keyword", "MERGE") => {
407 self.advance();
408 self.expect("keyword", Some("FIELDS"))?;
409 let mut sources = vec![self.parse_string_or_ident()?];
410 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
411 self.advance();
412 sources.push(self.parse_string_or_ident()?);
413 }
414 self.expect("keyword", Some("INTO"))?;
415 let target = self.parse_string_or_ident()?;
416 self.expect_end()?;
417 Ok(Statement::AlterMerge(AlterMergeFieldsQuery {
418 table,
419 sources,
420 into: target,
421 }))
422 }
423 _ => Err(MdqlError::QueryParse(format!(
424 "Expected RENAME, DROP, or MERGE, got '{}'",
425 t.raw
426 ))),
427 }
428 }
429
430 fn parse_create_view(&mut self) -> Result<Statement, MdqlError> {
431 self.expect("keyword", Some("CREATE"))?;
432 self.expect("keyword", Some("VIEW"))?;
433 let view_name = self.parse_ident()?;
434
435 let columns = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
436 self.advance();
437 let mut cols = vec![self.parse_ident()?];
438 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
439 self.advance();
440 cols.push(self.parse_ident()?);
441 }
442 self.expect("op", Some(")"))?;
443 Some(cols)
444 } else {
445 None
446 };
447
448 self.expect("keyword", Some("AS"))?;
449 let query = Box::new(self.parse_select()?);
450 self.expect_end()?;
451
452 Ok(Statement::CreateView(CreateViewQuery {
453 view_name,
454 columns,
455 query,
456 }))
457 }
458
459 fn parse_drop_view(&mut self) -> Result<Statement, MdqlError> {
460 self.expect("keyword", Some("DROP"))?;
461 self.expect("keyword", Some("VIEW"))?;
462 let view_name = self.parse_ident()?;
463 self.expect_end()?;
464 Ok(Statement::DropView(DropViewQuery { view_name }))
465 }
466
467 fn parse_string_or_ident(&mut self) -> Result<String, MdqlError> {
468 let t = self.peek().ok_or_else(|| {
469 MdqlError::QueryParse("Expected field name, got end of query".into())
470 })?;
471 match t.token_type.as_str() {
472 "string" => {
473 let v = self.advance().value;
474 Ok(v)
475 }
476 "ident" | "keyword" => {
477 let v = self.advance().value;
478 Ok(v)
479 }
480 _ => Err(MdqlError::QueryParse(format!(
481 "Expected field name, got '{}'",
482 t.raw
483 ))),
484 }
485 }
486
487 fn parse_columns(&mut self) -> Result<ColumnList, MdqlError> {
488 if let Some(t) = self.peek() {
489 if t.token_type == "op" && t.value == "*" {
490 self.advance();
491 return Ok(ColumnList::All);
492 }
493 }
494
495 let mut exprs = vec![self.parse_select_expr()?];
496 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
497 self.advance();
498 exprs.push(self.parse_select_expr()?);
499 }
500 Ok(ColumnList::Named(exprs))
501 }
502
503 fn peek_is_agg_func(&self) -> bool {
504 let t = match self.peek() {
505 Some(t) => t,
506 None => return false,
507 };
508 let name_upper = t.value.to_uppercase();
509 if !AGG_FUNCS.contains(&name_upper.as_str()) {
510 return false;
511 }
512 self.tokens
514 .get(self.pos + 1)
515 .map_or(false, |next| next.token_type == "op" && next.value == "(")
516 }
517
518 fn parse_select_expr(&mut self) -> Result<SelectExpr, MdqlError> {
519 let _t = self.peek().ok_or_else(|| {
520 MdqlError::QueryParse("Expected column or aggregate, got end of query".into())
521 })?;
522
523 let expr = self.parse_additive()?;
524
525 let alias = if self.match_keyword("AS") {
526 Some(self.parse_ident()?)
527 } else if self.peek().map_or(false, |t| {
528 t.token_type == "ident" && !self.is_clause_keyword(t)
529 }) {
530 Some(self.advance().value)
531 } else {
532 None
533 };
534
535 if let Expr::Aggregate { func, arg, arg_expr } = expr {
537 return Ok(SelectExpr::Aggregate {
538 func,
539 arg,
540 arg_expr: arg_expr.map(|e| *e),
541 alias,
542 });
543 }
544
545 if alias.is_none() {
546 if let Expr::Column(name) = &expr {
547 return Ok(SelectExpr::Column(name.clone()));
548 }
549 }
550
551 Ok(SelectExpr::Expr { expr, alias })
552 }
553
554 fn peek_is_additive_op(&self) -> bool {
557 self.peek().map_or(false, |t| {
558 t.token_type == "op" && (t.value == "+" || t.value == "-")
559 })
560 }
561
562 fn peek_is_multiplicative_op(&self) -> bool {
563 self.peek().map_or(false, |t| {
564 t.token_type == "op" && (t.value == "*" || t.value == "/" || t.value == "%")
565 })
566 }
567
568 fn parse_additive(&mut self) -> Result<Expr, MdqlError> {
569 let mut left = self.parse_multiplicative()?;
570 while self.peek_is_additive_op() {
571 let op_tok = self.advance();
572 let is_sub = op_tok.value == "-";
573
574 if self.peek().map_or(false, |t| t.token_type == "keyword" && t.value == "INTERVAL") {
576 self.advance(); let days_expr = self.parse_multiplicative()?;
578 if !self.match_keyword("DAY") && !self.match_keyword("DAYS") {
580 return Err(MdqlError::QueryParse("Expected DAY after INTERVAL value".into()));
581 }
582 let days = if is_sub {
583 Expr::UnaryMinus(Box::new(days_expr))
584 } else {
585 days_expr
586 };
587 left = Expr::DateAdd {
588 date: Box::new(left),
589 days: Box::new(days),
590 };
591 continue;
592 }
593
594 let op = match op_tok.value.as_str() {
595 "+" => ArithOp::Add,
596 "-" => ArithOp::Sub,
597 _ => unreachable!(),
598 };
599 let right = self.parse_multiplicative()?;
600 left = Expr::BinaryOp {
601 left: Box::new(left),
602 op,
603 right: Box::new(right),
604 };
605 }
606 Ok(left)
607 }
608
609 fn parse_multiplicative(&mut self) -> Result<Expr, MdqlError> {
610 let mut left = self.parse_unary()?;
611 while self.peek_is_multiplicative_op() {
612 let op_tok = self.advance();
613 let op = match op_tok.value.as_str() {
614 "*" => ArithOp::Mul,
615 "/" => ArithOp::Div,
616 "%" => ArithOp::Mod,
617 _ => unreachable!(),
618 };
619 let right = self.parse_unary()?;
620 left = Expr::BinaryOp {
621 left: Box::new(left),
622 op,
623 right: Box::new(right),
624 };
625 }
626 Ok(left)
627 }
628
629 fn parse_unary(&mut self) -> Result<Expr, MdqlError> {
630 if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "-") {
631 self.advance();
632 let inner = self.parse_atom()?;
633 match inner {
635 Expr::Literal(SqlValue::Int(n)) => Ok(Expr::Literal(SqlValue::Int(-n))),
636 Expr::Literal(SqlValue::Float(f)) => Ok(Expr::Literal(SqlValue::Float(-f))),
637 _ => Ok(Expr::UnaryMinus(Box::new(inner))),
638 }
639 } else {
640 self.parse_atom()
641 }
642 }
643
644 fn parse_atom(&mut self) -> Result<Expr, MdqlError> {
645 if self.peek_is_agg_func() {
646 return self.parse_agg_expr();
647 }
648
649 let t = self.peek().ok_or_else(|| {
650 MdqlError::QueryParse("Expected expression, got end of query".into())
651 })?;
652
653 match t.token_type.as_str() {
654 "number" => {
655 let v = self.advance().value;
656 if v.contains('.') {
657 let f: f64 = v.parse().map_err(|_| {
658 MdqlError::QueryParse(format!("Invalid float: {}", v))
659 })?;
660 Ok(Expr::Literal(SqlValue::Float(f)))
661 } else {
662 let n: i64 = v.parse().map_err(|_| {
663 MdqlError::QueryParse(format!("Invalid int: {}", v))
664 })?;
665 Ok(Expr::Literal(SqlValue::Int(n)))
666 }
667 }
668 "string" => {
669 let v = self.advance().value;
670 Ok(Expr::Literal(SqlValue::String(v)))
671 }
672 "keyword" if t.value == "NULL" => {
673 self.advance();
674 Ok(Expr::Literal(SqlValue::Null))
675 }
676 "keyword" if t.value == "CASE" => {
677 self.parse_case_expr()
678 }
679 "keyword" if t.value == "CURRENT_DATE" => {
680 self.advance();
681 Ok(Expr::CurrentDate)
682 }
683 "keyword" if t.value == "CURRENT_TIMESTAMP" => {
684 self.advance();
685 Ok(Expr::CurrentTimestamp)
686 }
687 "keyword" if t.value == "DATEDIFF" => {
688 self.advance();
689 self.expect("op", Some("("))?;
690 let left = self.parse_additive()?;
691 self.expect("op", Some(","))?;
692 let right = self.parse_additive()?;
693 self.expect("op", Some(")"))?;
694 Ok(Expr::DateDiff { left: Box::new(left), right: Box::new(right) })
695 }
696 "op" if t.value == "(" => {
697 self.advance();
698 let expr = self.parse_additive()?;
699 self.expect("op", Some(")"))?;
700 Ok(expr)
701 }
702 "ident" => {
703 let name = self.advance().value;
704 Ok(Expr::Column(name))
705 }
706 "keyword" if !Self::is_reserved_keyword(&t.value) => {
707 let name = self.advance().value;
708 Ok(Expr::Column(name))
709 }
710 _ => Err(MdqlError::QueryParse(format!(
711 "Expected expression, got '{}'",
712 t.raw
713 ))),
714 }
715 }
716
717 fn parse_case_expr(&mut self) -> Result<Expr, MdqlError> {
718 self.expect("keyword", Some("CASE"))?;
719 let mut whens = Vec::new();
720 while self.match_keyword("WHEN") {
721 let condition = self.parse_or_expr()?;
722 self.expect("keyword", Some("THEN"))?;
723 let result = self.parse_additive()?;
724 whens.push((condition, Box::new(result)));
725 }
726 if whens.is_empty() {
727 return Err(MdqlError::QueryParse("CASE requires at least one WHEN clause".into()));
728 }
729 let else_expr = if self.match_keyword("ELSE") {
730 Some(Box::new(self.parse_additive()?))
731 } else {
732 None
733 };
734 self.expect("keyword", Some("END"))?;
735 Ok(Expr::Case { whens, else_expr })
736 }
737
738 fn parse_agg_expr(&mut self) -> Result<Expr, MdqlError> {
739 let func_name = self.advance().value.to_uppercase();
740 let func = match func_name.as_str() {
741 "COUNT" => AggFunc::Count,
742 "SUM" => AggFunc::Sum,
743 "AVG" => AggFunc::Avg,
744 "MIN" => AggFunc::Min,
745 "MAX" => AggFunc::Max,
746 _ => unreachable!(),
747 };
748 self.expect("op", Some("("))?;
749 let (arg, arg_expr) = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "*") {
750 self.advance();
751 ("*".to_string(), None)
752 } else {
753 let expr = self.parse_additive()?;
754 if let Expr::Column(name) = &expr {
755 (name.clone(), None)
756 } else {
757 (expr.display_name(), Some(Box::new(expr)))
758 }
759 };
760 self.expect("op", Some(")"))?;
761 Ok(Expr::Aggregate { func, arg, arg_expr })
762 }
763
764 fn parse_ident(&mut self) -> Result<String, MdqlError> {
765 let t = self.peek().ok_or_else(|| {
766 MdqlError::QueryParse("Expected identifier, got end of query".into())
767 })?;
768 match t.token_type.as_str() {
769 "ident" | "keyword" => {
770 let v = self.advance().value;
771 Ok(v)
772 }
773 _ => Err(MdqlError::QueryParse(format!(
774 "Expected identifier, got '{}'",
775 t.raw
776 ))),
777 }
778 }
779
780 fn parse_or_expr(&mut self) -> Result<WhereClause, MdqlError> {
781 let mut left = self.parse_and_expr()?;
782 while self.match_keyword("OR") {
783 let right = self.parse_and_expr()?;
784 left = WhereClause::BoolOp(BoolOp {
785 op: BoolOpKind::Or,
786 left: Box::new(left),
787 right: Box::new(right),
788 });
789 }
790 Ok(left)
791 }
792
793 fn parse_and_expr(&mut self) -> Result<WhereClause, MdqlError> {
794 let mut left = self.parse_comparison()?;
795 while self.match_keyword("AND") {
796 let right = self.parse_comparison()?;
797 left = WhereClause::BoolOp(BoolOp {
798 op: BoolOpKind::And,
799 left: Box::new(left),
800 right: Box::new(right),
801 });
802 }
803 Ok(left)
804 }
805
806 fn parse_comparison(&mut self) -> Result<WhereClause, MdqlError> {
807 if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
809 let saved_pos = self.pos;
811 self.advance();
812 let result = self.parse_or_expr();
814 if result.is_ok() && self.peek().map_or(false, |t| t.token_type == "op" && t.value == ")") {
815 self.advance();
816 return result;
817 }
818 self.pos = saved_pos;
820 }
821
822 let left_expr = self.parse_additive()?;
824
825 let col = left_expr.as_column().unwrap_or("").to_string();
827
828 if self.match_keyword("IS") {
830 if self.match_keyword("NOT") {
831 self.expect("keyword", Some("NULL"))?;
832 return Ok(WhereClause::Comparison(Comparison {
833 column: col,
834 op: CmpOp::IsNotNull,
835 value: None,
836 left_expr: Some(left_expr),
837 right_expr: None,
838 }));
839 }
840 self.expect("keyword", Some("NULL"))?;
841 return Ok(WhereClause::Comparison(Comparison {
842 column: col,
843 op: CmpOp::IsNull,
844 value: None,
845 left_expr: Some(left_expr),
846 right_expr: None,
847 }));
848 }
849
850 if self.match_keyword("IN") {
852 self.expect("op", Some("("))?;
853 let mut values = vec![self.parse_value()?];
854 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
855 self.advance();
856 values.push(self.parse_value()?);
857 }
858 self.expect("op", Some(")"))?;
859 return Ok(WhereClause::Comparison(Comparison {
860 column: col,
861 op: CmpOp::In,
862 value: Some(SqlValue::List(values)),
863 left_expr: Some(left_expr),
864 right_expr: None,
865 }));
866 }
867
868 if self.match_keyword("LIKE") {
870 let val = self.parse_value()?;
871 return Ok(WhereClause::Comparison(Comparison {
872 column: col,
873 op: CmpOp::Like,
874 value: Some(val),
875 left_expr: Some(left_expr),
876 right_expr: None,
877 }));
878 }
879
880 if self.match_keyword("NOT") {
882 if self.match_keyword("LIKE") {
883 let val = self.parse_value()?;
884 return Ok(WhereClause::Comparison(Comparison {
885 column: col,
886 op: CmpOp::NotLike,
887 value: Some(val),
888 left_expr: Some(left_expr),
889 right_expr: None,
890 }));
891 }
892 return Err(MdqlError::QueryParse("Expected LIKE after NOT".into()));
893 }
894
895 if let Some(t) = self.peek() {
897 if t.token_type == "op" && ["=", "!=", "<", ">", "<=", ">="].contains(&t.value.as_str())
898 {
899 let op_str = self.advance().value;
900 let op = match op_str.as_str() {
901 "=" => CmpOp::Eq,
902 "!=" => CmpOp::Ne,
903 "<" => CmpOp::Lt,
904 ">" => CmpOp::Gt,
905 "<=" => CmpOp::Le,
906 ">=" => CmpOp::Ge,
907 _ => unreachable!(),
908 };
909 let right_expr = self.parse_additive()?;
911 let value = match &right_expr {
913 Expr::Literal(v) => Some(v.clone()),
914 _ => None,
915 };
916 return Ok(WhereClause::Comparison(Comparison {
917 column: col,
918 op,
919 value,
920 left_expr: Some(left_expr),
921 right_expr: Some(right_expr),
922 }));
923 }
924 }
925
926 let got = self.peek().map_or("end".to_string(), |t| t.raw.clone());
927 Err(MdqlError::QueryParse(format!(
928 "Expected operator after '{}', got '{}'",
929 left_expr.display_name(), got
930 )))
931 }
932
933 fn parse_value(&mut self) -> Result<SqlValue, MdqlError> {
934 let t = self.peek().ok_or_else(|| {
935 MdqlError::QueryParse("Expected value, got end of query".into())
936 })?;
937 match t.token_type.as_str() {
938 "string" => {
939 let v = self.advance().value;
940 Ok(SqlValue::String(v))
941 }
942 "number" => {
943 let v = self.advance().value;
944 if v.contains('.') {
945 Ok(SqlValue::Float(v.parse().map_err(|_| {
946 MdqlError::QueryParse(format!("Invalid float: {}", v))
947 })?))
948 } else {
949 Ok(SqlValue::Int(v.parse().map_err(|_| {
950 MdqlError::QueryParse(format!("Invalid int: {}", v))
951 })?))
952 }
953 }
954 "keyword" if t.value == "NULL" => {
955 self.advance();
956 Ok(SqlValue::Null)
957 }
958 _ => Err(MdqlError::QueryParse(format!(
959 "Expected value, got '{}'",
960 t.raw
961 ))),
962 }
963 }
964
965 fn parse_order_by(&mut self) -> Result<Vec<OrderSpec>, MdqlError> {
966 let mut specs = vec![self.parse_order_spec()?];
967 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
968 self.advance();
969 specs.push(self.parse_order_spec()?);
970 }
971 Ok(specs)
972 }
973
974 fn parse_order_spec(&mut self) -> Result<OrderSpec, MdqlError> {
975 let expr = self.parse_additive()?;
976 let col = expr.as_column().unwrap_or("").to_string();
977 let descending = if self.match_keyword("DESC") {
978 true
979 } else {
980 self.match_keyword("ASC");
981 false
982 };
983 Ok(OrderSpec {
984 column: col,
985 expr: Some(expr),
986 descending,
987 })
988 }
989
990 fn is_clause_keyword(&self, t: &Token) -> bool {
991 t.token_type == "keyword"
992 && ["WHERE", "ORDER", "LIMIT", "JOIN", "LEFT", "ON", "GROUP"].contains(&t.value.as_str())
993 }
994
995 fn is_reserved_keyword(kw: &str) -> bool {
997 matches!(kw,
998 "AS" | "FROM" | "WHERE" | "AND" | "OR" | "ORDER" | "BY"
999 | "ASC" | "DESC" | "LIMIT" | "JOIN" | "ON" | "GROUP"
1000 | "SELECT" | "INSERT" | "INTO" | "VALUES" | "UPDATE" | "SET"
1001 | "DELETE" | "ALTER" | "TABLE" | "IS" | "NOT" | "IN" | "LIKE"
1002 | "RENAME" | "FIELD" | "TO" | "DROP" | "MERGE" | "FIELDS"
1003 | "CASE" | "WHEN" | "THEN" | "ELSE" | "END"
1004 | "HAVING" | "INTERVAL" | "DAY" | "DAYS"
1005 | "CURRENT_DATE" | "CURRENT_TIMESTAMP" | "DATEDIFF"
1006 | "CREATE" | "VIEW" | "CASCADE" | "RESTRICT"
1007 )
1008 }
1009
1010 fn expect_end(&self) -> Result<(), MdqlError> {
1011 if let Some(t) = self.peek() {
1012 return Err(MdqlError::QueryParse(format!(
1013 "Unexpected token '{}' at position {}",
1014 t.raw, self.pos
1015 )));
1016 }
1017 Ok(())
1018 }
1019}
1020
1021pub fn parse_query(sql: &str) -> crate::errors::Result<Statement> {
1022 let tokens = tokenize(sql);
1023 if tokens.is_empty() {
1024 return Err(MdqlError::QueryParse("Empty query".into()));
1025 }
1026 let mut parser = Parser::new(tokens);
1027 parser.parse_statement()
1028}
1029
1030#[cfg(test)]
1031mod tests {
1032 use super::*;
1033
1034 #[test]
1035 fn test_simple_select() {
1036 let stmt = parse_query("SELECT title, status FROM strategies").unwrap();
1037 if let Statement::Select(q) = stmt {
1038 assert_eq!(q.columns, ColumnList::Named(vec![SelectExpr::Column("title".into()), SelectExpr::Column("status".into())]));
1039 assert_eq!(q.table, "strategies");
1040 } else {
1041 panic!("Expected Select");
1042 }
1043 }
1044
1045 #[test]
1046 fn test_select_star() {
1047 let stmt = parse_query("SELECT * FROM test").unwrap();
1048 if let Statement::Select(q) = stmt {
1049 assert_eq!(q.columns, ColumnList::All);
1050 } else {
1051 panic!("Expected Select");
1052 }
1053 }
1054
1055 #[test]
1056 fn test_where_clause() {
1057 let stmt = parse_query("SELECT title FROM test WHERE count > 5").unwrap();
1058 if let Statement::Select(q) = stmt {
1059 assert!(q.where_clause.is_some());
1060 } else {
1061 panic!("Expected Select");
1062 }
1063 }
1064
1065 #[test]
1066 fn test_order_by() {
1067 let stmt =
1068 parse_query("SELECT title FROM test ORDER BY composite DESC, title ASC").unwrap();
1069 if let Statement::Select(q) = stmt {
1070 let ob = q.order_by.unwrap();
1071 assert_eq!(ob.len(), 2);
1072 assert!(ob[0].descending);
1073 assert!(!ob[1].descending);
1074 } else {
1075 panic!("Expected Select");
1076 }
1077 }
1078
1079 #[test]
1080 fn test_limit() {
1081 let stmt = parse_query("SELECT * FROM test LIMIT 10").unwrap();
1082 if let Statement::Select(q) = stmt {
1083 assert_eq!(q.limit, Some(10));
1084 } else {
1085 panic!("Expected Select");
1086 }
1087 }
1088
1089 #[test]
1090 fn test_insert() {
1091 let stmt = parse_query(
1092 "INSERT INTO test (title, count) VALUES ('Hello', 42)",
1093 )
1094 .unwrap();
1095 if let Statement::Insert(q) = stmt {
1096 assert_eq!(q.table, "test");
1097 assert_eq!(q.columns, vec!["title", "count"]);
1098 assert_eq!(q.values[0], SqlValue::String("Hello".into()));
1099 assert_eq!(q.values[1], SqlValue::Int(42));
1100 } else {
1101 panic!("Expected Insert");
1102 }
1103 }
1104
1105 #[test]
1106 fn test_update() {
1107 let stmt = parse_query("UPDATE test SET status = 'KILLED' WHERE path = 'a.md'").unwrap();
1108 if let Statement::Update(q) = stmt {
1109 assert_eq!(q.table, "test");
1110 assert_eq!(q.assignments.len(), 1);
1111 assert!(q.where_clause.is_some());
1112 } else {
1113 panic!("Expected Update");
1114 }
1115 }
1116
1117 #[test]
1118 fn test_delete() {
1119 let stmt = parse_query("DELETE FROM test WHERE status = 'draft'").unwrap();
1120 if let Statement::Delete(q) = stmt {
1121 assert_eq!(q.table, "test");
1122 assert!(q.where_clause.is_some());
1123 } else {
1124 panic!("Expected Delete");
1125 }
1126 }
1127
1128 #[test]
1129 fn test_alter_rename() {
1130 let stmt =
1131 parse_query("ALTER TABLE test RENAME FIELD 'Summary' TO 'Overview'").unwrap();
1132 if let Statement::AlterRename(q) = stmt {
1133 assert_eq!(q.old_name, "Summary");
1134 assert_eq!(q.new_name, "Overview");
1135 } else {
1136 panic!("Expected AlterRename");
1137 }
1138 }
1139
1140 #[test]
1141 fn test_alter_drop() {
1142 let stmt = parse_query("ALTER TABLE test DROP FIELD 'Details'").unwrap();
1143 if let Statement::AlterDrop(q) = stmt {
1144 assert_eq!(q.field_name, "Details");
1145 } else {
1146 panic!("Expected AlterDrop");
1147 }
1148 }
1149
1150 #[test]
1151 fn test_alter_merge() {
1152 let stmt = parse_query(
1153 "ALTER TABLE test MERGE FIELDS 'Entry Rules', 'Exit Rules' INTO 'Trading Rules'",
1154 )
1155 .unwrap();
1156 if let Statement::AlterMerge(q) = stmt {
1157 assert_eq!(q.sources, vec!["Entry Rules", "Exit Rules"]);
1158 assert_eq!(q.into, "Trading Rules");
1159 } else {
1160 panic!("Expected AlterMerge");
1161 }
1162 }
1163
1164 #[test]
1165 fn test_backtick_ident() {
1166 let stmt = parse_query("SELECT `Structural Mechanism` FROM test").unwrap();
1167 if let Statement::Select(q) = stmt {
1168 assert_eq!(
1169 q.columns,
1170 ColumnList::Named(vec![SelectExpr::Column("Structural Mechanism".into())])
1171 );
1172 } else {
1173 panic!("Expected Select");
1174 }
1175 }
1176
1177 #[test]
1178 fn test_like_operator() {
1179 let stmt = parse_query("SELECT title FROM test WHERE categories LIKE '%defi%'").unwrap();
1180 if let Statement::Select(q) = stmt {
1181 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1182 assert_eq!(c.op, CmpOp::Like);
1183 assert_eq!(c.value, Some(SqlValue::String("%defi%".into())));
1184 } else {
1185 panic!("Expected LIKE comparison");
1186 }
1187 } else {
1188 panic!("Expected Select");
1189 }
1190 }
1191
1192 #[test]
1193 fn test_in_operator() {
1194 let stmt =
1195 parse_query("SELECT * FROM test WHERE status IN ('ACTIVE', 'LIVE')").unwrap();
1196 if let Statement::Select(q) = stmt {
1197 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1198 assert_eq!(c.op, CmpOp::In);
1199 } else {
1200 panic!("Expected IN comparison");
1201 }
1202 } else {
1203 panic!("Expected Select");
1204 }
1205 }
1206
1207 #[test]
1208 fn test_is_null() {
1209 let stmt = parse_query("SELECT * FROM test WHERE title IS NULL").unwrap();
1210 if let Statement::Select(q) = stmt {
1211 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1212 assert_eq!(c.op, CmpOp::IsNull);
1213 } else {
1214 panic!("Expected IS NULL comparison");
1215 }
1216 } else {
1217 panic!("Expected Select");
1218 }
1219 }
1220
1221 #[test]
1222 fn test_and_or() {
1223 let stmt = parse_query(
1224 "SELECT * FROM test WHERE status = 'ACTIVE' AND count > 5 OR title LIKE '%test%'",
1225 )
1226 .unwrap();
1227 if let Statement::Select(q) = stmt {
1228 assert!(q.where_clause.is_some());
1229 } else {
1230 panic!("Expected Select");
1231 }
1232 }
1233
1234 #[test]
1235 fn test_join() {
1236 let stmt = parse_query(
1237 "SELECT s.title, b.sharpe FROM strategies s JOIN backtests b ON b.strategy = s.path",
1238 )
1239 .unwrap();
1240 if let Statement::Select(q) = stmt {
1241 assert_eq!(q.table, "strategies");
1242 assert_eq!(q.table_alias, Some("s".into()));
1243 assert_eq!(q.joins.len(), 1);
1244 let join = &q.joins[0];
1245 assert_eq!(join.table, "backtests");
1246 assert_eq!(join.alias, Some("b".into()));
1247 } else {
1248 panic!("Expected Select");
1249 }
1250 }
1251
1252 #[test]
1253 fn test_multi_join() {
1254 let stmt = parse_query(
1255 "SELECT s.title, b.sharpe, c.verdict FROM strategies s JOIN backtests b ON b.strategy = s.path JOIN critiques c ON c.strategy = s.path",
1256 )
1257 .unwrap();
1258 if let Statement::Select(q) = stmt {
1259 assert_eq!(q.table, "strategies");
1260 assert_eq!(q.table_alias, Some("s".into()));
1261 assert_eq!(q.joins.len(), 2);
1262 assert_eq!(q.joins[0].table, "backtests");
1263 assert_eq!(q.joins[0].alias, Some("b".into()));
1264 assert_eq!(q.joins[0].left_col, "b.strategy");
1265 assert_eq!(q.joins[0].right_col, "s.path");
1266 assert_eq!(q.joins[1].table, "critiques");
1267 assert_eq!(q.joins[1].alias, Some("c".into()));
1268 assert_eq!(q.joins[1].left_col, "c.strategy");
1269 assert_eq!(q.joins[1].right_col, "s.path");
1270 } else {
1271 panic!("Expected Select");
1272 }
1273 }
1274
1275 #[test]
1276 fn test_left_join() {
1277 let stmt = parse_query(
1278 "SELECT s.title, b.sharpe FROM strategies s LEFT JOIN backtests b ON b.strategy = s.path",
1279 )
1280 .unwrap();
1281 if let Statement::Select(q) = stmt {
1282 assert_eq!(q.joins.len(), 1);
1283 assert_eq!(q.joins[0].join_type, JoinType::Left);
1284 assert_eq!(q.joins[0].table, "backtests");
1285 } else {
1286 panic!("Expected Select");
1287 }
1288 }
1289
1290 #[test]
1291 fn test_mixed_join_types() {
1292 let stmt = parse_query(
1293 "SELECT s.title FROM strategies s JOIN backtests b ON b.strategy = s.path LEFT JOIN allocations a ON a.strategy = s.path",
1294 )
1295 .unwrap();
1296 if let Statement::Select(q) = stmt {
1297 assert_eq!(q.joins.len(), 2);
1298 assert_eq!(q.joins[0].join_type, JoinType::Inner);
1299 assert_eq!(q.joins[1].join_type, JoinType::Left);
1300 } else {
1301 panic!("Expected Select");
1302 }
1303 }
1304
1305 #[test]
1306 fn test_empty_query() {
1307 assert!(parse_query("").is_err());
1308 }
1309
1310 #[test]
1311 fn test_count_star() {
1312 let stmt = parse_query("SELECT status, COUNT(*) AS cnt FROM strategies GROUP BY status").unwrap();
1313 if let Statement::Select(q) = stmt {
1314 if let ColumnList::Named(exprs) = &q.columns {
1315 assert_eq!(exprs.len(), 2);
1316 assert_eq!(exprs[0], SelectExpr::Column("status".into()));
1317 assert!(matches!(&exprs[1], SelectExpr::Aggregate {
1318 func: AggFunc::Count,
1319 arg,
1320 alias: Some(a),
1321 ..
1322 } if arg == "*" && a == "cnt"));
1323 } else {
1324 panic!("Expected Named columns");
1325 }
1326 assert_eq!(q.group_by, Some(vec!["status".into()]));
1327 } else {
1328 panic!("Expected Select");
1329 }
1330 }
1331
1332 #[test]
1333 fn test_count_column_as_ident() {
1334 let stmt = parse_query("INSERT INTO test (title, count) VALUES ('Hello', 42)").unwrap();
1336 if let Statement::Insert(q) = stmt {
1337 assert_eq!(q.columns, vec!["title", "count"]);
1338 } else {
1339 panic!("Expected Insert");
1340 }
1341 }
1342
1343 #[test]
1344 fn test_multiple_aggregates() {
1345 let stmt = parse_query("SELECT MIN(composite), MAX(composite), AVG(composite) FROM strategies").unwrap();
1346 if let Statement::Select(q) = stmt {
1347 if let ColumnList::Named(exprs) = &q.columns {
1348 assert_eq!(exprs.len(), 3);
1349 assert!(matches!(&exprs[0], SelectExpr::Aggregate { func: AggFunc::Min, .. }));
1350 assert!(matches!(&exprs[1], SelectExpr::Aggregate { func: AggFunc::Max, .. }));
1351 assert!(matches!(&exprs[2], SelectExpr::Aggregate { func: AggFunc::Avg, .. }));
1352 } else {
1353 panic!("Expected Named columns");
1354 }
1355 assert_eq!(q.group_by, None);
1356 } else {
1357 panic!("Expected Select");
1358 }
1359 }
1360
1361 #[test]
1364 fn test_select_arithmetic_expr() {
1365 let stmt = parse_query("SELECT a + b FROM test").unwrap();
1366 if let Statement::Select(q) = stmt {
1367 if let ColumnList::Named(exprs) = &q.columns {
1368 assert_eq!(exprs.len(), 1);
1369 assert!(matches!(&exprs[0], SelectExpr::Expr {
1370 expr: Expr::BinaryOp { op: ArithOp::Add, .. },
1371 alias: None,
1372 }));
1373 } else {
1374 panic!("Expected Named columns");
1375 }
1376 } else {
1377 panic!("Expected Select");
1378 }
1379 }
1380
1381 #[test]
1382 fn test_select_arithmetic_with_alias() {
1383 let stmt = parse_query("SELECT a + b AS total FROM test").unwrap();
1384 if let Statement::Select(q) = stmt {
1385 if let ColumnList::Named(exprs) = &q.columns {
1386 assert_eq!(exprs.len(), 1);
1387 assert!(matches!(&exprs[0], SelectExpr::Expr {
1388 alias: Some(a),
1389 ..
1390 } if a == "total"));
1391 assert_eq!(exprs[0].output_name(), "total");
1392 } else {
1393 panic!("Expected Named columns");
1394 }
1395 } else {
1396 panic!("Expected Select");
1397 }
1398 }
1399
1400 #[test]
1401 fn test_select_precedence() {
1402 let stmt = parse_query("SELECT a + b * c FROM test").unwrap();
1404 if let Statement::Select(q) = stmt {
1405 if let ColumnList::Named(exprs) = &q.columns {
1406 if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1407 if let Expr::BinaryOp { left, op, right } = expr {
1408 assert_eq!(*op, ArithOp::Add);
1409 assert!(matches!(left.as_ref(), Expr::Column(n) if n == "a"));
1410 assert!(matches!(right.as_ref(), Expr::BinaryOp { op: ArithOp::Mul, .. }));
1411 } else {
1412 panic!("Expected BinaryOp");
1413 }
1414 } else {
1415 panic!("Expected Expr variant");
1416 }
1417 } else {
1418 panic!("Expected Named columns");
1419 }
1420 } else {
1421 panic!("Expected Select");
1422 }
1423 }
1424
1425 #[test]
1426 fn test_select_parenthesized_expr() {
1427 let stmt = parse_query("SELECT (a + b) * c FROM test").unwrap();
1429 if let Statement::Select(q) = stmt {
1430 if let ColumnList::Named(exprs) = &q.columns {
1431 if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1432 if let Expr::BinaryOp { left, op, .. } = expr {
1433 assert_eq!(*op, ArithOp::Mul);
1434 assert!(matches!(left.as_ref(), Expr::BinaryOp { op: ArithOp::Add, .. }));
1435 } else {
1436 panic!("Expected BinaryOp");
1437 }
1438 } else {
1439 panic!("Expected Expr variant");
1440 }
1441 } else {
1442 panic!("Expected Named columns");
1443 }
1444 } else {
1445 panic!("Expected Select");
1446 }
1447 }
1448
1449 #[test]
1450 fn test_select_unary_minus() {
1451 let stmt = parse_query("SELECT -count FROM test").unwrap();
1452 if let Statement::Select(q) = stmt {
1453 if let ColumnList::Named(exprs) = &q.columns {
1454 assert!(matches!(&exprs[0], SelectExpr::Expr {
1455 expr: Expr::UnaryMinus(_),
1456 ..
1457 }));
1458 } else {
1459 panic!("Expected Named columns");
1460 }
1461 } else {
1462 panic!("Expected Select");
1463 }
1464 }
1465
1466 #[test]
1467 fn test_select_negative_literal() {
1468 let stmt = parse_query("SELECT -42 FROM test").unwrap();
1469 if let Statement::Select(q) = stmt {
1470 if let ColumnList::Named(exprs) = &q.columns {
1471 assert!(matches!(&exprs[0], SelectExpr::Expr {
1473 expr: Expr::Literal(SqlValue::Int(-42)),
1474 ..
1475 }));
1476 } else {
1477 panic!("Expected Named columns");
1478 }
1479 } else {
1480 panic!("Expected Select");
1481 }
1482 }
1483
1484 #[test]
1485 fn test_where_arithmetic_expr() {
1486 let stmt = parse_query("SELECT * FROM test WHERE a + b > 10").unwrap();
1487 if let Statement::Select(q) = stmt {
1488 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1489 assert_eq!(c.op, CmpOp::Gt);
1490 assert!(matches!(&c.left_expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1491 assert!(matches!(&c.right_expr, Some(Expr::Literal(SqlValue::Int(10)))));
1492 } else {
1493 panic!("Expected comparison");
1494 }
1495 } else {
1496 panic!("Expected Select");
1497 }
1498 }
1499
1500 #[test]
1501 fn test_where_both_sides_expr() {
1502 let stmt = parse_query("SELECT * FROM test WHERE a * 2 > b + 1").unwrap();
1503 if let Statement::Select(q) = stmt {
1504 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1505 assert_eq!(c.op, CmpOp::Gt);
1506 assert!(matches!(&c.left_expr, Some(Expr::BinaryOp { op: ArithOp::Mul, .. })));
1507 assert!(matches!(&c.right_expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1508 } else {
1509 panic!("Expected comparison");
1510 }
1511 } else {
1512 panic!("Expected Select");
1513 }
1514 }
1515
1516 #[test]
1517 fn test_order_by_expr() {
1518 let stmt = parse_query("SELECT * FROM test ORDER BY a + b DESC").unwrap();
1519 if let Statement::Select(q) = stmt {
1520 let ob = q.order_by.unwrap();
1521 assert_eq!(ob.len(), 1);
1522 assert!(ob[0].descending);
1523 assert!(matches!(&ob[0].expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1524 } else {
1525 panic!("Expected Select");
1526 }
1527 }
1528
1529 #[test]
1530 fn test_all_arithmetic_ops() {
1531 let stmt = parse_query("SELECT a + b, a - b, a * b, a / b, a % b FROM test").unwrap();
1532 if let Statement::Select(q) = stmt {
1533 if let ColumnList::Named(exprs) = &q.columns {
1534 assert_eq!(exprs.len(), 5);
1535 assert!(matches!(&exprs[0], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Add, .. }, .. }));
1536 assert!(matches!(&exprs[1], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Sub, .. }, .. }));
1537 assert!(matches!(&exprs[2], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Mul, .. }, .. }));
1538 assert!(matches!(&exprs[3], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Div, .. }, .. }));
1539 assert!(matches!(&exprs[4], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Mod, .. }, .. }));
1540 } else {
1541 panic!("Expected Named columns");
1542 }
1543 } else {
1544 panic!("Expected Select");
1545 }
1546 }
1547
1548 #[test]
1549 fn test_column_with_literal_arithmetic() {
1550 let stmt = parse_query("SELECT count * 2 + 1 FROM test").unwrap();
1551 if let Statement::Select(q) = stmt {
1552 if let ColumnList::Named(exprs) = &q.columns {
1553 if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1555 if let Expr::BinaryOp { left, op, right } = expr {
1556 assert_eq!(*op, ArithOp::Add);
1557 assert!(matches!(right.as_ref(), Expr::Literal(SqlValue::Int(1))));
1558 assert!(matches!(left.as_ref(), Expr::BinaryOp { op: ArithOp::Mul, .. }));
1559 } else {
1560 panic!("Expected BinaryOp");
1561 }
1562 } else {
1563 panic!("Expected Expr");
1564 }
1565 } else {
1566 panic!("Expected Named columns");
1567 }
1568 } else {
1569 panic!("Expected Select");
1570 }
1571 }
1572
1573 #[test]
1574 fn test_mixed_columns_and_exprs() {
1575 let stmt = parse_query("SELECT title, a + b AS sum, count FROM test").unwrap();
1576 if let Statement::Select(q) = stmt {
1577 if let ColumnList::Named(exprs) = &q.columns {
1578 assert_eq!(exprs.len(), 3);
1579 assert_eq!(exprs[0], SelectExpr::Column("title".into()));
1580 assert!(matches!(&exprs[1], SelectExpr::Expr { alias: Some(a), .. } if a == "sum"));
1581 assert_eq!(exprs[2], SelectExpr::Column("count".into()));
1582 } else {
1583 panic!("Expected Named columns");
1584 }
1585 } else {
1586 panic!("Expected Select");
1587 }
1588 }
1589
1590 #[test]
1593 fn test_case_when_basic() {
1594 let stmt = parse_query(
1595 "SELECT CASE WHEN status = 'ACTIVE' THEN 1 ELSE 0 END FROM test"
1596 ).unwrap();
1597 if let Statement::Select(q) = stmt {
1598 if let ColumnList::Named(exprs) = &q.columns {
1599 assert_eq!(exprs.len(), 1);
1600 assert!(matches!(&exprs[0], SelectExpr::Expr {
1601 expr: Expr::Case { .. },
1602 ..
1603 }));
1604 } else {
1605 panic!("Expected Named columns");
1606 }
1607 } else {
1608 panic!("Expected Select");
1609 }
1610 }
1611
1612 #[test]
1613 fn test_case_when_multiple_branches() {
1614 let stmt = parse_query(
1615 "SELECT CASE WHEN x > 10 THEN 'high' WHEN x > 5 THEN 'mid' ELSE 'low' END FROM test"
1616 ).unwrap();
1617 if let Statement::Select(q) = stmt {
1618 if let ColumnList::Named(exprs) = &q.columns {
1619 if let SelectExpr::Expr { expr: Expr::Case { whens, else_expr }, .. } = &exprs[0] {
1620 assert_eq!(whens.len(), 2);
1621 assert!(else_expr.is_some());
1622 } else {
1623 panic!("Expected Case expression");
1624 }
1625 } else {
1626 panic!("Expected Named columns");
1627 }
1628 } else {
1629 panic!("Expected Select");
1630 }
1631 }
1632
1633 #[test]
1634 fn test_case_when_no_else() {
1635 let stmt = parse_query(
1636 "SELECT CASE WHEN x = 1 THEN 'one' END FROM test"
1637 ).unwrap();
1638 if let Statement::Select(q) = stmt {
1639 if let ColumnList::Named(exprs) = &q.columns {
1640 if let SelectExpr::Expr { expr: Expr::Case { whens, else_expr }, .. } = &exprs[0] {
1641 assert_eq!(whens.len(), 1);
1642 assert!(else_expr.is_none());
1643 } else {
1644 panic!("Expected Case expression");
1645 }
1646 } else {
1647 panic!("Expected Named columns");
1648 }
1649 } else {
1650 panic!("Expected Select");
1651 }
1652 }
1653
1654 #[test]
1655 fn test_case_when_in_aggregate() {
1656 let stmt = parse_query(
1657 "SELECT SUM(CASE WHEN side = 'BUY' THEN size ELSE -size END) AS net FROM orders GROUP BY token"
1658 ).unwrap();
1659 if let Statement::Select(q) = stmt {
1660 if let ColumnList::Named(exprs) = &q.columns {
1661 assert_eq!(exprs.len(), 1);
1662 assert!(matches!(&exprs[0], SelectExpr::Aggregate {
1663 func: AggFunc::Sum,
1664 arg_expr: Some(Expr::Case { .. }),
1665 alias: Some(a),
1666 ..
1667 } if a == "net"));
1668 } else {
1669 panic!("Expected Named columns");
1670 }
1671 } else {
1672 panic!("Expected Select");
1673 }
1674 }
1675
1676 #[test]
1677 fn test_case_when_with_alias() {
1678 let stmt = parse_query(
1679 "SELECT CASE WHEN x > 0 THEN 'pos' ELSE 'neg' END AS sign FROM test"
1680 ).unwrap();
1681 if let Statement::Select(q) = stmt {
1682 if let ColumnList::Named(exprs) = &q.columns {
1683 assert!(matches!(&exprs[0], SelectExpr::Expr {
1684 expr: Expr::Case { .. },
1685 alias: Some(a),
1686 } if a == "sign"));
1687 } else {
1688 panic!("Expected Named columns");
1689 }
1690 } else {
1691 panic!("Expected Select");
1692 }
1693 }
1694
1695 #[test]
1696 fn test_create_view() {
1697 let stmt = parse_query("CREATE VIEW live AS SELECT * FROM strategies WHERE status = 'LIVE'").unwrap();
1698 if let Statement::CreateView(cv) = stmt {
1699 assert_eq!(cv.view_name, "live");
1700 assert!(cv.columns.is_none());
1701 assert_eq!(cv.query.table, "strategies");
1702 assert!(cv.query.where_clause.is_some());
1703 } else {
1704 panic!("Expected CreateView, got {:?}", stmt);
1705 }
1706 }
1707
1708 #[test]
1709 fn test_create_view_with_columns() {
1710 let stmt = parse_query("CREATE VIEW v1 (a, b) AS SELECT title, status FROM t").unwrap();
1711 if let Statement::CreateView(cv) = stmt {
1712 assert_eq!(cv.view_name, "v1");
1713 assert_eq!(cv.columns, Some(vec!["a".into(), "b".into()]));
1714 } else {
1715 panic!("Expected CreateView");
1716 }
1717 }
1718
1719 #[test]
1720 fn test_drop_view() {
1721 let stmt = parse_query("DROP VIEW live").unwrap();
1722 if let Statement::DropView(dv) = stmt {
1723 assert_eq!(dv.view_name, "live");
1724 } else {
1725 panic!("Expected DropView, got {:?}", stmt);
1726 }
1727 }
1728
1729 #[test]
1730 fn test_create_view_case_insensitive() {
1731 let stmt = parse_query("create view My_View as select * from t").unwrap();
1732 if let Statement::CreateView(cv) = stmt {
1733 assert_eq!(cv.view_name, "My_View");
1734 } else {
1735 panic!("Expected CreateView");
1736 }
1737 }
1738
1739 #[test]
1742 fn test_aggregate_division() {
1743 let stmt = parse_query(
1744 "SELECT token, SUM(sell) / SUM(buy) as ratio FROM orders GROUP BY token"
1745 ).unwrap();
1746 if let Statement::Select(q) = stmt {
1747 assert_eq!(q.group_by, Some(vec!["token".into()]));
1748 if let ColumnList::Named(exprs) = &q.columns {
1749 assert_eq!(exprs.len(), 2);
1750 assert!(exprs[1].is_aggregate());
1751 } else {
1752 panic!("Expected Named columns");
1753 }
1754 } else {
1755 panic!("Expected Select");
1756 }
1757 }
1758
1759 #[test]
1760 fn test_aggregate_subtraction() {
1761 let stmt = parse_query(
1762 "SELECT token, SUM(sell) - SUM(buy) as net FROM orders GROUP BY token"
1763 ).unwrap();
1764 if let Statement::Select(q) = stmt {
1765 if let ColumnList::Named(exprs) = &q.columns {
1766 assert_eq!(exprs[1].output_name(), "net");
1767 }
1768 } else {
1769 panic!("Expected Select");
1770 }
1771 }
1772
1773 #[test]
1774 fn test_create_view_with_arithmetic() {
1775 let stmt = parse_query(
1776 "CREATE VIEW positions AS SELECT token, SUM(sell) / SUM(buy) as ratio FROM orders GROUP BY token"
1777 ).unwrap();
1778 if let Statement::CreateView(cv) = stmt {
1779 assert_eq!(cv.view_name, "positions");
1780 } else {
1781 panic!("Expected CreateView, got {:?}", stmt);
1782 }
1783 }
1784
1785 #[test]
1788 fn test_subquery_in_from() {
1789 let stmt = parse_query(
1790 "SELECT token, sell_size FROM (SELECT token, SUM(size) as sell_size FROM orders GROUP BY token) LIMIT 5"
1791 ).unwrap();
1792 if let Statement::Select(q) = stmt {
1793 assert!(q.subquery.is_some());
1794 assert_eq!(q.limit, Some(5));
1795 let sub = q.subquery.unwrap();
1796 assert_eq!(sub.table, "orders");
1797 assert!(sub.group_by.is_some());
1798 } else {
1799 panic!("Expected Select");
1800 }
1801 }
1802
1803 #[test]
1806 fn test_create_view_with_having() {
1807 let stmt = parse_query(
1808 "CREATE VIEW positions AS SELECT token, SUM(sell) as sell_size, SUM(buy) as buy_size FROM orders GROUP BY token HAVING sell_size > buy_size"
1809 ).unwrap();
1810 if let Statement::CreateView(cv) = stmt {
1811 assert_eq!(cv.view_name, "positions");
1812 assert!(cv.query.having.is_some());
1813 } else {
1814 panic!("Expected CreateView, got {:?}", stmt);
1815 }
1816 }
1817
1818 #[test]
1821 fn test_aggregate_multiplication() {
1822 let stmt = parse_query(
1823 "SELECT SUM(a) * 2 as doubled FROM test"
1824 ).unwrap();
1825 if let Statement::Select(q) = stmt {
1826 if let ColumnList::Named(exprs) = &q.columns {
1827 assert_eq!(exprs.len(), 1);
1828 assert!(exprs[0].is_aggregate());
1829 assert_eq!(exprs[0].output_name(), "doubled");
1830 } else {
1831 panic!("Expected Named columns");
1832 }
1833 } else {
1834 panic!("Expected Select");
1835 }
1836 }
1837
1838 #[test]
1839 fn test_complex_aggregate_arithmetic() {
1840 let stmt = parse_query(
1841 "SELECT SUM(CASE WHEN side = 'SELL' THEN size ELSE 0 END) / SUM(CASE WHEN side = 'BUY' THEN size ELSE 0 END) as ratio FROM orders GROUP BY token"
1842 ).unwrap();
1843 if let Statement::Select(q) = stmt {
1844 if let ColumnList::Named(exprs) = &q.columns {
1845 assert_eq!(exprs.len(), 1);
1846 assert!(exprs[0].is_aggregate());
1847 assert_eq!(exprs[0].output_name(), "ratio");
1848 } else {
1849 panic!("Expected Named columns");
1850 }
1851 assert_eq!(q.group_by, Some(vec!["token".into()]));
1852 } else {
1853 panic!("Expected Select");
1854 }
1855 }
1856
1857 #[test]
1860 fn test_subquery_with_alias() {
1861 let stmt = parse_query(
1862 "SELECT x FROM (SELECT x FROM t) sub"
1863 ).unwrap();
1864 if let Statement::Select(q) = stmt {
1865 assert!(q.subquery.is_some());
1866 let sub = q.subquery.unwrap();
1867 assert_eq!(sub.table, "t");
1868 if let ColumnList::Named(exprs) = &q.columns {
1869 assert_eq!(exprs.len(), 1);
1870 assert_eq!(exprs[0].output_name(), "x");
1871 } else {
1872 panic!("Expected Named columns");
1873 }
1874 } else {
1875 panic!("Expected Select");
1876 }
1877 }
1878
1879 #[test]
1880 fn test_subquery_with_where() {
1881 let stmt = parse_query(
1882 "SELECT x FROM (SELECT x FROM t WHERE y > 0) LIMIT 5"
1883 ).unwrap();
1884 if let Statement::Select(q) = stmt {
1885 assert!(q.subquery.is_some());
1886 assert_eq!(q.limit, Some(5));
1887 let sub = q.subquery.unwrap();
1888 assert_eq!(sub.table, "t");
1889 assert!(sub.where_clause.is_some());
1890 } else {
1891 panic!("Expected Select");
1892 }
1893 }
1894
1895 #[test]
1898 fn test_create_view_aggregate_subtraction() {
1899 let stmt = parse_query(
1900 "CREATE VIEW v AS SELECT token, SUM(sell) - SUM(buy) as net FROM orders GROUP BY token"
1901 ).unwrap();
1902 if let Statement::CreateView(cv) = stmt {
1903 assert_eq!(cv.view_name, "v");
1904 assert_eq!(cv.query.group_by, Some(vec!["token".into()]));
1905 if let ColumnList::Named(exprs) = &cv.query.columns {
1906 assert_eq!(exprs.len(), 2);
1907 assert_eq!(exprs[1].output_name(), "net");
1908 assert!(exprs[1].is_aggregate());
1909 } else {
1910 panic!("Expected Named columns");
1911 }
1912 } else {
1913 panic!("Expected CreateView, got {:?}", stmt);
1914 }
1915 }
1916}