1use crate::{MatchIdentifier, VariableId};
2use std::collections::HashMap;
3
4use crate::expr_arena::{
5 ArmPatternId, ArmPatternNode, ExprArena, ExprId, ExprKind, MatchArmNode, TypeTable,
6};
7use crate::type_inference::expr_visitor::arena::children_of;
8
9pub fn bind_variables_of_let_assignment(root: ExprId, arena: &mut ExprArena, _types: &TypeTable) {
16 let mut state: HashMap<String, VariableId> = HashMap::new();
17
18 let mut order = Vec::new();
21 collect_post_order(root, arena, &mut order);
22
23 for id in order {
24 let kind = arena.expr(id).kind.clone();
25 match kind {
26 ExprKind::Let { variable_id, .. } => {
27 let name = variable_id.name();
28 let next = state
29 .entry(name.clone())
30 .and_modify(|x| *x = x.increment_local_variable_id())
31 .or_insert_with(|| VariableId::local(&name, 0))
32 .clone();
33 if let ExprKind::Let {
34 variable_id: ref mut vid,
35 ..
36 } = arena.expr_mut(id).kind
37 {
38 *vid = next;
39 }
40 }
41 ExprKind::Identifier { variable_id } if !variable_id.is_match_binding() => {
42 let name = variable_id.name();
43 if let Some(latest) = state.get(&name).cloned() {
44 if let ExprKind::Identifier {
45 variable_id: ref mut vid,
46 } = arena.expr_mut(id).kind
47 {
48 *vid = latest;
49 }
50 }
51 }
52 _ => {}
53 }
54 }
55}
56
57pub fn bind_variables_of_list_comprehension(
62 root: ExprId,
63 arena: &mut ExprArena,
64 _types: &TypeTable,
65) {
66 let mut order = Vec::new();
69 collect_pre_order(root, arena, &mut order);
70
71 for id in order {
72 let kind = arena.expr(id).kind.clone();
73 if let ExprKind::ListComprehension {
74 mut iterated_variable,
75 yield_expr,
76 ..
77 } = kind
78 {
79 let new_var = VariableId::list_comprehension_identifier(iterated_variable.name());
80 iterated_variable = new_var.clone();
81
82 if let ExprKind::ListComprehension {
84 iterated_variable: ref mut v,
85 ..
86 } = arena.expr_mut(id).kind
87 {
88 *v = new_var.clone();
89 }
90
91 patch_identifier_in_subtree(yield_expr, arena, &iterated_variable);
92 }
93 }
94}
95
96pub fn bind_variables_of_list_reduce(root: ExprId, arena: &mut ExprArena, _types: &TypeTable) {
101 let mut order = Vec::new();
102 collect_pre_order(root, arena, &mut order);
103
104 for id in order {
105 let kind = arena.expr(id).kind.clone();
106 if let ExprKind::ListReduce {
107 mut reduce_variable,
108 mut iterated_variable,
109 yield_expr,
110 ..
111 } = kind
112 {
113 let new_iter = VariableId::list_comprehension_identifier(iterated_variable.name());
114 let new_reduce = VariableId::list_reduce_identifier(reduce_variable.name());
115 iterated_variable = new_iter.clone();
116 reduce_variable = new_reduce.clone();
117
118 if let ExprKind::ListReduce {
119 reduce_variable: ref mut rv,
120 iterated_variable: ref mut iv,
121 ..
122 } = arena.expr_mut(id).kind
123 {
124 *rv = new_reduce.clone();
125 *iv = new_iter.clone();
126 }
127
128 patch_two_identifiers_in_subtree(
129 yield_expr,
130 arena,
131 &iterated_variable,
132 &reduce_variable,
133 );
134 }
135 }
136}
137
138pub fn bind_variables_of_pattern_match(root: ExprId, arena: &mut ExprArena, _types: &TypeTable) {
143 bind_pattern_match_internal(root, arena, 0, &mut []);
144}
145
146fn bind_pattern_match_internal(
147 root: ExprId,
148 arena: &mut ExprArena,
149 previous_index: usize,
150 match_identifiers: &mut [MatchIdentifier],
151) -> usize {
152 let mut index = previous_index;
153 let mut shadowed_let_bindings: Vec<String> = vec![];
154
155 let mut order = Vec::new();
156 collect_pre_order(root, arena, &mut order);
157
158 for id in order {
159 let kind = arena.expr(id).kind.clone();
160 match kind {
161 ExprKind::PatternMatch { match_arms, .. } => {
162 for arm in match_arms {
163 index += 1;
164 index = process_arm_arena(arm, index, arena);
165 }
166 }
167 ExprKind::Let { variable_id, .. } => {
168 shadowed_let_bindings.push(variable_id.name());
169 }
170 ExprKind::Identifier { variable_id } => {
171 let name = variable_id.name();
172 if let Some(mi) = match_identifiers.iter().find(|x| x.name == name) {
173 if !shadowed_let_bindings.contains(&name) {
174 if let ExprKind::Identifier {
175 variable_id: ref mut vid,
176 } = arena.expr_mut(id).kind
177 {
178 *vid = VariableId::MatchIdentifier(mi.clone());
179 }
180 }
181 }
182 }
183 _ => {}
184 }
185 }
186
187 index
188}
189
190fn process_arm_arena(arm: MatchArmNode, global_arm_index: usize, arena: &mut ExprArena) -> usize {
191 let mut match_identifiers = vec![];
192 collect_identifiers_from_arm_pattern(
193 arm.arm_pattern,
194 global_arm_index,
195 arena,
196 &mut match_identifiers,
197 );
198 bind_pattern_match_internal(
199 arm.arm_resolution_expr,
200 arena,
201 global_arm_index,
202 &mut match_identifiers,
203 )
204}
205
206fn collect_identifiers_from_arm_pattern(
207 pat_id: ArmPatternId,
208 global_arm_index: usize,
209 arena: &mut ExprArena,
210 out: &mut Vec<MatchIdentifier>,
211) {
212 let pat = arena.pattern(pat_id).clone();
213 match pat {
214 ArmPatternNode::Literal(expr_id) => {
215 update_identifiers_in_pattern_expr(expr_id, global_arm_index, arena, out);
216 }
217 ArmPatternNode::WildCard => {}
218 ArmPatternNode::As(name, inner) => {
219 out.push(MatchIdentifier::new(name, global_arm_index));
220 collect_identifiers_from_arm_pattern(inner, global_arm_index, arena, out);
221 }
222 ArmPatternNode::Constructor(_, children)
223 | ArmPatternNode::TupleConstructor(children)
224 | ArmPatternNode::ListConstructor(children) => {
225 for child in children {
226 collect_identifiers_from_arm_pattern(child, global_arm_index, arena, out);
227 }
228 }
229 ArmPatternNode::RecordConstructor(fields) => {
230 for (_, child) in fields {
231 collect_identifiers_from_arm_pattern(child, global_arm_index, arena, out);
232 }
233 }
234 }
235}
236
237fn update_identifiers_in_pattern_expr(
238 expr_id: ExprId,
239 global_arm_index: usize,
240 arena: &mut ExprArena,
241 out: &mut Vec<MatchIdentifier>,
242) {
243 let mut order = Vec::new();
244 collect_post_order(expr_id, arena, &mut order);
245 for id in order {
246 let kind = arena.expr(id).kind.clone();
247 if let ExprKind::Identifier { variable_id } = kind {
248 let mi = MatchIdentifier::new(variable_id.name(), global_arm_index);
249 out.push(mi.clone());
250 if let ExprKind::Identifier {
251 variable_id: ref mut vid,
252 } = arena.expr_mut(id).kind
253 {
254 *vid = VariableId::match_identifier(variable_id.name(), global_arm_index);
255 }
256 }
257 }
258}
259
260fn patch_identifier_in_subtree(root: ExprId, arena: &mut ExprArena, target: &VariableId) {
265 let mut order = Vec::new();
266 collect_pre_order(root, arena, &mut order);
267 for id in order {
268 let kind = arena.expr(id).kind.clone();
269 if let ExprKind::Identifier { variable_id } = kind {
270 if variable_id.name() == target.name() {
271 if let ExprKind::Identifier {
272 variable_id: ref mut vid,
273 } = arena.expr_mut(id).kind
274 {
275 *vid = target.clone();
276 }
277 }
278 }
279 }
280}
281
282fn patch_two_identifiers_in_subtree(
283 root: ExprId,
284 arena: &mut ExprArena,
285 iter_var: &VariableId,
286 reduce_var: &VariableId,
287) {
288 let mut order = Vec::new();
289 collect_pre_order(root, arena, &mut order);
290 for id in order {
291 let kind = arena.expr(id).kind.clone();
292 if let ExprKind::Identifier { variable_id } = kind {
293 let name = variable_id.name();
294 let new_vid = if name == iter_var.name() {
295 Some(iter_var.clone())
296 } else if name == reduce_var.name() {
297 Some(reduce_var.clone())
298 } else {
299 None
300 };
301 if let Some(new_vid) = new_vid {
302 if let ExprKind::Identifier {
303 variable_id: ref mut vid,
304 } = arena.expr_mut(id).kind
305 {
306 *vid = new_vid;
307 }
308 }
309 }
310 }
311}
312
313fn collect_post_order(root: ExprId, arena: &ExprArena, out: &mut Vec<ExprId>) {
314 let mut stack = vec![(root, false)];
315 while let Some((id, visited)) = stack.pop() {
316 if visited {
317 out.push(id);
318 } else {
319 stack.push((id, true));
320 for child in children_of(id, arena).into_iter().rev() {
321 stack.push((child, false));
322 }
323 }
324 }
325}
326
327fn collect_pre_order(root: ExprId, arena: &ExprArena, out: &mut Vec<ExprId>) {
328 let mut stack = vec![root];
329 while let Some(id) = stack.pop() {
330 out.push(id);
331 for child in children_of(id, arena).into_iter().rev() {
332 stack.push(child);
333 }
334 }
335}
336
337#[cfg(test)]
338mod name_binding_tests {
339 use bigdecimal::BigDecimal;
340 use test_r::test;
341
342 use crate::call_type::CallType;
343 use crate::function_name::{DynamicParsedFunctionName, DynamicParsedFunctionReference};
344 use crate::{Expr, InferredType, ParsedFunctionSite, VariableId};
345
346 fn bind_let_assignment_via_arena(expr: &mut Expr) {
348 let (mut arena, types, root) = crate::expr_arena::lower(expr);
349 super::bind_variables_of_let_assignment(root, &mut arena, &types);
350 *expr = crate::expr_arena::rebuild_expr(root, &arena, &types);
351 }
352
353 fn bind_pattern_match_via_arena(expr: &mut Expr) {
354 let (mut arena, types, root) = crate::expr_arena::lower(expr);
355 super::bind_variables_of_pattern_match(root, &mut arena, &types);
356 *expr = crate::expr_arena::rebuild_expr(root, &arena, &types);
357 }
358
359 #[test]
360 fn test_name_binding_simple() {
361 let rib_expr = r#"
362 let x = 1;
363 foo(x)
364 "#;
365
366 let mut expr = Expr::from_text(rib_expr).unwrap();
367
368 bind_let_assignment_via_arena(&mut expr);
369
370 let let_binding = Expr::let_binding_with_variable_id(
371 VariableId::local("x", 0),
372 Expr::number(BigDecimal::from(1)),
373 None,
374 );
375
376 let call_expr = Expr::call(
377 CallType::function_call(
378 DynamicParsedFunctionName {
379 site: ParsedFunctionSite::Global,
380 function: DynamicParsedFunctionReference::Function {
381 function: "foo".to_string(),
382 },
383 },
384 None,
385 ),
386 vec![Expr::identifier_local("x", 0, None)],
387 );
388
389 let expected = Expr::expr_block(vec![let_binding, call_expr]);
390
391 assert_eq!(expr, expected);
392 }
393
394 #[test]
395 fn test_name_binding_shadowing() {
396 let rib_expr = r#"
397 let x = 1;
398 foo(x);
399 let x = 2;
400 foo(x)
401 "#;
402
403 let mut expr = Expr::from_text(rib_expr).unwrap();
404
405 bind_let_assignment_via_arena(&mut expr);
406
407 let let_binding1 = Expr::let_binding_with_variable_id(
408 VariableId::local("x", 0),
409 Expr::number(BigDecimal::from(1)),
410 None,
411 );
412
413 let let_binding2 = Expr::let_binding_with_variable_id(
414 VariableId::local("x", 1),
415 Expr::number(BigDecimal::from(2)),
416 None,
417 );
418
419 let call_expr1 = Expr::call(
420 CallType::function_call(
421 DynamicParsedFunctionName {
422 site: ParsedFunctionSite::Global,
423 function: DynamicParsedFunctionReference::Function {
424 function: "foo".to_string(),
425 },
426 },
427 None,
428 ),
429 vec![Expr::identifier_local("x", 0, None)],
430 );
431
432 let call_expr2 = Expr::call(
433 CallType::function_call(
434 DynamicParsedFunctionName {
435 site: ParsedFunctionSite::Global,
436 function: DynamicParsedFunctionReference::Function {
437 function: "foo".to_string(),
438 },
439 },
440 None,
441 ),
442 vec![Expr::identifier_local("x", 1, None)],
443 );
444
445 let expected = Expr::expr_block(vec![let_binding1, call_expr1, let_binding2, call_expr2]);
446
447 assert_eq!(expr, expected);
448 }
449
450 #[test]
451 fn test_simple_pattern_match_name_binding() {
452 let expr_string = r#"
453 match some(x) {
454 some(x) => x,
455 none => 0
456 }
457 "#;
458
459 let mut expr = Expr::from_text(expr_string).unwrap();
460
461 bind_pattern_match_via_arena(&mut expr);
462
463 assert_eq!(expr, expectations::expected_match(1));
464 }
465
466 #[test]
467 fn test_simple_pattern_match_name_binding_block() {
468 let expr_string = r#"
469 match some(x) {
470 some(x) => x,
471 none => 0
472 };
473
474 match some(x) {
475 some(x) => x,
476 none => 0
477 }
478 "#;
479
480 let mut expr = Expr::from_text(expr_string).unwrap();
481
482 bind_pattern_match_via_arena(&mut expr);
483
484 let first_expr = expectations::expected_match(1);
485 let second_expr = expectations::expected_match(3);
486
487 let block = Expr::expr_block(vec![first_expr, second_expr])
488 .with_inferred_type(InferredType::unknown());
489
490 assert_eq!(expr, block);
491 }
492
493 mod expectations {
494 use crate::{ArmPattern, Expr, InferredType, MatchArm, MatchIdentifier, VariableId};
495 use bigdecimal::BigDecimal;
496
497 pub fn expected_match(index: usize) -> Expr {
498 Expr::pattern_match(
499 Expr::option(Some(Expr::identifier_global("x", None)))
500 .with_inferred_type(InferredType::option(InferredType::unknown())),
501 vec![
502 MatchArm {
503 arm_pattern: ArmPattern::constructor(
504 "some",
505 vec![ArmPattern::literal(Expr::identifier_with_variable_id(
506 VariableId::MatchIdentifier(MatchIdentifier::new(
507 "x".to_string(),
508 index,
509 )),
510 None,
511 ))],
512 ),
513 arm_resolution_expr: Box::new(Expr::identifier_with_variable_id(
514 VariableId::MatchIdentifier(MatchIdentifier::new(
515 "x".to_string(),
516 index,
517 )),
518 None,
519 )),
520 },
521 MatchArm {
522 arm_pattern: ArmPattern::constructor("none", vec![]),
523 arm_resolution_expr: Box::new(Expr::number(BigDecimal::from(0))),
524 },
525 ],
526 )
527 }
528 }
529}