1use crate::{
2 parser::{
3 ast::{expr::Expr, helper::ParenDelimited, literal::{Literal, LitSym}},
4 error::{kind::{CompoundAssignmentInHeader, InvalidAssignmentLhs, InvalidCompoundAssignmentLhs}, Error},
5 fmt::Latex,
6 garbage::Garbage,
7 token::op::AssignOp,
8 Parse,
9 Parser,
10 ParseResult,
11 },
12 return_if_ok,
13};
14use std::{fmt, ops::Range};
15
16#[cfg(feature = "serde")]
17use serde::{Deserialize, Serialize};
18
19#[derive(Debug, Clone, PartialEq)]
22#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
23pub enum Param {
24 Symbol(LitSym),
26
27 Default(LitSym, Expr),
29}
30
31impl Param {
32 pub fn symbol(&self) -> &LitSym {
34 match self {
35 Param::Symbol(symbol) => symbol,
36 Param::Default(symbol, _) => symbol,
37 }
38 }
39}
40
41impl<'source> Parse<'source> for Param {
42 fn std_parse(
43 input: &mut Parser<'source>,
44 recoverable_errors: &mut Vec<Error>
45 ) -> Result<Self, Vec<Error>> {
46 let symbol = input.try_parse().forward_errors(recoverable_errors)?;
47
48 if let Ok(assign) = input.try_parse::<AssignOp>().forward_errors(recoverable_errors) {
49 if assign.is_compound() {
50 recoverable_errors.push(Error::new(
51 vec![assign.span.clone()],
52 CompoundAssignmentInHeader,
53 ));
54 }
55 let default = input.try_parse().forward_errors(recoverable_errors)?;
56 Ok(Param::Default(symbol, default))
57 } else {
58 Ok(Param::Symbol(symbol))
59 }
60 }
61}
62
63impl std::fmt::Display for Param {
64 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65 match self {
66 Param::Symbol(symbol) => write!(f, "{}", symbol),
67 Param::Default(symbol, default) => write!(f, "{} = {}", symbol, default),
68 }
69 }
70}
71
72impl Latex for Param {
73 fn fmt_latex(&self, f: &mut fmt::Formatter) -> fmt::Result {
74 match self {
75 Param::Symbol(symbol) => symbol.fmt_latex(f),
76 Param::Default(symbol, default) => write!(f, "{} = {}", symbol.as_display(), default.as_display()),
77 }
78 }
79}
80
81#[derive(Debug, Clone, PartialEq)]
86#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
87pub struct FuncHeader {
88 pub name: LitSym,
90
91 pub params: Vec<Param>,
93
94 pub span: Range<usize>,
96}
97
98impl FuncHeader {
99 pub fn span(&self) -> Range<usize> {
101 self.span.clone()
102 }
103}
104
105impl<'source> Parse<'source> for FuncHeader {
106 fn std_parse(
107 input: &mut Parser<'source>,
108 recoverable_errors: &mut Vec<Error>
109 ) -> Result<Self, Vec<Error>> {
110 let name = input.try_parse::<LitSym>().forward_errors(recoverable_errors)?;
111 let surrounded = input.try_parse::<ParenDelimited<_>>().forward_errors(recoverable_errors)?;
112
113 let span = name.span.start..surrounded.close.span.end;
114 Ok(Self { name, params: surrounded.value.values, span })
115 }
116}
117
118impl std::fmt::Display for FuncHeader {
119 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
120 write!(f, "{}(", self.name)?;
121 if let Some((last, rest)) = self.params.split_last() {
122 for param in rest {
123 write!(f, "{}, ", param)?;
124 }
125 write!(f, "{}", last)?;
126 }
127 write!(f, ")")
128 }
129}
130
131impl Latex for FuncHeader {
132 fn fmt_latex(&self, f: &mut fmt::Formatter) -> fmt::Result {
133 write!(f, "\\mathrm{{ {} }} \\left(", self.name.as_display())?;
134 if let Some((last, rest)) = self.params.split_last() {
135 for param in rest {
136 param.fmt_latex(f)?;
137 write!(f, ", ")?;
138 }
139 last.fmt_latex(f)?;
140 }
141 write!(f, "\\right)")
142 }
143}
144
145#[derive(Debug, Clone, PartialEq)]
147#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
148pub enum AssignTarget {
149 Symbol(LitSym),
151
152 Func(FuncHeader),
154}
155
156impl AssignTarget {
157 pub fn span(&self) -> Range<usize> {
159 match self {
160 AssignTarget::Symbol(symbol) => symbol.span.clone(),
161 AssignTarget::Func(func) => func.span(),
162 }
163 }
164
165 pub fn try_from_with_op(expr: Expr, op: &AssignOp) -> ParseResult<Self> {
168 let op_span = op.span.clone();
169 match expr {
170 Expr::Literal(Literal::Symbol(symbol)) => ParseResult::Ok(AssignTarget::Symbol(symbol)),
171 Expr::Call(call) => {
172 let spans = vec![call.span.clone(), op_span.clone()];
173 let error = if op.is_compound() {
174 Error::new(spans, InvalidCompoundAssignmentLhs)
175 } else {
176 Error::new(spans, InvalidAssignmentLhs { is_call: true })
177 };
178
179 ParseResult::Recoverable(Garbage::garbage(), vec![error])
180 },
181 expr => {
182 let spans = vec![expr.span(), op_span.clone()];
183 let error = if op.is_compound() {
184 Error::new(spans, InvalidCompoundAssignmentLhs)
185 } else {
186 Error::new(spans, InvalidAssignmentLhs { is_call: false })
187 };
188
189 ParseResult::Recoverable(
190 Garbage::garbage(),
191 vec![error]
192 )
193 },
194 }
195 }
196}
197
198impl<'source> Parse<'source> for AssignTarget {
199 fn std_parse(
200 input: &mut Parser<'source>,
201 recoverable_errors: &mut Vec<Error>
202 ) -> Result<Self, Vec<Error>> {
203 let _ = return_if_ok!(input.try_parse().map(AssignTarget::Func).forward_errors(recoverable_errors));
204 input.try_parse().map(AssignTarget::Symbol).forward_errors(recoverable_errors)
205 }
206}
207
208impl std::fmt::Display for AssignTarget {
209 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
210 match self {
211 AssignTarget::Symbol(symbol) => write!(f, "{}", symbol),
212 AssignTarget::Func(func) => write!(f, "{}", func),
213 }
214 }
215}
216
217impl Latex for AssignTarget {
218 fn fmt_latex(&self, f: &mut fmt::Formatter) -> fmt::Result {
219 match self {
220 AssignTarget::Symbol(symbol) => symbol.fmt_latex(f),
221 AssignTarget::Func(func) => func.fmt_latex(f),
222 }
223 }
224}
225
226#[derive(Debug, Clone, PartialEq)]
228#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
229pub struct Assign {
230 pub target: AssignTarget,
232
233 pub op: AssignOp,
235
236 pub value: Box<Expr>,
238
239 pub span: Range<usize>,
241}
242
243impl Assign {
244 pub fn span(&self) -> Range<usize> {
246 self.span.clone()
247 }
248
249 pub fn is_recursive(&self) -> bool {
251 if let AssignTarget::Func(header) = &self.target {
252 let is_correct_call = |expr: &Expr| {
253 match expr {
254 Expr::Call(call) => call.name.name == header.name.name,
255 _ => false,
256 }
257 };
258
259 self.value.post_order_iter().any(is_correct_call)
260 } else {
261 false
262 }
263 }
264}
265
266impl<'source> Parse<'source> for Assign {
267 fn std_parse(
268 input: &mut Parser<'source>,
269 recoverable_errors: &mut Vec<Error>
270 ) -> Result<Self, Vec<Error>> {
271 let target = input.try_parse().forward_errors(recoverable_errors)?;
272 let op = input.try_parse::<AssignOp>().forward_errors(recoverable_errors)?;
273
274 let value = if matches!(target, AssignTarget::Func(_)) {
275 if op.is_compound() {
276 recoverable_errors.push(Error::new(
281 vec![op.span.clone()],
282 InvalidCompoundAssignmentLhs,
283 ));
284 }
285
286 input.try_parse_with_state::<_, Expr>(|state| {
287 state.allow_loop_control = false;
294 }).forward_errors(recoverable_errors)?
295 } else {
296 input.try_parse::<Expr>().forward_errors(recoverable_errors)?
297 };
298
299 let span = target.span().start..value.span().end;
300 Ok(Self {
301 target,
302 op,
303 value: Box::new(value),
304 span,
305 })
306 }
307}
308
309impl std::fmt::Display for Assign {
310 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
311 write!(
312 f,
313 "{} {} {}",
314 self.target,
315 self.op,
316 self.value,
317 )
318 }
319}
320
321impl Latex for Assign {
322 fn fmt_latex(&self, f: &mut fmt::Formatter) -> fmt::Result {
323 write!(
324 f,
325 "{} {} {}",
326 self.target.as_display(),
327 self.op.as_display(),
328 self.value.as_display(),
329 )
330 }
331}