cas_parser/parser/ast/
call.rs

1use crate::parser::{
2    ast::{expr::Expr, helper::ParenDelimited, literal::LitSym},
3    error::{kind::TooManyDerivatives, Error},
4    fmt::{Latex, fmt_pow},
5    token::Quote,
6    Parse,
7    Parser,
8};
9use std::{fmt, ops::Range};
10
11#[cfg(feature = "serde")]
12use serde::{Deserialize, Serialize};
13
14/// A function call, such as `func(x, -40)`.
15#[derive(Debug, Clone, PartialEq)]
16#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
17pub struct Call {
18    /// The name of the function to call.
19    pub name: LitSym,
20
21    /// The number of derivatives to take before calling the function.
22    pub derivatives: u8,
23
24    /// The arguments to the function.
25    pub args: Vec<Expr>,
26
27    /// The region of the source code that this function call was parsed from.
28    pub span: Range<usize>,
29
30    /// The span of the parentheses that surround the arguments.
31    pub paren_span: Range<usize>,
32}
33
34impl Call {
35    /// Returns the span of the function call.
36    pub fn span(&self) -> Range<usize> {
37        self.span.clone()
38    }
39
40    /// Returns a set of two spans, where the first is the span of the function name (with the
41    /// opening parenthesis) and the second is the span of the closing parenthesis.
42    pub fn outer_span(&self) -> [Range<usize>; 2] {
43        [
44            self.name.span.start..self.paren_span.start + 1,
45            self.paren_span.end - 1..self.paren_span.end,
46        ]
47    }
48}
49
50impl<'source> Parse<'source> for Call {
51    fn std_parse(
52        input: &mut Parser<'source>,
53        recoverable_errors: &mut Vec<Error>
54    ) -> Result<Self, Vec<Error>> {
55        let name = input.try_parse::<LitSym>().forward_errors(recoverable_errors)?;
56        let mut derivatives = 0usize;
57        let mut quote_span: Option<Range<_>> = None;
58        let mut too_many_derivatives = false;
59
60        while let Ok(quote) = input.try_parse::<Quote>().forward_errors(recoverable_errors) {
61            if derivatives == u8::MAX.into() {
62                too_many_derivatives = true;
63            }
64
65            derivatives += 1;
66            quote_span = quote_span
67                .or_else(|| Some(quote.span.clone()))
68                .map(|span| span.start..quote.span.end);
69        }
70
71        if too_many_derivatives {
72            recoverable_errors.push(Error::new(
73                vec![quote_span.unwrap()],
74                TooManyDerivatives { derivatives }
75            ));
76        }
77
78        let surrounded = input.try_parse::<ParenDelimited<_>>().forward_errors(recoverable_errors)?;
79
80        // use `name` here before it is moved into the struct
81        let span = name.span.start..surrounded.close.span.end;
82        Ok(Self {
83            name,
84            derivatives: derivatives as u8,
85            args: surrounded.value.values,
86            span,
87            paren_span: surrounded.open.span.start..surrounded.close.span.end,
88        })
89    }
90}
91
92impl std::fmt::Display for Call {
93    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
94        self.name.fmt(f)?;
95        for _ in 0..self.derivatives {
96            write!(f, "'")?;
97        }
98        write!(f, "(")?;
99        if let Some((last, args)) = self.args.split_last() {
100            for arg in args {
101                arg.fmt(f)?;
102                write!(f, ", ")?;
103            }
104            last.fmt(f)?;
105        }
106        write!(f, ")")
107    }
108}
109
110impl Latex for Call {
111    fn fmt_latex(&self, f: &mut fmt::Formatter) -> fmt::Result {
112        enum SpecialFunc {
113            Pow,
114            Root,
115            Cbrt,
116            Sqrt,
117            Abs,
118            Other,
119        }
120
121        impl SpecialFunc {
122            /// Write the name of the function.
123            fn name(&self, f: &mut fmt::Formatter, call: &Call) -> fmt::Result {
124                match self {
125                    Self::Pow => Ok(()),
126                    Self::Root => write!(f, "\\sqrt"),
127                    Self::Cbrt => write!(f, "\\sqrt[3]"),
128                    Self::Sqrt => write!(f, "\\sqrt"),
129                    Self::Abs => Ok(()),
130                    Self::Other => write!(f, "\\mathrm{{ {} }}", call.name.as_display()),
131                }
132            }
133
134            /// Write the tokens surrounding the arguments, and delegate the arguments to `inner_args`.
135            fn outer_args(&self, f: &mut fmt::Formatter, call: &Call) -> fmt::Result {
136                match self {
137                    Self::Pow => self.inner_args(f, call),
138                    Self::Root => {
139                        self.inner_args(f, call)?;
140                        write!(f, "}}")
141                    },
142                    Self::Cbrt | Self::Sqrt => {
143                        write!(f, "{{")?;
144                        self.inner_args(f, call)?;
145                        write!(f, "}}")
146                    },
147                    Self::Abs => {
148                        write!(f, "\\left|")?;
149                        self.inner_args(f, call)?;
150                        write!(f, "\\right|")
151                    },
152                    Self::Other => {
153                        write!(f, "\\left(")?;
154                        self.inner_args(f, call)?;
155                        write!(f, "\\right)")
156                    },
157                }
158            }
159
160            fn inner_args(&self, f: &mut fmt::Formatter, call: &Call) -> fmt::Result {
161                match self {
162                    Self::Pow => fmt_pow(f, call.args.first(), call.args.get(1))?,
163                    Self::Root => {
164                        if let Some(arg1) = call.args.get(1) {
165                            write!(f, "[{}]", arg1.as_display())?;
166                        }
167                        write!(f, "{{")?;
168                        if let Some(arg0) = call.args.first() {
169                            arg0.fmt_latex(f)?;
170                        }
171                    },
172                    Self::Cbrt | Self::Sqrt | Self::Abs | Self::Other => {
173                        if let Some((last, args)) = call.args.split_last() {
174                            for arg in args {
175                                arg.fmt_latex(f)?;
176                                write!(f, ", ")?;
177                            }
178                            last.fmt_latex(f)?;
179                        }
180                    },
181                }
182
183                Ok(())
184            }
185        }
186
187        let func = match self.name.name.as_str() {
188            "pow" => SpecialFunc::Pow,
189            "root" => SpecialFunc::Root,
190            "cbrt" => SpecialFunc::Cbrt,
191            "sqrt" => SpecialFunc::Sqrt,
192            "abs" => SpecialFunc::Abs,
193            _ => SpecialFunc::Other,
194        };
195
196        func.name(f, self)?;
197        match self.derivatives {
198            0 => {},
199            1 => write!(f, "'")?,
200            2 => write!(f, "''")?,
201            n => write!(f, "^{{ ({}) }}", n)?,
202        }
203
204        func.outer_args(f, self)
205    }
206}