1use regex::Regex;
4use std::sync::LazyLock;
5
6use crate::errors::MdqlError;
7pub use crate::query_ast::*;
8
9static KEYWORDS: &[&str] = &[
12 "SELECT", "FROM", "WHERE", "AND", "OR", "ORDER", "BY",
13 "ASC", "DESC", "LIMIT", "LIKE", "IN", "IS", "NOT", "NULL",
14 "JOIN", "ON", "AS", "GROUP", "HAVING",
15 "INSERT", "INTO", "VALUES", "UPDATE", "SET", "DELETE",
16 "ALTER", "TABLE", "RENAME", "FIELD", "TO", "DROP", "MERGE", "FIELDS",
17 "CASE", "WHEN", "THEN", "ELSE", "END",
18 "INTERVAL", "DAY", "DAYS", "CURRENT_DATE", "CURRENT_TIMESTAMP", "DATEDIFF",
19 "CREATE", "VIEW", "CASCADE", "RESTRICT",
20];
21
22static AGG_FUNCS: &[&str] = &["COUNT", "SUM", "AVG", "MIN", "MAX"];
23
24static TOKEN_RE: LazyLock<Regex> = LazyLock::new(|| {
25 Regex::new(
26 r#"(?x)
27 \s*(?:
28 (?P<backtick>`[^`]+`)
29 | (?P<string>'(?:[^'\\]|\\.)*')
30 | (?P<number>\d+(?:\.\d+)?)
31 | (?P<op><=|>=|!=|[=<>,*()+\-/%])
32 | (?P<word>[A-Za-z_][A-Za-z0-9_./-]*)
33 )"#,
34 )
35 .unwrap()
36});
37
38#[derive(Debug, Clone)]
39struct Token {
40 token_type: String,
41 value: String,
42 raw: String,
43}
44
45fn tokenize(sql: &str) -> Vec<Token> {
46 let mut tokens = Vec::new();
47 for caps in TOKEN_RE.captures_iter(sql) {
48 if let Some(m) = caps.name("backtick") {
49 let raw = m.as_str();
50 tokens.push(Token {
51 token_type: "ident".into(),
52 value: raw[1..raw.len() - 1].into(),
53 raw: raw.into(),
54 });
55 } else if let Some(m) = caps.name("string") {
56 let raw = m.as_str();
57 tokens.push(Token {
58 token_type: "string".into(),
59 value: raw[1..raw.len() - 1].into(),
60 raw: raw.into(),
61 });
62 } else if let Some(m) = caps.name("number") {
63 let raw = m.as_str();
64 tokens.push(Token {
65 token_type: "number".into(),
66 value: raw.into(),
67 raw: raw.into(),
68 });
69 } else if let Some(m) = caps.name("op") {
70 let raw = m.as_str();
71 tokens.push(Token {
72 token_type: "op".into(),
73 value: raw.into(),
74 raw: raw.into(),
75 });
76 } else if let Some(m) = caps.name("word") {
77 let raw = m.as_str();
78 if KEYWORDS.contains(&raw.to_uppercase().as_str()) {
79 tokens.push(Token {
80 token_type: "keyword".into(),
81 value: raw.to_uppercase(),
82 raw: raw.into(),
83 });
84 } else {
85 tokens.push(Token {
86 token_type: "ident".into(),
87 value: raw.into(),
88 raw: raw.into(),
89 });
90 }
91 }
92 }
93 tokens
94}
95
96struct Parser {
99 tokens: Vec<Token>,
100 pos: usize,
101}
102
103impl Parser {
104 fn new(tokens: Vec<Token>) -> Self {
105 Parser { tokens, pos: 0 }
106 }
107
108 fn peek(&self) -> Option<&Token> {
109 self.tokens.get(self.pos)
110 }
111
112 fn advance(&mut self) -> Token {
113 let t = self.tokens[self.pos].clone();
114 self.pos += 1;
115 t
116 }
117
118 fn expect(&mut self, type_: &str, value: Option<&str>) -> Result<Token, MdqlError> {
119 let t = self.peek().ok_or_else(|| {
120 MdqlError::QueryParse(format!(
121 "Unexpected end of query, expected {}",
122 value.unwrap_or(type_)
123 ))
124 })?;
125 let matches_type = t.token_type == type_;
126 let matches_value = value.map_or(true, |v| t.value == v);
127 if !matches_type || !matches_value {
128 return Err(MdqlError::QueryParse(format!(
129 "Expected {}, got '{}' at position {}",
130 value.unwrap_or(type_),
131 t.raw,
132 self.pos
133 )));
134 }
135 Ok(self.advance())
136 }
137
138 fn match_keyword(&mut self, kw: &str) -> bool {
139 if let Some(t) = self.peek() {
140 if t.token_type == "keyword" && t.value == kw {
141 self.advance();
142 return true;
143 }
144 }
145 false
146 }
147
148 fn parse_statement(&mut self) -> Result<Statement, MdqlError> {
149 let t = self.peek().ok_or_else(|| MdqlError::QueryParse("Empty query".into()))?;
150 match (t.token_type.as_str(), t.value.as_str()) {
151 ("keyword", "SELECT") => Ok(Statement::Select(self.parse_select()?)),
152 ("keyword", "INSERT") => Ok(Statement::Insert(self.parse_insert()?)),
153 ("keyword", "UPDATE") => Ok(Statement::Update(self.parse_update()?)),
154 ("keyword", "DELETE") => Ok(Statement::Delete(self.parse_delete()?)),
155 ("keyword", "ALTER") => self.parse_alter(),
156 ("keyword", "CREATE") => self.parse_create_view(),
157 ("keyword", "DROP") => self.parse_drop_view(),
158 _ => Err(MdqlError::QueryParse(format!(
159 "Expected SELECT, INSERT, UPDATE, DELETE, ALTER, CREATE, or DROP, got '{}'",
160 t.raw
161 ))),
162 }
163 }
164
165 fn parse_select(&mut self) -> Result<SelectQuery, MdqlError> {
166 self.expect("keyword", Some("SELECT"))?;
167 let columns = self.parse_columns()?;
168 self.expect("keyword", Some("FROM"))?;
169 let table = self.parse_ident()?;
170
171 let mut table_alias = None;
173 if let Some(t) = self.peek() {
174 if t.token_type == "ident" && !self.is_clause_keyword(t) {
175 table_alias = Some(self.advance().value);
176 }
177 }
178
179 let mut joins = Vec::new();
181 while self.match_keyword("JOIN") {
182 let join_table = self.parse_ident()?;
183 let mut join_alias = None;
184 if let Some(t) = self.peek() {
185 if t.token_type == "ident" && !self.is_clause_keyword(t) {
186 join_alias = Some(self.advance().value);
187 }
188 }
189 self.expect("keyword", Some("ON"))?;
190 let left_col = self.parse_ident()?;
191 self.expect("op", Some("="))?;
192 let right_col = self.parse_ident()?;
193 joins.push(JoinClause {
194 table: join_table,
195 alias: join_alias,
196 left_col,
197 right_col,
198 });
199 }
200
201 let mut where_clause = None;
202 if self.match_keyword("WHERE") {
203 where_clause = Some(self.parse_or_expr()?);
204 }
205
206 let mut group_by = None;
207 if self.match_keyword("GROUP") {
208 self.expect("keyword", Some("BY"))?;
209 let mut cols = vec![self.parse_ident()?];
210 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
211 self.advance();
212 cols.push(self.parse_ident()?);
213 }
214 group_by = Some(cols);
215 }
216
217 let mut having = None;
218 if self.match_keyword("HAVING") {
219 having = Some(self.parse_or_expr()?);
220 }
221
222 let mut order_by = None;
223 if self.match_keyword("ORDER") {
224 self.expect("keyword", Some("BY"))?;
225 order_by = Some(self.parse_order_by()?);
226 }
227
228 let mut limit = None;
229 if self.match_keyword("LIMIT") {
230 let t = self.expect("number", None)?;
231 limit = Some(t.value.parse::<i64>().map_err(|_| {
232 MdqlError::QueryParse(format!("Invalid LIMIT value: {}", t.value))
233 })?);
234 }
235
236 self.expect_end()?;
237
238 Ok(SelectQuery {
239 columns,
240 table,
241 table_alias,
242 joins,
243 where_clause,
244 group_by,
245 having,
246 order_by,
247 limit,
248 })
249 }
250
251 fn parse_insert(&mut self) -> Result<InsertQuery, MdqlError> {
252 self.expect("keyword", Some("INSERT"))?;
253 self.expect("keyword", Some("INTO"))?;
254 let table = self.parse_ident()?;
255
256 self.expect("op", Some("("))?;
257 let mut columns = vec![self.parse_ident()?];
258 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
259 self.advance();
260 columns.push(self.parse_ident()?);
261 }
262 self.expect("op", Some(")"))?;
263
264 self.expect("keyword", Some("VALUES"))?;
265
266 self.expect("op", Some("("))?;
267 let mut values = vec![self.parse_value()?];
268 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
269 self.advance();
270 values.push(self.parse_value()?);
271 }
272 self.expect("op", Some(")"))?;
273
274 if columns.len() != values.len() {
275 return Err(MdqlError::QueryParse(format!(
276 "Column count ({}) does not match value count ({})",
277 columns.len(),
278 values.len()
279 )));
280 }
281
282 self.expect_end()?;
283 Ok(InsertQuery {
284 table,
285 columns,
286 values,
287 })
288 }
289
290 fn parse_update(&mut self) -> Result<UpdateQuery, MdqlError> {
291 self.expect("keyword", Some("UPDATE"))?;
292 let table = self.parse_ident()?;
293 self.expect("keyword", Some("SET"))?;
294
295 let mut assignments = Vec::new();
296 let col = self.parse_ident()?;
297 self.expect("op", Some("="))?;
298 let val = self.parse_value()?;
299 assignments.push((col, val));
300
301 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
302 self.advance();
303 let col = self.parse_ident()?;
304 self.expect("op", Some("="))?;
305 let val = self.parse_value()?;
306 assignments.push((col, val));
307 }
308
309 let mut where_clause = None;
310 if self.match_keyword("WHERE") {
311 where_clause = Some(self.parse_or_expr()?);
312 }
313
314 self.expect_end()?;
315 Ok(UpdateQuery {
316 table,
317 assignments,
318 where_clause,
319 })
320 }
321
322 fn parse_delete(&mut self) -> Result<DeleteQuery, MdqlError> {
323 self.expect("keyword", Some("DELETE"))?;
324 self.expect("keyword", Some("FROM"))?;
325 let table = self.parse_ident()?;
326
327 let mut where_clause = None;
328 if self.match_keyword("WHERE") {
329 where_clause = Some(self.parse_or_expr()?);
330 }
331
332 self.expect_end()?;
333 Ok(DeleteQuery {
334 table,
335 where_clause,
336 })
337 }
338
339 fn parse_alter(&mut self) -> Result<Statement, MdqlError> {
340 self.expect("keyword", Some("ALTER"))?;
341 self.expect("keyword", Some("TABLE"))?;
342 let table = self.parse_ident()?;
343
344 let t = self.peek().ok_or_else(|| {
345 MdqlError::QueryParse("Expected RENAME, DROP, or MERGE after table name".into())
346 })?;
347
348 match (t.token_type.as_str(), t.value.as_str()) {
349 ("keyword", "RENAME") => {
350 self.advance();
351 self.expect("keyword", Some("FIELD"))?;
352 let old_name = self.parse_string_or_ident()?;
353 self.expect("keyword", Some("TO"))?;
354 let new_name = self.parse_string_or_ident()?;
355 self.expect_end()?;
356 Ok(Statement::AlterRename(AlterRenameFieldQuery {
357 table,
358 old_name,
359 new_name,
360 }))
361 }
362 ("keyword", "DROP") => {
363 self.advance();
364 self.expect("keyword", Some("FIELD"))?;
365 let field_name = self.parse_string_or_ident()?;
366 self.expect_end()?;
367 Ok(Statement::AlterDrop(AlterDropFieldQuery {
368 table,
369 field_name,
370 }))
371 }
372 ("keyword", "MERGE") => {
373 self.advance();
374 self.expect("keyword", Some("FIELDS"))?;
375 let mut sources = vec![self.parse_string_or_ident()?];
376 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
377 self.advance();
378 sources.push(self.parse_string_or_ident()?);
379 }
380 self.expect("keyword", Some("INTO"))?;
381 let target = self.parse_string_or_ident()?;
382 self.expect_end()?;
383 Ok(Statement::AlterMerge(AlterMergeFieldsQuery {
384 table,
385 sources,
386 into: target,
387 }))
388 }
389 _ => Err(MdqlError::QueryParse(format!(
390 "Expected RENAME, DROP, or MERGE, got '{}'",
391 t.raw
392 ))),
393 }
394 }
395
396 fn parse_create_view(&mut self) -> Result<Statement, MdqlError> {
397 self.expect("keyword", Some("CREATE"))?;
398 self.expect("keyword", Some("VIEW"))?;
399 let view_name = self.parse_ident()?;
400
401 let columns = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
402 self.advance();
403 let mut cols = vec![self.parse_ident()?];
404 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
405 self.advance();
406 cols.push(self.parse_ident()?);
407 }
408 self.expect("op", Some(")"))?;
409 Some(cols)
410 } else {
411 None
412 };
413
414 self.expect("keyword", Some("AS"))?;
415 let query = Box::new(self.parse_select()?);
416
417 Ok(Statement::CreateView(CreateViewQuery {
418 view_name,
419 columns,
420 query,
421 }))
422 }
423
424 fn parse_drop_view(&mut self) -> Result<Statement, MdqlError> {
425 self.expect("keyword", Some("DROP"))?;
426 self.expect("keyword", Some("VIEW"))?;
427 let view_name = self.parse_ident()?;
428 self.expect_end()?;
429 Ok(Statement::DropView(DropViewQuery { view_name }))
430 }
431
432 fn parse_string_or_ident(&mut self) -> Result<String, MdqlError> {
433 let t = self.peek().ok_or_else(|| {
434 MdqlError::QueryParse("Expected field name, got end of query".into())
435 })?;
436 match t.token_type.as_str() {
437 "string" => {
438 let v = self.advance().value;
439 Ok(v)
440 }
441 "ident" | "keyword" => {
442 let v = self.advance().value;
443 Ok(v)
444 }
445 _ => Err(MdqlError::QueryParse(format!(
446 "Expected field name, got '{}'",
447 t.raw
448 ))),
449 }
450 }
451
452 fn parse_columns(&mut self) -> Result<ColumnList, MdqlError> {
453 if let Some(t) = self.peek() {
454 if t.token_type == "op" && t.value == "*" {
455 self.advance();
456 return Ok(ColumnList::All);
457 }
458 }
459
460 let mut exprs = vec![self.parse_select_expr()?];
461 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
462 self.advance();
463 exprs.push(self.parse_select_expr()?);
464 }
465 Ok(ColumnList::Named(exprs))
466 }
467
468 fn peek_is_agg_func(&self) -> bool {
469 let t = match self.peek() {
470 Some(t) => t,
471 None => return false,
472 };
473 let name_upper = t.value.to_uppercase();
474 if !AGG_FUNCS.contains(&name_upper.as_str()) {
475 return false;
476 }
477 self.tokens
479 .get(self.pos + 1)
480 .map_or(false, |next| next.token_type == "op" && next.value == "(")
481 }
482
483 fn parse_select_expr(&mut self) -> Result<SelectExpr, MdqlError> {
484 let _t = self.peek().ok_or_else(|| {
485 MdqlError::QueryParse("Expected column or aggregate, got end of query".into())
486 })?;
487
488 if self.peek_is_agg_func() {
489 let func_name = self.advance().value.to_uppercase();
490 let func = match func_name.as_str() {
491 "COUNT" => AggFunc::Count,
492 "SUM" => AggFunc::Sum,
493 "AVG" => AggFunc::Avg,
494 "MIN" => AggFunc::Min,
495 "MAX" => AggFunc::Max,
496 _ => unreachable!(),
497 };
498 self.expect("op", Some("("))?;
499 let (arg, arg_expr) = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "*") {
500 self.advance();
501 ("*".to_string(), None)
502 } else {
503 let expr = self.parse_additive()?;
505 if let Expr::Column(name) = &expr {
506 (name.clone(), None)
507 } else {
508 (expr.display_name(), Some(expr))
509 }
510 };
511 self.expect("op", Some(")"))?;
512
513 let alias = if self.match_keyword("AS") {
514 Some(self.parse_ident()?)
515 } else if self.peek().map_or(false, |t| {
516 t.token_type == "ident" && !self.is_clause_keyword(t)
517 }) {
518 Some(self.advance().value)
519 } else {
520 None
521 };
522
523 Ok(SelectExpr::Aggregate { func, arg, arg_expr, alias })
524 } else {
525 let expr = self.parse_additive()?;
527
528 let alias = if self.match_keyword("AS") {
530 Some(self.parse_ident()?)
531 } else if self.peek().map_or(false, |t| {
532 t.token_type == "ident" && !self.is_clause_keyword(t)
533 }) {
534 Some(self.advance().value)
535 } else {
536 None
537 };
538
539 if alias.is_none() {
542 if let Expr::Column(name) = &expr {
543 return Ok(SelectExpr::Column(name.clone()));
544 }
545 }
546
547 Ok(SelectExpr::Expr { expr, alias })
548 }
549 }
550
551 fn peek_is_additive_op(&self) -> bool {
554 self.peek().map_or(false, |t| {
555 t.token_type == "op" && (t.value == "+" || t.value == "-")
556 })
557 }
558
559 fn peek_is_multiplicative_op(&self) -> bool {
560 self.peek().map_or(false, |t| {
561 t.token_type == "op" && (t.value == "*" || t.value == "/" || t.value == "%")
562 })
563 }
564
565 fn parse_additive(&mut self) -> Result<Expr, MdqlError> {
566 let mut left = self.parse_multiplicative()?;
567 while self.peek_is_additive_op() {
568 let op_tok = self.advance();
569 let is_sub = op_tok.value == "-";
570
571 if self.peek().map_or(false, |t| t.token_type == "keyword" && t.value == "INTERVAL") {
573 self.advance(); let days_expr = self.parse_multiplicative()?;
575 if !self.match_keyword("DAY") && !self.match_keyword("DAYS") {
577 return Err(MdqlError::QueryParse("Expected DAY after INTERVAL value".into()));
578 }
579 let days = if is_sub {
580 Expr::UnaryMinus(Box::new(days_expr))
581 } else {
582 days_expr
583 };
584 left = Expr::DateAdd {
585 date: Box::new(left),
586 days: Box::new(days),
587 };
588 continue;
589 }
590
591 let op = match op_tok.value.as_str() {
592 "+" => ArithOp::Add,
593 "-" => ArithOp::Sub,
594 _ => unreachable!(),
595 };
596 let right = self.parse_multiplicative()?;
597 left = Expr::BinaryOp {
598 left: Box::new(left),
599 op,
600 right: Box::new(right),
601 };
602 }
603 Ok(left)
604 }
605
606 fn parse_multiplicative(&mut self) -> Result<Expr, MdqlError> {
607 let mut left = self.parse_unary()?;
608 while self.peek_is_multiplicative_op() {
609 let op_tok = self.advance();
610 let op = match op_tok.value.as_str() {
611 "*" => ArithOp::Mul,
612 "/" => ArithOp::Div,
613 "%" => ArithOp::Mod,
614 _ => unreachable!(),
615 };
616 let right = self.parse_unary()?;
617 left = Expr::BinaryOp {
618 left: Box::new(left),
619 op,
620 right: Box::new(right),
621 };
622 }
623 Ok(left)
624 }
625
626 fn parse_unary(&mut self) -> Result<Expr, MdqlError> {
627 if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "-") {
628 self.advance();
629 let inner = self.parse_atom()?;
630 match inner {
632 Expr::Literal(SqlValue::Int(n)) => Ok(Expr::Literal(SqlValue::Int(-n))),
633 Expr::Literal(SqlValue::Float(f)) => Ok(Expr::Literal(SqlValue::Float(-f))),
634 _ => Ok(Expr::UnaryMinus(Box::new(inner))),
635 }
636 } else {
637 self.parse_atom()
638 }
639 }
640
641 fn parse_atom(&mut self) -> Result<Expr, MdqlError> {
642 let t = self.peek().ok_or_else(|| {
643 MdqlError::QueryParse("Expected expression, got end of query".into())
644 })?;
645
646 match t.token_type.as_str() {
647 "number" => {
648 let v = self.advance().value;
649 if v.contains('.') {
650 let f: f64 = v.parse().map_err(|_| {
651 MdqlError::QueryParse(format!("Invalid float: {}", v))
652 })?;
653 Ok(Expr::Literal(SqlValue::Float(f)))
654 } else {
655 let n: i64 = v.parse().map_err(|_| {
656 MdqlError::QueryParse(format!("Invalid int: {}", v))
657 })?;
658 Ok(Expr::Literal(SqlValue::Int(n)))
659 }
660 }
661 "string" => {
662 let v = self.advance().value;
663 Ok(Expr::Literal(SqlValue::String(v)))
664 }
665 "keyword" if t.value == "NULL" => {
666 self.advance();
667 Ok(Expr::Literal(SqlValue::Null))
668 }
669 "keyword" if t.value == "CASE" => {
670 self.parse_case_expr()
671 }
672 "keyword" if t.value == "CURRENT_DATE" => {
673 self.advance();
674 Ok(Expr::CurrentDate)
675 }
676 "keyword" if t.value == "CURRENT_TIMESTAMP" => {
677 self.advance();
678 Ok(Expr::CurrentTimestamp)
679 }
680 "keyword" if t.value == "DATEDIFF" => {
681 self.advance();
682 self.expect("op", Some("("))?;
683 let left = self.parse_additive()?;
684 self.expect("op", Some(","))?;
685 let right = self.parse_additive()?;
686 self.expect("op", Some(")"))?;
687 Ok(Expr::DateDiff { left: Box::new(left), right: Box::new(right) })
688 }
689 "op" if t.value == "(" => {
690 self.advance();
691 let expr = self.parse_additive()?;
692 self.expect("op", Some(")"))?;
693 Ok(expr)
694 }
695 "ident" => {
696 let name = self.advance().value;
697 Ok(Expr::Column(name))
698 }
699 "keyword" if !Self::is_reserved_keyword(&t.value) => {
700 let name = self.advance().value;
701 Ok(Expr::Column(name))
702 }
703 _ => Err(MdqlError::QueryParse(format!(
704 "Expected expression, got '{}'",
705 t.raw
706 ))),
707 }
708 }
709
710 fn parse_case_expr(&mut self) -> Result<Expr, MdqlError> {
711 self.expect("keyword", Some("CASE"))?;
712 let mut whens = Vec::new();
713 while self.match_keyword("WHEN") {
714 let condition = self.parse_or_expr()?;
715 self.expect("keyword", Some("THEN"))?;
716 let result = self.parse_additive()?;
717 whens.push((condition, Box::new(result)));
718 }
719 if whens.is_empty() {
720 return Err(MdqlError::QueryParse("CASE requires at least one WHEN clause".into()));
721 }
722 let else_expr = if self.match_keyword("ELSE") {
723 Some(Box::new(self.parse_additive()?))
724 } else {
725 None
726 };
727 self.expect("keyword", Some("END"))?;
728 Ok(Expr::Case { whens, else_expr })
729 }
730
731 fn parse_ident(&mut self) -> Result<String, MdqlError> {
732 let t = self.peek().ok_or_else(|| {
733 MdqlError::QueryParse("Expected identifier, got end of query".into())
734 })?;
735 match t.token_type.as_str() {
736 "ident" | "keyword" => {
737 let v = self.advance().value;
738 Ok(v)
739 }
740 _ => Err(MdqlError::QueryParse(format!(
741 "Expected identifier, got '{}'",
742 t.raw
743 ))),
744 }
745 }
746
747 fn parse_or_expr(&mut self) -> Result<WhereClause, MdqlError> {
748 let mut left = self.parse_and_expr()?;
749 while self.match_keyword("OR") {
750 let right = self.parse_and_expr()?;
751 left = WhereClause::BoolOp(BoolOp {
752 op: "OR".into(),
753 left: Box::new(left),
754 right: Box::new(right),
755 });
756 }
757 Ok(left)
758 }
759
760 fn parse_and_expr(&mut self) -> Result<WhereClause, MdqlError> {
761 let mut left = self.parse_comparison()?;
762 while self.match_keyword("AND") {
763 let right = self.parse_comparison()?;
764 left = WhereClause::BoolOp(BoolOp {
765 op: "AND".into(),
766 left: Box::new(left),
767 right: Box::new(right),
768 });
769 }
770 Ok(left)
771 }
772
773 fn parse_comparison(&mut self) -> Result<WhereClause, MdqlError> {
774 if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
776 let saved_pos = self.pos;
778 self.advance();
779 let result = self.parse_or_expr();
781 if result.is_ok() && self.peek().map_or(false, |t| t.token_type == "op" && t.value == ")") {
782 self.advance();
783 return result;
784 }
785 self.pos = saved_pos;
787 }
788
789 let left_expr = self.parse_additive()?;
791
792 let col = left_expr.as_column().unwrap_or("").to_string();
794
795 if self.match_keyword("IS") {
797 if self.match_keyword("NOT") {
798 self.expect("keyword", Some("NULL"))?;
799 return Ok(WhereClause::Comparison(Comparison {
800 column: col,
801 op: "IS NOT NULL".into(),
802 value: None,
803 left_expr: Some(left_expr),
804 right_expr: None,
805 }));
806 }
807 self.expect("keyword", Some("NULL"))?;
808 return Ok(WhereClause::Comparison(Comparison {
809 column: col,
810 op: "IS NULL".into(),
811 value: None,
812 left_expr: Some(left_expr),
813 right_expr: None,
814 }));
815 }
816
817 if self.match_keyword("IN") {
819 self.expect("op", Some("("))?;
820 let mut values = vec![self.parse_value()?];
821 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
822 self.advance();
823 values.push(self.parse_value()?);
824 }
825 self.expect("op", Some(")"))?;
826 return Ok(WhereClause::Comparison(Comparison {
827 column: col,
828 op: "IN".into(),
829 value: Some(SqlValue::List(values)),
830 left_expr: Some(left_expr),
831 right_expr: None,
832 }));
833 }
834
835 if self.match_keyword("LIKE") {
837 let val = self.parse_value()?;
838 return Ok(WhereClause::Comparison(Comparison {
839 column: col,
840 op: "LIKE".into(),
841 value: Some(val),
842 left_expr: Some(left_expr),
843 right_expr: None,
844 }));
845 }
846
847 if self.match_keyword("NOT") {
849 if self.match_keyword("LIKE") {
850 let val = self.parse_value()?;
851 return Ok(WhereClause::Comparison(Comparison {
852 column: col,
853 op: "NOT LIKE".into(),
854 value: Some(val),
855 left_expr: Some(left_expr),
856 right_expr: None,
857 }));
858 }
859 return Err(MdqlError::QueryParse("Expected LIKE after NOT".into()));
860 }
861
862 if let Some(t) = self.peek() {
864 if t.token_type == "op" && ["=", "!=", "<", ">", "<=", ">="].contains(&t.value.as_str())
865 {
866 let op = self.advance().value;
867 let right_expr = self.parse_additive()?;
869 let value = match &right_expr {
871 Expr::Literal(v) => Some(v.clone()),
872 _ => None,
873 };
874 return Ok(WhereClause::Comparison(Comparison {
875 column: col,
876 op,
877 value,
878 left_expr: Some(left_expr),
879 right_expr: Some(right_expr),
880 }));
881 }
882 }
883
884 let got = self.peek().map_or("end".to_string(), |t| t.raw.clone());
885 Err(MdqlError::QueryParse(format!(
886 "Expected operator after '{}', got '{}'",
887 left_expr.display_name(), got
888 )))
889 }
890
891 fn parse_value(&mut self) -> Result<SqlValue, MdqlError> {
892 let t = self.peek().ok_or_else(|| {
893 MdqlError::QueryParse("Expected value, got end of query".into())
894 })?;
895 match t.token_type.as_str() {
896 "string" => {
897 let v = self.advance().value;
898 Ok(SqlValue::String(v))
899 }
900 "number" => {
901 let v = self.advance().value;
902 if v.contains('.') {
903 Ok(SqlValue::Float(v.parse().map_err(|_| {
904 MdqlError::QueryParse(format!("Invalid float: {}", v))
905 })?))
906 } else {
907 Ok(SqlValue::Int(v.parse().map_err(|_| {
908 MdqlError::QueryParse(format!("Invalid int: {}", v))
909 })?))
910 }
911 }
912 "keyword" if t.value == "NULL" => {
913 self.advance();
914 Ok(SqlValue::Null)
915 }
916 _ => Err(MdqlError::QueryParse(format!(
917 "Expected value, got '{}'",
918 t.raw
919 ))),
920 }
921 }
922
923 fn parse_order_by(&mut self) -> Result<Vec<OrderSpec>, MdqlError> {
924 let mut specs = vec![self.parse_order_spec()?];
925 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
926 self.advance();
927 specs.push(self.parse_order_spec()?);
928 }
929 Ok(specs)
930 }
931
932 fn parse_order_spec(&mut self) -> Result<OrderSpec, MdqlError> {
933 let expr = self.parse_additive()?;
934 let col = expr.as_column().unwrap_or("").to_string();
935 let descending = if self.match_keyword("DESC") {
936 true
937 } else {
938 self.match_keyword("ASC");
939 false
940 };
941 Ok(OrderSpec {
942 column: col,
943 expr: Some(expr),
944 descending,
945 })
946 }
947
948 fn is_clause_keyword(&self, t: &Token) -> bool {
949 t.token_type == "keyword"
950 && ["WHERE", "ORDER", "LIMIT", "JOIN", "ON", "GROUP"].contains(&t.value.as_str())
951 }
952
953 fn is_reserved_keyword(kw: &str) -> bool {
955 matches!(kw,
956 "AS" | "FROM" | "WHERE" | "AND" | "OR" | "ORDER" | "BY"
957 | "ASC" | "DESC" | "LIMIT" | "JOIN" | "ON" | "GROUP"
958 | "SELECT" | "INSERT" | "INTO" | "VALUES" | "UPDATE" | "SET"
959 | "DELETE" | "ALTER" | "TABLE" | "IS" | "NOT" | "IN" | "LIKE"
960 | "RENAME" | "FIELD" | "TO" | "DROP" | "MERGE" | "FIELDS"
961 | "CASE" | "WHEN" | "THEN" | "ELSE" | "END"
962 | "HAVING" | "INTERVAL" | "DAY" | "DAYS"
963 | "CURRENT_DATE" | "CURRENT_TIMESTAMP" | "DATEDIFF"
964 | "CREATE" | "VIEW" | "CASCADE" | "RESTRICT"
965 )
966 }
967
968 fn expect_end(&self) -> Result<(), MdqlError> {
969 if let Some(t) = self.peek() {
970 return Err(MdqlError::QueryParse(format!(
971 "Unexpected token '{}' at position {}",
972 t.raw, self.pos
973 )));
974 }
975 Ok(())
976 }
977}
978
979pub fn parse_query(sql: &str) -> crate::errors::Result<Statement> {
980 let tokens = tokenize(sql);
981 if tokens.is_empty() {
982 return Err(MdqlError::QueryParse("Empty query".into()));
983 }
984 let mut parser = Parser::new(tokens);
985 parser.parse_statement()
986}
987
988#[cfg(test)]
989mod tests {
990 use super::*;
991
992 #[test]
993 fn test_simple_select() {
994 let stmt = parse_query("SELECT title, status FROM strategies").unwrap();
995 if let Statement::Select(q) = stmt {
996 assert_eq!(q.columns, ColumnList::Named(vec![SelectExpr::Column("title".into()), SelectExpr::Column("status".into())]));
997 assert_eq!(q.table, "strategies");
998 } else {
999 panic!("Expected Select");
1000 }
1001 }
1002
1003 #[test]
1004 fn test_select_star() {
1005 let stmt = parse_query("SELECT * FROM test").unwrap();
1006 if let Statement::Select(q) = stmt {
1007 assert_eq!(q.columns, ColumnList::All);
1008 } else {
1009 panic!("Expected Select");
1010 }
1011 }
1012
1013 #[test]
1014 fn test_where_clause() {
1015 let stmt = parse_query("SELECT title FROM test WHERE count > 5").unwrap();
1016 if let Statement::Select(q) = stmt {
1017 assert!(q.where_clause.is_some());
1018 } else {
1019 panic!("Expected Select");
1020 }
1021 }
1022
1023 #[test]
1024 fn test_order_by() {
1025 let stmt =
1026 parse_query("SELECT title FROM test ORDER BY composite DESC, title ASC").unwrap();
1027 if let Statement::Select(q) = stmt {
1028 let ob = q.order_by.unwrap();
1029 assert_eq!(ob.len(), 2);
1030 assert!(ob[0].descending);
1031 assert!(!ob[1].descending);
1032 } else {
1033 panic!("Expected Select");
1034 }
1035 }
1036
1037 #[test]
1038 fn test_limit() {
1039 let stmt = parse_query("SELECT * FROM test LIMIT 10").unwrap();
1040 if let Statement::Select(q) = stmt {
1041 assert_eq!(q.limit, Some(10));
1042 } else {
1043 panic!("Expected Select");
1044 }
1045 }
1046
1047 #[test]
1048 fn test_insert() {
1049 let stmt = parse_query(
1050 "INSERT INTO test (title, count) VALUES ('Hello', 42)",
1051 )
1052 .unwrap();
1053 if let Statement::Insert(q) = stmt {
1054 assert_eq!(q.table, "test");
1055 assert_eq!(q.columns, vec!["title", "count"]);
1056 assert_eq!(q.values[0], SqlValue::String("Hello".into()));
1057 assert_eq!(q.values[1], SqlValue::Int(42));
1058 } else {
1059 panic!("Expected Insert");
1060 }
1061 }
1062
1063 #[test]
1064 fn test_update() {
1065 let stmt = parse_query("UPDATE test SET status = 'KILLED' WHERE path = 'a.md'").unwrap();
1066 if let Statement::Update(q) = stmt {
1067 assert_eq!(q.table, "test");
1068 assert_eq!(q.assignments.len(), 1);
1069 assert!(q.where_clause.is_some());
1070 } else {
1071 panic!("Expected Update");
1072 }
1073 }
1074
1075 #[test]
1076 fn test_delete() {
1077 let stmt = parse_query("DELETE FROM test WHERE status = 'draft'").unwrap();
1078 if let Statement::Delete(q) = stmt {
1079 assert_eq!(q.table, "test");
1080 assert!(q.where_clause.is_some());
1081 } else {
1082 panic!("Expected Delete");
1083 }
1084 }
1085
1086 #[test]
1087 fn test_alter_rename() {
1088 let stmt =
1089 parse_query("ALTER TABLE test RENAME FIELD 'Summary' TO 'Overview'").unwrap();
1090 if let Statement::AlterRename(q) = stmt {
1091 assert_eq!(q.old_name, "Summary");
1092 assert_eq!(q.new_name, "Overview");
1093 } else {
1094 panic!("Expected AlterRename");
1095 }
1096 }
1097
1098 #[test]
1099 fn test_alter_drop() {
1100 let stmt = parse_query("ALTER TABLE test DROP FIELD 'Details'").unwrap();
1101 if let Statement::AlterDrop(q) = stmt {
1102 assert_eq!(q.field_name, "Details");
1103 } else {
1104 panic!("Expected AlterDrop");
1105 }
1106 }
1107
1108 #[test]
1109 fn test_alter_merge() {
1110 let stmt = parse_query(
1111 "ALTER TABLE test MERGE FIELDS 'Entry Rules', 'Exit Rules' INTO 'Trading Rules'",
1112 )
1113 .unwrap();
1114 if let Statement::AlterMerge(q) = stmt {
1115 assert_eq!(q.sources, vec!["Entry Rules", "Exit Rules"]);
1116 assert_eq!(q.into, "Trading Rules");
1117 } else {
1118 panic!("Expected AlterMerge");
1119 }
1120 }
1121
1122 #[test]
1123 fn test_backtick_ident() {
1124 let stmt = parse_query("SELECT `Structural Mechanism` FROM test").unwrap();
1125 if let Statement::Select(q) = stmt {
1126 assert_eq!(
1127 q.columns,
1128 ColumnList::Named(vec![SelectExpr::Column("Structural Mechanism".into())])
1129 );
1130 } else {
1131 panic!("Expected Select");
1132 }
1133 }
1134
1135 #[test]
1136 fn test_like_operator() {
1137 let stmt = parse_query("SELECT title FROM test WHERE categories LIKE '%defi%'").unwrap();
1138 if let Statement::Select(q) = stmt {
1139 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1140 assert_eq!(c.op, "LIKE");
1141 assert_eq!(c.value, Some(SqlValue::String("%defi%".into())));
1142 } else {
1143 panic!("Expected LIKE comparison");
1144 }
1145 } else {
1146 panic!("Expected Select");
1147 }
1148 }
1149
1150 #[test]
1151 fn test_in_operator() {
1152 let stmt =
1153 parse_query("SELECT * FROM test WHERE status IN ('ACTIVE', 'LIVE')").unwrap();
1154 if let Statement::Select(q) = stmt {
1155 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1156 assert_eq!(c.op, "IN");
1157 } else {
1158 panic!("Expected IN comparison");
1159 }
1160 } else {
1161 panic!("Expected Select");
1162 }
1163 }
1164
1165 #[test]
1166 fn test_is_null() {
1167 let stmt = parse_query("SELECT * FROM test WHERE title IS NULL").unwrap();
1168 if let Statement::Select(q) = stmt {
1169 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1170 assert_eq!(c.op, "IS NULL");
1171 } else {
1172 panic!("Expected IS NULL comparison");
1173 }
1174 } else {
1175 panic!("Expected Select");
1176 }
1177 }
1178
1179 #[test]
1180 fn test_and_or() {
1181 let stmt = parse_query(
1182 "SELECT * FROM test WHERE status = 'ACTIVE' AND count > 5 OR title LIKE '%test%'",
1183 )
1184 .unwrap();
1185 if let Statement::Select(q) = stmt {
1186 assert!(q.where_clause.is_some());
1187 } else {
1188 panic!("Expected Select");
1189 }
1190 }
1191
1192 #[test]
1193 fn test_join() {
1194 let stmt = parse_query(
1195 "SELECT s.title, b.sharpe FROM strategies s JOIN backtests b ON b.strategy = s.path",
1196 )
1197 .unwrap();
1198 if let Statement::Select(q) = stmt {
1199 assert_eq!(q.table, "strategies");
1200 assert_eq!(q.table_alias, Some("s".into()));
1201 assert_eq!(q.joins.len(), 1);
1202 let join = &q.joins[0];
1203 assert_eq!(join.table, "backtests");
1204 assert_eq!(join.alias, Some("b".into()));
1205 } else {
1206 panic!("Expected Select");
1207 }
1208 }
1209
1210 #[test]
1211 fn test_multi_join() {
1212 let stmt = parse_query(
1213 "SELECT s.title, b.sharpe, c.verdict FROM strategies s JOIN backtests b ON b.strategy = s.path JOIN critiques c ON c.strategy = s.path",
1214 )
1215 .unwrap();
1216 if let Statement::Select(q) = stmt {
1217 assert_eq!(q.table, "strategies");
1218 assert_eq!(q.table_alias, Some("s".into()));
1219 assert_eq!(q.joins.len(), 2);
1220 assert_eq!(q.joins[0].table, "backtests");
1221 assert_eq!(q.joins[0].alias, Some("b".into()));
1222 assert_eq!(q.joins[0].left_col, "b.strategy");
1223 assert_eq!(q.joins[0].right_col, "s.path");
1224 assert_eq!(q.joins[1].table, "critiques");
1225 assert_eq!(q.joins[1].alias, Some("c".into()));
1226 assert_eq!(q.joins[1].left_col, "c.strategy");
1227 assert_eq!(q.joins[1].right_col, "s.path");
1228 } else {
1229 panic!("Expected Select");
1230 }
1231 }
1232
1233 #[test]
1234 fn test_empty_query() {
1235 assert!(parse_query("").is_err());
1236 }
1237
1238 #[test]
1239 fn test_count_star() {
1240 let stmt = parse_query("SELECT status, COUNT(*) AS cnt FROM strategies GROUP BY status").unwrap();
1241 if let Statement::Select(q) = stmt {
1242 if let ColumnList::Named(exprs) = &q.columns {
1243 assert_eq!(exprs.len(), 2);
1244 assert_eq!(exprs[0], SelectExpr::Column("status".into()));
1245 assert!(matches!(&exprs[1], SelectExpr::Aggregate {
1246 func: AggFunc::Count,
1247 arg,
1248 alias: Some(a),
1249 ..
1250 } if arg == "*" && a == "cnt"));
1251 } else {
1252 panic!("Expected Named columns");
1253 }
1254 assert_eq!(q.group_by, Some(vec!["status".into()]));
1255 } else {
1256 panic!("Expected Select");
1257 }
1258 }
1259
1260 #[test]
1261 fn test_count_column_as_ident() {
1262 let stmt = parse_query("INSERT INTO test (title, count) VALUES ('Hello', 42)").unwrap();
1264 if let Statement::Insert(q) = stmt {
1265 assert_eq!(q.columns, vec!["title", "count"]);
1266 } else {
1267 panic!("Expected Insert");
1268 }
1269 }
1270
1271 #[test]
1272 fn test_multiple_aggregates() {
1273 let stmt = parse_query("SELECT MIN(composite), MAX(composite), AVG(composite) FROM strategies").unwrap();
1274 if let Statement::Select(q) = stmt {
1275 if let ColumnList::Named(exprs) = &q.columns {
1276 assert_eq!(exprs.len(), 3);
1277 assert!(matches!(&exprs[0], SelectExpr::Aggregate { func: AggFunc::Min, .. }));
1278 assert!(matches!(&exprs[1], SelectExpr::Aggregate { func: AggFunc::Max, .. }));
1279 assert!(matches!(&exprs[2], SelectExpr::Aggregate { func: AggFunc::Avg, .. }));
1280 } else {
1281 panic!("Expected Named columns");
1282 }
1283 assert_eq!(q.group_by, None);
1284 } else {
1285 panic!("Expected Select");
1286 }
1287 }
1288
1289 #[test]
1292 fn test_select_arithmetic_expr() {
1293 let stmt = parse_query("SELECT a + b FROM test").unwrap();
1294 if let Statement::Select(q) = stmt {
1295 if let ColumnList::Named(exprs) = &q.columns {
1296 assert_eq!(exprs.len(), 1);
1297 assert!(matches!(&exprs[0], SelectExpr::Expr {
1298 expr: Expr::BinaryOp { op: ArithOp::Add, .. },
1299 alias: None,
1300 }));
1301 } else {
1302 panic!("Expected Named columns");
1303 }
1304 } else {
1305 panic!("Expected Select");
1306 }
1307 }
1308
1309 #[test]
1310 fn test_select_arithmetic_with_alias() {
1311 let stmt = parse_query("SELECT a + b AS total FROM test").unwrap();
1312 if let Statement::Select(q) = stmt {
1313 if let ColumnList::Named(exprs) = &q.columns {
1314 assert_eq!(exprs.len(), 1);
1315 assert!(matches!(&exprs[0], SelectExpr::Expr {
1316 alias: Some(a),
1317 ..
1318 } if a == "total"));
1319 assert_eq!(exprs[0].output_name(), "total");
1320 } else {
1321 panic!("Expected Named columns");
1322 }
1323 } else {
1324 panic!("Expected Select");
1325 }
1326 }
1327
1328 #[test]
1329 fn test_select_precedence() {
1330 let stmt = parse_query("SELECT a + b * c FROM test").unwrap();
1332 if let Statement::Select(q) = stmt {
1333 if let ColumnList::Named(exprs) = &q.columns {
1334 if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1335 if let Expr::BinaryOp { left, op, right } = expr {
1336 assert_eq!(*op, ArithOp::Add);
1337 assert!(matches!(left.as_ref(), Expr::Column(n) if n == "a"));
1338 assert!(matches!(right.as_ref(), Expr::BinaryOp { op: ArithOp::Mul, .. }));
1339 } else {
1340 panic!("Expected BinaryOp");
1341 }
1342 } else {
1343 panic!("Expected Expr variant");
1344 }
1345 } else {
1346 panic!("Expected Named columns");
1347 }
1348 } else {
1349 panic!("Expected Select");
1350 }
1351 }
1352
1353 #[test]
1354 fn test_select_parenthesized_expr() {
1355 let stmt = parse_query("SELECT (a + b) * c FROM test").unwrap();
1357 if let Statement::Select(q) = stmt {
1358 if let ColumnList::Named(exprs) = &q.columns {
1359 if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1360 if let Expr::BinaryOp { left, op, .. } = expr {
1361 assert_eq!(*op, ArithOp::Mul);
1362 assert!(matches!(left.as_ref(), Expr::BinaryOp { op: ArithOp::Add, .. }));
1363 } else {
1364 panic!("Expected BinaryOp");
1365 }
1366 } else {
1367 panic!("Expected Expr variant");
1368 }
1369 } else {
1370 panic!("Expected Named columns");
1371 }
1372 } else {
1373 panic!("Expected Select");
1374 }
1375 }
1376
1377 #[test]
1378 fn test_select_unary_minus() {
1379 let stmt = parse_query("SELECT -count FROM test").unwrap();
1380 if let Statement::Select(q) = stmt {
1381 if let ColumnList::Named(exprs) = &q.columns {
1382 assert!(matches!(&exprs[0], SelectExpr::Expr {
1383 expr: Expr::UnaryMinus(_),
1384 ..
1385 }));
1386 } else {
1387 panic!("Expected Named columns");
1388 }
1389 } else {
1390 panic!("Expected Select");
1391 }
1392 }
1393
1394 #[test]
1395 fn test_select_negative_literal() {
1396 let stmt = parse_query("SELECT -42 FROM test").unwrap();
1397 if let Statement::Select(q) = stmt {
1398 if let ColumnList::Named(exprs) = &q.columns {
1399 assert!(matches!(&exprs[0], SelectExpr::Expr {
1401 expr: Expr::Literal(SqlValue::Int(-42)),
1402 ..
1403 }));
1404 } else {
1405 panic!("Expected Named columns");
1406 }
1407 } else {
1408 panic!("Expected Select");
1409 }
1410 }
1411
1412 #[test]
1413 fn test_where_arithmetic_expr() {
1414 let stmt = parse_query("SELECT * FROM test WHERE a + b > 10").unwrap();
1415 if let Statement::Select(q) = stmt {
1416 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1417 assert_eq!(c.op, ">");
1418 assert!(matches!(&c.left_expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1419 assert!(matches!(&c.right_expr, Some(Expr::Literal(SqlValue::Int(10)))));
1420 } else {
1421 panic!("Expected comparison");
1422 }
1423 } else {
1424 panic!("Expected Select");
1425 }
1426 }
1427
1428 #[test]
1429 fn test_where_both_sides_expr() {
1430 let stmt = parse_query("SELECT * FROM test WHERE a * 2 > b + 1").unwrap();
1431 if let Statement::Select(q) = stmt {
1432 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1433 assert_eq!(c.op, ">");
1434 assert!(matches!(&c.left_expr, Some(Expr::BinaryOp { op: ArithOp::Mul, .. })));
1435 assert!(matches!(&c.right_expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1436 } else {
1437 panic!("Expected comparison");
1438 }
1439 } else {
1440 panic!("Expected Select");
1441 }
1442 }
1443
1444 #[test]
1445 fn test_order_by_expr() {
1446 let stmt = parse_query("SELECT * FROM test ORDER BY a + b DESC").unwrap();
1447 if let Statement::Select(q) = stmt {
1448 let ob = q.order_by.unwrap();
1449 assert_eq!(ob.len(), 1);
1450 assert!(ob[0].descending);
1451 assert!(matches!(&ob[0].expr, Some(Expr::BinaryOp { op: ArithOp::Add, .. })));
1452 } else {
1453 panic!("Expected Select");
1454 }
1455 }
1456
1457 #[test]
1458 fn test_all_arithmetic_ops() {
1459 let stmt = parse_query("SELECT a + b, a - b, a * b, a / b, a % b FROM test").unwrap();
1460 if let Statement::Select(q) = stmt {
1461 if let ColumnList::Named(exprs) = &q.columns {
1462 assert_eq!(exprs.len(), 5);
1463 assert!(matches!(&exprs[0], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Add, .. }, .. }));
1464 assert!(matches!(&exprs[1], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Sub, .. }, .. }));
1465 assert!(matches!(&exprs[2], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Mul, .. }, .. }));
1466 assert!(matches!(&exprs[3], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Div, .. }, .. }));
1467 assert!(matches!(&exprs[4], SelectExpr::Expr { expr: Expr::BinaryOp { op: ArithOp::Mod, .. }, .. }));
1468 } else {
1469 panic!("Expected Named columns");
1470 }
1471 } else {
1472 panic!("Expected Select");
1473 }
1474 }
1475
1476 #[test]
1477 fn test_column_with_literal_arithmetic() {
1478 let stmt = parse_query("SELECT count * 2 + 1 FROM test").unwrap();
1479 if let Statement::Select(q) = stmt {
1480 if let ColumnList::Named(exprs) = &q.columns {
1481 if let SelectExpr::Expr { expr, .. } = &exprs[0] {
1483 if let Expr::BinaryOp { left, op, right } = expr {
1484 assert_eq!(*op, ArithOp::Add);
1485 assert!(matches!(right.as_ref(), Expr::Literal(SqlValue::Int(1))));
1486 assert!(matches!(left.as_ref(), Expr::BinaryOp { op: ArithOp::Mul, .. }));
1487 } else {
1488 panic!("Expected BinaryOp");
1489 }
1490 } else {
1491 panic!("Expected Expr");
1492 }
1493 } else {
1494 panic!("Expected Named columns");
1495 }
1496 } else {
1497 panic!("Expected Select");
1498 }
1499 }
1500
1501 #[test]
1502 fn test_mixed_columns_and_exprs() {
1503 let stmt = parse_query("SELECT title, a + b AS sum, count FROM test").unwrap();
1504 if let Statement::Select(q) = stmt {
1505 if let ColumnList::Named(exprs) = &q.columns {
1506 assert_eq!(exprs.len(), 3);
1507 assert_eq!(exprs[0], SelectExpr::Column("title".into()));
1508 assert!(matches!(&exprs[1], SelectExpr::Expr { alias: Some(a), .. } if a == "sum"));
1509 assert_eq!(exprs[2], SelectExpr::Column("count".into()));
1510 } else {
1511 panic!("Expected Named columns");
1512 }
1513 } else {
1514 panic!("Expected Select");
1515 }
1516 }
1517
1518 #[test]
1521 fn test_case_when_basic() {
1522 let stmt = parse_query(
1523 "SELECT CASE WHEN status = 'ACTIVE' THEN 1 ELSE 0 END FROM test"
1524 ).unwrap();
1525 if let Statement::Select(q) = stmt {
1526 if let ColumnList::Named(exprs) = &q.columns {
1527 assert_eq!(exprs.len(), 1);
1528 assert!(matches!(&exprs[0], SelectExpr::Expr {
1529 expr: Expr::Case { .. },
1530 ..
1531 }));
1532 } else {
1533 panic!("Expected Named columns");
1534 }
1535 } else {
1536 panic!("Expected Select");
1537 }
1538 }
1539
1540 #[test]
1541 fn test_case_when_multiple_branches() {
1542 let stmt = parse_query(
1543 "SELECT CASE WHEN x > 10 THEN 'high' WHEN x > 5 THEN 'mid' ELSE 'low' END FROM test"
1544 ).unwrap();
1545 if let Statement::Select(q) = stmt {
1546 if let ColumnList::Named(exprs) = &q.columns {
1547 if let SelectExpr::Expr { expr: Expr::Case { whens, else_expr }, .. } = &exprs[0] {
1548 assert_eq!(whens.len(), 2);
1549 assert!(else_expr.is_some());
1550 } else {
1551 panic!("Expected Case expression");
1552 }
1553 } else {
1554 panic!("Expected Named columns");
1555 }
1556 } else {
1557 panic!("Expected Select");
1558 }
1559 }
1560
1561 #[test]
1562 fn test_case_when_no_else() {
1563 let stmt = parse_query(
1564 "SELECT CASE WHEN x = 1 THEN 'one' END FROM test"
1565 ).unwrap();
1566 if let Statement::Select(q) = stmt {
1567 if let ColumnList::Named(exprs) = &q.columns {
1568 if let SelectExpr::Expr { expr: Expr::Case { whens, else_expr }, .. } = &exprs[0] {
1569 assert_eq!(whens.len(), 1);
1570 assert!(else_expr.is_none());
1571 } else {
1572 panic!("Expected Case expression");
1573 }
1574 } else {
1575 panic!("Expected Named columns");
1576 }
1577 } else {
1578 panic!("Expected Select");
1579 }
1580 }
1581
1582 #[test]
1583 fn test_case_when_in_aggregate() {
1584 let stmt = parse_query(
1585 "SELECT SUM(CASE WHEN side = 'BUY' THEN size ELSE -size END) AS net FROM orders GROUP BY token"
1586 ).unwrap();
1587 if let Statement::Select(q) = stmt {
1588 if let ColumnList::Named(exprs) = &q.columns {
1589 assert_eq!(exprs.len(), 1);
1590 assert!(matches!(&exprs[0], SelectExpr::Aggregate {
1591 func: AggFunc::Sum,
1592 arg_expr: Some(Expr::Case { .. }),
1593 alias: Some(a),
1594 ..
1595 } if a == "net"));
1596 } else {
1597 panic!("Expected Named columns");
1598 }
1599 } else {
1600 panic!("Expected Select");
1601 }
1602 }
1603
1604 #[test]
1605 fn test_case_when_with_alias() {
1606 let stmt = parse_query(
1607 "SELECT CASE WHEN x > 0 THEN 'pos' ELSE 'neg' END AS sign FROM test"
1608 ).unwrap();
1609 if let Statement::Select(q) = stmt {
1610 if let ColumnList::Named(exprs) = &q.columns {
1611 assert!(matches!(&exprs[0], SelectExpr::Expr {
1612 expr: Expr::Case { .. },
1613 alias: Some(a),
1614 } if a == "sign"));
1615 } else {
1616 panic!("Expected Named columns");
1617 }
1618 } else {
1619 panic!("Expected Select");
1620 }
1621 }
1622
1623 #[test]
1624 fn test_create_view() {
1625 let stmt = parse_query("CREATE VIEW live AS SELECT * FROM strategies WHERE status = 'LIVE'").unwrap();
1626 if let Statement::CreateView(cv) = stmt {
1627 assert_eq!(cv.view_name, "live");
1628 assert!(cv.columns.is_none());
1629 assert_eq!(cv.query.table, "strategies");
1630 assert!(cv.query.where_clause.is_some());
1631 } else {
1632 panic!("Expected CreateView, got {:?}", stmt);
1633 }
1634 }
1635
1636 #[test]
1637 fn test_create_view_with_columns() {
1638 let stmt = parse_query("CREATE VIEW v1 (a, b) AS SELECT title, status FROM t").unwrap();
1639 if let Statement::CreateView(cv) = stmt {
1640 assert_eq!(cv.view_name, "v1");
1641 assert_eq!(cv.columns, Some(vec!["a".into(), "b".into()]));
1642 } else {
1643 panic!("Expected CreateView");
1644 }
1645 }
1646
1647 #[test]
1648 fn test_drop_view() {
1649 let stmt = parse_query("DROP VIEW live").unwrap();
1650 if let Statement::DropView(dv) = stmt {
1651 assert_eq!(dv.view_name, "live");
1652 } else {
1653 panic!("Expected DropView, got {:?}", stmt);
1654 }
1655 }
1656
1657 #[test]
1658 fn test_create_view_case_insensitive() {
1659 let stmt = parse_query("create view My_View as select * from t").unwrap();
1660 if let Statement::CreateView(cv) = stmt {
1661 assert_eq!(cv.view_name, "My_View");
1662 } else {
1663 panic!("Expected CreateView");
1664 }
1665 }
1666}