1use regex::Regex;
4use std::sync::LazyLock;
5
6use crate::errors::MdqlError;
7
8#[derive(Debug, Clone, PartialEq)]
11pub struct OrderSpec {
12 pub column: String,
13 pub descending: bool,
14}
15
16#[derive(Debug, Clone, PartialEq)]
17pub struct Comparison {
18 pub column: String,
19 pub op: String,
20 pub value: Option<SqlValue>,
21}
22
23#[derive(Debug, Clone, PartialEq)]
24pub struct BoolOp {
25 pub op: String, pub left: Box<WhereClause>,
27 pub right: Box<WhereClause>,
28}
29
30#[derive(Debug, Clone, PartialEq)]
31pub enum WhereClause {
32 Comparison(Comparison),
33 BoolOp(BoolOp),
34}
35
36#[derive(Debug, Clone, PartialEq)]
37pub enum SqlValue {
38 String(String),
39 Int(i64),
40 Float(f64),
41 Null,
42 List(Vec<SqlValue>),
43}
44
45#[derive(Debug, Clone, PartialEq)]
46pub struct JoinClause {
47 pub table: String,
48 pub alias: Option<String>,
49 pub left_col: String,
50 pub right_col: String,
51}
52
53#[derive(Debug, Clone, PartialEq)]
54pub enum AggFunc {
55 Count,
56 Sum,
57 Avg,
58 Min,
59 Max,
60}
61
62#[derive(Debug, Clone, PartialEq)]
63pub enum SelectExpr {
64 Column(String),
65 Aggregate { func: AggFunc, arg: String, alias: Option<String> },
66}
67
68impl SelectExpr {
69 pub fn output_name(&self) -> String {
70 match self {
71 SelectExpr::Column(name) => name.clone(),
72 SelectExpr::Aggregate { func, arg, alias } => {
73 if let Some(a) = alias {
74 a.clone()
75 } else {
76 let func_name = match func {
77 AggFunc::Count => "COUNT",
78 AggFunc::Sum => "SUM",
79 AggFunc::Avg => "AVG",
80 AggFunc::Min => "MIN",
81 AggFunc::Max => "MAX",
82 };
83 format!("{}({})", func_name, arg)
84 }
85 }
86 }
87 }
88
89 pub fn is_aggregate(&self) -> bool {
90 matches!(self, SelectExpr::Aggregate { .. })
91 }
92}
93
94#[derive(Debug, Clone, PartialEq)]
95pub struct SelectQuery {
96 pub columns: ColumnList,
97 pub table: String,
98 pub table_alias: Option<String>,
99 pub joins: Vec<JoinClause>,
100 pub where_clause: Option<WhereClause>,
101 pub group_by: Option<Vec<String>>,
102 pub order_by: Option<Vec<OrderSpec>>,
103 pub limit: Option<i64>,
104}
105
106#[derive(Debug, Clone, PartialEq)]
107pub enum ColumnList {
108 All,
109 Named(Vec<SelectExpr>),
110}
111
112#[derive(Debug, Clone, PartialEq)]
113pub struct InsertQuery {
114 pub table: String,
115 pub columns: Vec<String>,
116 pub values: Vec<SqlValue>,
117}
118
119#[derive(Debug, Clone, PartialEq)]
120pub struct UpdateQuery {
121 pub table: String,
122 pub assignments: Vec<(String, SqlValue)>,
123 pub where_clause: Option<WhereClause>,
124}
125
126#[derive(Debug, Clone, PartialEq)]
127pub struct DeleteQuery {
128 pub table: String,
129 pub where_clause: Option<WhereClause>,
130}
131
132#[derive(Debug, Clone, PartialEq)]
133pub struct AlterRenameFieldQuery {
134 pub table: String,
135 pub old_name: String,
136 pub new_name: String,
137}
138
139#[derive(Debug, Clone, PartialEq)]
140pub struct AlterDropFieldQuery {
141 pub table: String,
142 pub field_name: String,
143}
144
145#[derive(Debug, Clone, PartialEq)]
146pub struct AlterMergeFieldsQuery {
147 pub table: String,
148 pub sources: Vec<String>,
149 pub into: String,
150}
151
152#[derive(Debug, Clone, PartialEq)]
153pub enum Statement {
154 Select(SelectQuery),
155 Insert(InsertQuery),
156 Update(UpdateQuery),
157 Delete(DeleteQuery),
158 AlterRename(AlterRenameFieldQuery),
159 AlterDrop(AlterDropFieldQuery),
160 AlterMerge(AlterMergeFieldsQuery),
161}
162
163static KEYWORDS: &[&str] = &[
166 "SELECT", "FROM", "WHERE", "AND", "OR", "ORDER", "BY",
167 "ASC", "DESC", "LIMIT", "LIKE", "IN", "IS", "NOT", "NULL",
168 "JOIN", "ON", "AS", "GROUP",
169 "INSERT", "INTO", "VALUES", "UPDATE", "SET", "DELETE",
170 "ALTER", "TABLE", "RENAME", "FIELD", "TO", "DROP", "MERGE", "FIELDS",
171];
172
173static AGG_FUNCS: &[&str] = &["COUNT", "SUM", "AVG", "MIN", "MAX"];
174
175static TOKEN_RE: LazyLock<Regex> = LazyLock::new(|| {
176 Regex::new(
177 r#"(?x)
178 \s*(?:
179 (?P<backtick>`[^`]+`)
180 | (?P<string>'(?:[^'\\]|\\.)*')
181 | (?P<number>-?\d+(?:\.\d+)?)
182 | (?P<op><=|>=|!=|[=<>,*()])
183 | (?P<word>[A-Za-z_][A-Za-z0-9_./-]*)
184 )"#,
185 )
186 .unwrap()
187});
188
189#[derive(Debug, Clone)]
190struct Token {
191 token_type: String,
192 value: String,
193 raw: String,
194}
195
196fn tokenize(sql: &str) -> Vec<Token> {
197 let mut tokens = Vec::new();
198 for caps in TOKEN_RE.captures_iter(sql) {
199 if let Some(m) = caps.name("backtick") {
200 let raw = m.as_str();
201 tokens.push(Token {
202 token_type: "ident".into(),
203 value: raw[1..raw.len() - 1].into(),
204 raw: raw.into(),
205 });
206 } else if let Some(m) = caps.name("string") {
207 let raw = m.as_str();
208 tokens.push(Token {
209 token_type: "string".into(),
210 value: raw[1..raw.len() - 1].into(),
211 raw: raw.into(),
212 });
213 } else if let Some(m) = caps.name("number") {
214 let raw = m.as_str();
215 tokens.push(Token {
216 token_type: "number".into(),
217 value: raw.into(),
218 raw: raw.into(),
219 });
220 } else if let Some(m) = caps.name("op") {
221 let raw = m.as_str();
222 tokens.push(Token {
223 token_type: "op".into(),
224 value: raw.into(),
225 raw: raw.into(),
226 });
227 } else if let Some(m) = caps.name("word") {
228 let raw = m.as_str();
229 if KEYWORDS.contains(&raw.to_uppercase().as_str()) {
230 tokens.push(Token {
231 token_type: "keyword".into(),
232 value: raw.to_uppercase(),
233 raw: raw.into(),
234 });
235 } else {
236 tokens.push(Token {
237 token_type: "ident".into(),
238 value: raw.into(),
239 raw: raw.into(),
240 });
241 }
242 }
243 }
244 tokens
245}
246
247struct Parser {
250 tokens: Vec<Token>,
251 pos: usize,
252}
253
254impl Parser {
255 fn new(tokens: Vec<Token>) -> Self {
256 Parser { tokens, pos: 0 }
257 }
258
259 fn peek(&self) -> Option<&Token> {
260 self.tokens.get(self.pos)
261 }
262
263 fn advance(&mut self) -> Token {
264 let t = self.tokens[self.pos].clone();
265 self.pos += 1;
266 t
267 }
268
269 fn expect(&mut self, type_: &str, value: Option<&str>) -> Result<Token, MdqlError> {
270 let t = self.peek().ok_or_else(|| {
271 MdqlError::QueryParse(format!(
272 "Unexpected end of query, expected {}",
273 value.unwrap_or(type_)
274 ))
275 })?;
276 let matches_type = t.token_type == type_;
277 let matches_value = value.map_or(true, |v| t.value == v);
278 if !matches_type || !matches_value {
279 return Err(MdqlError::QueryParse(format!(
280 "Expected {}, got '{}' at position {}",
281 value.unwrap_or(type_),
282 t.raw,
283 self.pos
284 )));
285 }
286 Ok(self.advance())
287 }
288
289 fn match_keyword(&mut self, kw: &str) -> bool {
290 if let Some(t) = self.peek() {
291 if t.token_type == "keyword" && t.value == kw {
292 self.advance();
293 return true;
294 }
295 }
296 false
297 }
298
299 fn parse_statement(&mut self) -> Result<Statement, MdqlError> {
300 let t = self.peek().ok_or_else(|| MdqlError::QueryParse("Empty query".into()))?;
301 match (t.token_type.as_str(), t.value.as_str()) {
302 ("keyword", "SELECT") => Ok(Statement::Select(self.parse_select()?)),
303 ("keyword", "INSERT") => Ok(Statement::Insert(self.parse_insert()?)),
304 ("keyword", "UPDATE") => Ok(Statement::Update(self.parse_update()?)),
305 ("keyword", "DELETE") => Ok(Statement::Delete(self.parse_delete()?)),
306 ("keyword", "ALTER") => self.parse_alter(),
307 _ => Err(MdqlError::QueryParse(format!(
308 "Expected SELECT, INSERT, UPDATE, DELETE, or ALTER, got '{}'",
309 t.raw
310 ))),
311 }
312 }
313
314 fn parse_select(&mut self) -> Result<SelectQuery, MdqlError> {
315 self.expect("keyword", Some("SELECT"))?;
316 let columns = self.parse_columns()?;
317 self.expect("keyword", Some("FROM"))?;
318 let table = self.parse_ident()?;
319
320 let mut table_alias = None;
322 if let Some(t) = self.peek() {
323 if t.token_type == "ident" && !self.is_clause_keyword(t) {
324 table_alias = Some(self.advance().value);
325 }
326 }
327
328 let mut joins = Vec::new();
330 while self.match_keyword("JOIN") {
331 let join_table = self.parse_ident()?;
332 let mut join_alias = None;
333 if let Some(t) = self.peek() {
334 if t.token_type == "ident" && !self.is_clause_keyword(t) {
335 join_alias = Some(self.advance().value);
336 }
337 }
338 self.expect("keyword", Some("ON"))?;
339 let left_col = self.parse_ident()?;
340 self.expect("op", Some("="))?;
341 let right_col = self.parse_ident()?;
342 joins.push(JoinClause {
343 table: join_table,
344 alias: join_alias,
345 left_col,
346 right_col,
347 });
348 }
349
350 let mut where_clause = None;
351 if self.match_keyword("WHERE") {
352 where_clause = Some(self.parse_or_expr()?);
353 }
354
355 let mut group_by = None;
356 if self.match_keyword("GROUP") {
357 self.expect("keyword", Some("BY"))?;
358 let mut cols = vec![self.parse_ident()?];
359 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
360 self.advance();
361 cols.push(self.parse_ident()?);
362 }
363 group_by = Some(cols);
364 }
365
366 let mut order_by = None;
367 if self.match_keyword("ORDER") {
368 self.expect("keyword", Some("BY"))?;
369 order_by = Some(self.parse_order_by()?);
370 }
371
372 let mut limit = None;
373 if self.match_keyword("LIMIT") {
374 let t = self.expect("number", None)?;
375 limit = Some(t.value.parse::<i64>().map_err(|_| {
376 MdqlError::QueryParse(format!("Invalid LIMIT value: {}", t.value))
377 })?);
378 }
379
380 self.expect_end()?;
381
382 Ok(SelectQuery {
383 columns,
384 table,
385 table_alias,
386 joins,
387 where_clause,
388 group_by,
389 order_by,
390 limit,
391 })
392 }
393
394 fn parse_insert(&mut self) -> Result<InsertQuery, MdqlError> {
395 self.expect("keyword", Some("INSERT"))?;
396 self.expect("keyword", Some("INTO"))?;
397 let table = self.parse_ident()?;
398
399 self.expect("op", Some("("))?;
400 let mut columns = vec![self.parse_ident()?];
401 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
402 self.advance();
403 columns.push(self.parse_ident()?);
404 }
405 self.expect("op", Some(")"))?;
406
407 self.expect("keyword", Some("VALUES"))?;
408
409 self.expect("op", Some("("))?;
410 let mut values = vec![self.parse_value()?];
411 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
412 self.advance();
413 values.push(self.parse_value()?);
414 }
415 self.expect("op", Some(")"))?;
416
417 if columns.len() != values.len() {
418 return Err(MdqlError::QueryParse(format!(
419 "Column count ({}) does not match value count ({})",
420 columns.len(),
421 values.len()
422 )));
423 }
424
425 self.expect_end()?;
426 Ok(InsertQuery {
427 table,
428 columns,
429 values,
430 })
431 }
432
433 fn parse_update(&mut self) -> Result<UpdateQuery, MdqlError> {
434 self.expect("keyword", Some("UPDATE"))?;
435 let table = self.parse_ident()?;
436 self.expect("keyword", Some("SET"))?;
437
438 let mut assignments = Vec::new();
439 let col = self.parse_ident()?;
440 self.expect("op", Some("="))?;
441 let val = self.parse_value()?;
442 assignments.push((col, val));
443
444 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
445 self.advance();
446 let col = self.parse_ident()?;
447 self.expect("op", Some("="))?;
448 let val = self.parse_value()?;
449 assignments.push((col, val));
450 }
451
452 let mut where_clause = None;
453 if self.match_keyword("WHERE") {
454 where_clause = Some(self.parse_or_expr()?);
455 }
456
457 self.expect_end()?;
458 Ok(UpdateQuery {
459 table,
460 assignments,
461 where_clause,
462 })
463 }
464
465 fn parse_delete(&mut self) -> Result<DeleteQuery, MdqlError> {
466 self.expect("keyword", Some("DELETE"))?;
467 self.expect("keyword", Some("FROM"))?;
468 let table = self.parse_ident()?;
469
470 let mut where_clause = None;
471 if self.match_keyword("WHERE") {
472 where_clause = Some(self.parse_or_expr()?);
473 }
474
475 self.expect_end()?;
476 Ok(DeleteQuery {
477 table,
478 where_clause,
479 })
480 }
481
482 fn parse_alter(&mut self) -> Result<Statement, MdqlError> {
483 self.expect("keyword", Some("ALTER"))?;
484 self.expect("keyword", Some("TABLE"))?;
485 let table = self.parse_ident()?;
486
487 let t = self.peek().ok_or_else(|| {
488 MdqlError::QueryParse("Expected RENAME, DROP, or MERGE after table name".into())
489 })?;
490
491 match (t.token_type.as_str(), t.value.as_str()) {
492 ("keyword", "RENAME") => {
493 self.advance();
494 self.expect("keyword", Some("FIELD"))?;
495 let old_name = self.parse_string_or_ident()?;
496 self.expect("keyword", Some("TO"))?;
497 let new_name = self.parse_string_or_ident()?;
498 self.expect_end()?;
499 Ok(Statement::AlterRename(AlterRenameFieldQuery {
500 table,
501 old_name,
502 new_name,
503 }))
504 }
505 ("keyword", "DROP") => {
506 self.advance();
507 self.expect("keyword", Some("FIELD"))?;
508 let field_name = self.parse_string_or_ident()?;
509 self.expect_end()?;
510 Ok(Statement::AlterDrop(AlterDropFieldQuery {
511 table,
512 field_name,
513 }))
514 }
515 ("keyword", "MERGE") => {
516 self.advance();
517 self.expect("keyword", Some("FIELDS"))?;
518 let mut sources = vec![self.parse_string_or_ident()?];
519 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
520 self.advance();
521 sources.push(self.parse_string_or_ident()?);
522 }
523 self.expect("keyword", Some("INTO"))?;
524 let target = self.parse_string_or_ident()?;
525 self.expect_end()?;
526 Ok(Statement::AlterMerge(AlterMergeFieldsQuery {
527 table,
528 sources,
529 into: target,
530 }))
531 }
532 _ => Err(MdqlError::QueryParse(format!(
533 "Expected RENAME, DROP, or MERGE, got '{}'",
534 t.raw
535 ))),
536 }
537 }
538
539 fn parse_string_or_ident(&mut self) -> Result<String, MdqlError> {
540 let t = self.peek().ok_or_else(|| {
541 MdqlError::QueryParse("Expected field name, got end of query".into())
542 })?;
543 match t.token_type.as_str() {
544 "string" => {
545 let v = self.advance().value;
546 Ok(v)
547 }
548 "ident" | "keyword" => {
549 let v = self.advance().value;
550 Ok(v)
551 }
552 _ => Err(MdqlError::QueryParse(format!(
553 "Expected field name, got '{}'",
554 t.raw
555 ))),
556 }
557 }
558
559 fn parse_columns(&mut self) -> Result<ColumnList, MdqlError> {
560 if let Some(t) = self.peek() {
561 if t.token_type == "op" && t.value == "*" {
562 self.advance();
563 return Ok(ColumnList::All);
564 }
565 }
566
567 let mut exprs = vec![self.parse_select_expr()?];
568 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
569 self.advance();
570 exprs.push(self.parse_select_expr()?);
571 }
572 Ok(ColumnList::Named(exprs))
573 }
574
575 fn peek_is_agg_func(&self) -> bool {
576 let t = match self.peek() {
577 Some(t) => t,
578 None => return false,
579 };
580 let name_upper = t.value.to_uppercase();
581 if !AGG_FUNCS.contains(&name_upper.as_str()) {
582 return false;
583 }
584 self.tokens
586 .get(self.pos + 1)
587 .map_or(false, |next| next.token_type == "op" && next.value == "(")
588 }
589
590 fn parse_select_expr(&mut self) -> Result<SelectExpr, MdqlError> {
591 let _t = self.peek().ok_or_else(|| {
592 MdqlError::QueryParse("Expected column or aggregate, got end of query".into())
593 })?;
594
595 if self.peek_is_agg_func() {
596 let func_name = self.advance().value.to_uppercase();
597 let func = match func_name.as_str() {
598 "COUNT" => AggFunc::Count,
599 "SUM" => AggFunc::Sum,
600 "AVG" => AggFunc::Avg,
601 "MIN" => AggFunc::Min,
602 "MAX" => AggFunc::Max,
603 _ => unreachable!(),
604 };
605 self.expect("op", Some("("))?;
606 let arg = if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "*") {
607 self.advance();
608 "*".to_string()
609 } else {
610 self.parse_ident()?
611 };
612 self.expect("op", Some(")"))?;
613
614 let alias = if self.match_keyword("AS") {
615 Some(self.parse_ident()?)
616 } else {
617 None
618 };
619
620 Ok(SelectExpr::Aggregate { func, arg, alias })
621 } else {
622 let name = self.parse_ident()?;
623 if self.match_keyword("AS") {
625 let alias = self.parse_ident()?;
626 Ok(SelectExpr::Aggregate {
627 func: AggFunc::Count, arg: name.clone(),
629 alias: Some(alias),
630 })
631 } else {
632 Ok(SelectExpr::Column(name))
633 }
634 }
635 }
636
637 fn parse_ident(&mut self) -> Result<String, MdqlError> {
638 let t = self.peek().ok_or_else(|| {
639 MdqlError::QueryParse("Expected identifier, got end of query".into())
640 })?;
641 match t.token_type.as_str() {
642 "ident" | "keyword" => {
643 let v = self.advance().value;
644 Ok(v)
645 }
646 _ => Err(MdqlError::QueryParse(format!(
647 "Expected identifier, got '{}'",
648 t.raw
649 ))),
650 }
651 }
652
653 fn parse_or_expr(&mut self) -> Result<WhereClause, MdqlError> {
654 let mut left = self.parse_and_expr()?;
655 while self.match_keyword("OR") {
656 let right = self.parse_and_expr()?;
657 left = WhereClause::BoolOp(BoolOp {
658 op: "OR".into(),
659 left: Box::new(left),
660 right: Box::new(right),
661 });
662 }
663 Ok(left)
664 }
665
666 fn parse_and_expr(&mut self) -> Result<WhereClause, MdqlError> {
667 let mut left = self.parse_comparison()?;
668 while self.match_keyword("AND") {
669 let right = self.parse_comparison()?;
670 left = WhereClause::BoolOp(BoolOp {
671 op: "AND".into(),
672 left: Box::new(left),
673 right: Box::new(right),
674 });
675 }
676 Ok(left)
677 }
678
679 fn parse_comparison(&mut self) -> Result<WhereClause, MdqlError> {
680 if self.peek().map_or(false, |t| t.token_type == "op" && t.value == "(") {
682 self.advance();
683 let expr = self.parse_or_expr()?;
684 self.expect("op", Some(")"))?;
685 return Ok(expr);
686 }
687
688 let col = self.parse_ident()?;
689
690 if self.match_keyword("IS") {
692 if self.match_keyword("NOT") {
693 self.expect("keyword", Some("NULL"))?;
694 return Ok(WhereClause::Comparison(Comparison {
695 column: col,
696 op: "IS NOT NULL".into(),
697 value: None,
698 }));
699 }
700 self.expect("keyword", Some("NULL"))?;
701 return Ok(WhereClause::Comparison(Comparison {
702 column: col,
703 op: "IS NULL".into(),
704 value: None,
705 }));
706 }
707
708 if self.match_keyword("IN") {
710 self.expect("op", Some("("))?;
711 let mut values = vec![self.parse_value()?];
712 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
713 self.advance();
714 values.push(self.parse_value()?);
715 }
716 self.expect("op", Some(")"))?;
717 return Ok(WhereClause::Comparison(Comparison {
718 column: col,
719 op: "IN".into(),
720 value: Some(SqlValue::List(values)),
721 }));
722 }
723
724 if self.match_keyword("LIKE") {
726 let val = self.parse_value()?;
727 return Ok(WhereClause::Comparison(Comparison {
728 column: col,
729 op: "LIKE".into(),
730 value: Some(val),
731 }));
732 }
733
734 if self.match_keyword("NOT") {
736 if self.match_keyword("LIKE") {
737 let val = self.parse_value()?;
738 return Ok(WhereClause::Comparison(Comparison {
739 column: col,
740 op: "NOT LIKE".into(),
741 value: Some(val),
742 }));
743 }
744 return Err(MdqlError::QueryParse("Expected LIKE after NOT".into()));
745 }
746
747 if let Some(t) = self.peek() {
749 if t.token_type == "op" && ["=", "!=", "<", ">", "<=", ">="].contains(&t.value.as_str())
750 {
751 let op = self.advance().value;
752 let val = self.parse_value()?;
753 return Ok(WhereClause::Comparison(Comparison {
754 column: col,
755 op,
756 value: Some(val),
757 }));
758 }
759 }
760
761 let got = self.peek().map_or("end".to_string(), |t| t.raw.clone());
762 Err(MdqlError::QueryParse(format!(
763 "Expected operator after '{}', got '{}'",
764 col, got
765 )))
766 }
767
768 fn parse_value(&mut self) -> Result<SqlValue, MdqlError> {
769 let t = self.peek().ok_or_else(|| {
770 MdqlError::QueryParse("Expected value, got end of query".into())
771 })?;
772 match t.token_type.as_str() {
773 "string" => {
774 let v = self.advance().value;
775 Ok(SqlValue::String(v))
776 }
777 "number" => {
778 let v = self.advance().value;
779 if v.contains('.') {
780 Ok(SqlValue::Float(v.parse().map_err(|_| {
781 MdqlError::QueryParse(format!("Invalid float: {}", v))
782 })?))
783 } else {
784 Ok(SqlValue::Int(v.parse().map_err(|_| {
785 MdqlError::QueryParse(format!("Invalid int: {}", v))
786 })?))
787 }
788 }
789 "keyword" if t.value == "NULL" => {
790 self.advance();
791 Ok(SqlValue::Null)
792 }
793 _ => Err(MdqlError::QueryParse(format!(
794 "Expected value, got '{}'",
795 t.raw
796 ))),
797 }
798 }
799
800 fn parse_order_by(&mut self) -> Result<Vec<OrderSpec>, MdqlError> {
801 let mut specs = vec![self.parse_order_spec()?];
802 while self.peek().map_or(false, |t| t.token_type == "op" && t.value == ",") {
803 self.advance();
804 specs.push(self.parse_order_spec()?);
805 }
806 Ok(specs)
807 }
808
809 fn parse_order_spec(&mut self) -> Result<OrderSpec, MdqlError> {
810 let col = self.parse_ident()?;
811 let descending = if self.match_keyword("DESC") {
812 true
813 } else {
814 self.match_keyword("ASC");
815 false
816 };
817 Ok(OrderSpec {
818 column: col,
819 descending,
820 })
821 }
822
823 fn is_clause_keyword(&self, t: &Token) -> bool {
824 t.token_type == "keyword"
825 && ["WHERE", "ORDER", "LIMIT", "JOIN", "ON", "GROUP"].contains(&t.value.as_str())
826 }
827
828 fn expect_end(&self) -> Result<(), MdqlError> {
829 if let Some(t) = self.peek() {
830 return Err(MdqlError::QueryParse(format!(
831 "Unexpected token '{}' at position {}",
832 t.raw, self.pos
833 )));
834 }
835 Ok(())
836 }
837}
838
839pub fn parse_query(sql: &str) -> crate::errors::Result<Statement> {
840 let tokens = tokenize(sql);
841 if tokens.is_empty() {
842 return Err(MdqlError::QueryParse("Empty query".into()));
843 }
844 let mut parser = Parser::new(tokens);
845 parser.parse_statement()
846}
847
848#[cfg(test)]
849mod tests {
850 use super::*;
851
852 #[test]
853 fn test_simple_select() {
854 let stmt = parse_query("SELECT title, status FROM strategies").unwrap();
855 if let Statement::Select(q) = stmt {
856 assert_eq!(q.columns, ColumnList::Named(vec![SelectExpr::Column("title".into()), SelectExpr::Column("status".into())]));
857 assert_eq!(q.table, "strategies");
858 } else {
859 panic!("Expected Select");
860 }
861 }
862
863 #[test]
864 fn test_select_star() {
865 let stmt = parse_query("SELECT * FROM test").unwrap();
866 if let Statement::Select(q) = stmt {
867 assert_eq!(q.columns, ColumnList::All);
868 } else {
869 panic!("Expected Select");
870 }
871 }
872
873 #[test]
874 fn test_where_clause() {
875 let stmt = parse_query("SELECT title FROM test WHERE count > 5").unwrap();
876 if let Statement::Select(q) = stmt {
877 assert!(q.where_clause.is_some());
878 } else {
879 panic!("Expected Select");
880 }
881 }
882
883 #[test]
884 fn test_order_by() {
885 let stmt =
886 parse_query("SELECT title FROM test ORDER BY composite DESC, title ASC").unwrap();
887 if let Statement::Select(q) = stmt {
888 let ob = q.order_by.unwrap();
889 assert_eq!(ob.len(), 2);
890 assert!(ob[0].descending);
891 assert!(!ob[1].descending);
892 } else {
893 panic!("Expected Select");
894 }
895 }
896
897 #[test]
898 fn test_limit() {
899 let stmt = parse_query("SELECT * FROM test LIMIT 10").unwrap();
900 if let Statement::Select(q) = stmt {
901 assert_eq!(q.limit, Some(10));
902 } else {
903 panic!("Expected Select");
904 }
905 }
906
907 #[test]
908 fn test_insert() {
909 let stmt = parse_query(
910 "INSERT INTO test (title, count) VALUES ('Hello', 42)",
911 )
912 .unwrap();
913 if let Statement::Insert(q) = stmt {
914 assert_eq!(q.table, "test");
915 assert_eq!(q.columns, vec!["title", "count"]);
916 assert_eq!(q.values[0], SqlValue::String("Hello".into()));
917 assert_eq!(q.values[1], SqlValue::Int(42));
918 } else {
919 panic!("Expected Insert");
920 }
921 }
922
923 #[test]
924 fn test_update() {
925 let stmt = parse_query("UPDATE test SET status = 'KILLED' WHERE path = 'a.md'").unwrap();
926 if let Statement::Update(q) = stmt {
927 assert_eq!(q.table, "test");
928 assert_eq!(q.assignments.len(), 1);
929 assert!(q.where_clause.is_some());
930 } else {
931 panic!("Expected Update");
932 }
933 }
934
935 #[test]
936 fn test_delete() {
937 let stmt = parse_query("DELETE FROM test WHERE status = 'draft'").unwrap();
938 if let Statement::Delete(q) = stmt {
939 assert_eq!(q.table, "test");
940 assert!(q.where_clause.is_some());
941 } else {
942 panic!("Expected Delete");
943 }
944 }
945
946 #[test]
947 fn test_alter_rename() {
948 let stmt =
949 parse_query("ALTER TABLE test RENAME FIELD 'Summary' TO 'Overview'").unwrap();
950 if let Statement::AlterRename(q) = stmt {
951 assert_eq!(q.old_name, "Summary");
952 assert_eq!(q.new_name, "Overview");
953 } else {
954 panic!("Expected AlterRename");
955 }
956 }
957
958 #[test]
959 fn test_alter_drop() {
960 let stmt = parse_query("ALTER TABLE test DROP FIELD 'Details'").unwrap();
961 if let Statement::AlterDrop(q) = stmt {
962 assert_eq!(q.field_name, "Details");
963 } else {
964 panic!("Expected AlterDrop");
965 }
966 }
967
968 #[test]
969 fn test_alter_merge() {
970 let stmt = parse_query(
971 "ALTER TABLE test MERGE FIELDS 'Entry Rules', 'Exit Rules' INTO 'Trading Rules'",
972 )
973 .unwrap();
974 if let Statement::AlterMerge(q) = stmt {
975 assert_eq!(q.sources, vec!["Entry Rules", "Exit Rules"]);
976 assert_eq!(q.into, "Trading Rules");
977 } else {
978 panic!("Expected AlterMerge");
979 }
980 }
981
982 #[test]
983 fn test_backtick_ident() {
984 let stmt = parse_query("SELECT `Structural Mechanism` FROM test").unwrap();
985 if let Statement::Select(q) = stmt {
986 assert_eq!(
987 q.columns,
988 ColumnList::Named(vec![SelectExpr::Column("Structural Mechanism".into())])
989 );
990 } else {
991 panic!("Expected Select");
992 }
993 }
994
995 #[test]
996 fn test_like_operator() {
997 let stmt = parse_query("SELECT title FROM test WHERE categories LIKE '%defi%'").unwrap();
998 if let Statement::Select(q) = stmt {
999 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1000 assert_eq!(c.op, "LIKE");
1001 assert_eq!(c.value, Some(SqlValue::String("%defi%".into())));
1002 } else {
1003 panic!("Expected LIKE comparison");
1004 }
1005 } else {
1006 panic!("Expected Select");
1007 }
1008 }
1009
1010 #[test]
1011 fn test_in_operator() {
1012 let stmt =
1013 parse_query("SELECT * FROM test WHERE status IN ('ACTIVE', 'LIVE')").unwrap();
1014 if let Statement::Select(q) = stmt {
1015 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1016 assert_eq!(c.op, "IN");
1017 } else {
1018 panic!("Expected IN comparison");
1019 }
1020 } else {
1021 panic!("Expected Select");
1022 }
1023 }
1024
1025 #[test]
1026 fn test_is_null() {
1027 let stmt = parse_query("SELECT * FROM test WHERE title IS NULL").unwrap();
1028 if let Statement::Select(q) = stmt {
1029 if let Some(WhereClause::Comparison(c)) = q.where_clause {
1030 assert_eq!(c.op, "IS NULL");
1031 } else {
1032 panic!("Expected IS NULL comparison");
1033 }
1034 } else {
1035 panic!("Expected Select");
1036 }
1037 }
1038
1039 #[test]
1040 fn test_and_or() {
1041 let stmt = parse_query(
1042 "SELECT * FROM test WHERE status = 'ACTIVE' AND count > 5 OR title LIKE '%test%'",
1043 )
1044 .unwrap();
1045 if let Statement::Select(q) = stmt {
1046 assert!(q.where_clause.is_some());
1047 } else {
1048 panic!("Expected Select");
1049 }
1050 }
1051
1052 #[test]
1053 fn test_join() {
1054 let stmt = parse_query(
1055 "SELECT s.title, b.sharpe FROM strategies s JOIN backtests b ON b.strategy = s.path",
1056 )
1057 .unwrap();
1058 if let Statement::Select(q) = stmt {
1059 assert_eq!(q.table, "strategies");
1060 assert_eq!(q.table_alias, Some("s".into()));
1061 assert_eq!(q.joins.len(), 1);
1062 let join = &q.joins[0];
1063 assert_eq!(join.table, "backtests");
1064 assert_eq!(join.alias, Some("b".into()));
1065 } else {
1066 panic!("Expected Select");
1067 }
1068 }
1069
1070 #[test]
1071 fn test_multi_join() {
1072 let stmt = parse_query(
1073 "SELECT s.title, b.sharpe, c.verdict FROM strategies s JOIN backtests b ON b.strategy = s.path JOIN critiques c ON c.strategy = s.path",
1074 )
1075 .unwrap();
1076 if let Statement::Select(q) = stmt {
1077 assert_eq!(q.table, "strategies");
1078 assert_eq!(q.table_alias, Some("s".into()));
1079 assert_eq!(q.joins.len(), 2);
1080 assert_eq!(q.joins[0].table, "backtests");
1081 assert_eq!(q.joins[0].alias, Some("b".into()));
1082 assert_eq!(q.joins[0].left_col, "b.strategy");
1083 assert_eq!(q.joins[0].right_col, "s.path");
1084 assert_eq!(q.joins[1].table, "critiques");
1085 assert_eq!(q.joins[1].alias, Some("c".into()));
1086 assert_eq!(q.joins[1].left_col, "c.strategy");
1087 assert_eq!(q.joins[1].right_col, "s.path");
1088 } else {
1089 panic!("Expected Select");
1090 }
1091 }
1092
1093 #[test]
1094 fn test_empty_query() {
1095 assert!(parse_query("").is_err());
1096 }
1097
1098 #[test]
1099 fn test_count_star() {
1100 let stmt = parse_query("SELECT status, COUNT(*) AS cnt FROM strategies GROUP BY status").unwrap();
1101 if let Statement::Select(q) = stmt {
1102 if let ColumnList::Named(exprs) = &q.columns {
1103 assert_eq!(exprs.len(), 2);
1104 assert_eq!(exprs[0], SelectExpr::Column("status".into()));
1105 assert!(matches!(&exprs[1], SelectExpr::Aggregate {
1106 func: AggFunc::Count,
1107 arg,
1108 alias: Some(a),
1109 } if arg == "*" && a == "cnt"));
1110 } else {
1111 panic!("Expected Named columns");
1112 }
1113 assert_eq!(q.group_by, Some(vec!["status".into()]));
1114 } else {
1115 panic!("Expected Select");
1116 }
1117 }
1118
1119 #[test]
1120 fn test_count_column_as_ident() {
1121 let stmt = parse_query("INSERT INTO test (title, count) VALUES ('Hello', 42)").unwrap();
1123 if let Statement::Insert(q) = stmt {
1124 assert_eq!(q.columns, vec!["title", "count"]);
1125 } else {
1126 panic!("Expected Insert");
1127 }
1128 }
1129
1130 #[test]
1131 fn test_multiple_aggregates() {
1132 let stmt = parse_query("SELECT MIN(composite), MAX(composite), AVG(composite) FROM strategies").unwrap();
1133 if let Statement::Select(q) = stmt {
1134 if let ColumnList::Named(exprs) = &q.columns {
1135 assert_eq!(exprs.len(), 3);
1136 assert!(matches!(&exprs[0], SelectExpr::Aggregate { func: AggFunc::Min, .. }));
1137 assert!(matches!(&exprs[1], SelectExpr::Aggregate { func: AggFunc::Max, .. }));
1138 assert!(matches!(&exprs[2], SelectExpr::Aggregate { func: AggFunc::Avg, .. }));
1139 } else {
1140 panic!("Expected Named columns");
1141 }
1142 assert_eq!(q.group_by, None);
1143 } else {
1144 panic!("Expected Select");
1145 }
1146 }
1147}