1use chumsky::prelude::*;
18use smol_str::SmolStr;
19
20use crate::ast::*;
21use crate::span::Spanned;
22use crate::token::Token;
23
24use super::pattern::ident;
25
26type ParserInput = Token;
27type ParserError = Simple<Token>;
28
29pub fn expression_parser() -> impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone
31{
32 recursive(|expr| {
36 let primary = primary_expr(expr.clone());
37 let postfix = postfix_expr(primary, expr.clone()).boxed();
38 let unary = unary_expr(postfix);
39 let power = power_expr(unary);
40 let mul_div = mul_div_expr(power).boxed();
41 let add_sub = add_sub_expr(mul_div);
42 let string_list = string_list_expr(add_sub, expr.clone());
43 let comparison = comparison_expr(string_list).boxed();
44 let not = not_expr(comparison);
45 let and = and_expr(not);
46 let xor = xor_expr(and);
47 or_expr(xor).boxed()
48 })
49}
50
51fn primary_expr(
54 expr: impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone + 'static,
55) -> impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone {
56 let integer = select! { Token::Integer(n) => Expression::Literal(Literal::Integer(n)) };
57 let float = select! { Token::Float(s) => {
58 let f: f64 = s.parse().unwrap_or(0.0);
59 Expression::Literal(Literal::Float(f))
60 }};
61 let string_lit =
62 select! { Token::StringLiteral(s) => Expression::Literal(Literal::String(s)) };
63 let bool_true = just(Token::True).to(Expression::Literal(Literal::Bool(true)));
64 let bool_false = just(Token::False).to(Expression::Literal(Literal::Bool(false)));
65 let null = just(Token::Null).to(Expression::Literal(Literal::Null));
66
67 let variable = ident().map(Expression::Variable);
70
71 let parameter = select! { Token::Parameter(name) => Expression::Parameter(name) };
72
73 let count_star = just(Token::Count)
75 .then(just(Token::LeftParen))
76 .then(just(Token::Star))
77 .then(just(Token::RightParen))
78 .to(Expression::CountStar);
79
80 let list_comprehension = just(Token::LeftBracket)
83 .ignore_then(ident().map_with_span(|n, s| (n, s)))
84 .then_ignore(just(Token::In))
85 .then(expr.clone())
86 .then(
87 just(Token::Where)
88 .ignore_then(expr.clone())
89 .or_not(),
90 )
91 .then(
92 just(Token::Pipe)
93 .ignore_then(expr.clone())
94 .or_not(),
95 )
96 .then_ignore(just(Token::RightBracket))
97 .map(|(((variable, list), filter), projection)| {
98 Expression::ListComprehension {
99 variable,
100 list: Box::new(list),
101 filter: filter.map(Box::new),
102 projection: projection.map(Box::new),
103 }
104 });
105
106 let list_literal = expr
108 .clone()
109 .separated_by(just(Token::Comma))
110 .allow_trailing()
111 .delimited_by(just(Token::LeftBracket), just(Token::RightBracket))
112 .map(Expression::ListLiteral);
113
114 let map_entry = ident()
116 .map_with_span(|n, s| (n, s))
117 .then_ignore(just(Token::Colon))
118 .then(expr.clone());
119
120 let map_literal = map_entry
121 .separated_by(just(Token::Comma))
122 .allow_trailing()
123 .delimited_by(just(Token::LeftBrace), just(Token::RightBrace))
124 .map(Expression::MapLiteral);
125
126 let case_expr = case_expression(expr.clone());
128
129 let paren = expr
131 .clone()
132 .delimited_by(just(Token::LeftParen), just(Token::RightParen))
133 .map(|(e, _span)| e);
134
135 let literals = choice((
137 count_star,
138 integer,
139 float,
140 string_lit,
141 bool_true,
142 bool_false,
143 null,
144 ))
145 .boxed();
146
147 let compound = choice((
148 parameter,
149 case_expr,
150 list_comprehension,
151 list_literal,
152 map_literal,
153 paren,
154 variable,
155 ))
156 .boxed();
157
158 literals.or(compound).map_with_span(|e, s| (e, s))
159}
160
161fn postfix_expr(
164 primary: impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone + 'static,
165 expr: impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone + 'static,
166) -> impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone {
167 let distinct = just(Token::Distinct).or_not().map(|d| d.is_some());
172
173 let func_args = distinct
174 .then(
175 expr.clone()
176 .separated_by(just(Token::Comma))
177 .allow_trailing(),
178 )
179 .delimited_by(just(Token::LeftParen), just(Token::RightParen));
180
181 let subscript = expr
183 .clone()
184 .delimited_by(just(Token::LeftBracket), just(Token::RightBracket));
185
186 enum Postfix {
187 Property(Spanned<SmolStr>),
188 Subscript(Spanned<Expression>),
189 FuncCall(bool, Vec<Spanned<Expression>>),
190 }
191
192 let property = just(Token::Dot)
193 .ignore_then(ident().map_with_span(|n, s| (n, s)))
194 .map(Postfix::Property);
195
196 let sub = subscript.map(Postfix::Subscript);
197
198 let call = func_args.map(|(d, args)| Postfix::FuncCall(d, args));
199
200 primary
201 .then(choice((property, call, sub)).repeated())
202 .foldl(|base, postfix| {
203 let span_start = base.1.start;
204 match postfix {
205 Postfix::Property(key) => {
206 let span_end = key.1.end;
207 (
208 Expression::Property {
209 object: Box::new(base),
210 key,
211 },
212 span_start..span_end,
213 )
214 }
215 Postfix::Subscript(index) => {
216 let span_end = index.1.end;
217 (
218 Expression::Subscript {
219 expr: Box::new(base),
220 index: Box::new(index),
221 },
222 span_start..span_end,
223 )
224 }
225 Postfix::FuncCall(distinct, args) => {
226 let name = match &base.0 {
228 Expression::Variable(n) => vec![(n.clone(), base.1.clone())],
229 Expression::Property { object, key } => {
230 let mut names = Vec::new();
232 if let Expression::Variable(n) = &object.0 {
233 names.push((n.clone(), object.1.clone()));
234 }
235 names.push(key.clone());
236 names
237 }
238 _ => vec![(SmolStr::new("<unknown>"), base.1.clone())],
239 };
240 let span_end = args.last().map(|a| a.1.end).unwrap_or(base.1.end) + 1;
241 (
242 Expression::FunctionCall {
243 name,
244 distinct,
245 args,
246 },
247 span_start..span_end,
248 )
249 }
250 }
251 })
252}
253
254fn unary_expr(
257 inner: impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone,
258) -> impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone {
259 let minus = just(Token::Dash).to(UnaryOp::Minus);
260 let bitwise_not = just(Token::Tilde).to(UnaryOp::BitwiseNot);
261
262 let op = minus.or(bitwise_not);
263
264 op.map_with_span(|op, s: std::ops::Range<usize>| (op, s))
265 .repeated()
266 .then(inner)
267 .foldr(|(op, op_span), operand| {
268 let span = op_span.start..operand.1.end;
269 (
270 Expression::UnaryOp {
271 op,
272 operand: Box::new(operand),
273 },
274 span,
275 )
276 })
277}
278
279fn power_expr(
282 inner: impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone,
283) -> impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone {
284 inner
285 .clone()
286 .then(just(Token::Caret).ignore_then(inner).repeated())
287 .foldl(|left, right| {
288 let span = left.1.start..right.1.end;
289 (
290 Expression::BinaryOp {
291 left: Box::new(left),
292 op: BinaryOp::Pow,
293 right: Box::new(right),
294 },
295 span,
296 )
297 })
298}
299
300fn mul_div_expr(
303 inner: impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone,
304) -> impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone {
305 let op = choice((
306 just(Token::Star).to(BinaryOp::Mul),
307 just(Token::Slash).to(BinaryOp::Div),
308 just(Token::Percent).to(BinaryOp::Mod),
309 ));
310
311 inner
312 .clone()
313 .then(op.then(inner).repeated())
314 .foldl(|left, (op, right)| {
315 let span = left.1.start..right.1.end;
316 (
317 Expression::BinaryOp {
318 left: Box::new(left),
319 op,
320 right: Box::new(right),
321 },
322 span,
323 )
324 })
325}
326
327fn add_sub_expr(
330 inner: impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone,
331) -> impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone {
332 let op = choice((
333 just(Token::Plus).to(BinaryOp::Add),
334 just(Token::Dash).to(BinaryOp::Sub),
335 ));
336
337 inner
338 .clone()
339 .then(op.then(inner).repeated())
340 .foldl(|left, (op, right)| {
341 let span = left.1.start..right.1.end;
342 (
343 Expression::BinaryOp {
344 left: Box::new(left),
345 op,
346 right: Box::new(right),
347 },
348 span,
349 )
350 })
351}
352
353fn string_list_expr(
356 inner: impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone + 'static,
357 _full_expr: impl Parser<ParserInput, Spanned<Expression>, Error = ParserError>
358 + Clone
359 + 'static,
360) -> impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone {
361 enum PostfixStringOp {
362 StartsWith(Spanned<Expression>),
363 EndsWith(Spanned<Expression>),
364 Contains(Spanned<Expression>),
365 In(Spanned<Expression>),
366 IsNull(bool),
367 HasLabel(Vec<Spanned<SmolStr>>),
368 }
369
370 let starts_with = just(Token::Starts)
371 .ignore_then(just(Token::With))
372 .ignore_then(inner.clone())
373 .map(PostfixStringOp::StartsWith);
374
375 let ends_with = just(Token::Ends)
376 .ignore_then(just(Token::With))
377 .ignore_then(inner.clone())
378 .map(PostfixStringOp::EndsWith);
379
380 let contains = just(Token::Contains)
381 .ignore_then(inner.clone())
382 .map(PostfixStringOp::Contains);
383
384 let in_list = just(Token::In)
385 .ignore_then(inner.clone())
386 .map(PostfixStringOp::In);
387
388 let is_null = just(Token::Is)
389 .ignore_then(just(Token::Not).or_not())
390 .then_ignore(just(Token::Null))
391 .map(|not| PostfixStringOp::IsNull(not.is_some()));
392
393 let has_label = just(Token::Colon)
395 .ignore_then(ident().map_with_span(|n, s| (n, s)))
396 .repeated()
397 .at_least(1)
398 .map(PostfixStringOp::HasLabel);
399
400 inner
401 .then(
402 choice((starts_with, ends_with, contains, in_list, is_null, has_label)).repeated(),
403 )
404 .foldl(|left, op| match op {
405 PostfixStringOp::StartsWith(right) => {
406 let span = left.1.start..right.1.end;
407 (
408 Expression::StringOp {
409 left: Box::new(left),
410 op: StringOp::StartsWith,
411 right: Box::new(right),
412 },
413 span,
414 )
415 }
416 PostfixStringOp::EndsWith(right) => {
417 let span = left.1.start..right.1.end;
418 (
419 Expression::StringOp {
420 left: Box::new(left),
421 op: StringOp::EndsWith,
422 right: Box::new(right),
423 },
424 span,
425 )
426 }
427 PostfixStringOp::Contains(right) => {
428 let span = left.1.start..right.1.end;
429 (
430 Expression::StringOp {
431 left: Box::new(left),
432 op: StringOp::Contains,
433 right: Box::new(right),
434 },
435 span,
436 )
437 }
438 PostfixStringOp::In(right) => {
439 let span = left.1.start..right.1.end;
440 (
441 Expression::InList {
442 expr: Box::new(left),
443 list: Box::new(right),
444 negated: false,
445 },
446 span,
447 )
448 }
449 PostfixStringOp::IsNull(negated) => {
450 let span = left.1.clone();
451 (
452 Expression::IsNull {
453 expr: Box::new(left),
454 negated,
455 },
456 span,
457 )
458 }
459 PostfixStringOp::HasLabel(labels) => {
460 let span_end = labels.last().map(|l| l.1.end).unwrap_or(left.1.end);
461 let span = left.1.start..span_end;
462 (
463 Expression::HasLabel {
464 expr: Box::new(left),
465 labels,
466 },
467 span,
468 )
469 }
470 })
471}
472
473fn comparison_expr(
476 inner: impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone,
477) -> impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone {
478 let op = choice((
479 just(Token::Eq).to(ComparisonOp::Eq),
480 just(Token::Neq).to(ComparisonOp::Neq),
481 just(Token::Le).to(ComparisonOp::Le),
482 just(Token::Lt).to(ComparisonOp::Lt),
483 just(Token::Ge).to(ComparisonOp::Ge),
484 just(Token::Gt).to(ComparisonOp::Gt),
485 just(Token::RegexMatch).to(ComparisonOp::RegexMatch),
486 ));
487
488 inner
489 .clone()
490 .then(op.then(inner).repeated())
491 .map_with_span(|(left, ops), span| {
492 if ops.is_empty() {
493 left
494 } else {
495 (
496 Expression::Comparison {
497 left: Box::new(left),
498 ops,
499 },
500 span,
501 )
502 }
503 })
504}
505
506fn not_expr(
509 inner: impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone,
510) -> impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone {
511 just(Token::Not)
512 .map_with_span(|_, s: std::ops::Range<usize>| s)
513 .repeated()
514 .then(inner)
515 .foldr(|op_span, operand| {
516 let span = op_span.start..operand.1.end;
517 (
518 Expression::UnaryOp {
519 op: UnaryOp::Not,
520 operand: Box::new(operand),
521 },
522 span,
523 )
524 })
525}
526
527fn and_expr(
530 inner: impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone,
531) -> impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone {
532 inner
533 .clone()
534 .then(just(Token::And).ignore_then(inner).repeated())
535 .foldl(|left, right| {
536 let span = left.1.start..right.1.end;
537 (
538 Expression::BinaryOp {
539 left: Box::new(left),
540 op: BinaryOp::And,
541 right: Box::new(right),
542 },
543 span,
544 )
545 })
546}
547
548fn xor_expr(
551 inner: impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone,
552) -> impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone {
553 inner
554 .clone()
555 .then(just(Token::Xor).ignore_then(inner).repeated())
556 .foldl(|left, right| {
557 let span = left.1.start..right.1.end;
558 (
559 Expression::BinaryOp {
560 left: Box::new(left),
561 op: BinaryOp::Xor,
562 right: Box::new(right),
563 },
564 span,
565 )
566 })
567}
568
569fn or_expr(
572 inner: impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone,
573) -> impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone {
574 inner
575 .clone()
576 .then(just(Token::Or).ignore_then(inner).repeated())
577 .foldl(|left, right| {
578 let span = left.1.start..right.1.end;
579 (
580 Expression::BinaryOp {
581 left: Box::new(left),
582 op: BinaryOp::Or,
583 right: Box::new(right),
584 },
585 span,
586 )
587 })
588}
589
590fn case_expression(
593 expr: impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone + 'static,
594) -> impl Parser<ParserInput, Expression, Error = ParserError> + Clone {
595 let when_clause = just(Token::When)
596 .ignore_then(expr.clone())
597 .then_ignore(just(Token::Then))
598 .then(expr.clone());
599
600 let else_clause = just(Token::Else).ignore_then(expr.clone());
601
602 just(Token::Case)
603 .ignore_then(expr.clone().map(Box::new).or_not())
604 .then(when_clause.repeated().at_least(1))
605 .then(else_clause.or_not())
606 .then_ignore(just(Token::End))
607 .map(|((operand, whens), else_expr)| Expression::Case {
608 operand,
609 whens,
610 else_expr: else_expr.map(Box::new),
611 })
612}
613
614#[cfg(test)]
619mod tests {
620 use super::*;
621 use crate::lexer::Lexer;
622
623 fn parse_expr(src: &str) -> Option<Expression> {
624 let (tokens, lex_errors) = Lexer::new(src).lex();
625 assert!(lex_errors.is_empty(), "lex errors: {lex_errors:?}");
626
627 let len = src.len();
628 let stream = chumsky::Stream::from_iter(
629 len..len + 1,
630 tokens
631 .into_iter()
632 .filter(|(tok, _)| !matches!(tok, Token::Eof)),
633 );
634
635 let (result, errors) = expression_parser().then_ignore(end()).parse_recovery(stream);
636 if !errors.is_empty() {
637 eprintln!("parse errors: {errors:?}");
638 }
639 result.map(|(expr, _)| expr)
640 }
641
642 #[test]
643 fn integer_literal() {
644 let expr = parse_expr("42").unwrap();
645 assert!(matches!(expr, Expression::Literal(Literal::Integer(42))));
646 }
647
648 #[test]
649 fn float_literal() {
650 let expr = parse_expr("3.14").unwrap();
651 if let Expression::Literal(Literal::Float(f)) = expr {
652 assert!((f - 3.14).abs() < 1e-10);
653 } else {
654 panic!("expected float literal");
655 }
656 }
657
658 #[test]
659 fn string_literal() {
660 let expr = parse_expr("'hello'").unwrap();
661 assert!(matches!(
662 expr,
663 Expression::Literal(Literal::String(ref s)) if s == "hello"
664 ));
665 }
666
667 #[test]
668 fn boolean_literals() {
669 assert!(matches!(
670 parse_expr("TRUE").unwrap(),
671 Expression::Literal(Literal::Bool(true))
672 ));
673 assert!(matches!(
674 parse_expr("FALSE").unwrap(),
675 Expression::Literal(Literal::Bool(false))
676 ));
677 }
678
679 #[test]
680 fn null_literal() {
681 assert!(matches!(
682 parse_expr("NULL").unwrap(),
683 Expression::Literal(Literal::Null)
684 ));
685 }
686
687 #[test]
688 fn variable() {
689 let expr = parse_expr("n").unwrap();
690 assert!(matches!(expr, Expression::Variable(ref name) if name == "n"));
691 }
692
693 #[test]
694 fn parameter() {
695 let expr = parse_expr("$since").unwrap();
696 assert!(matches!(expr, Expression::Parameter(ref name) if name == "since"));
697 }
698
699 #[test]
700 fn property_access() {
701 let expr = parse_expr("n.age").unwrap();
702 assert!(matches!(expr, Expression::Property { .. }));
703 }
704
705 #[test]
706 fn binary_arithmetic() {
707 let expr = parse_expr("1 + 2 * 3").unwrap();
708 if let Expression::BinaryOp { op, .. } = &expr {
710 assert_eq!(*op, BinaryOp::Add);
711 } else {
712 panic!("expected binary op, got {expr:?}");
713 }
714 }
715
716 #[test]
717 fn comparison() {
718 let expr = parse_expr("n.age > 30").unwrap();
719 assert!(matches!(expr, Expression::Comparison { .. }));
720 }
721
722 #[test]
723 fn logical_and_or() {
724 let expr = parse_expr("a AND b OR c").unwrap();
725 if let Expression::BinaryOp { op, .. } = &expr {
727 assert_eq!(*op, BinaryOp::Or);
728 } else {
729 panic!("expected OR");
730 }
731 }
732
733 #[test]
734 fn not_expression() {
735 let expr = parse_expr("NOT x").unwrap();
736 assert!(matches!(
737 expr,
738 Expression::UnaryOp {
739 op: UnaryOp::Not,
740 ..
741 }
742 ));
743 }
744
745 #[test]
746 fn is_null() {
747 let expr = parse_expr("x IS NULL").unwrap();
748 assert!(matches!(
749 expr,
750 Expression::IsNull {
751 negated: false,
752 ..
753 }
754 ));
755 }
756
757 #[test]
758 fn is_not_null() {
759 let expr = parse_expr("x IS NOT NULL").unwrap();
760 assert!(matches!(
761 expr,
762 Expression::IsNull {
763 negated: true,
764 ..
765 }
766 ));
767 }
768
769 #[test]
770 fn in_list() {
771 let expr = parse_expr("x IN [1, 2, 3]").unwrap();
772 assert!(matches!(expr, Expression::InList { .. }));
773 }
774
775 #[test]
776 fn function_call() {
777 let expr = parse_expr("count(n)").unwrap();
778 assert!(matches!(expr, Expression::FunctionCall { .. }));
779 }
780
781 #[test]
782 fn count_star() {
783 let expr = parse_expr("count(*)").unwrap();
784 assert!(matches!(expr, Expression::CountStar));
785 }
786
787 #[test]
788 fn list_literal() {
789 let expr = parse_expr("[1, 2, 3]").unwrap();
790 if let Expression::ListLiteral(items) = expr {
791 assert_eq!(items.len(), 3);
792 } else {
793 panic!("expected list literal");
794 }
795 }
796
797 #[test]
798 fn case_expression_simple() {
799 let expr = parse_expr("CASE WHEN x > 0 THEN 'pos' ELSE 'neg' END").unwrap();
800 assert!(matches!(expr, Expression::Case { .. }));
801 }
802
803 #[test]
804 fn nested_parentheses() {
805 let expr = parse_expr("((1 + 2))").unwrap();
806 assert!(matches!(expr, Expression::BinaryOp { .. }));
808 }
809
810 #[test]
811 fn unary_minus() {
812 let expr = parse_expr("-42").unwrap();
813 assert!(matches!(
814 expr,
815 Expression::UnaryOp {
816 op: UnaryOp::Minus,
817 ..
818 }
819 ));
820 }
821
822 #[test]
823 fn starts_with() {
824 let expr = parse_expr("name STARTS WITH 'Al'").unwrap();
825 assert!(matches!(
826 expr,
827 Expression::StringOp {
828 op: StringOp::StartsWith,
829 ..
830 }
831 ));
832 }
833
834 #[test]
835 fn contains() {
836 let expr = parse_expr("name CONTAINS 'foo'").unwrap();
837 assert!(matches!(
838 expr,
839 Expression::StringOp {
840 op: StringOp::Contains,
841 ..
842 }
843 ));
844 }
845
846 #[test]
847 fn map_literal() {
848 let expr = parse_expr("{name: 'Alice', age: 30}").unwrap();
849 assert!(matches!(expr, Expression::MapLiteral(_)));
850 }
851
852 #[test]
853 fn chained_property_access() {
854 let expr = parse_expr("a.b.c").unwrap();
855 if let Expression::Property { object, key } = &expr {
857 assert_eq!(key.0.as_str(), "c");
858 assert!(matches!(object.0, Expression::Property { .. }));
859 } else {
860 panic!("expected property chain");
861 }
862 }
863}