cas_parser/parser/ast/
expr.rs

1use cas_error::Error;
2use crate::{
3    parser::{
4        ast::{
5            assign::Assign,
6            binary::Binary,
7            block::Block,
8            branch::{Of, Then},
9            call::Call,
10            for_expr::For,
11            if_expr::If,
12            index::Index,
13            literal::Literal,
14            loop_expr::{Break, Continue, Loop},
15            paren::Paren,
16            product::Product,
17            range::Range,
18            return_expr::Return,
19            sum::Sum,
20            unary::Unary,
21            while_expr::While,
22        },
23        error::{ExpectedExpr, UnclosedParenthesis},
24        fmt::Latex,
25        iter::ExprIter,
26        token::{op::Precedence, CloseParen},
27        Parse,
28        ParseResult,
29        Parser,
30    },
31    tokenizer::TokenKind,
32    return_if_ok,
33};
34use std::fmt;
35
36#[cfg(feature = "serde")]
37use serde::{Deserialize, Serialize};
38
39/// Represents any kind of expression in CalcScript.
40///
41/// An expression is any valid piece of code that can be evaluated to produce a value. Expressions
42/// can be used as the right-hand side of an assignment, or as the argument to a function call.
43#[derive(Debug, Clone, PartialEq, Eq)]
44#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
45pub enum Expr {
46    /// A literal value.
47    Literal(Literal),
48
49    /// A parenthesized expression, such as `(1 + 2)`.
50    Paren(Paren),
51
52    /// A blocked expression, such as `{1 + 2}`.
53    Block(Block),
54
55    /// A sum expression, such as `sum n in 1..10 of n`.
56    Sum(Sum),
57
58    /// A product expression, such as `product n in 1..10 of n`.
59    Product(Product),
60
61    /// An if expression, such as `if x > 0 then x else -x`.
62    If(If),
63
64    /// A loop expression, as in `loop { ... }`.
65    Loop(Loop),
66
67    /// A while loop expression, as in `while x > 0 then { ... }`.
68    While(While),
69
70    /// A for loop expression, as in `for i in 0..10 then print(i)`.
71    For(For),
72
73    /// A then expression, as in `then x += 1`.
74    Then(Then),
75
76    /// An of expression, as in `of x`.
77    Of(Of),
78
79    /// A break expression, used to exit a loop, optionally with a value.
80    Break(Break),
81
82    /// A continue expression, used to skip the rest of a loop iteration.
83    Continue(Continue),
84
85    /// A return expression, as in `return x`, used to return a value from a function.
86    Return(Return),
87
88    /// A function call, such as `abs(-1)`.
89    Call(Call),
90
91    /// List indexing, such as `list[0]`.
92    Index(Index),
93
94    /// A unary operation, such as `-1` or `!true`.
95    Unary(Unary),
96
97    /// A binary operation, such as `1 + 2`.
98    Binary(Binary),
99
100    /// An assignment of a variable or function, such as `x = 1` or `f(x) = x^2`.
101    Assign(Assign),
102
103    /// A range expression, such as `1..10`.
104    Range(Range),
105}
106
107impl Expr {
108    /// Returns the span of the expression.
109    pub fn span(&self) -> std::ops::Range<usize> {
110        match self {
111            Expr::Literal(literal) => literal.span(),
112            Expr::Paren(paren) => paren.span(),
113            Expr::Block(block) => block.span(),
114            Expr::Sum(sum) => sum.span(),
115            Expr::Product(product) => product.span(),
116            Expr::If(if_expr) => if_expr.span(),
117            Expr::Loop(loop_expr) => loop_expr.span(),
118            Expr::While(while_expr) => while_expr.span(),
119            Expr::For(for_expr) => for_expr.span(),
120            Expr::Then(then) => then.span(),
121            Expr::Of(of) => of.span(),
122            Expr::Break(break_expr) => break_expr.span(),
123            Expr::Continue(continue_expr) => continue_expr.span(),
124            Expr::Return(return_expr) => return_expr.span(),
125            Expr::Call(call) => call.span(),
126            Expr::Index(index) => index.span(),
127            Expr::Unary(unary) => unary.span(),
128            Expr::Binary(binary) => binary.span(),
129            Expr::Assign(assign) => assign.span(),
130            Expr::Range(range) => range.span(),
131        }
132    }
133
134    /// Returns an iterator that traverses the tree of expressions in left-to-right post-order
135    /// (i.e. depth-first).
136    pub fn post_order_iter(&self) -> ExprIter {
137        ExprIter::new(self)
138    }
139
140    /// If this expression is an [`Expr::Paren`], returns the innermost expression in the
141    /// parenthesized expression. Otherwise, returns `self`.
142    pub fn innermost(&self) -> &Expr {
143        let mut inner = self;
144        while let Expr::Paren(paren) = inner {
145            inner = &paren.expr;
146        }
147        inner
148    }
149
150    /// Returns true if the given expression can be used as a target for implicit multiplication.
151    pub fn is_implicit_mul_target(&self) -> bool {
152        // TODO: there may be more reasonable targets
153        matches!(self,
154            Expr::Literal(Literal::Integer(_))
155                | Expr::Literal(Literal::Float(_))
156                | Expr::Literal(Literal::Radix(_))
157                | Expr::Literal(Literal::Symbol(_))
158                | Expr::Paren(_)
159                | Expr::Call(_)
160                | Expr::Unary(_)
161        )
162    }
163}
164
165impl<'source> Parse<'source> for Expr {
166    fn std_parse(
167        input: &mut Parser<'source>,
168        recoverable_errors: &mut Vec<Error>
169    ) -> Result<Self, Vec<Error>> {
170        if input.clone().try_parse::<CloseParen>().is_ok() {
171            return Err(vec![input.error(UnclosedParenthesis { opening: false })]);
172        }
173
174        let _ = return_if_ok!(input.try_parse().map(Self::Assign).forward_errors(recoverable_errors));
175        let lhs = Unary::parse_or_lower(input, recoverable_errors)?;
176        Ok(Binary::parse_expr(input, recoverable_errors, lhs, Precedence::Any)?.0)
177    }
178}
179
180/// Implements [`Parse`] for an [`Expr`] variant by parsing an [`Expr`] and then converting it to the
181/// variant.
182///
183/// This is done for completeness.
184macro_rules! impl_by_parsing_expr {
185    ($( $variant:ident $expected:literal ),* $(,)?) => {
186        $(
187            impl<'source> Parse<'source> for $variant {
188                fn std_parse(
189                    input: &mut Parser<'source>,
190                    recoverable_errors: &mut Vec<Error>
191                ) -> Result<Self, Vec<Error>> {
192                    match input.try_parse::<Expr>().forward_errors(recoverable_errors) {
193                        Ok(Expr::$variant(expr)) => Ok(expr),
194                        _ => Err(vec![input.error(ExpectedExpr { expected: $expected })]),
195                    }
196                }
197            }
198        )*
199    };
200}
201
202impl_by_parsing_expr!(
203    Call "a function call",
204    Index "a list indexing expression",
205    Unary "a unary operation",
206    Range "a range expression"
207);
208
209impl std::fmt::Display for Expr {
210    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
211        match self {
212            Expr::Literal(literal) => literal.fmt(f),
213            Expr::Paren(paren) => paren.fmt(f),
214            Expr::Block(block) => block.fmt(f),
215            Expr::Sum(sum) => sum.fmt(f),
216            Expr::Product(product) => product.fmt(f),
217            Expr::If(if_expr) => if_expr.fmt(f),
218            Expr::Loop(loop_expr) => loop_expr.fmt(f),
219            Expr::While(while_expr) => while_expr.fmt(f),
220            Expr::For(for_expr) => for_expr.fmt(f),
221            Expr::Then(then) => then.fmt(f),
222            Expr::Of(of) => of.fmt(f),
223            Expr::Break(break_expr) => break_expr.fmt(f),
224            Expr::Continue(continue_expr) => continue_expr.fmt(f),
225            Expr::Return(return_expr) => return_expr.fmt(f),
226            Expr::Call(call) => call.fmt(f),
227            Expr::Index(index) => index.fmt(f),
228            Expr::Unary(unary) => unary.fmt(f),
229            Expr::Binary(binary) => binary.fmt(f),
230            Expr::Assign(assign) => assign.fmt(f),
231            Expr::Range(range) => range.fmt(f),
232        }
233    }
234}
235
236impl Latex for Expr {
237    fn fmt_latex(&self, f: &mut fmt::Formatter) -> fmt::Result {
238        match self {
239            Expr::Literal(literal) => literal.fmt_latex(f),
240            Expr::Paren(paren) => paren.fmt_latex(f),
241            Expr::Block(block) => block.fmt_latex(f),
242            Expr::Sum(sum) => sum.fmt_latex(f),
243            Expr::Product(product) => product.fmt_latex(f),
244            Expr::If(if_expr) => if_expr.fmt_latex(f),
245            Expr::Loop(loop_expr) => loop_expr.fmt_latex(f),
246            Expr::While(while_expr) => while_expr.fmt_latex(f),
247            Expr::For(for_expr) => for_expr.fmt_latex(f),
248            Expr::Then(then) => then.fmt_latex(f),
249            Expr::Of(of) => of.fmt_latex(f),
250            Expr::Break(break_expr) => break_expr.fmt_latex(f),
251            Expr::Continue(continue_expr) => continue_expr.fmt_latex(f),
252            Expr::Return(return_expr) => return_expr.fmt_latex(f),
253            Expr::Call(call) => call.fmt_latex(f),
254            Expr::Index(index) => index.fmt_latex(f),
255            Expr::Unary(unary) => unary.fmt_latex(f),
256            Expr::Binary(binary) => binary.fmt_latex(f),
257            Expr::Assign(assign) => assign.fmt_latex(f),
258            Expr::Range(range) => range.fmt_latex(f),
259        }
260    }
261}
262
263/// Represents a primary expression in CalcScript.
264///
265/// Primary expressions extend the concept of [`Atom`] by allowing for more complex and ambiguous
266/// expressions that are still self-contained. These extensions include function calls and list
267/// indexing expressions, which can be ambiguous when encountered in isolation.
268///
269/// For example, when trying to parse a [`Primary`], a literal value like `abc` cannot
270/// automatically be declared a [`Primary::Literal`]. Instead, we must parse forward a little more
271/// to see if this is actually calling a function named `abc`, or indexing into a list named `abc`,
272/// or neither.
273#[derive(Debug, Clone, PartialEq, Eq)]
274pub enum Primary {
275    /// A literal value.
276    Literal(Literal),
277
278    /// A parenthesized expression, such as `(1 + 2)`.
279    Paren(Paren),
280
281    /// A blocked expression, such as `{1 + 2}`.
282    Block(Block),
283
284    /// A sum expression, such as `sum n in 1..10 of n`.
285    Sum(Sum),
286
287    /// A product expression, such as `product n in 1..10 of n`.
288    Product(Product),
289
290    /// An if expression, such as `if x > 0 then x else -x`.
291    If(If),
292
293    /// A loop expression, as in `loop { ... }`.
294    Loop(Loop),
295
296    /// A while loop expression, as in `while x > 0 then { ... }`.
297    While(While),
298
299    /// A for loop expression, as in `for i in 0..10 then print(i)`.
300    For(For),
301
302    /// A then expression, as in `then x += 1`.
303    Then(Then),
304
305    /// An of expression, as in `of x`.
306    Of(Of),
307
308    /// A break expression, used to exit a loop, optionally with a value.
309    Break(Break),
310
311    /// A continue expression, used to skip the rest of a loop iteration.
312    Continue(Continue),
313
314    /// A return expression, as in `return x`, used to return a value from a function.
315    Return(Return),
316
317    /// A function call, such as `abs(-1)`.
318    Call(Call),
319
320    /// List indexing, such as `list[0]`.
321    Index(Index),
322}
323
324impl Primary {
325    /// Returns the span of the primary expression.
326    pub fn span(&self) -> std::ops::Range<usize> {
327        match self {
328            Primary::Literal(literal) => literal.span(),
329            Primary::Paren(paren) => paren.span(),
330            Primary::Block(block) => block.span(),
331            Primary::Sum(sum) => sum.span(),
332            Primary::Product(product) => product.span(),
333            Primary::If(if_expr) => if_expr.span(),
334            Primary::Loop(loop_expr) => loop_expr.span(),
335            Primary::While(while_expr) => while_expr.span(),
336            Primary::For(for_expr) => for_expr.span(),
337            Primary::Then(then) => then.span(),
338            Primary::Of(of) => of.span(),
339            Primary::Break(break_expr) => break_expr.span(),
340            Primary::Continue(continue_expr) => continue_expr.span(),
341            Primary::Return(return_expr) => return_expr.span(),
342            Primary::Call(call) => call.span(),
343            Primary::Index(index) => index.span(),
344        }
345    }
346}
347
348impl<'source> Parse<'source> for Primary {
349    fn std_parse(
350        input: &mut Parser<'source>,
351        recoverable_errors: &mut Vec<Error>
352    ) -> Result<Self, Vec<Error>> {
353        let atom = input.try_parse::<Atom>().forward_errors(recoverable_errors)?;
354        let mut primary = Primary::from(atom);
355
356        loop {
357            let mut fork = input.clone();
358            match fork.next_token() {
359                Ok(next) if next.kind == TokenKind::OpenParen || next.kind == TokenKind::Quote => {
360                    match Call::parse_or_lower(input, recoverable_errors, primary)? {
361                        (new_primary, true) => primary = new_primary,
362                        // call was not parsed; is this implicit multiplication?
363                        (unchanged_primary, false) => break Ok(unchanged_primary),
364                    }
365                },
366                Ok(next) if next.kind == TokenKind::OpenSquare => {
367                    match Index::parse_or_lower(input, recoverable_errors, primary) {
368                        (new_primary, true) => primary = new_primary,
369                        (unchanged_primary, false) => break Ok(unchanged_primary),
370                    }
371                },
372                _ => break Ok(primary),
373            }
374        }
375    }
376}
377
378impl From<Primary> for Expr {
379    fn from(primary: Primary) -> Self {
380        match primary {
381            Primary::Literal(literal) => Self::Literal(literal),
382            Primary::Paren(paren) => Self::Paren(paren),
383            Primary::Block(block) => Self::Block(block),
384            Primary::Sum(sum) => Self::Sum(sum),
385            Primary::Product(product) => Self::Product(product),
386            Primary::If(if_expr) => Self::If(if_expr),
387            Primary::Loop(loop_expr) => Self::Loop(loop_expr),
388            Primary::While(while_expr) => Self::While(while_expr),
389            Primary::For(for_expr) => Self::For(for_expr),
390            Primary::Then(then) => Self::Then(then),
391            Primary::Of(of) => Self::Of(of),
392            Primary::Break(break_expr) => Self::Break(break_expr),
393            Primary::Continue(continue_expr) => Self::Continue(continue_expr),
394            Primary::Return(return_expr) => Self::Return(return_expr),
395            Primary::Call(call) => Self::Call(call),
396            Primary::Index(index) => Self::Index(index),
397        }
398    }
399}
400
401/// Represents an atom expression in CalcScript.
402///
403/// Atom expressions are the simplest kind of expression, and are entirely unambiguous to parse,
404/// meaning that they can be parsed without needing any context.
405///
406/// For example, a literal value like `1` or `true` has no ambiguity; when we encounter a numeric
407/// or boolean token, we know that it must be a literal value.
408///
409/// Some expressions, like `if` expressions or `loop` expressions, are also atom expressions,
410/// because they have a unique keyword that identifies them; when we encounter the `if` keyword, we
411/// automatically know there is only one correct way to parse the expression.
412///
413/// In addition, all atom expressions are self-contained. This means that atom expressions within a
414/// larger [`Expr`] can be replaced with semantically equivalent, but different variants of atom
415/// expressions, and the larger expression will still be valid.
416#[derive(Debug, Clone, PartialEq, Eq)]
417pub enum Atom {
418    /// A literal value.
419    Literal(Literal),
420
421    /// A parenthesized expression, such as `(1 + 2)`.
422    Paren(Paren),
423
424    /// A blocked expression, such as `{1 + 2}`.
425    Block(Block),
426
427    /// A sum expression, such as `sum n in 1..10 of n`.
428    Sum(Sum),
429
430    /// A product expression, such as `product n in 1..10 of n`.
431    Product(Product),
432
433    /// An if expression, such as `if x > 0 then x else -x`.
434    If(If),
435
436    /// A loop expression, as in `loop { ... }`.
437    Loop(Loop),
438
439    /// A while loop expression, as in `while x > 0 then { ... }`.
440    While(While),
441
442    /// A for loop expression, as in `for i in 0..10 then print(i)`.
443    For(For),
444
445    /// A then expression, as in `then x += 1`.
446    Then(Then),
447
448    /// An of expression, as in `of x`.
449    Of(Of),
450
451    /// A break expression, used to exit a loop, optionally with a value.
452    Break(Break),
453
454    /// A continue expression, used to skip the rest of a loop iteration.
455    Continue(Continue),
456
457    /// A return expression, as in `return x`, used to return a value from a function.
458    Return(Return),
459}
460
461impl<'source> Parse<'source> for Atom {
462    fn std_parse(
463        input: &mut Parser<'source>,
464        recoverable_errors: &mut Vec<Error>
465    ) -> Result<Self, Vec<Error>> {
466        #[inline]
467        fn parse_no_branch<'source, T: Parse<'source>>(input: &mut Parser<'source>) -> ParseResult<T> {
468            input.try_parse_with_state::<_, T>(|state| {
469                state.allow_then = false;
470                state.allow_of = false;
471            })
472        }
473
474        // definitely not every `parse_no_then` is needed, but it's safest to just try them all
475        // all this is to catch funny business like `if x > 0 { then x }`, where `then` is part
476        // of the body, not directly after the condition; that's invalid
477        let _ = return_if_ok!(parse_no_branch(input).map(Self::Literal).forward_errors(recoverable_errors));
478        let _ = return_if_ok!(parse_no_branch(input).map(Self::Paren).forward_errors(recoverable_errors));
479        let _ = return_if_ok!(parse_no_branch(input).map(Self::Block).forward_errors(recoverable_errors));
480        let _ = return_if_ok!(parse_no_branch(input).map(Self::Sum).forward_errors(recoverable_errors));
481        let _ = return_if_ok!(parse_no_branch(input).map(Self::Product).forward_errors(recoverable_errors));
482        let _ = return_if_ok!(parse_no_branch(input).map(Self::If).forward_errors(recoverable_errors));
483        let _ = return_if_ok!(parse_no_branch(input).map(Self::Loop).forward_errors(recoverable_errors));
484        let _ = return_if_ok!(parse_no_branch(input).map(Self::While).forward_errors(recoverable_errors));
485        let _ = return_if_ok!(parse_no_branch(input).map(Self::For).forward_errors(recoverable_errors));
486        let _ = return_if_ok!(input.try_parse_with_state::<_, _>(|state| {
487            state.allow_of = false;
488        }).map(Self::Then).forward_errors(recoverable_errors));
489        let _ = return_if_ok!(input.try_parse_with_state::<_, _>(|state| {
490            state.allow_then = false;
491        }).map(Self::Of).forward_errors(recoverable_errors));
492        let _ = return_if_ok!(parse_no_branch(input).map(Self::Break).forward_errors(recoverable_errors));
493        let _ = return_if_ok!(parse_no_branch(input).map(Self::Continue).forward_errors(recoverable_errors));
494        parse_no_branch(input).map(Self::Return).forward_errors(recoverable_errors)
495    }
496}
497
498impl From<Atom> for Primary {
499    fn from(atom: Atom) -> Self {
500        match atom {
501            Atom::Literal(literal) => Self::Literal(literal),
502            Atom::Paren(paren) => Self::Paren(paren),
503            Atom::Block(block) => Self::Block(block),
504            Atom::Sum(sum) => Self::Sum(sum),
505            Atom::Product(product) => Self::Product(product),
506            Atom::If(if_expr) => Self::If(if_expr),
507            Atom::Loop(loop_expr) => Self::Loop(loop_expr),
508            Atom::While(while_expr) => Self::While(while_expr),
509            Atom::For(for_expr) => Self::For(for_expr),
510            Atom::Then(then) => Self::Then(then),
511            Atom::Of(of) => Self::Of(of),
512            Atom::Break(break_expr) => Self::Break(break_expr),
513            Atom::Continue(continue_expr) => Self::Continue(continue_expr),
514            Atom::Return(return_expr) => Self::Return(return_expr),
515        }
516    }
517}
518
519impl From<Atom> for Expr {
520    fn from(atom: Atom) -> Self {
521        match atom {
522            Atom::Literal(literal) => Self::Literal(literal),
523            Atom::Paren(paren) => Self::Paren(paren),
524            Atom::Block(block) => Self::Block(block),
525            Atom::Sum(sum) => Self::Sum(sum),
526            Atom::Product(product) => Self::Product(product),
527            Atom::If(if_expr) => Self::If(if_expr),
528            Atom::Loop(loop_expr) => Self::Loop(loop_expr),
529            Atom::While(while_expr) => Self::While(while_expr),
530            Atom::For(for_expr) => Self::For(for_expr),
531            Atom::Then(then) => Self::Then(then),
532            Atom::Of(of) => Self::Of(of),
533            Atom::Break(break_expr) => Self::Break(break_expr),
534            Atom::Continue(continue_expr) => Self::Continue(continue_expr),
535            Atom::Return(return_expr) => Self::Return(return_expr),
536        }
537    }
538}