cas_parser/parser/ast/
call.rs

1use cas_error::Error;
2use crate::parser::{
3    ast::{expr::{Expr, Primary}, helper::ParenDelimited, literal::{Literal, LitSym}},
4    error::TooManyDerivatives,
5    fmt::{Latex, fmt_pow},
6    token::Quote,
7    Parser,
8};
9use std::{fmt, ops::Range};
10
11#[cfg(feature = "serde")]
12use serde::{Deserialize, Serialize};
13
14/// A function call, such as `func(x, -40)`.
15#[derive(Debug, Clone, PartialEq, Eq)]
16#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
17pub struct Call {
18    /// The name of the function to call.
19    pub name: LitSym,
20
21    /// The number of derivatives to take before calling the function.
22    pub derivatives: u8,
23
24    /// The arguments to the function.
25    pub args: Vec<Expr>,
26
27    /// The region of the source code that this function call was parsed from.
28    pub span: Range<usize>,
29
30    /// The span of the parentheses that surround the arguments.
31    pub paren_span: Range<usize>,
32}
33
34impl Call {
35    /// Returns the span of the function call.
36    pub fn span(&self) -> Range<usize> {
37        self.span.clone()
38    }
39
40    /// Returns a span that spans the selected arguments, given by index.
41    pub fn arg_span(&self, args: Range<usize>) -> Range<usize> {
42        let first = self.args[args.start].span().start;
43        let last = self.args[args.end].span().end;
44        first..last
45    }
46
47    /// Returns a set of two spans, where the first is the span of the function name (with the
48    /// opening parenthesis) and the second is the span of the closing parenthesis.
49    pub fn outer_span(&self) -> [Range<usize>; 2] {
50        [
51            self.name.span.start..self.paren_span.start + 1,
52            self.paren_span.end - 1..self.paren_span.end,
53        ]
54    }
55
56    /// Attempts to parse a [`Call`], where the initial target has already been parsed.
57    ///
58    /// Besides the returned [`Primary`], the return value also includes a boolean that indicates
59    /// if the expression was changed due to successfully parsing a [`Call`]. This function can
60    /// return `Ok` even if no [`Call`], which occurs when we determine that we shouldn't have
61    /// taken the [`Call`] path. The boolean is used to let the caller know that this is was the
62    /// case.
63    ///
64    /// This is similar to what we had to do with [`Binary`].
65    ///
66    /// [`Binary`]: crate::parser::ast::binary::Binary
67    pub fn parse_or_lower(
68        input: &mut Parser,
69        recoverable_errors: &mut Vec<Error>,
70        target: Primary,
71    ) -> Result<(Primary, bool), Vec<Error>> {
72        let name = match target {
73            Primary::Literal(Literal::Symbol(name)) => name,
74            target => return Ok((target, false)),
75        };
76
77        let mut derivatives = 0usize;
78        let mut quote_span: Option<Range<_>> = None;
79        let mut too_many_derivatives = false;
80
81        while let Ok(quote) = input.try_parse::<Quote>().forward_errors(recoverable_errors) {
82            if derivatives == u8::MAX.into() {
83                too_many_derivatives = true;
84            }
85
86            derivatives += 1;
87            quote_span = quote_span
88                .or_else(|| Some(quote.span.clone()))
89                .map(|span| span.start..quote.span.end);
90        }
91
92        if too_many_derivatives {
93            recoverable_errors.push(Error::new(
94                vec![quote_span.unwrap()],
95                TooManyDerivatives { derivatives }
96            ));
97        }
98
99        let surrounded = input.try_parse::<ParenDelimited<_>>().forward_errors(recoverable_errors)?;
100
101        // use `name` here before it is moved into the struct
102        let span = name.span.start..surrounded.close.span.end;
103        Ok((Primary::Call(Self {
104            name,
105            derivatives: derivatives as u8,
106            args: surrounded.value.values,
107            span,
108            paren_span: surrounded.open.span.start..surrounded.close.span.end,
109        }), true))
110    }
111}
112
113impl std::fmt::Display for Call {
114    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
115        self.name.fmt(f)?;
116        for _ in 0..self.derivatives {
117            write!(f, "'")?;
118        }
119        write!(f, "(")?;
120        if let Some((last, args)) = self.args.split_last() {
121            for arg in args {
122                arg.fmt(f)?;
123                write!(f, ", ")?;
124            }
125            last.fmt(f)?;
126        }
127        write!(f, ")")
128    }
129}
130
131impl Latex for Call {
132    fn fmt_latex(&self, f: &mut fmt::Formatter) -> fmt::Result {
133        enum SpecialFunc {
134            Pow,
135            Root,
136            Cbrt,
137            Sqrt,
138            Abs,
139            Other,
140        }
141
142        impl SpecialFunc {
143            /// Write the name of the function.
144            fn name(&self, f: &mut fmt::Formatter, call: &Call) -> fmt::Result {
145                match self {
146                    Self::Pow => Ok(()),
147                    Self::Root => write!(f, "\\sqrt"),
148                    Self::Cbrt => write!(f, "\\sqrt[3]"),
149                    Self::Sqrt => write!(f, "\\sqrt"),
150                    Self::Abs => Ok(()),
151                    Self::Other => write!(f, "\\mathrm{{ {} }}", call.name.as_display()),
152                }
153            }
154
155            /// Write the tokens surrounding the arguments, and delegate the arguments to `inner_args`.
156            fn outer_args(&self, f: &mut fmt::Formatter, call: &Call) -> fmt::Result {
157                match self {
158                    Self::Pow => self.inner_args(f, call),
159                    Self::Root => {
160                        self.inner_args(f, call)?;
161                        write!(f, "}}")
162                    },
163                    Self::Cbrt | Self::Sqrt => {
164                        write!(f, "{{")?;
165                        self.inner_args(f, call)?;
166                        write!(f, "}}")
167                    },
168                    Self::Abs => {
169                        write!(f, "\\left|")?;
170                        self.inner_args(f, call)?;
171                        write!(f, "\\right|")
172                    },
173                    Self::Other => {
174                        write!(f, "\\left(")?;
175                        self.inner_args(f, call)?;
176                        write!(f, "\\right)")
177                    },
178                }
179            }
180
181            fn inner_args(&self, f: &mut fmt::Formatter, call: &Call) -> fmt::Result {
182                match self {
183                    Self::Pow => fmt_pow(f, call.args.first(), call.args.get(1))?,
184                    Self::Root => {
185                        if let Some(arg1) = call.args.get(1) {
186                            write!(f, "[{}]", arg1.as_display())?;
187                        }
188                        write!(f, "{{")?;
189                        if let Some(arg0) = call.args.first() {
190                            arg0.fmt_latex(f)?;
191                        }
192                    },
193                    Self::Cbrt | Self::Sqrt | Self::Abs | Self::Other => {
194                        if let Some((last, args)) = call.args.split_last() {
195                            for arg in args {
196                                arg.fmt_latex(f)?;
197                                write!(f, ", ")?;
198                            }
199                            last.fmt_latex(f)?;
200                        }
201                    },
202                }
203
204                Ok(())
205            }
206        }
207
208        let func = match self.name.name.as_str() {
209            "pow" => SpecialFunc::Pow,
210            "root" => SpecialFunc::Root,
211            "cbrt" => SpecialFunc::Cbrt,
212            "sqrt" => SpecialFunc::Sqrt,
213            "abs" => SpecialFunc::Abs,
214            _ => SpecialFunc::Other,
215        };
216
217        func.name(f, self)?;
218        match self.derivatives {
219            0 => {},
220            1 => write!(f, "'")?,
221            2 => write!(f, "''")?,
222            n => write!(f, "^{{ ({}) }}", n)?,
223        }
224
225        func.outer_args(f, self)
226    }
227}