Skip to main content

rib/type_inference/
variable_binding.rs

1use crate::{MatchIdentifier, VariableId};
2use std::collections::HashMap;
3
4use crate::expr_arena::{
5    ArmPatternId, ArmPatternNode, ExprArena, ExprId, ExprKind, MatchArmNode, TypeTable,
6};
7use crate::type_inference::expr_visitor::arena::children_of;
8
9// -----------------------------------------------------------------------
10// bind_variables_of_let_assignment
11// -----------------------------------------------------------------------
12
13/// Arena version: assigns local `VariableId`s to `Let` nodes and propagates
14/// them to matching `Identifier` use-sites.
15pub fn bind_variables_of_let_assignment(root: ExprId, arena: &mut ExprArena, _types: &TypeTable) {
16    let mut state: HashMap<String, VariableId> = HashMap::new();
17
18    // Post-order: children before parents — so identifiers inside a let's
19    // rhs are processed before the let itself.
20    let mut order = Vec::new();
21    collect_post_order(root, arena, &mut order);
22
23    for id in order {
24        let kind = arena.expr(id).kind.clone();
25        match kind {
26            ExprKind::Let { variable_id, .. } => {
27                let name = variable_id.name();
28                let next = state
29                    .entry(name.clone())
30                    .and_modify(|x| *x = x.increment_local_variable_id())
31                    .or_insert_with(|| VariableId::local(&name, 0))
32                    .clone();
33                if let ExprKind::Let {
34                    variable_id: ref mut vid,
35                    ..
36                } = arena.expr_mut(id).kind
37                {
38                    *vid = next;
39                }
40            }
41            ExprKind::Identifier { variable_id } if !variable_id.is_match_binding() => {
42                let name = variable_id.name();
43                if let Some(latest) = state.get(&name).cloned() {
44                    if let ExprKind::Identifier {
45                        variable_id: ref mut vid,
46                    } = arena.expr_mut(id).kind
47                    {
48                        *vid = latest;
49                    }
50                }
51            }
52            _ => {}
53        }
54    }
55}
56
57// -----------------------------------------------------------------------
58// bind_variables_of_list_comprehension
59// -----------------------------------------------------------------------
60
61pub fn bind_variables_of_list_comprehension(
62    root: ExprId,
63    arena: &mut ExprArena,
64    _types: &TypeTable,
65) {
66    // Pre-order: process parent before children so the updated variable is
67    // used when we patch identifiers inside the yield expression.
68    let mut order = Vec::new();
69    collect_pre_order(root, arena, &mut order);
70
71    for id in order {
72        let kind = arena.expr(id).kind.clone();
73        if let ExprKind::ListComprehension {
74            mut iterated_variable,
75            yield_expr,
76            ..
77        } = kind
78        {
79            let new_var = VariableId::list_comprehension_identifier(iterated_variable.name());
80            iterated_variable = new_var.clone();
81
82            // patch the node
83            if let ExprKind::ListComprehension {
84                iterated_variable: ref mut v,
85                ..
86            } = arena.expr_mut(id).kind
87            {
88                *v = new_var.clone();
89            }
90
91            patch_identifier_in_subtree(yield_expr, arena, &iterated_variable);
92        }
93    }
94}
95
96// -----------------------------------------------------------------------
97// bind_variables_of_list_reduce
98// -----------------------------------------------------------------------
99
100pub fn bind_variables_of_list_reduce(root: ExprId, arena: &mut ExprArena, _types: &TypeTable) {
101    let mut order = Vec::new();
102    collect_pre_order(root, arena, &mut order);
103
104    for id in order {
105        let kind = arena.expr(id).kind.clone();
106        if let ExprKind::ListReduce {
107            mut reduce_variable,
108            mut iterated_variable,
109            yield_expr,
110            ..
111        } = kind
112        {
113            let new_iter = VariableId::list_comprehension_identifier(iterated_variable.name());
114            let new_reduce = VariableId::list_reduce_identifier(reduce_variable.name());
115            iterated_variable = new_iter.clone();
116            reduce_variable = new_reduce.clone();
117
118            if let ExprKind::ListReduce {
119                reduce_variable: ref mut rv,
120                iterated_variable: ref mut iv,
121                ..
122            } = arena.expr_mut(id).kind
123            {
124                *rv = new_reduce.clone();
125                *iv = new_iter.clone();
126            }
127
128            patch_two_identifiers_in_subtree(
129                yield_expr,
130                arena,
131                &iterated_variable,
132                &reduce_variable,
133            );
134        }
135    }
136}
137
138// -----------------------------------------------------------------------
139// bind_variables_of_pattern_match
140// -----------------------------------------------------------------------
141
142pub fn bind_variables_of_pattern_match(root: ExprId, arena: &mut ExprArena, _types: &TypeTable) {
143    bind_pattern_match_internal(root, arena, 0, &mut []);
144}
145
146fn bind_pattern_match_internal(
147    root: ExprId,
148    arena: &mut ExprArena,
149    previous_index: usize,
150    match_identifiers: &mut [MatchIdentifier],
151) -> usize {
152    let mut index = previous_index;
153    let mut shadowed_let_bindings: Vec<String> = vec![];
154
155    let mut order = Vec::new();
156    collect_pre_order(root, arena, &mut order);
157
158    for id in order {
159        let kind = arena.expr(id).kind.clone();
160        match kind {
161            ExprKind::PatternMatch { match_arms, .. } => {
162                for arm in match_arms {
163                    index += 1;
164                    index = process_arm_arena(arm, index, arena);
165                }
166            }
167            ExprKind::Let { variable_id, .. } => {
168                shadowed_let_bindings.push(variable_id.name());
169            }
170            ExprKind::Identifier { variable_id } => {
171                let name = variable_id.name();
172                if let Some(mi) = match_identifiers.iter().find(|x| x.name == name) {
173                    if !shadowed_let_bindings.contains(&name) {
174                        if let ExprKind::Identifier {
175                            variable_id: ref mut vid,
176                        } = arena.expr_mut(id).kind
177                        {
178                            *vid = VariableId::MatchIdentifier(mi.clone());
179                        }
180                    }
181                }
182            }
183            _ => {}
184        }
185    }
186
187    index
188}
189
190fn process_arm_arena(arm: MatchArmNode, global_arm_index: usize, arena: &mut ExprArena) -> usize {
191    let mut match_identifiers = vec![];
192    collect_identifiers_from_arm_pattern(
193        arm.arm_pattern,
194        global_arm_index,
195        arena,
196        &mut match_identifiers,
197    );
198    bind_pattern_match_internal(
199        arm.arm_resolution_expr,
200        arena,
201        global_arm_index,
202        &mut match_identifiers,
203    )
204}
205
206fn collect_identifiers_from_arm_pattern(
207    pat_id: ArmPatternId,
208    global_arm_index: usize,
209    arena: &mut ExprArena,
210    out: &mut Vec<MatchIdentifier>,
211) {
212    let pat = arena.pattern(pat_id).clone();
213    match pat {
214        ArmPatternNode::Literal(expr_id) => {
215            update_identifiers_in_pattern_expr(expr_id, global_arm_index, arena, out);
216        }
217        ArmPatternNode::WildCard => {}
218        ArmPatternNode::As(name, inner) => {
219            out.push(MatchIdentifier::new(name, global_arm_index));
220            collect_identifiers_from_arm_pattern(inner, global_arm_index, arena, out);
221        }
222        ArmPatternNode::Constructor(_, children)
223        | ArmPatternNode::TupleConstructor(children)
224        | ArmPatternNode::ListConstructor(children) => {
225            for child in children {
226                collect_identifiers_from_arm_pattern(child, global_arm_index, arena, out);
227            }
228        }
229        ArmPatternNode::RecordConstructor(fields) => {
230            for (_, child) in fields {
231                collect_identifiers_from_arm_pattern(child, global_arm_index, arena, out);
232            }
233        }
234    }
235}
236
237fn update_identifiers_in_pattern_expr(
238    expr_id: ExprId,
239    global_arm_index: usize,
240    arena: &mut ExprArena,
241    out: &mut Vec<MatchIdentifier>,
242) {
243    let mut order = Vec::new();
244    collect_post_order(expr_id, arena, &mut order);
245    for id in order {
246        let kind = arena.expr(id).kind.clone();
247        if let ExprKind::Identifier { variable_id } = kind {
248            let mi = MatchIdentifier::new(variable_id.name(), global_arm_index);
249            out.push(mi.clone());
250            if let ExprKind::Identifier {
251                variable_id: ref mut vid,
252            } = arena.expr_mut(id).kind
253            {
254                *vid = VariableId::match_identifier(variable_id.name(), global_arm_index);
255            }
256        }
257    }
258}
259
260// -----------------------------------------------------------------------
261// Helpers
262// -----------------------------------------------------------------------
263
264fn patch_identifier_in_subtree(root: ExprId, arena: &mut ExprArena, target: &VariableId) {
265    let mut order = Vec::new();
266    collect_pre_order(root, arena, &mut order);
267    for id in order {
268        let kind = arena.expr(id).kind.clone();
269        if let ExprKind::Identifier { variable_id } = kind {
270            if variable_id.name() == target.name() {
271                if let ExprKind::Identifier {
272                    variable_id: ref mut vid,
273                } = arena.expr_mut(id).kind
274                {
275                    *vid = target.clone();
276                }
277            }
278        }
279    }
280}
281
282fn patch_two_identifiers_in_subtree(
283    root: ExprId,
284    arena: &mut ExprArena,
285    iter_var: &VariableId,
286    reduce_var: &VariableId,
287) {
288    let mut order = Vec::new();
289    collect_pre_order(root, arena, &mut order);
290    for id in order {
291        let kind = arena.expr(id).kind.clone();
292        if let ExprKind::Identifier { variable_id } = kind {
293            let name = variable_id.name();
294            let new_vid = if name == iter_var.name() {
295                Some(iter_var.clone())
296            } else if name == reduce_var.name() {
297                Some(reduce_var.clone())
298            } else {
299                None
300            };
301            if let Some(new_vid) = new_vid {
302                if let ExprKind::Identifier {
303                    variable_id: ref mut vid,
304                } = arena.expr_mut(id).kind
305                {
306                    *vid = new_vid;
307                }
308            }
309        }
310    }
311}
312
313fn collect_post_order(root: ExprId, arena: &ExprArena, out: &mut Vec<ExprId>) {
314    let mut stack = vec![(root, false)];
315    while let Some((id, visited)) = stack.pop() {
316        if visited {
317            out.push(id);
318        } else {
319            stack.push((id, true));
320            for child in children_of(id, arena).into_iter().rev() {
321                stack.push((child, false));
322            }
323        }
324    }
325}
326
327fn collect_pre_order(root: ExprId, arena: &ExprArena, out: &mut Vec<ExprId>) {
328    let mut stack = vec![root];
329    while let Some(id) = stack.pop() {
330        out.push(id);
331        for child in children_of(id, arena).into_iter().rev() {
332            stack.push(child);
333        }
334    }
335}
336
337#[cfg(test)]
338mod name_binding_tests {
339    use bigdecimal::BigDecimal;
340    use test_r::test;
341
342    use crate::call_type::CallType;
343    use crate::function_name::{DynamicParsedFunctionName, DynamicParsedFunctionReference};
344    use crate::{Expr, InferredType, ParsedFunctionSite, VariableId};
345
346    /// Same pipeline as [`crate::type_inference::initial_arena_phase`]: lower → arena bind → rebuild.
347    fn bind_let_assignment_via_arena(expr: &mut Expr) {
348        let (mut arena, types, root) = crate::expr_arena::lower(expr);
349        super::bind_variables_of_let_assignment(root, &mut arena, &types);
350        *expr = crate::expr_arena::rebuild_expr(root, &arena, &types);
351    }
352
353    fn bind_pattern_match_via_arena(expr: &mut Expr) {
354        let (mut arena, types, root) = crate::expr_arena::lower(expr);
355        super::bind_variables_of_pattern_match(root, &mut arena, &types);
356        *expr = crate::expr_arena::rebuild_expr(root, &arena, &types);
357    }
358
359    #[test]
360    fn test_name_binding_simple() {
361        let rib_expr = r#"
362          let x = 1;
363          foo(x)
364        "#;
365
366        let mut expr = Expr::from_text(rib_expr).unwrap();
367
368        bind_let_assignment_via_arena(&mut expr);
369
370        let let_binding = Expr::let_binding_with_variable_id(
371            VariableId::local("x", 0),
372            Expr::number(BigDecimal::from(1)),
373            None,
374        );
375
376        let call_expr = Expr::call(
377            CallType::function_call(
378                DynamicParsedFunctionName {
379                    site: ParsedFunctionSite::Global,
380                    function: DynamicParsedFunctionReference::Function {
381                        function: "foo".to_string(),
382                    },
383                },
384                None,
385            ),
386            vec![Expr::identifier_local("x", 0, None)],
387        );
388
389        let expected = Expr::expr_block(vec![let_binding, call_expr]);
390
391        assert_eq!(expr, expected);
392    }
393
394    #[test]
395    fn test_name_binding_shadowing() {
396        let rib_expr = r#"
397          let x = 1;
398          foo(x);
399          let x = 2;
400          foo(x)
401        "#;
402
403        let mut expr = Expr::from_text(rib_expr).unwrap();
404
405        bind_let_assignment_via_arena(&mut expr);
406
407        let let_binding1 = Expr::let_binding_with_variable_id(
408            VariableId::local("x", 0),
409            Expr::number(BigDecimal::from(1)),
410            None,
411        );
412
413        let let_binding2 = Expr::let_binding_with_variable_id(
414            VariableId::local("x", 1),
415            Expr::number(BigDecimal::from(2)),
416            None,
417        );
418
419        let call_expr1 = Expr::call(
420            CallType::function_call(
421                DynamicParsedFunctionName {
422                    site: ParsedFunctionSite::Global,
423                    function: DynamicParsedFunctionReference::Function {
424                        function: "foo".to_string(),
425                    },
426                },
427                None,
428            ),
429            vec![Expr::identifier_local("x", 0, None)],
430        );
431
432        let call_expr2 = Expr::call(
433            CallType::function_call(
434                DynamicParsedFunctionName {
435                    site: ParsedFunctionSite::Global,
436                    function: DynamicParsedFunctionReference::Function {
437                        function: "foo".to_string(),
438                    },
439                },
440                None,
441            ),
442            vec![Expr::identifier_local("x", 1, None)],
443        );
444
445        let expected = Expr::expr_block(vec![let_binding1, call_expr1, let_binding2, call_expr2]);
446
447        assert_eq!(expr, expected);
448    }
449
450    #[test]
451    fn test_simple_pattern_match_name_binding() {
452        let expr_string = r#"
453          match some(x) {
454            some(x) => x,
455            none => 0
456          }
457        "#;
458
459        let mut expr = Expr::from_text(expr_string).unwrap();
460
461        bind_pattern_match_via_arena(&mut expr);
462
463        assert_eq!(expr, expectations::expected_match(1));
464    }
465
466    #[test]
467    fn test_simple_pattern_match_name_binding_block() {
468        let expr_string = r#"
469          match some(x) {
470            some(x) => x,
471            none => 0
472          };
473
474          match some(x) {
475            some(x) => x,
476            none => 0
477          }
478        "#;
479
480        let mut expr = Expr::from_text(expr_string).unwrap();
481
482        bind_pattern_match_via_arena(&mut expr);
483
484        let first_expr = expectations::expected_match(1);
485        let second_expr = expectations::expected_match(3);
486
487        let block = Expr::expr_block(vec![first_expr, second_expr])
488            .with_inferred_type(InferredType::unknown());
489
490        assert_eq!(expr, block);
491    }
492
493    mod expectations {
494        use crate::{ArmPattern, Expr, InferredType, MatchArm, MatchIdentifier, VariableId};
495        use bigdecimal::BigDecimal;
496
497        pub fn expected_match(index: usize) -> Expr {
498            Expr::pattern_match(
499                Expr::option(Some(Expr::identifier_global("x", None)))
500                    .with_inferred_type(InferredType::option(InferredType::unknown())),
501                vec![
502                    MatchArm {
503                        arm_pattern: ArmPattern::constructor(
504                            "some",
505                            vec![ArmPattern::literal(Expr::identifier_with_variable_id(
506                                VariableId::MatchIdentifier(MatchIdentifier::new(
507                                    "x".to_string(),
508                                    index,
509                                )),
510                                None,
511                            ))],
512                        ),
513                        arm_resolution_expr: Box::new(Expr::identifier_with_variable_id(
514                            VariableId::MatchIdentifier(MatchIdentifier::new(
515                                "x".to_string(),
516                                index,
517                            )),
518                            None,
519                        )),
520                    },
521                    MatchArm {
522                        arm_pattern: ArmPattern::constructor("none", vec![]),
523                        arm_resolution_expr: Box::new(Expr::number(BigDecimal::from(0))),
524                    },
525                ],
526            )
527        }
528    }
529}