1use crate::ast::{Expression, FormatStringPart, MatchArm, SelectArm, SelectArmPattern};
2
3pub(crate) trait AstFolder {
4 type Error;
5
6 fn fold_module(
7 &mut self,
8 expressions: Vec<Expression>,
9 ) -> Result<Vec<Expression>, Self::Error> {
10 expressions
11 .into_iter()
12 .map(|e| self.fold_expression(e))
13 .collect()
14 }
15
16 fn fold_expression(&mut self, expression: Expression) -> Result<Expression, Self::Error> {
17 self.fold_expression_default(expression)
18 }
19
20 fn fold_expression_default(
21 &mut self,
22 expression: Expression,
23 ) -> Result<Expression, Self::Error> {
24 use Expression::*;
25
26 Ok(match expression {
27 Binary {
28 operator,
29 left,
30 right,
31 ty,
32 span,
33 } => Binary {
34 operator,
35 left: Box::new(self.fold_expression(*left)?),
36 right: Box::new(self.fold_expression(*right)?),
37 ty,
38 span,
39 },
40
41 Call {
42 expression,
43 args,
44 type_args,
45 ty,
46 span,
47 } => Call {
48 expression: Box::new(self.fold_expression(*expression)?),
49 args: self.fold_vec(args)?,
50 type_args,
51 ty,
52 span,
53 },
54
55 Block { items, ty, span } => Block {
56 items: self.fold_vec(items)?,
57 ty,
58 span,
59 },
60
61 TryBlock {
62 items,
63 ty,
64 try_keyword_span,
65 span,
66 } => TryBlock {
67 items: self.fold_vec(items)?,
68 ty,
69 try_keyword_span,
70 span,
71 },
72
73 RecoverBlock {
74 items,
75 ty,
76 recover_keyword_span,
77 span,
78 } => RecoverBlock {
79 items: self.fold_vec(items)?,
80 ty,
81 recover_keyword_span,
82 span,
83 },
84
85 If {
86 condition,
87 consequence,
88 alternative,
89 ty,
90 span,
91 } => If {
92 condition: Box::new(self.fold_expression(*condition)?),
93 consequence: Box::new(self.fold_expression(*consequence)?),
94 alternative: Box::new(self.fold_expression(*alternative)?),
95 ty,
96 span,
97 },
98
99 IfLet {
100 pattern,
101 scrutinee,
102 consequence,
103 alternative,
104 typed_pattern,
105 else_span,
106 ty,
107 span,
108 } => IfLet {
109 pattern,
110 scrutinee: Box::new(self.fold_expression(*scrutinee)?),
111 consequence: Box::new(self.fold_expression(*consequence)?),
112 alternative: Box::new(self.fold_expression(*alternative)?),
113 typed_pattern,
114 else_span,
115 ty,
116 span,
117 },
118
119 Match {
120 subject,
121 arms,
122 origin,
123 ty,
124 span,
125 } => Match {
126 subject: Box::new(self.fold_expression(*subject)?),
127 arms: arms
128 .into_iter()
129 .map(|arm| self.fold_match_arm(arm))
130 .collect::<Result<_, _>>()?,
131 origin,
132 ty,
133 span,
134 },
135
136 Let {
137 binding,
138 value,
139 mutable,
140 mut_span,
141 else_block,
142 else_span,
143 typed_pattern,
144 ty,
145 span,
146 } => Let {
147 binding,
148 value: Box::new(self.fold_expression(*value)?),
149 mutable,
150 mut_span,
151 else_block: else_block
152 .map(|e| self.fold_expression(*e).map(Box::new))
153 .transpose()?,
154 else_span,
155 typed_pattern,
156 ty,
157 span,
158 },
159
160 Return {
161 expression,
162 ty,
163 span,
164 } => Return {
165 expression: Box::new(self.fold_expression(*expression)?),
166 ty,
167 span,
168 },
169
170 Propagate {
171 expression,
172 ty,
173 span,
174 } => Propagate {
175 expression: Box::new(self.fold_expression(*expression)?),
176 ty,
177 span,
178 },
179
180 Unary {
181 operator,
182 expression,
183 ty,
184 span,
185 } => Unary {
186 operator,
187 expression: Box::new(self.fold_expression(*expression)?),
188 ty,
189 span,
190 },
191
192 Paren {
193 expression,
194 ty,
195 span,
196 } => Paren {
197 expression: Box::new(self.fold_expression(*expression)?),
198 ty,
199 span,
200 },
201
202 DotAccess {
203 expression,
204 member,
205 ty,
206 span,
207 } => DotAccess {
208 expression: Box::new(self.fold_expression(*expression)?),
209 member,
210 ty,
211 span,
212 },
213
214 IndexedAccess {
215 expression,
216 index,
217 ty,
218 span,
219 } => IndexedAccess {
220 expression: Box::new(self.fold_expression(*expression)?),
221 index: Box::new(self.fold_expression(*index)?),
222 ty,
223 span,
224 },
225
226 Assignment {
227 target,
228 value,
229 compound_operator,
230 span,
231 } => Assignment {
232 target: Box::new(self.fold_expression(*target)?),
233 value: Box::new(self.fold_expression(*value)?),
234 compound_operator,
235 span,
236 },
237
238 Tuple { elements, ty, span } => Tuple {
239 elements: self.fold_vec(elements)?,
240 ty,
241 span,
242 },
243
244 StructCall {
245 name,
246 field_assignments,
247 spread,
248 ty,
249 span,
250 } => StructCall {
251 name,
252 field_assignments: field_assignments
253 .into_iter()
254 .map(|mut f| {
255 f.value = Box::new(self.fold_expression(*f.value)?);
256 Ok(f)
257 })
258 .collect::<Result<_, Self::Error>>()?,
259 spread: Box::new((*spread).map(|e| self.fold_expression(e)).transpose()?),
260 ty,
261 span,
262 },
263
264 Function {
265 doc,
266 attributes,
267 name,
268 name_span,
269 generics,
270 params,
271 return_annotation,
272 return_type,
273 visibility,
274 body,
275 ty,
276 span,
277 } => Function {
278 doc,
279 attributes,
280 name,
281 name_span,
282 generics,
283 params,
284 return_annotation,
285 return_type,
286 visibility,
287 body: Box::new(self.fold_expression(*body)?),
288 ty,
289 span,
290 },
291
292 Lambda {
293 params,
294 return_annotation,
295 body,
296 ty,
297 span,
298 } => Lambda {
299 params,
300 return_annotation,
301 body: Box::new(self.fold_expression(*body)?),
302 ty,
303 span,
304 },
305
306 Reference {
307 expression,
308 ty,
309 span,
310 } => Reference {
311 expression: Box::new(self.fold_expression(*expression)?),
312 ty,
313 span,
314 },
315
316 For {
317 binding,
318 iterable,
319 body,
320 span,
321 needs_label,
322 } => For {
323 binding,
324 iterable: Box::new(self.fold_expression(*iterable)?),
325 body: Box::new(self.fold_expression(*body)?),
326 span,
327 needs_label,
328 },
329
330 While {
331 condition,
332 body,
333 span,
334 needs_label,
335 } => While {
336 condition: Box::new(self.fold_expression(*condition)?),
337 body: Box::new(self.fold_expression(*body)?),
338 span,
339 needs_label,
340 },
341
342 WhileLet {
343 pattern,
344 scrutinee,
345 body,
346 typed_pattern,
347 span,
348 needs_label,
349 } => WhileLet {
350 pattern,
351 scrutinee: Box::new(self.fold_expression(*scrutinee)?),
352 body: Box::new(self.fold_expression(*body)?),
353 typed_pattern,
354 span,
355 needs_label,
356 },
357
358 Loop {
359 body,
360 ty,
361 span,
362 needs_label,
363 } => Loop {
364 body: Box::new(self.fold_expression(*body)?),
365 ty,
366 span,
367 needs_label,
368 },
369
370 Task {
371 expression,
372 ty,
373 span,
374 } => Task {
375 expression: Box::new(self.fold_expression(*expression)?),
376 ty,
377 span,
378 },
379
380 Defer {
381 expression,
382 ty,
383 span,
384 } => Defer {
385 expression: Box::new(self.fold_expression(*expression)?),
386 ty,
387 span,
388 },
389
390 Select { arms, ty, span } => Select {
391 arms: arms
392 .into_iter()
393 .map(|arm| self.fold_select_arm(arm))
394 .collect::<Result<_, _>>()?,
395 ty,
396 span,
397 },
398
399 ImplBlock {
400 annotation,
401 receiver_name,
402 methods,
403 generics,
404 ty,
405 span,
406 } => ImplBlock {
407 annotation,
408 receiver_name,
409 methods: self.fold_vec(methods)?,
410 generics,
411 ty,
412 span,
413 },
414
415 Const {
416 doc,
417 identifier,
418 identifier_span,
419 annotation,
420 expression,
421 visibility,
422 ty,
423 span,
424 } => Const {
425 doc,
426 identifier,
427 identifier_span,
428 annotation,
429 expression: Box::new(self.fold_expression(*expression)?),
430 visibility,
431 ty,
432 span,
433 },
434
435 Cast {
436 expression,
437 target_type,
438 ty,
439 span,
440 } => Cast {
441 expression: Box::new(self.fold_expression(*expression)?),
442 target_type,
443 ty,
444 span,
445 },
446
447 Break { value, span } => Break {
448 value: match value {
449 Some(v) => Some(Box::new(self.fold_expression(*v)?)),
450 None => None,
451 },
452 span,
453 },
454
455 Literal {
456 literal: crate::ast::Literal::FormatString(parts),
457 ty,
458 span,
459 } => {
460 let folded_parts = parts
461 .into_iter()
462 .map(|part| match part {
463 FormatStringPart::Expression(expression) => {
464 Ok(FormatStringPart::Expression(Box::new(
465 self.fold_expression(*expression)?,
466 )))
467 }
468 other => Ok(other),
469 })
470 .collect::<Result<Vec<_>, Self::Error>>()?;
471 Literal {
472 literal: crate::ast::Literal::FormatString(folded_parts),
473 ty,
474 span,
475 }
476 }
477
478 Range {
479 start,
480 end,
481 inclusive,
482 ty,
483 span,
484 } => Range {
485 start: start
486 .map(|e| self.fold_expression(*e).map(Box::new))
487 .transpose()?,
488 end: end
489 .map(|e| self.fold_expression(*e).map(Box::new))
490 .transpose()?,
491 inclusive,
492 ty,
493 span,
494 },
495
496 Literal { .. }
497 | Identifier { .. }
498 | Enum { .. }
499 | ValueEnum { .. }
500 | Struct { .. }
501 | TypeAlias { .. }
502 | VariableDeclaration { .. }
503 | ModuleImport { .. }
504 | Interface { .. }
505 | Continue { .. }
506 | Unit { .. }
507 | RawGo { .. }
508 | NoOp => expression,
509 })
510 }
511
512 fn fold_vec(&mut self, expressions: Vec<Expression>) -> Result<Vec<Expression>, Self::Error> {
513 expressions
514 .into_iter()
515 .map(|e| self.fold_expression(e))
516 .collect()
517 }
518
519 fn fold_match_arm(&mut self, mut arm: MatchArm) -> Result<MatchArm, Self::Error> {
520 arm.expression = Box::new(self.fold_expression(*arm.expression)?);
521 arm.guard = arm
522 .guard
523 .map(|g| self.fold_expression(*g).map(Box::new))
524 .transpose()?;
525 Ok(arm)
526 }
527
528 fn fold_select_arm(&mut self, arm: SelectArm) -> Result<SelectArm, Self::Error> {
529 let pattern = match arm.pattern {
530 SelectArmPattern::Receive {
531 binding,
532 typed_pattern,
533 receive_expression,
534 body,
535 } => SelectArmPattern::Receive {
536 binding,
537 typed_pattern,
538 receive_expression: Box::new(self.fold_expression(*receive_expression)?),
539 body: Box::new(self.fold_expression(*body)?),
540 },
541 SelectArmPattern::Send {
542 send_expression,
543 body,
544 } => SelectArmPattern::Send {
545 send_expression: Box::new(self.fold_expression(*send_expression)?),
546 body: Box::new(self.fold_expression(*body)?),
547 },
548 SelectArmPattern::MatchReceive {
549 receive_expression,
550 arms,
551 } => SelectArmPattern::MatchReceive {
552 receive_expression: Box::new(self.fold_expression(*receive_expression)?),
553 arms: arms
554 .into_iter()
555 .map(|arm| self.fold_match_arm(arm))
556 .collect::<Result<_, _>>()?,
557 },
558 SelectArmPattern::WildCard { body } => SelectArmPattern::WildCard {
559 body: Box::new(self.fold_expression(*body)?),
560 },
561 };
562 Ok(SelectArm { pattern })
563 }
564}