Skip to main content

rib/type_inference/
variable_binding.rs

1// Copyright 2024-2025 Golem Cloud
2//
3// Licensed under the Golem Source License v1.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://license.golem.cloud/LICENSE
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::{MatchIdentifier, VariableId};
16use std::collections::HashMap;
17
18use crate::expr_arena::{
19    ArmPatternId, ArmPatternNode, ExprArena, ExprId, ExprKind, MatchArmNode, TypeTable,
20};
21use crate::type_inference::expr_visitor::arena::children_of;
22
23// -----------------------------------------------------------------------
24// bind_variables_of_let_assignment
25// -----------------------------------------------------------------------
26
27/// Arena version: assigns local `VariableId`s to `Let` nodes and propagates
28/// them to matching `Identifier` use-sites.
29pub fn bind_variables_of_let_assignment(root: ExprId, arena: &mut ExprArena, _types: &TypeTable) {
30    let mut state: HashMap<String, VariableId> = HashMap::new();
31
32    // Post-order: children before parents — so identifiers inside a let's
33    // rhs are processed before the let itself.
34    let mut order = Vec::new();
35    collect_post_order(root, arena, &mut order);
36
37    for id in order {
38        let kind = arena.expr(id).kind.clone();
39        match kind {
40            ExprKind::Let { variable_id, .. } => {
41                let name = variable_id.name();
42                let next = state
43                    .entry(name.clone())
44                    .and_modify(|x| *x = x.increment_local_variable_id())
45                    .or_insert_with(|| VariableId::local(&name, 0))
46                    .clone();
47                if let ExprKind::Let {
48                    variable_id: ref mut vid,
49                    ..
50                } = arena.expr_mut(id).kind
51                {
52                    *vid = next;
53                }
54            }
55            ExprKind::Identifier { variable_id } if !variable_id.is_match_binding() => {
56                let name = variable_id.name();
57                if let Some(latest) = state.get(&name).cloned() {
58                    if let ExprKind::Identifier {
59                        variable_id: ref mut vid,
60                    } = arena.expr_mut(id).kind
61                    {
62                        *vid = latest;
63                    }
64                }
65            }
66            _ => {}
67        }
68    }
69}
70
71// -----------------------------------------------------------------------
72// bind_variables_of_list_comprehension
73// -----------------------------------------------------------------------
74
75pub fn bind_variables_of_list_comprehension(
76    root: ExprId,
77    arena: &mut ExprArena,
78    _types: &TypeTable,
79) {
80    // Pre-order: process parent before children so the updated variable is
81    // used when we patch identifiers inside the yield expression.
82    let mut order = Vec::new();
83    collect_pre_order(root, arena, &mut order);
84
85    for id in order {
86        let kind = arena.expr(id).kind.clone();
87        if let ExprKind::ListComprehension {
88            mut iterated_variable,
89            yield_expr,
90            ..
91        } = kind
92        {
93            let new_var = VariableId::list_comprehension_identifier(iterated_variable.name());
94            iterated_variable = new_var.clone();
95
96            // patch the node
97            if let ExprKind::ListComprehension {
98                iterated_variable: ref mut v,
99                ..
100            } = arena.expr_mut(id).kind
101            {
102                *v = new_var.clone();
103            }
104
105            patch_identifier_in_subtree(yield_expr, arena, &iterated_variable);
106        }
107    }
108}
109
110// -----------------------------------------------------------------------
111// bind_variables_of_list_reduce
112// -----------------------------------------------------------------------
113
114pub fn bind_variables_of_list_reduce(root: ExprId, arena: &mut ExprArena, _types: &TypeTable) {
115    let mut order = Vec::new();
116    collect_pre_order(root, arena, &mut order);
117
118    for id in order {
119        let kind = arena.expr(id).kind.clone();
120        if let ExprKind::ListReduce {
121            mut reduce_variable,
122            mut iterated_variable,
123            yield_expr,
124            ..
125        } = kind
126        {
127            let new_iter = VariableId::list_comprehension_identifier(iterated_variable.name());
128            let new_reduce = VariableId::list_reduce_identifier(reduce_variable.name());
129            iterated_variable = new_iter.clone();
130            reduce_variable = new_reduce.clone();
131
132            if let ExprKind::ListReduce {
133                reduce_variable: ref mut rv,
134                iterated_variable: ref mut iv,
135                ..
136            } = arena.expr_mut(id).kind
137            {
138                *rv = new_reduce.clone();
139                *iv = new_iter.clone();
140            }
141
142            patch_two_identifiers_in_subtree(
143                yield_expr,
144                arena,
145                &iterated_variable,
146                &reduce_variable,
147            );
148        }
149    }
150}
151
152// -----------------------------------------------------------------------
153// bind_variables_of_pattern_match
154// -----------------------------------------------------------------------
155
156pub fn bind_variables_of_pattern_match(root: ExprId, arena: &mut ExprArena, _types: &TypeTable) {
157    bind_pattern_match_internal(root, arena, 0, &mut []);
158}
159
160fn bind_pattern_match_internal(
161    root: ExprId,
162    arena: &mut ExprArena,
163    previous_index: usize,
164    match_identifiers: &mut [MatchIdentifier],
165) -> usize {
166    let mut index = previous_index;
167    let mut shadowed_let_bindings: Vec<String> = vec![];
168
169    let mut order = Vec::new();
170    collect_pre_order(root, arena, &mut order);
171
172    for id in order {
173        let kind = arena.expr(id).kind.clone();
174        match kind {
175            ExprKind::PatternMatch { match_arms, .. } => {
176                for arm in match_arms {
177                    index += 1;
178                    index = process_arm_arena(arm, index, arena);
179                }
180            }
181            ExprKind::Let { variable_id, .. } => {
182                shadowed_let_bindings.push(variable_id.name());
183            }
184            ExprKind::Identifier { variable_id } => {
185                let name = variable_id.name();
186                if let Some(mi) = match_identifiers.iter().find(|x| x.name == name) {
187                    if !shadowed_let_bindings.contains(&name) {
188                        if let ExprKind::Identifier {
189                            variable_id: ref mut vid,
190                        } = arena.expr_mut(id).kind
191                        {
192                            *vid = VariableId::MatchIdentifier(mi.clone());
193                        }
194                    }
195                }
196            }
197            _ => {}
198        }
199    }
200
201    index
202}
203
204fn process_arm_arena(arm: MatchArmNode, global_arm_index: usize, arena: &mut ExprArena) -> usize {
205    let mut match_identifiers = vec![];
206    collect_identifiers_from_arm_pattern(
207        arm.arm_pattern,
208        global_arm_index,
209        arena,
210        &mut match_identifiers,
211    );
212    bind_pattern_match_internal(
213        arm.arm_resolution_expr,
214        arena,
215        global_arm_index,
216        &mut match_identifiers,
217    )
218}
219
220fn collect_identifiers_from_arm_pattern(
221    pat_id: ArmPatternId,
222    global_arm_index: usize,
223    arena: &mut ExprArena,
224    out: &mut Vec<MatchIdentifier>,
225) {
226    let pat = arena.pattern(pat_id).clone();
227    match pat {
228        ArmPatternNode::Literal(expr_id) => {
229            update_identifiers_in_pattern_expr(expr_id, global_arm_index, arena, out);
230        }
231        ArmPatternNode::WildCard => {}
232        ArmPatternNode::As(name, inner) => {
233            out.push(MatchIdentifier::new(name, global_arm_index));
234            collect_identifiers_from_arm_pattern(inner, global_arm_index, arena, out);
235        }
236        ArmPatternNode::Constructor(_, children)
237        | ArmPatternNode::TupleConstructor(children)
238        | ArmPatternNode::ListConstructor(children) => {
239            for child in children {
240                collect_identifiers_from_arm_pattern(child, global_arm_index, arena, out);
241            }
242        }
243        ArmPatternNode::RecordConstructor(fields) => {
244            for (_, child) in fields {
245                collect_identifiers_from_arm_pattern(child, global_arm_index, arena, out);
246            }
247        }
248    }
249}
250
251fn update_identifiers_in_pattern_expr(
252    expr_id: ExprId,
253    global_arm_index: usize,
254    arena: &mut ExprArena,
255    out: &mut Vec<MatchIdentifier>,
256) {
257    let mut order = Vec::new();
258    collect_post_order(expr_id, arena, &mut order);
259    for id in order {
260        let kind = arena.expr(id).kind.clone();
261        if let ExprKind::Identifier { variable_id } = kind {
262            let mi = MatchIdentifier::new(variable_id.name(), global_arm_index);
263            out.push(mi.clone());
264            if let ExprKind::Identifier {
265                variable_id: ref mut vid,
266            } = arena.expr_mut(id).kind
267            {
268                *vid = VariableId::match_identifier(variable_id.name(), global_arm_index);
269            }
270        }
271    }
272}
273
274// -----------------------------------------------------------------------
275// Helpers
276// -----------------------------------------------------------------------
277
278fn patch_identifier_in_subtree(root: ExprId, arena: &mut ExprArena, target: &VariableId) {
279    let mut order = Vec::new();
280    collect_pre_order(root, arena, &mut order);
281    for id in order {
282        let kind = arena.expr(id).kind.clone();
283        if let ExprKind::Identifier { variable_id } = kind {
284            if variable_id.name() == target.name() {
285                if let ExprKind::Identifier {
286                    variable_id: ref mut vid,
287                } = arena.expr_mut(id).kind
288                {
289                    *vid = target.clone();
290                }
291            }
292        }
293    }
294}
295
296fn patch_two_identifiers_in_subtree(
297    root: ExprId,
298    arena: &mut ExprArena,
299    iter_var: &VariableId,
300    reduce_var: &VariableId,
301) {
302    let mut order = Vec::new();
303    collect_pre_order(root, arena, &mut order);
304    for id in order {
305        let kind = arena.expr(id).kind.clone();
306        if let ExprKind::Identifier { variable_id } = kind {
307            let name = variable_id.name();
308            let new_vid = if name == iter_var.name() {
309                Some(iter_var.clone())
310            } else if name == reduce_var.name() {
311                Some(reduce_var.clone())
312            } else {
313                None
314            };
315            if let Some(new_vid) = new_vid {
316                if let ExprKind::Identifier {
317                    variable_id: ref mut vid,
318                } = arena.expr_mut(id).kind
319                {
320                    *vid = new_vid;
321                }
322            }
323        }
324    }
325}
326
327fn collect_post_order(root: ExprId, arena: &ExprArena, out: &mut Vec<ExprId>) {
328    let mut stack = vec![(root, false)];
329    while let Some((id, visited)) = stack.pop() {
330        if visited {
331            out.push(id);
332        } else {
333            stack.push((id, true));
334            for child in children_of(id, arena).into_iter().rev() {
335                stack.push((child, false));
336            }
337        }
338    }
339}
340
341fn collect_pre_order(root: ExprId, arena: &ExprArena, out: &mut Vec<ExprId>) {
342    let mut stack = vec![root];
343    while let Some(id) = stack.pop() {
344        out.push(id);
345        for child in children_of(id, arena).into_iter().rev() {
346            stack.push(child);
347        }
348    }
349}
350
351#[cfg(test)]
352mod name_binding_tests {
353    use bigdecimal::BigDecimal;
354    use test_r::test;
355
356    use crate::call_type::CallType;
357    use crate::function_name::{DynamicParsedFunctionName, DynamicParsedFunctionReference};
358    use crate::{Expr, InferredType, ParsedFunctionSite, VariableId};
359
360    /// Same pipeline as [`crate::type_inference::initial_arena_phase`]: lower → arena bind → rebuild.
361    fn bind_let_assignment_via_arena(expr: &mut Expr) {
362        let (mut arena, types, root) = crate::expr_arena::lower(expr);
363        super::bind_variables_of_let_assignment(root, &mut arena, &types);
364        *expr = crate::expr_arena::rebuild_expr(root, &arena, &types);
365    }
366
367    fn bind_pattern_match_via_arena(expr: &mut Expr) {
368        let (mut arena, types, root) = crate::expr_arena::lower(expr);
369        super::bind_variables_of_pattern_match(root, &mut arena, &types);
370        *expr = crate::expr_arena::rebuild_expr(root, &arena, &types);
371    }
372
373    #[test]
374    fn test_name_binding_simple() {
375        let rib_expr = r#"
376          let x = 1;
377          foo(x)
378        "#;
379
380        let mut expr = Expr::from_text(rib_expr).unwrap();
381
382        bind_let_assignment_via_arena(&mut expr);
383
384        let let_binding = Expr::let_binding_with_variable_id(
385            VariableId::local("x", 0),
386            Expr::number(BigDecimal::from(1)),
387            None,
388        );
389
390        let call_expr = Expr::call(
391            CallType::function_call(
392                DynamicParsedFunctionName {
393                    site: ParsedFunctionSite::Global,
394                    function: DynamicParsedFunctionReference::Function {
395                        function: "foo".to_string(),
396                    },
397                },
398                None,
399            ),
400            vec![Expr::identifier_local("x", 0, None)],
401        );
402
403        let expected = Expr::expr_block(vec![let_binding, call_expr]);
404
405        assert_eq!(expr, expected);
406    }
407
408    #[test]
409    fn test_name_binding_shadowing() {
410        let rib_expr = r#"
411          let x = 1;
412          foo(x);
413          let x = 2;
414          foo(x)
415        "#;
416
417        let mut expr = Expr::from_text(rib_expr).unwrap();
418
419        bind_let_assignment_via_arena(&mut expr);
420
421        let let_binding1 = Expr::let_binding_with_variable_id(
422            VariableId::local("x", 0),
423            Expr::number(BigDecimal::from(1)),
424            None,
425        );
426
427        let let_binding2 = Expr::let_binding_with_variable_id(
428            VariableId::local("x", 1),
429            Expr::number(BigDecimal::from(2)),
430            None,
431        );
432
433        let call_expr1 = Expr::call(
434            CallType::function_call(
435                DynamicParsedFunctionName {
436                    site: ParsedFunctionSite::Global,
437                    function: DynamicParsedFunctionReference::Function {
438                        function: "foo".to_string(),
439                    },
440                },
441                None,
442            ),
443            vec![Expr::identifier_local("x", 0, None)],
444        );
445
446        let call_expr2 = Expr::call(
447            CallType::function_call(
448                DynamicParsedFunctionName {
449                    site: ParsedFunctionSite::Global,
450                    function: DynamicParsedFunctionReference::Function {
451                        function: "foo".to_string(),
452                    },
453                },
454                None,
455            ),
456            vec![Expr::identifier_local("x", 1, None)],
457        );
458
459        let expected = Expr::expr_block(vec![let_binding1, call_expr1, let_binding2, call_expr2]);
460
461        assert_eq!(expr, expected);
462    }
463
464    #[test]
465    fn test_simple_pattern_match_name_binding() {
466        let expr_string = r#"
467          match some(x) {
468            some(x) => x,
469            none => 0
470          }
471        "#;
472
473        let mut expr = Expr::from_text(expr_string).unwrap();
474
475        bind_pattern_match_via_arena(&mut expr);
476
477        assert_eq!(expr, expectations::expected_match(1));
478    }
479
480    #[test]
481    fn test_simple_pattern_match_name_binding_block() {
482        let expr_string = r#"
483          match some(x) {
484            some(x) => x,
485            none => 0
486          };
487
488          match some(x) {
489            some(x) => x,
490            none => 0
491          }
492        "#;
493
494        let mut expr = Expr::from_text(expr_string).unwrap();
495
496        bind_pattern_match_via_arena(&mut expr);
497
498        let first_expr = expectations::expected_match(1);
499        let second_expr = expectations::expected_match(3);
500
501        let block = Expr::expr_block(vec![first_expr, second_expr])
502            .with_inferred_type(InferredType::unknown());
503
504        assert_eq!(expr, block);
505    }
506
507    mod expectations {
508        use crate::{ArmPattern, Expr, InferredType, MatchArm, MatchIdentifier, VariableId};
509        use bigdecimal::BigDecimal;
510
511        pub fn expected_match(index: usize) -> Expr {
512            Expr::pattern_match(
513                Expr::option(Some(Expr::identifier_global("x", None)))
514                    .with_inferred_type(InferredType::option(InferredType::unknown())),
515                vec![
516                    MatchArm {
517                        arm_pattern: ArmPattern::constructor(
518                            "some",
519                            vec![ArmPattern::literal(Expr::identifier_with_variable_id(
520                                VariableId::MatchIdentifier(MatchIdentifier::new(
521                                    "x".to_string(),
522                                    index,
523                                )),
524                                None,
525                            ))],
526                        ),
527                        arm_resolution_expr: Box::new(Expr::identifier_with_variable_id(
528                            VariableId::MatchIdentifier(MatchIdentifier::new(
529                                "x".to_string(),
530                                index,
531                            )),
532                            None,
533                        )),
534                    },
535                    MatchArm {
536                        arm_pattern: ArmPattern::constructor("none", vec![]),
537                        arm_resolution_expr: Box::new(Expr::number(BigDecimal::from(0))),
538                    },
539                ],
540            )
541        }
542    }
543}