1use cas_error::Error;
2use crate::parser::{
3 ast::{expr::{Expr, Primary}, helper::ParenDelimited, literal::{Literal, LitSym}},
4 error::TooManyDerivatives,
5 fmt::{Latex, fmt_pow},
6 token::Quote,
7 Parser,
8};
9use std::{fmt, ops::Range};
10
11#[cfg(feature = "serde")]
12use serde::{Deserialize, Serialize};
13
14#[derive(Debug, Clone, PartialEq, Eq)]
16#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
17pub struct Call {
18 pub name: LitSym,
20
21 pub derivatives: u8,
23
24 pub args: Vec<Expr>,
26
27 pub span: Range<usize>,
29
30 pub paren_span: Range<usize>,
32}
33
34impl Call {
35 pub fn span(&self) -> Range<usize> {
37 self.span.clone()
38 }
39
40 pub fn arg_span(&self, args: Range<usize>) -> Range<usize> {
42 let first = self.args[args.start].span().start;
43 let last = self.args[args.end].span().end;
44 first..last
45 }
46
47 pub fn outer_span(&self) -> [Range<usize>; 2] {
50 [
51 self.name.span.start..self.paren_span.start + 1,
52 self.paren_span.end - 1..self.paren_span.end,
53 ]
54 }
55
56 pub fn parse_or_lower(
68 input: &mut Parser,
69 recoverable_errors: &mut Vec<Error>,
70 target: Primary,
71 ) -> Result<(Primary, bool), Vec<Error>> {
72 let name = match target {
73 Primary::Literal(Literal::Symbol(name)) => name,
74 target => return Ok((target, false)),
75 };
76
77 let mut derivatives = 0usize;
78 let mut quote_span: Option<Range<_>> = None;
79 let mut too_many_derivatives = false;
80
81 while let Ok(quote) = input.try_parse::<Quote>().forward_errors(recoverable_errors) {
82 if derivatives == u8::MAX.into() {
83 too_many_derivatives = true;
84 }
85
86 derivatives += 1;
87 quote_span = quote_span
88 .or_else(|| Some(quote.span.clone()))
89 .map(|span| span.start..quote.span.end);
90 }
91
92 if too_many_derivatives {
93 recoverable_errors.push(Error::new(
94 vec![quote_span.unwrap()],
95 TooManyDerivatives { derivatives }
96 ));
97 }
98
99 let surrounded = input.try_parse::<ParenDelimited<_>>().forward_errors(recoverable_errors)?;
100
101 let span = name.span.start..surrounded.close.span.end;
103 Ok((Primary::Call(Self {
104 name,
105 derivatives: derivatives as u8,
106 args: surrounded.value.values,
107 span,
108 paren_span: surrounded.open.span.start..surrounded.close.span.end,
109 }), true))
110 }
111}
112
113impl std::fmt::Display for Call {
114 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
115 self.name.fmt(f)?;
116 for _ in 0..self.derivatives {
117 write!(f, "'")?;
118 }
119 write!(f, "(")?;
120 if let Some((last, args)) = self.args.split_last() {
121 for arg in args {
122 arg.fmt(f)?;
123 write!(f, ", ")?;
124 }
125 last.fmt(f)?;
126 }
127 write!(f, ")")
128 }
129}
130
131impl Latex for Call {
132 fn fmt_latex(&self, f: &mut fmt::Formatter) -> fmt::Result {
133 enum SpecialFunc {
134 Pow,
135 Root,
136 Cbrt,
137 Sqrt,
138 Abs,
139 Other,
140 }
141
142 impl SpecialFunc {
143 fn name(&self, f: &mut fmt::Formatter, call: &Call) -> fmt::Result {
145 match self {
146 Self::Pow => Ok(()),
147 Self::Root => write!(f, "\\sqrt"),
148 Self::Cbrt => write!(f, "\\sqrt[3]"),
149 Self::Sqrt => write!(f, "\\sqrt"),
150 Self::Abs => Ok(()),
151 Self::Other => write!(f, "\\mathrm{{ {} }}", call.name.as_display()),
152 }
153 }
154
155 fn outer_args(&self, f: &mut fmt::Formatter, call: &Call) -> fmt::Result {
157 match self {
158 Self::Pow => self.inner_args(f, call),
159 Self::Root => {
160 self.inner_args(f, call)?;
161 write!(f, "}}")
162 },
163 Self::Cbrt | Self::Sqrt => {
164 write!(f, "{{")?;
165 self.inner_args(f, call)?;
166 write!(f, "}}")
167 },
168 Self::Abs => {
169 write!(f, "\\left|")?;
170 self.inner_args(f, call)?;
171 write!(f, "\\right|")
172 },
173 Self::Other => {
174 write!(f, "\\left(")?;
175 self.inner_args(f, call)?;
176 write!(f, "\\right)")
177 },
178 }
179 }
180
181 fn inner_args(&self, f: &mut fmt::Formatter, call: &Call) -> fmt::Result {
182 match self {
183 Self::Pow => fmt_pow(f, call.args.first(), call.args.get(1))?,
184 Self::Root => {
185 if let Some(arg1) = call.args.get(1) {
186 write!(f, "[{}]", arg1.as_display())?;
187 }
188 write!(f, "{{")?;
189 if let Some(arg0) = call.args.first() {
190 arg0.fmt_latex(f)?;
191 }
192 },
193 Self::Cbrt | Self::Sqrt | Self::Abs | Self::Other => {
194 if let Some((last, args)) = call.args.split_last() {
195 for arg in args {
196 arg.fmt_latex(f)?;
197 write!(f, ", ")?;
198 }
199 last.fmt_latex(f)?;
200 }
201 },
202 }
203
204 Ok(())
205 }
206 }
207
208 let func = match self.name.name.as_str() {
209 "pow" => SpecialFunc::Pow,
210 "root" => SpecialFunc::Root,
211 "cbrt" => SpecialFunc::Cbrt,
212 "sqrt" => SpecialFunc::Sqrt,
213 "abs" => SpecialFunc::Abs,
214 _ => SpecialFunc::Other,
215 };
216
217 func.name(f, self)?;
218 match self.derivatives {
219 0 => {},
220 1 => write!(f, "'")?,
221 2 => write!(f, "''")?,
222 n => write!(f, "^{{ ({}) }}", n)?,
223 }
224
225 func.outer_args(f, self)
226 }
227}