1use crate::{MatchIdentifier, VariableId};
16use std::collections::HashMap;
17
18use crate::expr_arena::{
19 ArmPatternId, ArmPatternNode, ExprArena, ExprId, ExprKind, MatchArmNode, TypeTable,
20};
21use crate::type_inference::expr_visitor::arena::children_of;
22
23pub fn bind_variables_of_let_assignment(root: ExprId, arena: &mut ExprArena, _types: &TypeTable) {
30 let mut state: HashMap<String, VariableId> = HashMap::new();
31
32 let mut order = Vec::new();
35 collect_post_order(root, arena, &mut order);
36
37 for id in order {
38 let kind = arena.expr(id).kind.clone();
39 match kind {
40 ExprKind::Let { variable_id, .. } => {
41 let name = variable_id.name();
42 let next = state
43 .entry(name.clone())
44 .and_modify(|x| *x = x.increment_local_variable_id())
45 .or_insert_with(|| VariableId::local(&name, 0))
46 .clone();
47 if let ExprKind::Let {
48 variable_id: ref mut vid,
49 ..
50 } = arena.expr_mut(id).kind
51 {
52 *vid = next;
53 }
54 }
55 ExprKind::Identifier { variable_id } if !variable_id.is_match_binding() => {
56 let name = variable_id.name();
57 if let Some(latest) = state.get(&name).cloned() {
58 if let ExprKind::Identifier {
59 variable_id: ref mut vid,
60 } = arena.expr_mut(id).kind
61 {
62 *vid = latest;
63 }
64 }
65 }
66 _ => {}
67 }
68 }
69}
70
71pub fn bind_variables_of_list_comprehension(
76 root: ExprId,
77 arena: &mut ExprArena,
78 _types: &TypeTable,
79) {
80 let mut order = Vec::new();
83 collect_pre_order(root, arena, &mut order);
84
85 for id in order {
86 let kind = arena.expr(id).kind.clone();
87 if let ExprKind::ListComprehension {
88 mut iterated_variable,
89 yield_expr,
90 ..
91 } = kind
92 {
93 let new_var = VariableId::list_comprehension_identifier(iterated_variable.name());
94 iterated_variable = new_var.clone();
95
96 if let ExprKind::ListComprehension {
98 iterated_variable: ref mut v,
99 ..
100 } = arena.expr_mut(id).kind
101 {
102 *v = new_var.clone();
103 }
104
105 patch_identifier_in_subtree(yield_expr, arena, &iterated_variable);
106 }
107 }
108}
109
110pub fn bind_variables_of_list_reduce(root: ExprId, arena: &mut ExprArena, _types: &TypeTable) {
115 let mut order = Vec::new();
116 collect_pre_order(root, arena, &mut order);
117
118 for id in order {
119 let kind = arena.expr(id).kind.clone();
120 if let ExprKind::ListReduce {
121 mut reduce_variable,
122 mut iterated_variable,
123 yield_expr,
124 ..
125 } = kind
126 {
127 let new_iter = VariableId::list_comprehension_identifier(iterated_variable.name());
128 let new_reduce = VariableId::list_reduce_identifier(reduce_variable.name());
129 iterated_variable = new_iter.clone();
130 reduce_variable = new_reduce.clone();
131
132 if let ExprKind::ListReduce {
133 reduce_variable: ref mut rv,
134 iterated_variable: ref mut iv,
135 ..
136 } = arena.expr_mut(id).kind
137 {
138 *rv = new_reduce.clone();
139 *iv = new_iter.clone();
140 }
141
142 patch_two_identifiers_in_subtree(
143 yield_expr,
144 arena,
145 &iterated_variable,
146 &reduce_variable,
147 );
148 }
149 }
150}
151
152pub fn bind_variables_of_pattern_match(root: ExprId, arena: &mut ExprArena, _types: &TypeTable) {
157 bind_pattern_match_internal(root, arena, 0, &mut []);
158}
159
160fn bind_pattern_match_internal(
161 root: ExprId,
162 arena: &mut ExprArena,
163 previous_index: usize,
164 match_identifiers: &mut [MatchIdentifier],
165) -> usize {
166 let mut index = previous_index;
167 let mut shadowed_let_bindings: Vec<String> = vec![];
168
169 let mut order = Vec::new();
170 collect_pre_order(root, arena, &mut order);
171
172 for id in order {
173 let kind = arena.expr(id).kind.clone();
174 match kind {
175 ExprKind::PatternMatch { match_arms, .. } => {
176 for arm in match_arms {
177 index += 1;
178 index = process_arm_arena(arm, index, arena);
179 }
180 }
181 ExprKind::Let { variable_id, .. } => {
182 shadowed_let_bindings.push(variable_id.name());
183 }
184 ExprKind::Identifier { variable_id } => {
185 let name = variable_id.name();
186 if let Some(mi) = match_identifiers.iter().find(|x| x.name == name) {
187 if !shadowed_let_bindings.contains(&name) {
188 if let ExprKind::Identifier {
189 variable_id: ref mut vid,
190 } = arena.expr_mut(id).kind
191 {
192 *vid = VariableId::MatchIdentifier(mi.clone());
193 }
194 }
195 }
196 }
197 _ => {}
198 }
199 }
200
201 index
202}
203
204fn process_arm_arena(arm: MatchArmNode, global_arm_index: usize, arena: &mut ExprArena) -> usize {
205 let mut match_identifiers = vec![];
206 collect_identifiers_from_arm_pattern(
207 arm.arm_pattern,
208 global_arm_index,
209 arena,
210 &mut match_identifiers,
211 );
212 bind_pattern_match_internal(
213 arm.arm_resolution_expr,
214 arena,
215 global_arm_index,
216 &mut match_identifiers,
217 )
218}
219
220fn collect_identifiers_from_arm_pattern(
221 pat_id: ArmPatternId,
222 global_arm_index: usize,
223 arena: &mut ExprArena,
224 out: &mut Vec<MatchIdentifier>,
225) {
226 let pat = arena.pattern(pat_id).clone();
227 match pat {
228 ArmPatternNode::Literal(expr_id) => {
229 update_identifiers_in_pattern_expr(expr_id, global_arm_index, arena, out);
230 }
231 ArmPatternNode::WildCard => {}
232 ArmPatternNode::As(name, inner) => {
233 out.push(MatchIdentifier::new(name, global_arm_index));
234 collect_identifiers_from_arm_pattern(inner, global_arm_index, arena, out);
235 }
236 ArmPatternNode::Constructor(_, children)
237 | ArmPatternNode::TupleConstructor(children)
238 | ArmPatternNode::ListConstructor(children) => {
239 for child in children {
240 collect_identifiers_from_arm_pattern(child, global_arm_index, arena, out);
241 }
242 }
243 ArmPatternNode::RecordConstructor(fields) => {
244 for (_, child) in fields {
245 collect_identifiers_from_arm_pattern(child, global_arm_index, arena, out);
246 }
247 }
248 }
249}
250
251fn update_identifiers_in_pattern_expr(
252 expr_id: ExprId,
253 global_arm_index: usize,
254 arena: &mut ExprArena,
255 out: &mut Vec<MatchIdentifier>,
256) {
257 let mut order = Vec::new();
258 collect_post_order(expr_id, arena, &mut order);
259 for id in order {
260 let kind = arena.expr(id).kind.clone();
261 if let ExprKind::Identifier { variable_id } = kind {
262 let mi = MatchIdentifier::new(variable_id.name(), global_arm_index);
263 out.push(mi.clone());
264 if let ExprKind::Identifier {
265 variable_id: ref mut vid,
266 } = arena.expr_mut(id).kind
267 {
268 *vid = VariableId::match_identifier(variable_id.name(), global_arm_index);
269 }
270 }
271 }
272}
273
274fn patch_identifier_in_subtree(root: ExprId, arena: &mut ExprArena, target: &VariableId) {
279 let mut order = Vec::new();
280 collect_pre_order(root, arena, &mut order);
281 for id in order {
282 let kind = arena.expr(id).kind.clone();
283 if let ExprKind::Identifier { variable_id } = kind {
284 if variable_id.name() == target.name() {
285 if let ExprKind::Identifier {
286 variable_id: ref mut vid,
287 } = arena.expr_mut(id).kind
288 {
289 *vid = target.clone();
290 }
291 }
292 }
293 }
294}
295
296fn patch_two_identifiers_in_subtree(
297 root: ExprId,
298 arena: &mut ExprArena,
299 iter_var: &VariableId,
300 reduce_var: &VariableId,
301) {
302 let mut order = Vec::new();
303 collect_pre_order(root, arena, &mut order);
304 for id in order {
305 let kind = arena.expr(id).kind.clone();
306 if let ExprKind::Identifier { variable_id } = kind {
307 let name = variable_id.name();
308 let new_vid = if name == iter_var.name() {
309 Some(iter_var.clone())
310 } else if name == reduce_var.name() {
311 Some(reduce_var.clone())
312 } else {
313 None
314 };
315 if let Some(new_vid) = new_vid {
316 if let ExprKind::Identifier {
317 variable_id: ref mut vid,
318 } = arena.expr_mut(id).kind
319 {
320 *vid = new_vid;
321 }
322 }
323 }
324 }
325}
326
327fn collect_post_order(root: ExprId, arena: &ExprArena, out: &mut Vec<ExprId>) {
328 let mut stack = vec![(root, false)];
329 while let Some((id, visited)) = stack.pop() {
330 if visited {
331 out.push(id);
332 } else {
333 stack.push((id, true));
334 for child in children_of(id, arena).into_iter().rev() {
335 stack.push((child, false));
336 }
337 }
338 }
339}
340
341fn collect_pre_order(root: ExprId, arena: &ExprArena, out: &mut Vec<ExprId>) {
342 let mut stack = vec![root];
343 while let Some(id) = stack.pop() {
344 out.push(id);
345 for child in children_of(id, arena).into_iter().rev() {
346 stack.push(child);
347 }
348 }
349}
350
351#[cfg(test)]
352mod name_binding_tests {
353 use bigdecimal::BigDecimal;
354 use test_r::test;
355
356 use crate::call_type::CallType;
357 use crate::function_name::{DynamicParsedFunctionName, DynamicParsedFunctionReference};
358 use crate::{Expr, InferredType, ParsedFunctionSite, VariableId};
359
360 fn bind_let_assignment_via_arena(expr: &mut Expr) {
362 let (mut arena, types, root) = crate::expr_arena::lower(expr);
363 super::bind_variables_of_let_assignment(root, &mut arena, &types);
364 *expr = crate::expr_arena::rebuild_expr(root, &arena, &types);
365 }
366
367 fn bind_pattern_match_via_arena(expr: &mut Expr) {
368 let (mut arena, types, root) = crate::expr_arena::lower(expr);
369 super::bind_variables_of_pattern_match(root, &mut arena, &types);
370 *expr = crate::expr_arena::rebuild_expr(root, &arena, &types);
371 }
372
373 #[test]
374 fn test_name_binding_simple() {
375 let rib_expr = r#"
376 let x = 1;
377 foo(x)
378 "#;
379
380 let mut expr = Expr::from_text(rib_expr).unwrap();
381
382 bind_let_assignment_via_arena(&mut expr);
383
384 let let_binding = Expr::let_binding_with_variable_id(
385 VariableId::local("x", 0),
386 Expr::number(BigDecimal::from(1)),
387 None,
388 );
389
390 let call_expr = Expr::call(
391 CallType::function_call(
392 DynamicParsedFunctionName {
393 site: ParsedFunctionSite::Global,
394 function: DynamicParsedFunctionReference::Function {
395 function: "foo".to_string(),
396 },
397 },
398 None,
399 ),
400 vec![Expr::identifier_local("x", 0, None)],
401 );
402
403 let expected = Expr::expr_block(vec![let_binding, call_expr]);
404
405 assert_eq!(expr, expected);
406 }
407
408 #[test]
409 fn test_name_binding_shadowing() {
410 let rib_expr = r#"
411 let x = 1;
412 foo(x);
413 let x = 2;
414 foo(x)
415 "#;
416
417 let mut expr = Expr::from_text(rib_expr).unwrap();
418
419 bind_let_assignment_via_arena(&mut expr);
420
421 let let_binding1 = Expr::let_binding_with_variable_id(
422 VariableId::local("x", 0),
423 Expr::number(BigDecimal::from(1)),
424 None,
425 );
426
427 let let_binding2 = Expr::let_binding_with_variable_id(
428 VariableId::local("x", 1),
429 Expr::number(BigDecimal::from(2)),
430 None,
431 );
432
433 let call_expr1 = Expr::call(
434 CallType::function_call(
435 DynamicParsedFunctionName {
436 site: ParsedFunctionSite::Global,
437 function: DynamicParsedFunctionReference::Function {
438 function: "foo".to_string(),
439 },
440 },
441 None,
442 ),
443 vec![Expr::identifier_local("x", 0, None)],
444 );
445
446 let call_expr2 = Expr::call(
447 CallType::function_call(
448 DynamicParsedFunctionName {
449 site: ParsedFunctionSite::Global,
450 function: DynamicParsedFunctionReference::Function {
451 function: "foo".to_string(),
452 },
453 },
454 None,
455 ),
456 vec![Expr::identifier_local("x", 1, None)],
457 );
458
459 let expected = Expr::expr_block(vec![let_binding1, call_expr1, let_binding2, call_expr2]);
460
461 assert_eq!(expr, expected);
462 }
463
464 #[test]
465 fn test_simple_pattern_match_name_binding() {
466 let expr_string = r#"
467 match some(x) {
468 some(x) => x,
469 none => 0
470 }
471 "#;
472
473 let mut expr = Expr::from_text(expr_string).unwrap();
474
475 bind_pattern_match_via_arena(&mut expr);
476
477 assert_eq!(expr, expectations::expected_match(1));
478 }
479
480 #[test]
481 fn test_simple_pattern_match_name_binding_block() {
482 let expr_string = r#"
483 match some(x) {
484 some(x) => x,
485 none => 0
486 };
487
488 match some(x) {
489 some(x) => x,
490 none => 0
491 }
492 "#;
493
494 let mut expr = Expr::from_text(expr_string).unwrap();
495
496 bind_pattern_match_via_arena(&mut expr);
497
498 let first_expr = expectations::expected_match(1);
499 let second_expr = expectations::expected_match(3);
500
501 let block = Expr::expr_block(vec![first_expr, second_expr])
502 .with_inferred_type(InferredType::unknown());
503
504 assert_eq!(expr, block);
505 }
506
507 mod expectations {
508 use crate::{ArmPattern, Expr, InferredType, MatchArm, MatchIdentifier, VariableId};
509 use bigdecimal::BigDecimal;
510
511 pub fn expected_match(index: usize) -> Expr {
512 Expr::pattern_match(
513 Expr::option(Some(Expr::identifier_global("x", None)))
514 .with_inferred_type(InferredType::option(InferredType::unknown())),
515 vec![
516 MatchArm {
517 arm_pattern: ArmPattern::constructor(
518 "some",
519 vec![ArmPattern::literal(Expr::identifier_with_variable_id(
520 VariableId::MatchIdentifier(MatchIdentifier::new(
521 "x".to_string(),
522 index,
523 )),
524 None,
525 ))],
526 ),
527 arm_resolution_expr: Box::new(Expr::identifier_with_variable_id(
528 VariableId::MatchIdentifier(MatchIdentifier::new(
529 "x".to_string(),
530 index,
531 )),
532 None,
533 )),
534 },
535 MatchArm {
536 arm_pattern: ArmPattern::constructor("none", vec![]),
537 arm_resolution_expr: Box::new(Expr::number(BigDecimal::from(0))),
538 },
539 ],
540 )
541 }
542 }
543}