1use chumsky::prelude::*;
6use rust_decimal::Decimal;
7use std::str::FromStr;
8
9use crate::ast::{
10 BalancesQuery, BinaryOperator, Expr, FromClause, FunctionCall, JournalQuery, Literal,
11 OrderSpec, PrintQuery, Query, SelectQuery, SortDirection, Target, UnaryOperator,
12 WindowFunction, WindowSpec,
13};
14use crate::error::{ParseError, ParseErrorKind};
15use rustledger_core::NaiveDate;
16
17type ParserInput<'a> = &'a str;
18type ParserExtra<'a> = extra::Err<Rich<'a, char>>;
19
20pub fn parse(source: &str) -> Result<Query, ParseError> {
26 let (result, errs) = query_parser()
27 .then_ignore(ws())
28 .then_ignore(end())
29 .parse(source)
30 .into_output_errors();
31
32 if let Some(query) = result {
33 Ok(query)
34 } else {
35 let err = errs.first().map(|e| {
36 let kind = if e.found().is_none() {
37 ParseErrorKind::UnexpectedEof
38 } else {
39 ParseErrorKind::SyntaxError(e.to_string())
40 };
41 ParseError::new(kind, e.span().start)
42 });
43 Err(err.unwrap_or_else(|| ParseError::new(ParseErrorKind::UnexpectedEof, 0)))
44 }
45}
46
47fn ws<'a>() -> impl Parser<'a, ParserInput<'a>, (), ParserExtra<'a>> + Clone {
49 one_of(" \t\r\n").repeated().ignored()
50}
51
52fn ws1<'a>() -> impl Parser<'a, ParserInput<'a>, (), ParserExtra<'a>> + Clone {
54 one_of(" \t\r\n").repeated().at_least(1).ignored()
55}
56
57fn kw<'a>(keyword: &'static str) -> impl Parser<'a, ParserInput<'a>, (), ParserExtra<'a>> + Clone {
59 text::keyword(keyword).ignored()
60}
61
62fn digits<'a>() -> impl Parser<'a, ParserInput<'a>, &'a str, ParserExtra<'a>> + Clone {
64 one_of("0123456789").repeated().at_least(1).to_slice()
65}
66
67fn query_parser<'a>() -> impl Parser<'a, ParserInput<'a>, Query, ParserExtra<'a>> {
69 ws().ignore_then(choice((
70 select_query().map(|sq| Query::Select(Box::new(sq))),
71 journal_query().map(Query::Journal),
72 balances_query().map(Query::Balances),
73 print_query().map(Query::Print),
74 )))
75 .then_ignore(ws())
76 .then_ignore(just(';').or_not())
77}
78
79fn select_query<'a>() -> impl Parser<'a, ParserInput<'a>, SelectQuery, ParserExtra<'a>> {
81 recursive(|select_parser| {
82 let subquery_from = ws1()
84 .ignore_then(kw("FROM"))
85 .ignore_then(ws1())
86 .ignore_then(just('('))
87 .ignore_then(ws())
88 .ignore_then(select_parser)
89 .then_ignore(ws())
90 .then_ignore(just(')'))
91 .map(|sq| Some(FromClause::from_subquery(sq)));
92
93 let regular_from = from_clause().map(Some);
95
96 kw("SELECT")
97 .ignore_then(ws1())
98 .ignore_then(
99 kw("DISTINCT")
100 .then_ignore(ws1())
101 .or_not()
102 .map(|d| d.is_some()),
103 )
104 .then(targets())
105 .then(
106 subquery_from
107 .or(regular_from)
108 .or_not()
109 .map(std::option::Option::flatten),
110 )
111 .then(where_clause().or_not())
112 .then(group_by_clause().or_not())
113 .then(having_clause().or_not())
114 .then(pivot_by_clause().or_not())
115 .then(order_by_clause().or_not())
116 .then(limit_clause().or_not())
117 .map(
118 |(
119 (
120 (
121 (((((distinct, targets), from), where_clause), group_by), having),
122 pivot_by,
123 ),
124 order_by,
125 ),
126 limit,
127 )| {
128 SelectQuery {
129 distinct,
130 targets,
131 from,
132 where_clause,
133 group_by,
134 having,
135 pivot_by,
136 order_by,
137 limit,
138 }
139 },
140 )
141 })
142}
143
144fn from_clause<'a>() -> impl Parser<'a, ParserInput<'a>, FromClause, ParserExtra<'a>> + Clone {
146 ws1()
147 .ignore_then(kw("FROM"))
148 .ignore_then(ws1())
149 .ignore_then(from_modifiers())
150}
151
152fn targets<'a>() -> impl Parser<'a, ParserInput<'a>, Vec<Target>, ParserExtra<'a>> + Clone {
154 target()
155 .separated_by(ws().then(just(',')).then(ws()))
156 .at_least(1)
157 .collect()
158}
159
160fn target<'a>() -> impl Parser<'a, ParserInput<'a>, Target, ParserExtra<'a>> + Clone {
162 expr()
163 .then(
164 ws1()
165 .ignore_then(kw("AS"))
166 .ignore_then(ws1())
167 .ignore_then(identifier())
168 .or_not(),
169 )
170 .map(|(expr, alias)| Target { expr, alias })
171}
172
173fn from_modifiers<'a>() -> impl Parser<'a, ParserInput<'a>, FromClause, ParserExtra<'a>> + Clone {
175 let open_on = kw("OPEN")
176 .ignore_then(ws1())
177 .ignore_then(kw("ON"))
178 .ignore_then(ws1())
179 .ignore_then(date_literal())
180 .then_ignore(ws());
181
182 let close_on = kw("CLOSE")
183 .ignore_then(ws().then(kw("ON")).then(ws()).or_not())
184 .ignore_then(date_literal())
185 .then_ignore(ws());
186
187 let clear = kw("CLEAR").then_ignore(ws());
188
189 open_on
191 .or_not()
192 .then(close_on.or_not())
193 .then(clear.or_not().map(|c| c.is_some()))
194 .then(from_filter().or_not())
195 .map(|(((open_on, close_on), clear), filter)| FromClause {
196 open_on,
197 close_on,
198 clear,
199 filter,
200 subquery: None,
201 })
202}
203
204fn from_filter<'a>() -> impl Parser<'a, ParserInput<'a>, Expr, ParserExtra<'a>> + Clone {
206 expr()
207}
208
209fn where_clause<'a>() -> impl Parser<'a, ParserInput<'a>, Expr, ParserExtra<'a>> + Clone {
211 ws1()
212 .ignore_then(kw("WHERE"))
213 .ignore_then(ws1())
214 .ignore_then(expr())
215}
216
217fn group_by_clause<'a>() -> impl Parser<'a, ParserInput<'a>, Vec<Expr>, ParserExtra<'a>> + Clone {
219 ws1()
220 .ignore_then(kw("GROUP"))
221 .ignore_then(ws1())
222 .ignore_then(kw("BY"))
223 .ignore_then(ws1())
224 .ignore_then(
225 expr()
226 .separated_by(ws().then(just(',')).then(ws()))
227 .at_least(1)
228 .collect(),
229 )
230}
231
232fn having_clause<'a>() -> impl Parser<'a, ParserInput<'a>, Expr, ParserExtra<'a>> + Clone {
234 ws1()
235 .ignore_then(kw("HAVING"))
236 .ignore_then(ws1())
237 .ignore_then(expr())
238}
239
240fn pivot_by_clause<'a>() -> impl Parser<'a, ParserInput<'a>, Vec<Expr>, ParserExtra<'a>> + Clone {
242 ws1()
243 .ignore_then(kw("PIVOT"))
244 .ignore_then(ws1())
245 .ignore_then(kw("BY"))
246 .ignore_then(ws1())
247 .ignore_then(
248 expr()
249 .separated_by(ws().then(just(',')).then(ws()))
250 .at_least(1)
251 .collect(),
252 )
253}
254
255fn order_by_clause<'a>() -> impl Parser<'a, ParserInput<'a>, Vec<OrderSpec>, ParserExtra<'a>> + Clone
257{
258 ws1()
259 .ignore_then(kw("ORDER"))
260 .ignore_then(ws1())
261 .ignore_then(kw("BY"))
262 .ignore_then(ws1())
263 .ignore_then(
264 order_spec()
265 .separated_by(ws().then(just(',')).then(ws()))
266 .at_least(1)
267 .collect(),
268 )
269}
270
271fn order_spec<'a>() -> impl Parser<'a, ParserInput<'a>, OrderSpec, ParserExtra<'a>> + Clone {
273 expr()
274 .then(
275 ws1()
276 .ignore_then(choice((
277 kw("ASC").to(SortDirection::Asc),
278 kw("DESC").to(SortDirection::Desc),
279 )))
280 .or_not(),
281 )
282 .map(|(expr, dir)| OrderSpec {
283 expr,
284 direction: dir.unwrap_or_default(),
285 })
286}
287
288fn limit_clause<'a>() -> impl Parser<'a, ParserInput<'a>, u64, ParserExtra<'a>> + Clone {
290 ws1()
291 .ignore_then(kw("LIMIT"))
292 .ignore_then(ws1())
293 .ignore_then(integer())
294 .map(|n| n as u64)
295}
296
297fn journal_query<'a>() -> impl Parser<'a, ParserInput<'a>, JournalQuery, ParserExtra<'a>> + Clone {
299 kw("JOURNAL")
300 .ignore_then(ws1())
301 .ignore_then(string_literal())
302 .then(at_function().or_not())
303 .then(
304 ws1()
305 .ignore_then(kw("FROM"))
306 .ignore_then(ws1())
307 .ignore_then(from_modifiers())
308 .or_not(),
309 )
310 .map(|((account_pattern, at_function), from)| JournalQuery {
311 account_pattern,
312 at_function,
313 from,
314 })
315}
316
317fn balances_query<'a>() -> impl Parser<'a, ParserInput<'a>, BalancesQuery, ParserExtra<'a>> + Clone
319{
320 kw("BALANCES")
321 .ignore_then(at_function().or_not())
322 .then(
323 ws1()
324 .ignore_then(kw("FROM"))
325 .ignore_then(ws1())
326 .ignore_then(from_modifiers())
327 .or_not(),
328 )
329 .map(|(at_function, from)| BalancesQuery { at_function, from })
330}
331
332fn print_query<'a>() -> impl Parser<'a, ParserInput<'a>, PrintQuery, ParserExtra<'a>> + Clone {
334 kw("PRINT")
335 .ignore_then(
336 ws1()
337 .ignore_then(kw("FROM"))
338 .ignore_then(ws1())
339 .ignore_then(from_modifiers())
340 .or_not(),
341 )
342 .map(|from| PrintQuery { from })
343}
344
345fn at_function<'a>() -> impl Parser<'a, ParserInput<'a>, String, ParserExtra<'a>> + Clone {
347 ws1()
348 .ignore_then(kw("AT"))
349 .ignore_then(ws1())
350 .ignore_then(identifier())
351}
352
353#[allow(clippy::large_stack_frames)]
355fn expr<'a>() -> impl Parser<'a, ParserInput<'a>, Expr, ParserExtra<'a>> + Clone {
356 recursive(|expr| {
357 let primary = primary_expr(expr.clone());
358
359 let unary = just('-')
361 .then_ignore(ws())
362 .or_not()
363 .then(primary)
364 .map(|(neg, e)| {
365 if neg.is_some() {
366 Expr::unary(UnaryOperator::Neg, e)
367 } else {
368 e
369 }
370 });
371
372 let multiplicative = unary.clone().foldl(
374 ws().ignore_then(choice((
375 just('*').to(BinaryOperator::Mul),
376 just('/').to(BinaryOperator::Div),
377 )))
378 .then_ignore(ws())
379 .then(unary)
380 .repeated(),
381 |left, (op, right)| Expr::binary(left, op, right),
382 );
383
384 let additive = multiplicative.clone().foldl(
386 ws().ignore_then(choice((
387 just('+').to(BinaryOperator::Add),
388 just('-').to(BinaryOperator::Sub),
389 )))
390 .then_ignore(ws())
391 .then(multiplicative)
392 .repeated(),
393 |left, (op, right)| Expr::binary(left, op, right),
394 );
395
396 let comparison = additive
398 .clone()
399 .then(
400 ws().ignore_then(comparison_op())
401 .then_ignore(ws())
402 .then(additive)
403 .or_not(),
404 )
405 .map(|(left, rest)| {
406 if let Some((op, right)) = rest {
407 Expr::binary(left, op, right)
408 } else {
409 left
410 }
411 });
412
413 let not_expr = kw("NOT")
415 .ignore_then(ws1())
416 .repeated()
417 .collect::<Vec<_>>()
418 .then(comparison)
419 .map(|(nots, e)| {
420 nots.into_iter()
421 .fold(e, |acc, ()| Expr::unary(UnaryOperator::Not, acc))
422 });
423
424 let and_expr = not_expr.clone().foldl(
426 ws1()
427 .ignore_then(kw("AND"))
428 .ignore_then(ws1())
429 .ignore_then(not_expr)
430 .repeated(),
431 |left, right| Expr::binary(left, BinaryOperator::And, right),
432 );
433
434 and_expr.clone().foldl(
436 ws1()
437 .ignore_then(kw("OR"))
438 .ignore_then(ws1())
439 .ignore_then(and_expr)
440 .repeated(),
441 |left, right| Expr::binary(left, BinaryOperator::Or, right),
442 )
443 })
444}
445
446fn comparison_op<'a>() -> impl Parser<'a, ParserInput<'a>, BinaryOperator, ParserExtra<'a>> + Clone
448{
449 choice((
450 just("!=").to(BinaryOperator::Ne),
451 just("<=").to(BinaryOperator::Le),
452 just(">=").to(BinaryOperator::Ge),
453 just('=').to(BinaryOperator::Eq),
454 just('<').to(BinaryOperator::Lt),
455 just('>').to(BinaryOperator::Gt),
456 just('~').to(BinaryOperator::Regex),
457 kw("IN").to(BinaryOperator::In),
458 ))
459}
460
461fn primary_expr<'a>(
463 expr: impl Parser<'a, ParserInput<'a>, Expr, ParserExtra<'a>> + Clone + 'a,
464) -> impl Parser<'a, ParserInput<'a>, Expr, ParserExtra<'a>> + Clone {
465 choice((
466 just('(')
468 .ignore_then(ws())
469 .ignore_then(expr)
470 .then_ignore(ws())
471 .then_ignore(just(')'))
472 .map(|e| Expr::Paren(Box::new(e))),
473 function_call_or_column(),
475 literal().map(Expr::Literal),
477 just('*').to(Expr::Wildcard),
479 ))
480}
481
482fn function_call_or_column<'a>() -> impl Parser<'a, ParserInput<'a>, Expr, ParserExtra<'a>> + Clone
484{
485 identifier()
486 .then(
487 ws().ignore_then(just('('))
488 .ignore_then(ws())
489 .ignore_then(function_args())
490 .then_ignore(ws())
491 .then_ignore(just(')'))
492 .or_not(),
493 )
494 .then(
495 ws1()
497 .ignore_then(kw("OVER"))
498 .ignore_then(ws())
499 .ignore_then(just('('))
500 .ignore_then(ws())
501 .ignore_then(window_spec())
502 .then_ignore(ws())
503 .then_ignore(just(')'))
504 .or_not(),
505 )
506 .map(|((name, args), over)| {
507 if let Some(args) = args {
508 if let Some(window_spec) = over {
509 Expr::Window(WindowFunction {
511 name,
512 args,
513 over: window_spec,
514 })
515 } else {
516 Expr::Function(FunctionCall { name, args })
518 }
519 } else {
520 Expr::Column(name)
521 }
522 })
523}
524
525fn window_spec<'a>() -> impl Parser<'a, ParserInput<'a>, WindowSpec, ParserExtra<'a>> + Clone {
527 let partition_by = kw("PARTITION")
528 .ignore_then(ws1())
529 .ignore_then(kw("BY"))
530 .ignore_then(ws1())
531 .ignore_then(
532 simple_arg()
533 .separated_by(ws().then(just(',')).then(ws()))
534 .at_least(1)
535 .collect::<Vec<_>>(),
536 )
537 .then_ignore(ws());
538
539 let window_order_by = kw("ORDER")
540 .ignore_then(ws1())
541 .ignore_then(kw("BY"))
542 .ignore_then(ws1())
543 .ignore_then(
544 window_order_spec()
545 .separated_by(ws().then(just(',')).then(ws()))
546 .at_least(1)
547 .collect::<Vec<_>>(),
548 );
549
550 partition_by
551 .or_not()
552 .then(window_order_by.or_not())
553 .map(|(partition_by, order_by)| WindowSpec {
554 partition_by,
555 order_by,
556 })
557}
558
559fn window_order_spec<'a>() -> impl Parser<'a, ParserInput<'a>, OrderSpec, ParserExtra<'a>> + Clone {
561 simple_arg()
562 .then(
563 ws1()
564 .ignore_then(choice((
565 kw("ASC").to(SortDirection::Asc),
566 kw("DESC").to(SortDirection::Desc),
567 )))
568 .or_not(),
569 )
570 .map(|(expr, dir)| OrderSpec {
571 expr,
572 direction: dir.unwrap_or_default(),
573 })
574}
575
576fn function_args<'a>() -> impl Parser<'a, ParserInput<'a>, Vec<Expr>, ParserExtra<'a>> + Clone {
578 simple_arg()
581 .separated_by(ws().then(just(',')).then(ws()))
582 .collect()
583}
584
585fn simple_arg<'a>() -> impl Parser<'a, ParserInput<'a>, Expr, ParserExtra<'a>> + Clone {
587 choice((
588 just('*').to(Expr::Wildcard),
589 identifier().map(Expr::Column),
590 literal().map(Expr::Literal),
591 ))
592}
593
594fn literal<'a>() -> impl Parser<'a, ParserInput<'a>, Literal, ParserExtra<'a>> + Clone {
596 choice((
597 kw("TRUE").to(Literal::Boolean(true)),
599 kw("FALSE").to(Literal::Boolean(false)),
600 kw("NULL").to(Literal::Null),
601 date_literal().map(Literal::Date),
603 decimal().map(Literal::Number),
605 string_literal().map(Literal::String),
607 ))
608}
609
610fn identifier<'a>() -> impl Parser<'a, ParserInput<'a>, String, ParserExtra<'a>> + Clone {
612 text::ident().map(|s: &str| s.to_string())
613}
614
615fn string_literal<'a>() -> impl Parser<'a, ParserInput<'a>, String, ParserExtra<'a>> + Clone {
617 just('"')
619 .ignore_then(
620 none_of("\"\\")
621 .or(just('\\').ignore_then(any()))
622 .repeated()
623 .collect::<String>(),
624 )
625 .then_ignore(just('"'))
626}
627
628fn date_literal<'a>() -> impl Parser<'a, ParserInput<'a>, NaiveDate, ParserExtra<'a>> + Clone {
630 digits()
631 .then_ignore(just('-'))
632 .then(digits())
633 .then_ignore(just('-'))
634 .then(digits())
635 .try_map(|((year, month), day): ((&str, &str), &str), span| {
636 let year: i32 = year
637 .parse()
638 .map_err(|_| Rich::custom(span, "invalid year"))?;
639 let month: u32 = month
640 .parse()
641 .map_err(|_| Rich::custom(span, "invalid month"))?;
642 let day: u32 = day.parse().map_err(|_| Rich::custom(span, "invalid day"))?;
643 NaiveDate::from_ymd_opt(year, month, day)
644 .ok_or_else(|| Rich::custom(span, "invalid date"))
645 })
646}
647
648fn decimal<'a>() -> impl Parser<'a, ParserInput<'a>, Decimal, ParserExtra<'a>> + Clone {
650 just('-')
651 .or_not()
652 .then(digits())
653 .then(just('.').then(digits()).or_not())
654 .try_map(
655 |((neg, int_part), frac_part): ((Option<char>, &str), Option<(char, &str)>), span| {
656 let mut s = String::new();
657 if neg.is_some() {
658 s.push('-');
659 }
660 s.push_str(int_part);
661 if let Some((_, frac)) = frac_part {
662 s.push('.');
663 s.push_str(frac);
664 }
665 Decimal::from_str(&s).map_err(|_| Rich::custom(span, "invalid number"))
666 },
667 )
668}
669
670fn integer<'a>() -> impl Parser<'a, ParserInput<'a>, i64, ParserExtra<'a>> + Clone {
672 digits().try_map(|s: &str, span| {
673 s.parse::<i64>()
674 .map_err(|_| Rich::custom(span, "invalid integer"))
675 })
676}
677
678#[cfg(test)]
679mod tests {
680 use super::*;
681 use rust_decimal_macros::dec;
682
683 #[test]
684 fn test_simple_select() {
685 let query = parse("SELECT * FROM year = 2024").unwrap();
686 match query {
687 Query::Select(sel) => {
688 assert!(!sel.distinct);
689 assert_eq!(sel.targets.len(), 1);
690 assert!(matches!(sel.targets[0].expr, Expr::Wildcard));
691 assert!(sel.from.is_some());
692 }
693 _ => panic!("Expected SELECT query"),
694 }
695 }
696
697 #[test]
698 fn test_select_columns() {
699 let query = parse("SELECT date, account, position").unwrap();
700 match query {
701 Query::Select(sel) => {
702 assert_eq!(sel.targets.len(), 3);
703 assert!(matches!(&sel.targets[0].expr, Expr::Column(c) if c == "date"));
704 assert!(matches!(&sel.targets[1].expr, Expr::Column(c) if c == "account"));
705 assert!(matches!(&sel.targets[2].expr, Expr::Column(c) if c == "position"));
706 }
707 _ => panic!("Expected SELECT query"),
708 }
709 }
710
711 #[test]
712 fn test_select_with_alias() {
713 let query = parse("SELECT SUM(position) AS total").unwrap();
714 match query {
715 Query::Select(sel) => {
716 assert_eq!(sel.targets.len(), 1);
717 assert_eq!(sel.targets[0].alias, Some("total".to_string()));
718 match &sel.targets[0].expr {
719 Expr::Function(f) => {
720 assert_eq!(f.name, "SUM");
721 assert_eq!(f.args.len(), 1);
722 }
723 _ => panic!("Expected function"),
724 }
725 }
726 _ => panic!("Expected SELECT query"),
727 }
728 }
729
730 #[test]
731 fn test_select_distinct() {
732 let query = parse("SELECT DISTINCT account").unwrap();
733 match query {
734 Query::Select(sel) => {
735 assert!(sel.distinct);
736 }
737 _ => panic!("Expected SELECT query"),
738 }
739 }
740
741 #[test]
742 fn test_where_clause() {
743 let query = parse("SELECT * WHERE account ~ \"Expenses:\"").unwrap();
744 match query {
745 Query::Select(sel) => {
746 assert!(sel.where_clause.is_some());
747 match sel.where_clause.unwrap() {
748 Expr::BinaryOp(op) => {
749 assert_eq!(op.op, BinaryOperator::Regex);
750 }
751 _ => panic!("Expected binary op"),
752 }
753 }
754 _ => panic!("Expected SELECT query"),
755 }
756 }
757
758 #[test]
759 fn test_group_by() {
760 let query = parse("SELECT account, SUM(position) GROUP BY account").unwrap();
761 match query {
762 Query::Select(sel) => {
763 assert!(sel.group_by.is_some());
764 assert_eq!(sel.group_by.unwrap().len(), 1);
765 }
766 _ => panic!("Expected SELECT query"),
767 }
768 }
769
770 #[test]
771 fn test_order_by() {
772 let query = parse("SELECT * ORDER BY date DESC, account ASC").unwrap();
773 match query {
774 Query::Select(sel) => {
775 assert!(sel.order_by.is_some());
776 let order = sel.order_by.unwrap();
777 assert_eq!(order.len(), 2);
778 assert_eq!(order[0].direction, SortDirection::Desc);
779 assert_eq!(order[1].direction, SortDirection::Asc);
780 }
781 _ => panic!("Expected SELECT query"),
782 }
783 }
784
785 #[test]
786 fn test_limit() {
787 let query = parse("SELECT * LIMIT 100").unwrap();
788 match query {
789 Query::Select(sel) => {
790 assert_eq!(sel.limit, Some(100));
791 }
792 _ => panic!("Expected SELECT query"),
793 }
794 }
795
796 #[test]
797 fn test_from_open_close_clear() {
798 let query = parse("SELECT * FROM OPEN ON 2024-01-01 CLOSE ON 2024-12-31 CLEAR").unwrap();
799 match query {
800 Query::Select(sel) => {
801 let from = sel.from.unwrap();
802 assert_eq!(
803 from.open_on,
804 Some(NaiveDate::from_ymd_opt(2024, 1, 1).unwrap())
805 );
806 assert_eq!(
807 from.close_on,
808 Some(NaiveDate::from_ymd_opt(2024, 12, 31).unwrap())
809 );
810 assert!(from.clear);
811 }
812 _ => panic!("Expected SELECT query"),
813 }
814 }
815
816 #[test]
817 fn test_journal_query() {
818 let query = parse("JOURNAL \"Assets:Bank\" AT cost").unwrap();
819 match query {
820 Query::Journal(j) => {
821 assert_eq!(j.account_pattern, "Assets:Bank");
822 assert_eq!(j.at_function, Some("cost".to_string()));
823 }
824 _ => panic!("Expected JOURNAL query"),
825 }
826 }
827
828 #[test]
829 fn test_balances_query() {
830 let query = parse("BALANCES AT units FROM year = 2024").unwrap();
831 match query {
832 Query::Balances(b) => {
833 assert_eq!(b.at_function, Some("units".to_string()));
834 assert!(b.from.is_some());
835 }
836 _ => panic!("Expected BALANCES query"),
837 }
838 }
839
840 #[test]
841 fn test_print_query() {
842 let query = parse("PRINT").unwrap();
843 assert!(matches!(query, Query::Print(_)));
844 }
845
846 #[test]
847 fn test_complex_expression() {
848 let query = parse("SELECT * WHERE date >= 2024-01-01 AND account ~ \"Expenses:\"").unwrap();
849 match query {
850 Query::Select(sel) => match sel.where_clause.unwrap() {
851 Expr::BinaryOp(op) => {
852 assert_eq!(op.op, BinaryOperator::And);
853 }
854 _ => panic!("Expected AND"),
855 },
856 _ => panic!("Expected SELECT query"),
857 }
858 }
859
860 #[test]
861 fn test_number_literal() {
862 let query = parse("SELECT * WHERE year = 2024").unwrap();
863 match query {
864 Query::Select(sel) => match sel.where_clause.unwrap() {
865 Expr::BinaryOp(op) => match op.right {
866 Expr::Literal(Literal::Number(n)) => {
867 assert_eq!(n, dec!(2024));
868 }
869 _ => panic!("Expected number literal"),
870 },
871 _ => panic!("Expected binary op"),
872 },
873 _ => panic!("Expected SELECT query"),
874 }
875 }
876
877 #[test]
878 fn test_semicolon_optional() {
879 assert!(parse("SELECT *").is_ok());
880 assert!(parse("SELECT *;").is_ok());
881 }
882
883 #[test]
884 fn test_subquery_basic() {
885 let query = parse("SELECT * FROM (SELECT account, position)").unwrap();
886 match query {
887 Query::Select(sel) => {
888 assert!(sel.from.is_some());
889 let from = sel.from.unwrap();
890 assert!(from.subquery.is_some());
891 let subquery = from.subquery.unwrap();
892 assert_eq!(subquery.targets.len(), 2);
893 }
894 _ => panic!("Expected SELECT query"),
895 }
896 }
897
898 #[test]
899 fn test_subquery_with_groupby() {
900 let query = parse(
901 "SELECT account, total FROM (SELECT account, SUM(position) AS total GROUP BY account)",
902 )
903 .unwrap();
904 match query {
905 Query::Select(sel) => {
906 assert_eq!(sel.targets.len(), 2);
907 let from = sel.from.unwrap();
908 assert!(from.subquery.is_some());
909 let subquery = from.subquery.unwrap();
910 assert!(subquery.group_by.is_some());
911 }
912 _ => panic!("Expected SELECT query"),
913 }
914 }
915
916 #[test]
917 fn test_subquery_with_outer_where() {
918 let query =
919 parse("SELECT * FROM (SELECT * WHERE year = 2024) WHERE account ~ \"Expenses:\"")
920 .unwrap();
921 match query {
922 Query::Select(sel) => {
923 assert!(sel.where_clause.is_some());
925 let from = sel.from.unwrap();
927 let subquery = from.subquery.unwrap();
928 assert!(subquery.where_clause.is_some());
929 }
930 _ => panic!("Expected SELECT query"),
931 }
932 }
933
934 #[test]
935 fn test_nested_subquery() {
936 let query = parse("SELECT * FROM (SELECT * FROM (SELECT account))").unwrap();
938 match query {
939 Query::Select(sel) => {
940 let from = sel.from.unwrap();
941 let subquery1 = from.subquery.unwrap();
942 let from2 = subquery1.from.unwrap();
943 assert!(from2.subquery.is_some());
944 }
945 _ => panic!("Expected SELECT query"),
946 }
947 }
948}