1use crate::parser::{
2 ast::{expr::Expr, helper::ParenDelimited, literal::LitSym},
3 error::{kind::TooManyDerivatives, Error},
4 fmt::{Latex, fmt_pow},
5 token::Quote,
6 Parse,
7 Parser,
8};
9use std::{fmt, ops::Range};
10
11#[cfg(feature = "serde")]
12use serde::{Deserialize, Serialize};
13
14#[derive(Debug, Clone, PartialEq)]
16#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
17pub struct Call {
18 pub name: LitSym,
20
21 pub derivatives: u8,
23
24 pub args: Vec<Expr>,
26
27 pub span: Range<usize>,
29
30 pub paren_span: Range<usize>,
32}
33
34impl Call {
35 pub fn span(&self) -> Range<usize> {
37 self.span.clone()
38 }
39
40 pub fn outer_span(&self) -> [Range<usize>; 2] {
43 [
44 self.name.span.start..self.paren_span.start + 1,
45 self.paren_span.end - 1..self.paren_span.end,
46 ]
47 }
48}
49
50impl<'source> Parse<'source> for Call {
51 fn std_parse(
52 input: &mut Parser<'source>,
53 recoverable_errors: &mut Vec<Error>
54 ) -> Result<Self, Vec<Error>> {
55 let name = input.try_parse::<LitSym>().forward_errors(recoverable_errors)?;
56 let mut derivatives = 0usize;
57 let mut quote_span: Option<Range<_>> = None;
58 let mut too_many_derivatives = false;
59
60 while let Ok(quote) = input.try_parse::<Quote>().forward_errors(recoverable_errors) {
61 if derivatives == u8::MAX.into() {
62 too_many_derivatives = true;
63 }
64
65 derivatives += 1;
66 quote_span = quote_span
67 .or_else(|| Some(quote.span.clone()))
68 .map(|span| span.start..quote.span.end);
69 }
70
71 if too_many_derivatives {
72 recoverable_errors.push(Error::new(
73 vec![quote_span.unwrap()],
74 TooManyDerivatives { derivatives }
75 ));
76 }
77
78 let surrounded = input.try_parse::<ParenDelimited<_>>().forward_errors(recoverable_errors)?;
79
80 let span = name.span.start..surrounded.close.span.end;
82 Ok(Self {
83 name,
84 derivatives: derivatives as u8,
85 args: surrounded.value.values,
86 span,
87 paren_span: surrounded.open.span.start..surrounded.close.span.end,
88 })
89 }
90}
91
92impl std::fmt::Display for Call {
93 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
94 self.name.fmt(f)?;
95 for _ in 0..self.derivatives {
96 write!(f, "'")?;
97 }
98 write!(f, "(")?;
99 if let Some((last, args)) = self.args.split_last() {
100 for arg in args {
101 arg.fmt(f)?;
102 write!(f, ", ")?;
103 }
104 last.fmt(f)?;
105 }
106 write!(f, ")")
107 }
108}
109
110impl Latex for Call {
111 fn fmt_latex(&self, f: &mut fmt::Formatter) -> fmt::Result {
112 enum SpecialFunc {
113 Pow,
114 Root,
115 Cbrt,
116 Sqrt,
117 Abs,
118 Other,
119 }
120
121 impl SpecialFunc {
122 fn name(&self, f: &mut fmt::Formatter, call: &Call) -> fmt::Result {
124 match self {
125 Self::Pow => Ok(()),
126 Self::Root => write!(f, "\\sqrt"),
127 Self::Cbrt => write!(f, "\\sqrt[3]"),
128 Self::Sqrt => write!(f, "\\sqrt"),
129 Self::Abs => Ok(()),
130 Self::Other => write!(f, "\\mathrm{{ {} }}", call.name.as_display()),
131 }
132 }
133
134 fn outer_args(&self, f: &mut fmt::Formatter, call: &Call) -> fmt::Result {
136 match self {
137 Self::Pow => self.inner_args(f, call),
138 Self::Root => {
139 self.inner_args(f, call)?;
140 write!(f, "}}")
141 },
142 Self::Cbrt | Self::Sqrt => {
143 write!(f, "{{")?;
144 self.inner_args(f, call)?;
145 write!(f, "}}")
146 },
147 Self::Abs => {
148 write!(f, "\\left|")?;
149 self.inner_args(f, call)?;
150 write!(f, "\\right|")
151 },
152 Self::Other => {
153 write!(f, "\\left(")?;
154 self.inner_args(f, call)?;
155 write!(f, "\\right)")
156 },
157 }
158 }
159
160 fn inner_args(&self, f: &mut fmt::Formatter, call: &Call) -> fmt::Result {
161 match self {
162 Self::Pow => fmt_pow(f, call.args.first(), call.args.get(1))?,
163 Self::Root => {
164 if let Some(arg1) = call.args.get(1) {
165 write!(f, "[{}]", arg1.as_display())?;
166 }
167 write!(f, "{{")?;
168 if let Some(arg0) = call.args.first() {
169 arg0.fmt_latex(f)?;
170 }
171 },
172 Self::Cbrt | Self::Sqrt | Self::Abs | Self::Other => {
173 if let Some((last, args)) = call.args.split_last() {
174 for arg in args {
175 arg.fmt_latex(f)?;
176 write!(f, ", ")?;
177 }
178 last.fmt_latex(f)?;
179 }
180 },
181 }
182
183 Ok(())
184 }
185 }
186
187 let func = match self.name.name.as_str() {
188 "pow" => SpecialFunc::Pow,
189 "root" => SpecialFunc::Root,
190 "cbrt" => SpecialFunc::Cbrt,
191 "sqrt" => SpecialFunc::Sqrt,
192 "abs" => SpecialFunc::Abs,
193 _ => SpecialFunc::Other,
194 };
195
196 func.name(f, self)?;
197 match self.derivatives {
198 0 => {},
199 1 => write!(f, "'")?,
200 2 => write!(f, "''")?,
201 n => write!(f, "^{{ ({}) }}", n)?,
202 }
203
204 func.outer_args(f, self)
205 }
206}