1use crate::{ArmPattern, Expr, ExprVisitor, MatchArm, MatchIdentifier, VariableId};
16use std::collections::HashMap;
17
18pub fn bind_variables_of_let_assignment(expr: &mut Expr) {
21 let mut identifier_id_state = IdentifierVariableIdState::new();
22 let mut visitor = ExprVisitor::bottom_up(expr);
23
24 while let Some(expr) = visitor.pop_front() {
26 match expr {
27 Expr::Let { variable_id, .. } => {
28 let field_name = variable_id.name();
29 identifier_id_state.update_variable_id(&field_name); if let Some(latest_variable_id) = identifier_id_state.lookup(&field_name) {
31 *variable_id = latest_variable_id.clone();
32 }
33 }
34
35 Expr::Identifier { variable_id, .. } if !variable_id.is_match_binding() => {
36 let field_name = variable_id.name();
37 if let Some(latest_variable_id) = identifier_id_state.lookup(&field_name) {
38 *variable_id = latest_variable_id.clone();
39 }
40 }
41 _ => {}
42 }
43 }
44}
45
46pub fn bind_variables_of_list_comprehension(expr: &mut Expr) {
47 let mut visitor = ExprVisitor::top_down(expr);
48
49 while let Some(expr) = visitor.pop_front() {
50 if let Expr::ListComprehension {
51 iterated_variable,
52 yield_expr,
53 ..
54 } = expr
55 {
56 *iterated_variable =
57 VariableId::list_comprehension_identifier(iterated_variable.name());
58
59 process_yield_expr_in_comprehension(iterated_variable, yield_expr)
60 }
61 }
62}
63
64pub fn bind_variables_of_list_reduce(expr: &mut Expr) {
65 let mut visitor = ExprVisitor::top_down(expr);
66
67 while let Some(expr) = visitor.pop_front() {
69 if let Expr::ListReduce {
70 reduce_variable,
71 iterated_variable,
72 yield_expr,
73 ..
74 } = expr
75 {
76 *iterated_variable =
79 VariableId::list_comprehension_identifier(iterated_variable.name());
80
81 *reduce_variable = VariableId::list_reduce_identifier(reduce_variable.name());
82
83 process_yield_expr_in_reduce(reduce_variable, iterated_variable, yield_expr)
84 }
85 }
86}
87
88pub fn bind_variables_of_pattern_match(expr: &mut Expr) {
89 bind_variables_in_pattern_match_internal(expr, 0, &mut []);
90}
91
92fn bind_variables_in_pattern_match_internal(
93 expr: &mut Expr,
94 previous_index: usize,
95 match_identifiers: &mut [MatchIdentifier],
96) -> usize {
97 let mut index = previous_index;
98 let mut queue = ExprVisitor::top_down(expr);
99 let mut shadowed_let_binding = vec![];
100
101 while let Some(expr) = queue.pop_front() {
103 match expr {
104 Expr::PatternMatch { match_arms, .. } => {
105 for arm in match_arms {
106 index += 1;
108 let latest = process_arm(arm, index);
109 index = latest
112 }
113 }
114 Expr::Let { variable_id, .. } => {
115 shadowed_let_binding.push(variable_id.name());
116 }
117 Expr::Identifier { variable_id, .. } => {
118 let identifier_name = variable_id.name();
119 if let Some(x) = match_identifiers.iter().find(|x| x.name == identifier_name) {
120 if !shadowed_let_binding.contains(&identifier_name) {
121 *variable_id = VariableId::MatchIdentifier(x.clone());
122 }
123 }
124 }
125
126 _ => {}
127 }
128 }
129
130 index
131}
132
133fn process_arm(match_arm: &mut MatchArm, global_arm_index: usize) -> usize {
134 let match_arm_pattern = &mut match_arm.arm_pattern;
135
136 pub fn go(
137 arm_pattern: &mut ArmPattern,
138 global_arm_index: usize,
139 match_identifiers: &mut Vec<MatchIdentifier>,
140 ) {
141 match arm_pattern {
142 ArmPattern::Literal(expr) => {
143 let new_match_identifiers =
144 update_all_identifier_in_lhs_expr(expr, global_arm_index);
145 match_identifiers.extend(new_match_identifiers);
146 }
147
148 ArmPattern::WildCard => {}
149 ArmPattern::As(name, arm_pattern) => {
150 let match_identifier = MatchIdentifier::new(name.clone(), global_arm_index);
151 match_identifiers.push(match_identifier);
152
153 go(arm_pattern, global_arm_index, match_identifiers);
154 }
155
156 ArmPattern::Constructor(_, arm_patterns) => {
157 for arm_pattern in arm_patterns {
158 go(arm_pattern, global_arm_index, match_identifiers);
159 }
160 }
161
162 ArmPattern::TupleConstructor(arm_patterns) => {
163 for arm_pattern in arm_patterns {
164 go(arm_pattern, global_arm_index, match_identifiers);
165 }
166 }
167
168 ArmPattern::ListConstructor(arm_patterns) => {
169 for arm_pattern in arm_patterns {
170 go(arm_pattern, global_arm_index, match_identifiers);
171 }
172 }
173
174 ArmPattern::RecordConstructor(fields) => {
175 for (_, arm_pattern) in fields {
176 go(arm_pattern, global_arm_index, match_identifiers);
177 }
178 }
179 }
180 }
181
182 let mut match_identifiers = vec![];
183
184 go(match_arm_pattern, global_arm_index, &mut match_identifiers);
186
187 let resolution_expression = &mut *match_arm.arm_resolution_expr;
188
189 bind_variables_in_pattern_match_internal(
192 resolution_expression,
193 global_arm_index,
194 &mut match_identifiers,
195 )
196}
197
198fn update_all_identifier_in_lhs_expr(
199 expr: &mut Expr,
200 global_arm_index: usize,
201) -> Vec<MatchIdentifier> {
202 let mut identifier_names = vec![];
203 let mut visitor = ExprVisitor::bottom_up(expr);
204
205 while let Some(expr) = visitor.pop_front() {
206 if let Expr::Identifier { variable_id, .. } = expr {
207 let match_identifier = MatchIdentifier::new(variable_id.name(), global_arm_index);
208 identifier_names.push(match_identifier);
209 let new_variable_id =
210 VariableId::match_identifier(variable_id.name(), global_arm_index);
211 *variable_id = new_variable_id;
212 }
213 }
214
215 identifier_names
216}
217
218fn process_yield_expr_in_comprehension(variable: &mut VariableId, yield_expr: &mut Expr) {
219 let mut visitor = ExprVisitor::top_down(yield_expr);
220
221 while let Some(expr) = visitor.pop_front() {
222 if let Expr::Identifier { variable_id, .. } = expr {
223 if variable.name() == variable_id.name() {
224 *variable_id = variable.clone();
225 }
226 }
227 }
228}
229
230fn process_yield_expr_in_reduce(
231 reduce_variable: &mut VariableId,
232 iterated_variable_id: &mut VariableId,
233 yield_expr: &mut Expr,
234) {
235 let mut visitor = ExprVisitor::top_down(yield_expr);
236
237 while let Some(expr) = visitor.pop_front() {
238 if let Expr::Identifier { variable_id, .. } = expr {
239 if iterated_variable_id.name() == variable_id.name() {
240 *variable_id = iterated_variable_id.clone();
241 } else if reduce_variable.name() == variable_id.name() {
242 *variable_id = reduce_variable.clone()
243 }
244 }
245 }
246}
247
248struct IdentifierVariableIdState(HashMap<String, VariableId>);
249
250impl IdentifierVariableIdState {
251 pub(crate) fn new() -> Self {
252 IdentifierVariableIdState(HashMap::new())
253 }
254
255 pub(crate) fn update_variable_id(&mut self, identifier: &str) {
256 self.0
257 .entry(identifier.to_string())
258 .and_modify(|x| {
259 *x = x.increment_local_variable_id();
260 })
261 .or_insert(VariableId::local(identifier, 0));
262 }
263
264 pub(crate) fn lookup(&self, identifier: &str) -> Option<VariableId> {
265 self.0.get(identifier).cloned()
266 }
267}
268
269#[cfg(test)]
270mod name_binding_tests {
271 use bigdecimal::BigDecimal;
272 use test_r::test;
273
274 use crate::call_type::CallType;
275 use crate::function_name::{DynamicParsedFunctionName, DynamicParsedFunctionReference};
276 use crate::{Expr, InferredType, ParsedFunctionSite, VariableId};
277
278 #[test]
279 fn test_name_binding_simple() {
280 let rib_expr = r#"
281 let x = 1;
282 foo(x)
283 "#;
284
285 let mut expr = Expr::from_text(rib_expr).unwrap();
286
287 expr.bind_variables_of_let_assignment();
289
290 let let_binding = Expr::let_binding_with_variable_id(
291 VariableId::local("x", 0),
292 Expr::number(BigDecimal::from(1)),
293 None,
294 );
295
296 let call_expr = Expr::call(
297 CallType::function_call(
298 DynamicParsedFunctionName {
299 site: ParsedFunctionSite::Global,
300 function: DynamicParsedFunctionReference::Function {
301 function: "foo".to_string(),
302 },
303 },
304 None,
305 ),
306 None,
307 vec![Expr::identifier_local("x", 0, None)],
308 );
309
310 let expected = Expr::expr_block(vec![let_binding, call_expr]);
311
312 assert_eq!(expr, expected);
313 }
314
315 #[test]
316 fn test_name_binding_multiple() {
317 let rib_expr = r#"
318 let x = 1;
319 let y = 2;
320 foo(x);
321 foo(y)
322 "#;
323
324 let mut expr = Expr::from_text(rib_expr).unwrap();
325
326 expr.bind_variables_of_let_assignment();
328
329 let let_binding1 = Expr::let_binding_with_variable_id(
330 VariableId::local("x", 0),
331 Expr::number(BigDecimal::from(1)),
332 None,
333 );
334
335 let let_binding2 = Expr::let_binding_with_variable_id(
336 VariableId::local("y", 0),
337 Expr::number(BigDecimal::from(2)),
338 None,
339 );
340
341 let call_expr1 = Expr::call(
342 CallType::function_call(
343 DynamicParsedFunctionName {
344 site: ParsedFunctionSite::Global,
345 function: DynamicParsedFunctionReference::Function {
346 function: "foo".to_string(),
347 },
348 },
349 None,
350 ),
351 None,
352 vec![Expr::identifier_local("x", 0, None)],
353 );
354
355 let call_expr2 = Expr::call(
356 CallType::function_call(
357 DynamicParsedFunctionName {
358 site: ParsedFunctionSite::Global,
359 function: DynamicParsedFunctionReference::Function {
360 function: "foo".to_string(),
361 },
362 },
363 None,
364 ),
365 None,
366 vec![Expr::identifier_local("y", 0, None)],
367 );
368
369 let expected = Expr::expr_block(vec![let_binding1, let_binding2, call_expr1, call_expr2]);
370
371 assert_eq!(expr, expected);
372 }
373
374 #[test]
375 fn test_name_binding_shadowing() {
376 let rib_expr = r#"
377 let x = 1;
378 foo(x);
379 let x = 2;
380 foo(x)
381 "#;
382
383 let mut expr = Expr::from_text(rib_expr).unwrap();
384
385 expr.bind_variables_of_let_assignment();
387
388 let let_binding1 = Expr::let_binding_with_variable_id(
389 VariableId::local("x", 0),
390 Expr::number(BigDecimal::from(1)),
391 None,
392 );
393
394 let let_binding2 = Expr::let_binding_with_variable_id(
395 VariableId::local("x", 1),
396 Expr::number(BigDecimal::from(2)),
397 None,
398 );
399
400 let call_expr1 = Expr::call(
401 CallType::function_call(
402 DynamicParsedFunctionName {
403 site: ParsedFunctionSite::Global,
404 function: DynamicParsedFunctionReference::Function {
405 function: "foo".to_string(),
406 },
407 },
408 None,
409 ),
410 None,
411 vec![Expr::identifier_local("x", 0, None)],
412 );
413
414 let call_expr2 = Expr::call(
415 CallType::function_call(
416 DynamicParsedFunctionName {
417 site: ParsedFunctionSite::Global,
418 function: DynamicParsedFunctionReference::Function {
419 function: "foo".to_string(),
420 },
421 },
422 None,
423 ),
424 None,
425 vec![Expr::identifier_local("x", 1, None)],
426 );
427
428 let expected = Expr::expr_block(vec![let_binding1, call_expr1, let_binding2, call_expr2]);
429
430 assert_eq!(expr, expected);
431 }
432
433 #[test]
434 fn test_simple_pattern_match_name_binding() {
435 let expr_string = r#"
437 match some(x) {
438 some(x) => x,
439 none => 0
440 }
441 "#;
442
443 let mut expr = Expr::from_text(expr_string).unwrap();
444
445 expr.bind_variables_of_pattern_match();
446
447 assert_eq!(expr, expectations::expected_match(1));
448 }
449
450 #[test]
451 fn test_simple_pattern_match_name_binding_with_shadow() {
452 let expr_string = r#"
454 match some(x) {
455 some(x) => {
456 let x = 1;
457 x
458 },
459 none => 0
460 }
461 "#;
462
463 let mut expr = Expr::from_text(expr_string).unwrap();
464
465 expr.bind_variables_of_pattern_match();
466
467 assert_eq!(expr, expectations::expected_match_with_let_binding(1));
468 }
469
470 #[test]
471 fn test_simple_pattern_match_name_binding_block() {
472 let expr_string = r#"
474 match some(x) {
475 some(x) => x,
476 none => 0
477 };
478
479 match some(x) {
480 some(x) => x,
481 none => 0
482 }
483 "#;
484
485 let mut expr = Expr::from_text(expr_string).unwrap();
486
487 expr.bind_variables_of_pattern_match();
488
489 let first_expr = expectations::expected_match(1);
490 let second_expr = expectations::expected_match(3); let block = Expr::expr_block(vec![first_expr, second_expr])
493 .with_inferred_type(InferredType::unknown());
494
495 assert_eq!(expr, block);
496 }
497
498 #[test]
499 fn test_nested_simple_pattern_match_binding() {
500 let expr_string = r#"
501 match ok(some(x)) {
502 ok(x) => match x {
503 some(x) => x,
504 none => 0
505 },
506 err(x) => 0
507 }
508 "#;
509
510 let mut expr = Expr::from_text(expr_string).unwrap();
511
512 expr.bind_variables_of_pattern_match();
513
514 assert_eq!(expr, expectations::expected_nested_match());
515 }
516
517 mod expectations {
518 use crate::{ArmPattern, Expr, InferredType, MatchArm, MatchIdentifier, VariableId};
519 use bigdecimal::BigDecimal;
520
521 pub(crate) fn expected_match(index: usize) -> Expr {
522 Expr::pattern_match(
523 Expr::option(Some(Expr::identifier_global("x", None)))
524 .with_inferred_type(InferredType::option(InferredType::unknown())),
525 vec![
526 MatchArm {
527 arm_pattern: ArmPattern::constructor(
528 "some",
529 vec![ArmPattern::literal(Expr::identifier_with_variable_id(
530 VariableId::MatchIdentifier(MatchIdentifier::new(
531 "x".to_string(),
532 index,
533 )),
534 None,
535 ))],
536 ),
537 arm_resolution_expr: Box::new(Expr::identifier_with_variable_id(
538 VariableId::MatchIdentifier(MatchIdentifier::new(
539 "x".to_string(),
540 index,
541 )),
542 None,
543 )),
544 },
545 MatchArm {
546 arm_pattern: ArmPattern::constructor("none", vec![]),
547 arm_resolution_expr: Box::new(Expr::number(BigDecimal::from(0))),
548 },
549 ],
550 )
551 }
552
553 pub(crate) fn expected_match_with_let_binding(index: usize) -> Expr {
554 let let_binding = Expr::let_binding("x", Expr::number(BigDecimal::from(1)), None);
555 let identifier_expr =
556 Expr::identifier_with_variable_id(VariableId::Global("x".to_string()), None);
557 let block = Expr::expr_block(vec![let_binding, identifier_expr]);
558
559 Expr::pattern_match(
560 Expr::option(Some(Expr::identifier_global("x", None))),
561 vec![
562 MatchArm {
563 arm_pattern: ArmPattern::constructor(
564 "some",
565 vec![ArmPattern::literal(Expr::identifier_with_variable_id(
566 VariableId::MatchIdentifier(MatchIdentifier::new(
567 "x".to_string(),
568 index,
569 )),
570 None,
571 ))],
572 ),
573 arm_resolution_expr: Box::new(block),
574 },
575 MatchArm {
576 arm_pattern: ArmPattern::constructor("none", vec![]),
577 arm_resolution_expr: Box::new(Expr::number(BigDecimal::from(0))),
578 },
579 ],
580 )
581 }
582
583 pub(crate) fn expected_nested_match() -> Expr {
584 Expr::pattern_match(
585 Expr::ok(
586 Expr::option(Some(Expr::identifier_with_variable_id(
587 VariableId::Global("x".to_string()),
588 None,
589 )))
590 .with_inferred_type(InferredType::option(InferredType::unknown())),
591 None,
592 )
593 .with_inferred_type(InferredType::result(
594 Some(InferredType::option(InferredType::unknown())),
595 Some(InferredType::unknown()),
596 )),
597 vec![
598 MatchArm {
599 arm_pattern: ArmPattern::constructor(
600 "ok",
601 vec![ArmPattern::literal(Expr::identifier_with_variable_id(
602 VariableId::MatchIdentifier(MatchIdentifier::new(
603 "x".to_string(),
604 1,
605 )),
606 None,
607 ))],
608 ),
609 arm_resolution_expr: Box::new(Expr::pattern_match(
610 Expr::identifier_with_variable_id(
611 VariableId::MatchIdentifier(MatchIdentifier::new(
612 "x".to_string(),
613 1,
614 )),
615 None,
616 ),
617 vec![
618 MatchArm {
619 arm_pattern: ArmPattern::constructor(
620 "some",
621 vec![ArmPattern::literal(
622 Expr::identifier_with_variable_id(
623 VariableId::MatchIdentifier(MatchIdentifier::new(
624 "x".to_string(),
625 5,
626 )),
627 None,
628 ),
629 )],
630 ),
631 arm_resolution_expr: Box::new(
632 Expr::identifier_with_variable_id(
633 VariableId::MatchIdentifier(MatchIdentifier::new(
634 "x".to_string(),
635 5,
636 )),
637 None,
638 ),
639 ),
640 },
641 MatchArm {
642 arm_pattern: ArmPattern::constructor("none", vec![]),
643 arm_resolution_expr: Box::new(Expr::number(BigDecimal::from(
644 0,
645 ))),
646 },
647 ],
648 )),
649 },
650 MatchArm {
651 arm_pattern: ArmPattern::constructor(
652 "err",
653 vec![ArmPattern::literal(Expr::identifier_with_variable_id(
654 VariableId::MatchIdentifier(MatchIdentifier::new(
655 "x".to_string(),
656 4,
657 )),
658 None,
659 ))],
660 ),
661 arm_resolution_expr: Box::new(Expr::number(BigDecimal::from(0))),
662 },
663 ],
664 )
665 }
666 }
667}