1use cas_error::Error;
2use crate::{
3 parser::{
4 ast::{
5 expr::{Atom, Expr, Primary},
6 helper::Surrounded,
7 index::Index,
8 literal::{Literal, LitSym},
9 },
10 error::{
11 CompoundAssignmentInHeader,
12 DefaultArgumentNotLast,
13 ExpectedExpr,
14 InvalidAssignmentLhs,
15 InvalidCompoundAssignmentLhs,
16 },
17 fmt::Latex,
18 garbage::Garbage,
19 token::{op::AssignOp, Comma, OpenParen},
20 Parse,
21 Parser,
22 ParseResult,
23 },
24 tokenizer::TokenKind,
25};
26use std::{fmt, ops::Range};
27
28#[cfg(feature = "serde")]
29use serde::{Deserialize, Serialize};
30
31#[derive(Debug, Clone, PartialEq, Eq)]
34#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
35pub enum Param {
36 Symbol(LitSym),
38
39 Default(LitSym, Expr),
41}
42
43impl Param {
44 pub fn span(&self) -> Range<usize> {
46 match self {
47 Param::Symbol(symbol) => symbol.span.clone(),
48 Param::Default(symbol, default) => symbol.span.start..default.span().end,
49 }
50 }
51
52 pub fn symbol(&self) -> &LitSym {
54 match self {
55 Param::Symbol(symbol) => symbol,
56 Param::Default(symbol, _) => symbol,
57 }
58 }
59
60 pub fn has_default(&self) -> bool {
62 matches!(self, Param::Default(_, _))
63 }
64}
65
66impl<'source> Parse<'source> for Param {
67 fn std_parse(
68 input: &mut Parser<'source>,
69 recoverable_errors: &mut Vec<Error>
70 ) -> Result<Self, Vec<Error>> {
71 let symbol = input.try_parse().forward_errors(recoverable_errors)?;
72
73 if let Ok(assign) = input.try_parse::<AssignOp>().forward_errors(recoverable_errors) {
74 if assign.is_compound() {
75 recoverable_errors.push(Error::new(
76 vec![assign.span.clone()],
77 CompoundAssignmentInHeader,
78 ));
79 }
80 let default = input.try_parse().forward_errors(recoverable_errors)?;
81 Ok(Param::Default(symbol, default))
82 } else {
83 Ok(Param::Symbol(symbol))
84 }
85 }
86}
87
88impl std::fmt::Display for Param {
89 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
90 match self {
91 Param::Symbol(symbol) => write!(f, "{}", symbol),
92 Param::Default(symbol, default) => write!(f, "{} = {}", symbol, default),
93 }
94 }
95}
96
97impl Latex for Param {
98 fn fmt_latex(&self, f: &mut fmt::Formatter) -> fmt::Result {
99 match self {
100 Param::Symbol(symbol) => symbol.fmt_latex(f),
101 Param::Default(symbol, default) => write!(f, "{} = {}", symbol.as_display(), default.as_display()),
102 }
103 }
104}
105
106#[derive(Debug, Clone, PartialEq, Eq)]
111#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
112pub struct FuncHeader {
113 pub name: LitSym,
115
116 pub params: Vec<Param>,
118
119 pub span: Range<usize>,
121}
122
123impl FuncHeader {
124 pub fn span(&self) -> Range<usize> {
126 self.span.clone()
127 }
128
129 fn parse_or_lower(
131 input: &mut Parser,
132 recoverable_errors: &mut Vec<Error>,
133 name: LitSym,
134 ) -> Result<Self, Vec<Error>> {
135 struct FuncHeaderInner {
138 values: Vec<Param>,
139 }
140
141 impl<'source> Parse<'source> for FuncHeaderInner {
142 fn std_parse(
143 input: &mut Parser<'source>,
144 recoverable_errors: &mut Vec<Error>
145 ) -> Result<Self, Vec<Error>> {
146 let mut bad_default_position = false;
147 let mut default_params = Vec::new();
148 let mut values = Vec::new();
149
150 loop {
151 let Ok(value) = input.try_parse().forward_errors(recoverable_errors) else {
152 break;
153 };
154
155 if !default_params.is_empty() && !bad_default_position {
158 if let Param::Symbol(_) = value {
159 bad_default_position = true;
160 }
161 }
162
163 if let Param::Default(_, _) = value {
164 default_params.push(value.span());
165 }
166
167 values.push(value);
168
169 if input.try_parse::<Comma>().forward_errors(recoverable_errors).is_err() {
170 break;
171 }
172 }
173
174 if bad_default_position {
175 recoverable_errors.push(Error::new(
176 default_params,
177 DefaultArgumentNotLast,
178 ));
179 }
180
181 Ok(Self { values })
182 }
183 }
184
185 let surrounded = input.try_parse::<Surrounded<OpenParen, FuncHeaderInner>>()
186 .forward_errors(recoverable_errors)?;
187
188 let span = name.span.start..surrounded.close.span.end;
189 Ok(Self { name, params: surrounded.value.values, span })
190 }
191}
192
193impl std::fmt::Display for FuncHeader {
194 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195 write!(f, "{}(", self.name)?;
196 if let Some((last, rest)) = self.params.split_last() {
197 for param in rest {
198 write!(f, "{}, ", param)?;
199 }
200 write!(f, "{}", last)?;
201 }
202 write!(f, ")")
203 }
204}
205
206impl Latex for FuncHeader {
207 fn fmt_latex(&self, f: &mut fmt::Formatter) -> fmt::Result {
208 write!(f, "\\mathrm{{ {} }} \\left(", self.name.as_display())?;
209 if let Some((last, rest)) = self.params.split_last() {
210 for param in rest {
211 param.fmt_latex(f)?;
212 write!(f, ", ")?;
213 }
214 last.fmt_latex(f)?;
215 }
216 write!(f, "\\right)")
217 }
218}
219
220#[derive(Debug, Clone, PartialEq, Eq)]
222#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
223pub enum AssignTarget {
224 Symbol(LitSym),
226
227 Index(Index),
229
230 Func(FuncHeader),
232}
233
234impl AssignTarget {
235 pub fn span(&self) -> Range<usize> {
237 match self {
238 AssignTarget::Symbol(symbol) => symbol.span.clone(),
239 AssignTarget::Index(index) => index.span(),
240 AssignTarget::Func(func) => func.span(),
241 }
242 }
243
244 pub fn is_func(&self) -> bool {
246 matches!(self, AssignTarget::Func(_))
247 }
248
249 pub fn try_from_with_op(expr: Expr, op: &AssignOp) -> ParseResult<Self> {
252 let op_span = op.span.clone();
253 match expr {
254 Expr::Literal(Literal::Symbol(symbol)) => ParseResult::Ok(AssignTarget::Symbol(symbol)),
255 Expr::Index(index) => ParseResult::Ok(AssignTarget::Index(index)),
256 Expr::Call(call) => {
257 let spans = vec![call.span.clone(), op_span.clone()];
258 let error = if op.is_compound() {
259 Error::new(spans, InvalidCompoundAssignmentLhs)
260 } else {
261 Error::new(spans, InvalidAssignmentLhs { is_call: true })
262 };
263
264 ParseResult::Recoverable(Garbage::garbage(), vec![error])
265 },
266 expr => {
267 let spans = vec![expr.span(), op_span.clone()];
268 let error = if op.is_compound() {
269 Error::new(spans, InvalidCompoundAssignmentLhs)
270 } else {
271 Error::new(spans, InvalidAssignmentLhs { is_call: false })
272 };
273
274 ParseResult::Recoverable(
275 Garbage::garbage(),
276 vec![error]
277 )
278 },
279 }
280 }
281}
282
283impl From<LitSym> for AssignTarget {
284 fn from(symbol: LitSym) -> Self {
285 AssignTarget::Symbol(symbol)
286 }
287}
288
289impl From<Index> for AssignTarget {
290 fn from(index: Index) -> Self {
291 AssignTarget::Index(index)
292 }
293}
294
295impl From<FuncHeader> for AssignTarget {
296 fn from(func: FuncHeader) -> Self {
297 AssignTarget::Func(func)
298 }
299}
300
301impl<'source> Parse<'source> for AssignTarget {
302 fn std_parse(
303 input: &mut Parser<'source>,
304 recoverable_errors: &mut Vec<Error>
305 ) -> Result<Self, Vec<Error>> {
306 let atom = input.try_parse::<Atom>().forward_errors(recoverable_errors)?;
310
311 let mut fork = input.clone();
312 match fork.next_token() {
313 Ok(next) if next.kind == TokenKind::OpenParen => {
314 if let Atom::Literal(Literal::Symbol(symbol)) = atom {
315 Ok(FuncHeader::parse_or_lower(input, recoverable_errors, symbol)
316 .map(Into::into)?)
317 } else {
318 Err(vec![input.error(ExpectedExpr { expected: "a symbol" })])
319 }
320 },
321 Ok(next) if next.kind == TokenKind::OpenSquare => {
322 match Index::parse_or_lower(input, recoverable_errors, atom.into()) {
323 (new_primary, true) => match new_primary {
324 Primary::Index(index) => Ok(AssignTarget::Index(index)),
325 _ => unreachable!(),
326 },
327 (unchanged_primary, false) => match unchanged_primary {
328 Primary::Literal(Literal::Symbol(symbol)) => Ok(AssignTarget::Symbol(symbol)),
329 _ => unreachable!(),
330 },
331 }
332 },
333 _ => if let Atom::Literal(Literal::Symbol(symbol)) = atom {
334 Ok(AssignTarget::Symbol(symbol))
335 } else {
336 Err(vec![input.error(ExpectedExpr { expected: "a symbol" })])
337 },
338 }
339 }
340}
341
342impl std::fmt::Display for AssignTarget {
343 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
344 match self {
345 AssignTarget::Symbol(symbol) => write!(f, "{}", symbol),
346 AssignTarget::Index(index) => write!(f, "{}", index),
347 AssignTarget::Func(func) => write!(f, "{}", func),
348 }
349 }
350}
351
352impl Latex for AssignTarget {
353 fn fmt_latex(&self, f: &mut fmt::Formatter) -> fmt::Result {
354 match self {
355 AssignTarget::Symbol(symbol) => symbol.fmt_latex(f),
356 AssignTarget::Index(index) => index.fmt_latex(f),
357 AssignTarget::Func(func) => func.fmt_latex(f),
358 }
359 }
360}
361
362#[derive(Debug, Clone, PartialEq, Eq)]
364#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
365pub struct Assign {
366 pub target: AssignTarget,
368
369 pub op: AssignOp,
371
372 pub value: Box<Expr>,
374
375 pub span: Range<usize>,
377}
378
379impl Assign {
380 pub fn span(&self) -> Range<usize> {
382 self.span.clone()
383 }
384}
385
386impl<'source> Parse<'source> for Assign {
387 fn std_parse(
388 input: &mut Parser<'source>,
389 recoverable_errors: &mut Vec<Error>
390 ) -> Result<Self, Vec<Error>> {
391 let target = input.try_parse().forward_errors(recoverable_errors)?;
392 let op = input.try_parse::<AssignOp>().forward_errors(recoverable_errors)?;
393
394 let value = if matches!(target, AssignTarget::Func(_)) {
395 if op.is_compound() {
396 recoverable_errors.push(Error::new(
401 vec![target.span(), op.span.clone()],
402 InvalidCompoundAssignmentLhs,
403 ));
404 }
405
406 input.try_parse_with_state::<_, Expr>(|state| {
407 state.allow_loop_control = false;
414 state.allow_return = true;
415 }).forward_errors(recoverable_errors)?
416 } else {
417 input.try_parse::<Expr>().forward_errors(recoverable_errors)?
418 };
419
420 let span = target.span().start..value.span().end;
421 Ok(Self {
422 target,
423 op,
424 value: Box::new(value),
425 span,
426 })
427 }
428}
429
430impl std::fmt::Display for Assign {
431 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
432 write!(
433 f,
434 "{} {} {}",
435 self.target,
436 self.op,
437 self.value,
438 )
439 }
440}
441
442impl Latex for Assign {
443 fn fmt_latex(&self, f: &mut fmt::Formatter) -> fmt::Result {
444 write!(
445 f,
446 "{} {} {}",
447 self.target.as_display(),
448 self.op.as_display(),
449 self.value.as_display(),
450 )
451 }
452}