1use chumsky::prelude::*;
18use smol_str::SmolStr;
19
20use crate::ast::*;
21use crate::span::Spanned;
22use crate::token::Token;
23
24use super::pattern::ident;
25
26type ParserInput = Token;
27type ParserError = Simple<Token>;
28
29pub fn expression_parser()
31-> impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone {
32 recursive(|expr| {
36 let primary = primary_expr(expr.clone());
37 let postfix = postfix_expr(primary, expr.clone()).boxed();
38 let unary = unary_expr(postfix);
39 let power = power_expr(unary);
40 let mul_div = mul_div_expr(power).boxed();
41 let add_sub = add_sub_expr(mul_div);
42 let string_list = string_list_expr(add_sub, expr.clone());
43 let comparison = comparison_expr(string_list).boxed();
44 let not = not_expr(comparison);
45 let and = and_expr(not);
46 let xor = xor_expr(and);
47 or_expr(xor).boxed()
48 })
49}
50
51fn primary_expr(
54 expr: impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone + 'static,
55) -> impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone {
56 let integer = select! { Token::Integer(n) => Expression::Literal(Literal::Integer(n)) };
57 let float = select! { Token::Float(s) => {
58 let f: f64 = s.parse().unwrap_or(0.0);
59 Expression::Literal(Literal::Float(f))
60 }};
61 let string_lit = select! { Token::StringLiteral(s) => Expression::Literal(Literal::String(s)) };
62 let bool_true = just(Token::True).to(Expression::Literal(Literal::Bool(true)));
63 let bool_false = just(Token::False).to(Expression::Literal(Literal::Bool(false)));
64 let null = just(Token::Null).to(Expression::Literal(Literal::Null));
65
66 let variable = ident().map(Expression::Variable);
69
70 let parameter = select! { Token::Parameter(name) => Expression::Parameter(name) };
71
72 let count_star = just(Token::Count)
74 .then(just(Token::LeftParen))
75 .then(just(Token::Star))
76 .then(just(Token::RightParen))
77 .to(Expression::CountStar);
78
79 let list_comprehension = just(Token::LeftBracket)
82 .ignore_then(ident().map_with_span(|n, s| (n, s)))
83 .then_ignore(just(Token::In))
84 .then(expr.clone())
85 .then(just(Token::Where).ignore_then(expr.clone()).or_not())
86 .then(just(Token::Pipe).ignore_then(expr.clone()).or_not())
87 .then_ignore(just(Token::RightBracket))
88 .map(
89 |(((variable, list), filter), projection)| Expression::ListComprehension {
90 variable,
91 list: Box::new(list),
92 filter: filter.map(Box::new),
93 projection: projection.map(Box::new),
94 },
95 );
96
97 let list_literal = expr
99 .clone()
100 .separated_by(just(Token::Comma))
101 .allow_trailing()
102 .delimited_by(just(Token::LeftBracket), just(Token::RightBracket))
103 .map(Expression::ListLiteral);
104
105 let map_entry = ident()
107 .map_with_span(|n, s| (n, s))
108 .then_ignore(just(Token::Colon))
109 .then(expr.clone());
110
111 let map_literal = map_entry
112 .separated_by(just(Token::Comma))
113 .allow_trailing()
114 .delimited_by(just(Token::LeftBrace), just(Token::RightBrace))
115 .map(Expression::MapLiteral);
116
117 let case_expr = case_expression(expr.clone());
119
120 let paren = expr
122 .clone()
123 .delimited_by(just(Token::LeftParen), just(Token::RightParen))
124 .map(|(e, _span)| e);
125
126 let literals = choice((
128 count_star, integer, float, string_lit, bool_true, bool_false, null,
129 ))
130 .boxed();
131
132 let compound = choice((
133 parameter,
134 case_expr,
135 list_comprehension,
136 list_literal,
137 map_literal,
138 paren,
139 variable,
140 ))
141 .boxed();
142
143 literals.or(compound).map_with_span(|e, s| (e, s))
144}
145
146fn postfix_expr(
149 primary: impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone + 'static,
150 expr: impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone + 'static,
151) -> impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone {
152 let distinct = just(Token::Distinct).or_not().map(|d| d.is_some());
157
158 let func_args = distinct
159 .then(
160 expr.clone()
161 .separated_by(just(Token::Comma))
162 .allow_trailing(),
163 )
164 .delimited_by(just(Token::LeftParen), just(Token::RightParen));
165
166 let subscript = expr
168 .clone()
169 .delimited_by(just(Token::LeftBracket), just(Token::RightBracket));
170
171 enum Postfix {
172 Property(Spanned<SmolStr>),
173 Subscript(Spanned<Expression>),
174 FuncCall(bool, Vec<Spanned<Expression>>),
175 }
176
177 let property = just(Token::Dot)
178 .ignore_then(ident().map_with_span(|n, s| (n, s)))
179 .map(Postfix::Property);
180
181 let sub = subscript.map(Postfix::Subscript);
182
183 let call = func_args.map(|(d, args)| Postfix::FuncCall(d, args));
184
185 primary
186 .then(choice((property, call, sub)).repeated())
187 .foldl(|base, postfix| {
188 let span_start = base.1.start;
189 match postfix {
190 Postfix::Property(key) => {
191 let span_end = key.1.end;
192 (
193 Expression::Property {
194 object: Box::new(base),
195 key,
196 },
197 span_start..span_end,
198 )
199 }
200 Postfix::Subscript(index) => {
201 let span_end = index.1.end;
202 (
203 Expression::Subscript {
204 expr: Box::new(base),
205 index: Box::new(index),
206 },
207 span_start..span_end,
208 )
209 }
210 Postfix::FuncCall(distinct, args) => {
211 let name = match &base.0 {
213 Expression::Variable(n) => vec![(n.clone(), base.1.clone())],
214 Expression::Property { object, key } => {
215 let mut names = Vec::new();
217 if let Expression::Variable(n) = &object.0 {
218 names.push((n.clone(), object.1.clone()));
219 }
220 names.push(key.clone());
221 names
222 }
223 _ => vec![(SmolStr::new("<unknown>"), base.1.clone())],
224 };
225 let span_end = args.last().map(|a| a.1.end).unwrap_or(base.1.end) + 1;
226 (
227 Expression::FunctionCall {
228 name,
229 distinct,
230 args,
231 },
232 span_start..span_end,
233 )
234 }
235 }
236 })
237}
238
239fn unary_expr(
242 inner: impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone,
243) -> impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone {
244 let minus = just(Token::Dash).to(UnaryOp::Minus);
245 let bitwise_not = just(Token::Tilde).to(UnaryOp::BitwiseNot);
246
247 let op = minus.or(bitwise_not);
248
249 op.map_with_span(|op, s: std::ops::Range<usize>| (op, s))
250 .repeated()
251 .then(inner)
252 .foldr(|(op, op_span), operand| {
253 let span = op_span.start..operand.1.end;
254 (
255 Expression::UnaryOp {
256 op,
257 operand: Box::new(operand),
258 },
259 span,
260 )
261 })
262}
263
264fn power_expr(
267 inner: impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone,
268) -> impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone {
269 inner
270 .clone()
271 .then(just(Token::Caret).ignore_then(inner).repeated())
272 .foldl(|left, right| {
273 let span = left.1.start..right.1.end;
274 (
275 Expression::BinaryOp {
276 left: Box::new(left),
277 op: BinaryOp::Pow,
278 right: Box::new(right),
279 },
280 span,
281 )
282 })
283}
284
285fn mul_div_expr(
288 inner: impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone,
289) -> impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone {
290 let op = choice((
291 just(Token::Star).to(BinaryOp::Mul),
292 just(Token::Slash).to(BinaryOp::Div),
293 just(Token::Percent).to(BinaryOp::Mod),
294 ));
295
296 inner
297 .clone()
298 .then(op.then(inner).repeated())
299 .foldl(|left, (op, right)| {
300 let span = left.1.start..right.1.end;
301 (
302 Expression::BinaryOp {
303 left: Box::new(left),
304 op,
305 right: Box::new(right),
306 },
307 span,
308 )
309 })
310}
311
312fn add_sub_expr(
315 inner: impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone,
316) -> impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone {
317 let op = choice((
318 just(Token::Plus).to(BinaryOp::Add),
319 just(Token::Dash).to(BinaryOp::Sub),
320 ));
321
322 inner
323 .clone()
324 .then(op.then(inner).repeated())
325 .foldl(|left, (op, right)| {
326 let span = left.1.start..right.1.end;
327 (
328 Expression::BinaryOp {
329 left: Box::new(left),
330 op,
331 right: Box::new(right),
332 },
333 span,
334 )
335 })
336}
337
338fn string_list_expr(
341 inner: impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone + 'static,
342 _full_expr: impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone + 'static,
343) -> impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone {
344 enum PostfixStringOp {
345 StartsWith(Spanned<Expression>),
346 EndsWith(Spanned<Expression>),
347 Contains(Spanned<Expression>),
348 In(Spanned<Expression>),
349 IsNull(bool),
350 HasLabel(Vec<Spanned<SmolStr>>),
351 }
352
353 let starts_with = just(Token::Starts)
354 .ignore_then(just(Token::With))
355 .ignore_then(inner.clone())
356 .map(PostfixStringOp::StartsWith);
357
358 let ends_with = just(Token::Ends)
359 .ignore_then(just(Token::With))
360 .ignore_then(inner.clone())
361 .map(PostfixStringOp::EndsWith);
362
363 let contains = just(Token::Contains)
364 .ignore_then(inner.clone())
365 .map(PostfixStringOp::Contains);
366
367 let in_list = just(Token::In)
368 .ignore_then(inner.clone())
369 .map(PostfixStringOp::In);
370
371 let is_null = just(Token::Is)
372 .ignore_then(just(Token::Not).or_not())
373 .then_ignore(just(Token::Null))
374 .map(|not| PostfixStringOp::IsNull(not.is_some()));
375
376 let has_label = just(Token::Colon)
378 .ignore_then(ident().map_with_span(|n, s| (n, s)))
379 .repeated()
380 .at_least(1)
381 .map(PostfixStringOp::HasLabel);
382
383 inner
384 .then(
385 choice((
386 starts_with,
387 ends_with,
388 contains,
389 in_list,
390 is_null,
391 has_label,
392 ))
393 .repeated(),
394 )
395 .foldl(|left, op| match op {
396 PostfixStringOp::StartsWith(right) => {
397 let span = left.1.start..right.1.end;
398 (
399 Expression::StringOp {
400 left: Box::new(left),
401 op: StringOp::StartsWith,
402 right: Box::new(right),
403 },
404 span,
405 )
406 }
407 PostfixStringOp::EndsWith(right) => {
408 let span = left.1.start..right.1.end;
409 (
410 Expression::StringOp {
411 left: Box::new(left),
412 op: StringOp::EndsWith,
413 right: Box::new(right),
414 },
415 span,
416 )
417 }
418 PostfixStringOp::Contains(right) => {
419 let span = left.1.start..right.1.end;
420 (
421 Expression::StringOp {
422 left: Box::new(left),
423 op: StringOp::Contains,
424 right: Box::new(right),
425 },
426 span,
427 )
428 }
429 PostfixStringOp::In(right) => {
430 let span = left.1.start..right.1.end;
431 (
432 Expression::InList {
433 expr: Box::new(left),
434 list: Box::new(right),
435 negated: false,
436 },
437 span,
438 )
439 }
440 PostfixStringOp::IsNull(negated) => {
441 let span = left.1.clone();
442 (
443 Expression::IsNull {
444 expr: Box::new(left),
445 negated,
446 },
447 span,
448 )
449 }
450 PostfixStringOp::HasLabel(labels) => {
451 let span_end = labels.last().map(|l| l.1.end).unwrap_or(left.1.end);
452 let span = left.1.start..span_end;
453 (
454 Expression::HasLabel {
455 expr: Box::new(left),
456 labels,
457 },
458 span,
459 )
460 }
461 })
462}
463
464fn comparison_expr(
467 inner: impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone,
468) -> impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone {
469 let op = choice((
470 just(Token::Eq).to(ComparisonOp::Eq),
471 just(Token::Neq).to(ComparisonOp::Neq),
472 just(Token::Le).to(ComparisonOp::Le),
473 just(Token::Lt).to(ComparisonOp::Lt),
474 just(Token::Ge).to(ComparisonOp::Ge),
475 just(Token::Gt).to(ComparisonOp::Gt),
476 just(Token::RegexMatch).to(ComparisonOp::RegexMatch),
477 ));
478
479 inner
480 .clone()
481 .then(op.then(inner).repeated())
482 .map_with_span(|(left, ops), span| {
483 if ops.is_empty() {
484 left
485 } else {
486 (
487 Expression::Comparison {
488 left: Box::new(left),
489 ops,
490 },
491 span,
492 )
493 }
494 })
495}
496
497fn not_expr(
500 inner: impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone,
501) -> impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone {
502 just(Token::Not)
503 .map_with_span(|_, s: std::ops::Range<usize>| s)
504 .repeated()
505 .then(inner)
506 .foldr(|op_span, operand| {
507 let span = op_span.start..operand.1.end;
508 (
509 Expression::UnaryOp {
510 op: UnaryOp::Not,
511 operand: Box::new(operand),
512 },
513 span,
514 )
515 })
516}
517
518fn and_expr(
521 inner: impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone,
522) -> impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone {
523 inner
524 .clone()
525 .then(just(Token::And).ignore_then(inner).repeated())
526 .foldl(|left, right| {
527 let span = left.1.start..right.1.end;
528 (
529 Expression::BinaryOp {
530 left: Box::new(left),
531 op: BinaryOp::And,
532 right: Box::new(right),
533 },
534 span,
535 )
536 })
537}
538
539fn xor_expr(
542 inner: impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone,
543) -> impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone {
544 inner
545 .clone()
546 .then(just(Token::Xor).ignore_then(inner).repeated())
547 .foldl(|left, right| {
548 let span = left.1.start..right.1.end;
549 (
550 Expression::BinaryOp {
551 left: Box::new(left),
552 op: BinaryOp::Xor,
553 right: Box::new(right),
554 },
555 span,
556 )
557 })
558}
559
560fn or_expr(
563 inner: impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone,
564) -> impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone {
565 inner
566 .clone()
567 .then(just(Token::Or).ignore_then(inner).repeated())
568 .foldl(|left, right| {
569 let span = left.1.start..right.1.end;
570 (
571 Expression::BinaryOp {
572 left: Box::new(left),
573 op: BinaryOp::Or,
574 right: Box::new(right),
575 },
576 span,
577 )
578 })
579}
580
581fn case_expression(
584 expr: impl Parser<ParserInput, Spanned<Expression>, Error = ParserError> + Clone + 'static,
585) -> impl Parser<ParserInput, Expression, Error = ParserError> + Clone {
586 let when_clause = just(Token::When)
587 .ignore_then(expr.clone())
588 .then_ignore(just(Token::Then))
589 .then(expr.clone());
590
591 let else_clause = just(Token::Else).ignore_then(expr.clone());
592
593 just(Token::Case)
594 .ignore_then(expr.clone().map(Box::new).or_not())
595 .then(when_clause.repeated().at_least(1))
596 .then(else_clause.or_not())
597 .then_ignore(just(Token::End))
598 .map(|((operand, whens), else_expr)| Expression::Case {
599 operand,
600 whens,
601 else_expr: else_expr.map(Box::new),
602 })
603}
604
605#[cfg(test)]
610mod tests {
611 use super::*;
612 use crate::lexer::Lexer;
613
614 fn parse_expr(src: &str) -> Option<Expression> {
615 let (tokens, lex_errors) = Lexer::new(src).lex();
616 assert!(lex_errors.is_empty(), "lex errors: {lex_errors:?}");
617
618 let len = src.len();
619 let stream = chumsky::Stream::from_iter(
620 len..len + 1,
621 tokens
622 .into_iter()
623 .filter(|(tok, _)| !matches!(tok, Token::Eof)),
624 );
625
626 let (result, errors) = expression_parser()
627 .then_ignore(end())
628 .parse_recovery(stream);
629 if !errors.is_empty() {
630 eprintln!("parse errors: {errors:?}");
631 }
632 result.map(|(expr, _)| expr)
633 }
634
635 #[test]
636 fn integer_literal() {
637 let expr = parse_expr("42").unwrap();
638 assert!(matches!(expr, Expression::Literal(Literal::Integer(42))));
639 }
640
641 #[test]
642 fn float_literal() {
643 let expr = parse_expr("3.14").unwrap();
644 if let Expression::Literal(Literal::Float(f)) = expr {
645 assert!((f - 3.14).abs() < 1e-10);
646 } else {
647 panic!("expected float literal");
648 }
649 }
650
651 #[test]
652 fn string_literal() {
653 let expr = parse_expr("'hello'").unwrap();
654 assert!(matches!(
655 expr,
656 Expression::Literal(Literal::String(ref s)) if s == "hello"
657 ));
658 }
659
660 #[test]
661 fn boolean_literals() {
662 assert!(matches!(
663 parse_expr("TRUE").unwrap(),
664 Expression::Literal(Literal::Bool(true))
665 ));
666 assert!(matches!(
667 parse_expr("FALSE").unwrap(),
668 Expression::Literal(Literal::Bool(false))
669 ));
670 }
671
672 #[test]
673 fn null_literal() {
674 assert!(matches!(
675 parse_expr("NULL").unwrap(),
676 Expression::Literal(Literal::Null)
677 ));
678 }
679
680 #[test]
681 fn variable() {
682 let expr = parse_expr("n").unwrap();
683 assert!(matches!(expr, Expression::Variable(ref name) if name == "n"));
684 }
685
686 #[test]
687 fn parameter() {
688 let expr = parse_expr("$since").unwrap();
689 assert!(matches!(expr, Expression::Parameter(ref name) if name == "since"));
690 }
691
692 #[test]
693 fn property_access() {
694 let expr = parse_expr("n.age").unwrap();
695 assert!(matches!(expr, Expression::Property { .. }));
696 }
697
698 #[test]
699 fn binary_arithmetic() {
700 let expr = parse_expr("1 + 2 * 3").unwrap();
701 if let Expression::BinaryOp { op, .. } = &expr {
703 assert_eq!(*op, BinaryOp::Add);
704 } else {
705 panic!("expected binary op, got {expr:?}");
706 }
707 }
708
709 #[test]
710 fn comparison() {
711 let expr = parse_expr("n.age > 30").unwrap();
712 assert!(matches!(expr, Expression::Comparison { .. }));
713 }
714
715 #[test]
716 fn logical_and_or() {
717 let expr = parse_expr("a AND b OR c").unwrap();
718 if let Expression::BinaryOp { op, .. } = &expr {
720 assert_eq!(*op, BinaryOp::Or);
721 } else {
722 panic!("expected OR");
723 }
724 }
725
726 #[test]
727 fn not_expression() {
728 let expr = parse_expr("NOT x").unwrap();
729 assert!(matches!(
730 expr,
731 Expression::UnaryOp {
732 op: UnaryOp::Not,
733 ..
734 }
735 ));
736 }
737
738 #[test]
739 fn is_null() {
740 let expr = parse_expr("x IS NULL").unwrap();
741 assert!(matches!(expr, Expression::IsNull { negated: false, .. }));
742 }
743
744 #[test]
745 fn is_not_null() {
746 let expr = parse_expr("x IS NOT NULL").unwrap();
747 assert!(matches!(expr, Expression::IsNull { negated: true, .. }));
748 }
749
750 #[test]
751 fn in_list() {
752 let expr = parse_expr("x IN [1, 2, 3]").unwrap();
753 assert!(matches!(expr, Expression::InList { .. }));
754 }
755
756 #[test]
757 fn function_call() {
758 let expr = parse_expr("count(n)").unwrap();
759 assert!(matches!(expr, Expression::FunctionCall { .. }));
760 }
761
762 #[test]
763 fn count_star() {
764 let expr = parse_expr("count(*)").unwrap();
765 assert!(matches!(expr, Expression::CountStar));
766 }
767
768 #[test]
769 fn list_literal() {
770 let expr = parse_expr("[1, 2, 3]").unwrap();
771 if let Expression::ListLiteral(items) = expr {
772 assert_eq!(items.len(), 3);
773 } else {
774 panic!("expected list literal");
775 }
776 }
777
778 #[test]
779 fn case_expression_simple() {
780 let expr = parse_expr("CASE WHEN x > 0 THEN 'pos' ELSE 'neg' END").unwrap();
781 assert!(matches!(expr, Expression::Case { .. }));
782 }
783
784 #[test]
785 fn nested_parentheses() {
786 let expr = parse_expr("((1 + 2))").unwrap();
787 assert!(matches!(expr, Expression::BinaryOp { .. }));
789 }
790
791 #[test]
792 fn unary_minus() {
793 let expr = parse_expr("-42").unwrap();
794 assert!(matches!(
795 expr,
796 Expression::UnaryOp {
797 op: UnaryOp::Minus,
798 ..
799 }
800 ));
801 }
802
803 #[test]
804 fn starts_with() {
805 let expr = parse_expr("name STARTS WITH 'Al'").unwrap();
806 assert!(matches!(
807 expr,
808 Expression::StringOp {
809 op: StringOp::StartsWith,
810 ..
811 }
812 ));
813 }
814
815 #[test]
816 fn contains() {
817 let expr = parse_expr("name CONTAINS 'foo'").unwrap();
818 assert!(matches!(
819 expr,
820 Expression::StringOp {
821 op: StringOp::Contains,
822 ..
823 }
824 ));
825 }
826
827 #[test]
828 fn map_literal() {
829 let expr = parse_expr("{name: 'Alice', age: 30}").unwrap();
830 assert!(matches!(expr, Expression::MapLiteral(_)));
831 }
832
833 #[test]
834 fn chained_property_access() {
835 let expr = parse_expr("a.b.c").unwrap();
836 if let Expression::Property { object, key } = &expr {
838 assert_eq!(key.0.as_str(), "c");
839 assert!(matches!(object.0, Expression::Property { .. }));
840 } else {
841 panic!("expected property chain");
842 }
843 }
844}