cas_parser/parser/ast/
assign.rs

1use crate::{
2    parser::{
3        ast::{expr::Expr, helper::ParenDelimited, literal::{Literal, LitSym}},
4        error::{kind::{CompoundAssignmentInHeader, InvalidAssignmentLhs, InvalidCompoundAssignmentLhs}, Error},
5        fmt::Latex,
6        garbage::Garbage,
7        token::op::AssignOp,
8        Parse,
9        Parser,
10        ParseResult,
11    },
12    return_if_ok,
13};
14use std::{fmt, ops::Range};
15
16#[cfg(feature = "serde")]
17use serde::{Deserialize, Serialize};
18
19/// A parameter of a function declaration, such as `x` or `y = 1` in the declaration `f(x, y = 1) =
20/// x^y`.
21#[derive(Debug, Clone, PartialEq)]
22#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
23pub enum Param {
24    /// A parameter with no default value, such as `x` in `f(x) = x^2`.
25    Symbol(LitSym),
26
27    /// A parameter with a default value, such as `y = 1` in `f(x, y = 1) = x^y`.
28    Default(LitSym, Expr),
29}
30
31impl Param {
32    /// Returns the symbol of the parameter.
33    pub fn symbol(&self) -> &LitSym {
34        match self {
35            Param::Symbol(symbol) => symbol,
36            Param::Default(symbol, _) => symbol,
37        }
38    }
39}
40
41impl<'source> Parse<'source> for Param {
42    fn std_parse(
43        input: &mut Parser<'source>,
44        recoverable_errors: &mut Vec<Error>
45    ) -> Result<Self, Vec<Error>> {
46        let symbol = input.try_parse().forward_errors(recoverable_errors)?;
47
48        if let Ok(assign) = input.try_parse::<AssignOp>().forward_errors(recoverable_errors) {
49            if assign.is_compound() {
50                recoverable_errors.push(Error::new(
51                    vec![assign.span.clone()],
52                    CompoundAssignmentInHeader,
53                ));
54            }
55            let default = input.try_parse().forward_errors(recoverable_errors)?;
56            Ok(Param::Default(symbol, default))
57        } else {
58            Ok(Param::Symbol(symbol))
59        }
60    }
61}
62
63impl std::fmt::Display for Param {
64    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65        match self {
66            Param::Symbol(symbol) => write!(f, "{}", symbol),
67            Param::Default(symbol, default) => write!(f, "{} = {}", symbol, default),
68        }
69    }
70}
71
72impl Latex for Param {
73    fn fmt_latex(&self, f: &mut fmt::Formatter) -> fmt::Result {
74        match self {
75            Param::Symbol(symbol) => symbol.fmt_latex(f),
76            Param::Default(symbol, default) => write!(f, "{} = {}", symbol.as_display(), default.as_display()),
77        }
78    }
79}
80
81/// A function header, **not including the body**. Functions can have multiple parameters with
82/// optional default values, like in `f(x, y = 1)`. When a function with this header is called, the
83/// default values are used (i.e. `y = 1`), unless the caller provides their own values (`f(2,
84/// 3)`).
85#[derive(Debug, Clone, PartialEq)]
86#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
87pub struct FuncHeader {
88    /// The name of the function.
89    pub name: LitSym,
90
91    /// The parameters of the function.
92    pub params: Vec<Param>,
93
94    /// The region of the source code that this function header was parsed from.
95    pub span: Range<usize>,
96}
97
98impl FuncHeader {
99    /// Returns the span of the function header.
100    pub fn span(&self) -> Range<usize> {
101        self.span.clone()
102    }
103}
104
105impl<'source> Parse<'source> for FuncHeader {
106    fn std_parse(
107        input: &mut Parser<'source>,
108        recoverable_errors: &mut Vec<Error>
109    ) -> Result<Self, Vec<Error>> {
110        let name = input.try_parse::<LitSym>().forward_errors(recoverable_errors)?;
111        let surrounded = input.try_parse::<ParenDelimited<_>>().forward_errors(recoverable_errors)?;
112
113        let span = name.span.start..surrounded.close.span.end;
114        Ok(Self { name, params: surrounded.value.values, span })
115    }
116}
117
118impl std::fmt::Display for FuncHeader {
119    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
120        write!(f, "{}(", self.name)?;
121        if let Some((last, rest)) = self.params.split_last() {
122            for param in rest {
123                write!(f, "{}, ", param)?;
124            }
125            write!(f, "{}", last)?;
126        }
127        write!(f, ")")
128    }
129}
130
131impl Latex for FuncHeader {
132    fn fmt_latex(&self, f: &mut fmt::Formatter) -> fmt::Result {
133        write!(f, "\\mathrm{{ {} }} \\left(", self.name.as_display())?;
134        if let Some((last, rest)) = self.params.split_last() {
135            for param in rest {
136                param.fmt_latex(f)?;
137                write!(f, ", ")?;
138            }
139            last.fmt_latex(f)?;
140        }
141        write!(f, "\\right)")
142    }
143}
144
145/// An assignment target, such as `x` or `f(x)`.
146#[derive(Debug, Clone, PartialEq)]
147#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
148pub enum AssignTarget {
149    /// A symbol, such as `x`.
150    Symbol(LitSym),
151
152    /// A function, such as `f(x)`.
153    Func(FuncHeader),
154}
155
156impl AssignTarget {
157    /// Returns the span of the assignment target.
158    pub fn span(&self) -> Range<usize> {
159        match self {
160            AssignTarget::Symbol(symbol) => symbol.span.clone(),
161            AssignTarget::Func(func) => func.span(),
162        }
163    }
164
165    /// Tries to convert a general [`Expr`] into an [`AssignTarget`]. This is used when parsing
166    /// assignment expressions, such as `x = 1` or `f(x) = x^2`.
167    pub fn try_from_with_op(expr: Expr, op: &AssignOp) -> ParseResult<Self> {
168        let op_span = op.span.clone();
169        match expr {
170            Expr::Literal(Literal::Symbol(symbol)) => ParseResult::Ok(AssignTarget::Symbol(symbol)),
171            Expr::Call(call) => {
172                let spans = vec![call.span.clone(), op_span.clone()];
173                let error = if op.is_compound() {
174                    Error::new(spans, InvalidCompoundAssignmentLhs)
175                } else {
176                    Error::new(spans, InvalidAssignmentLhs { is_call: true })
177                };
178
179                ParseResult::Recoverable(Garbage::garbage(), vec![error])
180            },
181            expr => {
182                let spans = vec![expr.span(), op_span.clone()];
183                let error = if op.is_compound() {
184                    Error::new(spans, InvalidCompoundAssignmentLhs)
185                } else {
186                    Error::new(spans, InvalidAssignmentLhs { is_call: false })
187                };
188
189                ParseResult::Recoverable(
190                    Garbage::garbage(),
191                    vec![error]
192                )
193            },
194        }
195    }
196}
197
198impl<'source> Parse<'source> for AssignTarget {
199    fn std_parse(
200        input: &mut Parser<'source>,
201        recoverable_errors: &mut Vec<Error>
202    ) -> Result<Self, Vec<Error>> {
203        let _ = return_if_ok!(input.try_parse().map(AssignTarget::Func).forward_errors(recoverable_errors));
204        input.try_parse().map(AssignTarget::Symbol).forward_errors(recoverable_errors)
205    }
206}
207
208impl std::fmt::Display for AssignTarget {
209    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
210        match self {
211            AssignTarget::Symbol(symbol) => write!(f, "{}", symbol),
212            AssignTarget::Func(func) => write!(f, "{}", func),
213        }
214    }
215}
216
217impl Latex for AssignTarget {
218    fn fmt_latex(&self, f: &mut fmt::Formatter) -> fmt::Result {
219        match self {
220            AssignTarget::Symbol(symbol) => symbol.fmt_latex(f),
221            AssignTarget::Func(func) => func.fmt_latex(f),
222        }
223    }
224}
225
226/// An assignment of a variable or function, such as `x = 1` or `f(x) = x^2`.
227#[derive(Debug, Clone, PartialEq)]
228#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
229pub struct Assign {
230    /// The target to assign to.
231    pub target: AssignTarget,
232
233    /// The operator used to assign to the target.
234    pub op: AssignOp,
235
236    /// The expression to assign to the target.
237    pub value: Box<Expr>,
238
239    /// The region of the source code that this assignment expression was parsed from.
240    pub span: Range<usize>,
241}
242
243impl Assign {
244    /// Returns the span of the assignment expression.
245    pub fn span(&self) -> Range<usize> {
246        self.span.clone()
247    }
248
249    /// Returns true if the assignment is to a function, and the function body references itself.
250    pub fn is_recursive(&self) -> bool {
251        if let AssignTarget::Func(header) = &self.target {
252            let is_correct_call = |expr: &Expr| {
253                match expr {
254                    Expr::Call(call) => call.name.name == header.name.name,
255                    _ => false,
256                }
257            };
258
259            self.value.post_order_iter().any(is_correct_call)
260        } else {
261            false
262        }
263    }
264}
265
266impl<'source> Parse<'source> for Assign {
267    fn std_parse(
268        input: &mut Parser<'source>,
269        recoverable_errors: &mut Vec<Error>
270    ) -> Result<Self, Vec<Error>> {
271        let target = input.try_parse().forward_errors(recoverable_errors)?;
272        let op = input.try_parse::<AssignOp>().forward_errors(recoverable_errors)?;
273
274        let value = if matches!(target, AssignTarget::Func(_)) {
275            if op.is_compound() {
276                // can't compound assignment to function, for example:
277                //
278                // f(x) += 5
279                //      ^^
280                recoverable_errors.push(Error::new(
281                    vec![op.span.clone()],
282                    InvalidCompoundAssignmentLhs,
283                ));
284            }
285
286            input.try_parse_with_state::<_, Expr>(|state| {
287                // loop control not allowed inside a function definition inside a loop, for example:
288                //
289                // loop {
290                //     f(x) = break x <-- illegal break
291                //     f(5)
292                // }
293                state.allow_loop_control = false;
294            }).forward_errors(recoverable_errors)?
295        } else {
296            input.try_parse::<Expr>().forward_errors(recoverable_errors)?
297        };
298
299        let span = target.span().start..value.span().end;
300        Ok(Self {
301            target,
302            op,
303            value: Box::new(value),
304            span,
305        })
306    }
307}
308
309impl std::fmt::Display for Assign {
310    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
311        write!(
312            f,
313            "{} {} {}",
314            self.target,
315            self.op,
316            self.value,
317        )
318    }
319}
320
321impl Latex for Assign {
322    fn fmt_latex(&self, f: &mut fmt::Formatter) -> fmt::Result {
323        write!(
324            f,
325            "{} {} {}",
326            self.target.as_display(),
327            self.op.as_display(),
328            self.value.as_display(),
329        )
330    }
331}