cas_parser/parser/ast/
assign.rs

1use cas_error::Error;
2use crate::{
3    parser::{
4        ast::{
5            expr::{Atom, Expr, Primary},
6            helper::Surrounded,
7            index::Index,
8            literal::{Literal, LitSym},
9        },
10        error::{
11            CompoundAssignmentInHeader,
12            DefaultArgumentNotLast,
13            ExpectedExpr,
14            InvalidAssignmentLhs,
15            InvalidCompoundAssignmentLhs,
16        },
17        fmt::Latex,
18        garbage::Garbage,
19        token::{op::AssignOp, Comma, OpenParen},
20        Parse,
21        Parser,
22        ParseResult,
23    },
24    tokenizer::TokenKind,
25};
26use std::{fmt, ops::Range};
27
28#[cfg(feature = "serde")]
29use serde::{Deserialize, Serialize};
30
31/// A parameter of a function declaration, such as `x` or `y = 1` in the declaration `f(x, y = 1) =
32/// x^y`.
33#[derive(Debug, Clone, PartialEq, Eq)]
34#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
35pub enum Param {
36    /// A parameter with no default value, such as `x` in `f(x) = x^2`.
37    Symbol(LitSym),
38
39    /// A parameter with a default value, such as `y = 1` in `f(x, y = 1) = x^y`.
40    Default(LitSym, Expr),
41}
42
43impl Param {
44    /// Returns the span of the parameter.
45    pub fn span(&self) -> Range<usize> {
46        match self {
47            Param::Symbol(symbol) => symbol.span.clone(),
48            Param::Default(symbol, default) => symbol.span.start..default.span().end,
49        }
50    }
51
52    /// Returns the symbol of the parameter.
53    pub fn symbol(&self) -> &LitSym {
54        match self {
55            Param::Symbol(symbol) => symbol,
56            Param::Default(symbol, _) => symbol,
57        }
58    }
59
60    /// Returns true if the parameter has a default value.
61    pub fn has_default(&self) -> bool {
62        matches!(self, Param::Default(_, _))
63    }
64}
65
66impl<'source> Parse<'source> for Param {
67    fn std_parse(
68        input: &mut Parser<'source>,
69        recoverable_errors: &mut Vec<Error>
70    ) -> Result<Self, Vec<Error>> {
71        let symbol = input.try_parse().forward_errors(recoverable_errors)?;
72
73        if let Ok(assign) = input.try_parse::<AssignOp>().forward_errors(recoverable_errors) {
74            if assign.is_compound() {
75                recoverable_errors.push(Error::new(
76                    vec![assign.span.clone()],
77                    CompoundAssignmentInHeader,
78                ));
79            }
80            let default = input.try_parse().forward_errors(recoverable_errors)?;
81            Ok(Param::Default(symbol, default))
82        } else {
83            Ok(Param::Symbol(symbol))
84        }
85    }
86}
87
88impl std::fmt::Display for Param {
89    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
90        match self {
91            Param::Symbol(symbol) => write!(f, "{}", symbol),
92            Param::Default(symbol, default) => write!(f, "{} = {}", symbol, default),
93        }
94    }
95}
96
97impl Latex for Param {
98    fn fmt_latex(&self, f: &mut fmt::Formatter) -> fmt::Result {
99        match self {
100            Param::Symbol(symbol) => symbol.fmt_latex(f),
101            Param::Default(symbol, default) => write!(f, "{} = {}", symbol.as_display(), default.as_display()),
102        }
103    }
104}
105
106/// A function header, **not including the body**. Functions can have multiple parameters with
107/// optional default values, like in `f(x, y = 1)`. When a function with this header is called, the
108/// default values are used (i.e. `y = 1`), unless the caller provides their own values (`f(2,
109/// 3)`).
110#[derive(Debug, Clone, PartialEq, Eq)]
111#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
112pub struct FuncHeader {
113    /// The name of the function.
114    pub name: LitSym,
115
116    /// The parameters of the function.
117    pub params: Vec<Param>,
118
119    /// The region of the source code that this function header was parsed from.
120    pub span: Range<usize>,
121}
122
123impl FuncHeader {
124    /// Returns the span of the function header.
125    pub fn span(&self) -> Range<usize> {
126        self.span.clone()
127    }
128
129    /// Attempts to parse a [`FuncHeader`], where the function name has already been parsed.
130    fn parse_or_lower(
131        input: &mut Parser,
132        recoverable_errors: &mut Vec<Error>,
133        name: LitSym,
134    ) -> Result<Self, Vec<Error>> {
135        /// Helper duplicate of the `Delimited` helper struct with additional state to ensure
136        /// default parameters in the correct position.
137        struct FuncHeaderInner {
138            values: Vec<Param>,
139        }
140
141        impl<'source> Parse<'source> for FuncHeaderInner {
142            fn std_parse(
143                input: &mut Parser<'source>,
144                recoverable_errors: &mut Vec<Error>
145            ) -> Result<Self, Vec<Error>> {
146                let mut bad_default_position = false;
147                let mut default_params = Vec::new();
148                let mut values = Vec::new();
149
150                loop {
151                    let Ok(value) = input.try_parse().forward_errors(recoverable_errors) else {
152                        break;
153                    };
154
155                    // default parameters must be at the end of the list, i.e., no required
156                    // parameters should come after
157                    if !default_params.is_empty() && !bad_default_position {
158                        if let Param::Symbol(_) = value {
159                            bad_default_position = true;
160                        }
161                    }
162
163                    if let Param::Default(_, _) = value {
164                        default_params.push(value.span());
165                    }
166
167                    values.push(value);
168
169                    if input.try_parse::<Comma>().forward_errors(recoverable_errors).is_err() {
170                        break;
171                    }
172                }
173
174                if bad_default_position {
175                    recoverable_errors.push(Error::new(
176                        default_params,
177                        DefaultArgumentNotLast,
178                    ));
179                }
180
181                Ok(Self { values })
182            }
183        }
184
185        let surrounded = input.try_parse::<Surrounded<OpenParen, FuncHeaderInner>>()
186            .forward_errors(recoverable_errors)?;
187
188        let span = name.span.start..surrounded.close.span.end;
189        Ok(Self { name, params: surrounded.value.values, span })
190    }
191}
192
193impl std::fmt::Display for FuncHeader {
194    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195        write!(f, "{}(", self.name)?;
196        if let Some((last, rest)) = self.params.split_last() {
197            for param in rest {
198                write!(f, "{}, ", param)?;
199            }
200            write!(f, "{}", last)?;
201        }
202        write!(f, ")")
203    }
204}
205
206impl Latex for FuncHeader {
207    fn fmt_latex(&self, f: &mut fmt::Formatter) -> fmt::Result {
208        write!(f, "\\mathrm{{ {} }} \\left(", self.name.as_display())?;
209        if let Some((last, rest)) = self.params.split_last() {
210            for param in rest {
211                param.fmt_latex(f)?;
212                write!(f, ", ")?;
213            }
214            last.fmt_latex(f)?;
215        }
216        write!(f, "\\right)")
217    }
218}
219
220/// An assignment target, such as `x`, `list[0]`, or `f(x)`.
221#[derive(Debug, Clone, PartialEq, Eq)]
222#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
223pub enum AssignTarget {
224    /// A symbol, such as `x`.
225    Symbol(LitSym),
226
227    /// A list index, such as `list[0]`.
228    Index(Index),
229
230    /// A function header, such as `f(x)`.
231    Func(FuncHeader),
232}
233
234impl AssignTarget {
235    /// Returns the span of the assignment target.
236    pub fn span(&self) -> Range<usize> {
237        match self {
238            AssignTarget::Symbol(symbol) => symbol.span.clone(),
239            AssignTarget::Index(index) => index.span(),
240            AssignTarget::Func(func) => func.span(),
241        }
242    }
243
244    /// Returns true if the assignment target is a function.
245    pub fn is_func(&self) -> bool {
246        matches!(self, AssignTarget::Func(_))
247    }
248
249    /// Tries to convert a general [`Expr`] into an [`AssignTarget`]. This is used when parsing
250    /// assignment expressions, such as `x = 1` or `f(x) = x^2`.
251    pub fn try_from_with_op(expr: Expr, op: &AssignOp) -> ParseResult<Self> {
252        let op_span = op.span.clone();
253        match expr {
254            Expr::Literal(Literal::Symbol(symbol)) => ParseResult::Ok(AssignTarget::Symbol(symbol)),
255            Expr::Index(index) => ParseResult::Ok(AssignTarget::Index(index)),
256            Expr::Call(call) => {
257                let spans = vec![call.span.clone(), op_span.clone()];
258                let error = if op.is_compound() {
259                    Error::new(spans, InvalidCompoundAssignmentLhs)
260                } else {
261                    Error::new(spans, InvalidAssignmentLhs { is_call: true })
262                };
263
264                ParseResult::Recoverable(Garbage::garbage(), vec![error])
265            },
266            expr => {
267                let spans = vec![expr.span(), op_span.clone()];
268                let error = if op.is_compound() {
269                    Error::new(spans, InvalidCompoundAssignmentLhs)
270                } else {
271                    Error::new(spans, InvalidAssignmentLhs { is_call: false })
272                };
273
274                ParseResult::Recoverable(
275                    Garbage::garbage(),
276                    vec![error]
277                )
278            },
279        }
280    }
281}
282
283impl From<LitSym> for AssignTarget {
284    fn from(symbol: LitSym) -> Self {
285        AssignTarget::Symbol(symbol)
286    }
287}
288
289impl From<Index> for AssignTarget {
290    fn from(index: Index) -> Self {
291        AssignTarget::Index(index)
292    }
293}
294
295impl From<FuncHeader> for AssignTarget {
296    fn from(func: FuncHeader) -> Self {
297        AssignTarget::Func(func)
298    }
299}
300
301impl<'source> Parse<'source> for AssignTarget {
302    fn std_parse(
303        input: &mut Parser<'source>,
304        recoverable_errors: &mut Vec<Error>
305    ) -> Result<Self, Vec<Error>> {
306        // this uses a similar approach to Primary::parse, where we try to parse an Atom first
307        // and then check if it's followed by an open parenthesis or open square bracket
308        // to determine if the target is a function or index
309        let atom = input.try_parse::<Atom>().forward_errors(recoverable_errors)?;
310
311        let mut fork = input.clone();
312        match fork.next_token() {
313            Ok(next) if next.kind == TokenKind::OpenParen => {
314                if let Atom::Literal(Literal::Symbol(symbol)) = atom {
315                    Ok(FuncHeader::parse_or_lower(input, recoverable_errors, symbol)
316                        .map(Into::into)?)
317                } else {
318                    Err(vec![input.error(ExpectedExpr { expected: "a symbol" })])
319                }
320            },
321            Ok(next) if next.kind == TokenKind::OpenSquare => {
322                match Index::parse_or_lower(input, recoverable_errors, atom.into()) {
323                    (new_primary, true) => match new_primary {
324                        Primary::Index(index) => Ok(AssignTarget::Index(index)),
325                        _ => unreachable!(),
326                    },
327                    (unchanged_primary, false) => match unchanged_primary {
328                        Primary::Literal(Literal::Symbol(symbol)) => Ok(AssignTarget::Symbol(symbol)),
329                        _ => unreachable!(),
330                    },
331                }
332            },
333            _ => if let Atom::Literal(Literal::Symbol(symbol)) = atom {
334                Ok(AssignTarget::Symbol(symbol))
335            } else {
336                Err(vec![input.error(ExpectedExpr { expected: "a symbol" })])
337            },
338        }
339    }
340}
341
342impl std::fmt::Display for AssignTarget {
343    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
344        match self {
345            AssignTarget::Symbol(symbol) => write!(f, "{}", symbol),
346            AssignTarget::Index(index) => write!(f, "{}", index),
347            AssignTarget::Func(func) => write!(f, "{}", func),
348        }
349    }
350}
351
352impl Latex for AssignTarget {
353    fn fmt_latex(&self, f: &mut fmt::Formatter) -> fmt::Result {
354        match self {
355            AssignTarget::Symbol(symbol) => symbol.fmt_latex(f),
356            AssignTarget::Index(index) => index.fmt_latex(f),
357            AssignTarget::Func(func) => func.fmt_latex(f),
358        }
359    }
360}
361
362/// An assignment of a variable or function, such as `x = 1` or `f(x) = x^2`.
363#[derive(Debug, Clone, PartialEq, Eq)]
364#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
365pub struct Assign {
366    /// The target to assign to.
367    pub target: AssignTarget,
368
369    /// The operator used to assign to the target.
370    pub op: AssignOp,
371
372    /// The expression to assign to the target.
373    pub value: Box<Expr>,
374
375    /// The region of the source code that this assignment expression was parsed from.
376    pub span: Range<usize>,
377}
378
379impl Assign {
380    /// Returns the span of the assignment expression.
381    pub fn span(&self) -> Range<usize> {
382        self.span.clone()
383    }
384}
385
386impl<'source> Parse<'source> for Assign {
387    fn std_parse(
388        input: &mut Parser<'source>,
389        recoverable_errors: &mut Vec<Error>
390    ) -> Result<Self, Vec<Error>> {
391        let target = input.try_parse().forward_errors(recoverable_errors)?;
392        let op = input.try_parse::<AssignOp>().forward_errors(recoverable_errors)?;
393
394        let value = if matches!(target, AssignTarget::Func(_)) {
395            if op.is_compound() {
396                // can't compound assignment to function, for example:
397                //
398                // f(x) += 5
399                //      ^^
400                recoverable_errors.push(Error::new(
401                    vec![target.span(), op.span.clone()],
402                    InvalidCompoundAssignmentLhs,
403                ));
404            }
405
406            input.try_parse_with_state::<_, Expr>(|state| {
407                // loop control not allowed inside a function definition inside a loop, for example:
408                //
409                // loop {
410                //     f(x) = break x <-- illegal break
411                //     f(5)
412                // }
413                state.allow_loop_control = false;
414                state.allow_return = true;
415            }).forward_errors(recoverable_errors)?
416        } else {
417            input.try_parse::<Expr>().forward_errors(recoverable_errors)?
418        };
419
420        let span = target.span().start..value.span().end;
421        Ok(Self {
422            target,
423            op,
424            value: Box::new(value),
425            span,
426        })
427    }
428}
429
430impl std::fmt::Display for Assign {
431    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
432        write!(
433            f,
434            "{} {} {}",
435            self.target,
436            self.op,
437            self.value,
438        )
439    }
440}
441
442impl Latex for Assign {
443    fn fmt_latex(&self, f: &mut fmt::Formatter) -> fmt::Result {
444        write!(
445            f,
446            "{} {} {}",
447            self.target.as_display(),
448            self.op.as_display(),
449            self.value.as_display(),
450        )
451    }
452}