rib/type_inference/
variable_binding.rs

1// Copyright 2024-2025 Golem Cloud
2//
3// Licensed under the Golem Source License v1.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://license.golem.cloud/LICENSE
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::{ArmPattern, Expr, ExprVisitor, MatchArm, MatchIdentifier, VariableId};
16use std::collections::HashMap;
17
18// This function will assign ids to variables declared with `let` expressions,
19// and propagate these ids to the usage sites (`Expr::Identifier` nodes).
20pub fn bind_variables_of_let_assignment(expr: &mut Expr) {
21    let mut identifier_id_state = IdentifierVariableIdState::new();
22    let mut visitor = ExprVisitor::bottom_up(expr);
23
24    // Start from the end
25    while let Some(expr) = visitor.pop_front() {
26        match expr {
27            Expr::Let { variable_id, .. } => {
28                let field_name = variable_id.name();
29                identifier_id_state.update_variable_id(&field_name); // Increment the variable_id
30                if let Some(latest_variable_id) = identifier_id_state.lookup(&field_name) {
31                    *variable_id = latest_variable_id.clone();
32                }
33            }
34
35            Expr::Identifier { variable_id, .. } if !variable_id.is_match_binding() => {
36                let field_name = variable_id.name();
37                if let Some(latest_variable_id) = identifier_id_state.lookup(&field_name) {
38                    *variable_id = latest_variable_id.clone();
39                }
40            }
41            _ => {}
42        }
43    }
44}
45
46pub fn bind_variables_of_list_comprehension(expr: &mut Expr) {
47    let mut visitor = ExprVisitor::top_down(expr);
48
49    while let Some(expr) = visitor.pop_front() {
50        if let Expr::ListComprehension {
51            iterated_variable,
52            yield_expr,
53            ..
54        } = expr
55        {
56            *iterated_variable =
57                VariableId::list_comprehension_identifier(iterated_variable.name());
58
59            process_yield_expr_in_comprehension(iterated_variable, yield_expr)
60        }
61    }
62}
63
64pub fn bind_variables_of_list_reduce(expr: &mut Expr) {
65    let mut visitor = ExprVisitor::top_down(expr);
66
67    // Start from the end
68    while let Some(expr) = visitor.pop_front() {
69        if let Expr::ListReduce {
70            reduce_variable,
71            iterated_variable,
72            yield_expr,
73            ..
74        } = expr
75        {
76            // While parser may update this directly, type inference phase
77            // still ensures that these variables are tagged to its appropriately
78            *iterated_variable =
79                VariableId::list_comprehension_identifier(iterated_variable.name());
80
81            *reduce_variable = VariableId::list_reduce_identifier(reduce_variable.name());
82
83            process_yield_expr_in_reduce(reduce_variable, iterated_variable, yield_expr)
84        }
85    }
86}
87
88pub fn bind_variables_of_pattern_match(expr: &mut Expr) {
89    bind_variables_in_pattern_match_internal(expr, 0, &mut []);
90}
91
92fn bind_variables_in_pattern_match_internal(
93    expr: &mut Expr,
94    previous_index: usize,
95    match_identifiers: &mut [MatchIdentifier],
96) -> usize {
97    let mut index = previous_index;
98    let mut queue = ExprVisitor::top_down(expr);
99    let mut shadowed_let_binding = vec![];
100
101    // Start from the end
102    while let Some(expr) = queue.pop_front() {
103        match expr {
104            Expr::PatternMatch { match_arms, .. } => {
105                for arm in match_arms {
106                    // We increment the index for each arm regardless of whether there is an identifier exist or not
107                    index += 1;
108                    let latest = process_arm(arm, index);
109                    // An arm can increment the index if there are nested pattern match arms, and therefore
110                    // set it to the latest max.
111                    index = latest
112                }
113            }
114            Expr::Let { variable_id, .. } => {
115                shadowed_let_binding.push(variable_id.name());
116            }
117            Expr::Identifier { variable_id, .. } => {
118                let identifier_name = variable_id.name();
119                if let Some(x) = match_identifiers.iter().find(|x| x.name == identifier_name) {
120                    if !shadowed_let_binding.contains(&identifier_name) {
121                        *variable_id = VariableId::MatchIdentifier(x.clone());
122                    }
123                }
124            }
125
126            _ => {}
127        }
128    }
129
130    index
131}
132
133fn process_arm(match_arm: &mut MatchArm, global_arm_index: usize) -> usize {
134    let match_arm_pattern = &mut match_arm.arm_pattern;
135
136    pub fn go(
137        arm_pattern: &mut ArmPattern,
138        global_arm_index: usize,
139        match_identifiers: &mut Vec<MatchIdentifier>,
140    ) {
141        match arm_pattern {
142            ArmPattern::Literal(expr) => {
143                let new_match_identifiers =
144                    update_all_identifier_in_lhs_expr(expr, global_arm_index);
145                match_identifiers.extend(new_match_identifiers);
146            }
147
148            ArmPattern::WildCard => {}
149            ArmPattern::As(name, arm_pattern) => {
150                let match_identifier = MatchIdentifier::new(name.clone(), global_arm_index);
151                match_identifiers.push(match_identifier);
152
153                go(arm_pattern, global_arm_index, match_identifiers);
154            }
155
156            ArmPattern::Constructor(_, arm_patterns) => {
157                for arm_pattern in arm_patterns {
158                    go(arm_pattern, global_arm_index, match_identifiers);
159                }
160            }
161
162            ArmPattern::TupleConstructor(arm_patterns) => {
163                for arm_pattern in arm_patterns {
164                    go(arm_pattern, global_arm_index, match_identifiers);
165                }
166            }
167
168            ArmPattern::ListConstructor(arm_patterns) => {
169                for arm_pattern in arm_patterns {
170                    go(arm_pattern, global_arm_index, match_identifiers);
171                }
172            }
173
174            ArmPattern::RecordConstructor(fields) => {
175                for (_, arm_pattern) in fields {
176                    go(arm_pattern, global_arm_index, match_identifiers);
177                }
178            }
179        }
180    }
181
182    let mut match_identifiers = vec![];
183
184    // Recursively identify the arm within an arm literal
185    go(match_arm_pattern, global_arm_index, &mut match_identifiers);
186
187    let resolution_expression = &mut *match_arm.arm_resolution_expr;
188
189    // Continue with original pattern_match_name_binding for resolution expressions
190    // to target nested pattern matching.
191    bind_variables_in_pattern_match_internal(
192        resolution_expression,
193        global_arm_index,
194        &mut match_identifiers,
195    )
196}
197
198fn update_all_identifier_in_lhs_expr(
199    expr: &mut Expr,
200    global_arm_index: usize,
201) -> Vec<MatchIdentifier> {
202    let mut identifier_names = vec![];
203    let mut visitor = ExprVisitor::bottom_up(expr);
204
205    while let Some(expr) = visitor.pop_front() {
206        if let Expr::Identifier { variable_id, .. } = expr {
207            let match_identifier = MatchIdentifier::new(variable_id.name(), global_arm_index);
208            identifier_names.push(match_identifier);
209            let new_variable_id =
210                VariableId::match_identifier(variable_id.name(), global_arm_index);
211            *variable_id = new_variable_id;
212        }
213    }
214
215    identifier_names
216}
217
218fn process_yield_expr_in_comprehension(variable: &mut VariableId, yield_expr: &mut Expr) {
219    let mut visitor = ExprVisitor::top_down(yield_expr);
220
221    while let Some(expr) = visitor.pop_front() {
222        if let Expr::Identifier { variable_id, .. } = expr {
223            if variable.name() == variable_id.name() {
224                *variable_id = variable.clone();
225            }
226        }
227    }
228}
229
230fn process_yield_expr_in_reduce(
231    reduce_variable: &mut VariableId,
232    iterated_variable_id: &mut VariableId,
233    yield_expr: &mut Expr,
234) {
235    let mut visitor = ExprVisitor::top_down(yield_expr);
236
237    while let Some(expr) = visitor.pop_front() {
238        if let Expr::Identifier { variable_id, .. } = expr {
239            if iterated_variable_id.name() == variable_id.name() {
240                *variable_id = iterated_variable_id.clone();
241            } else if reduce_variable.name() == variable_id.name() {
242                *variable_id = reduce_variable.clone()
243            }
244        }
245    }
246}
247
248struct IdentifierVariableIdState(HashMap<String, VariableId>);
249
250impl IdentifierVariableIdState {
251    pub(crate) fn new() -> Self {
252        IdentifierVariableIdState(HashMap::new())
253    }
254
255    pub(crate) fn update_variable_id(&mut self, identifier: &str) {
256        self.0
257            .entry(identifier.to_string())
258            .and_modify(|x| {
259                *x = x.increment_local_variable_id();
260            })
261            .or_insert(VariableId::local(identifier, 0));
262    }
263
264    pub(crate) fn lookup(&self, identifier: &str) -> Option<VariableId> {
265        self.0.get(identifier).cloned()
266    }
267}
268
269#[cfg(test)]
270mod name_binding_tests {
271    use bigdecimal::BigDecimal;
272    use test_r::test;
273
274    use crate::call_type::CallType;
275    use crate::function_name::{DynamicParsedFunctionName, DynamicParsedFunctionReference};
276    use crate::{Expr, InferredType, ParsedFunctionSite, VariableId};
277
278    #[test]
279    fn test_name_binding_simple() {
280        let rib_expr = r#"
281          let x = 1;
282          foo(x)
283        "#;
284
285        let mut expr = Expr::from_text(rib_expr).unwrap();
286
287        // Bind x in let with the x in foo
288        expr.bind_variables_of_let_assignment();
289
290        let let_binding = Expr::let_binding_with_variable_id(
291            VariableId::local("x", 0),
292            Expr::number(BigDecimal::from(1)),
293            None,
294        );
295
296        let call_expr = Expr::call(
297            CallType::function_call(
298                DynamicParsedFunctionName {
299                    site: ParsedFunctionSite::Global,
300                    function: DynamicParsedFunctionReference::Function {
301                        function: "foo".to_string(),
302                    },
303                },
304                None,
305            ),
306            None,
307            vec![Expr::identifier_local("x", 0, None)],
308        );
309
310        let expected = Expr::expr_block(vec![let_binding, call_expr]);
311
312        assert_eq!(expr, expected);
313    }
314
315    #[test]
316    fn test_name_binding_multiple() {
317        let rib_expr = r#"
318          let x = 1;
319          let y = 2;
320          foo(x);
321          foo(y)
322        "#;
323
324        let mut expr = Expr::from_text(rib_expr).unwrap();
325
326        // Bind x in let with the x in foo
327        expr.bind_variables_of_let_assignment();
328
329        let let_binding1 = Expr::let_binding_with_variable_id(
330            VariableId::local("x", 0),
331            Expr::number(BigDecimal::from(1)),
332            None,
333        );
334
335        let let_binding2 = Expr::let_binding_with_variable_id(
336            VariableId::local("y", 0),
337            Expr::number(BigDecimal::from(2)),
338            None,
339        );
340
341        let call_expr1 = Expr::call(
342            CallType::function_call(
343                DynamicParsedFunctionName {
344                    site: ParsedFunctionSite::Global,
345                    function: DynamicParsedFunctionReference::Function {
346                        function: "foo".to_string(),
347                    },
348                },
349                None,
350            ),
351            None,
352            vec![Expr::identifier_local("x", 0, None)],
353        );
354
355        let call_expr2 = Expr::call(
356            CallType::function_call(
357                DynamicParsedFunctionName {
358                    site: ParsedFunctionSite::Global,
359                    function: DynamicParsedFunctionReference::Function {
360                        function: "foo".to_string(),
361                    },
362                },
363                None,
364            ),
365            None,
366            vec![Expr::identifier_local("y", 0, None)],
367        );
368
369        let expected = Expr::expr_block(vec![let_binding1, let_binding2, call_expr1, call_expr2]);
370
371        assert_eq!(expr, expected);
372    }
373
374    #[test]
375    fn test_name_binding_shadowing() {
376        let rib_expr = r#"
377          let x = 1;
378          foo(x);
379          let x = 2;
380          foo(x)
381        "#;
382
383        let mut expr = Expr::from_text(rib_expr).unwrap();
384
385        // Bind x in let with the x in foo
386        expr.bind_variables_of_let_assignment();
387
388        let let_binding1 = Expr::let_binding_with_variable_id(
389            VariableId::local("x", 0),
390            Expr::number(BigDecimal::from(1)),
391            None,
392        );
393
394        let let_binding2 = Expr::let_binding_with_variable_id(
395            VariableId::local("x", 1),
396            Expr::number(BigDecimal::from(2)),
397            None,
398        );
399
400        let call_expr1 = Expr::call(
401            CallType::function_call(
402                DynamicParsedFunctionName {
403                    site: ParsedFunctionSite::Global,
404                    function: DynamicParsedFunctionReference::Function {
405                        function: "foo".to_string(),
406                    },
407                },
408                None,
409            ),
410            None,
411            vec![Expr::identifier_local("x", 0, None)],
412        );
413
414        let call_expr2 = Expr::call(
415            CallType::function_call(
416                DynamicParsedFunctionName {
417                    site: ParsedFunctionSite::Global,
418                    function: DynamicParsedFunctionReference::Function {
419                        function: "foo".to_string(),
420                    },
421                },
422                None,
423            ),
424            None,
425            vec![Expr::identifier_local("x", 1, None)],
426        );
427
428        let expected = Expr::expr_block(vec![let_binding1, call_expr1, let_binding2, call_expr2]);
429
430        assert_eq!(expr, expected);
431    }
432
433    #[test]
434    fn test_simple_pattern_match_name_binding() {
435        // The first x is global and the second x is a match binding
436        let expr_string = r#"
437          match some(x) {
438            some(x) => x,
439            none => 0
440          }
441        "#;
442
443        let mut expr = Expr::from_text(expr_string).unwrap();
444
445        expr.bind_variables_of_pattern_match();
446
447        assert_eq!(expr, expectations::expected_match(1));
448    }
449
450    #[test]
451    fn test_simple_pattern_match_name_binding_with_shadow() {
452        // The first x is global and the second x is a match binding
453        let expr_string = r#"
454          match some(x) {
455            some(x) => {
456              let x = 1;
457              x
458            },
459            none => 0
460          }
461        "#;
462
463        let mut expr = Expr::from_text(expr_string).unwrap();
464
465        expr.bind_variables_of_pattern_match();
466
467        assert_eq!(expr, expectations::expected_match_with_let_binding(1));
468    }
469
470    #[test]
471    fn test_simple_pattern_match_name_binding_block() {
472        // The first x is global and the second x is a match binding
473        let expr_string = r#"
474          match some(x) {
475            some(x) => x,
476            none => 0
477          };
478
479          match some(x) {
480            some(x) => x,
481            none => 0
482          }
483        "#;
484
485        let mut expr = Expr::from_text(expr_string).unwrap();
486
487        expr.bind_variables_of_pattern_match();
488
489        let first_expr = expectations::expected_match(1);
490        let second_expr = expectations::expected_match(3); // 3 because first block has 2 arms
491
492        let block = Expr::expr_block(vec![first_expr, second_expr])
493            .with_inferred_type(InferredType::unknown());
494
495        assert_eq!(expr, block);
496    }
497
498    #[test]
499    fn test_nested_simple_pattern_match_binding() {
500        let expr_string = r#"
501          match ok(some(x)) {
502            ok(x) => match x {
503              some(x) => x,
504              none => 0
505            },
506            err(x) => 0
507          }
508        "#;
509
510        let mut expr = Expr::from_text(expr_string).unwrap();
511
512        expr.bind_variables_of_pattern_match();
513
514        assert_eq!(expr, expectations::expected_nested_match());
515    }
516
517    mod expectations {
518        use crate::{ArmPattern, Expr, InferredType, MatchArm, MatchIdentifier, VariableId};
519        use bigdecimal::BigDecimal;
520
521        pub(crate) fn expected_match(index: usize) -> Expr {
522            Expr::pattern_match(
523                Expr::option(Some(Expr::identifier_global("x", None)))
524                    .with_inferred_type(InferredType::option(InferredType::unknown())),
525                vec![
526                    MatchArm {
527                        arm_pattern: ArmPattern::constructor(
528                            "some",
529                            vec![ArmPattern::literal(Expr::identifier_with_variable_id(
530                                VariableId::MatchIdentifier(MatchIdentifier::new(
531                                    "x".to_string(),
532                                    index,
533                                )),
534                                None,
535                            ))],
536                        ),
537                        arm_resolution_expr: Box::new(Expr::identifier_with_variable_id(
538                            VariableId::MatchIdentifier(MatchIdentifier::new(
539                                "x".to_string(),
540                                index,
541                            )),
542                            None,
543                        )),
544                    },
545                    MatchArm {
546                        arm_pattern: ArmPattern::constructor("none", vec![]),
547                        arm_resolution_expr: Box::new(Expr::number(BigDecimal::from(0))),
548                    },
549                ],
550            )
551        }
552
553        pub(crate) fn expected_match_with_let_binding(index: usize) -> Expr {
554            let let_binding = Expr::let_binding("x", Expr::number(BigDecimal::from(1)), None);
555            let identifier_expr =
556                Expr::identifier_with_variable_id(VariableId::Global("x".to_string()), None);
557            let block = Expr::expr_block(vec![let_binding, identifier_expr]);
558
559            Expr::pattern_match(
560                Expr::option(Some(Expr::identifier_global("x", None))),
561                vec![
562                    MatchArm {
563                        arm_pattern: ArmPattern::constructor(
564                            "some",
565                            vec![ArmPattern::literal(Expr::identifier_with_variable_id(
566                                VariableId::MatchIdentifier(MatchIdentifier::new(
567                                    "x".to_string(),
568                                    index,
569                                )),
570                                None,
571                            ))],
572                        ),
573                        arm_resolution_expr: Box::new(block),
574                    },
575                    MatchArm {
576                        arm_pattern: ArmPattern::constructor("none", vec![]),
577                        arm_resolution_expr: Box::new(Expr::number(BigDecimal::from(0))),
578                    },
579                ],
580            )
581        }
582
583        pub(crate) fn expected_nested_match() -> Expr {
584            Expr::pattern_match(
585                Expr::ok(
586                    Expr::option(Some(Expr::identifier_with_variable_id(
587                        VariableId::Global("x".to_string()),
588                        None,
589                    )))
590                    .with_inferred_type(InferredType::option(InferredType::unknown())),
591                    None,
592                )
593                .with_inferred_type(InferredType::result(
594                    Some(InferredType::option(InferredType::unknown())),
595                    Some(InferredType::unknown()),
596                )),
597                vec![
598                    MatchArm {
599                        arm_pattern: ArmPattern::constructor(
600                            "ok",
601                            vec![ArmPattern::literal(Expr::identifier_with_variable_id(
602                                VariableId::MatchIdentifier(MatchIdentifier::new(
603                                    "x".to_string(),
604                                    1,
605                                )),
606                                None,
607                            ))],
608                        ),
609                        arm_resolution_expr: Box::new(Expr::pattern_match(
610                            Expr::identifier_with_variable_id(
611                                VariableId::MatchIdentifier(MatchIdentifier::new(
612                                    "x".to_string(),
613                                    1,
614                                )),
615                                None,
616                            ),
617                            vec![
618                                MatchArm {
619                                    arm_pattern: ArmPattern::constructor(
620                                        "some",
621                                        vec![ArmPattern::literal(
622                                            Expr::identifier_with_variable_id(
623                                                VariableId::MatchIdentifier(MatchIdentifier::new(
624                                                    "x".to_string(),
625                                                    5,
626                                                )),
627                                                None,
628                                            ),
629                                        )],
630                                    ),
631                                    arm_resolution_expr: Box::new(
632                                        Expr::identifier_with_variable_id(
633                                            VariableId::MatchIdentifier(MatchIdentifier::new(
634                                                "x".to_string(),
635                                                5,
636                                            )),
637                                            None,
638                                        ),
639                                    ),
640                                },
641                                MatchArm {
642                                    arm_pattern: ArmPattern::constructor("none", vec![]),
643                                    arm_resolution_expr: Box::new(Expr::number(BigDecimal::from(
644                                        0,
645                                    ))),
646                                },
647                            ],
648                        )),
649                    },
650                    MatchArm {
651                        arm_pattern: ArmPattern::constructor(
652                            "err",
653                            vec![ArmPattern::literal(Expr::identifier_with_variable_id(
654                                VariableId::MatchIdentifier(MatchIdentifier::new(
655                                    "x".to_string(),
656                                    4,
657                                )),
658                                None,
659                            ))],
660                        ),
661                        arm_resolution_expr: Box::new(Expr::number(BigDecimal::from(0))),
662                    },
663                ],
664            )
665        }
666    }
667}