Skip to main content

aver/
resolver.rs

1/// Compile-time variable resolution pass.
2///
3/// After parsing and before interpretation, this pass walks each `FnDef` body
4/// and replaces `Expr::Ident(name)` with `Expr::Resolved(depth, slot)` for
5/// variables that are local to the function (parameters + bindings).
6///
7/// Global/namespace identifiers are left as `Expr::Ident` — the VM
8/// falls back to HashMap lookup for those.
9///
10/// Only top-level `FnDef` bodies are resolved. Top-level `Stmt` items (globals,
11/// REPL) are not touched.
12use std::collections::HashMap;
13use std::sync::Arc as Rc;
14
15use crate::ast::*;
16
17/// Run the resolver on all top-level function definitions,
18/// then annotate last-use information on all `Resolved` nodes.
19pub fn resolve_program(items: &mut [TopLevel]) {
20    for item in items.iter_mut() {
21        if let TopLevel::FnDef(fd) = item {
22            resolve_fn(fd);
23        }
24    }
25    crate::ir::last_use::annotate_program_last_use(items);
26}
27
28/// Resolve a single function definition.
29fn resolve_fn(fd: &mut FnDef) {
30    let mut local_slots: HashMap<String, u16> = HashMap::new();
31    let mut next_slot: u16 = 0;
32
33    // Params get slots 0..N-1
34    for (param_name, _) in &fd.params {
35        local_slots.insert(param_name.clone(), next_slot);
36        next_slot += 1;
37    }
38
39    // Scan body for val/var bindings to pre-allocate slots
40    collect_binding_slots(fd.body.stmts(), &mut local_slots, &mut next_slot);
41
42    // Resolve expressions in the body
43    let mut body = fd.body.as_ref().clone();
44    resolve_stmts(body.stmts_mut(), &local_slots);
45    fd.body = Rc::new(body);
46
47    fd.resolution = Some(FnResolution {
48        local_count: next_slot,
49        local_slots: Rc::new(local_slots),
50    });
51}
52
53/// Collect all binding names from a statement list and assign slots.
54/// This handles match arms recursively (pattern bindings get slots too).
55fn collect_binding_slots(
56    stmts: &[Stmt],
57    local_slots: &mut HashMap<String, u16>,
58    next_slot: &mut u16,
59) {
60    for stmt in stmts {
61        match stmt {
62            Stmt::Binding(name, _, expr) => {
63                if !local_slots.contains_key(name) {
64                    local_slots.insert(name.clone(), *next_slot);
65                    *next_slot += 1;
66                }
67                collect_expr_bindings(expr, local_slots, next_slot);
68            }
69            Stmt::Expr(expr) => {
70                collect_expr_bindings(expr, local_slots, next_slot);
71            }
72        }
73    }
74}
75
76/// Collect pattern bindings from match expressions inside an expression tree.
77fn collect_expr_bindings(
78    expr: &Spanned<Expr>,
79    local_slots: &mut HashMap<String, u16>,
80    next_slot: &mut u16,
81) {
82    match &expr.node {
83        Expr::Match { subject, arms } => {
84            collect_expr_bindings(subject, local_slots, next_slot);
85            for arm in arms {
86                collect_pattern_bindings(&arm.pattern, local_slots, next_slot);
87                collect_expr_bindings(&arm.body, local_slots, next_slot);
88            }
89        }
90        Expr::BinOp(_, left, right) => {
91            collect_expr_bindings(left, local_slots, next_slot);
92            collect_expr_bindings(right, local_slots, next_slot);
93        }
94        Expr::FnCall(func, args) => {
95            collect_expr_bindings(func, local_slots, next_slot);
96            for arg in args {
97                collect_expr_bindings(arg, local_slots, next_slot);
98            }
99        }
100        Expr::ErrorProp(inner) => {
101            collect_expr_bindings(inner, local_slots, next_slot);
102        }
103        Expr::Constructor(_, Some(inner)) => {
104            collect_expr_bindings(inner, local_slots, next_slot);
105        }
106        Expr::List(elements) => {
107            for elem in elements {
108                collect_expr_bindings(elem, local_slots, next_slot);
109            }
110        }
111        Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
112            for item in items {
113                collect_expr_bindings(item, local_slots, next_slot);
114            }
115        }
116        Expr::MapLiteral(entries) => {
117            for (key, value) in entries {
118                collect_expr_bindings(key, local_slots, next_slot);
119                collect_expr_bindings(value, local_slots, next_slot);
120            }
121        }
122        Expr::InterpolatedStr(parts) => {
123            for part in parts {
124                if let StrPart::Parsed(e) = part {
125                    collect_expr_bindings(e, local_slots, next_slot);
126                }
127            }
128        }
129        Expr::RecordCreate { fields, .. } => {
130            for (_, expr) in fields {
131                collect_expr_bindings(expr, local_slots, next_slot);
132            }
133        }
134        Expr::RecordUpdate { base, updates, .. } => {
135            collect_expr_bindings(base, local_slots, next_slot);
136            for (_, expr) in updates {
137                collect_expr_bindings(expr, local_slots, next_slot);
138            }
139        }
140        Expr::Attr(obj, _) => {
141            collect_expr_bindings(obj, local_slots, next_slot);
142        }
143        Expr::TailCall(boxed) => {
144            for arg in &boxed.args {
145                collect_expr_bindings(arg, local_slots, next_slot);
146            }
147        }
148        // Leaves — no bindings to collect
149        Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved { .. } | Expr::Constructor(_, None) => {}
150    }
151}
152
153/// Assign slots for names introduced by a pattern.
154fn collect_pattern_bindings(
155    pattern: &Pattern,
156    local_slots: &mut HashMap<String, u16>,
157    next_slot: &mut u16,
158) {
159    match pattern {
160        Pattern::Ident(name) => {
161            if !local_slots.contains_key(name) {
162                local_slots.insert(name.clone(), *next_slot);
163                *next_slot += 1;
164            }
165        }
166        Pattern::Cons(head, tail) => {
167            for name in [head, tail] {
168                if name != "_" && !local_slots.contains_key(name) {
169                    local_slots.insert(name.clone(), *next_slot);
170                    *next_slot += 1;
171                }
172            }
173        }
174        Pattern::Constructor(_, bindings) => {
175            for name in bindings {
176                if name != "_" && !local_slots.contains_key(name) {
177                    local_slots.insert(name.clone(), *next_slot);
178                    *next_slot += 1;
179                }
180            }
181        }
182        Pattern::Tuple(items) => {
183            for item in items {
184                collect_pattern_bindings(item, local_slots, next_slot);
185            }
186        }
187        Pattern::Wildcard | Pattern::Literal(_) | Pattern::EmptyList => {}
188    }
189}
190
191/// Resolve `Expr::Ident` → `Expr::Resolved` for locals in an expression.
192fn resolve_expr(expr: &mut Spanned<Expr>, local_slots: &HashMap<String, u16>) {
193    match &mut expr.node {
194        Expr::Ident(name) => {
195            if let Some(&slot) = local_slots.get(name) {
196                expr.node = Expr::Resolved {
197                    slot,
198                    name: name.clone(),
199                    last_use: AnnotBool(false),
200                };
201            }
202            // else: global/namespace — leave as Ident for HashMap fallback
203        }
204        Expr::Resolved { .. } | Expr::Literal(_) => {}
205        Expr::Attr(obj, _) => {
206            resolve_expr(obj, local_slots);
207        }
208        Expr::FnCall(func, args) => {
209            resolve_expr(func, local_slots);
210            for arg in args {
211                resolve_expr(arg, local_slots);
212            }
213        }
214        Expr::BinOp(_, left, right) => {
215            resolve_expr(left, local_slots);
216            resolve_expr(right, local_slots);
217        }
218        Expr::Match { subject, arms } => {
219            resolve_expr(subject, local_slots);
220            for arm in arms {
221                resolve_expr(&mut arm.body, local_slots);
222            }
223        }
224        Expr::Constructor(_, Some(inner)) => {
225            resolve_expr(inner, local_slots);
226        }
227        Expr::Constructor(_, None) => {}
228        Expr::ErrorProp(inner) => {
229            resolve_expr(inner, local_slots);
230        }
231        Expr::InterpolatedStr(parts) => {
232            for part in parts {
233                if let StrPart::Parsed(e) = part {
234                    resolve_expr(e, local_slots);
235                }
236            }
237        }
238        Expr::List(elements) => {
239            for elem in elements {
240                resolve_expr(elem, local_slots);
241            }
242        }
243        Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
244            for item in items {
245                resolve_expr(item, local_slots);
246            }
247        }
248        Expr::MapLiteral(entries) => {
249            for (key, value) in entries {
250                resolve_expr(key, local_slots);
251                resolve_expr(value, local_slots);
252            }
253        }
254        Expr::RecordCreate { fields, .. } => {
255            for (_, expr) in fields {
256                resolve_expr(expr, local_slots);
257            }
258        }
259        Expr::RecordUpdate { base, updates, .. } => {
260            resolve_expr(base, local_slots);
261            for (_, expr) in updates {
262                resolve_expr(expr, local_slots);
263            }
264        }
265        Expr::TailCall(boxed) => {
266            for arg in &mut boxed.args {
267                resolve_expr(arg, local_slots);
268            }
269        }
270    }
271}
272
273/// Resolve expressions inside statements.
274fn resolve_stmts(stmts: &mut [Stmt], local_slots: &HashMap<String, u16>) {
275    for stmt in stmts {
276        match stmt {
277            Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
278                resolve_expr(expr, local_slots);
279            }
280        }
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287
288    #[test]
289    fn resolves_param_to_slot() {
290        let mut fd = FnDef {
291            name: "add".to_string(),
292            line: 1,
293            params: vec![
294                ("a".to_string(), "Int".to_string()),
295                ("b".to_string(), "Int".to_string()),
296            ],
297            return_type: "Int".to_string(),
298            effects: vec![],
299            desc: None,
300            body: Rc::new(FnBody::from_expr(Spanned::bare(Expr::BinOp(
301                BinOp::Add,
302                Box::new(Spanned::bare(Expr::Ident("a".to_string()))),
303                Box::new(Spanned::bare(Expr::Ident("b".to_string()))),
304            )))),
305            resolution: None,
306        };
307        resolve_fn(&mut fd);
308        let res = fd.resolution.as_ref().unwrap();
309        assert_eq!(res.local_slots["a"], 0);
310        assert_eq!(res.local_slots["b"], 1);
311        assert_eq!(res.local_count, 2);
312
313        match fd.body.tail_expr() {
314            Some(Spanned {
315                node: Expr::BinOp(_, left, right),
316                ..
317            }) => {
318                assert_eq!(
319                    left.node,
320                    Expr::Resolved {
321                        slot: 0,
322                        name: "a".to_string(),
323                        last_use: AnnotBool(false)
324                    }
325                );
326                assert_eq!(
327                    right.node,
328                    Expr::Resolved {
329                        slot: 1,
330                        name: "b".to_string(),
331                        last_use: AnnotBool(false)
332                    }
333                );
334            }
335            other => panic!("unexpected body: {:?}", other),
336        }
337    }
338
339    #[test]
340    fn leaves_globals_as_ident() {
341        let mut fd = FnDef {
342            name: "f".to_string(),
343            line: 1,
344            params: vec![("x".to_string(), "Int".to_string())],
345            return_type: "Int".to_string(),
346            effects: vec![],
347            desc: None,
348            body: Rc::new(FnBody::from_expr(Spanned::bare(Expr::FnCall(
349                Box::new(Spanned::bare(Expr::Ident("Console".to_string()))),
350                vec![Spanned::bare(Expr::Ident("x".to_string()))],
351            )))),
352            resolution: None,
353        };
354        resolve_fn(&mut fd);
355        match fd.body.tail_expr() {
356            Some(Spanned {
357                node: Expr::FnCall(func, args),
358                ..
359            }) => {
360                assert_eq!(func.node, Expr::Ident("Console".to_string()));
361                assert_eq!(
362                    args[0].node,
363                    Expr::Resolved {
364                        slot: 0,
365                        name: "x".to_string(),
366                        last_use: AnnotBool(false)
367                    }
368                );
369            }
370            other => panic!("unexpected body: {:?}", other),
371        }
372    }
373
374    #[test]
375    fn resolves_val_in_block_body() {
376        let mut fd = FnDef {
377            name: "f".to_string(),
378            line: 1,
379            params: vec![("x".to_string(), "Int".to_string())],
380            return_type: "Int".to_string(),
381            effects: vec![],
382            desc: None,
383            body: Rc::new(FnBody::Block(vec![
384                Stmt::Binding(
385                    "y".to_string(),
386                    None,
387                    Spanned::bare(Expr::BinOp(
388                        BinOp::Add,
389                        Box::new(Spanned::bare(Expr::Ident("x".to_string()))),
390                        Box::new(Spanned::bare(Expr::Literal(Literal::Int(1)))),
391                    )),
392                ),
393                Stmt::Expr(Spanned::bare(Expr::Ident("y".to_string()))),
394            ])),
395            resolution: None,
396        };
397        resolve_fn(&mut fd);
398        let res = fd.resolution.as_ref().unwrap();
399        assert_eq!(res.local_slots["x"], 0);
400        assert_eq!(res.local_slots["y"], 1);
401        assert_eq!(res.local_count, 2);
402
403        let stmts = fd.body.stmts();
404        // val y = x + 1  →  val y = Resolved(0,0) + 1
405        match &stmts[0] {
406            Stmt::Binding(
407                _,
408                _,
409                Spanned {
410                    node: Expr::BinOp(_, left, _),
411                    ..
412                },
413            ) => {
414                assert_eq!(
415                    left.node,
416                    Expr::Resolved {
417                        slot: 0,
418                        name: "x".to_string(),
419                        last_use: AnnotBool(false)
420                    }
421                );
422            }
423            other => panic!("unexpected stmt: {:?}", other),
424        }
425        // y  →  Resolved(0,1)
426        match &stmts[1] {
427            Stmt::Expr(Spanned {
428                node: Expr::Resolved { slot: 1, .. },
429                ..
430            }) => {}
431            other => panic!("unexpected stmt: {:?}", other),
432        }
433    }
434
435    #[test]
436    fn resolves_match_pattern_bindings() {
437        // fn f(x: Int) -> Int / match x: Result.Ok(v) -> v, _ -> 0
438        let mut fd = FnDef {
439            name: "f".to_string(),
440            line: 1,
441            params: vec![("x".to_string(), "Int".to_string())],
442            return_type: "Int".to_string(),
443            effects: vec![],
444            desc: None,
445            body: Rc::new(FnBody::from_expr(Spanned::new(
446                Expr::Match {
447                    subject: Box::new(Spanned::bare(Expr::Ident("x".to_string()))),
448                    arms: vec![
449                        MatchArm {
450                            pattern: Pattern::Constructor(
451                                "Result.Ok".to_string(),
452                                vec!["v".to_string()],
453                            ),
454                            body: Box::new(Spanned::bare(Expr::Ident("v".to_string()))),
455                        },
456                        MatchArm {
457                            pattern: Pattern::Wildcard,
458                            body: Box::new(Spanned::bare(Expr::Literal(Literal::Int(0)))),
459                        },
460                    ],
461                },
462                1,
463            ))),
464            resolution: None,
465        };
466        resolve_fn(&mut fd);
467        let res = fd.resolution.as_ref().unwrap();
468        // x=0, v=1
469        assert_eq!(res.local_slots["v"], 1);
470
471        match fd.body.tail_expr() {
472            Some(Spanned {
473                node: Expr::Match { arms, .. },
474                ..
475            }) => {
476                assert_eq!(
477                    arms[0].body.node,
478                    Expr::Resolved {
479                        slot: 1,
480                        name: "v".to_string(),
481                        last_use: AnnotBool(false)
482                    }
483                );
484            }
485            other => panic!("unexpected body: {:?}", other),
486        }
487    }
488
489    #[test]
490    fn resolves_match_pattern_bindings_inside_binding_initializer() {
491        let mut fd = FnDef {
492            name: "f".to_string(),
493            line: 1,
494            params: vec![("x".to_string(), "Int".to_string())],
495            return_type: "Int".to_string(),
496            effects: vec![],
497            desc: None,
498            body: Rc::new(FnBody::Block(vec![
499                Stmt::Binding(
500                    "result".to_string(),
501                    None,
502                    Spanned::bare(Expr::Match {
503                        subject: Box::new(Spanned::bare(Expr::Ident("x".to_string()))),
504                        arms: vec![
505                            MatchArm {
506                                pattern: Pattern::Constructor(
507                                    "Option.Some".to_string(),
508                                    vec!["v".to_string()],
509                                ),
510                                body: Box::new(Spanned::bare(Expr::Ident("v".to_string()))),
511                            },
512                            MatchArm {
513                                pattern: Pattern::Wildcard,
514                                body: Box::new(Spanned::bare(Expr::Literal(Literal::Int(0)))),
515                            },
516                        ],
517                    }),
518                ),
519                Stmt::Expr(Spanned::bare(Expr::Ident("result".to_string()))),
520            ])),
521            resolution: None,
522        };
523
524        resolve_fn(&mut fd);
525        let res = fd.resolution.as_ref().unwrap();
526        assert_eq!(res.local_slots["x"], 0);
527        assert_eq!(res.local_slots["result"], 1);
528        assert_eq!(res.local_slots["v"], 2);
529
530        let stmts = fd.body.stmts();
531        match &stmts[0] {
532            Stmt::Binding(
533                _,
534                _,
535                Spanned {
536                    node: Expr::Match { arms, .. },
537                    ..
538                },
539            ) => {
540                assert_eq!(
541                    arms[0].body.node,
542                    Expr::Resolved {
543                        slot: 2,
544                        name: "v".to_string(),
545                        last_use: AnnotBool(false)
546                    }
547                );
548            }
549            other => panic!("unexpected stmt: {:?}", other),
550        }
551
552        match &stmts[1] {
553            Stmt::Expr(Spanned {
554                node: Expr::Resolved { slot: 1, .. },
555                ..
556            }) => {}
557            other => panic!("unexpected stmt: {:?}", other),
558        }
559    }
560}