1use regex::Regex;
4use std::sync::LazyLock;
5
6use crate::errors::MdqlError;
7pub use crate::query_ast::*;
8
9static KEYWORDS: &[&str] = &[
12 "SELECT", "FROM", "WHERE", "AND", "OR", "ORDER", "BY",
13 "ASC", "DESC", "LIMIT", "LIKE", "IN", "IS", "NOT", "NULL",
14 "JOIN", "LEFT", "ON", "AS", "GROUP", "HAVING",
15 "INSERT", "INTO", "VALUES", "UPDATE", "SET", "DELETE",
16 "ALTER", "TABLE", "RENAME", "FIELD", "TO", "DROP", "MERGE", "FIELDS",
17 "CASE", "WHEN", "THEN", "ELSE", "END",
18 "INTERVAL", "DAY", "DAYS", "CURRENT_DATE", "CURRENT_TIMESTAMP", "DATEDIFF",
19 "CREATE", "VIEW", "CASCADE", "RESTRICT",
20];
21
22static AGG_FUNCS: &[&str] = &["COUNT", "SUM", "AVG", "MIN", "MAX"];
23
24static TOKEN_RE: LazyLock<Regex> = LazyLock::new(|| {
25 Regex::new(
26 r#"(?x)
27 \s*(?:
28 (?P<backtick>`[^`]+`)
29 | (?P<string>'(?:[^'\\]|\\.)*')
30 | (?P<number>\d+(?:\.\d+)?)
31 | (?P<op><=|>=|!=|[=<>,*()+\-/%])
32 | (?P<word>[A-Za-z_][A-Za-z0-9_./-]*)
33 )"#,
34 )
35 .unwrap()
36});
37
38#[derive(Debug, Clone)]
39struct Token {
40 token_type: String,
41 value: String,
42 raw: String,
43}
44
45fn tokenize(sql: &str) -> Vec<Token> {
46 let mut tokens = Vec::new();
47 for caps in TOKEN_RE.captures_iter(sql) {
48 if let Some(m) = caps.name("backtick") {
49 let raw = m.as_str();
50 tokens.push(Token {
51 token_type: "ident".into(),
52 value: raw[1..raw.len() - 1].into(),
53 raw: raw.into(),
54 });
55 } else if let Some(m) = caps.name("string") {
56 let raw = m.as_str();
57 tokens.push(Token {
58 token_type: "string".into(),
59 value: raw[1..raw.len() - 1].into(),
60 raw: raw.into(),
61 });
62 } else if let Some(m) = caps.name("number") {
63 let raw = m.as_str();
64 tokens.push(Token {
65 token_type: "number".into(),
66 value: raw.into(),
67 raw: raw.into(),
68 });
69 } else if let Some(m) = caps.name("op") {
70 let raw = m.as_str();
71 tokens.push(Token {
72 token_type: "op".into(),
73 value: raw.into(),
74 raw: raw.into(),
75 });
76 } else if let Some(m) = caps.name("word") {
77 let raw = m.as_str();
78 if KEYWORDS.contains(&raw.to_uppercase().as_str()) {
79 tokens.push(Token {
80 token_type: "keyword".into(),
81 value: raw.to_uppercase(),
82 raw: raw.into(),
83 });
84 } else {
85 tokens.push(Token {
86 token_type: "ident".into(),
87 value: raw.into(),
88 raw: raw.into(),
89 });
90 }
91 }
92 }
93 tokens
94}
95
96struct Parser {
99 tokens: Vec<Token>,
100 pos: usize,
101}
102
103impl Parser {
104 fn new(tokens: Vec<Token>) -> Self {
105 Parser { tokens, pos: 0 }
106 }
107
108 fn peek(&self) -> Option<&Token> {
109 self.tokens.get(self.pos)
110 }
111
112 fn advance(&mut self) -> Token {
113 let t = self.tokens[self.pos].clone();
114 self.pos += 1;
115 t
116 }
117
118 fn expect(&mut self, type_: &str, value: Option<&str>) -> Result<Token, MdqlError> {
119 let t = self.peek().ok_or_else(|| {
120 MdqlError::QueryParse(format!(
121 "Unexpected end of query, expected {}",
122 value.unwrap_or(type_)
123 ))
124 })?;
125 let matches_type = t.token_type == type_;
126 let matches_value = value.map_or(true, |v| t.value == v);
127 if !matches_type || !matches_value {
128 return Err(MdqlError::QueryParse(format!(
129 "Expected {}, got '{}' at position {}",
130 value.unwrap_or(type_),
131 t.raw,
132 self.pos
133 )));
134 }
135 Ok(self.advance())
136 }
137
138 fn match_keyword(&mut self, kw: &str) -> bool {
139 if let Some(t) = self.peek() {
140 if t.token_type == "keyword" && t.value == kw {
141 self.advance();
142 return true;
143 }
144 }
145 false
146 }
147
148 fn parse_statement(&mut self) -> Result<Statement, MdqlError> {
149 let t = self.peek().ok_or_else(|| MdqlError::QueryParse("Empty query".into()))?;
150 match (t.token_type.as_str(), t.value.as_str()) {
151 ("keyword", "SELECT") => {
152 let q = self.parse_select()?;
153 self.expect_end()?;
154 Ok(Statement::Select(q))
155 }
156 ("keyword", "INSERT") => Ok(Statement::Insert(self.parse_insert()?)),
157 ("keyword", "UPDATE") => Ok(Statement::Update(self.parse_update()?)),
158 ("keyword", "DELETE") => Ok(Statement::Delete(self.parse_delete()?)),
159 ("keyword", "ALTER") => self.parse_alter(),
160 ("keyword", "CREATE") => self.parse_create_view(),
161 ("keyword", "DROP") => self.parse_drop_view(),
162 _ => Err(MdqlError::QueryParse(format!(
163 "Expected SELECT, INSERT, UPDATE, DELETE, ALTER, CREATE, or DROP, got '{}'",
164 t.raw
165 ))),
166 }
167 }
168
169 fn parse_select(&mut self) -> Result<SelectQuery, MdqlError> {
170 self.expect("keyword", Some("SELECT"))?;
171 let columns = self.parse_columns()?;
172 self.expect("keyword", Some("FROM"))?;
173
174 let mut subquery = None;
176 let (table, mut table_alias) = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
177 self.advance();
178 let inner = self.parse_select()?;
179 self.expect("op", Some(")"))?;
180 subquery = Some(Box::new(inner));
181 let alias = if let Some(t) = self.peek() {
182 if t.token_type == "ident" && !self.is_clause_keyword(t) {
183 Some(self.advance().value)
184 } else {
185 None
186 }
187 } else {
188 None
189 };
190 ("_subquery".to_string(), alias)
191 } else {
192 let t = self.parse_ident()?;
193 (t, None)
194 };
195
196 if subquery.is_none() {
198 if let Some(t) = self.peek() {
199 if t.token_type == "ident" && !self.is_clause_keyword(t) {
200 table_alias = Some(self.advance().value);
201 }
202 }
203 }
204
205 let mut joins = Vec::new();
207 loop {
208 let jt = if self.match_keyword("LEFT") {
209 self.expect("keyword", Some("JOIN"))?;
210 JoinType::Left
211 } else if self.match_keyword("JOIN") {
212 JoinType::Inner
213 } else {
214 break;
215 };
216 let join_table = self.parse_ident()?;
217 let mut join_alias = None;
218 if let Some(t) = self.peek() {
219 if t.token_type == "ident" && !self.is_clause_keyword(t) {
220 join_alias = Some(self.advance().value);
221 }
222 }
223 self.expect("keyword", Some("ON"))?;
224 let condition = self.parse_or_expr()?;
225 joins.push(JoinClause {
226 join_type: jt,
227 table: join_table,
228 alias: join_alias,
229 condition,
230 });
231 }
232
233 let mut where_clause = None;
234 if self.match_keyword("WHERE") {
235 where_clause = Some(self.parse_or_expr()?);
236 }
237
238 let mut group_by = None;
239 if self.match_keyword("GROUP") {
240 self.expect("keyword", Some("BY"))?;
241 let mut cols = vec![self.parse_ident()?];
242 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
243 self.advance();
244 cols.push(self.parse_ident()?);
245 }
246 group_by = Some(cols);
247 }
248
249 let mut having = None;
250 if self.match_keyword("HAVING") {
251 having = Some(self.parse_or_expr()?);
252 }
253
254 let mut order_by = None;
255 if self.match_keyword("ORDER") {
256 self.expect("keyword", Some("BY"))?;
257 order_by = Some(self.parse_order_by()?);
258 }
259
260 let mut limit = None;
261 if self.match_keyword("LIMIT") {
262 let t = self.expect("number", None)?;
263 limit = Some(t.value.parse::<i64>().map_err(|_| {
264 MdqlError::QueryParse(format!("Invalid LIMIT value: {}", t.value))
265 })?);
266 }
267
268 Ok(SelectQuery {
269 columns,
270 table,
271 table_alias,
272 subquery,
273 joins,
274 where_clause,
275 group_by,
276 having,
277 order_by,
278 limit,
279 })
280 }
281
282 fn parse_insert(&mut self) -> Result<InsertQuery, MdqlError> {
283 self.expect("keyword", Some("INSERT"))?;
284 self.expect("keyword", Some("INTO"))?;
285 let table = self.parse_ident()?;
286
287 self.expect("op", Some("("))?;
288 let mut columns = vec![self.parse_ident()?];
289 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
290 self.advance();
291 columns.push(self.parse_ident()?);
292 }
293 self.expect("op", Some(")"))?;
294
295 self.expect("keyword", Some("VALUES"))?;
296
297 self.expect("op", Some("("))?;
298 let mut values = vec![self.parse_value()?];
299 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
300 self.advance();
301 values.push(self.parse_value()?);
302 }
303 self.expect("op", Some(")"))?;
304
305 if columns.len() != values.len() {
306 return Err(MdqlError::QueryParse(format!(
307 "Column count ({}) does not match value count ({})",
308 columns.len(),
309 values.len()
310 )));
311 }
312
313 self.expect_end()?;
314 Ok(InsertQuery {
315 table,
316 columns,
317 values,
318 })
319 }
320
321 fn parse_update(&mut self) -> Result<UpdateQuery, MdqlError> {
322 self.expect("keyword", Some("UPDATE"))?;
323 let table = self.parse_ident()?;
324 self.expect("keyword", Some("SET"))?;
325
326 let mut assignments = Vec::new();
327 let col = self.parse_ident()?;
328 self.expect("op", Some("="))?;
329 let val = self.parse_value()?;
330 assignments.push((col, val));
331
332 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
333 self.advance();
334 let col = self.parse_ident()?;
335 self.expect("op", Some("="))?;
336 let val = self.parse_value()?;
337 assignments.push((col, val));
338 }
339
340 let mut where_clause = None;
341 if self.match_keyword("WHERE") {
342 where_clause = Some(self.parse_or_expr()?);
343 }
344
345 self.expect_end()?;
346 Ok(UpdateQuery {
347 table,
348 assignments,
349 where_clause,
350 })
351 }
352
353 fn parse_delete(&mut self) -> Result<DeleteQuery, MdqlError> {
354 self.expect("keyword", Some("DELETE"))?;
355 self.expect("keyword", Some("FROM"))?;
356 let table = self.parse_ident()?;
357
358 let mut where_clause = None;
359 if self.match_keyword("WHERE") {
360 where_clause = Some(self.parse_or_expr()?);
361 }
362
363 let mode = if self.match_keyword("CASCADE") {
364 DeleteMode::Cascade
365 } else if self.match_keyword("RESTRICT") {
366 DeleteMode::Restrict
367 } else {
368 DeleteMode::Default
369 };
370
371 self.expect_end()?;
372 Ok(DeleteQuery {
373 table,
374 where_clause,
375 mode,
376 })
377 }
378
379 fn parse_alter(&mut self) -> Result<Statement, MdqlError> {
380 self.expect("keyword", Some("ALTER"))?;
381 self.expect("keyword", Some("TABLE"))?;
382 let table = self.parse_ident()?;
383
384 let t = self.peek().ok_or_else(|| {
385 MdqlError::QueryParse("Expected RENAME, DROP, or MERGE after table name".into())
386 })?;
387
388 match (t.token_type.as_str(), t.value.as_str()) {
389 ("keyword", "RENAME") => {
390 self.advance();
391 self.expect("keyword", Some("FIELD"))?;
392 let old_name = self.parse_string_or_ident()?;
393 self.expect("keyword", Some("TO"))?;
394 let new_name = self.parse_string_or_ident()?;
395 self.expect_end()?;
396 Ok(Statement::AlterRename(AlterRenameFieldQuery {
397 table,
398 old_name,
399 new_name,
400 }))
401 }
402 ("keyword", "DROP") => {
403 self.advance();
404 self.expect("keyword", Some("FIELD"))?;
405 let field_name = self.parse_string_or_ident()?;
406 self.expect_end()?;
407 Ok(Statement::AlterDrop(AlterDropFieldQuery {
408 table,
409 field_name,
410 }))
411 }
412 ("keyword", "MERGE") => {
413 self.advance();
414 self.expect("keyword", Some("FIELDS"))?;
415 let mut sources = vec![self.parse_string_or_ident()?];
416 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
417 self.advance();
418 sources.push(self.parse_string_or_ident()?);
419 }
420 self.expect("keyword", Some("INTO"))?;
421 let target = self.parse_string_or_ident()?;
422 self.expect_end()?;
423 Ok(Statement::AlterMerge(AlterMergeFieldsQuery {
424 table,
425 sources,
426 into: target,
427 }))
428 }
429 _ => Err(MdqlError::QueryParse(format!(
430 "Expected RENAME, DROP, or MERGE, got '{}'",
431 t.raw
432 ))),
433 }
434 }
435
436 fn parse_create_view(&mut self) -> Result<Statement, MdqlError> {
437 self.expect("keyword", Some("CREATE"))?;
438 self.expect("keyword", Some("VIEW"))?;
439 let view_name = self.parse_ident()?;
440
441 let columns = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
442 self.advance();
443 let mut cols = vec![self.parse_ident()?];
444 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
445 self.advance();
446 cols.push(self.parse_ident()?);
447 }
448 self.expect("op", Some(")"))?;
449 Some(cols)
450 } else {
451 None
452 };
453
454 self.expect("keyword", Some("AS"))?;
455 let query = Box::new(self.parse_select()?);
456 self.expect_end()?;
457
458 Ok(Statement::CreateView(CreateViewQuery {
459 view_name,
460 columns,
461 query,
462 }))
463 }
464
465 fn parse_drop_view(&mut self) -> Result<Statement, MdqlError> {
466 self.expect("keyword", Some("DROP"))?;
467 self.expect("keyword", Some("VIEW"))?;
468 let view_name = self.parse_ident()?;
469 self.expect_end()?;
470 Ok(Statement::DropView(DropViewQuery { view_name }))
471 }
472
473 fn parse_string_or_ident(&mut self) -> Result<String, MdqlError> {
474 let t = self.peek().ok_or_else(|| {
475 MdqlError::QueryParse("Expected field name, got end of query".into())
476 })?;
477 match t.token_type.as_str() {
478 "string" => {
479 let v = self.advance().value;
480 Ok(v)
481 }
482 "ident" | "keyword" => {
483 let v = self.advance().value;
484 Ok(v)
485 }
486 _ => Err(MdqlError::QueryParse(format!(
487 "Expected field name, got '{}'",
488 t.raw
489 ))),
490 }
491 }
492
493 fn parse_columns(&mut self) -> Result<ColumnList, MdqlError> {
494 if let Some(t) = self.peek() {
495 if t.token_type == "op" && t.value == "*" {
496 self.advance();
497 return Ok(ColumnList::All);
498 }
499 }
500
501 let mut exprs = vec![self.parse_select_expr()?];
502 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
503 self.advance();
504 exprs.push(self.parse_select_expr()?);
505 }
506 Ok(ColumnList::Named(exprs))
507 }
508
509 fn peek_is_agg_func(&self) -> bool {
510 let t = match self.peek() {
511 Some(t) => t,
512 None => return false,
513 };
514 let name_upper = t.value.to_uppercase();
515 if !AGG_FUNCS.contains(&name_upper.as_str()) {
516 return false;
517 }
518 self.tokens
520 .get(self.pos + 1)
521 .map_or(false, |next| next.token_type == "op" && next.value == "(")
522 }
523
524 fn parse_select_expr(&mut self) -> Result<SelectExpr, MdqlError> {
525 let _t = self.peek().ok_or_else(|| {
526 MdqlError::QueryParse("Expected column or aggregate, got end of query".into())
527 })?;
528
529 let expr = self.parse_additive()?;
530
531 let alias = if self.match_keyword("AS") {
532 Some(self.parse_ident()?)
533 } else if self.peek().map_or(false, |t| {
534 t.token_type == "ident" && !self.is_clause_keyword(t)
535 }) {
536 Some(self.advance().value)
537 } else {
538 None
539 };
540
541 if let Expr::Aggregate { func, arg, arg_expr } = expr {
543 return Ok(SelectExpr::Aggregate {
544 func,
545 arg,
546 arg_expr: arg_expr.map(|e| *e),
547 alias,
548 });
549 }
550
551 if alias.is_none() {
552 if let Expr::Column(name) = &expr {
553 return Ok(SelectExpr::Column(name.clone()));
554 }
555 }
556
557 Ok(SelectExpr::Expr { expr, alias })
558 }
559
560 fn peek_is_additive_op(&self) -> bool {
563 self.peek().map_or(false, |t| {
564 t.token_type == "op" && (t.value == "+" || t.value == "-")
565 })
566 }
567
568 fn peek_is_multiplicative_op(&self) -> bool {
569 self.peek().map_or(false, |t| {
570 t.token_type == "op" && (t.value == "*" || t.value == "/" || t.value == "%")
571 })
572 }
573
574 fn parse_additive(&mut self) -> Result<Expr, MdqlError> {
575 let mut left = self.parse_multiplicative()?;
576 while self.peek_is_additive_op() {
577 let op_tok = self.advance();
578 let is_sub = op_tok.value == "-";
579
580 if self.peek().map_or(false, |t| t.token_type == "keyword" && t.value == "INTERVAL") {
582 self.advance(); let days_expr = self.parse_multiplicative()?;
584 if !self.match_keyword("DAY") && !self.match_keyword("DAYS") {
586 return Err(MdqlError::QueryParse("Expected DAY after INTERVAL value".into()));
587 }
588 let days = if is_sub {
589 Expr::UnaryMinus(Box::new(days_expr))
590 } else {
591 days_expr
592 };
593 left = Expr::DateAdd {
594 date: Box::new(left),
595 days: Box::new(days),
596 };
597 continue;
598 }
599
600 let op = match op_tok.value.as_str() {
601 "+" => ArithOp::Add,
602 "-" => ArithOp::Sub,
603 _ => unreachable!(),
604 };
605 let right = self.parse_multiplicative()?;
606 left = Expr::BinaryOp {
607 left: Box::new(left),
608 op,
609 right: Box::new(right),
610 };
611 }
612 Ok(left)
613 }
614
615 fn parse_multiplicative(&mut self) -> Result<Expr, MdqlError> {
616 let mut left = self.parse_unary()?;
617 while self.peek_is_multiplicative_op() {
618 let op_tok = self.advance();
619 let op = match op_tok.value.as_str() {
620 "*" => ArithOp::Mul,
621 "/" => ArithOp::Div,
622 "%" => ArithOp::Mod,
623 _ => unreachable!(),
624 };
625 let right = self.parse_unary()?;
626 left = Expr::BinaryOp {
627 left: Box::new(left),
628 op,
629 right: Box::new(right),
630 };
631 }
632 Ok(left)
633 }
634
635 fn parse_unary(&mut self) -> Result<Expr, MdqlError> {
636 if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "-") {
637 self.advance();
638 let inner = self.parse_atom()?;
639 match inner {
641 Expr::Literal(SqlValue::Int(n)) => Ok(Expr::Literal(SqlValue::Int(-n))),
642 Expr::Literal(SqlValue::Float(f)) => Ok(Expr::Literal(SqlValue::Float(-f))),
643 _ => Ok(Expr::UnaryMinus(Box::new(inner))),
644 }
645 } else {
646 self.parse_atom()
647 }
648 }
649
650 fn parse_atom(&mut self) -> Result<Expr, MdqlError> {
651 if self.peek_is_agg_func() {
652 return self.parse_agg_expr();
653 }
654
655 let t = self.peek().ok_or_else(|| {
656 MdqlError::QueryParse("Expected expression, got end of query".into())
657 })?;
658
659 match t.token_type.as_str() {
660 "number" => {
661 let v = self.advance().value;
662 if v.contains('.') {
663 let f: f64 = v.parse().map_err(|_| {
664 MdqlError::QueryParse(format!("Invalid float: {}", v))
665 })?;
666 Ok(Expr::Literal(SqlValue::Float(f)))
667 } else {
668 let n: i64 = v.parse().map_err(|_| {
669 MdqlError::QueryParse(format!("Invalid int: {}", v))
670 })?;
671 Ok(Expr::Literal(SqlValue::Int(n)))
672 }
673 }
674 "string" => {
675 let v = self.advance().value;
676 Ok(Expr::Literal(SqlValue::String(v)))
677 }
678 "keyword" if t.value == "NULL" => {
679 self.advance();
680 Ok(Expr::Literal(SqlValue::Null))
681 }
682 "keyword" if t.value == "CASE" => {
683 self.parse_case_expr()
684 }
685 "keyword" if t.value == "CURRENT_DATE" => {
686 self.advance();
687 Ok(Expr::CurrentDate)
688 }
689 "keyword" if t.value == "CURRENT_TIMESTAMP" => {
690 self.advance();
691 Ok(Expr::CurrentTimestamp)
692 }
693 "keyword" if t.value == "DATEDIFF" => {
694 self.advance();
695 self.expect("op", Some("("))?;
696 let left = self.parse_additive()?;
697 self.expect("op", Some(","))?;
698 let right = self.parse_additive()?;
699 self.expect("op", Some(")"))?;
700 Ok(Expr::DateDiff { left: Box::new(left), right: Box::new(right) })
701 }
702 "op" if t.value == "(" => {
703 self.advance();
704 let expr = self.parse_additive()?;
705 self.expect("op", Some(")"))?;
706 Ok(expr)
707 }
708 "ident" => {
709 let name = self.advance().value;
710 Ok(Expr::Column(name))
711 }
712 "keyword" if !Self::is_reserved_keyword(&t.value) => {
713 let name = self.advance().value;
714 Ok(Expr::Column(name))
715 }
716 _ => Err(MdqlError::QueryParse(format!(
717 "Expected expression, got '{}'",
718 t.raw
719 ))),
720 }
721 }
722
723 fn parse_case_expr(&mut self) -> Result<Expr, MdqlError> {
724 self.expect("keyword", Some("CASE"))?;
725 let mut whens = Vec::new();
726 while self.match_keyword("WHEN") {
727 let condition = self.parse_or_expr()?;
728 self.expect("keyword", Some("THEN"))?;
729 let result = self.parse_additive()?;
730 whens.push((condition, Box::new(result)));
731 }
732 if whens.is_empty() {
733 return Err(MdqlError::QueryParse("CASE requires at least one WHEN clause".into()));
734 }
735 let else_expr = if self.match_keyword("ELSE") {
736 Some(Box::new(self.parse_additive()?))
737 } else {
738 None
739 };
740 self.expect("keyword", Some("END"))?;
741 Ok(Expr::Case { whens, else_expr })
742 }
743
744 fn parse_agg_expr(&mut self) -> Result<Expr, MdqlError> {
745 let func_name = self.advance().value.to_uppercase();
746 let func = match func_name.as_str() {
747 "COUNT" => AggFunc::Count,
748 "SUM" => AggFunc::Sum,
749 "AVG" => AggFunc::Avg,
750 "MIN" => AggFunc::Min,
751 "MAX" => AggFunc::Max,
752 _ => unreachable!(),
753 };
754 self.expect("op", Some("("))?;
755 let (arg, arg_expr) = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "*") {
756 self.advance();
757 ("*".to_string(), None)
758 } else {
759 let expr = self.parse_additive()?;
760 if let Expr::Column(name) = &expr {
761 (name.clone(), None)
762 } else {
763 (expr.display_name(), Some(Box::new(expr)))
764 }
765 };
766 self.expect("op", Some(")"))?;
767 Ok(Expr::Aggregate { func, arg, arg_expr })
768 }
769
770 fn parse_ident(&mut self) -> Result<String, MdqlError> {
771 let t = self.peek().ok_or_else(|| {
772 MdqlError::QueryParse("Expected identifier, got end of query".into())
773 })?;
774 match t.token_type.as_str() {
775 "ident" | "keyword" => {
776 let v = self.advance().value;
777 Ok(v)
778 }
779 _ => Err(MdqlError::QueryParse(format!(
780 "Expected identifier, got '{}'",
781 t.raw
782 ))),
783 }
784 }
785
786 fn parse_or_expr(&mut self) -> Result<WhereClause, MdqlError> {
787 let mut left = self.parse_and_expr()?;
788 while self.match_keyword("OR") {
789 let right = self.parse_and_expr()?;
790 left = WhereClause::BoolOp(BoolOp {
791 op: BoolOpKind::Or,
792 left: Box::new(left),
793 right: Box::new(right),
794 });
795 }
796 Ok(left)
797 }
798
799 fn parse_and_expr(&mut self) -> Result<WhereClause, MdqlError> {
800 let mut left = self.parse_comparison()?;
801 while self.match_keyword("AND") {
802 let right = self.parse_comparison()?;
803 left = WhereClause::BoolOp(BoolOp {
804 op: BoolOpKind::And,
805 left: Box::new(left),
806 right: Box::new(right),
807 });
808 }
809 Ok(left)
810 }
811
812 fn parse_comparison(&mut self) -> Result<WhereClause, MdqlError> {
813 if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
815 let saved_pos = self.pos;
817 self.advance();
818 let result = self.parse_or_expr();
820 if result.is_ok() && self.peek().map_or(false, |t| t.token_type == "op" && t.value == ")") {
821 self.advance();
822 return result;
823 }
824 self.pos = saved_pos;
826 }
827
828 let left_expr = self.parse_additive()?;
830
831 let col = left_expr.as_column().unwrap_or("").to_string();
833
834 if self.match_keyword("IS") {
836 if self.match_keyword("NOT") {
837 self.expect("keyword", Some("NULL"))?;
838 return Ok(WhereClause::Comparison(Comparison {
839 column: col,
840 op: CmpOp::IsNotNull,
841 value: None,
842 left_expr: Some(left_expr),
843 right_expr: None,
844 }));
845 }
846 self.expect("keyword", Some("NULL"))?;
847 return Ok(WhereClause::Comparison(Comparison {
848 column: col,
849 op: CmpOp::IsNull,
850 value: None,
851 left_expr: Some(left_expr),
852 right_expr: None,
853 }));
854 }
855
856 if self.match_keyword("IN") {
858 self.expect("op", Some("("))?;
859 let mut values = vec![self.parse_value()?];
860 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
861 self.advance();
862 values.push(self.parse_value()?);
863 }
864 self.expect("op", Some(")"))?;
865 return Ok(WhereClause::Comparison(Comparison {
866 column: col,
867 op: CmpOp::In,
868 value: Some(SqlValue::List(values)),
869 left_expr: Some(left_expr),
870 right_expr: None,
871 }));
872 }
873
874 if self.match_keyword("LIKE") {
876 let val = self.parse_value()?;
877 return Ok(WhereClause::Comparison(Comparison {
878 column: col,
879 op: CmpOp::Like,
880 value: Some(val),
881 left_expr: Some(left_expr),
882 right_expr: None,
883 }));
884 }
885
886 if self.match_keyword("NOT") {
888 if self.match_keyword("LIKE") {
889 let val = self.parse_value()?;
890 return Ok(WhereClause::Comparison(Comparison {
891 column: col,
892 op: CmpOp::NotLike,
893 value: Some(val),
894 left_expr: Some(left_expr),
895 right_expr: None,
896 }));
897 }
898 return Err(MdqlError::QueryParse("Expected LIKE after NOT".into()));
899 }
900
901 if let Some(t) = self.peek() {
903 if t.token_type == "op" && ["=", "!=", "<", ">", "<=", ">="].contains(&t.value.as_str())
904 {
905 let op_str = self.advance().value;
906 let op = match op_str.as_str() {
907 "=" => CmpOp::Eq,
908 "!=" => CmpOp::Ne,
909 "<" => CmpOp::Lt,
910 ">" => CmpOp::Gt,
911 "<=" => CmpOp::Le,
912 ">=" => CmpOp::Ge,
913 _ => unreachable!(),
914 };
915 let right_expr = self.parse_additive()?;
917 let value = match &right_expr {
919 Expr::Literal(v) => Some(v.clone()),
920 _ => None,
921 };
922 return Ok(WhereClause::Comparison(Comparison {
923 column: col,
924 op,
925 value,
926 left_expr: Some(left_expr),
927 right_expr: Some(right_expr),
928 }));
929 }
930 }
931
932 let got = self.peek().map_or("end".to_string(), |t| t.raw.clone());
933 Err(MdqlError::QueryParse(format!(
934 "Expected operator after '{}', got '{}'",
935 left_expr.display_name(), got
936 )))
937 }
938
939 fn parse_value(&mut self) -> Result<SqlValue, MdqlError> {
940 let t = self.peek().ok_or_else(|| {
941 MdqlError::QueryParse("Expected value, got end of query".into())
942 })?;
943 match t.token_type.as_str() {
944 "string" => {
945 let v = self.advance().value;
946 Ok(SqlValue::String(v))
947 }
948 "number" => {
949 let v = self.advance().value;
950 if v.contains('.') {
951 Ok(SqlValue::Float(v.parse().map_err(|_| {
952 MdqlError::QueryParse(format!("Invalid float: {}", v))
953 })?))
954 } else {
955 Ok(SqlValue::Int(v.parse().map_err(|_| {
956 MdqlError::QueryParse(format!("Invalid int: {}", v))
957 })?))
958 }
959 }
960 "keyword" if t.value == "NULL" => {
961 self.advance();
962 Ok(SqlValue::Null)
963 }
964 _ => Err(MdqlError::QueryParse(format!(
965 "Expected value, got '{}'",
966 t.raw
967 ))),
968 }
969 }
970
971 fn parse_order_by(&mut self) -> Result<Vec<OrderSpec>, MdqlError> {
972 let mut specs = vec![self.parse_order_spec()?];
973 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
974 self.advance();
975 specs.push(self.parse_order_spec()?);
976 }
977 Ok(specs)
978 }
979
980 fn parse_order_spec(&mut self) -> Result<OrderSpec, MdqlError> {
981 let expr = self.parse_additive()?;
982 let col = expr.as_column().unwrap_or("").to_string();
983 let descending = if self.match_keyword("DESC") {
984 true
985 } else {
986 self.match_keyword("ASC");
987 false
988 };
989 Ok(OrderSpec {
990 column: col,
991 expr: Some(expr),
992 descending,
993 })
994 }
995
996 fn is_clause_keyword(&self, t: &Token) -> bool {
997 t.token_type == "keyword"
998 && ["WHERE", "ORDER", "LIMIT", "JOIN", "LEFT", "ON", "GROUP"].contains(&t.value.as_str())
999 }
1000
1001 fn is_reserved_keyword(kw: &str) -> bool {
1003 matches!(kw,
1004 "AS" | "FROM" | "WHERE" | "AND" | "OR" | "ORDER" | "BY"
1005 | "ASC" | "DESC" | "LIMIT" | "JOIN" | "ON" | "GROUP"
1006 | "SELECT" | "INSERT" | "INTO" | "VALUES" | "UPDATE" | "SET"
1007 | "DELETE" | "ALTER" | "TABLE" | "IS" | "NOT" | "IN" | "LIKE"
1008 | "RENAME" | "FIELD" | "TO" | "DROP" | "MERGE" | "FIELDS"
1009 | "CASE" | "WHEN" | "THEN" | "ELSE" | "END"
1010 | "HAVING" | "INTERVAL" | "DAY" | "DAYS"
1011 | "CURRENT_DATE" | "CURRENT_TIMESTAMP" | "DATEDIFF"
1012 | "CREATE" | "VIEW" | "CASCADE" | "RESTRICT"
1013 )
1014 }
1015
1016 fn expect_end(&self) -> Result<(), MdqlError> {
1017 if let Some(t) = self.peek() {
1018 return Err(MdqlError::QueryParse(format!(
1019 "Unexpected token '{}' at position {}",
1020 t.raw, self.pos
1021 )));
1022 }
1023 Ok(())
1024 }
1025}
1026
1027pub fn parse_query(sql: &str) -> crate::errors::Result<Statement> {
1028 let tokens = tokenize(sql);
1029 if tokens.is_empty() {
1030 return Err(MdqlError::QueryParse("Empty query".into()));
1031 }
1032 let mut parser = Parser::new(tokens);
1033 parser.parse_statement()
1034}
1035
1036#[cfg(test)]
1037mod tests {
1038 use super::*;
1039
1040 #[test]
1041 fn test_simple_select() {
1042 let stmt = parse_query("SELECT title, status FROM strategies").unwrap();
1043 if let Statement::Select(q) = stmt {
1044 assert_eq!(q.columns, ColumnList::Named(vec![SelectExpr::Column("title".into()), SelectExpr::Column("status".into())]));
1045 assert_eq!(q.table, "strategies");
1046 } else {
1047 panic!("Expected Select");
1048 }
1049 }
1050
1051 #[test]
1052 fn test_select_star() {
1053 let stmt = parse_query("SELECT * FROM test").unwrap();
1054 if let Statement::Select(q) = stmt {
1055 assert_eq!(q.columns, ColumnList::All);
1056 } else {
1057 panic!("Expected Select");
1058 }
1059 }
1060
1061 #[test]
1062 fn test_where_clause() {
1063 let stmt = parse_query("SELECT title FROM test WHERE count > 5").unwrap();
1064 if let Statement::Select(q) = stmt {
1065 assert!(q.where_clause.is_some());
1066 } else {
1067 panic!("Expected Select");
1068 }
1069 }
1070
1071 #[test]
1072 fn test_order_by() {
1073 let stmt =
1074 parse_query("SELECT title FROM test ORDER BY composite DESC, title ASC").unwrap();
1075 if let Statement::Select(q) = stmt {
1076 let ob = q.order_by.unwrap();
1077 assert_eq!(ob.len(), 2);
1078 assert!(ob[0].descending);
1079 assert!(!ob[1].descending);
1080 } else {
1081 panic!("Expected Select");
1082 }
1083 }
1084
1085 #[test]
1086 fn test_limit() {
1087 let stmt = parse_query("SELECT * FROM test LIMIT 10").unwrap();
1088 if let Statement::Select(q) = stmt {
1089 assert_eq!(q.limit, Some(10));
1090 } else {
1091 panic!("Expected Select");
1092 }
1093 }
1094
1095 #[test]
1096 fn test_insert() {
1097 let stmt = parse_query(
1098 "INSERT INTO test (title, count) VALUES ('Hello', 42)",
1099 )
1100 .unwrap();
1101 if let Statement::Insert(q) = stmt {
1102 assert_eq!(q.table, "test");
1103 assert_eq!(q.columns, vec!["title", "count"]);
1104 assert_eq!(q.values[0], SqlValue::String("Hello".into()));
1105 assert_eq!(q.values[1], SqlValue::Int(42));
1106 } else {
1107 panic!("Expected Insert");
1108 }
1109 }
1110
1111 #[test]
1112 fn test_update() {
1113 let stmt = parse_query("UPDATE test SET status = 'KILLED' WHERE path = 'a.md'").unwrap();
1114 if let Statement::Update(q) = stmt {
1115 assert_eq!(q.table, "test");
1116 assert_eq!(q.assignments.len(), 1);
1117 assert!(q.where_clause.is_some());
1118 } else {
1119 panic!("Expected Update");
1120 }
1121 }
1122
1123 #[test]
1124 fn test_delete() {
1125 let stmt = parse_query("DELETE FROM test WHERE status = 'draft'").unwrap();
1126 if let Statement::Delete(q) = stmt {
1127 assert_eq!(q.table, "test");
1128 assert!(q.where_clause.is_some());
1129 } else {
1130 panic!("Expected Delete");
1131 }
1132 }
1133
1134 #[test]
1135 fn test_alter_rename() {
1136 let stmt =
1137 parse_query("ALTER TABLE test RENAME FIELD 'Summary' TO 'Overview'").unwrap();
1138 if let Statement::AlterRename(q) = stmt {
1139 assert_eq!(q.old_name, "Summary");
1140 assert_eq!(q.new_name, "Overview");
1141 } else {
1142 panic!("Expected AlterRename");
1143 }
1144 }
1145
1146 #[test]
1147 fn test_alter_drop() {
1148 let stmt = parse_query("ALTER TABLE test DROP FIELD 'Details'").unwrap();
1149 if let Statement::AlterDrop(q) = stmt {
1150 assert_eq!(q.field_name, "Details");
1151 } else {
1152 panic!("Expected AlterDrop");
1153 }
1154 }
1155
1156 #[test]
1157 fn test_alter_merge() {
1158 let stmt = parse_query(
1159 "ALTER TABLE test MERGE FIELDS 'Entry Rules', 'Exit Rules' INTO 'Trading Rules'",
1160 )
1161 .unwrap();
1162 if let Statement::AlterMerge(q) = stmt {
1163 assert_eq!(q.sources, vec!["Entry Rules", "Exit Rules"]);
1164 assert_eq!(q.into, "Trading Rules");
1165 } else {
1166 panic!("Expected AlterMerge");
1167 }
1168 }
1169
1170 #[test]
1171 fn test_backtick_ident() {
1172 let stmt = parse_query("SELECT `Structural Mechanism` FROM test").unwrap();
1173 if let Statement::Select(q) = stmt {
1174 assert_eq!(
1175 q.columns,
1176 ColumnList::Named(vec![SelectExpr::Column("Structural Mechanism".into())])
1177 );
1178 } else {
1179 panic!("Expected Select");
1180 }
1181 }
1182
1183 #[test]
1184 fn test_like_operator() {
1185 let stmt = parse_query("SELECT title FROM test WHERE categories LIKE '%defi%'").unwrap();
1186 if let Statement::Select(q) = stmt {
1187 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1188 assert_eq!(c.op, CmpOp::Like);
1189 assert_eq!(c.value, Some(SqlValue::String("%defi%".into())));
1190 } else {
1191 panic!("Expected LIKE comparison");
1192 }
1193 } else {
1194 panic!("Expected Select");
1195 }
1196 }
1197
1198 #[test]
1199 fn test_in_operator() {
1200 let stmt =
1201 parse_query("SELECT * FROM test WHERE status IN ('ACTIVE', 'LIVE')").unwrap();
1202 if let Statement::Select(q) = stmt {
1203 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1204 assert_eq!(c.op, CmpOp::In);
1205 } else {
1206 panic!("Expected IN comparison");
1207 }
1208 } else {
1209 panic!("Expected Select");
1210 }
1211 }
1212
1213 #[test]
1214 fn test_is_null() {
1215 let stmt = parse_query("SELECT * FROM test WHERE title IS NULL").unwrap();
1216 if let Statement::Select(q) = stmt {
1217 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1218 assert_eq!(c.op, CmpOp::IsNull);
1219 } else {
1220 panic!("Expected IS NULL comparison");
1221 }
1222 } else {
1223 panic!("Expected Select");
1224 }
1225 }
1226
1227 #[test]
1228 fn test_and_or() {
1229 let stmt = parse_query(
1230 "SELECT * FROM test WHERE status = 'ACTIVE' AND count > 5 OR title LIKE '%test%'",
1231 )
1232 .unwrap();
1233 if let Statement::Select(q) = stmt {
1234 assert!(q.where_clause.is_some());
1235 } else {
1236 panic!("Expected Select");
1237 }
1238 }
1239
1240 #[test]
1241 fn test_join() {
1242 let stmt = parse_query(
1243 "SELECT s.title, b.sharpe FROM strategies s JOIN backtests b ON b.strategy = s.path",
1244 )
1245 .unwrap();
1246 if let Statement::Select(q) = stmt {
1247 assert_eq!(q.table, "strategies");
1248 assert_eq!(q.table_alias, Some("s".into()));
1249 assert_eq!(q.joins.len(), 1);
1250 let join = &q.joins[0];
1251 assert_eq!(join.table, "backtests");
1252 assert_eq!(join.alias, Some("b".into()));
1253 } else {
1254 panic!("Expected Select");
1255 }
1256 }
1257
1258 #[test]
1259 fn test_multi_join() {
1260 let stmt = parse_query(
1261 "SELECT s.title, b.sharpe, c.verdict FROM strategies s JOIN backtests b ON b.strategy = s.path JOIN critiques c ON c.strategy = s.path",
1262 )
1263 .unwrap();
1264 if let Statement::Select(q) = stmt {
1265 assert_eq!(q.table, "strategies");
1266 assert_eq!(q.table_alias, Some("s".into()));
1267 assert_eq!(q.joins.len(), 2);
1268 assert_eq!(q.joins[0].table, "backtests");
1269 assert_eq!(q.joins[0].alias, Some("b".into()));
1270 assert_eq!(where_clause_to_sql(&q.joins[0].condition), "b.strategy = s.path");
1271 assert_eq!(q.joins[1].table, "critiques");
1272 assert_eq!(q.joins[1].alias, Some("c".into()));
1273 assert_eq!(where_clause_to_sql(&q.joins[1].condition), "c.strategy = s.path");
1274 } else {
1275 panic!("Expected Select");
1276 }
1277 }
1278
1279 #[test]
1280 fn test_left_join() {
1281 let stmt = parse_query(
1282 "SELECT s.title, b.sharpe FROM strategies s LEFT JOIN backtests b ON b.strategy = s.path",
1283 )
1284 .unwrap();
1285 if let Statement::Select(q) = stmt {
1286 assert_eq!(q.joins.len(), 1);
1287 assert_eq!(q.joins[0].join_type, JoinType::Left);
1288 assert_eq!(q.joins[0].table, "backtests");
1289 } else {
1290 panic!("Expected Select");
1291 }
1292 }
1293
1294 #[test]
1295 fn test_mixed_join_types() {
1296 let stmt = parse_query(
1297 "SELECT s.title FROM strategies s JOIN backtests b ON b.strategy = s.path LEFT JOIN allocations a ON a.strategy = s.path",
1298 )
1299 .unwrap();
1300 if let Statement::Select(q) = stmt {
1301 assert_eq!(q.joins.len(), 2);
1302 assert_eq!(q.joins[0].join_type, JoinType::Inner);
1303 assert_eq!(q.joins[1].join_type, JoinType::Left);
1304 } else {
1305 panic!("Expected Select");
1306 }
1307 }
1308
1309 #[test]
1310 fn test_join_compound_and() {
1311 let stmt = parse_query(
1312 "SELECT s.title FROM strategies s LEFT JOIN backtests b ON b.strategy = s.path AND b.mode = 'PAPER'",
1313 )
1314 .unwrap();
1315 if let Statement::Select(q) = stmt {
1316 assert_eq!(q.joins.len(), 1);
1317 assert_eq!(q.joins[0].join_type, JoinType::Left);
1318 let sql = where_clause_to_sql(&q.joins[0].condition);
1319 assert!(sql.contains("b.strategy = s.path"));
1320 assert!(sql.contains("AND"));
1321 assert!(sql.contains("b.mode = 'PAPER'"));
1322 } else {
1323 panic!("Expected Select");
1324 }
1325 }
1326
1327 #[test]
1328 fn test_join_compound_or() {
1329 let stmt = parse_query(
1330 "SELECT * FROM a JOIN b ON a.id = b.id OR a.alt = b.id",
1331 )
1332 .unwrap();
1333 if let Statement::Select(q) = stmt {
1334 let sql = where_clause_to_sql(&q.joins[0].condition);
1335 assert!(sql.contains("OR"));
1336 } else {
1337 panic!("Expected Select");
1338 }
1339 }
1340
1341 #[test]
1342 fn test_join_compound_with_where() {
1343 let stmt = parse_query(
1344 "SELECT s.title FROM strategies s JOIN backtests b ON b.strategy = s.path AND b.mode = 'PAPER' WHERE s.title = 'Alpha'",
1345 )
1346 .unwrap();
1347 if let Statement::Select(q) = stmt {
1348 assert_eq!(q.joins.len(), 1);
1349 assert!(q.where_clause.is_some());
1350 let join_sql = where_clause_to_sql(&q.joins[0].condition);
1351 assert!(join_sql.contains("AND"));
1352 } else {
1353 panic!("Expected Select");
1354 }
1355 }
1356
1357 #[test]
1358 fn test_empty_query() {
1359 assert!(parse_query("").is_err());
1360 }
1361
1362 #[test]
1363 fn test_count_star() {
1364 let stmt = parse_query("SELECT status, COUNT(*) AS cnt FROM strategies GROUP BY status").unwrap();
1365 if let Statement::Select(q) = stmt {
1366 if let ColumnList::Named(exprs) = &q.columns {
1367 assert_eq!(exprs.len(), 2);
1368 assert_eq!(exprs[0], SelectExpr::Column("status".into()));
1369 assert!(matches!(&exprs[1], SelectExpr::Aggregate {
1370 func: AggFunc::Count,
1371 arg,
1372 alias: Some(a),
1373 ..
1374 } if arg == "*" && a == "cnt"));
1375 } else {
1376 panic!("Expected Named columns");
1377 }
1378 assert_eq!(q.group_by, Some(vec!["status".into()]));
1379 } else {
1380 panic!("Expected Select");
1381 }
1382 }
1383
1384 #[test]
1385 fn test_count_column_as_ident() {
1386 let stmt = parse_query("INSERT INTO test (title, count) VALUES ('Hello', 42)").unwrap();
1388 if let Statement::Insert(q) = stmt {
1389 assert_eq!(q.columns, vec!["title", "count"]);
1390 } else {
1391 panic!("Expected Insert");
1392 }
1393 }
1394
1395 #[test]
1396 fn test_multiple_aggregates() {
1397 let stmt = parse_query("SELECT MIN(composite), MAX(composite), AVG(composite) FROM strategies").unwrap();
1398 if let Statement::Select(q) = stmt {
1399 if let ColumnList::Named(exprs) = &q.columns {
1400 assert_eq!(exprs.len(), 3);
1401 assert!(matches!(&exprs[0], SelectExpr::Aggregate { func: AggFunc::Min, .. }));
1402 assert!(matches!(&exprs[1], SelectExpr::Aggregate { func: AggFunc::Max, .. }));
1403 assert!(matches!(&exprs[2], SelectExpr::Aggregate { func: AggFunc::Avg, .. }));
1404 } else {
1405 panic!("Expected Named columns");
1406 }
1407 assert_eq!(q.group_by, None);
1408 } else {
1409 panic!("Expected Select");
1410 }
1411 }
1412
1413 #[test]
1416 fn test_select_arithmetic_expr() {
1417 let stmt = parse_query("SELECT a + b FROM test").unwrap();
1418 if let Statement::Select(q) = stmt {
1419 if let ColumnList::Named(exprs) = &q.columns {
1420 assert_eq!(exprs.len(), 1);
1421 assert!(matches!(&exprs[0], SelectExpr::Expr {
1422 expr: Expr::BinaryOp { op: ArithOp::Add, .. },
1423 alias: None,
1424 }));
1425 } else {
1426 panic!("Expected Named columns");
1427 }
1428 } else {
1429 panic!("Expected Select");
1430 }
1431 }
1432
1433 #[test]
1434 fn test_select_arithmetic_with_alias() {
1435 let stmt = parse_query("SELECT a + b AS total FROM test").unwrap();
1436 if let Statement::Select(q) = stmt {
1437 if let ColumnList::Named(exprs) = &q.columns {
1438 assert_eq!(exprs.len(), 1);
1439 assert!(matches!(&exprs[0], SelectExpr::Expr {
1440 alias: Some(a),
1441 ..
1442 } if a == "total"));
1443 assert_eq!(exprs[0].output_name(), "total");
1444 } else {
1445 panic!("Expected Named columns");
1446 }
1447 } else {
1448 panic!("Expected Select");
1449 }
1450 }
1451
1452 #[test]
1453 fn test_select_precedence() {
1454 let stmt = parse_query("SELECT a + b * c FROM test").unwrap();
1456 if let Statement::Select(q) = stmt {
1457 if let ColumnList::Named(exprs) = &q.columns {
1458 if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1459 if let Expr::BinaryOp { left, op, right } = expr {
1460 assert_eq!(*op, ArithOp::Add);
1461 assert!(matches!(left.as_ref(), Expr::Column(n) if n == "a"));
1462 assert!(matches!(right.as_ref(), Expr::BinaryOp { op: ArithOp::Mul, .. }));
1463 } else {
1464 panic!("Expected BinaryOp");
1465 }
1466 } else {
1467 panic!("Expected Expr variant");
1468 }
1469 } else {
1470 panic!("Expected Named columns");
1471 }
1472 } else {
1473 panic!("Expected Select");
1474 }
1475 }
1476
1477 #[test]
1478 fn test_select_parenthesized_expr() {
1479 let stmt = parse_query("SELECT (a + b) * c FROM test").unwrap();
1481 if let Statement::Select(q) = stmt {
1482 if let ColumnList::Named(exprs) = &q.columns {
1483 if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1484 if let Expr::BinaryOp { left, op, .. } = expr {
1485 assert_eq!(*op, ArithOp::Mul);
1486 assert!(matches!(left.as_ref(), Expr::BinaryOp { op: ArithOp::Add, .. }));
1487 } else {
1488 panic!("Expected BinaryOp");
1489 }
1490 } else {
1491 panic!("Expected Expr variant");
1492 }
1493 } else {
1494 panic!("Expected Named columns");
1495 }
1496 } else {
1497 panic!("Expected Select");
1498 }
1499 }
1500
1501 #[test]
1502 fn test_select_unary_minus() {
1503 let stmt = parse_query("SELECT -count FROM test").unwrap();
1504 if let Statement::Select(q) = stmt {
1505 if let ColumnList::Named(exprs) = &q.columns {
1506 assert!(matches!(&exprs[0], SelectExpr::Expr {
1507 expr: Expr::UnaryMinus(_),
1508 ..
1509 }));
1510 } else {
1511 panic!("Expected Named columns");
1512 }
1513 } else {
1514 panic!("Expected Select");
1515 }
1516 }
1517
1518 #[test]
1519 fn test_select_negative_literal() {
1520 let stmt = parse_query("SELECT -42 FROM test").unwrap();
1521 if let Statement::Select(q) = stmt {
1522 if let ColumnList::Named(exprs) = &q.columns {
1523 assert!(matches!(&exprs[0], SelectExpr::Expr {
1525 expr: Expr::Literal(SqlValue::Int(-42)),
1526 ..
1527 }));
1528 } else {
1529 panic!("Expected Named columns");
1530 }
1531 } else {
1532 panic!("Expected Select");
1533 }
1534 }
1535
1536 #[test]
1537 fn test_where_arithmetic_expr() {
1538 let stmt = parse_query("SELECT * FROM test WHERE a + b > 10").unwrap();
1539 if let Statement::Select(q) = stmt {
1540 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1541 assert_eq!(c.op, CmpOp::Gt);
1542 assert!(matches!(&c.left_expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1543 assert!(matches!(&c.right_expr, Some(Expr::Literal(SqlValue::Int(10)))));
1544 } else {
1545 panic!("Expected comparison");
1546 }
1547 } else {
1548 panic!("Expected Select");
1549 }
1550 }
1551
1552 #[test]
1553 fn test_where_both_sides_expr() {
1554 let stmt = parse_query("SELECT * FROM test WHERE a * 2 > b + 1").unwrap();
1555 if let Statement::Select(q) = stmt {
1556 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1557 assert_eq!(c.op, CmpOp::Gt);
1558 assert!(matches!(&c.left_expr, Some(Expr::BinaryOp { op: ArithOp::Mul, .. })));
1559 assert!(matches!(&c.right_expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1560 } else {
1561 panic!("Expected comparison");
1562 }
1563 } else {
1564 panic!("Expected Select");
1565 }
1566 }
1567
1568 #[test]
1569 fn test_order_by_expr() {
1570 let stmt = parse_query("SELECT * FROM test ORDER BY a + b DESC").unwrap();
1571 if let Statement::Select(q) = stmt {
1572 let ob = q.order_by.unwrap();
1573 assert_eq!(ob.len(), 1);
1574 assert!(ob[0].descending);
1575 assert!(matches!(&ob[0].expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1576 } else {
1577 panic!("Expected Select");
1578 }
1579 }
1580
1581 #[test]
1582 fn test_all_arithmetic_ops() {
1583 let stmt = parse_query("SELECT a + b, a - b, a * b, a / b, a % b FROM test").unwrap();
1584 if let Statement::Select(q) = stmt {
1585 if let ColumnList::Named(exprs) = &q.columns {
1586 assert_eq!(exprs.len(), 5);
1587 assert!(matches!(&exprs[0], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Add, .. }, .. }));
1588 assert!(matches!(&exprs[1], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Sub, .. }, .. }));
1589 assert!(matches!(&exprs[2], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Mul, .. }, .. }));
1590 assert!(matches!(&exprs[3], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Div, .. }, .. }));
1591 assert!(matches!(&exprs[4], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Mod, .. }, .. }));
1592 } else {
1593 panic!("Expected Named columns");
1594 }
1595 } else {
1596 panic!("Expected Select");
1597 }
1598 }
1599
1600 #[test]
1601 fn test_column_with_literal_arithmetic() {
1602 let stmt = parse_query("SELECT count * 2 + 1 FROM test").unwrap();
1603 if let Statement::Select(q) = stmt {
1604 if let ColumnList::Named(exprs) = &q.columns {
1605 if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1607 if let Expr::BinaryOp { left, op, right } = expr {
1608 assert_eq!(*op, ArithOp::Add);
1609 assert!(matches!(right.as_ref(), Expr::Literal(SqlValue::Int(1))));
1610 assert!(matches!(left.as_ref(), Expr::BinaryOp { op: ArithOp::Mul, .. }));
1611 } else {
1612 panic!("Expected BinaryOp");
1613 }
1614 } else {
1615 panic!("Expected Expr");
1616 }
1617 } else {
1618 panic!("Expected Named columns");
1619 }
1620 } else {
1621 panic!("Expected Select");
1622 }
1623 }
1624
1625 #[test]
1626 fn test_mixed_columns_and_exprs() {
1627 let stmt = parse_query("SELECT title, a + b AS sum, count FROM test").unwrap();
1628 if let Statement::Select(q) = stmt {
1629 if let ColumnList::Named(exprs) = &q.columns {
1630 assert_eq!(exprs.len(), 3);
1631 assert_eq!(exprs[0], SelectExpr::Column("title".into()));
1632 assert!(matches!(&exprs[1], SelectExpr::Expr { alias: Some(a), .. } if a == "sum"));
1633 assert_eq!(exprs[2], SelectExpr::Column("count".into()));
1634 } else {
1635 panic!("Expected Named columns");
1636 }
1637 } else {
1638 panic!("Expected Select");
1639 }
1640 }
1641
1642 #[test]
1645 fn test_case_when_basic() {
1646 let stmt = parse_query(
1647 "SELECT CASE WHEN status = 'ACTIVE' THEN 1 ELSE 0 END FROM test"
1648 ).unwrap();
1649 if let Statement::Select(q) = stmt {
1650 if let ColumnList::Named(exprs) = &q.columns {
1651 assert_eq!(exprs.len(), 1);
1652 assert!(matches!(&exprs[0], SelectExpr::Expr {
1653 expr: Expr::Case { .. },
1654 ..
1655 }));
1656 } else {
1657 panic!("Expected Named columns");
1658 }
1659 } else {
1660 panic!("Expected Select");
1661 }
1662 }
1663
1664 #[test]
1665 fn test_case_when_multiple_branches() {
1666 let stmt = parse_query(
1667 "SELECT CASE WHEN x > 10 THEN 'high' WHEN x > 5 THEN 'mid' ELSE 'low' END FROM test"
1668 ).unwrap();
1669 if let Statement::Select(q) = stmt {
1670 if let ColumnList::Named(exprs) = &q.columns {
1671 if let SelectExpr::Expr { expr: Expr::Case { whens, else_expr }, .. } = &exprs[0] {
1672 assert_eq!(whens.len(), 2);
1673 assert!(else_expr.is_some());
1674 } else {
1675 panic!("Expected Case expression");
1676 }
1677 } else {
1678 panic!("Expected Named columns");
1679 }
1680 } else {
1681 panic!("Expected Select");
1682 }
1683 }
1684
1685 #[test]
1686 fn test_case_when_no_else() {
1687 let stmt = parse_query(
1688 "SELECT CASE WHEN x = 1 THEN 'one' END FROM test"
1689 ).unwrap();
1690 if let Statement::Select(q) = stmt {
1691 if let ColumnList::Named(exprs) = &q.columns {
1692 if let SelectExpr::Expr { expr: Expr::Case { whens, else_expr }, .. } = &exprs[0] {
1693 assert_eq!(whens.len(), 1);
1694 assert!(else_expr.is_none());
1695 } else {
1696 panic!("Expected Case expression");
1697 }
1698 } else {
1699 panic!("Expected Named columns");
1700 }
1701 } else {
1702 panic!("Expected Select");
1703 }
1704 }
1705
1706 #[test]
1707 fn test_case_when_in_aggregate() {
1708 let stmt = parse_query(
1709 "SELECT SUM(CASE WHEN side = 'BUY' THEN size ELSE -size END) AS net FROM orders GROUP BY token"
1710 ).unwrap();
1711 if let Statement::Select(q) = stmt {
1712 if let ColumnList::Named(exprs) = &q.columns {
1713 assert_eq!(exprs.len(), 1);
1714 assert!(matches!(&exprs[0], SelectExpr::Aggregate {
1715 func: AggFunc::Sum,
1716 arg_expr: Some(Expr::Case { .. }),
1717 alias: Some(a),
1718 ..
1719 } if a == "net"));
1720 } else {
1721 panic!("Expected Named columns");
1722 }
1723 } else {
1724 panic!("Expected Select");
1725 }
1726 }
1727
1728 #[test]
1729 fn test_case_when_with_alias() {
1730 let stmt = parse_query(
1731 "SELECT CASE WHEN x > 0 THEN 'pos' ELSE 'neg' END AS sign FROM test"
1732 ).unwrap();
1733 if let Statement::Select(q) = stmt {
1734 if let ColumnList::Named(exprs) = &q.columns {
1735 assert!(matches!(&exprs[0], SelectExpr::Expr {
1736 expr: Expr::Case { .. },
1737 alias: Some(a),
1738 } if a == "sign"));
1739 } else {
1740 panic!("Expected Named columns");
1741 }
1742 } else {
1743 panic!("Expected Select");
1744 }
1745 }
1746
1747 #[test]
1748 fn test_create_view() {
1749 let stmt = parse_query("CREATE VIEW live AS SELECT * FROM strategies WHERE status = 'LIVE'").unwrap();
1750 if let Statement::CreateView(cv) = stmt {
1751 assert_eq!(cv.view_name, "live");
1752 assert!(cv.columns.is_none());
1753 assert_eq!(cv.query.table, "strategies");
1754 assert!(cv.query.where_clause.is_some());
1755 } else {
1756 panic!("Expected CreateView, got {:?}", stmt);
1757 }
1758 }
1759
1760 #[test]
1761 fn test_create_view_with_columns() {
1762 let stmt = parse_query("CREATE VIEW v1 (a, b) AS SELECT title, status FROM t").unwrap();
1763 if let Statement::CreateView(cv) = stmt {
1764 assert_eq!(cv.view_name, "v1");
1765 assert_eq!(cv.columns, Some(vec!["a".into(), "b".into()]));
1766 } else {
1767 panic!("Expected CreateView");
1768 }
1769 }
1770
1771 #[test]
1772 fn test_drop_view() {
1773 let stmt = parse_query("DROP VIEW live").unwrap();
1774 if let Statement::DropView(dv) = stmt {
1775 assert_eq!(dv.view_name, "live");
1776 } else {
1777 panic!("Expected DropView, got {:?}", stmt);
1778 }
1779 }
1780
1781 #[test]
1782 fn test_create_view_case_insensitive() {
1783 let stmt = parse_query("create view My_View as select * from t").unwrap();
1784 if let Statement::CreateView(cv) = stmt {
1785 assert_eq!(cv.view_name, "My_View");
1786 } else {
1787 panic!("Expected CreateView");
1788 }
1789 }
1790
1791 #[test]
1794 fn test_aggregate_division() {
1795 let stmt = parse_query(
1796 "SELECT token, SUM(sell) / SUM(buy) as ratio FROM orders GROUP BY token"
1797 ).unwrap();
1798 if let Statement::Select(q) = stmt {
1799 assert_eq!(q.group_by, Some(vec!["token".into()]));
1800 if let ColumnList::Named(exprs) = &q.columns {
1801 assert_eq!(exprs.len(), 2);
1802 assert!(exprs[1].is_aggregate());
1803 } else {
1804 panic!("Expected Named columns");
1805 }
1806 } else {
1807 panic!("Expected Select");
1808 }
1809 }
1810
1811 #[test]
1812 fn test_aggregate_subtraction() {
1813 let stmt = parse_query(
1814 "SELECT token, SUM(sell) - SUM(buy) as net FROM orders GROUP BY token"
1815 ).unwrap();
1816 if let Statement::Select(q) = stmt {
1817 if let ColumnList::Named(exprs) = &q.columns {
1818 assert_eq!(exprs[1].output_name(), "net");
1819 }
1820 } else {
1821 panic!("Expected Select");
1822 }
1823 }
1824
1825 #[test]
1826 fn test_create_view_with_arithmetic() {
1827 let stmt = parse_query(
1828 "CREATE VIEW positions AS SELECT token, SUM(sell) / SUM(buy) as ratio FROM orders GROUP BY token"
1829 ).unwrap();
1830 if let Statement::CreateView(cv) = stmt {
1831 assert_eq!(cv.view_name, "positions");
1832 } else {
1833 panic!("Expected CreateView, got {:?}", stmt);
1834 }
1835 }
1836
1837 #[test]
1840 fn test_subquery_in_from() {
1841 let stmt = parse_query(
1842 "SELECT token, sell_size FROM (SELECT token, SUM(size) as sell_size FROM orders GROUP BY token) LIMIT 5"
1843 ).unwrap();
1844 if let Statement::Select(q) = stmt {
1845 assert!(q.subquery.is_some());
1846 assert_eq!(q.limit, Some(5));
1847 let sub = q.subquery.unwrap();
1848 assert_eq!(sub.table, "orders");
1849 assert!(sub.group_by.is_some());
1850 } else {
1851 panic!("Expected Select");
1852 }
1853 }
1854
1855 #[test]
1858 fn test_create_view_with_having() {
1859 let stmt = parse_query(
1860 "CREATE VIEW positions AS SELECT token, SUM(sell) as sell_size, SUM(buy) as buy_size FROM orders GROUP BY token HAVING sell_size > buy_size"
1861 ).unwrap();
1862 if let Statement::CreateView(cv) = stmt {
1863 assert_eq!(cv.view_name, "positions");
1864 assert!(cv.query.having.is_some());
1865 } else {
1866 panic!("Expected CreateView, got {:?}", stmt);
1867 }
1868 }
1869
1870 #[test]
1873 fn test_aggregate_multiplication() {
1874 let stmt = parse_query(
1875 "SELECT SUM(a) * 2 as doubled FROM test"
1876 ).unwrap();
1877 if let Statement::Select(q) = stmt {
1878 if let ColumnList::Named(exprs) = &q.columns {
1879 assert_eq!(exprs.len(), 1);
1880 assert!(exprs[0].is_aggregate());
1881 assert_eq!(exprs[0].output_name(), "doubled");
1882 } else {
1883 panic!("Expected Named columns");
1884 }
1885 } else {
1886 panic!("Expected Select");
1887 }
1888 }
1889
1890 #[test]
1891 fn test_complex_aggregate_arithmetic() {
1892 let stmt = parse_query(
1893 "SELECT SUM(CASE WHEN side = 'SELL' THEN size ELSE 0 END) / SUM(CASE WHEN side = 'BUY' THEN size ELSE 0 END) as ratio FROM orders GROUP BY token"
1894 ).unwrap();
1895 if let Statement::Select(q) = stmt {
1896 if let ColumnList::Named(exprs) = &q.columns {
1897 assert_eq!(exprs.len(), 1);
1898 assert!(exprs[0].is_aggregate());
1899 assert_eq!(exprs[0].output_name(), "ratio");
1900 } else {
1901 panic!("Expected Named columns");
1902 }
1903 assert_eq!(q.group_by, Some(vec!["token".into()]));
1904 } else {
1905 panic!("Expected Select");
1906 }
1907 }
1908
1909 #[test]
1912 fn test_subquery_with_alias() {
1913 let stmt = parse_query(
1914 "SELECT x FROM (SELECT x FROM t) sub"
1915 ).unwrap();
1916 if let Statement::Select(q) = stmt {
1917 assert!(q.subquery.is_some());
1918 let sub = q.subquery.unwrap();
1919 assert_eq!(sub.table, "t");
1920 if let ColumnList::Named(exprs) = &q.columns {
1921 assert_eq!(exprs.len(), 1);
1922 assert_eq!(exprs[0].output_name(), "x");
1923 } else {
1924 panic!("Expected Named columns");
1925 }
1926 } else {
1927 panic!("Expected Select");
1928 }
1929 }
1930
1931 #[test]
1932 fn test_subquery_with_where() {
1933 let stmt = parse_query(
1934 "SELECT x FROM (SELECT x FROM t WHERE y > 0) LIMIT 5"
1935 ).unwrap();
1936 if let Statement::Select(q) = stmt {
1937 assert!(q.subquery.is_some());
1938 assert_eq!(q.limit, Some(5));
1939 let sub = q.subquery.unwrap();
1940 assert_eq!(sub.table, "t");
1941 assert!(sub.where_clause.is_some());
1942 } else {
1943 panic!("Expected Select");
1944 }
1945 }
1946
1947 #[test]
1950 fn test_create_view_aggregate_subtraction() {
1951 let stmt = parse_query(
1952 "CREATE VIEW v AS SELECT token, SUM(sell) - SUM(buy) as net FROM orders GROUP BY token"
1953 ).unwrap();
1954 if let Statement::CreateView(cv) = stmt {
1955 assert_eq!(cv.view_name, "v");
1956 assert_eq!(cv.query.group_by, Some(vec!["token".into()]));
1957 if let ColumnList::Named(exprs) = &cv.query.columns {
1958 assert_eq!(exprs.len(), 2);
1959 assert_eq!(exprs[1].output_name(), "net");
1960 assert!(exprs[1].is_aggregate());
1961 } else {
1962 panic!("Expected Named columns");
1963 }
1964 } else {
1965 panic!("Expected CreateView, got {:?}", stmt);
1966 }
1967 }
1968
1969 #[test]
1970 fn test_delete_cascade() {
1971 let stmt = parse_query("DELETE FROM strategies WHERE status = 'KILLED' CASCADE").unwrap();
1972 if let Statement::Delete(q) = stmt {
1973 assert_eq!(q.table, "strategies");
1974 assert!(q.where_clause.is_some());
1975 assert_eq!(q.mode, DeleteMode::Cascade);
1976 } else {
1977 panic!("Expected Delete");
1978 }
1979 }
1980
1981 #[test]
1982 fn test_delete_restrict() {
1983 let stmt = parse_query("DELETE FROM strategies WHERE path = 'alpha.md' RESTRICT").unwrap();
1984 if let Statement::Delete(q) = stmt {
1985 assert_eq!(q.table, "strategies");
1986 assert_eq!(q.mode, DeleteMode::Restrict);
1987 } else {
1988 panic!("Expected Delete");
1989 }
1990 }
1991
1992 #[test]
1993 fn test_delete_default_unchanged() {
1994 let stmt = parse_query("DELETE FROM strategies WHERE status = 'KILLED'").unwrap();
1995 if let Statement::Delete(q) = stmt {
1996 assert_eq!(q.mode, DeleteMode::Default);
1997 } else {
1998 panic!("Expected Delete");
1999 }
2000 }
2001
2002 #[test]
2003 fn test_delete_cascade_no_where() {
2004 let stmt = parse_query("DELETE FROM strategies CASCADE").unwrap();
2005 if let Statement::Delete(q) = stmt {
2006 assert_eq!(q.table, "strategies");
2007 assert!(q.where_clause.is_none());
2008 assert_eq!(q.mode, DeleteMode::Cascade);
2009 } else {
2010 panic!("Expected Delete");
2011 }
2012 }
2013}